From c0c6559a2ec48b3149872a26d02af53ec93e1bcc Mon Sep 17 00:00:00 2001 From: Aidan Beggs Date: Sat, 13 Jul 2024 18:35:50 -0700 Subject: [PATCH] Enable customizing Zstd decoding parameters. --- Cargo.toml | 6 ++- src/codec/zstd/decoder.rs | 10 +++++ src/macros.rs | 11 +++++ src/zstd.rs | 20 +++++++++ tests/artifacts/long-window-size-lib.rs.zst | Bin 0 -> 3369 bytes tests/zstd-window-size.rs | 45 ++++++++++++++++++++ 6 files changed, 91 insertions(+), 1 deletion(-) create mode 100644 tests/artifacts/long-window-size-lib.rs.zst create mode 100644 tests/zstd-window-size.rs diff --git a/Cargo.toml b/Cargo.toml index 171ecf2a..c1456a0c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,7 +37,7 @@ bzip2 = { version = "0.4.4", optional = true } flate2 = { version = "1.0.13", optional = true } futures-core = { version = "0.3", default-features = false } futures-io = { version = "0.3", default-features = false, features = ["std"], optional = true } -libzstd = { package = "zstd", version = "0.13", optional = true, default-features = false } +libzstd = { package = "zstd", version = "0.13.1", optional = true, default-features = false } memchr = "2" pin-project-lite = "0.2" tokio = { version = "1.24.2", optional = true, default-features = false } @@ -92,6 +92,10 @@ required-features = ["zstd"] name = "zstd-dict" required-features = ["zstd", "tokio"] +[[test]] +name = "zstd-window-size" +required-features = ["zstd", "tokio"] + [[example]] name = "zlib_tokio_write" required-features = ["zlib", "tokio"] diff --git a/src/codec/zstd/decoder.rs b/src/codec/zstd/decoder.rs index c2696799..d1a92f78 100644 --- a/src/codec/zstd/decoder.rs +++ b/src/codec/zstd/decoder.rs @@ -16,6 +16,16 @@ impl ZstdDecoder { } } + pub(crate) fn new_with_params(params: &[crate::zstd::DParameter]) -> Self { + let mut decoder = Decoder::new().unwrap(); + for param in params { + decoder.set_parameter(param.as_zstd()).unwrap(); + } + Self { + decoder: Unshared::new(decoder), + } + } + pub(crate) fn new_with_dict(dictionary: &[u8]) -> io::Result { let mut decoder = Decoder::with_dictionary(dictionary)?; Ok(Self { diff --git a/src/macros.rs b/src/macros.rs index 70844ad1..75909638 100644 --- a/src/macros.rs +++ b/src/macros.rs @@ -198,6 +198,17 @@ macro_rules! algos { } } { @dec + /// Creates a new decoder, using the specified parameters, which will read compressed + /// data from the given stream and emit a decompressed stream. + pub fn with_params(inner: $inner, params: &[crate::zstd::DParameter]) -> Self { + Self { + inner: crate::$($mod::)+generic::Decoder::new( + inner, + crate::codec::ZstdDecoder::new_with_params(params), + ), + } + } + /// Creates a new decoder, using the specified compression level and pre-trained /// dictionary, which will read compressed data from the given stream and emit an /// uncompressed stream. diff --git a/src/zstd.rs b/src/zstd.rs index 4c97bd5d..be40b94a 100644 --- a/src/zstd.rs +++ b/src/zstd.rs @@ -1,6 +1,7 @@ //! This module contains zstd-specific types for async-compression. use libzstd::stream::raw::CParameter::*; +use libzstd::stream::raw::DParameter::*; /// A compression parameter for zstd. This is a stable wrapper around zstd's own `CParameter` /// type, to abstract over different versions of the zstd library. @@ -110,3 +111,22 @@ impl CParameter { self.0 } } + +/// A decompression parameter for zstd. This is a stable wrapper around zstd's own `DParameter` +/// type, to abstract over different versions of the zstd library. +/// +/// See the [zstd documentation](https://facebook.github.io/zstd/zstd_manual.html) for more +/// information on these parameters. +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub struct DParameter(libzstd::stream::raw::DParameter); + +impl DParameter { + /// Maximum window size in bytes (as a power of two) + pub fn window_log_max(value: u32) -> Self { + Self(WindowLogMax(value)) + } + + pub(crate) fn as_zstd(&self) -> libzstd::stream::raw::DParameter { + self.0 + } +} diff --git a/tests/artifacts/long-window-size-lib.rs.zst b/tests/artifacts/long-window-size-lib.rs.zst new file mode 100644 index 0000000000000000000000000000000000000000..5bf51d5e299ffa8b73acd6a4bb03fd8e33460f06 GIT binary patch literal 3369 zcmV+^4c77~wJ-eysO4w?nwoSPFrW^Zx&ar6BVt8%pY6s$cao6W_^8t!CZxJ?)$@-K zyi`9(Q_o;cIgFR%#ly?73Wfoy0i*#88IhU!P&5`A9_u1~e=(gw(>y*#$vkH_Xk&r@ z`b-j2d_P7E|J89U=5WW+SZIKfdCkw()0z5wonJ}s5vFs@Kspx_7a6(MB*WLm)>XQV ztaGrsy|#QIY?d2QDqpErYdu{ZYeapiPo}M9DU{?!=^RV)S>g0};^4Lyo!=o#xm3nQ zMK&}@ov&GwdX{8c$ZF+A*`Llp4i)tbdvf}3t94ez>0E@LJ;(}~2|1@wI#h7jp2!ML z+sD#2D>j^R7gE1)sn6*#^%J;}&By6nw6Z_{sYI%fwA;wv*;NgtY`f3*buRV!9xKRo zk-p95VpE>6C&|2vDs9V2lxp5{wL6vYYjU0!)bsWbZ+d`%M#RmjLwyrXY)?;H~fMDC2 zL3p;V9vkcNX~?wFXm~sz`d`HaW;o8njl5=#)82DEJ~ozx;o<7oSI01Us!S5oXm~s} zmZh!77Zpi6r^zhNsq+1%7;SqdR(_?q89#qLPN$g?Wl9?p- z^hbSIH-&NNtXq3;@9mv}_I|bZD`H1M#Eyc9-4(IBASfB=>bN>4nBeNTI<&ehw7M*` zx|~**gP>#}i?0MMz7nwbnk>F12AbGcCvAhS>%rpd!QzX7b|QjeP}dPqHXyjI*`t5} zLF~R^w2_W<8~fvZT*3WqzKok1LHW#21h@{ zWQAM;*a$z<#{nq0DBo{mc`z*d3?*QeN7;ay?gLVUz~00kI8L+UWQANDN7;ZN4Tg+J z){S9kc8F3|OW;b{N=;WKeSZ?=5uhavuv8?0s;(B9#;?p;=|QpK@kCk#nZq=1qKGi0}~`%NIV%95Il%L zF#)kCaNlFdqU`h(KF<}g%ZOa^*F6Zp?fu$UC+(|~GBcS)fizHHKbAtti;L3#_?f_m zD3X1V^w=nZSd@2_97oDu=QtaW3zR${VsYRgnFb1Eao~VykP(mrA{<9@i;j4hAq+{jBDjCVGD^LI9e9o%)@wni@EDr*(5-aY|DxE=pwzVpCEGO=nv9l_60P>vE zV-BWKGM1RF)tmFH3qfr0GB zo^Yp$qx5C5HM9Z^Ad*|p%mlzSpOi}!EsUrah`NHPD@aZ6&$khtB#0VtxE!!0PxO|B zL1En*)^*)NT^Qwe?31!>gBqvQ3ieA~4z1idWeF{cKke!BP|p&_f%0*-W&G4i&Go1( z*I5FW`c9@nU-?3g6c(m4emWr=U{h0UNWCxR(MpK&=p-oH(}{!Wi!hyINjl(!;9%`u zl$ZwVDOEg)TCNMJZxzpYqVe;LTh;H8rM{&kIZ=@lsNCGCR7K}tOR>}ldohNDI%(Mm zKOH6iE%jVcygop=D)0(DvbT}K!U%gY2860v?8W{>*S^rXW(ix8Fkz9LzZ@)9&Elz= ztD41AH78+WdTrslqvx4kyT{jcA$v{JK%T43_LQUBWq(eOmpy&w<}kf>6c!oJ?*#Ri zC2n0M_0wCI|JGqL?JIl1=cwQKoi}B%hQ~41&L? z!fOt*3^6q$A|gqWqyv%&5Cp}-XcDv-9}++ii7bnPB!rAY1{p&}L_|bHL_{=AvK;Lv z0qf<~`4`p>_JJqdN0lI&#M>j23BIvTA47lg4d*lS9hFs>h>m{#D5G3s$z%*9N;QX1 z9sg;_&kM~FC;fZVanb}PX>J`aq%e%-jiPF+6$Mil-?mN_;Wf}W-O=AFx;3rhEwSsN zOzA8_2v&95V8r& zdnP*>ip`lFr^u0|U~1QcHle*z1IHfdZmh5?ys-^6_poFG%F|3DgR8fURA+N=wPlBp z-4sF(AW?BH^HURrL9#|4ro|YzrnK$O9~;?DXrSUy3L~18sY}WcdkG9QS_SBEenMlL zxy#^^;VbP0g=cyKiEj^tw)0i!$WJa2@m3XC^2DfOLbfTWNML-U!xgkDp* z#D>e{x0Pj48apu^2qklfTMRggzqAf-SUzy6!@ZTbm`JfX@~UeS0TjagWFRWK2Nc9F zNN}ZXM&q11q*R+@k|p0X?KTZ>DE}YjQ|A_5;aV;o$E_WPmthrShj7LhR$ZLXxf|3U z)->sw76QVR8cIy_Z6y8zWZ~9=M)ELKknTmb?qa>?T{DvtcGnx=@FRe%#^(v>xa>0U(bL6`O$tiCo0yd??SxT)! z_C5LJ#SpuzP%yhHc~*QXcD*p|A+YuA^gmIiEazY-ki!xQ8rg!PCt@7TgB}fOSgH@} z4hm{y((#{Q1Dt*3SB;pUd>g-M|B;^tU*X#^5fQwbPaYuhz?)FPFOPZab8&f{=%Lxy z>^_840~IJI1#2%2!g&be?ihh0pl8>JrbN&>mmzNva`m&(Ao^nm$g+I)7ZJr@Ehj`H+@D8#%uR}mIaEXDYOa<}L5J>~9Zf(JWsq+CkRwwX4y3mEeFsrj0 zTVebq*5TF+$w_b}M!5>iakLHzP|bIFmf)@S@95Kj*VL7Zug!aq=Ms5g)Vzx%u38fX zD)rU{ekNW&yK7*~GFg$#%InikpN-neBAt0csFtYmDzjg+6_NAvW7&*00m>4~sM=HS zQQ)vmKUnaK?!`7K8y6y>i`&ja^sT>+;(+ilW=ia*DmAu&2nwbU1&H8fe{7)FuwQAj zs(baxT{pr*9i>bXM~S&ute%P}b~uP7M$bx`(>WG*@sm%62pT@u1-2M+5X+BW$X`4+ zygx=e$zEd@y)X+7%SaEHR5=+^Yd`lrj?|NS?NVauuZjIlDC5X5_8KBi^KGHTi(aaO zzk-9#7L*fcx>fQ}4Bb+Xt!S{dc;)b}0Giu?wX7O%pc%n)YgF)kw;*~1*ZLi6HlN_a zTGPyhvT-&IB{ioPJUVMDwH*y4V5vDO-8V9Fc<==PM)vZEhhy?tKj%_euj}z~bPA>c zr=?q>f3?3ByrSFW3my17X~Xq_bKB8@1S~X1CHqE3ga=;&!e*};dALl>^D~NCKOxw> z>Db5@HxTike(DLdc~%#Q%g7^8vJiI%%Lsh(TqW04qbA@F*g-uDR6ar+ILE^kD6El( zE+h>3LfK({)}`{bMSlKrXYQ2`n$*a8EF2UV<5 z8rMq04h&EjnsHAg5j4;({MD@aO{}ai*8J_EWCFw>qUzk_Ig@|5o!Z30;RPQ(Pw^hu zQ1Rxp7CkNc_Ff&X%zSEBl z39SXn(4?5!2c3{tq!}m#r=lN*%?{7T{w52JQAOMsOwNyMctDHARNJmD60tirIS%vl zg2)`j3iYni7N|JF!$NWBX;3Tqr|PO4{yG&3Y5y?;(Q}OaX|RsLW9n`NGBb@k9tL%e literal 0 HcmV?d00001 diff --git a/tests/zstd-window-size.rs b/tests/zstd-window-size.rs new file mode 100644 index 00000000..7b5e6dec --- /dev/null +++ b/tests/zstd-window-size.rs @@ -0,0 +1,45 @@ +#![cfg(not(windows))] + +use async_compression::zstd::DParameter; +use tokio::io::AsyncWriteExt as _; + +#[tokio::test] +async fn zstd_decode_large_window_size_default() { + let compressed = include_bytes!("./artifacts/long-window-size-lib.rs.zst"); + + // Default decoder should throw with an error, window size maximum is too low. + let mut decoder = async_compression::tokio::write::ZstdDecoder::new(Vec::new()); + decoder.write_all(compressed).await.unwrap_err(); +} + +#[tokio::test] +async fn zstd_decode_large_window_size_explicit_small_window_size() { + let compressed = include_bytes!("./artifacts/long-window-size-lib.rs.zst"); + + // Short window decoder should throw with an error, window size maximum is too low. + let mut decoder = async_compression::tokio::write::ZstdDecoder::with_params( + Vec::new(), + &[DParameter::window_log_max(16)], + ); + decoder.write_all(compressed).await.unwrap_err(); +} + +#[tokio::test] +async fn zstd_decode_large_window_size_explicit_large_window_size() { + let compressed = include_bytes!("./artifacts/long-window-size-lib.rs.zst"); + let source = include_bytes!("./artifacts/lib.rs"); + + // Long window decoder should succeed as the window size is large enough to decompress the given input. + let mut long_window_size_decoder = async_compression::tokio::write::ZstdDecoder::with_params( + Vec::new(), + &[DParameter::window_log_max(31)], + ); + // Long window size decoder should successfully decode the given input data. + long_window_size_decoder + .write_all(compressed) + .await + .unwrap(); + long_window_size_decoder.shutdown().await.unwrap(); + + assert_eq!(long_window_size_decoder.into_inner(), source); +}