diff --git a/codegen/src/generator.rs b/codegen/src/generator.rs index 00cf6ff..aa4db23 100644 --- a/codegen/src/generator.rs +++ b/codegen/src/generator.rs @@ -33,7 +33,7 @@ pub fn to_snake_case(s: &str) -> String { result } -fn map_type_to_rust_accessor(field_type: i32, label: i32, is_map: bool) -> (String, String) { +fn map_type_to_rust_accessor(field_type: i32, label: i32, is_map: bool) -> (String, String, String) { if label == 3 { // LABEL_REPEATED let iterator_type = if is_map { @@ -44,6 +44,7 @@ fn map_type_to_rust_accessor(field_type: i32, label: i32, is_map: bool) -> (Stri return ( iterator_type.to_string(), "".to_string(), // Not used for repeated fields in the same way + "".to_string(), // Not used for repeated fields ); } @@ -51,37 +52,53 @@ fn map_type_to_rust_accessor(field_type: i32, label: i32, is_map: bool) -> (Stri 9 => ( "&'a str".to_string(), "str::from_utf8(bytes).map_err(|_| crate::RotoError::WireFormatViolation)".to_string(), + "\"\"".to_string(), ), // TYPE_STRING 1 => ( "f64".to_string(), "Ok(f64::from_le_bytes(bytes.try_into().map_err(|_| crate::RotoError::WireFormatViolation)?))".to_string(), + "0.0".to_string(), ), // TYPE_DOUBLE 2 => ( "f32".to_string(), "Ok(f32::from_le_bytes(bytes.try_into().map_err(|_| crate::RotoError::WireFormatViolation)?))".to_string(), + "0.0".to_string(), ), // TYPE_FLOAT 3 | 5 | 15 | 17 => ( "i32".to_string(), "roto_runtime::read_varint(bytes).map(|(v, _)| v as i32).map_err(|_| roto_runtime::RotoError::WireFormatViolation)".to_string(), + "0".to_string(), ), // INT/SINT/SFIXED 32 4 | 6 | 13 => ( "u32".to_string(), "roto_runtime::read_varint(bytes).map(|(v, _)| v as u32).map_err(|_| roto_runtime::RotoError::WireFormatViolation)".to_string(), + "0".to_string(), ), // UINT/FIXED 32 16 | 18 => ( "i64".to_string(), "roto_runtime::read_varint(bytes).map(|(v, _)| v as i64).map_err(|_| roto_runtime::RotoError::WireFormatViolation)".to_string(), + "0".to_string(), ), // SINT/SFIXED 64 7 | 14 => ( "u64".to_string(), "roto_runtime::read_varint(bytes).map(|(v, _)| v as u64).map_err(|_| roto_runtime::RotoError::WireFormatViolation)".to_string(), + "0".to_string(), ), // UINT/FIXED 64 8 => ( "bool".to_string(), "roto_runtime::read_varint(bytes).map(|(v, _)| v != 0).map_err(|_| roto_runtime::RotoError::WireFormatViolation)".to_string(), + "false".to_string(), ), // TYPE_BOOL - 11 | 12 => ("&'a [u8]".to_string(), "Ok(bytes)".to_string()), // MESSAGE/BYTES - _ => ("&'a [u8]".to_string(), "Ok(bytes)".to_string()), + 11 | 12 => ( + "&'a [u8]".to_string(), + "Ok(bytes)".to_string(), + "&[]".to_string(), + ), // MESSAGE/BYTES + _ => ( + "&'a [u8]".to_string(), + "Ok(bytes)".to_string(), + "&[]".to_string(), + ), } } @@ -248,7 +265,7 @@ fn write_message(msg_proto: &DescriptorProto, output: &mut String) { output.push_str(" })\n }\n\n"); for (field_name, tag, f_type, f_label, _oneof_index, is_map) in &fields_info { - let (rust_type, logic) = map_type_to_rust_accessor(*f_type, *f_label, *is_map); + let (rust_type, logic, default_val) = map_type_to_rust_accessor(*f_type, *f_label, *is_map); let safe_name = if field_name == "type" { format!("r#{}", field_name) } else { @@ -290,6 +307,17 @@ fn write_message(msg_proto: &DescriptorProto, output: &mut String) { output.push_str(" let (bytes, _) = self.accessor.get_value_at(offset)?;\n"); output.push_str(&format!(" {}\n", logic)); output.push_str(" }\n\n"); + + output.push_str(&format!( + " pub fn {}_or_default(&self) -> roto_runtime::Result<{}> {{\n", + safe_name, rust_type + )); + output.push_str(&format!( + " self.{}().or(Ok({}))\n", + safe_name, default_val + )); + output.push_str(" }\n\n"); + output.push_str(&format!( " pub fn has_{}(&self) -> bool {{ self.{}_offset.is_some() }}\n\n", field_name, field_name @@ -444,7 +472,7 @@ fn write_message(msg_proto: &DescriptorProto, output: &mut String) { output.push_str(&format!("pub enum {}<'a> {{\n", pascal_oneof_name)); for (field_name, _tag, f_type, f_label, f_oneof_index, _is_map) in &fields_info { if *f_oneof_index == Some(oneof_index as i32) { - let (rust_type, _) = map_type_to_rust_accessor(*f_type, *f_label, *_is_map); + let (rust_type, _, _) = map_type_to_rust_accessor(*f_type, *f_label, *_is_map); let safe_field_name = if field_name == "type" { format!("r#{}", field_name) } else {