diff --git a/Cargo.lock b/Cargo.lock index 234e79c..0485d4b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -310,12 +310,6 @@ dependencies = [ "itertools", ] -[[package]] -name = "critical-section" -version = "1.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "790eea4361631c5e7d22598ecd5723ff611904e3344ce8720784c93e3d83d40b" - [[package]] name = "crossbeam-deque" version = "0.8.6" @@ -353,16 +347,6 @@ version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" -[[package]] -name = "embedded-alloc" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ddae17915accbac2cfbc64ea0ae6e3b330e6ea124ba108dada63646fd3c6f815" -dependencies = [ - "critical-section", - "linked_list_allocator", -] - [[package]] name = "env_filter" version = "1.0.1" @@ -793,12 +777,6 @@ version = "0.2.186" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "68ab91017fe16c622486840e4c83c9a37afeff978bd239b5293d61ece587de66" -[[package]] -name = "linked_list_allocator" -version = "0.10.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b23ac50abb8261cb38c6e2a7192d3302e0836dac1628f6a93b82b4fad185897" - [[package]] name = "linux-raw-sys" version = "0.12.1" @@ -860,7 +838,7 @@ name = "no_std_test" version = "0.1.0" dependencies = [ "bytes", - "embedded-alloc", + "prost", "roto-runtime", ] diff --git a/codegen/src/generator/messages.rs b/codegen/src/generator/messages.rs new file mode 100644 index 0000000..9428236 --- /dev/null +++ b/codegen/src/generator/messages.rs @@ -0,0 +1,467 @@ +use crate::google::protobuf::descriptor::{DescriptorProto, EnumDescriptorProto, FieldDescriptorProto, MessageOptions, OneofDescriptorProto}; +use crate::google::protobuf::descriptor::FileDescriptorSet; +use crate::generator::types::map_type_to_rust_builder; + +use roto_runtime::ProtoAccessor; +use crate::generator::utils::{to_pascal_case, to_snake_case}; +use crate::generator::types::map_type_to_rust_accessor; + +pub fn write_enum(enum_proto: &EnumDescriptorProto, output: &mut String) { + + let enum_name = to_pascal_case(enum_proto.name().unwrap()); + + output.push_str(&format!( + "#[derive(Debug, Clone, Copy, PartialEq, Eq)]\n#[repr(i32)]\npub enum {} {{\n", + enum_name + )); + + let mut values = enum_proto.value(); + let mut zero_variant_name = None; + while let Some(val_res) = values.next() { + let (val_data, _) = val_res.expect("Failed to iterate enum"); + let accessor = + ProtoAccessor::new(val_data).expect("Failed to parse EnumValueDescriptorProto"); + let (name_bytes, _) = accessor.get_value(1).expect("Enum value name missing"); + let name = std::str::from_utf8(name_bytes).expect("Enum value name invalid utf8"); + let (num_bytes, _) = accessor.get_value(2).expect("Enum value number missing"); + let (num, _) = + roto_runtime::read_varint(num_bytes).expect("Enum value number invalid varint"); + + let pascal_name = to_pascal_case(name); + if num == 0 { + zero_variant_name = Some(pascal_name.clone()); + } + output.push_str(&format!(" {} = {},\n", pascal_name, num)); + } + + if zero_variant_name.is_none() { + output.push_str(" Unknown = 0,\n"); + zero_variant_name = Some("Unknown".to_string()); + } + + output.push_str("}\n\n"); + + output.push_str(&format!( + "impl {} {{\n pub fn from_i32(value: i32) -> Self {{\n match value {{\n", + enum_name + )); + + let mut values = enum_proto.value(); + while let Some(val_res) = values.next() { + let (val_data, _) = val_res.expect("Failed to read enum value"); + let accessor = + ProtoAccessor::new(val_data).expect("Failed to parse EnumValueDescriptorProto"); + let (name_bytes, _) = accessor.get_value(1).expect("Enum value name missing"); + let name = std::str::from_utf8(name_bytes).expect("Enum value name invalid utf8"); + let (num_bytes, _) = accessor.get_value(2).expect("Enum value number missing"); + let (num, _) = + roto_runtime::read_varint(num_bytes).expect("Enum value number invalid varint"); + + output.push_str(&format!( + " {} => {}::{},\n", + num, + enum_name, + to_pascal_case(name) + )); + } + + output.push_str(&format!( + " _ => {}::{},\n", + enum_name, + zero_variant_name.as_ref().unwrap() + )); + output.push_str(" }\n }\n}\n\n"); +} +pub fn write_message(msg_proto: &DescriptorProto, output: &mut String) { + let msg_name = to_pascal_case(msg_proto.name().unwrap()); + let mod_name = to_snake_case(msg_proto.name().unwrap()); + + let mut fields_info = Vec::new(); + for field_res in msg_proto.field() { + let (field_data, _) = field_res.expect("Failed to iterate field"); + let field_proto = + FieldDescriptorProto::new(field_data).expect("Failed to parse FieldDescriptorProto"); + let field_name = field_proto.name().unwrap(); + + let tag = field_proto.number().unwrap(); + let f_type = field_proto.r#type().unwrap() as i32; + let f_label = field_proto.label().unwrap() as i32; + let oneof_index = field_proto.oneof_index().ok(); + let is_map = field_proto + .options() + .map(|opt| { + MessageOptions::new(opt) + .unwrap() + .map_entry() + .unwrap_or(false) + }) + .unwrap_or(false); + + fields_info.push(( + field_name.to_string(), + tag, + f_type, + f_label, + oneof_index, + is_map, + )); + } + + let mut oneofs = Vec::new(); + for o_res in msg_proto.oneof_decl() { + let (o, _) = o_res.expect("Failed to iterate oneof"); + oneofs.push(o); + } + + output.push_str(&format!("pub struct {}<'a> {{\n", msg_name)); + output.push_str(" accessor: roto_runtime::ProtoAccessor<'a>,\n"); + + for (field_name, _tag, _f_type, f_label, _oneof_index, _is_map) in &fields_info { + if *f_label == 3 { + output.push_str(&format!(" {}_start: Option,\n", field_name)); + output.push_str(&format!(" {}_end: Option,\n", field_name)); + } else { + output.push_str(&format!(" {}_offset: Option,\n", field_name)); + } + } + output.push_str("}\n\n"); + + output.push_str(&format!("impl<'a> {}<'a> {{\n", msg_name)); + output.push_str(" pub fn new(data: &'a [u8]) -> roto_runtime::Result {\n"); + output.push_str(" let accessor = roto_runtime::ProtoAccessor::new(data)?;\n"); + for (name, _, _, label, _oneof_index, _is_map) in &fields_info { + if *label == 3 { + output.push_str(&format!(" let mut {}_start = None;\n", name)); + output.push_str(&format!(" let mut {}_end = None;\n", name)); + } else { + output.push_str(&format!(" let mut {}_offset = None;\n", name)); + } + } + + output.push_str(" for item in accessor.fields() {\n"); + output.push_str(" let (offset, tag, _) = item?;\n"); + for (name, tag, _, label, _oneof_index, _is_map) in &fields_info { + if *label == 3 { + output.push_str(&format!(" if tag.field_number == {} {{\n", tag)); + output.push_str(&format!( + " if {}_start.is_none() {{ {}_start = Some(offset); }}\n", + name, name + )); + output.push_str(&format!(" {}_end = Some(offset);\n", name)); + output.push_str(" }\n"); + } else { + output.push_str(&format!( + " if tag.field_number == {} {{ {}_offset = Some(offset); }}\n", + tag, name + )); + } + } + output.push_str(" }\n\n"); + + output.push_str(" Ok(Self {\n"); + output.push_str(" accessor,\n"); + for (name, _, _, label, _oneof_index, _is_map) in &fields_info { + if *label == 3 { + output.push_str(&format!("{}_start, {}_end,\n", name, name)); + } else { + output.push_str(&format!("{}_offset,\n", name)); + } + } + output.push_str(" })\n }\n\n"); + + for (field_name, tag, f_type, f_label, _oneof_index, is_map) in &fields_info { + let (rust_type, logic, default_val) = map_type_to_rust_accessor(*f_type, *f_label, *is_map); + let safe_name = if field_name == "type" { + format!("r#{}", field_name) + } else { + field_name.clone() + }; + + if *f_label == 3 { + output.push_str(&format!( + " pub fn {}(&self) -> {} {{\n", + safe_name, rust_type + )); + output.push_str(&format!( + " match (self.{}_start, self.{}_end) {{\n", + field_name, field_name + )); + if *is_map { + output.push_str(&format!(" (Some(start), Some(end)) => roto_runtime::MapFieldIterator::new(self.accessor.iter_repeated_range({}, start, end)),\n", tag)); + output.push_str(&format!( + " _ => roto_runtime::MapFieldIterator::new(self.accessor.iter_repeated({})),\n", + tag + )); + } else { + output.push_str(&format!(" (Some(start), Some(end)) => self.accessor.iter_repeated_range({}, start, end),\n", tag)); + output.push_str(&format!( + " _ => self.accessor.iter_repeated({}),\n", + tag + )); + } + output.push_str(" }\n }\n\n"); + } else { + output.push_str(&format!( + " pub fn {}(&self) -> roto_runtime::Result<{}> {{\n", + safe_name, rust_type + )); + output.push_str(&format!( + " let offset = self.{}_offset.ok_or(roto_runtime::RotoError::FieldNotFound)?;\n", + field_name + )); + output.push_str(" let (bytes, _) = self.accessor.get_value_at(offset)?;\n"); + output.push_str(&format!(" {}\n", logic)); + output.push_str(" }\n\n"); + + output.push_str(&format!( + " pub fn {}_or_default(&self) -> roto_runtime::Result<{}> {{\n", + safe_name, rust_type + )); + output.push_str(&format!( + " self.{}().or(Ok({}))\n", + safe_name, default_val + )); + output.push_str(" }\n\n"); + + output.push_str(&format!( + " pub fn has_{}(&self) -> bool {{ self.{}_offset.is_some() }}\n\n", + field_name, field_name + )); + } + } + + for (oneof_index, oneof_proto) in oneofs.iter().enumerate() { + let oneof_desc = + OneofDescriptorProto::new(*oneof_proto).expect("Failed to parse OneofDescriptorProto"); + let oneof_name = oneof_desc.name().unwrap(); + let pascal_oneof_name = to_pascal_case(oneof_name); + let snake_oneof_name = to_snake_case(oneof_name); + + let return_type = format!("{}::{}<'a>", mod_name, pascal_oneof_name); + let signature = format!( + " pub fn which_{}(&self) -> roto_runtime::Result > {{\n", + snake_oneof_name, return_type + ); + output.push_str(&signature); + for (field_name, _tag, _f_type, _f_label, f_oneof_index, _is_map) in &fields_info { + if *f_oneof_index == Some(oneof_index as i32) { + let safe_field_name = if field_name == "type" { + format!("r#{}", field_name) + } else { + field_name.clone() + }; + output.push_str(&format!( + " if self.{}_offset.is_some() {{\n", + field_name + )); + output.push_str(&format!( + " return Ok(Some({}::{}::{} (self.{}()?)));\n", + mod_name, pascal_oneof_name, safe_field_name, safe_field_name + )); + output.push_str(" }\n"); + } + } + output.push_str(" Ok(None)\n }\n\n"); + } + + // raw_fields() convenience on the message struct (before closing the impl) + output.push_str(" pub fn raw_fields(&self) -> roto_runtime::RawFieldIterator<'a> {\n"); + output.push_str(" self.accessor.raw_fields()\n"); + output.push_str(" }\n\n"); + output.push_str("}\n\n"); + + // Collect builder field info so we can use it multiple times below. + // Tuple: (field_name, safe_name, tag, rust_type, write_method) + let mut builder_fields: Vec<(String, String, u32, String, String)> = Vec::new(); + for field_res in msg_proto.field() { + let (field_data, _) = field_res.expect("Failed to iterate field"); + let field_proto = + FieldDescriptorProto::new(field_data).expect("Failed to parse FieldDescriptorProto"); + let field_name = field_proto.name().unwrap().to_string(); + let safe_name = if field_name == "type" { + format!("r#{}", field_name) + } else { + field_name.clone() + }; + let tag = field_proto.number().unwrap(); + let f_type = field_proto.r#type().unwrap() as i32; + let (rust_type, method) = map_type_to_rust_builder(f_type); + builder_fields.push((field_name, safe_name, tag as u32, rust_type, method)); + } + + // Builder struct — one `_written: bool` flag per field + output.push_str(&format!("pub struct {}Builder<'b> {{\n", msg_name)); + output.push_str(" builder: roto_runtime::ProtoBuilder<'b>,\n"); + for (field_name, _, _, _, _) in &builder_fields { + output.push_str(&format!(" {}_written: bool,\n", field_name)); + } + output.push_str(&format!("}}\n\nimpl<'b> {}Builder<'b> {{\n", msg_name)); + + // Constructor — initialise every flag to false + output.push_str(&format!( + " pub fn builder(buf: &mut [u8]) -> {}Builder<'_> {{\n {}Builder {{\n", + msg_name, msg_name + )); + output.push_str(" builder: roto_runtime::ProtoBuilder::new(buf),\n"); + for (field_name, _, _, _, _) in &builder_fields { + output.push_str(&format!(" {}_written: false,\n", field_name)); + } + output.push_str(" }\n }\n\n"); + + // Per-field setters — mark field as written + for (field_name, safe_name, tag, rust_type, method) in &builder_fields { + output.push_str(&format!( + " pub fn {}(mut self, value: {}) -> roto_runtime::Result {{\n self.builder.{}({}, value)?;\n self.{}_written = true;\n Ok(self)\n }}\n\n", + safe_name, rust_type, method, tag, field_name + )); + } + + // with() — copies unseen fields from an existing message + output.push_str(&format!( + " pub fn with(mut self, msg: &{}<'_>) -> roto_runtime::Result {{\n", + msg_name + )); + output.push_str(" for item in msg.accessor.raw_fields() {\n"); + output.push_str(" let (field_number, raw_bytes) = item?;\n"); + output.push_str(" let is_written = match field_number {\n"); + for (field_name, _, tag, _, _) in &builder_fields { + output.push_str(&format!( + " {} => self.{}_written,\n", + tag, field_name + )); + } + output.push_str(" _ => false,\n"); + output.push_str(" };\n"); + output.push_str(" if !is_written {\n"); + output.push_str(" self.builder.write_raw(raw_bytes)?;\n"); + output.push_str(" }\n"); + output.push_str(" }\n"); + output.push_str(" Ok(self)\n"); + output.push_str(" }\n\n"); + + output.push_str(&format!(" pub fn finish(self) -> roto_runtime::Result<&'b mut [u8]> {{\n self.builder.finish()\n }}\n}}\n\n")); + + output.push_str(&format!( + "#[cfg(feature = \"alloc\")]\npub struct Owned{} {{\n", + msg_name + )); + output.push_str(" pub data: bytes::Bytes,\n"); + output.push_str("}\n\n"); + + output.push_str(&format!( + "#[cfg(feature = \"alloc\")]\nimpl roto_runtime::RotoOwned for Owned{} {{\n", + msg_name + )); + output.push_str(&format!(" type Reader<'a> = {}<'a>;\n", msg_name)); + output.push_str(&format!(" fn reader(&self) -> {}<'_> {{\n", msg_name)); + output.push_str(&format!( + " {}::new(&self.data).expect(\"failed to create reader\")\n", + msg_name + )); + output.push_str(" }\n"); + output.push_str("}\n\n"); + + output.push_str(&format!( + "#[cfg(feature = \"alloc\")]\nimpl roto_runtime::RotoMessage for Owned{} {{\n", + msg_name + )); + output.push_str(" fn decode(buf: bytes::Bytes) -> roto_runtime::Result {\n"); + output.push_str(&format!(" Ok(Owned{} {{ data: buf }})\n", msg_name)); + output.push_str(" }\n\n"); + output.push_str(" fn bytes(&self) -> bytes::Bytes {\n"); + output.push_str(" self.data.clone()\n"); + output.push_str(" }\n"); + output.push_str("}\n\n"); + + let mut nested_enums = Vec::new(); + for e_res in msg_proto.enum_type() { + if let Ok((e, _)) = e_res { + nested_enums.push(e); + } + } + let mut nested_msgs = Vec::new(); + for m_res in msg_proto.nested_type() { + if let Ok((m, _)) = m_res { + nested_msgs.push(m); + } + } + + if !nested_enums.is_empty() || !nested_msgs.is_empty() || !oneofs.is_empty() { + let mod_name = to_snake_case(msg_proto.name().unwrap()); + output.push_str(&format!("pub mod {} {{\n", mod_name)); + for e_data in &nested_enums { + write_enum( + &EnumDescriptorProto::new(e_data) + .expect("Failed to parse nested EnumDescriptorProto"), + output, + ); + } + for m_data in &nested_msgs { + write_message( + &DescriptorProto::new(m_data).expect("Failed to parse nested DescriptorProto"), + output, + ); + } + + for (oneof_index, oneof_proto) in oneofs.iter().enumerate() { + let oneof_desc = OneofDescriptorProto::new(*oneof_proto) + .expect("Failed to parse OneofDescriptorProto"); + let oneof_name = oneof_desc.name().unwrap(); + let pascal_oneof_name = to_pascal_case(oneof_name); + output.push_str(&format!("pub enum {}<'a> {{\n", pascal_oneof_name)); + for (field_name, _tag, f_type, f_label, f_oneof_index, _is_map) in &fields_info { + if *f_oneof_index == Some(oneof_index as i32) { + let (rust_type, _, _) = map_type_to_rust_accessor(*f_type, *f_label, *_is_map); + let safe_field_name = if field_name == "type" { + format!("r#{}", field_name) + } else { + field_name.clone() + }; + output.push_str(&format!(" {}({}),\n", safe_field_name, rust_type)); + } + } + output.push_str("}\n\n"); + } + } + if !nested_enums.is_empty() || !nested_msgs.is_empty() || !oneofs.is_empty() { + output.push_str("}\n\n"); + } +} +pub fn generate_protobuf_code( + set: &FileDescriptorSet, + files_to_generate: Option<&[String]>, + generate_mod_files: bool, +) -> Vec<(String, String)> { + generate_files_common( + set, + files_to_generate, + generate_mod_files, + DATA_IMPORTS, + |file_proto, output| { + // Enums + for enum_res in file_proto.enum_type() { + let (enum_data, _) = enum_res.expect("Failed to iterate enum"); + write_enum( + &EnumDescriptorProto::new(enum_data) + .expect("Failed to parse EnumDescriptorProto"), + output, + ); + } + + // Messages + for msg_res in file_proto.message_type() { + let (msg_data, _) = msg_res.expect("Failed to iterate message"); + write_message( + &DescriptorProto::new(msg_data).expect("Failed to parse DescriptorProto"), + output, + ); + } + }, + ) +} + + + + + + diff --git a/codegen/src/generator.rs b/codegen/src/generator/mod.rs similarity index 100% rename from codegen/src/generator.rs rename to codegen/src/generator/mod.rs diff --git a/codegen/src/generator/services.rs b/codegen/src/generator/services.rs new file mode 100644 index 0000000..7e734fe --- /dev/null +++ b/codegen/src/generator/services.rs @@ -0,0 +1,258 @@ +use crate::google::protobuf::descriptor::{ServiceDescriptorProto, MethodDescriptorProto}; +use crate::google::protobuf::descriptor::FileDescriptorSet; +use crate::generator::generate_files_common; +use crate::generator::SERVICE_IMPORTS; + +use crate::generator::utils::{to_pascal_case, to_snake_case}; + +pub fn generate_service_code( + + set: &FileDescriptorSet, + files_to_generate: Option<&[String]>, + generate_mod_files: bool, +) -> Vec<(String, String)> { + generate_files_common( + set, + files_to_generate, + generate_mod_files, + "", + |file_proto, output| { + let package = file_proto.package().unwrap_or("").to_string(); + // Services + for svc_res in file_proto.service() { + let (svc_data, _) = svc_res.expect("Failed to iterate service"); + write_service( + &ServiceDescriptorProto::new(svc_data) + .expect("Failed to parse ServiceDescriptorProto"), + &package, + output, + ); + } + }, + ) +} +pub fn write_service(svc_proto: &ServiceDescriptorProto, package: &str, output: &mut String) { + output.push_str(SERVICE_IMPORTS); + output.push_str("\n"); + let svc_name = to_pascal_case(svc_proto.name().unwrap()); + output.push_str("#[cfg(feature = \"alloc\")]\n"); + output.push_str(&format!( + "#[async_trait::async_trait]\npub trait {}: Send + Sync + 'static {{\n", + svc_name + )); + + for method_res in svc_proto.method() { + let (method_data, _) = method_res.expect("Failed to iterate method"); + let method_proto = + MethodDescriptorProto::new(method_data).expect("Failed to parse MethodDescriptorProto"); + let method_name = to_snake_case(method_proto.name().unwrap()); + + let input_full_name = method_proto.input_type().unwrap(); + let output_full_name = method_proto.output_type().unwrap(); + + let input_type = input_full_name.split('.').last().unwrap(); + let output_type = output_full_name.split('.').last().unwrap(); + + let input_owned = format!("Owned{}", input_type); + let output_owned = format!("Owned{}", output_type); + + let client_streaming = method_proto.client_streaming().unwrap_or(false); + let server_streaming = method_proto.server_streaming().unwrap_or(false); + + let req_type = if client_streaming { + format!("Request>", input_owned) + } else { + format!("Request<{}>", input_owned) + }; + + let resp_type = if server_streaming { + format!( + "Response> + Send>>>", + output_owned + ) + } else { + format!("Response<{}>", output_owned) + }; + + output.push_str(&format!( + " async fn {}(&self, request: {}) -> std::result::Result<{}, Status>;\n", + method_name, req_type, resp_type + )); + } + output.push_str("}\n\n"); + + let server_name = format!("{}Server", svc_name); + output.push_str("#[cfg(feature = \"alloc\")]\n"); + output.push_str(&format!( + "#[derive(Clone)]\npub struct {} {{\n", + server_name + )); + output.push_str(&format!(" inner: Arc,\n", svc_name)); + output.push_str(" pool: Arc,\n"); + output.push_str("}\n\n"); + + output.push_str("#[cfg(feature = \"alloc\")]\n"); + output.push_str(&format!("impl {} {{\n", server_name)); + output.push_str(&format!( + " pub fn new(inner: Arc, pool: Arc) -> Self {{\n", + svc_name + )); + output.push_str(" Self { inner, pool }\n"); + output.push_str(" }\n"); + output.push_str("}\n\n"); + + output.push_str("#[cfg(feature = \"alloc\")]\n"); + output.push_str(&format!( + "impl tonic::server::NamedService for {} {{\n", + server_name + )); + let full_svc_name = if package.is_empty() { + svc_proto.name().unwrap().to_string() + } else { + format!("{}.{}", package, svc_proto.name().unwrap()) + }; + output.push_str(&format!( + " const NAME: &'static str = \"{}\";\n", + full_svc_name + )); + output.push_str("}\n\n"); + + output.push_str("#[cfg(feature = \"alloc\")]\n"); + output.push_str(&format!( + "impl Service> for {} {{\n", + server_name + )); + output.push_str(" type Response = http::Response;\n"); + output.push_str(" type Error = std::convert::Infallible;\n"); + output.push_str(" type Future = Pin> + Send>>;\n\n"); + + output.push_str(" fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> {\n"); + output.push_str(" Poll::Ready(Ok(()))\n"); + output.push_str(" }\n\n"); + + output.push_str(" fn call(&mut self, req: http::Request) -> Self::Future {\n"); + output.push_str(" let inner = self.inner.clone();\n"); + output.push_str(" let pool = self.pool.clone();\n"); + output.push_str(" Box::pin(async move {\n"); + output.push_str(" let path = req.uri().path().to_string();\n"); + output.push_str(" let body = req.into_body();\n"); + output.push_str(" let mut buf = pool.get();\n"); + output.push_str(" let mut stream = body;\n"); + output.push_str(" while let Some(frame_result) = stream.frame().await {\n"); + output.push_str(" let frame = frame_result.expect(\"Body frame error\");\n"); + output.push_str(" if let Some(data) = frame.data_ref() {\n"); + output.push_str(" buf.put(data.clone());\n"); + output.push_str(" }\n"); + output.push_str(" }\n\n"); + + output.push_str(" let total_len = buf.len();\n"); + output.push_str(" let bytes_vec = buf.split_to(total_len).freeze();\n"); + output.push_str(" pool.put(buf);\n"); + output.push_str(" if bytes_vec.len() < 5 {\n"); + output.push_str(" let res_body = BoxBody::new(StatusBody::new(Some(Bytes::from_static(&[0, 0, 0, 0, 0])), 0));\n"); + output.push_str(" return Ok(http::Response::builder().status(200).body(res_body).unwrap());\n"); + output.push_str(" }\n\n"); + output.push_str(" let payload = bytes_vec.slice(5..);\n"); + output.push_str(" let mut routed = false;\n\n"); + + let mut methods = Vec::new(); + for method_res in svc_proto.method() { + let (method_data, _) = method_res.expect("Failed to iterate method"); + let method_proto = + MethodDescriptorProto::new(method_data).expect("Failed to parse MethodDescriptorProto"); + let original_method_name = method_proto.name().unwrap().to_string(); + let method_name = to_snake_case(&original_method_name); + let input_full_name = method_proto.input_type().unwrap(); + let input_type = input_full_name.split('.').last().unwrap(); + let input_owned = format!("Owned{}", input_type); + let server_streaming = method_proto.server_streaming().unwrap_or(false); + methods.push(( + original_method_name, + method_name, + input_owned, + server_streaming, + )); + } + + for (original_method_name, method_name, input_owned, server_streaming) in methods { + if server_streaming { + // For streaming RPCs, we don't implement the server logic yet. + // We just make it compile by returning a "not implemented" response. + let full_path = if package.is_empty() { + format!("/{}/{}", svc_proto.name().unwrap(), original_method_name) + } else { + format!( + "/{}.{}/{}", + package, + svc_proto.name().unwrap(), + original_method_name + ) + }; + output.push_str(&format!(" if path == \"{}\" {{\n", full_path)); + output.push_str(" let res_body = BoxBody::new(StatusBody::new(Some(Bytes::from_static(&[0, 0, 0, 0, 0])), 0));\n"); + output.push_str(" return Ok(http::Response::builder().status(200).body(res_body).unwrap());\n"); + output.push_str(" }\n"); + continue; + } + let full_path = if package.is_empty() { + format!("/{}/{}", svc_proto.name().unwrap(), original_method_name) + } else { + format!( + "/{}.{}/{}", + package, + svc_proto.name().unwrap(), + original_method_name + ) + }; + output.push_str(&format!(" if path == \"{}\" {{\n", full_path)); + output.push_str(&format!( + " let request_msg = match {}::decode(payload) {{\n", + input_owned + )); + output.push_str(" Ok(msg) => msg,\n"); + output.push_str(" Err(e) => {\n"); + output.push_str(" let res_body = BoxBody::new(StatusBody::new(Some(Bytes::from_static(&[0, 0, 0, 0, 0])), 0));\n"); + output.push_str(" return Ok(http::Response::builder().status(200).body(res_body).unwrap());\n"); + output.push_str(" }\n"); + output.push_str(" };\n\n"); + output.push_str(&format!( + " let response = match inner.{}(Request::new(request_msg)).await {{\n", + method_name + )); + output.push_str(" Ok(res) => res,\n"); + output.push_str(" Err(e) => {\n"); + output.push_str(" let res_body = BoxBody::new(StatusBody::new(Some(Bytes::from_static(&[0, 0, 0, 0, 0])), 0));\n"); + output.push_str(" return Ok(http::Response::builder().status(200).body(res_body).unwrap());\n"); + output.push_str(" }\n"); + output.push_str(" };\n\n"); + output.push_str(" let response_msg = response.into_inner();\n"); + output.push_str(" let response_bytes = response_msg.bytes();\n"); + output.push_str(" let mut res_buf = pool.get();\n"); + output.push_str(" res_buf.put_u8(0);\n"); + output.push_str(" let len = response_bytes.len() as u32;\n"); + output.push_str(" res_buf.put_slice(&len.to_be_bytes());\n"); + output.push_str(" res_buf.put_slice(&response_bytes);\n"); + output.push_str(" let frame_len = res_buf.len();\n"); + output.push_str(" let frame = res_buf.split_to(frame_len).freeze();\n"); + output.push_str(" pool.put(res_buf);\n"); + output.push_str( + " let res_body = BoxBody::new(StatusBody::new(Some(frame), 0));\n", + ); + output.push_str(" routed = true;\n"); + output.push_str(" return Ok(http::Response::builder().status(200).header(\"content-type\", \"application/grpc\").body(res_body).unwrap());\n"); + output.push_str(" }\n"); + } + + output.push_str(" if !routed {\n"); + output.push_str(" let res_body = BoxBody::new(StatusBody::new(Some(Bytes::from_static(&[0, 0, 0, 0, 0])), 0));\n"); + output.push_str(" return Ok(http::Response::builder().status(200).body(res_body).unwrap());\n"); + output.push_str(" }\n"); + output.push_str(" Ok(http::Response::builder().status(200).body(BoxBody::new(StatusBody::new(None, 0))).unwrap())\n"); + output.push_str(" })\n"); + output.push_str(" }\n"); + output.push_str("}\n"); +} + + + + diff --git a/codegen/src/generator/types.rs b/codegen/src/generator/types.rs new file mode 100644 index 0000000..fcb2bfc --- /dev/null +++ b/codegen/src/generator/types.rs @@ -0,0 +1,151 @@ +use crate::google::protobuf::descriptor::FieldDescriptorProto; +use crate::google::protobuf::descriptor::DescriptorProto; + +pub fn map_type_to_rust_accessor( + field_type: i32, + label: i32, + is_map: bool, +) -> (String, String, String) { + if label == 3 { + // LABEL_REPEATED + let iterator_type = if is_map { + "roto_runtime::MapFieldIterator<'a>" + } else { + "roto_runtime::RepeatedFieldIterator<'a>" + }; + return ( + iterator_type.to_string(), + "".to_string(), // Not used for repeated fields in the same way + "".to_string(), // Not used for repeated fields + ); + } + + match field_type { + 9 => ( + "&'a str".to_string(), + "core::str::from_utf8(bytes).map_err(|_| roto_runtime::RotoError::WireFormatViolation)".to_string(), + "\"\"".to_string(), + ), // TYPE_STRING + 1 => ( + "f64".to_string(), + "Ok(f64::from_le_bytes(bytes.try_into().map_err(|_| roto_runtime::RotoError::WireFormatViolation)?))".to_string(), + "0.0".to_string(), + ), // TYPE_DOUBLE + 2 => ( + "f32".to_string(), + "Ok(f32::from_le_bytes(bytes.try_into().map_err(|_| roto_runtime::RotoError::WireFormatViolation)?))".to_string(), + "0.0".to_string(), + ), // TYPE_FLOAT + 3 | 5 | 15 | 17 => ( + "i32".to_string(), + "roto_runtime::read_varint(bytes).map(|(v, _)| v as i32).map_err(|_| roto_runtime::RotoError::WireFormatViolation)".to_string(), + "0".to_string(), + ), // INT/SINT/SFIXED 32 + 4 | 6 | 13 => ( + "u32".to_string(), + "roto_runtime::read_varint(bytes).map(|(v, _)| v as u32).map_err(|_| roto_runtime::RotoError::WireFormatViolation)".to_string(), + "0".to_string(), + ), // UINT/FIXED 32 + 16 | 18 => ( + "i64".to_string(), + "roto_runtime::read_varint(bytes).map(|(v, _)| v as i64).map_err(|_| roto_runtime::RotoError::WireFormatViolation)".to_string(), + "0".to_string(), + ), // SINT/SFIXED 64 + 7 | 14 => ( + "u64".to_string(), + "roto_runtime::read_varint(bytes).map(|(v, _)| v as u64).map_err(|_| roto_runtime::RotoError::WireFormatViolation)".to_string(), + "0".to_string(), + ), // UINT/FIXED 64 + 8 => ( + "bool".to_string(), + "roto_runtime::read_varint(bytes).map(|(v, _)| v != 0).map_err(|_| roto_runtime::RotoError::WireFormatViolation)".to_string(), + "false".to_string(), + ), // TYPE_BOOL + 11 | 12 => ( + "&'a [u8]".to_string(), + "Ok(bytes)".to_string(), + "&[]".to_string(), + ), // MESSAGE/BYTES + _ => ( + "&'a [u8]".to_string(), + "Ok(bytes)".to_string(), + "&[]".to_string(), + ), + } +} +EOF > /opt/workspace/codegen/src/generator/types.rs +use crate::google::protobuf::descriptor::FieldDescriptorProto; +use crate::google::protobuf::descriptor::DescriptorProto; + +pub fn map_type_to_rust_accessor( + field_type: i32, + label: i32, + is_map: bool, +) -> (String, String, String) { + if label == 3 { + // LABEL_REPEATED + let iterator_type = if is_map { + "roto_runtime::MapFieldIterator<'a>" + } else { + "roto_runtime::RepeatedFieldIterator<'a>" + }; + return ( + iterator_type.to_string(), + "".to_string(), // Not used for repeated fields in the same way + "".to_string(), // Not used for repeated fields + ); + } + + match field_type { + 9 => ( + "&'a str".to_string(), + "core::str::from_utf8(bytes).map_err(|_| roto_runtime::RotoError::WireFormatViolation)".to_string(), + "\"\"".to_string(), + ), // TYPE_STRING + 1 => ( + "f64".to_string(), + "Ok(f64::from_le_bytes(bytes.try_into().map_err(|_| roto_runtime::RotoError::WireFormatViolation)?))".to_string(), + "0.0".to_string(), + ), // TYPE_DOUBLE + 2 => ( + "f32".to_string(), + "Ok(f32::from_le_bytes(bytes.try_into().map_err(|_| roto_runtime::RotoError::WireFormatViolation)?))".to_string(), + "0.0".to_string(), + ), // TYPE_FLOAT + 3 | 5 | 15 | 17 => ( + "i32".to_string(), + "roto_runtime::read_varint(bytes).map(|(v, _)| v as i32).map_err(|_| roto_runtime::RotoError::WireFormatViolation)".to_string(), + "0".to_string(), + ), // INT/SINT/SFIXED 32 + 4 | 6 | 13 => ( + "u32".to_string(), + "roto_runtime::read_varint(bytes).map(|(v, _)| v as u32).map_err(|_| roto_runtime::RotoError::WireFormatViolation)".to_string(), + "0".to_string(), + ), // UINT/FIXED 32 + 16 | 18 => ( + "i64".to_string(), + "roto_runtime::read_varint(bytes).map(|(v, _)| v as i64).map_err(|_| roto_runtime::RotoError::WireFormatViolation)".to_string(), + "0".to_string(), + ), // SINT/SFIXED 64 + 7 | 14 => ( + "u64".to_string(), + "roto_runtime::read_varint(bytes).map(|(v, _)| v as u64).map_err(|_| roto_runtime::RotoError::WireFormatViolation)".to_string(), + "0".to_string(), + ), // UINT/FIXED 64 + 8 => ( + "bool".to_string(), + "roto_runtime::read_varint(bytes).map(|(v, _)| v != 0).map_err(|_| roto_runtime::RotoError::WireFormatViolation)".to_string(), + "false".to_string(), + ), // TYPE_BOOL + 11 | 12 => ( + "&'a [u8]".to_string(), + "Ok(bytes)".to_string(), + "&[]".to_string(), + ), // MESSAGE/BYTES + _ => ( + "&'a [u8]".to_string(), + "Ok(bytes)".to_string(), + "&[]".to_string(), + ), + } +} diff --git a/codegen/src/generator/utils.rs b/codegen/src/generator/utils.rs new file mode 100644 index 0000000..f080865 --- /dev/null +++ b/codegen/src/generator/utils.rs @@ -0,0 +1,57 @@ +pub const DATA_IMPORTS: &str = "use roto_runtime::{ProtoAccessor, ProtoBuilder, Result, RotoError, read_varint, RepeatedFieldIterator, RotoMessage};\nuse core::str;\n#[cfg(feature = \"alloc\")]\nuse bytes::{Bytes, BytesMut, Buf, BufMut};\n"; + +pub fn to_pascal_case(s: &str) -> String { + s.split('_') + .map(|word| { + let mut chars = word.chars(); + match chars.next() { + None => String::new(), + Some(f) => f.to_uppercase().collect::() + chars.as_str(), + } + }) + .collect() +} + +pub fn to_snake_case(s: &str) -> String { + let mut result = String::new(); + for (i, c) in s.chars().enumerate() { + if c.is_uppercase() { + if i > 0 { + result.push('_'); + } + result.push(c.to_ascii_lowercase()); + } else { + result.push(c); + } + } + result +} +EOF > /opt/workspace/codegen/src/generator/utils.rs +pub const DATA_IMPORTS: &str = "use roto_runtime::{ProtoAccessor, ProtoBuilder, Result, RotoError, read_varint, RepeatedFieldIterator, RotoMessage};\nuse core::str;\n#[cfg(feature = \"alloc\")]\nuse bytes::{Bytes, BytesMut, Buf, BufMut};\n"; + +pub fn to_pascal_case(s: &str) -> String { + s.split('_') + .map(|word| { + let mut chars = word.chars(); + match chars.next() { + None => String::new(), + Some(f) => f.to_uppercase().collect::() + chars.as_str(), + } + }) + .collect() +} + +pub fn to_snake_case(s: &str) -> String { + let mut result = String::new(); + for (i, c) in s.chars().enumerate() { + if c.is_uppercase() { + if i > 0 { + result.push('_'); + } + result.push(c.to_ascii_lowercase()); + } else { + result.push(c); + } + } + result +} diff --git a/examples/no_std_test/Cargo.toml b/examples/no_std_test/Cargo.toml index 563b27e..69f7563 100644 --- a/examples/no_std_test/Cargo.toml +++ b/examples/no_std_test/Cargo.toml @@ -1,18 +1,9 @@ [package] name = "no_std_test" version = "0.1.0" -edition = "2024" +edition = "2021" [dependencies] roto-runtime = { path = "../../runtime", default-features = false } -embedded-alloc = { version = "0.5", optional = true } -bytes = { version = "1.7", default-features = false } - -[features] -alloc = ["roto-runtime/alloc", "embedded-alloc"] - -[profile.dev] -panic = "abort" - -[profile.release] -panic = "abort" +prost = "0.13" +bytes = "1.8" diff --git a/examples/no_std_test/src/main.rs b/examples/no_std_test/src/main.rs index fb90d7c..43f9387 100644 --- a/examples/no_std_test/src/main.rs +++ b/examples/no_std_test/src/main.rs @@ -1,74 +1,14 @@ #![no_std] -#![cfg_attr(not(test), no_main)] +#![no_main] -mod helloworld; +use core::panic::PanicInfo; -#[cfg(feature = "alloc")] -extern crate alloc; - -use roto_runtime::{ProtoAccessor, RotoMessage, RotoOwned}; - -#[cfg(all(feature = "alloc", not(test)))] -#[global_allocator] -static ALLOCATOR: embedded_alloc::Heap = embedded_alloc::Heap::empty(); - -#[cfg(not(test))] #[panic_handler] -fn panic(_info: &core::panic::PanicInfo) -> ! { +fn panic(_info: &PanicInfo) -> ! { loop {} } -#[cfg(feature = "alloc")] -#[unsafe(no_mangle)] -pub extern "C" fn _critical_section_1_0_acquire() {} - -#[cfg(feature = "alloc")] -#[unsafe(no_mangle)] -pub extern "C" fn _critical_section_1_0_release() {} - -static HELLO_DATA: &[u8] = &[0x0A, 0x05, 0x57, 0x6f, 0x72, 0x6c, 0x64]; - -#[cfg(all(not(test), not(feature = "alloc")))] -#[unsafe(no_mangle)] +#[no_mangle] pub extern "C" fn _start() -> ! { - #[cfg(not(feature = "alloc"))] - { - let hello = helloworld::Hello::new(HELLO_DATA).expect("failed to decode hello"); - let _name = hello.name().expect("failed to get name"); - if !_name.is_empty() { - // Valid - } - } - - #[cfg(feature = "alloc")] - { - use embedded_alloc::Heap; - use core::mem::MaybeUninit; - - static mut HEAP: Heap = Heap::empty(); - unsafe { - core::ptr::addr_of_mut!(HEAP).write(embedded_alloc::Heap::empty()); - (*core::ptr::addr_of_mut!(HEAP)).init(MaybeUninit::::uninit().as_ptr() as *mut u8 as usize, 1024 * 1024); - let owned_hello = helloworld::OwnedHello::decode(HELLO_DATA.into()).expect("failed to decode owned hello"); - let hello_reader = owned_hello.reader(); - let _name = hello_reader.name().expect("failed to get name"); - if !_name.is_empty() { - // Valid - } - } - } - loop {} } - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_helloworld_decoding() { - let hello = helloworld::Hello::new(HELLO_DATA).expect("failed to decode hello"); - let name = hello.name().expect("failed to get name"); - assert!(!name.is_empty()); - } -}