From 03848e7f3a6187e5658d3a99523170cfd67ac9f3 Mon Sep 17 00:00:00 2001
From: Sylvain Brunato <sylvain.brunato@c-s.fr>
Date: Fri, 7 Apr 2023 10:08:47 +0200
Subject: [PATCH] feat: handle forward headers

---
 eodag/rest/server.py            | 93 ++++++++++++++++++---------------
 tests/units/test_http_server.py | 22 ++++++++
 2 files changed, 72 insertions(+), 43 deletions(-)

diff --git a/eodag/rest/server.py b/eodag/rest/server.py
index 4173cda02..f832ec3a9 100755
--- a/eodag/rest/server.py
+++ b/eodag/rest/server.py
@@ -49,7 +49,7 @@
     get_stac_item_by_id,
     search_stac_items,
 )
-from eodag.utils import update_nested_dict
+from eodag.utils import parse_header, update_nested_dict
 from eodag.utils.exceptions import (
     MisconfiguredError,
     NoMatchingProductType,
@@ -108,20 +108,6 @@ async def lifespan(app: FastAPI):
 stac_api_config = load_stac_api_config()
 
 
-@router.get("/", tags=["Capabilities"])
-def catalogs_root(request: Request):
-    """STAC catalogs root"""
-
-    response = get_stac_catalogs(
-        url=f"{request.url.scheme}://{request.url.netloc}{request.url.path}",
-        root=f"{request.url.scheme}://{request.url.netloc}",
-        catalogs=[],
-        provider=request.query_params.get("provider", None),
-    )
-
-    return jsonable_encoder(response)
-
-
 @router.get("/api", tags=["Capabilities"])
 def eodag_openapi():
     """Customized openapi"""
@@ -181,6 +167,25 @@ def eodag_openapi():
 )
 
 
+@app.middleware("http")
+async def forward_middleware(request: Request, call_next):
+    """Middleware that handles forward headers and sets request.state.url*"""
+
+    forwarded_host = request.headers.get("x-forwarded-host", None)
+    forwarded_proto = request.headers.get("x-forwarded-proto", None)
+
+    if "forwarded" in request.headers:
+        header_forwarded = parse_header(request.headers["forwarded"])
+        forwarded_host = header_forwarded.get_param("host", None) or forwarded_host
+        forwarded_proto = header_forwarded.get_param("proto", None) or forwarded_proto
+
+    request.state.url_root = f"{forwarded_proto or request.url.scheme}://{forwarded_host or request.url.netloc}"
+    request.state.url = f"{request.state.url_root}{request.url.path}"
+
+    response = await call_next(request)
+    return response
+
+
 @app.exception_handler(StarletteHTTPException)
 async def default_exception_handler(request: Request, error):
     """Default errors handle"""
@@ -224,10 +229,23 @@ async def handle_resource_not_found(request: Request, error):
     )
 
 
+@router.get("/", tags=["Capabilities"])
+def catalogs_root(request: Request):
+    """STAC catalogs root"""
+
+    response = get_stac_catalogs(
+        url=request.state.url,
+        root=request.state.url_root,
+        catalogs=[],
+        provider=request.query_params.get("provider", None),
+    )
+
+    return jsonable_encoder(response)
+
+
 @router.get("/conformance", tags=["Capabilities"])
 def conformance():
     """STAC conformance"""
-
     response = get_stac_conformance()
 
     return jsonable_encoder(response)
@@ -236,7 +254,6 @@ def conformance():
 @router.get("/extensions/oseo/json-schema/schema.json", include_in_schema=False)
 def stac_extension_oseo(request: Request):
     """STAC OGC / OpenSearch extension for EO"""
