diff --git a/uvicorn/_types.py b/uvicorn/_types.py index 501d04337..e63f4723c 100644 --- a/uvicorn/_types.py +++ b/uvicorn/_types.py @@ -1,10 +1,10 @@ import sys -from typing import Dict, Iterable, Optional, Tuple, Union +from typing import Awaitable, Callable, Dict, Iterable, Optional, Tuple, Type, Union if sys.version_info < (3, 8): - from typing_extensions import Literal, TypedDict + from typing_extensions import Literal, Protocol, TypedDict else: - from typing import Literal, TypedDict + from typing import Literal, Protocol, TypedDict class ASGISpecInfo(TypedDict): @@ -65,3 +65,152 @@ class WebsocketScope(TypedDict): WWWScope = Union[HTTPScope, WebsocketScope] Scope = Union[HTTPScope, WebsocketScope, LifespanScope] + + +class HTTPRequestEvent(TypedDict): + type: Literal["http.request"] + body: bytes + more_body: bool + + +class HTTPResponseStartEvent(TypedDict): + type: Literal["http.response.start"] + status: int + headers: Iterable[Tuple[bytes, bytes]] + + +class HTTPResponseBodyEvent(TypedDict): + type: Literal["http.response.body"] + body: bytes + more_body: bool + + +class HTTPServerPushEvent(TypedDict): + type: Literal["http.response.push"] + path: str + headers: Iterable[Tuple[bytes, bytes]] + + +class HTTPDisconnectEvent(TypedDict): + type: Literal["http.disconnect"] + + +class WebsocketConnectEvent(TypedDict): + type: Literal["websocket.connect"] + + +class WebsocketAcceptEvent(TypedDict): + type: Literal["websocket.accept"] + subprotocol: Optional[str] + headers: Iterable[Tuple[bytes, bytes]] + + +class WebsocketReceiveEvent(TypedDict): + type: Literal["websocket.receive"] + bytes: Optional[bytes] + text: Optional[str] + + +class WebsocketSendEvent(TypedDict): + type: Literal["websocket.send"] + bytes: Optional[bytes] + text: Optional[str] + + +class WebsocketResponseStartEvent(TypedDict): + type: Literal["websocket.http.response.start"] + status: int + headers: Iterable[Tuple[bytes, bytes]] + + +class WebsocketResponseBodyEvent(TypedDict): + type: Literal["websocket.http.response.body"] + body: bytes + more_body: bool + + +class WebsocketDisconnectEvent(TypedDict): + type: Literal["websocket.disconnect"] + code: int + + +class WebsocketCloseEvent(TypedDict): + type: Literal["websocket.close"] + code: int + reason: Optional[str] + + +class LifespanStartupEvent(TypedDict): + type: Literal["lifespan.startup"] + + +class LifespanShutdownEvent(TypedDict): + type: Literal["lifespan.shutdown"] + + +class LifespanStartupCompleteEvent(TypedDict): + type: Literal["lifespan.startup.complete"] + + +class LifespanStartupFailedEvent(TypedDict): + type: Literal["lifespan.startup.failed"] + message: str + + +class LifespanShutdownCompleteEvent(TypedDict): + type: Literal["lifespan.shutdown.complete"] + + +class LifespanShutdownFailedEvent(TypedDict): + type: Literal["lifespan.shutdown.failed"] + message: str + + +ASGIReceiveEvent = Union[ + HTTPRequestEvent, + HTTPDisconnectEvent, + WebsocketConnectEvent, + WebsocketReceiveEvent, + WebsocketDisconnectEvent, + LifespanStartupEvent, + LifespanShutdownEvent, +] + + +ASGISendEvent = Union[ + HTTPResponseStartEvent, + HTTPResponseBodyEvent, + HTTPServerPushEvent, + HTTPDisconnectEvent, + WebsocketAcceptEvent, + WebsocketSendEvent, + WebsocketResponseStartEvent, + WebsocketResponseBodyEvent, + WebsocketCloseEvent, + LifespanStartupCompleteEvent, + LifespanStartupFailedEvent, + LifespanShutdownCompleteEvent, + LifespanShutdownFailedEvent, +] + + +ASGIReceiveCallable = Callable[[], Awaitable[ASGIReceiveEvent]] +ASGISendCallable = Callable[[ASGISendEvent], Awaitable[None]] + + +class ASGI2Protocol(Protocol): + def __init__(self, scope: Scope) -> None: + ... + + async def __call__( + self, receive: ASGIReceiveCallable, send: ASGISendCallable + ) -> None: + ... + + +ASGI2Application = Type[ASGI2Protocol] +ASGI3Application = Callable[ + [Scope, ASGIReceiveCallable, ASGISendCallable], + Awaitable[None], +] +ASGIApplication = Union[ASGI2Application, ASGI3Application] diff --git a/uvicorn/middleware/wsgi.py b/uvicorn/middleware/wsgi.py index fca747ed7..23d091b91 100644 --- a/uvicorn/middleware/wsgi.py +++ b/uvicorn/middleware/wsgi.py @@ -2,9 +2,18 @@ import concurrent.futures import io import sys +from typing import Awaitable, Dict, Iterable, Optional, Tuple +from uvicorn._types import ( + ASGI3Application, + ASGIReceiveCallable, + ASGISendCallable, + ASGISendEvent, + HTTPScope, +) -def build_environ(scope, message, body): + +def build_environ(scope: HTTPScope, message: ASGISendEvent, body: bytes) -> Dict: """ Builds a scope and request message into a WSGI environ object. """ @@ -54,18 +63,25 @@ def build_environ(scope, message, body): class WSGIMiddleware: - def __init__(self, app, workers=10): + def __init__(self, app: ASGI3Application, workers: int = 10) -> None: self.app = app self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=workers) - async def __call__(self, scope, receive, send): + async def __call__( + self, scope: HTTPScope, receive: ASGIReceiveCallable, send: ASGISendCallable + ) -> None: assert scope["type"] == "http" instance = WSGIResponder(self.app, self.executor, scope) await instance(receive, send) class WSGIResponder: - def __init__(self, app, executor, scope): + def __init__( + self, + app: ASGI3Application, + executor: concurrent.futures.ThreadPoolExecutor, + scope: HTTPScope, + ) -> Awaitable: self.app = app self.executor = executor self.scope = scope @@ -75,9 +91,11 @@ def __init__(self, app, executor, scope): self.send_queue = [] self.loop = None self.response_started = False - self.exc_info = None + self.exc_info: Optional[str] = None - async def __call__(self, receive, send): + async def __call__( + self, receive: ASGIReceiveCallable, send: ASGISendCallable + ) -> None: message = await receive() body = message.get("body", b"") more_body = message.get("more_body", False) @@ -100,7 +118,7 @@ async def __call__(self, receive, send): if self.exc_info is not None: raise self.exc_info[0].with_traceback(self.exc_info[1], self.exc_info[2]) - async def sender(self, send): + async def sender(self, send: ASGISendCallable) -> None: while True: if self.send_queue: message = self.send_queue.pop(0) @@ -111,7 +129,12 @@ async def sender(self, send): await self.send_event.wait() self.send_event.clear() - def start_response(self, status, response_headers, exc_info=None): + def start_response( + self, + status: str, + response_headers: Iterable[Tuple[bytes, bytes]], + exc_info: Optional[str] = None, + ) -> None: self.exc_info = exc_info if not self.response_started: self.response_started = True @@ -130,7 +153,7 @@ def start_response(self, status, response_headers, exc_info=None): ) self.loop.call_soon_threadsafe(self.send_event.set) - def wsgi(self, environ, start_response): + def wsgi(self, environ: Dict, start_response: Awaitable[None]) -> None: for chunk in self.app(environ, start_response): self.send_queue.append( {"type": "http.response.body", "body": chunk, "more_body": True}