From a20eed722340b2ff8ec78381b93e6c12a64ab0a6 Mon Sep 17 00:00:00 2001 From: charles Date: Thu, 7 May 2026 20:15:08 -0700 Subject: [PATCH] Add support for protobuf map fields Update the generator to detect map fields and use MapFieldIterator. Implement MapFieldIterator in the runtime to handle key-value pair extraction and add write_map_entry to ProtoBuilder. Add tests to verify that map-bearing messages generate and compile correctly. --- .gitignore | 1 + codegen/data/test_map.desc | 9 +++++ codegen/data/test_map.proto | 7 ++++ codegen/src/generator.rs | 52 +++++++++++++++++-------- codegen/tests/test_map_build.rs | 67 +++++++++++++++++++++++++++++++++ runtime/src/lib.rs | 47 ++++++++++++++++++++++- 6 files changed, 167 insertions(+), 16 deletions(-) create mode 100644 codegen/data/test_map.desc create mode 100644 codegen/data/test_map.proto create mode 100644 codegen/tests/test_map_build.rs diff --git a/.gitignore b/.gitignore index 8f3612a..dd0cb92 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ /target test_gen_project test_types_gen_project +test_map_gen_project diff --git a/codegen/data/test_map.desc b/codegen/data/test_map.desc new file mode 100644 index 0000000..89935fb --- /dev/null +++ b/codegen/data/test_map.desc @@ -0,0 +1,9 @@ + +« +codegen/data/test_map.proto roto.test"y +MapTest4 +my_map ( 2.roto.test.MapTest.MyMapEntryRmyMap8 + +MyMapEntry +key ( Rkey +value (Rvalue:8bproto3 \ No newline at end of file diff --git a/codegen/data/test_map.proto b/codegen/data/test_map.proto new file mode 100644 index 0000000..9053ebb --- /dev/null +++ b/codegen/data/test_map.proto @@ -0,0 +1,7 @@ +syntax = "proto3"; + +package roto.test; + +message MapTest { + map my_map = 1; +} diff --git a/codegen/src/generator.rs b/codegen/src/generator.rs index abe21b8..78b55c2 100644 --- a/codegen/src/generator.rs +++ b/codegen/src/generator.rs @@ -1,6 +1,6 @@ use crate::google::protobuf::descriptor::{ DescriptorProto, EnumDescriptorProto, FieldDescriptorProto, FileDescriptorProto, - FileDescriptorSet, + FileDescriptorSet, MessageOptions, }; use roto_runtime::ProtoAccessor; use std::collections::{HashMap, HashSet}; @@ -33,11 +33,16 @@ pub fn to_snake_case(s: &str) -> String { result } -fn map_type_to_rust_accessor(field_type: i32, label: i32) -> (String, String) { +fn map_type_to_rust_accessor(field_type: i32, label: i32, is_map: bool) -> (String, String) { if label == 3 { // LABEL_REPEATED + let iterator_type = if is_map { + "roto_runtime::MapFieldIterator<'a>" + } else { + "roto_runtime::RepeatedFieldIterator<'a>" + }; return ( - "roto_runtime::RepeatedFieldIterator<'a>".to_string(), + iterator_type.to_string(), "".to_string(), // Not used for repeated fields in the same way ); } @@ -159,14 +164,23 @@ fn write_message(msg_proto: &DescriptorProto, output: &mut String) { let tag = field_proto.number().unwrap(); let f_type = field_proto.r#type().unwrap() as i32; let f_label = field_proto.label().unwrap() as i32; + let is_map = field_proto + .options() + .map(|opt| { + MessageOptions::new(opt) + .unwrap() + .map_entry() + .unwrap_or(false) + }) + .unwrap_or(false); - fields_info.push((field_name.to_string(), tag, f_type, f_label)); + fields_info.push((field_name.to_string(), tag, f_type, f_label, is_map)); } output.push_str(&format!("pub struct {}<'a> {{\n", msg_name)); output.push_str(" accessor: roto_runtime::ProtoAccessor<'a>,\n"); - for (field_name, _tag, _f_type, f_label) in &fields_info { + for (field_name, _tag, _f_type, f_label, _is_map) in &fields_info { if *f_label == 3 { output.push_str(&format!(" {}_start: Option,\n", field_name)); output.push_str(&format!(" {}_end: Option,\n", field_name)); @@ -180,7 +194,7 @@ fn write_message(msg_proto: &DescriptorProto, output: &mut String) { output.push_str(" pub fn new(data: &'a [u8]) -> roto_runtime::Result {\n"); output.push_str(" let accessor = roto_runtime::ProtoAccessor::new(data)?;\n"); if !fields_info.is_empty() { - for (name, _, _, label) in &fields_info { + for (name, _, _, label, _is_map) in &fields_info { if *label == 3 { output.push_str(&format!(" let mut {}_start = None;\n", name)); output.push_str(&format!(" let mut {}_end = None;\n", name)); @@ -192,7 +206,7 @@ fn write_message(msg_proto: &DescriptorProto, output: &mut String) { output.push_str(" for item in accessor.fields() {\n"); output.push_str(" let (offset, tag, _) = item?;\n"); - for (name, tag, _, label) in &fields_info { + for (name, tag, _, label, _is_map) in &fields_info { if *label == 3 { output.push_str(&format!(" if tag.field_number == {} {{\n", tag)); output.push_str(&format!( @@ -213,7 +227,7 @@ fn write_message(msg_proto: &DescriptorProto, output: &mut String) { output.push_str(" Ok(Self {\n"); output.push_str(" accessor,\n"); - for (name, _, _, label) in &fields_info { + for (name, _, _, label, _is_map) in &fields_info { if *label == 3 { output.push_str(&format!("{}_start, {}_end,\n", name, name)); } else { @@ -222,8 +236,8 @@ fn write_message(msg_proto: &DescriptorProto, output: &mut String) { } output.push_str(" })\n }\n\n"); - for (field_name, tag, f_type, f_label) in fields_info { - let (rust_type, logic) = map_type_to_rust_accessor(f_type, f_label); + for (field_name, tag, f_type, f_label, is_map) in fields_info { + let (rust_type, logic) = map_type_to_rust_accessor(f_type, f_label, is_map); let safe_name = if field_name == "type" { format!("r#{}", field_name) } else { @@ -239,11 +253,19 @@ fn write_message(msg_proto: &DescriptorProto, output: &mut String) { " match (self.{}_start, self.{}_end) {{\n", field_name, field_name )); - output.push_str(&format!(" (Some(start), Some(end)) => self.accessor.iter_repeated_range({}, start, end),\n", tag)); - output.push_str(&format!( - " _ => self.accessor.iter_repeated({}),\n", - tag - )); + if is_map { + output.push_str(&format!(" (Some(start), Some(end)) => roto_runtime::MapFieldIterator::new(self.accessor.iter_repeated_range({}, start, end)),\n", tag)); + output.push_str(&format!( + " _ => roto_runtime::MapFieldIterator::new(self.accessor.iter_repeated({})),\n", + tag + )); + } else { + output.push_str(&format!(" (Some(start), Some(end)) => self.accessor.iter_repeated_range({}, start, end),\n", tag)); + output.push_str(&format!( + " _ => self.accessor.iter_repeated({}),\n", + tag + )); + } output.push_str(" }\n }\n\n"); } else { output.push_str(&format!( diff --git a/codegen/tests/test_map_build.rs b/codegen/tests/test_map_build.rs new file mode 100644 index 0000000..cf16e11 --- /dev/null +++ b/codegen/tests/test_map_build.rs @@ -0,0 +1,67 @@ +use roto_codegen::google::protobuf::descriptor::FileDescriptorSet; +use std::fs; +use std::process::Command; + +#[test] +fn test_map_generated_code_builds() { + // 1. Load FileDescriptorSet from data/test_map.desc + let desc_path = "data/test_map.desc"; + let data = fs::read(desc_path).expect("Failed to read test_map.desc"); + let set = FileDescriptorSet::new(&data) + .expect("Failed to create FileDescriptorSet from test_map.desc"); + + let generated_files = roto_codegen::generator::generate_rust_code(&set, None, false); + assert!( + !generated_files.is_empty(), + "Generated code should not be empty" + ); + + // 2. Setup a temporary Cargo project to verify the code builds + let root = std::env::current_dir().expect("Failed to get current directory"); + let temp_project_dir = root.join("test_map_gen_project"); + + // Clean up previous runs + if temp_project_dir.exists() { + fs::remove_dir_all(&temp_project_dir).expect("Failed to clean up temp project directory"); + } + + // Create new library project + let status = Command::new("cargo") + .args(["new", "--lib", "test_map_gen_project"]) + .current_dir(&root) + .status() + .expect("Failed to run cargo new"); + assert!(status.success(), "cargo new failed"); + + // 3. Configure the project to depend on the current roto crate + let cargo_toml_path = temp_project_dir.join("Cargo.toml"); + let cargo_toml_content = + fs::read_to_string(&cargo_toml_path).expect("Failed to read Cargo.toml"); + let updated_cargo_toml = format!( + "{}\n\nroto-codegen = {{ path = \"..\" }}\nroto-runtime = {{ path = \"../../runtime\" }}\n\n[workspace]\n", + cargo_toml_content + ); + fs::write(cargo_toml_path, updated_cargo_toml).expect("Failed to write Cargo.toml"); + + // 4. Write the generated code to src/lib.rs + let mut all_code = String::new(); + for (_, content) in generated_files { + all_code.push_str(&content); + all_code.push_str("\n"); + } + let final_code = all_code.replace("use crate::", "use roto::"); + let lib_path = temp_project_dir.join("src/lib.rs"); + fs::write(lib_path, final_code).expect("Failed to write generated code to src/lib.rs"); + + // 5. Attempt to build the project + let build_status = Command::new("cargo") + .args(["--offline", "build"]) + .current_dir(&temp_project_dir) + .status() + .expect("Failed to run cargo build"); + + assert!( + build_status.success(), + "The generated Rust code for test_map.proto failed to build in a standalone project!" + ); +} diff --git a/runtime/src/lib.rs b/runtime/src/lib.rs index 0c71a0f..55d46bb 100644 --- a/runtime/src/lib.rs +++ b/runtime/src/lib.rs @@ -1,6 +1,33 @@ use std::fmt; -#[derive(Debug, PartialEq, Eq)] +pub struct MapFieldIterator<'a> { + inner: RepeatedFieldIterator<'a>, +} + +impl<'a> MapFieldIterator<'a> { + pub fn new(inner: RepeatedFieldIterator<'a>) -> Self { + Self { inner } + } +} + +impl<'a> Iterator for MapFieldIterator<'a> { + type Item = Result<(&'a [u8], &'a [u8])>; + + fn next(&mut self) -> Option { + match self.inner.next() { + Some(Ok((value, _wire_type))) => { + let accessor = ProtoAccessor::new(value).ok()?; + let (key_bytes, _) = accessor.get_value(1).ok()?; + let (val_bytes, _) = accessor.get_value(2).ok()?; + Some(Ok((key_bytes, val_bytes))) + } + Some(Err(e)) => Some(Err(e)), + None => None, + } + } +} + +#[derive(Debug, PartialEq)] pub enum RotoError { UnexpectedEndOfBuffer, InvalidVarint, @@ -769,6 +796,24 @@ impl<'a> ProtoBuilder<'a> { self.append_bytes(raw_bytes) } + pub fn write_map_entry( + &mut self, + field_number: u32, + key_encoded: &[u8], + value_encoded: &[u8], + ) -> Result<()> { + let entry_len = key_encoded.len() + value_encoded.len(); + self.write_tag(field_number, WireType::LengthDelimited)?; + + let mut len_buf = [0u8; 10]; + let len_len = write_varint(entry_len as u64, &mut len_buf)?; + self.append_bytes(&len_buf[..len_len])?; + + self.append_bytes(key_encoded)?; + self.append_bytes(value_encoded)?; + Ok(()) + } + pub fn finish(self) -> Result<&'a mut [u8]> { Ok(&mut self.buf[..self.pos]) }