Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(http): allow custom OutgoingHandler implementations #6272

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/wasi-http/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ readme = "readme.md"

[dependencies]
anyhow = { workspace = true }
async-trait = { workspace = true }
bytes = "1.1.0"
hyper = { version = "1.0.0-rc.3", features = ["full"] }
tokio = { version = "1", default-features = false, features = ["net", "rt-multi-thread", "time"] }
Expand Down
193 changes: 67 additions & 126 deletions crates/wasi-http/src/http_impl.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,15 @@
use crate::r#struct::ActiveResponse;
use crate::r#struct::{Stream, WasiHttp};
use crate::r#struct::{ActiveResponse, WasiHttp};
use crate::types::{RequestOptions, Scheme};
#[cfg(not(any(target_arch = "riscv64", target_arch = "s390x")))]
use anyhow::anyhow;
use anyhow::bail;
use bytes::{BufMut, Bytes, BytesMut};
use anyhow::{bail, Context};
use bytes::{BufMut, BytesMut};
use http::Uri;
use http_body_util::{BodyExt, Full};
use hyper::Method;
use hyper::Request;
use std::collections::HashMap;
#[cfg(not(any(target_arch = "riscv64", target_arch = "s390x")))]
use std::sync::Arc;
use std::time::Duration;
use tokio::net::TcpStream;
use tokio::runtime::Runtime;
use tokio::time::timeout;
#[cfg(not(any(target_arch = "riscv64", target_arch = "s390x")))]
use tokio_rustls::rustls::{self, OwnedTrustAnchor};

