Skip to content

Commit

Permalink
[Identity] Correctly implement TokenCredential protocols (#31047)
Browse files Browse the repository at this point in the history
  • Loading branch information
mccoyp authored Aug 4, 2023
1 parent 85856b2 commit 0464b2a
Show file tree
Hide file tree
Showing 53 changed files with 526 additions and 187 deletions.
3 changes: 3 additions & 0 deletions sdk/identity/azure-identity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@

### Bugs Fixed

- Credential types correctly implement `azure-core`'s `TokenCredential` protocol.
([#25175](https://github.com/Azure/azure-sdk-for-python/issues/25175))

### Other Changes

## 1.14.0b2 (2023-07-11)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# ------------------------------------
import logging
import os
from typing import Any
from typing import Any, Optional

from azure.core.credentials import AccessToken
from .chained import ChainedTokenCredential
Expand Down Expand Up @@ -63,24 +63,30 @@ def __init__(self, **kwargs: Any) -> None:
ManagedIdentityCredential(client_id=managed_identity_client_id, **kwargs),
)

def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
def get_token(
self, *scopes: str, claims: Optional[str] = None, tenant_id: Optional[str] = None, **kwargs: Any
) -> AccessToken:
"""Request an access token for `scopes`.
This method is called automatically by Azure SDK clients.
:param str scopes: desired scopes for the access token. This method requires at least one scope.
For more information about scopes, see
https://learn.microsoft.com/azure/active-directory/develop/scopes-oidc.
:keyword str claims: additional claims required in the token, such as those returned in a resource provider's
claims challenge following an authorization failure.
:keyword str tenant_id: optional tenant to include in the token request.
:return: An access token with the desired scopes.
:rtype: ~azure.core.credentials.AccessToken
:raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The exception has a
`message` attribute listing each authentication attempt and its error message.
"""
if self._successful_credential:
token = self._successful_credential.get_token(*scopes, **kwargs)
token = self._successful_credential.get_token(*scopes, claims=claims, tenant_id=tenant_id, **kwargs)
_LOGGER.info(
"%s acquired a token from %s", self.__class__.__name__, self._successful_credential.__class__.__name__
)
return token

return super(AzureApplicationCredential, self).get_token(*scopes, **kwargs)
return super(AzureApplicationCredential, self).get_token(*scopes, claims=claims, tenant_id=tenant_id, **kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ def close(self) -> None:
"""Close the credential's transport session."""
self.__exit__()

def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
def get_token(
self, *scopes: str, claims: Optional[str] = None, tenant_id: Optional[str] = None, **kwargs: Any
) -> AccessToken:
"""Request an access token for `scopes`.
This method is called automatically by Azure SDK clients.
Expand All @@ -73,6 +75,8 @@ def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
:param str scopes: desired scopes for the access token. This method requires at least one scope.
For more information about scopes, see
https://learn.microsoft.com/azure/active-directory/develop/scopes-oidc.
:keyword str claims: additional claims required in the token, such as those returned in a resource provider's
claims challenge following an authorization failure.
:keyword str tenant_id: optional tenant to include in the token request.
:return: An access token with the desired scopes.
Expand All @@ -82,7 +86,7 @@ def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
``response`` attribute.
"""
# pylint:disable=useless-super-delegation
return super(AuthorizationCodeCredential, self).get_token(*scopes, **kwargs)
return super(AuthorizationCodeCredential, self).get_token(*scopes, claims=claims, tenant_id=tenant_id, **kwargs)

def _acquire_token_silently(self, *scopes: str, **kwargs) -> Optional[AccessToken]:
return self._client.get_cached_access_token(scopes, **kwargs)
Expand Down
14 changes: 12 additions & 2 deletions sdk/identity/azure-identity/azure/identity/_credentials/azd_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,13 @@ def close(self) -> None:
"""Calling this method is unnecessary."""

@log_get_token("AzureDeveloperCliCredential")
def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
def get_token(
self,
*scopes: str,
claims: Optional[str] = None, # pylint:disable=unused-argument
tenant_id: Optional[str] = None,
**kwargs: Any,
) -> AccessToken:
"""Request an access token for `scopes`.
This method is called automatically by Azure SDK clients. Applications calling this method directly must
Expand All @@ -100,6 +106,7 @@ def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
:param str scopes: desired scope for the access token. This credential allows only one scope per request.
For more information about scopes, see
https://learn.microsoft.com/azure/active-directory/develop/scopes-oidc.
:keyword str claims: not used by this credential; any value provided will be ignored.
:keyword str tenant_id: optional tenant to include in the token request.
:return: An access token with the desired scopes.
Expand All @@ -117,7 +124,10 @@ def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
commandString = " --scope ".join(scopes)
command = COMMAND_LINE.format(commandString)
tenant = resolve_tenant(
default_tenant=self.tenant_id, additionally_allowed_tenants=self._additionally_allowed_tenants, **kwargs
default_tenant=self.tenant_id,
tenant_id=tenant_id,
additionally_allowed_tenants=self._additionally_allowed_tenants,
**kwargs,
)
if tenant:
command += " --tenant-id " + tenant
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,13 @@ def close(self) -> None:
"""Calling this method is unnecessary."""

@log_get_token("AzureCliCredential")
def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
def get_token(
self,
*scopes: str,
claims: Optional[str] = None, # pylint:disable=unused-argument
tenant_id: Optional[str] = None,
**kwargs: Any,
) -> AccessToken:
"""Request an access token for `scopes`.
This method is called automatically by Azure SDK clients. Applications calling this method directly must
Expand All @@ -78,6 +84,7 @@ def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
:param str scopes: desired scope for the access token. This credential allows only one scope per request.
For more information about scopes, see
https://learn.microsoft.com/azure/active-directory/develop/scopes-oidc.
:keyword str claims: not used by this credential; any value provided will be ignored.
:keyword str tenant_id: optional tenant to include in the token request.
:return: An access token with the desired scopes.
Expand All @@ -91,7 +98,10 @@ def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
resource = _scopes_to_resource(*scopes)
command = COMMAND_LINE.format(resource)
tenant = resolve_tenant(
default_tenant=self.tenant_id, additionally_allowed_tenants=self._additionally_allowed_tenants, **kwargs
default_tenant=self.tenant_id,
tenant_id=tenant_id,
additionally_allowed_tenants=self._additionally_allowed_tenants,
**kwargs,
)
if tenant:
command += " --tenant " + tenant
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import logging
import subprocess
import sys
from typing import List, Tuple, Optional, Any
from typing import Any, List, Tuple, Optional

from azure.core.credentials import AccessToken
from azure.core.exceptions import ClientAuthenticationError
Expand Down Expand Up @@ -83,7 +83,13 @@ def close(self) -> None:
"""Calling this method is unnecessary."""

@log_get_token("AzurePowerShellCredential")
def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
def get_token(
self,
*scopes: str,
claims: Optional[str] = None, # pylint:disable=unused-argument
tenant_id: Optional[str] = None,
**kwargs: Any,
) -> AccessToken:
"""Request an access token for `scopes`.
This method is called automatically by Azure SDK clients. Applications calling this method directly must
Expand All @@ -92,6 +98,7 @@ def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
:param str scopes: desired scope for the access token. This credential allows only one scope per request.
For more information about scopes, see
https://learn.microsoft.com/azure/active-directory/develop/scopes-oidc.
:keyword str claims: not used by this credential; any value provided will be ignored.
:keyword str tenant_id: optional tenant to include in the token request.
:return: An access token with the desired scopes.
Expand All @@ -103,7 +110,10 @@ def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
receive an access token
"""
tenant_id = resolve_tenant(
default_tenant=self.tenant_id, additionally_allowed_tenants=self._additionally_allowed_tenants, **kwargs
default_tenant=self.tenant_id,
tenant_id=tenant_id,
additionally_allowed_tenants=self._additionally_allowed_tenants,
**kwargs,
)
command_line = get_command_line(scopes, tenant_id)
output = run_command_line(command_line, self._process_timeout)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,19 @@ def close(self) -> None:
"""Close the transport session of each credential in the chain."""
self.__exit__()

def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: # pylint:disable=unused-argument
def get_token(
self, *scopes: str, claims: Optional[str] = None, tenant_id: Optional[str] = None, **kwargs: Any
) -> AccessToken:
"""Request a token from each chained credential, in order, returning the first token received.
This method is called automatically by Azure SDK clients.
:param str scopes: desired scopes for the access token. This method requires at least one scope.
For more information about scopes, see
https://learn.microsoft.com/azure/active-directory/develop/scopes-oidc.
:keyword str claims: additional claims required in the token, such as those returned in a resource provider's
claims challenge following an authorization failure.
:keyword str tenant_id: optional tenant to include in the token request.
:return: An access token with the desired scopes.
:rtype: ~azure.core.credentials.AccessToken
Expand All @@ -86,7 +91,7 @@ def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: # pylint:disab
history = []
for credential in self.credentials:
try:
token = credential.get_token(*scopes, **kwargs)
token = credential.get_token(*scopes, claims=claims, tenant_id=tenant_id, **kwargs)
_LOGGER.info("%s acquired a token from %s", self.__class__.__name__, credential.__class__.__name__)
self._successful_credential = credential
return token
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# ------------------------------------
import logging
import os
from typing import List, TYPE_CHECKING, Any, cast
from typing import List, TYPE_CHECKING, Any, Optional, cast

from azure.core.credentials import AccessToken
from .._constants import EnvironmentVariables
Expand Down Expand Up @@ -195,14 +195,18 @@ def __init__(self, **kwargs: Any) -> None: # pylint: disable=too-many-statement

super(DefaultAzureCredential, self).__init__(*credentials)

def get_token(self, *scopes: str, **kwargs) -> AccessToken:
def get_token(
self, *scopes: str, claims: Optional[str] = None, tenant_id: Optional[str] = None, **kwargs: Any
) -> AccessToken:
"""Request an access token for `scopes`.
This method is called automatically by Azure SDK clients.
:param str scopes: desired scopes for the access token. This method requires at least one scope.
For more information about scopes, see
https://learn.microsoft.com/azure/active-directory/develop/scopes-oidc.
:keyword str claims: additional claims required in the token, such as those returned in a resource provider's
claims challenge following an authorization failure.
:keyword str tenant_id: optional tenant to include in the token request.
:return: An access token with the desired scopes.
Expand All @@ -212,12 +216,12 @@ def get_token(self, *scopes: str, **kwargs) -> AccessToken:
`message` attribute listing each authentication attempt and its error message.
"""
if self._successful_credential:
token = self._successful_credential.get_token(*scopes, **kwargs)
token = self._successful_credential.get_token(*scopes, claims=claims, tenant_id=tenant_id, **kwargs)
_LOGGER.info(
"%s acquired a token from %s", self.__class__.__name__, self._successful_credential.__class__.__name__
)
return token
within_dac.set(True)
token = super(DefaultAzureCredential, self).get_token(*scopes, **kwargs)
token = super().get_token(*scopes, claims=claims, tenant_id=tenant_id, **kwargs)
within_dac.set(False)
return token
Original file line number Diff line number Diff line change
Expand Up @@ -120,14 +120,18 @@ def close(self) -> None:
self.__exit__()

@log_get_token("EnvironmentCredential")
def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
def get_token(
self, *scopes: str, claims: Optional[str] = None, tenant_id: Optional[str] = None, **kwargs: Any
) -> AccessToken:
"""Request an access token for `scopes`.
This method is called automatically by Azure SDK clients.
:param str scopes: desired scopes for the access token. This method requires at least one scope.
For more information about scopes, see
https://learn.microsoft.com/azure/active-directory/develop/scopes-oidc.
:keyword str claims: additional claims required in the token, such as those returned in a resource provider's
claims challenge following an authorization failure.
:keyword str tenant_id: optional tenant to include in the token request.
:return: An access token with the desired scopes.
Expand All @@ -142,4 +146,4 @@ def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
"this issue."
)
raise CredentialUnavailableError(message=message)
return self._credential.get_token(*scopes, **kwargs)
return self._credential.get_token(*scopes, claims=claims, tenant_id=tenant_id, **kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,9 @@ def close(self) -> None:
self.__exit__()

@log_get_token("ManagedIdentityCredential")
def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
def get_token(
self, *scopes: str, claims: Optional[str] = None, tenant_id: Optional[str] = None, **kwargs: Any
) -> AccessToken:
"""Request an access token for `scopes`.
This method is called automatically by Azure SDK clients.
Expand All @@ -117,6 +119,9 @@ def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
For more information about scopes, see
https://learn.microsoft.com/azure/active-directory/develop/scopes-oidc.
:keyword str claims: not used by this credential; any value provided will be ignored.
:keyword str tenant_id: not used by this credential; any value provided will be ignored.
:return: An access token with the desired scopes.
:rtype: ~azure.core.credentials.AccessToken
:raises ~azure.identity.CredentialUnavailableError: managed identity isn't available in the hosting environment
Expand All @@ -129,4 +134,4 @@ def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
"Visit https://aka.ms/azsdk/python/identity/managedidentitycredential/troubleshoot to "
"troubleshoot this issue."
)
return self._credential.get_token(*scopes, **kwargs)
return self._credential.get_token(*scopes, claims=claims, tenant_id=tenant_id, **kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ def close(self) -> None:
self.__exit__()

@log_get_token("SharedTokenCacheCredential")
def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
def get_token(
self, *scopes: str, claims: Optional[str] = None, tenant_id: Optional[str] = None, **kwargs: Any
) -> AccessToken:
"""Get an access token for `scopes` from the shared cache.
If no access token is cached, attempt to acquire one using a cached refresh token.
Expand All @@ -64,16 +66,18 @@ def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
https://learn.microsoft.com/azure/active-directory/develop/scopes-oidc.
:keyword str claims: additional claims required in the token, such as those returned in a resource provider's
claims challenge following an authorization failure
:keyword str tenant_id: not used by this credential; any value provided will be ignored.
:keyword bool enable_cae: indicates whether to enable Continuous Access Evaluation (CAE) for the requested
token. Defaults to False.
:return: An access token with the desired scopes.
:rtype: ~azure.core.credentials.AccessToken
:raises ~azure.identity.CredentialUnavailableError: the cache is unavailable or contains insufficient user
information
:raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message``
attribute gives a reason.
"""
return self._credential.get_token(*scopes, **kwargs)
return self._credential.get_token(*scopes, claims=claims, tenant_id=tenant_id, **kwargs)

@staticmethod
def supported() -> bool:
Expand All @@ -97,7 +101,9 @@ def __exit__(self, *args):
if self._client:
self._client.__exit__(*args)

def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
def get_token(
self, *scopes: str, claims: Optional[str] = None, tenant_id: Optional[str] = None, **kwargs: Any
) -> AccessToken:
if not scopes:
raise ValueError("'get_token' requires at least one scope")

Expand All @@ -123,7 +129,9 @@ def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:

# try each refresh token, returning the first access token acquired
for refresh_token in self._get_refresh_tokens(account, is_cae=is_cae):
token = self._client.obtain_token_by_refresh_token(scopes, refresh_token, **kwargs)
token = self._client.obtain_token_by_refresh_token(
scopes, refresh_token, claims=claims, tenant_id=tenant_id, **kwargs
)
return token

raise CredentialUnavailableError(message=NO_TOKEN.format(account.get("username")))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ def __enter__(self):
def __exit__(self, *args):
self._client.__exit__(*args)

def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
def get_token(
self, *scopes: str, claims: Optional[str] = None, tenant_id: Optional[str] = None, **kwargs: Any
) -> AccessToken:
if not scopes:
raise ValueError('"get_token" requires at least one scope')

Expand All @@ -70,7 +72,7 @@ def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
raise CredentialUnavailableError(message="Shared token cache unavailable")
raise ClientAuthenticationError(message="Shared token cache unavailable")

return self._acquire_token_silent(*scopes, **kwargs)
return self._acquire_token_silent(*scopes, claims=claims, tenant_id=tenant_id, **kwargs)

def _initialize_cache(self, is_cae: bool = False) -> Optional[TokenCache]:

Expand Down
Loading

0 comments on commit 0464b2a

Please sign in to comment.