Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support the WebSocket Denial Response ASGI extension #1916

Merged
merged 23 commits into from
Dec 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
1a32e1a
create test for websocket responses
kristjanvalur Mar 19, 2023
9c37602
update websockets protocol
kristjanvalur Mar 19, 2023
ad649ed
update wsproto
kristjanvalur Mar 19, 2023
ed7a4dd
add multi-body response test
kristjanvalur Mar 19, 2023
8f83cab
Update tests/protocols/test_websocket.py
kristjanvalur Mar 20, 2023
4d74d4d
fix tests
kristjanvalur Mar 20, 2023
3996b48
fix mypy problems
kristjanvalur Mar 20, 2023
4390fb7
Simply fail if application passes an invalid status code
kristjanvalur Mar 20, 2023
4f7f1f4
Add a test for invalid message order
kristjanvalur Mar 20, 2023
2f47ae4
Update uvicorn/protocols/websockets/websockets_impl.py
kristjanvalur Mar 22, 2023
7a745e4
Fix initial response error handling and unit-test
kristjanvalur Mar 22, 2023
c9586bb
Add a similar missing-body test
kristjanvalur Mar 22, 2023
4047a3a
Use httpx to check rejection response body
kristjanvalur Mar 27, 2023
46d654d
Add test showing how content-length/transfer-encoding is automaticall…
kristjanvalur Mar 29, 2023
d2944e0
Do not send the response start until the first response body is received
kristjanvalur Mar 30, 2023
8c9db89
Update unittest to use protocol fixture
kristjanvalur Dec 3, 2023
198be2e
Check wsproto response status code
kristjanvalur Dec 3, 2023
946eb7c
Fix linter
Kludex Dec 17, 2023
95f3b03
Send response headers as soon as we have them
Kludex Dec 17, 2023
db586a1
Add test to make sure a single `websocket.http.response.start` is sent
Kludex Dec 17, 2023
f81f3f0
Modify the conditionals on wsproto
Kludex Dec 17, 2023
0876494
lint
Kludex Dec 17, 2023
798613a
Remove handshake_complete
Kludex Dec 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
349 changes: 345 additions & 4 deletions tests/protocols/test_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,16 @@
from websockets.extensions.permessage_deflate import ClientPerMessageDeflateFactory
from websockets.typing import Subprotocol

from tests.response import Response
from tests.utils import run_server
from uvicorn._types import (
ASGIReceiveCallable,
ASGIReceiveEvent,
ASGISendCallable,
Scope,
WebSocketCloseEvent,
WebSocketDisconnectEvent,
WebSocketResponseStartEvent,
)
from uvicorn.config import Config
from uvicorn.protocols.websockets.websockets_impl import WebSocketProtocol
Expand Down Expand Up @@ -55,6 +58,21 @@ async def asgi(self):
break


async def wsresponse(url):
"""
A simple websocket connection request and response helper
"""
url = url.replace("ws:", "http:")
headers = {
"connection": "upgrade",
"upgrade": "websocket",
"Sec-WebSocket-Key": "x3JJHMbDL1EzLkh9GBhXDw==",
"Sec-WebSocket-Version": "13",
}
async with httpx.AsyncClient() as client:
return await client.get(url, headers=headers)


@pytest.mark.anyio
async def test_invalid_upgrade(
ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]",
Expand Down Expand Up @@ -942,7 +960,10 @@ async def test_server_reject_connection(
http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]",
unused_tcp_port: int,
):
disconnected_message: ASGIReceiveEvent = {} # type: ignore

async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
nonlocal disconnected_message
assert scope["type"] == "websocket"

# Pull up first recv message.
Expand All @@ -955,15 +976,241 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable

# This doesn't raise `TypeError`:
# See https://github.com/encode/uvicorn/issues/244
disconnected_message = await receive()

async def websocket_session(url):
with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info:
async with websockets.client.connect(url):
pass # pragma: no cover
assert exc_info.value.status_code == 403