impl crate::default_outgoing_http::Host for WasiHttp {
fn handle(
Expand All @@ -42,39 +35,12 @@ impl crate::default_outgoing_http::Host for WasiHttp {
}
}

fn port_for_scheme(scheme: &Option<Scheme>) -> &str {
match scheme {
Some(s) => match s {
Scheme::Http => ":80",
Scheme::Https => ":443",
// This should never happen.
_ => panic!("unsupported scheme!"),
},
None => ":443",
}
}

impl WasiHttp {
async fn handle_async(
&mut self,
request_id: crate::default_outgoing_http::OutgoingRequest,
options: Option<crate::default_outgoing_http::RequestOptions>,
) -> wasmtime::Result<crate::default_outgoing_http::FutureIncomingResponse> {
let opts = options.unwrap_or(
// TODO: Configurable defaults here?
RequestOptions {
connect_timeout_ms: Some(600 * 1000),
first_byte_timeout_ms: Some(600 * 1000),
between_bytes_timeout_ms: Some(600 * 1000),
},
);
let connect_timeout =
Duration::from_millis(opts.connect_timeout_ms.unwrap_or(600 * 1000).into());
let first_bytes_timeout =
Duration::from_millis(opts.first_byte_timeout_ms.unwrap_or(600 * 1000).into());
let between_bytes_timeout =
Duration::from_millis(opts.between_bytes_timeout_ms.unwrap_or(600 * 1000).into());

let request = match self.requests.get(&request_id) {
Some(r) => r,
None => bail!("not found!"),
Expand All @@ -92,104 +58,66 @@ impl WasiHttp {
crate::types::Method::Patch => Method::PATCH,
_ => bail!("unknown method!"),
};
let mut uri = Uri::builder()
.authority(request.authority.as_str())
// NOTE: this is broken, but will be fixed by `wasi-http` dependency update
.path_and_query(request.path.to_owned() + &request.query);
match &request.scheme {
Some(Scheme::Http) => uri = uri.scheme("http"),
Some(Scheme::Https) => uri = uri.scheme("https"),
Some(scheme) => bail!("unsupported scheme `{scheme:?}`"),
_ => {}
}
// NOTE: This does not belong here, the complete struct should have been constructed
// on request creation
let uri = uri.build().context("failed to build URI")?;

let scheme = match request.scheme.as_ref().unwrap_or(&Scheme::Https) {
Scheme::Http => "http://",
Scheme::Https => "https://",
// TODO: this is wrong, fix this.
_ => panic!("Unsupported scheme!"),
};
let mut req = Request::builder()
.method(method)
.uri(uri)
.header(hyper::header::HOST, &request.authority);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we really be injecting headers on behalf of the caller?

Copy link
Contributor

@pchickey pchickey May 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, the responsibility of setting the Host header should be on the user of this interface, rather than on the implementer. It is both valid and frequently useful to set a different host header than the authority the network connection is made to

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree as well, this logic was simply moved from existing implementation, will update

for (key, val) in request.headers.iter() {
for item in val {
req = req.header(key, item.clone());
}
}
let body = self
.streams
.get(&request.body)
.map(|stream| stream.clone().into())
.unwrap_or_default();
let req = req.body(Full::new(body))?;

// Largely adapted from https://hyper.rs/guides/1/client/basic/
let authority = match request.authority.find(":") {
Some(_) => request.authority.clone(),
None => request.authority.clone() + port_for_scheme(&request.scheme),
let connect_timeout = if let Some(RequestOptions {
connect_timeout_ms: Some(connect_timeout_ms),
..
}) = options
{
Duration::from_millis(connect_timeout_ms.into())
} else {
// TODO: Configurable default
Duration::from_millis(600)
};
let mut sender = if scheme == "https://" {
#[cfg(not(any(target_arch = "riscv64", target_arch = "s390x")))]
{
let stream = TcpStream::connect(authority.clone()).await?;
//TODO: uncomment this code and make the tls implementation a feature decision.
//let connector = tokio_native_tls::native_tls::TlsConnector::builder().build()?;
//let connector = tokio_native_tls::TlsConnector::from(connector);
//let host = authority.split(":").next().unwrap_or(&authority);
//let stream = connector.connect(&host, stream).await?;

// derived from https://github.com/tokio-rs/tls/blob/master/tokio-rustls/examples/client/src/main.rs
let mut root_cert_store = rustls::RootCertStore::empty();
root_cert_store.add_server_trust_anchors(
webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject,
ta.spki,
ta.name_constraints,
)
}),
);
let config = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_cert_store)
.with_no_client_auth();
let connector = tokio_rustls::TlsConnector::from(Arc::new(config));
let mut parts = authority.split(":");
let host = parts.next().unwrap_or(&authority);
let domain =
rustls::ServerName::try_from(host).map_err(|_| anyhow!("invalid dnsname"))?;
let stream = connector.connect(domain, stream).await?;

let t = timeout(
connect_timeout,
hyper::client::conn::http1::handshake(stream),
)
.await?;
let (s, conn) = t?;
tokio::task::spawn(async move {
if let Err(err) = conn.await {
println!("Connection failed: {:?}", err);
}
});
s
}
#[cfg(any(target_arch = "riscv64", target_arch = "s390x"))]
bail!("unsupported architecture for SSL")
let first_byte_timeout = if let Some(RequestOptions {
first_byte_timeout_ms: Some(first_byte_timeout_ms),
..
}) = options
{
Duration::from_millis(first_byte_timeout_ms.into())
} else {
let tcp = TcpStream::connect(authority).await?;
let t = timeout(connect_timeout, hyper::client::conn::http1::handshake(tcp)).await?;
let (s, conn) = t?;
tokio::task::spawn(async move {
if let Err(err) = conn.await {
println!("Connection failed: {:?}", err);
}
});
s
// TODO: Configurable default
Duration::from_millis(600)
};

let url = scheme.to_owned() + &request.authority + &request.path + &request.query;

let mut call = Request::builder()
.method(method)
.uri(url)
.header(hyper::header::HOST, request.authority.as_str());

for (key, val) in request.headers.iter() {
for item in val {
call = call.header(key, item.clone());
}
}
let res = self
.outgoing_handler
.handle(req, connect_timeout, first_byte_timeout)
.await?;

let response_id = self.response_id_base;
self.response_id_base = self.response_id_base + 1;
let mut response = ActiveResponse::new(response_id);
let body = Full::<Bytes>::new(
self.streams
.get(&request.body)
.unwrap_or(&Stream::default())
.data
.clone()
.freeze(),
);
let t = timeout(first_bytes_timeout, sender.send_request(call.body(body)?)).await?;
let mut res = t?;
response.status = res.status().try_into()?;
for (key, value) in res.headers().iter() {
let mut vec = std::vec::Vec::new();
Expand All @@ -198,8 +126,21 @@ impl WasiHttp {
.response_headers
.insert(key.as_str().to_string(), vec);
}

let between_bytes_timeout = if let Some(RequestOptions {
between_bytes_timeout_ms: Some(between_bytes_timeout_ms),
..
}) = options
{
Duration::from_millis(between_bytes_timeout_ms.into())
} else {
// TODO: Configurable default
Duration::from_millis(600)
};
let body = res.into_body();
let mut body = body.lock().await;
let mut buf = BytesMut::new();
while let Some(next) = timeout(between_bytes_timeout, res.frame()).await? {
while let Some(next) = timeout(between_bytes_timeout, body.frame()).await? {
let frame = next?;
if let Some(chunk) = frame.data_ref() {
buf.put(chunk.clone());
Expand Down
Loading