Skip to content

Commit

Permalink
Add default impl to handle token challenges (#37652)
Browse files Browse the repository at this point in the history
* Add default impl to handle token challenges

* update version

* update

* update

* update

* update

* Update sdk/core/azure-core/azure/core/pipeline/policies/_utils.py

Co-authored-by: Paul Van Eck <[email protected]>

* Update sdk/core/azure-core/azure/core/pipeline/policies/_utils.py

Co-authored-by: Paul Van Eck <[email protected]>

* update

* Update sdk/core/azure-core/tests/test_utils.py

Co-authored-by: Paul Van Eck <[email protected]>

* Update sdk/core/azure-core/azure/core/pipeline/policies/_utils.py

Co-authored-by: Paul Van Eck <[email protected]>

* update

---------

Co-authored-by: Paul Van Eck <[email protected]>
  • Loading branch information
xiangyan99 and pvaneck authored Oct 4, 2024
1 parent 37a2e61 commit 55632a7
Show file tree
Hide file tree
Showing 6 changed files with 225 additions and 28 deletions.
4 changes: 3 additions & 1 deletion sdk/core/azure-core/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# Release History

## 1.31.1 (Unreleased)
## 1.32.0 (Unreleased)

### Features Added

- Added a default implementation to handle token challenges in `BearerTokenCredentialPolicy` and `AsyncBearerTokenCredentialPolicy`.

### Breaking Changes

### Bugs Fixed
Expand Down
2 changes: 1 addition & 1 deletion sdk/core/azure-core/azure/core/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@
# regenerated.
# --------------------------------------------------------------------------

VERSION = "1.31.1"
VERSION = "1.32.0"
40 changes: 30 additions & 10 deletions sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# license information.
# -------------------------------------------------------------------------
import time
import base64
from typing import TYPE_CHECKING, Optional, TypeVar, MutableMapping, Any, Union, cast
from azure.core.credentials import (
TokenCredential,
Expand All @@ -19,6 +20,7 @@
from azure.core.rest import HttpResponse, HttpRequest
from . import HTTPPolicy, SansIOHTTPPolicy
from ...exceptions import ServiceRequestError
from ._utils import get_challenge_parameter

if TYPE_CHECKING:

Expand Down Expand Up @@ -82,13 +84,7 @@ def _need_new_token(self) -> bool:
refresh_on = getattr(self._token, "refresh_on", None)
return not self._token or (refresh_on and refresh_on <= now) or self._token.expires_on - now < 300

def _request_token(self, *scopes: str, **kwargs: Any) -> None:
"""Request a new token from the credential.
This will call the credential's appropriate method to get a token and store it in the policy.
:param str scopes: The type of access needed.
"""
def _get_token(self, *scopes: str, **kwargs: Any) -> Union["AccessToken", "AccessTokenInfo"]:
if self._enable_cae:
kwargs.setdefault("enable_cae", self._enable_cae)

Expand All @@ -99,9 +95,17 @@ def _request_token(self, *scopes: str, **kwargs: Any) -> None:
if key in TokenRequestOptions.__annotations__: # pylint: disable=no-member
options[key] = kwargs.pop(key) # type: ignore[literal-required]

self._token = cast(SupportsTokenInfo, self._credential).get_token_info(*scopes, options=options)
else:
self._token = cast(TokenCredential, self._credential).get_token(*scopes, **kwargs)
return cast(SupportsTokenInfo, self._credential).get_token_info(*scopes, options=options)
return cast(TokenCredential, self._credential).get_token(*scopes, **kwargs)

def _request_token(self, *scopes: str, **kwargs: Any) -> None:
"""Request a new token from the credential.
This will call the credential's appropriate method to get a token and store it in the policy.
:param str scopes: The type of access needed.
"""
self._token = self._get_token(*scopes, **kwargs)


class BearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, HTTPPolicy[HTTPRequestType, HTTPResponseType]):
Expand Down Expand Up @@ -191,6 +195,22 @@ def on_challenge(
:rtype: bool
"""
# pylint:disable=unused-argument
headers = response.http_response.headers
error = get_challenge_parameter(headers, "Bearer", "error")
if error == "insufficient_claims":
encoded_claims = get_challenge_parameter(headers, "Bearer", "claims")
if not encoded_claims:
return False
try:
padding_needed = -len(encoded_claims) % 4
claims = base64.urlsafe_b64decode(encoded_claims + "=" * padding_needed).decode("utf-8")
if claims:
token = self._get_token(*self._scopes, claims=claims)
bearer_token = cast(Union["AccessToken", "AccessTokenInfo"], token).token
request.http_request.headers["Authorization"] = "Bearer " + bearer_token
return True
except Exception: # pylint:disable=broad-except
return False
return False

def on_response(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# license information.
# -------------------------------------------------------------------------
import time
import base64
from typing import Any, Awaitable, Optional, cast, TypeVar, Union

from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions
Expand All @@ -23,6 +24,7 @@
)
from azure.core.rest import AsyncHttpResponse, HttpRequest
from azure.core.utils._utils import get_running_async_lock
from ._utils import get_challenge_parameter

from .._tools_async import await_result

Expand Down Expand Up @@ -138,6 +140,22 @@ async def on_challenge(
:rtype: bool
"""
# pylint:disable=unused-argument
headers = response.http_response.headers
error = get_challenge_parameter(headers, "Bearer", "error")
if error == "insufficient_claims":
encoded_claims = get_challenge_parameter(headers, "Bearer", "claims")
if not encoded_claims:
return False
try:
padding_needed = -len(encoded_claims) % 4
claims = base64.urlsafe_b64decode(encoded_claims + "=" * padding_needed).decode("utf-8")
if claims:
token = await self._get_token(*self._scopes, claims=claims)
bearer_token = cast(Union["AccessToken", "AccessTokenInfo"], token).token
request.http_request.headers["Authorization"] = "Bearer " + bearer_token
return True
except Exception: # pylint:disable=broad-except
return False
return False

def on_response(
Expand Down Expand Up @@ -169,13 +187,7 @@ def _need_new_token(self) -> bool:
refresh_on = getattr(self._token, "refresh_on", None)
return not self._token or (refresh_on and refresh_on <= now) or self._token.expires_on - now < 300

async def _request_token(self, *scopes: str, **kwargs: Any) -> None:
"""Request a new token from the credential.
This will call the credential's appropriate method to get a token and store it in the policy.
:param str scopes: The type of access needed.
"""
async def _get_token(self, *scopes: str, **kwargs: Any) -> Union["AccessToken", "AccessTokenInfo"]:
if self._enable_cae:
kwargs.setdefault("enable_cae", self._enable_cae)

Expand All @@ -186,14 +198,22 @@ async def _request_token(self, *scopes: str, **kwargs: Any) -> None:
if key in TokenRequestOptions.__annotations__: # pylint: disable=no-member
options[key] = kwargs.pop(key) # type: ignore[literal-required]

self._token = await await_result(
return await await_result(
cast(AsyncSupportsTokenInfo, self._credential).get_token_info,
*scopes,
options=options,
)
else:
self._token = await await_result(
cast(AsyncTokenCredential, self._credential).get_token,
*scopes,
**kwargs,
)
return await await_result(
cast(AsyncTokenCredential, self._credential).get_token,
*scopes,
**kwargs,
)

async def _request_token(self, *scopes: str, **kwargs: Any) -> None:
"""Request a new token from the credential.
This will call the credential's appropriate method to get a token and store it in the policy.
:param str scopes: The type of access needed.
"""
self._token = await self._get_token(*scopes, **kwargs)
102 changes: 101 additions & 1 deletion sdk/core/azure-core/azure/core/pipeline/policies/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
# --------------------------------------------------------------------------
import datetime
import email.utils
from typing import Optional, cast, Union
from typing import Optional, cast, Union, Tuple
from urllib.parse import urlparse

from azure.core.pipeline.transport import (
Expand Down Expand Up @@ -102,3 +102,103 @@ def get_domain(url: str) -> str:
:return: The domain of the url.
"""
return str(urlparse(url).netloc).lower()


def get_challenge_parameter(headers, challenge_scheme: str, challenge_parameter: str) -> Optional[str]:
"""
Parses the specified parameter from a challenge header found in the response.
:param dict[str, str] headers: The response headers to parse.
:param str challenge_scheme: The challenge scheme containing the challenge parameter, e.g., "Bearer".
:param str challenge_parameter: The parameter key name to search for.
:return: The value of the parameter name if found.
:rtype: str or None
"""
header_value = headers.get("WWW-Authenticate")
if not header_value:
return None

scheme = challenge_scheme
parameter = challenge_parameter
header_span = header_value

# Iterate through each challenge value.
while True:
challenge = get_next_challenge(header_span)
if not challenge:
break
challenge_key, header_span = challenge
if challenge_key.lower() != scheme.lower():
continue
# Enumerate each key-value parameter until we find the parameter key on the specified scheme challenge.
while True:
parameters = get_next_parameter(header_span)
if not parameters:
break
key, value, header_span = parameters
if key.lower() == parameter.lower():
return value

return None


def get_next_challenge(header_value: str) -> Optional[Tuple[str, str]]:
"""
Iterates through the challenge schemes present in a challenge header.
:param str header_value: The header value which will be sliced to remove the first parsed challenge key.
:return: The parsed challenge scheme and the remaining header value.
:rtype: tuple[str, str] or None
"""
header_value = header_value.lstrip(" ")
end_of_challenge_key = header_value.find(" ")

if end_of_challenge_key < 0:
return None

challenge_key = header_value[:end_of_challenge_key]
header_value = header_value[end_of_challenge_key + 1 :]

return challenge_key, header_value


def get_next_parameter(header_value: str, separator: str = "=") -> Optional[Tuple[str, str, str]]:
"""
Iterates through a challenge header value to extract key-value parameters.
:param str header_value: The header value after being parsed by get_next_challenge.
:param str separator: The challenge parameter key-value pair separator, default is '='.
:return: The next available challenge parameter as a tuple (param_key, param_value, remaining header_value).
:rtype: tuple[str, str, str] or None
"""
space_or_comma = " ,"
header_value = header_value.lstrip(space_or_comma)

next_space = header_value.find(" ")
next_separator = header_value.find(separator)

if next_space < next_separator and next_space != -1:
return None

if next_separator < 0:
return None

param_key = header_value[:next_separator].strip()
header_value = header_value[next_separator + 1 :]

quote_index = header_value.find('"')

if quote_index >= 0:
header_value = header_value[quote_index + 1 :]
param_value = header_value[: header_value.find('"')]
else:
trailing_delimiter_index = header_value.find(" ")
if trailing_delimiter_index >= 0:
param_value = header_value[:trailing_delimiter_index]
else:
param_value = header_value

if header_value != param_value:
header_value = header_value[len(param_value) + 1 :]

return param_key, param_value, header_value
57 changes: 56 additions & 1 deletion sdk/core/azure-core/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pytest
from azure.core.utils import case_insensitive_dict
from azure.core.utils._utils import get_running_async_lock
from azure.core.pipeline.policies._utils import parse_retry_after
from azure.core.pipeline.policies._utils import parse_retry_after, get_challenge_parameter


@pytest.fixture()
Expand Down Expand Up @@ -146,3 +146,58 @@ def test_parse_retry_after():
assert ret == 0
ret = parse_retry_after("0.9")
assert ret == 0.9


def test_get_challenge_parameter():
headers = {
"WWW-Authenticate": 'Bearer authorization_uri="https://login.microsoftonline.com/tenant-id", resource="https://vault.azure.net"'
}
assert (
get_challenge_parameter(headers, "Bearer", "authorization_uri") == "https://login.microsoftonline.com/tenant-id"
)
assert get_challenge_parameter(headers, "Bearer", "resource") == "https://vault.azure.net"
assert get_challenge_parameter(headers, "Bearer", "foo") is None

headers = {
"WWW-Authenticate": 'Bearer realm="", authorization_uri="https://login.microsoftonline.com/common/oauth2/authorize", error="insufficient_claims", claims="eyJhY2Nlc3NfdG9rZW4iOnsibmJmIjp7ImVzc2VudGlhbCI6dHJ1ZSwidmFsdWUiOiIxNzI2MDc3NTk1In0sInhtc19jYWVlcnJvciI6eyJ2YWx1ZSI6IjEwMDEyIn19fQ=="'
}
assert (
get_challenge_parameter(headers, "Bearer", "authorization_uri")
== "https://login.microsoftonline.com/common/oauth2/authorize"
)
assert get_challenge_parameter(headers, "Bearer", "error") == "insufficient_claims"
assert (
get_challenge_parameter(headers, "Bearer", "claims")
== "eyJhY2Nlc3NfdG9rZW4iOnsibmJmIjp7ImVzc2VudGlhbCI6dHJ1ZSwidmFsdWUiOiIxNzI2MDc3NTk1In0sInhtc19jYWVlcnJvciI6eyJ2YWx1ZSI6IjEwMDEyIn19fQ=="
)


def test_get_challenge_parameter_not_found():
headers = {
"WWW-Authenticate": 'Pop authorization_uri="https://login.microsoftonline.com/tenant-id", resource="https://vault.azure.net"'
}
assert get_challenge_parameter(headers, "Bearer", "resource") is None


def test_get_multi_challenge_parameter():
headers = {
"WWW-Authenticate": 'Bearer authorization_uri="https://login.microsoftonline.com/tenant-id", resource="https://vault.azure.net" Bearer authorization_uri="https://login.microsoftonline.com/tenant-id", resource="https://vault.azure.net"'
}
assert (
get_challenge_parameter(headers, "Bearer", "authorization_uri") == "https://login.microsoftonline.com/tenant-id"
)
assert get_challenge_parameter(headers, "Bearer", "resource") == "https://vault.azure.net"
assert get_challenge_parameter(headers, "Bearer", "foo") is None

headers = {
"WWW-Authenticate": 'Digest realm="[email protected]", qop="auth,auth-int", nonce="123456abcdefg", opaque="123456", Bearer realm="", authorization_uri="https://login.microsoftonline.com/common/oauth2/authorize", error="insufficient_claims", claims="eyJhY2Nlc3NfdG9rZW4iOnsibmJmIjp7ImVzc2VudGlhbCI6dHJ1ZSwidmFsdWUiOiIxNzI2MDc3NTk1In0sInhtc19jYWVlcnJvciI6eyJ2YWx1ZSI6IjEwMDEyIn19fQ=="'
}
assert (
get_challenge_parameter(headers, "Bearer", "authorization_uri")
== "https://login.microsoftonline.com/common/oauth2/authorize"
)
assert get_challenge_parameter(headers, "Bearer", "error") == "insufficient_claims"
assert (
get_challenge_parameter(headers, "Bearer", "claims")
== "eyJhY2Nlc3NfdG9rZW4iOnsibmJmIjp7ImVzc2VudGlhbCI6dHJ1ZSwidmFsdWUiOiIxNzI2MDc3NTk1In0sInhtc19jYWVlcnJvciI6eyJ2YWx1ZSI6IjEwMDEyIn19fQ=="
)

0 comments on commit 55632a7

Please sign in to comment.