diff --git a/docs/cookbook.rst b/docs/cookbook.rst index ebe0bca0b..7bc85a8a0 100644 --- a/docs/cookbook.rst +++ b/docs/cookbook.rst @@ -116,3 +116,32 @@ Starlette. You can add it to your application, ideally in front of the ``Routing :noindex: .. _CORSMiddleware: https://www.starlette.io/middleware/#corsmiddleware + +Reverse Proxy +------------- + +When running behind a reverse proxy with stripped path prefix, you need to configure your +application to properly handle this. + +Single known path prefix +'''''''''''''''''''''''' + +If there is only a single known prefix your application will be running behind, you can simply +pass this path prefix as the `root_path` to your ASGI server: + +.. code-block:: bash + + $ uvicorn run:app --root-path + +.. code-block:: bash + + $ gunicorn -k uvicorn.workers.UvicornWorker run:app --root-path + + +Dynamic path prefix +''''''''''''''''''' + +If you are running behind multiple proxies, or the path is not known, you can wrap your +application in a `ReverseProxied` middleware as shown in `this example`_. + +.. _this example: https://github.com/spec-first/connexion/tree/main/examples/reverseproxy diff --git a/examples/reverseproxy/app.py b/examples/reverseproxy/app.py index c11298bd3..f589cb249 100755 --- a/examples/reverseproxy/app.py +++ b/examples/reverseproxy/app.py @@ -54,10 +54,10 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send): root_path = value.decode() break if root_path: - scope["root_path"] = "/" + root_path.strip("/") - path_info = scope.get("PATH_INFO", scope.get("path")) - if path_info.startswith(root_path): - scope["PATH_INFO"] = path_info[len(root_path) :] + root_path = "/" + root_path.strip("/") + scope["root_path"] = root_path + scope["path"] = root_path + scope.get("path", "") + scope["raw_path"] = root_path.encode() + scope.get("raw_path", "") scope["scheme"] = scope.get("scheme") or self.scheme scope["server"] = scope.get("server") or (self.server, None) diff --git a/pyproject.toml b/pyproject.toml index 2ac6c6057..260569d9a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,7 +53,7 @@ Jinja2 = ">= 3.0.0" python-multipart = ">= 0.0.5" PyYAML = ">= 5.1" requests = ">= 2.27" -starlette = ">= 0.27" +starlette = ">= 0.35" typing-extensions = ">= 4" werkzeug = ">= 2.2.1" diff --git a/tests/api/conftest.py b/tests/api/conftest.py index 9249e6126..bfde3ed03 100644 --- a/tests/api/conftest.py +++ b/tests/api/conftest.py @@ -73,10 +73,10 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send): root_path = value.decode() break if root_path: - scope["root_path"] = "/" + root_path.strip("/") - path_info = scope.get("PATH_INFO", scope.get("path")) - if path_info.startswith(root_path): - scope["PATH_INFO"] = path_info[len(root_path) :] + root_path = "/" + root_path.strip("/") + scope["root_path"] = root_path + scope["path"] = root_path + scope.get("path", "") + scope["raw_path"] = root_path.encode() + scope.get("raw_path", "") scope["scheme"] = scope.get("scheme") or self.scheme scope["server"] = scope.get("server") or (self.server, None)