config = Config(
app=app,
ws=ws_protocol_cls,
http=http_protocol_cls,
lifespan="off",
port=unused_tcp_port,
)
async with run_server(config):
await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}")

assert disconnected_message == {"type": "websocket.disconnect", "code": 1006}


@pytest.mark.anyio
async def test_server_reject_connection_with_response(
ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]",
http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]",
unused_tcp_port: int,
):
disconnected_message = {}

async def app(scope, receive, send):
nonlocal disconnected_message
assert scope["type"] == "websocket"
assert "websocket.http.response" in scope["extensions"]

# Pull up first recv message.
message = await receive()
assert message["type"] == "websocket.disconnect"
assert message["type"] == "websocket.connect"

# Reject the connection with a response
response = Response(b"goodbye", status_code=400)
await response(scope, receive, send)
disconnected_message = await receive()

async def websocket_session(url):
response = await wsresponse(url)
assert response.status_code == 400
assert response.content == b"goodbye"

config = Config(
app=app,
ws=ws_protocol_cls,
http=http_protocol_cls,
lifespan="off",
port=unused_tcp_port,
)
async with run_server(config):
await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}")

assert disconnected_message == {"type": "websocket.disconnect", "code": 1006}


@pytest.mark.anyio
async def test_server_reject_connection_with_multibody_response(
ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]",
http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]",
unused_tcp_port: int,
):
disconnected_message: ASGIReceiveEvent = {} # type: ignore

async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
nonlocal disconnected_message
assert scope["type"] == "websocket"
assert "extensions" in scope
assert "websocket.http.response" in scope["extensions"]

# Pull up first recv message.
message = await receive()
assert message["type"] == "websocket.connect"
await send(
{
"type": "websocket.http.response.start",
"status": 400,
"headers": [
(b"Content-Length", b"20"),
(b"Content-Type", b"text/plain"),
],
}
)
await send(
{
"type": "websocket.http.response.body",
"body": b"x" * 10,
"more_body": True,
}
)
await send({"type": "websocket.http.response.body", "body": b"y" * 10})
disconnected_message = await receive()

async def websocket_session(url: str):
try:
response = await wsresponse(url)
assert response.status_code == 400
assert response.content == (b"x" * 10) + (b"y" * 10)

config = Config(
app=app,
ws=ws_protocol_cls,
http=http_protocol_cls,
lifespan="off",
port=unused_tcp_port,
)
async with run_server(config):
await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}")

assert disconnected_message == {"type": "websocket.disconnect", "code": 1006}


@pytest.mark.anyio
async def test_server_reject_connection_with_invalid_status(
ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]",
http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]",
unused_tcp_port: int,
):
# this test checks that even if there is an error in the response, the server
# can successfully send a 500 error back to the client
async def app(scope, receive, send):
assert scope["type"] == "websocket"
assert "websocket.http.response" in scope["extensions"]

# Pull up first recv message.
message = await receive()
assert message["type"] == "websocket.connect"

message = {
"type": "websocket.http.response.start",
"status": 700, # invalid status code
"headers": [(b"Content-Length", b"0"), (b"Content-Type", b"text/plain")],
}
await send(message)
message = {
"type": "websocket.http.response.body",
"body": b"",
}
await send(message)

async def websocket_session(url):
response = await wsresponse(url)
assert response.status_code == 500
assert response.content == b"Internal Server Error"

config = Config(
app=app,
ws=ws_protocol_cls,
http=http_protocol_cls,
lifespan="off",
port=unused_tcp_port,
)
async with run_server(config):
await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}")


@pytest.mark.anyio
async def test_server_reject_connection_with_body_nolength(
ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]",
http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]",
unused_tcp_port: int,
):
# test that the server can send a response with a body but no content-length
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
assert scope["type"] == "websocket"
assert "extensions" in scope
assert "websocket.http.response" in scope["extensions"]

# Pull up first recv message.
message = await receive()
assert message["type"] == "websocket.connect"

