diff --git a/prost-build/src/code_generator.rs b/prost-build/src/code_generator.rs index 89c48f30b..c303a6952 100644 --- a/prost-build/src/code_generator.rs +++ b/prost-build/src/code_generator.rs @@ -49,6 +49,44 @@ fn prost_path(config: &Config) -> &str { config.prost_path.as_deref().unwrap_or("::prost") } +struct Field { + descriptor: FieldDescriptorProto, + path_index: i32, +} + +impl Field { + fn new(descriptor: FieldDescriptorProto, path_index: i32) -> Self { + Self { + descriptor, + path_index, + } + } + + fn rust_name(&self) -> String { + to_snake(self.descriptor.name()) + } +} + +struct OneofField { + descriptor: OneofDescriptorProto, + fields: Vec, + path_index: i32, +} + +impl OneofField { + fn new(descriptor: OneofDescriptorProto, fields: Vec, path_index: i32) -> Self { + Self { + descriptor, + fields, + path_index, + } + } + + fn rust_name(&self) -> String { + to_snake(self.descriptor.name()) + } +} + impl<'a> CodeGenerator<'a> { pub fn generate( config: &mut Config, @@ -158,21 +196,33 @@ impl<'a> CodeGenerator<'a> { // Split the fields into a vector of the normal fields, and oneof fields. // Path indexes are preserved so that comments can be retrieved. - type Fields = Vec<(FieldDescriptorProto, usize)>; - type OneofFields = MultiMap; - let (fields, mut oneof_fields): (Fields, OneofFields) = message + type OneofFieldsByIndex = MultiMap; + let (fields, mut oneof_map): (Vec, OneofFieldsByIndex) = message .field .into_iter() .enumerate() - .partition_map(|(idx, field)| { - if field.proto3_optional.unwrap_or(false) { - Either::Left((field, idx)) - } else if let Some(oneof_index) = field.oneof_index { - Either::Right((oneof_index, (field, idx))) + .partition_map(|(idx, proto)| { + let idx = idx as i32; + if proto.proto3_optional.unwrap_or(false) { + Either::Left(Field::new(proto, idx)) + } else if let Some(oneof_index) = proto.oneof_index { + Either::Right((oneof_index, Field::new(proto, idx))) } else { - Either::Left((field, idx)) + Either::Left(Field::new(proto, idx)) } }); + // Optional fields create a synthetic oneof that we want to skip + let oneof_fields: Vec = message + .oneof_decl + .into_iter() + .enumerate() + .filter_map(move |(idx, proto)| { + let idx = idx as i32; + oneof_map + .remove(&idx) + .map(|fields| OneofField::new(proto, fields, idx)) + }) + .collect(); self.append_doc(&fq_message_name, None); self.append_type_attributes(&fq_message_name); @@ -192,9 +242,10 @@ impl<'a> CodeGenerator<'a> { self.depth += 1; self.path.push(2); - for (field, idx) in fields { - self.path.push(idx as i32); + for field in &fields { + self.path.push(field.path_index); match field + .descriptor .type_name .as_ref() .and_then(|type_name| map_types.get(type_name)) @@ -207,16 +258,9 @@ impl<'a> CodeGenerator<'a> { self.path.pop(); self.path.push(8); - for (idx, oneof) in message.oneof_decl.iter().enumerate() { - let idx = idx as i32; - - let fields = match oneof_fields.get_vec(&idx) { - Some(fields) => fields, - None => continue, - }; - - self.path.push(idx); - self.append_oneof_field(&message_name, &fq_message_name, oneof, fields); + for oneof in &oneof_fields { + self.path.push(oneof.path_index); + self.append_oneof_field(&message_name, &fq_message_name, oneof); self.path.pop(); } self.path.pop(); @@ -243,14 +287,8 @@ impl<'a> CodeGenerator<'a> { } self.path.pop(); - for (idx, oneof) in message.oneof_decl.into_iter().enumerate() { - let idx = idx as i32; - // optional fields create a synthetic oneof that we want to skip - let fields = match oneof_fields.remove(&idx) { - Some(fields) => fields, - None => continue, - }; - self.append_oneof(&fq_message_name, oneof, idx, fields); + for oneof in &oneof_fields { + self.append_oneof(&fq_message_name, oneof); } self.pop_mod(); @@ -359,32 +397,32 @@ impl<'a> CodeGenerator<'a> { } } - fn append_field(&mut self, fq_message_name: &str, field: FieldDescriptorProto) { - let type_ = field.r#type(); - let repeated = field.label == Some(Label::Repeated as i32); - let deprecated = self.deprecated(&field); - let optional = self.optional(&field); - let ty = self.resolve_type(&field, fq_message_name); + fn append_field(&mut self, fq_message_name: &str, field: &Field) { + let type_ = field.descriptor.r#type(); + let repeated = field.descriptor.label == Some(Label::Repeated as i32); + let deprecated = self.deprecated(&field.descriptor); + let optional = self.optional(&field.descriptor); + let ty = self.resolve_type(&field.descriptor, fq_message_name); let boxed = !repeated && ((type_ == Type::Message || type_ == Type::Group) && self .message_graph - .is_nested(field.type_name(), fq_message_name)) + .is_nested(field.descriptor.type_name(), fq_message_name)) || (self .config .boxed - .get_first_field(fq_message_name, field.name()) + .get_first_field(fq_message_name, field.descriptor.name()) .is_some()); debug!( " field: {:?}, type: {:?}, boxed: {}", - field.name(), + field.descriptor.name(), ty, boxed ); - self.append_doc(fq_message_name, Some(field.name())); + self.append_doc(fq_message_name, Some(field.descriptor.name())); if deprecated { self.push_indent(); @@ -393,21 +431,21 @@ impl<'a> CodeGenerator<'a> { self.push_indent(); self.buf.push_str("#[prost("); - let type_tag = self.field_type_tag(&field); + let type_tag = self.field_type_tag(&field.descriptor); self.buf.push_str(&type_tag); if type_ == Type::Bytes { let bytes_type = self .config .bytes_type - .get_first_field(fq_message_name, field.name()) + .get_first_field(fq_message_name, field.descriptor.name()) .copied() .unwrap_or_default(); self.buf .push_str(&format!("={:?}", bytes_type.annotation())); } - match field.label() { + match field.descriptor.label() { Label::Optional => { if optional { self.buf.push_str(", optional"); @@ -416,8 +454,9 @@ impl<'a> CodeGenerator<'a> { Label::Required => self.buf.push_str(", required"), Label::Repeated => { self.buf.push_str(", repeated"); - if can_pack(&field) + if can_pack(&field.descriptor) && !field + .descriptor .options .as_ref() .map_or(self.syntax == Syntax::Proto3, |options| options.packed()) @@ -431,9 +470,9 @@ impl<'a> CodeGenerator<'a> { self.buf.push_str(", boxed"); } self.buf.push_str(", tag=\""); - self.buf.push_str(&field.number().to_string()); + self.buf.push_str(&field.descriptor.number().to_string()); - if let Some(ref default) = field.default_value { + if let Some(ref default) = field.descriptor.default_value { self.buf.push_str("\", default=\""); if type_ == Type::Bytes { self.buf.push_str("b\\\""); @@ -450,6 +489,7 @@ impl<'a> CodeGenerator<'a> { // the last segment and strip it from the left // side of the default value. let enum_type = field + .descriptor .type_name .as_ref() .and_then(|ty| ty.split('.').last()) @@ -464,10 +504,10 @@ impl<'a> CodeGenerator<'a> { } self.buf.push_str("\")]\n"); - self.append_field_attributes(fq_message_name, field.name()); + self.append_field_attributes(fq_message_name, field.descriptor.name()); self.push_indent(); self.buf.push_str("pub "); - self.buf.push_str(&to_snake(field.name())); + self.buf.push_str(&field.rust_name()); self.buf.push_str(": "); let prost_path = prost_path(self.config); @@ -495,7 +535,7 @@ impl<'a> CodeGenerator<'a> { fn append_map_field( &mut self, fq_message_name: &str, - field: FieldDescriptorProto, + field: &Field, key: &FieldDescriptorProto, value: &FieldDescriptorProto, ) { @@ -504,18 +544,18 @@ impl<'a> CodeGenerator<'a> { debug!( " map field: {:?}, key type: {:?}, value type: {:?}", - field.name(), + field.descriptor.name(), key_ty, value_ty ); - self.append_doc(fq_message_name, Some(field.name())); + self.append_doc(fq_message_name, Some(field.descriptor.name())); self.push_indent(); let map_type = self .config .map_type - .get_first_field(fq_message_name, field.name()) + .get_first_field(fq_message_name, field.descriptor.name()) .copied() .unwrap_or_default(); let key_tag = self.field_type_tag(key); @@ -526,13 +566,13 @@ impl<'a> CodeGenerator<'a> { map_type.annotation(), key_tag, value_tag, - field.number() + field.descriptor.number() )); - self.append_field_attributes(fq_message_name, field.name()); + self.append_field_attributes(fq_message_name, field.descriptor.name()); self.push_indent(); self.buf.push_str(&format!( "pub {}: {}<{}, {}>,\n", - to_snake(field.name()), + field.rust_name(), map_type.rust_type(), key_ty, value_ty @@ -543,44 +583,41 @@ impl<'a> CodeGenerator<'a> { &mut self, message_name: &str, fq_message_name: &str, - oneof: &OneofDescriptorProto, - fields: &[(FieldDescriptorProto, usize)], + oneof: &OneofField, ) { - let name = format!( + let type_name = format!( "{}::{}", to_snake(message_name), - to_upper_camel(oneof.name()) + to_upper_camel(oneof.descriptor.name()) ); self.append_doc(fq_message_name, None); self.push_indent(); self.buf.push_str(&format!( "#[prost(oneof=\"{}\", tags=\"{}\")]\n", - name, - fields.iter().map(|(field, _)| field.number()).join(", ") + type_name, + oneof + .fields + .iter() + .map(|field| field.descriptor.number()) + .join(", "), )); - self.append_field_attributes(fq_message_name, oneof.name()); + self.append_field_attributes(fq_message_name, oneof.descriptor.name()); self.push_indent(); self.buf.push_str(&format!( "pub {}: ::core::option::Option<{}>,\n", - to_snake(oneof.name()), - name + oneof.rust_name(), + type_name )); } - fn append_oneof( - &mut self, - fq_message_name: &str, - oneof: OneofDescriptorProto, - idx: i32, - fields: Vec<(FieldDescriptorProto, usize)>, - ) { + fn append_oneof(&mut self, fq_message_name: &str, oneof: &OneofField) { self.path.push(8); - self.path.push(idx); + self.path.push(oneof.path_index); self.append_doc(fq_message_name, None); self.path.pop(); self.path.pop(); - let oneof_name = format!("{}.{}", fq_message_name, oneof.name()); + let oneof_name = format!("{}.{}", fq_message_name, oneof.descriptor.name()); self.append_type_attributes(&oneof_name); self.append_enum_attributes(&oneof_name); self.push_indent(); @@ -593,43 +630,43 @@ impl<'a> CodeGenerator<'a> { self.append_skip_debug(fq_message_name); self.push_indent(); self.buf.push_str("pub enum "); - self.buf.push_str(&to_upper_camel(oneof.name())); + self.buf.push_str(&to_upper_camel(oneof.descriptor.name())); self.buf.push_str(" {\n"); self.path.push(2); self.depth += 1; - for (field, idx) in fields { - let type_ = field.r#type(); + for field in &oneof.fields { + let type_ = field.descriptor.r#type(); - self.path.push(idx as i32); - self.append_doc(fq_message_name, Some(field.name())); + self.path.push(field.path_index); + self.append_doc(fq_message_name, Some(field.descriptor.name())); self.path.pop(); self.push_indent(); - let ty_tag = self.field_type_tag(&field); + let ty_tag = self.field_type_tag(&field.descriptor); self.buf.push_str(&format!( "#[prost({}, tag=\"{}\")]\n", ty_tag, - field.number() + field.descriptor.number() )); - self.append_field_attributes(&oneof_name, field.name()); + self.append_field_attributes(&oneof_name, field.descriptor.name()); self.push_indent(); - let ty = self.resolve_type(&field, fq_message_name); + let ty = self.resolve_type(&field.descriptor, fq_message_name); let boxed = ((type_ == Type::Message || type_ == Type::Group) && self .message_graph - .is_nested(field.type_name(), fq_message_name)) + .is_nested(field.descriptor.type_name(), fq_message_name)) || (self .config .boxed - .get_first_field(&oneof_name, field.name()) + .get_first_field(&oneof_name, field.descriptor.name()) .is_some()); debug!( " oneof: {:?}, type: {:?}, boxed: {}", - field.name(), + field.descriptor.name(), ty, boxed ); @@ -637,12 +674,15 @@ impl<'a> CodeGenerator<'a> { if boxed { self.buf.push_str(&format!( "{}(::prost::alloc::boxed::Box<{}>),\n", - to_upper_camel(field.name()), + to_upper_camel(field.descriptor.name()), ty )); } else { - self.buf - .push_str(&format!("{}({}),\n", to_upper_camel(field.name()), ty)); + self.buf.push_str(&format!( + "{}({}),\n", + to_upper_camel(field.descriptor.name()), + ty + )); } } self.depth -= 1;