Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

style: Move encoding functions into separate modules #1111

Merged
merged 3 commits into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions prost-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ fn try_message(input: TokenStream) -> Result<TokenStream, Error> {
fn merge_field(
&mut self,
tag: u32,
wire_type: ::prost::encoding::WireType,
wire_type: ::prost::encoding::wire_type::WireType,
buf: &mut impl ::prost::bytes::Buf,
ctx: ::prost::encoding::DecodeContext,
) -> ::core::result::Result<(), ::prost::DecodeError>
Expand Down Expand Up @@ -472,7 +472,7 @@ fn try_oneof(input: TokenStream) -> Result<TokenStream, Error> {
pub fn merge(
field: &mut ::core::option::Option<#ident #ty_generics>,
tag: u32,
wire_type: ::prost::encoding::WireType,
wire_type: ::prost::encoding::wire_type::WireType,
buf: &mut impl ::prost::bytes::Buf,
ctx: ::prost::encoding::DecodeContext,
) -> ::core::result::Result<(), ::prost::DecodeError>
Expand Down
2 changes: 1 addition & 1 deletion prost/benches/varint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::mem;

use bytes::Buf;
use criterion::{Criterion, Throughput};
use prost::encoding::{decode_varint, encode_varint, encoded_len_varint};
use prost::encoding::varint::{decode_varint, encode_varint, encoded_len_varint};
use rand::{rngs::StdRng, seq::SliceRandom, SeedableRng};

fn benchmark_varint(criterion: &mut Criterion, name: &str, mut values: Vec<u64>) {
Expand Down
311 changes: 8 additions & 303 deletions prost/src/encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ use alloc::collections::BTreeMap;
use alloc::format;
use alloc::string::String;
use alloc::vec::Vec;
use core::cmp::min;
use core::mem;
use core::str;

Expand All @@ -17,162 +16,16 @@ use ::bytes::{Buf, BufMut, Bytes};
use crate::DecodeError;
use crate::Message;

/// Encodes an integer value into LEB128 variable length format, and writes it to the buffer.
/// The buffer must have enough remaining space (maximum 10 bytes).
#[inline]
pub fn encode_varint(mut value: u64, buf: &mut impl BufMut) {
// Varints are never more than 10 bytes
for _ in 0..10 {
if value < 0x80 {
buf.put_u8(value as u8);
break;
} else {
buf.put_u8(((value & 0x7F) | 0x80) as u8);
value >>= 7;
}
}
}

/// Decodes a LEB128-encoded variable length integer from the buffer.
#[inline]
pub fn decode_varint(buf: &mut impl Buf) -> Result<u64, DecodeError> {
let bytes = buf.chunk();
let len = bytes.len();
if len == 0 {
return Err(DecodeError::new("invalid varint"));
}

let byte = bytes[0];
if byte < 0x80 {
buf.advance(1);
Ok(u64::from(byte))
} else if len > 10 || bytes[len - 1] < 0x80 {
let (value, advance) = decode_varint_slice(bytes)?;
buf.advance(advance);
Ok(value)
} else {
decode_varint_slow(buf)
}
}
pub mod varint;
pub use varint::{decode_varint, encode_varint, encoded_len_varint};

/// Decodes a LEB128-encoded variable length integer from the slice, returning the value and the
/// number of bytes read.
///
/// Based loosely on [`ReadVarint64FromArray`][1] with a varint overflow check from
/// [`ConsumeVarint`][2].
///
/// ## Safety
///
/// The caller must ensure that `bytes` is non-empty and either `bytes.len() >= 10` or the last
/// element in bytes is < `0x80`.
///
/// [1]: https://github.com/google/protobuf/blob/3.3.x/src/google/protobuf/io/coded_stream.cc#L365-L406
/// [2]: https://github.com/protocolbuffers/protobuf-go/blob/v1.27.1/encoding/protowire/wire.go#L358
#[inline]
fn decode_varint_slice(bytes: &[u8]) -> Result<(u64, usize), DecodeError> {
// Fully unrolled varint decoding loop. Splitting into 32-bit pieces gives better performance.
pub mod length_delimiter;
pub use length_delimiter::{
decode_length_delimiter, encode_length_delimiter, length_delimiter_len,
};

// Use assertions to ensure memory safety, but it should always be optimized after inline.
assert!(!bytes.is_empty());
assert!(bytes.len() > 10 || bytes[bytes.len() - 1] < 0x80);

let mut b: u8 = unsafe { *bytes.get_unchecked(0) };
let mut part0: u32 = u32::from(b);
if b < 0x80 {
return Ok((u64::from(part0), 1));
};
part0 -= 0x80;
b = unsafe { *bytes.get_unchecked(1) };
part0 += u32::from(b) << 7;
if b < 0x80 {
return Ok((u64::from(part0), 2));
};
part0 -= 0x80 << 7;
b = unsafe { *bytes.get_unchecked(2) };
part0 += u32::from(b) << 14;
if b < 0x80 {
return Ok((u64::from(part0), 3));
};
part0 -= 0x80 << 14;
b = unsafe { *bytes.get_unchecked(3) };
part0 += u32::from(b) << 21;
if b < 0x80 {
return Ok((u64::from(part0), 4));
};
part0 -= 0x80 << 21;
let value = u64::from(part0);

b = unsafe { *bytes.get_unchecked(4) };
let mut part1: u32 = u32::from(b);
if b < 0x80 {
return Ok((value + (u64::from(part1) << 28), 5));
};
part1 -= 0x80;
b = unsafe { *bytes.get_unchecked(5) };
part1 += u32::from(b) << 7;
if b < 0x80 {
return Ok((value + (u64::from(part1) << 28), 6));
};
part1 -= 0x80 << 7;
b = unsafe { *bytes.get_unchecked(6) };
part1 += u32::from(b) << 14;
if b < 0x80 {
return Ok((value + (u64::from(part1) << 28), 7));
};
part1 -= 0x80 << 14;
b = unsafe { *bytes.get_unchecked(7) };
part1 += u32::from(b) << 21;
if b < 0x80 {
return Ok((value + (u64::from(part1) << 28), 8));
};
part1 -= 0x80 << 21;
let value = value + ((u64::from(part1)) << 28);

b = unsafe { *bytes.get_unchecked(8) };
let mut part2: u32 = u32::from(b);
if b < 0x80 {
return Ok((value + (u64::from(part2) << 56), 9));
};
part2 -= 0x80;
b = unsafe { *bytes.get_unchecked(9) };
part2 += u32::from(b) << 7;
// Check for u64::MAX overflow. See [`ConsumeVarint`][1] for details.
// [1]: https://github.com/protocolbuffers/protobuf-go/blob/v1.27.1/encoding/protowire/wire.go#L358
if b < 0x02 {
return Ok((value + (u64::from(part2) << 56), 10));
};

// We have overrun the maximum size of a varint (10 bytes) or the final byte caused an overflow.
// Assume the data is corrupt.
Err(DecodeError::new("invalid varint"))
}

/// Decodes a LEB128-encoded variable length integer from the buffer, advancing the buffer as
/// necessary.
///
/// Contains a varint overflow check from [`ConsumeVarint`][1].
///
/// [1]: https://github.com/protocolbuffers/protobuf-go/blob/v1.27.1/encoding/protowire/wire.go#L358
#[inline(never)]
#[cold]
fn decode_varint_slow(buf: &mut impl Buf) -> Result<u64, DecodeError> {
let mut value = 0;
for count in 0..min(10, buf.remaining()) {
let byte = buf.get_u8();
value |= u64::from(byte & 0x7F) << (count * 7);
if byte <= 0x7F {
// Check for u64::MAX overflow. See [`ConsumeVarint`][1] for details.
// [1]: https://github.com/protocolbuffers/protobuf-go/blob/v1.27.1/encoding/protowire/wire.go#L358
if count == 9 && byte >= 0x02 {
return Err(DecodeError::new("invalid varint"));
} else {
return Ok(value);
}
}
}

Err(DecodeError::new("invalid varint"))
}
pub mod wire_type;
pub use wire_type::{check_wire_type, WireType};

/// Additional information passed to every decode/merge function.
///
Expand Down Expand Up @@ -244,49 +97,9 @@ impl DecodeContext {
}
}

/// Returns the encoded length of the value in LEB128 variable length format.
/// The returned value will be between 1 and 10, inclusive.
#[inline]
pub fn encoded_len_varint(value: u64) -> usize {
// Based on [VarintSize64][1].
// [1]: https://github.com/google/protobuf/blob/3.3.x/src/google/protobuf/io/coded_stream.h#L1301-L1309
((((value | 1).leading_zeros() ^ 63) * 9 + 73) / 64) as usize
}

#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[repr(u8)]
pub enum WireType {
Varint = 0,
SixtyFourBit = 1,
LengthDelimited = 2,
StartGroup = 3,
EndGroup = 4,
ThirtyTwoBit = 5,
}

pub const MIN_TAG: u32 = 1;
pub const MAX_TAG: u32 = (1 << 29) - 1;

impl TryFrom<u64> for WireType {
type Error = DecodeError;

#[inline]
fn try_from(value: u64) -> Result<Self, Self::Error> {
match value {
0 => Ok(WireType::Varint),
1 => Ok(WireType::SixtyFourBit),
2 => Ok(WireType::LengthDelimited),
3 => Ok(WireType::StartGroup),
4 => Ok(WireType::EndGroup),
5 => Ok(WireType::ThirtyTwoBit),
_ => Err(DecodeError::new(format!(
"invalid wire type value: {}",
value
))),
}
}
}

/// Encodes a Protobuf field key, which consists of a wire type designator and
/// the field tag.
#[inline]
Expand Down Expand Up @@ -321,19 +134,6 @@ pub fn key_len(tag: u32) -> usize {
encoded_len_varint(u64::from(tag << 3))
}

/// Checks that the expected wire type matches the actual wire type,
/// or returns an error result.
#[inline]
pub fn check_wire_type(expected: WireType, actual: WireType) -> Result<(), DecodeError> {
if expected != actual {
return Err(DecodeError::new(format!(
"invalid wire type: {:?} (expected {:?})",
actual, expected
)));
}
Ok(())
}

/// Helper function which abstracts reading a length delimiter prefix followed
/// by decoding values until the length of bytes is exhausted.
pub fn merge_loop<T, M, B>(
Expand Down Expand Up @@ -1522,101 +1322,6 @@ mod test {
assert!(s.is_empty());
}

#[test]
fn varint() {
fn check(value: u64, encoded: &[u8]) {
// Small buffer.
let mut buf = Vec::with_capacity(1);
encode_varint(value, &mut buf);
assert_eq!(buf, encoded);

// Large buffer.
let mut buf = Vec::with_capacity(100);
encode_varint(value, &mut buf);
assert_eq!(buf, encoded);

assert_eq!(encoded_len_varint(value), encoded.len());

// See: https://github.com/tokio-rs/prost/pull/1008 for copying reasoning.
let mut encoded_copy = encoded;
let roundtrip_value = decode_varint(&mut encoded_copy).expect("decoding failed");
assert_eq!(value, roundtrip_value);

let mut encoded_copy = encoded;
let roundtrip_value =
decode_varint_slow(&mut encoded_copy).expect("slow decoding failed");
assert_eq!(value, roundtrip_value);
}

check(2u64.pow(0) - 1, &[0x00]);
check(2u64.pow(0), &[0x01]);

check(2u64.pow(7) - 1, &[0x7F]);
check(2u64.pow(7), &[0x80, 0x01]);
check(300, &[0xAC, 0x02]);

check(2u64.pow(14) - 1, &[0xFF, 0x7F]);
check(2u64.pow(14), &[0x80, 0x80, 0x01]);

check(2u64.pow(21) - 1, &[0xFF, 0xFF, 0x7F]);
check(2u64.pow(21), &[0x80, 0x80, 0x80, 0x01]);

check(2u64.pow(28) - 1, &[0xFF, 0xFF, 0xFF, 0x7F]);
check(2u64.pow(28), &[0x80, 0x80, 0x80, 0x80, 0x01]);

check(2u64.pow(35) - 1, &[0xFF, 0xFF, 0xFF, 0xFF, 0x7F]);
check(2u64.pow(35), &[0x80, 0x80, 0x80, 0x80, 0x80, 0x01]);

check(2u64.pow(42) - 1, &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F]);
check(2u64.pow(42), &[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01]);

check(
2u64.pow(49) - 1,
&[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F],
);
check(
2u64.pow(49),
&[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01],
);

check(
2u64.pow(56) - 1,
&[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F],
);
check(
2u64.pow(56),
&[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01],
);

check(
2u64.pow(63) - 1,
&[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F],
);
check(
2u64.pow(63),
&[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01],
);

check(
u64::MAX,
&[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x01],
);
}

const U64_MAX_PLUS_ONE: &[u8] = &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x02];

#[test]
fn varint_overflow() {
let mut copy = U64_MAX_PLUS_ONE;
decode_varint(&mut copy).expect_err("decoding u64::MAX + 1 succeeded");
}

#[test]
fn variant_slow_overflow() {
let mut copy = U64_MAX_PLUS_ONE;
decode_varint_slow(&mut copy).expect_err("slow decoding u64::MAX + 1 succeeded");
}

/// This big bowl o' macro soup generates an encoding property test for each combination of map
/// type, scalar map key, and value type.
/// TODO: these tests take a long time to compile, can this be improved?
Expand Down
Loading
Loading