Split service and proto gen

This commit is contained in:
2026-05-17 18:53:00 -07:00
parent b2c5639338
commit 956993d1d0
3 changed files with 211 additions and 43 deletions
+42
View File
@@ -0,0 +1,42 @@
use clap::Parser;
use roto_codegen::generator::generate_protobuf_code;
use roto_codegen::google::protobuf::descriptor::FileDescriptorSet;
use std::fs;
use std::path::PathBuf;
#[derive(Parser)]
#[command(
author,
version,
about = "Generates Rust accessor and builder code from a protobuf descriptor set"
)]
struct Args {
/// Path to the descriptor set file (.desc)
#[arg(short, long)]
input: PathBuf,
/// Path to the output directory
#[arg(short, long)]
output: PathBuf,
/// Files to generate. If omitted, all files are generated.
#[arg(short, long, value_delimiter = ',')]
files: Option<Vec<String>>,
}
fn main() -> Result<(), Box<dyn std::error::Error>> {
let args = Args::parse();
let data = fs::read(&args.input)?;
let set = FileDescriptorSet::new(&data).expect("Failed to parse FileDescriptorSet");
let files = generate_protobuf_code(&set, args.files.as_deref(), true);
for (filename, content) in files {
let path = args.output.join(filename);
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)?;
}
fs::write(path, content)?;
}
Ok(())
}
+42
View File
@@ -0,0 +1,42 @@
use clap::Parser;
use roto_codegen::generator::generate_service_code;
use roto_codegen::google::protobuf::descriptor::FileDescriptorSet;
use std::fs;
use std::path::PathBuf;
#[derive(Parser)]
#[command(
author,
version,
about = "Generates Rust gRPC service code from a protobuf descriptor set"
)]
struct Args {
/// Path to the descriptor set file (.desc)
#[arg(short, long)]
input: PathBuf,
/// Path to the output directory
#[arg(short, long)]
output: PathBuf,
/// Files to generate. If omitted, all files are generated.
#[arg(short, long, value_delimiter = ',')]
files: Option<Vec<String>>,
}
fn main() -> Result<(), Box<dyn std::error::Error>> {
let args = Args::parse();
let data = fs::read(&args.input)?;
let set = FileDescriptorSet::new(&data).expect("Failed to parse FileDescriptorSet");
let files = generate_service_code(&set, args.files.as_deref(), true);
for (filename, content) in files {
let path = args.output.join(filename);
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)?;
}
fs::write(path, content)?;
}
Ok(())
}
+127 -43
View File
@@ -6,6 +6,9 @@ use roto_runtime::ProtoAccessor;
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::str; use std::str;
const DATA_IMPORTS: &str = "use roto_runtime::{ProtoAccessor, ProtoBuilder, Result, RotoError, read_varint, RepeatedFieldIterator, RotoMessage};\nuse std::str;\nuse bytes::{Bytes, BytesMut, Buf, BufMut};\n";
const SERVICE_IMPORTS: &str = "use tonic::{Request, Response, Status};\nuse tokio_stream::Stream;\nuse std::pin::Pin;\nuse std::sync::Arc;\nuse std::task::{Context, Poll};\nuse std::future::Future;\nuse tonic::body::BoxBody;\nuse tower::Service;\nuse futures_util::StreamExt;\nuse http_body_util::BodyExt;\nuse http_body::Body;\nuse crate::{BufferPool, StatusBody};\n";
pub fn to_pascal_case(s: &str) -> String { pub fn to_pascal_case(s: &str) -> String {
s.split('_') s.split('_')
.map(|word| { .map(|word| {
@@ -524,11 +527,16 @@ fn map_type_to_rust_builder(field_type: i32) -> (String, String) {
} }
} }
pub fn generate_rust_code( fn generate_files_common<F>(
set: &FileDescriptorSet, set: &FileDescriptorSet,
files_to_generate: Option<&[String]>, files_to_generate: Option<&[String]>,
generate_mod_files: bool, generate_mod_files: bool,
) -> Vec<(String, String)> { imports: &str,
mut content_gen: F,
) -> Vec<(String, String)>
where
F: FnMut(&FileDescriptorProto, &mut String),
{
let mut generated_files = Vec::new(); let mut generated_files = Vec::new();
for file_res in set.file() { for file_res in set.file() {
@@ -548,21 +556,7 @@ pub fn generate_rust_code(
let mut output = String::new(); let mut output = String::new();
output.push_str("// @generated by protoc-gen-roto — do not edit\n"); output.push_str("// @generated by protoc-gen-roto — do not edit\n");
output.push_str("#[allow(unused_imports)]\n\n"); output.push_str("#[allow(unused_imports)]\n\n");
output.push_str("use roto_runtime::{ProtoAccessor, ProtoBuilder, Result, RotoError, read_varint, RepeatedFieldIterator, RotoMessage};\n"); output.push_str(imports);
output.push_str("use std::str;\n");
output.push_str("use bytes::{Bytes, BytesMut, Buf, BufMut};\n");
output.push_str("use tonic::{Request, Response, Status};\n");
output.push_str("use tokio_stream::Stream;\n");
output.push_str("use std::pin::Pin;\n");
output.push_str("use std::sync::Arc;\n");
output.push_str("use std::task::{Context, Poll};\n");
output.push_str("use std::future::Future;\n");
output.push_str("use tonic::body::BoxBody;\n");
output.push_str("use tower::Service;\n");
output.push_str("use futures_util::StreamExt;\n");
output.push_str("use http_body_util::BodyExt;\n");
output.push_str("use http_body::Body;\n");
output.push_str("use crate::{BufferPool, StatusBody};\n\n");
for dep_res in file_proto.dependency() { for dep_res in file_proto.dependency() {
let (dep_data, _) = dep_res.expect("Failed to iterate dependency"); let (dep_data, _) = dep_res.expect("Failed to iterate dependency");
@@ -572,33 +566,8 @@ pub fn generate_rust_code(
} }
output.push_str("\n"); output.push_str("\n");
// Enums content_gen(&file_proto, &mut output);
for enum_res in file_proto.enum_type() {
let (enum_data, _) = enum_res.expect("Failed to iterate enum");
write_enum(
&EnumDescriptorProto::new(enum_data).expect("Failed to parse EnumDescriptorProto"),
&mut output,
);
}
// Messages
for msg_res in file_proto.message_type() {
let (msg_data, _) = msg_res.expect("Failed to iterate message");
write_message(
&DescriptorProto::new(msg_data).expect("Failed to parse DescriptorProto"),
&mut output,
);
}
// Services
for svc_res in file_proto.service() {
let (svc_data, _) = svc_res.expect("Failed to iterate service");
write_service(
&ServiceDescriptorProto::new(svc_data).expect("Failed to parse ServiceDescriptorProto"),
file_proto.package().unwrap_or(""),
&mut output,
);
}
generated_files.push((rust_file_name, output)); generated_files.push((rust_file_name, output));
} }
@@ -655,6 +624,121 @@ pub fn generate_rust_code(
generated_files generated_files
} }
pub fn generate_protobuf_code(
set: &FileDescriptorSet,
files_to_generate: Option<&[String]>,
generate_mod_files: bool,
) -> Vec<(String, String)> {
generate_files_common(
set,
files_to_generate,
generate_mod_files,
DATA_IMPORTS,
|file_proto, output| {
// Enums
for enum_res in file_proto.enum_type() {
let (enum_data, _) = enum_res.expect("Failed to iterate enum");
write_enum(
&EnumDescriptorProto::new(enum_data).expect("Failed to parse EnumDescriptorProto"),
output,
);
}
// Messages
for msg_res in file_proto.message_type() {
let (msg_data, _) = msg_res.expect("Failed to iterate message");
write_message(
&DescriptorProto::new(msg_data).expect("Failed to parse DescriptorProto"),
output,
);
}
},
)
}
pub fn generate_service_code(
set: &FileDescriptorSet,
files_to_generate: Option<&[String]>,
generate_mod_files: bool,
) -> Vec<(String, String)> {
generate_files_common(
set,
files_to_generate,
generate_mod_files,
SERVICE_IMPORTS,
|file_proto, output| {
let package = file_proto.package().unwrap_or("").to_string();
// Services
for svc_res in file_proto.service() {
let (svc_data, _) = svc_res.expect("Failed to iterate service");
write_service(
&ServiceDescriptorProto::new(svc_data).expect("Failed to parse ServiceDescriptorProto"),
&package,
output,
);
}
},
)
}
pub fn generate_rust_code(
set: &FileDescriptorSet,
files_to_generate: Option<&[String]>,
generate_mod_files: bool,
) -> Vec<(String, String)> {
let protobuf_files = generate_protobuf_code(set, files_to_generate, false);
let service_files = generate_service_code(set, files_to_generate, false);
let mut combined_files: HashMap<String, String> = HashMap::new();
for (filename, content) in protobuf_files {
combined_files.insert(filename, content);
}
for (filename, content) in service_files {
if let Some(existing_content) = combined_files.get_mut(&filename) {
let stripped = strip_boilerplate(&content);
existing_content.push_str("\n");
existing_content.push_str(&stripped);
} else {
combined_files.insert(filename, content);
}
}
let mut result = combined_files.into_iter().collect::<Vec<_>>();
result.sort_by(|a, b| a.0.cmp(&b.0));
if generate_mod_files {
let mods = generate_files_common(
set,
files_to_generate,
true,
"",
|_, _| {},
);
for (filename, content) in mods {
if filename == "mod.rs" || filename.contains("/mod.rs") {
result.push((filename, content));
}
}
}
result
}
fn strip_boilerplate(content: &str) -> String {
// Find the first occurrence of a service definition or a trait
// In our case, the services start after the dependency imports and a newline.
if let Some(idx) = content.find("pub trait ") {
return content[idx..].to_string();
}
if let Some(idx) = content.find("pub struct ") {
// This might be a message, but generate_service_code only generates services (and their server structs)
return content[idx..].to_string();
}
content.to_string()
}
fn write_service(svc_proto: &ServiceDescriptorProto, package: &str, output: &mut String) { fn write_service(svc_proto: &ServiceDescriptorProto, package: &str, output: &mut String) {
let svc_name = to_pascal_case(svc_proto.name().unwrap()); let svc_name = to_pascal_case(svc_proto.name().unwrap());
output.push_str(&format!("#[tonic::async_trait]\npub trait {}: Send + Sync + 'static {{\n", svc_name)); output.push_str(&format!("#[tonic::async_trait]\npub trait {}: Send + Sync + 'static {{\n", svc_name));