Skip to content

Commit

Permalink
style: Move wire type to separate module
Browse files Browse the repository at this point in the history
While the `encoding` module is undocumented and not stable, there are known users. Leave a alias in place to prevent breaking them.
  • Loading branch information
caspermeijn committed Jul 26, 2024
1 parent 8634d3f commit 997de1d
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 49 deletions.
4 changes: 2 additions & 2 deletions prost-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ fn try_message(input: TokenStream) -> Result<TokenStream, Error> {
fn merge_field(
&mut self,
tag: u32,
wire_type: ::prost::encoding::WireType,
wire_type: ::prost::encoding::wire_type::WireType,
buf: &mut impl ::prost::bytes::Buf,
ctx: ::prost::encoding::DecodeContext,
) -> ::core::result::Result<(), ::prost::DecodeError>
Expand Down Expand Up @@ -472,7 +472,7 @@ fn try_oneof(input: TokenStream) -> Result<TokenStream, Error> {
pub fn merge(
field: &mut ::core::option::Option<#ident #ty_generics>,
tag: u32,
wire_type: ::prost::encoding::WireType,
wire_type: ::prost::encoding::wire_type::WireType,
buf: &mut impl ::prost::bytes::Buf,
ctx: ::prost::encoding::DecodeContext,
) -> ::core::result::Result<(), ::prost::DecodeError>
Expand Down
47 changes: 3 additions & 44 deletions prost/src/encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ pub use length_delimiter::{
decode_length_delimiter, encode_length_delimiter, length_delimiter_len,
};

pub mod wire_type;
pub use wire_type::{check_wire_type, WireType};

/// Additional information passed to every decode/merge function.
///
/// The context should be passed by value and can be freely cloned. When passing
Expand Down Expand Up @@ -94,40 +97,9 @@ impl DecodeContext {
}
}

#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[repr(u8)]
pub enum WireType {
Varint = 0,
SixtyFourBit = 1,
LengthDelimited = 2,
StartGroup = 3,
EndGroup = 4,
ThirtyTwoBit = 5,
}

pub const MIN_TAG: u32 = 1;
pub const MAX_TAG: u32 = (1 << 29) - 1;

impl TryFrom<u64> for WireType {
type Error = DecodeError;

#[inline]
fn try_from(value: u64) -> Result<Self, Self::Error> {
match value {
0 => Ok(WireType::Varint),
1 => Ok(WireType::SixtyFourBit),
2 => Ok(WireType::LengthDelimited),
3 => Ok(WireType::StartGroup),
4 => Ok(WireType::EndGroup),
5 => Ok(WireType::ThirtyTwoBit),
_ => Err(DecodeError::new(format!(
"invalid wire type value: {}",
value
))),
}
}
}

/// Encodes a Protobuf field key, which consists of a wire type designator and
/// the field tag.
#[inline]
Expand Down Expand Up @@ -162,19 +134,6 @@ pub fn key_len(tag: u32) -> usize {
encoded_len_varint(u64::from(tag << 3))
}

/// Checks that the expected wire type matches the actual wire type,
/// or returns an error result.
#[inline]
pub fn check_wire_type(expected: WireType, actual: WireType) -> Result<(), DecodeError> {
if expected != actual {
return Err(DecodeError::new(format!(
"invalid wire type: {:?} (expected {:?})",
actual, expected
)));
}
Ok(())
}

/// Helper function which abstracts reading a length delimiter prefix followed
/// by decoding values until the length of bytes is exhausted.
pub fn merge_loop<T, M, B>(
Expand Down
49 changes: 49 additions & 0 deletions prost/src/encoding/wire_type.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
use crate::DecodeError;
use alloc::format;

/// Represent the wire type for protobuf encoding.
///
/// The integer value is equvilant with the encoded value.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[repr(u8)]
pub enum WireType {
Varint = 0,
SixtyFourBit = 1,
LengthDelimited = 2,
StartGroup = 3,
EndGroup = 4,
ThirtyTwoBit = 5,
}

impl TryFrom<u64> for WireType {
type Error = DecodeError;

#[inline]
fn try_from(value: u64) -> Result<Self, Self::Error> {
match value {
0 => Ok(WireType::Varint),
1 => Ok(WireType::SixtyFourBit),
2 => Ok(WireType::LengthDelimited),
3 => Ok(WireType::StartGroup),
4 => Ok(WireType::EndGroup),
5 => Ok(WireType::ThirtyTwoBit),
_ => Err(DecodeError::new(format!(
"invalid wire type value: {}",
value
))),
}
}
}

/// Checks that the expected wire type matches the actual wire type,
/// or returns an error result.
#[inline]
pub fn check_wire_type(expected: WireType, actual: WireType) -> Result<(), DecodeError> {
if expected != actual {
return Err(DecodeError::new(format!(
"invalid wire type: {:?} (expected {:?})",
actual, expected
)));
}
Ok(())
}
3 changes: 2 additions & 1 deletion prost/src/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ use core::fmt::Debug;
use bytes::{Buf, BufMut};

use crate::encoding::varint::{encode_varint, encoded_len_varint};
use crate::encoding::{decode_key, message, DecodeContext, WireType};
use crate::encoding::wire_type::WireType;
use crate::encoding::{decode_key, message, DecodeContext};
use crate::DecodeError;
use crate::EncodeError;

Expand Down
4 changes: 2 additions & 2 deletions prost/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ use alloc::vec::Vec;

use ::bytes::{Buf, BufMut, Bytes};

use crate::encoding::wire_type::WireType;
use crate::{
encoding::{
bool, bytes, double, float, int32, int64, skip_field, string, uint32, uint64,
DecodeContext, WireType,
bool, bytes, double, float, int32, int64, skip_field, string, uint32, uint64, DecodeContext,
},
DecodeError, Message,
};
Expand Down

0 comments on commit 997de1d

Please sign in to comment.