From 4a503d84fa8703d7534d810bb10b3a0b0e6e1a39 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Sun, 3 Mar 2024 10:15:10 +0100 Subject: [PATCH] Change ruff rules (#2251) * Change ruff rules * fix type checker --- pyproject.toml | 3 +- tests/conftest.py | 25 ++-- tests/middleware/test_logging.py | 57 ++------ tests/middleware/test_message_logger.py | 8 +- tests/middleware/test_proxy_headers.py | 32 ++--- tests/middleware/test_wsgi.py | 24 ++-- tests/protocols/test_http.py | 55 ++------ tests/protocols/test_utils.py | 24 +--- tests/protocols/test_websocket.py | 121 +++++------------ tests/response.py | 5 +- tests/supervisors/test_multiprocess.py | 9 +- tests/supervisors/test_reload.py | 69 +++------- tests/supervisors/test_signal.py | 17 +-- tests/test_auto_detection.py | 4 +- tests/test_cli.py | 22 ++- tests/test_config.py | 128 ++++++------------ tests/test_default_headers.py | 9 +- tests/test_lifespan.py | 16 +-- tests/test_main.py | 19 +-- tests/test_ssl.py | 4 +- tests/test_subprocess.py | 9 +- uvicorn/_types.py | 16 +-- uvicorn/config.py | 74 +++------- uvicorn/importer.py | 8 +- uvicorn/lifespan/off.py | 6 +- uvicorn/lifespan/on.py | 16 +-- uvicorn/logging.py | 8 +- uvicorn/main.py | 36 ++--- uvicorn/middleware/asgi2.py | 4 +- uvicorn/middleware/proxy_headers.py | 35 ++--- uvicorn/middleware/wsgi.py | 41 +++--- uvicorn/protocols/http/flow_control.py | 4 +- uvicorn/protocols/http/h11_impl.py | 19 +-- uvicorn/protocols/http/httptools_impl.py | 27 ++-- uvicorn/protocols/utils.py | 4 +- uvicorn/protocols/websockets/auto.py | 4 +- .../protocols/websockets/websockets_impl.py | 61 +++------ uvicorn/protocols/websockets/wsproto_impl.py | 40 ++---- uvicorn/server.py | 22 +-- uvicorn/supervisors/__init__.py | 6 +- uvicorn/supervisors/basereload.py | 14 +- uvicorn/supervisors/multiprocess.py | 16 +-- uvicorn/supervisors/statreload.py | 5 +- uvicorn/supervisors/watchfilesreload.py | 12 +- uvicorn/supervisors/watchgodreload.py | 25 +--- uvicorn/workers.py | 10 +- 46 files changed, 358 insertions(+), 815 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index fc8f0af50..f8502bcb4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,8 @@ path = "uvicorn/__init__.py" include = ["/uvicorn"] [tool.ruff] -select = ["E", "F", "I"] +line-length = 120 +select = ["E", "F", "I", "FA", "UP"] ignore = ["B904", "B028"] [tool.ruff.lint.isort] diff --git a/tests/conftest.py b/tests/conftest.py index 6ee34b11a..b1214061a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import contextlib import importlib.util import os @@ -9,6 +11,7 @@ from tempfile import TemporaryDirectory from threading import Thread from time import sleep +from typing import Any from uuid import uuid4 import pytest @@ -38,14 +41,14 @@ @pytest.fixture -def tls_certificate_authority() -> "trustme.CA": +def tls_certificate_authority() -> trustme.CA: if not HAVE_TRUSTME: pytest.skip("trustme not installed") # pragma: no cover return trustme.CA() @pytest.fixture -def tls_certificate(tls_certificate_authority: "trustme.CA") -> "trustme.LeafCert": +def tls_certificate(tls_certificate_authority: trustme.CA) -> trustme.LeafCert: return tls_certificate_authority.issue_cert( "localhost", "127.0.0.1", @@ -54,13 +57,13 @@ def tls_certificate(tls_certificate_authority: "trustme.CA") -> "trustme.LeafCer @pytest.fixture -def tls_ca_certificate_pem_path(tls_certificate_authority: "trustme.CA"): +def tls_ca_certificate_pem_path(tls_certificate_authority: trustme.CA): with tls_certificate_authority.cert_pem.tempfile() as ca_cert_pem: yield ca_cert_pem @pytest.fixture -def tls_ca_certificate_private_key_path(tls_certificate_authority: "trustme.CA"): +def tls_ca_certificate_private_key_path(tls_certificate_authority: trustme.CA): with tls_certificate_authority.private_key_pem.tempfile() as private_key: yield private_key @@ -82,25 +85,25 @@ def tls_certificate_private_key_encrypted_path(tls_certificate): @pytest.fixture -def tls_certificate_private_key_path(tls_certificate: "trustme.CA"): +def tls_certificate_private_key_path(tls_certificate: trustme.CA): with tls_certificate.private_key_pem.tempfile() as private_key: yield private_key @pytest.fixture -def tls_certificate_key_and_chain_path(tls_certificate: "trustme.LeafCert"): +def tls_certificate_key_and_chain_path(tls_certificate: trustme.LeafCert): with tls_certificate.private_key_and_cert_chain_pem.tempfile() as cert_pem: yield cert_pem @pytest.fixture -def tls_certificate_server_cert_path(tls_certificate: "trustme.LeafCert"): +def tls_certificate_server_cert_path(tls_certificate: trustme.LeafCert): with tls_certificate.cert_chain_pems[0].tempfile() as cert_pem: yield cert_pem @pytest.fixture -def tls_ca_ssl_context(tls_certificate_authority: "trustme.CA") -> ssl.SSLContext: +def tls_ca_ssl_context(tls_certificate_authority: trustme.CA) -> ssl.SSLContext: ssl_ctx = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) tls_certificate_authority.configure_trust(ssl_ctx) return ssl_ctx @@ -172,7 +175,7 @@ def anyio_backend() -> str: @pytest.fixture(scope="function") -def logging_config() -> dict: +def logging_config() -> dict[str, Any]: return deepcopy(LOGGING_CONFIG) @@ -250,9 +253,7 @@ def unused_tcp_port() -> int: params=[ pytest.param( "uvicorn.protocols.websockets.wsproto_impl:WSProtocol", - marks=pytest.mark.skipif( - not importlib.util.find_spec("wsproto"), reason="wsproto not installed." - ), + marks=pytest.mark.skipif(not importlib.util.find_spec("wsproto"), reason="wsproto not installed."), id="wsproto", ), pytest.param( diff --git a/tests/middleware/test_logging.py b/tests/middleware/test_logging.py index a8d2d596f..59bef1d37 100644 --- a/tests/middleware/test_logging.py +++ b/tests/middleware/test_logging.py @@ -60,9 +60,7 @@ async def test_trace_logging(caplog, logging_config, unused_tcp_port: int): async with httpx.AsyncClient() as client: response = await client.get(f"http://127.0.0.1:{unused_tcp_port}") assert response.status_code == 204 - messages = [ - record.message for record in caplog.records if record.name == "uvicorn.asgi" - ] + messages = [record.message for record in caplog.records if record.name == "uvicorn.asgi"] assert "ASGI [1] Started scope=" in messages.pop(0) assert "ASGI [1] Raised exception" in messages.pop(0) assert "ASGI [2] Started scope=" in messages.pop(0) @@ -72,9 +70,7 @@ async def test_trace_logging(caplog, logging_config, unused_tcp_port: int): @pytest.mark.anyio -async def test_trace_logging_on_http_protocol( - http_protocol_cls, caplog, logging_config, unused_tcp_port: int -): +async def test_trace_logging_on_http_protocol(http_protocol_cls, caplog, logging_config, unused_tcp_port: int): config = Config( app=app, log_level="trace", @@ -87,11 +83,7 @@ async def test_trace_logging_on_http_protocol( async with httpx.AsyncClient() as client: response = await client.get(f"http://127.0.0.1:{unused_tcp_port}") assert response.status_code == 204 - messages = [ - record.message - for record in caplog.records - if record.name == "uvicorn.error" - ] + messages = [record.message for record in caplog.records if record.name == "uvicorn.error"] assert any(" - HTTP connection made" in message for message in messages) assert any(" - HTTP connection lost" in message for message in messages) @@ -127,11 +119,7 @@ async def open_connection(url): async with run_server(config): is_open = await open_connection(f"ws://127.0.0.1:{unused_tcp_port}") assert is_open - messages = [ - record.message - for record in caplog.records - if record.name == "uvicorn.error" - ] + messages = [record.message for record in caplog.records if record.name == "uvicorn.error"] assert any(" - Upgrading to WebSocket" in message for message in messages) assert any(" - WebSocket connection made" in message for message in messages) assert any(" - WebSocket connection lost" in message for message in messages) @@ -140,39 +128,27 @@ async def open_connection(url): @pytest.mark.anyio @pytest.mark.parametrize("use_colors", [(True), (False), (None)]) async def test_access_logging(use_colors, caplog, logging_config, unused_tcp_port: int): - config = Config( - app=app, use_colors=use_colors, log_config=logging_config, port=unused_tcp_port - ) + config = Config(app=app, use_colors=use_colors, log_config=logging_config, port=unused_tcp_port) with caplog_for_logger(caplog, "uvicorn.access"): async with run_server(config): async with httpx.AsyncClient() as client: response = await client.get(f"http://127.0.0.1:{unused_tcp_port}") assert response.status_code == 204 - messages = [ - record.message - for record in caplog.records - if record.name == "uvicorn.access" - ] + messages = [record.message for record in caplog.records if record.name == "uvicorn.access"] assert '"GET / HTTP/1.1" 204' in messages.pop() @pytest.mark.anyio @pytest.mark.parametrize("use_colors", [(True), (False)]) -async def test_default_logging( - use_colors, caplog, logging_config, unused_tcp_port: int -): - config = Config( - app=app, use_colors=use_colors, log_config=logging_config, port=unused_tcp_port - ) +async def test_default_logging(use_colors, caplog, logging_config, unused_tcp_port: int): + config = Config(app=app, use_colors=use_colors, log_config=logging_config, port=unused_tcp_port) with caplog_for_logger(caplog, "uvicorn.access"): async with run_server(config): async with httpx.AsyncClient() as client: response = await client.get(f"http://127.0.0.1:{unused_tcp_port}") assert response.status_code == 204 - messages = [ - record.message for record in caplog.records if "uvicorn" in record.name - ] + messages = [record.message for record in caplog.records if "uvicorn" in record.name] assert "Started server process" in messages.pop(0) assert "Waiting for application startup" in messages.pop(0) assert "ASGI 'lifespan' protocol appears unsupported" in messages.pop(0) @@ -184,19 +160,14 @@ async def test_default_logging( @pytest.mark.anyio @pytest.mark.skipif(sys.platform == "win32", reason="require unix-like system") -async def test_running_log_using_uds( - caplog, short_socket_name, unused_tcp_port: int -): # pragma: py-win32 +async def test_running_log_using_uds(caplog, short_socket_name, unused_tcp_port: int): # pragma: py-win32 config = Config(app=app, uds=short_socket_name, port=unused_tcp_port) with caplog_for_logger(caplog, "uvicorn.access"): async with run_server(config): ... messages = [record.message for record in caplog.records if "uvicorn" in record.name] - assert ( - f"Uvicorn running on unix socket {short_socket_name} (Press CTRL+C to quit)" - in messages - ) + assert f"Uvicorn running on unix socket {short_socket_name} (Press CTRL+C to quit)" in messages @pytest.mark.anyio @@ -227,11 +198,7 @@ async def app(scope, receive, send): response = await client.get(f"http://127.0.0.1:{unused_tcp_port}") assert response.status_code == 599 - messages = [ - record.message - for record in caplog.records - if record.name == "uvicorn.access" - ] + messages = [record.message for record in caplog.records if record.name == "uvicorn.access"] assert '"GET / HTTP/1.1" 599' in messages.pop() diff --git a/tests/middleware/test_message_logger.py b/tests/middleware/test_message_logger.py index 11fc4810f..3f5c3af2d 100644 --- a/tests/middleware/test_message_logger.py +++ b/tests/middleware/test_message_logger.py @@ -26,9 +26,7 @@ async def app(scope, receive, send): assert sum(["ASGI [1] Send" in message for message in messages]) == 2 assert sum(["ASGI [1] Receive" in message for message in messages]) == 1 assert sum(["ASGI [1] Completed" in message for message in messages]) == 1 - assert ( - sum(["ASGI [1] Raised exception" in message for message in messages]) == 0 - ) + assert sum(["ASGI [1] Raised exception" in message for message in messages]) == 0 @pytest.mark.anyio @@ -48,6 +46,4 @@ async def app(scope, receive, send): assert sum(["ASGI [1] Send" in message for message in messages]) == 0 assert sum(["ASGI [1] Receive" in message for message in messages]) == 0 assert sum(["ASGI [1] Completed" in message for message in messages]) == 0 - assert ( - sum(["ASGI [1] Raised exception" in message for message in messages]) == 1 - ) + assert sum(["ASGI [1] Raised exception" in message for message in messages]) == 1 diff --git a/tests/middleware/test_proxy_headers.py b/tests/middleware/test_proxy_headers.py index f86573fb5..6d7fc8c23 100644 --- a/tests/middleware/test_proxy_headers.py +++ b/tests/middleware/test_proxy_headers.py @@ -1,4 +1,6 @@ -from typing import TYPE_CHECKING, List, Type, Union +from __future__ import annotations + +from typing import TYPE_CHECKING import httpx import pytest @@ -45,13 +47,9 @@ async def app( ("192.168.0.1", "Remote: http://127.0.0.1:123"), ], ) -async def test_proxy_headers_trusted_hosts( - trusted_hosts: Union[List[str], str], response_text: str -) -> None: +async def test_proxy_headers_trusted_hosts(trusted_hosts: list[str] | str, response_text: str) -> None: app_with_middleware = ProxyHeadersMiddleware(app, trusted_hosts=trusted_hosts) - async with httpx.AsyncClient( - app=app_with_middleware, base_url="http://testserver" - ) as client: + async with httpx.AsyncClient(app=app_with_middleware, base_url="http://testserver") as client: headers = {"X-Forwarded-Proto": "https", "X-Forwarded-For": "1.2.3.4"} response = await client.get("/", headers=headers) @@ -79,13 +77,9 @@ async def test_proxy_headers_trusted_hosts( (["192.168.0.2", "127.0.0.1"], "Remote: https://10.0.2.1:0"), ], ) -async def test_proxy_headers_multiple_proxies( - trusted_hosts: Union[List[str], str], response_text: str -) -> None: +async def test_proxy_headers_multiple_proxies(trusted_hosts: list[str] | str, response_text: str) -> None: app_with_middleware = ProxyHeadersMiddleware(app, trusted_hosts=trusted_hosts) - async with httpx.AsyncClient( - app=app_with_middleware, base_url="http://testserver" - ) as client: + async with httpx.AsyncClient(app=app_with_middleware, base_url="http://testserver") as client: headers = { "X-Forwarded-Proto": "https", "X-Forwarded-For": "1.2.3.4, 10.0.2.1, 192.168.0.2", @@ -99,9 +93,7 @@ async def test_proxy_headers_multiple_proxies( @pytest.mark.anyio async def test_proxy_headers_invalid_x_forwarded_for() -> None: app_with_middleware = ProxyHeadersMiddleware(app, trusted_hosts="*") - async with httpx.AsyncClient( - app=app_with_middleware, base_url="http://testserver" - ) as client: + async with httpx.AsyncClient(app=app_with_middleware, base_url="http://testserver") as client: headers = httpx.Headers( { "X-Forwarded-Proto": "https", @@ -127,12 +119,14 @@ async def test_proxy_headers_invalid_x_forwarded_for() -> None: async def test_proxy_headers_websocket_x_forwarded_proto( x_forwarded_proto: str, addr: str, - ws_protocol_cls: "Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls: "Type[H11Protocol | HttpToolsProtocol]", + ws_protocol_cls: type[WSProtocol | WebSocketProtocol], + http_protocol_cls: type[H11Protocol | HttpToolsProtocol], unused_tcp_port: int, ) -> None: - async def websocket_app(scope, receive, send): + async def websocket_app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None: + assert scope["type"] == "websocket" scheme = scope["scheme"] + assert scope["client"] is not None host, port = scope["client"] addr = "%s://%s:%d" % (scheme, host, port) await send({"type": "websocket.accept"}) diff --git a/tests/middleware/test_wsgi.py b/tests/middleware/test_wsgi.py index 96a23def1..adc8e241a 100644 --- a/tests/middleware/test_wsgi.py +++ b/tests/middleware/test_wsgi.py @@ -1,6 +1,8 @@ +from __future__ import annotations + import io import sys -from typing import AsyncGenerator, Callable, List +from typing import AsyncGenerator, Callable import a2wsgi import httpx @@ -10,7 +12,7 @@ from uvicorn.middleware import wsgi -def hello_world(environ: Environ, start_response: StartResponse) -> List[bytes]: +def hello_world(environ: Environ, start_response: StartResponse) -> list[bytes]: status = "200 OK" output = b"Hello World!\n" headers = [ @@ -21,7 +23,7 @@ def hello_world(environ: Environ, start_response: StartResponse) -> List[bytes]: return [output] -def echo_body(environ: Environ, start_response: StartResponse) -> List[bytes]: +def echo_body(environ: Environ, start_response: StartResponse) -> list[bytes]: status = "200 OK" output = environ["wsgi.input"].read() headers = [ @@ -32,11 +34,11 @@ def echo_body(environ: Environ, start_response: StartResponse) -> List[bytes]: return [output] -def raise_exception(environ: Environ, start_response: StartResponse) -> List[bytes]: +def raise_exception(environ: Environ, start_response: StartResponse) -> list[bytes]: raise RuntimeError("Something went wrong") -def return_exc_info(environ: Environ, start_response: StartResponse) -> List[bytes]: +def return_exc_info(environ: Environ, start_response: StartResponse) -> list[bytes]: try: raise RuntimeError("Something went wrong") except RuntimeError: @@ -110,16 +112,14 @@ async def test_wsgi_exc_info(wsgi_middleware: Callable) -> None: app=app, raise_app_exceptions=False, ) - async with httpx.AsyncClient( - transport=transport, base_url="http://testserver" - ) as client: + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: response = await client.get("/") assert response.status_code == 500 assert response.text == "Internal Server Error" def test_build_environ_encoding() -> None: - scope: "HTTPScope" = { + scope: HTTPScope = { "asgi": {"version": "3.0", "spec_version": "2.0"}, "scheme": "http", "raw_path": b"/\xe6\x96\x87%2Fall", @@ -134,12 +134,12 @@ def test_build_environ_encoding() -> None: "headers": [(b"key", b"value1"), (b"key", b"value2")], "extensions": {}, } - message: "HTTPRequestEvent" = { + message: HTTPRequestEvent = { "type": "http.request", "body": b"", "more_body": False, } environ = wsgi.build_environ(scope, message, io.BytesIO(b"")) - assert environ["SCRIPT_NAME"] == "/文".encode("utf8").decode("latin-1") - assert environ["PATH_INFO"] == "/all".encode("utf8").decode("latin-1") + assert environ["SCRIPT_NAME"] == "/文".encode().decode("latin-1") + assert environ["PATH_INFO"] == b"/all".decode("latin-1") assert environ["HTTP_KEY"] == "value1,value2" diff --git a/tests/protocols/test_http.py b/tests/protocols/test_http.py index bc794c52b..d570f8c39 100644 --- a/tests/protocols/test_http.py +++ b/tests/protocols/test_http.py @@ -58,9 +58,7 @@ ] ) -CONNECTION_CLOSE_REQUEST = b"\r\n".join( - [b"GET / HTTP/1.1", b"Host: example.org", b"Connection: close", b"", b""] -) +CONNECTION_CLOSE_REQUEST = b"\r\n".join([b"GET / HTTP/1.1", b"Host: example.org", b"Connection: close", b"", b""]) LARGE_POST_REQUEST = b"\r\n".join( [ @@ -88,9 +86,7 @@ HTTP10_GET_REQUEST = b"\r\n".join([b"GET / HTTP/1.0", b"Host: example.org", b"", b""]) -GET_REQUEST_WITH_RAW_PATH = b"\r\n".join( - [b"GET /one%2Ftwo HTTP/1.1", b"Host: example.org", b"", b""] -) +GET_REQUEST_WITH_RAW_PATH = b"\r\n".join([b"GET /one%2Ftwo HTTP/1.1", b"Host: example.org", b"", b""]) UPGRADE_REQUEST = b"\r\n".join( [ @@ -257,11 +253,9 @@ async def test_get_request(http_protocol_cls: HTTPProtocol): @pytest.mark.anyio @pytest.mark.parametrize("path", ["/", "/?foo", "/?foo=bar", "/?foo=bar&baz=1"]) -async def test_request_logging( - path: str, http_protocol_cls: HTTPProtocol, caplog: pytest.LogCaptureFixture -): +async def test_request_logging(path: str, http_protocol_cls: HTTPProtocol, caplog: pytest.LogCaptureFixture): get_request_with_query_string = b"\r\n".join( - ["GET {} HTTP/1.1".format(path).encode("ascii"), b"Host: example.org", b"", b""] + [f"GET {path} HTTP/1.1".encode("ascii"), b"Host: example.org", b"", b""] ) caplog.set_level(logging.INFO, logger="uvicorn.access") logging.getLogger("uvicorn.access").propagate = True @@ -271,7 +265,7 @@ async def test_request_logging( protocol = get_connected_protocol(app, http_protocol_cls, log_config=None) protocol.data_received(get_request_with_query_string) await protocol.loop.run_one() - assert '"GET {} HTTP/1.1" 200'.format(path) in caplog.records[0].message + assert f'"GET {path} HTTP/1.1" 200' in caplog.records[0].message @pytest.mark.anyio @@ -371,9 +365,7 @@ async def test_close(http_protocol_cls: HTTPProtocol): @pytest.mark.anyio async def test_chunked_encoding(http_protocol_cls: HTTPProtocol): - app = Response( - b"Hello, world!", status_code=200, headers={"transfer-encoding": "chunked"} - ) + app = Response(b"Hello, world!", status_code=200, headers={"transfer-encoding": "chunked"}) protocol = get_connected_protocol(app, http_protocol_cls) protocol.data_received(SIMPLE_GET_REQUEST) @@ -385,9 +377,7 @@ async def test_chunked_encoding(http_protocol_cls: HTTPProtocol): @pytest.mark.anyio async def test_chunked_encoding_empty_body(http_protocol_cls: HTTPProtocol): - app = Response( - b"Hello, world!", status_code=200, headers={"transfer-encoding": "chunked"} - ) + app = Response(b"Hello, world!", status_code=200, headers={"transfer-encoding": "chunked"}) protocol = get_connected_protocol(app, http_protocol_cls) protocol.data_received(SIMPLE_GET_REQUEST) @@ -401,9 +391,7 @@ async def test_chunked_encoding_empty_body(http_protocol_cls: HTTPProtocol): async def test_chunked_encoding_head_request( http_protocol_cls: HTTPProtocol, ): - app = Response( - b"Hello, world!", status_code=200, headers={"transfer-encoding": "chunked"} - ) + app = Response(b"Hello, world!", status_code=200, headers={"transfer-encoding": "chunked"}) protocol = get_connected_protocol(app, http_protocol_cls) protocol.data_received(SIMPLE_HEAD_REQUEST) @@ -669,9 +657,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable assert scope["type"] == "http" root_path = scope.get("root_path", "") path = scope["path"] - response = Response( - f"root_path={root_path} path={path}", media_type="text/plain" - ) + response = Response(f"root_path={root_path} path={path}", media_type="text/plain") await response(scope, receive, send) protocol = get_connected_protocol(app, http_protocol_cls, root_path="/app") @@ -821,21 +807,14 @@ async def test_unsupported_ws_upgrade_request_warn_on_auto( await protocol.loop.run_one() assert b"HTTP/1.1 200 OK" in protocol.transport.buffer assert b"Hello, world" in protocol.transport.buffer - warnings = [ - record.msg - for record in filter( - lambda record: record.levelname == "WARNING", caplog.records - ) - ] + warnings = [record.msg for record in filter(lambda record: record.levelname == "WARNING", caplog.records)] assert "Unsupported upgrade request." in warnings msg = "No supported WebSocket library detected. Please use \"pip install 'uvicorn[standard]'\", or install 'websockets' or 'wsproto' manually." # noqa: E501 assert msg in warnings @pytest.mark.anyio -async def test_http2_upgrade_request( - http_protocol_cls: HTTPProtocol, ws_protocol_cls: WSProtocol -): +async def test_http2_upgrade_request(http_protocol_cls: HTTPProtocol, ws_protocol_cls: WSProtocol): app = Response("Hello, world", media_type="text/plain") protocol = get_connected_protocol(app, http_protocol_cls, ws=ws_protocol_cls) @@ -915,9 +894,7 @@ def receive_all(sock: socket.socket): def send_fragmented_req(path: str): sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect(("127.0.0.1", unused_tcp_port)) - d = ( - f"GET {path} HTTP/1.1\r\n" "Host: localhost\r\n" "Connection: close\r\n\r\n" - ).encode() + d = (f"GET {path} HTTP/1.1\r\n" "Host: localhost\r\n" "Connection: close\r\n\r\n").encode() split = len(path) // 2 sock.sendall(d[:split]) time.sleep(0.01) @@ -979,9 +956,7 @@ async def test_huge_headers_httptools_will_pass(): async def test_huge_headers_h11protocol_failure_with_setting(): app = Response("Hello, world", media_type="text/plain") - protocol = get_connected_protocol( - app, H11Protocol, h11_max_incomplete_event_size=20 * 1024 - ) + protocol = get_connected_protocol(app, H11Protocol, h11_max_incomplete_event_size=20 * 1024) # Huge headers make h11 fail in it's default config # h11 sends back a 400 in this case protocol.data_received(GET_REQUEST_HUGE_HEADERS[0]) @@ -1009,9 +984,7 @@ async def test_huge_headers_httptools(): async def test_huge_headers_h11_max_incomplete(): app = Response("Hello, world", media_type="text/plain") - protocol = get_connected_protocol( - app, H11Protocol, h11_max_incomplete_event_size=64 * 1024 - ) + protocol = get_connected_protocol(app, H11Protocol, h11_max_incomplete_event_size=64 * 1024) protocol.data_received(GET_REQUEST_HUGE_HEADERS[0]) protocol.data_received(GET_REQUEST_HUGE_HEADERS[1]) await protocol.loop.run_one() diff --git a/tests/protocols/test_utils.py b/tests/protocols/test_utils.py index b1f4cae1b..7639a99df 100644 --- a/tests/protocols/test_utils.py +++ b/tests/protocols/test_utils.py @@ -31,20 +31,14 @@ def test_get_local_addr_with_socket(): transport = MockTransport({"socket": MockSocket(family=socket.AF_IPX)}) assert get_local_addr(transport) is None - transport = MockTransport( - {"socket": MockSocket(family=socket.AF_INET6, sockname=("::1", 123))} - ) + transport = MockTransport({"socket": MockSocket(family=socket.AF_INET6, sockname=("::1", 123))}) assert get_local_addr(transport) == ("::1", 123) - transport = MockTransport( - {"socket": MockSocket(family=socket.AF_INET, sockname=("123.45.6.7", 123))} - ) + transport = MockTransport({"socket": MockSocket(family=socket.AF_INET, sockname=("123.45.6.7", 123))}) assert get_local_addr(transport) == ("123.45.6.7", 123) if hasattr(socket, "AF_UNIX"): # pragma: no cover - transport = MockTransport( - {"socket": MockSocket(family=socket.AF_UNIX, sockname=("127.0.0.1", 8000))} - ) + transport = MockTransport({"socket": MockSocket(family=socket.AF_UNIX, sockname=("127.0.0.1", 8000))}) assert get_local_addr(transport) == ("127.0.0.1", 8000) @@ -52,20 +46,14 @@ def test_get_remote_addr_with_socket(): transport = MockTransport({"socket": MockSocket(family=socket.AF_IPX)}) assert get_remote_addr(transport) is None - transport = MockTransport( - {"socket": MockSocket(family=socket.AF_INET6, peername=("::1", 123))} - ) + transport = MockTransport({"socket": MockSocket(family=socket.AF_INET6, peername=("::1", 123))}) assert get_remote_addr(transport) == ("::1", 123) - transport = MockTransport( - {"socket": MockSocket(family=socket.AF_INET, peername=("123.45.6.7", 123))} - ) + transport = MockTransport({"socket": MockSocket(family=socket.AF_INET, peername=("123.45.6.7", 123))}) assert get_remote_addr(transport) == ("123.45.6.7", 123) if hasattr(socket, "AF_UNIX"): # pragma: no cover - transport = MockTransport( - {"socket": MockSocket(family=socket.AF_UNIX, peername=("127.0.0.1", 8000))} - ) + transport = MockTransport({"socket": MockSocket(family=socket.AF_UNIX, peername=("127.0.0.1", 8000))}) assert get_remote_addr(transport) == ("127.0.0.1", 8000) diff --git a/tests/protocols/test_websocket.py b/tests/protocols/test_websocket.py index 93070b2c7..2d1e667de 100644 --- a/tests/protocols/test_websocket.py +++ b/tests/protocols/test_websocket.py @@ -50,9 +50,7 @@ class WebSocketResponse: - def __init__( - self, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable - ): + def __init__(self, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): self.scope = scope self.receive = receive self.send = send @@ -87,15 +85,11 @@ async def wsresponse(url): @pytest.mark.anyio -async def test_invalid_upgrade( - ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int -): +async def test_invalid_upgrade(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int): def app(scope: Scope): return None - config = Config( - app=app, ws=ws_protocol_cls, http=http_protocol_cls, port=unused_tcp_port - ) + config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, port=unused_tcp_port) async with run_server(config): async with httpx.AsyncClient() as client: response = await client.get( @@ -117,18 +111,14 @@ def app(scope: Scope): "missing sec-websocket-key header", "missing sec-websocket-version header", # websockets "missing or empty sec-websocket-key header", # wsproto - "failed to open a websocket connection: missing " - "sec-websocket-key header", - "failed to open a websocket connection: missing or empty " - "sec-websocket-key header", + "failed to open a websocket connection: missing " "sec-websocket-key header", + "failed to open a websocket connection: missing or empty " "sec-websocket-key header", ] ) @pytest.mark.anyio -async def test_accept_connection( - ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int -): +async def test_accept_connection(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int): class App(WebSocketResponse): async def websocket_connect(self, message): await self.send({"type": "websocket.accept"}) @@ -150,9 +140,7 @@ async def open_connection(url): @pytest.mark.anyio -async def test_shutdown( - ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int -): +async def test_shutdown(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int): class App(WebSocketResponse): async def websocket_connect(self, message): await self.send({"type": "websocket.accept"}) @@ -180,9 +168,7 @@ async def websocket_connect(self, message): async def open_connection(url): extension_factories = [ClientPerMessageDeflateFactory()] - async with websockets.client.connect( - url, extensions=extension_factories - ) as websocket: + async with websockets.client.connect(url, extensions=extension_factories) as websocket: return [extension.name for extension in websocket.extensions] config = Config( @@ -209,9 +195,7 @@ async def open_connection(url: str): # enable per-message deflate on the client, so that we can check the server # won't support it when it's disabled. extension_factories = [ClientPerMessageDeflateFactory()] - async with websockets.client.connect( - url, extensions=extension_factories - ) as websocket: + async with websockets.client.connect(url, extensions=extension_factories) as websocket: return [extension.name for extension in websocket.extensions] config = Config( @@ -228,9 +212,7 @@ async def open_connection(url: str): @pytest.mark.anyio -async def test_close_connection( - ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int -): +async def test_close_connection(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int): class App(WebSocketResponse): async def websocket_connect(self, message): await self.send({"type": "websocket.close"}) @@ -255,9 +237,7 @@ async def open_connection(url: str): @pytest.mark.anyio -async def test_headers( - ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int -): +async def test_headers(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int): class App(WebSocketResponse): async def websocket_connect(self, message): headers = self.scope.get("headers") @@ -267,9 +247,7 @@ async def websocket_connect(self, message): await self.send({"type": "websocket.accept"}) async def open_connection(url: str): - async with websockets.client.connect( - url, extra_headers=[("username", "abraão")] - ) as websocket: + async with websockets.client.connect(url, extra_headers=[("username", "abraão")]) as websocket: return websocket.open config = Config( @@ -285,14 +263,10 @@ async def open_connection(url: str): @pytest.mark.anyio -async def test_extra_headers( - ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int -): +async def test_extra_headers(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int): class App(WebSocketResponse): async def websocket_connect(self, message): - await self.send( - {"type": "websocket.accept", "headers": [(b"extra", b"header")]} - ) + await self.send({"type": "websocket.accept", "headers": [(b"extra", b"header")]}) async def open_connection(url: str): async with websockets.client.connect(url) as websocket: @@ -311,9 +285,7 @@ async def open_connection(url: str): @pytest.mark.anyio -async def test_path_and_raw_path( - ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int -): +async def test_path_and_raw_path(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int): class App(WebSocketResponse): async def websocket_connect(self, message): path = self.scope.get("path") @@ -515,9 +487,7 @@ async def get_data(url: str): @pytest.mark.anyio -async def test_missing_handshake( - ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int -): +async def test_missing_handshake(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int): async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): pass @@ -561,9 +531,7 @@ async def connect(url: str): @pytest.mark.anyio -async def test_duplicate_handshake( - ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int -): +async def test_duplicate_handshake(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int): async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): await send({"type": "websocket.accept"}) await send({"type": "websocket.accept"}) @@ -586,9 +554,7 @@ async def connect(url: str): @pytest.mark.anyio -async def test_asgi_return_value( - ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int -): +async def test_asgi_return_value(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int): """ The ASGI callable should return 'None'. If it doesn't, make sure that the connection is closed with an error condition. @@ -668,9 +634,7 @@ async def websocket_session(url: str): @pytest.mark.anyio -async def test_client_close( - ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int -): +async def test_client_close(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int): async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): while True: message = await receive() @@ -723,9 +687,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable port=unused_tcp_port, ) async with run_server(config): - async with websockets.client.connect( - f"ws://127.0.0.1:{unused_tcp_port}" - ) as websocket: + async with websockets.client.connect(f"ws://127.0.0.1:{unused_tcp_port}") as websocket: websocket.transport.close() await asyncio.sleep(0.1) got_disconnect_event_before_shutdown = got_disconnect_event @@ -748,7 +710,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable try: await disconnect.wait() await send({"type": "websocket.send", "text": "123"}) - except IOError: + except OSError: got_disconnect_event = True config = Config( @@ -781,7 +743,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable await send_accept_task.wait() disconnect_message = await receive() # type: ignore - response: typing.Optional[httpx.Response] = None + response: httpx.Response | None = None async def websocket_session(uri: str): nonlocal response @@ -804,9 +766,7 @@ async def websocket_session(uri: str): port=unused_tcp_port, ) async with run_server(config): - task = asyncio.create_task( - websocket_session(f"ws://127.0.0.1:{unused_tcp_port}") - ) + task = asyncio.create_task(websocket_session(f"ws://127.0.0.1:{unused_tcp_port}")) await asyncio.sleep(0.1) send_accept_task.set() await asyncio.sleep(0.1) @@ -835,7 +795,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable disconnect_message = message break - websocket: typing.Optional[websockets.client.WebSocketClientProtocol] = None + websocket: websockets.client.WebSocketClientProtocol | None = None async def websocket_session(uri: str): nonlocal websocket @@ -851,9 +811,7 @@ async def websocket_session(uri: str): port=unused_tcp_port, ) async with run_server(config): - task = asyncio.create_task( - websocket_session(f"ws://127.0.0.1:{unused_tcp_port}") - ) + task = asyncio.create_task(websocket_session(f"ws://127.0.0.1:{unused_tcp_port}")) await asyncio.sleep(0.1) disconnect_message_before_shutdown = disconnect_message server_shutdown_event.set() @@ -891,9 +849,7 @@ async def get_subprotocol(url: str): port=unused_tcp_port, ) async with run_server(config): - accepted_subprotocol = await get_subprotocol( - f"ws://127.0.0.1:{unused_tcp_port}" - ) + accepted_subprotocol = await get_subprotocol(f"ws://127.0.0.1:{unused_tcp_port}") assert accepted_subprotocol == subprotocol @@ -1257,7 +1213,7 @@ async def test_server_multiple_websocket_http_response_start_events( The server should raise an exception if it sends multiple websocket.http.response.start events. """ - exception_message: typing.Optional[str] = None + exception_message: str | None = None async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): nonlocal exception_message @@ -1297,8 +1253,7 @@ async def websocket_session(url: str): await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}") assert exception_message == ( - "Expected ASGI message 'websocket.http.response.body' but got " - "'websocket.http.response.start'." + "Expected ASGI message 'websocket.http.response.body' but got " "'websocket.http.response.start'." ) @@ -1369,9 +1324,7 @@ async def open_connection(url: str): @pytest.mark.anyio -async def test_no_server_headers( - ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int -): +async def test_no_server_headers(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int): class App(WebSocketResponse): async def websocket_connect(self, message): await self.send({"type": "websocket.accept"}) @@ -1395,9 +1348,7 @@ async def open_connection(url: str): @pytest.mark.anyio @skip_if_no_wsproto -async def test_no_date_header_on_wsproto( - http_protocol_cls: HTTPProtocol, unused_tcp_port: int -): +async def test_no_date_header_on_wsproto(http_protocol_cls: HTTPProtocol, unused_tcp_port: int): class App(WebSocketResponse): async def websocket_connect(self, message): await self.send({"type": "websocket.accept"}) @@ -1452,9 +1403,7 @@ async def open_connection(url: str): @pytest.mark.anyio -async def test_lifespan_state( - ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int -): +async def test_lifespan_state(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int): expected_states = [ {"a": 123, "b": [1]}, {"a": 123, "b": [1, 2]}, @@ -1462,9 +1411,7 @@ async def test_lifespan_state( actual_states = [] - async def lifespan_app( - scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable - ): + async def lifespan_app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): message = await receive() assert message["type"] == "lifespan.startup" and "state" in scope scope["state"]["a"] = 123 @@ -1485,9 +1432,7 @@ async def open_connection(url: str): async with websockets.client.connect(url) as websocket: return websocket.open - async def app_wrapper( - scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable - ): + async def app_wrapper(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): if scope["type"] == "lifespan": return await lifespan_app(scope, receive, send) return await App(scope, receive, send) diff --git a/tests/response.py b/tests/response.py index 55766d3f1..c88fdf53b 100644 --- a/tests/response.py +++ b/tests/response.py @@ -15,10 +15,7 @@ async def __call__(self, scope, receive, send) -> None: { "type": prefix + "http.response.start", "status": self.status_code, - "headers": [ - [key.encode(), value.encode()] - for key, value in self.headers.items() - ], + "headers": [[key.encode(), value.encode()] for key, value in self.headers.items()], } ) await send({"type": prefix + "http.response.body", "body": self.body}) diff --git a/tests/supervisors/test_multiprocess.py b/tests/supervisors/test_multiprocess.py index 82dc1118a..391b66a73 100644 --- a/tests/supervisors/test_multiprocess.py +++ b/tests/supervisors/test_multiprocess.py @@ -1,19 +1,18 @@ +from __future__ import annotations + import signal import socket -from typing import List, Optional from uvicorn import Config from uvicorn._types import ASGIReceiveCallable, ASGISendCallable, Scope from uvicorn.supervisors import Multiprocess -async def app( - scope: "Scope", receive: "ASGIReceiveCallable", send: "ASGISendCallable" -) -> None: +async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None: pass # pragma: no cover -def run(sockets: Optional[List[socket.socket]]) -> None: +def run(sockets: list[socket.socket] | None) -> None: pass # pragma: no cover diff --git a/tests/supervisors/test_reload.py b/tests/supervisors/test_reload.py index 4e3822375..30eea2321 100644 --- a/tests/supervisors/test_reload.py +++ b/tests/supervisors/test_reload.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging import platform import signal @@ -5,7 +7,6 @@ import sys from pathlib import Path from time import sleep -from typing import List, Optional, Type import pytest @@ -41,7 +42,7 @@ class TestBaseReload: def setup( self, reload_directory_structure: Path, - reloader_class: Optional[Type[BaseReload]], + reloader_class: type[BaseReload] | None, ): if reloader_class is None: # pragma: no cover pytest.skip("Needed dependency not installed") @@ -61,9 +62,7 @@ def _setup_reloader(self, config: Config) -> BaseReload: reloader.startup() return reloader - def _reload_tester( - self, touch_soon, reloader: BaseReload, *files: Path - ) -> Optional[List[Path]]: + def _reload_tester(self, touch_soon, reloader: BaseReload, *files: Path) -> list[Path] | None: reloader.restart() if WatchFilesReload is not None and isinstance(reloader, WatchFilesReload): touch_soon(*files) @@ -74,9 +73,7 @@ def _reload_tester( file.touch() return next(reloader) - @pytest.mark.parametrize( - "reloader_class", [StatReload, WatchGodReload, WatchFilesReload] - ) + @pytest.mark.parametrize("reloader_class", [StatReload, WatchGodReload, WatchFilesReload]) def test_reloader_should_initialize(self) -> None: """ A basic sanity check. @@ -89,9 +86,7 @@ def test_reloader_should_initialize(self) -> None: reloader = self._setup_reloader(config) reloader.shutdown() - @pytest.mark.parametrize( - "reloader_class", [StatReload, WatchGodReload, WatchFilesReload] - ) + @pytest.mark.parametrize("reloader_class", [StatReload, WatchGodReload, WatchFilesReload]) def test_reload_when_python_file_is_changed(self, touch_soon) -> None: file = self.reload_path / "main.py" @@ -104,12 +99,8 @@ def test_reload_when_python_file_is_changed(self, touch_soon) -> None: reloader.shutdown() - @pytest.mark.parametrize( - "reloader_class", [StatReload, WatchGodReload, WatchFilesReload] - ) - def test_should_reload_when_python_file_in_subdir_is_changed( - self, touch_soon - ) -> None: + @pytest.mark.parametrize("reloader_class", [StatReload, WatchGodReload, WatchFilesReload]) + def test_should_reload_when_python_file_in_subdir_is_changed(self, touch_soon) -> None: file = self.reload_path / "app" / "sub" / "sub.py" with as_cwd(self.reload_path): @@ -121,9 +112,7 @@ def test_should_reload_when_python_file_in_subdir_is_changed( reloader.shutdown() @pytest.mark.parametrize("reloader_class", [WatchFilesReload, WatchGodReload]) - def test_should_not_reload_when_python_file_in_excluded_subdir_is_changed( - self, touch_soon - ) -> None: + def test_should_not_reload_when_python_file_in_excluded_subdir_is_changed(self, touch_soon) -> None: sub_dir = self.reload_path / "app" / "sub" sub_file = sub_dir / "sub.py" @@ -139,18 +128,12 @@ def test_should_not_reload_when_python_file_in_excluded_subdir_is_changed( reloader.shutdown() - @pytest.mark.parametrize( - "reloader_class, result", [(StatReload, False), (WatchFilesReload, True)] - ) - def test_reload_when_pattern_matched_file_is_changed( - self, result: bool, touch_soon - ) -> None: + @pytest.mark.parametrize("reloader_class, result", [(StatReload, False), (WatchFilesReload, True)]) + def test_reload_when_pattern_matched_file_is_changed(self, result: bool, touch_soon) -> None: file = self.reload_path / "app" / "js" / "main.js" with as_cwd(self.reload_path): - config = Config( - app="tests.test_config:asgi_app", reload=True, reload_includes=["*.js"] - ) + config = Config(app="tests.test_config:asgi_app", reload=True, reload_includes=["*.js"]) reloader = self._setup_reloader(config) assert bool(self._reload_tester(touch_soon, reloader, file)) == result @@ -164,9 +147,7 @@ def test_reload_when_pattern_matched_file_is_changed( WatchGodReload, ], ) - def test_should_not_reload_when_exclude_pattern_match_file_is_changed( - self, touch_soon - ) -> None: + def test_should_not_reload_when_exclude_pattern_match_file_is_changed(self, touch_soon) -> None: python_file = self.reload_path / "app" / "src" / "main.py" css_file = self.reload_path / "app" / "css" / "main.css" js_file = self.reload_path / "app" / "js" / "main.js" @@ -186,9 +167,7 @@ def test_should_not_reload_when_exclude_pattern_match_file_is_changed( reloader.shutdown() - @pytest.mark.parametrize( - "reloader_class", [StatReload, WatchGodReload, WatchFilesReload] - ) + @pytest.mark.parametrize("reloader_class", [StatReload, WatchGodReload, WatchFilesReload]) def test_should_not_reload_when_dot_file_is_changed(self, touch_soon) -> None: file = self.reload_path / ".dotted" @@ -200,9 +179,7 @@ def test_should_not_reload_when_dot_file_is_changed(self, touch_soon) -> None: reloader.shutdown() - @pytest.mark.parametrize( - "reloader_class", [StatReload, WatchGodReload, WatchFilesReload] - ) + @pytest.mark.parametrize("reloader_class", [StatReload, WatchGodReload, WatchFilesReload]) def test_should_reload_when_directories_have_same_prefix(self, touch_soon) -> None: app_dir = self.reload_path / "app" app_file = app_dir / "src" / "main.py" @@ -230,9 +207,7 @@ def test_should_reload_when_directories_have_same_prefix(self, touch_soon) -> No pytest.param(WatchFilesReload, marks=skip_if_m1), ], ) - def test_should_not_reload_when_only_subdirectory_is_watched( - self, touch_soon - ) -> None: + def test_should_not_reload_when_only_subdirectory_is_watched(self, touch_soon) -> None: app_dir = self.reload_path / "app" app_dir_file = self.reload_path / "app" / "src" / "main.py" root_file = self.reload_path / "main.py" @@ -245,9 +220,7 @@ def test_should_not_reload_when_only_subdirectory_is_watched( reloader = self._setup_reloader(config) assert self._reload_tester(touch_soon, reloader, app_dir_file) - assert not self._reload_tester( - touch_soon, reloader, root_file, app_dir / "~ignored" - ) + assert not self._reload_tester(touch_soon, reloader, root_file, app_dir / "~ignored") reloader.shutdown() @@ -335,9 +308,7 @@ def test_watchfiles_no_changes(self) -> None: reloader.shutdown() @pytest.mark.parametrize("reloader_class", [WatchGodReload]) - def test_should_detect_new_reload_dirs( - self, touch_soon, caplog: pytest.LogCaptureFixture, tmp_path: Path - ) -> None: + def test_should_detect_new_reload_dirs(self, touch_soon, caplog: pytest.LogCaptureFixture, tmp_path: Path) -> None: app_dir = tmp_path / "app" app_file = app_dir / "file.py" app_dir.mkdir() @@ -346,9 +317,7 @@ def test_should_detect_new_reload_dirs( app_first_file = app_first_dir / "file.py" with as_cwd(tmp_path): - config = Config( - app="tests.test_config:asgi_app", reload=True, reload_includes=["app*"] - ) + config = Config(app="tests.test_config:asgi_app", reload=True, reload_includes=["app*"]) reloader = self._setup_reloader(config) assert self._reload_tester(touch_soon, reloader, app_file) diff --git a/tests/supervisors/test_signal.py b/tests/supervisors/test_signal.py index 32801a8a6..95c4675d6 100644 --- a/tests/supervisors/test_signal.py +++ b/tests/supervisors/test_signal.py @@ -29,9 +29,7 @@ async def wait_app(scope, receive, send): await server_event.wait() await send({"type": "http.response.body", "body": b"end", "more_body": False}) - config = Config( - app=wait_app, reload=False, port=unused_tcp_port, timeout_graceful_shutdown=1 - ) + config = Config(app=wait_app, reload=False, port=unused_tcp_port, timeout_graceful_shutdown=1) server: Server async with run_server(config) as server: async with httpx.AsyncClient() as client: @@ -64,9 +62,7 @@ async def forever_app(scope, receive, send): await server_event.wait() await send({"type": "http.response.body", "body": b"end", "more_body": False}) - config = Config( - app=forever_app, reload=False, port=unused_tcp_port, timeout_graceful_shutdown=1 - ) + config = Config(app=forever_app, reload=False, port=unused_tcp_port, timeout_graceful_shutdown=1) server: Server async with run_server(config) as server: async with httpx.AsyncClient() as client: @@ -78,10 +74,7 @@ async def forever_app(scope, receive, send): await req # req.result() - assert ( - "Cancel 1 running task(s), timeout graceful shutdown exceeded" - in caplog.messages - ) + assert "Cancel 1 running task(s), timeout graceful shutdown exceeded" in caplog.messages @pytest.mark.anyio @@ -99,9 +92,7 @@ async def app(scope, receive, send): await send({"type": "http.response.start", "status": 200, "headers": []}) await asyncio.sleep(1) - config = Config( - app=app, reload=False, port=unused_tcp_port, timeout_graceful_shutdown=1 - ) + config = Config(app=app, reload=False, port=unused_tcp_port, timeout_graceful_shutdown=1) server: Server async with run_server(config) as server: # exit and ensure we do not accept more requests diff --git a/tests/test_auto_detection.py b/tests/test_auto_detection.py index 3cc8e19d0..54ab904e5 100644 --- a/tests/test_auto_detection.py +++ b/tests/test_auto_detection.py @@ -59,7 +59,5 @@ async def test_websocket_auto(): server_state = ServerState() assert AutoWebSocketsProtocol is not None - protocol = AutoWebSocketsProtocol( - config=config, server_state=server_state, app_state={} - ) + protocol = AutoWebSocketsProtocol(config=config, server_state=server_state, app_state={}) assert type(protocol).__name__ == expected_websockets diff --git a/tests/test_cli.py b/tests/test_cli.py index b5d18cacc..cb24b10a6 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -41,12 +41,11 @@ def test_cli_print_version() -> None: assert result.exit_code == 0 assert ( - "Running uvicorn %s with %s %s on %s" - % ( - uvicorn.__version__, - platform.python_implementation(), - platform.python_version(), - platform.system(), + "Running uvicorn {version} with {py_implementation} {py_version} on {system}".format( + version=uvicorn.__version__, + py_implementation=platform.python_implementation(), + py_version=platform.python_version(), + system=platform.system(), ) ) in result.output @@ -103,9 +102,7 @@ def test_cli_call_multiprocess_run() -> None: @pytest.fixture(params=(True, False)) -def uds_file( - tmp_path: Path, request: pytest.FixtureRequest -) -> Path: # pragma: py-win32 +def uds_file(tmp_path: Path, request: pytest.FixtureRequest) -> Path: # pragma: py-win32 file = tmp_path / "uvicorn.sock" should_create_file = request.param if should_create_file: @@ -119,9 +116,7 @@ def test_cli_uds(uds_file: Path) -> None: # pragma: py-win32 with mock.patch.object(Config, "bind_socket") as mock_bind_socket: with mock.patch.object(Multiprocess, "run") as mock_run: - result = runner.invoke( - cli, ["tests.test_cli:App", "--workers=2", "--uds", str(uds_file)] - ) + result = runner.invoke(cli, ["tests.test_cli:App", "--workers=2", "--uds", str(uds_file)]) assert result.exit_code == 0 assert result.output == "" @@ -136,8 +131,7 @@ def test_cli_incomplete_app_parameter() -> None: result = runner.invoke(cli, ["tests.test_cli"]) assert ( - 'Error loading ASGI app. Import string "tests.test_cli" ' - 'must be in format ":".' + 'Error loading ASGI app. Import string "tests.test_cli" ' 'must be in format ":".' ) in result.output assert result.exit_code == 1 diff --git a/tests/test_config.py b/tests/test_config.py index 17ab097e6..ca305f6c2 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json import logging import os @@ -5,7 +7,7 @@ import sys import typing from pathlib import Path -from typing import Literal, Optional +from typing import Any, Literal from unittest.mock import MagicMock import pytest @@ -42,9 +44,7 @@ def yaml_logging_config(logging_config: dict) -> str: return yaml.dump(logging_config) -async def asgi_app( - scope: "Scope", receive: "ASGIReceiveCallable", send: "ASGISendCallable" -) -> None: +async def asgi_app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None: pass # pragma: nocover @@ -56,65 +56,44 @@ def wsgi_app(environ: Environ, start_response: StartResponse) -> None: "app, expected_should_reload", [(asgi_app, False), ("tests.test_config:asgi_app", True)], ) -def test_config_should_reload_is_set( - app: "ASGIApplication", expected_should_reload: bool -) -> None: +def test_config_should_reload_is_set(app: ASGIApplication, expected_should_reload: bool) -> None: config = Config(app=app, reload=True) assert config.reload is True assert config.should_reload is expected_should_reload -def test_should_warn_on_invalid_reload_configuration( - tmp_path: Path, caplog: pytest.LogCaptureFixture -) -> None: +def test_should_warn_on_invalid_reload_configuration(tmp_path: Path, caplog: pytest.LogCaptureFixture) -> None: config_class = Config(app=asgi_app, reload_dirs=[str(tmp_path)]) assert not config_class.should_reload assert len(caplog.records) == 1 assert ( - caplog.records[-1].message - == "Current configuration will not reload as not all conditions are met, " + caplog.records[-1].message == "Current configuration will not reload as not all conditions are met, " "please refer to documentation." ) - config_no_reload = Config( - app="tests.test_config:asgi_app", reload_dirs=[str(tmp_path)] - ) + config_no_reload = Config(app="tests.test_config:asgi_app", reload_dirs=[str(tmp_path)]) assert not config_no_reload.should_reload assert len(caplog.records) == 2 assert ( - caplog.records[-1].message - == "Current configuration will not reload as not all conditions are met, " + caplog.records[-1].message == "Current configuration will not reload as not all conditions are met, " "please refer to documentation." ) -def test_reload_dir_is_set( - reload_directory_structure: Path, caplog: pytest.LogCaptureFixture -) -> None: +def test_reload_dir_is_set(reload_directory_structure: Path, caplog: pytest.LogCaptureFixture) -> None: app_dir = reload_directory_structure / "app" with caplog.at_level(logging.INFO): - config = Config( - app="tests.test_config:asgi_app", reload=True, reload_dirs=[str(app_dir)] - ) + config = Config(app="tests.test_config:asgi_app", reload=True, reload_dirs=[str(app_dir)]) assert len(caplog.records) == 1 - assert ( - caplog.records[-1].message - == f"Will watch for changes in these directories: {[str(app_dir)]}" - ) + assert caplog.records[-1].message == f"Will watch for changes in these directories: {[str(app_dir)]}" assert config.reload_dirs == [app_dir] - config = Config( - app="tests.test_config:asgi_app", reload=True, reload_dirs=str(app_dir) - ) + config = Config(app="tests.test_config:asgi_app", reload=True, reload_dirs=str(app_dir)) assert config.reload_dirs == [app_dir] -def test_non_existant_reload_dir_is_not_set( - reload_directory_structure: Path, caplog: pytest.LogCaptureFixture -) -> None: +def test_non_existant_reload_dir_is_not_set(reload_directory_structure: Path, caplog: pytest.LogCaptureFixture) -> None: with as_cwd(reload_directory_structure), caplog.at_level(logging.WARNING): - config = Config( - app="tests.test_config:asgi_app", reload=True, reload_dirs=["reload"] - ) + config = Config(app="tests.test_config:asgi_app", reload=True, reload_dirs=["reload"]) assert config.reload_dirs == [reload_directory_structure] assert ( caplog.records[-1].message @@ -129,9 +108,7 @@ def test_reload_subdir_removal(reload_directory_structure: Path) -> None: reload_dirs = [str(reload_directory_structure), "app", str(app_dir)] with as_cwd(reload_directory_structure): - config = Config( - app="tests.test_config:asgi_app", reload=True, reload_dirs=reload_dirs - ) + config = Config(app="tests.test_config:asgi_app", reload=True, reload_dirs=reload_dirs) assert config.reload_dirs == [reload_directory_structure] @@ -188,9 +165,7 @@ def test_reload_excluded_subdirectories_are_removed( ) assert frozenset(config.reload_dirs) == frozenset([reload_directory_structure]) assert frozenset(config.reload_dirs_excludes) == frozenset([app_dir]) - assert frozenset(config.reload_excludes) == frozenset( - [str(app_dir), str(app_sub_dir)] - ) + assert frozenset(config.reload_excludes) == frozenset([str(app_dir), str(app_sub_dir)]) def test_reload_includes_exclude_dir_patterns_are_matched( @@ -209,13 +184,10 @@ def test_reload_includes_exclude_dir_patterns_are_matched( ) assert len(caplog.records) == 1 assert ( - caplog.records[-1].message - == "Will watch for changes in these directories: " + caplog.records[-1].message == "Will watch for changes in these directories: " f"{sorted([str(first_app_dir), str(second_app_dir)])}" ) - assert frozenset(config.reload_dirs) == frozenset( - [first_app_dir, second_app_dir] - ) + assert frozenset(config.reload_dirs) == frozenset([first_app_dir, second_app_dir]) assert config.reload_includes == ["*/src"] @@ -247,9 +219,7 @@ def test_app_unimportable_other(caplog: pytest.LogCaptureFixture) -> None: with pytest.raises(SystemExit): config.load() error_messages = [ - record.message - for record in caplog.records - if record.name == "uvicorn.error" and record.levelname == "ERROR" + record.message for record in caplog.records if record.name == "uvicorn.error" and record.levelname == "ERROR" ] assert ( 'Error loading ASGI app. Attribute "app" not found in module "tests.test_config".' # noqa: E501 @@ -258,7 +228,7 @@ def test_app_unimportable_other(caplog: pytest.LogCaptureFixture) -> None: def test_app_factory(caplog: pytest.LogCaptureFixture) -> None: - def create_app() -> "ASGIApplication": + def create_app() -> ASGIApplication: return asgi_app config = Config(app=create_app, factory=True, proxy_headers=False) @@ -319,21 +289,15 @@ def test_ssl_config_combined(tls_certificate_key_and_chain_path: str) -> None: assert config.is_ssl is True -def asgi2_app(scope: "Scope") -> typing.Callable: - async def asgi( - receive: "ASGIReceiveCallable", send: "ASGISendCallable" - ) -> None: # pragma: nocover +def asgi2_app(scope: Scope) -> typing.Callable: + async def asgi(receive: ASGIReceiveCallable, send: ASGISendCallable) -> None: # pragma: nocover pass return asgi # pragma: nocover -@pytest.mark.parametrize( - "app, expected_interface", [(asgi_app, "3.0"), (asgi2_app, "2.0")] -) -def test_asgi_version( - app: "ASGIApplication", expected_interface: Literal["2.0", "3.0"] -) -> None: +@pytest.mark.parametrize("app, expected_interface", [(asgi_app, "3.0"), (asgi2_app, "2.0")]) +def test_asgi_version(app: ASGIApplication, expected_interface: Literal["2.0", "3.0"]) -> None: config = Config(app=app) config.load() assert config.asgi_version == expected_interface @@ -350,9 +314,9 @@ def test_asgi_version( ) def test_log_config_default( mocked_logging_config_module: MagicMock, - use_colors: typing.Optional[bool], - expected: typing.Optional[bool], - logging_config, + use_colors: bool | None, + expected: bool | None, + logging_config: dict[str, Any], ) -> None: """ Test that one can specify the use_colors option when using the default logging @@ -369,16 +333,14 @@ def test_log_config_default( def test_log_config_json( mocked_logging_config_module: MagicMock, - logging_config: dict, + logging_config: dict[str, Any], json_logging_config: str, mocker: MockerFixture, ) -> None: """ Test that one can load a json config from disk. """ - mocked_open = mocker.patch( - "uvicorn.config.open", mocker.mock_open(read_data=json_logging_config) - ) + mocked_open = mocker.patch("uvicorn.config.open", mocker.mock_open(read_data=json_logging_config)) config = Config(app=asgi_app, log_config="log_config.json") config.load() @@ -390,7 +352,7 @@ def test_log_config_json( @pytest.mark.parametrize("config_filename", ["log_config.yml", "log_config.yaml"]) def test_log_config_yaml( mocked_logging_config_module: MagicMock, - logging_config: dict, + logging_config: dict[str, Any], yaml_logging_config: str, mocker: MockerFixture, config_filename: str, @@ -398,9 +360,7 @@ def test_log_config_yaml( """ Test that one can load a yaml config from disk. """ - mocked_open = mocker.patch( - "uvicorn.config.open", mocker.mock_open(read_data=yaml_logging_config) - ) + mocked_open = mocker.patch("uvicorn.config.open", mocker.mock_open(read_data=yaml_logging_config)) config = Config(app=asgi_app, log_config=config_filename) config.load() @@ -416,9 +376,7 @@ def test_log_config_file(mocked_logging_config_module: MagicMock) -> None: config = Config(app=asgi_app, log_config="log_config") config.load() - mocked_logging_config_module.fileConfig.assert_called_once_with( - "log_config", disable_existing_loggers=False - ) + mocked_logging_config_module.fileConfig.assert_called_once_with("log_config", disable_existing_loggers=False) @pytest.fixture(params=[0, 1]) @@ -445,10 +403,7 @@ def test_env_file( Test that one can load environment variables using an env file. """ fp = tmp_path / ".env" - content = ( - f"WEB_CONCURRENCY={web_concurrency}\n" - f"FORWARDED_ALLOW_IPS={forwarded_allow_ips}\n" - ) + content = f"WEB_CONCURRENCY={web_concurrency}\n" f"FORWARDED_ALLOW_IPS={forwarded_allow_ips}\n" fp.write_text(content) with caplog.at_level(logging.INFO): config = Config(app=asgi_app, env_file=fp) @@ -488,9 +443,7 @@ def test_config_log_level(log_level: int) -> None: @pytest.mark.parametrize("log_level", [None, 0, 5, 10, 20, 30, 40, 50]) @pytest.mark.parametrize("uvicorn_logger_level", [0, 5, 10, 20, 30, 40, 50]) -def test_config_log_effective_level( - log_level: Optional[int], uvicorn_logger_level: Optional[int] -) -> None: +def test_config_log_effective_level(log_level: int, uvicorn_logger_level: int) -> None: default_level = 30 log_config = { "version": 1, @@ -530,7 +483,7 @@ def test_ws_max_queue() -> None: ) @pytest.mark.skipif(sys.platform == "win32", reason="require unix-like system") def test_bind_unix_socket_works_with_reload_or_workers( - tmp_path, reload, workers, short_socket_name + tmp_path: Path, reload: bool, workers: int, short_socket_name: str ): # pragma: py-win32 config = Config(app=asgi_app, uds=short_socket_name, reload=reload, workers=workers) config.load() @@ -550,7 +503,7 @@ def test_bind_unix_socket_works_with_reload_or_workers( ids=["--reload=True --workers=1", "--reload=False --workers=2"], ) @pytest.mark.skipif(sys.platform == "win32", reason="require unix-like system") -def test_bind_fd_works_with_reload_or_workers(reload, workers): # pragma: py-win32 +def test_bind_fd_works_with_reload_or_workers(reload: bool, workers: int): # pragma: py-win32 fdsock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) fd = fdsock.fileno() config = Config(app=asgi_app, fd=fd, reload=reload, workers=workers) @@ -576,7 +529,7 @@ def test_bind_fd_works_with_reload_or_workers(reload, workers): # pragma: py-wi "--reload=False --workers=1", ], ) -def test_config_use_subprocess(reload, workers, expected): +def test_config_use_subprocess(reload: bool, workers: int, expected: bool): config = Config(app=asgi_app, reload=reload, workers=workers) config.load() assert config.use_subprocess == expected @@ -585,7 +538,4 @@ def test_config_use_subprocess(reload, workers, expected): def test_warn_when_using_reload_and_workers(caplog: pytest.LogCaptureFixture) -> None: Config(app=asgi_app, reload=True, workers=2) assert len(caplog.records) == 1 - assert ( - '"workers" flag is ignored when reloading is enabled.' - in caplog.records[0].message - ) + assert '"workers" flag is ignored when reloading is enabled.' in caplog.records[0].message diff --git a/tests/test_default_headers.py b/tests/test_default_headers.py index e458b9234..ab3375ea6 100644 --- a/tests/test_default_headers.py +++ b/tests/test_default_headers.py @@ -32,9 +32,7 @@ async def test_override_server_header(unused_tcp_port: int): async with run_server(config): async with httpx.AsyncClient() as client: response = await client.get(f"http://127.0.0.1:{unused_tcp_port}") - assert ( - response.headers["server"] == "over-ridden" and response.headers["date"] - ) + assert response.headers["server"] == "over-ridden" and response.headers["date"] @pytest.mark.anyio @@ -64,10 +62,7 @@ async def test_override_server_header_multiple_times(unused_tcp_port: int): async with run_server(config): async with httpx.AsyncClient() as client: response = await client.get(f"http://127.0.0.1:{unused_tcp_port}") - assert ( - response.headers["server"] == "over-ridden, another-value" - and response.headers["date"] - ) + assert response.headers["server"] == "over-ridden, another-value" and response.headers["date"] @pytest.mark.anyio diff --git a/tests/test_lifespan.py b/tests/test_lifespan.py index a9cb73e3a..89368fa98 100644 --- a/tests/test_lifespan.py +++ b/tests/test_lifespan.py @@ -132,9 +132,7 @@ def test_lifespan_with_failed_startup(mode, raise_exception, caplog): async def app(scope, receive, send): message = await receive() assert message["type"] == "lifespan.startup" - await send( - {"type": "lifespan.startup.failed", "message": "the lifespan event failed"} - ) + await send({"type": "lifespan.startup.failed", "message": "the lifespan event failed"}) if raise_exception: # App should be able to re-raise an exception if startup failed. raise RuntimeError() @@ -153,9 +151,7 @@ async def test(): loop.run_until_complete(test()) loop.close() error_messages = [ - record.message - for record in caplog.records - if record.name == "uvicorn.error" and record.levelname == "ERROR" + record.message for record in caplog.records if record.name == "uvicorn.error" and record.levelname == "ERROR" ] assert "the lifespan event failed" in error_messages.pop(0) assert "Application startup failed. Exiting." in error_messages.pop(0) @@ -218,9 +214,7 @@ async def app(scope, receive, send): await send({"type": "lifespan.startup.complete"}) message = await receive() assert message["type"] == "lifespan.shutdown" - await send( - {"type": "lifespan.shutdown.failed", "message": "the lifespan event failed"} - ) + await send({"type": "lifespan.shutdown.failed", "message": "the lifespan event failed"}) if raise_exception: # App should be able to re-raise an exception if startup failed. @@ -240,9 +234,7 @@ async def test(): loop = asyncio.new_event_loop() loop.run_until_complete(test()) error_messages = [ - record.message - for record in caplog.records - if record.name == "uvicorn.error" and record.levelname == "ERROR" + record.message for record in caplog.records if record.name == "uvicorn.error" and record.levelname == "ERROR" ] assert "the lifespan event failed" in error_messages.pop(0) assert "Application shutdown failed. Exiting." in error_messages.pop(0) diff --git a/tests/test_main.py b/tests/test_main.py index ddd80cf76..fc2532749 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -47,9 +47,7 @@ def _has_ipv6(host): ], ) async def test_run(host, url: str, unused_tcp_port: int): - config = Config( - app=app, host=host, loop="asyncio", limit_max_requests=1, port=unused_tcp_port - ) + config = Config(app=app, host=host, loop="asyncio", limit_max_requests=1, port=unused_tcp_port) async with run_server(config): async with httpx.AsyncClient() as client: response = await client.get(f"{url}:{unused_tcp_port}") @@ -58,9 +56,7 @@ async def test_run(host, url: str, unused_tcp_port: int): @pytest.mark.anyio async def test_run_multiprocess(unused_tcp_port: int): - config = Config( - app=app, loop="asyncio", workers=2, limit_max_requests=1, port=unused_tcp_port - ) + config = Config(app=app, loop="asyncio", workers=2, limit_max_requests=1, port=unused_tcp_port) async with run_server(config): async with httpx.AsyncClient() as client: response = await client.get(f"http://127.0.0.1:{unused_tcp_port}") @@ -69,9 +65,7 @@ async def test_run_multiprocess(unused_tcp_port: int): @pytest.mark.anyio async def test_run_reload(unused_tcp_port: int): - config = Config( - app=app, loop="asyncio", reload=True, limit_max_requests=1, port=unused_tcp_port - ) + config = Config(app=app, loop="asyncio", reload=True, limit_max_requests=1, port=unused_tcp_port) async with run_server(config): async with httpx.AsyncClient() as client: response = await client.get(f"http://127.0.0.1:{unused_tcp_port}") @@ -85,8 +79,7 @@ def test_run_invalid_app_config_combination(caplog: pytest.LogCaptureFixture) -> assert caplog.records[-1].name == "uvicorn.error" assert caplog.records[-1].levelno == WARNING assert caplog.records[-1].message == ( - "You must pass the application as an import string to enable " - "'reload' or 'workers'." + "You must pass the application as an import string to enable " "'reload' or 'workers'." ) @@ -109,9 +102,7 @@ def test_run_match_config_params() -> None: if key not in ("self", "timeout_notify", "callback_notify") } run_params = { - key: repr(value) - for key, value in inspect.signature(run).parameters.items() - if key not in ("app_dir",) + key: repr(value) for key, value in inspect.signature(run).parameters.items() if key not in ("app_dir",) } assert config_params == run_params diff --git a/tests/test_ssl.py b/tests/test_ssl.py index d60bcf54e..da60bb8dd 100644 --- a/tests/test_ssl.py +++ b/tests/test_ssl.py @@ -56,9 +56,7 @@ async def test_run_chain( @pytest.mark.anyio -async def test_run_chain_only( - tls_ca_ssl_context, tls_certificate_key_and_chain_path, unused_tcp_port: int -): +async def test_run_chain_only(tls_ca_ssl_context, tls_certificate_key_and_chain_path, unused_tcp_port: int): config = Config( app=app, loop="asyncio", diff --git a/tests/test_subprocess.py b/tests/test_subprocess.py index f32721a6c..93191bacb 100644 --- a/tests/test_subprocess.py +++ b/tests/test_subprocess.py @@ -1,5 +1,6 @@ +from __future__ import annotations + import socket -from typing import List from unittest.mock import patch from uvicorn._subprocess import SpawnProcess, get_subprocess, subprocess_started @@ -7,13 +8,11 @@ from uvicorn.config import Config -def server_run(sockets: List[socket.socket]): # pragma: no cover +def server_run(sockets: list[socket.socket]): # pragma: no cover ... -async def app( - scope: "Scope", receive: "ASGIReceiveCallable", send: "ASGISendCallable" -) -> None: # pragma: no cover +async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None: # pragma: no cover ... diff --git a/uvicorn/_types.py b/uvicorn/_types.py index 3a510d0ac..7546262a8 100644 --- a/uvicorn/_types.py +++ b/uvicorn/_types.py @@ -36,18 +36,16 @@ Awaitable, Callable, Iterable, + Literal, MutableMapping, Optional, + Protocol, Tuple, Type, + TypedDict, Union, ) -if sys.version_info >= (3, 8): # pragma: py-lt-38 - from typing import Literal, Protocol, TypedDict -else: # pragma: py-gte-38 - from typing_extensions import Literal, Protocol, TypedDict - if sys.version_info >= (3, 11): # pragma: py-lt-311 from typing import NotRequired else: # pragma: py-gte-311 @@ -239,9 +237,7 @@ class LifespanShutdownFailedEvent(TypedDict): message: str -WebSocketEvent = Union[ - WebSocketReceiveEvent, WebSocketDisconnectEvent, WebSocketConnectEvent -] +WebSocketEvent = Union[WebSocketReceiveEvent, WebSocketDisconnectEvent, WebSocketConnectEvent] ASGIReceiveEvent = Union[ @@ -281,9 +277,7 @@ class ASGI2Protocol(Protocol): def __init__(self, scope: Scope) -> None: ... # pragma: no cover - async def __call__( - self, receive: ASGIReceiveCallable, send: ASGISendCallable - ) -> None: + async def __call__(self, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None: ... # pragma: no cover diff --git a/uvicorn/config.py b/uvicorn/config.py index b0dff4604..3cad1d90f 100644 --- a/uvicorn/config.py +++ b/uvicorn/config.py @@ -127,9 +127,7 @@ def is_dir(path: Path) -> bool: return False -def resolve_reload_patterns( - patterns_list: list[str], directories_list: list[str] -) -> tuple[list[str], list[Path]]: +def resolve_reload_patterns(patterns_list: list[str], directories_list: list[str]) -> tuple[list[str], list[Path]]: directories: list[Path] = list(set(map(Path, directories_list.copy()))) patterns: list[str] = patterns_list.copy() @@ -150,9 +148,7 @@ def resolve_reload_patterns( directories = list(set(directories)) directories = list(map(Path, directories)) directories = list(map(lambda x: x.resolve(), directories)) - directories = list( - {reload_path for reload_path in directories if is_dir(reload_path)} - ) + directories = list({reload_path for reload_path in directories if is_dir(reload_path)}) children = [] for j in range(len(directories)): @@ -280,12 +276,9 @@ def __init__( self.reload_includes: list[str] = [] self.reload_excludes: list[str] = [] - if ( - reload_dirs or reload_includes or reload_excludes - ) and not self.should_reload: + if (reload_dirs or reload_includes or reload_excludes) and not self.should_reload: logger.warning( - "Current configuration will not reload as not all conditions are met, " - "please refer to documentation." + "Current configuration will not reload as not all conditions are met, " "please refer to documentation." ) if self.should_reload: @@ -293,22 +286,15 @@ def __init__( reload_includes = _normalize_dirs(reload_includes) reload_excludes = _normalize_dirs(reload_excludes) - self.reload_includes, self.reload_dirs = resolve_reload_patterns( - reload_includes, reload_dirs - ) + self.reload_includes, self.reload_dirs = resolve_reload_patterns(reload_includes, reload_dirs) - self.reload_excludes, self.reload_dirs_excludes = resolve_reload_patterns( - reload_excludes, [] - ) + self.reload_excludes, self.reload_dirs_excludes = resolve_reload_patterns(reload_excludes, []) reload_dirs_tmp = self.reload_dirs.copy() for directory in self.reload_dirs_excludes: for reload_directory in reload_dirs_tmp: - if ( - directory == reload_directory - or directory in reload_directory.parents - ): + if directory == reload_directory or directory in reload_directory.parents: try: self.reload_dirs.remove(reload_directory) except ValueError: @@ -343,9 +329,7 @@ def __init__( self.forwarded_allow_ips: list[str] | str if forwarded_allow_ips is None: - self.forwarded_allow_ips = os.environ.get( - "FORWARDED_ALLOW_IPS", "127.0.0.1" - ) + self.forwarded_allow_ips = os.environ.get("FORWARDED_ALLOW_IPS", "127.0.0.1") else: self.forwarded_allow_ips = forwarded_allow_ips @@ -375,12 +359,8 @@ def configure_logging(self) -> None: if self.log_config is not None: if isinstance(self.log_config, dict): if self.use_colors in (True, False): - self.log_config["formatters"]["default"][ - "use_colors" - ] = self.use_colors - self.log_config["formatters"]["access"][ - "use_colors" - ] = self.use_colors + self.log_config["formatters"]["default"]["use_colors"] = self.use_colors + self.log_config["formatters"]["access"]["use_colors"] = self.use_colors logging.config.dictConfig(self.log_config) elif self.log_config.endswith(".json"): with open(self.log_config) as file: @@ -397,9 +377,7 @@ def configure_logging(self) -> None: else: # See the note about fileConfig() here: # https://docs.python.org/3/library/logging.config.html#configuration-file-format - logging.config.fileConfig( - self.log_config, disable_existing_loggers=False - ) + logging.config.fileConfig(self.log_config, disable_existing_loggers=False) if self.log_level is not None: if isinstance(self.log_level, str): @@ -430,10 +408,7 @@ def load(self) -> None: else: self.ssl = None - encoded_headers = [ - (key.lower().encode("latin1"), value.encode("latin1")) - for key, value in self.headers - ] + encoded_headers = [(key.lower().encode("latin1"), value.encode("latin1")) for key, value in self.headers] self.encoded_headers = ( [(b"server", b"uvicorn")] + encoded_headers if b"server" not in dict(encoded_headers) and self.server_header @@ -469,8 +444,7 @@ def load(self) -> None: else: if not self.factory: logger.warning( - "ASGI app factory detected. Using it, " - "but please consider setting the --factory flag explicitly." + "ASGI app factory detected. Using it, " "but please consider setting the --factory flag explicitly." ) if self.interface == "auto": @@ -492,9 +466,7 @@ def load(self) -> None: if logger.getEffectiveLevel() <= TRACE_LOG_LEVEL: self.loaded_app = MessageLoggerMiddleware(self.loaded_app) if self.proxy_headers: - self.loaded_app = ProxyHeadersMiddleware( - self.loaded_app, trusted_hosts=self.forwarded_allow_ips - ) + self.loaded_app = ProxyHeadersMiddleware(self.loaded_app, trusted_hosts=self.forwarded_allow_ips) self.loaded = True @@ -518,21 +490,13 @@ def bind_socket(self) -> socket.socket: message = "Uvicorn running on unix socket %s (Press CTRL+C to quit)" sock_name_format = "%s" - color_message = ( - "Uvicorn running on " - + click.style(sock_name_format, bold=True) - + " (Press CTRL+C to quit)" - ) + color_message = "Uvicorn running on " + click.style(sock_name_format, bold=True) + " (Press CTRL+C to quit)" logger_args = [self.uds] elif self.fd: # pragma: py-win32 sock = socket.fromfd(self.fd, socket.AF_UNIX, socket.SOCK_STREAM) message = "Uvicorn running on socket %s (Press CTRL+C to quit)" fd_name_format = "%s" - color_message = ( - "Uvicorn running on " - + click.style(fd_name_format, bold=True) - + " (Press CTRL+C to quit)" - ) + color_message = "Uvicorn running on " + click.style(fd_name_format, bold=True) + " (Press CTRL+C to quit)" logger_args = [sock.getsockname()] else: family = socket.AF_INET @@ -552,11 +516,7 @@ def bind_socket(self) -> socket.socket: sys.exit(1) message = f"Uvicorn running on {addr_format} (Press CTRL+C to quit)" - color_message = ( - "Uvicorn running on " - + click.style(addr_format, bold=True) - + " (Press CTRL+C to quit)" - ) + color_message = "Uvicorn running on " + click.style(addr_format, bold=True) + " (Press CTRL+C to quit)" protocol_name = "https" if self.is_ssl else "http" logger_args = [protocol_name, self.host, sock.getsockname()[1]] logger.info(message, *logger_args, extra={"color_message": color_message}) diff --git a/uvicorn/importer.py b/uvicorn/importer.py index 338eba25c..f77520ee1 100644 --- a/uvicorn/importer.py +++ b/uvicorn/importer.py @@ -12,9 +12,7 @@ def import_from_string(import_str: Any) -> Any: module_str, _, attrs_str = import_str.partition(":") if not module_str or not attrs_str: - message = ( - 'Import string "{import_str}" must be in format ":".' - ) + message = 'Import string "{import_str}" must be in format ":".' raise ImportFromStringError(message.format(import_str=import_str)) try: @@ -31,8 +29,6 @@ def import_from_string(import_str: Any) -> Any: instance = getattr(instance, attr_str) except AttributeError: message = 'Attribute "{attrs_str}" not found in module "{module_str}".' - raise ImportFromStringError( - message.format(attrs_str=attrs_str, module_str=module_str) - ) + raise ImportFromStringError(message.format(attrs_str=attrs_str, module_str=module_str)) return instance diff --git a/uvicorn/lifespan/off.py b/uvicorn/lifespan/off.py index e1516f16a..74554b1e2 100644 --- a/uvicorn/lifespan/off.py +++ b/uvicorn/lifespan/off.py @@ -1,4 +1,6 @@ -from typing import Any, Dict +from __future__ import annotations + +from typing import Any from uvicorn import Config @@ -6,7 +8,7 @@ class LifespanOff: def __init__(self, config: Config) -> None: self.should_exit = False - self.state: Dict[str, Any] = {} + self.state: dict[str, Any] = {} async def startup(self) -> None: pass diff --git a/uvicorn/lifespan/on.py b/uvicorn/lifespan/on.py index 34dfdb1c5..09df984ea 100644 --- a/uvicorn/lifespan/on.py +++ b/uvicorn/lifespan/on.py @@ -1,7 +1,9 @@ +from __future__ import annotations + import asyncio import logging from asyncio import Queue -from typing import Any, Dict, Union +from typing import Any, Union from uvicorn import Config from uvicorn._types import ( @@ -35,12 +37,12 @@ def __init__(self, config: Config) -> None: self.logger = logging.getLogger("uvicorn.error") self.startup_event = asyncio.Event() self.shutdown_event = asyncio.Event() - self.receive_queue: "Queue[LifespanReceiveMessage]" = asyncio.Queue() + self.receive_queue: Queue[LifespanReceiveMessage] = asyncio.Queue() self.error_occured = False self.startup_failed = False self.shutdown_failed = False self.should_exit = False - self.state: Dict[str, Any] = {} + self.state: dict[str, Any] = {} async def startup(self) -> None: self.logger.info("Waiting for application startup.") @@ -67,9 +69,7 @@ async def shutdown(self) -> None: await self.receive_queue.put(shutdown_event) await self.shutdown_event.wait() - if self.shutdown_failed or ( - self.error_occured and self.config.lifespan == "on" - ): + if self.shutdown_failed or (self.error_occured and self.config.lifespan == "on"): self.logger.error("Application shutdown failed. Exiting.") self.should_exit = True else: @@ -99,7 +99,7 @@ async def main(self) -> None: self.startup_event.set() self.shutdown_event.set() - async def send(self, message: "LifespanSendMessage") -> None: + async def send(self, message: LifespanSendMessage) -> None: assert message["type"] in ( "lifespan.startup.complete", "lifespan.startup.failed", @@ -133,5 +133,5 @@ async def send(self, message: "LifespanSendMessage") -> None: if message.get("message"): self.logger.error(message["message"]) - async def receive(self) -> "LifespanReceiveMessage": + async def receive(self) -> LifespanReceiveMessage: return await self.receive_queue.get() diff --git a/uvicorn/logging.py b/uvicorn/logging.py index 74e864ed3..ab6261d14 100644 --- a/uvicorn/logging.py +++ b/uvicorn/logging.py @@ -26,9 +26,7 @@ class ColourizedFormatter(logging.Formatter): logging.INFO: lambda level_name: click.style(str(level_name), fg="green"), logging.WARNING: lambda level_name: click.style(str(level_name), fg="yellow"), logging.ERROR: lambda level_name: click.style(str(level_name), fg="red"), - logging.CRITICAL: lambda level_name: click.style( - str(level_name), fg="bright_red" - ), + logging.CRITICAL: lambda level_name: click.style(str(level_name), fg="bright_red"), } def __init__( @@ -86,7 +84,7 @@ def get_status_code(self, status_code: int) -> str: status_phrase = http.HTTPStatus(status_code).phrase except ValueError: status_phrase = "" - status_and_phrase = "%s %s" % (status_code, status_phrase) + status_and_phrase = f"{status_code} {status_phrase}" if self.use_colors: def default(code: int) -> str: @@ -106,7 +104,7 @@ def formatMessage(self, record: logging.LogRecord) -> str: status_code, ) = recordcopy.args # type: ignore[misc] status_code = self.get_status_code(int(status_code)) # type: ignore[arg-type] - request_line = "%s %s HTTP/%s" % (method, full_path, http_version) + request_line = f"{method} {full_path} HTTP/{http_version}" if self.use_colors: request_line = click.style(request_line, bold=True) recordcopy.__dict__.update( diff --git a/uvicorn/main.py b/uvicorn/main.py index fee6c5b4c..ace2b70d7 100644 --- a/uvicorn/main.py +++ b/uvicorn/main.py @@ -47,12 +47,11 @@ def print_version(ctx: click.Context, param: click.Parameter, value: bool) -> No if not value or ctx.resilient_parsing: return click.echo( - "Running uvicorn %s with %s %s on %s" - % ( - uvicorn.__version__, - platform.python_implementation(), - platform.python_version(), - platform.system(), + "Running uvicorn {version} with {py_implementation} {py_version} on {system}".format( + version=uvicorn.__version__, + py_implementation=platform.python_implementation(), + py_version=platform.python_version(), + system=platform.system(), ) ) ctx.exit() @@ -75,16 +74,13 @@ def print_version(ctx: click.Context, param: click.Parameter, value: bool) -> No show_default=True, ) @click.option("--uds", type=str, default=None, help="Bind to a UNIX domain socket.") -@click.option( - "--fd", type=int, default=None, help="Bind to socket from this file descriptor." -) +@click.option("--fd", type=int, default=None, help="Bind to socket from this file descriptor.") @click.option("--reload", is_flag=True, default=False, help="Enable auto-reload.") @click.option( "--reload-dir", "reload_dirs", multiple=True, - help="Set reload directories explicitly, instead of using the current working" - " directory.", + help="Set reload directories explicitly, instead of using the current working" " directory.", type=click.Path(exists=True), ) @click.option( @@ -109,8 +105,7 @@ def print_version(ctx: click.Context, param: click.Parameter, value: bool) -> No type=float, default=0.25, show_default=True, - help="Delay between previous and next check if application needs to be." - " Defaults to 0.25s.", + help="Delay between previous and next check if application needs to be." " Defaults to 0.25s.", ) @click.option( "--workers", @@ -226,8 +221,7 @@ def print_version(ctx: click.Context, param: click.Parameter, value: bool) -> No "--proxy-headers/--no-proxy-headers", is_flag=True, default=True, - help="Enable/Disable X-Forwarded-Proto, X-Forwarded-For, X-Forwarded-Port to " - "populate remote address info.", + help="Enable/Disable X-Forwarded-Proto, X-Forwarded-For, X-Forwarded-Port to " "populate remote address info.", ) @click.option( "--server-header/--no-server-header", @@ -258,8 +252,7 @@ def print_version(ctx: click.Context, param: click.Parameter, value: bool) -> No "--limit-concurrency", type=int, default=None, - help="Maximum number of concurrent connections or tasks to allow, before issuing" - " HTTP 503 responses.", + help="Maximum number of concurrent connections or tasks to allow, before issuing" " HTTP 503 responses.", ) @click.option( "--backlog", @@ -286,9 +279,7 @@ def print_version(ctx: click.Context, param: click.Parameter, value: bool) -> No default=None, help="Maximum number of seconds to wait for graceful shutdown.", ) -@click.option( - "--ssl-keyfile", type=str, default=None, help="SSL key file", show_default=True -) +@click.option("--ssl-keyfile", type=str, default=None, help="SSL key file", show_default=True) @click.option( "--ssl-certfile", type=str, @@ -571,10 +562,7 @@ def run( if (config.reload or config.workers > 1) and not isinstance(app, str): logger = logging.getLogger("uvicorn.error") - logger.warning( - "You must pass the application as an import string to enable 'reload' or " - "'workers'." - ) + logger.warning("You must pass the application as an import string to enable 'reload' or " "'workers'.") sys.exit(1) if config.should_reload: diff --git a/uvicorn/middleware/asgi2.py b/uvicorn/middleware/asgi2.py index 75145f732..4e15d1599 100644 --- a/uvicorn/middleware/asgi2.py +++ b/uvicorn/middleware/asgi2.py @@ -10,8 +10,6 @@ class ASGI2Middleware: def __init__(self, app: "ASGI2Application"): self.app = app - async def __call__( - self, scope: "Scope", receive: "ASGIReceiveCallable", send: "ASGISendCallable" - ) -> None: + async def __call__(self, scope: "Scope", receive: "ASGIReceiveCallable", send: "ASGISendCallable") -> None: instance = self.app(scope) await instance(receive, send) diff --git a/uvicorn/middleware/proxy_headers.py b/uvicorn/middleware/proxy_headers.py index 1c254416e..8f987ab0b 100644 --- a/uvicorn/middleware/proxy_headers.py +++ b/uvicorn/middleware/proxy_headers.py @@ -8,23 +8,18 @@ https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers#Proxies """ -from typing import List, Optional, Tuple, Union, cast +from __future__ import annotations -from uvicorn._types import ( - ASGI3Application, - ASGIReceiveCallable, - ASGISendCallable, - HTTPScope, - Scope, - WebSocketScope, -) +from typing import Union, cast + +from uvicorn._types import ASGI3Application, ASGIReceiveCallable, ASGISendCallable, HTTPScope, Scope, WebSocketScope class ProxyHeadersMiddleware: def __init__( self, - app: "ASGI3Application", - trusted_hosts: Union[List[str], str] = "127.0.0.1", + app: ASGI3Application, + trusted_hosts: list[str] | str = "127.0.0.1", ) -> None: self.app = app if isinstance(trusted_hosts, str): @@ -33,9 +28,7 @@ def __init__( self.trusted_hosts = set(trusted_hosts) self.always_trust = "*" in self.trusted_hosts - def get_trusted_client_host( - self, x_forwarded_for_hosts: List[str] - ) -> Optional[str]: + def get_trusted_client_host(self, x_forwarded_for_hosts: list[str]) -> str | None: if self.always_trust: return x_forwarded_for_hosts[0] @@ -45,12 +38,10 @@ def get_trusted_client_host( return None - async def __call__( - self, scope: "Scope", receive: "ASGIReceiveCallable", send: "ASGISendCallable" - ) -> None: + async def __call__(self, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None: if scope["type"] in ("http", "websocket"): scope = cast(Union["HTTPScope", "WebSocketScope"], scope) - client_addr: Optional[Tuple[str, int]] = scope.get("client") + client_addr: tuple[str, int] | None = scope.get("client") client_host = client_addr[0] if client_addr else None if self.always_trust or client_host in self.trusted_hosts: @@ -59,9 +50,7 @@ async def __call__( if b"x-forwarded-proto" in headers: # Determine if the incoming request was http or https based on # the X-Forwarded-Proto header. - x_forwarded_proto = ( - headers[b"x-forwarded-proto"].decode("latin1").strip() - ) + x_forwarded_proto = headers[b"x-forwarded-proto"].decode("latin1").strip() if scope["type"] == "websocket": scope["scheme"] = x_forwarded_proto.replace("http", "ws") else: @@ -72,9 +61,7 @@ async def __call__( # X-Forwarded-For header. We've lost the connecting client's port # information by now, so only include the host. x_forwarded_for = headers[b"x-forwarded-for"].decode("latin1") - x_forwarded_for_hosts = [ - item.strip() for item in x_forwarded_for.split(",") - ] + x_forwarded_for_hosts = [item.strip() for item in x_forwarded_for.split(",")] host = self.get_trusted_client_host(x_forwarded_for_hosts) port = 0 scope["client"] = (host, port) # type: ignore[arg-type] diff --git a/uvicorn/middleware/wsgi.py b/uvicorn/middleware/wsgi.py index b181e0f16..078de1af0 100644 --- a/uvicorn/middleware/wsgi.py +++ b/uvicorn/middleware/wsgi.py @@ -1,10 +1,12 @@ +from __future__ import annotations + import asyncio import concurrent.futures import io import sys import warnings from collections import deque -from typing import Deque, Iterable, Optional, Tuple +from typing import Iterable from uvicorn._types import ( ASGIReceiveCallable, @@ -22,9 +24,7 @@ ) -def build_environ( - scope: "HTTPScope", message: "ASGIReceiveEvent", body: io.BytesIO -) -> Environ: +def build_environ(scope: HTTPScope, message: ASGIReceiveEvent, body: io.BytesIO) -> Environ: """ Builds a scope and request message into a WSGI environ object. """ @@ -91,9 +91,9 @@ def __init__(self, app: WSGIApp, workers: int = 10): async def __call__( self, - scope: "HTTPScope", - receive: "ASGIReceiveCallable", - send: "ASGISendCallable", + scope: HTTPScope, + receive: ASGIReceiveCallable, + send: ASGISendCallable, ) -> None: assert scope["type"] == "http" instance = WSGIResponder(self.app, self.executor, scope) @@ -105,7 +105,7 @@ def __init__( self, app: WSGIApp, executor: concurrent.futures.ThreadPoolExecutor, - scope: "HTTPScope", + scope: HTTPScope, ): self.app = app self.executor = executor @@ -113,21 +113,19 @@ def __init__( self.status = None self.response_headers = None self.send_event = asyncio.Event() - self.send_queue: Deque[Optional["ASGISendEvent"]] = deque() + self.send_queue: deque[ASGISendEvent | None] = deque() self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() self.response_started = False - self.exc_info: Optional[ExcInfo] = None + self.exc_info: ExcInfo | None = None - async def __call__( - self, receive: "ASGIReceiveCallable", send: "ASGISendCallable" - ) -> None: + async def __call__(self, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None: message: HTTPRequestEvent = await receive() # type: ignore[assignment] body = io.BytesIO(message.get("body", b"")) more_body = message.get("more_body", False) if more_body: body.seek(0, io.SEEK_END) while more_body: - body_message: "HTTPRequestEvent" = ( + body_message: HTTPRequestEvent = ( await receive() # type: ignore[assignment] ) body.write(body_message.get("body", b"")) @@ -135,9 +133,7 @@ async def __call__( body.seek(0) environ = build_environ(self.scope, message, body) self.loop = asyncio.get_event_loop() - wsgi = self.loop.run_in_executor( - self.executor, self.wsgi, environ, self.start_response - ) + wsgi = self.loop.run_in_executor(self.executor, self.wsgi, environ, self.start_response) sender = self.loop.create_task(self.sender(send)) try: await asyncio.wait_for(wsgi, None) @@ -148,7 +144,7 @@ async def __call__( if self.exc_info is not None: raise self.exc_info[0].with_traceback(self.exc_info[1], self.exc_info[2]) - async def sender(self, send: "ASGISendCallable") -> None: + async def sender(self, send: ASGISendCallable) -> None: while True: if self.send_queue: message = self.send_queue.popleft() @@ -162,18 +158,15 @@ async def sender(self, send: "ASGISendCallable") -> None: def start_response( self, status: str, - response_headers: Iterable[Tuple[str, str]], - exc_info: Optional[ExcInfo] = None, + response_headers: Iterable[tuple[str, str]], + exc_info: ExcInfo | None = None, ) -> None: self.exc_info = exc_info if not self.response_started: self.response_started = True status_code_str, _ = status.split(" ", 1) status_code = int(status_code_str) - headers = [ - (name.encode("ascii"), value.encode("ascii")) - for name, value in response_headers - ] + headers = [(name.encode("ascii"), value.encode("ascii")) for name, value in response_headers] http_response_start_event: HTTPResponseStartEvent = { "type": "http.response.start", "status": status_code, diff --git a/uvicorn/protocols/http/flow_control.py b/uvicorn/protocols/http/flow_control.py index df642c7b6..893a26c80 100644 --- a/uvicorn/protocols/http/flow_control.py +++ b/uvicorn/protocols/http/flow_control.py @@ -45,9 +45,7 @@ def resume_writing(self) -> None: self._is_writable_event.set() -async def service_unavailable( - scope: "Scope", receive: "ASGIReceiveCallable", send: "ASGISendCallable" -) -> None: +async def service_unavailable(scope: "Scope", receive: "ASGIReceiveCallable", send: "ASGISendCallable") -> None: response_start: "HTTPResponseStartEvent" = { "type": "http.response.start", "status": 503, diff --git a/uvicorn/protocols/http/h11_impl.py b/uvicorn/protocols/http/h11_impl.py index 90bfaeadf..1e1872ef6 100644 --- a/uvicorn/protocols/http/h11_impl.py +++ b/uvicorn/protocols/http/h11_impl.py @@ -44,9 +44,7 @@ def _get_status_phrase(status_code: int) -> bytes: return b"" -STATUS_PHRASES = { - status_code: _get_status_phrase(status_code) for status_code in range(100, 600) -} +STATUS_PHRASES = {status_code: _get_status_phrase(status_code) for status_code in range(100, 600)} class H11Protocol(asyncio.Protocol): @@ -228,8 +226,7 @@ def handle_events(self) -> None: # Handle 503 responses when 'limit_concurrency' is exceeded. if self.limit_concurrency is not None and ( - len(self.connections) >= self.limit_concurrency - or len(self.tasks) >= self.limit_concurrency + len(self.connections) >= self.limit_concurrency or len(self.tasks) >= self.limit_concurrency ): app = service_unavailable message = "Exceeded concurrency limit." @@ -323,9 +320,7 @@ def on_response_complete(self) -> None: # Set a short Keep-Alive timeout. self._unset_keepalive_if_required() - self.timeout_keep_alive_task = self.loop.call_later( - self.timeout_keep_alive, self.timeout_keep_alive_handler - ) + self.timeout_keep_alive_task = self.loop.call_later(self.timeout_keep_alive, self.timeout_keep_alive_handler) # Unpause data reads if needed. self.flow.resume_reading() @@ -372,7 +367,7 @@ def timeout_keep_alive_handler(self) -> None: class RequestResponseCycle: def __init__( self, - scope: "HTTPScope", + scope: HTTPScope, conn: h11.Connection, transport: asyncio.Transport, flow: FlowControl, @@ -408,7 +403,7 @@ def __init__( self.response_complete = False # ASGI exception wrapper - async def run_asgi(self, app: "ASGI3Application") -> None: + async def run_asgi(self, app: ASGI3Application) -> None: try: result = await app( # type: ignore[func-returns-value] self.scope, self.receive, self.send @@ -533,9 +528,7 @@ async def send(self, message: ASGISendEvent) -> None: async def receive(self) -> ASGIReceiveEvent: if self.waiting_for_100_continue and not self.transport.is_closing(): headers: list[tuple[str, str]] = [] - event = h11.InformationalResponse( - status_code=100, headers=headers, reason="Continue" - ) + event = h11.InformationalResponse(status_code=100, headers=headers, reason="Continue") output = self.conn.send(event=event) self.transport.write(output) self.waiting_for_100_continue = False diff --git a/uvicorn/protocols/http/httptools_impl.py b/uvicorn/protocols/http/httptools_impl.py index 78e38154d..2950bd537 100644 --- a/uvicorn/protocols/http/httptools_impl.py +++ b/uvicorn/protocols/http/httptools_impl.py @@ -7,7 +7,7 @@ import urllib from asyncio.events import TimerHandle from collections import deque -from typing import Any, Callable, Deque, Literal, cast +from typing import Any, Callable, Literal, cast import httptools @@ -50,9 +50,7 @@ def _get_status_line(status_code: int) -> bytes: return b"".join([b"HTTP/1.1 ", str(status_code).encode(), b" ", phrase, b"\r\n"]) -STATUS_LINE = { - status_code: _get_status_line(status_code) for status_code in range(100, 600) -} +STATUS_LINE = {status_code: _get_status_line(status_code) for status_code in range(100, 600)} class HttpToolsProtocol(asyncio.Protocol): @@ -93,7 +91,7 @@ def __init__( self.server: tuple[str, int] | None = None self.client: tuple[str, int] | None = None self.scheme: Literal["http", "https"] | None = None - self.pipeline: Deque[tuple[RequestResponseCycle, ASGI3Application]] = deque() + self.pipeline: deque[tuple[RequestResponseCycle, ASGI3Application]] = deque() # Per-request state self.scope: HTTPScope = None # type: ignore[assignment] @@ -268,8 +266,7 @@ def on_headers_complete(self) -> None: # Handle 503 responses when 'limit_concurrency' is exceeded. if self.limit_concurrency is not None and ( - len(self.connections) >= self.limit_concurrency - or len(self.tasks) >= self.limit_concurrency + len(self.connections) >= self.limit_concurrency or len(self.tasks) >= self.limit_concurrency ): app = service_unavailable message = "Exceeded concurrency limit." @@ -302,9 +299,7 @@ def on_headers_complete(self) -> None: self.pipeline.appendleft((self.cycle, app)) def on_body(self, body: bytes) -> None: - if ( - self.parser.should_upgrade() and self._should_upgrade() - ) or self.cycle.response_complete: + if (self.parser.should_upgrade() and self._should_upgrade()) or self.cycle.response_complete: return self.cycle.body += body if len(self.cycle.body) > HIGH_WATER_LIMIT: @@ -312,9 +307,7 @@ def on_body(self, body: bytes) -> None: self.cycle.message_event.set() def on_message_complete(self) -> None: - if ( - self.parser.should_upgrade() and self._should_upgrade() - ) or self.cycle.response_complete: + if (self.parser.should_upgrade() and self._should_upgrade()) or self.cycle.response_complete: return self.cycle.more_body = False self.cycle.message_event.set() @@ -376,7 +369,7 @@ def timeout_keep_alive_handler(self) -> None: class RequestResponseCycle: def __init__( self, - scope: "HTTPScope", + scope: HTTPScope, transport: asyncio.Transport, flow: FlowControl, logger: logging.Logger, @@ -517,11 +510,7 @@ async def send(self, message: ASGISendEvent) -> None: self.keep_alive = False content.extend([name, b": ", value, b"\r\n"]) - if ( - self.chunked_encoding is None - and self.scope["method"] != "HEAD" - and status_code not in (204, 304) - ): + if self.chunked_encoding is None and self.scope["method"] != "HEAD" and status_code not in (204, 304): # Neither content-length nor transfer-encoding specified self.chunked_encoding = True content.append(b"transfer-encoding: chunked\r\n") diff --git a/uvicorn/protocols/utils.py b/uvicorn/protocols/utils.py index a064e0189..4e65806ca 100644 --- a/uvicorn/protocols/utils.py +++ b/uvicorn/protocols/utils.py @@ -53,7 +53,5 @@ def get_client_addr(scope: WWWScope) -> str: def get_path_with_query_string(scope: WWWScope) -> str: path_with_query_string = urllib.parse.quote(scope["path"]) if scope["query_string"]: - path_with_query_string = "{}?{}".format( - path_with_query_string, scope["query_string"].decode("ascii") - ) + path_with_query_string = "{}?{}".format(path_with_query_string, scope["query_string"].decode("ascii")) return path_with_query_string diff --git a/uvicorn/protocols/websockets/auto.py b/uvicorn/protocols/websockets/auto.py index 368b98242..08fd13678 100644 --- a/uvicorn/protocols/websockets/auto.py +++ b/uvicorn/protocols/websockets/auto.py @@ -1,7 +1,9 @@ +from __future__ import annotations + import asyncio import typing -AutoWebSocketsProtocol: typing.Optional[typing.Callable[..., asyncio.Protocol]] +AutoWebSocketsProtocol: typing.Callable[..., asyncio.Protocol] | None try: import websockets # noqa except ImportError: # pragma: no cover diff --git a/uvicorn/protocols/websockets/websockets_impl.py b/uvicorn/protocols/websockets/websockets_impl.py index 9aab66759..6d098d5af 100644 --- a/uvicorn/protocols/websockets/websockets_impl.py +++ b/uvicorn/protocols/websockets/websockets_impl.py @@ -3,16 +3,7 @@ import asyncio import http import logging -from typing import ( - Any, - List, - Literal, - Optional, - Sequence, - Tuple, - Union, - cast, -) +from typing import Any, Literal, Optional, Sequence, cast from urllib.parse import unquote import websockets @@ -61,7 +52,7 @@ def is_serving(self) -> bool: class WebSocketProtocol(WebSocketServerProtocol): - extra_headers: List[Tuple[str, str]] + extra_headers: list[tuple[str, str]] def __init__( self, @@ -117,8 +108,7 @@ def __init__( ) self.server_header = None self.extra_headers = [ - (name.decode("latin-1"), value.decode("latin-1")) - for name, value in server_state.default_headers + (name.decode("latin-1"), value.decode("latin-1")) for name, value in server_state.default_headers ] def connection_made( # type: ignore[override] @@ -136,16 +126,14 @@ def connection_made( # type: ignore[override] super().connection_made(transport) - def connection_lost(self, exc: Optional[Exception]) -> None: + def connection_lost(self, exc: Exception | None) -> None: self.connections.remove(self) if self.logger.isEnabledFor(TRACE_LOG_LEVEL): prefix = "%s:%d - " % self.client if self.client else "" self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection lost", prefix) - self.lost_connection_before_handshake = ( - not self.handshake_completed_event.is_set() - ) + self.lost_connection_before_handshake = not self.handshake_completed_event.is_set() self.handshake_completed_event.set() super().connection_lost(exc) if exc is None: @@ -162,9 +150,7 @@ def shutdown(self) -> None: def on_task_complete(self, task: asyncio.Task) -> None: self.tasks.discard(task) - async def process_request( - self, path: str, headers: Headers - ) -> Optional[HTTPResponse]: + async def process_request(self, path: str, headers: Headers) -> HTTPResponse | None: """ This hook is called to determine if the websocket should return an HTTP response and close. @@ -212,8 +198,8 @@ async def process_request( return self.initial_response def process_subprotocol( - self, headers: Headers, available_subprotocols: Optional[Sequence[Subprotocol]] - ) -> Optional[Subprotocol]: + self, headers: Headers, available_subprotocols: Sequence[Subprotocol] | None + ) -> Subprotocol | None: """ We override the standard 'process_subprotocol' behavior here so that we return whatever subprotocol is sent in the 'accept' message. @@ -223,8 +209,7 @@ def process_subprotocol( def send_500_response(self) -> None: msg = b"Internal Server Error" content = [ - b"HTTP/1.1 500 Internal Server Error\r\n" - b"content-type: text/plain; charset=utf-8\r\n", + b"HTTP/1.1 500 Internal Server Error\r\n" b"content-type: text/plain; charset=utf-8\r\n", b"content-length: " + str(len(msg)).encode("ascii") + b"\r\n", b"connection: close\r\n", b"\r\n", @@ -278,7 +263,7 @@ async def run_asgi(self) -> None: await self.handshake_completed_event.wait() self.transport.close() - async def asgi_send(self, message: "ASGISendEvent") -> None: + async def asgi_send(self, message: ASGISendEvent) -> None: message_type = message["type"] if not self.handshake_started_event.is_set(): @@ -290,9 +275,7 @@ async def asgi_send(self, message: "ASGISendEvent") -> None: get_path_with_query_string(self.scope), ) self.initial_response = None - self.accepted_subprotocol = cast( - Optional[Subprotocol], message.get("subprotocol") - ) + self.accepted_subprotocol = cast(Optional[Subprotocol], message.get("subprotocol")) if "headers" in message: self.extra_headers.extend( # ASGI spec requires bytes @@ -324,8 +307,7 @@ async def asgi_send(self, message: "ASGISendEvent") -> None: # websockets requires the status to be an enum. look it up. status = http.HTTPStatus(message["status"]) headers = [ - (name.decode("latin-1"), value.decode("latin-1")) - for name, value in message.get("headers", []) + (name.decode("latin-1"), value.decode("latin-1")) for name, value in message.get("headers", []) ] self.initial_response = (status, headers, b"") self.handshake_started_event.set() @@ -356,10 +338,7 @@ async def asgi_send(self, message: "ASGISendEvent") -> None: self.closed_event.set() else: - msg = ( - "Expected ASGI message 'websocket.send' or 'websocket.close'," - " but got '%s'." - ) + msg = "Expected ASGI message 'websocket.send' or 'websocket.close'," " but got '%s'." raise RuntimeError(msg % message_type) except ConnectionClosed as exc: raise ClientDisconnected from exc @@ -372,24 +351,16 @@ async def asgi_send(self, message: "ASGISendEvent") -> None: if not message.get("more_body", False): self.closed_event.set() else: - msg = ( - "Expected ASGI message 'websocket.http.response.body' " - "but got '%s'." - ) + msg = "Expected ASGI message 'websocket.http.response.body' " "but got '%s'." raise RuntimeError(msg % message_type) else: - msg = ( - "Unexpected ASGI message '%s', after sending 'websocket.close' " - "or response already completed." - ) + msg = "Unexpected ASGI message '%s', after sending 'websocket.close' " "or response already completed." raise RuntimeError(msg % message_type) async def asgi_receive( self, - ) -> Union[ - "WebSocketDisconnectEvent", "WebSocketConnectEvent", "WebSocketReceiveEvent" - ]: + ) -> WebSocketDisconnectEvent | WebSocketConnectEvent | WebSocketReceiveEvent: if not self.connect_sent: self.connect_sent = True return {"type": "websocket.connect"} diff --git a/uvicorn/protocols/websockets/wsproto_impl.py b/uvicorn/protocols/websockets/wsproto_impl.py index 85880a408..c92625277 100644 --- a/uvicorn/protocols/websockets/wsproto_impl.py +++ b/uvicorn/protocols/websockets/wsproto_impl.py @@ -168,7 +168,7 @@ def handle_connect(self, event: events.Request) -> None: path = unquote(raw_path) full_path = self.root_path + path full_raw_path = self.root_path.encode("ascii") + raw_path.encode("ascii") - self.scope: "WebSocketScope" = { + self.scope: WebSocketScope = { "type": "websocket", "asgi": {"version": self.config.asgi_version, "spec_version": "2.4"}, "http_version": "1.1", @@ -224,14 +224,8 @@ def send_500_response(self) -> None: (b"content-type", b"text/plain; charset=utf-8"), (b"connection", b"close"), ] - output = self.conn.send( - wsproto.events.RejectConnection( - status_code=500, headers=headers, has_body=True - ) - ) - output += self.conn.send( - wsproto.events.RejectData(data=b"Internal Server Error") - ) + output = self.conn.send(wsproto.events.RejectConnection(status_code=500, headers=headers, has_body=True)) + output += self.conn.send(wsproto.events.RejectData(data=b"Internal Server Error")) self.transport.write(output) async def run_asgi(self) -> None: @@ -269,7 +263,7 @@ async def send(self, message: ASGISendEvent) -> None: ) subprotocol = message.get("subprotocol") extra_headers = self.default_headers + list(message.get("headers", [])) - extensions: typing.List[Extension] = [] + extensions: list[Extension] = [] if self.config.ws_per_message_deflate: extensions.append(PerMessageDeflate()) if not self.transport.is_closing(): @@ -343,21 +337,14 @@ async def send(self, message: ASGISendEvent) -> None: self.close_sent = True code = message.get("code", 1000) reason = message.get("reason", "") or "" - self.queue.put_nowait( - {"type": "websocket.disconnect", "code": code} - ) - output = self.conn.send( - wsproto.events.CloseConnection(code=code, reason=reason) - ) + self.queue.put_nowait({"type": "websocket.disconnect", "code": code}) + output = self.conn.send(wsproto.events.CloseConnection(code=code, reason=reason)) if not self.transport.is_closing(): self.transport.write(output) self.transport.close() else: - msg = ( - "Expected ASGI message 'websocket.send' or 'websocket.close'," - " but got '%s'." - ) + msg = "Expected ASGI message 'websocket.send' or 'websocket.close'," " but got '%s'." raise RuntimeError(msg % message_type) except LocalProtocolError as exc: raise ClientDisconnected from exc @@ -365,24 +352,17 @@ async def send(self, message: ASGISendEvent) -> None: if message_type == "websocket.http.response.body": message = typing.cast("WebSocketResponseBodyEvent", message) body_finished = not message.get("more_body", False) - reject_data = events.RejectData( - data=message["body"], body_finished=body_finished - ) + reject_data = events.RejectData(data=message["body"], body_finished=body_finished) output = self.conn.send(reject_data) self.transport.write(output) if body_finished: - self.queue.put_nowait( - {"type": "websocket.disconnect", "code": 1006} - ) + self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006}) self.close_sent = True self.transport.close() else: - msg = ( - "Expected ASGI message 'websocket.http.response.body' " - "but got '%s'." - ) + msg = "Expected ASGI message 'websocket.http.response.body' " "but got '%s'." raise RuntimeError(msg % message_type) else: diff --git a/uvicorn/server.py b/uvicorn/server.py index 1f0b726f8..c7645f3ce 100644 --- a/uvicorn/server.py +++ b/uvicorn/server.py @@ -126,18 +126,14 @@ def _share_socket( is_windows = platform.system() == "Windows" if config.workers > 1 and is_windows: # pragma: py-not-win32 sock = _share_socket(sock) # type: ignore[assignment] - server = await loop.create_server( - create_protocol, sock=sock, ssl=config.ssl, backlog=config.backlog - ) + server = await loop.create_server(create_protocol, sock=sock, ssl=config.ssl, backlog=config.backlog) self.servers.append(server) listeners = sockets elif config.fd is not None: # pragma: py-win32 # Use an existing socket, from a file descriptor. sock = socket.fromfd(config.fd, socket.AF_UNIX, socket.SOCK_STREAM) - server = await loop.create_server( - create_protocol, sock=sock, ssl=config.ssl, backlog=config.backlog - ) + server = await loop.create_server(create_protocol, sock=sock, ssl=config.ssl, backlog=config.backlog) assert server.sockets is not None # mypy listeners = server.sockets self.servers = [server] @@ -194,9 +190,7 @@ def _log_started_message(self, listeners: Sequence[socket.SocketType]) -> None: ) elif config.uds is not None: # pragma: py-win32 - logger.info( - "Uvicorn running on unix socket %s (Press CTRL+C to quit)", config.uds - ) + logger.info("Uvicorn running on unix socket %s (Press CTRL+C to quit)", config.uds) else: addr_format = "%s://%s:%d" @@ -211,11 +205,7 @@ def _log_started_message(self, listeners: Sequence[socket.SocketType]) -> None: protocol_name = "https" if config.ssl else "http" message = f"Uvicorn running on {addr_format} (Press CTRL+C to quit)" - color_message = ( - "Uvicorn running on " - + click.style(addr_format, bold=True) - + " (Press CTRL+C to quit)" - ) + color_message = "Uvicorn running on " + click.style(addr_format, bold=True) + " (Press CTRL+C to quit)" logger.info( message, protocol_name, @@ -244,9 +234,7 @@ async def on_tick(self, counter: int) -> bool: else: date_header = [] - self.server_state.default_headers = ( - date_header + self.config.encoded_headers - ) + self.server_state.default_headers = date_header + self.config.encoded_headers # Callback to `callback_notify` once every `timeout_notify` seconds. if self.config.callback_notify is not None: diff --git a/uvicorn/supervisors/__init__.py b/uvicorn/supervisors/__init__.py index deaf12ede..c90f24e4a 100644 --- a/uvicorn/supervisors/__init__.py +++ b/uvicorn/supervisors/__init__.py @@ -1,10 +1,12 @@ -from typing import TYPE_CHECKING, Type +from __future__ import annotations + +from typing import TYPE_CHECKING from uvicorn.supervisors.basereload import BaseReload from uvicorn.supervisors.multiprocess import Multiprocess if TYPE_CHECKING: - ChangeReload: Type[BaseReload] + ChangeReload: type[BaseReload] else: try: from uvicorn.supervisors.watchfilesreload import ( diff --git a/uvicorn/supervisors/basereload.py b/uvicorn/supervisors/basereload.py index 6e2e0c359..1c791a8fb 100644 --- a/uvicorn/supervisors/basereload.py +++ b/uvicorn/supervisors/basereload.py @@ -81,9 +81,7 @@ def startup(self) -> None: for sig in HANDLED_SIGNALS: signal.signal(sig, self.signal_handler) - self.process = get_subprocess( - config=self.config, target=self.target, sockets=self.sockets - ) + self.process = get_subprocess(config=self.config, target=self.target, sockets=self.sockets) self.process.start() def restart(self) -> None: @@ -95,9 +93,7 @@ def restart(self) -> None: self.process.terminate() self.process.join() - self.process = get_subprocess( - config=self.config, target=self.target, sockets=self.sockets - ) + self.process = get_subprocess(config=self.config, target=self.target, sockets=self.sockets) self.process.start() def shutdown(self) -> None: @@ -110,10 +106,8 @@ def shutdown(self) -> None: for sock in self.sockets: sock.close() - message = "Stopping reloader process [{}]".format(str(self.pid)) - color_message = "Stopping reloader process [{}]".format( - click.style(str(self.pid), fg="cyan", bold=True) - ) + message = f"Stopping reloader process [{str(self.pid)}]" + color_message = "Stopping reloader process [{}]".format(click.style(str(self.pid), fg="cyan", bold=True)) logger.info(message, extra={"color_message": color_message}) def should_restart(self) -> list[Path] | None: diff --git a/uvicorn/supervisors/multiprocess.py b/uvicorn/supervisors/multiprocess.py index 153b3d658..e0916721b 100644 --- a/uvicorn/supervisors/multiprocess.py +++ b/uvicorn/supervisors/multiprocess.py @@ -48,19 +48,15 @@ def run(self) -> None: self.shutdown() def startup(self) -> None: - message = "Started parent process [{}]".format(str(self.pid)) - color_message = "Started parent process [{}]".format( - click.style(str(self.pid), fg="cyan", bold=True) - ) + message = f"Started parent process [{str(self.pid)}]" + color_message = "Started parent process [{}]".format(click.style(str(self.pid), fg="cyan", bold=True)) logger.info(message, extra={"color_message": color_message}) for sig in HANDLED_SIGNALS: signal.signal(sig, self.signal_handler) for _idx in range(self.config.workers): - process = get_subprocess( - config=self.config, target=self.target, sockets=self.sockets - ) + process = get_subprocess(config=self.config, target=self.target, sockets=self.sockets) process.start() self.processes.append(process) @@ -69,8 +65,6 @@ def shutdown(self) -> None: process.terminate() process.join() - message = "Stopping parent process [{}]".format(str(self.pid)) - color_message = "Stopping parent process [{}]".format( - click.style(str(self.pid), fg="cyan", bold=True) - ) + message = f"Stopping parent process [{str(self.pid)}]" + color_message = "Stopping parent process [{}]".format(click.style(str(self.pid), fg="cyan", bold=True)) logger.info(message, extra={"color_message": color_message}) diff --git a/uvicorn/supervisors/statreload.py b/uvicorn/supervisors/statreload.py index 2e25dd4a9..70d0a6d5c 100644 --- a/uvicorn/supervisors/statreload.py +++ b/uvicorn/supervisors/statreload.py @@ -23,10 +23,7 @@ def __init__( self.mtimes: dict[Path, float] = {} if config.reload_excludes or config.reload_includes: - logger.warning( - "--reload-include and --reload-exclude have no effect unless " - "watchfiles is installed." - ) + logger.warning("--reload-include and --reload-exclude have no effect unless " "watchfiles is installed.") def should_restart(self) -> list[Path] | None: self.pause() diff --git a/uvicorn/supervisors/watchfilesreload.py b/uvicorn/supervisors/watchfilesreload.py index e1cb311f2..292a7bab8 100644 --- a/uvicorn/supervisors/watchfilesreload.py +++ b/uvicorn/supervisors/watchfilesreload.py @@ -13,20 +13,12 @@ class FileFilter: def __init__(self, config: Config): default_includes = ["*.py"] - self.includes = [ - default - for default in default_includes - if default not in config.reload_excludes - ] + self.includes = [default for default in default_includes if default not in config.reload_excludes] self.includes.extend(config.reload_includes) self.includes = list(set(self.includes)) default_excludes = [".*", ".py[cod]", ".sw.*", "~*"] - self.excludes = [ - default - for default in default_excludes - if default not in config.reload_includes - ] + self.excludes = [default for default in default_excludes if default not in config.reload_includes] self.exclude_dirs = [] for e in config.reload_excludes: p = Path(e) diff --git a/uvicorn/supervisors/watchgodreload.py b/uvicorn/supervisors/watchgodreload.py index 987909fd6..6f248faa7 100644 --- a/uvicorn/supervisors/watchgodreload.py +++ b/uvicorn/supervisors/watchgodreload.py @@ -22,20 +22,12 @@ class CustomWatcher(DefaultWatcher): def __init__(self, root_path: Path, config: Config): default_includes = ["*.py"] - self.includes = [ - default - for default in default_includes - if default not in config.reload_excludes - ] + self.includes = [default for default in default_includes if default not in config.reload_excludes] self.includes.extend(config.reload_includes) self.includes = list(set(self.includes)) default_excludes = [".*", ".py[cod]", ".sw.*", "~*"] - self.excludes = [ - default - for default in default_excludes - if default not in config.reload_includes - ] + self.excludes = [default for default in default_excludes if default not in config.reload_includes] self.excludes.extend(config.reload_excludes) self.excludes = list(set(self.excludes)) @@ -46,7 +38,7 @@ def __init__(self, root_path: Path, config: Config): self.resolved_root = root_path super().__init__(str(root_path)) - def should_watch_file(self, entry: "DirEntry") -> bool: + def should_watch_file(self, entry: DirEntry) -> bool: cached_result = self.watched_files.get(entry.path) if cached_result is not None: return cached_result @@ -71,7 +63,7 @@ def should_watch_file(self, entry: "DirEntry") -> bool: self.watched_files[entry.path] = False return False - def should_watch_dir(self, entry: "DirEntry") -> bool: + def should_watch_dir(self, entry: DirEntry) -> bool: cached_result = self.watched_dirs.get(entry.path) if cached_result is not None: return cached_result @@ -94,8 +86,7 @@ def should_watch_dir(self, entry: "DirEntry") -> bool: if is_watched: logger.debug( - "WatchGodReload detected a new excluded dir '%s' in '%s'; " - "Adding to exclude list.", + "WatchGodReload detected a new excluded dir '%s' in '%s'; " "Adding to exclude list.", entry_path.relative_to(self.resolved_root), str(self.resolved_root), ) @@ -115,8 +106,7 @@ def should_watch_dir(self, entry: "DirEntry") -> bool: for include_pattern in self.includes: if entry_path.match(include_pattern): logger.info( - "WatchGodReload detected a new reload dir '%s' in '%s'; " - "Adding to watch list.", + "WatchGodReload detected a new reload dir '%s' in '%s'; " "Adding to watch list.", str(entry_path.relative_to(self.resolved_root)), str(self.resolved_root), ) @@ -136,8 +126,7 @@ def __init__( sockets: list[socket], ) -> None: warnings.warn( - '"watchgod" is deprecated, you should switch ' - "to watchfiles (`pip install watchfiles`).", + '"watchgod" is deprecated, you should switch ' "to watchfiles (`pip install watchfiles`).", DeprecationWarning, ) super().__init__(config, target, sockets) diff --git a/uvicorn/workers.py b/uvicorn/workers.py index 0d6b49bdc..3b46471e3 100644 --- a/uvicorn/workers.py +++ b/uvicorn/workers.py @@ -1,8 +1,10 @@ +from __future__ import annotations + import asyncio import logging import signal import sys -from typing import Any, Dict +from typing import Any from gunicorn.arbiter import Arbiter from gunicorn.workers.base import Worker @@ -17,10 +19,10 @@ class UvicornWorker(Worker): rather than a WSGI callable. """ - CONFIG_KWARGS: Dict[str, Any] = {"loop": "auto", "http": "auto"} + CONFIG_KWARGS: dict[str, Any] = {"loop": "auto", "http": "auto"} def __init__(self, *args: Any, **kwargs: Any) -> None: - super(UvicornWorker, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) logger = logging.getLogger("uvicorn.error") logger.handlers = self.log.error_log.handlers @@ -63,7 +65,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: def init_process(self) -> None: self.config.setup_event_loop() - super(UvicornWorker, self).init_process() + super().init_process() def init_signals(self) -> None: # Reset signals so Gunicorn doesn't swallow subprocess return codes