Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add test to detect regression in transfer-encoding header #516

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions worker-sandbox/src/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ use futures_util::TryStreamExt;
use std::time::Duration;
use worker::Env;
use worker::{console_log, Date, Delay, Request, Response, ResponseBody, Result};

#[cfg(not(feature = "http"))]
pub fn handle_a_request(req: Request, _env: Env, _data: SomeSharedData) -> Result<Response> {
Response::ok(format!(
"req at: {}, located at: {:?}, within: {}",
Expand All @@ -17,6 +19,11 @@ pub fn handle_a_request(req: Request, _env: Env, _data: SomeSharedData) -> Resul
))
}

#[cfg(feature = "http")]
pub async fn handle_a_request() -> &'static str {
"Hello World"
}

pub async fn handle_async_request(
req: Request,
_env: Env,
Expand Down Expand Up @@ -200,6 +207,24 @@ pub async fn handle_cloned_stream(
Response::ok((left == right).to_string())
}

#[worker::send]
pub async fn handle_stream_response(
_req: Request,
_env: Env,
_data: SomeSharedData,
) -> Result<Response> {
let stream =
futures_util::stream::repeat(())
.take(10)
.enumerate()
.then(|(index, _)| async move {
Delay::from(Duration::from_millis(100)).await;
Result::Ok(index.to_string().into_bytes())
});
let resp = Response::from_stream(stream)?;
Ok(resp)
}

pub async fn handle_custom_response_body(
_req: Request,
_env: Env,
Expand Down
4 changes: 3 additions & 1 deletion worker-sandbox/src/router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ macro_rules! handler_sync (
#[cfg(feature = "http")]
pub fn make_router(data: SomeSharedData, env: Env) -> axum::Router {
axum::Router::new()
.route("/request", get(handler_sync!(request::handle_a_request)))
.route("/request", get(request::handle_a_request))
.route(
"/async-request",
get(handler!(request::handle_async_request)),
Expand All @@ -70,6 +70,7 @@ pub fn make_router(data: SomeSharedData, env: Env) -> axum::Router {
.route("/test-data", get(handler!(request::handle_test_data)))
.route("/xor/:num", post(handler!(request::handle_xor)))
.route("/headers", post(handler!(request::handle_headers)))
.route("/stream", get(handler!(request::handle_stream_response)))
.route("/formdata-name", post(handler!(form::handle_formdata_name)))
.route("/is-secret", post(handler!(form::handle_is_secret)))
.route(
Expand Down Expand Up @@ -215,6 +216,7 @@ pub fn make_router<'a>(data: SomeSharedData) -> Router<'a, SomeSharedData> {
.get_async("/test-data", handler!(request::handle_test_data))
.post_async("/xor/:num", handler!(request::handle_xor))
.post_async("/headers", handler!(request::handle_headers))
.get_async("/stream", handler!(request::handle_stream_response))
.post_async("/formdata-name", handler!(form::handle_formdata_name))
.post_async("/is-secret", handler!(form::handle_is_secret))
.post_async(
Expand Down
7 changes: 7 additions & 0 deletions worker-sandbox/tests/request.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import { mf } from "./mf";

test("basic sync request", async () => {
const resp = await mf.dispatchFetch("https://fake.host/request");
expect(resp.headers.get("Transfer-Encoding")).not.toBe("chunked");
expect(resp.status).toBe(200);
});

Expand All @@ -29,6 +30,12 @@ test("headers", async () => {
expect(resp.headers.get("A")).toBe("B");
});

test("stream response", async () => {
const resp = await mf.dispatchFetch("https://fake.host/stream");
expect(resp.headers.get("Transfer-Encoding")).toBe("chunked");
expect(resp.status).toBe(200);
});

test("secret", async () => {
const formData = new FormData();
formData.append("secret", "EXAMPLE_SECRET");
Expand Down
72 changes: 57 additions & 15 deletions worker/src/http/response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,31 @@ use crate::HttpResponse;
use crate::Result;
use crate::WebSocket;
use bytes::Bytes;
use http_body::Body as HttpBody;

use crate::http::body::BodyStream;
use js_sys::Uint8Array;
use worker_sys::ext::ResponseExt;
use worker_sys::ext::ResponseInitExt;

use crate::Error;
use std::sync::Arc;
use std::task::Context;
use std::task::Poll;
use std::task::Wake;

struct NoopWaker;

impl Wake for NoopWaker {
// Required method
fn wake(self: Arc<Self>) {}
}

/// **Requires** `http` feature. Convert generic [`http::Response<B>`](crate::HttpResponse)
/// to [`web_sys::Resopnse`](web_sys::Response) where `B` can be any [`http_body::Body`](http_body::Body)
pub fn to_wasm<B>(mut res: http::Response<B>) -> Result<web_sys::Response>
where
B: http_body::Body<Data = Bytes> + 'static,
B: http_body::Body<Data = Bytes> + Unpin + 'static,
{
let mut init = web_sys::ResponseInit::new();
init.status(res.status().as_u16());
Expand All @@ -23,22 +38,49 @@ where
init.websocket(ws.as_ref());
}

let body = res.into_body();
// I'm not sure how we are supposed to determine if there is no
// body for an `http::Response`, seems like this may be the only
// option given the trait? This appears to work for things like
// `hyper::Empty`.
let readable_stream = if body.is_end_stream() {
None
let mut body = res.into_body();

if let Some(body_size) = body.size_hint().upper() {
let waker = Arc::new(NoopWaker).into();
let mut cx = Context::from_waker(&waker);
let poll = HttpBody::poll_frame(std::pin::Pin::new(&mut body), &mut cx);
match poll {
Poll::Ready(Some(Ok(frame))) => {
// Fixed size body
let array = Uint8Array::new_with_length(body_size as u32);
array.copy_from(frame.data_ref().unwrap());
Ok(web_sys::Response::new_with_opt_buffer_source_and_init(
Some(&array),
&init,
)?)
}
Poll::Pending => Err(Error::RustError(
"Unable to poll fixed-length body: Pending".to_owned(),
)),
Poll::Ready(None) => Ok(web_sys::Response::new_with_opt_buffer_source_and_init(
None, &init,
)?),
Poll::Ready(Some(Err(_))) => Err(Error::RustError(
"Unable to poll fixed-length body: Err".to_owned(),
)),
}
} else {
let stream = BodyStream::new(body);
Some(wasm_streams::ReadableStream::from_stream(stream).into_raw())
};
// I'm not sure how we are supposed to determine if there is no
// body for an `http::Response`, seems like this may be the only
// option given the trait? This appears to work for things like
// `hyper::Empty`.
let readable_stream = if body.is_end_stream() {
None
} else {
let stream = BodyStream::new(body);
Some(wasm_streams::ReadableStream::from_stream(stream).into_raw())
};

Ok(web_sys::Response::new_with_opt_readable_stream_and_init(
readable_stream.as_ref(),
&init,
)?)
Ok(web_sys::Response::new_with_opt_readable_stream_and_init(
readable_stream.as_ref(),
&init,
)?)
}
}

/// **Requires** `http` feature. Convert [`web_sys::Response`](web_sys::Response)
Expand Down
2 changes: 1 addition & 1 deletion worker/src/response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ pub struct Response {
}

#[cfg(feature = "http")]
impl<B: http_body::Body<Data = Bytes> + 'static> TryFrom<http::Response<B>> for Response {
impl<B: http_body::Body<Data = Bytes> + Unpin + 'static> TryFrom<http::Response<B>> for Response {
type Error = crate::Error;
fn try_from(res: http::Response<B>) -> Result<Self> {
let resp = crate::http::response::to_wasm(res)?;
Expand Down
Loading