Skip to content

Commit

Permalink
Refine caching heuristics, support non-HTTPs requests
Browse files Browse the repository at this point in the history
  • Loading branch information
fasterthanlime committed Oct 13, 2024
1 parent 8897fc7 commit fc5718c
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 60 deletions.
11 changes: 11 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ categories = ["command-line-utilities", "network-programming"]
[dependencies]
byteorder = "1.5.0"
color-eyre = "0.6.3"
fs-err = { version = "2.11.0", features = ["tokio"] }
futures-util = "0.3.31"
http-body-util = "0.1.2"
http-serde = "2.1.1"
Expand Down
14 changes: 10 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,19 @@

An proof-of-concept(TM) caching HTTP forward proxy

Limitations:
## Limitations

* Will only accept to negotiate http/2 over TLS (via CONNECT) right now
* Will serve self-signed certificates, no way to export a CA cert so it
can be "installed" on whichever client talks to it right now
* Very naive rules to decide if something is cachable (see sources)
specifically, fopro DOES NOT RESPECT CACHE-CONTROL, VARY, ETC.
specifically, **fopro DOES NOT RESPECT `cache-control`, `vary`, ETC**.
* The cache is boundless (both in memory and on disk)
* Responses are buffered in memory completely before being proxied
(instead of being streamed)
* Partial responses (HTTP 206) are not cached at all.
* Really you shouldn't use fopro, it currently does the bare minimum
to get _most_ of the [uv](https://github.com/astral-sh/uv) test suite passing.

## Features

* Supports `CONNECT` requests
* Caches 200 responses in memory and on-disk
161 changes: 105 additions & 56 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use color_eyre::eyre;
use color_eyre::eyre::{self, Context};
use futures_util::future::BoxFuture;
use http_body_util::Full;
use http_body_util::{combinators::BoxBody, BodyExt, Full};
use hyper::{
body::{Body, Bytes},
server::conn,
Expand All @@ -13,7 +13,9 @@ use rcgen::DistinguishedName;
use serde::{Deserialize, Serialize};
use std::{
collections::HashMap,
convert::Infallible,
fmt::Debug,
net::IpAddr,
str::FromStr,
sync::{Arc, Mutex},
time::Instant,
Expand Down Expand Up @@ -133,35 +135,50 @@ struct UpgradeService {

impl<ReqBody> Service<Request<ReqBody>> for UpgradeService
where
ReqBody: Body + Debug + Send + 'static,
ReqBody: Body + Send + Sync + Debug + 'static,
<ReqBody as Body>::Data: Into<Bytes>,
<ReqBody as Body>::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
{
type Response = hyper::Response<String>;
type Response = hyper::Response<OurBody>;
type Error = hyper::Error;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;

fn call(&self, req: Request<ReqBody>) -> Self::Future {
let settings = self.settings.clone();
tracing::debug!("Got request {req:#?}");

Box::pin(async move {
if req.method() != Method::CONNECT {
return Ok(Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(String::from("we're a forward proxy, we only serve CONNECT"))
.unwrap());
}

let uri = req.uri().clone();
tracing::trace!("Got CONNECT to {uri}, headers = {:#?}", req.headers());
tracing::trace!(
"Got {} to {uri}, headers = {:#?}",
req.method(),
req.headers()
);

let host = match uri.host() {
Some(host) => host.to_string(),
None => {
return Ok(Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(String::from("expected host in CONNECT request"))
.body(Full::from("expected host in request").boxed())
.unwrap())
}
};

if req.method() != Method::CONNECT {
let service = ProxyService { host, settings };
return match service.proxy_request(req).await {
Ok(resp) => Ok(resp),
Err(e) => {
tracing::error!("Error proxying request: {e}");
Ok(Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Full::from(format!("{e}")).boxed())
.unwrap())
}
};
}

let on_upgrade = hyper::upgrade::on(req);
tokio::spawn(async move {
if let Err(e) = handle_upgraded_conn(on_upgrade, host, settings).await {
Expand All @@ -171,9 +188,7 @@ where

Ok(Response::builder()
.status(StatusCode::OK)
.body(String::from(
"you're connected, prepare to accept an invalid TLS cert because we're MITM'ing today"
))
.body(Full::from("you're connected, have you added fopro's CA cert as a root for your client??").boxed())
.unwrap())
})
}
Expand Down Expand Up @@ -243,6 +258,8 @@ async fn handle_upgraded_conn(
Ok(())
}

type OurBody = BoxBody<Bytes, Infallible>;

#[derive(Clone)]
struct ProxySettings {
/// the shared reqwest client for upstream
Expand Down Expand Up @@ -274,7 +291,7 @@ where
<ReqBody as Body>::Data: Into<Bytes>,
<ReqBody as Body>::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
{
type Response = Response<Full<Bytes>>;
type Response = Response<OurBody>;
type Error = hyper::Error;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;

Expand All @@ -288,7 +305,7 @@ where
tracing::error!("Error proxying request: {e}");
Ok(Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Bytes::copy_from_slice(format!("{e}").as_bytes()).into())
.body(Full::from(format!("{e}")).boxed())
.unwrap())
}
}
Expand All @@ -300,7 +317,7 @@ impl ProxyService {
async fn proxy_request<ReqBody>(
self,
req: Request<ReqBody>,
) -> Result<Response<Full<Bytes>>, eyre::Error>
) -> Result<Response<OurBody>, eyre::Error>
where
ReqBody: Body + Send + Sync + Debug + 'static,
<ReqBody as Body>::Data: Into<Bytes>,
Expand All @@ -316,12 +333,7 @@ impl ProxyService {
if uri_host != self.host {
return Ok(Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(
Bytes::copy_from_slice(
format!("expected host {}, got {uri_host}", self.host).as_bytes(),
)
.into(),
)
.body(Full::from(format!("expected host {}, got {uri_host}", self.host)).boxed())
.unwrap());
}

Expand All @@ -334,7 +346,7 @@ impl ProxyService {

let cache_key = format!(
"{}{}",
uri.host().unwrap_or_default(),
uri.authority().map(|a| a.as_str()).unwrap_or_default(),
uri.path_and_query()
.map(|pq| pq.as_str())
.unwrap_or_default()
Expand All @@ -344,24 +356,39 @@ impl ProxyService {
if cache_key.contains("..") {
cachable = false;
}
let cache_key = cache_key.strip_suffix('/').unwrap_or(&cache_key);
let cache_key = if cache_key.ends_with('/') {
format!("{cache_key}_INDEX_")
} else {
cache_key.to_string()
};
tracing::debug!("Cache key: {}", cache_key);

if uri.host().unwrap_or_default() == "github.com" {
// don't cache, probably a git clone, we don't know how to cache that yet
cachable = false;
if let Some(host) = uri.host() {
if host == "github.com" {
// don't cache, probably a git clone, we don't know how to cache that yet
tracing::debug!("Not caching github.com request");
cachable = false;
}
if IpAddr::from_str(host).is_ok() {
// don't cache, probably a temp local server for testing
tracing::debug!("Not caching {host} request (IP address)");
cachable = false;
}
}

if method != Method::GET {
// only cache GET requests for now
tracing::debug!("Not caching request with method {method}");
cachable = false;
}

if part.headers.contains_key(hyper::header::AUTHORIZATION) {
tracing::debug!("Not caching request with authorization header");
cachable = false;
}

let cache_dir = std::env::current_dir()?.join(".fopro-cache");
let cache_path_on_disk = cache_dir.join(cache_key);
let cache_path_on_disk = cache_dir.join(&cache_key);

if cachable {
enum Source {
Expand All @@ -373,21 +400,34 @@ impl ProxyService {
let mut source = Source::Unknown;
let mut maybe_entry_and_body: Option<(CacheEntry, Bytes)> = {
let entries = self.settings.imch.entries.lock().unwrap();
entries.get(cache_key).cloned()
entries.get(&cache_key).cloned()
};

if maybe_entry_and_body.is_some() {
source = Source::Memory;
} else {
match tokio::fs::File::open(&cache_path_on_disk).await {
match fs_err::tokio::File::open(&cache_path_on_disk).await {
Ok(mut file) => {
source = Source::Disk;

tracing::debug!("Cache hit: {}", cache_key);
let cache_entry = read_cache_entry(&mut file).await?;
let body = tokio::fs::read(&cache_path_on_disk).await?
[cache_entry.body_offset as usize..]
.to_vec();
let cache_entry =
read_cache_entry(&mut file).await.wrap_err_with(|| {
format!(
"Error reading cache entry for {cache_key} at {}",
cache_path_on_disk.display()
)
})?;
let body =
tokio::fs::read(&cache_path_on_disk)
.await
.wrap_err_with(|| {
format!(
"Error reading cache body for {cache_key} at {}",
cache_path_on_disk.display()
)
})?[cache_entry.body_offset as usize..]
.to_vec();
let body = Bytes::from(body);

{
Expand All @@ -398,13 +438,9 @@ impl ProxyService {

maybe_entry_and_body = Some((cache_entry, body));
}
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
Err(_) => {
// Cache miss, continue with the original request
}
Err(e) => {
// Unexpected error
return Err(e.into());
}
}
}

Expand Down Expand Up @@ -433,7 +469,7 @@ impl ProxyService {
tracing::info!("\x1b[32m[{hit_string}]\x1b[0m {status} {res_size}B {method} {uri} (read {read_elapsed:?})");
}

return Ok(res.body(body).unwrap());
return Ok(res.body(body.boxed()).unwrap());
}
}

Expand All @@ -453,7 +489,7 @@ impl ProxyService {
tracing::error!("Error sending request: {e}");
return Ok(Response::builder()
.status(StatusCode::BAD_GATEWAY)
.body(Bytes::copy_from_slice(format!("{e}").as_bytes()).into())
.body(Full::from(format!("{e}")).boxed())
.unwrap());
}
};
Expand All @@ -473,7 +509,7 @@ impl ProxyService {
tracing::error!("Error reading upstream response body: {e}");
return Ok(Response::builder()
.status(StatusCode::BAD_GATEWAY)
.body(Bytes::copy_from_slice(format!("{e}").as_bytes()).into())
.body(Full::from(format!("{e}")).boxed())
.unwrap());
}
};
Expand All @@ -487,7 +523,17 @@ impl ProxyService {
);
}

if cachable && status == StatusCode::OK {
if status == StatusCode::OK {
// cacheable, okay
} else if status.is_client_error() {
// also cacheable
} else {
// not cacheable — server error, partial resopnse, etc.
tracing::debug!("Not caching request with status {status}");
cachable = false;
}

if cachable {
// Create cache entry
let cache_entry = CacheEntry {
header: CacheHeader {
Expand All @@ -498,21 +544,24 @@ impl ProxyService {
};

// Write to a temporary file
tokio::fs::create_dir_all(cache_path_on_disk.parent().unwrap()).await?;
fs_err::tokio::create_dir_all(cache_path_on_disk.parent().unwrap()).await?;
let temp_file = cache_path_on_disk.with_extension("tmp");
let mut file = tokio::fs::File::create(&temp_file).await?;

// Write the cache entry
write_cache_entry(&mut file, cache_entry.clone()).await?;
{
let mut file = fs_err::tokio::File::create(&temp_file).await?;

// Write the cache entry
write_cache_entry(&mut file, cache_entry.clone()).await?;

// Write the body
file.write_all(&body).await?;
// Write the body
file.write_all(&body).await?;

// Ensure all data is written to disk
file.flush().await?;
// Ensure all data is written to disk
file.shutdown().await?;
}

// Rename the temporary file to the final cache file
tokio::fs::rename(temp_file, cache_path_on_disk).await?;
fs_err::tokio::rename(temp_file, cache_path_on_disk).await?;

{
let mut entries = self.settings.imch.entries.lock().unwrap();
Expand All @@ -522,7 +571,7 @@ impl ProxyService {

let mut res = Response::builder().status(status);
res.headers_mut().unwrap().extend(headers);
Ok(res.body(body.into()).unwrap())
Ok(res.body(Full::from(body).boxed()).unwrap())
}
}

Expand Down

0 comments on commit fc5718c

Please sign in to comment.