Skip to content

Commit

Permalink
feat: handle forward headers
Browse files Browse the repository at this point in the history
  • Loading branch information
sbrunato committed Apr 18, 2023
1 parent 5c1fea3 commit 03848e7
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 43 deletions.
93 changes: 50 additions & 43 deletions eodag/rest/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
get_stac_item_by_id,
search_stac_items,
)
from eodag.utils import update_nested_dict
from eodag.utils import parse_header, update_nested_dict
from eodag.utils.exceptions import (
MisconfiguredError,
NoMatchingProductType,
Expand Down Expand Up @@ -108,20 +108,6 @@ async def lifespan(app: FastAPI):
stac_api_config = load_stac_api_config()


@router.get("/", tags=["Capabilities"])
def catalogs_root(request: Request):
"""STAC catalogs root"""

response = get_stac_catalogs(
url=f"{request.url.scheme}://{request.url.netloc}{request.url.path}",
root=f"{request.url.scheme}://{request.url.netloc}",
catalogs=[],
provider=request.query_params.get("provider", None),
)

return jsonable_encoder(response)


@router.get("/api", tags=["Capabilities"])
def eodag_openapi():
"""Customized openapi"""
Expand Down Expand Up @@ -181,6 +167,25 @@ def eodag_openapi():
)


@app.middleware("http")
async def forward_middleware(request: Request, call_next):
"""Middleware that handles forward headers and sets request.state.url*"""

forwarded_host = request.headers.get("x-forwarded-host", None)
forwarded_proto = request.headers.get("x-forwarded-proto", None)

if "forwarded" in request.headers:
header_forwarded = parse_header(request.headers["forwarded"])
forwarded_host = header_forwarded.get_param("host", None) or forwarded_host
forwarded_proto = header_forwarded.get_param("proto", None) or forwarded_proto

request.state.url_root = f"{forwarded_proto or request.url.scheme}://{forwarded_host or request.url.netloc}"
request.state.url = f"{request.state.url_root}{request.url.path}"

response = await call_next(request)
return response


@app.exception_handler(StarletteHTTPException)
async def default_exception_handler(request: Request, error):
"""Default errors handle"""
Expand Down Expand Up @@ -224,10 +229,23 @@ async def handle_resource_not_found(request: Request, error):
)


@router.get("/", tags=["Capabilities"])
def catalogs_root(request: Request):
"""STAC catalogs root"""

response = get_stac_catalogs(
url=request.state.url,
root=request.state.url_root,
catalogs=[],
provider=request.query_params.get("provider", None),
)

return jsonable_encoder(response)


@router.get("/conformance", tags=["Capabilities"])
def conformance():
"""STAC conformance"""

response = get_stac_conformance()

return jsonable_encoder(response)
Expand All @@ -236,7 +254,6 @@ def conformance():
@router.get("/extensions/oseo/json-schema/schema.json", include_in_schema=False)
def stac_extension_oseo(request: Request):
"""STAC OGC / OpenSearch extension for EO"""

response = get_stac_extension_oseo(url=request.url.split("?")[0])