await send(
{
"type": "websocket.http.response.start",
"status": 403,
"headers": [],
}
)
await send({"type": "websocket.http.response.body", "body": b"hardbody"})

async def websocket_session(url):
response = await wsresponse(url)
assert response.status_code == 403
assert response.content == b"hardbody"
if ws_protocol_cls == WSProtocol: # pragma: no cover
# wsproto automatically makes the message chunked
assert response.headers["transfer-encoding"] == "chunked"
else: # pragma: no cover
# websockets automatically adds a content-length
assert response.headers["content-length"] == "8"

config = Config(
app=app,
ws=ws_protocol_cls,
http=http_protocol_cls,
lifespan="off",
port=unused_tcp_port,
)
async with run_server(config):
await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}")


@pytest.mark.anyio
async def test_server_reject_connection_with_invalid_msg(
ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]",
http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]",
unused_tcp_port: int,
):
async def app(scope, receive, send):
assert scope["type"] == "websocket"
assert "websocket.http.response" in scope["extensions"]

# Pull up first recv message.
message = await receive()
assert message["type"] == "websocket.connect"

message = {
"type": "websocket.http.response.start",
"status": 404,
"headers": [(b"Content-Length", b"0"), (b"Content-Type", b"text/plain")],
}
await send(message)
# send invalid message. This will raise an exception here
await send(message)

async def websocket_session(url):
with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info:
async with websockets.client.connect(url):
pass # pragma: no cover
except Exception:
pass
assert exc_info.value.status_code == 404

config = Config(
app=app,
Expand All @@ -976,6 +1223,100 @@ async def websocket_session(url: str):
await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}")


@pytest.mark.anyio
async def test_server_reject_connection_with_missing_body(
ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]",
http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]",
unused_tcp_port: int,
):
async def app(scope, receive, send):
assert scope["type"] == "websocket"
assert "websocket.http.response" in scope["extensions"]

# Pull up first recv message.
message = await receive()
assert message["type"] == "websocket.connect"

message = {
"type": "websocket.http.response.start",
"status": 404,
"headers": [(b"Content-Length", b"0"), (b"Content-Type", b"text/plain")],
}
await send(message)
# no further message

async def websocket_session(url):
with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info:
async with websockets.client.connect(url):
pass # pragma: no cover
assert exc_info.value.status_code == 404

config = Config(
app=app,
ws=ws_protocol_cls,
http=http_protocol_cls,
lifespan="off",
port=unused_tcp_port,
)
async with run_server(config):
await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}")


@pytest.mark.anyio
async def test_server_multiple_websocket_http_response_start_events(
ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]",
http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]",
unused_tcp_port: int,
):
"""
The server should raise an exception if it sends multiple
websocket.http.response.start events.
"""
exception_message: typing.Optional[str] = None

async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
nonlocal exception_message
assert scope["type"] == "websocket"
assert "extensions" in scope
assert "websocket.http.response" in scope["extensions"]

# Pull up first recv message.
message = await receive()
assert message["type"] == "websocket.connect"

start_event: WebSocketResponseStartEvent = {
"type": "websocket.http.response.start",
"status": 404,
"headers": [(b"Content-Length", b"0"), (b"Content-Type", b"text/plain")],
}
await send(start_event)
try:
await send(start_event)
except Exception as exc:
exception_message = str(exc)

async def websocket_session(url: str):
with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info:
async with websockets.client.connect(url):
pass
assert exc_info.value.status_code == 404

config = Config(
app=app,
ws=ws_protocol_cls,
http=http_protocol_cls,
lifespan="off",
port=unused_tcp_port,
)
async with run_server(config):
await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}")

assert exception_message == (
"Expected ASGI message 'websocket.http.response.body' but got "
"'websocket.http.response.start'."
)


@pytest.mark.anyio
async def test_server_can_read_messages_in_buffer_after_close(
ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]",
Expand Down
Loading