Skip to content

Commit

Permalink
♻️ Update root_path handling (from --root-path CLI option) to inc…
Browse files Browse the repository at this point in the history
…lude the root path prefix in the full ASGI `path` as per the ASGI spec (#2213)

* ♻️ Update root-path handling to include it in the full path as per the ASGI spe, related to Starlette 0.35.0

* ✅ Update tests for root_path, ensure it's added to the prefix of the path in the ASGI scope

* ♻️ Update the (deprecated) WSGIMiddleware to follow closely the ASGI spec

* ✅ Update tests for WSGIMiddleware

* 🎨 Fix format in tests

* Update tests/protocols/test_http.py

Co-authored-by: Marcelo Trylesinski <[email protected]>

* Update uvicorn/protocols/http/httptools_impl.py

Co-authored-by: Marcelo Trylesinski <[email protected]>

---------

Co-authored-by: Marcelo Trylesinski <[email protected]>
  • Loading branch information
tiangolo and Kludex authored Jan 16, 2024
1 parent ebcd996 commit 4af46c9
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 18 deletions.
7 changes: 4 additions & 3 deletions tests/middleware/test_wsgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,11 +122,11 @@ def test_build_environ_encoding() -> None:
scope: "HTTPScope" = {
"asgi": {"version": "3.0", "spec_version": "2.0"},
"scheme": "http",
"raw_path": b"/\xe6\x96\x87",
"raw_path": b"/\xe6\x96\x87%2Fall",
"type": "http",
"http_version": "1.1",
"method": "GET",
"path": "/文",
"path": "/文/all",
"root_path": "/文",
"client": None,
"server": None,
Expand All @@ -140,5 +140,6 @@ def test_build_environ_encoding() -> None:
"more_body": False,
}
environ = wsgi.build_environ(scope, message, io.BytesIO(b""))
assert environ["PATH_INFO"] == "/文".encode("utf8").decode("latin-1")
assert environ["SCRIPT_NAME"] == "/文".encode("utf8").decode("latin-1")
assert environ["PATH_INFO"] == "/all".encode("utf8").decode("latin-1")
assert environ["HTTP_KEY"] == "value1,value2"
13 changes: 8 additions & 5 deletions tests/protocols/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,15 +630,18 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable
async def test_root_path(http_protocol_cls: "Type[HttpToolsProtocol | H11Protocol]"):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
assert scope["type"] == "http"
path = scope.get("root_path", "") + scope["path"]
response = Response("Path: " + path, media_type="text/plain")
root_path = scope.get("root_path", "")
path = scope["path"]
response = Response(
f"root_path={root_path} path={path}", media_type="text/plain"
)
await response(scope, receive, send)

protocol = get_connected_protocol(app, http_protocol_cls, root_path="/app")
protocol.data_received(SIMPLE_GET_REQUEST)
await protocol.loop.run_one()
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
assert b"Path: /app/" in protocol.transport.buffer
assert b"root_path=/app path=/app/" in protocol.transport.buffer


@pytest.mark.anyio
Expand All @@ -647,8 +650,8 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable
assert scope["type"] == "http"
path = scope["path"]
raw_path = scope.get("raw_path", None)
assert "/one/two" == path
assert b"/one%2Ftwo" == raw_path
assert "/app/one/two" == path
assert b"/app/one%2Ftwo" == raw_path

response = Response("Done", media_type="text/plain")
await response(scope, receive, send)
Expand Down
8 changes: 6 additions & 2 deletions uvicorn/middleware/wsgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,14 @@ def build_environ(
"""
Builds a scope and request message into a WSGI environ object.
"""
script_name = scope.get("root_path", "").encode("utf8").decode("latin1")
path_info = scope["path"].encode("utf8").decode("latin1")
if path_info.startswith(script_name):
path_info = path_info[len(script_name) :]
environ = {
"REQUEST_METHOD": scope["method"],
"SCRIPT_NAME": "",
"PATH_INFO": scope["path"].encode("utf8").decode("latin1"),
"SCRIPT_NAME": script_name,
"PATH_INFO": path_info,
"QUERY_STRING": scope["query_string"].decode("ascii"),
"SERVER_PROTOCOL": "HTTP/%s" % scope["http_version"],
"wsgi.version": (1, 0),
Expand Down
7 changes: 5 additions & 2 deletions uvicorn/protocols/http/h11_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,9 @@ def handle_events(self) -> None:
elif isinstance(event, h11.Request):
self.headers = [(key.lower(), value) for key, value in event.headers]
raw_path, _, query_string = event.target.partition(b"?")
path = unquote(raw_path.decode("ascii"))
full_path = self.root_path + path
full_raw_path = self.root_path.encode("ascii") + raw_path
self.scope = {
"type": "http",
"asgi": {
Expand All @@ -210,8 +213,8 @@ def handle_events(self) -> None:
"scheme": self.scheme, # type: ignore[typeddict-item]
"method": event.method.decode("ascii"),
"root_path": self.root_path,
"path": unquote(raw_path.decode("ascii")),
"raw_path": raw_path,
"path": full_path,
"raw_path": full_raw_path,
"query_string": query_string,
"headers": self.headers,
"state": self.app_state.copy(),
Expand Down
6 changes: 4 additions & 2 deletions uvicorn/protocols/http/httptools_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,10 @@ def on_headers_complete(self) -> None:
path = raw_path.decode("ascii")
if "%" in path:
path = urllib.parse.unquote(path)
self.scope["path"] = path
self.scope["raw_path"] = raw_path
full_path = self.root_path + path
full_raw_path = self.root_path.encode("ascii") + raw_path
self.scope["path"] = full_path
self.scope["raw_path"] = full_raw_path
self.scope["query_string"] = parsed_url.query or b""

# Handle 503 responses when 'limit_concurrency' is exceeded.
Expand Down
7 changes: 5 additions & 2 deletions uvicorn/protocols/websockets/websockets_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,9 @@ async def process_request(
(name.encode("ascii"), value.encode("ascii", errors="surrogateescape"))
for name, value in headers.raw_items()
]
path = unquote(path_portion)
full_path = self.root_path + path
full_raw_path = self.root_path.encode("ascii") + path_portion.encode("ascii")

self.scope = {
"type": "websocket",
Expand All @@ -193,8 +196,8 @@ async def process_request(
"server": self.server,
"client": self.client,
"root_path": self.root_path,
"path": unquote(path_portion),
"raw_path": path_portion.encode("ascii"),
"path": full_path,
"raw_path": full_raw_path,
"query_string": query_string.encode("ascii"),
"headers": asgi_headers,
"subprotocols": subprotocols,
Expand Down
7 changes: 5 additions & 2 deletions uvicorn/protocols/websockets/wsproto_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,9 @@ def handle_connect(self, event: events.Request) -> None:
headers = [(b"host", event.host.encode())]
headers += [(key.lower(), value) for key, value in event.extra_headers]
raw_path, _, query_string = event.target.partition("?")
path = unquote(raw_path)
full_path = self.root_path + path
full_raw_path = self.root_path.encode("ascii") + raw_path.encode("ascii")
self.scope: "WebSocketScope" = {
"type": "websocket",
"asgi": {"version": self.config.asgi_version, "spec_version": "2.3"},
Expand All @@ -172,8 +175,8 @@ def handle_connect(self, event: events.Request) -> None:
"server": self.server,
"client": self.client,
"root_path": self.root_path,
"path": unquote(raw_path),
"raw_path": raw_path.encode("ascii"),
"path": full_path,
"raw_path": full_raw_path,
"query_string": query_string.encode("ascii"),
"headers": headers,
"subprotocols": event.subprotocols,
Expand Down

0 comments on commit 4af46c9

Please sign in to comment.