diff --git a/Cargo.lock b/Cargo.lock index 7b45f8915c33..ff9a25c755aa 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2668,8 +2668,12 @@ name = "test-programs" version = "0.0.0" dependencies = [ "anyhow", + "base64", + "futures", "getrandom", "libc", + "sha2", + "url", "wasi", "wit-bindgen", ] @@ -3782,6 +3786,7 @@ version = "15.0.0" dependencies = [ "anyhow", "async-trait", + "base64", "bytes", "futures", "http", @@ -3789,6 +3794,7 @@ dependencies = [ "http-body-util", "hyper", "rustls", + "sha2", "test-log", "test-programs-artifacts", "tokio", diff --git a/crates/test-programs/Cargo.toml b/crates/test-programs/Cargo.toml index 422e851d34f8..f195b3d591d1 100644 --- a/crates/test-programs/Cargo.toml +++ b/crates/test-programs/Cargo.toml @@ -12,3 +12,7 @@ wasi = "0.11.0" wit-bindgen = { workspace = true, features = ['default'] } libc = { workspace = true } getrandom = "0.2.9" +futures = { workspace = true, default-features = false, features = ['alloc'] } +url = { workspace = true } +sha2 = "0.10.2" +base64 = "0.21.0" diff --git a/crates/test-programs/src/bin/api_proxy_streaming.rs b/crates/test-programs/src/bin/api_proxy_streaming.rs new file mode 100644 index 000000000000..e7215cb4af9b --- /dev/null +++ b/crates/test-programs/src/bin/api_proxy_streaming.rs @@ -0,0 +1,362 @@ +use anyhow::{bail, Result}; +use bindings::wasi::http::types::{ + Fields, IncomingRequest, Method, OutgoingBody, OutgoingRequest, OutgoingResponse, + ResponseOutparam, Scheme, +}; +use futures::{stream, SinkExt, StreamExt, TryStreamExt}; +use url::Url; + +mod bindings { + use super::Handler; + + wit_bindgen::generate!({ + path: "../wasi-http/wit", + world: "wasi:http/proxy", + exports: { + "wasi:http/incoming-handler": Handler, + }, + }); +} + +const MAX_CONCURRENCY: usize = 16; + +struct Handler; + +impl bindings::exports::wasi::http::incoming_handler::Guest for Handler { + fn handle(request: IncomingRequest, response_out: ResponseOutparam) { + executor::run(async move { + handle_request(request, response_out).await; + }) + } +} + +async fn handle_request(request: IncomingRequest, response_out: ResponseOutparam) { + let headers = request.headers().entries(); + + match (request.method(), request.path_with_query().as_deref()) { + (Method::Get, Some("/hash-all")) => { + let urls = headers.iter().filter_map(|(k, v)| { + (k == "url") + .then_some(v) + .and_then(|v| std::str::from_utf8(v).ok()) + .and_then(|v| Url::parse(v).ok()) + }); + + let results = urls.map(|url| async move { + let result = hash(&url).await; + (url, result) + }); + + let mut results = stream::iter(results).buffer_unordered(MAX_CONCURRENCY); + + let response = OutgoingResponse::new( + 200, + &Fields::new(&[("content-type".to_string(), b"text/plain".to_vec())]), + ); + + let mut body = + executor::outgoing_body(response.write().expect("response should be writable")); + + ResponseOutparam::set(response_out, Ok(response)); + + while let Some((url, result)) = results.next().await { + let payload = match result { + Ok(hash) => format!("{url}: {hash}\n"), + Err(e) => format!("{url}: {e:?}\n"), + } + .into_bytes(); + + if let Err(e) = body.send(payload).await { + eprintln!("Error sending payload: {e}"); + } + } + } + + (Method::Post, Some("/echo")) => { + let response = OutgoingResponse::new( + 200, + &Fields::new( + &headers + .into_iter() + .filter_map(|(k, v)| (k == "content-type").then_some((k, v))) + .collect::>(), + ), + ); + + let mut body = + executor::outgoing_body(response.write().expect("response should be writable")); + + ResponseOutparam::set(response_out, Ok(response)); + + let mut stream = + executor::incoming_body(request.consume().expect("request should be readable")); + + while let Some(chunk) = stream.next().await { + match chunk { + Ok(chunk) => { + if let Err(e) = body.send(chunk).await { + eprintln!("Error sending body: {e}"); + break; + } + } + Err(e) => { + eprintln!("Error receiving body: {e}"); + break; + } + } + } + } + + _ => { + let response = OutgoingResponse::new(405, &Fields::new(&[])); + + let body = response.write().expect("response should be writable"); + + ResponseOutparam::set(response_out, Ok(response)); + + OutgoingBody::finish(body, None); + } + } +} + +async fn hash(url: &Url) -> Result { + let request = OutgoingRequest::new( + &Method::Get, + Some(url.path()), + Some(&match url.scheme() { + "http" => Scheme::Http, + "https" => Scheme::Https, + scheme => Scheme::Other(scheme.into()), + }), + Some(url.authority()), + &Fields::new(&[]), + ); + + let response = executor::outgoing_request_send(request).await?; + + let status = response.status(); + + if !(200..300).contains(&status) { + bail!("unexpected status: {status}"); + } + + let mut body = + executor::incoming_body(response.consume().expect("response should be readable")); + + use sha2::Digest; + let mut hasher = sha2::Sha256::new(); + while let Some(chunk) = body.try_next().await? { + hasher.update(&chunk); + } + + use base64::Engine; + Ok(base64::engine::general_purpose::STANDARD_NO_PAD.encode(hasher.finalize())) +} + +// Technically this should not be here for a proxy, but given the current +// framework for tests it's required since this file is built as a `bin` +fn main() {} + +mod executor { + use super::bindings::wasi::{ + http::{ + outgoing_handler, + types::{ + self, IncomingBody, IncomingResponse, InputStream, OutgoingBody, OutgoingRequest, + OutputStream, + }, + }, + io::{self, streams::StreamError}, + }; + use anyhow::{anyhow, Error, Result}; + use futures::{future, sink, stream, Sink, Stream}; + use std::{ + cell::RefCell, + future::Future, + mem, + rc::Rc, + sync::{Arc, Mutex}, + task::{Context, Poll, Wake, Waker}, + }; + + const READ_SIZE: u64 = 16 * 1024; + + static WAKERS: Mutex> = Mutex::new(Vec::new()); + + pub fn run(future: impl Future) -> T { + futures::pin_mut!(future); + + struct DummyWaker; + + impl Wake for DummyWaker { + fn wake(self: Arc) {} + } + + let waker = Arc::new(DummyWaker).into(); + + loop { + match future.as_mut().poll(&mut Context::from_waker(&waker)) { + Poll::Pending => { + let mut new_wakers = Vec::new(); + + let wakers = mem::take::>(&mut WAKERS.lock().unwrap()); + + assert!(!wakers.is_empty()); + + let pollables = wakers + .iter() + .map(|(pollable, _)| pollable) + .collect::>(); + + let mut ready = vec![false; wakers.len()]; + + for index in io::poll::poll_list(&pollables) { + ready[usize::try_from(index).unwrap()] = true; + } + + for (ready, (pollable, waker)) in ready.into_iter().zip(wakers) { + if ready { + waker.wake() + } else { + new_wakers.push((pollable, waker)); + } + } + + *WAKERS.lock().unwrap() = new_wakers; + } + Poll::Ready(result) => break result, + } + } + } + + pub fn outgoing_body(body: OutgoingBody) -> impl Sink, Error = Error> { + struct Outgoing(Option<(OutputStream, OutgoingBody)>); + + impl Drop for Outgoing { + fn drop(&mut self) { + if let Some((stream, body)) = self.0.take() { + drop(stream); + OutgoingBody::finish(body, None); + } + } + } + + let stream = body.write().expect("response body should be writable"); + let pair = Rc::new(RefCell::new(Outgoing(Some((stream, body))))); + + sink::unfold((), { + move |(), chunk: Vec| { + future::poll_fn({ + let mut offset = 0; + let mut flushing = false; + let pair = pair.clone(); + + move |context| { + let pair = pair.borrow(); + let (stream, _) = &pair.0.as_ref().unwrap(); + + loop { + match stream.check_write() { + Ok(0) => { + WAKERS + .lock() + .unwrap() + .push((stream.subscribe(), context.waker().clone())); + + break Poll::Pending; + } + Ok(count) => { + if offset == chunk.len() { + if flushing { + break Poll::Ready(Ok(())); + } else { + stream.flush().expect("stream should be flushable"); + flushing = true; + } + } else { + let count = usize::try_from(count) + .unwrap() + .min(chunk.len() - offset); + + match stream.write(&chunk[offset..][..count]) { + Ok(()) => { + offset += count; + } + Err(_) => break Poll::Ready(Err(anyhow!("I/O error"))), + } + } + } + Err(_) => break Poll::Ready(Err(anyhow!("I/O error"))), + } + } + } + }) + } + }) + } + + pub fn outgoing_request_send( + request: OutgoingRequest, + ) -> impl Future> { + future::poll_fn({ + let response = outgoing_handler::handle(request, None); + + move |context| match &response { + Ok(response) => { + if let Some(response) = response.get() { + Poll::Ready(response.unwrap()) + } else { + WAKERS + .lock() + .unwrap() + .push((response.subscribe(), context.waker().clone())); + Poll::Pending + } + } + Err(error) => Poll::Ready(Err(error.clone())), + } + }) + } + + pub fn incoming_body(body: IncomingBody) -> impl Stream>> { + struct Incoming(Option<(InputStream, IncomingBody)>); + + impl Drop for Incoming { + fn drop(&mut self) { + if let Some((stream, body)) = self.0.take() { + drop(stream); + IncomingBody::finish(body); + } + } + } + + stream::poll_fn({ + let stream = body.stream().expect("response body should be readable"); + let pair = Incoming(Some((stream, body))); + + move |context| { + if let Some((stream, _)) = &pair.0 { + match stream.read(READ_SIZE) { + Ok(buffer) => { + if buffer.is_empty() { + WAKERS + .lock() + .unwrap() + .push((stream.subscribe(), context.waker().clone())); + Poll::Pending + } else { + Poll::Ready(Some(Ok(buffer))) + } + } + Err(StreamError::Closed) => Poll::Ready(None), + Err(StreamError::LastOperationFailed(error)) => { + Poll::Ready(Some(Err(anyhow!("{}", error.to_debug_string())))) + } + } + } else { + Poll::Ready(None) + } + } + }) + } +} diff --git a/crates/wasi-http/Cargo.toml b/crates/wasi-http/Cargo.toml index 9210d50384e7..39388f2e91a3 100644 --- a/crates/wasi-http/Cargo.toml +++ b/crates/wasi-http/Cargo.toml @@ -39,6 +39,9 @@ test-log = { workspace = true } tracing-subscriber = { workspace = true } wasmtime = { workspace = true, features = ['cranelift'] } tokio = { workspace = true, features = ['macros'] } +futures = { workspace = true, default-features = false, features = ['alloc'] } +sha2 = "0.10.2" +base64 = "0.21.0" [features] default = ["sync"] diff --git a/crates/wasi-http/src/http_impl.rs b/crates/wasi-http/src/http_impl.rs index 287556e60b7e..d9e25a4ce490 100644 --- a/crates/wasi-http/src/http_impl.rs +++ b/crates/wasi-http/src/http_impl.rs @@ -2,18 +2,14 @@ use crate::bindings::http::{ outgoing_handler, types::{RequestOptions, Scheme}, }; -use crate::types::{self, HostFutureIncomingResponse, IncomingResponseInternal}; +use crate::types::{self, HostFutureIncomingResponse, OutgoingRequest}; use crate::WasiHttpView; -use anyhow::Context; use bytes::Bytes; use http_body_util::{BodyExt, Empty}; use hyper::Method; use std::time::Duration; -use tokio::net::TcpStream; -use tokio::time::timeout; use types::HostOutgoingRequest; use wasmtime::component::Resource; -use wasmtime_wasi::preview2; impl outgoing_handler::Host for T { fn handle( @@ -90,120 +86,15 @@ impl outgoing_handler::Host for T { .boxed() }); - let request = builder.body(body).map_err(http_protocol_error)?; + let request = builder.body(body).map_err(types::http_protocol_error)?; - let handle = preview2::spawn(async move { - let tcp_stream = TcpStream::connect(authority.clone()) - .await - .map_err(invalid_url)?; - - let (mut sender, worker) = if use_tls { - #[cfg(any(target_arch = "riscv64", target_arch = "s390x"))] - { - anyhow::bail!(crate::bindings::http::types::Error::UnexpectedError( - "unsupported architecture for SSL".to_string(), - )); - } - - #[cfg(not(any(target_arch = "riscv64", target_arch = "s390x")))] - { - use tokio_rustls::rustls::OwnedTrustAnchor; - - // derived from https://github.com/tokio-rs/tls/blob/master/tokio-rustls/examples/client/src/main.rs - let mut root_cert_store = rustls::RootCertStore::empty(); - root_cert_store.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map( - |ta| { - OwnedTrustAnchor::from_subject_spki_name_constraints( - ta.subject, - ta.spki, - ta.name_constraints, - ) - }, - )); - let config = rustls::ClientConfig::builder() - .with_safe_defaults() - .with_root_certificates(root_cert_store) - .with_no_client_auth(); - let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(config)); - let mut parts = authority.split(":"); - let host = parts.next().unwrap_or(&authority); - let domain = rustls::ServerName::try_from(host)?; - let stream = connector.connect(domain, tcp_stream).await.map_err(|e| { - crate::bindings::http::types::Error::ProtocolError(e.to_string()) - })?; - - let (sender, conn) = timeout( - connect_timeout, - hyper::client::conn::http1::handshake(stream), - ) - .await - .map_err(|_| timeout_error("connection"))??; - - let worker = preview2::spawn(async move { - conn.await.context("hyper connection failed")?; - Ok::<_, anyhow::Error>(()) - }); - - (sender, worker) - } - } else { - let (sender, conn) = timeout( - connect_timeout, - // TODO: we should plumb the builder through the http context, and use it here - hyper::client::conn::http1::handshake(tcp_stream), - ) - .await - .map_err(|_| timeout_error("connection"))??; - - let worker = preview2::spawn(async move { - conn.await.context("hyper connection failed")?; - Ok::<_, anyhow::Error>(()) - }); - - (sender, worker) - }; - - let resp = timeout(first_byte_timeout, sender.send_request(request)) - .await - .map_err(|_| timeout_error("first byte"))? - .map_err(hyper_protocol_error)? - .map(|body| body.map_err(|e| anyhow::anyhow!(e)).boxed()); - - Ok(IncomingResponseInternal { - resp, - worker, - between_bytes_timeout, - }) - }); - - let fut = self.table().push(HostFutureIncomingResponse::new(handle))?; - - Ok(Ok(fut)) + Ok(Ok(self.send_request(OutgoingRequest { + use_tls, + authority, + request, + connect_timeout, + first_byte_timeout, + between_bytes_timeout, + })?)) } } - -fn timeout_error(kind: &str) -> anyhow::Error { - anyhow::anyhow!(crate::bindings::http::types::Error::TimeoutError(format!( - "{kind} timed out" - ))) -} - -fn http_protocol_error(e: http::Error) -> anyhow::Error { - anyhow::anyhow!(crate::bindings::http::types::Error::ProtocolError( - e.to_string() - )) -} - -fn hyper_protocol_error(e: hyper::Error) -> anyhow::Error { - anyhow::anyhow!(crate::bindings::http::types::Error::ProtocolError( - e.to_string() - )) -} - -fn invalid_url(e: std::io::Error) -> anyhow::Error { - // TODO: DNS errors show up as a Custom io error, what subset of errors should we consider for - // InvalidUrl here? - anyhow::anyhow!(crate::bindings::http::types::Error::InvalidUrl( - e.to_string() - )) -} diff --git a/crates/wasi-http/src/types.rs b/crates/wasi-http/src/types.rs index c73249ac9ee4..76530811ccec 100644 --- a/crates/wasi-http/src/types.rs +++ b/crates/wasi-http/src/types.rs @@ -5,13 +5,27 @@ use crate::{ bindings::http::types::{self, Method, Scheme}, body::{HostIncomingBodyBuilder, HyperIncomingBody, HyperOutgoingBody}, }; +use anyhow::Context; +use http_body_util::BodyExt; use std::any::Any; +use std::time::Duration; +use tokio::net::TcpStream; +use tokio::time::timeout; use wasmtime::component::Resource; -use wasmtime_wasi::preview2::{AbortOnDropJoinHandle, Subscribe, Table}; +use wasmtime_wasi::preview2::{self, AbortOnDropJoinHandle, Subscribe, Table}; /// Capture the state necessary for use in the wasi-http API implementation. pub struct WasiHttpCtx; +pub struct OutgoingRequest { + pub use_tls: bool, + pub authority: String, + pub request: hyper::Request, + pub connect_timeout: Duration, + pub first_byte_timeout: Duration, + pub between_bytes_timeout: Duration, +} + pub trait WasiHttpView: Send { fn ctx(&mut self) -> &mut WasiHttpCtx; fn table(&mut self) -> &mut Table; @@ -41,6 +55,142 @@ pub trait WasiHttpView: Send { let id = self.table().push(HostResponseOutparam { result })?; Ok(id) } + + fn send_request( + &mut self, + request: OutgoingRequest, + ) -> wasmtime::Result> + where + Self: Sized, + { + default_send_request(self, request) + } +} + +pub fn default_send_request( + view: &mut dyn WasiHttpView, + OutgoingRequest { + use_tls, + authority, + request, + connect_timeout, + first_byte_timeout, + between_bytes_timeout, + }: OutgoingRequest, +) -> wasmtime::Result> { + let handle = preview2::spawn(async move { + let tcp_stream = TcpStream::connect(authority.clone()) + .await + .map_err(invalid_url)?; + + let (mut sender, worker) = if use_tls { + #[cfg(any(target_arch = "riscv64", target_arch = "s390x"))] + { + anyhow::bail!(crate::bindings::http::types::Error::UnexpectedError( + "unsupported architecture for SSL".to_string(), + )); + } + + #[cfg(not(any(target_arch = "riscv64", target_arch = "s390x")))] + { + use tokio_rustls::rustls::OwnedTrustAnchor; + + // derived from https://github.com/tokio-rs/tls/blob/master/tokio-rustls/examples/client/src/main.rs + let mut root_cert_store = rustls::RootCertStore::empty(); + root_cert_store.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map( + |ta| { + OwnedTrustAnchor::from_subject_spki_name_constraints( + ta.subject, + ta.spki, + ta.name_constraints, + ) + }, + )); + let config = rustls::ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(root_cert_store) + .with_no_client_auth(); + let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(config)); + let mut parts = authority.split(":"); + let host = parts.next().unwrap_or(&authority); + let domain = rustls::ServerName::try_from(host)?; + let stream = connector.connect(domain, tcp_stream).await.map_err(|e| { + crate::bindings::http::types::Error::ProtocolError(e.to_string()) + })?; + + let (sender, conn) = timeout( + connect_timeout, + hyper::client::conn::http1::handshake(stream), + ) + .await + .map_err(|_| timeout_error("connection"))??; + + let worker = preview2::spawn(async move { + conn.await.context("hyper connection failed")?; + Ok::<_, anyhow::Error>(()) + }); + + (sender, worker) + } + } else { + let (sender, conn) = timeout( + connect_timeout, + // TODO: we should plumb the builder through the http context, and use it here + hyper::client::conn::http1::handshake(tcp_stream), + ) + .await + .map_err(|_| timeout_error("connection"))??; + + let worker = preview2::spawn(async move { + conn.await.context("hyper connection failed")?; + Ok::<_, anyhow::Error>(()) + }); + + (sender, worker) + }; + + let resp = timeout(first_byte_timeout, sender.send_request(request)) + .await + .map_err(|_| timeout_error("first byte"))? + .map_err(hyper_protocol_error)? + .map(|body| body.map_err(|e| anyhow::anyhow!(e)).boxed()); + + Ok(IncomingResponseInternal { + resp, + worker, + between_bytes_timeout, + }) + }); + + let fut = view.table().push(HostFutureIncomingResponse::new(handle))?; + + Ok(fut) +} + +pub fn timeout_error(kind: &str) -> anyhow::Error { + anyhow::anyhow!(crate::bindings::http::types::Error::TimeoutError(format!( + "{kind} timed out" + ))) +} + +pub fn http_protocol_error(e: http::Error) -> anyhow::Error { + anyhow::anyhow!(crate::bindings::http::types::Error::ProtocolError( + e.to_string() + )) +} + +pub fn hyper_protocol_error(e: hyper::Error) -> anyhow::Error { + anyhow::anyhow!(crate::bindings::http::types::Error::ProtocolError( + e.to_string() + )) +} + +fn invalid_url(e: std::io::Error) -> anyhow::Error { + // TODO: DNS errors show up as a Custom io error, what subset of errors should we consider for + // InvalidUrl here? + anyhow::anyhow!(crate::bindings::http::types::Error::InvalidUrl( + e.to_string() + )) } pub struct HostIncomingRequest { @@ -81,7 +231,7 @@ impl TryFrom for hyper::Response { fn try_from( resp: HostOutgoingResponse, ) -> Result, Self::Error> { - use http_body_util::{BodyExt, Empty}; + use http_body_util::Empty; let mut builder = hyper::Response::builder().status(resp.status); diff --git a/crates/wasi-http/tests/all/main.rs b/crates/wasi-http/tests/all/main.rs index 9df4e557c4c0..f9c68da0a9b7 100644 --- a/crates/wasi-http/tests/all/main.rs +++ b/crates/wasi-http/tests/all/main.rs @@ -1,20 +1,41 @@ use crate::http_server::Server; -use anyhow::{Context, Result}; +use anyhow::{anyhow, Context, Result}; +use futures::{channel::oneshot, future, stream, FutureExt}; +use http_body::Frame; +use http_body_util::{combinators::BoxBody, Collected, StreamBody}; +use hyper::{body::Bytes, server::conn::http1, service::service_fn, Method, StatusCode}; +use sha2::{Digest, Sha256}; +use std::{collections::HashMap, iter, net::Ipv4Addr, str, sync::Arc}; +use tokio::task; use wasmtime::{ - component::{Component, Linker}, + component::{Component, Linker, Resource}, Config, Engine, Store, }; -use wasmtime_wasi::preview2::{pipe::MemoryOutputPipe, Table, WasiCtx, WasiCtxBuilder, WasiView}; -use wasmtime_wasi_http::{WasiHttpCtx, WasiHttpView}; +use wasmtime_wasi::preview2::{ + self, pipe::MemoryOutputPipe, Table, WasiCtx, WasiCtxBuilder, WasiView, +}; +use wasmtime_wasi_http::{ + bindings::http::types::Error, + body::HyperIncomingBody, + types::{self, HostFutureIncomingResponse, IncomingResponseInternal, OutgoingRequest}, + WasiHttpCtx, WasiHttpView, +}; mod http_server; +type RequestSender = Arc< + dyn Fn(&mut Ctx, OutgoingRequest) -> wasmtime::Result> + + Send + + Sync, +>; + struct Ctx { table: Table, wasi: WasiCtx, http: WasiHttpCtx, stdout: MemoryOutputPipe, stderr: MemoryOutputPipe, + send_request: Option, } impl WasiView for Ctx { @@ -40,6 +61,17 @@ impl WasiHttpView for Ctx { fn table(&mut self) -> &mut Table { &mut self.table } + + fn send_request( + &mut self, + request: OutgoingRequest, + ) -> wasmtime::Result> { + if let Some(send_request) = self.send_request.clone() { + send_request(self, request) + } else { + types::default_send_request(self, request) + } + } } fn store(engine: &Engine, server: &Server) -> Store { @@ -57,6 +89,7 @@ fn store(engine: &Engine, server: &Server) -> Store { http: WasiHttpCtx {}, stderr, stdout, + send_request: None, }; Store::new(&engine, ctx) @@ -87,8 +120,11 @@ macro_rules! assert_test_exists { mod async_; mod sync; -#[test_log::test(tokio::test)] -async fn wasi_http_proxy_tests() -> anyhow::Result<()> { +async fn run_wasi_http( + component_filename: &str, + req: hyper::Request, + send_request: Option, +) -> anyhow::Result>, Error>> { let stdout = MemoryOutputPipe::new(4096); let stderr = MemoryOutputPipe::new(4096); let table = Table::new(); @@ -98,7 +134,7 @@ async fn wasi_http_proxy_tests() -> anyhow::Result<()> { config.wasm_component_model(true); config.async_support(true); let engine = Engine::new(&config)?; - let component = Component::from_file(&engine, test_programs_artifacts::API_PROXY_COMPONENT)?; + let component = Component::from_file(&engine, component_filename)?; // Create our wasi context. let mut builder = WasiCtxBuilder::new(); @@ -112,6 +148,7 @@ async fn wasi_http_proxy_tests() -> anyhow::Result<()> { http, stderr, stdout, + send_request, }; let mut store = Store::new(&engine, ctx); @@ -121,21 +158,12 @@ async fn wasi_http_proxy_tests() -> anyhow::Result<()> { wasmtime_wasi_http::proxy::Proxy::instantiate_async(&mut store, &component, &linker) .await?; - let req = { - use http_body_util::{BodyExt, Empty}; - - let req = hyper::Request::builder().method(http::Method::GET).body( - Empty::::new() - .map_err(|e| anyhow::anyhow!(e)) - .boxed(), - )?; - store.data_mut().new_incoming_request(req)? - }; + let req = store.data_mut().new_incoming_request(req)?; let (sender, receiver) = tokio::sync::oneshot::channel(); let out = store.data_mut().new_response_outparam(sender)?; - let handle = wasmtime_wasi::preview2::spawn(async move { + let handle = preview2::spawn(async move { proxy .wasi_http_incoming_handler() .call_handle(&mut store, req, out) @@ -162,6 +190,17 @@ async fn wasi_http_proxy_tests() -> anyhow::Result<()> { // deadlocking. handle.await.context("Component execution")?; + Ok(resp) +} + +#[test_log::test(tokio::test)] +async fn wasi_http_proxy_tests() -> anyhow::Result<()> { + let req = hyper::Request::builder() + .method(http::Method::GET) + .body(body::empty())?; + + let resp = run_wasi_http(test_programs_artifacts::API_PROXY_COMPONENT, req, None).await?; + match resp { Ok(resp) => println!("response: {resp:?}"), Err(e) => panic!("Error given in response: {e:?}"), @@ -169,3 +208,207 @@ async fn wasi_http_proxy_tests() -> anyhow::Result<()> { Ok(()) } + +#[test_log::test(tokio::test)] +async fn wasi_http_hash_all() -> Result<()> { + do_wasi_http_hash_all(false).await +} + +#[test_log::test(tokio::test)] +async fn wasi_http_hash_all_with_override() -> Result<()> { + do_wasi_http_hash_all(true).await +} + +async fn do_wasi_http_hash_all(override_send_request: bool) -> Result<()> { + let bodies = Arc::new( + [ + ("/a", "’Twas brillig, and the slithy toves"), + ("/b", "Did gyre and gimble in the wabe:"), + ("/c", "All mimsy were the borogoves,"), + ("/d", "And the mome raths outgrabe."), + ] + .into_iter() + .collect::>(), + ); + + let listener = tokio::net::TcpListener::bind((Ipv4Addr::new(127, 0, 0, 1), 0)).await?; + + let prefix = format!("http://{}", listener.local_addr()?); + + let (_tx, rx) = oneshot::channel::<()>(); + + let handle = { + let bodies = bodies.clone(); + + move |request: http::request::Parts| { + if let (Method::GET, Some(body)) = (request.method, bodies.get(request.uri.path())) { + Ok::<_, anyhow::Error>(hyper::Response::new(body::full(Bytes::copy_from_slice( + body.as_bytes(), + )))) + } else { + Ok(hyper::Response::builder() + .status(StatusCode::METHOD_NOT_ALLOWED) + .body(body::empty())?) + } + } + }; + + let send_request = if override_send_request { + Some(Arc::new( + move |view: &mut Ctx, + OutgoingRequest { + request, + between_bytes_timeout, + .. + }| { + Ok(view.table().push(HostFutureIncomingResponse::Ready( + handle(request.into_parts().0).map(|resp| IncomingResponseInternal { + resp, + worker: preview2::spawn(future::ready(Ok(()))), + between_bytes_timeout, + }), + ))?) + }, + ) as RequestSender) + } else { + let server = async move { + loop { + let (stream, _) = listener.accept().await?; + let handle = handle.clone(); + task::spawn(async move { + if let Err(e) = http1::Builder::new() + .keep_alive(true) + .serve_connection( + stream, + service_fn(move |request| { + let handle = handle.clone(); + async move { handle(request.into_parts().0) } + }), + ) + .await + { + eprintln!("error serving connection: {e:?}"); + } + }); + + // Help rustc with type inference: + if false { + return Ok::<_, anyhow::Error>(()); + } + } + } + .then(|result| { + if let Err(e) = result { + eprintln!("error listening for connections: {e:?}"); + } + future::ready(()) + }) + .boxed(); + + task::spawn(async move { + drop(future::select(server, rx).await); + }); + + None + }; + + let mut request = hyper::Request::get("/hash-all"); + for path in bodies.keys() { + request = request.header("url", format!("{prefix}{path}")); + } + let request = request.body(body::empty())?; + + let response = run_wasi_http( + test_programs_artifacts::API_PROXY_STREAMING_COMPONENT, + request, + send_request, + ) + .await??; + + assert_eq!(StatusCode::OK, response.status()); + let body = response.into_body().to_bytes(); + let body = str::from_utf8(&body)?; + for line in body.lines() { + let (url, hash) = line + .split_once(": ") + .ok_or_else(|| anyhow!("expected string of form `: `; got {line}"))?; + + let path = url + .strip_prefix(&prefix) + .ok_or_else(|| anyhow!("expected string with prefix {prefix}; got {url}"))?; + + let mut hasher = Sha256::new(); + hasher.update( + bodies + .get(path) + .ok_or_else(|| anyhow!("unexpected path: {path}"))?, + ); + + use base64::Engine; + assert_eq!( + hash, + base64::engine::general_purpose::STANDARD_NO_PAD.encode(hasher.finalize()) + ); + } + + Ok(()) +} + +#[test_log::test(tokio::test)] +async fn wasi_http_echo() -> Result<()> { + let body = { + // A sorta-random-ish megabyte + let mut n = 0_u8; + iter::repeat_with(move || { + n = n.wrapping_add(251); + n + }) + .take(1024 * 1024) + .collect::>() + }; + + let request = hyper::Request::post("/echo") + .header("content-type", "application/octet-stream") + .body(BoxBody::new(StreamBody::new(stream::iter( + body.chunks(16 * 1024) + .map(|chunk| Ok::<_, anyhow::Error>(Frame::data(Bytes::copy_from_slice(chunk)))) + .collect::>(), + ))))?; + + let response = run_wasi_http( + test_programs_artifacts::API_PROXY_STREAMING_COMPONENT, + request, + None, + ) + .await??; + + assert_eq!(StatusCode::OK, response.status()); + assert_eq!( + response.headers()["content-type"], + "application/octet-stream" + ); + let received = Vec::from(response.into_body().to_bytes()); + if body != received { + panic!( + "body content mismatch (expected length {}; actual length {})", + body.len(), + received.len() + ); + } + + Ok(()) +} + +mod body { + use http_body_util::{combinators::BoxBody, BodyExt, Empty, Full}; + use hyper::body::Bytes; + use wasmtime_wasi_http::body::HyperIncomingBody; + + pub fn full(bytes: Bytes) -> HyperIncomingBody { + BoxBody::new(Full::new(bytes).map_err(|_| unreachable!())) + } + + pub fn empty() -> HyperIncomingBody { + BoxBody::new(Empty::new().map_err(|_| unreachable!())) + } +}