diff --git a/Cargo.lock b/Cargo.lock index 7ce5459..bf461d3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -552,6 +552,7 @@ dependencies = [ "http-body", "http-body-util", "prost", + "roto-codegen", "roto-runtime", "roto-tonic", "tokio", diff --git a/codegen/src/generator.rs b/codegen/src/generator.rs index 81f2b78..55c6b1b 100644 --- a/codegen/src/generator.rs +++ b/codegen/src/generator.rs @@ -540,12 +540,12 @@ pub fn generate_rust_code( } } - let rust_file_name = format!("{}.rs", proto_name.replace(".proto", "")); + let rust_file_name = format!("{}.rs", std::path::Path::new(proto_name).file_stem().unwrap().to_str().unwrap()); let mut output = String::new(); output.push_str("// @generated by protoc-gen-roto — do not edit\n"); output.push_str("#[allow(unused_imports)]\n\n"); - output.push_str("use roto_runtime::{ProtoAccessor, ProtoBuilder, Result, RotoError, read_varint, RepeatedFieldIterator};\n"); + output.push_str("use roto_runtime::{ProtoAccessor, ProtoBuilder, Result, RotoError, read_varint, RepeatedFieldIterator, RotoMessage};\n"); output.push_str("use std::str;\n"); output.push_str("use bytes::{Bytes, BytesMut, Buf, BufMut};\n"); output.push_str("use tonic::{Request, Response, Status};\n"); @@ -710,9 +710,9 @@ fn write_service(svc_proto: &ServiceDescriptorProto, output: &mut String) { output.push_str(&format!("impl Service> for {} {{\n", server_name)); output.push_str(" type Response = http::Response;\n"); output.push_str(" type Error = std::convert::Infallible;\n"); - output.push_str(" type Future = Pin> + Send>>;\n\n"); + output.push_str(" type Future = Pin> + Send>>;\n\n"); - output.push_str(" fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> {\n"); + output.push_str(" fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> {\n"); output.push_str(" Poll::Ready(Ok(()))\n"); output.push_str(" }\n\n"); @@ -720,6 +720,7 @@ fn write_service(svc_proto: &ServiceDescriptorProto, output: &mut String) { output.push_str(" let inner = self.inner.clone();\n"); output.push_str(" let pool = self.pool.clone();\n"); output.push_str(" Box::pin(async move {\n"); + output.push_str(" let path = req.uri().path().to_string();\n"); output.push_str(" let body = req.into_body();\n"); output.push_str(" let mut buf = pool.get();\n"); output.push_str(" let mut stream = body;\n"); @@ -740,7 +741,6 @@ fn write_service(svc_proto: &ServiceDescriptorProto, output: &mut String) { output.push_str(" return Ok(http::Response::builder().status(200).body(res_body).unwrap());\n"); output.push_str(" }\n\n"); output.push_str(" let payload = bytes_vec.slice(5..);\n"); - output.push_str(" let path = req.uri().path();\n"); output.push_str(" let mut routed = false;\n\n"); let mut methods = Vec::new(); @@ -762,14 +762,14 @@ fn write_service(svc_proto: &ServiceDescriptorProto, output: &mut String) { output.push_str(" let res_body = BoxBody::new(StatusBody(Some(Bytes::from_static(&[0, 0, 0, 0, 0]))));\n"); output.push_str(" return Ok(http::Response::builder().status(200).body(res_body).unwrap());\n"); output.push_str(" }\n"); - output.push_str(" }};\n\n"); + output.push_str(" };\n\n"); output.push_str(&format!(" let response = match inner.{}(Request::new(request_msg)).await {{\n", method_name)); output.push_str(" Ok(res) => res,\n"); output.push_str(" Err(e) => {\n"); output.push_str(" let res_body = BoxBody::new(StatusBody(Some(Bytes::from_static(&[0, 0, 0, 0, 0]))));\n"); output.push_str(" return Ok(http::Response::builder().status(200).body(res_body).unwrap());\n"); output.push_str(" }\n"); - output.push_str(" }};\n\n"); + output.push_str(" };\n\n"); output.push_str(" let response_msg = response.into_inner();\n"); output.push_str(" let response_bytes = response_msg.bytes();\n"); output.push_str(" let mut res_buf = pool.get();\n"); diff --git a/examples/hello_world/Cargo.toml b/examples/hello_world/Cargo.toml index 9fb2af8..cd2f19b 100644 --- a/examples/hello_world/Cargo.toml +++ b/examples/hello_world/Cargo.toml @@ -27,3 +27,4 @@ http-body = "1.0" [build-dependencies] tonic-build = "0.12" +roto-codegen = { path = "../../codegen" } diff --git a/examples/hello_world/build.rs b/examples/hello_world/build.rs index 35bd65f..8b39c10 100644 --- a/examples/hello_world/build.rs +++ b/examples/hello_world/build.rs @@ -4,9 +4,9 @@ fn main() { let dest_path = std::path::Path::new(&out_dir).join("hello.rs"); // Find the protoc-gen-roto binary - // In a real scenario, this should be passed as an environment variable or found in PATH - // For this example, we'll try to find it in the target directory - let target_dir = std::env::current_dir().unwrap().join("../../target/debug"); + // Since we added roto-codegen to build-dependencies, it will be built. + let manifest_dir = std::env::var("CARGO_MANIFEST_DIR").unwrap(); + let target_dir = std::path::Path::new(&manifest_dir).join("../../target/debug"); let plugin_path = target_dir.join("protoc-gen-roto"); if !plugin_path.exists() { diff --git a/examples/hello_world/proto/hello.rs b/examples/hello_world/proto/hello.rs new file mode 100644 index 0000000..78f898d --- /dev/null +++ b/examples/hello_world/proto/hello.rs @@ -0,0 +1,310 @@ +// @generated by protoc-gen-roto — do not edit +#[allow(unused_imports)] + +use roto_runtime::{ProtoAccessor, ProtoBuilder, Result, RotoError, read_varint, RepeatedFieldIterator}; +use std::str; +use bytes::{Bytes, BytesMut, Buf, BufMut}; +use tonic::{Request, Response, Status}; +use tokio_stream::Stream; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; +use std::future::Future; +use tonic::body::BoxBody; +use tower::Service; +use futures_util::StreamExt; +use http_body_util::BodyExt; +use http_body::Body; +use roto_tonic::{BufferPool, StatusBody}; + + +pub struct HelloRequest<'a> { + accessor: roto_runtime::ProtoAccessor<'a>, + name_offset: Option, +} + +impl<'a> HelloRequest<'a> { + pub fn new(data: &'a [u8]) -> roto_runtime::Result { + let accessor = roto_runtime::ProtoAccessor::new(data)?; + let mut name_offset = None; + for item in accessor.fields() { + let (offset, tag, _) = item?; + if tag.field_number == 1 { name_offset = Some(offset); } + } + + Ok(Self { + accessor, +name_offset, + }) + } + + pub fn name(&self) -> roto_runtime::Result<&'a str> { + let offset = self.name_offset.ok_or(roto_runtime::RotoError::FieldNotFound)?; + let (bytes, _) = self.accessor.get_value_at(offset)?; + str::from_utf8(bytes).map_err(|_| roto_runtime::RotoError::WireFormatViolation) + } + + pub fn name_or_default(&self) -> roto_runtime::Result<&'a str> { + self.name().or(Ok("")) + } + + pub fn has_name(&self) -> bool { self.name_offset.is_some() } + + pub fn raw_fields(&self) -> roto_runtime::RawFieldIterator<'a> { + self.accessor.raw_fields() + } + +} + +pub struct HelloRequestBuilder<'b> { + builder: roto_runtime::ProtoBuilder<'b>, + name_written: bool, +} + +impl<'b> HelloRequestBuilder<'b> { + pub fn builder(buf: &mut [u8]) -> HelloRequestBuilder<'_> { + HelloRequestBuilder { + builder: roto_runtime::ProtoBuilder::new(buf), + name_written: false, + } + } + + pub fn name(mut self, value: &str) -> roto_runtime::Result { + self.builder.write_string(1, value)?; + self.name_written = true; + Ok(self) + } + + pub fn with(mut self, msg: &HelloRequest<'_>) -> roto_runtime::Result { + for item in msg.raw_fields() { + let (field_number, raw_bytes) = item?; + let is_written = match field_number { + 1 => self.name_written, + _ => false, + }; + if !is_written { + self.builder.write_raw(raw_bytes)?; + } + } + Ok(self) + } + + pub fn finish(self) -> roto_runtime::Result<&'b mut [u8]> { + self.builder.finish() + } +} + +pub struct OwnedHelloRequest { + pub data: bytes::Bytes, +} + +impl roto_runtime::RotoOwned for OwnedHelloRequest { + type Reader<'a> = HelloRequest<'a>; + fn reader(&self) -> HelloRequest<'_> { + HelloRequest::new(&self.data).expect("failed to create reader") + } +} + +impl roto_runtime::RotoMessage for OwnedHelloRequest { + fn decode(buf: bytes::Bytes) -> roto_runtime::Result { + Ok(OwnedHelloRequest { data: buf }) + } + + fn bytes(&self) -> bytes::Bytes { + self.data.clone() + } +} + +pub struct HelloResponse<'a> { + accessor: roto_runtime::ProtoAccessor<'a>, + message_offset: Option, +} + +impl<'a> HelloResponse<'a> { + pub fn new(data: &'a [u8]) -> roto_runtime::Result { + let accessor = roto_runtime::ProtoAccessor::new(data)?; + let mut message_offset = None; + for item in accessor.fields() { + let (offset, tag, _) = item?; + if tag.field_number == 1 { message_offset = Some(offset); } + } + + Ok(Self { + accessor, +message_offset, + }) + } + + pub fn message(&self) -> roto_runtime::Result<&'a str> { + let offset = self.message_offset.ok_or(roto_runtime::RotoError::FieldNotFound)?; + let (bytes, _) = self.accessor.get_value_at(offset)?; + str::from_utf8(bytes).map_err(|_| roto_runtime::RotoError::WireFormatViolation) + } + + pub fn message_or_default(&self) -> roto_runtime::Result<&'a str> { + self.message().or(Ok("")) + } + + pub fn has_message(&self) -> bool { self.message_offset.is_some() } + + pub fn raw_fields(&self) -> roto_runtime::RawFieldIterator<'a> { + self.accessor.raw_fields() + } + +} + +pub struct HelloResponseBuilder<'b> { + builder: roto_runtime::ProtoBuilder<'b>, + message_written: bool, +} + +impl<'b> HelloResponseBuilder<'b> { + pub fn builder(buf: &mut [u8]) -> HelloResponseBuilder<'_> { + HelloResponseBuilder { + builder: roto_runtime::ProtoBuilder::new(buf), + message_written: false, + } + } + + pub fn message(mut self, value: &str) -> roto_runtime::Result { + self.builder.write_string(1, value)?; + self.message_written = true; + Ok(self) + } + + pub fn with(mut self, msg: &HelloResponse<'_>) -> roto_runtime::Result { + for item in msg.raw_fields() { + let (field_number, raw_bytes) = item?; + let is_written = match field_number { + 1 => self.message_written, + _ => false, + }; + if !is_written { + self.builder.write_raw(raw_bytes)?; + } + } + Ok(self) + } + + pub fn finish(self) -> roto_runtime::Result<&'b mut [u8]> { + self.builder.finish() + } +} + +pub struct OwnedHelloResponse { + pub data: bytes::Bytes, +} + +impl roto_runtime::RotoOwned for OwnedHelloResponse { + type Reader<'a> = HelloResponse<'a>; + fn reader(&self) -> HelloResponse<'_> { + HelloResponse::new(&self.data).expect("failed to create reader") + } +} + +impl roto_runtime::RotoMessage for OwnedHelloResponse { + fn decode(buf: bytes::Bytes) -> roto_runtime::Result { + Ok(OwnedHelloResponse { data: buf }) + } + + fn bytes(&self) -> bytes::Bytes { + self.data.clone() + } +} + +#[tonic::async_trait] +pub trait HelloWorldService: Send + Sync + 'static { + async fn hello_world(&self, request: Request) -> std::result::Result, Status>; +} + +pub struct HelloWorldServiceServer { + inner: Arc, + pool: Arc, +} + +impl HelloWorldServiceServer { + pub fn new(inner: Arc, pool: Arc) -> Self { + Self { inner, pool } + } +} + +impl tonic::server::NamedService for HelloWorldServiceServer { + const NAME: &'static str = "HelloWorldService"; +} + +impl Service> for HelloWorldServiceServer { + type Response = http::Response; + type Error = std::convert::Infallible; + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: http::Request) -> Self::Future { + let inner = self.inner.clone(); + let pool = self.pool.clone(); + Box::pin(async move { + let body = req.into_body(); + let mut buf = pool.get(); + let mut stream = body; + while let Some(frame_result) = stream.frame().await { + let frame = frame_result.map_err(|e| { + panic!("Body frame error: {}", e); + })?; + if let Some(data) = frame.data_ref() { + buf.put(data.clone()); + } + } + + let total_len = buf.len(); + let bytes_vec = buf.split_to(total_len).freeze(); + pool.put(buf); + if bytes_vec.len() < 5 { + let res_body = BoxBody::new(StatusBody(Some(Bytes::from_static(&[0, 0, 0, 0, 0])))); + return Ok(http::Response::builder().status(200).body(res_body).unwrap()); + } + + let payload = bytes_vec.slice(5..); + let path = req.uri().path(); + let mut routed = false; + + if path == "/HelloWorldService/hello_world" { + let request_msg = match OwnedHelloRequest::decode(payload) { + Ok(msg) => msg, + Err(e) => { + let res_body = BoxBody::new(StatusBody(Some(Bytes::from_static(&[0, 0, 0, 0, 0])))); + return Ok(http::Response::builder().status(200).body(res_body).unwrap()); + } + }}; + + let response = match inner.hello_world(Request::new(request_msg)).await { + Ok(res) => res, + Err(e) => { + let res_body = BoxBody::new(StatusBody(Some(Bytes::from_static(&[0, 0, 0, 0, 0])))); + return Ok(http::Response::builder().status(200).body(res_body).unwrap()); + } + }}; + + let response_msg = response.into_inner(); + let response_bytes = response_msg.bytes(); + let mut res_buf = pool.get(); + res_buf.put_u8(0); + let len = response_bytes.len() as u32; + res_buf.put_slice(&len.to_be_bytes()); + res_buf.put_slice(&response_bytes); + let frame_len = res_buf.len(); + let frame = res_buf.split_to(frame_len).freeze(); + pool.put(res_buf); + let res_body = BoxBody::new(StatusBody(Some(frame))); + routed = true; + return Ok(http::Response::builder().status(200).header("content-type", "application/grpc").body(res_body).unwrap()); + } + if !routed { + let res_body = BoxBody::new(StatusBody(Some(Bytes::from_static(&[0, 0, 0, 0, 0])))); + return Ok(http::Response::builder().status(200).body(res_body).unwrap()); + } + Ok(http::Response::builder().status(200).body(BoxBody::new(StatusBody(None))).unwrap()) + }) + } +} diff --git a/roto-tonic/src/lib.rs b/roto-tonic/src/lib.rs index 46168e3..d795267 100644 --- a/roto-tonic/src/lib.rs +++ b/roto-tonic/src/lib.rs @@ -101,7 +101,7 @@ impl BufferPool { } } -pub struct StatusBody(pub(crate) Option); +pub struct StatusBody(pub Option); impl Body for StatusBody { type Data = Bytes;