From ede5f09c633be5ae1d2153457a24e03787fe0c23 Mon Sep 17 00:00:00 2001 From: Clemens Wolff Date: Wed, 2 Oct 2024 11:24:01 +0200 Subject: [PATCH] Preserve query string on redirect (#48) --- CHANGELOG.md | 1 + src/servestatic/asgi.py | 1 + src/servestatic/responders.py | 13 +++++++++++-- tests/test_asgi.py | 16 +++++++++++++++- tests/test_servestatic.py | 17 ++++++++++++++++- 5 files changed, 44 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index edabdecb..87c9b329 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ Using the following categories, list your changes in this order: ### Added - Support Python 3.13. +- Query strings are now preserved during HTTP redirection. ## [2.0.1] - 2024-09-13 diff --git a/src/servestatic/asgi.py b/src/servestatic/asgi.py index 91e88b37..b2c2d944 100644 --- a/src/servestatic/asgi.py +++ b/src/servestatic/asgi.py @@ -47,6 +47,7 @@ async def __call__(self, scope, receive, send): wsgi_headers = { "HTTP_" + key.decode().upper().replace("-", "_"): value.decode() for key, value in scope["headers"] } + wsgi_headers["QUERY_STRING"] = scope["query_string"].decode() # Get the ServeStatic file response response = await self.static_file.aget_response(scope["method"], wsgi_headers) diff --git a/src/servestatic/responders.py b/src/servestatic/responders.py index 7b4e8a02..e83d191e 100644 --- a/src/servestatic/responders.py +++ b/src/servestatic/responders.py @@ -315,16 +315,25 @@ def get_path_and_headers(self, request_headers): class Redirect: + location = "Location" + def __init__(self, location, headers=None): headers = list(headers.items()) if headers else [] - headers.append(("Location", quote(location.encode("utf8")))) + headers.append((self.location, quote(location.encode("utf8")))) self.response = Response(HTTPStatus.FOUND, headers, None) def get_response(self, method, request_headers): + query_string = request_headers.get("QUERY_STRING") + if query_string: + headers = list(self.response.headers) + i, value = next((i, value) for (i, (name, value)) in enumerate(headers) if name == self.location) + value = f"{value}?{query_string}" + headers[i] = (self.location, value) + return Response(self.response.status, headers, None) return self.response async def aget_response(self, method, request_headers): - return self.response + return self.get_response(method, request_headers) class NotARegularFileError(Exception): diff --git a/tests/test_asgi.py b/tests/test_asgi.py index 6489e407..94ee8d6d 100644 --- a/tests/test_asgi.py +++ b/tests/test_asgi.py @@ -15,6 +15,7 @@ def test_files(): return Files( js=str(Path("static") / "app.js"), + index=str(Path("static") / "with-index" / "index.html"), ) @@ -34,7 +35,12 @@ async def asgi_app(scope, receive, send): }) await send({"type": "http.response.body", "body": b"Not Found"}) - return ServeStaticASGI(asgi_app, root=test_files.directory, autorefresh=request.param) + return ServeStaticASGI( + asgi_app, + root=test_files.directory, + autorefresh=request.param, + index_file=True, + ) def test_get_js_static_file(application, test_files): @@ -47,6 +53,14 @@ def test_get_js_static_file(application, test_files): assert send.headers[b"content-length"] == str(len(test_files.js_content)).encode() +def test_redirect_preserves_query_string(application, test_files): + scope = AsgiScopeEmulator({"path": "/static/with-index", "query_string": b"v=1&x=2"}) + receive = AsgiReceiveEmulator() + send = AsgiSendEmulator() + asyncio.run(application(scope, receive, send)) + assert send.headers[b"location"] == b"with-index/?v=1&x=2" + + def test_user_app(application): scope = AsgiScopeEmulator({"path": "/"}) receive = AsgiReceiveEmulator() diff --git a/tests/test_servestatic.py b/tests/test_servestatic.py index dd9c3119..84002b77 100644 --- a/tests/test_servestatic.py +++ b/tests/test_servestatic.py @@ -16,7 +16,7 @@ import pytest from servestatic import ServeStatic -from servestatic.responders import StaticFile +from servestatic.responders import Redirect, StaticFile from .utils import AppServer, Files @@ -245,6 +245,15 @@ def test_index_file_path_redirected(server, files): assert location == directory_url +def test_index_file_path_redirected_with_query_string(server, files): + directory_url = files.index_url.rpartition("/")[0] + "/" + query_string = "v=1" + response = server.get(f"{files.index_url}?{query_string}", allow_redirects=False) + location = urljoin(files.index_url, response.headers["Location"]) + assert response.status_code == 302 + assert location == f"{directory_url}?{query_string}" + + def test_directory_path_without_trailing_slash_redirected(server, files): directory_url = files.index_url.rpartition("/")[0] + "/" no_slash_url = directory_url.rstrip("/") @@ -376,3 +385,9 @@ def test_chunked_file_size_matches_range_with_range_header(): while response.file.read(1): file_size += 1 assert file_size == 14 + + +def test_redirect_preserves_query_string(): + responder = Redirect("/redirect/to/here/") + response = responder.get_response("GET", {"QUERY_STRING": "foo=1&bar=2"}) + assert response.headers[0] == ("Location", "/redirect/to/here/?foo=1&bar=2")