Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[14.0.0] wasi-http: Allow embedder to manage outgoing connections #7297

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 10 additions & 122 deletions crates/wasi-http/src/http_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,14 @@ use crate::bindings::http::{
outgoing_handler,
types::{RequestOptions, Scheme},
};
use crate::types::{self, HostFutureIncomingResponse, IncomingResponseInternal};
use crate::types::{self, HostFutureIncomingResponse, OutgoingRequest};
use crate::WasiHttpView;
use anyhow::Context;
use bytes::Bytes;
use http_body_util::{BodyExt, Empty};
use hyper::Method;
use std::time::Duration;
use tokio::net::TcpStream;
use tokio::time::timeout;
use types::HostOutgoingRequest;
use wasmtime::component::Resource;
use wasmtime_wasi::preview2;

impl<T: WasiHttpView> outgoing_handler::Host for T {
fn handle(
Expand Down Expand Up @@ -90,122 +86,14 @@ impl<T: WasiHttpView> outgoing_handler::Host for T {
.boxed()
});

let request = builder.body(body).map_err(http_protocol_error)?;

let handle = preview2::spawn(async move {
let tcp_stream = TcpStream::connect(authority.clone())
.await
.map_err(invalid_url)?;

let (mut sender, worker) = if use_tls {
#[cfg(any(target_arch = "riscv64", target_arch = "s390x"))]
{
anyhow::bail!(crate::bindings::http::types::Error::UnexpectedError(
"unsupported architecture for SSL".to_string(),
));
}

#[cfg(not(any(target_arch = "riscv64", target_arch = "s390x")))]
{
use tokio_rustls::rustls::OwnedTrustAnchor;

// derived from https://github.com/tokio-rs/tls/blob/master/tokio-rustls/examples/client/src/main.rs
let mut root_cert_store = rustls::RootCertStore::empty();
root_cert_store.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map(
|ta| {
OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject,
ta.spki,
ta.name_constraints,
)
},
));
let config = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_cert_store)
.with_no_client_auth();
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(config));
let mut parts = authority.split(":");
let host = parts.next().unwrap_or(&authority);
let domain = rustls::ServerName::try_from(host)?;
let stream = connector.connect(domain, tcp_stream).await.map_err(|e| {
crate::bindings::http::types::Error::ProtocolError(e.to_string())
})?;

let (sender, conn) = timeout(
connect_timeout,
hyper::client::conn::http1::handshake(stream),
)
.await
.map_err(|_| timeout_error("connection"))??;

let worker = preview2::spawn(async move {
conn.await.context("hyper connection failed")?;
Ok::<_, anyhow::Error>(())
});

(sender, worker)
}
} else {
let (sender, conn) = timeout(
connect_timeout,
// TODO: we should plumb the builder through the http context, and use it here
hyper::client::conn::http1::handshake(tcp_stream),
)
.await
.map_err(|_| timeout_error("connection"))??;

let worker = preview2::spawn(async move {
conn.await.context("hyper connection failed")?;
Ok::<_, anyhow::Error>(())
});

(sender, worker)
};

let resp = timeout(first_byte_timeout, sender.send_request(request))
.await
.map_err(|_| timeout_error("first byte"))?
.map_err(hyper_protocol_error)?
.map(|body| body.map_err(|e| anyhow::anyhow!(e)).boxed());

Ok(IncomingResponseInternal {
resp,
worker,
between_bytes_timeout,
})
});

let fut = self
.table()
.push_resource(HostFutureIncomingResponse::new(handle))?;

Ok(Ok(fut))
let request = builder.body(body).map_err(types::http_protocol_error)?;
Ok(Ok(self.send_request(OutgoingRequest {
use_tls,
authority,
request,
connect_timeout,
first_byte_timeout,
between_bytes_timeout,
})?))
}
}

fn timeout_error(kind: &str) -> anyhow::Error {
anyhow::anyhow!(crate::bindings::http::types::Error::TimeoutError(format!(
"{kind} timed out"
)))
}

fn http_protocol_error(e: http::Error) -> anyhow::Error {
anyhow::anyhow!(crate::bindings::http::types::Error::ProtocolError(
e.to_string()
))
}

fn hyper_protocol_error(e: hyper::Error) -> anyhow::Error {
anyhow::anyhow!(crate::bindings::http::types::Error::ProtocolError(
e.to_string()
))
}

fn invalid_url(e: std::io::Error) -> anyhow::Error {
// TODO: DNS errors show up as a Custom io error, what subset of errors should we consider for
// InvalidUrl here?
anyhow::anyhow!(crate::bindings::http::types::Error::InvalidUrl(
e.to_string()
))
}
156 changes: 154 additions & 2 deletions crates/wasi-http/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,27 @@ use crate::{
bindings::http::types::{self, Method, Scheme},
body::{HostIncomingBodyBuilder, HyperIncomingBody, HyperOutgoingBody},
};
use anyhow::Context;
use http_body_util::BodyExt;
use std::any::Any;
use std::time::Duration;
use tokio::net::TcpStream;
use tokio::time::timeout;
use wasmtime::component::Resource;
use wasmtime_wasi::preview2::{AbortOnDropJoinHandle, Subscribe, Table};
use wasmtime_wasi::preview2::{self, AbortOnDropJoinHandle, Subscribe, Table};

