From fc5718c93f8c32fa339315cb82be087feaca97ca Mon Sep 17 00:00:00 2001 From: Amos Wenger Date: Sun, 13 Oct 2024 19:35:35 +0200 Subject: [PATCH] Refine caching heuristics, support non-HTTPs requests --- Cargo.lock | 11 ++++ Cargo.toml | 1 + README.md | 14 +++-- src/main.rs | 161 ++++++++++++++++++++++++++++++++++------------------ 4 files changed, 127 insertions(+), 60 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a6f0ac2..f2c2be2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -224,6 +224,7 @@ version = "1.0.2" dependencies = [ "byteorder", "color-eyre", + "fs-err", "futures-util", "http-body-util", "http-serde", @@ -251,6 +252,16 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "fs-err" +version = "2.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88a41f105fe1d5b6b34b2055e3dc59bb79b46b48b2040b9e6c7b4b5de097aa41" +dependencies = [ + "autocfg", + "tokio", +] + [[package]] name = "futures-channel" version = "0.3.31" diff --git a/Cargo.toml b/Cargo.toml index 9989de9..9bc895f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,6 +13,7 @@ categories = ["command-line-utilities", "network-programming"] [dependencies] byteorder = "1.5.0" color-eyre = "0.6.3" +fs-err = { version = "2.11.0", features = ["tokio"] } futures-util = "0.3.31" http-body-util = "0.1.2" http-serde = "2.1.1" diff --git a/README.md b/README.md index e466c8a..510d2c7 100644 --- a/README.md +++ b/README.md @@ -6,13 +6,19 @@ An proof-of-concept(TM) caching HTTP forward proxy -Limitations: +## Limitations * Will only accept to negotiate http/2 over TLS (via CONNECT) right now - * Will serve self-signed certificates, no way to export a CA cert so it - can be "installed" on whichever client talks to it right now * Very naive rules to decide if something is cachable (see sources) - specifically, fopro DOES NOT RESPECT CACHE-CONTROL, VARY, ETC. + specifically, **fopro DOES NOT RESPECT `cache-control`, `vary`, ETC**. * The cache is boundless (both in memory and on disk) + * Responses are buffered in memory completely before being proxied + (instead of being streamed) + * Partial responses (HTTP 206) are not cached at all. * Really you shouldn't use fopro, it currently does the bare minimum to get _most_ of the [uv](https://github.com/astral-sh/uv) test suite passing. + +## Features + + * Supports `CONNECT` requests + * Caches 200 responses in memory and on-disk diff --git a/src/main.rs b/src/main.rs index 2d75e3a..92ffd02 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,6 @@ -use color_eyre::eyre; +use color_eyre::eyre::{self, Context}; use futures_util::future::BoxFuture; -use http_body_util::Full; +use http_body_util::{combinators::BoxBody, BodyExt, Full}; use hyper::{ body::{Body, Bytes}, server::conn, @@ -13,7 +13,9 @@ use rcgen::DistinguishedName; use serde::{Deserialize, Serialize}; use std::{ collections::HashMap, + convert::Infallible, fmt::Debug, + net::IpAddr, str::FromStr, sync::{Arc, Mutex}, time::Instant, @@ -133,35 +135,50 @@ struct UpgradeService { impl Service> for UpgradeService where - ReqBody: Body + Debug + Send + 'static, + ReqBody: Body + Send + Sync + Debug + 'static, + ::Data: Into, + ::Error: Into>, { - type Response = hyper::Response; + type Response = hyper::Response; type Error = hyper::Error; type Future = BoxFuture<'static, Result>; fn call(&self, req: Request) -> Self::Future { let settings = self.settings.clone(); + tracing::debug!("Got request {req:#?}"); Box::pin(async move { - if req.method() != Method::CONNECT { - return Ok(Response::builder() - .status(StatusCode::BAD_REQUEST) - .body(String::from("we're a forward proxy, we only serve CONNECT")) - .unwrap()); - } - let uri = req.uri().clone(); - tracing::trace!("Got CONNECT to {uri}, headers = {:#?}", req.headers()); + tracing::trace!( + "Got {} to {uri}, headers = {:#?}", + req.method(), + req.headers() + ); let host = match uri.host() { Some(host) => host.to_string(), None => { return Ok(Response::builder() .status(StatusCode::BAD_REQUEST) - .body(String::from("expected host in CONNECT request")) + .body(Full::from("expected host in request").boxed()) .unwrap()) } }; + + if req.method() != Method::CONNECT { + let service = ProxyService { host, settings }; + return match service.proxy_request(req).await { + Ok(resp) => Ok(resp), + Err(e) => { + tracing::error!("Error proxying request: {e}"); + Ok(Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .body(Full::from(format!("{e}")).boxed()) + .unwrap()) + } + }; + } + let on_upgrade = hyper::upgrade::on(req); tokio::spawn(async move { if let Err(e) = handle_upgraded_conn(on_upgrade, host, settings).await { @@ -171,9 +188,7 @@ where Ok(Response::builder() .status(StatusCode::OK) - .body(String::from( - "you're connected, prepare to accept an invalid TLS cert because we're MITM'ing today" - )) + .body(Full::from("you're connected, have you added fopro's CA cert as a root for your client??").boxed()) .unwrap()) }) } @@ -243,6 +258,8 @@ async fn handle_upgraded_conn( Ok(()) } +type OurBody = BoxBody; + #[derive(Clone)] struct ProxySettings { /// the shared reqwest client for upstream @@ -274,7 +291,7 @@ where ::Data: Into, ::Error: Into>, { - type Response = Response>; + type Response = Response; type Error = hyper::Error; type Future = BoxFuture<'static, Result>; @@ -288,7 +305,7 @@ where tracing::error!("Error proxying request: {e}"); Ok(Response::builder() .status(StatusCode::INTERNAL_SERVER_ERROR) - .body(Bytes::copy_from_slice(format!("{e}").as_bytes()).into()) + .body(Full::from(format!("{e}")).boxed()) .unwrap()) } } @@ -300,7 +317,7 @@ impl ProxyService { async fn proxy_request( self, req: Request, - ) -> Result>, eyre::Error> + ) -> Result, eyre::Error> where ReqBody: Body + Send + Sync + Debug + 'static, ::Data: Into, @@ -316,12 +333,7 @@ impl ProxyService { if uri_host != self.host { return Ok(Response::builder() .status(StatusCode::BAD_REQUEST) - .body( - Bytes::copy_from_slice( - format!("expected host {}, got {uri_host}", self.host).as_bytes(), - ) - .into(), - ) + .body(Full::from(format!("expected host {}, got {uri_host}", self.host)).boxed()) .unwrap()); } @@ -334,7 +346,7 @@ impl ProxyService { let cache_key = format!( "{}{}", - uri.host().unwrap_or_default(), + uri.authority().map(|a| a.as_str()).unwrap_or_default(), uri.path_and_query() .map(|pq| pq.as_str()) .unwrap_or_default() @@ -344,24 +356,39 @@ impl ProxyService { if cache_key.contains("..") { cachable = false; } - let cache_key = cache_key.strip_suffix('/').unwrap_or(&cache_key); + let cache_key = if cache_key.ends_with('/') { + format!("{cache_key}_INDEX_") + } else { + cache_key.to_string() + }; tracing::debug!("Cache key: {}", cache_key); - if uri.host().unwrap_or_default() == "github.com" { - // don't cache, probably a git clone, we don't know how to cache that yet - cachable = false; + if let Some(host) = uri.host() { + if host == "github.com" { + // don't cache, probably a git clone, we don't know how to cache that yet + tracing::debug!("Not caching github.com request"); + cachable = false; + } + if IpAddr::from_str(host).is_ok() { + // don't cache, probably a temp local server for testing + tracing::debug!("Not caching {host} request (IP address)"); + cachable = false; + } } + if method != Method::GET { // only cache GET requests for now + tracing::debug!("Not caching request with method {method}"); cachable = false; } if part.headers.contains_key(hyper::header::AUTHORIZATION) { + tracing::debug!("Not caching request with authorization header"); cachable = false; } let cache_dir = std::env::current_dir()?.join(".fopro-cache"); - let cache_path_on_disk = cache_dir.join(cache_key); + let cache_path_on_disk = cache_dir.join(&cache_key); if cachable { enum Source { @@ -373,21 +400,34 @@ impl ProxyService { let mut source = Source::Unknown; let mut maybe_entry_and_body: Option<(CacheEntry, Bytes)> = { let entries = self.settings.imch.entries.lock().unwrap(); - entries.get(cache_key).cloned() + entries.get(&cache_key).cloned() }; if maybe_entry_and_body.is_some() { source = Source::Memory; } else { - match tokio::fs::File::open(&cache_path_on_disk).await { + match fs_err::tokio::File::open(&cache_path_on_disk).await { Ok(mut file) => { source = Source::Disk; tracing::debug!("Cache hit: {}", cache_key); - let cache_entry = read_cache_entry(&mut file).await?; - let body = tokio::fs::read(&cache_path_on_disk).await? - [cache_entry.body_offset as usize..] - .to_vec(); + let cache_entry = + read_cache_entry(&mut file).await.wrap_err_with(|| { + format!( + "Error reading cache entry for {cache_key} at {}", + cache_path_on_disk.display() + ) + })?; + let body = + tokio::fs::read(&cache_path_on_disk) + .await + .wrap_err_with(|| { + format!( + "Error reading cache body for {cache_key} at {}", + cache_path_on_disk.display() + ) + })?[cache_entry.body_offset as usize..] + .to_vec(); let body = Bytes::from(body); { @@ -398,13 +438,9 @@ impl ProxyService { maybe_entry_and_body = Some((cache_entry, body)); } - Err(e) if e.kind() == std::io::ErrorKind::NotFound => { + Err(_) => { // Cache miss, continue with the original request } - Err(e) => { - // Unexpected error - return Err(e.into()); - } } } @@ -433,7 +469,7 @@ impl ProxyService { tracing::info!("\x1b[32m[{hit_string}]\x1b[0m {status} {res_size}B {method} {uri} (read {read_elapsed:?})"); } - return Ok(res.body(body).unwrap()); + return Ok(res.body(body.boxed()).unwrap()); } } @@ -453,7 +489,7 @@ impl ProxyService { tracing::error!("Error sending request: {e}"); return Ok(Response::builder() .status(StatusCode::BAD_GATEWAY) - .body(Bytes::copy_from_slice(format!("{e}").as_bytes()).into()) + .body(Full::from(format!("{e}")).boxed()) .unwrap()); } }; @@ -473,7 +509,7 @@ impl ProxyService { tracing::error!("Error reading upstream response body: {e}"); return Ok(Response::builder() .status(StatusCode::BAD_GATEWAY) - .body(Bytes::copy_from_slice(format!("{e}").as_bytes()).into()) + .body(Full::from(format!("{e}")).boxed()) .unwrap()); } }; @@ -487,7 +523,17 @@ impl ProxyService { ); } - if cachable && status == StatusCode::OK { + if status == StatusCode::OK { + // cacheable, okay + } else if status.is_client_error() { + // also cacheable + } else { + // not cacheable — server error, partial resopnse, etc. + tracing::debug!("Not caching request with status {status}"); + cachable = false; + } + + if cachable { // Create cache entry let cache_entry = CacheEntry { header: CacheHeader { @@ -498,21 +544,24 @@ impl ProxyService { }; // Write to a temporary file - tokio::fs::create_dir_all(cache_path_on_disk.parent().unwrap()).await?; + fs_err::tokio::create_dir_all(cache_path_on_disk.parent().unwrap()).await?; let temp_file = cache_path_on_disk.with_extension("tmp"); - let mut file = tokio::fs::File::create(&temp_file).await?; - // Write the cache entry - write_cache_entry(&mut file, cache_entry.clone()).await?; + { + let mut file = fs_err::tokio::File::create(&temp_file).await?; + + // Write the cache entry + write_cache_entry(&mut file, cache_entry.clone()).await?; - // Write the body - file.write_all(&body).await?; + // Write the body + file.write_all(&body).await?; - // Ensure all data is written to disk - file.flush().await?; + // Ensure all data is written to disk + file.shutdown().await?; + } // Rename the temporary file to the final cache file - tokio::fs::rename(temp_file, cache_path_on_disk).await?; + fs_err::tokio::rename(temp_file, cache_path_on_disk).await?; { let mut entries = self.settings.imch.entries.lock().unwrap(); @@ -522,7 +571,7 @@ impl ProxyService { let mut res = Response::builder().status(status); res.headers_mut().unwrap().extend(headers); - Ok(res.body(body.into()).unwrap()) + Ok(res.body(Full::from(body).boxed()).unwrap()) } }