Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Using the following categories, list your changes in this order:
### Added

- Support Python 3.13.
- Query strings are now preserved during HTTP redirection.

## [2.0.1] - 2024-09-13

Expand Down
1 change: 1 addition & 0 deletions src/servestatic/asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ async def __call__(self, scope, receive, send):
wsgi_headers = {
"HTTP_" + key.decode().upper().replace("-", "_"): value.decode() for key, value in scope["headers"]
}
wsgi_headers["QUERY_STRING"] = scope["query_string"].decode()

# Get the ServeStatic file response
response = await self.static_file.aget_response(scope["method"], wsgi_headers)
Expand Down
13 changes: 11 additions & 2 deletions src/servestatic/responders.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,16 +315,25 @@ def get_path_and_headers(self, request_headers):


class Redirect:
location = "Location"

def __init__(self, location, headers=None):
headers = list(headers.items()) if headers else []
headers.append(("Location", quote(location.encode("utf8"))))
headers.append((self.location, quote(location.encode("utf8"))))
self.response = Response(HTTPStatus.FOUND, headers, None)

def get_response(self, method, request_headers):
query_string = request_headers.get("QUERY_STRING")
if query_string:
headers = list(self.response.headers)
i, value = next((i, value) for (i, (name, value)) in enumerate(headers) if name == self.location)
value = f"{value}?{query_string}"
headers[i] = (self.location, value)
return Response(self.response.status, headers, None)
return self.response

async def aget_response(self, method, request_headers):
return self.response
return self.get_response(method, request_headers)


class NotARegularFileError(Exception):
Expand Down
16 changes: 15 additions & 1 deletion tests/test_asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
def test_files():
return Files(
js=str(Path("static") / "app.js"),
index=str(Path("static") / "with-index" / "index.html"),
)


Expand All @@ -34,7 +35,12 @@ async def asgi_app(scope, receive, send):
})
await send({"type": "http.response.body", "body": b"Not Found"})

return ServeStaticASGI(asgi_app, root=test_files.directory, autorefresh=request.param)
return ServeStaticASGI(
asgi_app,
root=test_files.directory,
autorefresh=request.param,
index_file=True,
)


def test_get_js_static_file(application, test_files):
Expand All @@ -47,6 +53,14 @@ def test_get_js_static_file(application, test_files):
assert send.headers[b"content-length"] == str(len(test_files.js_content)).encode()


def test_redirect_preserves_query_string(application, test_files):
scope = AsgiScopeEmulator({"path": "/static/with-index", "query_string": b"v=1&x=2"})
receive = AsgiReceiveEmulator()
send = AsgiSendEmulator()
asyncio.run(application(scope, receive, send))
assert send.headers[b"location"] == b"with-index/?v=1&x=2"


def test_user_app(application):
scope = AsgiScopeEmulator({"path": "/"})
receive = AsgiReceiveEmulator()
Expand Down
17 changes: 16 additions & 1 deletion tests/test_servestatic.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import pytest

from servestatic import ServeStatic
from servestatic.responders import StaticFile
from servestatic.responders import Redirect, StaticFile

from .utils import AppServer, Files

Expand Down Expand Up @@ -245,6 +245,15 @@ def test_index_file_path_redirected(server, files):
assert location == directory_url


def test_index_file_path_redirected_with_query_string(server, files):
directory_url = files.index_url.rpartition("/")[0] + "/"
query_string = "v=1"
response = server.get(f"{files.index_url}?{query_string}", allow_redirects=False)
location = urljoin(files.index_url, response.headers["Location"])
assert response.status_code == 302
assert location == f"{directory_url}?{query_string}"


def test_directory_path_without_trailing_slash_redirected(server, files):
directory_url = files.index_url.rpartition("/")[0] + "/"
no_slash_url = directory_url.rstrip("/")
Expand Down Expand Up @@ -376,3 +385,9 @@ def test_chunked_file_size_matches_range_with_range_header():
while response.file.read(1):
file_size += 1
assert file_size == 14


def test_redirect_preserves_query_string():
responder = Redirect("/redirect/to/here/")
response = responder.get_response("GET", {"QUERY_STRING": "foo=1&bar=2"})
assert response.headers[0] == ("Location", "/redirect/to/here/?foo=1&bar=2")