From 05e4c275bbcaeaa09bbbbed4f02eb9cabab7ad9f Mon Sep 17 00:00:00 2001 From: charles Date: Mon, 4 May 2026 19:03:56 -0700 Subject: [PATCH] Add raw field iterator and with builder method - Implement RawFieldIterator and ProtoAccessor::raw_fields that yield (field_number, raw_bytes) pairs for each field - Extend Builder with per-field _written flags and add a with() method to copy unseen fields from a source message - Add ProtoBuilder::write_raw to copy pre-encoded field bytes - Add tests for raw-field iteration, verbatim copying, and with() --- src/generator.rs | 73 ++++++++++++++---- src/lib.rs | 158 ++++++++++++++++++++++++++++++++++++++ tests/test_with_method.rs | 80 +++++++++++++++++++ 3 files changed, 297 insertions(+), 14 deletions(-) create mode 100644 tests/test_with_method.rs diff --git a/src/generator.rs b/src/generator.rs index b76e0b2..b8d4d47 100644 --- a/src/generator.rs +++ b/src/generator.rs @@ -257,36 +257,81 @@ fn write_message(msg_proto: &DescriptorProto, output: &mut String) { output.push_str(" }\n\n"); } } + // raw_fields() convenience on the message struct (before closing the impl) + output.push_str(" pub fn raw_fields(&self) -> crate::RawFieldIterator<'a> {\n"); + output.push_str(" self.accessor.raw_fields()\n"); + output.push_str(" }\n\n"); output.push_str("}\n\n"); - // Builder - output.push_str(&format!( - "pub struct {}Builder<'b> {{\n builder: crate::ProtoBuilder<'b>,\n}}\n\nimpl<'b> {}Builder<'b> {{\n", - msg_name, msg_name - )); - output.push_str(&format!( - " pub fn builder(buf: &mut [u8]) -> {}Builder<'_> {{\n {}Builder {{\n builder: crate::ProtoBuilder::new(buf),\n }}\n }}\n\n", - msg_name, msg_name - )); - + // Collect builder field info so we can use it multiple times below. + // Tuple: (field_name, safe_name, tag, rust_type, write_method) + let mut builder_fields: Vec<(String, String, u32, String, String)> = Vec::new(); for field_res in msg_proto.field() { let (field_data, _) = field_res.expect("Failed to iterate field"); let field_proto = FieldDescriptorProto::new(field_data).expect("Failed to parse FieldDescriptorProto"); - let field_name = field_proto.name().unwrap(); + let field_name = field_proto.name().unwrap().to_string(); let safe_name = if field_name == "type" { format!("r#{}", field_name) } else { - field_name.to_string() + field_name.clone() }; let tag = field_proto.number().unwrap(); let f_type = field_proto.r#type().unwrap() as i32; let (rust_type, method) = map_type_to_rust_builder(f_type); + builder_fields.push((field_name, safe_name, tag, rust_type, method)); + } + + // Builder struct — one `_written: bool` flag per field + output.push_str(&format!("pub struct {}Builder<'b> {{\n", msg_name)); + output.push_str(" builder: crate::ProtoBuilder<'b>,\n"); + for (field_name, _, _, _, _) in &builder_fields { + output.push_str(&format!(" {}_written: bool,\n", field_name)); + } + output.push_str(&format!("}}\.\n\nimpl<'b> {}Builder<'b> {{\n", msg_name)); + + // Constructor — initialise every flag to false + output.push_str(&format!( + " pub fn builder(buf: &mut [u8]) -> {}Builder<'_> {{\n {}Builder {{\n", + msg_name, msg_name + )); + output.push_str(" builder: crate::ProtoBuilder::new(buf),\n"); + for (field_name, _, _, _, _) in &builder_fields { + output.push_str(&format!(" {}_written: false,\n", field_name)); + } + output.push_str(" }\n }\n\n"); + + // Per-field setters — mark field as written + for (field_name, safe_name, tag, rust_type, method) in &builder_fields { output.push_str(&format!( - " pub fn {}(mut self, value: {}) -> crate::Result {{\n self.builder.{}({}, value)?;\n Ok(self)\n }}\n\n", - safe_name, rust_type, method, tag + " pub fn {}(mut self, value: {}) -> crate::Result {{\n self.builder.{}({}, value)?;\n self.{}_written = true;\n Ok(self)\n }}\n\n", + safe_name, rust_type, method, tag, field_name )); } + + // with() — copies unseen fields from an existing message + output.push_str(&format!( + " pub fn with(mut self, msg: &{}<'_>) -> crate::Result {{\n", + msg_name + )); + output.push_str(" for item in msg.raw_fields() {\n"); + output.push_str(" let (field_number, raw_bytes) = item?;\n"); + output.push_str(" let is_written = match field_number {\n"); + for (field_name, _, tag, _, _) in &builder_fields { + output.push_str(&format!( + " {} => self.{}_written,\n", + tag, field_name + )); + } + output.push_str(" _ => false,\n"); + output.push_str(" };\n"); + output.push_str(" if !is_written {\n"); + output.push_str(" self.builder.write_raw(raw_bytes)?;\n"); + output.push_str(" }\n"); + output.push_str(" }\n"); + output.push_str(" Ok(self)\n"); + output.push_str(" }\n\n"); + output.push_str(&format!(" pub fn finish(self) -> crate::Result<&'b mut [u8]> {{\n self.builder.finish()\n }}\n}}\n\n")); let mut nested_enums = Vec::new(); diff --git a/src/lib.rs b/src/lib.rs index 92e1fff..e289ddf 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -236,6 +236,17 @@ impl<'a> ProtoAccessor<'a> { ) -> RepeatedFieldIterator<'a> { RepeatedFieldIterator::new_range(self.data, field_number, start, end) } + + /// Returns an iterator that yields `(field_number, raw_bytes)` for every + /// field in the message. `raw_bytes` is the complete on-wire encoding + /// (tag + value, including any length prefix), suitable for passing + /// directly to `ProtoBuilder::write_raw`. + pub fn raw_fields(&self) -> RawFieldIterator<'a> { + RawFieldIterator { + data: self.data, + cursor: 0, + } + } } pub struct FieldIterator<'a> { @@ -346,6 +357,48 @@ impl<'a> Iterator for RepeatedFieldIterator<'a> { } } +/// An iterator that yields `(field_number, raw_bytes)` for every field in a +/// protobuf message, where `raw_bytes` is the complete on-wire encoding of the +/// field: tag varint + value bytes (including the length prefix for +/// length-delimited fields). This is the slice needed by +/// `ProtoBuilder::write_raw` to copy a field verbatim. +pub struct RawFieldIterator<'a> { + data: &'a [u8], + cursor: usize, +} + +impl<'a> Iterator for RawFieldIterator<'a> { + type Item = Result<(u32, &'a [u8])>; + + fn next(&mut self) -> Option { + if self.cursor >= self.data.len() { + return None; + } + let field_start = self.cursor; + let (tag, tag_len) = match Tag::decode(&self.data[self.cursor..]) { + Ok(t) => t, + Err(e) => { + self.cursor = self.data.len(); + return Some(Err(e)); + } + }; + let cursor_after_tag = self.cursor + tag_len; + if cursor_after_tag > self.data.len() { + self.cursor = self.data.len(); + return Some(Err(RotoError::UnexpectedEndOfBuffer)); + } + let value_len = match skip_value(tag.wire_type, &self.data[cursor_after_tag..]) { + Ok(l) => l, + Err(e) => { + self.cursor = self.data.len(); + return Some(Err(e)); + } + }; + self.cursor = cursor_after_tag + value_len; + Some(Ok((tag.field_number, &self.data[field_start..self.cursor]))) + } +} + #[cfg(test)] mod tests { use super::*; @@ -455,6 +508,104 @@ mod tests { assert_eq!(result, Err(RotoError::BufferOverflow)); } + #[test] + fn test_raw_field_iterator_yields_correct_bytes() { + // Build: field 1 = string "hi", field 2 = int32 42 + let mut buf = [0u8; 64]; + let mut builder = ProtoBuilder::new(&mut buf); + builder.write_string(1, "hi").unwrap(); + builder.write_int32(2, 42).unwrap(); + let data = builder.finish().unwrap().to_vec(); + + let acc = ProtoAccessor::new(&data).unwrap(); + let raw: Vec<_> = acc.raw_fields().collect(); + assert_eq!(raw.len(), 2); + + // Field 1: tag = (1 << 3) | 2 = 0x0A, len varint = 0x02, "hi" = [0x68, 0x69] + let (fn1, bytes1) = raw[0].as_ref().unwrap(); + assert_eq!(*fn1, 1); + assert_eq!(*bytes1, [0x0A, 0x02, b'h', b'i']); + + // Field 2: tag = (2 << 3) | 0 = 0x10, varint 42 = 0x2A + let (fn2, bytes2) = raw[1].as_ref().unwrap(); + assert_eq!(*fn2, 2); + assert_eq!(*bytes2, [0x10, 0x2A]); + } + + #[test] + fn test_write_raw_copies_field_verbatim() { + // Build source: field 1 = string "hello", field 2 = int32 99 + let mut src_buf = [0u8; 64]; + let mut src_builder = ProtoBuilder::new(&mut src_buf); + src_builder.write_string(1, "hello").unwrap(); + src_builder.write_int32(2, 99).unwrap(); + let src_data = src_builder.finish().unwrap().to_vec(); + + // Copy every raw field verbatim into a new buffer + let src_acc = ProtoAccessor::new(&src_data).unwrap(); + let mut dst_buf = [0u8; 64]; + let mut dst_builder = ProtoBuilder::new(&mut dst_buf); + for item in src_acc.raw_fields() { + let (_, raw_bytes) = item.unwrap(); + dst_builder.write_raw(raw_bytes).unwrap(); + } + let dst_data = dst_builder.finish().unwrap(); + + // The copy must be byte-identical to the source + assert_eq!(dst_data, src_data.as_slice()); + } + + #[test] + fn test_with_pattern_copies_unseen_fields() { + // Build an existing source message with 3 fields + let mut src_buf = [0u8; 128]; + let mut src_builder = ProtoBuilder::new(&mut src_buf); + src_builder.write_string(1, "original").unwrap(); + src_builder.write_int32(2, 99).unwrap(); + src_builder.write_varint(3, 1u64).unwrap(); // bool + let src_data = src_builder.finish().unwrap().to_vec(); + let src_acc = ProtoAccessor::new(&src_data).unwrap(); + + // Simulate what a generated `with` method does: + // field 1 was explicitly written; fields 2 and 3 come from source. + let field1_written = true; + let field2_written = false; + let field3_written = false; + + let mut dst_buf = [0u8; 128]; + let mut dst_builder = ProtoBuilder::new(&mut dst_buf); + dst_builder.write_string(1, "updated").unwrap(); + + for item in src_acc.raw_fields() { + let (field_number, raw_bytes) = item.unwrap(); + let is_written = match field_number { + 1 => field1_written, + 2 => field2_written, + 3 => field3_written, + _ => false, + }; + if !is_written { + dst_builder.write_raw(raw_bytes).unwrap(); + } + } + let dst_data = dst_builder.finish().unwrap(); + let dst_acc = ProtoAccessor::new(dst_data).unwrap(); + + // Field 1: overridden value + let (val1, _) = dst_acc.get_value(1).unwrap(); + assert_eq!(val1, b"updated"); + + // Field 2: copied from source + let (val2, _) = dst_acc.get_value(2).unwrap(); + let (v2, _) = read_varint(val2).unwrap(); + assert_eq!(v2 as i32, 99); + + // Field 3: copied from source + let (val3, _) = dst_acc.get_value(3).unwrap(); + let (v3, _) = read_varint(val3).unwrap(); + assert_eq!(v3, 1u64); + } + #[test] fn test_protoc_binary_compatibility() { let data = include_bytes!("../data/test_data.pb"); @@ -618,6 +769,13 @@ impl<'a> ProtoBuilder<'a> { self.append_bytes(value) } + /// Appends a pre-encoded field (tag + value bytes) verbatim into the + /// buffer. Use this together with `ProtoAccessor::raw_fields` to copy + /// fields from an existing message into a builder without re-encoding them. + pub fn write_raw(&mut self, raw_bytes: &[u8]) -> Result<()> { + self.append_bytes(raw_bytes) + } + pub fn finish(self) -> Result<&'a mut [u8]> { Ok(&mut self.buf[..self.pos]) } diff --git a/tests/test_with_method.rs b/tests/test_with_method.rs new file mode 100644 index 0000000..7c8dfd9 --- /dev/null +++ b/tests/test_with_method.rs @@ -0,0 +1,80 @@ +use roto::generator::generate_rust_code; +use roto::google::protobuf::compiler::plugin::CodeGeneratorRequest; +use roto::google::protobuf::descriptor::FileDescriptorSet; +use std::fs; + +fn load_generated_code() -> String { + let data = fs::read("data/request.bin").expect("Failed to read data/request.bin"); + let request = CodeGeneratorRequest::new(&data).expect("Failed to parse CodeGeneratorRequest"); + + let mut set_buf = Vec::new(); + for file_res in request.proto_file() { + let (file_data, _) = file_res.expect("Failed to iterate proto_file"); + set_buf.push(10u8); + let len = file_data.len() as u64; + let mut len_buf = [0u8; 10]; + let len_size = roto::write_varint(len, &mut len_buf).unwrap(); + set_buf.extend_from_slice(&len_buf[..len_size]); + set_buf.extend_from_slice(file_data); + } + let set = FileDescriptorSet::new(&set_buf).expect("Failed to create FileDescriptorSet"); + + generate_rust_code(&set, None, false) + .into_iter() + .map(|(_, content)| content) + .collect() +} + +#[test] +fn test_builder_structs_have_written_flags() { + let code = load_generated_code(); + assert!( + code.contains("_written: bool"), + "Builder structs should contain `_written: bool` fields for each proto field" + ); +} + +#[test] +fn test_builder_constructor_initialises_written_flags_to_false() { + let code = load_generated_code(); + assert!( + code.contains("_written: false"), + "Builder constructors should initialise every `_written` flag to false" + ); +} + +#[test] +fn test_builder_setters_mark_field_as_written() { + let code = load_generated_code(); + assert!( + code.contains("_written = true"), + "Each builder setter should set its `_written` flag to true" + ); +} + +#[test] +fn test_builder_has_with_method() { + let code = load_generated_code(); + assert!( + code.contains("pub fn with("), + "Each builder impl should expose a `with` method" + ); +} + +#[test] +fn test_message_structs_have_raw_fields_method() { + let code = load_generated_code(); + assert!( + code.contains("pub fn raw_fields("), + "Each message struct impl should expose a `raw_fields` method" + ); +} + +#[test] +fn test_with_method_uses_write_raw() { + let code = load_generated_code(); + assert!( + code.contains("write_raw(raw_bytes)"), + "The `with` method should call `write_raw` to copy field bytes" + ); +}