Implement buffer pooling in hello_world example

Introduce BufferPool to reuse BytesMut allocations for requests and
responses. Update AGENTS.md to require build and test success.
This commit is contained in:
2026-05-13 09:45:27 -07:00
parent dfdcd8ae46
commit db2bf1bffd
2 changed files with 83 additions and 26 deletions
+2
View File
@@ -7,6 +7,8 @@ you should be able to work without user assistance.
If you are writing code, write tests first. The tests must pass for your work to be complete. If you are writing code, write tests first. The tests must pass for your work to be complete.
Before considering a task complete, make sure that all target build, and all tests suceed.
## Special instructions ## Special instructions
### Fork ### Fork
+81 -26
View File
@@ -1,12 +1,12 @@
use std::pin::Pin; use std::pin::Pin;
use std::future::Future; use std::future::Future;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
use std::sync::Arc; use std::sync::{Arc, Mutex};
use tonic::{transport::Server, Request, Response, Status}; use tonic::{transport::Server, Request, Response, Status};
use roto_tonic::RotoCodec; use roto_tonic::RotoCodec;
use hello::{HelloWorldService, OwnedHelloRequest, OwnedHelloResponse}; use hello::{HelloWorldService, OwnedHelloRequest, OwnedHelloResponse};
use tower::Service; use tower::Service;
use bytes::{Bytes, Buf, BufMut}; use bytes::{Bytes, BytesMut, Buf, BufMut};
use tonic::body::BoxBody; use tonic::body::BoxBody;
use futures_util::StreamExt; use futures_util::StreamExt;
use roto_runtime::{RotoOwned, RotoMessage}; use roto_runtime::{RotoOwned, RotoMessage};
@@ -17,8 +17,41 @@ pub mod hello {
include!("../../proto/hello.rs"); include!("../../proto/hello.rs");
} }
#[derive(Default, Clone)] struct BufferPool {
pub struct MyHelloWorld {} pool: Mutex<Vec<BytesMut>>,
default_capacity: usize,
}
impl BufferPool {
fn new(default_capacity: usize) -> Self {
Self {
pool: Mutex::new(Vec::new()),
default_capacity,
}
}
fn get(&self) -> BytesMut {
self.pool.lock().unwrap().pop().unwrap_or_else(|| BytesMut::with_capacity(self.default_capacity))
}
fn put(&self, mut buf: BytesMut) {
buf.clear();
if buf.capacity() >= self.default_capacity {
self.pool.lock().unwrap().push(buf);
}
}
}
#[derive(Clone)]
pub struct MyHelloWorld {
pool: Arc<BufferPool>,
}
impl MyHelloWorld {
pub fn new(pool: Arc<BufferPool>) -> Self {
Self { pool }
}
}
#[tonic::async_trait] #[tonic::async_trait]
impl HelloWorldService for MyHelloWorld { impl HelloWorldService for MyHelloWorld {
@@ -30,13 +63,18 @@ impl HelloWorldService for MyHelloWorld {
let reader = req.reader(); let reader = req.reader();
let name = reader.name().unwrap_or("Unknown"); let name = reader.name().unwrap_or("Unknown");
let mut buf = vec![0u8; 1024]; let mut buf = self.pool.get();
let slice = hello::HelloResponseBuilder::builder(&mut buf) buf.resize(1024, 0);
let slice = hello::HelloResponseBuilder::builder(&mut buf[..])
.message(&format!("Hello {}!", name)).unwrap() .message(&format!("Hello {}!", name)).unwrap()
.finish().unwrap(); .finish().unwrap();
let res_len = slice.len();
let response_bytes = buf.split_to(res_len).freeze();
self.pool.put(buf);
let reply = OwnedHelloResponse { let reply = OwnedHelloResponse {
data: bytes::Bytes::copy_from_slice(slice), data: response_bytes,
}; };
Ok(Response::new(reply)) Ok(Response::new(reply))
@@ -48,11 +86,12 @@ impl HelloWorldService for MyHelloWorld {
#[derive(Clone)] #[derive(Clone)]
pub struct HelloWorldServer { pub struct HelloWorldServer {
inner: Arc<MyHelloWorld>, inner: Arc<MyHelloWorld>,
pool: Arc<BufferPool>,
} }
impl HelloWorldServer { impl HelloWorldServer {
pub fn new(inner: MyHelloWorld) -> Self { pub fn new(inner: MyHelloWorld, pool: Arc<BufferPool>) -> Self {
Self { inner: Arc::new(inner) } Self { inner: Arc::new(inner), pool }
} }
} }
@@ -89,32 +128,43 @@ impl Service<http::Request<BoxBody>> for HelloWorldServer {
fn call(&mut self, req: http::Request<BoxBody>) -> Self::Future { fn call(&mut self, req: http::Request<BoxBody>) -> Self::Future {
let inner = self.inner.clone(); let inner = self.inner.clone();
let pool = self.pool.clone();
println!("Server received request: {} {}", req.method(), req.uri()); println!("Server received request: {} {}", req.method(), req.uri());
Box::pin(async move { Box::pin(async move {
let body = req.into_body(); let body = req.into_body();
let bytes_vec = body.collect().await.map_err(|e| { let mut buf = pool.get();
println!("Body collect error: {}", e); let mut stream = body;
panic!("Body collect error: {}", e); while let Some(frame_result) = stream.frame().await {
})?.to_bytes(); let frame = frame_result.map_err(|e| {
println!("Body frame error: {}", e);
panic!("Body frame error: {}", e);
})?;
if let Some(data) = frame.data_ref() {
buf.put(data.clone());
}
}
let total_len = buf.len();
let bytes_vec = buf.split_to(total_len).freeze();
pool.put(buf);
println!("Collected body bytes: {} bytes", bytes_vec.len()); println!("Collected body bytes: {} bytes", bytes_vec.len());
if bytes_vec.len() < 5 { if bytes_vec.len() < 5 {
println!("Body too short: {} bytes", bytes_vec.len()); println!("Body too short: {} bytes", bytes_vec.len());
let res_body = BoxBody::new(StatusBody(Some(Bytes::from(vec![0, 0, 0, 0, 0])))); let res_body = BoxBody::new(StatusBody(Some(Bytes::from_static(&[0, 0, 0, 0, 0]))));
return Ok(http::Response::builder() return Ok(http::Response::builder()
.status(200) .status(200)
.body(res_body) .body(res_body)
.unwrap()); .unwrap());
} }
let data = &bytes_vec[5..]; println!("Decoding request from {} bytes", bytes_vec.len() - 5);
println!("Decoding request from {} bytes", data.len()); let request_msg = match OwnedHelloRequest::decode(bytes_vec.slice(5..)) {
let request_msg = match OwnedHelloRequest::decode(Bytes::copy_from_slice(data)) {
Ok(msg) => msg, Ok(msg) => msg,
Err(e) => { Err(e) => {
println!("Decode error: {}", e); println!("Decode error: {}", e);
let res_body = BoxBody::new(StatusBody(Some(Bytes::from(vec![0, 0, 0, 0, 0])))); let res_body = BoxBody::new(StatusBody(Some(Bytes::from_static(&[0, 0, 0, 0, 0]))));
return Ok(http::Response::builder().status(200).body(res_body).unwrap()); return Ok(http::Response::builder().status(200).body(res_body).unwrap());
} }
}; };
@@ -124,7 +174,7 @@ impl Service<http::Request<BoxBody>> for HelloWorldServer {
Ok(res) => res, Ok(res) => res,
Err(e) => { Err(e) => {
println!("Service error: {}", e); println!("Service error: {}", e);
let res_body = BoxBody::new(StatusBody(Some(Bytes::from(vec![0, 0, 0, 0, 0])))); let res_body = BoxBody::new(StatusBody(Some(Bytes::from_static(&[0, 0, 0, 0, 0]))));
return Ok(http::Response::builder().status(200).body(res_body).unwrap()); return Ok(http::Response::builder().status(200).body(res_body).unwrap());
} }
}; };
@@ -133,13 +183,17 @@ impl Service<http::Request<BoxBody>> for HelloWorldServer {
let response_bytes = response_msg.bytes(); let response_bytes = response_msg.bytes();
println!("Service responded with {} bytes", response_bytes.len()); println!("Service responded with {} bytes", response_bytes.len());
let mut res_buf = vec![0u8; 5 + response_bytes.len()]; let mut res_buf = pool.get();
res_buf[0] = 0; res_buf.put_u8(0);
let len = response_bytes.len() as u32; let len = response_bytes.len() as u32;
res_buf[1..5].copy_from_slice(&len.to_be_bytes()); res_buf.put_slice(&len.to_be_bytes());
res_buf[5..].copy_from_slice(&response_bytes); res_buf.put_slice(&response_bytes);
let res_body = BoxBody::new(StatusBody(Some(Bytes::from(res_buf)))); let frame_len = res_buf.len();
let frame = res_buf.split_to(frame_len).freeze();
pool.put(res_buf);
let res_body = BoxBody::new(StatusBody(Some(frame)));
Ok(http::Response::builder() Ok(http::Response::builder()
.status(200) .status(200)
.header("content-type", "application/grpc") .header("content-type", "application/grpc")
@@ -152,12 +206,13 @@ impl Service<http::Request<BoxBody>> for HelloWorldServer {
#[tokio::main] #[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> { async fn main() -> Result<(), Box<dyn std::error::Error>> {
let addr: std::net::SocketAddr = "[::1]:50051".parse()?; let addr: std::net::SocketAddr = "[::1]:50051".parse()?;
let hello = MyHelloWorld::default(); let pool = Arc::new(BufferPool::new(1024));
let hello = MyHelloWorld::new(pool.clone());
println!("Server listening on {}", addr); println!("Server listening on {}", addr);
Server::builder() Server::builder()
.add_service(HelloWorldServer::new(hello)) .add_service(HelloWorldServer::new(hello, pool))
.serve(addr) .serve(addr)
.await?; .await?;