diff --git a/tests/protocols/test_websocket.py b/tests/protocols/test_websocket.py index 26e094d9f..4ee9e570f 100644 --- a/tests/protocols/test_websocket.py +++ b/tests/protocols/test_websocket.py @@ -51,6 +51,21 @@ async def asgi(self): break +async def wsresponse(url): + """ + A simple websocket connection request and response helper + """ + url = url.replace("ws:", "http:") + headers = { + "connection": "upgrade", + "upgrade": "websocket", + "Sec-WebSocket-Key": "x3JJHMbDL1EzLkh9GBhXDw==", + "Sec-WebSocket-Version": "13", + } + async with httpx.AsyncClient() as client: + return await client.get(url, headers=headers) + + @pytest.mark.anyio @pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS) @pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) @@ -968,11 +983,9 @@ async def app(scope, receive, send): disconnected_message = await receive() async def websocket_session(url): - with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info: - async with websockets.client.connect(url): - pass # pragma: no cover - assert exc_info.value.status_code == 400 - # Websockets module currently does not read the response body from the socket. + response = await wsresponse(url) + assert response.status_code == 400 + assert response.content == b"goodbye" config = Config( app=app, @@ -1023,11 +1036,9 @@ async def app(scope, receive, send): disconnected_message = await receive() async def websocket_session(url): - with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info: - async with websockets.client.connect(url): - pass # pragma: no cover - assert exc_info.value.status_code == 400 - # Websockets module currently does not read the response body from the socket. + response = await wsresponse(url) + assert response.status_code == 400 + assert response.content == (b"x" * 10) + (b"y" * 10) config = Config( app=app, @@ -1069,10 +1080,9 @@ async def app(scope, receive, send): await send(message) async def websocket_session(url): - with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info: - async with websockets.client.connect(url): - pass # pragma: no cover - assert exc_info.value.status_code == 500 + response = await wsresponse(url) + assert response.status_code == 500 + assert response.content == b"Internal Server Error" config = Config( app=app,