diff --git a/eodag/rest/server.py b/eodag/rest/server.py index 4173cda02..f832ec3a9 100755 --- a/eodag/rest/server.py +++ b/eodag/rest/server.py @@ -49,7 +49,7 @@ get_stac_item_by_id, search_stac_items, ) -from eodag.utils import update_nested_dict +from eodag.utils import parse_header, update_nested_dict from eodag.utils.exceptions import ( MisconfiguredError, NoMatchingProductType, @@ -108,20 +108,6 @@ async def lifespan(app: FastAPI): stac_api_config = load_stac_api_config() -@router.get("/", tags=["Capabilities"]) -def catalogs_root(request: Request): - """STAC catalogs root""" - - response = get_stac_catalogs( - url=f"{request.url.scheme}://{request.url.netloc}{request.url.path}", - root=f"{request.url.scheme}://{request.url.netloc}", - catalogs=[], - provider=request.query_params.get("provider", None), - ) - - return jsonable_encoder(response) - - @router.get("/api", tags=["Capabilities"]) def eodag_openapi(): """Customized openapi""" @@ -181,6 +167,25 @@ def eodag_openapi(): ) +@app.middleware("http") +async def forward_middleware(request: Request, call_next): + """Middleware that handles forward headers and sets request.state.url*""" + + forwarded_host = request.headers.get("x-forwarded-host", None) + forwarded_proto = request.headers.get("x-forwarded-proto", None) + + if "forwarded" in request.headers: + header_forwarded = parse_header(request.headers["forwarded"]) + forwarded_host = header_forwarded.get_param("host", None) or forwarded_host + forwarded_proto = header_forwarded.get_param("proto", None) or forwarded_proto + + request.state.url_root = f"{forwarded_proto or request.url.scheme}://{forwarded_host or request.url.netloc}" + request.state.url = f"{request.state.url_root}{request.url.path}" + + response = await call_next(request) + return response + + @app.exception_handler(StarletteHTTPException) async def default_exception_handler(request: Request, error): """Default errors handle""" @@ -224,10 +229,23 @@ async def handle_resource_not_found(request: Request, error): ) +@router.get("/", tags=["Capabilities"]) +def catalogs_root(request: Request): + """STAC catalogs root""" + + response = get_stac_catalogs( + url=request.state.url, + root=request.state.url_root, + catalogs=[], + provider=request.query_params.get("provider", None), + ) + + return jsonable_encoder(response) + + @router.get("/conformance", tags=["Capabilities"]) def conformance(): """STAC conformance""" - response = get_stac_conformance() return jsonable_encoder(response) @@ -236,7 +254,6 @@ def conformance(): @router.get("/extensions/oseo/json-schema/schema.json", include_in_schema=False) def stac_extension_oseo(request: Request): """STAC OGC / OpenSearch extension for EO""" - response = get_stac_extension_oseo(url=request.url.split("?")[0]) return app.response_class( @@ -248,8 +265,8 @@ def stac_extension_oseo(request: Request): @router.post("/search", tags=["STAC"]) async def stac_search(request: Request): """STAC collections items""" - url = f"{request.url.scheme}://{request.url.netloc}{request.url.path}" - url_root = f"{request.url.scheme}://{request.url.netloc}" + url = request.state.url + url_root = request.state.url_root try: body = await request.json() except JSONDecodeError: @@ -274,9 +291,8 @@ async def collections(request: Request): Can be filtered using parameters: instrument, platform, platformSerialIdentifier, sensorType, processingLevel """ - url = f"{request.url.scheme}://{request.url.netloc}{request.url.path}" - url_root = f"{request.url.scheme}://{request.url.netloc}" - + url = request.state.url + url_root = request.state.url_root try: body = await request.json() except JSONDecodeError: @@ -297,9 +313,8 @@ async def collections(request: Request): @router.get("/collections/{collection_id}/items", tags=["Data"]) async def stac_collections_items(collection_id, request: Request): """STAC collections items""" - url = f"{request.url.scheme}://{request.url.netloc}{request.url.path}" - url_root = f"{request.url.scheme}://{request.url.netloc}" - + url = request.state.url + url_root = request.state.url_root try: body = await request.json() except JSONDecodeError: @@ -321,9 +336,8 @@ async def stac_collections_items(collection_id, request: Request): @router.get("/collections/{collection_id}", tags=["Capabilities"]) async def collection_by_id(collection_id, request: Request): """STAC collection by id""" - url = f"{request.url.scheme}://{request.url.netloc}{request.url.path}" - url_root = f"{request.url.scheme}://{request.url.netloc}" - + url = request.state.url + url_root = request.state.url_root try: body = await request.json() except JSONDecodeError: @@ -344,10 +358,8 @@ async def collection_by_id(collection_id, request: Request): @router.get("/collections/{collection_id}/items/{item_id}", tags=["Data"]) async def stac_collections_item(collection_id, item_id, request: Request): """STAC collection item by id""" - - url = f"{request.url.scheme}://{request.url.netloc}{request.url.path}" - url_root = f"{request.url.scheme}://{request.url.netloc}" - + url = request.state.url + url_root = request.state.url_root try: body = await request.json() except JSONDecodeError: @@ -377,7 +389,6 @@ async def stac_collections_item(collection_id, item_id, request: Request): @router.get("/collections/{collection_id}/items/{item_id}/download", tags=["Data"]) async def stac_collections_item_download(collection_id, item_id, request: Request): """STAC collection item local download""" - try: body = await request.json() except JSONDecodeError: @@ -427,10 +438,8 @@ async def stac_catalogs_items(catalogs, request: Request): '500': $ref: '#/components/responses/ServerError '""" - - url = f"{request.url.scheme}://{request.url.netloc}{request.url.path}" - url_root = f"{request.url.scheme}://{request.url.netloc}" - + url = request.state.url + url_root = request.state.url_root try: body = await request.json() except JSONDecodeError: @@ -485,9 +494,8 @@ async def stac_catalogs_item(catalogs, item_id, request: Request): '500': $ref: '#/components/responses/ServerError' """ - url = f"{request.url.scheme}://{request.url.netloc}{request.url.path}" - url_root = f"{request.url.scheme}://{request.url.netloc}" - + url = request.state.url + url_root = request.state.url_root try: body = await request.json() except JSONDecodeError: @@ -562,9 +570,8 @@ async def stac_catalogs(catalogs, request: Request): '500': $ref: '#/components/responses/ServerError' """ - url = f"{request.url.scheme}://{request.url.netloc}{request.url.path}" - url_root = f"{request.url.scheme}://{request.url.netloc}" - + url = request.state.url + url_root = request.state.url_root try: body = await request.json() except JSONDecodeError: diff --git a/tests/units/test_http_server.py b/tests/units/test_http_server.py index c84f5215a..4b462f776 100644 --- a/tests/units/test_http_server.py +++ b/tests/units/test_http_server.py @@ -76,6 +76,28 @@ def setUp(self): def test_route(self): self._request_valid("/") + def test_forward(self): + response = self.app.get("/", follow_redirects=True) + self.assertEqual(200, response.status_code) + resp_json = json.loads(response.content.decode("utf-8")) + self.assertEqual(resp_json["links"][0]["href"], "http://testserver") + + response = self.app.get( + "/", follow_redirects=True, headers={"Forwarded": "host=foo;proto=https"} + ) + self.assertEqual(200, response.status_code) + resp_json = json.loads(response.content.decode("utf-8")) + self.assertEqual(resp_json["links"][0]["href"], "https://foo") + + response = self.app.get( + "/", + follow_redirects=True, + headers={"X-Forwarded-Host": "bar", "X-Forwarded-Proto": "httpz"}, + ) + self.assertEqual(200, response.status_code) + resp_json = json.loads(response.content.decode("utf-8")) + self.assertEqual(resp_json["links"][0]["href"], "httpz://bar") + @mock.patch( "eodag.rest.utils.eodag_api.search", autospec=True,