diff --git a/docs/deployment.md b/docs/deployment.md index 726c58749..60d1af763 100644 --- a/docs/deployment.md +++ b/docs/deployment.md @@ -61,6 +61,9 @@ Options: [default: 16777216] --ws-ping-interval FLOAT WebSocket ping interval [default: 20.0] --ws-ping-timeout FLOAT WebSocket ping timeout [default: 20.0] + --ws-per-message-deflate BOOLEAN + WebSocket per-message-deflate compression + [default: True] --lifespan [auto|on|off] Lifespan implementation. [default: auto] --interface [auto|asgi3|asgi2|wsgi] Select ASGI3, ASGI2, or WSGI as the diff --git a/docs/index.md b/docs/index.md index 624f0eefd..63460bcc8 100644 --- a/docs/index.md +++ b/docs/index.md @@ -134,6 +134,9 @@ Options: [default: 16777216] --ws-ping-interval FLOAT WebSocket ping interval [default: 20.0] --ws-ping-timeout FLOAT WebSocket ping timeout [default: 20.0] + --ws-per-message-deflate BOOLEAN + WebSocket per-message-deflate compression + [default: True] --lifespan [auto|on|off] Lifespan implementation. [default: auto] --interface [auto|asgi3|asgi2|wsgi] Select ASGI3, ASGI2, or WSGI as the diff --git a/tests/protocols/test_websocket.py b/tests/protocols/test_websocket.py index b444b3ffc..63def1785 100644 --- a/tests/protocols/test_websocket.py +++ b/tests/protocols/test_websocket.py @@ -120,6 +120,35 @@ async def open_connection(url): assert "permessage-deflate" in extension_names +@pytest.mark.asyncio +@pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS) +@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) +async def test_can_disable_permessage_deflate_extension( + ws_protocol_cls, http_protocol_cls +): + class App(WebSocketResponse): + async def websocket_connect(self, message): + await self.send({"type": "websocket.accept"}) + + async def open_connection(url): + # enable per-message deflate on the client, so that we can check the server + # won't support it when it's disabled. + extension_factories = [ClientPerMessageDeflateFactory()] + async with websockets.connect(url, extensions=extension_factories) as websocket: + return [extension.name for extension in websocket.extensions] + + config = Config( + app=App, + ws=ws_protocol_cls, + http=http_protocol_cls, + lifespan="off", + ws_per_message_deflate=False, + ) + async with run_server(config): + extension_names = await open_connection("ws://127.0.0.1:8000") + assert "permessage-deflate" not in extension_names + + @pytest.mark.asyncio @pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS) @pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) diff --git a/uvicorn/config.py b/uvicorn/config.py index 59c94b412..681ca0caf 100644 --- a/uvicorn/config.py +++ b/uvicorn/config.py @@ -205,6 +205,7 @@ def __init__( ws_max_size: int = 16 * 1024 * 1024, ws_ping_interval: Optional[float] = 20, ws_ping_timeout: Optional[float] = 20, + ws_per_message_deflate: Optional[bool] = True, lifespan: LifespanType = "auto", env_file: Optional[Union[str, os.PathLike]] = None, log_config: Optional[Union[dict, str]] = LOGGING_CONFIG, @@ -251,6 +252,7 @@ def __init__( self.ws_max_size = ws_max_size self.ws_ping_interval = ws_ping_interval self.ws_ping_timeout = ws_ping_timeout + self.ws_per_message_deflate = ws_per_message_deflate self.lifespan = lifespan self.log_config = log_config self.log_level = log_level diff --git a/uvicorn/main.py b/uvicorn/main.py index 798e7f854..cd853b613 100644 --- a/uvicorn/main.py +++ b/uvicorn/main.py @@ -154,6 +154,13 @@ def print_version(ctx: click.Context, param: click.Parameter, value: bool) -> No help="WebSocket ping timeout", show_default=True, ) +@click.option( + "--ws-per-message-deflate", + type=bool, + default=True, + help="WebSocket per-message-deflate compression", + show_default=True, +) @click.option( "--lifespan", type=LIFESPAN_CHOICES, @@ -344,6 +351,7 @@ def main( ws_max_size: int, ws_ping_interval: float, ws_ping_timeout: float, + ws_per_message_deflate: bool, lifespan: str, interface: str, debug: bool, @@ -389,6 +397,7 @@ def main( "ws_max_size": ws_max_size, "ws_ping_interval": ws_ping_interval, "ws_ping_timeout": ws_ping_timeout, + "ws_per_message_deflate": ws_per_message_deflate, "lifespan": lifespan, "env_file": env_file, "log_config": LOGGING_CONFIG if log_config is None else log_config, diff --git a/uvicorn/protocols/websockets/websockets_impl.py b/uvicorn/protocols/websockets/websockets_impl.py index af0f954e6..6c61f3ea5 100644 --- a/uvicorn/protocols/websockets/websockets_impl.py +++ b/uvicorn/protocols/websockets/websockets_impl.py @@ -74,13 +74,18 @@ def __init__( self.transfer_data_task = None self.ws_server = Server() + + extensions = [] + if self.config.ws_per_message_deflate: + extensions.append(ServerPerMessageDeflateFactory()) + super().__init__( ws_handler=self.ws_handler, ws_server=self.ws_server, max_size=self.config.ws_max_size, ping_interval=self.config.ws_ping_interval, ping_timeout=self.config.ws_ping_timeout, - extensions=[ServerPerMessageDeflateFactory()], + extensions=extensions, logger=logging.getLogger("uvicorn.error"), ) diff --git a/uvicorn/protocols/websockets/wsproto_impl.py b/uvicorn/protocols/websockets/wsproto_impl.py index 6a08b5475..6ed3ab702 100644 --- a/uvicorn/protocols/websockets/wsproto_impl.py +++ b/uvicorn/protocols/websockets/wsproto_impl.py @@ -257,9 +257,12 @@ async def send(self, message): ) self.handshake_complete = True subprotocol = message.get("subprotocol") + extensions = [] + if self.config.ws_per_message_deflate: + extensions.append(PerMessageDeflate()) output = self.conn.send( wsproto.events.AcceptConnection( - subprotocol=subprotocol, extensions=[PerMessageDeflate()] + subprotocol=subprotocol, extensions=extensions ) ) self.transport.write(output)