Skip to content

Commit

Permalink
wasi-http: Allow embedder to manage outgoing connections (#7297)
Browse files Browse the repository at this point in the history
This is a backport of #7288,
minus the tests.  The test suite has been refactored since the release-14.0.0
branch was created, so backporting the tests would be complicated and not worth
the effort.

Signed-off-by: Joel Dice <[email protected]>
  • Loading branch information
dicej authored Oct 19, 2023
1 parent 6943ed0 commit ce3d775
Show file tree
Hide file tree
Showing 2 changed files with 164 additions and 124 deletions.
132 changes: 10 additions & 122 deletions crates/wasi-http/src/http_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,14 @@ use crate::bindings::http::{
outgoing_handler,
types::{RequestOptions, Scheme},
};
use crate::types::{self, HostFutureIncomingResponse, IncomingResponseInternal};
use crate::types::{self, HostFutureIncomingResponse, OutgoingRequest};
use crate::WasiHttpView;
use anyhow::Context;
use bytes::Bytes;
use http_body_util::{BodyExt, Empty};
use hyper::Method;
use std::time::Duration;
use tokio::net::TcpStream;
use tokio::time::timeout;
use types::HostOutgoingRequest;
use wasmtime::component::Resource;
use wasmtime_wasi::preview2;

impl<T: WasiHttpView> outgoing_handler::Host for T {
fn handle(
Expand Down Expand Up @@ -90,122 +86,14 @@ impl<T: WasiHttpView> outgoing_handler::Host for T {
.boxed()
});

let request = builder.body(body).map_err(http_protocol_error)?;

let handle = preview2::spawn(async move {
let tcp_stream = TcpStream::connect(authority.clone())
.await
.map_err(invalid_url)?;

let (mut sender, worker) = if use_tls {
#[cfg(any(target_arch = "riscv64", target_arch = "s390x"))]
{
anyhow::bail!(crate::bindings::http::types::Error::UnexpectedError(
"unsupported architecture for SSL".to_string(),
));
}

#[cfg(not(any(target_arch = "riscv64", target_arch = "s390x")))]
{
use tokio_rustls::rustls::OwnedTrustAnchor;

// derived from https://github.com/tokio-rs/tls/blob/master/tokio-rustls/examples/client/src/main.rs
let mut root_cert_store = rustls::RootCertStore::empty();
root_cert_store.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map(
|ta| {
OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject,
ta.spki,
ta.name_constraints,
)
},
));
let config = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_cert_store)
.with_no_client_auth();
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(config));
let mut parts = authority.split(":");
let host = parts.next().unwrap_or(&authority);
let domain = rustls::ServerName::try_from(host)?;
let stream = connector.connect(domain, tcp_stream).await.map_err(|e| {
crate::bindings::http::types::Error::ProtocolError(e.to_string())
})?;

let (sender, conn) = timeout(
connect_timeout,
hyper::client::conn::http1::handshake(stream),
)
.await
.map_err(|_| timeout_error("connection"))??;

let worker = preview2::spawn(async move {
conn.await.context("hyper connection failed")?;
Ok::<_, anyhow::Error>(())
});

(sender, worker)
}
} else {
let (sender, conn) = timeout(
connect_timeout,
// TODO: we should plumb the builder through the http context, and use it here
hyper::client::conn::http1::handshake(tcp_stream),
)
.await
.map_err(|_| timeout_error("connection"))??;

let worker = preview2::spawn(async move {
conn.await.context("hyper connection failed")?;
Ok::<_, anyhow::Error>(())
});

(sender, worker)
};

let resp = timeout(first_byte_timeout, sender.send_request(request))
.await
.map_err(|_| timeout_error("first byte"))?
.map_err(hyper_protocol_error)?
.map(|body| body.map_err(|e| anyhow::anyhow!(e)).boxed());

Ok(IncomingResponseInternal {
resp,
worker,
between_bytes_timeout,
})
});

let fut = self
.table()
.push_resource(HostFutureIncomingResponse::new(handle))?;

Ok(Ok(fut))
let request = builder.body(body).map_err(types::http_protocol_error)?;
Ok(Ok(self.send_request(OutgoingRequest {
use_tls,
authority,
request,
connect_timeout,
first_byte_timeout,
between_bytes_timeout,
})?))
}
}

fn timeout_error(kind: &str) -> anyhow::Error {
anyhow::anyhow!(crate::bindings::http::types::Error::TimeoutError(format!(
"{kind} timed out"
)))
}

fn http_protocol_error(e: http::Error) -> anyhow::Error {
anyhow::anyhow!(crate::bindings::http::types::Error::ProtocolError(
e.to_string()
))
}

fn hyper_protocol_error(e: hyper::Error) -> anyhow::Error {
anyhow::anyhow!(crate::bindings::http::types::Error::ProtocolError(
e.to_string()
))
}

fn invalid_url(e: std::io::Error) -> anyhow::Error {
// TODO: DNS errors show up as a Custom io error, what subset of errors should we consider for
// InvalidUrl here?
anyhow::anyhow!(crate::bindings::http::types::Error::InvalidUrl(
e.to_string()
))
}
156 changes: 154 additions & 2 deletions crates/wasi-http/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,27 @@ use crate::{
bindings::http::types::{self, Method, Scheme},
body::{HostIncomingBodyBuilder, HyperIncomingBody, HyperOutgoingBody},
};
use anyhow::Context;
use http_body_util::BodyExt;
use std::any::Any;
use std::time::Duration;
use tokio::net::TcpStream;
use tokio::time::timeout;
use wasmtime::component::Resource;
use wasmtime_wasi::preview2::{AbortOnDropJoinHandle, Subscribe, Table};
use wasmtime_wasi::preview2::{self, AbortOnDropJoinHandle, Subscribe, Table};

