Skip to content

Commit

Permalink
Merge pull request #1 from aptos-labs/max_specify_encoding_orderings
Browse files Browse the repository at this point in the history
Allow specifying compression encoding order
  • Loading branch information
CapCap committed Jun 15, 2024
2 parents c783652 + e7f9d3f commit 5cc673b
Show file tree
Hide file tree
Showing 3 changed files with 170 additions and 48 deletions.
91 changes: 90 additions & 1 deletion tests/compression/src/compressing_request.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::*;
use http_body::Body;
use tonic::codec::CompressionEncoding;
use tonic::codec::{CompressionEncoding, EnabledCompressionEncodings};

util::parametrized_tests! {
client_enabled_server_enabled,
Expand Down Expand Up @@ -230,3 +230,92 @@ async fn client_mark_compressed_without_header_server_enabled(encoding: Compress
"protocol error: received message with compressed-flag but no grpc-encoding was specified"
);
}

#[test]
fn test_compression_priority() {
let mut encodings = EnabledCompressionEncodings::default();
encodings.enable(CompressionEncoding::Gzip);
encodings.enable(CompressionEncoding::Zstd);

assert_eq!(encodings.priority(CompressionEncoding::Gzip), Some(1));
assert_eq!(encodings.priority(CompressionEncoding::Zstd), Some(0));

encodings.enable(CompressionEncoding::Gzip);

assert_eq!(encodings.priority(CompressionEncoding::Gzip), Some(0));
assert_eq!(encodings.priority(CompressionEncoding::Zstd), Some(1));
}

#[allow(dead_code)]
fn build_accept_encoding_header(encodings: &str) -> http::HeaderMap {
let mut headers = http::HeaderMap::new();
headers.insert("grpc-accept-encoding", encodings.parse().unwrap());
headers
}

#[allow(dead_code)]
fn build_enabled_compression_settings(
encodings: &[CompressionEncoding],
) -> EnabledCompressionEncodings {
let mut settings = EnabledCompressionEncodings::default();
for encoding in encodings {
settings.enable(*encoding);
}
settings
}

#[allow(dead_code)]
fn build_and_run_accept_encoding_header_test(
encodings: &str,
enabled_encodings: &[CompressionEncoding],
expected: Option<CompressionEncoding>,
) {
let header = build_accept_encoding_header(encodings);
let compression = CompressionEncoding::from_accept_encoding_header(
&header,
build_enabled_compression_settings(enabled_encodings),
);
assert_eq!(compression, expected);
}

#[test]
fn test_from_accept_encoding_header() {
build_and_run_accept_encoding_header_test(
"gzip",
&[CompressionEncoding::Gzip],
Some(CompressionEncoding::Gzip),
);

build_and_run_accept_encoding_header_test(
"zstd",
&[CompressionEncoding::Zstd],
Some(CompressionEncoding::Zstd),
);

// Client provides ordering preferring gzip, but we prefer zstd
build_and_run_accept_encoding_header_test(
"gzip,zstd",
&[CompressionEncoding::Zstd, CompressionEncoding::Gzip],
Some(CompressionEncoding::Zstd),
);

// Client provides ordering preferring zstd, but we prefer gzip
build_and_run_accept_encoding_header_test(
"zstd,gzip",
&[CompressionEncoding::Gzip, CompressionEncoding::Zstd],
Some(CompressionEncoding::Gzip),
);

// Client provides ordering preferring gzip, and we also prefer gzip
build_and_run_accept_encoding_header_test(
"gzip,zstd",
&[CompressionEncoding::Gzip, CompressionEncoding::Zstd],
Some(CompressionEncoding::Gzip),
);

// Client provides two, but we don't support any
build_and_run_accept_encoding_header_test("gzip,zstd", &[], None);

// Client provides gzip, but we only support zstd
build_and_run_accept_encoding_header_test("gzip", &[CompressionEncoding::Zstd], None);
}
117 changes: 78 additions & 39 deletions tonic/src/codec/compression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,69 +2,101 @@ use crate::{metadata::MetadataValue, Status};
use bytes::{Buf, BytesMut};
#[cfg(feature = "gzip")]
use flate2::read::{GzDecoder, GzEncoder};
use http::HeaderValue;
use std::fmt;
#[cfg(feature = "zstd")]
use zstd::stream::read::{Decoder, Encoder};