return app.response_class(
Expand All @@ -248,8 +265,8 @@ def stac_extension_oseo(request: Request):
@router.post("/search", tags=["STAC"])
async def stac_search(request: Request):
"""STAC collections items"""
url = f"{request.url.scheme}://{request.url.netloc}{request.url.path}"
url_root = f"{request.url.scheme}://{request.url.netloc}"
url = request.state.url
url_root = request.state.url_root
try:
body = await request.json()
except JSONDecodeError:
Expand All @@ -274,9 +291,8 @@ async def collections(request: Request):
Can be filtered using parameters: instrument, platform, platformSerialIdentifier, sensorType, processingLevel
"""
url = f"{request.url.scheme}://{request.url.netloc}{request.url.path}"
url_root = f"{request.url.scheme}://{request.url.netloc}"

url = request.state.url
url_root = request.state.url_root
try:
body = await request.json()
except JSONDecodeError:
Expand All @@ -297,9 +313,8 @@ async def collections(request: Request):
@router.get("/collections/{collection_id}/items", tags=["Data"])
async def stac_collections_items(collection_id, request: Request):
"""STAC collections items"""
url = f"{request.url.scheme}://{request.url.netloc}{request.url.path}"
url_root = f"{request.url.scheme}://{request.url.netloc}"

url = request.state.url
url_root = request.state.url_root
try:
body = await request.json()
except JSONDecodeError:
Expand All @@ -321,9 +336,8 @@ async def stac_collections_items(collection_id, request: Request):
@router.get("/collections/{collection_id}", tags=["Capabilities"])
async def collection_by_id(collection_id, request: Request):
"""STAC collection by id"""
url = f"{request.url.scheme}://{request.url.netloc}{request.url.path}"
url_root = f"{request.url.scheme}://{request.url.netloc}"

url = request.state.url
url_root = request.state.url_root
try:
body = await request.json()
except JSONDecodeError:
Expand All @@ -344,10 +358,8 @@ async def collection_by_id(collection_id, request: Request):
@router.get("/collections/{collection_id}/items/{item_id}", tags=["Data"])
async def stac_collections_item(collection_id, item_id, request: Request):
"""STAC collection item by id"""

url = f"{request.url.scheme}://{request.url.netloc}{request.url.path}"
url_root = f"{request.url.scheme}://{request.url.netloc}"

url = request.state.url
url_root = request.state.url_root
try:
body = await request.json()
except JSONDecodeError:
Expand Down Expand Up @@ -377,7 +389,6 @@ async def stac_collections_item(collection_id, item_id, request: Request):
@router.get("/collections/{collection_id}/items/{item_id}/download", tags=["Data"])
async def stac_collections_item_download(collection_id, item_id, request: Request):
"""STAC collection item local download"""

try:
body = await request.json()
except JSONDecodeError:
Expand Down Expand Up @@ -427,10 +438,8 @@ async def stac_catalogs_items(catalogs, request: Request):
'500':
$ref: '#/components/responses/ServerError
'"""

url = f"{request.url.scheme}://{request.url.netloc}{request.url.path}"
url_root = f"{request.url.scheme}://{request.url.netloc}"

url = request.state.url
url_root = request.state.url_root
try:
body = await request.json()
except JSONDecodeError:
Expand Down Expand Up @@ -485,9 +494,8 @@ async def stac_catalogs_item(catalogs, item_id, request: Request):
'500':
$ref: '#/components/responses/ServerError'
"""
url = f"{request.url.scheme}://{request.url.netloc}{request.url.path}"
url_root = f"{request.url.scheme}://{request.url.netloc}"

url = request.state.url
url_root = request.state.url_root
try:
body = await request.json()
except JSONDecodeError:
Expand Down Expand Up @@ -562,9 +570,8 @@ async def stac_catalogs(catalogs, request: Request):
'500':
$ref: '#/components/responses/ServerError'
"""
url = f"{request.url.scheme}://{request.url.netloc}{request.url.path}"
url_root = f"{request.url.scheme}://{request.url.netloc}"

url = request.state.url
url_root = request.state.url_root
try:
body = await request.json()
except JSONDecodeError:
Expand Down
22 changes: 22 additions & 0 deletions tests/units/test_http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,28 @@ def setUp(self):
def test_route(self):
self._request_valid("/")

def test_forward(self):
response = self.app.get("/", follow_redirects=True)
self.assertEqual(200, response.status_code)
resp_json = json.loads(response.content.decode("utf-8"))
self.assertEqual(resp_json["links"][0]["href"], "http://testserver")

response = self.app.get(
"/", follow_redirects=True, headers={"Forwarded": "host=foo;proto=https"}
)
self.assertEqual(200, response.status_code)
resp_json = json.loads(response.content.decode("utf-8"))
self.assertEqual(resp_json["links"][0]["href"], "https://foo")

response = self.app.get(
"/",
follow_redirects=True,
headers={"X-Forwarded-Host": "bar", "X-Forwarded-Proto": "httpz"},
)
self.assertEqual(200, response.status_code)
resp_json = json.loads(response.content.decode("utf-8"))
self.assertEqual(resp_json["links"][0]["href"], "httpz://bar")

@mock.patch(
"eodag.rest.utils.eodag_api.search",
autospec=True,
Expand Down

0 comments on commit 03848e7

Please sign in to comment.