Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TestClient timeout simulates http.disconnect #1446

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions starlette/testclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from concurrent.futures import Future
from urllib.parse import unquote, urljoin, urlsplit

import anyio
import anyio.abc
import requests
from anyio.streams.stapled import StapledObjectStream
Expand Down Expand Up @@ -190,10 +191,18 @@ def send(
request_complete = False
response_started = False
response_complete: anyio.Event
timeout_called = False
raw_kwargs: typing.Dict[str, typing.Any] = {"body": io.BytesIO()}
template = None
context = None

def do_timeout() -> None:
nonlocal timeout_called, response_complete, response_started
timeout_called = True
if request_complete:
response_started = True
response_complete.set()

async def receive() -> Message:
nonlocal request_complete

Expand All @@ -215,17 +224,23 @@ async def receive() -> Message:
return {"type": "http.request", "body": chunk, "more_body": True}
except StopIteration:
request_complete = True
if timeout_called:
do_timeout()
return {"type": "http.request", "body": b""}
else:
body_bytes = body

request_complete = True
if timeout_called:
do_timeout()
return {"type": "http.request", "body": body_bytes}

async def send(message: Message) -> None:
nonlocal raw_kwargs, response_started, template, context

if message["type"] == "http.response.start":
if timeout_called:
pass
elif message["type"] == "http.response.start":
assert (
not response_started
), 'Received multiple "http.response.start" messages.'
Expand Down Expand Up @@ -259,15 +274,24 @@ async def send(message: Message) -> None:
template = message["template"]
context = message["context"]

async def timeout_task(delay: float) -> None:
await anyio.sleep(delay)
do_timeout()

timeout: typing.Optional[float] = kwargs.get("timeout")
try:
with self.portal_factory() as portal:
response_complete = portal.call(anyio.Event)
if timeout:
portal.start_task_soon(timeout_task, timeout)
portal.call(self.app, scope, receive, send)
except BaseException as exc:
if self.raise_server_exceptions:
raise exc

if self.raise_server_exceptions:
if timeout_called:
raise requests.exceptions.ReadTimeout()
elif self.raise_server_exceptions:
assert response_started, "TestClient did not receive any response."
elif not response_started:
raw_kwargs = {
Expand Down
77 changes: 77 additions & 0 deletions tests/test_testclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import anyio
import pytest
import requests
import sniffio
import trio.lowlevel

Expand Down Expand Up @@ -229,3 +230,79 @@ async def asgi(receive, send):
with client.websocket_connect("/") as websocket:
data = websocket.receive_json()
assert data == {"message": "test"}


def test_timeout(test_client_factory):
done = False

async def app(scope, receive, send):
nonlocal done
assert (await receive())["type"] == "http.request"
assert (await receive())["type"] == "http.disconnect"
await send(
{
"type": "http.response.start",
"status": 200,
"headers": [[b"content-type", b"text/plain"]],
}
)
await send({"type": "http.response.body", "body": b"Hello, world!"})
done = True

client = test_client_factory(app)
with pytest.raises(requests.ReadTimeout):
client.get("/", timeout=0.001)
assert done


def test_timeout_generator(test_client_factory):
done = False

async def app(scope, receive, send):
nonlocal done
assert (await receive())["type"] == "http.request"
await anyio.sleep(0.01)
assert (await receive())["type"] == "http.request"
assert (await receive())["type"] == "http.disconnect"
await send(
{
"type": "http.response.start",
"status": 200,
"headers": [[b"content-type", b"text/plain"]],
}
)
await send({"type": "http.response.body", "body": b"Hello, world!"})
done = True

client = test_client_factory(app)

def gen():
yield "hello"

with pytest.raises(requests.ReadTimeout):
client.post("/", data=gen(), timeout=0.001)
assert done


def test_timeout_early_done(test_client_factory):
done = False

async def app(scope, receive, send):
nonlocal done
await anyio.sleep(0.01)
assert (await receive())["type"] == "http.request"
assert (await receive())["type"] == "http.disconnect"
await send(
{
"type": "http.response.start",
"status": 200,
"headers": [[b"content-type", b"text/plain"]],
}
)
await send({"type": "http.response.body", "body": b"Hello, world!"})
done = True

client = test_client_factory(app)
with pytest.raises(requests.ReadTimeout):
client.get("/", timeout=0.001)
assert done