Skip to content

Commit

Permalink
Actually stream the response back
Browse files Browse the repository at this point in the history
We have to use the async version of reqwest to call
Response::bytes_stream() to do actual streaming. This requires
explicitly bringing in tokio and making various functions async. Since
we're just firing off one request, set up the single-threaded flavor of
tokio instead of spawning multiple.

Also implement serializing all headers per discussion with @cfm.
  • Loading branch information
legoktm committed Jan 18, 2024
1 parent 0541715 commit 31fdf38
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 60 deletions.
78 changes: 51 additions & 27 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion proxy/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ edition = "2021"

[dependencies]
anyhow = {version = "1.0.75"}
reqwest = {version = "0.11.20", features = ["blocking", "gzip"]}
futures-util = "0.3.30"
reqwest = {version = "0.11.20", features = ["gzip", "stream"]}
serde = {version = "1.0.188", features = ["derive"]}
serde_json = "1.0.107"
tokio = {version = "1.0", features = ["macros", "rt"]}
url = "2.4.1"
63 changes: 31 additions & 32 deletions proxy/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
use anyhow::{anyhow, bail, Result};
use reqwest::blocking::{Client, Response};
use reqwest::header::{self, HeaderMap};
use anyhow::{bail, Result};
use futures_util::StreamExt;
use reqwest::header::HeaderMap;
use reqwest::Method;
use reqwest::{Client, Response};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::io::Write;
use std::process::ExitCode;
use std::str::FromStr;
use std::time::Duration;
Expand Down Expand Up @@ -42,8 +44,7 @@ struct OutgoingResponse {
/// In a streamed response, we emit the checksum after the streaming finishes
#[derive(Serialize, Debug)]
struct StreamMetadataResponse {
sha256sum: String,
filename: String,
headers: HashMap<String, String>,
}

/// Serialization format for errors, always over stderr
Expand All @@ -52,48 +53,45 @@ struct ErrorResponse {
error: String,
}

fn handle_json_response(resp: Response) -> Result<()> {
fn headers_to_map(resp: &Response) -> Result<HashMap<String, String>> {
let mut headers = HashMap::new();
for (name, value) in resp.headers() {
headers.insert(name.to_string(), value.to_str()?.to_string());
}
Ok(headers)
}

async fn handle_json_response(resp: Response) -> Result<()> {
let headers = headers_to_map(&resp)?;
let outgoing_response = OutgoingResponse {
status: resp.status().as_u16(),
headers,
body: resp.text()?,
body: resp.text().await?,
};
println!("{}", serde_json::to_string(&outgoing_response)?);
Ok(())
}

fn handle_stream_response(mut resp: Response) -> Result<()> {
async fn handle_stream_response(resp: Response) -> Result<()> {
// Get the headers, will be output later but we want to fail early if it's missing/invalid
let etag = resp
.headers()
.get(header::ETAG)
.ok_or_else(|| anyhow!("missing etag header"))?
.to_str()?
.to_string();
let content_disposition = resp
.headers()
.get(header::CONTENT_DISPOSITION)
.ok_or_else(|| anyhow!("missing content-disposition"))?
.to_str()?
.to_string();
let headers = headers_to_map(&resp)?;
let mut stdout = io::stdout().lock();
resp.copy_to(&mut stdout)?;
// Emit the checksum as stderr
let mut stream = resp.bytes_stream();
// Stream the response, printing bytes as we receive them
while let Some(item) = stream.next().await {
stdout.write_all(&item?)?;
// TODO: can we flush at smarter intervals?
stdout.flush()?;
}
// Emit the headers as stderr
eprintln!(
"{}",
serde_json::to_string(&StreamMetadataResponse {
sha256sum: etag,
filename: content_disposition
})?
serde_json::to_string(&StreamMetadataResponse { headers })?
);
Ok(())
}

fn proxy() -> Result<()> {
async fn proxy() -> Result<()> {
// Get the hostname from the environment
let origin = env::var(ENV_CONFIG)?;
// Read incoming request from stdin (must be on single line)
Expand Down Expand Up @@ -122,22 +120,23 @@ fn proxy() -> Result<()> {
req = req.body(body);
}
// Fire off the request!
let resp = req.send()?;
let resp = req.send().await?;
// We return the output in two ways, either a JSON blob or stream the output.
// JSON is used for HTTP 4xx, 5xx, and all non-stream requests.
if !incoming_request.stream
|| resp.status().is_client_error()
|| resp.status().is_server_error()
{
handle_json_response(resp)?;
handle_json_response(resp).await?;
} else {
handle_stream_response(resp)?;
handle_stream_response(resp).await?;
}
Ok(())
}

fn main() -> ExitCode {
match proxy() {
#[tokio::main(flavor = "current_thread")]
async fn main() -> ExitCode {
match proxy().await {
Ok(()) => ExitCode::SUCCESS,
Err(err) => {
// Try to serialize into our error format
Expand Down

0 comments on commit 31fdf38

Please sign in to comment.