diff --git a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/async_challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/async_challenge_auth_policy.py index 1a872f36b6a83..17cd0674e8942 100644 --- a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/async_challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/async_challenge_auth_policy.py @@ -53,9 +53,11 @@ async def on_request(self, request: PipelineRequest) -> None: scope = challenge.get_scope() or challenge.get_resource() + "/.default" # Exclude tenant for AD FS authentication if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - self._token = await self._credential.get_token(scope) + self._token = await self._credential.get_token(scope, claims=challenge.claims) else: - self._token = await self._credential.get_token(scope, tenant_id=challenge.tenant_id) + self._token = await self._credential.get_token( + scope, claims=challenge.claims, tenant_id=challenge.tenant_id + ) # ignore mypy's warning -- although self._token is Optional, get_token raises when it fails to get a token request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" # type: ignore @@ -104,9 +106,9 @@ async def on_challenge(self, request: PipelineRequest, response: PipelineRespons # The tenant parsed from AD FS challenges is "adfs"; we don't actually need a tenant for AD FS authentication # For AD FS we skip cross-tenant authentication per https://github.com/Azure/azure-sdk-for-python/issues/28648 if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - await self.authorize_request(request, scope) + await self.authorize_request(request, scope, claims=challenge.claims) else: - await self.authorize_request(request, scope, tenant_id=challenge.tenant_id) + await self.authorize_request(request, scope, claims=challenge.claims, tenant_id=challenge.tenant_id) return True diff --git a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/challenge_auth_policy.py index f16297aa50263..b9858736b13d2 100644 --- a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/challenge_auth_policy.py @@ -82,9 +82,11 @@ def on_request(self, request: PipelineRequest) -> None: scope = challenge.get_scope() or challenge.get_resource() + "/.default" # Exclude tenant for AD FS authentication if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - self._token = self._credential.get_token(scope) + self._token = self._credential.get_token(scope, claims=challenge.claims) else: - self._token = self._credential.get_token(scope, tenant_id=challenge.tenant_id) + self._token = self._credential.get_token( + scope, claims=challenge.claims, tenant_id=challenge.tenant_id + ) # ignore mypy's warning -- although self._token is Optional, get_token raises when it fails to get a token request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" # type: ignore @@ -132,9 +134,9 @@ def on_challenge(self, request: PipelineRequest, response: PipelineResponse) -> # The tenant parsed from AD FS challenges is "adfs"; we don't actually need a tenant for AD FS authentication # For AD FS we skip cross-tenant authentication per https://github.com/Azure/azure-sdk-for-python/issues/28648 if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - self.authorize_request(request, scope) + self.authorize_request(request, scope, claims=challenge.claims) else: - self.authorize_request(request, scope, tenant_id=challenge.tenant_id) + self.authorize_request(request, scope, claims=challenge.claims, tenant_id=challenge.tenant_id) return True diff --git a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/http_challenge.py b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/http_challenge.py index df9055c7bda6a..62427ec1372ac 100644 --- a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/http_challenge.py +++ b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/http_challenge.py @@ -2,6 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ +import base64 from typing import Dict, MutableMapping, Optional from urllib import parse @@ -18,7 +19,13 @@ class HttpChallenge(object): def __init__( self, request_uri: str, challenge: str, response_headers: "Optional[MutableMapping[str, str]]" = None ) -> None: - """Parses an HTTP WWW-Authentication Bearer challenge from a server.""" + """Parses an HTTP WWW-Authentication Bearer challenge from a server. + + Example challenge with claims: + Bearer authorization="https://login.windows-ppe.net/", error="invalid_token", + error_description="User session has been revoked", + claims="eyJhY2Nlc3NfdG9rZW4iOnsibmJmIjp7ImVzc2VudGlhbCI6dHJ1ZSwgInZhbHVlIjoiMTYwMzc0MjgwMCJ9fX0=" + """ self.source_authority = self._validate_request_uri(request_uri) self.source_uri = request_uri self._parameters: "Dict[str, str]" = {} @@ -29,16 +36,32 @@ def __init__( self.scheme = split_challenge[0] trimmed_challenge = split_challenge[1] + self.claims = None + encoded_claims = None # split trimmed challenge into comma-separated name=value pairs. Values are expected # to be surrounded by quotes which are stripped here. for item in trimmed_challenge.split(","): + # special case for claims, which can contain = symbols as padding + if "claims=" in item: + if encoded_claims: + # multiple claims challenges, e.g. for cross-tenant auth, would require special handling + # we can't support this scenario for now, so we ignore claims altogether if there are multiple + self.claims = None + encoded_claims = item[item.index("=") + 1 :].strip(" \"'") + padding_needed = -len(encoded_claims) % 4 + try: + decoded_claims = base64.urlsafe_b64decode(encoded_claims + "=" * padding_needed).decode() + self.claims = decoded_claims + except Exception: + continue # process name=value pairs - comps = item.split("=") - if len(comps) == 2: - key = comps[0].strip(' "') - value = comps[1].strip(' "') - if key: - self._parameters[key] = value + else: + comps = item.split("=") + if len(comps) == 2: + key = comps[0].strip(' "') + value = comps[1].strip(' "') + if key: + self._parameters[key] = value # minimum set of parameters if not self._parameters: