Skip to content

Commit

Permalink
Create ws_protocol_cls fixture (#2049)
Browse files Browse the repository at this point in the history
  • Loading branch information
Kludex authored Jul 18, 2023
1 parent bf8f52b commit 8239373
Show file tree
Hide file tree
Showing 6 changed files with 196 additions and 110 deletions.
17 changes: 17 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import contextlib
import importlib.util
import os
import socket
import ssl
Expand All @@ -22,6 +23,7 @@
HAVE_TRUSTME = False

from uvicorn.config import LOGGING_CONFIG
from uvicorn.importer import import_from_string

# Note: We explicitly turn the propagate on just for tests, because pytest
# caplog not able to capture no-propagate loggers.
Expand Down Expand Up @@ -239,3 +241,18 @@ def _unused_port(socket_type: int) -> int:
@pytest.fixture
def unused_tcp_port() -> int:
return _unused_port(socket.SOCK_STREAM)


@pytest.fixture(
params=[
pytest.param(
"uvicorn.protocols.websockets.wsproto_impl:WSProtocol",
marks=pytest.mark.skipif(
not importlib.util.find_spec("wsproto"), reason="wsproto not installed."
),
),
"uvicorn.protocols.websockets.websockets_impl:WebSocketProtocol",
]
)
def ws_protocol_cls(request: pytest.FixtureRequest):
return import_from_string(request.param)
13 changes: 10 additions & 3 deletions tests/middleware/test_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import socket
import sys
import typing

import httpx
import pytest
Expand All @@ -19,6 +20,10 @@
except ImportError: # pragma: nocover
HTTP_PROTOCOLS = [H11Protocol]

if typing.TYPE_CHECKING:
from uvicorn.protocols.websockets.websockets_impl import WebSocketProtocol
from uvicorn.protocols.websockets.wsproto_impl import WSProtocol


@contextlib.contextmanager
def caplog_for_logger(caplog, logger_name):
Expand Down Expand Up @@ -90,9 +95,11 @@ async def test_trace_logging_on_http_protocol(


@pytest.mark.anyio
@pytest.mark.parametrize("ws_protocol", [("websockets"), ("wsproto")])
async def test_trace_logging_on_ws_protocol(
ws_protocol, caplog, logging_config, unused_tcp_port: int
ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]",
caplog,
logging_config,
unused_tcp_port: int,
):
async def websocket_app(scope, receive, send):
assert scope["type"] == "websocket"
Expand All @@ -111,7 +118,7 @@ async def open_connection(url):
app=websocket_app,
log_level="trace",
log_config=logging_config,
ws=ws_protocol,
ws=ws_protocol_cls,
port=unused_tcp_port,
)
with caplog_for_logger(caplog, "uvicorn.error"):
Expand Down
19 changes: 7 additions & 12 deletions tests/middleware/test_proxy_headers.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,19 @@
from typing import List, Union
from typing import TYPE_CHECKING, List, Type, Union

import httpx
import pytest
import websockets.client

from tests.protocols.test_http import HTTP_PROTOCOLS
from tests.response import Response
from tests.utils import run_server
from uvicorn._types import ASGIReceiveCallable, ASGISendCallable, Scope
from uvicorn.config import Config
from uvicorn.middleware.proxy_headers import ProxyHeadersMiddleware
from uvicorn.protocols.websockets.wsproto_impl import WSProtocol

try:
import websockets.client

if TYPE_CHECKING:
from uvicorn.protocols.websockets.websockets_impl import WebSocketProtocol

WS_PROTOCOLS = [WSProtocol, WebSocketProtocol]
except ImportError: # pragma: nocover
WS_PROTOCOLS = []
from uvicorn.protocols.websockets.wsproto_impl import WSProtocol


async def app(
Expand Down Expand Up @@ -119,11 +114,11 @@ async def test_proxy_headers_invalid_x_forwarded_for() -> None:


@pytest.mark.anyio
@pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS)
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
@pytest.mark.skipif(not WS_PROTOCOLS, reason="websockets module not installed.")
async def test_proxy_headers_websocket_x_forwarded_proto(
ws_protocol_cls, http_protocol_cls, unused_tcp_port: int
ws_protocol_cls: "Type[WSProtocol | WebSocketProtocol]",
http_protocol_cls,
unused_tcp_port: int,
) -> None:
async def websocket_app(scope, receive, send):
scheme = scope["scheme"]
Expand Down
15 changes: 11 additions & 4 deletions tests/protocols/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import socket
import threading
import time
from typing import Optional, Union
from typing import TYPE_CHECKING, Optional, Type, Union

import pytest

Expand All @@ -19,6 +19,10 @@
except ImportError: # pragma: nocover
HttpToolsProtocol = None # type: ignore[misc,assignment]

if TYPE_CHECKING:
from uvicorn.protocols.websockets.websockets_impl import WebSocketProtocol
from uvicorn.protocols.websockets.wsproto_impl import WSProtocol


HTTP_PROTOCOLS = [p for p in [H11Protocol, HttpToolsProtocol] if p is not None]
WEBSOCKET_PROTOCOLS = WS_PROTOCOLS.keys()
Expand Down Expand Up @@ -729,6 +733,8 @@ async def test_100_continue_not_sent_when_body_not_consumed(protocol_cls):
@pytest.mark.anyio
@pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS)
async def test_supported_upgrade_request(protocol_cls):
pytest.importorskip("wsproto")

app = Response("Hello, world", media_type="text/plain")

protocol = get_connected_protocol(app, protocol_cls, ws="wsproto")
Expand Down Expand Up @@ -774,11 +780,12 @@ async def test_unsupported_ws_upgrade_request_warn_on_auto(

@pytest.mark.anyio
@pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS)
@pytest.mark.parametrize("ws", WEBSOCKET_PROTOCOLS)
async def test_http2_upgrade_request(protocol_cls, ws):
async def test_http2_upgrade_request(
protocol_cls, ws_protocol_cls: "Type[WSProtocol | WebSocketProtocol]"
):
app = Response("Hello, world", media_type="text/plain")

protocol = get_connected_protocol(app, protocol_cls, ws=ws)
protocol = get_connected_protocol(app, protocol_cls, ws=ws_protocol_cls)
protocol.data_received(UPGRADE_HTTP2_REQUEST)
await protocol.loop.run_one()
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
Expand Down
Loading

0 comments on commit 8239373

Please sign in to comment.