Skip to content

Commit

Permalink
Allow specifying compression encoding order
Browse files Browse the repository at this point in the history
  • Loading branch information
CapCap committed Jun 13, 2024
1 parent c783652 commit d6288d7
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 29 deletions.
17 changes: 16 additions & 1 deletion tests/compression/src/compressing_request.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::*;
use http_body::Body;
use tonic::codec::CompressionEncoding;
use tonic::codec::{CompressionEncoding, EnabledCompressionEncodings};

util::parametrized_tests! {
client_enabled_server_enabled,
Expand Down Expand Up @@ -230,3 +230,18 @@ async fn client_mark_compressed_without_header_server_enabled(encoding: Compress
"protocol error: received message with compressed-flag but no grpc-encoding was specified"
);
}

#[test]
fn test_compression_priority() {
let mut encodings = EnabledCompressionEncodings::default();
encodings.enable(CompressionEncoding::Gzip);
encodings.enable(CompressionEncoding::Zstd);

assert_eq!(encodings.priority(CompressionEncoding::Gzip), Some(1));
assert_eq!(encodings.priority(CompressionEncoding::Zstd), Some(0));

encodings.enable(CompressionEncoding::Gzip);

assert_eq!(encodings.priority(CompressionEncoding::Gzip), Some(0));
assert_eq!(encodings.priority(CompressionEncoding::Zstd), Some(1));
}
93 changes: 73 additions & 20 deletions tonic/src/codec/compression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,28 @@ use crate::{metadata::MetadataValue, Status};
use bytes::{Buf, BytesMut};
#[cfg(feature = "gzip")]
use flate2::read::{GzDecoder, GzEncoder};
use http::HeaderValue;
use std::fmt;
#[cfg(feature = "zstd")]
use zstd::stream::read::{Decoder, Encoder};

pub(crate) const ENCODING_HEADER: &str = "grpc-encoding";
pub(crate) const ACCEPT_ENCODING_HEADER: &str = "grpc-accept-encoding";

/// This should always match the cardinality of `CompressionEncoding`
pub(crate) const COMPRESSION_ENCODINGS_LENGTH: usize = 2;

/// Struct used to configure which encodings are enabled on a server or channel.
/// Supports setting the priority of each compression
#[derive(Debug, Default, Clone, Copy)]
pub struct EnabledCompressionEncodings {
// We keep the encodings as struct values so we get `const fn` compatibility
#[cfg(feature = "gzip")]
pub(crate) gzip: bool,
#[cfg(feature = "zstd")]
pub(crate) zstd: bool,
// And we have an array so we can keep the order of the encodings (i.e prefer `zstd` over `gzip`)
pub(crate) order: [Option<CompressionEncoding>; COMPRESSION_ENCODINGS_LENGTH],
}

