diff --git a/src/generator.rs b/src/generator.rs index 561ffc3..7b36d4d 100644 --- a/src/generator.rs +++ b/src/generator.rs @@ -19,6 +19,21 @@ pub fn to_pascal_case(s: &str) -> String { .collect() } +pub fn to_snake_case(s: &str) -> String { + let mut result = String::new(); + for (i, c) in s.chars().enumerate() { + if c.is_uppercase() { + if i > 0 { + result.push('_'); + } + result.push(c.to_ascii_lowercase()); + } else { + result.push(c); + } + } + result +} + fn map_type_to_rust_accessor(field_type: i32, label: i32) -> (String, String) { if label == 3 { // LABEL_REPEATED @@ -31,23 +46,23 @@ fn map_type_to_rust_accessor(field_type: i32, label: i32) -> (String, String) { match field_type { 9 => ( "&'a str".to_string(), - "str::from_utf8(bytes).map_err(|_| RotoError::WireFormatViolation)".to_string(), + "str::from_utf8(bytes).map_err(|_| crate::RotoError::WireFormatViolation)".to_string(), ), // TYPE_STRING 1 => ( "f64".to_string(), - "Ok(f64::from_le_bytes(bytes.try_into().map_err(|_| RotoError::WireFormatViolation)?))".to_string(), + "Ok(f64::from_le_bytes(bytes.try_into().map_err(|_| crate::RotoError::WireFormatViolation)?))".to_string(), ), // TYPE_DOUBLE 2 => ( "f32".to_string(), - "f32::from_le_bytes(bytes.try_into().map_err(|_| RotoError::WireFormatViolation)?)".to_string(), + "f32::from_le_bytes(bytes.try_into().map_err(|_| crate::RotoError::WireFormatViolation)?)".to_string(), ), // TYPE_FLOAT 3 | 5 | 15 | 17 => ( "i32".to_string(), - "crate::read_varint(bytes).map(|(v, _)| v as i32).map_err(|_| RotoError::WireFormatViolation)".to_string(), + "crate::read_varint(bytes).map(|(v, _)| v as i32).map_err(|_| crate::RotoError::WireFormatViolation)".to_string(), ), // INT/SINT/SFIXED 32 4 | 6 | 13 => ( "u32".to_string(), - "crate::read_varint(bytes).map(|(v, _)| v as u32).map_err(|_| RotoError::WireFormatViolation)".to_string(), + "crate::read_varint(bytes).map(|(v, _)| v as u32).map_err(|_| crate::RotoError::WireFormatViolation)".to_string(), ), // UINT/FIXED 32 16 | 18 => ( "i64".to_string(), @@ -66,6 +81,195 @@ fn map_type_to_rust_accessor(field_type: i32, label: i32) -> (String, String) { } } +fn write_enum(enum_proto: &EnumDescriptorProto, output: &mut String) { + let enum_name = to_pascal_case(enum_proto.name().unwrap()); + + output.push_str(&format!( + "#[derive(Debug, Clone, Copy, PartialEq, Eq)]\n#[repr(i32)]\npub enum {} {{\n", + enum_name + )); + + let mut values = enum_proto.enum_value(); + let mut variant_count = 0; + let mut zero_variant_name = None; + while let Some(val_res) = values.next() { + let (val_data, _) = val_res.expect("Failed to iterate enum"); + let accessor = ProtoAccessor::new(val_data).expect("Failed to parse EnumValueDescriptorProto"); + let (name_bytes, _) = accessor.get_value(1).expect("Enum value name missing"); + let name = str::from_utf8(name_bytes).expect("Enum value name invalid utf8"); + let (num_bytes, _) = accessor.get_value(2).expect("Enum value number missing"); + let (num, _) = crate::read_varint(num_bytes).expect("Enum value number invalid varint"); + + let pascal_name = to_pascal_case(name); + if num == 0 { + zero_variant_name = Some(pascal_name.clone()); + } + output.push_str(&format!(" {} = {},\n", pascal_name, num)); + variant_count += 1; + } + + if zero_variant_name.is_none() { + output.push_str(" Unknown = 0,\n"); + zero_variant_name = Some("Unknown".to_string()); + } + + output.push_str("}\n\n"); + + output.push_str(&format!( + "impl {} {{\n pub fn from_i32(value: i32) -> Self {{\n match value {{\n", + enum_name + )); + + let mut values = enum_proto.enum_value(); + while let Some(val_res) = values.next() { + let (val_data, _) = val_res.expect("Failed to read enum value"); + let accessor = ProtoAccessor::new(val_data).expect("Failed to parse EnumValueDescriptorProto"); + let (name_bytes, _) = accessor.get_value(1).expect("Enum value name missing"); + let name = str::from_utf8(name_bytes).expect("Enum value name invalid utf8"); + let (num_bytes, _) = accessor.get_value(2).expect("Enum value number missing"); + let (num, _) = crate::read_varint(num_bytes).expect("Enum value number invalid varint"); + + output.push_str(&format!(" {} => {}::{},\n", num, enum_name, to_pascal_case(name))); + } + + output.push_str(&format!(" _ => {}::{},\n", enum_name, zero_variant_name.as_ref().unwrap())); + output.push_str(" }\n }\n}\n\n"); +} + +fn write_message(msg_proto: &DescriptorProto, output: &mut String) { + let msg_name = to_pascal_case(msg_proto.name().unwrap()); + + output.push_str(&format!( + "pub struct {}<'a> {{\n accessor: crate::ProtoAccessor<'a>,\n", + msg_name + )); + + let mut fields_info = Vec::new(); + for field_res in msg_proto.field() { + let (field_data, _) = field_res.expect("Failed to iterate field"); + let field_proto = FieldDescriptorProto::new(field_data).expect("Failed to parse FieldDescriptorProto"); + let field_name = field_proto.name().unwrap(); + + let tag = field_proto.number().unwrap(); + let f_type = field_proto.field_type().unwrap() as i32; + let f_label = field_proto.label().unwrap() as i32; + + fields_info.push((field_name.to_string(), tag, f_type, f_label)); + + if f_label == 3 { + output.push_str(&format!(" {}_start: Option,\n", field_name)); + output.push_str(&format!(" {}_end: Option,\n", field_name)); + } else { + output.push_str(&format!(" {}_offset: Option,\n", field_name)); + } + } + output.push_str("}\n\n"); + + output.push_str(&format!("impl<'a> {}<'a> {{\n", msg_name)); + output.push_str(" pub fn new(data: &'a [u8]) -> crate::Result {\n"); + output.push_str(" let accessor = crate::ProtoAccessor::new(data)?;\n"); + + for (name, _, _, label) in &fields_info { + if *label == 3 { + output.push_str(&format!(" let mut {}_start = None;\n", name)); + output.push_str(&format!(" let mut {}_end = None;\n", name)); + } else { + output.push_str(&format!(" let mut {}_offset = None;\n", name)); + } + } + + output.push_str(" for item in accessor.fields() {\n"); + output.push_str(" let (offset, tag, _) = item?;\n"); + + for (name, tag, _, label) in &fields_info { + if *label == 3 { + output.push_str(&format!(" if tag.field_number == {} {{\n", tag)); + output.push_str(&format!(" if {}_start.is_none() {{ {}_start = Some(offset); }}\n", name, name)); + output.push_str(&format!(" {}_end = Some(offset);\n", name)); + output.push_str(" }\n"); + } else { + output.push_str(&format!(" if tag.field_number == {} {{ {}_offset = Some(offset); }}\n", tag, name)); + } + } + output.push_str(" }\n\n"); + + output.push_str(" Ok(Self {\n"); + for (name, _, _, label) in &fields_info { + if *label == 3 { + output.push_str(&format!("{}_start, {}_end,\n", name, name)); + } else { + output.push_str(&format!("{}_offset,\n", name)); + } + } + output.push_str(" })\n }\n\n"); + + for (field_name, tag, f_type, f_label) in fields_info { + let (rust_type, logic) = map_type_to_rust_accessor(f_type, f_label); + let safe_name = if field_name == "type" { format!("r#{}", field_name) } else { field_name.clone() }; + + if f_label == 3 { + output.push_str(&format!(" pub fn {}(&self) -> {} {{\n", safe_name, rust_type)); + output.push_str(&format!(" match (self.{}_start, self.{}_end) {{\n", field_name, field_name)); + output.push_str(&format!(" (Some(start), Some(end)) => self.accessor.iter_repeated_range({}, start, end),\n", tag)); + output.push_str(&format!(" _ => self.accessor.iter_repeated({}),\n", tag)); + output.push_str(" }\n }\n\n"); + } else { + output.push_str(&format!(" pub fn {}(&self) -> crate::Result<{}> {{\n", safe_name, rust_type)); + output.push_str(&format!(" let offset = self.{}_offset.ok_or(crate::RotoError::FieldNotFound)?;\n", field_name)); + output.push_str(" let (bytes, _) = self.accessor.get_value_at(offset)?;\n"); + output.push_str(&format!(" {}\n", logic)); + output.push_str(" }\n\n"); + } + } + output.push_str("}\n\n"); + + // Builder + output.push_str(&format!( + "pub struct {}Builder<'b> {{\n builder: crate::ProtoBuilder<'b>,\n}}\n\nimpl<'b> {}Builder<'b> {{\n", + msg_name, msg_name + )); + output.push_str(&format!( + " pub fn builder(buf: &mut [u8]) -> {}Builder<'_> {{\n {}Builder {{\n builder: crate::ProtoBuilder::new(buf),\n }}\n }}\n\n", + msg_name, msg_name + )); + + for field_res in msg_proto.field() { + let (field_data, _) = field_res.expect("Failed to iterate field"); + let field_proto = FieldDescriptorProto::new(field_data).expect("Failed to parse FieldDescriptorProto"); + let field_name = field_proto.name().unwrap(); + let safe_name = if field_name == "type" { format!("r#{}", field_name) } else { field_name.to_string() }; + let tag = field_proto.number().unwrap(); + let f_type = field_proto.field_type().unwrap() as i32; + let (rust_type, method) = map_type_to_rust_builder(f_type); + output.push_str(&format!( + " pub fn {}(mut self, value: {}) -> crate::Result {{\n self.builder.{}({}, value)?;\n Ok(self)\n }}\n\n", + safe_name, rust_type, method, tag + )); + } + output.push_str(&format!(" pub fn finish(self) -> crate::Result<&'b mut [u8]> {{\n self.builder.finish()\n }}\n}}\n\n")); + + let mut nested_enums = Vec::new(); + for e_res in msg_proto.enum_type() { + if let Ok((e, _)) = e_res { nested_enums.push(e); } + } + let mut nested_msgs = Vec::new(); + for m_res in msg_proto.nested_type() { + if let Ok((m, _)) = m_res { nested_msgs.push(m); } + } + + if !nested_enums.is_empty() || !nested_msgs.is_empty() { + let mod_name = to_snake_case(msg_proto.name().unwrap()); + output.push_str(&format!("pub mod {} {{\n", mod_name)); + for e_data in nested_enums { + write_enum(&EnumDescriptorProto::new(e_data).expect("Failed to parse nested EnumDescriptorProto"), output); + } + for m_data in nested_msgs { + write_message(&DescriptorProto::new(m_data).expect("Failed to parse nested DescriptorProto"), output); + } + output.push_str("}\n\n"); + } +} + fn map_type_to_rust_builder(field_type: i32) -> (String, String) { match field_type { 9 => ("&str".to_string(), "write_string".to_string()), @@ -113,211 +317,13 @@ pub fn generate_rust_code( // Enums for enum_res in file_proto.enum_type() { let (enum_data, _) = enum_res.expect("Failed to iterate enum"); - let enum_proto = EnumDescriptorProto::new(enum_data).expect("Failed to parse EnumDescriptorProto"); - let enum_name = to_pascal_case(enum_proto.name().unwrap()); - - output.push_str(&format!( - "#[derive(Debug, Clone, Copy, PartialEq, Eq)]\n#[repr(i32)]\npub enum {} {{\n", - enum_name - )); - - let mut values = enum_proto.value(); - let mut variant_count = 0; - let mut zero_variant_name = None; - while let Some(val_res) = values.next() { - let (val_data, _) = val_res.expect("Failed to iterate enum"); - let accessor = ProtoAccessor::new(val_data).expect("Failed to parse EnumValueDescriptorProto"); - let (name_bytes, _) = accessor.get_value(1).expect("Enum value name missing"); - let name = str::from_utf8(name_bytes).expect("Enum value name invalid utf8"); - let (num_bytes, _) = accessor.get_value(2).expect("Enum value number missing"); - let (num, _) = crate::read_varint(num_bytes).expect("Enum value number invalid varint"); - - let pascal_name = to_pascal_case(name); - if num == 0 { - zero_variant_name = Some(pascal_name.clone()); - } - output.push_str(&format!(" {} = {},\n", pascal_name, num)); - variant_count += 1; - } - - if zero_variant_name.is_none() { - output.push_str(" Unknown = 0,\n"); - zero_variant_name = Some("Unknown".to_string()); - } - - output.push_str("}\n\n"); - - output.push_str(&format!( - "impl {} {{\n pub fn from_i32(value: i32) -> Self {{\n match value {{\n", - enum_name - )); - - let mut values = enum_proto.value(); - while let Some(val_res) = values.next() { - let (val_data, _) = val_res.expect("Failed to read enum value"); - let accessor = ProtoAccessor::new(val_data).expect("Failed to parse EnumValueDescriptorProto"); - let (name_bytes, _) = accessor.get_value(1).expect("Enum value name missing"); - let name = str::from_utf8(name_bytes).expect("Enum value name invalid utf8"); - let (num_bytes, _) = accessor.get_value(2).expect("Enum value number missing"); - let (num, _) = crate::read_varint(num_bytes).expect("Enum value number invalid varint"); - - output.push_str(&format!(" {} => {}::{},\n", num, enum_name, to_pascal_case(name))); - } - - output.push_str(&format!(" _ => {}::{},\n", enum_name, zero_variant_name.as_ref().unwrap())); - output.push_str(" }\n }\n}\n\n"); + write_enum(&EnumDescriptorProto::new(enum_data).expect("Failed to parse EnumDescriptorProto"), &mut output); } // Messages for msg_res in file_proto.message_type() { let (msg_data, _) = msg_res.expect("Failed to iterate message"); - let msg_proto = DescriptorProto::new(msg_data).expect("Failed to parse DescriptorProto"); - let msg_name = to_pascal_case(msg_proto.name().unwrap()); - - // Accessor Struct Definition - output.push_str(&format!( - "pub struct {}<'a> {{\n accessor: ProtoAccessor<'a>,\n", - msg_name - )); - - let mut fields_info = Vec::new(); - for field_res in msg_proto.field() { - let (field_data, _) = field_res.expect("Failed to iterate field"); - let field_proto = FieldDescriptorProto::new(field_data).expect("Failed to parse FieldDescriptorProto"); - let field_name = field_proto.name().unwrap(); - - let tag = field_proto.number().unwrap(); - let f_type = field_proto.r#type().unwrap() as i32; - let f_label = field_proto.label().unwrap() as i32; - - fields_info.push((field_name.to_string(), tag, f_type, f_label)); - - if f_label == 3 { - output.push_str(&format!(" {}_start: Option,\n", field_name)); - output.push_str(&format!(" {}_end: Option,\n", field_name)); - } else { - output.push_str(&format!(" {}_offset: Option,\n", field_name)); - } - } - output.push_str("}\n\n"); - - // Accessor Implementation - output.push_str(&format!("impl<'a> {}<'a> {{\n", msg_name)); - - // new() method - output.push_str(" pub fn new(data: &'a [u8]) -> Result {\n"); - output.push_str(" let accessor = ProtoAccessor::new(data)?;\n"); - - for (name, _, _, label) in &fields_info { - if *label == 3 { - output.push_str(&format!(" let mut {}_start = None;\n", name)); - output.push_str(&format!(" let mut {}_end = None;\n", name)); - } else { - output.push_str(&format!(" let mut {}_offset = None;\n", name)); - } - } - - output.push_str(" for item in accessor.fields() {\n"); - output.push_str(" let (offset, tag, _) = item?;\n"); - - for (name, tag, _, label) in &fields_info { - if *label == 3 { - output.push_str(&format!( - " if tag.field_number == {} {{\n", tag - )); - output.push_str(&format!(" if {}_start.is_none() {{ {}_start = Some(offset); }}\n", name, name)); - output.push_str(&format!(" {}_end = Some(offset);\n", name)); - output.push_str(" }\n"); - } else { - output.push_str(&format!( - " if tag.field_number == {} {{ {}_offset = Some(offset); }}\n", tag, name - )); - } - } - output.push_str(" }\n\n"); - - output.push_str(" Ok(Self {\n"); - output.push_str(" accessor,\n"); - for (name, _, _, label) in &fields_info { - if *label == 3 { - output.push_str(&format!("{}_start, {}_end,\n", name, name)); - } else { - output.push_str(&format!("{}_offset,\n", name)); - } - } - output.push_str(" })\n }\n\n"); - - // Field Accessors - for (field_name, tag, f_type, f_label) in fields_info { - let (rust_type, logic) = map_type_to_rust_accessor(f_type, f_label); - let safe_name = if field_name == "type" { format!("r#{}", field_name) } else { field_name.clone() }; - - if f_label == 3 { - output.push_str(&format!( - " pub fn {}(&self) -> {} {{\n", - safe_name, rust_type - )); - output.push_str(&format!( - " match (self.{}_start, self.{}_end) {{\n", - field_name, field_name - )); - output.push_str(&format!( - " (Some(start), Some(end)) => self.accessor.iter_repeated_range({}, start, end),\n", - tag - )); - output.push_str(&format!( - " _ => self.accessor.iter_repeated({}),\n", - tag - )); - output.push_str(" }\n }\n\n"); - } else { - output.push_str(&format!( - " pub fn {}(&self) -> Result<{}> {{\n", - safe_name, rust_type - )); - output.push_str(&format!( - " let offset = self.{}_offset.ok_or(RotoError::FieldNotFound)?;\n", - field_name - )); - output.push_str(&format!( - " let (bytes, _) = self.accessor.get_value_at(offset)?;\n", - )); - output.push_str(&format!(" {}\n", logic)); - output.push_str(" }\n\n"); - } - } - output.push_str("}\n\n"); - - // Builder - output.push_str(&format!( - "pub struct {}Builder<'b> {{\n builder: ProtoBuilder<'b>,\n}}\n\nimpl<'b> {}Builder<'b> {{\n", - msg_name, msg_name - )); - output.push_str(&format!( - " pub fn builder(buf: &mut [u8]) -> {}Builder<'_> {{\n {}Builder {{\n builder: ProtoBuilder::new(buf),\n }}\n }}\n\n", - msg_name, msg_name - )); - - for field_res in msg_proto.field() { - let (field_data, _) = field_res.expect("Failed to iterate field"); - let field_proto = FieldDescriptorProto::new(field_data).expect("Failed to parse FieldDescriptorProto"); - let field_name = field_proto.name().unwrap(); - let safe_name = if field_name == "type" { format!("r#{}", field_name) } else { field_name.to_string() }; - - let tag = field_proto.number().unwrap(); - let f_type = field_proto.r#type().unwrap() as i32; - - let (rust_type, method) = map_type_to_rust_builder(f_type); - - output.push_str(&format!( - " pub fn {}(mut self, value: {}) -> Result {{\n self.builder.{}({}, value)?;\n Ok(self)\n }}\n\n", - safe_name, rust_type, method, tag - )); - } - - output.push_str(&format!( - " pub fn finish(self) -> Result<&'b mut [u8]> {{\n self.builder.finish()\n }}\n}}\n\n" - )); + write_message(&DescriptorProto::new(msg_data).expect("Failed to parse DescriptorProto"), &mut output); } generated_files.push((rust_file_name, output)); } diff --git a/tests/build_generated_code.rs b/tests/build_generated_code.rs index fb25bdb..3797c4c 100644 --- a/tests/build_generated_code.rs +++ b/tests/build_generated_code.rs @@ -30,8 +30,8 @@ fn test_generated_code_builds() { let set = roto::proto_gen::google::protobuf::descriptor::FileDescriptorSet::new(&set_buf) .expect("Failed to create FileDescriptorSet"); - let generated_code = roto::generator::generate_rust_code(&set); - assert!(!generated_code.is_empty(), "Generated code should not be empty"); + let generated_files = roto::generator::generate_rust_code(&set, None, false); + assert!(!generated_files.is_empty(), "Generated code should not be empty"); // 2. Setup a temporary Cargo project to verify the code builds let root = std::env::current_dir().expect("Failed to get current directory"); @@ -62,7 +62,12 @@ fn test_generated_code_builds() { // 4. Write the generated code to src/lib.rs // The generated code uses `use crate::{...}`, but it's now in a separate crate. // Replace `crate` with `roto` to reference the types in the dependency. - let final_code = generated_code.replace("use crate::", "use roto::"); + let mut all_code = String::new(); + for (_, content) in generated_files { + all_code.push_str(&content); + all_code.push_str("\n"); + } + let final_code = all_code.replace("use crate::", "use roto::"); let lib_path = temp_project_dir.join("src/lib.rs"); fs::write(lib_path, final_code).expect("Failed to write generated code to src/lib.rs");