use crate::ProtoAccessor; use crate::google::protobuf::descriptor::{ DescriptorProto, EnumDescriptorProto, FieldDescriptorProto, FileDescriptorProto, FileDescriptorSet, }; use std::collections::{HashMap, HashSet}; use std::str; pub fn to_pascal_case(s: &str) -> String { s.split('_') .map(|word| { let mut chars = word.chars(); match chars.next() { None => String::new(), Some(f) => f.to_uppercase().collect::() + chars.as_str(), } }) .collect() } pub fn to_snake_case(s: &str) -> String { let mut result = String::new(); for (i, c) in s.chars().enumerate() { if c.is_uppercase() { if i > 0 { result.push('_'); } result.push(c.to_ascii_lowercase()); } else { result.push(c); } } result } fn map_type_to_rust_accessor(field_type: i32, label: i32) -> (String, String) { if label == 3 { // LABEL_REPEATED return ( "crate::RepeatedFieldIterator<'a>".to_string(), "".to_string(), // Not used for repeated fields in the same way ); } match field_type { 9 => ( "&'a str".to_string(), "str::from_utf8(bytes).map_err(|_| crate::RotoError::WireFormatViolation)".to_string(), ), // TYPE_STRING 1 => ( "f64".to_string(), "Ok(f64::from_le_bytes(bytes.try_into().map_err(|_| crate::RotoError::WireFormatViolation)?))".to_string(), ), // TYPE_DOUBLE 2 => ( "f32".to_string(), "f32::from_le_bytes(bytes.try_into().map_err(|_| crate::RotoError::WireFormatViolation)?)".to_string(), ), // TYPE_FLOAT 3 | 5 | 15 | 17 => ( "i32".to_string(), "crate::read_varint(bytes).map(|(v, _)| v as i32).map_err(|_| crate::RotoError::WireFormatViolation)".to_string(), ), // INT/SINT/SFIXED 32 4 | 6 | 13 => ( "u32".to_string(), "crate::read_varint(bytes).map(|(v, _)| v as u32).map_err(|_| crate::RotoError::WireFormatViolation)".to_string(), ), // UINT/FIXED 32 16 | 18 => ( "i64".to_string(), "crate::read_varint(bytes).map(|(v, _)| v as i64).map_err(|_| crate::RotoError::WireFormatViolation)".to_string(), ), // SINT/SFIXED 64 7 | 14 => ( "u64".to_string(), "crate::read_varint(bytes).map(|(v, _)| v as u64).map_err(|_| crate::RotoError::WireFormatViolation)".to_string(), ), // UINT/FIXED 64 8 => ( "bool".to_string(), "crate::read_varint(bytes).map(|(v, _)| v != 0).map_err(|_| crate::RotoError::WireFormatViolation)".to_string(), ), // TYPE_BOOL 11 | 12 => ("&'a [u8]".to_string(), "Ok(bytes)".to_string()), // MESSAGE/BYTES _ => ("&'a [u8]".to_string(), "Ok(bytes)".to_string()), } } fn write_enum(enum_proto: &EnumDescriptorProto, output: &mut String) { let enum_name = to_pascal_case(enum_proto.name().unwrap()); output.push_str(&format!( "#[derive(Debug, Clone, Copy, PartialEq, Eq)]\n#[repr(i32)]\npub enum {} {{\n", enum_name )); let mut values = enum_proto.value(); let mut zero_variant_name = None; while let Some(val_res) = values.next() { let (val_data, _) = val_res.expect("Failed to iterate enum"); let accessor = ProtoAccessor::new(val_data).expect("Failed to parse EnumValueDescriptorProto"); let (name_bytes, _) = accessor.get_value(1).expect("Enum value name missing"); let name = str::from_utf8(name_bytes).expect("Enum value name invalid utf8"); let (num_bytes, _) = accessor.get_value(2).expect("Enum value number missing"); let (num, _) = crate::read_varint(num_bytes).expect("Enum value number invalid varint"); let pascal_name = to_pascal_case(name); if num == 0 { zero_variant_name = Some(pascal_name.clone()); } output.push_str(&format!(" {} = {},\n", pascal_name, num)); } if zero_variant_name.is_none() { output.push_str(" Unknown = 0,\n"); zero_variant_name = Some("Unknown".to_string()); } output.push_str("}\n\n"); output.push_str(&format!( "impl {} {{\n pub fn from_i32(value: i32) -> Self {{\n match value {{\n", enum_name )); let mut values = enum_proto.value(); while let Some(val_res) = values.next() { let (val_data, _) = val_res.expect("Failed to read enum value"); let accessor = ProtoAccessor::new(val_data).expect("Failed to parse EnumValueDescriptorProto"); let (name_bytes, _) = accessor.get_value(1).expect("Enum value name missing"); let name = str::from_utf8(name_bytes).expect("Enum value name invalid utf8"); let (num_bytes, _) = accessor.get_value(2).expect("Enum value number missing"); let (num, _) = crate::read_varint(num_bytes).expect("Enum value number invalid varint"); output.push_str(&format!( " {} => {}::{},\n", num, enum_name, to_pascal_case(name) )); } output.push_str(&format!( " _ => {}::{},\n", enum_name, zero_variant_name.as_ref().unwrap() )); output.push_str(" }\n }\n}\n\n"); } fn write_message(msg_proto: &DescriptorProto, output: &mut String) { let msg_name = to_pascal_case(msg_proto.name().unwrap()); let mut fields_info = Vec::new(); for field_res in msg_proto.field() { let (field_data, _) = field_res.expect("Failed to iterate field"); let field_proto = FieldDescriptorProto::new(field_data).expect("Failed to parse FieldDescriptorProto"); let field_name = field_proto.name().unwrap(); let tag = field_proto.number().unwrap(); let f_type = field_proto.r#type().unwrap() as i32; let f_label = field_proto.label().unwrap() as i32; fields_info.push((field_name.to_string(), tag, f_type, f_label)); } output.push_str(&format!("pub struct {}<'a> {{\n", msg_name)); output.push_str(" accessor: crate::ProtoAccessor<'a>,\n"); for (field_name, _tag, _f_type, f_label) in &fields_info { if *f_label == 3 { output.push_str(&format!(" {}_start: Option,\n", field_name)); output.push_str(&format!(" {}_end: Option,\n", field_name)); } else { output.push_str(&format!(" {}_offset: Option,\n", field_name)); } } output.push_str("}\n\n"); output.push_str(&format!("impl<'a> {}<'a> {{\n", msg_name)); output.push_str(" pub fn new(data: &'a [u8]) -> crate::Result {\n"); output.push_str(" let accessor = crate::ProtoAccessor::new(data)?;\n"); if !fields_info.is_empty() { for (name, _, _, label) in &fields_info { if *label == 3 { output.push_str(&format!(" let mut {}_start = None;\n", name)); output.push_str(&format!(" let mut {}_end = None;\n", name)); } else { output.push_str(&format!(" let mut {}_offset = None;\n", name)); } } output.push_str(" for item in accessor.fields() {\n"); output.push_str(" let (offset, tag, _) = item?;\n"); for (name, tag, _, label) in &fields_info { if *label == 3 { output.push_str(&format!(" if tag.field_number == {} {{\n", tag)); output.push_str(&format!( " if {}_start.is_none() {{ {}_start = Some(offset); }}\n", name, name )); output.push_str(&format!(" {}_end = Some(offset);\n", name)); output.push_str(" }\n"); } else { output.push_str(&format!( " if tag.field_number == {} {{ {}_offset = Some(offset); }}\n", tag, name )); } } output.push_str(" }\n\n"); } output.push_str(" Ok(Self {\n"); output.push_str(" accessor,\n"); for (name, _, _, label) in &fields_info { if *label == 3 { output.push_str(&format!("{}_start, {}_end,\n", name, name)); } else { output.push_str(&format!("{}_offset,\n", name)); } } output.push_str(" })\n }\n\n"); for (field_name, tag, f_type, f_label) in fields_info { let (rust_type, logic) = map_type_to_rust_accessor(f_type, f_label); let safe_name = if field_name == "type" { format!("r#{}", field_name) } else { field_name.clone() }; if f_label == 3 { output.push_str(&format!( " pub fn {}(&self) -> {} {{\n", safe_name, rust_type )); output.push_str(&format!( " match (self.{}_start, self.{}_end) {{\n", field_name, field_name )); output.push_str(&format!(" (Some(start), Some(end)) => self.accessor.iter_repeated_range({}, start, end),\n", tag)); output.push_str(&format!( " _ => self.accessor.iter_repeated({}),\n", tag )); output.push_str(" }\n }\n\n"); } else { output.push_str(&format!( " pub fn {}(&self) -> crate::Result<{}> {{\n", safe_name, rust_type )); output.push_str(&format!( " let offset = self.{}_offset.ok_or(crate::RotoError::FieldNotFound)?;\n", field_name )); output.push_str(" let (bytes, _) = self.accessor.get_value_at(offset)?;\n"); output.push_str(&format!(" {}\n", logic)); output.push_str(" }\n\n"); } } output.push_str("}\n\n"); // Builder output.push_str(&format!( "pub struct {}Builder<'b> {{\n builder: crate::ProtoBuilder<'b>,\n}}\n\nimpl<'b> {}Builder<'b> {{\n", msg_name, msg_name )); output.push_str(&format!( " pub fn builder(buf: &mut [u8]) -> {}Builder<'_> {{\n {}Builder {{\n builder: crate::ProtoBuilder::new(buf),\n }}\n }}\n\n", msg_name, msg_name )); for field_res in msg_proto.field() { let (field_data, _) = field_res.expect("Failed to iterate field"); let field_proto = FieldDescriptorProto::new(field_data).expect("Failed to parse FieldDescriptorProto"); let field_name = field_proto.name().unwrap(); let safe_name = if field_name == "type" { format!("r#{}", field_name) } else { field_name.to_string() }; let tag = field_proto.number().unwrap(); let f_type = field_proto.r#type().unwrap() as i32; let (rust_type, method) = map_type_to_rust_builder(f_type); output.push_str(&format!( " pub fn {}(mut self, value: {}) -> crate::Result {{\n self.builder.{}({}, value)?;\n Ok(self)\n }}\n\n", safe_name, rust_type, method, tag )); } output.push_str(&format!(" pub fn finish(self) -> crate::Result<&'b mut [u8]> {{\n self.builder.finish()\n }}\n}}\n\n")); let mut nested_enums = Vec::new(); for e_res in msg_proto.enum_type() { if let Ok((e, _)) = e_res { nested_enums.push(e); } } let mut nested_msgs = Vec::new(); for m_res in msg_proto.nested_type() { if let Ok((m, _)) = m_res { nested_msgs.push(m); } } if !nested_enums.is_empty() || !nested_msgs.is_empty() { let mod_name = to_snake_case(msg_proto.name().unwrap()); output.push_str(&format!("pub mod {} {{\n", mod_name)); for e_data in nested_enums { write_enum( &EnumDescriptorProto::new(e_data) .expect("Failed to parse nested EnumDescriptorProto"), output, ); } for m_data in nested_msgs { write_message( &DescriptorProto::new(m_data).expect("Failed to parse nested DescriptorProto"), output, ); } output.push_str("}\n\n"); } } fn map_type_to_rust_builder(field_type: i32) -> (String, String) { match field_type { 9 => ("&str".to_string(), "write_string".to_string()), 5 | 17 => ("i32".to_string(), "write_int32".to_string()), 3 | 4 | 8 | 13 | 14 | 18 => ("u64".to_string(), "write_varint".to_string()), 7 | 15 => ("u32".to_string(), "write_fixed32".to_string()), 6 | 16 => ("u64".to_string(), "write_fixed64".to_string()), 11 | 12 => ("&[u8]".to_string(), "write_bytes".to_string()), _ => ("&[u8]".to_string(), "write_bytes".to_string()), } } pub fn generate_rust_code( set: &FileDescriptorSet, files_to_generate: Option<&[String]>, generate_mod_files: bool, ) -> Vec<(String, String)> { let mut generated_files = Vec::new(); for file_res in set.file() { let (file_data, _) = file_res.expect("Failed to iterate file"); let file_proto = FileDescriptorProto::new(file_data).expect("Failed to parse FileDescriptorProto"); let proto_name = file_proto.name().expect("File proto name missing"); if let Some(filter) = files_to_generate { if !filter.contains(&proto_name.to_string()) { continue; } } let rust_file_name = format!("{}.rs", proto_name.replace(".proto", "")); let mut output = String::new(); output.push_str("// @generated by protoc-gen-roto — do not edit\n\n"); output.push_str("use crate::{ProtoAccessor, ProtoBuilder, Result, RotoError, read_varint, RepeatedFieldIterator};\n"); output.push_str("use std::str;\n\n"); for dep_res in file_proto.dependency() { let (dep_data, _) = dep_res.expect("Failed to iterate dependency"); let dep_name = str::from_utf8(dep_data).expect("Dependency name invalid utf8"); let dep_mod_path = dep_name.replace(".proto", "").replace('/', "::"); output.push_str(&format!("use crate::{};\n", dep_mod_path)); } output.push_str("\n"); // Enums for enum_res in file_proto.enum_type() { let (enum_data, _) = enum_res.expect("Failed to iterate enum"); write_enum( &EnumDescriptorProto::new(enum_data).expect("Failed to parse EnumDescriptorProto"), &mut output, ); } // Messages for msg_res in file_proto.message_type() { let (msg_data, _) = msg_res.expect("Failed to iterate message"); write_message( &DescriptorProto::new(msg_data).expect("Failed to parse DescriptorProto"), &mut output, ); } generated_files.push((rust_file_name, output)); } if !generate_mod_files { return generated_files; } let mut all_paths: Vec = generated_files.iter().map(|(p, _)| p.clone()).collect(); all_paths.sort(); let mut mod_files: HashMap> = HashMap::new(); for path in &all_paths { let parts: Vec<&str> = path.split('/').collect(); let mut current_dir = String::new(); for i in 0..parts.len() - 1 { if !current_dir.is_empty() { current_dir.push('/'); } current_dir.push_str(parts[i]); let mod_path = format!("{}/mod.rs", current_dir); let sub_mod = parts[i + 1].replace(".rs", ""); mod_files.entry(mod_path).or_default().insert(sub_mod); } } let mut root_mods = HashSet::new(); for path in &all_paths { let parts: Vec<&str> = path.split('/').collect(); root_mods.insert(parts[0].replace(".rs", "")); } let mut root_mod_content = String::new(); root_mod_content.push_str("// @generated by protoc-gen-roto — do not edit\n\n"); let mut sorted_root_mods: Vec<_> = root_mods.into_iter().collect(); sorted_root_mods.sort(); for m in sorted_root_mods { root_mod_content.push_str(&format!("pub mod {};\n", m)); } generated_files.push(("mod.rs".to_string(), root_mod_content)); for (mod_path, sub_mods) in mod_files { let mut content = String::new(); content.push_str("// @generated by protoc-gen-roto — do not edit\n\n"); let mut sorted_subs: Vec<_> = sub_mods.into_iter().collect(); sorted_subs.sort(); for sub in sorted_subs { content.push_str(&format!("pub mod {};\n", sub)); } generated_files.push((mod_path, content)); } generated_files }