Split service and proto gen
This commit is contained in:
@@ -0,0 +1,42 @@
|
|||||||
|
use clap::Parser;
|
||||||
|
use roto_codegen::generator::generate_protobuf_code;
|
||||||
|
use roto_codegen::google::protobuf::descriptor::FileDescriptorSet;
|
||||||
|
use std::fs;
|
||||||
|
use std::path::PathBuf;
|
||||||
|
|
||||||
|
#[derive(Parser)]
|
||||||
|
#[command(
|
||||||
|
author,
|
||||||
|
version,
|
||||||
|
about = "Generates Rust accessor and builder code from a protobuf descriptor set"
|
||||||
|
)]
|
||||||
|
struct Args {
|
||||||
|
/// Path to the descriptor set file (.desc)
|
||||||
|
#[arg(short, long)]
|
||||||
|
input: PathBuf,
|
||||||
|
|
||||||
|
/// Path to the output directory
|
||||||
|
#[arg(short, long)]
|
||||||
|
output: PathBuf,
|
||||||
|
|
||||||
|
/// Files to generate. If omitted, all files are generated.
|
||||||
|
#[arg(short, long, value_delimiter = ',')]
|
||||||
|
files: Option<Vec<String>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
let args = Args::parse();
|
||||||
|
let data = fs::read(&args.input)?;
|
||||||
|
let set = FileDescriptorSet::new(&data).expect("Failed to parse FileDescriptorSet");
|
||||||
|
|
||||||
|
let files = generate_protobuf_code(&set, args.files.as_deref(), true);
|
||||||
|
|
||||||
|
for (filename, content) in files {
|
||||||
|
let path = args.output.join(filename);
|
||||||
|
if let Some(parent) = path.parent() {
|
||||||
|
fs::create_dir_all(parent)?;
|
||||||
|
}
|
||||||
|
fs::write(path, content)?;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
@@ -0,0 +1,42 @@
|
|||||||
|
use clap::Parser;
|
||||||
|
use roto_codegen::generator::generate_service_code;
|
||||||
|
use roto_codegen::google::protobuf::descriptor::FileDescriptorSet;
|
||||||
|
use std::fs;
|
||||||
|
use std::path::PathBuf;
|
||||||
|
|
||||||
|
#[derive(Parser)]
|
||||||
|
#[command(
|
||||||
|
author,
|
||||||
|
version,
|
||||||
|
about = "Generates Rust gRPC service code from a protobuf descriptor set"
|
||||||
|
)]
|
||||||
|
struct Args {
|
||||||
|
/// Path to the descriptor set file (.desc)
|
||||||
|
#[arg(short, long)]
|
||||||
|
input: PathBuf,
|
||||||
|
|
||||||
|
/// Path to the output directory
|
||||||
|
#[arg(short, long)]
|
||||||
|
output: PathBuf,
|
||||||
|
|
||||||
|
/// Files to generate. If omitted, all files are generated.
|
||||||
|
#[arg(short, long, value_delimiter = ',')]
|
||||||
|
files: Option<Vec<String>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
let args = Args::parse();
|
||||||
|
let data = fs::read(&args.input)?;
|
||||||
|
let set = FileDescriptorSet::new(&data).expect("Failed to parse FileDescriptorSet");
|
||||||
|
|
||||||
|
let files = generate_service_code(&set, args.files.as_deref(), true);
|
||||||
|
|
||||||
|
for (filename, content) in files {
|
||||||
|
let path = args.output.join(filename);
|
||||||
|
if let Some(parent) = path.parent() {
|
||||||
|
fs::create_dir_all(parent)?;
|
||||||
|
}
|
||||||
|
fs::write(path, content)?;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
+127
-43
@@ -6,6 +6,9 @@ use roto_runtime::ProtoAccessor;
|
|||||||
use std::collections::{HashMap, HashSet};
|
use std::collections::{HashMap, HashSet};
|
||||||
use std::str;
|
use std::str;
|
||||||
|
|
||||||
|
const DATA_IMPORTS: &str = "use roto_runtime::{ProtoAccessor, ProtoBuilder, Result, RotoError, read_varint, RepeatedFieldIterator, RotoMessage};\nuse std::str;\nuse bytes::{Bytes, BytesMut, Buf, BufMut};\n";
|
||||||
|
const SERVICE_IMPORTS: &str = "use tonic::{Request, Response, Status};\nuse tokio_stream::Stream;\nuse std::pin::Pin;\nuse std::sync::Arc;\nuse std::task::{Context, Poll};\nuse std::future::Future;\nuse tonic::body::BoxBody;\nuse tower::Service;\nuse futures_util::StreamExt;\nuse http_body_util::BodyExt;\nuse http_body::Body;\nuse crate::{BufferPool, StatusBody};\n";
|
||||||
|
|
||||||
pub fn to_pascal_case(s: &str) -> String {
|
pub fn to_pascal_case(s: &str) -> String {
|
||||||
s.split('_')
|
s.split('_')
|
||||||
.map(|word| {
|
.map(|word| {
|
||||||
@@ -524,11 +527,16 @@ fn map_type_to_rust_builder(field_type: i32) -> (String, String) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn generate_rust_code(
|
fn generate_files_common<F>(
|
||||||
set: &FileDescriptorSet,
|
set: &FileDescriptorSet,
|
||||||
files_to_generate: Option<&[String]>,
|
files_to_generate: Option<&[String]>,
|
||||||
generate_mod_files: bool,
|
generate_mod_files: bool,
|
||||||
) -> Vec<(String, String)> {
|
imports: &str,
|
||||||
|
mut content_gen: F,
|
||||||
|
) -> Vec<(String, String)>
|
||||||
|
where
|
||||||
|
F: FnMut(&FileDescriptorProto, &mut String),
|
||||||
|
{
|
||||||
let mut generated_files = Vec::new();
|
let mut generated_files = Vec::new();
|
||||||
|
|
||||||
for file_res in set.file() {
|
for file_res in set.file() {
|
||||||
@@ -548,21 +556,7 @@ pub fn generate_rust_code(
|
|||||||
let mut output = String::new();
|
let mut output = String::new();
|
||||||
output.push_str("// @generated by protoc-gen-roto — do not edit\n");
|
output.push_str("// @generated by protoc-gen-roto — do not edit\n");
|
||||||
output.push_str("#[allow(unused_imports)]\n\n");
|
output.push_str("#[allow(unused_imports)]\n\n");
|
||||||
output.push_str("use roto_runtime::{ProtoAccessor, ProtoBuilder, Result, RotoError, read_varint, RepeatedFieldIterator, RotoMessage};\n");
|
output.push_str(imports);
|
||||||
output.push_str("use std::str;\n");
|
|
||||||
output.push_str("use bytes::{Bytes, BytesMut, Buf, BufMut};\n");
|
|
||||||
output.push_str("use tonic::{Request, Response, Status};\n");
|
|
||||||
output.push_str("use tokio_stream::Stream;\n");
|
|
||||||
output.push_str("use std::pin::Pin;\n");
|
|
||||||
output.push_str("use std::sync::Arc;\n");
|
|
||||||
output.push_str("use std::task::{Context, Poll};\n");
|
|
||||||
output.push_str("use std::future::Future;\n");
|
|
||||||
output.push_str("use tonic::body::BoxBody;\n");
|
|
||||||
output.push_str("use tower::Service;\n");
|
|
||||||
output.push_str("use futures_util::StreamExt;\n");
|
|
||||||
output.push_str("use http_body_util::BodyExt;\n");
|
|
||||||
output.push_str("use http_body::Body;\n");
|
|
||||||
output.push_str("use crate::{BufferPool, StatusBody};\n\n");
|
|
||||||
|
|
||||||
for dep_res in file_proto.dependency() {
|
for dep_res in file_proto.dependency() {
|
||||||
let (dep_data, _) = dep_res.expect("Failed to iterate dependency");
|
let (dep_data, _) = dep_res.expect("Failed to iterate dependency");
|
||||||
@@ -572,33 +566,8 @@ pub fn generate_rust_code(
|
|||||||
}
|
}
|
||||||
output.push_str("\n");
|
output.push_str("\n");
|
||||||
|
|
||||||
// Enums
|
content_gen(&file_proto, &mut output);
|
||||||
for enum_res in file_proto.enum_type() {
|
|
||||||
let (enum_data, _) = enum_res.expect("Failed to iterate enum");
|
|
||||||
write_enum(
|
|
||||||
&EnumDescriptorProto::new(enum_data).expect("Failed to parse EnumDescriptorProto"),
|
|
||||||
&mut output,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Messages
|
|
||||||
for msg_res in file_proto.message_type() {
|
|
||||||
let (msg_data, _) = msg_res.expect("Failed to iterate message");
|
|
||||||
write_message(
|
|
||||||
&DescriptorProto::new(msg_data).expect("Failed to parse DescriptorProto"),
|
|
||||||
&mut output,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Services
|
|
||||||
for svc_res in file_proto.service() {
|
|
||||||
let (svc_data, _) = svc_res.expect("Failed to iterate service");
|
|
||||||
write_service(
|
|
||||||
&ServiceDescriptorProto::new(svc_data).expect("Failed to parse ServiceDescriptorProto"),
|
|
||||||
file_proto.package().unwrap_or(""),
|
|
||||||
&mut output,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
generated_files.push((rust_file_name, output));
|
generated_files.push((rust_file_name, output));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -655,6 +624,121 @@ pub fn generate_rust_code(
|
|||||||
generated_files
|
generated_files
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn generate_protobuf_code(
|
||||||
|
set: &FileDescriptorSet,
|
||||||
|
files_to_generate: Option<&[String]>,
|
||||||
|
generate_mod_files: bool,
|
||||||
|
) -> Vec<(String, String)> {
|
||||||
|
generate_files_common(
|
||||||
|
set,
|
||||||
|
files_to_generate,
|
||||||
|
generate_mod_files,
|
||||||
|
DATA_IMPORTS,
|
||||||
|
|file_proto, output| {
|
||||||
|
// Enums
|
||||||
|
for enum_res in file_proto.enum_type() {
|
||||||
|
let (enum_data, _) = enum_res.expect("Failed to iterate enum");
|
||||||
|
write_enum(
|
||||||
|
&EnumDescriptorProto::new(enum_data).expect("Failed to parse EnumDescriptorProto"),
|
||||||
|
output,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Messages
|
||||||
|
for msg_res in file_proto.message_type() {
|
||||||
|
let (msg_data, _) = msg_res.expect("Failed to iterate message");
|
||||||
|
write_message(
|
||||||
|
&DescriptorProto::new(msg_data).expect("Failed to parse DescriptorProto"),
|
||||||
|
output,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn generate_service_code(
|
||||||
|
set: &FileDescriptorSet,
|
||||||
|
files_to_generate: Option<&[String]>,
|
||||||
|
generate_mod_files: bool,
|
||||||
|
) -> Vec<(String, String)> {
|
||||||
|
generate_files_common(
|
||||||
|
set,
|
||||||
|
files_to_generate,
|
||||||
|
generate_mod_files,
|
||||||
|
SERVICE_IMPORTS,
|
||||||
|
|file_proto, output| {
|
||||||
|
let package = file_proto.package().unwrap_or("").to_string();
|
||||||
|
// Services
|
||||||
|
for svc_res in file_proto.service() {
|
||||||
|
let (svc_data, _) = svc_res.expect("Failed to iterate service");
|
||||||
|
write_service(
|
||||||
|
&ServiceDescriptorProto::new(svc_data).expect("Failed to parse ServiceDescriptorProto"),
|
||||||
|
&package,
|
||||||
|
output,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn generate_rust_code(
|
||||||
|
set: &FileDescriptorSet,
|
||||||
|
files_to_generate: Option<&[String]>,
|
||||||
|
generate_mod_files: bool,
|
||||||
|
) -> Vec<(String, String)> {
|
||||||
|
let protobuf_files = generate_protobuf_code(set, files_to_generate, false);
|
||||||
|
let service_files = generate_service_code(set, files_to_generate, false);
|
||||||
|
|
||||||
|
let mut combined_files: HashMap<String, String> = HashMap::new();
|
||||||
|
|
||||||
|
for (filename, content) in protobuf_files {
|
||||||
|
combined_files.insert(filename, content);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (filename, content) in service_files {
|
||||||
|
if let Some(existing_content) = combined_files.get_mut(&filename) {
|
||||||
|
let stripped = strip_boilerplate(&content);
|
||||||
|
existing_content.push_str("\n");
|
||||||
|
existing_content.push_str(&stripped);
|
||||||
|
} else {
|
||||||
|
combined_files.insert(filename, content);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut result = combined_files.into_iter().collect::<Vec<_>>();
|
||||||
|
result.sort_by(|a, b| a.0.cmp(&b.0));
|
||||||
|
|
||||||
|
if generate_mod_files {
|
||||||
|
let mods = generate_files_common(
|
||||||
|
set,
|
||||||
|
files_to_generate,
|
||||||
|
true,
|
||||||
|
"",
|
||||||
|
|_, _| {},
|
||||||
|
);
|
||||||
|
for (filename, content) in mods {
|
||||||
|
if filename == "mod.rs" || filename.contains("/mod.rs") {
|
||||||
|
result.push((filename, content));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
fn strip_boilerplate(content: &str) -> String {
|
||||||
|
// Find the first occurrence of a service definition or a trait
|
||||||
|
// In our case, the services start after the dependency imports and a newline.
|
||||||
|
if let Some(idx) = content.find("pub trait ") {
|
||||||
|
return content[idx..].to_string();
|
||||||
|
}
|
||||||
|
if let Some(idx) = content.find("pub struct ") {
|
||||||
|
// This might be a message, but generate_service_code only generates services (and their server structs)
|
||||||
|
return content[idx..].to_string();
|
||||||
|
}
|
||||||
|
content.to_string()
|
||||||
|
}
|
||||||
|
|
||||||
fn write_service(svc_proto: &ServiceDescriptorProto, package: &str, output: &mut String) {
|
fn write_service(svc_proto: &ServiceDescriptorProto, package: &str, output: &mut String) {
|
||||||
let svc_name = to_pascal_case(svc_proto.name().unwrap());
|
let svc_name = to_pascal_case(svc_proto.name().unwrap());
|
||||||
output.push_str(&format!("#[tonic::async_trait]\npub trait {}: Send + Sync + 'static {{\n", svc_name));
|
output.push_str(&format!("#[tonic::async_trait]\npub trait {}: Send + Sync + 'static {{\n", svc_name));
|
||||||
|
|||||||
Reference in New Issue
Block a user