Skip to content

Commit

Permalink
Implement CAE support
Browse files Browse the repository at this point in the history
  • Loading branch information
mccoyp committed Sep 12, 2024
1 parent 40c6c85 commit 614cf35
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,11 @@ async def on_request(self, request: PipelineRequest) -> None:
scope = challenge.get_scope() or challenge.get_resource() + "/.default"
# Exclude tenant for AD FS authentication
if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"):
self._token = await self._credential.get_token(scope)
self._token = await self._credential.get_token(scope, claims=challenge.claims)
else:
self._token = await self._credential.get_token(scope, tenant_id=challenge.tenant_id)
self._token = await self._credential.get_token(
scope, claims=challenge.claims, tenant_id=challenge.tenant_id
)

# ignore mypy's warning -- although self._token is Optional, get_token raises when it fails to get a token
request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" # type: ignore
Expand Down Expand Up @@ -104,9 +106,9 @@ async def on_challenge(self, request: PipelineRequest, response: PipelineRespons
# The tenant parsed from AD FS challenges is "adfs"; we don't actually need a tenant for AD FS authentication
# For AD FS we skip cross-tenant authentication per https://github.com/Azure/azure-sdk-for-python/issues/28648
if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"):
await self.authorize_request(request, scope)
await self.authorize_request(request, scope, claims=challenge.claims)
else:
await self.authorize_request(request, scope, tenant_id=challenge.tenant_id)
await self.authorize_request(request, scope, claims=challenge.claims, tenant_id=challenge.tenant_id)

return True

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,11 @@ def on_request(self, request: PipelineRequest) -> None:
scope = challenge.get_scope() or challenge.get_resource() + "/.default"
# Exclude tenant for AD FS authentication
if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"):
self._token = self._credential.get_token(scope)
self._token = self._credential.get_token(scope, claims=challenge.claims)
else:
self._token = self._credential.get_token(scope, tenant_id=challenge.tenant_id)
self._token = self._credential.get_token(
scope, claims=challenge.claims, tenant_id=challenge.tenant_id
)

# ignore mypy's warning -- although self._token is Optional, get_token raises when it fails to get a token
request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" # type: ignore
Expand Down Expand Up @@ -132,9 +134,9 @@ def on_challenge(self, request: PipelineRequest, response: PipelineResponse) ->
# The tenant parsed from AD FS challenges is "adfs"; we don't actually need a tenant for AD FS authentication
# For AD FS we skip cross-tenant authentication per https://github.com/Azure/azure-sdk-for-python/issues/28648
if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"):
self.authorize_request(request, scope)
self.authorize_request(request, scope, claims=challenge.claims)
else:
self.authorize_request(request, scope, tenant_id=challenge.tenant_id)
self.authorize_request(request, scope, claims=challenge.claims, tenant_id=challenge.tenant_id)

return True

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import base64
from typing import Dict, MutableMapping, Optional
from urllib import parse

Expand All @@ -18,7 +19,13 @@ class HttpChallenge(object):
def __init__(
self, request_uri: str, challenge: str, response_headers: "Optional[MutableMapping[str, str]]" = None
) -> None:
"""Parses an HTTP WWW-Authentication Bearer challenge from a server."""
"""Parses an HTTP WWW-Authentication Bearer challenge from a server.
Example challenge with claims:
Bearer authorization="https://login.windows-ppe.net/", error="invalid_token",
error_description="User session has been revoked",
claims="eyJhY2Nlc3NfdG9rZW4iOnsibmJmIjp7ImVzc2VudGlhbCI6dHJ1ZSwgInZhbHVlIjoiMTYwMzc0MjgwMCJ9fX0="
"""
self.source_authority = self._validate_request_uri(request_uri)
self.source_uri = request_uri
self._parameters: "Dict[str, str]" = {}
Expand All @@ -29,16 +36,32 @@ def __init__(
self.scheme = split_challenge[0]
trimmed_challenge = split_challenge[1]

self.claims = None
encoded_claims = None
# split trimmed challenge into comma-separated name=value pairs. Values are expected
# to be surrounded by quotes which are stripped here.
for item in trimmed_challenge.split(","):
# special case for claims, which can contain = symbols as padding
if "claims=" in item:
if encoded_claims:
# multiple claims challenges, e.g. for cross-tenant auth, would require special handling
# we can't support this scenario for now, so we ignore claims altogether if there are multiple
self.claims = None
encoded_claims = item[item.index("=") + 1 :].strip(" \"'")
padding_needed = -len(encoded_claims) % 4
try:
decoded_claims = base64.urlsafe_b64decode(encoded_claims + "=" * padding_needed).decode()
self.claims = decoded_claims
except Exception:
continue
# process name=value pairs
comps = item.split("=")
if len(comps) == 2:
key = comps[0].strip(' "')
value = comps[1].strip(' "')
if key:
self._parameters[key] = value
else:
comps = item.split("=")
if len(comps) == 2:
key = comps[0].strip(' "')
value = comps[1].strip(' "')
if key:
self._parameters[key] = value

# minimum set of parameters
if not self._parameters:
Expand Down

0 comments on commit 614cf35

Please sign in to comment.