impl EnabledCompressionEncodings {
Expand All @@ -30,22 +38,57 @@ impl EnabledCompressionEncodings {
}

/// Enable a [`CompressionEncoding`].
/// Every time an encoding is enabled, it is given the lowest priority (the start of the list)
/// In order to enable both `gzip` and `zstd`, and have zstd have higher priority, you would call:
/// `enable(CompressionEncoding::Zstd).enable(CompressionEncoding::Gzip)`
pub fn enable(&mut self, encoding: CompressionEncoding) {
match encoding {
#[cfg(feature = "gzip")]
CompressionEncoding::Gzip => self.gzip = true,
#[cfg(feature = "zstd")]
CompressionEncoding::Zstd => self.zstd = true,
}

// If it is already enabled, remove it, so we can add it to the front of the list
for encoding_opt in &mut self.order {
if let Some(enabled_encoding) = encoding_opt {
if *enabled_encoding == encoding {
*encoding_opt = None;
}
}
}

// Free up the element at the front of the list by shifting all elements to the right
for i in (1..self.order.len()).rev() {
self.order[i] = self.order[i - 1];
}

// Add the new encoding to the front of the list
self.order[0] = Some(encoding);
}

/// Get the priority of a given encoding
#[inline]
pub fn priority(&self, encoding: CompressionEncoding) -> Option<usize> {
self.order.iter().position(|&e| e == Some(encoding))
}

pub(crate) fn into_accept_encoding_header_value(self) -> Option<http::HeaderValue> {
match (self.is_gzip_enabled(), self.is_zstd_enabled()) {
(true, false) => Some(http::HeaderValue::from_static("gzip,identity")),
(false, true) => Some(http::HeaderValue::from_static("zstd,identity")),
(true, true) => Some(http::HeaderValue::from_static("gzip,zstd,identity")),
(false, false) => None,
if !self.is_gzip_enabled() && !self.is_zstd_enabled() {
return None;
}
// Here we are guaranteed to have at least one, so we can concat with comma
// They are sent in priority order (i.e `zstd,gzip,identity`)
HeaderValue::from_str(
&(self
.order
.iter()
.rev()
.filter_map(|encoding| encoding.map(|e| e.as_str()))
.collect::<Vec<_>>()
.join(",") + ",identity"),
)
.ok()
}

#[cfg(feature = "gzip")]
Expand Down Expand Up @@ -104,13 +147,30 @@ impl CompressionEncoding {
let header_value = map.get(ACCEPT_ENCODING_HEADER)?;
let header_value_str = header_value.to_str().ok()?;

split_by_comma(header_value_str).find_map(|value| match value {
#[cfg(feature = "gzip")]
"gzip" => Some(CompressionEncoding::Gzip),
#[cfg(feature = "zstd")]
"zstd" => Some(CompressionEncoding::Zstd),
_ => None,
})
// Get the highest priority supported encoding
split_by_comma(header_value_str)
// We allow for +1 to account for the identity encoding
.take(COMPRESSION_ENCODINGS_LENGTH + 1)
.filter_map(|value| {
let encoding = match value {
#[cfg(feature = "gzip")]
"gzip" => Some(CompressionEncoding::Gzip),
#[cfg(feature = "zstd")]
"zstd" => Some(CompressionEncoding::Zstd),
_ => None,
};
if let Some(encoding) = encoding {
if let Some(priority) = enabled_encodings.priority(encoding) {
Some((encoding, priority))
} else {
None
}
} else {
None
}
})
.max_by_key(|(_, priority)| *priority)
.map(|(encoding, _)| encoding)
}

/// Get the value of `grpc-encoding` header. Returns an error if the encoding isn't supported.
Expand Down Expand Up @@ -150,6 +210,7 @@ impl CompressionEncoding {
.into_accept_encoding_header_value()
.map(MetadataValue::unchecked_from_header_value)
.unwrap_or_else(|| MetadataValue::from_static("identity"));

status
.metadata_mut()
.insert(ACCEPT_ENCODING_HEADER, header_value);
Expand All @@ -175,14 +236,6 @@ impl CompressionEncoding {
http::HeaderValue::from_static(self.as_str())
}

pub(crate) fn encodings() -> &'static [Self] {
&[
#[cfg(feature = "gzip")]
CompressionEncoding::Gzip,
#[cfg(feature = "zstd")]
CompressionEncoding::Zstd,
]
}
}

impl fmt::Display for CompressionEncoding {
Expand Down
10 changes: 2 additions & 8 deletions tonic/src/server/grpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,14 +187,8 @@ where
) -> Self {
let mut this = self;

for &encoding in CompressionEncoding::encodings() {
if accept_encodings.is_enabled(encoding) {
this = this.accept_compressed(encoding);
}
if send_encodings.is_enabled(encoding) {
this = this.send_compressed(encoding);
}
}
this.accept_compression_encodings = accept_encodings;
this.send_compression_encodings = send_encodings;

this
}
Expand Down

0 comments on commit d6288d7

Please sign in to comment.