Skip to content

Commit

Permalink
[Identity] Enable CAE toggle per token request (#30777)
Browse files Browse the repository at this point in the history
- All relevant credentials (User Credentials + Service Principal Credentials + SharedTokenCacheCredential) now accept and honor an enable_cae keyword argument. This denotes that the token request should include "CP1" client capabilities indicating that the SDK is ready to handle CAE claims challenges.

- Two token caches are now maintained — one for non-CAE tokens and one for CAE-tokens.

- The AZURE_IDENTITY_DISABLE_CP1 environment variable is removed since the behavior of the CP1 capability being "always-on" has been changed.

Signed-off-by: Paul Van Eck <[email protected]>
  • Loading branch information
pvaneck authored Jul 28, 2023
1 parent 8fb77fc commit f6d7789
Show file tree
Hide file tree
Showing 28 changed files with 696 additions and 342 deletions.
6 changes: 6 additions & 0 deletions sdk/identity/azure-identity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,14 @@

### Features Added

- Continuous Access Evaluation (CAE) is now configurable per-request by setting the `enable_cae` keyword argument to `True` in `get_token`. This applies to user credentials and service principal credentials. ([#30777](https://github.com/Azure/azure-sdk-for-python/pull/30777))

### Breaking Changes

- CP1 client capabilities for CAE is no longer always-on by default for user credentials. This capability will now be configured as-needed in each `get_token` request by each SDK. ([#30777](https://github.com/Azure/azure-sdk-for-python/pull/30777))
- Suffixes are now appended to persistent cache names to indicate whether CAE or non-CAE tokens are stored in the cache. This is to prevent CAE and non-CAE tokens from being mixed/overwritten in the same cache. This could potentially cause issues if you are trying to share the same cache between applications that are using different versions of the Azure Identity library as each application would be reading from a different cache file.
- Since CAE is no longer always enabled for user-credentials, the `AZURE_IDENTITY_DISABLE_CP1` environment variable is no longer supported.

### Bugs Fixed

### Other Changes
Expand Down
5 changes: 3 additions & 2 deletions sdk/identity/azure-identity/azure/identity/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
DEFAULT_REFRESH_OFFSET = 300
DEFAULT_TOKEN_REFRESH_RETRY_DELAY = 30

CACHE_NON_CAE_SUFFIX = ".nocae" # cspell:disable-line
CACHE_CAE_SUFFIX = ".cae"


class AzureAuthorityHosts:
AZURE_CHINA = "login.chinacloudapi.cn"
Expand Down Expand Up @@ -50,5 +53,3 @@ class EnvironmentVariables:

AZURE_FEDERATED_TOKEN_FILE = "AZURE_FEDERATED_TOKEN_FILE"
WORKLOAD_IDENTITY_VARS = (AZURE_AUTHORITY_HOST, AZURE_TENANT_ID, AZURE_FEDERATED_TOKEN_FILE)

AZURE_IDENTITY_DISABLE_CP1 = "AZURE_IDENTITY_DISABLE_CP1"
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
https://learn.microsoft.com/azure/active-directory/develop/scopes-oidc.
:keyword str claims: additional claims required in the token, such as those returned in a resource provider's
claims challenge following an authorization failure
:keyword bool enable_cae: indicates whether to enable Continuous Access Evaluation (CAE) for the requested
token. Defaults to False.
:return: An access token with the desired scopes.
:rtype: ~azure.core.credentials.AccessToken
:raises ~azure.identity.CredentialUnavailableError: the cache is unavailable or contains insufficient user
Expand Down Expand Up @@ -100,20 +101,28 @@ def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
if not scopes:
raise ValueError("'get_token' requires at least one scope")

if not self._initialized:
self._initialize()
if not self._client_initialized:
self._initialize_client()

is_cae = bool(kwargs.get("enable_cae", False))
token_cache = self._cae_cache if is_cae else self._cache

# Try to load the cache if it is None.
if not token_cache:
token_cache = self._initialize_cache(is_cae=is_cae)

if not self._cache:
raise CredentialUnavailableError(message="Shared token cache unavailable")
# If the cache is still None, raise an error.
if not token_cache:
raise CredentialUnavailableError(message="Shared token cache unavailable")

account = self._get_account(self._username, self._tenant_id)
account = self._get_account(self._username, self._tenant_id, is_cae=is_cae)

token = self._get_cached_access_token(scopes, account)
token = self._get_cached_access_token(scopes, account, is_cae=is_cae)
if token:
return token

# try each refresh token, returning the first access token acquired
for refresh_token in self._get_refresh_tokens(account):
for refresh_token in self._get_refresh_tokens(account, is_cae=is_cae):
token = self._client.obtain_token_by_refresh_token(scopes, refresh_token, **kwargs)
return token

Expand Down
78 changes: 51 additions & 27 deletions sdk/identity/azure-identity/azure/identity/_credentials/silent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import os
import platform
import time
from typing import Dict, Optional, Any

from msal import PublicClientApplication
from msal import PublicClientApplication, TokenCache

from azure.core.credentials import AccessToken
from azure.core.exceptions import ClientAuthenticationError
Expand All @@ -18,7 +17,6 @@
from .._internal.msal_client import MsalClient
from .._internal.shared_token_cache import NO_TOKEN
from .._persistent_cache import _load_persistent_cache, TokenCachePersistenceOptions
from .._constants import EnvironmentVariables
from .. import AuthenticationRecord


Expand All @@ -39,11 +37,15 @@ def __init__(
self._tenant_id = tenant_id or self._auth_record.tenant_id
validate_tenant_id(self._tenant_id)
self._cache = kwargs.pop("_cache", None)
self._cae_cache = kwargs.pop("_cae_cache", None)

self._cache_persistence_options = kwargs.pop("cache_persistence_options", None)

self._client_applications: Dict[str, PublicClientApplication] = {}
self._cae_client_applications: Dict[str, PublicClientApplication] = {}

self._additionally_allowed_tenants = kwargs.pop("additionally_allowed_tenants", [])
self._client = MsalClient(**kwargs)
self._initialized = False

def __enter__(self):
self._client.__enter__()
Expand All @@ -56,47 +58,69 @@ def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
if not scopes:
raise ValueError('"get_token" requires at least one scope')

if not self._initialized:
self._initialize()
token_cache = self._cae_cache if kwargs.get("enable_cae") else self._cache

# Try to load the cache if it is None.
if not token_cache:
token_cache = self._initialize_cache(is_cae=bool(kwargs.get("enable_cae")))

if not self._cache:
if within_dac.get():
raise CredentialUnavailableError(message="Shared token cache unavailable")
raise ClientAuthenticationError(message="Shared token cache unavailable")
# If the cache is still None, raise an error.
if not token_cache:
if within_dac.get():
raise CredentialUnavailableError(message="Shared token cache unavailable")
raise ClientAuthenticationError(message="Shared token cache unavailable")

return self._acquire_token_silent(*scopes, **kwargs)

def _initialize(self):
if not self._cache and platform.system() in {"Darwin", "Linux", "Windows"}:
def _initialize_cache(self, is_cae: bool = False) -> Optional[TokenCache]:

# If no cache options were provided, the default cache will be used. This credential accepts the
# user's default cache regardless of whether it's encrypted. It doesn't create a new cache. If the
# default cache exists, the user must have created it earlier. If it's unencrypted, the user must
# have allowed that.
cache_options = self._cache_persistence_options or TokenCachePersistenceOptions(allow_unencrypted_storage=True)

if platform.system() not in {"Darwin", "Linux", "Windows"}:
raise CredentialUnavailableError(message="Shared token cache is not supported on this platform.")

if not self._cache and not is_cae:
try:
# If no cache options were provided, the default cache will be used. This credential accepts the
# user's default cache regardless of whether it's encrypted. It doesn't create a new cache. If the
# default cache exists, the user must have created it earlier. If it's unencrypted, the user must
# have allowed that.
options = self._cache_persistence_options or TokenCachePersistenceOptions(
allow_unencrypted_storage=True
)
self._cache = _load_persistent_cache(options)
self._cache = _load_persistent_cache(cache_options, is_cae)
except Exception: # pylint:disable=broad-except
pass
return None

self._initialized = True
if not self._cae_cache and is_cae:
try:
self._cae_cache = _load_persistent_cache(cache_options, is_cae)
except Exception: # pylint:disable=broad-except
return None

return self._cae_cache if is_cae else self._cache

def _get_client_application(self, **kwargs: Any):
tenant_id = resolve_tenant(
self._tenant_id, additionally_allowed_tenants=self._additionally_allowed_tenants, **kwargs
)
if tenant_id not in self._client_applications:

client_applications_map = self._client_applications
capabilities = None
token_cache = self._cache

if kwargs.get("enable_cae"):
client_applications_map = self._cae_client_applications
# CP1 = can handle claims challenges (CAE)
capabilities = None if EnvironmentVariables.AZURE_IDENTITY_DISABLE_CP1 in os.environ else ["CP1"]
self._client_applications[tenant_id] = PublicClientApplication(
capabilities = ["CP1"]
token_cache = self._cae_cache

if tenant_id not in client_applications_map:
client_applications_map[tenant_id] = PublicClientApplication(
client_id=self._auth_record.client_id,
authority="https://{}/{}".format(self._auth_record.authority, tenant_id),
token_cache=self._cache,
token_cache=token_cache,
http_client=self._client,
client_capabilities=capabilities,
)
return self._client_applications[tenant_id]
return client_applications_map[tenant_id]

@wrap_exceptions
def _acquire_token_silent(self, *scopes: str, **kwargs: Any) -> AccessToken:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,4 @@ def _run_pipeline(self, request: HttpRequest, **kwargs: Any) -> AccessToken:
kwargs.pop("claims", None)
now = int(time.time())
response = self._pipeline.run(request, retry_on_methods=self._POST, **kwargs)
return self._process_response(response, now)
return self._process_response(response, now, **kwargs)
Loading

0 comments on commit f6d7789

Please sign in to comment.