Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preserve query string on redirect #48

Merged
merged 10 commits into from
Oct 2, 2024
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Using the following categories, list your changes in this order:
### Added

- Support Python 3.13.
- Query strings are now preserved during HTTP redirection.

## [2.0.1] - 2024-09-13

Expand Down
1 change: 1 addition & 0 deletions src/servestatic/asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ async def __call__(self, scope, receive, send):
wsgi_headers = {
"HTTP_" + key.decode().upper().replace("-", "_"): value.decode() for key, value in scope["headers"]
}
wsgi_headers["QUERY_STRING"] = scope["query_string"].decode()

# Get the ServeStatic file response
response = await self.static_file.aget_response(scope["method"], wsgi_headers)
Expand Down
13 changes: 11 additions & 2 deletions src/servestatic/responders.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,16 +315,25 @@ def get_path_and_headers(self, request_headers):


class Redirect:
location = "Location"

def __init__(self, location, headers=None):
headers = list(headers.items()) if headers else []
headers.append(("Location", quote(location.encode("utf8"))))
headers.append((self.location, quote(location.encode("utf8"))))
self.response = Response(HTTPStatus.FOUND, headers, None)

def get_response(self, method, request_headers):
query_string = request_headers.get("QUERY_STRING")
if query_string:
headers = list(self.response.headers)
i, value = next((i, value) for (i, (name, value)) in enumerate(headers) if name == self.location)
value = f"{value}?{query_string}"
headers[i] = (self.location, value)
return Response(self.response.status, headers, None)
return self.response

async def aget_response(self, method, request_headers):
return self.response
return self.get_response(method, request_headers)


class NotARegularFileError(Exception):
Expand Down
16 changes: 15 additions & 1 deletion tests/test_asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
def test_files():
return Files(
js=str(Path("static") / "app.js"),
index=str(Path("static") / "with-index" / "index.html"),
)


Expand All @@ -34,7 +35,12 @@ async def asgi_app(scope, receive, send):
})
await send({"type": "http.response.body", "body": b"Not Found"})

return ServeStaticASGI(asgi_app, root=test_files.directory, autorefresh=request.param)
return ServeStaticASGI(
asgi_app,
root=test_files.directory,
autorefresh=request.param,
index_file=True,
)


def test_get_js_static_file(application, test_files):
Expand All @@ -47,6 +53,14 @@ def test_get_js_static_file(application, test_files):
assert send.headers[b"content-length"] == str(len(test_files.js_content)).encode()


def test_redirect_preserves_query_string(application, test_files):
scope = AsgiScopeEmulator({"path": "/static/with-index", "query_string": b"v=1&x=2"})
receive = AsgiReceiveEmulator()
send = AsgiSendEmulator()
asyncio.run(application(scope, receive, send))
assert send.headers[b"location"] == b"with-index/?v=1&x=2"


def test_user_app(application):
scope = AsgiScopeEmulator({"path": "/"})
receive = AsgiReceiveEmulator()
Expand Down
17 changes: 16 additions & 1 deletion tests/test_servestatic.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import pytest
Archmonger marked this conversation as resolved.
Show resolved Hide resolved

from servestatic import ServeStatic
from servestatic.responders import StaticFile
from servestatic.responders import Redirect, StaticFile

from .utils import AppServer, Files

Expand Down Expand Up @@ -245,6 +245,15 @@ def test_index_file_path_redirected(server, files):
assert location == directory_url


def test_index_file_path_redirected_with_query_string(server, files):
directory_url = files.index_url.rpartition("/")[0] + "/"
query_string = "v=1"
response = server.get(f"{files.index_url}?{query_string}", allow_redirects=False)
location = urljoin(files.index_url, response.headers["Location"])
assert response.status_code == 302
assert location == f"{directory_url}?{query_string}"


def test_directory_path_without_trailing_slash_redirected(server, files):
directory_url = files.index_url.rpartition("/")[0] + "/"
no_slash_url = directory_url.rstrip("/")
Expand Down Expand Up @@ -376,3 +385,9 @@ def test_chunked_file_size_matches_range_with_range_header():
while response.file.read(1):
file_size += 1
assert file_size == 14


def test_redirect_preserves_query_string():
responder = Redirect("/redirect/to/here/")
response = responder.get_response("GET", {"QUERY_STRING": "foo=1&bar=2"})
assert response.headers[0] == ("Location", "/redirect/to/here/?foo=1&bar=2")