pub(crate) const ENCODING_HEADER: &str = "grpc-encoding";
pub(crate) const ACCEPT_ENCODING_HEADER: &str = "grpc-accept-encoding";

/// This should always match the cardinality of the `CompressionEncoding` enum
pub(crate) const COMPRESSION_ENCODINGS_LENGTH: usize = 2;

/// Struct used to configure which encodings are enabled on a server or channel.
/// Supports setting the priority of each compression
#[derive(Debug, Default, Clone, Copy)]
pub struct EnabledCompressionEncodings {
#[cfg(feature = "gzip")]
pub(crate) gzip: bool,
#[cfg(feature = "zstd")]
pub(crate) zstd: bool,
// And we have an array so we can keep the order of the encodings (i.e prefer `zstd` over `gzip`)
pub(crate) order: [Option<CompressionEncoding>; COMPRESSION_ENCODINGS_LENGTH],
}

impl EnabledCompressionEncodings {
/// Check if a [`CompressionEncoding`] is enabled.
pub fn is_enabled(&self, encoding: CompressionEncoding) -> bool {
match encoding {
#[cfg(feature = "gzip")]
CompressionEncoding::Gzip => self.gzip,
CompressionEncoding::Gzip => self.is_gzip_enabled(),
#[cfg(feature = "zstd")]
CompressionEncoding::Zstd => self.zstd,
CompressionEncoding::Zstd => self.is_zstd_enabled(),
}
}

/// Enable a [`CompressionEncoding`].
/// Every time an encoding is enabled, it is given the lowest priority (the start of the list)
/// In order to enable both `gzip` and `zstd`, and have zstd have higher priority, you would call:
/// `enable(CompressionEncoding::Zstd).enable(CompressionEncoding::Gzip)`
pub fn enable(&mut self, encoding: CompressionEncoding) {
match encoding {
#[cfg(feature = "gzip")]
CompressionEncoding::Gzip => self.gzip = true,
#[cfg(feature = "zstd")]
CompressionEncoding::Zstd => self.zstd = true,
// If it is already enabled, move everything after it left to overwrite it
if let Some(index) = self.order.iter().position(|&e| e == Some(encoding)) {
for i in (0..index).rev() {
self.order[i + 1] = self.order[i];
}
}

// Free up the element at the front of the list by shifting all elements to the right
self.order.rotate_right(1);

// Add the new encoding to the front of the list
self.order[0] = Some(encoding);
}

/// Get the priority of a given encoding
#[inline]
pub fn priority(&self, encoding: CompressionEncoding) -> Option<usize> {
self.order.iter().position(|&e| e == Some(encoding))
}

pub(crate) fn into_accept_encoding_header_value(self) -> Option<http::HeaderValue> {
match (self.is_gzip_enabled(), self.is_zstd_enabled()) {
(true, false) => Some(http::HeaderValue::from_static("gzip,identity")),
(false, true) => Some(http::HeaderValue::from_static("zstd,identity")),
(true, true) => Some(http::HeaderValue::from_static("gzip,zstd,identity")),
(false, false) => None,
if !self.is_gzip_enabled() && !self.is_zstd_enabled() {
return None;
}
// Here we are guaranteed to have at least one, so we can concat with comma
// They are sent in priority order (i.e `zstd,gzip,identity`)
HeaderValue::from_str(
&(self
.order
.iter()
.rev()
.filter_map(|encoding| encoding.map(|e| e.as_str()))
.collect::<Vec<_>>()
.join(",")
+ ",identity"),
)
.ok()
}

#[cfg(feature = "gzip")]
const fn is_gzip_enabled(&self) -> bool {
self.gzip
fn is_gzip_enabled(&self) -> bool {
self.order
.iter()
.any(|&e| e == Some(CompressionEncoding::Gzip))
}

#[cfg(not(feature = "gzip"))]
const fn is_gzip_enabled(&self) -> bool {
fn is_gzip_enabled(&self) -> bool {
false
}

#[cfg(feature = "zstd")]
const fn is_zstd_enabled(&self) -> bool {
self.zstd
fn is_zstd_enabled(&self) -> bool {
self.order
.iter()
.any(|&e| e == Some(CompressionEncoding::Zstd))
}

#[cfg(not(feature = "zstd"))]
const fn is_zstd_enabled(&self) -> bool {
fn is_zstd_enabled(&self) -> bool {
false
}
}
Expand Down Expand Up @@ -93,7 +125,7 @@ pub enum CompressionEncoding {

impl CompressionEncoding {
/// Based on the `grpc-accept-encoding` header, pick an encoding to use.
pub(crate) fn from_accept_encoding_header(
pub fn from_accept_encoding_header(
map: &http::HeaderMap,
enabled_encodings: EnabledCompressionEncodings,
) -> Option<Self> {
Expand All @@ -104,13 +136,28 @@ impl CompressionEncoding {
let header_value = map.get(ACCEPT_ENCODING_HEADER)?;
let header_value_str = header_value.to_str().ok()?;

split_by_comma(header_value_str).find_map(|value| match value {
#[cfg(feature = "gzip")]
"gzip" => Some(CompressionEncoding::Gzip),
#[cfg(feature = "zstd")]
"zstd" => Some(CompressionEncoding::Zstd),
_ => None,
})
// Get the highest priority supported encoding
split_by_comma(header_value_str)
// We allow for +1 to account for the identity encoding
.take(COMPRESSION_ENCODINGS_LENGTH + 1)
.filter_map(|value| {
let encoding = match value {
#[cfg(feature = "gzip")]
"gzip" => Some(CompressionEncoding::Gzip),
#[cfg(feature = "zstd")]
"zstd" => Some(CompressionEncoding::Zstd),
_ => None,
};
if let Some(encoding) = encoding {
enabled_encodings
.priority(encoding)
.map(|priority| (encoding, priority))
} else {
None
}
})
.max_by_key(|(_, priority)| *priority)
.map(|(encoding, _)| encoding)
}

/// Get the value of `grpc-encoding` header. Returns an error if the encoding isn't supported.
Expand Down Expand Up @@ -150,6 +197,7 @@ impl CompressionEncoding {
.into_accept_encoding_header_value()
.map(MetadataValue::unchecked_from_header_value)
.unwrap_or_else(|| MetadataValue::from_static("identity"));

status
.metadata_mut()
.insert(ACCEPT_ENCODING_HEADER, header_value);
Expand All @@ -174,15 +222,6 @@ impl CompressionEncoding {
pub(crate) fn into_header_value(self) -> http::HeaderValue {
http::HeaderValue::from_static(self.as_str())
}

pub(crate) fn encodings() -> &'static [Self] {
&[
#[cfg(feature = "gzip")]
CompressionEncoding::Gzip,
#[cfg(feature = "zstd")]
CompressionEncoding::Zstd,
]
}
}

impl fmt::Display for CompressionEncoding {
Expand Down
10 changes: 2 additions & 8 deletions tonic/src/server/grpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,14 +187,8 @@ where
) -> Self {
let mut this = self;

for &encoding in CompressionEncoding::encodings() {
if accept_encodings.is_enabled(encoding) {
this = this.accept_compressed(encoding);
}
if send_encodings.is_enabled(encoding) {
this = this.send_compressed(encoding);
}
}
this.accept_compression_encodings = accept_encodings;
this.send_compression_encodings = send_encodings;

this
}
Expand Down

0 comments on commit 5cc673b

Please sign in to comment.