Skip to content

Commit

Permalink
Add type annotation to wsproto_impl.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Kludex committed Nov 2, 2022
1 parent c3aa2c0 commit 2522f90
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 28 deletions.
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ files =
uvicorn/protocols/http/__init__.py,
uvicorn/protocols/websockets/__init__.py,
uvicorn/protocols/websockets/websockets_impl.py,
uvicorn/protocols/websockets/wsproto_impl.py,
uvicorn/protocols/http/h11_impl.py,
uvicorn/protocols/http/httptools_impl.py,
tests/middleware/test_wsgi.py,
Expand Down
2 changes: 2 additions & 0 deletions uvicorn/protocols/http/h11_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ def handle_events(self) -> None:
break

elif event_type is h11.Request:
event = cast(h11.Request, event)
self.headers = [(key.lower(), value) for key, value in event.headers]
raw_path, _, query_string = event.target.partition(b"?")
self.scope = { # type: ignore[typeddict-item]
Expand Down Expand Up @@ -258,6 +259,7 @@ def handle_events(self) -> None:
self.tasks.add(task)

elif event_type is h11.Data:
event = cast(h11.Data, event)
if self.conn.our_state is h11.DONE:
continue
self.cycle.body += event.data
Expand Down
81 changes: 53 additions & 28 deletions uvicorn/protocols/websockets/wsproto_impl.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import logging
import typing
from urllib.parse import unquote

import h11
Expand All @@ -9,17 +10,36 @@
from wsproto.extensions import PerMessageDeflate
from wsproto.utilities import RemoteProtocolError

from uvicorn.config import Config
from uvicorn.logging import TRACE_LOG_LEVEL
from uvicorn.protocols.utils import (
get_local_addr,
get_path_with_query_string,
get_remote_addr,
is_ssl,
)
from uvicorn.server import ServerState

if typing.TYPE_CHECKING:
from asgiref.typing import (
ASGISendEvent,
WebSocketAcceptEvent,
WebSocketCloseEvent,
WebSocketConnectEvent,
WebSocketDisconnectEvent,
WebSocketReceiveEvent,
WebSocketScope,
WebSocketSendEvent,
)


class WSProtocol(asyncio.Protocol):
def __init__(self, config, server_state, _loop=None):
def __init__(
self,
config: Config,
server_state: ServerState,
_loop: typing.Optional[asyncio.AbstractEventLoop] = None,
) -> None:
if not config.loaded:
config.load()

Expand All @@ -35,14 +55,14 @@ def __init__(self, config, server_state, _loop=None):
self.default_headers = server_state.default_headers

# Connection state
self.transport = None
self.server = None
self.client = None
self.scheme = None
self.transport: asyncio.Transport = None # type: ignore[assignment]
self.server: typing.Optional[typing.Tuple[str, int]] = None
self.client: typing.Optional[typing.Tuple[str, int]] = None
self.scheme: typing.Literal["wss", "ws"] = None # type: ignore[assignment]

# WebSocket state
self.connect_event = None
self.queue = asyncio.Queue()
self.queue: asyncio.Queue[WebSocketReceiveEvent] = asyncio.Queue()
self.handshake_complete = False
self.close_sent = False

Expand All @@ -58,41 +78,44 @@ def __init__(self, config, server_state, _loop=None):

# Protocol interface

def connection_made(self, transport):
def connection_made( # type: ignore[override]
self, transport: asyncio.Transport
) -> None:
self.connections.add(self)
self.transport = transport
self.server = get_local_addr(transport)
self.client = get_remote_addr(transport)
self.scheme = "wss" if is_ssl(transport) else "ws"

if self.logger.level <= TRACE_LOG_LEVEL:
prefix = "%s:%d - " % tuple(self.client) if self.client else ""
prefix = "%s:%d - " % self.client if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection made", prefix)

def connection_lost(self, exc):
def connection_lost(self, exc: typing.Optional[Exception]) -> None:
self.queue.put_nowait({"type": "websocket.disconnect"})
self.connections.remove(self)

if self.logger.level <= TRACE_LOG_LEVEL:
prefix = "%s:%d - " % tuple(self.client) if self.client else ""
prefix = "%s:%d - " % self.client if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection lost", prefix)

if exc is None:
self.transport.close()

def eof_received(self):
def eof_received(self) -> None:
pass

def data_received(self, data):
def data_received(self, data: bytes) -> None:
try:
self.conn.receive_data(data)
except RemoteProtocolError as err:
self.transport.write(self.conn.send(err.event_hint))
# TODO: Remove `type: ignore` when wsproto fixes the type annotation.
self.transport.write(self.conn.send(err.event_hint)) # type: ignore[arg-type]
self.transport.close()
else:
self.handle_events()

def handle_events(self):
def handle_events(self) -> None:
for event in self.conn.events():
if isinstance(event, events.Request):
self.handle_connect(event)
Expand All @@ -109,19 +132,19 @@ def handle_events(self):
elif isinstance(event, events.Ping):
self.handle_ping(event)

def pause_writing(self):
def pause_writing(self) -> None:
"""
Called by the transport when the write buffer exceeds the high water mark.
"""
self.writable.clear()

def resume_writing(self):
def resume_writing(self) -> None:
"""
Called by the transport when the write buffer drops below the low water mark.
"""
self.writable.set()

def shutdown(self):
def shutdown(self) -> None:
if self.handshake_complete:
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1012})
output = self.conn.send(wsproto.events.CloseConnection(code=1012))
Expand All @@ -130,12 +153,12 @@ def shutdown(self):
self.send_500_response()
self.transport.close()

