diff --git a/Cargo.toml b/Cargo.toml index 171ecf2a..07c5537a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -42,7 +42,7 @@ memchr = "2" pin-project-lite = "0.2" tokio = { version = "1.24.2", optional = true, default-features = false } xz2 = { version = "0.1.6", optional = true } -zstd-safe = { version = "7", optional = true, default-features = false } +zstd-safe = { version = "7.1", optional = true, default-features = false } deflate64 = { version = "0.1.5", optional = true } [dev-dependencies] diff --git a/src/codec/zstd/decoder.rs b/src/codec/zstd/decoder.rs index c2696799..d1a92f78 100644 --- a/src/codec/zstd/decoder.rs +++ b/src/codec/zstd/decoder.rs @@ -16,6 +16,16 @@ impl ZstdDecoder { } } + pub(crate) fn new_with_params(params: &[crate::zstd::DParameter]) -> Self { + let mut decoder = Decoder::new().unwrap(); + for param in params { + decoder.set_parameter(param.as_zstd()).unwrap(); + } + Self { + decoder: Unshared::new(decoder), + } + } + pub(crate) fn new_with_dict(dictionary: &[u8]) -> io::Result { let mut decoder = Decoder::with_dictionary(dictionary)?; Ok(Self { diff --git a/src/macros.rs b/src/macros.rs index 70844ad1..75909638 100644 --- a/src/macros.rs +++ b/src/macros.rs @@ -198,6 +198,17 @@ macro_rules! algos { } } { @dec + /// Creates a new decoder, using the specified parameters, which will read compressed + /// data from the given stream and emit a decompressed stream. + pub fn with_params(inner: $inner, params: &[crate::zstd::DParameter]) -> Self { + Self { + inner: crate::$($mod::)+generic::Decoder::new( + inner, + crate::codec::ZstdDecoder::new_with_params(params), + ), + } + } + /// Creates a new decoder, using the specified compression level and pre-trained /// dictionary, which will read compressed data from the given stream and emit an /// uncompressed stream. diff --git a/src/zstd.rs b/src/zstd.rs index 4c97bd5d..be40b94a 100644 --- a/src/zstd.rs +++ b/src/zstd.rs @@ -1,6 +1,7 @@ //! This module contains zstd-specific types for async-compression. use libzstd::stream::raw::CParameter::*; +use libzstd::stream::raw::DParameter::*; /// A compression parameter for zstd. This is a stable wrapper around zstd's own `CParameter` /// type, to abstract over different versions of the zstd library. @@ -110,3 +111,22 @@ impl CParameter { self.0 } } + +/// A decompression parameter for zstd. This is a stable wrapper around zstd's own `DParameter` +/// type, to abstract over different versions of the zstd library. +/// +/// See the [zstd documentation](https://facebook.github.io/zstd/zstd_manual.html) for more +/// information on these parameters. +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub struct DParameter(libzstd::stream::raw::DParameter); + +impl DParameter { + /// Maximum window size in bytes (as a power of two) + pub fn window_log_max(value: u32) -> Self { + Self(WindowLogMax(value)) + } + + pub(crate) fn as_zstd(&self) -> libzstd::stream::raw::DParameter { + self.0 + } +} diff --git a/tests/artifacts/long-window-size-lib.rs.zst b/tests/artifacts/long-window-size-lib.rs.zst new file mode 100644 index 00000000..5bf51d5e Binary files /dev/null and b/tests/artifacts/long-window-size-lib.rs.zst differ diff --git a/tests/zstd.rs b/tests/zstd.rs index 8401aff9..b1e7a070 100644 --- a/tests/zstd.rs +++ b/tests/zstd.rs @@ -2,3 +2,47 @@ mod utils; test_cases!(zstd); + +use async_compression::zstd::DParameter; +use tokio::io::AsyncWriteExt as _; + +#[tokio::test] +async fn zstd_decode_large_window_size_default() { + let compressed = include_bytes!("./artifacts/long-window-size-lib.rs.zst"); + + // Default decoder should throw with an error, window size maximum is too low. + let mut decoder = async_compression::tokio::write::ZstdDecoder::new(Vec::new()); + decoder.write_all(compressed).await.unwrap_err(); +} + +#[tokio::test] +async fn zstd_decode_large_window_size_explicit_small_window_size() { + let compressed = include_bytes!("./artifacts/long-window-size-lib.rs.zst"); + + // Short window decoder should throw with an error, window size maximum is too low. + let mut decoder = async_compression::tokio::write::ZstdDecoder::with_params( + Vec::new(), + &[DParameter::window_log_max(16)], + ); + decoder.write_all(compressed).await.unwrap_err(); +} + +#[tokio::test] +async fn zstd_decode_large_window_size_explicit_large_window_size() { + let compressed = include_bytes!("./artifacts/long-window-size-lib.rs.zst"); + let source = include_bytes!("./artifacts/lib.rs"); + + // Long window decoder should succeed as the window size is large enough to decompress the given input. + let mut long_window_size_decoder = async_compression::tokio::write::ZstdDecoder::with_params( + Vec::new(), + &[DParameter::window_log_max(31)], + ); + // Long window size decoder should successfully decode the given input data. + long_window_size_decoder + .write_all(compressed) + .await + .unwrap(); + long_window_size_decoder.shutdown().await.unwrap(); + + assert_eq!(long_window_size_decoder.into_inner(), source); +}