Skip to content

Commit

Permalink
Allow configurable websocket per-message-deflate setting (#1300)
Browse files Browse the repository at this point in the history
  • Loading branch information
cfal authored Dec 29, 2021
1 parent 66e22f8 commit 6e6a841
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 2 deletions.
3 changes: 3 additions & 0 deletions docs/deployment.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ Options:
[default: 16777216]
--ws-ping-interval FLOAT WebSocket ping interval [default: 20.0]
--ws-ping-timeout FLOAT WebSocket ping timeout [default: 20.0]
--ws-per-message-deflate BOOLEAN
WebSocket per-message-deflate compression
[default: True]
--lifespan [auto|on|off] Lifespan implementation. [default: auto]
--interface [auto|asgi3|asgi2|wsgi]
Select ASGI3, ASGI2, or WSGI as the
Expand Down
3 changes: 3 additions & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,9 @@ Options:
[default: 16777216]
--ws-ping-interval FLOAT WebSocket ping interval [default: 20.0]
--ws-ping-timeout FLOAT WebSocket ping timeout [default: 20.0]
--ws-per-message-deflate BOOLEAN
WebSocket per-message-deflate compression
[default: True]
--lifespan [auto|on|off] Lifespan implementation. [default: auto]
--interface [auto|asgi3|asgi2|wsgi]
Select ASGI3, ASGI2, or WSGI as the
Expand Down
29 changes: 29 additions & 0 deletions tests/protocols/test_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,35 @@ async def open_connection(url):
assert "permessage-deflate" in extension_names


@pytest.mark.asyncio
@pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS)
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
async def test_can_disable_permessage_deflate_extension(
ws_protocol_cls, http_protocol_cls
):
class App(WebSocketResponse):
async def websocket_connect(self, message):
await self.send({"type": "websocket.accept"})

async def open_connection(url):
# enable per-message deflate on the client, so that we can check the server
# won't support it when it's disabled.
extension_factories = [ClientPerMessageDeflateFactory()]
async with websockets.connect(url, extensions=extension_factories) as websocket:
return [extension.name for extension in websocket.extensions]

config = Config(
app=App,
ws=ws_protocol_cls,
http=http_protocol_cls,
lifespan="off",
ws_per_message_deflate=False,
)
async with run_server(config):
extension_names = await open_connection("ws://127.0.0.1:8000")
assert "permessage-deflate" not in extension_names


@pytest.mark.asyncio
@pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS)
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
Expand Down
2 changes: 2 additions & 0 deletions uvicorn/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ def __init__(
ws_max_size: int = 16 * 1024 * 1024,
ws_ping_interval: Optional[float] = 20,
ws_ping_timeout: Optional[float] = 20,
ws_per_message_deflate: Optional[bool] = True,
lifespan: LifespanType = "auto",
env_file: Optional[Union[str, os.PathLike]] = None,
log_config: Optional[Union[dict, str]] = LOGGING_CONFIG,
Expand Down Expand Up @@ -251,6 +252,7 @@ def __init__(
self.ws_max_size = ws_max_size
self.ws_ping_interval = ws_ping_interval
self.ws_ping_timeout = ws_ping_timeout
self.ws_per_message_deflate = ws_per_message_deflate
self.lifespan = lifespan
self.log_config = log_config
self.log_level = log_level
Expand Down
9 changes: 9 additions & 0 deletions uvicorn/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,13 @@ def print_version(ctx: click.Context, param: click.Parameter, value: bool) -> No
help="WebSocket ping timeout",
show_default=True,
)
@click.option(
"--ws-per-message-deflate",
type=bool,
default=True,
help="WebSocket per-message-deflate compression",
show_default=True,
)
@click.option(
"--lifespan",
type=LIFESPAN_CHOICES,
Expand Down Expand Up @@ -344,6 +351,7 @@ def main(
ws_max_size: int,
ws_ping_interval: float,
ws_ping_timeout: float,
ws_per_message_deflate: bool,
lifespan: str,
interface: str,
debug: bool,
Expand Down Expand Up @@ -389,6 +397,7 @@ def main(
"ws_max_size": ws_max_size,
"ws_ping_interval": ws_ping_interval,
"ws_ping_timeout": ws_ping_timeout,
"ws_per_message_deflate": ws_per_message_deflate,
"lifespan": lifespan,
"env_file": env_file,
"log_config": LOGGING_CONFIG if log_config is None else log_config,
Expand Down
7 changes: 6 additions & 1 deletion uvicorn/protocols/websockets/websockets_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,18 @@ def __init__(
self.transfer_data_task = None

self.ws_server = Server()

extensions = []
if self.config.ws_per_message_deflate:
extensions.append(ServerPerMessageDeflateFactory())

super().__init__(
ws_handler=self.ws_handler,
ws_server=self.ws_server,
max_size=self.config.ws_max_size,
ping_interval=self.config.ws_ping_interval,
ping_timeout=self.config.ws_ping_timeout,
extensions=[ServerPerMessageDeflateFactory()],
extensions=extensions,
logger=logging.getLogger("uvicorn.error"),
)

Expand Down
5 changes: 4 additions & 1 deletion uvicorn/protocols/websockets/wsproto_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,9 +257,12 @@ async def send(self, message):
)
self.handshake_complete = True
subprotocol = message.get("subprotocol")
extensions = []
if self.config.ws_per_message_deflate:
extensions.append(PerMessageDeflate())
output = self.conn.send(
wsproto.events.AcceptConnection(
subprotocol=subprotocol, extensions=[PerMessageDeflate()]
subprotocol=subprotocol, extensions=extensions
)
)
self.transport.write(output)
Expand Down

0 comments on commit 6e6a841

Please sign in to comment.