Skip to content

Commit

Permalink
Fix x_forwarded_proto for websockets (#2043)
Browse files Browse the repository at this point in the history
  • Loading branch information
aminalaee authored Jul 12, 2023
1 parent 57c6d57 commit 806c227
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 2 deletions.
44 changes: 44 additions & 0 deletions tests/middleware/test_proxy_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,22 @@
import httpx
import pytest

from tests.protocols.test_http import HTTP_PROTOCOLS
from tests.response import Response
from tests.utils import run_server
from uvicorn._types import ASGIReceiveCallable, ASGISendCallable, Scope
from uvicorn.config import Config
from uvicorn.middleware.proxy_headers import ProxyHeadersMiddleware
from uvicorn.protocols.websockets.wsproto_impl import WSProtocol

try:
import websockets.client

from uvicorn.protocols.websockets.websockets_impl import WebSocketProtocol

WS_PROTOCOLS = [WSProtocol, WebSocketProtocol]
except ImportError: # pragma: nocover
WS_PROTOCOLS = []


async def app(
Expand Down Expand Up @@ -103,3 +116,34 @@ async def test_proxy_headers_invalid_x_forwarded_for() -> None:
response = await client.get("/", headers=headers)
assert response.status_code == 200
assert response.text == "Remote: https://1.2.3.4:0"


@pytest.mark.anyio
@pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS)
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
@pytest.mark.skipif(not WS_PROTOCOLS, reason="websockets module not installed.")
async def test_proxy_headers_websocket_x_forwarded_proto(
ws_protocol_cls, http_protocol_cls, unused_tcp_port: int
) -> None:
async def websocket_app(scope, receive, send):
scheme = scope["scheme"]
host, port = scope["client"]
addr = "%s://%s:%d" % (scheme, host, port)
await send({"type": "websocket.accept"})
await send({"type": "websocket.send", "text": addr})

app_with_middleware = ProxyHeadersMiddleware(websocket_app, trusted_hosts="*")
config = Config(
app=app_with_middleware,
ws=ws_protocol_cls,
http=http_protocol_cls,
lifespan="off",
port=unused_tcp_port,
)

async with run_server(config):
url = f"ws://127.0.0.1:{unused_tcp_port}"
headers = {"X-Forwarded-Proto": "https", "X-Forwarded-For": "1.2.3.4"}
async with websockets.client.connect(url, extra_headers=headers) as websocket:
data = await websocket.recv()
assert data == "wss://1.2.3.4:0"
11 changes: 9 additions & 2 deletions uvicorn/middleware/proxy_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,15 @@ async def __call__(
if b"x-forwarded-proto" in headers:
# Determine if the incoming request was http or https based on
# the X-Forwarded-Proto header.
x_forwarded_proto = headers[b"x-forwarded-proto"].decode("latin1")
scope["scheme"] = x_forwarded_proto.strip()
x_forwarded_proto = (
headers[b"x-forwarded-proto"].decode("latin1").strip()
)
if scope["type"] == "websocket":
scope["scheme"] = (
"wss" if x_forwarded_proto == "https" else "ws"
)
else:
scope["scheme"] = x_forwarded_proto

if b"x-forwarded-for" in headers:
# Determine the client address from the last trusted IP in the
Expand Down

0 comments on commit 806c227

Please sign in to comment.