diff --git a/src/context/api.rs b/src/context/api.rs index ca871f6..72476f9 100644 --- a/src/context/api.rs +++ b/src/context/api.rs @@ -1,10 +1,8 @@ use super::*; -impl Context -{ +impl Context { /// Create a new context. - pub fn new() -> Self - { + pub fn new() -> Self { Context { packages: Default::default(), types: Default::default(), @@ -19,30 +17,26 @@ impl Context /// Will **panic** if the package defined by the `PackageRef` does not exist in this context. /// Such panic means the `PackageRef` came from a different context. The panic is not /// guaranteed, as a message with an equal `MessageRef` may exist in multiple contexts. - pub fn resolve_package(&self, package_ref: PackageRef) -> &Package - { + pub fn resolve_package(&self, package_ref: PackageRef) -> &Package { &self.packages[package_ref.0 .0] } /// Gets type info by name. - pub fn get_type(&self, full_name: &str) -> Option<&TypeInfo> - { + pub fn get_type(&self, full_name: &str) -> Option<&TypeInfo> { self.types_by_name .get(full_name) .map(|idx| &self.types[*idx]) } /// Gets a message type info by name. - pub fn get_message(&self, full_name: &str) -> Option<&MessageInfo> - { + pub fn get_message(&self, full_name: &str) -> Option<&MessageInfo> { match self.get_type(full_name) { Some(TypeInfo::Message(m)) => Some(m), _ => None, } } - fn resolve_type(&self, tr: InternalRef) -> Option<&TypeInfo> - { + fn resolve_type(&self, tr: InternalRef) -> Option<&TypeInfo> { self.types.get(tr.0) } @@ -51,8 +45,7 @@ impl Context /// Will **panic** if the message defined by the `MessageRef` does not exist in this context. /// Such panic means the `MessageRef` came from a different context. The panic is not /// guaranteed, as a message with an equal `MessageRef` may exist in multiple contexts. - pub fn resolve_message(&self, tr: MessageRef) -> &MessageInfo - { + pub fn resolve_message(&self, tr: MessageRef) -> &MessageInfo { match self.resolve_type(tr.0) { Some(TypeInfo::Message(msg)) => msg, _ => panic!("Message did not exist in this context"), @@ -64,8 +57,7 @@ impl Context /// Will **panic** if the enum defined by the `EnumRef` does not exist in this context. /// Such panic means the `EnumRef` came from a different context. The panic is not /// guaranteed, as an enum with an equal `EnumRef` may exist in multiple contexts. - pub fn resolve_enum(&self, tr: EnumRef) -> &EnumInfo - { + pub fn resolve_enum(&self, tr: EnumRef) -> &EnumInfo { match self.resolve_type(tr.0) { Some(TypeInfo::Enum(e)) => e, _ => panic!("Message did not exist in this context"), @@ -73,19 +65,16 @@ impl Context } /// Gets a service by full name. - pub fn get_service(&self, full_name: &str) -> Option<&Service> - { + pub fn get_service(&self, full_name: &str) -> Option<&Service> { self.services_by_name .get(full_name) .map(|idx| &self.services[*idx]) } } -impl TypeInfo -{ +impl TypeInfo { /// Get the full name of the type. - pub fn name(&self) -> &str - { + pub fn name(&self) -> &str { match self { TypeInfo::Message(m) => &m.name, TypeInfo::Enum(e) => &e.name, @@ -93,8 +82,7 @@ impl TypeInfo } /// Get the full name of the type. - pub fn full_name(&self) -> &str - { + pub fn full_name(&self) -> &str { match self { TypeInfo::Message(m) => &m.full_name, TypeInfo::Enum(e) => &e.full_name, @@ -102,8 +90,7 @@ impl TypeInfo } /// Get the parent information for the type. - pub fn parent(&self) -> TypeParent - { + pub fn parent(&self) -> TypeParent { match self { TypeInfo::Message(m) => m.parent, TypeInfo::Enum(e) => e.parent, @@ -111,59 +98,48 @@ impl TypeInfo } } -impl MessageInfo -{ +impl MessageInfo { /// Iterates all message fields. - pub fn iter_fields(&self) -> impl Iterator - { + pub fn iter_fields(&self) -> impl Iterator { self.fields.values() } /// Get a field by its number. - pub fn get_field(&self, number: u64) -> Option<&MessageField> - { + pub fn get_field(&self, number: u64) -> Option<&MessageField> { self.fields.get(&number) } /// Get a field by its name. - pub fn get_field_by_name(&self, name: &str) -> Option<&MessageField> - { + pub fn get_field_by_name(&self, name: &str) -> Option<&MessageField> { self.fields_by_name .get(name) .and_then(|id| self.get_field(*id)) } /// Gets a oneof by a oneof reference. - pub fn get_oneof(&self, oneof: OneofRef) -> Option<&Oneof> - { + pub fn get_oneof(&self, oneof: OneofRef) -> Option<&Oneof> { self.oneofs.iter().find(|oo| oo.self_ref == oneof) } } -impl EnumInfo -{ +impl EnumInfo { /// Gets a field by value. /// /// If the field is aliased, an undefined field alias is returned. - pub fn get_field_by_value(&self, value: i64) -> Option<&EnumField> - { + pub fn get_field_by_value(&self, value: i64) -> Option<&EnumField> { self.fields_by_value.get(&value) } } -impl Service -{ +impl Service { /// Gets an `Rpc` info by operation name. - pub fn rpc_by_name(&self, name: &str) -> Option<&Rpc> - { + pub fn rpc_by_name(&self, name: &str) -> Option<&Rpc> { self.rpcs_by_name.get(name).map(|idx| &self.rpcs[*idx]) } } -impl ValueType -{ - pub(crate) fn wire_type(&self) -> u8 - { +impl ValueType { + pub(crate) fn wire_type(&self) -> u8 { match self { Self::Double => 1, Self::Float => 5, diff --git a/src/context/builder.rs b/src/context/builder.rs index 80d8563..e1812da 100644 --- a/src/context/builder.rs +++ b/src/context/builder.rs @@ -5,14 +5,12 @@ use std::path::PathBuf; use super::*; #[derive(Default)] -pub(crate) struct ContextBuilder -{ +pub(crate) struct ContextBuilder { pub(crate) packages: Vec, } #[derive(Default, Debug, PartialEq)] -pub(crate) struct PackageBuilder -{ +pub(crate) struct PackageBuilder { pub(crate) path: PathBuf, pub(crate) name: Option, pub(crate) imported_types: Vec, @@ -20,22 +18,19 @@ pub(crate) struct PackageBuilder } #[derive(Debug, PartialEq)] -pub(crate) enum ProtobufItemBuilder -{ +pub(crate) enum ProtobufItemBuilder { Type(ProtobufTypeBuilder), Service(ServiceBuilder), } #[derive(Debug, PartialEq)] -pub(crate) enum ProtobufTypeBuilder -{ +pub(crate) enum ProtobufTypeBuilder { Message(MessageBuilder), Enum(EnumBuilder), } #[derive(Default, Debug, PartialEq, Clone)] -pub(crate) struct MessageBuilder -{ +pub(crate) struct MessageBuilder { pub(crate) name: String, pub(crate) fields: Vec, pub(crate) oneofs: Vec, @@ -44,31 +39,36 @@ pub(crate) struct MessageBuilder } #[derive(Debug, PartialEq, Clone)] -pub(crate) enum InnerTypeBuilder -{ +pub(crate) struct PseudoMapBuilder { + pub(crate) key_type: FieldTypeBuilder, + pub(crate) value_type: FieldTypeBuilder, + pub(crate) field_name: String, + pub(crate) number: u64, + pub(crate) options: Vec, +} + +#[derive(Debug, PartialEq, Clone)] +pub(crate) enum InnerTypeBuilder { Message(MessageBuilder), Enum(EnumBuilder), } #[derive(Default, Debug, PartialEq, Clone)] -pub(crate) struct EnumBuilder -{ +pub(crate) struct EnumBuilder { pub(crate) name: String, pub(crate) fields: Vec, pub(crate) options: Vec, } #[derive(Default, Debug, PartialEq)] -pub(crate) struct ServiceBuilder -{ +pub(crate) struct ServiceBuilder { pub(crate) name: String, pub(crate) rpcs: Vec, pub(crate) options: Vec, } #[derive(Debug, PartialEq, Clone)] -pub(crate) struct FieldBuilder -{ +pub(crate) struct FieldBuilder { pub(crate) multiplicity: Multiplicity, pub(crate) field_type: FieldTypeBuilder, pub(crate) name: String, @@ -77,23 +77,20 @@ pub(crate) struct FieldBuilder } #[derive(Default, Debug, PartialEq, Clone)] -pub(crate) struct OneofBuilder -{ +pub(crate) struct OneofBuilder { pub(crate) name: String, pub(crate) fields: Vec, pub(crate) options: Vec, } #[derive(Debug, PartialEq, Clone)] -pub(crate) enum FieldTypeBuilder -{ +pub(crate) enum FieldTypeBuilder { Builtin(ValueType), Unknown(String), } #[derive(Default, Debug, PartialEq)] -pub(crate) struct RpcBuilder -{ +pub(crate) struct RpcBuilder { pub(crate) name: String, pub(crate) input: RpcArgBuilder, pub(crate) output: RpcArgBuilder, @@ -101,16 +98,13 @@ pub(crate) struct RpcBuilder } #[derive(Default, Debug, PartialEq)] -pub(crate) struct RpcArgBuilder -{ +pub(crate) struct RpcArgBuilder { pub(crate) stream: bool, pub(crate) message: String, } -impl ContextBuilder -{ - pub fn build(mut self) -> Result - { +impl ContextBuilder { + pub fn build(mut self) -> Result { let mut cache = BuildCache::default(); for (i, p) in self.packages.iter().enumerate() { p.populate(&mut cache, &mut vec![i])?; @@ -194,21 +188,17 @@ impl ContextBuilder }) } - fn take_type(&mut self, idx: &[usize]) -> ProtobufTypeBuilder - { + fn take_type(&mut self, idx: &[usize]) -> ProtobufTypeBuilder { self.packages[idx[0]].take_type(&idx[1..]) } - fn take_service(&mut self, idx: &[usize]) -> ServiceBuilder - { + fn take_service(&mut self, idx: &[usize]) -> ServiceBuilder { self.packages[idx[0]].take_service(&idx[1..]) } } -impl PackageBuilder -{ - fn populate(&self, cache: &mut BuildCache, idx: &mut Vec) -> Result<(), ParseError> - { +impl PackageBuilder { + fn populate(&self, cache: &mut BuildCache, idx: &mut Vec) -> Result<(), ParseError> { let mut path = match &self.name { Some(name) => name.split('.').collect(), None => vec![], @@ -233,8 +223,7 @@ impl PackageBuilder Ok(()) } - fn take_type(&mut self, idx: &[usize]) -> ProtobufTypeBuilder - { + fn take_type(&mut self, idx: &[usize]) -> ProtobufTypeBuilder { match &mut self.types[idx[0]] { ProtobufItemBuilder::Type(t) => match t { ProtobufTypeBuilder::Message(m) => m.take_type(&idx[1..]), @@ -248,8 +237,7 @@ impl PackageBuilder } } - fn take_service(&mut self, idx: &[usize]) -> ServiceBuilder - { + fn take_service(&mut self, idx: &[usize]) -> ServiceBuilder { match &mut self.types[idx[0]] { ProtobufItemBuilder::Service(e) => std::mem::take(e), @@ -259,10 +247,8 @@ impl PackageBuilder } } -impl ProtobufTypeBuilder -{ - fn build(self, self_data: &CacheData, cache: &BuildCache) -> Result - { +impl ProtobufTypeBuilder { + fn build(self, self_data: &CacheData, cache: &BuildCache) -> Result { Ok(match self { ProtobufTypeBuilder::Message(m) => TypeInfo::Message(m.build(self_data, cache)?), ProtobufTypeBuilder::Enum(e) => TypeInfo::Enum(e.build(self_data, cache)?), @@ -270,8 +256,7 @@ impl ProtobufTypeBuilder } } -impl MessageBuilder -{ +impl MessageBuilder { /// Lists types found in this message builder recursively into the build cache. /// /// On error the `path` and `idx` will be left in an undefined state. @@ -280,8 +265,7 @@ impl MessageBuilder cache: &mut BuildCache, path: &mut Vec<&'a str>, idx: &mut Vec, - ) -> Result<(), ParseError> - { + ) -> Result<(), ParseError> { path.push(&self.name); let full_name = path.join("."); let cache_idx = cache.types.len(); @@ -320,8 +304,7 @@ impl MessageBuilder Ok(()) } - fn take_type(&mut self, idx: &[usize]) -> ProtobufTypeBuilder - { + fn take_type(&mut self, idx: &[usize]) -> ProtobufTypeBuilder { if idx.is_empty() { ProtobufTypeBuilder::Message(MessageBuilder { name: self.name.clone(), @@ -339,8 +322,7 @@ impl MessageBuilder } } - fn build(self, self_data: &CacheData, cache: &BuildCache) -> Result - { + fn build(self, self_data: &CacheData, cache: &BuildCache) -> Result { let inner_types: Vec<_> = self .inner_types .iter() @@ -419,10 +401,8 @@ impl MessageBuilder } } -impl InnerTypeBuilder -{ - fn clone_name(&self) -> InnerTypeBuilder - { +impl InnerTypeBuilder { + fn clone_name(&self) -> InnerTypeBuilder { match self { InnerTypeBuilder::Message(m) => InnerTypeBuilder::Message(MessageBuilder { name: m.name.clone(), @@ -436,15 +416,13 @@ impl InnerTypeBuilder } } -impl FieldBuilder -{ +impl FieldBuilder { fn build( self, self_data: &CacheData, cache: &BuildCache, oneof: Option, - ) -> Result - { + ) -> Result { let multiplicity = resolve_multiplicity(self.multiplicity, &self.field_type, &self.options); Ok(MessageField { name: self.name, @@ -461,8 +439,7 @@ fn resolve_multiplicity( proto_multiplicity: Multiplicity, field_type: &FieldTypeBuilder, options: &[ProtoOption], -) -> Multiplicity -{ +) -> Multiplicity { // If this isn't a repeated field, the multiplicity follows the proto one (single or optional). if proto_multiplicity != Multiplicity::Repeated { return proto_multiplicity; @@ -489,10 +466,8 @@ fn resolve_multiplicity( Multiplicity::RepeatedPacked } -impl FieldTypeBuilder -{ - fn build(self, self_data: &CacheData, cache: &BuildCache) -> Result - { +impl FieldTypeBuilder { + fn build(self, self_data: &CacheData, cache: &BuildCache) -> Result { Ok(match self { FieldTypeBuilder::Builtin(vt) => vt, FieldTypeBuilder::Unknown(s) => { @@ -513,23 +488,20 @@ impl FieldTypeBuilder } } -impl InnerTypeBuilder -{ +impl InnerTypeBuilder { fn populate<'a>( &'a self, cache: &mut BuildCache, path: &mut Vec<&'a str>, idx: &mut Vec, - ) -> Result<(), ParseError> - { + ) -> Result<(), ParseError> { match self { InnerTypeBuilder::Message(m) => m.populate(cache, path, idx), InnerTypeBuilder::Enum(e) => e.populate(cache, path, idx), } } - fn take_type(&mut self, idx: &[usize]) -> ProtobufTypeBuilder - { + fn take_type(&mut self, idx: &[usize]) -> ProtobufTypeBuilder { match self { InnerTypeBuilder::Message(m) => m.take_type(idx), InnerTypeBuilder::Enum(e) => e.take_type(idx), @@ -537,8 +509,7 @@ impl InnerTypeBuilder } } -impl EnumBuilder -{ +impl EnumBuilder { /// Lists types found in this message builder recursively into the build cache. /// /// On error the `path` and `idx` will be left in an undefined state. @@ -547,8 +518,7 @@ impl EnumBuilder cache: &mut BuildCache, path: &mut Vec<&'a str>, idx: &mut Vec, - ) -> Result<(), ParseError> - { + ) -> Result<(), ParseError> { path.push(&self.name); let full_name = path.join("."); let cache_idx = cache.types.len(); @@ -573,8 +543,7 @@ impl EnumBuilder Ok(()) } - fn build(self, self_data: &CacheData, cache: &BuildCache) -> Result - { + fn build(self, self_data: &CacheData, cache: &BuildCache) -> Result { let fields_by_name = self .fields .iter() @@ -594,8 +563,7 @@ impl EnumBuilder }) } - fn take_type(&mut self, idx: &[usize]) -> ProtobufTypeBuilder - { + fn take_type(&mut self, idx: &[usize]) -> ProtobufTypeBuilder { if !idx.is_empty() { panic!("Trying to take an inner type from an enum"); } @@ -604,8 +572,7 @@ impl EnumBuilder } } -impl ServiceBuilder -{ +impl ServiceBuilder { /// Lists types found in this message builder recursively into the build cache. /// /// On error the `path` and `idx` will be left in an undefined state. @@ -614,8 +581,7 @@ impl ServiceBuilder cache: &mut BuildCache, path: &mut Vec<&'a str>, idx: &mut Vec, - ) -> Result<(), ParseError> - { + ) -> Result<(), ParseError> { path.push(&self.name); let full_name = path.join("."); let cache_idx = cache.services.len(); @@ -639,8 +605,7 @@ impl ServiceBuilder Ok(()) } - fn build(self, self_data: &CacheData, cache: &BuildCache) -> Result - { + fn build(self, self_data: &CacheData, cache: &BuildCache) -> Result { let rpcs: Vec<_> = self .rpcs .into_iter() @@ -669,10 +634,8 @@ impl ServiceBuilder } } -impl RpcBuilder -{ - fn build(self, self_data: &CacheData, cache: &BuildCache) -> Result - { +impl RpcBuilder { + fn build(self, self_data: &CacheData, cache: &BuildCache) -> Result { Ok(Rpc { name: self.name, input: self.input.build(self_data, cache)?, @@ -682,10 +645,8 @@ impl RpcBuilder } } -impl RpcArgBuilder -{ - fn build(self, rpc_data: &CacheData, cache: &BuildCache) -> Result - { +impl RpcArgBuilder { + fn build(self, rpc_data: &CacheData, cache: &BuildCache) -> Result { // Fetch the type data from the cache so we can figure out the type reference. let self_data = match cache.resolve_type(&self.message, &rpc_data.full_name) { Some(data) => data, @@ -715,10 +676,8 @@ impl RpcArgBuilder } } -impl MessageRef -{ - fn from(data: &CacheData) -> Self - { +impl MessageRef { + fn from(data: &CacheData) -> Self { if data.item_type != ItemType::Message { panic!("Trying to create MessageRef for {:?}", data.item_type); } @@ -726,10 +685,8 @@ impl MessageRef } } -impl EnumRef -{ - fn from(data: &CacheData) -> Self - { +impl EnumRef { + fn from(data: &CacheData) -> Self { if data.item_type != ItemType::Enum { panic!("Trying to create EnumRef for {:?}", data.item_type); } @@ -738,26 +695,22 @@ impl EnumRef } #[derive(Default)] -struct BuildCache -{ +struct BuildCache { items: BTreeMap, items_by_idx: BTreeMap, (ItemType, usize)>, types: Vec, services: Vec, } -struct CacheData -{ +struct CacheData { item_type: ItemType, idx_path: Vec, final_idx: usize, full_name: String, } -impl BuildCache -{ - fn resolve_type(&self, relative_name: &str, mut current_path: &str) -> Option<&CacheData> - { +impl BuildCache { + fn resolve_type(&self, relative_name: &str, mut current_path: &str) -> Option<&CacheData> { if let Some(absolute) = relative_name.strip_prefix('.') { return self.type_by_full_name(absolute); } @@ -788,8 +741,7 @@ impl BuildCache } } - fn parent_type(&self, current: &[usize]) -> TypeParent - { + fn parent_type(&self, current: &[usize]) -> TypeParent { match current.len() { 0 | 1 => panic!("Empty type ID path"), 2 => TypeParent::Package(PackageRef(InternalRef(current[0]))), @@ -800,22 +752,19 @@ impl BuildCache } } - fn type_by_full_name(&self, full_name: &str) -> Option<&CacheData> - { + fn type_by_full_name(&self, full_name: &str) -> Option<&CacheData> { self.items .get(full_name) .and_then(|(ty, i)| self.type_by_idx(*ty, *i)) } - fn type_by_idx_path(&self, idx: &[usize]) -> Option<&CacheData> - { + fn type_by_idx_path(&self, idx: &[usize]) -> Option<&CacheData> { self.items_by_idx .get(idx) .and_then(|(ty, i)| self.type_by_idx(*ty, *i)) } - fn type_by_idx(&self, item_type: ItemType, idx: usize) -> Option<&CacheData> - { + fn type_by_idx(&self, item_type: ItemType, idx: usize) -> Option<&CacheData> { match item_type { ItemType::Message => self.types.get(idx), ItemType::Enum => self.types.get(idx), diff --git a/src/context/mod.rs b/src/context/mod.rs index 518e150..32d17c9 100644 --- a/src/context/mod.rs +++ b/src/context/mod.rs @@ -1,537 +1,506 @@ -//! Decoding context built from the proto-files. - -use bytes::Bytes; -use snafu::{ResultExt, Snafu}; -use std::collections::{BTreeMap, HashMap}; - -mod api; -mod builder; -mod modify_api; -mod parse; - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -struct InternalRef(usize); - -/// A reference to a message. Can be resolved to `MessageInfo` through a `Context`. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct MessageRef(InternalRef); - -/// A reference to an enum. Can be resolved to `EnumInfo` through a `Context`. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct EnumRef(InternalRef); - -/// A reference to a package. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct PackageRef(InternalRef); - -/// A reference to a service. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct ServiceRef(InternalRef); - -/// A reference to a service. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct OneofRef(InternalRef); - -/// Protofish error type. -#[derive(Debug, Snafu)] -#[snafu(visibility(pub(crate)))] -#[non_exhaustive] -pub enum ParseError -{ - /// Syntax error in the input files. - #[snafu(display("Parsing error: {}", source))] - SyntaxError - { - /// Source error. - source: Box, - }, - - /// Duplicate type. - #[snafu(display("Duplicate type: {}", name))] - DuplicateType - { - /// Type. - name: String, - }, - - /// Unknown type reference. - #[snafu(display("Unknown type '{}' in '{}'", name, context))] - TypeNotFound - { - /// Type name. - name: String, - /// Type that referred to the unknown type. - context: String, - }, - - /// Wrong kind of type used in a specific context. - #[snafu(display( - "Invalid type '{}' ({:?}) for {}, expected {:?}", - type_name, - actual, - context, - expected - ))] - InvalidTypeKind - { - /// Type that is of the wrong kind. - type_name: String, - - /// The context where the type was used. - context: &'static str, - - /// Expected item type. - expected: ItemType, - - /// Actual item type. - actual: ItemType, - }, -} - -/// Error modifying the context. -#[derive(Debug, Snafu)] -#[non_exhaustive] -pub enum InsertError -{ - /// A type conflicts with an existing type. - TypeExists - { - /// The previous type that conflicts with the new one. - original: TypeRef, - }, -} - -/// Error modifying a type. -#[derive(Debug)] -#[non_exhaustive] -pub enum MemberInsertError -{ - /// A field with the same number already exists. - NumberConflict, - - /// A field with the same name already exists. - NameConflict, - - /// A field refers to a oneof that does not exist. - MissingOneof, -} - -/// Error modifying a type. -#[derive(Debug)] -#[non_exhaustive] -pub enum OneofInsertError -{ - /// A oneof with the same name already exists. - NameConflict, - - /// The oneof refers to a field that doesn't exist. - FieldNotFound - { - /// Field number the Oneof referenced. - field: u64, - }, -} - -/// Type reference that references either message or enum type. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub enum TypeRef -{ - /// Message type reference. - Message(MessageRef), - - /// Enum type reference. - Enum(EnumRef), -} - -/// Protobuf item type -#[derive(Clone, Copy, Debug, PartialEq)] -pub enum ItemType -{ - /// `message` item - Message, - - /// `enum` item - Enum, - - /// `service` item - Service, -} - -/// Protofish decoding context. -/// -/// Contains type information parsed from the files. Required for decoding -/// incoming Protobuf messages. -#[derive(Default, Debug, PartialEq)] -pub struct Context -{ - packages: Vec, - types: Vec, - types_by_name: HashMap, - services: Vec, - services_by_name: HashMap, -} - -/// Package details. -#[derive(Debug, PartialEq)] -pub struct Package -{ - /// Package name. None for an anonymous package. - name: Option, - - /// Package self reference. - self_ref: PackageRef, - - /// Top level types. - types: Vec, - - /// Services. - services: Vec, -} - -/// Message or enum type. -#[derive(Debug, PartialEq)] -pub enum TypeInfo -{ - /// Message. - Message(MessageInfo), - - /// Enum. - Enum(EnumInfo), -} - -/// Message details -#[derive(Debug, PartialEq)] -#[non_exhaustive] -pub struct MessageInfo -{ - /// Message name. - pub name: String, - - /// Full message name, including package and parent type names. - pub full_name: String, - - /// Parent - pub parent: TypeParent, - - /// `MessageRef` that references this message. - pub self_ref: MessageRef, - - /// `oneof` structures defined within the message. - pub oneofs: Vec, - - /// References to the inner types defined within this message. - pub inner_types: Vec, - - // Using BTreeMap here to ensure ordering. - fields: BTreeMap, - fields_by_name: BTreeMap, -} - -/// Reference to a type parent. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum TypeParent -{ - /// Reference to a package for top-level types. - Package(PackageRef), - - /// Reference to a message for inner types. - Message(MessageRef), -} - -/// Enum details -#[derive(Debug, PartialEq)] -#[non_exhaustive] -pub struct EnumInfo -{ - /// Enum name. - pub name: String, - - /// Full message name, including package and parent type names. - pub full_name: String, - - /// Parent - pub parent: TypeParent, - - /// `EnumRef` that references this enum. - pub self_ref: EnumRef, - - fields_by_value: BTreeMap, - fields_by_name: BTreeMap, -} - -/// Message field details. -#[derive(Debug, PartialEq)] -#[non_exhaustive] -pub struct MessageField -{ - /// Field name. - pub name: String, - - /// Field number. - pub number: u64, - - /// Field type - pub field_type: ValueType, - - /// True, if this field is a repeated field. - pub multiplicity: Multiplicity, - - /// Field options. - pub options: Vec, - - /// Index to the ´oneof` structure in the parent type if this field is part of a `oneof`. - pub oneof: Option, -} - -/// Defines the multiplicity of the field values. -#[derive(Debug, PartialEq, Clone)] -pub enum Multiplicity -{ - /// Field is not repeated. - Single, - - /// Field may be repeated. - Repeated, - - /// Field is repeated by packing. - RepeatedPacked, - - /// Field is optional. - Optional, -} - -/// Message `oneof` details. -#[derive(Debug, PartialEq)] -#[non_exhaustive] -pub struct Oneof -{ - /// Name of the `oneof` structure. - pub name: String, - - /// Self reference of the `Oneof` in the owning type. - pub self_ref: OneofRef, - - /// Field numbers of the fields contained in the `oneof`. - pub fields: Vec, - - /// Options. - pub options: Vec, -} - -/// Enum field details. -#[derive(Debug, PartialEq, Clone)] -#[non_exhaustive] -pub struct EnumField -{ - /// Enum field name. - pub name: String, - - /// Enum field value. - pub value: i64, - - /// Options. - pub options: Vec, -} - -/// Field value types. -#[derive(Clone, Debug, PartialEq)] -pub enum ValueType -{ - /// `double` - Double, - - /// `float` - Float, - - /// `int32` - Int32, - - /// `int64` - Int64, - - /// `uint32` - UInt32, - - /// `uint64` - UInt64, - - /// `sint32` - SInt32, - - /// `sint64` - SInt64, - - /// `fixed32` - Fixed32, - - /// `fixed64` - Fixed64, - - /// `sfixed32` - SFixed32, - - /// `sfixed64` - SFixed64, - - /// `bool` - Bool, - - /// `string` - String, - - /// `bytes` - Bytes, - - /// A message type. - Message(MessageRef), - - /// An enum type. - Enum(EnumRef), -} - -/// Service details -#[derive(Debug, PartialEq)] -#[non_exhaustive] -pub struct Service -{ - /// Service name. - pub name: String, - - /// Full service name, including the package name. - pub full_name: String, - - /// Service self reference. - pub self_ref: ServiceRef, - - /// Package that contains the service. - pub parent: PackageRef, - - /// List of `rpc` operations defined in the service. - pub rpcs: Vec, - - /// Options. - pub options: Vec, - - rpcs_by_name: HashMap, -} - -/// Rpc operation -#[derive(Debug, PartialEq)] -#[non_exhaustive] -pub struct Rpc -{ - /// Operation name. - pub name: String, - - /// Input details. - pub input: RpcArg, - - /// Output details. - pub output: RpcArg, - - /// Options. - pub options: Vec, -} - -/// Rpc operation input or output details. -#[derive(Debug, PartialEq)] -#[non_exhaustive] -pub struct RpcArg -{ - /// References to the message type. - pub message: MessageRef, - - /// True, if this is a stream. - pub stream: bool, -} - -/// A single option. -#[derive(Debug, PartialEq, Clone)] -pub struct ProtoOption -{ - /// Option name. - pub name: String, - - /// Optionn value. - pub value: Constant, -} - -/// Constant value, used for options. -#[derive(Debug, PartialEq, Clone)] -pub enum Constant -{ - /// An ident `foo.bar.baz`. - Ident(String), - - /// An integer constant. - Integer(i64), - - /// A floating point constant. - Float(f64), - - /// A string constant. - /// - /// The string isn't guaranteed to be well formed UTF-8 so it's stored as - /// Bytes here. - String(Bytes), - - /// A boolean constant. - Bool(bool), -} - -#[cfg(test)] -mod test -{ - use super::*; - - #[test] - fn basic_package() - { - let ctx = Context::parse(&[r#" - syntax = "proto3"; - message Message {} - "#]) - .unwrap(); - - let m = ctx.get_message("Message").unwrap(); - assert_eq!(m.parent, TypeParent::Package(PackageRef(InternalRef(0)))); - } - - #[test] - fn basic_multiple_package() - { - let ctx = Context::parse(&[ - r#" - syntax = "proto3"; - package First; - message Message {} - "#, - r#" - syntax = "proto3"; - package Second; - message Message {} - "#, - ]) - .unwrap(); - - let m = ctx.get_message("First.Message").unwrap(); - let pkg_ref = match m.parent { - TypeParent::Package(p) => p, - _ => panic!("Not a package reference: {:?}", m.parent), - }; - let pkg = ctx.resolve_package(pkg_ref); - assert_eq!(m.parent, TypeParent::Package(PackageRef(InternalRef(0)))); - assert_eq!(pkg.name.as_deref(), Some("First")); - assert_eq!(pkg.types.len(), 1); - - let m = ctx.get_message("Second.Message").unwrap(); - let pkg_ref = match m.parent { - TypeParent::Package(p) => p, - _ => panic!("Not a package reference: {:?}", m.parent), - }; - let pkg = ctx.resolve_package(pkg_ref); - assert_eq!(m.parent, TypeParent::Package(PackageRef(InternalRef(1)))); - assert_eq!(pkg.name.as_deref(), Some("Second")); - assert_eq!(pkg.types.len(), 1); - } -} +//! Decoding context built from the proto-files. + +use bytes::Bytes; +use snafu::{ResultExt, Snafu}; +use std::collections::{BTreeMap, HashMap}; + +mod api; +mod builder; +mod modify_api; +mod parse; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +struct InternalRef(usize); + +/// A reference to a message. Can be resolved to `MessageInfo` through a `Context`. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct MessageRef(InternalRef); + +/// A reference to an enum. Can be resolved to `EnumInfo` through a `Context`. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct EnumRef(InternalRef); + +/// A reference to a package. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct PackageRef(InternalRef); + +/// A reference to a service. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct ServiceRef(InternalRef); + +/// A reference to a service. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct OneofRef(InternalRef); + +/// Protofish error type. +#[derive(Debug, Snafu)] +#[snafu(visibility(pub(crate)))] +#[non_exhaustive] +pub enum ParseError { + /// Syntax error in the input files. + #[snafu(display("Parsing error: {}", source))] + SyntaxError { + /// Source error. + source: Box, + }, + + /// Duplicate type. + #[snafu(display("Duplicate type: {}", name))] + DuplicateType { + /// Type. + name: String, + }, + + /// Unknown type reference. + #[snafu(display("Unknown type '{}' in '{}'", name, context))] + TypeNotFound { + /// Type name. + name: String, + /// Type that referred to the unknown type. + context: String, + }, + + /// Wrong kind of type used in a specific context. + #[snafu(display( + "Invalid type '{}' ({:?}) for {}, expected {:?}", + type_name, + actual, + context, + expected + ))] + InvalidTypeKind { + /// Type that is of the wrong kind. + type_name: String, + + /// The context where the type was used. + context: &'static str, + + /// Expected item type. + expected: ItemType, + + /// Actual item type. + actual: ItemType, + }, +} + +/// Error modifying the context. +#[derive(Debug, Snafu)] +#[non_exhaustive] +pub enum InsertError { + /// A type conflicts with an existing type. + TypeExists { + /// The previous type that conflicts with the new one. + original: TypeRef, + }, +} + +/// Error modifying a type. +#[derive(Debug)] +#[non_exhaustive] +pub enum MemberInsertError { + /// A field with the same number already exists. + NumberConflict, + + /// A field with the same name already exists. + NameConflict, + + /// A field refers to a oneof that does not exist. + MissingOneof, +} + +/// Error modifying a type. +#[derive(Debug)] +#[non_exhaustive] +pub enum OneofInsertError { + /// A oneof with the same name already exists. + NameConflict, + + /// The oneof refers to a field that doesn't exist. + FieldNotFound { + /// Field number the Oneof referenced. + field: u64, + }, +} + +/// Type reference that references either message or enum type. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum TypeRef { + /// Message type reference. + Message(MessageRef), + + /// Enum type reference. + Enum(EnumRef), +} + +/// Protobuf item type +#[derive(Clone, Copy, Debug, PartialEq)] +pub enum ItemType { + /// `message` item + Message, + + /// `enum` item + Enum, + + /// `service` item + Service, +} + +/// Protofish decoding context. +/// +/// Contains type information parsed from the files. Required for decoding +/// incoming Protobuf messages. +#[derive(Default, Debug, PartialEq)] +pub struct Context { + packages: Vec, + types: Vec, + types_by_name: HashMap, + services: Vec, + services_by_name: HashMap, +} + +/// Package details. +#[derive(Debug, PartialEq)] +pub struct Package { + /// Package name. None for an anonymous package. + name: Option, + + /// Package self reference. + self_ref: PackageRef, + + /// Top level types. + types: Vec, + + /// Services. + services: Vec, +} + +/// Message or enum type. +#[derive(Debug, PartialEq)] +pub enum TypeInfo { + /// Message. + Message(MessageInfo), + + /// Enum. + Enum(EnumInfo), +} + +/// Message details +#[derive(Debug, PartialEq)] +#[non_exhaustive] +pub struct MessageInfo { + /// Message name. + pub name: String, + + /// Full message name, including package and parent type names. + pub full_name: String, + + /// Parent + pub parent: TypeParent, + + /// `MessageRef` that references this message. + pub self_ref: MessageRef, + + /// `oneof` structures defined within the message. + pub oneofs: Vec, + + /// References to the inner types defined within this message. + pub inner_types: Vec, + + // Using BTreeMap here to ensure ordering. + fields: BTreeMap, + fields_by_name: BTreeMap, +} + +/// Reference to a type parent. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum TypeParent { + /// Reference to a package for top-level types. + Package(PackageRef), + + /// Reference to a message for inner types. + Message(MessageRef), +} + +/// Enum details +#[derive(Debug, PartialEq)] +#[non_exhaustive] +pub struct EnumInfo { + /// Enum name. + pub name: String, + + /// Full message name, including package and parent type names. + pub full_name: String, + + /// Parent + pub parent: TypeParent, + + /// `EnumRef` that references this enum. + pub self_ref: EnumRef, + + fields_by_value: BTreeMap, + fields_by_name: BTreeMap, +} + +/// Message field details. +#[derive(Debug, PartialEq)] +#[non_exhaustive] +pub struct MessageField { + /// Field name. + pub name: String, + + /// Field number. + pub number: u64, + + /// Field type + pub field_type: ValueType, + + /// True, if this field is a repeated field. + pub multiplicity: Multiplicity, + + /// Field options. + pub options: Vec, + + /// Index to the ´oneof` structure in the parent type if this field is part of a `oneof`. + pub oneof: Option, +} + +/// Defines the multiplicity of the field values. +#[derive(Debug, PartialEq, Clone)] +pub enum Multiplicity { + /// Field is not repeated. + Single, + + /// Field may be repeated. + Repeated, + + /// Field is repeated by packing. + RepeatedPacked, + + /// Field is optional. + Optional, +} + +/// Message `oneof` details. +#[derive(Debug, PartialEq)] +#[non_exhaustive] +pub struct Oneof { + /// Name of the `oneof` structure. + pub name: String, + + /// Self reference of the `Oneof` in the owning type. + pub self_ref: OneofRef, + + /// Field numbers of the fields contained in the `oneof`. + pub fields: Vec, + + /// Options. + pub options: Vec, +} + +/// Enum field details. +#[derive(Debug, PartialEq, Clone)] +#[non_exhaustive] +pub struct EnumField { + /// Enum field name. + pub name: String, + + /// Enum field value. + pub value: i64, + + /// Options. + pub options: Vec, +} + +/// Field value types. +#[derive(Clone, Debug, PartialEq)] +pub enum ValueType { + /// `double` + Double, + + /// `float` + Float, + + /// `int32` + Int32, + + /// `int64` + Int64, + + /// `uint32` + UInt32, + + /// `uint64` + UInt64, + + /// `sint32` + SInt32, + + /// `sint64` + SInt64, + + /// `fixed32` + Fixed32, + + /// `fixed64` + Fixed64, + + /// `sfixed32` + SFixed32, + + /// `sfixed64` + SFixed64, + + /// `bool` + Bool, + + /// `string` + String, + + /// `bytes` + Bytes, + + /// A message type. + Message(MessageRef), + + /// An enum type. + Enum(EnumRef), +} + +/// Service details +#[derive(Debug, PartialEq)] +#[non_exhaustive] +pub struct Service { + /// Service name. + pub name: String, + + /// Full service name, including the package name. + pub full_name: String, + + /// Service self reference. + pub self_ref: ServiceRef, + + /// Package that contains the service. + pub parent: PackageRef, + + /// List of `rpc` operations defined in the service. + pub rpcs: Vec, + + /// Options. + pub options: Vec, + + rpcs_by_name: HashMap, +} + +/// Rpc operation +#[derive(Debug, PartialEq)] +#[non_exhaustive] +pub struct Rpc { + /// Operation name. + pub name: String, + + /// Input details. + pub input: RpcArg, + + /// Output details. + pub output: RpcArg, + + /// Options. + pub options: Vec, +} + +/// Rpc operation input or output details. +#[derive(Debug, PartialEq)] +#[non_exhaustive] +pub struct RpcArg { + /// References to the message type. + pub message: MessageRef, + + /// True, if this is a stream. + pub stream: bool, +} + +/// A single option. +#[derive(Debug, PartialEq, Clone)] +pub struct ProtoOption { + /// Option name. + pub name: String, + + /// Optionn value. + pub value: Constant, +} + +/// Constant value, used for options. +#[derive(Debug, PartialEq, Clone)] +pub enum Constant { + /// An ident `foo.bar.baz`. + Ident(String), + + /// An integer constant. + Integer(i64), + + /// A floating point constant. + Float(f64), + + /// A string constant. + /// + /// The string isn't guaranteed to be well formed UTF-8 so it's stored as + /// Bytes here. + String(Bytes), + + /// A boolean constant. + Bool(bool), +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn basic_package() { + let ctx = Context::parse(&[r#" + syntax = "proto3"; + message Message {} + "#]) + .unwrap(); + + let m = ctx.get_message("Message").unwrap(); + assert_eq!(m.parent, TypeParent::Package(PackageRef(InternalRef(0)))); + } + + #[test] + fn basic_multiple_package() { + let ctx = Context::parse(&[ + r#" + syntax = "proto3"; + package First; + message Message {} + "#, + r#" + syntax = "proto3"; + package Second; + message Message {} + "#, + ]) + .unwrap(); + + let m = ctx.get_message("First.Message").unwrap(); + let pkg_ref = match m.parent { + TypeParent::Package(p) => p, + _ => panic!("Not a package reference: {:?}", m.parent), + }; + let pkg = ctx.resolve_package(pkg_ref); + assert_eq!(m.parent, TypeParent::Package(PackageRef(InternalRef(0)))); + assert_eq!(pkg.name.as_deref(), Some("First")); + assert_eq!(pkg.types.len(), 1); + + let m = ctx.get_message("Second.Message").unwrap(); + let pkg_ref = match m.parent { + TypeParent::Package(p) => p, + _ => panic!("Not a package reference: {:?}", m.parent), + }; + let pkg = ctx.resolve_package(pkg_ref); + assert_eq!(m.parent, TypeParent::Package(PackageRef(InternalRef(1)))); + assert_eq!(pkg.name.as_deref(), Some("Second")); + assert_eq!(pkg.types.len(), 1); + } +} diff --git a/src/context/modify_api.rs b/src/context/modify_api.rs index 2a80035..abfa732 100644 --- a/src/context/modify_api.rs +++ b/src/context/modify_api.rs @@ -1,24 +1,20 @@ use super::*; -impl Context -{ +impl Context { /// Insert a new message definition to the context. - pub fn insert_message(&mut self, ty: MessageInfo) -> Result - { + pub fn insert_message(&mut self, ty: MessageInfo) -> Result { self.insert_type(TypeInfo::Message(ty)).map(MessageRef) } /// Insert a new enum definition to the context. - pub fn insert_enum(&mut self, ty: EnumInfo) -> Result - { + pub fn insert_enum(&mut self, ty: EnumInfo) -> Result { self.insert_type(TypeInfo::Enum(ty)).map(EnumRef) } /// Insert a new package to the context. /// /// Returns an error if the package with the same name already exists. - pub fn insert_package(&mut self, mut pkg: Package) -> Result - { + pub fn insert_package(&mut self, mut pkg: Package) -> Result { let pkg_ref = PackageRef(InternalRef(self.packages.len())); for existing in &self.packages { if existing.name == pkg.name { @@ -31,8 +27,7 @@ impl Context Ok(pkg_ref) } - fn insert_type(&mut self, mut ty: TypeInfo) -> Result - { + fn insert_type(&mut self, mut ty: TypeInfo) -> Result { use std::collections::hash_map::Entry; // First validate the operation. We'll want to ensure the operation succeeds before we make @@ -108,11 +103,9 @@ impl Context } } -impl Package -{ +impl Package { /// Create a new package. - pub fn new(name: Option) -> Self - { + pub fn new(name: Option) -> Self { Self { name, self_ref: PackageRef(InternalRef(0)), @@ -122,14 +115,12 @@ impl Package } } -impl MessageInfo -{ +impl MessageInfo { /// Create a new message info. /// /// Before inserting the message info into a [`Context`] certain fields such as `self_ref` or /// `full_name` are not valid. - pub fn new(name: String, parent: TypeParent) -> Self - { + pub fn new(name: String, parent: TypeParent) -> Self { MessageInfo { name, parent, @@ -145,8 +136,7 @@ impl MessageInfo } /// Add a field to the type. - pub fn add_field(&mut self, field: MessageField) -> Result<(), MemberInsertError> - { + pub fn add_field(&mut self, field: MessageField) -> Result<(), MemberInsertError> { use std::collections::btree_map::Entry; let num = field.number; @@ -174,8 +164,7 @@ impl MessageInfo } /// Add a oneof record to the message. - pub fn add_oneof(&mut self, mut oneof: Oneof) -> Result - { + pub fn add_oneof(&mut self, mut oneof: Oneof) -> Result { let oneof_ref = OneofRef(InternalRef(self.oneofs.len())); for o in &self.oneofs { if o.name == oneof.name { @@ -205,11 +194,9 @@ impl MessageInfo } } -impl MessageField -{ +impl MessageField { /// Create a new message field. - pub fn new(name: String, number: u64, field_type: ValueType) -> Self - { + pub fn new(name: String, number: u64, field_type: ValueType) -> Self { Self { name, number, @@ -221,11 +208,9 @@ impl MessageField } } -impl Oneof -{ +impl Oneof { /// Create a new Oneof definition. - pub fn new(name: String) -> Self - { + pub fn new(name: String) -> Self { Self { name, self_ref: OneofRef(InternalRef(0)), @@ -235,11 +220,9 @@ impl Oneof } } -impl EnumInfo -{ +impl EnumInfo { /// Create a new enum info. - pub fn new(name: String, parent: TypeParent) -> Self - { + pub fn new(name: String, parent: TypeParent) -> Self { Self { name, parent, @@ -251,8 +234,7 @@ impl EnumInfo } /// Add a field to the enum definition. - pub fn add_field(&mut self, field: EnumField) -> Result<(), MemberInsertError> - { + pub fn add_field(&mut self, field: EnumField) -> Result<(), MemberInsertError> { use std::collections::btree_map::Entry; let value = field.value; @@ -272,11 +254,9 @@ impl EnumInfo } } -impl EnumField -{ +impl EnumField { /// Create a new enum field. - pub fn new(name: String, value: i64) -> Self - { + pub fn new(name: String, value: i64) -> Self { Self { name, value, diff --git a/src/context/parse.rs b/src/context/parse.rs index 2291819..81e1a52 100644 --- a/src/context/parse.rs +++ b/src/context/parse.rs @@ -1,735 +1,760 @@ -use bytes::{BufMut, Bytes, BytesMut}; -use pest::{ - iterators::{Pair, Pairs}, - Parser, -}; - -use super::builder::*; -use super::*; - -#[derive(pest_derive::Parser)] -#[grammar = "proto.pest"] -struct ProtoParser; - -impl Context -{ - /// Parses the files and creates a decoding context. - pub fn parse(files: T) -> Result - where - T: IntoIterator, - S: AsRef, - { - let builder = ContextBuilder { - packages: files - .into_iter() - .map(|f| PackageBuilder::parse_str(f.as_ref())) - .collect::>()?, - }; - - builder.build() - } -} - -impl PackageBuilder -{ - pub fn parse_str(input: &str) -> Result - { - let pairs = ProtoParser::parse(Rule::proto, input) - .map_err(|e| Box::new(e) as Box) - .context(SyntaxError {})?; - - let mut current_package = PackageBuilder::default(); - for pair in pairs { - for inner in pair.into_inner() { - match inner.as_rule() { - Rule::syntax => {} - Rule::topLevelDef => current_package - .types - .push(ProtobufItemBuilder::parse(inner)), - Rule::import => {} - Rule::package => { - current_package.name = - Some(inner.into_inner().next().unwrap().as_str().to_string()) - } - Rule::option => {} - Rule::EOI => {} - r => unreachable!("{:?}: {:?}", r, inner), - } - } - } - - Ok(current_package) - } -} - -impl ProtobufItemBuilder -{ - pub fn parse(p: Pair) -> Self - { - let pair = p.into_inner().next().unwrap(); - match pair.as_rule() { - Rule::message => { - ProtobufItemBuilder::Type(ProtobufTypeBuilder::Message(MessageBuilder::parse(pair))) - } - Rule::enum_ => { - ProtobufItemBuilder::Type(ProtobufTypeBuilder::Enum(EnumBuilder::parse(pair))) - } - Rule::service => ProtobufItemBuilder::Service(ServiceBuilder::parse(pair)), - r => unreachable!("{:?}: {:?}", r, pair), - } - } -} - -impl MessageBuilder -{ - pub fn parse(p: Pair) -> Self - { - let mut inner = p.into_inner(); - let name = inner.next().unwrap().as_str().to_string(); - - let mut fields = vec![]; - let mut oneofs = vec![]; - let mut inner_types = vec![]; - let mut options = vec![]; - let body = inner.next().unwrap(); - for p in body.into_inner() { - match p.as_rule() { - Rule::field => fields.push(FieldBuilder::parse(p)), - Rule::enum_ => inner_types.push(InnerTypeBuilder::Enum(EnumBuilder::parse(p))), - Rule::message => { - inner_types.push(InnerTypeBuilder::Message(MessageBuilder::parse(p))) - } - Rule::option => options.push(ProtoOption::parse(p)), - Rule::oneof => oneofs.push(OneofBuilder::parse(p)), - Rule::mapField => unimplemented!("Maps are not supported"), - Rule::reserved => {} // We don't need to care about reserved field numbers. - Rule::emptyStatement => {} - r => unreachable!("{:?}: {:?}", r, p), - } - } - - MessageBuilder { - name, - fields, - oneofs, - inner_types, - options, - } - } -} - -impl EnumBuilder -{ - fn parse(p: Pair) -> EnumBuilder - { - let mut inner = p.into_inner(); - let name = inner.next().unwrap().as_str().to_string(); - - let mut fields = vec![]; - let mut options = vec![]; - let body = inner.next().unwrap(); - for p in body.into_inner() { - match p.as_rule() { - Rule::enumField => { - let mut inner = p.into_inner(); - fields.push(EnumField { - name: inner.next().unwrap().as_str().to_string(), - value: parse_int_literal(inner.next().unwrap()), - options: ProtoOption::parse_options(inner), - }) - } - Rule::option => options.push(ProtoOption::parse(p)), - Rule::emptyStatement => {} - r => unreachable!("{:?}: {:?}", r, p), - } - } - - EnumBuilder { - name, - fields, - options, - } - } -} - -impl ServiceBuilder -{ - pub fn parse(p: Pair) -> Self - { - let mut inner = p.into_inner(); - let name = inner.next().unwrap(); - let mut rpcs = vec![]; - let mut options = vec![]; - for p in inner { - match p.as_rule() { - Rule::option => options.push(ProtoOption::parse(p)), - Rule::rpc => rpcs.push(RpcBuilder::parse(p)), - Rule::emptyStatement => {} - r => unreachable!("{:?}: {:?}", r, p), - } - } - - ServiceBuilder { - name: name.as_str().to_string(), - rpcs, - options, - } - } -} - -impl FieldBuilder -{ - pub fn parse(p: Pair) -> Self - { - let mut inner = p.into_inner(); - let multiplicity = match inner.next().unwrap().into_inner().next() { - Some(t) => { - let multiplicity = t.into_inner().next().unwrap().as_rule(); - match multiplicity { - Rule::optional => Multiplicity::Optional, - Rule::repeated => Multiplicity::Repeated, - r => unreachable!("{:?}: {:?}", r, multiplicity), - } - } - None => Multiplicity::Single, - }; - let field_type = parse_field_type(inner.next().unwrap().as_str()); - let name = inner.next().unwrap().as_str().to_string(); - let number = parse_uint_literal(inner.next().unwrap()); - - let options = match inner.next() { - Some(p) => ProtoOption::parse_options(p.into_inner()), - None => vec![], - }; - - FieldBuilder { - multiplicity, - field_type, - name, - number, - options, - } - } - - pub fn parse_oneof(p: Pair) -> Self - { - let mut inner = p.into_inner(); - let field_type = parse_field_type(inner.next().unwrap().as_str()); - let name = inner.next().unwrap().as_str().to_string(); - let number = parse_uint_literal(inner.next().unwrap()); - - let options = match inner.next() { - Some(p) => ProtoOption::parse_options(p.into_inner()), - None => vec![], - }; - - FieldBuilder { - multiplicity: Multiplicity::Single, - field_type, - name, - number, - options, - } - } -} - -impl OneofBuilder -{ - pub fn parse(p: Pair) -> Self - { - let mut inner = p.into_inner(); - let name = inner.next().unwrap().as_str().to_string(); - let mut options = Vec::new(); - let mut fields = vec![]; - for p in inner { - match p.as_rule() { - Rule::option => options.push(ProtoOption::parse(p)), - Rule::oneofField => fields.push(FieldBuilder::parse_oneof(p)), - Rule::emptyStatement => {} - r => unreachable!("{:?}: {:?}", r, p), - } - } - OneofBuilder { - name, - fields, - options, - } - } -} - -fn parse_field_type(t: &str) -> FieldTypeBuilder -{ - FieldTypeBuilder::Builtin(match t { - "double" => ValueType::Double, - "float" => ValueType::Float, - "int32" => ValueType::Int32, - "int64" => ValueType::Int64, - "uint32" => ValueType::UInt32, - "uint64" => ValueType::UInt64, - "sint32" => ValueType::SInt32, - "sint64" => ValueType::SInt64, - "fixed32" => ValueType::Fixed32, - "fixed64" => ValueType::Fixed64, - "sfixed32" => ValueType::SFixed32, - "sfixed64" => ValueType::SFixed64, - "bool" => ValueType::Bool, - "string" => ValueType::String, - "bytes" => ValueType::Bytes, - _ => return FieldTypeBuilder::Unknown(t.to_string()), - }) -} - -impl RpcBuilder -{ - pub fn parse(p: Pair) -> Self - { - let mut inner = p.into_inner(); - let name = inner.next().unwrap(); - - let input = RpcArgBuilder::parse(inner.next().unwrap()); - let output = RpcArgBuilder::parse(inner.next().unwrap()); - - let mut options = vec![]; - for p in inner { - match p.as_rule() { - Rule::option => options.push(ProtoOption::parse(p)), - Rule::emptyStatement => {} - r => unreachable!("{:?}: {:?}", r, p), - } - } - - RpcBuilder { - name: name.as_str().to_string(), - input, - output, - options, - } - } -} - -impl RpcArgBuilder -{ - pub fn parse(p: Pair) -> Self - { - let mut inner = p.into_inner(); - RpcArgBuilder { - stream: inner.next().unwrap().into_inner().next().is_some(), - message: inner.next().unwrap().as_str().to_string(), - } - } -} - -pub fn parse_uint_literal(p: Pair) -> u64 -{ - match p.as_rule() { - Rule::fieldNumber => parse_uint_literal(p.into_inner().next().unwrap()), - Rule::intLit => { - let mut inner = p.into_inner(); - let lit = inner.next().unwrap(); - match lit.as_rule() { - Rule::decimalLit => str::parse(lit.as_str()).unwrap(), - Rule::octalLit => u64::from_str_radix(&lit.as_str()[1..], 8).unwrap(), - Rule::hexLit => u64::from_str_radix(&lit.as_str()[2..], 16).unwrap(), - r => unreachable!("{:?}: {:?}", r, lit), - } - } - r => unreachable!("{:?}: {:?}", r, p), - } -} - -pub fn parse_int_literal(p: Pair) -> i64 -{ - match p.as_rule() { - Rule::intLit => { - let mut inner = p.into_inner(); - let sign = inner.next().unwrap(); - let (sign, lit) = match sign.as_rule() { - Rule::sign if sign.as_str() == "-" => (-1, inner.next().unwrap()), - Rule::sign if sign.as_str() == "+" => (1, inner.next().unwrap()), - _ => (1, sign), - }; - match lit.as_rule() { - Rule::decimalLit => sign * str::parse::(lit.as_str()).unwrap(), - Rule::octalLit => sign * i64::from_str_radix(lit.as_str(), 8).unwrap(), - Rule::hexLit => sign * i64::from_str_radix(&lit.as_str()[2..], 16).unwrap(), - r => unreachable!("{:?}: {:?}", r, lit), - } - } - r => unreachable!("{:?}: {:?}", r, p), - } -} - -pub fn parse_float_literal(p: Pair) -> f64 -{ - match p.as_rule() { - Rule::floatLit => p.as_str().parse::().unwrap(), - r => unreachable!("{:?}: {:?}", r, p), - } -} - -impl ProtoOption -{ - fn parse(p: Pair) -> Self - { - let mut inner = p.into_inner(); - Self { - name: parse_ident(inner.next().unwrap()), - value: Constant::parse(inner.next().unwrap()), - } - } - - fn parse_options(pairs: Pairs) -> Vec - { - pairs - .map(|p| match p.as_rule() { - Rule::fieldOption => Self::parse(p), - Rule::enumValueOption => Self::parse(p), - Rule::option => Self::parse(p), - r => unreachable!("{:?}: {:?}", r, p), - }) - .collect() - } -} - -impl Constant -{ - fn parse(p: Pair) -> Self - { - let p = p.into_inner().next().unwrap(); - match p.as_rule() { - Rule::fullIdent => Constant::Ident(parse_ident(p)), - Rule::intLit => Constant::Integer(parse_int_literal(p)), - Rule::floatLit => Constant::Float(parse_float_literal(p)), - Rule::strLit => Constant::String(parse_string_literal(p)), - Rule::boolLit => Constant::Bool(p.as_str() == "true"), - r => unreachable!("{:?}: {:?}", r, p), - } - } -} - -fn parse_ident(p: Pair) -> String -{ - let mut ident = vec![]; - let mut inner = p.into_inner(); - - let first = inner.next().unwrap(); - match first.as_rule() { - Rule::ident => ident.push(first.as_str().to_string()), - Rule::fullIdent => ident.push(format!("({})", parse_ident(first))), - r => unreachable!("{:?}: {:?}", r, first), - } - - for other in inner { - match other.as_rule() { - Rule::ident => ident.push(other.as_str().to_string()), - r => unreachable!("{:?}: {:?}", r, other), - } - } - - ident.join(".") -} - -fn parse_string_literal(s: Pair) -> Bytes -{ - let inner = s.into_inner(); - let mut output = BytesMut::new(); - for c in inner { - let c = c.into_inner().next().unwrap(); - match c.as_rule() { - Rule::hexEscape => { - output.put_u8( - u8::from_str_radix(c.into_inner().next().unwrap().as_str(), 16).unwrap(), - ); - } - Rule::octEscape => { - output.put_u8( - u8::from_str_radix(c.into_inner().next().unwrap().as_str(), 8).unwrap(), - ); - } - Rule::charEscape => match c.into_inner().next().unwrap().as_str() { - "a" => output.put_u8(0x07), - "b" => output.put_u8(0x08), - "f" => output.put_u8(0x0C), - "n" => output.put_u8(0x0A), - "r" => output.put_u8(0x0D), - "t" => output.put_u8(0x09), - "v" => output.put_u8(0x0B), - "\\" => output.put_u8(0x5C), - "\'" => output.put_u8(0x27), - "\"" => output.put_u8(0x22), - o => unreachable!("Invalid escape sequence \\{}", o), - }, - Rule::anyChar => output.put(c.as_str().as_ref()), - r => unreachable!("{:?}: {:?}", r, c), - } - } - output.freeze() -} - -#[cfg(test)] -mod test -{ - use super::*; - - #[test] - fn empty() - { - assert_eq!( - PackageBuilder::parse_str( - r#" - syntax = "proto3"; - "# - ) - .unwrap(), - PackageBuilder::default(), - ); - } - - #[test] - fn package() - { - assert_eq!( - PackageBuilder::parse_str( - r#" - syntax = "proto3"; - package Test; - "# - ) - .unwrap(), - PackageBuilder { - name: Some("Test".to_string()), - ..Default::default() - } - ); - } - - #[test] - fn bom() - { - assert_eq!( - PackageBuilder::parse_str(&format!( - "\u{FEFF}{}", - r#" - syntax = "proto3"; - package Test; - "# - )) - .unwrap(), - PackageBuilder { - name: Some("Test".to_string()), - ..Default::default() - } - ); - } - - #[test] - fn message() - { - assert_eq!( - PackageBuilder::parse_str( - r#" - syntax = "proto3"; - - message MyMessage { - int32 value = 1; - } - "# - ) - .unwrap(), - PackageBuilder { - types: vec![ProtobufItemBuilder::Type(ProtobufTypeBuilder::Message( - MessageBuilder { - name: "MyMessage".to_string(), - fields: vec![FieldBuilder { - multiplicity: Multiplicity::Single, - field_type: FieldTypeBuilder::Builtin(ValueType::Int32), - name: "value".to_string(), - number: 1, - options: vec![], - }], - ..Default::default() - } - )),], - ..Default::default() - } - ); - } - - #[test] - fn pbenum() - { - assert_eq!( - PackageBuilder::parse_str( - r#" - syntax = "proto3"; - - enum MyEnum { - a = 1; - b = -1; - } - "# - ) - .unwrap(), - PackageBuilder { - types: vec![ProtobufItemBuilder::Type(ProtobufTypeBuilder::Enum( - EnumBuilder { - name: "MyEnum".to_string(), - fields: vec![ - EnumField { - name: "a".to_string(), - value: 1, - options: vec![], - }, - EnumField { - name: "b".to_string(), - value: -1, - options: vec![], - } - ], - ..Default::default() - } - )),], - ..Default::default() - } - ); - } - - #[test] - fn service() - { - assert_eq!( - PackageBuilder::parse_str( - r#" - syntax = "proto3"; - - service MyService { - rpc function( Foo ) returns ( stream Bar ); - } - "# - ) - .unwrap(), - PackageBuilder { - types: vec![ProtobufItemBuilder::Service(ServiceBuilder { - name: "MyService".to_string(), - rpcs: vec![RpcBuilder { - name: "function".to_string(), - input: RpcArgBuilder { - stream: false, - message: "Foo".to_string(), - }, - output: RpcArgBuilder { - stream: true, - message: "Bar".to_string(), - }, - ..Default::default() - },], - ..Default::default() - }),], - ..Default::default() - } - ); - } - - #[test] - fn options() - { - assert_eq!( - PackageBuilder::parse_str( - r#" - syntax = "proto3"; - - message Message { - option mOption = "foo"; - uint32 field = 1 [ fOption = bar ]; - } - - enum Enum { - value = 1 [ (a.b).c = 1, o2 = 2 ]; - option eOption = "banana"; - } - - service MyService { - rpc function( Foo ) returns ( stream Bar ) { option o = true; } - option sOption = "bar"; - } - "# - ) - .unwrap(), - PackageBuilder { - types: vec![ - ProtobufItemBuilder::Type(ProtobufTypeBuilder::Message(MessageBuilder { - name: "Message".to_string(), - fields: vec![FieldBuilder { - multiplicity: Multiplicity::Single, - field_type: FieldTypeBuilder::Builtin(ValueType::UInt32), - name: "field".to_string(), - number: 1, - options: vec![ProtoOption { - name: "fOption".to_string(), - value: Constant::Ident("bar".to_string()), - }], - }], - options: vec![ProtoOption { - name: "mOption".to_string(), - value: Constant::String(Bytes::from_static(b"foo")), - }], - ..Default::default() - })), - ProtobufItemBuilder::Type(ProtobufTypeBuilder::Enum(EnumBuilder { - name: "Enum".to_string(), - fields: vec![EnumField { - name: "value".to_string(), - value: 1, - options: vec![ - ProtoOption { - name: "(a.b).c".to_string(), - value: Constant::Integer(1), - }, - ProtoOption { - name: "o2".to_string(), - value: Constant::Integer(2), - } - ], - }], - options: vec![ProtoOption { - name: "eOption".to_string(), - value: Constant::String(Bytes::from_static(b"banana")), - }], - ..Default::default() - })), - ProtobufItemBuilder::Service(ServiceBuilder { - name: "MyService".to_string(), - rpcs: vec![RpcBuilder { - name: "function".to_string(), - input: RpcArgBuilder { - stream: false, - message: "Foo".to_string(), - }, - output: RpcArgBuilder { - stream: true, - message: "Bar".to_string(), - }, - options: vec![ProtoOption { - name: "o".to_string(), - value: Constant::Bool(true), - }] - },], - options: vec![ProtoOption { - name: "sOption".to_string(), - value: Constant::String(Bytes::from_static(b"bar")), - }] - }), - ], - ..Default::default() - } - ); - } - - #[test] - fn parse_string_vec() - { - let _ = Context::parse(&["foo", "bar"]); - let _ = Context::parse(vec!["foo", "bar"]); - let _ = Context::parse(vec!["foo".to_string(), "bar".to_string()]); - } -} +use bytes::{BufMut, Bytes, BytesMut}; +use pest::{ + iterators::{Pair, Pairs}, + Parser, +}; + +use super::builder::*; +use super::*; + +#[derive(pest_derive::Parser)] +#[grammar = "proto.pest"] +struct ProtoParser; + +impl Context { + /// Parses the files and creates a decoding context. + pub fn parse(files: T) -> Result + where + T: IntoIterator, + S: AsRef, + { + let builder = ContextBuilder { + packages: files + .into_iter() + .map(|f| PackageBuilder::parse_str(f.as_ref())) + .collect::>()?, + }; + + builder.build() + } +} + +impl PackageBuilder { + pub fn parse_str(input: &str) -> Result { + let pairs = ProtoParser::parse(Rule::proto, input) + .map_err(|e| Box::new(e) as Box) + .context(SyntaxError {})?; + + let mut current_package = PackageBuilder::default(); + for pair in pairs { + for inner in pair.into_inner() { + match inner.as_rule() { + Rule::syntax => {} + Rule::topLevelDef => current_package + .types + .push(ProtobufItemBuilder::parse(inner)), + Rule::import => {} + Rule::package => { + current_package.name = + Some(inner.into_inner().next().unwrap().as_str().to_string()) + } + Rule::option => {} + Rule::EOI => {} + r => unreachable!("{:?}: {:?}", r, inner), + } + } + } + + Ok(current_package) + } +} + +impl ProtobufItemBuilder { + pub fn parse(p: Pair) -> Self { + let pair = p.into_inner().next().unwrap(); + match pair.as_rule() { + Rule::message => { + ProtobufItemBuilder::Type(ProtobufTypeBuilder::Message(MessageBuilder::parse(pair))) + } + Rule::enum_ => { + ProtobufItemBuilder::Type(ProtobufTypeBuilder::Enum(EnumBuilder::parse(pair))) + } + Rule::service => ProtobufItemBuilder::Service(ServiceBuilder::parse(pair)), + r => unreachable!("{:?}: {:?}", r, pair), + } + } +} + +impl MessageBuilder { + pub fn parse(p: Pair) -> Self { + let mut inner = p.into_inner(); + let name = inner.next().unwrap().as_str().to_string(); + + let mut fields = vec![]; + let mut oneofs = vec![]; + let mut inner_types = vec![]; + let mut options = vec![]; + let body = inner.next().unwrap(); + for p in body.into_inner() { + match p.as_rule() { + Rule::field => fields.push(FieldBuilder::parse(p)), + Rule::enum_ => inner_types.push(InnerTypeBuilder::Enum(EnumBuilder::parse(p))), + Rule::message => { + inner_types.push(InnerTypeBuilder::Message(MessageBuilder::parse(p))) + } + Rule::option => options.push(ProtoOption::parse(p)), + Rule::oneof => oneofs.push(OneofBuilder::parse(p)), + Rule::mapField => { + let map_builder = PseudoMapBuilder::parse(p); + let message_builder = map_builder.create_message_builder(); + inner_types.push(InnerTypeBuilder::Message(message_builder)); + fields.push(map_builder.create_field_builder()); + } + + Rule::reserved => {} // We don't need to care about reserved field numbers. + Rule::emptyStatement => {} + r => unreachable!("{:?}: {:?}", r, p), + } + } + + MessageBuilder { + name, + fields, + oneofs, + inner_types, + options, + } + } +} + +impl PseudoMapBuilder { + pub fn parse(p: Pair) -> Self { + let mut inner = p.into_inner(); + let key_type = parse_field_type(inner.next().unwrap().as_str()); + let value_type = parse_field_type(inner.next().unwrap().as_str()); + let field_name = inner.next().unwrap().as_str().to_string(); + let number = parse_uint_literal(inner.next().unwrap()); + + let options = match inner.next() { + Some(p) => ProtoOption::parse_options(p.into_inner()), + None => vec![], + }; + Self { + key_type, + value_type, + field_name, + number, + options, + } + } + + fn entry_type_name(&self) -> String { + format!("[map {}]FieldEntry", self.field_name) + } + + pub fn create_message_builder(&self) -> MessageBuilder { + MessageBuilder { + name: self.entry_type_name(), + fields: vec![ + FieldBuilder { + multiplicity: Multiplicity::Single, + field_type: self.key_type.clone(), + name: "key".to_string(), + number: 1, + options: vec![], + }, + FieldBuilder { + multiplicity: Multiplicity::Single, + field_type: self.value_type.clone(), + name: "value".to_string(), + number: 2, + options: vec![], + }, + ], + ..Default::default() + } + } + + pub fn create_field_builder(&self) -> FieldBuilder { + FieldBuilder { + multiplicity: Multiplicity::Repeated, + field_type: FieldTypeBuilder::Unknown(self.entry_type_name()), + name: self.field_name.clone(), + number: self.number, + options: self.options.clone(), + } + } +} + +impl EnumBuilder { + fn parse(p: Pair) -> EnumBuilder { + let mut inner = p.into_inner(); + let name = inner.next().unwrap().as_str().to_string(); + + let mut fields = vec![]; + let mut options = vec![]; + let body = inner.next().unwrap(); + for p in body.into_inner() { + match p.as_rule() { + Rule::enumField => { + let mut inner = p.into_inner(); + fields.push(EnumField { + name: inner.next().unwrap().as_str().to_string(), + value: parse_int_literal(inner.next().unwrap()), + options: ProtoOption::parse_options(inner), + }) + } + Rule::option => options.push(ProtoOption::parse(p)), + Rule::emptyStatement => {} + r => unreachable!("{:?}: {:?}", r, p), + } + } + + EnumBuilder { + name, + fields, + options, + } + } +} + +impl ServiceBuilder { + pub fn parse(p: Pair) -> Self { + let mut inner = p.into_inner(); + let name = inner.next().unwrap(); + let mut rpcs = vec![]; + let mut options = vec![]; + for p in inner { + match p.as_rule() { + Rule::option => options.push(ProtoOption::parse(p)), + Rule::rpc => rpcs.push(RpcBuilder::parse(p)), + Rule::emptyStatement => {} + r => unreachable!("{:?}: {:?}", r, p), + } + } + + ServiceBuilder { + name: name.as_str().to_string(), + rpcs, + options, + } + } +} + +impl FieldBuilder { + pub fn parse(p: Pair) -> Self { + let mut inner = p.into_inner(); + let multiplicity = match inner.next().unwrap().into_inner().next() { + Some(t) => { + let rule = t.into_inner().next().unwrap().as_rule(); + match rule { + Rule::optional => Multiplicity::Optional, + Rule::repeated => Multiplicity::Repeated, + r => unreachable!("{:?}: {:?}", r, rule), + } + } + None => Multiplicity::Single, + }; + let field_type = parse_field_type(inner.next().unwrap().as_str()); + let name = inner.next().unwrap().as_str().to_string(); + let number = parse_uint_literal(inner.next().unwrap()); + + let options = match inner.next() { + Some(p) => ProtoOption::parse_options(p.into_inner()), + None => vec![], + }; + + FieldBuilder { + multiplicity, + field_type, + name, + number, + options, + } + } + + pub fn parse_oneof(p: Pair) -> Self { + let mut inner = p.into_inner(); + let field_type = parse_field_type(inner.next().unwrap().as_str()); + let name = inner.next().unwrap().as_str().to_string(); + let number = parse_uint_literal(inner.next().unwrap()); + + let options = match inner.next() { + Some(p) => ProtoOption::parse_options(p.into_inner()), + None => vec![], + }; + + FieldBuilder { + multiplicity: Multiplicity::Single, + field_type, + name, + number, + options, + } + } +} + +impl OneofBuilder { + pub fn parse(p: Pair) -> Self { + let mut inner = p.into_inner(); + let name = inner.next().unwrap().as_str().to_string(); + let mut options = Vec::new(); + let mut fields = vec![]; + for p in inner { + match p.as_rule() { + Rule::option => options.push(ProtoOption::parse(p)), + Rule::oneofField => fields.push(FieldBuilder::parse_oneof(p)), + Rule::emptyStatement => {} + r => unreachable!("{:?}: {:?}", r, p), + } + } + OneofBuilder { + name, + fields, + options, + } + } +} + +fn parse_field_type(t: &str) -> FieldTypeBuilder { + FieldTypeBuilder::Builtin(match t { + "double" => ValueType::Double, + "float" => ValueType::Float, + "int32" => ValueType::Int32, + "int64" => ValueType::Int64, + "uint32" => ValueType::UInt32, + "uint64" => ValueType::UInt64, + "sint32" => ValueType::SInt32, + "sint64" => ValueType::SInt64, + "fixed32" => ValueType::Fixed32, + "fixed64" => ValueType::Fixed64, + "sfixed32" => ValueType::SFixed32, + "sfixed64" => ValueType::SFixed64, + "bool" => ValueType::Bool, + "string" => ValueType::String, + "bytes" => ValueType::Bytes, + _ => return FieldTypeBuilder::Unknown(t.to_string()), + }) +} + +impl RpcBuilder { + pub fn parse(p: Pair) -> Self { + let mut inner = p.into_inner(); + let name = inner.next().unwrap(); + + let input = RpcArgBuilder::parse(inner.next().unwrap()); + let output = RpcArgBuilder::parse(inner.next().unwrap()); + + let mut options = vec![]; + for p in inner { + match p.as_rule() { + Rule::option => options.push(ProtoOption::parse(p)), + Rule::emptyStatement => {} + r => unreachable!("{:?}: {:?}", r, p), + } + } + + RpcBuilder { + name: name.as_str().to_string(), + input, + output, + options, + } + } +} + +impl RpcArgBuilder { + pub fn parse(p: Pair) -> Self { + let mut inner = p.into_inner(); + RpcArgBuilder { + stream: inner.next().unwrap().into_inner().next().is_some(), + message: inner.next().unwrap().as_str().to_string(), + } + } +} + +pub fn parse_uint_literal(p: Pair) -> u64 { + match p.as_rule() { + Rule::fieldNumber => parse_uint_literal(p.into_inner().next().unwrap()), + Rule::intLit => { + let mut inner = p.into_inner(); + let lit = inner.next().unwrap(); + match lit.as_rule() { + Rule::decimalLit => str::parse(lit.as_str()).unwrap(), + Rule::octalLit => u64::from_str_radix(&lit.as_str()[1..], 8).unwrap(), + Rule::hexLit => u64::from_str_radix(&lit.as_str()[2..], 16).unwrap(), + r => unreachable!("{:?}: {:?}", r, lit), + } + } + r => unreachable!("{:?}: {:?}", r, p), + } +} + +pub fn parse_int_literal(p: Pair) -> i64 { + match p.as_rule() { + Rule::intLit => { + let mut inner = p.into_inner(); + let sign = inner.next().unwrap(); + let (sign, lit) = match sign.as_rule() { + Rule::sign if sign.as_str() == "-" => (-1, inner.next().unwrap()), + Rule::sign if sign.as_str() == "+" => (1, inner.next().unwrap()), + _ => (1, sign), + }; + match lit.as_rule() { + Rule::decimalLit => sign * str::parse::(lit.as_str()).unwrap(), + Rule::octalLit => sign * i64::from_str_radix(lit.as_str(), 8).unwrap(), + Rule::hexLit => sign * i64::from_str_radix(&lit.as_str()[2..], 16).unwrap(), + r => unreachable!("{:?}: {:?}", r, lit), + } + } + r => unreachable!("{:?}: {:?}", r, p), + } +} + +pub fn parse_float_literal(p: Pair) -> f64 { + match p.as_rule() { + Rule::floatLit => p.as_str().parse::().unwrap(), + r => unreachable!("{:?}: {:?}", r, p), + } +} + +impl ProtoOption { + fn parse(p: Pair) -> Self { + let mut inner = p.into_inner(); + Self { + name: parse_ident(inner.next().unwrap()), + value: Constant::parse(inner.next().unwrap()), + } + } + + fn parse_options(pairs: Pairs) -> Vec { + pairs + .map(|p| match p.as_rule() { + Rule::fieldOption => Self::parse(p), + Rule::enumValueOption => Self::parse(p), + Rule::option => Self::parse(p), + r => unreachable!("{:?}: {:?}", r, p), + }) + .collect() + } +} + +impl Constant { + fn parse(p: Pair) -> Self { + let p = p.into_inner().next().unwrap(); + match p.as_rule() { + Rule::fullIdent => Constant::Ident(parse_ident(p)), + Rule::intLit => Constant::Integer(parse_int_literal(p)), + Rule::floatLit => Constant::Float(parse_float_literal(p)), + Rule::strLit => Constant::String(parse_string_literal(p)), + Rule::boolLit => Constant::Bool(p.as_str() == "true"), + r => unreachable!("{:?}: {:?}", r, p), + } + } +} + +fn parse_ident(p: Pair) -> String { + let mut ident = vec![]; + let mut inner = p.into_inner(); + + let first = inner.next().unwrap(); + match first.as_rule() { + Rule::ident => ident.push(first.as_str().to_string()), + Rule::fullIdent => ident.push(format!("({})", parse_ident(first))), + r => unreachable!("{:?}: {:?}", r, first), + } + + for other in inner { + match other.as_rule() { + Rule::ident => ident.push(other.as_str().to_string()), + r => unreachable!("{:?}: {:?}", r, other), + } + } + + ident.join(".") +} + +fn parse_string_literal(s: Pair) -> Bytes { + let inner = s.into_inner(); + let mut output = BytesMut::new(); + for c in inner { + let c = c.into_inner().next().unwrap(); + match c.as_rule() { + Rule::hexEscape => { + output.put_u8( + u8::from_str_radix(c.into_inner().next().unwrap().as_str(), 16).unwrap(), + ); + } + Rule::octEscape => { + output.put_u8( + u8::from_str_radix(c.into_inner().next().unwrap().as_str(), 8).unwrap(), + ); + } + Rule::charEscape => match c.into_inner().next().unwrap().as_str() { + "a" => output.put_u8(0x07), + "b" => output.put_u8(0x08), + "f" => output.put_u8(0x0C), + "n" => output.put_u8(0x0A), + "r" => output.put_u8(0x0D), + "t" => output.put_u8(0x09), + "v" => output.put_u8(0x0B), + "\\" => output.put_u8(0x5C), + "\'" => output.put_u8(0x27), + "\"" => output.put_u8(0x22), + o => unreachable!("Invalid escape sequence \\{}", o), + }, + Rule::anyChar => output.put(c.as_str().as_ref()), + r => unreachable!("{:?}: {:?}", r, c), + } + } + output.freeze() +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn empty() { + assert_eq!( + PackageBuilder::parse_str( + r#" + syntax = "proto3"; + "# + ) + .unwrap(), + PackageBuilder::default(), + ); + } + + #[test] + fn package() { + assert_eq!( + PackageBuilder::parse_str( + r#" + syntax = "proto3"; + package Test; + "# + ) + .unwrap(), + PackageBuilder { + name: Some("Test".to_string()), + ..Default::default() + } + ); + } + + #[test] + fn bom() { + assert_eq!( + PackageBuilder::parse_str(&format!( + "\u{FEFF}{}", + r#" + syntax = "proto3"; + package Test; + "# + )) + .unwrap(), + PackageBuilder { + name: Some("Test".to_string()), + ..Default::default() + } + ); + } + + #[test] + fn message() { + assert_eq!( + PackageBuilder::parse_str( + r#" + syntax = "proto3"; + + message MyMessage { + int32 value = 1; + } + "# + ) + .unwrap(), + PackageBuilder { + types: vec![ProtobufItemBuilder::Type(ProtobufTypeBuilder::Message( + MessageBuilder { + name: "MyMessage".to_string(), + fields: vec![FieldBuilder { + multiplicity: Multiplicity::Single, + field_type: FieldTypeBuilder::Builtin(ValueType::Int32), + name: "value".to_string(), + number: 1, + options: vec![], + }], + ..Default::default() + } + )),], + ..Default::default() + } + ); + } + + #[test] + fn pbenum() { + assert_eq!( + PackageBuilder::parse_str( + r#" + syntax = "proto3"; + + enum MyEnum { + a = 1; + b = -1; + } + "# + ) + .unwrap(), + PackageBuilder { + types: vec![ProtobufItemBuilder::Type(ProtobufTypeBuilder::Enum( + EnumBuilder { + name: "MyEnum".to_string(), + fields: vec![ + EnumField { + name: "a".to_string(), + value: 1, + options: vec![], + }, + EnumField { + name: "b".to_string(), + value: -1, + options: vec![], + } + ], + ..Default::default() + } + )),], + ..Default::default() + } + ); + } + + #[test] + fn service() { + assert_eq!( + PackageBuilder::parse_str( + r#" + syntax = "proto3"; + + service MyService { + rpc function( Foo ) returns ( stream Bar ); + } + "# + ) + .unwrap(), + PackageBuilder { + types: vec![ProtobufItemBuilder::Service(ServiceBuilder { + name: "MyService".to_string(), + rpcs: vec![RpcBuilder { + name: "function".to_string(), + input: RpcArgBuilder { + stream: false, + message: "Foo".to_string(), + }, + output: RpcArgBuilder { + stream: true, + message: "Bar".to_string(), + }, + ..Default::default() + },], + ..Default::default() + }),], + ..Default::default() + } + ); + } + + #[test] + fn options() { + assert_eq!( + PackageBuilder::parse_str( + r#" + syntax = "proto3"; + + message Message { + option mOption = "foo"; + uint32 field = 1 [ fOption = bar ]; + } + + enum Enum { + value = 1 [ (a.b).c = 1, o2 = 2 ]; + option eOption = "banana"; + } + + service MyService { + rpc function( Foo ) returns ( stream Bar ) { option o = true; } + option sOption = "bar"; + } + "# + ) + .unwrap(), + PackageBuilder { + types: vec![ + ProtobufItemBuilder::Type(ProtobufTypeBuilder::Message(MessageBuilder { + name: "Message".to_string(), + fields: vec![FieldBuilder { + multiplicity: Multiplicity::Single, + field_type: FieldTypeBuilder::Builtin(ValueType::UInt32), + name: "field".to_string(), + number: 1, + options: vec![ProtoOption { + name: "fOption".to_string(), + value: Constant::Ident("bar".to_string()), + }], + }], + options: vec![ProtoOption { + name: "mOption".to_string(), + value: Constant::String(Bytes::from_static(b"foo")), + }], + ..Default::default() + })), + ProtobufItemBuilder::Type(ProtobufTypeBuilder::Enum(EnumBuilder { + name: "Enum".to_string(), + fields: vec![EnumField { + name: "value".to_string(), + value: 1, + options: vec![ + ProtoOption { + name: "(a.b).c".to_string(), + value: Constant::Integer(1), + }, + ProtoOption { + name: "o2".to_string(), + value: Constant::Integer(2), + } + ], + }], + options: vec![ProtoOption { + name: "eOption".to_string(), + value: Constant::String(Bytes::from_static(b"banana")), + }], + ..Default::default() + })), + ProtobufItemBuilder::Service(ServiceBuilder { + name: "MyService".to_string(), + rpcs: vec![RpcBuilder { + name: "function".to_string(), + input: RpcArgBuilder { + stream: false, + message: "Foo".to_string(), + }, + output: RpcArgBuilder { + stream: true, + message: "Bar".to_string(), + }, + options: vec![ProtoOption { + name: "o".to_string(), + value: Constant::Bool(true), + }] + },], + options: vec![ProtoOption { + name: "sOption".to_string(), + value: Constant::String(Bytes::from_static(b"bar")), + }] + }), + ], + ..Default::default() + } + ); + } + + #[test] + fn parse_string_vec() { + let _ = Context::parse(&["foo", "bar"]); + let _ = Context::parse(vec!["foo", "bar"]); + let _ = Context::parse(vec!["foo".to_string(), "bar".to_string()]); + } +} diff --git a/src/decode.rs b/src/decode.rs index 32632c4..c7beffd 100644 --- a/src/decode.rs +++ b/src/decode.rs @@ -1,801 +1,765 @@ -//! Protocol buffer binary payload decoding. -//! -//! The decoding functionality can be accessed by building a decoding context and acquiring a -//! message or message reference. See the example in the [crate root](crate). - -use crate::context::*; -use bytes::{Bytes, BytesMut}; -use std::convert::{TryFrom, TryInto}; -use std::fmt::Debug; - -impl Context -{ - /// Decode a message. - pub fn decode(&self, msg: MessageRef, data: &[u8]) -> MessageValue - { - self.resolve_message(msg).decode(data, self) - } -} - -/// Decoded protocol buffer value. -#[derive(Debug, PartialEq, Clone)] -pub enum Value -{ - /// `double` value. - Double(f64), - /// `float` value. - Float(f32), - /// `int32` value. - Int32(i32), - /// `int64` value. - Int64(i64), - /// `uint32` value. - UInt32(u32), - /// `uint64` value. - UInt64(u64), - /// `sint32` value. - SInt32(i32), - /// `sint64` value. - SInt64(i64), - /// `fixed32` value. - Fixed32(u32), - /// `fixed64` value. - Fixed64(u64), - /// `sfixed32` value. - SFixed32(i32), - /// `sfixed64` value. - SFixed64(i64), - /// `bool` value. - Bool(bool), - /// `string` value. - String(String), - /// `bytes` value. - Bytes(Bytes), - - /// A repeated packed value. - Packed(PackedArray), - - /// Message type value. - Message(Box), - - /// Enum type value. - Enum(EnumValue), - - /// Value which was incomplete due to missing bytes in the payload. - Incomplete(u8, Bytes), - - /// Value which wasn't defined in the context. - /// - /// The wire type allows the decoder to tell how large an unknown value is. This allows the - /// unknown value to be skipped and decoding can continue from the next value. - Unknown(UnknownValue), -} - -/// Packed scalar fields. -#[derive(Debug, PartialEq, Clone)] -pub enum PackedArray -{ - /// `double` value. - Double(Vec), - /// `float` value. - Float(Vec), - /// `int32` value. - Int32(Vec), - /// `int64` value. - Int64(Vec), - /// `uint32` value. - UInt32(Vec), - /// `uint64` value. - UInt64(Vec), - /// `sint32` value. - SInt32(Vec), - /// `sint64` value. - SInt64(Vec), - /// `fixed32` value. - Fixed32(Vec), - /// `fixed64` value. - Fixed64(Vec), - /// `sfixed32` value. - SFixed32(Vec), - /// `sfixed64` value. - SFixed64(Vec), - /// `bool` value. - Bool(Vec), -} - -/// Unknown value. -#[derive(Debug, PartialEq, Clone)] -pub enum UnknownValue -{ - /// Unknown varint (wire type = 0). - Varint(u128), - - /// Unknown 64-bit value (wire type = 1). - Fixed64(u64), - - /// Unknown variable length value (wire type = 2). - VariableLength(Bytes), - - /// Unknown 32-bit value (wire type = 5). - Fixed32(u32), - - /// Invalid value. - /// - /// Invalid value is a value for which the wire type wasn't valid. Encountering invalid wire - /// type will result in the remaining bytes to be consumed from the current variable length - /// stream as it is imposible to tell how large such invalid value is. - /// - /// The decoding will continue after the current variable length value. - Invalid(u8, Bytes), -} - -/// Enum value. -#[derive(Debug, PartialEq, Clone)] -pub struct EnumValue -{ - /// Reference to the enum type. - pub enum_ref: EnumRef, - - /// Value. - pub value: i64, -} - -/// Message value. -#[derive(Debug, PartialEq, Clone)] -pub struct MessageValue -{ - /// Reference to the message type. - pub msg_ref: MessageRef, - - /// Mesage field values. - pub fields: Vec, - - /// Garbage data at the end of the message. - /// - /// As opposed to an `UnknownValue::Invalid`, the garbage data did not have a valid field - /// number and for that reason cannot be placed into the `fields` vector. - pub garbage: Option, -} - -/// Field value. -#[derive(Debug, PartialEq, Clone)] -pub struct FieldValue -{ - /// Field number. - pub number: u64, - - /// Field value. - pub value: Value, -} - -impl Value -{ - fn decode(data: &mut &[u8], vt_raw: u8, vt: &ValueType, ctx: &Context) -> Self - { - let original = *data; - let opt = match vt { - ValueType::Double => { - try_read_8_bytes(data).map(|b| Value::Double(f64::from_le_bytes(b))) - } - ValueType::Float => try_read_4_bytes(data).map(|b| Value::Float(f32::from_le_bytes(b))), - ValueType::Int32 => i32::from_signed_varint(data).map(Value::Int32), - ValueType::Int64 => i64::from_signed_varint(data).map(Value::Int64), - ValueType::UInt32 => u32::from_unsigned_varint(data).map(Value::UInt32), - ValueType::UInt64 => u64::from_unsigned_varint(data).map(Value::UInt64), - ValueType::SInt32 => u32::from_unsigned_varint(data).map(|u| { - let (sign, sign_bit) = if u % 2 == 0 { (1i32, 0) } else { (-1i32, 1) }; - let magnitude = (u / 2) as i32 + sign_bit; - Value::SInt32(sign * magnitude) - }), - ValueType::SInt64 => u64::from_unsigned_varint(data).map(|u| { - let (sign, sign_bit) = if u % 2 == 0 { (1i64, 0) } else { (-1i64, 1) }; - let magnitude = (u / 2) as i64 + sign_bit; - Value::SInt64(sign * magnitude) - }), - ValueType::Fixed32 => { - try_read_4_bytes(data).map(|b| Value::Fixed32(u32::from_le_bytes(b))) - } - ValueType::Fixed64 => { - try_read_8_bytes(data).map(|b| Value::Fixed64(u64::from_le_bytes(b))) - } - ValueType::SFixed32 => { - try_read_4_bytes(data).map(|b| Value::SFixed32(i32::from_le_bytes(b))) - } - ValueType::SFixed64 => { - try_read_8_bytes(data).map(|b| Value::SFixed64(i64::from_le_bytes(b))) - } - ValueType::Bool => usize::from_unsigned_varint(data).map(|u| Value::Bool(u != 0)), - ValueType::String => read_string(data).map(Value::String), - ValueType::Bytes => read_bytes(data).map(Value::Bytes), - ValueType::Enum(eref) => i64::from_signed_varint(data).map(|v| { - Value::Enum(EnumValue { - enum_ref: *eref, - value: v, - }) - }), - ValueType::Message(mref) => usize::from_unsigned_varint(data).and_then(|length| { - if data.len() < length { - *data = original; - return None; - } - let (consumed, remainder) = data.split_at(length); - *data = remainder; - - Some(Value::Message(Box::new( - ctx.resolve_message(*mref).decode(consumed, ctx), - ))) - }), - }; - - opt.unwrap_or_else(|| { - *data = &[]; - Value::Incomplete(vt_raw, Bytes::copy_from_slice(original)) - }) - } - - fn decode_packed(data: &mut &[u8], vt_raw: u8, vt: &ValueType) -> Self - { - let original = *data; - let length = match usize::from_unsigned_varint(data) { - Some(len) => len, - None => { - return return_incomplete(data, vt_raw, original); - } - }; - - if data.len() < length { - return return_incomplete(data, vt_raw, original); - } - - let mut array = &data[..length]; - *data = &data[length..]; - - // Reading the packed arrays follows very similar format for each type. The variances are - // in how to read the data from the stream and what to do with the data to get the final - // value. - // - // This macro implements the basic structure with holes for the varying bits. - macro_rules! read_packed { - ($variant:ident @ $val:ident = $try_read:expr => $insert:expr ) => { - let mut output = vec![]; - loop { - if array.is_empty() { - break Value::Packed(PackedArray::$variant(output)); - } - - match $try_read { - Some($val) => output.push($insert), - None => return return_incomplete(&mut array, vt_raw, original), - } - } - }; - } - - match vt { - ValueType::Double => { - read_packed! { Double @ b = try_read_8_bytes(&mut array) => f64::from_le_bytes(b) } - } - ValueType::Float => { - read_packed! { Float @ b = try_read_4_bytes(&mut array) => f32::from_le_bytes(b) } - } - ValueType::Int32 => { - read_packed! { Int32 @ b = i32::from_signed_varint(&mut array) => b } - } - ValueType::Int64 => { - read_packed! { Int64 @ b = i64::from_signed_varint(&mut array) => b } - } - ValueType::UInt32 => { - read_packed! { UInt32 @ b = u32::from_signed_varint(&mut array) => b } - } - ValueType::UInt64 => { - read_packed! { UInt64 @ b = u64::from_signed_varint(&mut array) => b } - } - ValueType::SInt32 => { - read_packed! { SInt32 @ b = u32::from_signed_varint(&mut array) => { - let (sign, sign_bit) = if b % 2 == 0 { (1i32, 0) } else { (-1i32, 1) }; - let magnitude = (b / 2) as i32 + sign_bit; - sign * magnitude - } } - } - ValueType::SInt64 => { - read_packed! { SInt64 @ b = u64::from_signed_varint(&mut array) => { - let (sign, sign_bit) = if b % 2 == 0 { (1i64, 0) } else { (-1i64, 1) }; - let magnitude = (b / 2) as i64 + sign_bit; - sign * magnitude - } } - } - ValueType::Fixed32 => { - read_packed! { Fixed32 @ b = try_read_4_bytes(&mut array) => u32::from_le_bytes(b) } - } - ValueType::Fixed64 => { - read_packed! { Fixed64 @ b = try_read_8_bytes(&mut array) => u64::from_le_bytes(b) } - } - ValueType::SFixed32 => { - read_packed! { SFixed32 @ b = try_read_4_bytes(&mut array) => i32::from_le_bytes(b) } - } - ValueType::SFixed64 => { - read_packed! { SFixed64 @ b = try_read_8_bytes(&mut array) => i64::from_le_bytes(b) } - } - ValueType::Bool => { - read_packed! { Bool @ b = u8::from_unsigned_varint(&mut array) => b != 0 } - } - _ => panic!("Non-scalar type was handled as packed"), - } - } - - fn decode_unknown(data: &mut &[u8], vt: u8) -> Value - { - let original = *data; - let value = - match vt { - 0 => u128::from_unsigned_varint(data).map(UnknownValue::Varint), - 1 => try_read_8_bytes(data) - .map(|value| UnknownValue::Fixed64(u64::from_le_bytes(value))), - 2 => usize::from_unsigned_varint(data).and_then(|length| { - if length > data.len() { - *data = original; - return None; - } - let (consumed, remainder) = data.split_at(length); - *data = remainder; - Some(UnknownValue::VariableLength(Bytes::copy_from_slice( - consumed, - ))) - }), - 5 => try_read_4_bytes(data) - .map(|value| UnknownValue::Fixed32(u32::from_le_bytes(value))), - _ => { - let bytes = Bytes::copy_from_slice(data); - *data = &[]; - Some(UnknownValue::Invalid(vt, bytes)) - } - }; - - value - .map(Value::Unknown) - .unwrap_or_else(|| Value::Incomplete(vt, Bytes::copy_from_slice(data))) - } - - fn encode(&self, ctx: &Context) -> Option<(u8, BytesMut)> - { - let bytes = match self { - Value::Double(v) => BytesMut::from(v.to_le_bytes().as_ref()), - Value::Float(v) => BytesMut::from(v.to_le_bytes().as_ref()), - Value::Int32(v) => BytesMut::from(v.into_signed_varint().as_ref()), - Value::Int64(v) => BytesMut::from(v.into_signed_varint().as_ref()), - Value::UInt32(v) => BytesMut::from(v.into_unsigned_varint().as_ref()), - Value::UInt64(v) => BytesMut::from(v.into_unsigned_varint().as_ref()), - Value::SInt32(v) => { - let (v, sign_bit) = if *v < 0 { (-v, 1) } else { (*v, 0) }; - (v * 2 - sign_bit).into_unsigned_varint() - } - Value::SInt64(v) => { - let (v, sign_bit) = if *v < 0 { (-v, 1) } else { (*v, 0) }; - (v * 2 - sign_bit).into_unsigned_varint() - } - Value::Fixed32(v) => BytesMut::from(v.to_le_bytes().as_ref()), - Value::Fixed64(v) => BytesMut::from(v.to_le_bytes().as_ref()), - Value::SFixed32(v) => BytesMut::from(v.to_le_bytes().as_ref()), - Value::SFixed64(v) => BytesMut::from(v.to_le_bytes().as_ref()), - Value::Bool(v) => BytesMut::from(if *v { [1u8].as_ref() } else { [0u8].as_ref() }), - Value::String(v) => { - let mut output = v.len().into_unsigned_varint(); - output.extend_from_slice(v.as_bytes()); - output - } - Value::Bytes(v) => { - let mut output = v.len().into_unsigned_varint(); - output.extend_from_slice(v); - output - } - Value::Enum(v) => BytesMut::from(v.value.into_signed_varint().as_ref()), - Value::Message(v) => { - let data = v.encode(ctx); - let mut output = data.len().into_unsigned_varint(); - output.extend_from_slice(&data); - output - } - Value::Packed(p) => p.encode(), - Value::Unknown(u) => u.encode(), - Value::Incomplete(_, bytes) => BytesMut::from(bytes.as_ref()), - }; - - Some((self.wire_type(), bytes)) - } - - fn wire_type(&self) -> u8 - { - match self { - Value::Double(..) => 1, - Value::Float(..) => 5, - Value::Int32(..) => 0, - Value::Int64(..) => 0, - Value::UInt32(..) => 0, - Value::UInt64(..) => 0, - Value::SInt32(..) => 0, - Value::SInt64(..) => 0, - Value::Fixed32(..) => 5, - Value::Fixed64(..) => 1, - Value::SFixed32(..) => 5, - Value::SFixed64(..) => 1, - Value::Bool(..) => 0, - Value::String(..) => 2, - Value::Bytes(..) => 2, - Value::Message(..) => 2, - Value::Enum(..) => 0, - Value::Packed(..) => 2, - Value::Unknown(unk) => match unk { - UnknownValue::Varint(..) => 0, - UnknownValue::Fixed64(..) => 1, - UnknownValue::VariableLength(..) => 2, - UnknownValue::Fixed32(..) => 5, - UnknownValue::Invalid(vt, ..) => *vt, - }, - Value::Incomplete(vt, ..) => *vt, - } - } -} - -impl PackedArray -{ - fn encode(&self) -> BytesMut - { - macro_rules! write_packed { - ($value:ident => $convert:expr ) => { - $value.iter().flat_map($convert).collect() - }; - } - - let data: Bytes = match self { - PackedArray::Double(v) => { - write_packed!(v => |v| BytesMut::from(v.to_le_bytes().as_ref())) - } - PackedArray::Float(v) => { - write_packed!(v => |v| BytesMut::from(v.to_le_bytes().as_ref())) - } - PackedArray::Int32(v) => { - write_packed!(v => |v| BytesMut::from(v.into_signed_varint().as_ref())) - } - PackedArray::Int64(v) => { - write_packed!(v => |v| BytesMut::from(v.into_signed_varint().as_ref())) - } - PackedArray::UInt32(v) => { - write_packed!(v => |v| BytesMut::from(v.into_unsigned_varint().as_ref())) - } - PackedArray::UInt64(v) => { - write_packed!(v => |v| BytesMut::from(v.into_unsigned_varint().as_ref())) - } - PackedArray::SInt32(v) => { - write_packed! { v => |v| { - let (v, sign_bit) = if *v < 0 { (-v, 1) } else { (*v, 0) }; - (v * 2 - sign_bit).into_unsigned_varint() - } } - } - PackedArray::SInt64(v) => { - write_packed! { v => |v| { - let (v, sign_bit) = if *v < 0 { (-v, 1) } else { (*v, 0) }; - (v * 2 - sign_bit).into_unsigned_varint() - } } - } - PackedArray::Fixed32(v) => { - write_packed!( v => |v| BytesMut::from(v.to_le_bytes().as_ref()) ) - } - PackedArray::Fixed64(v) => { - write_packed!( v => |v| BytesMut::from(v.to_le_bytes().as_ref()) ) - } - PackedArray::SFixed32(v) => { - write_packed!( v => |v| BytesMut::from(v.to_le_bytes().as_ref()) ) - } - PackedArray::SFixed64(v) => { - write_packed!( v => |v| BytesMut::from(v.to_le_bytes().as_ref()) ) - } - PackedArray::Bool(v) => { - write_packed!( v => |v| BytesMut::from(if *v { [1u8].as_ref() } else { [0u8].as_ref() })) - } - }; - - let mut output = data.len().into_unsigned_varint(); - output.extend_from_slice(data.as_ref()); - output - } -} - -fn return_incomplete(data: &mut &[u8], vt: u8, original: &[u8]) -> Value -{ - *data = &[]; - Value::Incomplete(vt, Bytes::copy_from_slice(original)) -} - -fn try_read_8_bytes(data: &mut &[u8]) -> Option<[u8; 8]> -{ - if data.len() < 8 { - return None; - } - - match (data[..8]).try_into() { - Ok(v) => { - *data = &data[8..]; - Some(v) - } - Err(_) => None, - } -} - -fn try_read_4_bytes(data: &mut &[u8]) -> Option<[u8; 4]> -{ - if data.len() < 4 { - return None; - } - - match (data[..4]).try_into() { - Ok(v) => { - *data = &data[4..]; - Some(v) - } - Err(_) => None, - } -} - -fn read_string(data: &mut &[u8]) -> Option -{ - let original = *data; - let len = usize::from_unsigned_varint(data)?; - if len > data.len() { - *data = original; - return None; - } - let (str_data, remainder) = data.split_at(len); - *data = remainder; - Some(String::from_utf8_lossy(str_data).to_string()) -} - -fn read_bytes(data: &mut &[u8]) -> Option -{ - let original = *data; - let len = usize::from_unsigned_varint(data)?; - if len > data.len() { - *data = original; - return None; - } - let (str_data, remainder) = data.split_at(len); - *data = remainder; - Some(Bytes::copy_from_slice(str_data)) -} - -impl MessageInfo -{ - /// Decode a message. - /// - /// Will **panic** if the message defined by the `MessageRef` does not exist in this context. - /// Such panic means the `MessageRef` came from a different context. The panic is not - /// guaranteed, as a message with an equal `MessageRef` may exist in multiple contexts. - pub fn decode(&self, mut data: &[u8], ctx: &Context) -> MessageValue - { - let mut msg = MessageValue { - msg_ref: self.self_ref, - fields: vec![], - garbage: None, - }; - - loop { - if data.is_empty() { - break; - } - - let tag = match u64::from_unsigned_varint(&mut data) { - Some(tag) => tag, - None => { - msg.garbage = Some(Bytes::copy_from_slice(data)); - break; - } - }; - - let number = tag >> 3; - let wire_type = (tag & 0x07) as u8; - - let value = match self.get_field(number) { - Some(field) => { - if field.multiplicity == Multiplicity::RepeatedPacked { - if wire_type == 2 { - Value::decode_packed(&mut data, wire_type, &field.field_type) - } else { - Value::decode_unknown(&mut data, wire_type) - } - } else if field.field_type.wire_type() == wire_type { - Value::decode(&mut data, wire_type, &field.field_type, ctx) - } else { - Value::decode_unknown(&mut data, wire_type) - } - } - _ => Value::decode_unknown(&mut data, wire_type), - }; - - msg.fields.push(FieldValue { number, value }) - } - - msg - } -} - -impl MessageValue -{ - /// Encodes a message value into protobuf wire format. - /// - /// Will **panic** if the message defined by the `MessageRef` does not exist in this context. - /// Such panic means the `MessageRef` came from a different context. The panic is not - /// guaranteed, as a message with an equal `MessageRef` may exist in multiple contexts. - pub fn encode(&self, ctx: &Context) -> bytes::BytesMut - { - self.fields - .iter() - .filter_map(|f| f.value.encode(ctx).map(|(w, b)| (f, w, b))) - .flat_map(|(field, wire_type, bytes)| { - let tag = wire_type as u64 + (field.number << 3); - let mut field_data = tag.into_unsigned_varint(); - field_data.extend_from_slice(&bytes); - field_data - }) - .collect() - } -} - -impl UnknownValue -{ - /// Encodes a message value into protobuf wire format. - fn encode(&self) -> bytes::BytesMut - { - match self { - UnknownValue::Varint(v) => v.into_unsigned_varint(), - UnknownValue::Fixed64(v) => BytesMut::from(v.to_le_bytes().as_ref()), - UnknownValue::VariableLength(b) => { - let mut output = b.len().into_unsigned_varint(); - output.extend_from_slice(b); - output - } - UnknownValue::Fixed32(v) => BytesMut::from(v.to_le_bytes().as_ref()), - UnknownValue::Invalid(_, v) => BytesMut::from(v.as_ref()), - } - } -} - -trait FromUnsignedVarint: Sized -{ - fn from_unsigned_varint(data: &mut &[u8]) -> Option; -} - -trait ToUnsignedVarint: Sized -{ - fn into_unsigned_varint(self) -> BytesMut; -} - -impl> FromUnsignedVarint for T -where - T::Error: Debug, -{ - fn from_unsigned_varint(data: &mut &[u8]) -> Option - { - let mut result = 0u64; - let mut idx = 0; - loop { - if idx >= data.len() { - return None; - } - - let b = data[idx]; - let value = (b & 0x7f) as u64; - result += value << (idx * 7); - - idx += 1; - if b & 0x80 == 0 { - break; - } - } - - let result = T::try_from(result).expect("Out of range"); - *data = &data[idx..]; - Some(result) - } -} - -impl> ToUnsignedVarint for T -where - T::Error: Debug, -{ - fn into_unsigned_varint(self) -> BytesMut - { - let mut value: u64 = self.try_into().unwrap(); - let mut data: Vec = Vec::with_capacity(8); - loop { - let mut byte = (value & 0x7f) as u8; - value >>= 7; - if value > 0 { - byte |= 0x80; - data.push(byte); - } else { - data.push(byte); - break BytesMut::from(data.as_slice()); - } - } - } -} - -trait FromSignedVarint: Sized -{ - fn from_signed_varint(data: &mut &[u8]) -> Option; -} - -trait ToSignedVarint: Sized -{ - fn into_signed_varint(self) -> BytesMut; -} - -impl> FromSignedVarint for T -where - T::Error: Debug, -{ - fn from_signed_varint(data: &mut &[u8]) -> Option - { - u64::from_unsigned_varint(data).map(|u| { - let signed: i64 = unsafe { std::mem::transmute(u) }; - signed.try_into().unwrap() - }) - } -} - -impl> ToSignedVarint for T -where - T::Error: Debug, -{ - fn into_signed_varint(self) -> BytesMut - { - let v: u64 = unsafe { std::mem::transmute(self.try_into().unwrap()) }; - v.into_unsigned_varint() - } -} - -#[cfg(test)] -mod test -{ - use super::*; - - #[test] - fn test_zigzag_encoding() - // Source: https://developers.google.com/protocol-buffers/docs/encoding#signed-ints - // Signed Original Encoded As - // 0 0 - // -1 1 - // 1 2 - // -2 3 - // 2147483647 4294967294 - // -2147483648 4294967295 - { - let ctx = Context::parse(&[r#" - syntax = "proto3"; - message Message {} - "#]) - .unwrap(); - - // Singular - assert_eq!(Value::SInt32(0).encode(&ctx), Value::Int32(0).encode(&ctx)); - assert_eq!(Value::SInt32(-1).encode(&ctx), Value::Int32(1).encode(&ctx)); - assert_eq!(Value::SInt32(1).encode(&ctx), Value::Int32(2).encode(&ctx)); - assert_eq!( - Value::SInt64(2147483647).encode(&ctx), - Value::Int64(4294967294).encode(&ctx) - ); - assert_eq!( - Value::SInt64(-2147483648).encode(&ctx), - Value::Int64(4294967295).encode(&ctx) - ); - - // Packed - assert_eq!( - Value::Packed(PackedArray::SInt32(vec![0, -1, 1, -2, 2])).encode(&ctx), - Value::Packed(PackedArray::Int32(vec![0, 1, 2, 3, 4])).encode(&ctx), - ); - assert_eq!( - Value::Packed(PackedArray::SInt64(vec![0, 2147483647, -2147483648])).encode(&ctx), - Value::Packed(PackedArray::Int64(vec![0, 4294967294, 4294967295])).encode(&ctx), - ); - } -} +//! Protocol buffer binary payload decoding. +//! +//! The decoding functionality can be accessed by building a decoding context and acquiring a +//! message or message reference. See the example in the [crate root](crate). + +use crate::context::*; +use bytes::{Bytes, BytesMut}; +use std::convert::{TryFrom, TryInto}; +use std::fmt::Debug; + +impl Context { + /// Decode a message. + pub fn decode(&self, msg: MessageRef, data: &[u8]) -> MessageValue { + self.resolve_message(msg).decode(data, self) + } +} + +/// Decoded protocol buffer value. +#[derive(Debug, PartialEq, Clone)] +pub enum Value { + /// `double` value. + Double(f64), + /// `float` value. + Float(f32), + /// `int32` value. + Int32(i32), + /// `int64` value. + Int64(i64), + /// `uint32` value. + UInt32(u32), + /// `uint64` value. + UInt64(u64), + /// `sint32` value. + SInt32(i32), + /// `sint64` value. + SInt64(i64), + /// `fixed32` value. + Fixed32(u32), + /// `fixed64` value. + Fixed64(u64), + /// `sfixed32` value. + SFixed32(i32), + /// `sfixed64` value. + SFixed64(i64), + /// `bool` value. + Bool(bool), + /// `string` value. + String(String), + /// `bytes` value. + Bytes(Bytes), + + /// A repeated packed value. + Packed(PackedArray), + + /// Message type value. + Message(Box), + + /// Enum type value. + Enum(EnumValue), + + /// Value which was incomplete due to missing bytes in the payload. + Incomplete(u8, Bytes), + + /// Value which wasn't defined in the context. + /// + /// The wire type allows the decoder to tell how large an unknown value is. This allows the + /// unknown value to be skipped and decoding can continue from the next value. + Unknown(UnknownValue), +} + +/// Packed scalar fields. +#[derive(Debug, PartialEq, Clone)] +pub enum PackedArray { + /// `double` value. + Double(Vec), + /// `float` value. + Float(Vec), + /// `int32` value. + Int32(Vec), + /// `int64` value. + Int64(Vec), + /// `uint32` value. + UInt32(Vec), + /// `uint64` value. + UInt64(Vec), + /// `sint32` value. + SInt32(Vec), + /// `sint64` value. + SInt64(Vec), + /// `fixed32` value. + Fixed32(Vec), + /// `fixed64` value. + Fixed64(Vec), + /// `sfixed32` value. + SFixed32(Vec), + /// `sfixed64` value. + SFixed64(Vec), + /// `bool` value. + Bool(Vec), +} + +/// Unknown value. +#[derive(Debug, PartialEq, Clone)] +pub enum UnknownValue { + /// Unknown varint (wire type = 0). + Varint(u128), + + /// Unknown 64-bit value (wire type = 1). + Fixed64(u64), + + /// Unknown variable length value (wire type = 2). + VariableLength(Bytes), + + /// Unknown 32-bit value (wire type = 5). + Fixed32(u32), + + /// Invalid value. + /// + /// Invalid value is a value for which the wire type wasn't valid. Encountering invalid wire + /// type will result in the remaining bytes to be consumed from the current variable length + /// stream as it is imposible to tell how large such invalid value is. + /// + /// The decoding will continue after the current variable length value. + Invalid(u8, Bytes), +} + +/// Enum value. +#[derive(Debug, PartialEq, Clone)] +pub struct EnumValue { + /// Reference to the enum type. + pub enum_ref: EnumRef, + + /// Value. + pub value: i64, +} + +/// Message value. +#[derive(Debug, PartialEq, Clone)] +pub struct MessageValue { + /// Reference to the message type. + pub msg_ref: MessageRef, + + /// Mesage field values. + pub fields: Vec, + + /// Garbage data at the end of the message. + /// + /// As opposed to an `UnknownValue::Invalid`, the garbage data did not have a valid field + /// number and for that reason cannot be placed into the `fields` vector. + pub garbage: Option, +} + +/// Field value. +#[derive(Debug, PartialEq, Clone)] +pub struct FieldValue { + /// Field number. + pub number: u64, + + /// Field value. + pub value: Value, +} + +impl Value { + fn decode(data: &mut &[u8], vt_raw: u8, vt: &ValueType, ctx: &Context) -> Self { + let original = *data; + let opt = match vt { + ValueType::Double => { + try_read_8_bytes(data).map(|b| Value::Double(f64::from_le_bytes(b))) + } + ValueType::Float => try_read_4_bytes(data).map(|b| Value::Float(f32::from_le_bytes(b))), + ValueType::Int32 => i32::from_signed_varint(data).map(Value::Int32), + ValueType::Int64 => i64::from_signed_varint(data).map(Value::Int64), + ValueType::UInt32 => u32::from_unsigned_varint(data).map(Value::UInt32), + ValueType::UInt64 => u64::from_unsigned_varint(data).map(Value::UInt64), + ValueType::SInt32 => u32::from_unsigned_varint(data).map(|u| { + let (sign, sign_bit) = if u % 2 == 0 { (1i32, 0) } else { (-1i32, 1) }; + let magnitude = (u / 2) as i32 + sign_bit; + Value::SInt32(sign * magnitude) + }), + ValueType::SInt64 => u64::from_unsigned_varint(data).map(|u| { + let (sign, sign_bit) = if u % 2 == 0 { (1i64, 0) } else { (-1i64, 1) }; + let magnitude = (u / 2) as i64 + sign_bit; + Value::SInt64(sign * magnitude) + }), + ValueType::Fixed32 => { + try_read_4_bytes(data).map(|b| Value::Fixed32(u32::from_le_bytes(b))) + } + ValueType::Fixed64 => { + try_read_8_bytes(data).map(|b| Value::Fixed64(u64::from_le_bytes(b))) + } + ValueType::SFixed32 => { + try_read_4_bytes(data).map(|b| Value::SFixed32(i32::from_le_bytes(b))) + } + ValueType::SFixed64 => { + try_read_8_bytes(data).map(|b| Value::SFixed64(i64::from_le_bytes(b))) + } + ValueType::Bool => usize::from_unsigned_varint(data).map(|u| Value::Bool(u != 0)), + ValueType::String => read_string(data).map(Value::String), + ValueType::Bytes => read_bytes(data).map(Value::Bytes), + ValueType::Enum(eref) => i64::from_signed_varint(data).map(|v| { + Value::Enum(EnumValue { + enum_ref: *eref, + value: v, + }) + }), + ValueType::Message(mref) => usize::from_unsigned_varint(data).and_then(|length| { + if data.len() < length { + *data = original; + return None; + } + let (consumed, remainder) = data.split_at(length); + *data = remainder; + + Some(Value::Message(Box::new( + ctx.resolve_message(*mref).decode(consumed, ctx), + ))) + }), + }; + + opt.unwrap_or_else(|| { + *data = &[]; + Value::Incomplete(vt_raw, Bytes::copy_from_slice(original)) + }) + } + + fn decode_packed(data: &mut &[u8], vt_raw: u8, vt: &ValueType) -> Self { + let original = *data; + let length = match usize::from_unsigned_varint(data) { + Some(len) => len, + None => { + return return_incomplete(data, vt_raw, original); + } + }; + + if data.len() < length { + return return_incomplete(data, vt_raw, original); + } + + let mut array = &data[..length]; + *data = &data[length..]; + + // Reading the packed arrays follows very similar format for each type. The variances are + // in how to read the data from the stream and what to do with the data to get the final + // value. + // + // This macro implements the basic structure with holes for the varying bits. + macro_rules! read_packed { + ($variant:ident @ $val:ident = $try_read:expr => $insert:expr ) => { + let mut output = vec![]; + loop { + if array.is_empty() { + break Value::Packed(PackedArray::$variant(output)); + } + + match $try_read { + Some($val) => output.push($insert), + None => return return_incomplete(&mut array, vt_raw, original), + } + } + }; + } + + match vt { + ValueType::Double => { + read_packed! { Double @ b = try_read_8_bytes(&mut array) => f64::from_le_bytes(b) } + } + ValueType::Float => { + read_packed! { Float @ b = try_read_4_bytes(&mut array) => f32::from_le_bytes(b) } + } + ValueType::Int32 => { + read_packed! { Int32 @ b = i32::from_signed_varint(&mut array) => b } + } + ValueType::Int64 => { + read_packed! { Int64 @ b = i64::from_signed_varint(&mut array) => b } + } + ValueType::UInt32 => { + read_packed! { UInt32 @ b = u32::from_signed_varint(&mut array) => b } + } + ValueType::UInt64 => { + read_packed! { UInt64 @ b = u64::from_signed_varint(&mut array) => b } + } + ValueType::SInt32 => { + read_packed! { SInt32 @ b = u32::from_signed_varint(&mut array) => { + let (sign, sign_bit) = if b % 2 == 0 { (1i32, 0) } else { (-1i32, 1) }; + let magnitude = (b / 2) as i32 + sign_bit; + sign * magnitude + } } + } + ValueType::SInt64 => { + read_packed! { SInt64 @ b = u64::from_signed_varint(&mut array) => { + let (sign, sign_bit) = if b % 2 == 0 { (1i64, 0) } else { (-1i64, 1) }; + let magnitude = (b / 2) as i64 + sign_bit; + sign * magnitude + } } + } + ValueType::Fixed32 => { + read_packed! { Fixed32 @ b = try_read_4_bytes(&mut array) => u32::from_le_bytes(b) } + } + ValueType::Fixed64 => { + read_packed! { Fixed64 @ b = try_read_8_bytes(&mut array) => u64::from_le_bytes(b) } + } + ValueType::SFixed32 => { + read_packed! { SFixed32 @ b = try_read_4_bytes(&mut array) => i32::from_le_bytes(b) } + } + ValueType::SFixed64 => { + read_packed! { SFixed64 @ b = try_read_8_bytes(&mut array) => i64::from_le_bytes(b) } + } + ValueType::Bool => { + read_packed! { Bool @ b = u8::from_unsigned_varint(&mut array) => b != 0 } + } + _ => panic!("Non-scalar type was handled as packed"), + } + } + + fn decode_unknown(data: &mut &[u8], vt: u8) -> Value { + let original = *data; + let value = + match vt { + 0 => u128::from_unsigned_varint(data).map(UnknownValue::Varint), + 1 => try_read_8_bytes(data) + .map(|value| UnknownValue::Fixed64(u64::from_le_bytes(value))), + 2 => usize::from_unsigned_varint(data).and_then(|length| { + if length > data.len() { + *data = original; + return None; + } + let (consumed, remainder) = data.split_at(length); + *data = remainder; + Some(UnknownValue::VariableLength(Bytes::copy_from_slice( + consumed, + ))) + }), + 5 => try_read_4_bytes(data) + .map(|value| UnknownValue::Fixed32(u32::from_le_bytes(value))), + _ => { + let bytes = Bytes::copy_from_slice(data); + *data = &[]; + Some(UnknownValue::Invalid(vt, bytes)) + } + }; + + value + .map(Value::Unknown) + .unwrap_or_else(|| Value::Incomplete(vt, Bytes::copy_from_slice(data))) + } + + fn encode(&self, ctx: &Context) -> Option<(u8, BytesMut)> { + let bytes = match self { + Value::Double(v) => BytesMut::from(v.to_le_bytes().as_ref()), + Value::Float(v) => BytesMut::from(v.to_le_bytes().as_ref()), + Value::Int32(v) => BytesMut::from(v.into_signed_varint().as_ref()), + Value::Int64(v) => BytesMut::from(v.into_signed_varint().as_ref()), + Value::UInt32(v) => BytesMut::from(v.into_unsigned_varint().as_ref()), + Value::UInt64(v) => BytesMut::from(v.into_unsigned_varint().as_ref()), + Value::SInt32(v) => { + let (v, sign_bit) = if *v < 0 { (-v, 1) } else { (*v, 0) }; + (v * 2 - sign_bit).into_unsigned_varint() + } + Value::SInt64(v) => { + let (v, sign_bit) = if *v < 0 { (-v, 1) } else { (*v, 0) }; + (v * 2 - sign_bit).into_unsigned_varint() + } + Value::Fixed32(v) => BytesMut::from(v.to_le_bytes().as_ref()), + Value::Fixed64(v) => BytesMut::from(v.to_le_bytes().as_ref()), + Value::SFixed32(v) => BytesMut::from(v.to_le_bytes().as_ref()), + Value::SFixed64(v) => BytesMut::from(v.to_le_bytes().as_ref()), + Value::Bool(v) => BytesMut::from(if *v { [1u8].as_ref() } else { [0u8].as_ref() }), + Value::String(v) => { + let mut output = v.len().into_unsigned_varint(); + output.extend_from_slice(v.as_bytes()); + output + } + Value::Bytes(v) => { + let mut output = v.len().into_unsigned_varint(); + output.extend_from_slice(v); + output + } + Value::Enum(v) => BytesMut::from(v.value.into_signed_varint().as_ref()), + Value::Message(v) => { + let data = v.encode(ctx); + let mut output = data.len().into_unsigned_varint(); + output.extend_from_slice(&data); + output + } + Value::Packed(p) => p.encode(), + Value::Unknown(u) => u.encode(), + Value::Incomplete(_, bytes) => BytesMut::from(bytes.as_ref()), + }; + + Some((self.wire_type(), bytes)) + } + + fn wire_type(&self) -> u8 { + match self { + Value::Double(..) => 1, + Value::Float(..) => 5, + Value::Int32(..) => 0, + Value::Int64(..) => 0, + Value::UInt32(..) => 0, + Value::UInt64(..) => 0, + Value::SInt32(..) => 0, + Value::SInt64(..) => 0, + Value::Fixed32(..) => 5, + Value::Fixed64(..) => 1, + Value::SFixed32(..) => 5, + Value::SFixed64(..) => 1, + Value::Bool(..) => 0, + Value::String(..) => 2, + Value::Bytes(..) => 2, + Value::Message(..) => 2, + Value::Enum(..) => 0, + Value::Packed(..) => 2, + Value::Unknown(unk) => match unk { + UnknownValue::Varint(..) => 0, + UnknownValue::Fixed64(..) => 1, + UnknownValue::VariableLength(..) => 2, + UnknownValue::Fixed32(..) => 5, + UnknownValue::Invalid(vt, ..) => *vt, + }, + Value::Incomplete(vt, ..) => *vt, + } + } +} + +impl PackedArray { + fn encode(&self) -> BytesMut { + macro_rules! write_packed { + ($value:ident => $convert:expr ) => { + $value.iter().flat_map($convert).collect() + }; + } + + let data: Bytes = match self { + PackedArray::Double(v) => { + write_packed!(v => |v| BytesMut::from(v.to_le_bytes().as_ref())) + } + PackedArray::Float(v) => { + write_packed!(v => |v| BytesMut::from(v.to_le_bytes().as_ref())) + } + PackedArray::Int32(v) => { + write_packed!(v => |v| BytesMut::from(v.into_signed_varint().as_ref())) + } + PackedArray::Int64(v) => { + write_packed!(v => |v| BytesMut::from(v.into_signed_varint().as_ref())) + } + PackedArray::UInt32(v) => { + write_packed!(v => |v| BytesMut::from(v.into_unsigned_varint().as_ref())) + } + PackedArray::UInt64(v) => { + write_packed!(v => |v| BytesMut::from(v.into_unsigned_varint().as_ref())) + } + PackedArray::SInt32(v) => { + write_packed! { v => |v| { + let (v, sign_bit) = if *v < 0 { (-v, 1) } else { (*v, 0) }; + (v * 2 - sign_bit).into_unsigned_varint() + } } + } + PackedArray::SInt64(v) => { + write_packed! { v => |v| { + let (v, sign_bit) = if *v < 0 { (-v, 1) } else { (*v, 0) }; + (v * 2 - sign_bit).into_unsigned_varint() + } } + } + PackedArray::Fixed32(v) => { + write_packed!( v => |v| BytesMut::from(v.to_le_bytes().as_ref()) ) + } + PackedArray::Fixed64(v) => { + write_packed!( v => |v| BytesMut::from(v.to_le_bytes().as_ref()) ) + } + PackedArray::SFixed32(v) => { + write_packed!( v => |v| BytesMut::from(v.to_le_bytes().as_ref()) ) + } + PackedArray::SFixed64(v) => { + write_packed!( v => |v| BytesMut::from(v.to_le_bytes().as_ref()) ) + } + PackedArray::Bool(v) => { + write_packed!( v => |v| BytesMut::from(if *v { [1u8].as_ref() } else { [0u8].as_ref() })) + } + }; + + let mut output = data.len().into_unsigned_varint(); + output.extend_from_slice(data.as_ref()); + output + } +} + +fn return_incomplete(data: &mut &[u8], vt: u8, original: &[u8]) -> Value { + *data = &[]; + Value::Incomplete(vt, Bytes::copy_from_slice(original)) +} + +fn try_read_8_bytes(data: &mut &[u8]) -> Option<[u8; 8]> { + if data.len() < 8 { + return None; + } + + match (data[..8]).try_into() { + Ok(v) => { + *data = &data[8..]; + Some(v) + } + Err(_) => None, + } +} + +fn try_read_4_bytes(data: &mut &[u8]) -> Option<[u8; 4]> { + if data.len() < 4 { + return None; + } + + match (data[..4]).try_into() { + Ok(v) => { + *data = &data[4..]; + Some(v) + } + Err(_) => None, + } +} + +fn read_string(data: &mut &[u8]) -> Option { + let original = *data; + let len = usize::from_unsigned_varint(data)?; + if len > data.len() { + *data = original; + return None; + } + let (str_data, remainder) = data.split_at(len); + *data = remainder; + Some(String::from_utf8_lossy(str_data).to_string()) +} + +fn read_bytes(data: &mut &[u8]) -> Option { + let original = *data; + let len = usize::from_unsigned_varint(data)?; + if len > data.len() { + *data = original; + return None; + } + let (str_data, remainder) = data.split_at(len); + *data = remainder; + Some(Bytes::copy_from_slice(str_data)) +} + +impl MessageInfo { + /// Decode a message. + /// + /// Will **panic** if the message defined by the `MessageRef` does not exist in this context. + /// Such panic means the `MessageRef` came from a different context. The panic is not + /// guaranteed, as a message with an equal `MessageRef` may exist in multiple contexts. + pub fn decode(&self, mut data: &[u8], ctx: &Context) -> MessageValue { + let mut msg = MessageValue { + msg_ref: self.self_ref, + fields: vec![], + garbage: None, + }; + + loop { + if data.is_empty() { + break; + } + + let tag = match u64::from_unsigned_varint(&mut data) { + Some(tag) => tag, + None => { + msg.garbage = Some(Bytes::copy_from_slice(data)); + break; + } + }; + + let number = tag >> 3; + let wire_type = (tag & 0x07) as u8; + + let value = match self.get_field(number) { + Some(field) => { + if field.multiplicity == Multiplicity::RepeatedPacked { + if wire_type == 2 { + Value::decode_packed(&mut data, wire_type, &field.field_type) + } else { + Value::decode_unknown(&mut data, wire_type) + } + } else if field.field_type.wire_type() == wire_type { + Value::decode(&mut data, wire_type, &field.field_type, ctx) + } else { + Value::decode_unknown(&mut data, wire_type) + } + } + _ => Value::decode_unknown(&mut data, wire_type), + }; + + msg.fields.push(FieldValue { number, value }) + } + + msg + } +} + +impl MessageValue { + /// Encodes a message value into protobuf wire format. + /// + /// Will **panic** if the message defined by the `MessageRef` does not exist in this context. + /// Such panic means the `MessageRef` came from a different context. The panic is not + /// guaranteed, as a message with an equal `MessageRef` may exist in multiple contexts. + pub fn encode(&self, ctx: &Context) -> bytes::BytesMut { + self.fields + .iter() + .filter_map(|f| f.value.encode(ctx).map(|(w, b)| (f, w, b))) + .flat_map(|(field, wire_type, bytes)| { + let tag = wire_type as u64 + (field.number << 3); + let mut field_data = tag.into_unsigned_varint(); + field_data.extend_from_slice(&bytes); + field_data + }) + .collect() + } +} + +impl UnknownValue { + /// Encodes a message value into protobuf wire format. + fn encode(&self) -> bytes::BytesMut { + match self { + UnknownValue::Varint(v) => v.into_unsigned_varint(), + UnknownValue::Fixed64(v) => BytesMut::from(v.to_le_bytes().as_ref()), + UnknownValue::VariableLength(b) => { + let mut output = b.len().into_unsigned_varint(); + output.extend_from_slice(b); + output + } + UnknownValue::Fixed32(v) => BytesMut::from(v.to_le_bytes().as_ref()), + UnknownValue::Invalid(_, v) => BytesMut::from(v.as_ref()), + } + } +} + +trait FromUnsignedVarint: Sized { + fn from_unsigned_varint(data: &mut &[u8]) -> Option; +} + +trait ToUnsignedVarint: Sized { + fn into_unsigned_varint(self) -> BytesMut; +} + +impl> FromUnsignedVarint for T +where + T::Error: Debug, +{ + fn from_unsigned_varint(data: &mut &[u8]) -> Option { + let mut result = 0u64; + let mut idx = 0; + loop { + if idx >= data.len() { + return None; + } + + let b = data[idx]; + let value = (b & 0x7f) as u64; + result += value << (idx * 7); + + idx += 1; + if b & 0x80 == 0 { + break; + } + } + + let result = T::try_from(result).expect("Out of range"); + *data = &data[idx..]; + Some(result) + } +} + +impl> ToUnsignedVarint for T +where + T::Error: Debug, +{ + fn into_unsigned_varint(self) -> BytesMut { + let mut value: u64 = self.try_into().unwrap(); + let mut data: Vec = Vec::with_capacity(8); + loop { + let mut byte = (value & 0x7f) as u8; + value >>= 7; + if value > 0 { + byte |= 0x80; + data.push(byte); + } else { + data.push(byte); + break BytesMut::from(data.as_slice()); + } + } + } +} + +trait FromSignedVarint: Sized { + fn from_signed_varint(data: &mut &[u8]) -> Option; +} + +trait ToSignedVarint: Sized { + fn into_signed_varint(self) -> BytesMut; +} + +impl> FromSignedVarint for T +where + T::Error: Debug, +{ + fn from_signed_varint(data: &mut &[u8]) -> Option { + u64::from_unsigned_varint(data).map(|u| { + let signed: i64 = unsafe { std::mem::transmute(u) }; + signed.try_into().unwrap() + }) + } +} + +impl> ToSignedVarint for T +where + T::Error: Debug, +{ + fn into_signed_varint(self) -> BytesMut { + let v: u64 = unsafe { std::mem::transmute(self.try_into().unwrap()) }; + v.into_unsigned_varint() + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_zigzag_encoding() + // Source: https://developers.google.com/protocol-buffers/docs/encoding#signed-ints + // Signed Original Encoded As + // 0 0 + // -1 1 + // 1 2 + // -2 3 + // 2147483647 4294967294 + // -2147483648 4294967295 + { + let ctx = Context::parse(&[r#" + syntax = "proto3"; + message Message {} + "#]) + .unwrap(); + + // Singular + assert_eq!(Value::SInt32(0).encode(&ctx), Value::Int32(0).encode(&ctx)); + assert_eq!(Value::SInt32(-1).encode(&ctx), Value::Int32(1).encode(&ctx)); + assert_eq!(Value::SInt32(1).encode(&ctx), Value::Int32(2).encode(&ctx)); + assert_eq!( + Value::SInt64(2147483647).encode(&ctx), + Value::Int64(4294967294).encode(&ctx) + ); + assert_eq!( + Value::SInt64(-2147483648).encode(&ctx), + Value::Int64(4294967295).encode(&ctx) + ); + + // Packed + assert_eq!( + Value::Packed(PackedArray::SInt32(vec![0, -1, 1, -2, 2])).encode(&ctx), + Value::Packed(PackedArray::Int32(vec![0, 1, 2, 3, 4])).encode(&ctx), + ); + assert_eq!( + Value::Packed(PackedArray::SInt64(vec![0, 2147483647, -2147483648])).encode(&ctx), + Value::Packed(PackedArray::Int64(vec![0, 4294967294, 4294967295])).encode(&ctx), + ); + } +} diff --git a/tests/create.rs b/tests/create.rs index 8c63d4f..ac22ca3 100644 --- a/tests/create.rs +++ b/tests/create.rs @@ -3,8 +3,7 @@ use protofish::context::{ }; #[test] -fn create_context_by_hand() -{ +fn create_context_by_hand() { let parsed_context = Context::parse(&[r#" syntax = "proto3"; @@ -76,8 +75,7 @@ fn create_context_by_hand() } #[test] -fn iterate_fields() -{ +fn iterate_fields() { let context = Context::parse(&[r#" syntax = "proto3"; diff --git a/tests/encode-message.rs b/tests/encode-message.rs index 8d0a0d7..111f79d 100644 --- a/tests/encode-message.rs +++ b/tests/encode-message.rs @@ -1,79 +1,81 @@ -#[test] -fn encode_message() -{ - use bytes::BufMut; - use protofish::{ - context::Context, - decode::{FieldValue, MessageValue, Value}, - }; - - let context = Context::parse(&[r#" - syntax = "proto3"; - message Message { - string s = 1; - int32 small = 2; - int64 large = 3; - sint32 signed = 4; - fixed64 fixed = 5; - double dbl = 6; - bool b = 7; - Message child = 10; - } - "#]) - .unwrap(); - - let msg = context.get_message("Message").unwrap(); - - let original = MessageValue { - msg_ref: msg.self_ref.clone(), - garbage: None, - fields: vec![ - FieldValue { - number: 1, - value: Value::String("parent".to_string()), - }, - FieldValue { - number: 2, - value: Value::Int32(123), - }, - FieldValue { - number: 3, - value: Value::Int64(12356), - }, - FieldValue { - number: 4, - value: Value::SInt32(-123), - }, - FieldValue { - number: 5, - value: Value::Fixed64(12356), - }, - FieldValue { - number: 6, - value: Value::Double(1.2345), - }, - FieldValue { - number: 7, - value: Value::Bool(true), - }, - FieldValue { - number: 10, - value: Value::Message(Box::new(MessageValue { - msg_ref: msg.self_ref.clone(), - garbage: None, - fields: vec![FieldValue { - number: 1, - value: Value::String("child".to_string()), - }], - })), - }, - ], - }; - - let expected = original.encode(&context); - let decoded = msg.decode(&expected, &context); - let actual = decoded.encode(&context); - - assert_eq!(original, decoded); - assert_eq!(expected, actual); -} +#[test] +fn encode_message() { + use bytes::BufMut; + use protofish::{ + context::Context, + decode::{FieldValue, MessageValue, Value}, + }; + + let context = Context::parse(&[r#" + syntax = "proto3"; + message Message { + string s = 1; + int32 small = 2; + int64 large = 3; + sint32 signed = 4; + fixed64 fixed = 5; + double dbl = 6; + bool b = 7; + Message child = 10; + map mymap = 11; + } + "#]) + .unwrap(); + + let msg = context.get_message("Message").unwrap(); + + let original = MessageValue { + msg_ref: msg.self_ref.clone(), + garbage: None, + fields: vec![ + FieldValue { + number: 1, + value: Value::String("parent".to_string()), + }, + FieldValue { + number: 2, + value: Value::Int32(123), + }, + FieldValue { + number: 3, + value: Value::Int64(12356), + }, + FieldValue { + number: 4, + value: Value::SInt32(-123), + }, + FieldValue { + number: 5, + value: Value::Fixed64(12356), + }, + FieldValue { + number: 6, + value: Value::Double(1.2345), + }, + FieldValue { + number: 7, + value: Value::Bool(true), + }, + FieldValue { + number: 10, + value: Value::Message(Box::new(MessageValue { + msg_ref: msg.self_ref.clone(), + garbage: None, + fields: vec![FieldValue { + number: 1, + value: Value::String("child".to_string()), + }], + })), + }, + // note we don't add any data for the map type + // but due to how wire encoding works, decoding will still work + ], + }; + + let expected = original.encode(&context); + let decoded = msg.decode(&expected, &context); + let actual = decoded.encode(&context); + + assert_eq!(original, decoded); + assert_eq!(expected, actual); +} diff --git a/tests/map.rs b/tests/map.rs new file mode 100644 index 0000000..3698a86 --- /dev/null +++ b/tests/map.rs @@ -0,0 +1,18 @@ +#[test] +fn maptype() { + use protofish::context::Context; + + // Hey at least we're ensuring this doesn't panic. :< + Context::parse(&[r#" + syntax = "proto3"; + message Message { + message mydata { + uint32 b1 = 1; + map m1 = 2; + map m2 = 3; + map m3 = 4; + } + } + "#]) + .unwrap(); +} diff --git a/tests/oneof.rs b/tests/oneof.rs index 5ab19ef..1dba961 100644 --- a/tests/oneof.rs +++ b/tests/oneof.rs @@ -1,23 +1,22 @@ -#[test] -fn oneof() -{ - use protofish::context::Context; - - // Hey at least we're ensuring this doesn't panic. :< - Context::parse(&[r#" - syntax = "proto3"; - message Message { - oneof a { - string a1 = 1; - string a2 = 2; - string a3 = 3; - }; - oneof b { - uint32 b1 = 4; - uint32 b2 = 5; - uint32 b3 = 6; - } - } - "#]) - .unwrap(); -} +#[test] +fn oneof() { + use protofish::context::Context; + + // Hey at least we're ensuring this doesn't panic. :< + Context::parse(&[r#" + syntax = "proto3"; + message Message { + oneof a { + string a1 = 1; + string a2 = 2; + string a3 = 3; + }; + oneof b { + uint32 b1 = 4; + uint32 b2 = 5; + uint32 b3 = 6; + } + } + "#]) + .unwrap(); +} diff --git a/tests/packed_array.rs b/tests/packed_array.rs index 2762a26..b47b314 100644 --- a/tests/packed_array.rs +++ b/tests/packed_array.rs @@ -1,73 +1,72 @@ -#[test] -fn repeated() -{ - use bytes::BufMut; - use protofish::{ - context::Context, - decode::{FieldValue, MessageValue, PackedArray, Value}, - }; - - let context = Context::parse(&[r#" - syntax = "proto3"; - message Message { - repeated string s = 1; - repeated int32 small = 2; - repeated int32 large = 3; - } - "#]) - .unwrap(); - - let mut payload = bytes::BytesMut::new(); - - payload.put_u8(1 << 3 | 2); // String tag. - payload.put_u8(11); - payload.put_slice(b"first value"); - - payload.put_u8(1 << 3 | 2); // String tag. - payload.put_u8(12); - payload.put_slice(b"second value"); - - payload.put_u8(2 << 3 | 2); // Packed integer array. - payload.put_slice(b"\x06"); // Length - payload.put_slice(b"\x01"); - payload.put_slice(b"\x80\x01"); - payload.put_slice(b"\x80\x80\x02"); - - payload.put_u8(3 << 3 | 2); // Packed integer array. - payload.put_slice(b"\x80\x01"); // Length - payload.put_slice(&(b"\x01".repeat(128))); - - let msg = context.get_message("Message").unwrap(); - let value = msg.decode(&payload, &context); - - assert_eq!( - value, - MessageValue { - msg_ref: msg.self_ref.clone(), - garbage: None, - fields: vec![ - FieldValue { - number: 1, - value: Value::String("first value".to_string()), - }, - FieldValue { - number: 1, - value: Value::String("second value".to_string()), - }, - FieldValue { - number: 2, - value: Value::Packed(PackedArray::Int32(vec![1, 1 << 7, 1 << 15])), - }, - FieldValue { - number: 3, - value: Value::Packed(PackedArray::Int32( - std::iter::repeat(1).take(128).collect() - )), - }, - ] - } - ); - - let encoded = value.encode(&context); - assert_eq!(payload, encoded); -} +#[test] +fn repeated() { + use bytes::BufMut; + use protofish::{ + context::Context, + decode::{FieldValue, MessageValue, PackedArray, Value}, + }; + + let context = Context::parse(&[r#" + syntax = "proto3"; + message Message { + repeated string s = 1; + repeated int32 small = 2; + repeated int32 large = 3; + } + "#]) + .unwrap(); + + let mut payload = bytes::BytesMut::new(); + + payload.put_u8(1 << 3 | 2); // String tag. + payload.put_u8(11); + payload.put_slice(b"first value"); + + payload.put_u8(1 << 3 | 2); // String tag. + payload.put_u8(12); + payload.put_slice(b"second value"); + + payload.put_u8(2 << 3 | 2); // Packed integer array. + payload.put_slice(b"\x06"); // Length + payload.put_slice(b"\x01"); + payload.put_slice(b"\x80\x01"); + payload.put_slice(b"\x80\x80\x02"); + + payload.put_u8(3 << 3 | 2); // Packed integer array. + payload.put_slice(b"\x80\x01"); // Length + payload.put_slice(&(b"\x01".repeat(128))); + + let msg = context.get_message("Message").unwrap(); + let value = msg.decode(&payload, &context); + + assert_eq!( + value, + MessageValue { + msg_ref: msg.self_ref.clone(), + garbage: None, + fields: vec![ + FieldValue { + number: 1, + value: Value::String("first value".to_string()), + }, + FieldValue { + number: 1, + value: Value::String("second value".to_string()), + }, + FieldValue { + number: 2, + value: Value::Packed(PackedArray::Int32(vec![1, 1 << 7, 1 << 15])), + }, + FieldValue { + number: 3, + value: Value::Packed(PackedArray::Int32( + std::iter::repeat(1).take(128).collect() + )), + }, + ] + } + ); + + let encoded = value.encode(&context); + assert_eq!(payload, encoded); +} diff --git a/tests/parse.rs b/tests/parse.rs index b324900..5c0cfda 100644 --- a/tests/parse.rs +++ b/tests/parse.rs @@ -1,51 +1,50 @@ -#[test] -fn parse() -{ - use protofish::context::{ - Context, MessageField, MessageInfo, Multiplicity, Package, TypeParent, ValueType, - }; - - let context = Context::parse(&[r#" - syntax = "proto3"; - message Message { - string s = 1; - repeated bytes b = 2; - optional int64 large = 3; - repeated sint32 signed = 4; - Message child = 10; - } - "#]) - .unwrap(); - - let mut expected = Context::new(); - let package = expected.insert_package(Package::new(None)).unwrap(); - let mut message = MessageInfo::new("Message".to_string(), TypeParent::Package(package)); - - message - .add_field(MessageField::new("s".to_string(), 1, ValueType::String)) - .unwrap(); - - let mut b_field = MessageField::new("b".to_string(), 2, ValueType::Bytes); - b_field.multiplicity = Multiplicity::Repeated; - - let mut large_field = MessageField::new("large".to_string(), 3, ValueType::Int64); - large_field.multiplicity = Multiplicity::Optional; - - let mut signed_field = MessageField::new("signed".to_string(), 4, ValueType::SInt32); - signed_field.multiplicity = Multiplicity::RepeatedPacked; - - let child_field = MessageField::new( - "child".to_string(), - 10, - ValueType::Message(message.self_ref), - ); - - message.add_field(b_field).unwrap(); - message.add_field(large_field).unwrap(); - message.add_field(signed_field).unwrap(); - message.add_field(child_field).unwrap(); - - expected.insert_message(message).unwrap(); - - assert_eq!(expected, context); -} +#[test] +fn parse() { + use protofish::context::{ + Context, MessageField, MessageInfo, Multiplicity, Package, TypeParent, ValueType, + }; + + let context = Context::parse(&[r#" + syntax = "proto3"; + message Message { + string s = 1; + repeated bytes b = 2; + optional int64 large = 3; + repeated sint32 signed = 4; + Message child = 10; + } + "#]) + .unwrap(); + + let mut expected = Context::new(); + let package = expected.insert_package(Package::new(None)).unwrap(); + let mut message = MessageInfo::new("Message".to_string(), TypeParent::Package(package)); + + message + .add_field(MessageField::new("s".to_string(), 1, ValueType::String)) + .unwrap(); + + let mut b_field = MessageField::new("b".to_string(), 2, ValueType::Bytes); + b_field.multiplicity = Multiplicity::Repeated; + + let mut large_field = MessageField::new("large".to_string(), 3, ValueType::Int64); + large_field.multiplicity = Multiplicity::Optional; + + let mut signed_field = MessageField::new("signed".to_string(), 4, ValueType::SInt32); + signed_field.multiplicity = Multiplicity::RepeatedPacked; + + let child_field = MessageField::new( + "child".to_string(), + 10, + ValueType::Message(message.self_ref), + ); + + message.add_field(b_field).unwrap(); + message.add_field(large_field).unwrap(); + message.add_field(signed_field).unwrap(); + message.add_field(child_field).unwrap(); + + expected.insert_message(message).unwrap(); + + assert_eq!(expected, context); +}