Skip to content
This repository has been archived by the owner on Mar 27, 2023. It is now read-only.

Fix severe compression bug introduced by #97 #99

Merged
merged 8 commits into from
Sep 17, 2022
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 29 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 7 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,18 @@ edition = "2021"
clap = { version = "3", default-features = false, features = ["std", "cargo"] }
# Server
tokio = { version = "1", features = ["rt-multi-thread", "macros"] }
tokio-util = { version = "0.7", features = ["io"] }
hyper = { version = "0.14.20", features = ["http1", "server", "tcp", "stream"] }
headers = "0.3"
mime_guess = "2.0"
percent-encoding = "2.1"
# Compression
brotli = "3"
flate2 = "1"
async-compression = { version = "0.3.7", features = [
"brotli",
"deflate",
"gzip",
"tokio",
] }
# Rendering
tera = "1"
serde = { version = "1.0", features = [
Expand Down
74 changes: 42 additions & 32 deletions src/http/content_encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,17 @@
// except according to those terms.

use std::cmp::Ordering;
use std::io::{self, BufReader};
use std::io;

use flate2::read::{DeflateEncoder, GzEncoder};
use flate2::Compression;
use async_compression::{
tokio::bufread::{BrotliEncoder, DeflateEncoder, GzipEncoder},
Level,
};
use bytes::Bytes;
use futures::Stream;
use hyper::header::HeaderValue;
use hyper::Body;
use tokio_util::io::{ReaderStream, StreamReader};

pub const IDENTITY: &str = "identity";
pub const DEFLATE: &str = "deflate";
Expand Down Expand Up @@ -122,30 +128,28 @@ pub fn get_prior_encoding<'a>(accept_encoding: &'a HeaderValue) -> &'static str
.unwrap_or(IDENTITY)
}

/// Compress data.
/// Compress data stream.
///
/// # Parameters
///
/// * `data` - Data to be compressed.
/// * `encoding` - Only support `br`, `gzip`, `deflate` and `identity`.
pub fn compress(data: &[u8], encoding: &str) -> io::Result<Vec<u8>> {
use std::io::prelude::*;
let mut buf = Vec::new();
/// * `input` - [`futures::stream::Stream`](futures::stream::Stream) to be compressed e.g. [`hyper::body::Body`](hyper::body::Body).
henry40408 marked this conversation as resolved.
Show resolved Hide resolved
/// * `encoding` - Only support `br`, `deflate`, `gzip` and `identity`.
pub fn compress_stream(
henry40408 marked this conversation as resolved.
Show resolved Hide resolved
input: impl Stream<Item = io::Result<Bytes>> + Send + 'static,
encoding: &str,
) -> io::Result<hyper::Body> {
match encoding {
BR => {
BufReader::new(brotli::CompressorReader::new(data, 4096, 6, 20))
.read_to_end(&mut buf)?;
}
GZIP => {
BufReader::new(GzEncoder::new(data, Compression::default())).read_to_end(&mut buf)?;
}
DEFLATE => {
BufReader::new(DeflateEncoder::new(data, Compression::default()))
.read_to_end(&mut buf)?;
}
_ => return Err(io::Error::new(io::ErrorKind::Other, "Unsupported Encoding")),
};
Ok(buf)
BR => Ok(Body::wrap_stream(ReaderStream::new(
BrotliEncoder::with_quality(StreamReader::new(input), Level::Fastest),
))),
DEFLATE => Ok(Body::wrap_stream(ReaderStream::new(DeflateEncoder::new(
StreamReader::new(input),
)))),
GZIP => Ok(Body::wrap_stream(ReaderStream::new(GzipEncoder::new(
StreamReader::new(input),
)))),
_ => Err(io::Error::new(io::ErrorKind::Other, "Unsupported Encoding")),
}
}

pub fn should_compress(enc: &str) -> bool {
Expand Down Expand Up @@ -251,17 +255,23 @@ mod t_compress {

#[test]
fn failed() {
let error = compress(b"hello", "unrecognized").unwrap_err();
let s = futures::stream::iter(vec![Ok::<_, io::Error>(Bytes::from_static(b"hello"))]);
let error = compress_stream(s, "unrecognized").unwrap_err();
assert_eq!(error.kind(), io::ErrorKind::Other);
}

#[test]
fn compressed() {
let buf = compress(b"xxxxx", DEFLATE);
assert!(!buf.unwrap().is_empty());
let buf = compress(b"xxxxx", GZIP);
assert!(!buf.unwrap().is_empty());
let buf = compress(b"xxxxx", BR);
assert!(!buf.unwrap().is_empty());
#[tokio::test]
async fn compressed() {
let s = futures::stream::iter(vec![Ok::<_, io::Error>(Bytes::from_static(b"xxxxx"))]);
let body = compress_stream(s, BR).unwrap();
assert_eq!(hyper::body::to_bytes(body).await.unwrap().len(), 9);

let s = futures::stream::iter(vec![Ok::<_, io::Error>(Bytes::from_static(b"xxxxx"))]);
let body = compress_stream(s, DEFLATE).unwrap();
assert_eq!(hyper::body::to_bytes(body).await.unwrap().len(), 5);

let s = futures::stream::iter(vec![Ok::<_, io::Error>(Bytes::from_static(b"xxxxx"))]);
let body = compress_stream(s, GZIP).unwrap();
assert_eq!(hyper::body::to_bytes(body).await.unwrap().len(), 23);
}
}
17 changes: 7 additions & 10 deletions src/server/serve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use std::sync::Arc;
use std::time::Duration;

use chrono::Local;
use futures::{StreamExt as _, TryStreamExt as _};
use futures::TryStreamExt as _;
use headers::{
AcceptRanges, AccessControlAllowHeaders, AccessControlAllowOrigin, CacheControl, ContentLength,
ContentType, ETag, HeaderMapExt, LastModified, Range, Server,
Expand All @@ -32,7 +32,7 @@ use serde::Serialize;
use crate::cli::Args;
use crate::extensions::{MimeExt, PathExt, SystemTimeExt};
use crate::http::conditional_requests::{is_fresh, is_precondition_failed};
use crate::http::content_encoding::{compress, get_prior_encoding, should_compress};
use crate::http::content_encoding::{compress_stream, get_prior_encoding, should_compress};
use crate::http::range_requests::{is_range_fresh, is_satisfiable_range};

use crate::server::send::{send_dir, send_dir_as_zip, send_file, send_file_with_range};
Expand Down Expand Up @@ -411,13 +411,10 @@ impl InnerService {
if let Some(encoding) = req.headers().get(hyper::header::ACCEPT_ENCODING) {
let content_encoding = get_prior_encoding(encoding);
if should_compress(content_encoding) {
let stream = body
.map_err(|_e| io::Error::from(io::ErrorKind::InvalidData))
.map(|b| match b {
Ok(b) => compress(&b, content_encoding),
Err(e) => Err(e),
});
let body = Body::wrap_stream(stream);
let b = compress_stream(
body.map_err(|e| io::Error::new(io::ErrorKind::Other, e)),
content_encoding,
)?;
res.headers_mut().insert(
hyper::header::CONTENT_ENCODING,
hyper::header::HeaderValue::from_static(content_encoding),
Expand All @@ -427,7 +424,7 @@ impl InnerService {
hyper::header::VARY,
hyper::header::HeaderValue::from_name(hyper::header::ACCEPT_ENCODING),
);
body
b
} else {
body
}
Expand Down