Implement buffer pooling in hello_world example
Introduce BufferPool to reuse BytesMut allocations for requests and responses. Update AGENTS.md to require build and test success.
This commit is contained in:
@@ -1,12 +1,12 @@
|
||||
use std::pin::Pin;
|
||||
use std::future::Future;
|
||||
use std::task::{Context, Poll};
|
||||
use std::sync::Arc;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use tonic::{transport::Server, Request, Response, Status};
|
||||
use roto_tonic::RotoCodec;
|
||||
use hello::{HelloWorldService, OwnedHelloRequest, OwnedHelloResponse};
|
||||
use tower::Service;
|
||||
use bytes::{Bytes, Buf, BufMut};
|
||||
use bytes::{Bytes, BytesMut, Buf, BufMut};
|
||||
use tonic::body::BoxBody;
|
||||
use futures_util::StreamExt;
|
||||
use roto_runtime::{RotoOwned, RotoMessage};
|
||||
@@ -17,8 +17,41 @@ pub mod hello {
|
||||
include!("../../proto/hello.rs");
|
||||
}
|
||||
|
||||
#[derive(Default, Clone)]
|
||||
pub struct MyHelloWorld {}
|
||||
struct BufferPool {
|
||||
pool: Mutex<Vec<BytesMut>>,
|
||||
default_capacity: usize,
|
||||
}
|
||||
|
||||
impl BufferPool {
|
||||
fn new(default_capacity: usize) -> Self {
|
||||
Self {
|
||||
pool: Mutex::new(Vec::new()),
|
||||
default_capacity,
|
||||
}
|
||||
}
|
||||
|
||||
fn get(&self) -> BytesMut {
|
||||
self.pool.lock().unwrap().pop().unwrap_or_else(|| BytesMut::with_capacity(self.default_capacity))
|
||||
}
|
||||
|
||||
fn put(&self, mut buf: BytesMut) {
|
||||
buf.clear();
|
||||
if buf.capacity() >= self.default_capacity {
|
||||
self.pool.lock().unwrap().push(buf);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct MyHelloWorld {
|
||||
pool: Arc<BufferPool>,
|
||||
}
|
||||
|
||||
impl MyHelloWorld {
|
||||
pub fn new(pool: Arc<BufferPool>) -> Self {
|
||||
Self { pool }
|
||||
}
|
||||
}
|
||||
|
||||
#[tonic::async_trait]
|
||||
impl HelloWorldService for MyHelloWorld {
|
||||
@@ -30,13 +63,18 @@ impl HelloWorldService for MyHelloWorld {
|
||||
let reader = req.reader();
|
||||
let name = reader.name().unwrap_or("Unknown");
|
||||
|
||||
let mut buf = vec![0u8; 1024];
|
||||
let slice = hello::HelloResponseBuilder::builder(&mut buf)
|
||||
let mut buf = self.pool.get();
|
||||
buf.resize(1024, 0);
|
||||
let slice = hello::HelloResponseBuilder::builder(&mut buf[..])
|
||||
.message(&format!("Hello {}!", name)).unwrap()
|
||||
.finish().unwrap();
|
||||
|
||||
let res_len = slice.len();
|
||||
let response_bytes = buf.split_to(res_len).freeze();
|
||||
self.pool.put(buf);
|
||||
|
||||
let reply = OwnedHelloResponse {
|
||||
data: bytes::Bytes::copy_from_slice(slice),
|
||||
data: response_bytes,
|
||||
};
|
||||
|
||||
Ok(Response::new(reply))
|
||||
@@ -48,11 +86,12 @@ impl HelloWorldService for MyHelloWorld {
|
||||
#[derive(Clone)]
|
||||
pub struct HelloWorldServer {
|
||||
inner: Arc<MyHelloWorld>,
|
||||
pool: Arc<BufferPool>,
|
||||
}
|
||||
|
||||
impl HelloWorldServer {
|
||||
pub fn new(inner: MyHelloWorld) -> Self {
|
||||
Self { inner: Arc::new(inner) }
|
||||
pub fn new(inner: MyHelloWorld, pool: Arc<BufferPool>) -> Self {
|
||||
Self { inner: Arc::new(inner), pool }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -89,32 +128,43 @@ impl Service<http::Request<BoxBody>> for HelloWorldServer {
|
||||
|
||||
fn call(&mut self, req: http::Request<BoxBody>) -> Self::Future {
|
||||
let inner = self.inner.clone();
|
||||
let pool = self.pool.clone();
|
||||
println!("Server received request: {} {}", req.method(), req.uri());
|
||||
|
||||
Box::pin(async move {
|
||||
let body = req.into_body();
|
||||
let bytes_vec = body.collect().await.map_err(|e| {
|
||||
println!("Body collect error: {}", e);
|
||||
panic!("Body collect error: {}", e);
|
||||
})?.to_bytes();
|
||||
let mut buf = pool.get();
|
||||
let mut stream = body;
|
||||
while let Some(frame_result) = stream.frame().await {
|
||||
let frame = frame_result.map_err(|e| {
|
||||
println!("Body frame error: {}", e);
|
||||
panic!("Body frame error: {}", e);
|
||||
})?;
|
||||
if let Some(data) = frame.data_ref() {
|
||||
buf.put(data.clone());
|
||||
}
|
||||
}
|
||||
|
||||
let total_len = buf.len();
|
||||
let bytes_vec = buf.split_to(total_len).freeze();
|
||||
pool.put(buf);
|
||||
println!("Collected body bytes: {} bytes", bytes_vec.len());
|
||||
|
||||
if bytes_vec.len() < 5 {
|
||||
println!("Body too short: {} bytes", bytes_vec.len());
|
||||
let res_body = BoxBody::new(StatusBody(Some(Bytes::from(vec![0, 0, 0, 0, 0]))));
|
||||
let res_body = BoxBody::new(StatusBody(Some(Bytes::from_static(&[0, 0, 0, 0, 0]))));
|
||||
return Ok(http::Response::builder()
|
||||
.status(200)
|
||||
.body(res_body)
|
||||
.unwrap());
|
||||
}
|
||||
|
||||
let data = &bytes_vec[5..];
|
||||
println!("Decoding request from {} bytes", data.len());
|
||||
let request_msg = match OwnedHelloRequest::decode(Bytes::copy_from_slice(data)) {
|
||||
println!("Decoding request from {} bytes", bytes_vec.len() - 5);
|
||||
let request_msg = match OwnedHelloRequest::decode(bytes_vec.slice(5..)) {
|
||||
Ok(msg) => msg,
|
||||
Err(e) => {
|
||||
println!("Decode error: {}", e);
|
||||
let res_body = BoxBody::new(StatusBody(Some(Bytes::from(vec![0, 0, 0, 0, 0]))));
|
||||
let res_body = BoxBody::new(StatusBody(Some(Bytes::from_static(&[0, 0, 0, 0, 0]))));
|
||||
return Ok(http::Response::builder().status(200).body(res_body).unwrap());
|
||||
}
|
||||
};
|
||||
@@ -124,7 +174,7 @@ impl Service<http::Request<BoxBody>> for HelloWorldServer {
|
||||
Ok(res) => res,
|
||||
Err(e) => {
|
||||
println!("Service error: {}", e);
|
||||
let res_body = BoxBody::new(StatusBody(Some(Bytes::from(vec![0, 0, 0, 0, 0]))));
|
||||
let res_body = BoxBody::new(StatusBody(Some(Bytes::from_static(&[0, 0, 0, 0, 0]))));
|
||||
return Ok(http::Response::builder().status(200).body(res_body).unwrap());
|
||||
}
|
||||
};
|
||||
@@ -133,13 +183,17 @@ impl Service<http::Request<BoxBody>> for HelloWorldServer {
|
||||
let response_bytes = response_msg.bytes();
|
||||
println!("Service responded with {} bytes", response_bytes.len());
|
||||
|
||||
let mut res_buf = vec![0u8; 5 + response_bytes.len()];
|
||||
res_buf[0] = 0;
|
||||
let mut res_buf = pool.get();
|
||||
res_buf.put_u8(0);
|
||||
let len = response_bytes.len() as u32;
|
||||
res_buf[1..5].copy_from_slice(&len.to_be_bytes());
|
||||
res_buf[5..].copy_from_slice(&response_bytes);
|
||||
res_buf.put_slice(&len.to_be_bytes());
|
||||
res_buf.put_slice(&response_bytes);
|
||||
|
||||
let res_body = BoxBody::new(StatusBody(Some(Bytes::from(res_buf))));
|
||||
let frame_len = res_buf.len();
|
||||
let frame = res_buf.split_to(frame_len).freeze();
|
||||
pool.put(res_buf);
|
||||
|
||||
let res_body = BoxBody::new(StatusBody(Some(frame)));
|
||||
Ok(http::Response::builder()
|
||||
.status(200)
|
||||
.header("content-type", "application/grpc")
|
||||
@@ -152,12 +206,13 @@ impl Service<http::Request<BoxBody>> for HelloWorldServer {
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let addr: std::net::SocketAddr = "[::1]:50051".parse()?;
|
||||
let hello = MyHelloWorld::default();
|
||||
let pool = Arc::new(BufferPool::new(1024));
|
||||
let hello = MyHelloWorld::new(pool.clone());
|
||||
|
||||
println!("Server listening on {}", addr);
|
||||
|
||||
Server::builder()
|
||||
.add_service(HelloWorldServer::new(hello))
|
||||
.add_service(HelloWorldServer::new(hello, pool))
|
||||
.serve(addr)
|
||||
.await?;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user