diff --git a/nucliadb/nucliadb/learning_proxy.py b/nucliadb/nucliadb/learning_proxy.py index 574e9d7f36..e29984cac3 100644 --- a/nucliadb/nucliadb/learning_proxy.py +++ b/nucliadb/nucliadb/learning_proxy.py @@ -35,8 +35,16 @@ SERVICE_NAME = "nucliadb.learning_proxy" logger = logging.getLogger(SERVICE_NAME) - -NUCLIA_ONPREM_AUTH_HEADER = "X-NUCLIA-NUAKEY" +WHITELISTED_HEADERS = { + "x-nucliadb-user", + "x-nucliadb-roles", + "x-stf-roles", + "x-stf-user", + "x-forwarded-for", + "x-forwarded-host", + "x-forwarded-proto", + "x-forwarded-port", +} class LearningService(str, Enum): @@ -113,6 +121,10 @@ async def learning_collector_proxy( ) +def is_white_listed_header(header: str) -> bool: + return header.lower() in WHITELISTED_HEADERS + + async def proxy( service: LearningService, request: Request, @@ -131,9 +143,11 @@ async def proxy( Returns: Response. The response from the learning API. If the response is chunked, a StreamingResponse is returned. """ + proxied_headers = extra_headers or {} - proxied_headers.update({k.lower(): v for k, v in request.headers.items()}) - proxied_headers.pop("host", None) + proxied_headers.update( + {k.lower(): v for k, v in request.headers.items() if is_white_listed_header(k)} + ) async with service_client( base_url=get_base_url(service=service),