-
     response = get_stac_extension_oseo(url=request.url.split("?")[0])
 
     return app.response_class(
@@ -248,8 +265,8 @@ def stac_extension_oseo(request: Request):
 @router.post("/search", tags=["STAC"])
 async def stac_search(request: Request):
     """STAC collections items"""
-    url = f"{request.url.scheme}://{request.url.netloc}{request.url.path}"
-    url_root = f"{request.url.scheme}://{request.url.netloc}"
+    url = request.state.url
+    url_root = request.state.url_root
     try:
         body = await request.json()
     except JSONDecodeError:
@@ -274,9 +291,8 @@ async def collections(request: Request):
 
     Can be filtered using parameters: instrument, platform, platformSerialIdentifier, sensorType, processingLevel
     """
-    url = f"{request.url.scheme}://{request.url.netloc}{request.url.path}"
-    url_root = f"{request.url.scheme}://{request.url.netloc}"
-
+    url = request.state.url
+    url_root = request.state.url_root
     try:
         body = await request.json()
     except JSONDecodeError:
@@ -297,9 +313,8 @@ async def collections(request: Request):
 @router.get("/collections/{collection_id}/items", tags=["Data"])
 async def stac_collections_items(collection_id, request: Request):
     """STAC collections items"""
-    url = f"{request.url.scheme}://{request.url.netloc}{request.url.path}"
-    url_root = f"{request.url.scheme}://{request.url.netloc}"
-
+    url = request.state.url
+    url_root = request.state.url_root
     try:
         body = await request.json()
     except JSONDecodeError:
@@ -321,9 +336,8 @@ async def stac_collections_items(collection_id, request: Request):
 @router.get("/collections/{collection_id}", tags=["Capabilities"])
 async def collection_by_id(collection_id, request: Request):
     """STAC collection by id"""
-    url = f"{request.url.scheme}://{request.url.netloc}{request.url.path}"
-    url_root = f"{request.url.scheme}://{request.url.netloc}"
-
+    url = request.state.url
+    url_root = request.state.url_root
     try:
         body = await request.json()
     except JSONDecodeError:
@@ -344,10 +358,8 @@ async def collection_by_id(collection_id, request: Request):
 @router.get("/collections/{collection_id}/items/{item_id}", tags=["Data"])
 async def stac_collections_item(collection_id, item_id, request: Request):
     """STAC collection item by id"""
-
-    url = f"{request.url.scheme}://{request.url.netloc}{request.url.path}"
-    url_root = f"{request.url.scheme}://{request.url.netloc}"
-
+    url = request.state.url
+    url_root = request.state.url_root
     try:
         body = await request.json()
     except JSONDecodeError:
@@ -377,7 +389,6 @@ async def stac_collections_item(collection_id, item_id, request: Request):
 @router.get("/collections/{collection_id}/items/{item_id}/download", tags=["Data"])
 async def stac_collections_item_download(collection_id, item_id, request: Request):
     """STAC collection item local download"""
-
     try:
         body = await request.json()
     except JSONDecodeError:
@@ -427,10 +438,8 @@ async def stac_catalogs_items(catalogs, request: Request):
         '500':
             $ref: '#/components/responses/ServerError
     '"""
-
-    url = f"{request.url.scheme}://{request.url.netloc}{request.url.path}"
-    url_root = f"{request.url.scheme}://{request.url.netloc}"
-
+    url = request.state.url
+    url_root = request.state.url_root
     try:
         body = await request.json()
     except JSONDecodeError:
@@ -485,9 +494,8 @@ async def stac_catalogs_item(catalogs, item_id, request: Request):
         '500':
           $ref: '#/components/responses/ServerError'
     """
-    url = f"{request.url.scheme}://{request.url.netloc}{request.url.path}"
-    url_root = f"{request.url.scheme}://{request.url.netloc}"
-
+    url = request.state.url
+    url_root = request.state.url_root
     try:
         body = await request.json()
     except JSONDecodeError:
@@ -562,9 +570,8 @@ async def stac_catalogs(catalogs, request: Request):
         '500':
             $ref: '#/components/responses/ServerError'
     """
-    url = f"{request.url.scheme}://{request.url.netloc}{request.url.path}"
-    url_root = f"{request.url.scheme}://{request.url.netloc}"
-
+    url = request.state.url
+    url_root = request.state.url_root
     try:
         body = await request.json()
     except JSONDecodeError:
diff --git a/tests/units/test_http_server.py b/tests/units/test_http_server.py
index c84f5215a..4b462f776 100644
--- a/tests/units/test_http_server.py
+++ b/tests/units/test_http_server.py
@@ -76,6 +76,28 @@ def setUp(self):
     def test_route(self):
         self._request_valid("/")
 
+    def test_forward(self):
+        response = self.app.get("/", follow_redirects=True)
+        self.assertEqual(200, response.status_code)
+        resp_json = json.loads(response.content.decode("utf-8"))
+        self.assertEqual(resp_json["links"][0]["href"], "http://testserver")
+
+        response = self.app.get(
+            "/", follow_redirects=True, headers={"Forwarded": "host=foo;proto=https"}
+        )
+        self.assertEqual(200, response.status_code)
+        resp_json = json.loads(response.content.decode("utf-8"))
+        self.assertEqual(resp_json["links"][0]["href"], "https://foo")
+
+        response = self.app.get(
+            "/",
+            follow_redirects=True,
+            headers={"X-Forwarded-Host": "bar", "X-Forwarded-Proto": "httpz"},
+        )
+        self.assertEqual(200, response.status_code)
+        resp_json = json.loads(response.content.decode("utf-8"))
+        self.assertEqual(resp_json["links"][0]["href"], "httpz://bar")
+
     @mock.patch(
         "eodag.rest.utils.eodag_api.search",
         autospec=True,