Compare commits

...

7 Commits

Author SHA1 Message Date
charles 7e368feddf Add method signature for repeated fields 2026-05-07 21:21:29 -07:00
charles d9186e697e Fix merge conflicts and generated code logic in generator.rs 2026-05-07 20:53:08 -07:00
charles 13625a48c9 Add support for Protobuf oneof fields in generator
Generate `which_<oneof>` methods and corresponding enums to handle
oneof fields in generated messages. Also add `has_<field>` helper
methods for all fields.
2026-05-07 20:15:16 -07:00
charles a20eed7223 Add support for protobuf map fields
Update the generator to detect map fields and use MapFieldIterator.
Implement MapFieldIterator in the runtime to handle key-value pair
extraction and add write_map_entry to ProtoBuilder. Add tests to
verify that map-bearing messages generate and compile correctly.
2026-05-07 20:15:08 -07:00
charles 8395195ac1 Clean tests 2026-05-07 17:59:39 -07:00
charles f76a020b1e Add test to verify generated code builds
Create a temporary Cargo project to ensure that the Rust code generated from
`test_types.desc` compiles successfully.
2026-05-06 16:21:21 -07:00
charles 80f3aa49ba Fix bug in codegen 2026-05-06 16:03:56 -07:00
9 changed files with 366 additions and 51 deletions
+2
View File
@@ -1,2 +1,4 @@
/target
test_gen_project
test_types_gen_project
test_map_gen_project
+9
View File
@@ -0,0 +1,9 @@
«
codegen/data/test_map.proto roto.test"y
MapTest4
my_map ( 2.roto.test.MapTest.MyMapEntryRmyMap8
MyMapEntry
key ( Rkey
value (Rvalue:8bproto3
+7
View File
@@ -0,0 +1,7 @@
syntax = "proto3";
package roto.test;
message MapTest {
map<string, int32> my_map = 1;
}
+139 -47
View File
@@ -1,6 +1,6 @@
use crate::google::protobuf::descriptor::{
DescriptorProto, EnumDescriptorProto, FieldDescriptorProto, FileDescriptorProto,
FileDescriptorSet,
FileDescriptorSet, MessageOptions, OneofDescriptorProto,
};
use roto_runtime::ProtoAccessor;
use std::collections::{HashMap, HashSet};
@@ -33,11 +33,16 @@ pub fn to_snake_case(s: &str) -> String {
result
}
fn map_type_to_rust_accessor(field_type: i32, label: i32) -> (String, String) {
fn map_type_to_rust_accessor(field_type: i32, label: i32, is_map: bool) -> (String, String) {
if label == 3 {
// LABEL_REPEATED
let iterator_type = if is_map {
"roto_runtime::MapFieldIterator<'a>"
} else {
"roto_runtime::RepeatedFieldIterator<'a>"
};
return (
"roto_runtime::RepeatedFieldIterator<'a>".to_string(),
iterator_type.to_string(),
"".to_string(), // Not used for repeated fields in the same way
);
}
@@ -159,14 +164,37 @@ fn write_message(msg_proto: &DescriptorProto, output: &mut String) {
let tag = field_proto.number().unwrap();
let f_type = field_proto.r#type().unwrap() as i32;
let f_label = field_proto.label().unwrap() as i32;
let oneof_index = field_proto.oneof_index().ok();
let is_map = field_proto
.options()
.map(|opt| {
MessageOptions::new(opt)
.unwrap()
.map_entry()
.unwrap_or(false)
})
.unwrap_or(false);
fields_info.push((field_name.to_string(), tag, f_type, f_label));
fields_info.push((
field_name.to_string(),
tag,
f_type,
f_label,
oneof_index,
is_map,
));
}
let mut oneofs = Vec::new();
for o_res in msg_proto.oneof_decl() {
let (o, _) = o_res.expect("Failed to iterate oneof");
oneofs.push(o);
}
output.push_str(&format!("pub struct {}<'a> {{\n", msg_name));
output.push_str(" accessor: roto_runtime::ProtoAccessor<'a>,\n");
for (field_name, _tag, _f_type, f_label) in &fields_info {
for (field_name, _tag, _f_type, f_label, _oneof_index, _is_map) in &fields_info {
if *f_label == 3 {
output.push_str(&format!(" {}_start: Option<usize>,\n", field_name));
output.push_str(&format!(" {}_end: Option<usize>,\n", field_name));
@@ -179,41 +207,38 @@ fn write_message(msg_proto: &DescriptorProto, output: &mut String) {
output.push_str(&format!("impl<'a> {}<'a> {{\n", msg_name));
output.push_str(" pub fn new(data: &'a [u8]) -> roto_runtime::Result<Self> {\n");
output.push_str(" let accessor = roto_runtime::ProtoAccessor::new(data)?;\n");
if !fields_info.is_empty() {
for (name, _, _, label) in &fields_info {
if *label == 3 {
output.push_str(&format!(" let mut {}_start = None;\n", name));
output.push_str(&format!(" let mut {}_end = None;\n", name));
} else {
output.push_str(&format!(" let mut {}_offset = None;\n", name));
}
for (name, _, _, label, _oneof_index, _is_map) in &fields_info {
if *label == 3 {
output.push_str(&format!(" let mut {}_start = None;\n", name));
output.push_str(&format!(" let mut {}_end = None;\n", name));
} else {
output.push_str(&format!(" let mut {}_offset = None;\n", name));
}
output.push_str(" for item in accessor.fields() {\n");
output.push_str(" let (offset, tag, _) = item?;\n");
for (name, tag, _, label) in &fields_info {
if *label == 3 {
output.push_str(&format!(" if tag.field_number == {} {{\n", tag));
output.push_str(&format!(
" if {}_start.is_none() {{ {}_start = Some(offset); }}\n",
name, name
));
output.push_str(&format!(" {}_end = Some(offset);\n", name));
output.push_str(" }\n");
} else {
output.push_str(&format!(
" if tag.field_number == {} {{ {}_offset = Some(offset); }}\n",
tag, name
));
}
}
output.push_str(" }\n\n");
}
output.push_str(" for item in accessor.fields() {\n");
output.push_str(" let (offset, tag, _) = item?;\n");
for (name, tag, _, label, _oneof_index, _is_map) in &fields_info {
if *label == 3 {
output.push_str(&format!(" if tag.field_number == {} {{\n", tag));
output.push_str(&format!(
" if {}_start.is_none() {{ {}_start = Some(offset); }}\n",
name, name
));
output.push_str(&format!(" {}_end = Some(offset);\n", name));
output.push_str(" }\n");
} else {
output.push_str(&format!(
" if tag.field_number == {} {{ {}_offset = Some(offset); }}\n",
tag, name
));
}
}
output.push_str(" }\n\n");
output.push_str(" Ok(Self {\n");
output.push_str(" accessor,\n");
for (name, _, _, label) in &fields_info {
for (name, _, _, label, _oneof_index, _is_map) in &fields_info {
if *label == 3 {
output.push_str(&format!("{}_start, {}_end,\n", name, name));
} else {
@@ -222,15 +247,15 @@ fn write_message(msg_proto: &DescriptorProto, output: &mut String) {
}
output.push_str(" })\n }\n\n");
for (field_name, tag, f_type, f_label) in fields_info {
let (rust_type, logic) = map_type_to_rust_accessor(f_type, f_label);
for (field_name, tag, f_type, f_label, _oneof_index, is_map) in &fields_info {
let (rust_type, logic) = map_type_to_rust_accessor(*f_type, *f_label, *is_map);
let safe_name = if field_name == "type" {
format!("r#{}", field_name)
} else {
field_name.clone()
};
if f_label == 3 {
if *f_label == 3 {
output.push_str(&format!(
" pub fn {}(&self) -> {} {{\n",
safe_name, rust_type
@@ -239,11 +264,19 @@ fn write_message(msg_proto: &DescriptorProto, output: &mut String) {
" match (self.{}_start, self.{}_end) {{\n",
field_name, field_name
));
output.push_str(&format!(" (Some(start), Some(end)) => self.accessor.iter_repeated_range({}, start, end),\n", tag));
output.push_str(&format!(
" _ => self.accessor.iter_repeated({}),\n",
tag
));
if *is_map {
output.push_str(&format!(" (Some(start), Some(end)) => roto_runtime::MapFieldIterator::new(self.accessor.iter_repeated_range({}, start, end)),\n", tag));
output.push_str(&format!(
" _ => roto_runtime::MapFieldIterator::new(self.accessor.iter_repeated({})),\n",
tag
));
} else {
output.push_str(&format!(" (Some(start), Some(end)) => self.accessor.iter_repeated_range({}, start, end),\n", tag));
output.push_str(&format!(
" _ => self.accessor.iter_repeated({}),\n",
tag
));
}
output.push_str(" }\n }\n\n");
} else {
output.push_str(&format!(
@@ -257,10 +290,47 @@ fn write_message(msg_proto: &DescriptorProto, output: &mut String) {
output.push_str(" let (bytes, _) = self.accessor.get_value_at(offset)?;\n");
output.push_str(&format!(" {}\n", logic));
output.push_str(" }\n\n");
output.push_str(&format!(
" pub fn has_{}(&self) -> bool {{ self.{}_offset.is_some() }}\n\n",
field_name, field_name
));
}
}
for (oneof_index, oneof_proto) in oneofs.iter().enumerate() {
let oneof_desc =
OneofDescriptorProto::new(*oneof_proto).expect("Failed to parse OneofDescriptorProto");
let oneof_name = oneof_desc.name().unwrap();
let pascal_oneof_name = to_pascal_case(oneof_name);
let snake_oneof_name = to_snake_case(oneof_name);
output.push_str(&format!(
" pub fn which_{}(&self) -> roto_runtime::Result<Option<{}::{}<'a>>> {{\n",
snake_oneof_name, msg_name, pascal_oneof_name
));
for (field_name, _tag, _f_type, _f_label, f_oneof_index, _is_map) in &fields_info {
if *f_oneof_index == Some(oneof_index as i32) {
let safe_field_name = if field_name == "type" {
format!("r#{}", field_name)
} else {
field_name.clone()
};
output.push_str(&format!(
" if self.{}_offset.is_some() {{\n",
field_name
));
output.push_str(&format!(
" return Ok(Some({}::{} (self.{}()?)));\n",
pascal_oneof_name, safe_field_name, safe_field_name
));
output.push_str(" }\n");
}
}
output.push_str(" Ok(None)\n }\n\n");
}
// raw_fields() convenience on the message struct (before closing the impl)
output.push_str(" pub fn raw_fields(&self) -> roto::RawFieldIterator<'a> {\n");
output.push_str(" pub fn raw_fields(&self) -> roto_runtime::RawFieldIterator<'a> {\n");
output.push_str(" self.accessor.raw_fields()\n");
output.push_str(" }\n\n");
output.push_str("}\n\n");
@@ -349,22 +419,44 @@ fn write_message(msg_proto: &DescriptorProto, output: &mut String) {
}
}
if !nested_enums.is_empty() || !nested_msgs.is_empty() {
if !nested_enums.is_empty() || !nested_msgs.is_empty() || !oneofs.is_empty() {
let mod_name = to_snake_case(msg_proto.name().unwrap());
output.push_str(&format!("pub mod {} {{\n", mod_name));
for e_data in nested_enums {
for e_data in &nested_enums {
write_enum(
&EnumDescriptorProto::new(e_data)
.expect("Failed to parse nested EnumDescriptorProto"),
output,
);
}
for m_data in nested_msgs {
for m_data in &nested_msgs {
write_message(
&DescriptorProto::new(m_data).expect("Failed to parse nested DescriptorProto"),
output,
);
}
for (oneof_index, oneof_proto) in oneofs.iter().enumerate() {
let oneof_desc = OneofDescriptorProto::new(*oneof_proto)
.expect("Failed to parse OneofDescriptorProto");
let oneof_name = oneof_desc.name().unwrap();
let pascal_oneof_name = to_pascal_case(oneof_name);
output.push_str(&format!("pub enum {}<'a> {{\n", pascal_oneof_name));
for (field_name, _tag, f_type, f_label, f_oneof_index, _is_map) in &fields_info {
if *f_oneof_index == Some(oneof_index as i32) {
let (rust_type, _) = map_type_to_rust_accessor(*f_type, *f_label, *_is_map);
let safe_field_name = if field_name == "type" {
format!("r#{}", field_name)
} else {
field_name.clone()
};
output.push_str(&format!(" {}({}),\n", safe_field_name, rust_type));
}
}
output.push_str("}\n\n");
}
}
if !nested_enums.is_empty() || !nested_msgs.is_empty() || !oneofs.is_empty() {
output.push_str("}\n\n");
}
}
+2 -3
View File
@@ -71,13 +71,12 @@ fn test_generated_code_builds() {
all_code.push_str(&content);
all_code.push_str("\n");
}
let final_code = all_code.replace("use crate::", "use roto::");
let lib_path = temp_project_dir.join("src/lib.rs");
fs::write(lib_path, final_code).expect("Failed to write generated code to src/lib.rs");
fs::write(lib_path, all_code).expect("Failed to write generated code to src/lib.rs");
// 5. Attempt to build the project
let build_status = Command::new("cargo")
.args(["build"])
.args(["--offline", "build"])
.current_dir(&temp_project_dir)
.status()
.expect("Failed to run cargo build");
+67
View File
@@ -0,0 +1,67 @@
use roto_codegen::google::protobuf::descriptor::FileDescriptorSet;
use std::fs;
use std::process::Command;
#[test]
fn test_map_generated_code_builds() {
// 1. Load FileDescriptorSet from data/test_map.desc
let desc_path = "data/test_map.desc";
let data = fs::read(desc_path).expect("Failed to read test_map.desc");
let set = FileDescriptorSet::new(&data)
.expect("Failed to create FileDescriptorSet from test_map.desc");
let generated_files = roto_codegen::generator::generate_rust_code(&set, None, false);
assert!(
!generated_files.is_empty(),
"Generated code should not be empty"
);
// 2. Setup a temporary Cargo project to verify the code builds
let root = std::env::current_dir().expect("Failed to get current directory");
let temp_project_dir = root.join("test_map_gen_project");
// Clean up previous runs
if temp_project_dir.exists() {
fs::remove_dir_all(&temp_project_dir).expect("Failed to clean up temp project directory");
}
// Create new library project
let status = Command::new("cargo")
.args(["new", "--lib", "test_map_gen_project"])
.current_dir(&root)
.status()
.expect("Failed to run cargo new");
assert!(status.success(), "cargo new failed");
// 3. Configure the project to depend on the current roto crate
let cargo_toml_path = temp_project_dir.join("Cargo.toml");
let cargo_toml_content =
fs::read_to_string(&cargo_toml_path).expect("Failed to read Cargo.toml");
let updated_cargo_toml = format!(
"{}\n\nroto-codegen = {{ path = \"..\" }}\nroto-runtime = {{ path = \"../../runtime\" }}\n\n[workspace]\n",
cargo_toml_content
);
fs::write(cargo_toml_path, updated_cargo_toml).expect("Failed to write Cargo.toml");
// 4. Write the generated code to src/lib.rs
let mut all_code = String::new();
for (_, content) in generated_files {
all_code.push_str(&content);
all_code.push_str("\n");
}
let final_code = all_code.replace("use crate::", "use roto::");
let lib_path = temp_project_dir.join("src/lib.rs");
fs::write(lib_path, final_code).expect("Failed to write generated code to src/lib.rs");
// 5. Attempt to build the project
let build_status = Command::new("cargo")
.args(["--offline", "build"])
.current_dir(&temp_project_dir)
.status()
.expect("Failed to run cargo build");
assert!(
build_status.success(),
"The generated Rust code for test_map.proto failed to build in a standalone project!"
);
}
+22
View File
@@ -0,0 +1,22 @@
use roto_codegen::generator::generate_rust_code;
use roto_codegen::google::protobuf::descriptor::{
DescriptorProto, FieldDescriptorProto, FileDescriptorSet,
};
use std::collections::HashMap;
#[test]
fn test_oneof_generation() {
let mut set = FileDescriptorSet::new(b"").unwrap(); // Simplified for testing
// In a real scenario, we'd build up a FileDescriptorSet from a proto.
// For this unit test, we'll manually construct a DescriptorProto that has a oneof.
// However, generate_rust_code takes a FileDescriptorSet.
// Let's mock a simple setup.
// Since manually constructing FileDescriptorSet is complex, let's instead check if the
// generator logic for oneofs produces the expected strings given a DescriptorProto.
// But the current tests use load_generated_code() which reads from data/request.bin.
// Let's see if we can find a way to test just the write_message function or similar.
}
+72
View File
@@ -0,0 +1,72 @@
use roto_codegen::google::protobuf::descriptor::FileDescriptorSet;
use std::fs;
use std::process::Command;
#[test]
fn test_types_generated_code_builds() {
// 1. Load FileDescriptorSet from data/test_types.desc
let desc_path = "data/test_types.desc";
let data = fs::read(desc_path).expect("Failed to read test_types.desc");
let set = FileDescriptorSet::new(&data)
.expect("Failed to create FileDescriptorSet from test_types.desc");
let generated_files = roto_codegen::generator::generate_rust_code(&set, None, false);
assert!(
!generated_files.is_empty(),
"Generated code should not be empty"
);
// 2. Setup a temporary Cargo project to verify the code builds
let root = std::env::current_dir().expect("Failed to get current directory");
let temp_project_dir = root.join("test_types_gen_project");
// Clean up previous runs
if temp_project_dir.exists() {
fs::remove_dir_all(&temp_project_dir).expect("Failed to clean up temp project directory");
}
// Create new library project
let status = Command::new("cargo")
.args(["new", "--lib", "test_types_gen_project"])
.current_dir(&root)
.status()
.expect("Failed to run cargo new");
assert!(status.success(), "cargo new failed");
// 3. Configure the project to depend on the current roto crate
let cargo_toml_path = temp_project_dir.join("Cargo.toml");
let cargo_toml_content =
fs::read_to_string(&cargo_toml_path).expect("Failed to read Cargo.toml");
let updated_cargo_toml = format!(
"{}\n\nroto-codegen = {{ path = \"..\" }}\nroto-runtime = {{ path = \"../../runtime\" }}\n\n[workspace]\n",
cargo_toml_content
);
fs::write(cargo_toml_path, updated_cargo_toml).expect("Failed to write Cargo.toml");
// 4. Write the generated code to src/lib.rs
let mut all_code = String::new();
for (_, content) in generated_files {
all_code.push_str(&content);
all_code.push_str("\n");
}
// The generated code uses `use crate::{...}`, but it's now in a separate crate.
// Replace `crate` with `roto` to reference the types in the dependency.
// Note: in build_generated_code.rs it does replace("use crate::", "use roto::").
// But here the generated code might not have dependencies since it's a single file.
// However, to be safe and consistent with the template:
let final_code = all_code.replace("use crate::", "use roto::");
let lib_path = temp_project_dir.join("src/lib.rs");
fs::write(lib_path, final_code).expect("Failed to write generated code to src/lib.rs");
// 5. Attempt to build the project
let build_status = Command::new("cargo")
.args(["--offline", "build"])
.current_dir(&temp_project_dir)
.status()
.expect("Failed to run cargo build");
assert!(
build_status.success(),
"The generated Rust code for test_types.proto failed to build in a standalone project!"
);
}
+46 -1
View File
@@ -1,6 +1,33 @@
use std::fmt;
#[derive(Debug, PartialEq, Eq)]
pub struct MapFieldIterator<'a> {
inner: RepeatedFieldIterator<'a>,
}
impl<'a> MapFieldIterator<'a> {
pub fn new(inner: RepeatedFieldIterator<'a>) -> Self {
Self { inner }
}
}
impl<'a> Iterator for MapFieldIterator<'a> {
type Item = Result<(&'a [u8], &'a [u8])>;
fn next(&mut self) -> Option<Self::Item> {
match self.inner.next() {
Some(Ok((value, _wire_type))) => {
let accessor = ProtoAccessor::new(value).ok()?;
let (key_bytes, _) = accessor.get_value(1).ok()?;
let (val_bytes, _) = accessor.get_value(2).ok()?;
Some(Ok((key_bytes, val_bytes)))
}
Some(Err(e)) => Some(Err(e)),
None => None,
}
}
}
#[derive(Debug, PartialEq)]
pub enum RotoError {
UnexpectedEndOfBuffer,
InvalidVarint,
@@ -769,6 +796,24 @@ impl<'a> ProtoBuilder<'a> {
self.append_bytes(raw_bytes)
}
pub fn write_map_entry(
&mut self,
field_number: u32,
key_encoded: &[u8],
value_encoded: &[u8],
) -> Result<()> {
let entry_len = key_encoded.len() + value_encoded.len();
self.write_tag(field_number, WireType::LengthDelimited)?;
let mut len_buf = [0u8; 10];
let len_len = write_varint(entry_len as u64, &mut len_buf)?;
self.append_bytes(&len_buf[..len_len])?;
self.append_bytes(key_encoded)?;
self.append_bytes(value_encoded)?;
Ok(())
}
pub fn finish(self) -> Result<&'a mut [u8]> {
Ok(&mut self.buf[..self.pos])
}