/// Capture the state necessary for use in the wasi-http API implementation.
pub struct WasiHttpCtx;

pub struct OutgoingRequest {
pub use_tls: bool,
pub authority: String,
pub request: hyper::Request<HyperOutgoingBody>,
pub connect_timeout: Duration,
pub first_byte_timeout: Duration,
pub between_bytes_timeout: Duration,
}

pub trait WasiHttpView: Send {
fn ctx(&mut self) -> &mut WasiHttpCtx;
fn table(&mut self) -> &mut Table;
Expand Down Expand Up @@ -43,6 +57,144 @@ pub trait WasiHttpView: Send {
.push_resource(HostResponseOutparam { result })?;
Ok(id)
}

fn send_request(
&mut self,
request: OutgoingRequest,
) -> wasmtime::Result<Resource<HostFutureIncomingResponse>>
where
Self: Sized,
{
default_send_request(self, request)
}
}

pub fn default_send_request(
view: &mut dyn WasiHttpView,
OutgoingRequest {
use_tls,
authority,
request,
connect_timeout,
first_byte_timeout,
between_bytes_timeout,
}: OutgoingRequest,
) -> wasmtime::Result<Resource<HostFutureIncomingResponse>> {
let handle = preview2::spawn(async move {
let tcp_stream = TcpStream::connect(authority.clone())
.await
.map_err(invalid_url)?;

let (mut sender, worker) = if use_tls {
#[cfg(any(target_arch = "riscv64", target_arch = "s390x"))]
{
anyhow::bail!(crate::bindings::http::types::Error::UnexpectedError(
"unsupported architecture for SSL".to_string(),
));
}

#[cfg(not(any(target_arch = "riscv64", target_arch = "s390x")))]
{
use tokio_rustls::rustls::OwnedTrustAnchor;

// derived from https://github.com/tokio-rs/tls/blob/master/tokio-rustls/examples/client/src/main.rs
let mut root_cert_store = rustls::RootCertStore::empty();
root_cert_store.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map(
|ta| {
OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject,
ta.spki,
ta.name_constraints,
)
},
));
let config = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_cert_store)
.with_no_client_auth();
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(config));
let mut parts = authority.split(":");
let host = parts.next().unwrap_or(&authority);
let domain = rustls::ServerName::try_from(host)?;
let stream = connector.connect(domain, tcp_stream).await.map_err(|e| {
crate::bindings::http::types::Error::ProtocolError(e.to_string())
})?;

let (sender, conn) = timeout(
connect_timeout,
hyper::client::conn::http1::handshake(stream),
)
.await
.map_err(|_| timeout_error("connection"))??;

let worker = preview2::spawn(async move {
conn.await.context("hyper connection failed")?;
Ok::<_, anyhow::Error>(())
});

(sender, worker)
}
} else {
let (sender, conn) = timeout(
connect_timeout,
// TODO: we should plumb the builder through the http context, and use it here
hyper::client::conn::http1::handshake(tcp_stream),
)
.await
.map_err(|_| timeout_error("connection"))??;

let worker = preview2::spawn(async move {
conn.await.context("hyper connection failed")?;
Ok::<_, anyhow::Error>(())
});

(sender, worker)
};

let resp = timeout(first_byte_timeout, sender.send_request(request))
.await
.map_err(|_| timeout_error("first byte"))?
.map_err(hyper_protocol_error)?
.map(|body| body.map_err(|e| anyhow::anyhow!(e)).boxed());

Ok(IncomingResponseInternal {
resp,
worker,
between_bytes_timeout,
})
});

let fut = view
.table()
.push_resource(HostFutureIncomingResponse::new(handle))?;

Ok(fut)
}

pub fn timeout_error(kind: &str) -> anyhow::Error {
anyhow::anyhow!(crate::bindings::http::types::Error::TimeoutError(format!(
"{kind} timed out"
)))
}

pub fn http_protocol_error(e: http::Error) -> anyhow::Error {
anyhow::anyhow!(crate::bindings::http::types::Error::ProtocolError(
e.to_string()
))
}

pub fn hyper_protocol_error(e: hyper::Error) -> anyhow::Error {
anyhow::anyhow!(crate::bindings::http::types::Error::ProtocolError(
e.to_string()
))
}

fn invalid_url(e: std::io::Error) -> anyhow::Error {
// TODO: DNS errors show up as a Custom io error, what subset of errors should we consider for
// InvalidUrl here?
anyhow::anyhow!(crate::bindings::http::types::Error::InvalidUrl(
e.to_string()
))
}

pub struct HostIncomingRequest {
Expand Down Expand Up @@ -83,7 +235,7 @@ impl TryFrom<HostOutgoingResponse> for hyper::Response<HyperOutgoingBody> {
fn try_from(
resp: HostOutgoingResponse,
) -> Result<hyper::Response<HyperOutgoingBody>, Self::Error> {
use http_body_util::{BodyExt, Empty};
use http_body_util::Empty;

let mut builder = hyper::Response::builder().status(resp.status);

Expand Down