diff --git a/setup.cfg b/setup.cfg index 9c3b85b2c..91aac9730 100644 --- a/setup.cfg +++ b/setup.cfg @@ -44,6 +44,12 @@ known_third_party = click,does_not_exist,gunicorn,h11,httptools,pytest,requests, [tool:pytest] addopts = -rxXs + --strict-config + --strict-markers +xfail_strict=True +filterwarnings= + # Turn warnings that aren't filtered into exceptions + error [coverage:run] omit = venv/* diff --git a/tests/protocols/test_http.py b/tests/protocols/test_http.py index 25110fc66..577f70edb 100644 --- a/tests/protocols/test_http.py +++ b/tests/protocols/test_http.py @@ -1,4 +1,4 @@ -import asyncio +import contextlib import logging import pytest @@ -124,9 +124,10 @@ def set_protocol(self, protocol): class MockLoop: - def __init__(self): + def __init__(self, event_loop): self.tasks = [] self.later = [] + self.loop = event_loop def create_task(self, coroutine): self.tasks.insert(0, coroutine) @@ -137,7 +138,10 @@ def call_later(self, delay, callback, *args): def run_one(self): coroutine = self.tasks.pop() - asyncio.get_event_loop().run_until_complete(coroutine) + self.loop.run_until_complete(coroutine) + + def close(self): + self.loop.close() def run_later(self, with_delay): later = [] @@ -154,31 +158,32 @@ def add_done_callback(self, callback): pass -def get_connected_protocol(app, protocol_cls, **kwargs): - loop = MockLoop() +@contextlib.contextmanager +def get_connected_protocol(app, protocol_cls, event_loop, **kwargs): + loop = MockLoop(event_loop) transport = MockTransport() config = Config(app=app, **kwargs) server_state = ServerState() protocol = protocol_cls(config=config, server_state=server_state, _loop=loop) protocol.connection_made(transport) - return protocol + yield protocol + protocol.loop.close() @pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -def test_get_request(protocol_cls): +def test_get_request(protocol_cls, event_loop): app = Response("Hello, world", media_type="text/plain") - protocol = get_connected_protocol(app, protocol_cls) - protocol.data_received(SIMPLE_GET_REQUEST) - protocol.loop.run_one() - - assert b"HTTP/1.1 200 OK" in protocol.transport.buffer - assert b"Hello, world" in protocol.transport.buffer + with get_connected_protocol(app, protocol_cls, event_loop) as protocol: + protocol.data_received(SIMPLE_GET_REQUEST) + protocol.loop.run_one() + assert b"HTTP/1.1 200 OK" in protocol.transport.buffer + assert b"Hello, world" in protocol.transport.buffer @pytest.mark.parametrize("path", ["/", "/?foo", "/?foo=bar", "/?foo=bar&baz=1"]) @pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -def test_request_logging(path, protocol_cls, caplog): +def test_request_logging(path, protocol_cls, caplog, event_loop): get_request_with_query_string = b"\r\n".join( ["GET {} HTTP/1.1".format(path).encode("ascii"), b"Host: example.org", b"", b""] ) @@ -187,27 +192,27 @@ def test_request_logging(path, protocol_cls, caplog): app = Response("Hello, world", media_type="text/plain") - protocol = get_connected_protocol(app, protocol_cls, log_config=None) - protocol.data_received(get_request_with_query_string) - protocol.loop.run_one() - - assert '"GET {} HTTP/1.1" 200'.format(path) in caplog.records[0].message + with get_connected_protocol( + app, protocol_cls, event_loop, log_config=None + ) as protocol: + protocol.data_received(get_request_with_query_string) + protocol.loop.run_one() + assert '"GET {} HTTP/1.1" 200'.format(path) in caplog.records[0].message @pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -def test_head_request(protocol_cls): +def test_head_request(protocol_cls, event_loop): app = Response("Hello, world", media_type="text/plain") - protocol = get_connected_protocol(app, protocol_cls) - protocol.data_received(SIMPLE_HEAD_REQUEST) - protocol.loop.run_one() - - assert b"HTTP/1.1 200 OK" in protocol.transport.buffer - assert b"Hello, world" not in protocol.transport.buffer + with get_connected_protocol(app, protocol_cls, event_loop) as protocol: + protocol.data_received(SIMPLE_HEAD_REQUEST) + protocol.loop.run_one() + assert b"HTTP/1.1 200 OK" in protocol.transport.buffer + assert b"Hello, world" not in protocol.transport.buffer @pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -def test_post_request(protocol_cls): +def test_post_request(protocol_cls, event_loop): async def app(scope, receive, send): body = b"" more_body = True @@ -218,267 +223,259 @@ async def app(scope, receive, send): response = Response(b"Body: " + body, media_type="text/plain") await response(scope, receive, send) - protocol = get_connected_protocol(app, protocol_cls) - protocol.data_received(SIMPLE_POST_REQUEST) - protocol.loop.run_one() - - assert b"HTTP/1.1 200 OK" in protocol.transport.buffer - assert b'Body: {"hello": "world"}' in protocol.transport.buffer + with get_connected_protocol(app, protocol_cls, event_loop) as protocol: + protocol.data_received(SIMPLE_POST_REQUEST) + protocol.loop.run_one() + assert b"HTTP/1.1 200 OK" in protocol.transport.buffer + assert b'Body: {"hello": "world"}' in protocol.transport.buffer @pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -def test_keepalive(protocol_cls): +def test_keepalive(protocol_cls, event_loop): app = Response(b"", status_code=204) - protocol = get_connected_protocol(app, protocol_cls) - protocol.data_received(SIMPLE_GET_REQUEST) - protocol.loop.run_one() + with get_connected_protocol(app, protocol_cls, event_loop) as protocol: + protocol.data_received(SIMPLE_GET_REQUEST) + protocol.loop.run_one() - assert b"HTTP/1.1 204 No Content" in protocol.transport.buffer - assert not protocol.transport.is_closing() + assert b"HTTP/1.1 204 No Content" in protocol.transport.buffer + assert not protocol.transport.is_closing() @pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -def test_keepalive_timeout(protocol_cls): +def test_keepalive_timeout(protocol_cls, event_loop): app = Response(b"", status_code=204) - protocol = get_connected_protocol(app, protocol_cls) - protocol.data_received(SIMPLE_GET_REQUEST) - protocol.loop.run_one() - assert b"HTTP/1.1 204 No Content" in protocol.transport.buffer - assert not protocol.transport.is_closing() - - protocol.loop.run_later(with_delay=1) - assert not protocol.transport.is_closing() - - protocol.loop.run_later(with_delay=10) - assert protocol.transport.is_closing() + with get_connected_protocol(app, protocol_cls, event_loop) as protocol: + protocol.data_received(SIMPLE_GET_REQUEST) + protocol.loop.run_one() + assert b"HTTP/1.1 204 No Content" in protocol.transport.buffer + assert not protocol.transport.is_closing() + protocol.loop.run_later(with_delay=1) + assert not protocol.transport.is_closing() + protocol.loop.run_later(with_delay=5) + assert protocol.transport.is_closing() @pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -def test_close(protocol_cls): +def test_close(protocol_cls, event_loop): app = Response(b"", status_code=204, headers={"connection": "close"}) - protocol = get_connected_protocol(app, protocol_cls) - protocol.data_received(SIMPLE_GET_REQUEST) - protocol.loop.run_one() - assert b"HTTP/1.1 204 No Content" in protocol.transport.buffer - assert protocol.transport.is_closing() + with get_connected_protocol(app, protocol_cls, event_loop) as protocol: + protocol.data_received(SIMPLE_GET_REQUEST) + protocol.loop.run_one() + assert b"HTTP/1.1 204 No Content" in protocol.transport.buffer + assert protocol.transport.is_closing() @pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -def test_chunked_encoding(protocol_cls): +def test_chunked_encoding(protocol_cls, event_loop): app = Response( b"Hello, world!", status_code=200, headers={"transfer-encoding": "chunked"} ) - protocol = get_connected_protocol(app, protocol_cls) - protocol.data_received(SIMPLE_GET_REQUEST) - protocol.loop.run_one() - assert b"HTTP/1.1 200 OK" in protocol.transport.buffer - assert b"0\r\n\r\n" in protocol.transport.buffer - assert not protocol.transport.is_closing() + with get_connected_protocol(app, protocol_cls, event_loop) as protocol: + protocol.data_received(SIMPLE_GET_REQUEST) + protocol.loop.run_one() + assert b"HTTP/1.1 200 OK" in protocol.transport.buffer + assert b"0\r\n\r\n" in protocol.transport.buffer + assert not protocol.transport.is_closing() @pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -def test_chunked_encoding_empty_body(protocol_cls): +def test_chunked_encoding_empty_body(protocol_cls, event_loop): app = Response( b"Hello, world!", status_code=200, headers={"transfer-encoding": "chunked"} ) - protocol = get_connected_protocol(app, protocol_cls) - protocol.data_received(SIMPLE_GET_REQUEST) - protocol.loop.run_one() - assert b"HTTP/1.1 200 OK" in protocol.transport.buffer - assert protocol.transport.buffer.count(b"0\r\n\r\n") == 1 - assert not protocol.transport.is_closing() + with get_connected_protocol(app, protocol_cls, event_loop) as protocol: + protocol.data_received(SIMPLE_GET_REQUEST) + protocol.loop.run_one() + assert b"HTTP/1.1 200 OK" in protocol.transport.buffer + assert protocol.transport.buffer.count(b"0\r\n\r\n") == 1 + assert not protocol.transport.is_closing() @pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -def test_chunked_encoding_head_request(protocol_cls): +def test_chunked_encoding_head_request(protocol_cls, event_loop): app = Response( b"Hello, world!", status_code=200, headers={"transfer-encoding": "chunked"} ) - protocol = get_connected_protocol(app, protocol_cls) - protocol.data_received(SIMPLE_HEAD_REQUEST) - protocol.loop.run_one() - assert b"HTTP/1.1 200 OK" in protocol.transport.buffer - assert not protocol.transport.is_closing() + with get_connected_protocol(app, protocol_cls, event_loop) as protocol: + protocol.data_received(SIMPLE_HEAD_REQUEST) + protocol.loop.run_one() + assert b"HTTP/1.1 200 OK" in protocol.transport.buffer + assert not protocol.transport.is_closing() @pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -def test_pipelined_requests(protocol_cls): +def test_pipelined_requests(protocol_cls, event_loop): app = Response("Hello, world", media_type="text/plain") - protocol = get_connected_protocol(app, protocol_cls) - protocol.data_received(SIMPLE_GET_REQUEST) - protocol.data_received(SIMPLE_GET_REQUEST) - protocol.data_received(SIMPLE_GET_REQUEST) - - protocol.loop.run_one() - assert b"HTTP/1.1 200 OK" in protocol.transport.buffer - assert b"Hello, world" in protocol.transport.buffer - protocol.transport.clear_buffer() - - protocol.loop.run_one() - assert b"HTTP/1.1 200 OK" in protocol.transport.buffer - assert b"Hello, world" in protocol.transport.buffer - protocol.transport.clear_buffer() - - protocol.loop.run_one() - assert b"HTTP/1.1 200 OK" in protocol.transport.buffer - assert b"Hello, world" in protocol.transport.buffer - protocol.transport.clear_buffer() - - -@pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -def test_undersized_request(protocol_cls): + with get_connected_protocol(app, protocol_cls, event_loop) as protocol: + protocol.data_received(SIMPLE_GET_REQUEST) + protocol.data_received(SIMPLE_GET_REQUEST) + protocol.data_received(SIMPLE_GET_REQUEST) + protocol.loop.run_one() + assert b"HTTP/1.1 200 OK" in protocol.transport.buffer + assert b"Hello, world" in protocol.transport.buffer + protocol.transport.clear_buffer() + protocol.loop.run_one() + assert b"HTTP/1.1 200 OK" in protocol.transport.buffer + assert b"Hello, world" in protocol.transport.buffer + protocol.transport.clear_buffer() + protocol.loop.run_one() + assert b"HTTP/1.1 200 OK" in protocol.transport.buffer + assert b"Hello, world" in protocol.transport.buffer + protocol.transport.clear_buffer() + + +@pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) +def test_undersized_request(protocol_cls, event_loop): app = Response(b"xxx", headers={"content-length": "10"}) - protocol = get_connected_protocol(app, protocol_cls) - protocol.data_received(SIMPLE_GET_REQUEST) - protocol.loop.run_one() - assert protocol.transport.is_closing() + with get_connected_protocol(app, protocol_cls, event_loop) as protocol: + protocol.data_received(SIMPLE_GET_REQUEST) + protocol.loop.run_one() + assert protocol.transport.is_closing() @pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -def test_oversized_request(protocol_cls): +def test_oversized_request(protocol_cls, event_loop): app = Response(b"xxx" * 20, headers={"content-length": "10"}) - protocol = get_connected_protocol(app, protocol_cls) - protocol.data_received(SIMPLE_GET_REQUEST) - protocol.loop.run_one() - assert protocol.transport.is_closing() + with get_connected_protocol(app, protocol_cls, event_loop) as protocol: + protocol.data_received(SIMPLE_GET_REQUEST) + protocol.loop.run_one() + assert protocol.transport.is_closing() @pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -def test_large_post_request(protocol_cls): +def test_large_post_request(protocol_cls, event_loop): app = Response("Hello, world", media_type="text/plain") - protocol = get_connected_protocol(app, protocol_cls) - protocol.data_received(LARGE_POST_REQUEST) - assert protocol.transport.read_paused - protocol.loop.run_one() - assert not protocol.transport.read_paused + with get_connected_protocol(app, protocol_cls, event_loop) as protocol: + protocol.data_received(LARGE_POST_REQUEST) + assert protocol.transport.read_paused + protocol.loop.run_one() + assert not protocol.transport.read_paused @pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -def test_invalid_http(protocol_cls): +def test_invalid_http(protocol_cls, event_loop): app = Response("Hello, world", media_type="text/plain") - protocol = get_connected_protocol(app, protocol_cls) - protocol.data_received(b"x" * 100000) - - assert protocol.transport.is_closing() + with get_connected_protocol(app, protocol_cls, event_loop) as protocol: + protocol.data_received(b"x" * 100000) + assert protocol.transport.is_closing() @pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -def test_app_exception(protocol_cls): +def test_app_exception(protocol_cls, event_loop): async def app(scope, receive, send): raise Exception() - protocol = get_connected_protocol(app, protocol_cls) - protocol.data_received(SIMPLE_GET_REQUEST) - protocol.loop.run_one() - assert b"HTTP/1.1 500 Internal Server Error" in protocol.transport.buffer - assert protocol.transport.is_closing() + with get_connected_protocol(app, protocol_cls, event_loop) as protocol: + protocol.data_received(SIMPLE_GET_REQUEST) + protocol.loop.run_one() + assert b"HTTP/1.1 500 Internal Server Error" in protocol.transport.buffer + assert protocol.transport.is_closing() @pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -def test_exception_during_response(protocol_cls): +def test_exception_during_response(protocol_cls, event_loop): async def app(scope, receive, send): await send({"type": "http.response.start", "status": 200}) await send({"type": "http.response.body", "body": b"1", "more_body": True}) raise Exception() - protocol = get_connected_protocol(app, protocol_cls) - protocol.data_received(SIMPLE_GET_REQUEST) - protocol.loop.run_one() - assert b"HTTP/1.1 500 Internal Server Error" not in protocol.transport.buffer - assert protocol.transport.is_closing() + with get_connected_protocol(app, protocol_cls, event_loop) as protocol: + protocol.data_received(SIMPLE_GET_REQUEST) + protocol.loop.run_one() + assert b"HTTP/1.1 500 Internal Server Error" not in protocol.transport.buffer + assert protocol.transport.is_closing() @pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -def test_no_response_returned(protocol_cls): +def test_no_response_returned(protocol_cls, event_loop): async def app(scope, receive, send): pass - protocol = get_connected_protocol(app, protocol_cls) - protocol.data_received(SIMPLE_GET_REQUEST) - protocol.loop.run_one() - assert b"HTTP/1.1 500 Internal Server Error" in protocol.transport.buffer - assert protocol.transport.is_closing() + with get_connected_protocol(app, protocol_cls, event_loop) as protocol: + protocol.data_received(SIMPLE_GET_REQUEST) + protocol.loop.run_one() + assert b"HTTP/1.1 500 Internal Server Error" in protocol.transport.buffer + assert protocol.transport.is_closing() @pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -def test_partial_response_returned(protocol_cls): +def test_partial_response_returned(protocol_cls, event_loop): async def app(scope, receive, send): await send({"type": "http.response.start", "status": 200}) - protocol = get_connected_protocol(app, protocol_cls) - protocol.data_received(SIMPLE_GET_REQUEST) - protocol.loop.run_one() - - assert b"HTTP/1.1 500 Internal Server Error" not in protocol.transport.buffer - assert protocol.transport.is_closing() + with get_connected_protocol(app, protocol_cls, event_loop) as protocol: + protocol.data_received(SIMPLE_GET_REQUEST) + protocol.loop.run_one() + assert b"HTTP/1.1 500 Internal Server Error" not in protocol.transport.buffer + assert protocol.transport.is_closing() @pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -def test_duplicate_start_message(protocol_cls): +def test_duplicate_start_message(protocol_cls, event_loop): async def app(scope, receive, send): await send({"type": "http.response.start", "status": 200}) await send({"type": "http.response.start", "status": 200}) - protocol = get_connected_protocol(app, protocol_cls) - protocol.data_received(SIMPLE_GET_REQUEST) - protocol.loop.run_one() - assert b"HTTP/1.1 500 Internal Server Error" not in protocol.transport.buffer - assert protocol.transport.is_closing() + with get_connected_protocol(app, protocol_cls, event_loop) as protocol: + protocol.data_received(SIMPLE_GET_REQUEST) + protocol.loop.run_one() + assert b"HTTP/1.1 500 Internal Server Error" not in protocol.transport.buffer + assert protocol.transport.is_closing() @pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -def test_missing_start_message(protocol_cls): +def test_missing_start_message(protocol_cls, event_loop): async def app(scope, receive, send): await send({"type": "http.response.body", "body": b""}) - protocol = get_connected_protocol(app, protocol_cls) - protocol.data_received(SIMPLE_GET_REQUEST) - protocol.loop.run_one() - assert b"HTTP/1.1 500 Internal Server Error" in protocol.transport.buffer - assert protocol.transport.is_closing() + with get_connected_protocol(app, protocol_cls, event_loop) as protocol: + protocol.data_received(SIMPLE_GET_REQUEST) + protocol.loop.run_one() + assert b"HTTP/1.1 500 Internal Server Error" in protocol.transport.buffer + assert protocol.transport.is_closing() @pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -def test_message_after_body_complete(protocol_cls): +def test_message_after_body_complete(protocol_cls, event_loop): async def app(scope, receive, send): await send({"type": "http.response.start", "status": 200}) await send({"type": "http.response.body", "body": b""}) await send({"type": "http.response.body", "body": b""}) - protocol = get_connected_protocol(app, protocol_cls) - protocol.data_received(SIMPLE_GET_REQUEST) - protocol.loop.run_one() - assert b"HTTP/1.1 200 OK" in protocol.transport.buffer - assert protocol.transport.is_closing() + with get_connected_protocol(app, protocol_cls, event_loop) as protocol: + protocol.data_received(SIMPLE_GET_REQUEST) + protocol.loop.run_one() + assert b"HTTP/1.1 200 OK" in protocol.transport.buffer + assert protocol.transport.is_closing() @pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -def test_value_returned(protocol_cls): +def test_value_returned(protocol_cls, event_loop): async def app(scope, receive, send): await send({"type": "http.response.start", "status": 200}) await send({"type": "http.response.body", "body": b""}) return 123 - protocol = get_connected_protocol(app, protocol_cls) - protocol.data_received(SIMPLE_GET_REQUEST) - protocol.loop.run_one() - assert b"HTTP/1.1 200 OK" in protocol.transport.buffer - assert protocol.transport.is_closing() + with get_connected_protocol(app, protocol_cls, event_loop) as protocol: + protocol.data_received(SIMPLE_GET_REQUEST) + protocol.loop.run_one() + assert b"HTTP/1.1 200 OK" in protocol.transport.buffer + assert protocol.transport.is_closing() @pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -def test_early_disconnect(protocol_cls): +def test_early_disconnect(protocol_cls, event_loop): got_disconnect_event = False async def app(scope, receive, send): @@ -491,28 +488,28 @@ async def app(scope, receive, send): got_disconnect_event = True - protocol = get_connected_protocol(app, protocol_cls) - protocol.data_received(SIMPLE_POST_REQUEST) - protocol.eof_received() - protocol.connection_lost(None) - protocol.loop.run_one() - assert got_disconnect_event + with get_connected_protocol(app, protocol_cls, event_loop) as protocol: + protocol.data_received(SIMPLE_POST_REQUEST) + protocol.eof_received() + protocol.connection_lost(None) + protocol.loop.run_one() + assert got_disconnect_event @pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -def test_early_response(protocol_cls): +def test_early_response(protocol_cls, event_loop): app = Response("Hello, world", media_type="text/plain") - protocol = get_connected_protocol(app, protocol_cls) - protocol.data_received(START_POST_REQUEST) - protocol.loop.run_one() - assert b"HTTP/1.1 200 OK" in protocol.transport.buffer - protocol.data_received(FINISH_POST_REQUEST) - assert not protocol.transport.is_closing() + with get_connected_protocol(app, protocol_cls, event_loop) as protocol: + protocol.data_received(START_POST_REQUEST) + protocol.loop.run_one() + assert b"HTTP/1.1 200 OK" in protocol.transport.buffer + protocol.data_received(FINISH_POST_REQUEST) + assert not protocol.transport.is_closing() @pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -def test_read_after_response(protocol_cls): +def test_read_after_response(protocol_cls, event_loop): message_after_response = None async def app(scope, receive, send): @@ -522,44 +519,45 @@ async def app(scope, receive, send): await response(scope, receive, send) message_after_response = await receive() - protocol = get_connected_protocol(app, protocol_cls) - protocol.data_received(SIMPLE_POST_REQUEST) - protocol.loop.run_one() - assert b"HTTP/1.1 200 OK" in protocol.transport.buffer - assert message_after_response == {"type": "http.disconnect"} + with get_connected_protocol(app, protocol_cls, event_loop) as protocol: + protocol.data_received(SIMPLE_POST_REQUEST) + protocol.loop.run_one() + assert b"HTTP/1.1 200 OK" in protocol.transport.buffer + assert message_after_response == {"type": "http.disconnect"} @pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -def test_http10_request(protocol_cls): +def test_http10_request(protocol_cls, event_loop): async def app(scope, receive, send): content = "Version: %s" % scope["http_version"] response = Response(content, media_type="text/plain") await response(scope, receive, send) - protocol = get_connected_protocol(app, protocol_cls) - protocol.data_received(HTTP10_GET_REQUEST) - protocol.loop.run_one() - assert b"HTTP/1.1 200 OK" in protocol.transport.buffer - assert b"Version: 1.0" in protocol.transport.buffer + with get_connected_protocol(app, protocol_cls, event_loop) as protocol: + protocol.data_received(HTTP10_GET_REQUEST) + protocol.loop.run_one() + assert b"HTTP/1.1 200 OK" in protocol.transport.buffer + assert b"Version: 1.0" in protocol.transport.buffer @pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -def test_root_path(protocol_cls): +def test_root_path(protocol_cls, event_loop): async def app(scope, receive, send): path = scope.get("root_path", "") + scope["path"] response = Response("Path: " + path, media_type="text/plain") await response(scope, receive, send) - protocol = get_connected_protocol(app, protocol_cls, root_path="/app") - protocol.data_received(SIMPLE_GET_REQUEST) - protocol.loop.run_one() - - assert b"HTTP/1.1 200 OK" in protocol.transport.buffer - assert b"Path: /app/" in protocol.transport.buffer + with get_connected_protocol( + app, protocol_cls, event_loop, root_path="/app" + ) as protocol: + protocol.data_received(SIMPLE_GET_REQUEST) + protocol.loop.run_one() + assert b"HTTP/1.1 200 OK" in protocol.transport.buffer + assert b"Path: /app/" in protocol.transport.buffer @pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -def test_raw_path(protocol_cls): +def test_raw_path(protocol_cls, event_loop): async def app(scope, receive, send): path = scope["path"] raw_path = scope.get("raw_path", None) @@ -569,49 +567,50 @@ async def app(scope, receive, send): response = Response("Done", media_type="text/plain") await response(scope, receive, send) - protocol = get_connected_protocol(app, protocol_cls, root_path="/app") - protocol.data_received(GET_REQUEST_WITH_RAW_PATH) - protocol.loop.run_one() - assert b"Done" in protocol.transport.buffer + with get_connected_protocol( + app, protocol_cls, event_loop, root_path="/app" + ) as protocol: + protocol.data_received(GET_REQUEST_WITH_RAW_PATH) + protocol.loop.run_one() + assert b"Done" in protocol.transport.buffer @pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -def test_max_concurrency(protocol_cls): +def test_max_concurrency(protocol_cls, event_loop): app = Response("Hello, world", media_type="text/plain") - protocol = get_connected_protocol(app, protocol_cls, limit_concurrency=1) - protocol.data_received(SIMPLE_GET_REQUEST) - protocol.loop.run_one() - - assert b"HTTP/1.1 503 Service Unavailable" in protocol.transport.buffer + with get_connected_protocol( + app, protocol_cls, event_loop, limit_concurrency=1 + ) as protocol: + protocol.data_received(SIMPLE_GET_REQUEST) + protocol.loop.run_one() + assert b"HTTP/1.1 503 Service Unavailable" in protocol.transport.buffer @pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -def test_shutdown_during_request(protocol_cls): +def test_shutdown_during_request(protocol_cls, event_loop): app = Response(b"", status_code=204) - protocol = get_connected_protocol(app, protocol_cls) - protocol.data_received(SIMPLE_GET_REQUEST) - protocol.shutdown() - protocol.loop.run_one() - - assert b"HTTP/1.1 204 No Content" in protocol.transport.buffer - assert protocol.transport.is_closing() + with get_connected_protocol(app, protocol_cls, event_loop) as protocol: + protocol.data_received(SIMPLE_GET_REQUEST) + protocol.shutdown() + protocol.loop.run_one() + assert b"HTTP/1.1 204 No Content" in protocol.transport.buffer + assert protocol.transport.is_closing() @pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -def test_shutdown_during_idle(protocol_cls): +def test_shutdown_during_idle(protocol_cls, event_loop): app = Response("Hello, world", media_type="text/plain") - protocol = get_connected_protocol(app, protocol_cls) - protocol.shutdown() - - assert protocol.transport.buffer == b"" - assert protocol.transport.is_closing() + with get_connected_protocol(app, protocol_cls, event_loop) as protocol: + protocol.shutdown() + assert protocol.transport.buffer == b"" + assert protocol.transport.is_closing() @pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -def test_100_continue_sent_when_body_consumed(protocol_cls): +def test_100_continue_sent_when_body_consumed(protocol_cls, event_loop): async def app(scope, receive, send): body = b"" more_body = True @@ -622,68 +621,66 @@ async def app(scope, receive, send): response = Response(b"Body: " + body, media_type="text/plain") await response(scope, receive, send) - protocol = get_connected_protocol(app, protocol_cls) - EXPECT_100_REQUEST = b"\r\n".join( - [ - b"POST / HTTP/1.1", - b"Host: example.org", - b"Expect: 100-continue", - b"Content-Type: application/json", - b"Content-Length: 18", - b"", - b'{"hello": "world"}', - ] - ) - protocol.data_received(EXPECT_100_REQUEST) - protocol.loop.run_one() - - assert b"HTTP/1.1 100 Continue" in protocol.transport.buffer - assert b"HTTP/1.1 200 OK" in protocol.transport.buffer - assert b'Body: {"hello": "world"}' in protocol.transport.buffer - - -@pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -def test_100_continue_not_sent_when_body_not_consumed(protocol_cls): + with get_connected_protocol(app, protocol_cls, event_loop) as protocol: + EXPECT_100_REQUEST = b"\r\n".join( + [ + b"POST / HTTP/1.1", + b"Host: example.org", + b"Expect: 100-continue", + b"Content-Type: application/json", + b"Content-Length: 18", + b"", + b'{"hello": "world"}', + ] + ) + protocol.data_received(EXPECT_100_REQUEST) + protocol.loop.run_one() + assert b"HTTP/1.1 100 Continue" in protocol.transport.buffer + assert b"HTTP/1.1 200 OK" in protocol.transport.buffer + assert b'Body: {"hello": "world"}' in protocol.transport.buffer + + +@pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) +def test_100_continue_not_sent_when_body_not_consumed(protocol_cls, event_loop): app = Response(b"", status_code=204) - protocol = get_connected_protocol(app, protocol_cls) - EXPECT_100_REQUEST = b"\r\n".join( - [ - b"POST / HTTP/1.1", - b"Host: example.org", - b"Expect: 100-continue", - b"Content-Type: application/json", - b"Content-Length: 18", - b"", - b'{"hello": "world"}', - ] - ) - protocol.data_received(EXPECT_100_REQUEST) - protocol.loop.run_one() - - assert b"HTTP/1.1 100 Continue" not in protocol.transport.buffer - assert b"HTTP/1.1 204 No Content" in protocol.transport.buffer - - -@pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -def test_unsupported_upgrade_request(protocol_cls): + with get_connected_protocol(app, protocol_cls, event_loop) as protocol: + EXPECT_100_REQUEST = b"\r\n".join( + [ + b"POST / HTTP/1.1", + b"Host: example.org", + b"Expect: 100-continue", + b"Content-Type: application/json", + b"Content-Length: 18", + b"", + b'{"hello": "world"}', + ] + ) + protocol.data_received(EXPECT_100_REQUEST) + protocol.loop.run_one() + assert b"HTTP/1.1 100 Continue" not in protocol.transport.buffer + assert b"HTTP/1.1 204 No Content" in protocol.transport.buffer + + +@pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) +def test_unsupported_upgrade_request(protocol_cls, event_loop): app = Response("Hello, world", media_type="text/plain") - protocol = get_connected_protocol(app, protocol_cls, ws="none") - protocol.data_received(UPGRADE_REQUEST) - - assert b"HTTP/1.1 400 Bad Request" in protocol.transport.buffer - assert b"Unsupported upgrade request." in protocol.transport.buffer + with get_connected_protocol(app, protocol_cls, event_loop, ws="none") as protocol: + protocol.data_received(UPGRADE_REQUEST) + assert b"HTTP/1.1 400 Bad Request" in protocol.transport.buffer + assert b"Unsupported upgrade request." in protocol.transport.buffer @pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -def test_supported_upgrade_request(protocol_cls): +def test_supported_upgrade_request(protocol_cls, event_loop): app = Response("Hello, world", media_type="text/plain") - protocol = get_connected_protocol(app, protocol_cls, ws="wsproto") - protocol.data_received(UPGRADE_REQUEST) - - assert b"HTTP/1.1 426 " in protocol.transport.buffer + with get_connected_protocol( + app, protocol_cls, event_loop, ws="wsproto" + ) as protocol: + protocol.data_received(UPGRADE_REQUEST) + assert b"HTTP/1.1 426 " in protocol.transport.buffer async def asgi3app(scope, receive, send): @@ -705,11 +702,11 @@ async def asgi(receive, send): @pytest.mark.parametrize("asgi2or3_app, expected_scopes", asgi_scope_data) @pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -def test_scopes(asgi2or3_app, expected_scopes, protocol_cls): - protocol = get_connected_protocol(asgi2or3_app, protocol_cls) - protocol.data_received(SIMPLE_GET_REQUEST) - protocol.loop.run_one() - assert expected_scopes == protocol.scope.get("asgi") +def test_scopes(asgi2or3_app, expected_scopes, protocol_cls, event_loop): + with get_connected_protocol(asgi2or3_app, protocol_cls, event_loop) as protocol: + protocol.data_received(SIMPLE_GET_REQUEST) + protocol.loop.run_one() + assert expected_scopes == protocol.scope.get("asgi") @pytest.mark.parametrize( @@ -721,14 +718,14 @@ def test_scopes(asgi2or3_app, expected_scopes, protocol_cls): ], ) @pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -def test_invalid_http_request(request_line, protocol_cls, caplog): +def test_invalid_http_request(request_line, protocol_cls, caplog, event_loop): app = Response("Hello, world", media_type="text/plain") request = INVALID_REQUEST_TEMPLATE % request_line caplog.set_level(logging.INFO, logger="uvicorn.error") logging.getLogger("uvicorn.error").propagate = True - protocol = get_connected_protocol(app, protocol_cls) - protocol.data_received(request) - assert not protocol.transport.buffer - assert "Invalid HTTP request received." in caplog.messages + with get_connected_protocol(app, protocol_cls, event_loop) as protocol: + protocol.data_received(request) + assert not protocol.transport.buffer + assert "Invalid HTTP request received." in caplog.messages diff --git a/tests/test_config.py b/tests/test_config.py index 21405aecc..0b7383e52 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -140,8 +140,9 @@ def test_concrete_http_class(): def test_socket_bind(): config = Config(app=asgi_app) config.load() - - assert isinstance(config.bind_socket(), socket.socket) + sock = config.bind_socket() + assert isinstance(sock, socket.socket) + sock.close() def test_ssl_config(tls_ca_certificate_pem_path, tls_ca_certificate_private_key_path): diff --git a/tests/test_lifespan.py b/tests/test_lifespan.py index 462d4f6ae..d49a9dcf9 100644 --- a/tests/test_lifespan.py +++ b/tests/test_lifespan.py @@ -37,6 +37,7 @@ async def test(): loop = asyncio.new_event_loop() loop.run_until_complete(test()) + loop.close() def test_lifespan_off(): @@ -52,6 +53,7 @@ async def test(): loop = asyncio.new_event_loop() loop.run_until_complete(test()) + loop.close() def test_lifespan_auto(): @@ -84,6 +86,7 @@ async def test(): loop = asyncio.new_event_loop() loop.run_until_complete(test()) + loop.close() def test_lifespan_auto_with_error(): @@ -101,6 +104,7 @@ async def test(): loop = asyncio.new_event_loop() loop.run_until_complete(test()) + loop.close() def test_lifespan_on_with_error(): @@ -119,6 +123,7 @@ async def test(): loop = asyncio.new_event_loop() loop.run_until_complete(test()) + loop.close() @pytest.mark.parametrize("mode", ("auto", "on")) @@ -146,6 +151,7 @@ async def test(): loop = asyncio.new_event_loop() loop.run_until_complete(test()) + loop.close() error_messages = [ record.message for record in caplog.records @@ -174,6 +180,7 @@ async def test(): loop = asyncio.new_event_loop() loop.run_until_complete(test()) + loop.close() def test_lifespan_scope_asgi2app(): @@ -197,6 +204,7 @@ async def test(): loop = asyncio.new_event_loop() loop.run_until_complete(test()) + loop.close() @pytest.mark.parametrize("mode", ("auto", "on")) @@ -236,3 +244,4 @@ async def test(): ] assert "the lifespan event failed" in error_messages.pop(0) assert "Application shutdown failed. Exiting." in error_messages.pop(0) + loop.close() diff --git a/uvicorn/protocols/http/h11_impl.py b/uvicorn/protocols/http/h11_impl.py index 2b5df1f52..7e4834782 100644 --- a/uvicorn/protocols/http/h11_impl.py +++ b/uvicorn/protocols/http/h11_impl.py @@ -110,6 +110,8 @@ def connection_lost(self, exc): self.cycle.message_event.set() if self.flow is not None: self.flow.resume_writing() + if exc is None: + self.transport.close() if self.on_connection_lost is not None: self.on_connection_lost() diff --git a/uvicorn/protocols/http/httptools_impl.py b/uvicorn/protocols/http/httptools_impl.py index 5f0bc864b..51b095db9 100644 --- a/uvicorn/protocols/http/httptools_impl.py +++ b/uvicorn/protocols/http/httptools_impl.py @@ -110,6 +110,8 @@ def connection_lost(self, exc): self.cycle.message_event.set() if self.flow is not None: self.flow.resume_writing() + if exc is None: + self.transport.close() if self.on_connection_lost is not None: self.on_connection_lost() diff --git a/uvicorn/protocols/websockets/websockets_impl.py b/uvicorn/protocols/websockets/websockets_impl.py index 0238a5ee0..f166f7185 100644 --- a/uvicorn/protocols/websockets/websockets_impl.py +++ b/uvicorn/protocols/websockets/websockets_impl.py @@ -80,6 +80,8 @@ def connection_lost(self, exc): super().connection_lost(exc) if self.on_connection_lost is not None: self.on_connection_lost() + if exc is None: + self.transport.close() def shutdown(self): self.ws_server.closing = True @@ -99,7 +101,7 @@ async def process_request(self, path, headers): """ path_portion, _, query_string = path.partition("?") - websockets.handshake.check_request(headers) + websockets.legacy.handshake.check_request(headers) subprotocols = [] for header in headers.get_all("Sec-WebSocket-Protocol"): diff --git a/uvicorn/protocols/websockets/wsproto_impl.py b/uvicorn/protocols/websockets/wsproto_impl.py index cdb3ad685..eef1de5ce 100644 --- a/uvicorn/protocols/websockets/wsproto_impl.py +++ b/uvicorn/protocols/websockets/wsproto_impl.py @@ -70,6 +70,8 @@ def connection_lost(self, exc): self.connections.remove(self) if self.on_connection_lost is not None: self.on_connection_lost() + if exc is None: + self.transport.close() def eof_received(self): pass