Add _or_default accessors to generated messages

Update map_type_to_rust_accessor to provide default values for each
type, which are used to generate helper methods that return a default
value when a field is missing.
This commit is contained in:
2026-05-08 23:20:40 -07:00
parent 7e368feddf
commit 6e045fd808
+33 -5
View File
@@ -33,7 +33,7 @@ pub fn to_snake_case(s: &str) -> String {
result result
} }
fn map_type_to_rust_accessor(field_type: i32, label: i32, is_map: bool) -> (String, String) { fn map_type_to_rust_accessor(field_type: i32, label: i32, is_map: bool) -> (String, String, String) {
if label == 3 { if label == 3 {
// LABEL_REPEATED // LABEL_REPEATED
let iterator_type = if is_map { let iterator_type = if is_map {
@@ -44,6 +44,7 @@ fn map_type_to_rust_accessor(field_type: i32, label: i32, is_map: bool) -> (Stri
return ( return (
iterator_type.to_string(), iterator_type.to_string(),
"".to_string(), // Not used for repeated fields in the same way "".to_string(), // Not used for repeated fields in the same way
"".to_string(), // Not used for repeated fields
); );
} }
@@ -51,37 +52,53 @@ fn map_type_to_rust_accessor(field_type: i32, label: i32, is_map: bool) -> (Stri
9 => ( 9 => (
"&'a str".to_string(), "&'a str".to_string(),
"str::from_utf8(bytes).map_err(|_| crate::RotoError::WireFormatViolation)".to_string(), "str::from_utf8(bytes).map_err(|_| crate::RotoError::WireFormatViolation)".to_string(),
"\"\"".to_string(),
), // TYPE_STRING ), // TYPE_STRING
1 => ( 1 => (
"f64".to_string(), "f64".to_string(),
"Ok(f64::from_le_bytes(bytes.try_into().map_err(|_| crate::RotoError::WireFormatViolation)?))".to_string(), "Ok(f64::from_le_bytes(bytes.try_into().map_err(|_| crate::RotoError::WireFormatViolation)?))".to_string(),
"0.0".to_string(),
), // TYPE_DOUBLE ), // TYPE_DOUBLE
2 => ( 2 => (
"f32".to_string(), "f32".to_string(),
"Ok(f32::from_le_bytes(bytes.try_into().map_err(|_| crate::RotoError::WireFormatViolation)?))".to_string(), "Ok(f32::from_le_bytes(bytes.try_into().map_err(|_| crate::RotoError::WireFormatViolation)?))".to_string(),
"0.0".to_string(),
), // TYPE_FLOAT ), // TYPE_FLOAT
3 | 5 | 15 | 17 => ( 3 | 5 | 15 | 17 => (
"i32".to_string(), "i32".to_string(),
"roto_runtime::read_varint(bytes).map(|(v, _)| v as i32).map_err(|_| roto_runtime::RotoError::WireFormatViolation)".to_string(), "roto_runtime::read_varint(bytes).map(|(v, _)| v as i32).map_err(|_| roto_runtime::RotoError::WireFormatViolation)".to_string(),
"0".to_string(),
), // INT/SINT/SFIXED 32 ), // INT/SINT/SFIXED 32
4 | 6 | 13 => ( 4 | 6 | 13 => (
"u32".to_string(), "u32".to_string(),
"roto_runtime::read_varint(bytes).map(|(v, _)| v as u32).map_err(|_| roto_runtime::RotoError::WireFormatViolation)".to_string(), "roto_runtime::read_varint(bytes).map(|(v, _)| v as u32).map_err(|_| roto_runtime::RotoError::WireFormatViolation)".to_string(),
"0".to_string(),
), // UINT/FIXED 32 ), // UINT/FIXED 32
16 | 18 => ( 16 | 18 => (
"i64".to_string(), "i64".to_string(),
"roto_runtime::read_varint(bytes).map(|(v, _)| v as i64).map_err(|_| roto_runtime::RotoError::WireFormatViolation)".to_string(), "roto_runtime::read_varint(bytes).map(|(v, _)| v as i64).map_err(|_| roto_runtime::RotoError::WireFormatViolation)".to_string(),
"0".to_string(),
), // SINT/SFIXED 64 ), // SINT/SFIXED 64
7 | 14 => ( 7 | 14 => (
"u64".to_string(), "u64".to_string(),
"roto_runtime::read_varint(bytes).map(|(v, _)| v as u64).map_err(|_| roto_runtime::RotoError::WireFormatViolation)".to_string(), "roto_runtime::read_varint(bytes).map(|(v, _)| v as u64).map_err(|_| roto_runtime::RotoError::WireFormatViolation)".to_string(),
"0".to_string(),
), // UINT/FIXED 64 ), // UINT/FIXED 64
8 => ( 8 => (
"bool".to_string(), "bool".to_string(),
"roto_runtime::read_varint(bytes).map(|(v, _)| v != 0).map_err(|_| roto_runtime::RotoError::WireFormatViolation)".to_string(), "roto_runtime::read_varint(bytes).map(|(v, _)| v != 0).map_err(|_| roto_runtime::RotoError::WireFormatViolation)".to_string(),
"false".to_string(),
), // TYPE_BOOL ), // TYPE_BOOL
11 | 12 => ("&'a [u8]".to_string(), "Ok(bytes)".to_string()), // MESSAGE/BYTES 11 | 12 => (
_ => ("&'a [u8]".to_string(), "Ok(bytes)".to_string()), "&'a [u8]".to_string(),
"Ok(bytes)".to_string(),
"&[]".to_string(),
), // MESSAGE/BYTES
_ => (
"&'a [u8]".to_string(),
"Ok(bytes)".to_string(),
"&[]".to_string(),
),
} }
} }
@@ -248,7 +265,7 @@ fn write_message(msg_proto: &DescriptorProto, output: &mut String) {
output.push_str(" })\n }\n\n"); output.push_str(" })\n }\n\n");
for (field_name, tag, f_type, f_label, _oneof_index, is_map) in &fields_info { for (field_name, tag, f_type, f_label, _oneof_index, is_map) in &fields_info {
let (rust_type, logic) = map_type_to_rust_accessor(*f_type, *f_label, *is_map); let (rust_type, logic, default_val) = map_type_to_rust_accessor(*f_type, *f_label, *is_map);
let safe_name = if field_name == "type" { let safe_name = if field_name == "type" {
format!("r#{}", field_name) format!("r#{}", field_name)
} else { } else {
@@ -290,6 +307,17 @@ fn write_message(msg_proto: &DescriptorProto, output: &mut String) {
output.push_str(" let (bytes, _) = self.accessor.get_value_at(offset)?;\n"); output.push_str(" let (bytes, _) = self.accessor.get_value_at(offset)?;\n");
output.push_str(&format!(" {}\n", logic)); output.push_str(&format!(" {}\n", logic));
output.push_str(" }\n\n"); output.push_str(" }\n\n");
output.push_str(&format!(
" pub fn {}_or_default(&self) -> roto_runtime::Result<{}> {{\n",
safe_name, rust_type
));
output.push_str(&format!(
" self.{}().or(Ok({}))\n",
safe_name, default_val
));
output.push_str(" }\n\n");
output.push_str(&format!( output.push_str(&format!(
" pub fn has_{}(&self) -> bool {{ self.{}_offset.is_some() }}\n\n", " pub fn has_{}(&self) -> bool {{ self.{}_offset.is_some() }}\n\n",
field_name, field_name field_name, field_name
@@ -444,7 +472,7 @@ fn write_message(msg_proto: &DescriptorProto, output: &mut String) {
output.push_str(&format!("pub enum {}<'a> {{\n", pascal_oneof_name)); output.push_str(&format!("pub enum {}<'a> {{\n", pascal_oneof_name));
for (field_name, _tag, f_type, f_label, f_oneof_index, _is_map) in &fields_info { for (field_name, _tag, f_type, f_label, f_oneof_index, _is_map) in &fields_info {
if *f_oneof_index == Some(oneof_index as i32) { if *f_oneof_index == Some(oneof_index as i32) {
let (rust_type, _) = map_type_to_rust_accessor(*f_type, *f_label, *_is_map); let (rust_type, _, _) = map_type_to_rust_accessor(*f_type, *f_label, *_is_map);
let safe_field_name = if field_name == "type" { let safe_field_name = if field_name == "type" {
format!("r#{}", field_name) format!("r#{}", field_name)
} else { } else {