From 117cbf812b62c485ea4f614e5a89a7f846e4138d Mon Sep 17 00:00:00 2001 From: charles Date: Tue, 19 May 2026 21:55:18 -0700 Subject: [PATCH] Introduce alloc feature for optional allocation Wrap heap-allocated types and service generation in the `alloc` feature flag to support environments without a memory allocator. --- Cargo.lock | 1 + codegen/src/generator.rs | 150 ++++++++++++++++++++++------ examples/hello_world/Cargo.toml | 5 + examples/no_std_test/src/main.rs | 18 +++- roto-tonic/Cargo.toml | 4 + roto-tonic/src/generated/interop.rs | 70 +++++++++---- runtime/src/lib.rs | 5 + 7 files changed, 201 insertions(+), 52 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 6f911fb..234e79c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -562,6 +562,7 @@ checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" name = "hello-world" version = "0.1.0" dependencies = [ + "async-trait", "bytes", "futures-util", "http", diff --git a/codegen/src/generator.rs b/codegen/src/generator.rs index ffdcec9..0235654 100644 --- a/codegen/src/generator.rs +++ b/codegen/src/generator.rs @@ -1,13 +1,26 @@ use crate::google::protobuf::descriptor::{ DescriptorProto, EnumDescriptorProto, FieldDescriptorProto, FileDescriptorProto, - FileDescriptorSet, MessageOptions, MethodDescriptorProto, OneofDescriptorProto, ServiceDescriptorProto, + FileDescriptorSet, MessageOptions, MethodDescriptorProto, OneofDescriptorProto, + ServiceDescriptorProto, }; use roto_runtime::ProtoAccessor; use std::collections::{HashMap, HashSet}; use std::str; const DATA_IMPORTS: &str = "use roto_runtime::{ProtoAccessor, ProtoBuilder, Result, RotoError, read_varint, RepeatedFieldIterator, RotoMessage};\nuse core::str;\n#[cfg(feature = \"alloc\")]\nuse bytes::{Bytes, BytesMut, Buf, BufMut};\n"; -const SERVICE_IMPORTS: &str = "use tonic::{Request, Response, Status};\nuse tokio_stream::Stream;\nuse std::pin::Pin;\nuse std::sync::Arc;\nuse std::task::{Context, Poll};\nuse std::future::Future;\nuse tonic::body::BoxBody;\nuse tower::Service;\nuse futures_util::StreamExt;\nuse http_body_util::BodyExt;\nuse http_body::Body;\nuse crate::{BufferPool, StatusBody};\n"; +const SERVICE_IMPORTS: &str = + "#[cfg(feature = \"alloc\")]\nuse tonic::{Request, Response, Status};\n\ + #[cfg(feature = \"alloc\")]\nuse tokio_stream::Stream;\n\ + #[cfg(feature = \"alloc\")]\nuse std::pin::Pin;\n\ + #[cfg(feature = \"alloc\")]\nuse std::sync::Arc;\n\ + #[cfg(feature = \"alloc\")]\nuse std::task::{Context, Poll};\n\ + #[cfg(feature = \"alloc\")]\nuse std::future::Future;\n\ + #[cfg(feature = \"alloc\")]\nuse tonic::body::BoxBody;\n\ + #[cfg(feature = \"alloc\")]\nuse tower::Service;\n\ + #[cfg(feature = \"alloc\")]\nuse futures_util::StreamExt;\n\ + #[cfg(feature = \"alloc\")]\nuse http_body_util::BodyExt;\n\ + #[cfg(feature = \"alloc\")]\nuse http_body::Body;\n\ + #[cfg(feature = \"alloc\")]\nuse crate::{BufferPool, StatusBody};\n"; pub fn to_pascal_case(s: &str) -> String { s.split('_') @@ -36,7 +49,11 @@ pub fn to_snake_case(s: &str) -> String { result } -fn map_type_to_rust_accessor(field_type: i32, label: i32, is_map: bool) -> (String, String, String) { +fn map_type_to_rust_accessor( + field_type: i32, + label: i32, + is_map: bool, +) -> (String, String, String) { if label == 3 { // LABEL_REPEATED let iterator_type = if is_map { @@ -314,7 +331,7 @@ fn write_message(msg_proto: &DescriptorProto, output: &mut String) { output.push_str(&format!( " pub fn {}_or_default(&self) -> roto_runtime::Result<{}> {{\n", - safe_name, rust_type + safe_name, rust_type )); output.push_str(&format!( " self.{}().or(Ok({}))\n", @@ -440,18 +457,30 @@ fn write_message(msg_proto: &DescriptorProto, output: &mut String) { output.push_str(&format!(" pub fn finish(self) -> roto_runtime::Result<&'b mut [u8]> {{\n self.builder.finish()\n }}\n}}\n\n")); - output.push_str(&format!("#[cfg(feature = \"alloc\")]\npub struct Owned{} {{\n", msg_name)); + output.push_str(&format!( + "#[cfg(feature = \"alloc\")]\npub struct Owned{} {{\n", + msg_name + )); output.push_str(" pub data: bytes::Bytes,\n"); output.push_str("}\n\n"); - output.push_str(&format!("#[cfg(feature = \"alloc\")]\nimpl roto_runtime::RotoOwned for Owned{} {{\n", msg_name)); + output.push_str(&format!( + "#[cfg(feature = \"alloc\")]\nimpl roto_runtime::RotoOwned for Owned{} {{\n", + msg_name + )); output.push_str(&format!(" type Reader<'a> = {}<'a>;\n", msg_name)); output.push_str(&format!(" fn reader(&self) -> {}<'_> {{\n", msg_name)); - output.push_str(&format!(" {}::new(&self.data).expect(\"failed to create reader\")\n", msg_name)); + output.push_str(&format!( + " {}::new(&self.data).expect(\"failed to create reader\")\n", + msg_name + )); output.push_str(" }\n"); output.push_str("}\n\n"); - output.push_str(&format!("#[cfg(feature = \"alloc\")]\nimpl roto_runtime::RotoMessage for Owned{} {{\n", msg_name)); + output.push_str(&format!( + "#[cfg(feature = \"alloc\")]\nimpl roto_runtime::RotoMessage for Owned{} {{\n", + msg_name + )); output.push_str(" fn decode(buf: bytes::Bytes) -> roto_runtime::Result {\n"); output.push_str(&format!(" Ok(Owned{} {{ data: buf }})\n", msg_name)); output.push_str(" }\n\n"); @@ -551,7 +580,14 @@ where } } - let rust_file_name = format!("{}.rs", std::path::Path::new(proto_name).file_stem().unwrap().to_str().unwrap()); + let rust_file_name = format!( + "{}.rs", + std::path::Path::new(proto_name) + .file_stem() + .unwrap() + .to_str() + .unwrap() + ); let mut output = String::new(); output.push_str("// @generated by protoc-gen-roto — do not edit\n"); @@ -639,7 +675,8 @@ pub fn generate_protobuf_code( for enum_res in file_proto.enum_type() { let (enum_data, _) = enum_res.expect("Failed to iterate enum"); write_enum( - &EnumDescriptorProto::new(enum_data).expect("Failed to parse EnumDescriptorProto"), + &EnumDescriptorProto::new(enum_data) + .expect("Failed to parse EnumDescriptorProto"), output, ); } @@ -672,7 +709,8 @@ pub fn generate_service_code( for svc_res in file_proto.service() { let (svc_data, _) = svc_res.expect("Failed to iterate service"); write_service( - &ServiceDescriptorProto::new(svc_data).expect("Failed to parse ServiceDescriptorProto"), + &ServiceDescriptorProto::new(svc_data) + .expect("Failed to parse ServiceDescriptorProto"), &package, output, ); @@ -709,13 +747,7 @@ pub fn generate_rust_code( result.sort_by(|a, b| a.0.cmp(&b.0)); if generate_mod_files { - let mods = generate_files_common( - set, - files_to_generate, - true, - "", - |_, _| {}, - ); + let mods = generate_files_common(set, files_to_generate, true, "", |_, _| {}); for (filename, content) in mods { if filename == "mod.rs" || filename.contains("/mod.rs") { result.push((filename, content)); @@ -732,12 +764,18 @@ fn strip_boilerplate(content: &str) -> String { fn write_service(svc_proto: &ServiceDescriptorProto, package: &str, output: &mut String) { output.push_str(SERVICE_IMPORTS); + output.push_str("\n"); let svc_name = to_pascal_case(svc_proto.name().unwrap()); - output.push_str(&format!("#[async_trait::async_trait]\npub trait {}: Send + Sync + 'static {{\n", svc_name)); + output.push_str("#[cfg(feature = \"alloc\")]\n"); + output.push_str(&format!( + "#[async_trait::async_trait]\npub trait {}: Send + Sync + 'static {{\n", + svc_name + )); for method_res in svc_proto.method() { let (method_data, _) = method_res.expect("Failed to iterate method"); - let method_proto = MethodDescriptorProto::new(method_data).expect("Failed to parse MethodDescriptorProto"); + let method_proto = + MethodDescriptorProto::new(method_data).expect("Failed to parse MethodDescriptorProto"); let method_name = to_snake_case(method_proto.name().unwrap()); let input_full_name = method_proto.input_type().unwrap(); @@ -759,7 +797,10 @@ fn write_service(svc_proto: &ServiceDescriptorProto, package: &str, output: &mut }; let resp_type = if server_streaming { - format!("Response> + Send>>>", output_owned) + format!( + "Response> + Send>>>", + output_owned + ) } else { format!("Response<{}>", output_owned) }; @@ -772,27 +813,46 @@ fn write_service(svc_proto: &ServiceDescriptorProto, package: &str, output: &mut output.push_str("}\n\n"); let server_name = format!("{}Server", svc_name); - output.push_str(&format!("#[derive(Clone)]\npub struct {} {{\n", server_name)); + output.push_str("#[cfg(feature = \"alloc\")]\n"); + output.push_str(&format!( + "#[derive(Clone)]\npub struct {} {{\n", + server_name + )); output.push_str(&format!(" inner: Arc,\n", svc_name)); output.push_str(" pool: Arc,\n"); output.push_str("}\n\n"); + output.push_str("#[cfg(feature = \"alloc\")]\n"); output.push_str(&format!("impl {} {{\n", server_name)); - output.push_str(&format!(" pub fn new(inner: Arc, pool: Arc) -> Self {{\n", svc_name)); + output.push_str(&format!( + " pub fn new(inner: Arc, pool: Arc) -> Self {{\n", + svc_name + )); output.push_str(" Self { inner, pool }\n"); output.push_str(" }\n"); output.push_str("}\n\n"); - output.push_str(&format!("impl tonic::server::NamedService for {} {{\n", server_name)); + output.push_str("#[cfg(feature = \"alloc\")]\n"); + output.push_str(&format!( + "impl tonic::server::NamedService for {} {{\n", + server_name + )); let full_svc_name = if package.is_empty() { svc_proto.name().unwrap().to_string() } else { format!("{}.{}", package, svc_proto.name().unwrap()) }; - output.push_str(&format!(" const NAME: &'static str = \"{}\";\n", full_svc_name)); + output.push_str(&format!( + " const NAME: &'static str = \"{}\";\n", + full_svc_name + )); output.push_str("}\n\n"); - output.push_str(&format!("impl Service> for {} {{\n", server_name)); + output.push_str("#[cfg(feature = \"alloc\")]\n"); + output.push_str(&format!( + "impl Service> for {} {{\n", + server_name + )); output.push_str(" type Response = http::Response;\n"); output.push_str(" type Error = std::convert::Infallible;\n"); output.push_str(" type Future = Pin> + Send>>;\n\n"); @@ -829,14 +889,20 @@ fn write_service(svc_proto: &ServiceDescriptorProto, package: &str, output: &mut let mut methods = Vec::new(); for method_res in svc_proto.method() { let (method_data, _) = method_res.expect("Failed to iterate method"); - let method_proto = MethodDescriptorProto::new(method_data).expect("Failed to parse MethodDescriptorProto"); + let method_proto = + MethodDescriptorProto::new(method_data).expect("Failed to parse MethodDescriptorProto"); let original_method_name = method_proto.name().unwrap().to_string(); let method_name = to_snake_case(&original_method_name); let input_full_name = method_proto.input_type().unwrap(); let input_type = input_full_name.split('.').last().unwrap(); let input_owned = format!("Owned{}", input_type); let server_streaming = method_proto.server_streaming().unwrap_or(false); - methods.push((original_method_name, method_name, input_owned, server_streaming)); + methods.push(( + original_method_name, + method_name, + input_owned, + server_streaming, + )); } for (original_method_name, method_name, input_owned, server_streaming) in methods { @@ -846,7 +912,12 @@ fn write_service(svc_proto: &ServiceDescriptorProto, package: &str, output: &mut let full_path = if package.is_empty() { format!("/{}/{}", svc_proto.name().unwrap(), original_method_name) } else { - format!("/{}.{}/{}", package, svc_proto.name().unwrap(), original_method_name) + format!( + "/{}.{}/{}", + package, + svc_proto.name().unwrap(), + original_method_name + ) }; output.push_str(&format!(" if path == \"{}\" {{\n", full_path)); output.push_str(" let res_body = BoxBody::new(StatusBody::new(Some(Bytes::from_static(&[0, 0, 0, 0, 0])), 0));\n"); @@ -857,17 +928,28 @@ fn write_service(svc_proto: &ServiceDescriptorProto, package: &str, output: &mut let full_path = if package.is_empty() { format!("/{}/{}", svc_proto.name().unwrap(), original_method_name) } else { - format!("/{}.{}/{}", package, svc_proto.name().unwrap(), original_method_name) + format!( + "/{}.{}/{}", + package, + svc_proto.name().unwrap(), + original_method_name + ) }; output.push_str(&format!(" if path == \"{}\" {{\n", full_path)); - output.push_str(&format!(" let request_msg = match {}::decode(payload) {{\n", input_owned)); + output.push_str(&format!( + " let request_msg = match {}::decode(payload) {{\n", + input_owned + )); output.push_str(" Ok(msg) => msg,\n"); output.push_str(" Err(e) => {\n"); output.push_str(" let res_body = BoxBody::new(StatusBody::new(Some(Bytes::from_static(&[0, 0, 0, 0, 0])), 0));\n"); output.push_str(" return Ok(http::Response::builder().status(200).body(res_body).unwrap());\n"); output.push_str(" }\n"); output.push_str(" };\n\n"); - output.push_str(&format!(" let response = match inner.{}(Request::new(request_msg)).await {{\n", method_name)); + output.push_str(&format!( + " let response = match inner.{}(Request::new(request_msg)).await {{\n", + method_name + )); output.push_str(" Ok(res) => res,\n"); output.push_str(" Err(e) => {\n"); output.push_str(" let res_body = BoxBody::new(StatusBody::new(Some(Bytes::from_static(&[0, 0, 0, 0, 0])), 0));\n"); @@ -884,7 +966,9 @@ fn write_service(svc_proto: &ServiceDescriptorProto, package: &str, output: &mut output.push_str(" let frame_len = res_buf.len();\n"); output.push_str(" let frame = res_buf.split_to(frame_len).freeze();\n"); output.push_str(" pool.put(res_buf);\n"); - output.push_str(" let res_body = BoxBody::new(StatusBody::new(Some(frame), 0));\n"); + output.push_str( + " let res_body = BoxBody::new(StatusBody::new(Some(frame), 0));\n", + ); output.push_str(" routed = true;\n"); output.push_str(" return Ok(http::Response::builder().status(200).header(\"content-type\", \"application/grpc\").body(res_body).unwrap());\n"); output.push_str(" }\n"); diff --git a/examples/hello_world/Cargo.toml b/examples/hello_world/Cargo.toml index cd2f19b..9fdfd50 100644 --- a/examples/hello_world/Cargo.toml +++ b/examples/hello_world/Cargo.toml @@ -24,7 +24,12 @@ futures-util = "0.3" http-body-util = "0.1" http = "1.1" http-body = "1.0" +async-trait = "0.1" [build-dependencies] tonic-build = "0.12" roto-codegen = { path = "../../codegen" } + +[features] +default = ["alloc"] +alloc = [] diff --git a/examples/no_std_test/src/main.rs b/examples/no_std_test/src/main.rs index 1cf9cf9..fb90d7c 100644 --- a/examples/no_std_test/src/main.rs +++ b/examples/no_std_test/src/main.rs @@ -1,5 +1,5 @@ #![no_std] -#![no_main] +#![cfg_attr(not(test), no_main)] mod helloworld; @@ -8,10 +8,11 @@ extern crate alloc; use roto_runtime::{ProtoAccessor, RotoMessage, RotoOwned}; -#[cfg(feature = "alloc")] +#[cfg(all(feature = "alloc", not(test)))] #[global_allocator] static ALLOCATOR: embedded_alloc::Heap = embedded_alloc::Heap::empty(); +#[cfg(not(test))] #[panic_handler] fn panic(_info: &core::panic::PanicInfo) -> ! { loop {} @@ -27,6 +28,7 @@ pub extern "C" fn _critical_section_1_0_release() {} static HELLO_DATA: &[u8] = &[0x0A, 0x05, 0x57, 0x6f, 0x72, 0x6c, 0x64]; +#[cfg(all(not(test), not(feature = "alloc")))] #[unsafe(no_mangle)] pub extern "C" fn _start() -> ! { #[cfg(not(feature = "alloc"))] @@ -58,3 +60,15 @@ pub extern "C" fn _start() -> ! { loop {} } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_helloworld_decoding() { + let hello = helloworld::Hello::new(HELLO_DATA).expect("failed to decode hello"); + let name = hello.name().expect("failed to get name"); + assert!(!name.is_empty()); + } +} diff --git a/roto-tonic/Cargo.toml b/roto-tonic/Cargo.toml index b8386cd..7b74a3c 100644 --- a/roto-tonic/Cargo.toml +++ b/roto-tonic/Cargo.toml @@ -19,3 +19,7 @@ http = "1.1" [build-dependencies] tonic-build = "0.12" + +[features] +default = ["alloc"] +alloc = [] diff --git a/roto-tonic/src/generated/interop.rs b/roto-tonic/src/generated/interop.rs index b2ce666..4429bfb 100644 --- a/roto-tonic/src/generated/interop.rs +++ b/roto-tonic/src/generated/interop.rs @@ -2,19 +2,9 @@ #[allow(unused_imports)] use roto_runtime::{ProtoAccessor, ProtoBuilder, Result, RotoError, read_varint, RepeatedFieldIterator, RotoMessage}; -use std::str; +use core::str; +#[cfg(feature = "alloc")] use bytes::{Bytes, BytesMut, Buf, BufMut}; -use tonic::{Request, Response, Status}; -use std::sync::Arc; -use std::pin::Pin; -use std::task::{Context, Poll}; -use std::future::Future; -use tower::Service; -use tonic::body::BoxBody; -use tokio_stream::Stream; -use crate::{BufferPool, StatusBody}; -use async_trait::async_trait; -use http_body_util::BodyExt; pub struct UnaryRequest<'a> { accessor: roto_runtime::ProtoAccessor<'a>, @@ -39,7 +29,7 @@ message_offset, pub fn message(&self) -> roto_runtime::Result<&'a str> { let offset = self.message_offset.ok_or(roto_runtime::RotoError::FieldNotFound)?; let (bytes, _) = self.accessor.get_value_at(offset)?; - std::str::from_utf8(bytes).map_err(|_| roto_runtime::RotoError::WireFormatViolation) + core::str::from_utf8(bytes).map_err(|_| roto_runtime::RotoError::WireFormatViolation) } pub fn message_or_default(&self) -> roto_runtime::Result<&'a str> { @@ -92,10 +82,12 @@ impl<'b> UnaryRequestBuilder<'b> { } } +#[cfg(feature = "alloc")] pub struct OwnedUnaryRequest { pub data: bytes::Bytes, } +#[cfg(feature = "alloc")] impl roto_runtime::RotoOwned for OwnedUnaryRequest { type Reader<'a> = UnaryRequest<'a>; fn reader(&self) -> UnaryRequest<'_> { @@ -103,6 +95,7 @@ impl roto_runtime::RotoOwned for OwnedUnaryRequest { } } +#[cfg(feature = "alloc")] impl roto_runtime::RotoMessage for OwnedUnaryRequest { fn decode(buf: bytes::Bytes) -> roto_runtime::Result { Ok(OwnedUnaryRequest { data: buf }) @@ -136,7 +129,7 @@ reply_offset, pub fn reply(&self) -> roto_runtime::Result<&'a str> { let offset = self.reply_offset.ok_or(roto_runtime::RotoError::FieldNotFound)?; let (bytes, _) = self.accessor.get_value_at(offset)?; - std::str::from_utf8(bytes).map_err(|_| roto_runtime::RotoError::WireFormatViolation) + core::str::from_utf8(bytes).map_err(|_| roto_runtime::RotoError::WireFormatViolation) } pub fn reply_or_default(&self) -> roto_runtime::Result<&'a str> { @@ -189,10 +182,12 @@ impl<'b> UnaryResponseBuilder<'b> { } } +#[cfg(feature = "alloc")] pub struct OwnedUnaryResponse { pub data: bytes::Bytes, } +#[cfg(feature = "alloc")] impl roto_runtime::RotoOwned for OwnedUnaryResponse { type Reader<'a> = UnaryResponse<'a>; fn reader(&self) -> UnaryResponse<'_> { @@ -200,6 +195,7 @@ impl roto_runtime::RotoOwned for OwnedUnaryResponse { } } +#[cfg(feature = "alloc")] impl roto_runtime::RotoMessage for OwnedUnaryResponse { fn decode(buf: bytes::Bytes) -> roto_runtime::Result { Ok(OwnedUnaryResponse { data: buf }) @@ -233,7 +229,7 @@ query_offset, pub fn query(&self) -> roto_runtime::Result<&'a str> { let offset = self.query_offset.ok_or(roto_runtime::RotoError::FieldNotFound)?; let (bytes, _) = self.accessor.get_value_at(offset)?; - std::str::from_utf8(bytes).map_err(|_| roto_runtime::RotoError::WireFormatViolation) + core::str::from_utf8(bytes).map_err(|_| roto_runtime::RotoError::WireFormatViolation) } pub fn query_or_default(&self) -> roto_runtime::Result<&'a str> { @@ -286,10 +282,12 @@ impl<'b> StreamingRequestBuilder<'b> { } } +#[cfg(feature = "alloc")] pub struct OwnedStreamingRequest { pub data: bytes::Bytes, } +#[cfg(feature = "alloc")] impl roto_runtime::RotoOwned for OwnedStreamingRequest { type Reader<'a> = StreamingRequest<'a>; fn reader(&self) -> StreamingRequest<'_> { @@ -297,6 +295,7 @@ impl roto_runtime::RotoOwned for OwnedStreamingRequest { } } +#[cfg(feature = "alloc")] impl roto_runtime::RotoMessage for OwnedStreamingRequest { fn decode(buf: bytes::Bytes) -> roto_runtime::Result { Ok(OwnedStreamingRequest { data: buf }) @@ -330,7 +329,7 @@ item_offset, pub fn item(&self) -> roto_runtime::Result<&'a str> { let offset = self.item_offset.ok_or(roto_runtime::RotoError::FieldNotFound)?; let (bytes, _) = self.accessor.get_value_at(offset)?; - std::str::from_utf8(bytes).map_err(|_| roto_runtime::RotoError::WireFormatViolation) + core::str::from_utf8(bytes).map_err(|_| roto_runtime::RotoError::WireFormatViolation) } pub fn item_or_default(&self) -> roto_runtime::Result<&'a str> { @@ -383,10 +382,12 @@ impl<'b> StreamingResponseBuilder<'b> { } } +#[cfg(feature = "alloc")] pub struct OwnedStreamingResponse { pub data: bytes::Bytes, } +#[cfg(feature = "alloc")] impl roto_runtime::RotoOwned for OwnedStreamingResponse { type Reader<'a> = StreamingResponse<'a>; fn reader(&self) -> StreamingResponse<'_> { @@ -394,6 +395,7 @@ impl roto_runtime::RotoOwned for OwnedStreamingResponse { } } +#[cfg(feature = "alloc")] impl roto_runtime::RotoMessage for OwnedStreamingResponse { fn decode(buf: bytes::Bytes) -> roto_runtime::Result { Ok(OwnedStreamingResponse { data: buf }) @@ -405,28 +407,62 @@ impl roto_runtime::RotoMessage for OwnedStreamingResponse { } -#[async_trait] +// @generated by protoc-gen-roto — do not edit +#[allow(unused_imports)] + + +#[cfg(feature = "alloc")] +use tonic::{Request, Response, Status}; +#[cfg(feature = "alloc")] +use tokio_stream::Stream; +#[cfg(feature = "alloc")] +use std::pin::Pin; +#[cfg(feature = "alloc")] +use std::sync::Arc; +#[cfg(feature = "alloc")] +use std::task::{Context, Poll}; +#[cfg(feature = "alloc")] +use std::future::Future; +#[cfg(feature = "alloc")] +use tonic::body::BoxBody; +#[cfg(feature = "alloc")] +use tower::Service; +#[cfg(feature = "alloc")] +use futures_util::StreamExt; +#[cfg(feature = "alloc")] +use http_body_util::BodyExt; +#[cfg(feature = "alloc")] +use http_body::Body; +#[cfg(feature = "alloc")] +use crate::{BufferPool, StatusBody}; + +#[cfg(feature = "alloc")] +#[async_trait::async_trait] pub trait InteropService: Send + Sync + 'static { async fn unary_call(&self, request: Request) -> std::result::Result, Status>; async fn streaming_call(&self, request: Request) -> std::result::Result> + Send>>>, Status>; } +#[cfg(feature = "alloc")] #[derive(Clone)] pub struct InteropServiceServer { inner: Arc, pool: Arc, } +#[cfg(feature = "alloc")] impl InteropServiceServer { pub fn new(inner: Arc, pool: Arc) -> Self { Self { inner, pool } } } +#[cfg(feature = "alloc")] impl tonic::server::NamedService for InteropServiceServer { const NAME: &'static str = "interop.InteropService"; } +#[cfg(feature = "alloc")] impl Service> for InteropServiceServer { type Response = http::Response; type Error = std::convert::Infallible; diff --git a/runtime/src/lib.rs b/runtime/src/lib.rs index 56a902b..7fdacfd 100644 --- a/runtime/src/lib.rs +++ b/runtime/src/lib.rs @@ -1,5 +1,8 @@ #![no_std] +#[cfg(feature = "alloc")] +extern crate alloc; + #[cfg(feature = "std")] extern crate std; @@ -438,6 +441,8 @@ impl<'a> Iterator for RawFieldIterator<'a> { #[cfg(test)] mod tests { use super::*; + #[cfg(feature = "alloc")] + use alloc::{vec, vec::{Vec}}; #[test] fn test_varint_read_write() {