diff --git a/prost-derive/src/lib.rs b/prost-derive/src/lib.rs index 789ea2c5b..f34fa0a35 100644 --- a/prost-derive/src/lib.rs +++ b/prost-derive/src/lib.rs @@ -183,7 +183,7 @@ fn try_message(input: TokenStream) -> Result { fn merge_field( &mut self, tag: u32, - wire_type: ::prost::encoding::WireType, + wire_type: ::prost::encoding::wire_type::WireType, buf: &mut impl ::prost::bytes::Buf, ctx: ::prost::encoding::DecodeContext, ) -> ::core::result::Result<(), ::prost::DecodeError> @@ -472,7 +472,7 @@ fn try_oneof(input: TokenStream) -> Result { pub fn merge( field: &mut ::core::option::Option<#ident #ty_generics>, tag: u32, - wire_type: ::prost::encoding::WireType, + wire_type: ::prost::encoding::wire_type::WireType, buf: &mut impl ::prost::bytes::Buf, ctx: ::prost::encoding::DecodeContext, ) -> ::core::result::Result<(), ::prost::DecodeError> diff --git a/prost/src/encoding.rs b/prost/src/encoding.rs index bf27fcf1b..96eb99ae4 100644 --- a/prost/src/encoding.rs +++ b/prost/src/encoding.rs @@ -24,6 +24,9 @@ pub use length_delimiter::{ decode_length_delimiter, encode_length_delimiter, length_delimiter_len, }; +pub mod wire_type; +pub use wire_type::{check_wire_type, WireType}; + /// Additional information passed to every decode/merge function. /// /// The context should be passed by value and can be freely cloned. When passing @@ -94,40 +97,9 @@ impl DecodeContext { } } -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -#[repr(u8)] -pub enum WireType { - Varint = 0, - SixtyFourBit = 1, - LengthDelimited = 2, - StartGroup = 3, - EndGroup = 4, - ThirtyTwoBit = 5, -} - pub const MIN_TAG: u32 = 1; pub const MAX_TAG: u32 = (1 << 29) - 1; -impl TryFrom for WireType { - type Error = DecodeError; - - #[inline] - fn try_from(value: u64) -> Result { - match value { - 0 => Ok(WireType::Varint), - 1 => Ok(WireType::SixtyFourBit), - 2 => Ok(WireType::LengthDelimited), - 3 => Ok(WireType::StartGroup), - 4 => Ok(WireType::EndGroup), - 5 => Ok(WireType::ThirtyTwoBit), - _ => Err(DecodeError::new(format!( - "invalid wire type value: {}", - value - ))), - } - } -} - /// Encodes a Protobuf field key, which consists of a wire type designator and /// the field tag. #[inline] @@ -162,19 +134,6 @@ pub fn key_len(tag: u32) -> usize { encoded_len_varint(u64::from(tag << 3)) } -/// Checks that the expected wire type matches the actual wire type, -/// or returns an error result. -#[inline] -pub fn check_wire_type(expected: WireType, actual: WireType) -> Result<(), DecodeError> { - if expected != actual { - return Err(DecodeError::new(format!( - "invalid wire type: {:?} (expected {:?})", - actual, expected - ))); - } - Ok(()) -} - /// Helper function which abstracts reading a length delimiter prefix followed /// by decoding values until the length of bytes is exhausted. pub fn merge_loop( diff --git a/prost/src/encoding/wire_type.rs b/prost/src/encoding/wire_type.rs new file mode 100644 index 000000000..74a857a2e --- /dev/null +++ b/prost/src/encoding/wire_type.rs @@ -0,0 +1,49 @@ +use crate::DecodeError; +use alloc::format; + +/// Represent the wire type for protobuf encoding. +/// +/// The integer value is equvilant with the encoded value. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[repr(u8)] +pub enum WireType { + Varint = 0, + SixtyFourBit = 1, + LengthDelimited = 2, + StartGroup = 3, + EndGroup = 4, + ThirtyTwoBit = 5, +} + +impl TryFrom for WireType { + type Error = DecodeError; + + #[inline] + fn try_from(value: u64) -> Result { + match value { + 0 => Ok(WireType::Varint), + 1 => Ok(WireType::SixtyFourBit), + 2 => Ok(WireType::LengthDelimited), + 3 => Ok(WireType::StartGroup), + 4 => Ok(WireType::EndGroup), + 5 => Ok(WireType::ThirtyTwoBit), + _ => Err(DecodeError::new(format!( + "invalid wire type value: {}", + value + ))), + } + } +} + +/// Checks that the expected wire type matches the actual wire type, +/// or returns an error result. +#[inline] +pub fn check_wire_type(expected: WireType, actual: WireType) -> Result<(), DecodeError> { + if expected != actual { + return Err(DecodeError::new(format!( + "invalid wire type: {:?} (expected {:?})", + actual, expected + ))); + } + Ok(()) +} diff --git a/prost/src/message.rs b/prost/src/message.rs index d0b25a1c2..ee33eecd9 100644 --- a/prost/src/message.rs +++ b/prost/src/message.rs @@ -8,7 +8,8 @@ use core::fmt::Debug; use bytes::{Buf, BufMut}; use crate::encoding::varint::{encode_varint, encoded_len_varint}; -use crate::encoding::{decode_key, message, DecodeContext, WireType}; +use crate::encoding::wire_type::WireType; +use crate::encoding::{decode_key, message, DecodeContext}; use crate::DecodeError; use crate::EncodeError; diff --git a/prost/src/types.rs b/prost/src/types.rs index 6e4994bfb..5abb513bc 100644 --- a/prost/src/types.rs +++ b/prost/src/types.rs @@ -12,10 +12,10 @@ use alloc::vec::Vec; use ::bytes::{Buf, BufMut, Bytes}; +use crate::encoding::wire_type::WireType; use crate::{ encoding::{ - bool, bytes, double, float, int32, int64, skip_field, string, uint32, uint64, - DecodeContext, WireType, + bool, bytes, double, float, int32, int64, skip_field, string, uint32, uint64, DecodeContext, }, DecodeError, Message, };