Update StatusBody to support gRPC trailers

Change StatusBody from a tuple struct to a struct containing both data
and trailers. Update the codegen to use the new StatusBody::new
constructor to specify gRPC status codes.

Also remove the temp_test_project.
This commit is contained in:
2026-05-15 18:57:15 -07:00
parent db89c9842a
commit 809a0d844c
5 changed files with 26 additions and 26 deletions
Generated
+1 -17
View File
@@ -1196,6 +1196,7 @@ version = "0.1.0"
dependencies = [ dependencies = [
"bytes", "bytes",
"futures-util", "futures-util",
"http",
"http-body", "http-body",
"http-body-util", "http-body-util",
"prost", "prost",
@@ -1352,23 +1353,6 @@ version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263" checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263"
[[package]]
name = "temp_test_project"
version = "0.1.0"
dependencies = [
"bytes",
"futures-util",
"http",
"http-body",
"http-body-util",
"roto-codegen",
"roto-runtime",
"roto-tonic",
"tokio-stream",
"tonic",
"tower 0.4.13",
]
[[package]] [[package]]
name = "tempfile" name = "tempfile"
version = "3.27.0" version = "3.27.0"
-1
View File
@@ -5,7 +5,6 @@ members = [
"protos", "protos",
"benches", "benches",
"roto-tonic", "roto-tonic",
"temp_test_project",
"examples/hello_world", "examples/hello_world",
] ]
+6 -6
View File
@@ -738,7 +738,7 @@ fn write_service(svc_proto: &ServiceDescriptorProto, output: &mut String) {
output.push_str(" let bytes_vec = buf.split_to(total_len).freeze();\n"); output.push_str(" let bytes_vec = buf.split_to(total_len).freeze();\n");
output.push_str(" pool.put(buf);\n"); output.push_str(" pool.put(buf);\n");
output.push_str(" if bytes_vec.len() < 5 {\n"); output.push_str(" if bytes_vec.len() < 5 {\n");
output.push_str(" let res_body = BoxBody::new(StatusBody(Some(Bytes::from_static(&[0, 0, 0, 0, 0]))));\n"); output.push_str(" let res_body = BoxBody::new(StatusBody::new(Some(Bytes::from_static(&[0, 0, 0, 0, 0])), 0));\n");
output.push_str(" return Ok(http::Response::builder().status(200).body(res_body).unwrap());\n"); output.push_str(" return Ok(http::Response::builder().status(200).body(res_body).unwrap());\n");
output.push_str(" }\n\n"); output.push_str(" }\n\n");
output.push_str(" let payload = bytes_vec.slice(5..);\n"); output.push_str(" let payload = bytes_vec.slice(5..);\n");
@@ -760,14 +760,14 @@ fn write_service(svc_proto: &ServiceDescriptorProto, output: &mut String) {
output.push_str(&format!(" let request_msg = match {}::decode(payload) {{\n", input_owned)); output.push_str(&format!(" let request_msg = match {}::decode(payload) {{\n", input_owned));
output.push_str(" Ok(msg) => msg,\n"); output.push_str(" Ok(msg) => msg,\n");
output.push_str(" Err(e) => {\n"); output.push_str(" Err(e) => {\n");
output.push_str(" let res_body = BoxBody::new(StatusBody(Some(Bytes::from_static(&[0, 0, 0, 0, 0]))));\n"); output.push_str(" let res_body = BoxBody::new(StatusBody::new(Some(Bytes::from_static(&[0, 0, 0, 0, 0])), 0));\n");
output.push_str(" return Ok(http::Response::builder().status(200).body(res_body).unwrap());\n"); output.push_str(" return Ok(http::Response::builder().status(200).body(res_body).unwrap());\n");
output.push_str(" }\n"); output.push_str(" }\n");
output.push_str(" };\n\n"); output.push_str(" };\n\n");
output.push_str(&format!(" let response = match inner.{}(Request::new(request_msg)).await {{\n", method_name)); output.push_str(&format!(" let response = match inner.{}(Request::new(request_msg)).await {{\n", method_name));
output.push_str(" Ok(res) => res,\n"); output.push_str(" Ok(res) => res,\n");
output.push_str(" Err(e) => {\n"); output.push_str(" Err(e) => {\n");
output.push_str(" let res_body = BoxBody::new(StatusBody(Some(Bytes::from_static(&[0, 0, 0, 0, 0]))));\n"); output.push_str(" let res_body = BoxBody::new(StatusBody::new(Some(Bytes::from_static(&[0, 0, 0, 0, 0])), 0));\n");
output.push_str(" return Ok(http::Response::builder().status(200).body(res_body).unwrap());\n"); output.push_str(" return Ok(http::Response::builder().status(200).body(res_body).unwrap());\n");
output.push_str(" }\n"); output.push_str(" }\n");
output.push_str(" };\n\n"); output.push_str(" };\n\n");
@@ -781,17 +781,17 @@ fn write_service(svc_proto: &ServiceDescriptorProto, output: &mut String) {
output.push_str(" let frame_len = res_buf.len();\n"); output.push_str(" let frame_len = res_buf.len();\n");
output.push_str(" let frame = res_buf.split_to(frame_len).freeze();\n"); output.push_str(" let frame = res_buf.split_to(frame_len).freeze();\n");
output.push_str(" pool.put(res_buf);\n"); output.push_str(" pool.put(res_buf);\n");
output.push_str(" let res_body = BoxBody::new(StatusBody(Some(frame)));\n"); output.push_str(" let res_body = BoxBody::new(StatusBody::new(Some(frame), 0));\n");
output.push_str(" routed = true;\n"); output.push_str(" routed = true;\n");
output.push_str(" return Ok(http::Response::builder().status(200).header(\"content-type\", \"application/grpc\").body(res_body).unwrap());\n"); output.push_str(" return Ok(http::Response::builder().status(200).header(\"content-type\", \"application/grpc\").body(res_body).unwrap());\n");
output.push_str(" }\n"); output.push_str(" }\n");
} }
output.push_str(" if !routed {\n"); output.push_str(" if !routed {\n");
output.push_str(" let res_body = BoxBody::new(StatusBody(Some(Bytes::from_static(&[0, 0, 0, 0, 0]))));\n"); output.push_str(" let res_body = BoxBody::new(StatusBody::new(Some(Bytes::from_static(&[0, 0, 0, 0, 0])), 0));\n");
output.push_str(" return Ok(http::Response::builder().status(200).body(res_body).unwrap());\n"); output.push_str(" return Ok(http::Response::builder().status(200).body(res_body).unwrap());\n");
output.push_str(" }\n"); output.push_str(" }\n");
output.push_str(" Ok(http::Response::builder().status(200).body(BoxBody::new(StatusBody(None))).unwrap())\n"); output.push_str(" Ok(http::Response::builder().status(200).body(BoxBody::new(StatusBody::new(None, 0))).unwrap())\n");
output.push_str(" })\n"); output.push_str(" })\n");
output.push_str(" }\n"); output.push_str(" }\n");
output.push_str("}\n"); output.push_str("}\n");
+1
View File
@@ -12,3 +12,4 @@ http-body = "1.0"
http-body-util = "0.1" http-body-util = "0.1"
tower = "0.4" tower = "0.4"
futures-util = "0.3" futures-util = "0.3"
http = "1.1"
+18 -2
View File
@@ -101,7 +101,21 @@ impl BufferPool {
} }
} }
pub struct StatusBody(pub Option<Bytes>); pub struct StatusBody {
pub data: Option<Bytes>,
pub trailers: Option<http::HeaderMap>,
}
impl StatusBody {
pub fn new(data: Option<Bytes>, status: u8) -> Self {
let mut trailers = http::HeaderMap::new();
trailers.insert("grpc-status", status.to_string().parse().unwrap());
Self {
data,
trailers: Some(trailers),
}
}
}
impl Body for StatusBody { impl Body for StatusBody {
type Data = Bytes; type Data = Bytes;
@@ -111,8 +125,10 @@ impl Body for StatusBody {
mut self: Pin<&mut Self>, mut self: Pin<&mut Self>,
cx: &mut Context<'_>, cx: &mut Context<'_>,
) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> { ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
if let Some(data) = self.0.take() { if let Some(data) = self.data.take() {
Poll::Ready(Some(Ok(http_body::Frame::data(data)))) Poll::Ready(Some(Ok(http_body::Frame::data(data))))
} else if let Some(trailers) = self.trailers.take() {
Poll::Ready(Some(Ok(http_body::Frame::trailers(trailers))))
} else { } else {
Poll::Ready(None) Poll::Ready(None)
} }