/// Capture the state necessary for use in the wasi-http API implementation.
pub struct WasiHttpCtx;

pub struct OutgoingRequest {
pub use_tls: bool,
pub authority: String,
pub request: hyper::Request<HyperOutgoingBody>,
pub connect_timeout: Duration,
pub first_byte_timeout: Duration,
pub between_bytes_timeout: Duration,
}

pub trait WasiHttpView: Send {
fn ctx(&mut self) -> &mut WasiHttpCtx;
fn table(&mut self) -> &mut Table;
Expand Down Expand Up @@ -43,6 +57,144 @@ pub trait WasiHttpView: Send {
.push_resource(HostResponseOutparam { result })?;
Ok(id)
}

fn send_request(
&mut self,
request: OutgoingRequest,
) -> wasmtime::Result<Resource<HostFutureIncomingResponse>>
where
Self: Sized,
{
default_send_request(self, request)
}
}

pub fn default_send_request(
view: &mut dyn WasiHttpView,
OutgoingRequest {
use_tls,
authority,
request,
connect_timeout,
first_byte_timeout,
between_bytes_timeout,
}: OutgoingRequest,
) -> wasmtime::Result<Resource<HostFutureIncomingResponse>> {
let handle = preview2::spawn(async move {
let tcp_stream = TcpStream::connect(authority.clone())
.await
.map_err(invalid_url)?;

let (mut sender, worker) = if use_tls {
#[cfg(any(target_arch = "riscv64", target_arch = "s390x"))]
{
anyhow::bail!(crate::bindings::http::types::Error::UnexpectedError(
"unsupported architecture for SSL".to_string(),
));
}

#[cfg(not(any(target_arch = "riscv64", target_arch = "s390x")))]
{
use tokio_rustls::rustls::OwnedTrustAnchor;

// derived from https://github.com/tokio-rs/tls/blob/master/tokio-rustls/examples/client/src/main.rs
let mut root_cert_store = rustls::RootCertStore::empty();
root_cert_store.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map(
|ta| {
OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject,
ta.spki,
ta.name_constraints,
)
},
));
let config = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_cert_store)
.with_no_client_auth();
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(config));
let mut parts = authority.split(":");
let host = parts.next().unwrap_or(&authority);
let domain = rustls::ServerName::try_from(host)?;
let stream = connector.connect(domain, tcp_stream).await.map_err(|e| {
crate::bindings::http::types::Error::ProtocolError(e.to_string())
})?;

let (sender, conn) = timeout(
connect_timeout,
hyper::client::conn::http1::handshake(stream),
)
.await
.map_err(|_| timeout_error("connection"))??;

let worker = preview2::spawn(async move {
conn.await.context("hyper connection failed")?;
Ok::<_, anyhow::Error>(())
});

(sender, worker)
}
} else {
let (sender, conn) = timeout(
connect_timeout,
// TODO: we should plumb the builder through the http context, and use it here
hyper::client::conn::http1::handshake(tcp_stream),
)
.await
.map_err(|_| timeout_error("connection"))??;

let worker = preview2::spawn(async move {
conn.await.context("hyper connection failed")?;
Ok::<_, anyhow::Error>(())
});

(sender, worker)
};

let resp = timeout(first_byte_timeout, sender.send_request(request))
.await
.map_err(|_| timeout_error("first byte"))?
.map_err(hyper_protocol_error)?
.map(|body| body.map_err(|e| anyhow::anyhow!(e)).boxed());

Ok(IncomingResponseInternal {
resp,
worker,
between_bytes_timeout,
})
});

let fut = view
.table()
.push_resource(HostFutureIncomingResponse::new(handle))?;

Ok(fut)
}

pub fn timeout_error(kind: &str) -> anyhow::Error {
anyhow::anyhow!(crate::bindings::http::types::Error::TimeoutError(format!(
"{kind} timed out"
)))
}

pub fn http_protocol_error(e: http::Error) -> anyhow::Error {
anyhow::anyhow!(crate::bindings::http::types::Error::ProtocolError(
e.to_string()
))
}

pub fn hyper_protocol_error(e: hyper::Error) -> anyhow::Error {
anyhow::anyhow!(crate::bindings::http::types::Error::ProtocolError(
e.to_string()
))
}

fn invalid_url(e: std::io::Error) -> anyhow::Error {
// TODO: DNS errors show up as a Custom io error, what subset of errors should we consider for
// InvalidUrl here?
anyhow::anyhow!(crate::bindings::http::types::Error::InvalidUrl(
e.to_string()
))
}

pub struct HostIncomingRequest {
Expand Down Expand Up @@ -83,7 +235,7 @@ impl TryFrom<HostOutgoingResponse> for hyper::Response<HyperOutgoingBody> {
fn try_from(
resp: HostOutgoingResponse,
) -> Result<hyper::Response<HyperOutgoingBody>, Self::Error> {
use http_body_util::{BodyExt, Empty};
use http_body_util::Empty;

let mut builder = hyper::Response::builder().status(resp.status);

Expand Down

0 comments on commit ce3d775

Please sign in to comment.