diff --git a/src/generator.rs b/src/generator.rs index f70a1de..dd17d52 100644 --- a/src/generator.rs +++ b/src/generator.rs @@ -21,7 +21,7 @@ fn map_type_to_rust_accessor(field_type: i32, label: i32) -> (String, String) { // LABEL_REPEATED return ( "crate::RepeatedFieldIterator<'a>".to_string(), - "self.0.iter_repeated(%d)".to_string(), + "".to_string(), // Not used for repeated fields in the same way ); } @@ -142,15 +142,13 @@ pub fn generate_rust_code(set: &FileDescriptorSet) -> String { let msg_proto = DescriptorProto::new(msg_data).expect("Failed to parse DescriptorProto"); let msg_name = to_pascal_case(msg_proto.name().unwrap()); - // Accessor + // Accessor Struct Definition output.push_str(&format!( - "pub struct {}<'a>(ProtoAccessor<'a>);\n\nimpl<'a> {}<'a> {{\n", - msg_name, msg_name - )); - output.push_str(&format!( - " pub fn new(data: &'a [u8]) -> Result {{\n Ok(Self(ProtoAccessor::new(data)?))\n }}\n\n" + "pub struct {}<'a> {{\n accessor: ProtoAccessor<'a>,\n", + msg_name )); + let mut fields_info = Vec::new(); for field_res in msg_proto.field() { let (field_data, _) = field_res.expect("Failed to iterate field"); let field_proto = FieldDescriptorProto::new(field_data).expect("Failed to parse FieldDescriptorProto"); @@ -159,18 +157,99 @@ pub fn generate_rust_code(set: &FileDescriptorSet) -> String { let f_type = field_proto.field_type().unwrap() as i32; let f_label = field_proto.label().unwrap() as i32; + fields_info.push((field_name.to_string(), tag, f_type, f_label)); + + if f_label == 3 { + output.push_str(&format!(" {}_start: Option,\n", field_name)); + output.push_str(&format!(" {}_end: Option,\n", field_name)); + } else { + output.push_str(&format!(" {}_offset: Option,\n", field_name)); + } + } + output.push_str("}\n\n"); + + // Accessor Implementation + output.push_str(&format!("impl<'a> {}<'a> {{\n", msg_name)); + + // new() method + output.push_str(" pub fn new(data: &'a [u8]) -> Result {\n"); + output.push_str(" let accessor = ProtoAccessor::new(data)?;\n"); + + for (name, _, _, label) in &fields_info { + if *label == 3 { + output.push_str(&format!(" let mut {}_start = None;\n", name)); + output.push_str(&format!(" let mut {}_end = None;\n", name)); + } else { + output.push_str(&format!(" let mut {}_offset = None;\n", name)); + } + } + + output.push_str(" for item in accessor.fields() {\n"); + output.push_str(" let (offset, tag, _) = item?;\n"); + + for (name, tag, _, label) in &fields_info { + if *label == 3 { + output.push_str(&format!( + " if tag.field_number == {} {{\n", tag + )); + output.push_str(&format!(" if {}_start.is_none() {{ {}_start = Some(offset); }}\n", name, name)); + output.push_str(&format!(" {}_end = Some(offset);\n", name)); + output.push_str(" }\n"); + } else { + output.push_str(&format!( + " if tag.field_number == {} {{ {}_offset = Some(offset); }}\n", tag, name + )); + } + } + output.push_str(" }\n\n"); + + output.push_str(" Ok(Self {\n"); + output.push_str(" accessor,\n"); + for (name, _, _, label) in &fields_info { + if *label == 3 { + output.push_str(&format!("{}_start, {}_end,\n", name, name)); + } else { + output.push_str(&format!("{}_offset,\n", name)); + } + } + output.push_str(" })\n }\n\n"); + + // Field Accessors + for (field_name, tag, f_type, f_label) in fields_info { let (rust_type, logic) = map_type_to_rust_accessor(f_type, f_label); if f_label == 3 { output.push_str(&format!( - " pub fn {}(&self) -> {} {{\n {}\n }}\n\n", - field_name, rust_type, logic.replace("%d", &tag.to_string()) + " pub fn {}(&self) -> {} {{\n", + field_name, rust_type )); + output.push_str(&format!( + " match (self.{}_start, self.{}_end) {{\n", + field_name, field_name + )); + output.push_str(&format!( + " (Some(start), Some(end)) => self.accessor.iter_repeated_range({}, start, end),\n", + tag + )); + output.push_str(&format!( + " _ => self.accessor.iter_repeated({}),\n", + tag + )); + output.push_str(" }\n }\n\n"); } else { output.push_str(&format!( - " pub fn {}(&self) -> Result<{}> {{\n let (bytes, _) = self.0.get_value({})?;\n {}\n }}\n\n", - field_name, rust_type, tag, logic + " pub fn {}(&self) -> Result<{}> {{\n", + field_name, rust_type )); + output.push_str(&format!( + " let offset = self.{}_offset.ok_or(RotoError::FieldNotFound)?;\n", + field_name + )); + output.push_str(&format!( + " let (bytes, _) = self.accessor.get_value_at(offset)?;\n", + )); + output.push_str(&format!(" {}\n", logic)); + output.push_str(" }\n\n"); } } output.push_str("}\n\n"); diff --git a/src/lib.rs b/src/lib.rs index a3e1d41..38f4edf 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -185,7 +185,7 @@ impl<'a> ProtoAccessor<'a> { pub fn get_value(&self, field_number: u32) -> Result<(&'a [u8], WireType)> { let mut last_value = None; for item in self.fields() { - let (tag, value) = item?; + let (_offset, tag, value) = item?; if tag.field_number == field_number { last_value = Some((value, tag.wire_type)); } @@ -197,6 +197,32 @@ impl<'a> ProtoAccessor<'a> { pub fn iter_repeated(&self, field_number: u32) -> RepeatedFieldIterator<'a> { RepeatedFieldIterator::new(self.data, field_number) } + + /// Returns the value and wire type of a field at a specific offset. + pub fn get_value_at(&self, offset: usize) -> Result<(&'a [u8], WireType)> { + if offset >= self.data.len() { + return Err(RotoError::UnexpectedEndOfBuffer); + } + let (tag, tag_len) = Tag::decode(&self.data[offset..])?; + let cursor_after_tag = offset + tag_len; + if cursor_after_tag > self.data.len() { + return Err(RotoError::UnexpectedEndOfBuffer); + } + let value_len = skip_value(tag.wire_type, &self.data[cursor_after_tag..])?; + let (value_offset, actual_value_len) = match tag.wire_type { + WireType::LengthDelimited => { + let (_, varint_len) = read_varint(&self.data[cursor_after_tag..])?; + (cursor_after_tag + varint_len, value_len - varint_len) + } + _ => (cursor_after_tag, value_len), + }; + Ok((&self.data[value_offset..value_offset + actual_value_len], tag.wire_type)) + } + + /// Returns an iterator that scans a specific range of the buffer for all occurrences of the specified field. + pub fn iter_repeated_range(&self, field_number: u32, start: usize, end: usize) -> RepeatedFieldIterator<'a> { + RepeatedFieldIterator::new_range(self.data, field_number, start, end) + } } pub struct FieldIterator<'a> { @@ -205,7 +231,7 @@ pub struct FieldIterator<'a> { } impl<'a> Iterator for FieldIterator<'a> { - type Item = Result<(Tag, &'a [u8])>; + type Item = Result<(usize, Tag, &'a [u8])>; fn next(&mut self) -> Option { if self.cursor >= self.data.len() { @@ -250,23 +276,36 @@ impl<'a> Iterator for FieldIterator<'a> { self.cursor = cursor_after_tag + value_len; - Some(Ok((tag, &self.data[value_offset..value_offset + actual_value_len]))) + Some(Ok((self.cursor - tag_len - value_len, tag, &self.data[value_offset..value_offset + actual_value_len]))) } } pub struct RepeatedFieldIterator<'a> { iterator: FieldIterator<'a>, field_number: u32, + end_offset: Option, } impl<'a> RepeatedFieldIterator<'a> { - fn new(data: &'a [u8], field_number: u32) -> Self { + pub fn new(data: &'a [u8], field_number: u32) -> Self { Self { iterator: FieldIterator { data, cursor: 0, }, field_number, + end_offset: None, + } + } + + pub fn new_range(data: &'a [u8], field_number: u32, start: usize, end: usize) -> Self { + Self { + iterator: FieldIterator { + data, + cursor: start, + }, + field_number, + end_offset: Some(end), } } } @@ -277,7 +316,12 @@ impl<'a> Iterator for RepeatedFieldIterator<'a> { fn next(&mut self) -> Option { while let Some(item) = self.iterator.next() { match item { - Ok((tag, value)) if tag.field_number == self.field_number => { + Ok((offset, tag, value)) if tag.field_number == self.field_number => { + if let Some(end) = self.end_offset { + if offset > end { + return None; + } + } return Some(Ok((value, tag.wire_type))); } Ok(_) => continue,