def on_task_complete(self, task):
def on_task_complete(self, task: asyncio.Task) -> None:
self.tasks.discard(task)

# Event handlers

def handle_connect(self, event):
def handle_connect(self, event: events.Request) -> None:
self.connect_event = event
headers = [(b"host", event.host.encode())]
headers += [(key.lower(), value) for key, value in event.extra_headers]
Expand All @@ -159,7 +182,9 @@ def handle_connect(self, event):
task.add_done_callback(self.on_task_complete)
self.tasks.add(task)

def handle_no_connect(self, event):
def handle_no_connect(
self, event: typing.Union[events.RejectData, events.RejectConnection]
) -> None:
headers = [
(b"content-type", b"text/plain; charset=utf-8"),
(b"connection", b"close"),
Expand All @@ -173,7 +198,7 @@ def handle_no_connect(self, event):
self.transport.write(output)
self.transport.close()

def handle_text(self, event):
def handle_text(self, event: events.TextMessage) -> None:
self.text += event.data
if event.message_finished:
self.queue.put_nowait({"type": "websocket.receive", "text": self.text})
Expand All @@ -182,7 +207,7 @@ def handle_text(self, event):
self.read_paused = True
self.transport.pause_reading()

def handle_bytes(self, event):
def handle_bytes(self, event: events.BytesMessage) -> None:
self.bytes += event.data
# todo: we may want to guard the size of self.bytes and self.text
if event.message_finished:
Expand All @@ -192,16 +217,16 @@ def handle_bytes(self, event):
self.read_paused = True
self.transport.pause_reading()

def handle_close(self, event):
def handle_close(self, event: events.CloseConnection) -> None:
if self.conn.state == ConnectionState.REMOTE_CLOSING:
self.transport.write(self.conn.send(event.response()))
self.queue.put_nowait({"type": "websocket.disconnect", "code": event.code})
self.transport.close()

def handle_ping(self, event):
def handle_ping(self, event: events.Ping) -> None:
self.transport.write(self.conn.send(event.response()))

def send_500_response(self):
def send_500_response(self) -> None:
headers = [
(b"content-type", b"text/plain; charset=utf-8"),
(b"connection", b"close"),
Expand All @@ -219,7 +244,7 @@ def send_500_response(self):
output += self.conn.send(msg)
self.transport.write(output)

async def run_asgi(self):
async def run_asgi(self) -> None:
try:
result = await self.app(self.scope, self.receive, self.send)
except BaseException:
Expand All @@ -238,7 +263,7 @@ async def run_asgi(self):
self.logger.error(msg, result)
self.transport.close()

async def send(self, message):
async def send(self, message: WebSocketSendEvent) -> None:
await self.writable.wait()

message_type = message["type"]
Expand Down Expand Up @@ -319,7 +344,7 @@ async def send(self, message):
msg = "Unexpected ASGI message '%s', after sending 'websocket.close'."
raise RuntimeError(msg % message_type)

async def receive(self):
async def receive(self) -> WebSocketReceiveEvent:
message = await self.queue.get()
if self.read_paused and self.queue.empty():
self.read_paused = False
Expand Down

0 comments on commit 2522f90

Please sign in to comment.