From 57f8ce752dedcb009f9f0ca6d408a77b459a2174 Mon Sep 17 00:00:00 2001 From: Joel Dice Date: Thu, 19 Oct 2023 10:28:50 -0600 Subject: [PATCH] wasi-http: Allow embedder to manage outgoing connections This is a backport of https://github.com/bytecodealliance/wasmtime/pull/7288, minus the tests. The test suite has been refactored since the release-14.0.0 branch was created, so backporting the tests would be complicated and not worth the effort. Signed-off-by: Joel Dice --- crates/wasi-http/src/http_impl.rs | 132 ++----------------------- crates/wasi-http/src/types.rs | 156 +++++++++++++++++++++++++++++- 2 files changed, 164 insertions(+), 124 deletions(-) diff --git a/crates/wasi-http/src/http_impl.rs b/crates/wasi-http/src/http_impl.rs index 2df2213134de..3a01f773da42 100644 --- a/crates/wasi-http/src/http_impl.rs +++ b/crates/wasi-http/src/http_impl.rs @@ -2,18 +2,14 @@ use crate::bindings::http::{ outgoing_handler, types::{RequestOptions, Scheme}, }; -use crate::types::{self, HostFutureIncomingResponse, IncomingResponseInternal}; +use crate::types::{self, HostFutureIncomingResponse, OutgoingRequest}; use crate::WasiHttpView; -use anyhow::Context; use bytes::Bytes; use http_body_util::{BodyExt, Empty}; use hyper::Method; use std::time::Duration; -use tokio::net::TcpStream; -use tokio::time::timeout; use types::HostOutgoingRequest; use wasmtime::component::Resource; -use wasmtime_wasi::preview2; impl outgoing_handler::Host for T { fn handle( @@ -90,122 +86,14 @@ impl outgoing_handler::Host for T { .boxed() }); - let request = builder.body(body).map_err(http_protocol_error)?; - - let handle = preview2::spawn(async move { - let tcp_stream = TcpStream::connect(authority.clone()) - .await - .map_err(invalid_url)?; - - let (mut sender, worker) = if use_tls { - #[cfg(any(target_arch = "riscv64", target_arch = "s390x"))] - { - anyhow::bail!(crate::bindings::http::types::Error::UnexpectedError( - "unsupported architecture for SSL".to_string(), - )); - } - - #[cfg(not(any(target_arch = "riscv64", target_arch = "s390x")))] - { - use tokio_rustls::rustls::OwnedTrustAnchor; - - // derived from https://github.com/tokio-rs/tls/blob/master/tokio-rustls/examples/client/src/main.rs - let mut root_cert_store = rustls::RootCertStore::empty(); - root_cert_store.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map( - |ta| { - OwnedTrustAnchor::from_subject_spki_name_constraints( - ta.subject, - ta.spki, - ta.name_constraints, - ) - }, - )); - let config = rustls::ClientConfig::builder() - .with_safe_defaults() - .with_root_certificates(root_cert_store) - .with_no_client_auth(); - let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(config)); - let mut parts = authority.split(":"); - let host = parts.next().unwrap_or(&authority); - let domain = rustls::ServerName::try_from(host)?; - let stream = connector.connect(domain, tcp_stream).await.map_err(|e| { - crate::bindings::http::types::Error::ProtocolError(e.to_string()) - })?; - - let (sender, conn) = timeout( - connect_timeout, - hyper::client::conn::http1::handshake(stream), - ) - .await - .map_err(|_| timeout_error("connection"))??; - - let worker = preview2::spawn(async move { - conn.await.context("hyper connection failed")?; - Ok::<_, anyhow::Error>(()) - }); - - (sender, worker) - } - } else { - let (sender, conn) = timeout( - connect_timeout, - // TODO: we should plumb the builder through the http context, and use it here - hyper::client::conn::http1::handshake(tcp_stream), - ) - .await - .map_err(|_| timeout_error("connection"))??; - - let worker = preview2::spawn(async move { - conn.await.context("hyper connection failed")?; - Ok::<_, anyhow::Error>(()) - }); - - (sender, worker) - }; - - let resp = timeout(first_byte_timeout, sender.send_request(request)) - .await - .map_err(|_| timeout_error("first byte"))? - .map_err(hyper_protocol_error)? - .map(|body| body.map_err(|e| anyhow::anyhow!(e)).boxed()); - - Ok(IncomingResponseInternal { - resp, - worker, - between_bytes_timeout, - }) - }); - - let fut = self - .table() - .push_resource(HostFutureIncomingResponse::new(handle))?; - - Ok(Ok(fut)) + let request = builder.body(body).map_err(types::http_protocol_error)?; + Ok(Ok(self.send_request(OutgoingRequest { + use_tls, + authority, + request, + connect_timeout, + first_byte_timeout, + between_bytes_timeout, + })?)) } } - -fn timeout_error(kind: &str) -> anyhow::Error { - anyhow::anyhow!(crate::bindings::http::types::Error::TimeoutError(format!( - "{kind} timed out" - ))) -} - -fn http_protocol_error(e: http::Error) -> anyhow::Error { - anyhow::anyhow!(crate::bindings::http::types::Error::ProtocolError( - e.to_string() - )) -} - -fn hyper_protocol_error(e: hyper::Error) -> anyhow::Error { - anyhow::anyhow!(crate::bindings::http::types::Error::ProtocolError( - e.to_string() - )) -} - -fn invalid_url(e: std::io::Error) -> anyhow::Error { - // TODO: DNS errors show up as a Custom io error, what subset of errors should we consider for - // InvalidUrl here? - anyhow::anyhow!(crate::bindings::http::types::Error::InvalidUrl( - e.to_string() - )) -} diff --git a/crates/wasi-http/src/types.rs b/crates/wasi-http/src/types.rs index 308b65283cb3..5176fa9add04 100644 --- a/crates/wasi-http/src/types.rs +++ b/crates/wasi-http/src/types.rs @@ -5,13 +5,27 @@ use crate::{ bindings::http::types::{self, Method, Scheme}, body::{HostIncomingBodyBuilder, HyperIncomingBody, HyperOutgoingBody}, }; +use anyhow::Context; +use http_body_util::BodyExt; use std::any::Any; +use std::time::Duration; +use tokio::net::TcpStream; +use tokio::time::timeout; use wasmtime::component::Resource; -use wasmtime_wasi::preview2::{AbortOnDropJoinHandle, Subscribe, Table}; +use wasmtime_wasi::preview2::{self, AbortOnDropJoinHandle, Subscribe, Table}; /// Capture the state necessary for use in the wasi-http API implementation. pub struct WasiHttpCtx; +pub struct OutgoingRequest { + pub use_tls: bool, + pub authority: String, + pub request: hyper::Request, + pub connect_timeout: Duration, + pub first_byte_timeout: Duration, + pub between_bytes_timeout: Duration, +} + pub trait WasiHttpView: Send { fn ctx(&mut self) -> &mut WasiHttpCtx; fn table(&mut self) -> &mut Table; @@ -43,6 +57,144 @@ pub trait WasiHttpView: Send { .push_resource(HostResponseOutparam { result })?; Ok(id) } + + fn send_request( + &mut self, + request: OutgoingRequest, + ) -> wasmtime::Result> + where + Self: Sized, + { + default_send_request(self, request) + } +} + +pub fn default_send_request( + view: &mut dyn WasiHttpView, + OutgoingRequest { + use_tls, + authority, + request, + connect_timeout, + first_byte_timeout, + between_bytes_timeout, + }: OutgoingRequest, +) -> wasmtime::Result> { + let handle = preview2::spawn(async move { + let tcp_stream = TcpStream::connect(authority.clone()) + .await + .map_err(invalid_url)?; + + let (mut sender, worker) = if use_tls { + #[cfg(any(target_arch = "riscv64", target_arch = "s390x"))] + { + anyhow::bail!(crate::bindings::http::types::Error::UnexpectedError( + "unsupported architecture for SSL".to_string(), + )); + } + + #[cfg(not(any(target_arch = "riscv64", target_arch = "s390x")))] + { + use tokio_rustls::rustls::OwnedTrustAnchor; + + // derived from https://github.com/tokio-rs/tls/blob/master/tokio-rustls/examples/client/src/main.rs + let mut root_cert_store = rustls::RootCertStore::empty(); + root_cert_store.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map( + |ta| { + OwnedTrustAnchor::from_subject_spki_name_constraints( + ta.subject, + ta.spki, + ta.name_constraints, + ) + }, + )); + let config = rustls::ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(root_cert_store) + .with_no_client_auth(); + let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(config)); + let mut parts = authority.split(":"); + let host = parts.next().unwrap_or(&authority); + let domain = rustls::ServerName::try_from(host)?; + let stream = connector.connect(domain, tcp_stream).await.map_err(|e| { + crate::bindings::http::types::Error::ProtocolError(e.to_string()) + })?; + + let (sender, conn) = timeout( + connect_timeout, + hyper::client::conn::http1::handshake(stream), + ) + .await + .map_err(|_| timeout_error("connection"))??; + + let worker = preview2::spawn(async move { + conn.await.context("hyper connection failed")?; + Ok::<_, anyhow::Error>(()) + }); + + (sender, worker) + } + } else { + let (sender, conn) = timeout( + connect_timeout, + // TODO: we should plumb the builder through the http context, and use it here + hyper::client::conn::http1::handshake(tcp_stream), + ) + .await + .map_err(|_| timeout_error("connection"))??; + + let worker = preview2::spawn(async move { + conn.await.context("hyper connection failed")?; + Ok::<_, anyhow::Error>(()) + }); + + (sender, worker) + }; + + let resp = timeout(first_byte_timeout, sender.send_request(request)) + .await + .map_err(|_| timeout_error("first byte"))? + .map_err(hyper_protocol_error)? + .map(|body| body.map_err(|e| anyhow::anyhow!(e)).boxed()); + + Ok(IncomingResponseInternal { + resp, + worker, + between_bytes_timeout, + }) + }); + + let fut = view + .table() + .push_resource(HostFutureIncomingResponse::new(handle))?; + + Ok(fut) +} + +pub fn timeout_error(kind: &str) -> anyhow::Error { + anyhow::anyhow!(crate::bindings::http::types::Error::TimeoutError(format!( + "{kind} timed out" + ))) +} + +pub fn http_protocol_error(e: http::Error) -> anyhow::Error { + anyhow::anyhow!(crate::bindings::http::types::Error::ProtocolError( + e.to_string() + )) +} + +pub fn hyper_protocol_error(e: hyper::Error) -> anyhow::Error { + anyhow::anyhow!(crate::bindings::http::types::Error::ProtocolError( + e.to_string() + )) +} + +fn invalid_url(e: std::io::Error) -> anyhow::Error { + // TODO: DNS errors show up as a Custom io error, what subset of errors should we consider for + // InvalidUrl here? + anyhow::anyhow!(crate::bindings::http::types::Error::InvalidUrl( + e.to_string() + )) } pub struct HostIncomingRequest { @@ -83,7 +235,7 @@ impl TryFrom for hyper::Response { fn try_from( resp: HostOutgoingResponse, ) -> Result, Self::Error> { - use http_body_util::{BodyExt, Empty}; + use http_body_util::Empty; let mut builder = hyper::Response::builder().status(resp.status);