diff --git a/starlette/responses.py b/starlette/responses.py index 1c9aaa1148..527db9c5f9 100644 --- a/starlette/responses.py +++ b/starlette/responses.py @@ -239,7 +239,7 @@ class FileResponse(Response): def __init__( self, - path: str, + path: typing.Union[str, "os.PathLike[str]"], status_code: int = 200, headers: dict = None, media_type: str = None, diff --git a/starlette/staticfiles.py b/starlette/staticfiles.py index 22b9d3ae6b..41df98062e 100644 --- a/starlette/staticfiles.py +++ b/starlette/staticfiles.py @@ -15,6 +15,8 @@ ) from starlette.types import Receive, Scope, Send +PathLike = typing.Union[str, "os.PathLike[str]"] + class NotModifiedResponse(Response): NOT_MODIFIED_HEADERS = ( @@ -41,7 +43,7 @@ class StaticFiles: def __init__( self, *, - directory: str = None, + directory: PathLike = None, packages: typing.List[str] = None, html: bool = False, check_dir: bool = True, @@ -55,8 +57,8 @@ def __init__( raise RuntimeError(f"Directory '{directory}' does not exist") def get_directories( - self, directory: str = None, packages: typing.List[str] = None - ) -> typing.List[str]: + self, directory: PathLike = None, packages: typing.List[str] = None + ) -> typing.List[PathLike]: """ Given `directory` and `packages` arguments, return a list of all the directories that should be used for serving static files from. @@ -71,11 +73,13 @@ def get_directories( assert ( spec.origin is not None ), f"Directory 'statics' in package {package!r} could not be found." - directory = os.path.normpath(os.path.join(spec.origin, "..", "statics")) + package_directory = os.path.normpath( + os.path.join(spec.origin, "..", "statics") + ) assert os.path.isdir( - directory + package_directory ), f"Directory 'statics' in package {package!r} could not be found." - directories.append(directory) + directories.append(package_directory) return directories @@ -154,7 +158,7 @@ async def lookup_path( def file_response( self, - full_path: str, + full_path: PathLike, stat_result: os.stat_result, scope: Scope, status_code: int = 200, diff --git a/tests/test_staticfiles.py b/tests/test_staticfiles.py index 9e8101bc83..e2cae08f19 100644 --- a/tests/test_staticfiles.py +++ b/tests/test_staticfiles.py @@ -1,5 +1,6 @@ import asyncio import os +import pathlib import time import pytest @@ -23,6 +24,19 @@ def test_staticfiles(tmpdir): assert response.text == "" +def test_staticfiles_with_pathlib(tmpdir): + base_dir = pathlib.Path(tmpdir) + path = base_dir / "example.txt" + with open(path, "w") as file: + file.write("") + + app = StaticFiles(directory=base_dir) + client = TestClient(app) + response = client.get("/example.txt") + assert response.status_code == 200 + assert response.text == "" + + def test_staticfiles_head_with_middleware(tmpdir): """ see https://github.com/encode/starlette/pull/935