From 9d90e568dc4de83f5a9e37cdd9b4eca38b82c3b5 Mon Sep 17 00:00:00 2001 From: Paul Van Eck Date: Thu, 5 Sep 2024 23:14:37 +0000 Subject: [PATCH 1/8] [Identity] Implement new protocol methods Signed-off-by: Paul Van Eck --- sdk/identity/azure-identity/CHANGELOG.md | 3 + sdk/identity/azure-identity/assets.json | 2 +- .../identity/_credentials/application.py | 29 +- .../_credentials/authorization_code.py | 31 +- .../azure/identity/_credentials/azd_cli.py | 39 ++- .../azure/identity/_credentials/azure_cli.py | 43 ++- .../identity/_credentials/azure_pipelines.py | 21 +- .../identity/_credentials/azure_powershell.py | 42 ++- .../azure/identity/_credentials/chained.py | 68 ++++- .../identity/_credentials/client_assertion.py | 6 +- .../azure/identity/_credentials/default.py | 33 ++- .../identity/_credentials/environment.py | 30 +- .../azure/identity/_credentials/imds.py | 8 +- .../identity/_credentials/managed_identity.py | 29 +- .../identity/_credentials/on_behalf_of.py | 10 +- .../identity/_credentials/shared_cache.py | 65 +++- .../azure/identity/_credentials/silent.py | 47 ++- .../azure/identity/_credentials/vscode.py | 37 ++- .../azure/identity/_internal/aad_client.py | 18 +- .../identity/_internal/aad_client_base.py | 21 +- .../_internal/client_credential_base.py | 20 +- .../identity/_internal/get_token_mixin.py | 66 ++++- .../azure/identity/_internal/interactive.py | 79 ++++- .../_internal/managed_identity_base.py | 11 +- .../_internal/managed_identity_client.py | 28 +- .../_internal/msal_managed_identity_client.py | 66 ++++- .../identity/_internal/shared_token_cache.py | 7 +- .../identity/aio/_credentials/application.py | 32 +- .../aio/_credentials/authorization_code.py | 33 ++- .../identity/aio/_credentials/azd_cli.py | 39 ++- .../identity/aio/_credentials/azure_cli.py | 38 ++- .../aio/_credentials/azure_pipelines.py | 21 +- .../aio/_credentials/azure_powershell.py | 38 ++- .../identity/aio/_credentials/certificate.py | 6 +- .../identity/aio/_credentials/chained.py | 68 ++++- .../aio/_credentials/client_assertion.py | 6 +- .../aio/_credentials/client_secret.py | 6 +- .../identity/aio/_credentials/default.py | 47 ++- .../identity/aio/_credentials/environment.py | 31 +- .../azure/identity/aio/_credentials/imds.py | 10 +- .../aio/_credentials/managed_identity.py | 30 +- .../identity/aio/_credentials/on_behalf_of.py | 6 +- .../identity/aio/_credentials/shared_cache.py | 59 +++- .../azure/identity/aio/_credentials/vscode.py | 37 ++- .../identity/aio/_internal/aad_client.py | 20 +- .../identity/aio/_internal/get_token_mixin.py | 66 ++++- .../aio/_internal/managed_identity_base.py | 11 +- .../aio/_internal/managed_identity_client.py | 4 +- sdk/identity/azure-identity/setup.py | 2 +- sdk/identity/azure-identity/tests/helpers.py | 8 +- .../azure-identity/tests/helpers_async.py | 9 - .../azure-identity/tests/test_aad_client.py | 6 +- .../tests/test_app_service_async.py | 9 +- .../tests/test_application_credential.py | 46 +-- .../test_application_credential_async.py | 37 ++- .../azure-identity/tests/test_auth_code.py | 81 +++-- .../tests/test_auth_code_async.py | 71 +++-- .../azure-identity/tests/test_authority.py | 5 +- .../tests/test_azd_cli_credential.py | 103 ++++--- .../tests/test_azd_cli_credential_async.py | 103 ++++--- .../tests/test_azure_application.py | 5 +- .../azure-identity/tests/test_azure_arc.py | 7 +- .../tests/test_azure_pipelines_credential.py | 19 +- .../test_azure_pipelines_credential_async.py | 19 +- .../tests/test_bearer_token_provider.py | 12 +- .../tests/test_bearer_token_provider_async.py | 14 +- .../tests/test_browser_credential.py | 57 ++-- .../tests/test_certificate_credential.py | 119 +++++--- .../test_certificate_credential_async.py | 100 ++++--- .../tests/test_chained_credential.py | 155 +++++++--- .../test_chained_token_credential_async.py | 153 +++++++--- .../tests/test_cli_credential.py | 119 +++++--- .../tests/test_cli_credential_async.py | 119 +++++--- .../tests/test_client_assertion_credential.py | 13 +- .../test_client_assertion_credential_async.py | 12 +- .../tests/test_client_secret_credential.py | 129 +++++--- .../test_client_secret_credential_async.py | 124 +++++--- .../tests/test_context_manager.py | 5 +- .../azure-identity/tests/test_default.py | 52 ++-- .../tests/test_default_async.py | 50 ++-- .../tests/test_device_code_credential.py | 58 ++-- .../tests/test_environment_credential.py | 9 +- .../test_environment_credential_async.py | 14 +- .../tests/test_get_token_mixin.py | 61 ++-- .../tests/test_get_token_mixin_async.py | 61 ++-- .../tests/test_imds_credential.py | 75 ++--- .../tests/test_imds_credential_async.py | 117 ++++---- .../tests/test_initialization.py | 68 +++++ .../tests/test_initialization_async.py | 58 ++++ .../tests/test_interactive_credential.py | 95 ++++-- .../azure-identity/tests/test_live.py | 56 ++-- .../azure-identity/tests/test_live_async.py | 46 +-- .../tests/test_managed_identity.py | 187 +++++++----- .../tests/test_managed_identity_async.py | 231 +++++++++------ .../tests/test_multi_tenant_auth.py | 11 +- .../tests/test_multi_tenant_auth_async.py | 11 +- sdk/identity/azure-identity/tests/test_obo.py | 49 +-- .../azure-identity/tests/test_obo_async.py | 50 ++-- .../tests/test_powershell_credential.py | 144 +++++---- .../tests/test_powershell_credential_async.py | 138 +++++---- .../tests/test_shared_cache_credential.py | 278 +++++++++++------- .../test_shared_cache_credential_async.py | 203 ++++++++----- .../test_username_password_credential.py | 54 ++-- .../tests/test_vscode_credential.py | 7 +- .../test_workload_identity_credential.py | 8 +- ...test_workload_identity_credential_async.py | 7 +- 106 files changed, 3737 insertions(+), 1589 deletions(-) create mode 100644 sdk/identity/azure-identity/tests/test_initialization.py create mode 100644 sdk/identity/azure-identity/tests/test_initialization_async.py diff --git a/sdk/identity/azure-identity/CHANGELOG.md b/sdk/identity/azure-identity/CHANGELOG.md index 2f86649a7fce..d9c6788bf622 100644 --- a/sdk/identity/azure-identity/CHANGELOG.md +++ b/sdk/identity/azure-identity/CHANGELOG.md @@ -4,6 +4,9 @@ ### Features Added +- All credentials now support the `SupportsTokenInfo` protocol. Each credential now has a `get_token_info` method which returns an `AccessTokenInfo` object. The `get_token_info` method is an alternative method to `get_token` that improves support support for more complex authentication scenarios. ([#36882](https://github.com/Azure/azure-sdk-for-python/pull/36882)) + - Information on when a token should be refreshed is now saved in `AccessTokenInfo` (if available). + ### Breaking Changes ### Bugs Fixed diff --git a/sdk/identity/azure-identity/assets.json b/sdk/identity/azure-identity/assets.json index be62aaed10a1..477825b9d646 100644 --- a/sdk/identity/azure-identity/assets.json +++ b/sdk/identity/azure-identity/assets.json @@ -2,5 +2,5 @@ "AssetsRepo": "Azure/azure-sdk-assets", "AssetsRepoPrefixPath": "python", "TagPrefix": "python/identity/azure-identity", - "Tag": "python/identity/azure-identity_cb8dd6f319" + "Tag": "python/identity/azure-identity_61e626a4a0" } diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/application.py b/sdk/identity/azure-identity/azure/identity/_credentials/application.py index 852b44e7f712..1d900f0d1fcb 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/application.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/application.py @@ -4,9 +4,9 @@ # ------------------------------------ import logging import os -from typing import Any, Optional +from typing import Any, Optional, cast -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions, SupportsTokenInfo from .chained import ChainedTokenCredential from .environment import EnvironmentCredential from .managed_identity import ManagedIdentityCredential @@ -90,3 +90,28 @@ def get_token( return token return super(AzureApplicationCredential, self).get_token(*scopes, claims=claims, tenant_id=tenant_id, **kwargs) + + def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: + """Request an access token for `scopes`. + + This is an alternative to `get_token` to enable certain scenarios that require additional properties + on the token. This method is called automatically by Azure SDK clients. + + :param str scopes: desired scopes for the access token. This method requires at least one scope. + For more information about scopes, see https://learn.microsoft.com/entra/identity-platform/scopes-oidc. + :keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional. + :paramtype options: ~azure.core.credentials.TokenRequestOptions + + :rtype: AccessTokenInfo + :return: An AccessTokenInfo instance containing information about the token. + :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The exception has a + `message` attribute listing each authentication attempt and its error message. + """ + if self._successful_credential: + token_info = cast(SupportsTokenInfo, self._successful_credential).get_token_info(*scopes, options=options) + _LOGGER.info( + "%s acquired a token from %s", self.__class__.__name__, self._successful_credential.__class__.__name__ + ) + return token_info + + return cast(SupportsTokenInfo, super()).get_token_info(*scopes, options=options) diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/authorization_code.py b/sdk/identity/azure-identity/azure/identity/_credentials/authorization_code.py index b5997e60e80e..c98aa26c9a09 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/authorization_code.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/authorization_code.py @@ -4,7 +4,7 @@ # ------------------------------------ from typing import Optional, Any -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions from azure.core.exceptions import ClientAuthenticationError from .._internal.aad_client import AadClient from .._internal.get_token_mixin import GetTokenMixin @@ -90,10 +90,35 @@ def get_token( *scopes, claims=claims, tenant_id=tenant_id, client_secret=self._client_secret, **kwargs ) - def _acquire_token_silently(self, *scopes: str, **kwargs) -> Optional[AccessToken]: + def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: + """Request an access token for `scopes`. + + This is an alternative to `get_token` to enable certain scenarios that require additional properties + on the token. This method is called automatically by Azure SDK clients. + + The first time this method is called, the credential will redeem its authorization code. On subsequent calls + the credential will return a cached access token or redeem a refresh token, if it acquired a refresh token upon + redeeming the authorization code. + + :param str scopes: desired scopes for the access token. This method requires at least one scope. + For more information about scopes, see https://learn.microsoft.com/entra/identity-platform/scopes-oidc. + :keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional. + :paramtype options: ~azure.core.credentials.TokenRequestOptions + + :rtype: AccessTokenInfo + :return: An AccessTokenInfo instance containing information about the token. + :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message`` + attribute gives a reason. Any error response from Microsoft Entra ID is available as the error's + ``response`` attribute. + """ + return super()._get_token_base( + *scopes, options=options, client_secret=self._client_secret, base_method_name="get_token_info" + ) + + def _acquire_token_silently(self, *scopes: str, **kwargs) -> Optional[AccessTokenInfo]: return self._client.get_cached_access_token(scopes, **kwargs) - def _request_token(self, *scopes: str, **kwargs) -> AccessToken: + def _request_token(self, *scopes: str, **kwargs) -> AccessTokenInfo: if self._authorization_code: token = self._client.obtain_token_by_authorization_code( scopes=scopes, code=self._authorization_code, redirect_uri=self._redirect_uri, **kwargs diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/azd_cli.py b/sdk/identity/azure-identity/azure/identity/_credentials/azd_cli.py index 2af899ae2e63..319569482abd 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/azd_cli.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/azd_cli.py @@ -12,7 +12,7 @@ import sys from typing import Any, Dict, List, Optional -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions from azure.core.exceptions import ClientAuthenticationError from .. import CredentialUnavailableError @@ -118,10 +118,43 @@ def get_token( :raises ~azure.core.exceptions.ClientAuthenticationError: the credential invoked the Azure Developer CLI but didn't receive an access token. """ + options: TokenRequestOptions = {} + if tenant_id: + options["tenant_id"] = tenant_id + + token_info = self._get_token_base(*scopes, options=options, **kwargs) + return AccessToken(token_info.token, token_info.expires_on) + + @log_get_token + def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: + """Request an access token for `scopes`. + + This is an alternative to `get_token` to enable certain scenarios that require additional properties + on the token. This method is called automatically by Azure SDK clients. Applications calling this method + directly must also handle token caching because this credential doesn't cache the tokens it acquires. + + :param str scopes: desired scopes for the access token. This method requires at least one scope. + For more information about scopes, see https://learn.microsoft.com/entra/identity-platform/scopes-oidc. + :keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional. + :paramtype options: ~azure.core.credentials.TokenRequestOptions + + :rtype: AccessTokenInfo + :return: An AccessTokenInfo instance containing information about the token. + + :raises ~azure.identity.CredentialUnavailableError: the credential was unable to invoke + the Azure Developer CLI. + :raises ~azure.core.exceptions.ClientAuthenticationError: the credential invoked + the Azure Developer CLI but didn't receive an access token. + """ + return self._get_token_base(*scopes, options=options) + def _get_token_base( + self, *scopes: str, options: Optional[TokenRequestOptions] = None, **kwargs: Any + ) -> AccessTokenInfo: if not scopes: raise ValueError("Missing scope in request. \n") + tenant_id = options.get("tenant_id") if options else None if tenant_id: validate_tenant_id(tenant_id) for scope in scopes: @@ -154,7 +187,7 @@ def get_token( return token -def parse_token(output: str) -> Optional[AccessToken]: +def parse_token(output: str) -> Optional[AccessTokenInfo]: """Parse to an AccessToken. In particular, convert the "expiresOn" value to epoch seconds. This value is a naive local datetime as returned by @@ -169,7 +202,7 @@ def parse_token(output: str) -> Optional[AccessToken]: dt = datetime.strptime(token["expiresOn"], "%Y-%m-%dT%H:%M:%SZ") expires_on = dt.timestamp() - return AccessToken(token["token"], int(expires_on)) + return AccessTokenInfo(token["token"], int(expires_on)) except (KeyError, ValueError): return None diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/azure_cli.py b/sdk/identity/azure-identity/azure/identity/_credentials/azure_cli.py index cb48b98356d4..a6feae9d2d98 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/azure_cli.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/azure_cli.py @@ -11,7 +11,7 @@ import sys from typing import List, Optional, Any, Dict -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions from azure.core.exceptions import ClientAuthenticationError from .. import CredentialUnavailableError @@ -94,6 +94,41 @@ def get_token( :raises ~azure.core.exceptions.ClientAuthenticationError: the credential invoked the Azure CLI but didn't receive an access token. """ + + options: TokenRequestOptions = {} + if tenant_id: + options["tenant_id"] = tenant_id + + token_info = self._get_token_base(*scopes, options=options, **kwargs) + return AccessToken(token_info.token, token_info.expires_on) + + @log_get_token + def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: + """Request an access token for `scopes`. + + This is an alternative to `get_token` to enable certain scenarios that require additional properties + on the token. This method is called automatically by Azure SDK clients. Applications calling this method + directly must also handle token caching because this credential doesn't cache the tokens it acquires. + + :param str scopes: desired scopes for the access token. This credential allows only one scope per request. + For more information about scopes, see https://learn.microsoft.com/entra/identity-platform/scopes-oidc. + :keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional. + :paramtype options: ~azure.core.credentials.TokenRequestOptions + + :rtype: AccessTokenInfo + :return: An AccessTokenInfo instance containing information about the token. + + :raises ~azure.identity.CredentialUnavailableError: the credential was unable to invoke the Azure CLI. + :raises ~azure.core.exceptions.ClientAuthenticationError: the credential invoked the Azure CLI but didn't + receive an access token. + """ + return self._get_token_base(*scopes, options=options) + + def _get_token_base( + self, *scopes: str, options: Optional[TokenRequestOptions] = None, **kwargs: Any + ) -> AccessTokenInfo: + + tenant_id = options.get("tenant_id") if options else None if tenant_id: validate_tenant_id(tenant_id) for scope in scopes: @@ -126,7 +161,7 @@ def get_token( return token -def parse_token(output) -> Optional[AccessToken]: +def parse_token(output) -> Optional[AccessTokenInfo]: """Parse output of 'az account get-access-token' to an AccessToken. In particular, convert the "expiresOn" value to epoch seconds. This value is a naive local datetime as returned by @@ -141,11 +176,11 @@ def parse_token(output) -> Optional[AccessToken]: # Use "expires_on" if it's present, otherwise use "expiresOn". if "expires_on" in token: - return AccessToken(token["accessToken"], int(token["expires_on"])) + return AccessTokenInfo(token["accessToken"], int(token["expires_on"])) dt = datetime.strptime(token["expiresOn"], "%Y-%m-%d %H:%M:%S.%f") expires_on = dt.timestamp() - return AccessToken(token["accessToken"], int(expires_on)) + return AccessTokenInfo(token["accessToken"], int(expires_on)) except (KeyError, ValueError): return None diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/azure_pipelines.py b/sdk/identity/azure-identity/azure/identity/_credentials/azure_pipelines.py index ea074b406cc7..a981ecff7a4c 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/azure_pipelines.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/azure_pipelines.py @@ -7,7 +7,7 @@ from typing import Any, Optional from azure.core.exceptions import ClientAuthenticationError -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions from azure.core.rest import HttpRequest, HttpResponse from .client_assertion import ClientAssertionCredential @@ -125,6 +125,25 @@ def get_token( *scopes, claims=claims, tenant_id=tenant_id, enable_cae=enable_cae, **kwargs ) + def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: + """Request an access token for `scopes`. + + This is an alternative to `get_token` to enable certain scenarios that require additional properties + on the token. This method is called automatically by Azure SDK clients. + + :param str scopes: desired scope for the access token. This method requires at least one scope. + For more information about scopes, see https://learn.microsoft.com/entra/identity-platform/scopes-oidc. + :keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional. + :paramtype options: ~azure.core.credentials.TokenRequestOptions + + :rtype: AccessTokenInfo + :return: An AccessTokenInfo instance containing information about the token. + :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message`` + attribute gives a reason. + """ + validate_env_vars() + return self._client_assertion_credential.get_token_info(*scopes, options=options) + def _get_oidc_token(self) -> str: request = build_oidc_request(self._service_connection_id, self._system_access_token) response = self._pipeline.run(request, retry_on_methods=[request.method]) diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/azure_powershell.py b/sdk/identity/azure-identity/azure/identity/_credentials/azure_powershell.py index da3dd2c45ab0..92dd0432bce2 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/azure_powershell.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/azure_powershell.py @@ -8,7 +8,7 @@ import sys from typing import Any, List, Tuple, Optional -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions from azure.core.exceptions import ClientAuthenticationError from .azure_cli import get_safe_working_dir @@ -125,6 +125,42 @@ def get_token( :raises ~azure.core.exceptions.ClientAuthenticationError: the credential invoked Azure PowerShell but didn't receive an access token """ + + options: TokenRequestOptions = {} + if tenant_id: + options["tenant_id"] = tenant_id + + token_info = self._get_token_base(*scopes, options=options, **kwargs) + return AccessToken(token_info.token, token_info.expires_on) + + @log_get_token + def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: + """Request an access token for `scopes`. + + This is an alternative to `get_token` to enable certain scenarios that require additional properties + on the token. This method is called automatically by Azure SDK clients. Applications calling this method + directly must also handle token caching because this credential doesn't cache the tokens it acquires. + + :param str scopes: desired scopes for the access token. TThis credential allows only one scope per request. + For more information about scopes, see https://learn.microsoft.com/entra/identity-platform/scopes-oidc. + :keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional. + :paramtype options: ~azure.core.credentials.TokenRequestOptions + + :rtype: AccessTokenInfo + :return: An AccessTokenInfo instance containing information about the token. + + :raises ~azure.identity.CredentialUnavailableError: the credential was unable to invoke Azure PowerShell, or + no account is authenticated + :raises ~azure.core.exceptions.ClientAuthenticationError: the credential invoked Azure PowerShell but didn't + receive an access token + """ + return self._get_token_base(*scopes, options=options) + + def _get_token_base( + self, *scopes: str, options: Optional[TokenRequestOptions] = None, **kwargs: Any + ) -> AccessTokenInfo: + + tenant_id = options.get("tenant_id") if options else None if tenant_id: validate_tenant_id(tenant_id) for scope in scopes: @@ -185,11 +221,11 @@ def start_process(args: List[str]) -> "subprocess.Popen": return proc -def parse_token(output: str) -> AccessToken: +def parse_token(output: str) -> AccessTokenInfo: for line in output.split(): if line.startswith("azsdk%"): _, token, expires_on = line.split("%") - return AccessToken(token, int(expires_on)) + return AccessTokenInfo(token, int(expires_on)) if within_dac.get(): raise CredentialUnavailableError(message='Unexpected output from Get-AzAccessToken: "{}"'.format(output)) diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/chained.py b/sdk/identity/azure-identity/azure/identity/_credentials/chained.py index 10e03a6daa0a..df888cb2ef70 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/chained.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/chained.py @@ -3,10 +3,10 @@ # Licensed under the MIT License. # ------------------------------------ import logging -from typing import Any, Optional, TYPE_CHECKING +from typing import Any, Optional, TYPE_CHECKING, cast from azure.core.exceptions import ClientAuthenticationError -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions, SupportsTokenInfo from .. import CredentialUnavailableError from .._internal import within_credential_chain @@ -74,6 +74,9 @@ def get_token( ) -> AccessToken: """Request a token from each chained credential, in order, returning the first token received. + If no credential provides a token, raises :class:`azure.core.exceptions.ClientAuthenticationError` + with an error message from each credential. + This method is called automatically by Azure SDK clients. :param str scopes: desired scopes for the access token. This method requires at least one scope. @@ -122,3 +125,64 @@ def get_token( ) _LOGGER.warning(message) raise ClientAuthenticationError(message=message) + + def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: + """Request a token from each chained credential, in order, returning the first token received. + + If no credential provides a token, raises :class:`azure.core.exceptions.ClientAuthenticationError` + with an error message from each credential. + + This is an alternative to `get_token` to enable certain scenarios that require additional properties + on the token. This method is called automatically by Azure SDK clients. + + :param str scopes: desired scopes for the access token. This method requires at least one scope. + For more information about scopes, see https://learn.microsoft.com/entra/identity-platform/scopes-oidc. + :keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional. + :paramtype options: ~azure.core.credentials.TokenRequestOptions + + :rtype: AccessTokenInfo + :return: An AccessTokenInfo instance containing information about the token. + + :raises ~azure.core.exceptions.ClientAuthenticationError: no credential in the chain provided a token. + """ + within_credential_chain.set(True) + history = [] + for credential in self.credentials: + try: + # A custom credential in the chain may not implement get_token_info + if hasattr(credential, "get_token_info"): + token_info = cast(SupportsTokenInfo, credential).get_token_info(*scopes, options=options) + else: + options = options or {} + token = credential.get_token(*scopes, **options) + token_info = AccessTokenInfo(token=token.token, expires_on=token.expires_on) + _LOGGER.info("%s acquired a token from %s", self.__class__.__name__, credential.__class__.__name__) + self._successful_credential = credential + within_credential_chain.set(False) + return token_info + except CredentialUnavailableError as ex: + # credential didn't attempt authentication because it lacks required data or state -> continue + history.append((credential, ex.message)) + except Exception as ex: # pylint: disable=broad-except + # credential failed to authenticate, or something unexpectedly raised -> break + history.append((credential, str(ex))) + _LOGGER.debug( + '%s.get_token_info failed: %s raised unexpected error "%s"', + self.__class__.__name__, + credential.__class__.__name__, + ex, + exc_info=True, + ) + break + + within_credential_chain.set(False) + attempts = _get_error_message(history) + message = ( + self.__class__.__name__ + + " failed to retrieve a token from the included credentials." + + attempts + + "\nTo mitigate this issue, please refer to the troubleshooting guidelines here at " + "https://aka.ms/azsdk/python/identity/defaultazurecredential/troubleshoot." + ) + _LOGGER.warning(message) + raise ClientAuthenticationError(message=message) diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/client_assertion.py b/sdk/identity/azure-identity/azure/identity/_credentials/client_assertion.py index 9970a2fb80e2..bb371381c6b6 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/client_assertion.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/client_assertion.py @@ -4,7 +4,7 @@ # ------------------------------------ from typing import Callable, Optional, Any -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessTokenInfo from .._internal import AadClient from .._internal.get_token_mixin import GetTokenMixin @@ -68,10 +68,10 @@ def __exit__(self, *args: Any) -> None: def close(self) -> None: self.__exit__() - def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessToken]: + def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessTokenInfo]: return self._client.get_cached_access_token(scopes, **kwargs) - def _request_token(self, *scopes: str, **kwargs: Any) -> AccessToken: + def _request_token(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo: assertion = self._func() token = self._client.obtain_token_by_jwt_assertion(scopes, assertion, **kwargs) return token diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/default.py b/sdk/identity/azure-identity/azure/identity/_credentials/default.py index 57446c062ca9..b1503cb65dff 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/default.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/default.py @@ -6,7 +6,7 @@ import os from typing import List, TYPE_CHECKING, Any, Optional, cast -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions, SupportsTokenInfo from .._constants import EnvironmentVariables from .._internal import get_default_authority, normalize_authority, within_dac from .azure_powershell import AzurePowerShellCredential @@ -214,7 +214,7 @@ def get_token( :rtype: ~azure.core.credentials.AccessToken :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The exception has a - `message` attribute listing each authentication attempt and its error message. + `message` attribute listing each authentication attempt and its error message. """ if self._successful_credential: token = self._successful_credential.get_token(*scopes, claims=claims, tenant_id=tenant_id, **kwargs) @@ -226,3 +226,32 @@ def get_token( token = super().get_token(*scopes, claims=claims, tenant_id=tenant_id, **kwargs) within_dac.set(False) return token + + def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: + """Request an access token for `scopes`. + + This is an alternative to `get_token` to enable certain scenarios that require additional properties + on the token. This method is called automatically by Azure SDK clients. + + :param str scopes: desired scopes for the access token. This method requires at least one scope. + For more information about scopes, see https://learn.microsoft.com/entra/identity-platform/scopes-oidc. + :keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional. + :paramtype options: ~azure.core.credentials.TokenRequestOptions + + :rtype: AccessTokenInfo + :return: An AccessTokenInfo instance containing information about the token. + + :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The exception has a + `message` attribute listing each authentication attempt and its error message. + """ + if self._successful_credential: + token_info = cast(SupportsTokenInfo, self._successful_credential).get_token_info(*scopes, options=options) + _LOGGER.info( + "%s acquired a token from %s", self.__class__.__name__, self._successful_credential.__class__.__name__ + ) + return token_info + + within_dac.set(True) + token_info = cast(SupportsTokenInfo, super()).get_token_info(*scopes, options=options) + within_dac.set(False) + return token_info diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/environment.py b/sdk/identity/azure-identity/azure/identity/_credentials/environment.py index 146d9be5c9e1..8f87a1d9ff97 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/environment.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/environment.py @@ -4,8 +4,8 @@ # ------------------------------------ import logging import os -from typing import Optional, Union, Any -from azure.core.credentials import AccessToken +from typing import Optional, Union, Any, cast +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions, SupportsTokenInfo from .. import CredentialUnavailableError from .._constants import EnvironmentVariables @@ -155,3 +155,29 @@ def get_token( ) raise CredentialUnavailableError(message=message) return self._credential.get_token(*scopes, claims=claims, tenant_id=tenant_id, **kwargs) + + @log_get_token + def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: + """Request an access token for `scopes`. + + This is an alternative to `get_token` to enable certain scenarios that require additional properties + on the token. This method is called automatically by Azure SDK clients. + + :param str scopes: desired scope for the access token. This method requires at least one scope. + For more information about scopes, see https://learn.microsoft.com/entra/identity-platform/scopes-oidc. + :keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional. + :paramtype options: ~azure.core.credentials.TokenRequestOptions + + :rtype: AccessTokenInfo + :return: An AccessTokenInfo instance containing information about the token. + + :raises ~azure.identity.CredentialUnavailableError: environment variable configuration is incomplete. + """ + if not self._credential: + message = ( + "EnvironmentCredential authentication unavailable. Environment variables are not fully configured.\n" + "Visit https://aka.ms/azsdk/python/identity/environmentcredential/troubleshoot to troubleshoot " + "this issue." + ) + raise CredentialUnavailableError(message=message) + return cast(SupportsTokenInfo, self._credential).get_token_info(*scopes, options=options) diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/imds.py b/sdk/identity/azure-identity/azure/identity/_credentials/imds.py index cd9a0149afc6..6528f23b83ec 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/imds.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/imds.py @@ -8,7 +8,7 @@ from azure.core.exceptions import ClientAuthenticationError, HttpResponseError from azure.core.pipeline.transport import HttpRequest -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessTokenInfo from .. import CredentialUnavailableError from .._constants import EnvironmentVariables @@ -76,7 +76,7 @@ def __exit__(self, *args): def close(self) -> None: self.__exit__() - def _request_token(self, *scopes: str, **kwargs: Any) -> AccessToken: + def _request_token(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo: if within_credential_chain.get() and not self._endpoint_available: # If within a chain (e.g. DefaultAzureCredential), we do a quick check to see if the IMDS endpoint @@ -96,7 +96,7 @@ def _request_token(self, *scopes: str, **kwargs: Any) -> AccessToken: raise CredentialUnavailableError(error_message) from ex try: - token = super()._request_token(*scopes) + token_info = super()._request_token(*scopes) except CredentialUnavailableError: # Response is not json, skip the IMDS credential raise @@ -123,7 +123,7 @@ def _request_token(self, *scopes: str, **kwargs: Any) -> AccessToken: # if anything else was raised, assume the endpoint is unavailable error_message = "ManagedIdentityCredential authentication unavailable, no response from the IMDS endpoint." raise CredentialUnavailableError(error_message) from ex - return token + return token_info def get_unavailable_message(self, desc: str = "") -> str: return f"ManagedIdentityCredential authentication unavailable, no response from the IMDS endpoint. {desc}" diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/managed_identity.py b/sdk/identity/azure-identity/azure/identity/_credentials/managed_identity.py index 8c8fc9012cfc..c9d9aef9408b 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/managed_identity.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/managed_identity.py @@ -4,9 +4,9 @@ # ------------------------------------ import logging import os -from typing import Optional, TYPE_CHECKING, Any, Mapping +from typing import Optional, TYPE_CHECKING, Any, Mapping, cast -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions, SupportsTokenInfo from .. import CredentialUnavailableError from .._constants import EnvironmentVariables from .._internal.decorators import log_get_token @@ -160,3 +160,28 @@ def get_token( "troubleshoot this issue." ) return self._credential.get_token(*scopes, claims=claims, tenant_id=tenant_id, **kwargs) + + @log_get_token + def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: + """Request an access token for `scopes`. + + This is an alternative to `get_token` to enable certain scenarios that require additional properties + on the token. This method is called automatically by Azure SDK clients. + + :param str scopes: desired scope for the access token. This credential allows only one scope per request. + For more information about scopes, see https://learn.microsoft.com/entra/identity-platform/scopes-oidc. + :keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional. + :paramtype options: ~azure.core.credentials.TokenRequestOptions + + :rtype: AccessTokenInfo + :return: An AccessTokenInfo instance containing information about the token. + :raises ~azure.identity.CredentialUnavailableError: managed identity isn't available in the hosting environment. + """ + if not self._credential: + raise CredentialUnavailableError( + message="No managed identity endpoint found. \n" + "The Target Azure platform could not be determined from environment variables. \n" + "Visit https://aka.ms/azsdk/python/identity/managedidentitycredential/troubleshoot to " + "troubleshoot this issue." + ) + return cast(SupportsTokenInfo, self._credential).get_token_info(*scopes, options=options) diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/on_behalf_of.py b/sdk/identity/azure-identity/azure/identity/_credentials/on_behalf_of.py index 9f8889bd44f5..464014cf20b8 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/on_behalf_of.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/on_behalf_of.py @@ -7,7 +7,7 @@ import msal -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessTokenInfo from azure.core.exceptions import ClientAuthenticationError from .certificate import get_client_credential @@ -123,7 +123,7 @@ def __init__( self._auth_record: Optional[AuthenticationRecord] = None @wrap_exceptions - def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessToken]: + def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessTokenInfo]: if self._auth_record: claims = kwargs.get("claims") app = self._get_app(**kwargs) @@ -134,12 +134,12 @@ def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[Acces now = int(time.time()) result = app.acquire_token_silent_with_error(list(scopes), account=account, claims_challenge=claims) if result and "access_token" in result and "expires_in" in result: - return AccessToken(result["access_token"], now + int(result["expires_in"])) + return AccessTokenInfo(result["access_token"], now + int(result["expires_in"])) return None @wrap_exceptions - def _request_token(self, *scopes: str, **kwargs: Any) -> AccessToken: + def _request_token(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo: app: msal.ConfidentialClientApplication = self._get_app(**kwargs) request_time = int(time.time()) result = app.acquire_token_on_behalf_of(self._assertion, list(scopes), claims_challenge=kwargs.get("claims")) @@ -153,4 +153,4 @@ def _request_token(self, *scopes: str, **kwargs: Any) -> AccessToken: except ClientAuthenticationError: pass # non-fatal; we'll use the assertion again next time instead of a refresh token - return AccessToken(result["access_token"], request_time + int(result["expires_in"])) + return AccessTokenInfo(result["access_token"], request_time + int(result["expires_in"])) diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py b/sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py index 0aaaf11cab58..c23dfc0485b9 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py @@ -2,8 +2,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ -from typing import TYPE_CHECKING, Any, Optional, TypeVar -from azure.core.credentials import AccessToken +from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast +from azure.core.credentials import AccessToken, TokenRequestOptions, AccessTokenInfo, SupportsTokenInfo from .silent import SilentAuthenticationCredential from .. import CredentialUnavailableError @@ -39,7 +39,7 @@ class SharedTokenCacheCredential: def __init__(self, username: Optional[str] = None, **kwargs: Any) -> None: if "authentication_record" in kwargs: - self._credential = SilentAuthenticationCredential(**kwargs) # type: TokenCredential + self._credential: TokenCredential = SilentAuthenticationCredential(**kwargs) else: self._credential = _SharedTokenCacheCredential(username=username, **kwargs) @@ -61,7 +61,7 @@ def get_token( claims: Optional[str] = None, tenant_id: Optional[str] = None, enable_cae: bool = False, - **kwargs: Any + **kwargs: Any, ) -> AccessToken: """Get an access token for `scopes` from the shared cache. @@ -87,6 +87,29 @@ def get_token( """ return self._credential.get_token(*scopes, claims=claims, tenant_id=tenant_id, enable_cae=enable_cae, **kwargs) + @log_get_token + def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: + """Request an access token for `scopes`. + + If no access token is cached, attempt to acquire one using a cached refresh token. + + This is an alternative to `get_token` to enable certain scenarios that require additional properties + on the token. This method is called automatically by Azure SDK clients. + + :param str scopes: desired scope for the access token. This method requires at least one scope. + For more information about scopes, see https://learn.microsoft.com/entra/identity-platform/scopes-oidc. + :keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional. + :paramtype options: ~azure.core.credentials.TokenRequestOptions + + :rtype: AccessTokenInfo + :return: An AccessTokenInfo instance containing information about the token. + :raises ~azure.identity.CredentialUnavailableError: the cache is unavailable or contains insufficient user + information. + :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message`` + attribute gives a reason. + """ + return cast(SupportsTokenInfo, self._credential).get_token_info(*scopes, options=options) + @staticmethod def supported() -> bool: """Whether the shared token cache is supported on the current platform. @@ -115,15 +138,39 @@ def get_token( claims: Optional[str] = None, tenant_id: Optional[str] = None, enable_cae: bool = False, - **kwargs: Any + **kwargs: Any, ) -> AccessToken: + options: TokenRequestOptions = {} + if claims: + options["claims"] = claims + if tenant_id: + options["tenant_id"] = tenant_id + options["enable_cae"] = enable_cae + + token_info = self._get_token_base(*scopes, options=options, base_method_name="get_token", **kwargs) + return AccessToken(token_info.token, token_info.expires_on) + + def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: + return self._get_token_base(*scopes, options=options, base_method_name="get_token_info") + + def _get_token_base( + self, + *scopes: str, + options: Optional[TokenRequestOptions] = None, + base_method_name: str = "get_token", + **kwargs: Any, + ) -> AccessTokenInfo: if not scopes: - raise ValueError("'get_token' requires at least one scope") + raise ValueError(f"'{base_method_name}' requires at least one scope") if not self._client_initialized: self._initialize_client() - is_cae = enable_cae + options = options or {} + claims = options.get("claims") + tenant_id = options.get("tenant_id") + is_cae = options.get("enable_cae", False) + token_cache = self._cae_cache if is_cae else self._cache # Try to load the cache if it is None. @@ -142,8 +189,8 @@ def get_token( # try each refresh token, returning the first access token acquired for refresh_token in self._get_refresh_tokens(account, is_cae=is_cae): - token = self._client.obtain_token_by_refresh_token( - scopes, refresh_token, claims=claims, tenant_id=tenant_id, **kwargs + token = cast(AadClient, self._client).obtain_token_by_refresh_token( + scopes, refresh_token, claims=claims, tenant_id=tenant_id, enable_cae=is_cae, **kwargs ) return token diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/silent.py b/sdk/identity/azure-identity/azure/identity/_credentials/silent.py index 7a86656c23dc..f4add4b726c8 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/silent.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/silent.py @@ -8,7 +8,7 @@ from msal import PublicClientApplication, TokenCache -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions from azure.core.exceptions import ClientAuthenticationError from .. import CredentialUnavailableError @@ -59,16 +59,47 @@ def __exit__(self, *args): self._client.__exit__(*args) def get_token( - self, *scopes: str, claims: Optional[str] = None, tenant_id: Optional[str] = None, **kwargs: Any + self, + *scopes: str, + claims: Optional[str] = None, + tenant_id: Optional[str] = None, + enable_cae: bool = False, + **kwargs: Any, ) -> AccessToken: + options: TokenRequestOptions = {} + if claims: + options["claims"] = claims + if tenant_id: + options["tenant_id"] = tenant_id + options["enable_cae"] = enable_cae + + token_info = self._get_token_base(*scopes, options=options, base_method_name="get_token", **kwargs) + return AccessToken(token_info.token, token_info.expires_on) + + def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: + return self._get_token_base(*scopes, options=options, base_method_name="get_token_info") + + def _get_token_base( + self, + *scopes: str, + options: Optional[TokenRequestOptions] = None, + base_method_name: str = "get_token", + **kwargs: Any, + ) -> AccessTokenInfo: + if not scopes: - raise ValueError('"get_token" requires at least one scope') + raise ValueError(f"'{base_method_name}' requires at least one scope") + + options = options or {} + claims = options.get("claims") + tenant_id = options.get("tenant_id") + enable_cae = options.get("enable_cae", False) - token_cache = self._cae_cache if kwargs.get("enable_cae") else self._cache + token_cache = self._cae_cache if enable_cae else self._cache # Try to load the cache if it is None. if not token_cache: - token_cache = self._initialize_cache(is_cae=bool(kwargs.get("enable_cae"))) + token_cache = self._initialize_cache(is_cae=enable_cae) # If the cache is still None, raise an error. if not token_cache: @@ -76,7 +107,7 @@ def get_token( raise CredentialUnavailableError(message="Shared token cache unavailable") raise ClientAuthenticationError(message="Shared token cache unavailable") - return self._acquire_token_silent(*scopes, claims=claims, tenant_id=tenant_id, **kwargs) + return self._acquire_token_silent(*scopes, claims=claims, tenant_id=tenant_id, enable_cae=enable_cae, **kwargs) def _initialize_cache(self, is_cae: bool = False) -> Optional[TokenCache]: @@ -129,7 +160,7 @@ def _get_client_application(self, **kwargs: Any): return client_applications_map[tenant_id] @wrap_exceptions - def _acquire_token_silent(self, *scopes: str, **kwargs: Any) -> AccessToken: + def _acquire_token_silent(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo: """Silently acquire a token from MSAL. :param str scopes: desired scopes for the access token @@ -153,7 +184,7 @@ def _acquire_token_silent(self, *scopes: str, **kwargs: Any) -> AccessToken: list(scopes), account=account, claims_challenge=kwargs.get("claims") ) if result and "access_token" in result and "expires_in" in result: - return AccessToken(result["access_token"], now + int(result["expires_in"])) + return AccessTokenInfo(result["access_token"], now + int(result["expires_in"])) # if we get this far, the cache contained a matching account but MSAL failed to authenticate it silently if result: diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/vscode.py b/sdk/identity/azure-identity/azure/identity/_credentials/vscode.py index 4a1ca8a26c22..4990d7076292 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/vscode.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/vscode.py @@ -7,7 +7,7 @@ import sys from typing import cast, Any, Dict, Optional -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, TokenRequestOptions, AccessTokenInfo from azure.core.exceptions import ClientAuthenticationError from .._exceptions import CredentialUnavailableError from .._constants import AzureAuthorityHosts, AZURE_VSCODE_CLIENT_ID, EnvironmentVariables @@ -174,11 +174,42 @@ def get_token( raise CredentialUnavailableError(message=ex.message) from ex return super().get_token(*scopes, claims=claims, tenant_id=tenant_id, **kwargs) - def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessToken]: + def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: + """Request an access token for `scopes` as the user currently signed in to Visual Studio Code. + + This is an alternative to `get_token` to enable certain scenarios that require additional properties + on the token. This method is called automatically by Azure SDK clients. + + :param str scopes: desired scopes for the access token. This method requires at least one scope. + For more information about scopes, see https://learn.microsoft.com/entra/identity-platform/scopes-oidc. + :keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional. + :paramtype options: ~azure.core.credentials.TokenRequestOptions + + :rtype: AccessTokenInfo + :return: An AccessTokenInfo instance containing information about the token. + :raises ~azure.identity.CredentialUnavailableError: the credential cannot retrieve user details from Visual + Studio Code. + """ + if self._unavailable_reason: + error_message = ( + self._unavailable_reason + "\n" + "Visit https://aka.ms/azsdk/python/identity/vscodecredential/troubleshoot" + " to troubleshoot this issue." + ) + raise CredentialUnavailableError(message=error_message) + if within_dac.get(): + try: + token = super().get_token_info(*scopes, options=options) + return token + except ClientAuthenticationError as ex: + raise CredentialUnavailableError(message=ex.message) from ex + return super().get_token_info(*scopes, options=options) + + def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessTokenInfo]: self._client = cast(AadClient, self._client) return self._client.get_cached_access_token(scopes, **kwargs) - def _request_token(self, *scopes: str, **kwargs: Any) -> AccessToken: + def _request_token(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo: refresh_token = self._get_refresh_token() self._client = cast(AadClient, self._client) return self._client.obtain_token_by_refresh_token(scopes, refresh_token, **kwargs) diff --git a/sdk/identity/azure-identity/azure/identity/_internal/aad_client.py b/sdk/identity/azure-identity/azure/identity/_internal/aad_client.py index f12ff5618aa2..02fb0c922f5e 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/aad_client.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/aad_client.py @@ -5,7 +5,7 @@ import time from typing import Iterable, Union, Optional, Any -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessTokenInfo from azure.core.pipeline import Pipeline from azure.core.pipeline.transport import HttpRequest from .aad_client_base import AadClientBase @@ -26,7 +26,7 @@ def close(self) -> None: def obtain_token_by_authorization_code( self, scopes: Iterable[str], code: str, redirect_uri: str, client_secret: Optional[str] = None, **kwargs: Any - ) -> AccessToken: + ) -> AccessTokenInfo: request = self._get_auth_code_request( scopes=scopes, code=code, redirect_uri=redirect_uri, client_secret=client_secret, **kwargs ) @@ -34,19 +34,21 @@ def obtain_token_by_authorization_code( def obtain_token_by_client_certificate( self, scopes: Iterable[str], certificate: AadClientCertificate, **kwargs: Any - ) -> AccessToken: + ) -> AccessTokenInfo: request = self._get_client_certificate_request(scopes, certificate, **kwargs) return self._run_pipeline(request, **kwargs) - def obtain_token_by_client_secret(self, scopes: Iterable[str], secret: str, **kwargs: Any) -> AccessToken: + def obtain_token_by_client_secret(self, scopes: Iterable[str], secret: str, **kwargs: Any) -> AccessTokenInfo: request = self._get_client_secret_request(scopes, secret, **kwargs) return self._run_pipeline(request, **kwargs) - def obtain_token_by_jwt_assertion(self, scopes: Iterable[str], assertion: str, **kwargs: Any) -> AccessToken: + def obtain_token_by_jwt_assertion(self, scopes: Iterable[str], assertion: str, **kwargs: Any) -> AccessTokenInfo: request = self._get_jwt_assertion_request(scopes, assertion, **kwargs) return self._run_pipeline(request, **kwargs) - def obtain_token_by_refresh_token(self, scopes: Iterable[str], refresh_token: str, **kwargs: Any) -> AccessToken: + def obtain_token_by_refresh_token( + self, scopes: Iterable[str], refresh_token: str, **kwargs: Any + ) -> AccessTokenInfo: request = self._get_refresh_token_request(scopes, refresh_token, **kwargs) return self._run_pipeline(request, **kwargs) @@ -56,14 +58,14 @@ def obtain_token_on_behalf_of( client_credential: Union[str, AadClientCertificate], user_assertion: str, **kwargs: Any - ) -> AccessToken: + ) -> AccessTokenInfo: # no need for an implementation, non-async OnBehalfOfCredential acquires tokens through MSAL raise NotImplementedError() def _build_pipeline(self, **kwargs: Any) -> Pipeline: return build_pipeline(**kwargs) - def _run_pipeline(self, request: HttpRequest, **kwargs: Any) -> AccessToken: + def _run_pipeline(self, request: HttpRequest, **kwargs: Any) -> AccessTokenInfo: # remove tenant_id and claims kwarg that could have been passed from credential's get_token method # tenant_id is already part of `request` at this point kwargs.pop("tenant_id", None) diff --git a/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py b/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py index c41f87d23612..209a4bf8a790 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py @@ -7,14 +7,14 @@ import json import time from uuid import uuid4 -from typing import TYPE_CHECKING, List, Any, Iterable, Optional, Union, Dict +from typing import TYPE_CHECKING, List, Any, Iterable, Optional, Union, Dict, cast from msal import TokenCache from azure.core.pipeline import PipelineResponse from azure.core.pipeline.policies import ContentDecodePolicy from azure.core.pipeline.transport import HttpRequest -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessTokenInfo from azure.core.exceptions import ClientAuthenticationError from .utils import get_default_authority, normalize_authority, resolve_tenant from .aadclient_certificate import AadClientCertificate @@ -79,9 +79,9 @@ def _initialize_cache(self, is_cae: bool = False) -> TokenCache: self._cae_cache = TokenCache() else: self._cache = TokenCache() - return self._cae_cache if is_cae else self._cache + return cast(TokenCache, self._cae_cache if is_cae else self._cache) - def get_cached_access_token(self, scopes: Iterable[str], **kwargs: Any) -> Optional[AccessToken]: + def get_cached_access_token(self, scopes: Iterable[str], **kwargs: Any) -> Optional[AccessTokenInfo]: tenant = resolve_tenant( self._tenant_id, additionally_allowed_tenants=self._additionally_allowed_tenants, **kwargs ) @@ -94,7 +94,8 @@ def get_cached_access_token(self, scopes: Iterable[str], **kwargs: Any) -> Optio ): expires_on = int(token["expires_on"]) if expires_on > int(time.time()): - return AccessToken(token["secret"], expires_on) + refresh_on = int(token["refresh_on"]) if "refresh_on" in token else None + return AccessTokenInfo(token["secret"], expires_on, refresh_on=refresh_on) return None def get_cached_refresh_tokens(self, scopes: Iterable[str], **kwargs) -> List[Dict]: @@ -130,7 +131,7 @@ def obtain_token_on_behalf_of(self, scopes, client_credential, user_assertion, * def _build_pipeline(self, **kwargs): pass - def _process_response(self, response: PipelineResponse, request_time: int, **kwargs) -> AccessToken: + def _process_response(self, response: PipelineResponse, request_time: int, **kwargs) -> AccessTokenInfo: content = response.context.get( ContentDecodePolicy.CONTEXT_NAME ) or ContentDecodePolicy.deserialize_from_http_generics(response.http_response) @@ -171,7 +172,13 @@ def _process_response(self, response: PipelineResponse, request_time: int, **kwa _scrub_secrets(content) raise ClientAuthenticationError(message="Unexpected response from Microsoft Entra ID: {}".format(content)) - token = AccessToken(content["access_token"], expires_on) + expires_in = int(content.get("expires_in") or expires_on - request_time) + if "refresh_in" not in content and expires_in >= 7200: + # MSAL TokenCache expects "refresh_in" + content["refresh_in"] = expires_in // 2 + + refresh_on = request_time + int(content["refresh_in"]) if "refresh_in" in content else None + token = AccessTokenInfo(content["access_token"], expires_on, refresh_on=refresh_on) # caching is the final step because 'add' mutates 'content' cache.add( diff --git a/sdk/identity/azure-identity/azure/identity/_internal/client_credential_base.py b/sdk/identity/azure-identity/azure/identity/_internal/client_credential_base.py index 81579e3bdec6..16e4b75928f7 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/client_credential_base.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/client_credential_base.py @@ -5,7 +5,7 @@ import time from typing import Any, Optional, Dict -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessTokenInfo from azure.core.exceptions import ClientAuthenticationError from .get_token_mixin import GetTokenMixin @@ -23,18 +23,23 @@ class ClientCredentialBase(MsalCredential, GetTokenMixin): """Base class for credentials authenticating a service principal with a certificate or secret""" @wrap_exceptions - def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessToken]: + def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessTokenInfo]: app = self._get_app(**kwargs) request_time = int(time.time()) result = app.acquire_token_silent_with_error( list(scopes), account=None, claims_challenge=kwargs.pop("claims", None), **_get_known_kwargs(kwargs) ) if result and "access_token" in result and "expires_in" in result: - return AccessToken(result["access_token"], request_time + int(result["expires_in"])) + refresh_on = int(result["refresh_on"]) if "refresh_on" in result else None + return AccessTokenInfo( + result["access_token"], + request_time + int(result["expires_in"]), + refresh_on=refresh_on, + ) return None @wrap_exceptions - def _request_token(self, *scopes: str, **kwargs: Any) -> Optional[AccessToken]: + def _request_token(self, *scopes: str, **kwargs: Any) -> Optional[AccessTokenInfo]: app = self._get_app(**kwargs) request_time = int(time.time()) result = app.acquire_token_for_client(list(scopes), claims_challenge=kwargs.pop("claims", None)) @@ -42,4 +47,9 @@ def _request_token(self, *scopes: str, **kwargs: Any) -> Optional[AccessToken]: message = "Authentication failed: {}".format(result.get("error_description") or result.get("error")) raise ClientAuthenticationError(message=message) - return AccessToken(result["access_token"], request_time + int(result["expires_in"])) + refresh_on = int(result["refresh_on"]) if "refresh_on" in result else None + return AccessTokenInfo( + result["access_token"], + request_time + int(result["expires_in"]), + refresh_on=refresh_on, + ) diff --git a/sdk/identity/azure-identity/azure/identity/_internal/get_token_mixin.py b/sdk/identity/azure-identity/azure/identity/_internal/get_token_mixin.py index 022555b59986..a8fcb2195851 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/get_token_mixin.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/get_token_mixin.py @@ -7,7 +7,7 @@ import time from typing import Any, Optional -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions from .utils import within_credential_chain from .._constants import DEFAULT_REFRESH_OFFSET, DEFAULT_TOKEN_REFRESH_RETRY_DELAY @@ -22,7 +22,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: super(GetTokenMixin, self).__init__(*args, **kwargs) # type: ignore @abc.abstractmethod - def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessToken]: + def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessTokenInfo]: """Attempt to acquire an access token from a cache or by redeeming a refresh token. :param str scopes: desired scopes for the access token. This method requires at least one scope. @@ -30,11 +30,11 @@ def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[Acces https://learn.microsoft.com/entra/identity-platform/scopes-oidc. :return: An access token with the desired scopes if successful; otherwise, None. - :rtype: ~azure.core.credentials.AccessToken or None + :rtype: ~azure.core.credentials.AccessTokenInfo or None """ @abc.abstractmethod - def _request_token(self, *scopes: str, **kwargs: Any) -> AccessToken: + def _request_token(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo: """Request an access token from the STS. :param str scopes: desired scopes for the access token. This method requires at least one scope. @@ -42,11 +42,13 @@ def _request_token(self, *scopes: str, **kwargs: Any) -> AccessToken: https://learn.microsoft.com/entra/identity-platform/scopes-oidc. :return: An access token with the desired scopes. - :rtype: ~azure.core.credentials.AccessToken + :rtype: ~azure.core.credentials.AccessTokenInfo """ - def _should_refresh(self, token: AccessToken) -> bool: + def _should_refresh(self, token: AccessTokenInfo) -> bool: now = int(time.time()) + if token.refresh_on is not None and now >= token.refresh_on: + return True if token.expires_on - now > DEFAULT_REFRESH_OFFSET: return False if now - self._last_request_time < DEFAULT_TOKEN_REFRESH_RETRY_DELAY: @@ -59,7 +61,7 @@ def get_token( claims: Optional[str] = None, tenant_id: Optional[str] = None, enable_cae: bool = False, - **kwargs: Any + **kwargs: Any, ) -> AccessToken: """Request an access token for `scopes`. @@ -81,8 +83,50 @@ def get_token( :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message`` attribute gives a reason. """ + options: TokenRequestOptions = {} + if claims: + options["claims"] = claims + if tenant_id: + options["tenant_id"] = tenant_id + options["enable_cae"] = enable_cae + + token_info = self._get_token_base(*scopes, options=options, base_method_name="get_token", **kwargs) + return AccessToken(token_info.token, token_info.expires_on) + + def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: + """Request an access token for `scopes`. + + This is an alternative to `get_token` to enable certain scenarios that require additional properties + on the token. This method is called automatically by Azure SDK clients. + + :param str scopes: desired scopes for the access token. This method requires at least one scope. + For more information about scopes, see https://learn.microsoft.com/entra/identity-platform/scopes-oidc. + :keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional. + :paramtype options: ~azure.core.credentials.TokenRequestOptions + + :rtype: AccessTokenInfo + :return: An AccessTokenInfo instance containing information about the token. + :raises CredentialUnavailableError: the credential is unable to attempt authentication because it lacks + required data, state, or platform support + :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message`` + attribute gives a reason. + """ + return self._get_token_base(*scopes, options=options, base_method_name="get_token_info") + + def _get_token_base( + self, + *scopes: str, + options: Optional[TokenRequestOptions] = None, + base_method_name: str = "get_token", + **kwargs: Any, + ) -> AccessTokenInfo: if not scopes: - raise ValueError('"get_token" requires at least one scope') + raise ValueError(f'"{base_method_name}" requires at least one scope') + + options = options or {} + claims = options.get("claims") + tenant_id = options.get("tenant_id") + enable_cae = options.get("enable_cae", False) try: token = self._acquire_token_silently( @@ -103,16 +147,18 @@ def get_token( pass _LOGGER.log( logging.DEBUG if within_credential_chain.get() else logging.INFO, - "%s.get_token succeeded", + "%s.%s succeeded", self.__class__.__name__, + base_method_name, ) return token except Exception as ex: _LOGGER.log( logging.DEBUG if within_credential_chain.get() else logging.WARNING, - "%s.get_token failed: %s", + "%s.%s failed: %s", self.__class__.__name__, + base_method_name, ex, exc_info=_LOGGER.isEnabledFor(logging.DEBUG), ) diff --git a/sdk/identity/azure-identity/azure/identity/_internal/interactive.py b/sdk/identity/azure-identity/azure/identity/_internal/interactive.py index d3e671e30694..54a7474fc1b4 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/interactive.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/interactive.py @@ -9,10 +9,10 @@ import json import logging import time -from typing import Any, Optional, Iterable +from typing import Any, Optional, Iterable, Dict from urllib.parse import urlparse -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions from azure.core.exceptions import ClientAuthenticationError from .msal_credentials import MsalCredential @@ -95,7 +95,7 @@ def __init__( *, authentication_record: Optional[AuthenticationRecord] = None, disable_automatic_authentication: bool = False, - **kwargs: Any + **kwargs: Any, ) -> None: self._disable_automatic_authentication = disable_automatic_authentication self._auth_record = authentication_record @@ -106,7 +106,7 @@ def __init__( client_id=self._auth_record.client_id, authority=self._auth_record.authority, tenant_id=tenant_id, - **kwargs + **kwargs, ) else: super(InteractiveCredential, self).__init__(**kwargs) @@ -117,7 +117,7 @@ def get_token( claims: Optional[str] = None, tenant_id: Optional[str] = None, enable_cae: bool = False, - **kwargs: Any + **kwargs: Any, ) -> AccessToken: """Request an access token for `scopes`. @@ -140,23 +140,68 @@ def get_token( :raises AuthenticationRequiredError: user interaction is necessary to acquire a token, and the credential is configured not to begin this automatically. Call :func:`authenticate` to begin interactive authentication. """ + options: TokenRequestOptions = {} + if claims: + options["claims"] = claims + if tenant_id: + options["tenant_id"] = tenant_id + options["enable_cae"] = enable_cae + + token_info = self._get_token_base(*scopes, options=options, base_method_name="get_token", **kwargs) + return AccessToken(token_info.token, token_info.expires_on) + + def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: + """Request an access token for `scopes`. + + This is an alternative to `get_token` to enable certain scenarios that require additional properties + on the token. This method is called automatically by Azure SDK clients. + + :param str scopes: desired scopes for the access token. This method requires at least one scope. + For more information about scopes, see https://learn.microsoft.com/entra/identity-platform/scopes-oidc. + :keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional. + :paramtype options: ~azure.core.credentials.TokenRequestOptions + + :rtype: AccessTokenInfo + :return: An AccessTokenInfo instance containing information about the token. + + :raises CredentialUnavailableError: the credential is unable to attempt authentication because it lacks + required data, state, or platform support + :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message`` + attribute gives a reason. + :raises AuthenticationRequiredError: user interaction is necessary to acquire a token, and the credential is + configured not to begin this automatically. Call :func:`authenticate` to begin interactive authentication. + """ + return self._get_token_base(*scopes, options=options, base_method_name="get_token_info") + + def _get_token_base( + self, + *scopes: str, + options: Optional[TokenRequestOptions] = None, + base_method_name: str = "get_token", + **kwargs: Any, + ) -> AccessTokenInfo: if not scopes: - message = "'get_token' requires at least one scope" - _LOGGER.warning("%s.get_token failed: %s", self.__class__.__name__, message) + message = f"'{base_method_name}' requires at least one scope" + _LOGGER.warning("%s.%s failed: %s", self.__class__.__name__, base_method_name, message) raise ValueError(message) allow_prompt = kwargs.pop("_allow_prompt", not self._disable_automatic_authentication) + options = options or {} + claims = options.get("claims") + tenant_id = options.get("tenant_id") + enable_cae = options.get("enable_cae", False) try: token = self._acquire_token_silent( *scopes, claims=claims, tenant_id=tenant_id, enable_cae=enable_cae, **kwargs ) - _LOGGER.info("%s.get_token succeeded", self.__class__.__name__) + _LOGGER.info("%s.%s succeeded", self.__class__.__name__, base_method_name) return token except Exception as ex: # pylint:disable=broad-except if not (isinstance(ex, AuthenticationRequiredError) and allow_prompt): _LOGGER.warning( - "%s.get_token failed: %s", + "%s.%s failed: %s", self.__class__.__name__, + base_method_name, ex, exc_info=_LOGGER.isEnabledFor(logging.DEBUG), ) @@ -176,15 +221,16 @@ def get_token( self._auth_record = _build_auth_record(result) except Exception as ex: # pylint:disable=broad-except _LOGGER.warning( - "%s.get_token failed: %s", + "%s.%s failed: %s", self.__class__.__name__, + base_method_name, ex, exc_info=_LOGGER.isEnabledFor(logging.DEBUG), ) raise - _LOGGER.info("%s.get_token succeeded", self.__class__.__name__) - return AccessToken(result["access_token"], now + int(result["expires_in"])) + _LOGGER.info("%s.%s succeeded", self.__class__.__name__, base_method_name) + return AccessTokenInfo(result["access_token"], now + int(result["expires_in"])) def authenticate( self, *, scopes: Optional[Iterable[str]] = None, claims: Optional[str] = None, **kwargs: Any @@ -214,7 +260,7 @@ def authenticate( return self._auth_record # type: ignore @wrap_exceptions - def _acquire_token_silent(self, *scopes: str, **kwargs: Any) -> AccessToken: + def _acquire_token_silent(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo: result = None claims = kwargs.get("claims") if self._auth_record: @@ -226,7 +272,10 @@ def _acquire_token_silent(self, *scopes: str, **kwargs: Any) -> AccessToken: now = int(time.time()) result = app.acquire_token_silent_with_error(list(scopes), account=account, claims_challenge=claims) if result and "access_token" in result and "expires_in" in result: - return AccessToken(result["access_token"], now + int(result["expires_in"])) + refresh_on = int(result["refresh_on"]) if "refresh_on" in result else None + return AccessTokenInfo( + result["access_token"], now + int(result["expires_in"]), refresh_on=refresh_on + ) # if we get this far, result is either None or the content of a Microsoft Entra ID error response if result: @@ -235,5 +284,5 @@ def _acquire_token_silent(self, *scopes: str, **kwargs: Any) -> AccessToken: raise AuthenticationRequiredError(scopes, claims=claims) @abc.abstractmethod - def _request_token(self, *scopes, **kwargs): + def _request_token(self, *scopes, **kwargs) -> Dict: pass diff --git a/sdk/identity/azure-identity/azure/identity/_internal/managed_identity_base.py b/sdk/identity/azure-identity/azure/identity/_internal/managed_identity_base.py index 540e52b4f710..949ac14a844f 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/managed_identity_base.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/managed_identity_base.py @@ -5,7 +5,7 @@ import abc from typing import cast, Any, Optional, TypeVar -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions from .. import CredentialUnavailableError from .._internal.managed_identity_client import ManagedIdentityClient from .._internal.get_token_mixin import GetTokenMixin @@ -47,10 +47,15 @@ def get_token( raise CredentialUnavailableError(message=self.get_unavailable_message()) return super(ManagedIdentityBase, self).get_token(*scopes, claims=claims, tenant_id=tenant_id, **kwargs) - def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessToken]: + def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: + if not self._client: + raise CredentialUnavailableError(message=self.get_unavailable_message()) + return super().get_token_info(*scopes, options=options) + + def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessTokenInfo]: # casting because mypy can't determine that these methods are called # only by get_token, which raises when self._client is None return cast(ManagedIdentityClient, self._client).get_cached_token(*scopes) - def _request_token(self, *scopes: str, **kwargs: Any) -> AccessToken: + def _request_token(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo: return cast(ManagedIdentityClient, self._client).request_token(*scopes, **kwargs) diff --git a/sdk/identity/azure-identity/azure/identity/_internal/managed_identity_client.py b/sdk/identity/azure-identity/azure/identity/_internal/managed_identity_client.py index 1b867cc84407..58b5e29b9871 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/managed_identity_client.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/managed_identity_client.py @@ -8,7 +8,7 @@ from msal import TokenCache -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessTokenInfo from azure.core.exceptions import ClientAuthenticationError, DecodeError from azure.core.pipeline.policies import ContentDecodePolicy from azure.core.pipeline import PipelineResponse @@ -40,7 +40,7 @@ def __init__( self._pipeline = self._build_pipeline(**kwargs) self._request_factory = request_factory - def _process_response(self, response: PipelineResponse, request_time: int) -> AccessToken: + def _process_response(self, response: PipelineResponse, request_time: int) -> AccessTokenInfo: content = response.context.get(ContentDecodePolicy.CONTEXT_NAME) if not content: try: @@ -70,7 +70,13 @@ def _process_response(self, response: PipelineResponse, request_time: int) -> Ac expires_on = int(content.get("expires_on") or int(content["expires_in"]) + request_time) content["expires_on"] = expires_on - token = AccessToken(content["access_token"], content["expires_on"]) + expires_in = int(content.get("expires_in") or expires_on - request_time) + if "refresh_in" not in content and expires_in >= 7200: + # MSAL TokenCache expects "refresh_in" + content["refresh_in"] = expires_in // 2 + + refresh_on = request_time + int(content["refresh_in"]) if "refresh_in" in content else None + token = AccessTokenInfo(content["access_token"], content["expires_on"], refresh_on=refresh_on) # caching is the final step because TokenCache.add mutates its "event" self._cache.add( @@ -80,15 +86,15 @@ def _process_response(self, response: PipelineResponse, request_time: int) -> Ac return token - def get_cached_token(self, *scopes: str) -> Optional[AccessToken]: + def get_cached_token(self, *scopes: str) -> Optional[AccessTokenInfo]: resource = _scopes_to_resource(*scopes) - for token in self._cache.search( - TokenCache.CredentialType.ACCESS_TOKEN, - target=[resource], - ): + now = time.time() + for token in self._cache.search(TokenCache.CredentialType.ACCESS_TOKEN, target=[resource]): expires_on = int(token["expires_on"]) - if expires_on > time.time(): - return AccessToken(token["secret"], expires_on) + refresh_on = int(token["refresh_on"]) if "refresh_on" in token else None + if expires_on > now and (not refresh_on or refresh_on > now): + return AccessTokenInfo(token["secret"], expires_on, refresh_on=refresh_on) + return None @abc.abstractmethod @@ -124,7 +130,7 @@ def __exit__(self, *args: Any) -> None: def close(self) -> None: self.__exit__() - def request_token(self, *scopes: str, **kwargs: Any) -> AccessToken: + def request_token(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo: resource = _scopes_to_resource(*scopes) request = self._request_factory(resource, self._identity_config) kwargs.pop("tenant_id", None) diff --git a/sdk/identity/azure-identity/azure/identity/_internal/msal_managed_identity_client.py b/sdk/identity/azure-identity/azure/identity/_internal/msal_managed_identity_client.py index cee1acc03450..a2d46074d19c 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/msal_managed_identity_client.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/msal_managed_identity_client.py @@ -8,7 +8,7 @@ import logging import msal -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions from azure.core.exceptions import ClientAuthenticationError from .msal_client import MsalClient @@ -45,14 +45,15 @@ def get_unavailable_message(self, desc: str = "") -> str: def close(self) -> None: self.__exit__() - def _request_token(self, *scopes: str, **kwargs: Any) -> AccessToken: # pylint:disable=unused-argument + def _request_token(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo: # pylint:disable=unused-argument if not scopes: raise ValueError('"get_token" requires at least one scope') resource = _scopes_to_resource(*scopes) result = self._msal_client.acquire_token_for_client(resource=resource) now = int(time.time()) if result and "access_token" in result and "expires_in" in result: - return AccessToken(result["access_token"], now + int(result["expires_in"])) + refresh_on = int(result["refresh_on"]) if "refresh_on" in result else None + return AccessTokenInfo(result["access_token"], now + int(result["expires_in"]), refresh_on=refresh_on) if result and "error" in result: error_desc = cast(str, result["error"]) error_message = self.get_unavailable_message(error_desc) @@ -83,7 +84,7 @@ def get_token( claims: Optional[str] = None, tenant_id: Optional[str] = None, enable_cae: bool = False, - **kwargs: Any + **kwargs: Any, ) -> AccessToken: """Request an access token for `scopes`. @@ -105,31 +106,77 @@ def get_token( :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message`` attribute gives a reason. """ + options: TokenRequestOptions = {} + if claims: + options["claims"] = claims + if tenant_id: + options["tenant_id"] = tenant_id + options["enable_cae"] = enable_cae + + token_info = self._get_token_base(*scopes, options=options, base_method_name="get_token", **kwargs) + return AccessToken(token_info.token, token_info.expires_on) + + def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: + """Request an access token for `scopes`. + + This is an alternative to `get_token` to enable certain scenarios that require additional properties + on the token. This method is called automatically by Azure SDK clients. + + :param str scopes: desired scopes for the access token. This method requires at least one scope. + For more information about scopes, see https://learn.microsoft.com/entra/identity-platform/scopes-oidc. + :keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional. + :paramtype options: ~azure.core.credentials.TokenRequestOptions + + :rtype: AccessTokenInfo + :return: An AccessTokenInfo instance containing information about the token. + :raises CredentialUnavailableError: the credential is unable to attempt authentication because it lacks + required data, state, or platform support + :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message`` + attribute gives a reason. + """ + return self._get_token_base(*scopes, options=options, base_method_name="get_token_info") + + def _get_token_base( + self, + *scopes: str, + options: Optional[TokenRequestOptions] = None, + base_method_name: str = "get_token", + **kwargs: Any, + ) -> AccessTokenInfo: if not scopes: - raise ValueError('"get_token" requires at least one scope') + raise ValueError(f'"{base_method_name}" requires at least one scope') _scopes_to_resource(*scopes) token = None + + options = options or {} + claims = options.get("claims") + tenant_id = options.get("tenant_id") + enable_cae = options.get("enable_cae", False) + try: token = self._request_token(*scopes, claims=claims, tenant_id=tenant_id, enable_cae=enable_cae, **kwargs) if token: _LOGGER.log( logging.DEBUG if within_credential_chain.get() else logging.INFO, - "%s.get_token succeeded", + "%s.%s succeeded", self.__class__.__name__, + base_method_name, ) return token _LOGGER.log( logging.DEBUG if within_credential_chain.get() else logging.WARNING, - "%s.get_token failed", + "%s.%s failed", self.__class__.__name__, + base_method_name, exc_info=_LOGGER.isEnabledFor(logging.DEBUG), ) raise CredentialUnavailableError(self.get_unavailable_message()) except msal.ManagedIdentityError as ex: _LOGGER.log( logging.DEBUG if within_credential_chain.get() else logging.WARNING, - "%s.get_token failed: %s", + "%s.%s failed: %s", self.__class__.__name__, + base_method_name, ex, exc_info=_LOGGER.isEnabledFor(logging.DEBUG), ) @@ -137,8 +184,9 @@ def get_token( except Exception as ex: # pylint:disable=broad-except _LOGGER.log( logging.DEBUG if within_credential_chain.get() else logging.WARNING, - "%s.get_token failed: %s", + "%s.%s failed: %s", self.__class__.__name__, + base_method_name, ex, exc_info=_LOGGER.isEnabledFor(logging.DEBUG), ) diff --git a/sdk/identity/azure-identity/azure/identity/_internal/shared_token_cache.py b/sdk/identity/azure-identity/azure/identity/_internal/shared_token_cache.py index 867b255a0c76..4afb4bbdef84 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/shared_token_cache.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/shared_token_cache.py @@ -9,7 +9,7 @@ from urllib.parse import urlparse import msal -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessTokenInfo from .. import CredentialUnavailableError from .._constants import KnownAuthorities from .._internal import get_default_authority, normalize_authority, wrap_exceptions @@ -228,7 +228,7 @@ def _get_account( def _get_cached_access_token( self, scopes: Iterable[str], account: CacheItem, is_cae: bool = False - ) -> Optional[AccessToken]: + ) -> Optional[AccessTokenInfo]: if "home_account_id" not in account: return None @@ -241,8 +241,9 @@ def _get_cached_access_token( ) for token in cache_entries: expires_on = int(token["expires_on"]) + refresh_on = int(token["refresh_on"]) if "refresh_on" in token else None if expires_on - 300 > int(time.time()): - return AccessToken(token["secret"], expires_on) + return AccessTokenInfo(token["secret"], expires_on, refresh_on=refresh_on) except Exception as ex: # pylint:disable=broad-except message = "Error accessing cached data: {}".format(ex) raise CredentialUnavailableError(message=message) from ex diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/application.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/application.py index 8a795b4afc2c..fa16e36a3609 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/application.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/application.py @@ -4,9 +4,10 @@ # ------------------------------------ import logging import os -from typing import Optional, Any +from typing import Optional, Any, cast -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions +from azure.core.credentials_async import AsyncSupportsTokenInfo from .chained import ChainedTokenCredential from .environment import EnvironmentCredential from .managed_identity import ManagedIdentityCredential @@ -89,3 +90,30 @@ async def get_token( return token return await super().get_token(*scopes, claims=claims, tenant_id=tenant_id, **kwargs) + + async def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: + """Request an access token for `scopes`. + + This is an alternative to `get_token` to enable certain scenarios that require additional properties + on the token. This method is called automatically by Azure SDK clients. + + :param str scopes: desired scopes for the access token. This method requires at least one scope. + For more information about scopes, see https://learn.microsoft.com/entra/identity-platform/scopes-oidc. + :keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional. + :paramtype options: ~azure.core.credentials.TokenRequestOptions + + :rtype: AccessTokenInfo + :return: An AccessTokenInfo instance containing information about the token. + :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The exception has a + `message` attribute listing each authentication attempt and its error message. + """ + if self._successful_credential: + token_info = await cast(AsyncSupportsTokenInfo, self._successful_credential).get_token_info( + *scopes, options=options + ) + _LOGGER.info( + "%s acquired a token from %s", self.__class__.__name__, self._successful_credential.__class__.__name__ + ) + return token_info + + return await cast(AsyncSupportsTokenInfo, super()).get_token_info(*scopes, options=options) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/authorization_code.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/authorization_code.py index aa2ed7d6d87a..c075cd300553 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/authorization_code.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/authorization_code.py @@ -5,7 +5,7 @@ from typing import Optional, Any, cast from azure.core.exceptions import ClientAuthenticationError -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions from .._internal import AadClient, AsyncContextManager from .._internal.get_token_mixin import GetTokenMixin @@ -96,10 +96,35 @@ async def get_token( *scopes, claims=claims, tenant_id=tenant_id, client_secret=self._client_secret, **kwargs ) - async def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessToken]: + async def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: + """Request an access token for `scopes`. + + This is an alternative to `get_token` to enable certain scenarios that require additional properties + on the token. This method is called automatically by Azure SDK clients. + + The first time this method is called, the credential will redeem its authorization code. On subsequent calls + the credential will return a cached access token or redeem a refresh token, if it acquired a refresh token upon + redeeming the authorization code. + + :param str scopes: desired scopes for the access token. This method requires at least one scope. + For more information about scopes, see https://learn.microsoft.com/entra/identity-platform/scopes-oidc. + :keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional. + :paramtype options: ~azure.core.credentials.TokenRequestOptions + + :rtype: AccessTokenInfo + :return: An AccessTokenInfo instance containing information about the token. + :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message`` + attribute gives a reason. Any error response from Microsoft Entra ID is available as the error's + ``response`` attribute. + """ + return await super()._get_token_base( + *scopes, options=options, client_secret=self._client_secret, base_method_name="get_token_info" + ) + + async def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessTokenInfo]: return self._client.get_cached_access_token(scopes, **kwargs) - async def _request_token(self, *scopes: str, **kwargs: Any) -> AccessToken: + async def _request_token(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo: if self._authorization_code: token = await self._client.obtain_token_by_authorization_code( scopes=scopes, code=self._authorization_code, redirect_uri=self._redirect_uri, **kwargs @@ -107,7 +132,7 @@ async def _request_token(self, *scopes: str, **kwargs: Any) -> AccessToken: self._authorization_code = None # auth codes are single-use return token - token = cast(AccessToken, None) + token = cast(AccessTokenInfo, None) for refresh_token in self._client.get_cached_refresh_tokens(scopes): if "secret" in refresh_token: token = await self._client.obtain_token_by_refresh_token(scopes, refresh_token["secret"], **kwargs) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/azd_cli.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/azd_cli.py index 19b1fda2c999..eafeb5affd47 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/azd_cli.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/azd_cli.py @@ -9,7 +9,7 @@ from typing import Any, List, Optional from azure.core.exceptions import ClientAuthenticationError -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions from .._internal import AsyncContextManager from .._internal.decorators import log_get_token_async from ... import CredentialUnavailableError @@ -108,9 +108,46 @@ async def get_token( if sys.platform.startswith("win") and not isinstance(asyncio.get_event_loop(), asyncio.ProactorEventLoop): return _SyncAzureDeveloperCliCredential().get_token(*scopes, tenant_id=tenant_id, **kwargs) + options: TokenRequestOptions = {} + if tenant_id: + options["tenant_id"] = tenant_id + + token_info = await self._get_token_base(*scopes, options=options, **kwargs) + return AccessToken(token_info.token, token_info.expires_on) + + @log_get_token_async + async def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: + """Request an access token for `scopes`. + + This is an alternative to `get_token` to enable certain scenarios that require additional properties + on the token. This method is called automatically by Azure SDK clients. Applications calling this method + directly must also handle token caching because this credential doesn't cache the tokens it acquires. + + :param str scopes: desired scopes for the access token. This method requires at least one scope. + For more information about scopes, see https://learn.microsoft.com/entra/identity-platform/scopes-oidc. + :keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional. + :paramtype options: ~azure.core.credentials.TokenRequestOptions + + :rtype: AccessTokenInfo + :return: An AccessTokenInfo instance containing information about the token. + + :raises ~azure.identity.CredentialUnavailableError: the credential was unable to invoke + the Azure Developer CLI. + :raises ~azure.core.exceptions.ClientAuthenticationError: the credential invoked + the Azure Developer CLI but didn't receive an access token. + """ + # only ProactorEventLoop supports subprocesses on Windows (and it isn't the default loop on Python < 3.8) + if sys.platform.startswith("win") and not isinstance(asyncio.get_event_loop(), asyncio.ProactorEventLoop): + return _SyncAzureDeveloperCliCredential().get_token_info(*scopes, options=options) + return await self._get_token_base(*scopes, options=options) + + async def _get_token_base( + self, *scopes: str, options: Optional[TokenRequestOptions] = None, **kwargs: Any + ) -> AccessTokenInfo: if not scopes: raise ValueError("Missing scope in request. \n") + tenant_id = options.get("tenant_id") if options else None if tenant_id: validate_tenant_id(tenant_id) for scope in scopes: diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_cli.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_cli.py index dec0fcc690b1..62f4f23e478c 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_cli.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_cli.py @@ -9,7 +9,7 @@ from typing import Any, List, Optional from azure.core.exceptions import ClientAuthenticationError -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions from .._internal import AsyncContextManager from .._internal.decorators import log_get_token_async from ... import CredentialUnavailableError @@ -89,6 +89,42 @@ async def get_token( if sys.platform.startswith("win") and not isinstance(asyncio.get_event_loop(), asyncio.ProactorEventLoop): return _SyncAzureCliCredential().get_token(*scopes, tenant_id=tenant_id, **kwargs) + options: TokenRequestOptions = {} + if tenant_id: + options["tenant_id"] = tenant_id + + token_info = await self._get_token_base(*scopes, options=options, **kwargs) + return AccessToken(token_info.token, token_info.expires_on) + + @log_get_token_async + async def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: + """Request an access token for `scopes`. + + This is an alternative to `get_token` to enable certain scenarios that require additional properties + on the token. This method is called automatically by Azure SDK clients. Applications calling this method + directly must also handle token caching because this credential doesn't cache the tokens it acquires. + + :param str scopes: desired scopes for the access token. This credential allows only one scope per request. + For more information about scopes, see https://learn.microsoft.com/entra/identity-platform/scopes-oidc. + :keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional. + :paramtype options: ~azure.core.credentials.TokenRequestOptions + + :rtype: AccessTokenInfo + :return: An AccessTokenInfo instance containing information about the token. + + :raises ~azure.identity.CredentialUnavailableError: the credential was unable to invoke the Azure CLI. + :raises ~azure.core.exceptions.ClientAuthenticationError: the credential invoked the Azure CLI but didn't + receive an access token. + """ + # only ProactorEventLoop supports subprocesses on Windows (and it isn't the default loop on Python < 3.8) + if sys.platform.startswith("win") and not isinstance(asyncio.get_event_loop(), asyncio.ProactorEventLoop): + return _SyncAzureCliCredential().get_token_info(*scopes, options=options) + return await self._get_token_base(*scopes, options=options) + + async def _get_token_base( + self, *scopes: str, options: Optional[TokenRequestOptions] = None, **kwargs: Any + ) -> AccessTokenInfo: + tenant_id = options.get("tenant_id") if options else None if tenant_id: validate_tenant_id(tenant_id) for scope in scopes: diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_pipelines.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_pipelines.py index 918cae86a921..ccaa635f1dac 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_pipelines.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_pipelines.py @@ -5,7 +5,7 @@ from typing import Any, Optional from azure.core.exceptions import ClientAuthenticationError -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions from azure.core.rest import HttpResponse from .client_assertion import ClientAssertionCredential @@ -103,6 +103,25 @@ async def get_token( *scopes, claims=claims, tenant_id=tenant_id, enable_cae=enable_cae, **kwargs ) + async def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: + """Request an access token for `scopes`. + + This is an alternative to `get_token` to enable certain scenarios that require additional properties + on the token. This method is called automatically by Azure SDK clients. + + :param str scopes: desired scope for the access token. This method requires at least one scope. + For more information about scopes, see https://learn.microsoft.com/entra/identity-platform/scopes-oidc. + :keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional. + :paramtype options: ~azure.core.credentials.TokenRequestOptions + + :rtype: AccessTokenInfo + :return: An AccessTokenInfo instance containing information about the token. + :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message`` + attribute gives a reason. + """ + validate_env_vars() + return await self._client_assertion_credential.get_token_info(*scopes, options=options) + def _get_oidc_token(self) -> str: request = build_oidc_request(self._service_connection_id, self._system_access_token) response = self._pipeline.run(request, retry_on_methods=[request.method]) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_powershell.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_powershell.py index 6f7e0ffc5fa6..f9117f2e66c3 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_powershell.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_powershell.py @@ -5,7 +5,7 @@ import asyncio import sys from typing import Any, cast, List, Optional -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions from .._internal import AsyncContextManager from .._internal.decorators import log_get_token_async @@ -84,6 +84,42 @@ async def get_token( if sys.platform.startswith("win") and not isinstance(asyncio.get_event_loop(), asyncio.ProactorEventLoop): return _SyncCredential().get_token(*scopes, tenant_id=tenant_id, **kwargs) + options: TokenRequestOptions = {} + if tenant_id: + options["tenant_id"] = tenant_id + + token_info = await self._get_token_base(*scopes, options=options, **kwargs) + return AccessToken(token_info.token, token_info.expires_on) + + @log_get_token_async + async def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: + """Request an access token for `scopes`. + + This is an alternative to `get_token` to enable certain scenarios that require additional properties + on the token. This method is called automatically by Azure SDK clients. Applications calling this method + directly must also handle token caching because this credential doesn't cache the tokens it acquires. + + :param str scopes: desired scopes for the access token. TThis credential allows only one scope per request. + For more information about scopes, see https://learn.microsoft.com/entra/identity-platform/scopes-oidc. + :keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional. + :paramtype options: ~azure.core.credentials.TokenRequestOptions + + :rtype: AccessTokenInfo + :return: An AccessTokenInfo instance containing information about the token. + + :raises ~azure.identity.CredentialUnavailableError: the credential was unable to invoke Azure PowerShell, or + no account is authenticated + :raises ~azure.core.exceptions.ClientAuthenticationError: the credential invoked Azure PowerShell but didn't + receive an access token + """ + if sys.platform.startswith("win") and not isinstance(asyncio.get_event_loop(), asyncio.ProactorEventLoop): + return _SyncCredential().get_token_info(*scopes, options=options) + return await self._get_token_base(*scopes, options=options) + + async def _get_token_base( + self, *scopes: str, options: Optional[TokenRequestOptions] = None, **kwargs: Any + ) -> AccessTokenInfo: + tenant_id = options.get("tenant_id") if options else None if tenant_id: validate_tenant_id(tenant_id) for scope in scopes: diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/certificate.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/certificate.py index 67a2fefbc3b7..8d2ad9ae12f2 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/certificate.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/certificate.py @@ -4,7 +4,7 @@ # ------------------------------------ from typing import Optional, Any -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessTokenInfo from .._internal import AadClient, AsyncContextManager from .._internal.get_token_mixin import GetTokenMixin from ..._credentials.certificate import get_client_credential @@ -70,8 +70,8 @@ async def close(self) -> None: await self._client.__aexit__() - async def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessToken]: + async def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessTokenInfo]: return self._client.get_cached_access_token(scopes, **kwargs) - async def _request_token(self, *scopes: str, **kwargs: Any) -> AccessToken: + async def _request_token(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo: return await self._client.obtain_token_by_client_certificate(scopes, self._certificate, **kwargs) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/chained.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/chained.py index e6cb50744fd9..80d25746ad67 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/chained.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/chained.py @@ -4,10 +4,11 @@ # ------------------------------------ import asyncio import logging -from typing import Any, Optional, TYPE_CHECKING +from typing import Any, Optional, TYPE_CHECKING, cast from azure.core.exceptions import ClientAuthenticationError -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions +from azure.core.credentials_async import AsyncSupportsTokenInfo from .._internal import AsyncContextManager from ... import CredentialUnavailableError from ..._credentials.chained import _get_error_message @@ -42,7 +43,7 @@ def __init__(self, *credentials: "AsyncTokenCredential") -> None: if not credentials: raise ValueError("at least one credential is required") - self._successful_credential = None # type: Optional[AsyncTokenCredential] + self._successful_credential: Optional[AsyncTokenCredential] = None self.credentials = credentials async def close(self) -> None: @@ -105,3 +106,64 @@ async def get_token( "https://aka.ms/azsdk/python/identity/defaultazurecredential/troubleshoot." ) raise ClientAuthenticationError(message=message) + + async def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: + """Request a token from each chained credential, in order, returning the first token received. + + If no credential provides a token, raises :class:`azure.core.exceptions.ClientAuthenticationError` + with an error message from each credential. + + This is an alternative to `get_token` to enable certain scenarios that require additional properties + on the token. This method is called automatically by Azure SDK clients. + + :param str scopes: desired scopes for the access token. This method requires at least one scope. + For more information about scopes, see https://learn.microsoft.com/entra/identity-platform/scopes-oidc. + :keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional. + :paramtype options: ~azure.core.credentials.TokenRequestOptions + + :rtype: AccessTokenInfo + :return: An AccessTokenInfo instance containing information about the token. + + :raises ~azure.core.exceptions.ClientAuthenticationError: no credential in the chain provided a token. + """ + within_credential_chain.set(True) + history = [] + for credential in self.credentials: + try: + # A custom credential in the chain may not implement get_token_info + if hasattr(credential, "get_token_info"): + token_info = await cast(AsyncSupportsTokenInfo, credential).get_token_info(*scopes, options=options) + else: + options = options or {} + token = await credential.get_token(*scopes, **options) + token_info = AccessTokenInfo(token=token.token, expires_on=token.expires_on) + _LOGGER.info("%s acquired a token from %s", self.__class__.__name__, credential.__class__.__name__) + self._successful_credential = credential + within_credential_chain.set(False) + return token_info + except CredentialUnavailableError as ex: + # credential didn't attempt authentication because it lacks required data or state -> continue + history.append((credential, ex.message)) + except Exception as ex: # pylint: disable=broad-except + # credential failed to authenticate, or something unexpectedly raised -> break + history.append((credential, str(ex))) + _LOGGER.debug( + '%s.get_token_info failed: %s raised unexpected error "%s"', + self.__class__.__name__, + credential.__class__.__name__, + ex, + exc_info=True, + ) + break + + within_credential_chain.set(False) + attempts = _get_error_message(history) + message = ( + self.__class__.__name__ + + " failed to retrieve a token from the included credentials." + + attempts + + "\nTo mitigate this issue, please refer to the troubleshooting guidelines here at " + "https://aka.ms/azsdk/python/identity/defaultazurecredential/troubleshoot." + ) + _LOGGER.warning(message) + raise ClientAuthenticationError(message=message) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/client_assertion.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/client_assertion.py index 64150cfdcd44..a316760455e1 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/client_assertion.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/client_assertion.py @@ -4,7 +4,7 @@ # ------------------------------------ from typing import Any, Callable, Optional -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessTokenInfo from .._internal import AadClient, AsyncContextManager from .._internal.get_token_mixin import GetTokenMixin @@ -66,10 +66,10 @@ async def close(self) -> None: """Close the credential's transport session.""" await self._client.close() - async def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessToken]: + async def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessTokenInfo]: return self._client.get_cached_access_token(scopes, **kwargs) - async def _request_token(self, *scopes: str, **kwargs: Any) -> AccessToken: + async def _request_token(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo: assertion = self._func() token = await self._client.obtain_token_by_jwt_assertion(scopes, assertion, **kwargs) return token diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/client_secret.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/client_secret.py index f5495b17da6b..50bbb3de9315 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/client_secret.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/client_secret.py @@ -4,7 +4,7 @@ # ------------------------------------ from typing import Optional, Any -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessTokenInfo from .._internal import AadClient, AsyncContextManager from .._internal.get_token_mixin import GetTokenMixin from ..._internal import validate_tenant_id @@ -60,8 +60,8 @@ async def close(self) -> None: await self._client.__aexit__() - async def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessToken]: + async def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessTokenInfo]: return self._client.get_cached_access_token(scopes, **kwargs) - async def _request_token(self, *scopes: str, **kwargs: Any) -> AccessToken: + async def _request_token(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo: return await self._client.obtain_token_by_client_secret(scopes, self._secret, **kwargs) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/default.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/default.py index bd8672a2c7cc..ded986453685 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/default.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/default.py @@ -4,9 +4,10 @@ # ------------------------------------ import logging import os -from typing import List, Optional, TYPE_CHECKING, Any, cast +from typing import List, Optional, Any, cast -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions +from azure.core.credentials_async import AsyncTokenCredential, AsyncSupportsTokenInfo from ..._constants import EnvironmentVariables from ..._internal import get_default_authority, normalize_authority, within_dac from .azure_cli import AzureCliCredential @@ -19,8 +20,6 @@ from .vscode import VisualStudioCodeCredential from .workload_identity import WorkloadIdentityCredential -if TYPE_CHECKING: - from azure.core.credentials_async import AsyncTokenCredential _LOGGER = logging.getLogger(__name__) @@ -134,7 +133,7 @@ def __init__(self, **kwargs: Any) -> None: # pylint: disable=too-many-statement exclude_shared_token_cache_credential = kwargs.pop("exclude_shared_token_cache_credential", False) exclude_powershell_credential = kwargs.pop("exclude_powershell_credential", False) - credentials = [] # type: List[AsyncTokenCredential] + credentials: List[AsyncTokenCredential] = [] within_dac.set(True) if not exclude_environment_credential: credentials.append(EnvironmentCredential(authority=authority, _within_dac=True, **kwargs)) @@ -197,8 +196,44 @@ async def get_token( `message` attribute listing each authentication attempt and its error message. """ if self._successful_credential: - return await self._successful_credential.get_token(*scopes, claims=claims, tenant_id=tenant_id, **kwargs) + token = await self._successful_credential.get_token(*scopes, claims=claims, tenant_id=tenant_id, **kwargs) + _LOGGER.info( + "%s acquired a token from %s", self.__class__.__name__, self._successful_credential.__class__.__name__ + ) + return token + within_dac.set(True) token = await super().get_token(*scopes, claims=claims, tenant_id=tenant_id, **kwargs) within_dac.set(False) return token + + async def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: + """Asynchronously request an access token for `scopes`. + + This is an alternative to `get_token` to enable certain scenarios that require additional properties + on the token. This method is called automatically by Azure SDK clients. + + :param str scopes: desired scopes for the access token. This method requires at least one scope. + For more information about scopes, see https://learn.microsoft.com/entra/identity-platform/scopes-oidc. + :keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional. + :paramtype options: ~azure.core.credentials.TokenRequestOptions + + :rtype: AccessTokenInfo + :return: An AccessTokenInfo instance containing information about the token. + + :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The exception has a + `message` attribute listing each authentication attempt and its error message. + """ + if self._successful_credential: + token_info = await cast(AsyncSupportsTokenInfo, self._successful_credential).get_token_info( + *scopes, options=options + ) + _LOGGER.info( + "%s acquired a token from %s", self.__class__.__name__, self._successful_credential.__class__.__name__ + ) + return token_info + + within_dac.set(True) + token_info = await cast(AsyncSupportsTokenInfo, super()).get_token_info(*scopes, options=options) + within_dac.set(False) + return token_info diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/environment.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/environment.py index 146c91edd4b0..bb981406d2a2 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/environment.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/environment.py @@ -4,9 +4,10 @@ # ------------------------------------ import logging import os -from typing import Optional, Union, Any +from typing import Optional, Union, Any, cast -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions +from azure.core.credentials_async import AsyncSupportsTokenInfo from .._internal.decorators import log_get_token_async from ... import CredentialUnavailableError from ..._constants import EnvironmentVariables @@ -124,3 +125,29 @@ async def get_token( ) raise CredentialUnavailableError(message=message) return await self._credential.get_token(*scopes, claims=claims, tenant_id=tenant_id, **kwargs) + + @log_get_token_async + async def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: + """Request an access token for `scopes`. + + This is an alternative to `get_token` to enable certain scenarios that require additional properties + on the token. This method is called automatically by Azure SDK clients. + + :param str scopes: desired scope for the access token. This method requires at least one scope. + For more information about scopes, see https://learn.microsoft.com/entra/identity-platform/scopes-oidc. + :keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional. + :paramtype options: ~azure.core.credentials.TokenRequestOptions + + :rtype: AccessTokenInfo + :return: An AccessTokenInfo instance containing information about the token. + + :raises ~azure.identity.CredentialUnavailableError: environment variable configuration is incomplete. + """ + if not self._credential: + message = ( + "EnvironmentCredential authentication unavailable. Environment variables are not fully configured.\n" + "Visit https://aka.ms/azsdk/python/identity/environmentcredential/troubleshoot to troubleshoot " + "this issue." + ) + raise CredentialUnavailableError(message=message) + return await cast(AsyncSupportsTokenInfo, self._credential).get_token_info(*scopes, options=options) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/imds.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/imds.py index 0f667be030e4..f9286c9f88f6 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/imds.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/imds.py @@ -6,7 +6,7 @@ from typing import Optional, Any from azure.core.exceptions import ClientAuthenticationError, HttpResponseError -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessTokenInfo from ... import CredentialUnavailableError from ..._constants import EnvironmentVariables from .._internal import AsyncContextManager @@ -34,10 +34,10 @@ async def __aenter__(self) -> "ImdsCredential": async def close(self) -> None: await self._client.close() - async def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessToken]: + async def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessTokenInfo]: return self._client.get_cached_token(*scopes) - async def _request_token(self, *scopes: str, **kwargs: Any) -> AccessToken: # pylint:disable=unused-argument + async def _request_token(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo: # pylint:disable=unused-argument if within_credential_chain.get() and not self._endpoint_available: # If within a chain (e.g. DefaultAzureCredential), we do a quick check to see if the IMDS endpoint @@ -56,7 +56,7 @@ async def _request_token(self, *scopes: str, **kwargs: Any) -> AccessToken: # p raise CredentialUnavailableError(message=error_message) from ex try: - token = await self._client.request_token(*scopes, headers={"Metadata": "true"}) + token_info = await self._client.request_token(*scopes, headers={"Metadata": "true"}) except CredentialUnavailableError: # Response is not json, skip the IMDS credential raise @@ -82,4 +82,4 @@ async def _request_token(self, *scopes: str, **kwargs: Any) -> AccessToken: # p # if anything else was raised, assume the endpoint is unavailable error_message = "ManagedIdentityCredential authentication unavailable, no response from the IMDS endpoint." raise CredentialUnavailableError(error_message) from ex - return token + return token_info diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/managed_identity.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/managed_identity.py index 0362886712ed..6b8ced00219b 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/managed_identity.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/managed_identity.py @@ -4,9 +4,10 @@ # ------------------------------------ import logging import os -from typing import TYPE_CHECKING, Optional, Any, Mapping +from typing import TYPE_CHECKING, Optional, Any, Mapping, cast -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions +from azure.core.credentials_async import AsyncSupportsTokenInfo from .._internal import AsyncContextManager from .._internal.decorators import log_get_token_async from ... import CredentialUnavailableError @@ -142,3 +143,28 @@ async def get_token( "troubleshoot this issue." ) return await self._credential.get_token(*scopes, claims=claims, tenant_id=tenant_id, **kwargs) + + @log_get_token_async + async def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: + """Request an access token for `scopes`. + + This is an alternative to `get_token` to enable certain scenarios that require additional properties + on the token. This method is called automatically by Azure SDK clients. + + :param str scopes: desired scope for the access token. This credential allows only one scope per request. + For more information about scopes, see https://learn.microsoft.com/entra/identity-platform/scopes-oidc. + :keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional. + :paramtype options: ~azure.core.credentials.TokenRequestOptions + + :rtype: AccessTokenInfo + :return: An AccessTokenInfo instance containing information about the token. + :raises ~azure.identity.CredentialUnavailableError: managed identity isn't available in the hosting environment. + """ + if not self._credential: + raise CredentialUnavailableError( + message="No managed identity endpoint found. \n" + "The Target Azure platform could not be determined from environment variables. \n" + "Visit https://aka.ms/azsdk/python/identity/managedidentitycredential/troubleshoot to " + "troubleshoot this issue." + ) + return await cast(AsyncSupportsTokenInfo, self._credential).get_token_info(*scopes, options=options) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/on_behalf_of.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/on_behalf_of.py index d807db72c948..102db030f63a 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/on_behalf_of.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/on_behalf_of.py @@ -6,7 +6,7 @@ from typing import Optional, Union, Any, Dict, Callable from azure.core.exceptions import ClientAuthenticationError -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessTokenInfo from .._internal import AadClient, AsyncContextManager from .._internal.get_token_mixin import GetTokenMixin from ..._credentials.certificate import get_client_credential @@ -111,10 +111,10 @@ async def __aenter__(self) -> "OnBehalfOfCredential": async def close(self) -> None: await self._client.close() - async def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessToken]: + async def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessTokenInfo]: return self._client.get_cached_access_token(scopes, **kwargs) - async def _request_token(self, *scopes: str, **kwargs: Any) -> AccessToken: + async def _request_token(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo: # Note we assume the cache has tokens for one user only. That's okay because each instance of this class is # locked to a single user (assertion). This assumption will become unsafe if this class allows applications # to change an instance's assertion. diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/shared_cache.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/shared_cache.py index 395ccb009a7e..e710013ed1b4 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/shared_cache.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/shared_cache.py @@ -2,8 +2,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ -from typing import Any, Optional -from azure.core.credentials import AccessToken +from typing import Any, Optional, cast +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions from ..._internal.aad_client import AadClientBase from ... import CredentialUnavailableError from ..._constants import DEVELOPER_SIGN_ON_CLIENT_ID @@ -48,7 +48,7 @@ async def get_token( claims: Optional[str] = None, tenant_id: Optional[str] = None, enable_cae: bool = False, - **kwargs: Any + **kwargs: Any, ) -> AccessToken: """Get an access token for `scopes` from the shared cache. @@ -73,13 +73,58 @@ async def get_token( attribute gives a reason. Any error response from Microsoft Entra ID is available as the error's ``response`` attribute. """ + options: TokenRequestOptions = {} + if claims: + options["claims"] = claims + if tenant_id: + options["tenant_id"] = tenant_id + options["enable_cae"] = enable_cae + + token_info = await self._get_token_base(*scopes, options=options, base_method_name="get_token", **kwargs) + return AccessToken(token_info.token, token_info.expires_on) + + @log_get_token_async + async def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: + """Get an access token for `scopes` from the shared cache. + + If no access token is cached, attempt to acquire one using a cached refresh token. + + This is an alternative to `get_token` to enable certain scenarios that require additional properties + on the token. This method is called automatically by Azure SDK clients. + + :param str scopes: desired scope for the access token. This method requires at least one scope. + For more information about scopes, see https://learn.microsoft.com/entra/identity-platform/scopes-oidc. + :keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional. + :paramtype options: ~azure.core.credentials.TokenRequestOptions + + :rtype: AccessTokenInfo + :return: An AccessTokenInfo instance containing information about the token. + :raises ~azure.identity.CredentialUnavailableError: the cache is unavailable or contains insufficient user + information + :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message`` + attribute gives a reason. Any error response from Microsoft Entra ID is available as the error's + ``response`` attribute. + """ + return await self._get_token_base(*scopes, options=options, base_method_name="get_token_info") + + async def _get_token_base( + self, + *scopes: str, + options: Optional[TokenRequestOptions] = None, + base_method_name: str = "get_token", + **kwargs: Any, + ) -> AccessTokenInfo: if not scopes: - raise ValueError("'get_token' requires at least one scope") + raise ValueError(f"'{base_method_name}' requires at least one scope") if not self._client_initialized: self._initialize_client() - is_cae = enable_cae + options = options or {} + claims = options.get("claims") + tenant_id = options.get("tenant_id") + is_cae = options.get("enable_cae", False) + token_cache = self._cae_cache if is_cae else self._cache # Try to load the cache if it is None. @@ -98,8 +143,8 @@ async def get_token( # try each refresh token, returning the first access token acquired for refresh_token in self._get_refresh_tokens(account, is_cae=is_cae): - token = await self._client.obtain_token_by_refresh_token( - scopes, refresh_token, claims=claims, tenant_id=tenant_id, **kwargs + token = await cast(AadClient, self._client).obtain_token_by_refresh_token( + scopes, refresh_token, claims=claims, tenant_id=tenant_id, enable_cae=is_cae, **kwargs ) return token diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/vscode.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/vscode.py index 35eda5eb547a..9451cc45f01e 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/vscode.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/vscode.py @@ -4,7 +4,7 @@ # ------------------------------------ from typing import cast, Optional, Any -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions from azure.core.exceptions import ClientAuthenticationError from ..._exceptions import CredentialUnavailableError from .._internal import AsyncContextManager @@ -83,11 +83,42 @@ async def get_token( raise CredentialUnavailableError(message=ex.message) from ex return await super().get_token(*scopes, claims=claims, tenant_id=tenant_id, **kwargs) - async def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessToken]: + async def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: + """Request an access token for `scopes` as the user currently signed in to Visual Studio Code. + + This is an alternative to `get_token` to enable certain scenarios that require additional properties + on the token. This method is called automatically by Azure SDK clients. + + :param str scopes: desired scopes for the access token. This method requires at least one scope. + For more information about scopes, see https://learn.microsoft.com/entra/identity-platform/scopes-oidc. + :keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional. + :paramtype options: ~azure.core.credentials.TokenRequestOptions + + :rtype: AccessTokenInfo + :return: An AccessTokenInfo instance containing information about the token. + :raises ~azure.identity.CredentialUnavailableError: the credential cannot retrieve user details from Visual + Studio Code. + """ + if self._unavailable_reason: + error_message = ( + self._unavailable_reason + "\n" + "Visit https://aka.ms/azsdk/python/identity/vscodecredential/troubleshoot" + " to troubleshoot this issue." + ) + raise CredentialUnavailableError(message=error_message) + if within_dac.get(): + try: + token = await super().get_token_info(*scopes, options=options) + return token + except ClientAuthenticationError as ex: + raise CredentialUnavailableError(message=ex.message) from ex + return await super().get_token_info(*scopes, options=options) + + async def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessTokenInfo]: self._client = cast(AadClient, self._client) return self._client.get_cached_access_token(scopes, **kwargs) - async def _request_token(self, *scopes: str, **kwargs: Any) -> AccessToken: + async def _request_token(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo: refresh_token = self._get_refresh_token() self._client = cast(AadClient, self._client) return await self._client.obtain_token_by_refresh_token(scopes, refresh_token, **kwargs) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_internal/aad_client.py b/sdk/identity/azure-identity/azure/identity/aio/_internal/aad_client.py index 1a6bb03cc19b..7b99f85ac912 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_internal/aad_client.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_internal/aad_client.py @@ -5,7 +5,7 @@ import time from typing import Iterable, Optional, Union, Dict, Any -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessTokenInfo from azure.core.pipeline import AsyncPipeline from azure.core.pipeline.policies import AsyncHTTPPolicy, SansIOHTTPPolicy from azure.core.pipeline.transport import HttpRequest @@ -32,7 +32,7 @@ async def close(self) -> None: async def obtain_token_by_authorization_code( self, scopes: Iterable[str], code: str, redirect_uri: str, client_secret: Optional[str] = None, **kwargs - ) -> AccessToken: + ) -> AccessTokenInfo: request = self._get_auth_code_request( scopes=scopes, code=code, redirect_uri=redirect_uri, client_secret=client_secret, **kwargs ) @@ -40,19 +40,21 @@ async def obtain_token_by_authorization_code( async def obtain_token_by_client_certificate( self, scopes: Iterable[str], certificate: AadClientCertificate, **kwargs - ) -> AccessToken: + ) -> AccessTokenInfo: request = self._get_client_certificate_request(scopes, certificate, **kwargs) return await self._run_pipeline(request, stream=False, **kwargs) - async def obtain_token_by_client_secret(self, scopes: Iterable[str], secret: str, **kwargs) -> AccessToken: + async def obtain_token_by_client_secret(self, scopes: Iterable[str], secret: str, **kwargs) -> AccessTokenInfo: request = self._get_client_secret_request(scopes, secret, **kwargs) return await self._run_pipeline(request, **kwargs) - async def obtain_token_by_jwt_assertion(self, scopes: Iterable[str], assertion: str, **kwargs) -> AccessToken: + async def obtain_token_by_jwt_assertion(self, scopes: Iterable[str], assertion: str, **kwargs) -> AccessTokenInfo: request = self._get_jwt_assertion_request(scopes, assertion, **kwargs) return await self._run_pipeline(request, stream=False, **kwargs) - async def obtain_token_by_refresh_token(self, scopes: Iterable[str], refresh_token: str, **kwargs) -> AccessToken: + async def obtain_token_by_refresh_token( + self, scopes: Iterable[str], refresh_token: str, **kwargs + ) -> AccessTokenInfo: request = self._get_refresh_token_request(scopes, refresh_token, **kwargs) return await self._run_pipeline(request, **kwargs) @@ -62,7 +64,7 @@ async def obtain_token_by_refresh_token_on_behalf_of( # pylint: disable=name-to client_credential: Union[str, AadClientCertificate, Dict[str, Any]], refresh_token: str, **kwargs - ) -> AccessToken: + ) -> AccessTokenInfo: request = self._get_refresh_token_on_behalf_of_request( scopes, client_credential=client_credential, refresh_token=refresh_token, **kwargs ) @@ -74,7 +76,7 @@ async def obtain_token_on_behalf_of( client_credential: Union[str, AadClientCertificate, Dict[str, Any]], user_assertion: str, **kwargs - ) -> AccessToken: + ) -> AccessTokenInfo: request = self._get_on_behalf_of_request( scopes=scopes, client_credential=client_credential, user_assertion=user_assertion, **kwargs ) @@ -83,7 +85,7 @@ async def obtain_token_on_behalf_of( def _build_pipeline(self, **kwargs) -> AsyncPipeline: return build_async_pipeline(**kwargs) - async def _run_pipeline(self, request: HttpRequest, **kwargs) -> AccessToken: + async def _run_pipeline(self, request: HttpRequest, **kwargs) -> AccessTokenInfo: # remove tenant_id and claims kwarg that could have been passed from credential's get_token method # tenant_id is already part of `request` at this point kwargs.pop("tenant_id", None) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_internal/get_token_mixin.py b/sdk/identity/azure-identity/azure/identity/aio/_internal/get_token_mixin.py index d9ceaf4e8612..24bf36ef8811 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_internal/get_token_mixin.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_internal/get_token_mixin.py @@ -7,7 +7,7 @@ import time from typing import Any, Optional -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions from ..._constants import DEFAULT_REFRESH_OFFSET, DEFAULT_TOKEN_REFRESH_RETRY_DELAY from ..._internal import within_credential_chain @@ -22,7 +22,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: super(GetTokenMixin, self).__init__(*args, **kwargs) # type: ignore @abc.abstractmethod - async def _acquire_token_silently(self, *scopes: str, **kwargs) -> Optional[AccessToken]: + async def _acquire_token_silently(self, *scopes: str, **kwargs) -> Optional[AccessTokenInfo]: """Attempt to acquire an access token from a cache or by redeeming a refresh token. :param str scopes: desired scopes for the access token. This method requires at least one scope. @@ -30,11 +30,11 @@ async def _acquire_token_silently(self, *scopes: str, **kwargs) -> Optional[Acce https://learn.microsoft.com/entra/identity-platform/scopes-oidc. :return: An access token with the desired scopes if successful; otherwise, None. - :rtype: ~azure.core.credentials.AccessToken or None + :rtype: ~azure.core.credentials.AccessTokenInfo or None """ @abc.abstractmethod - async def _request_token(self, *scopes: str, **kwargs) -> AccessToken: + async def _request_token(self, *scopes: str, **kwargs) -> AccessTokenInfo: """Request an access token from the STS. :param str scopes: desired scopes for the access token. This method requires at least one scope. @@ -42,11 +42,13 @@ async def _request_token(self, *scopes: str, **kwargs) -> AccessToken: https://learn.microsoft.com/entra/identity-platform/scopes-oidc. :return: An access token with the desired scopes. - :rtype: ~azure.core.credentials.AccessToken + :rtype: ~azure.core.credentials.AccessTokenInfo """ - def _should_refresh(self, token: AccessToken) -> bool: + def _should_refresh(self, token: AccessTokenInfo) -> bool: now = int(time.time()) + if token.refresh_on is not None and now >= token.refresh_on: + return True if token.expires_on - now > DEFAULT_REFRESH_OFFSET: return False if now - self._last_request_time < DEFAULT_TOKEN_REFRESH_RETRY_DELAY: @@ -59,7 +61,7 @@ async def get_token( claims: Optional[str] = None, tenant_id: Optional[str] = None, enable_cae: bool = False, - **kwargs: Any + **kwargs: Any, ) -> AccessToken: """Request an access token for `scopes`. @@ -81,8 +83,50 @@ async def get_token( :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message`` attribute gives a reason. """ + options: TokenRequestOptions = {} + if claims: + options["claims"] = claims + if tenant_id: + options["tenant_id"] = tenant_id + options["enable_cae"] = enable_cae + + token_info = await self._get_token_base(*scopes, options=options, base_method_name="get_token", **kwargs) + return AccessToken(token_info.token, token_info.expires_on) + + async def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: + """Request an access token for `scopes`. + + This is an alternative to `get_token` to enable certain scenarios that require additional properties + on the token. This method is called automatically by Azure SDK clients. + + :param str scopes: desired scopes for the access token. This method requires at least one scope. + For more information about scopes, see https://learn.microsoft.com/entra/identity-platform/scopes-oidc. + :keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional. + :paramtype options: ~azure.core.credentials.TokenRequestOptions + + :rtype: AccessTokenInfo + :return: An AccessTokenInfo instance containing information about the token. + :raises CredentialUnavailableError: the credential is unable to attempt authentication because it lacks + required data, state, or platform support + :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message`` + attribute gives a reason. + """ + return await self._get_token_base(*scopes, options=options, base_method_name="get_token_info") + + async def _get_token_base( + self, + *scopes: str, + options: Optional[TokenRequestOptions] = None, + base_method_name: str = "get_token", + **kwargs: Any, + ) -> AccessTokenInfo: if not scopes: - raise ValueError('"get_token" requires at least one scope') + raise ValueError(f'"{base_method_name}" requires at least one scope') + + options = options or {} + claims = options.get("claims") + tenant_id = options.get("tenant_id") + enable_cae = options.get("enable_cae", False) try: token = await self._acquire_token_silently( @@ -103,16 +147,18 @@ async def get_token( pass _LOGGER.log( logging.DEBUG if within_credential_chain.get() else logging.INFO, - "%s.get_token succeeded", + "%s.%s succeeded", self.__class__.__name__, + base_method_name, ) return token except Exception as ex: _LOGGER.log( logging.DEBUG if within_credential_chain.get() else logging.WARNING, - "%s.get_token failed: %s", + "%s.%s failed: %s", self.__class__.__name__, + base_method_name, ex, exc_info=_LOGGER.isEnabledFor(logging.DEBUG), ) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_internal/managed_identity_base.py b/sdk/identity/azure-identity/azure/identity/aio/_internal/managed_identity_base.py index 1bebc28fb5c7..636fbbf9b2f7 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_internal/managed_identity_base.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_internal/managed_identity_base.py @@ -6,7 +6,7 @@ from types import TracebackType from typing import Any, cast, Optional, TypeVar, Type -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions from . import AsyncContextManager from .get_token_mixin import GetTokenMixin from .managed_identity_client import AsyncManagedIdentityClient @@ -54,10 +54,15 @@ async def get_token( raise CredentialUnavailableError(message=self.get_unavailable_message()) return await super().get_token(*scopes, claims=claims, tenant_id=tenant_id, **kwargs) - async def _acquire_token_silently(self, *scopes: str, **kwargs) -> Optional[AccessToken]: + async def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: + if not self._client: + raise CredentialUnavailableError(message=self.get_unavailable_message()) + return await super().get_token_info(*scopes, options=options) + + async def _acquire_token_silently(self, *scopes: str, **kwargs) -> Optional[AccessTokenInfo]: # casting because mypy can't determine that these methods are called # only by get_token, which raises when self._client is None return cast(AsyncManagedIdentityClient, self._client).get_cached_token(*scopes) - async def _request_token(self, *scopes: str, **kwargs) -> AccessToken: + async def _request_token(self, *scopes: str, **kwargs) -> AccessTokenInfo: return await cast(AsyncManagedIdentityClient, self._client).request_token(*scopes, **kwargs) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_internal/managed_identity_client.py b/sdk/identity/azure-identity/azure/identity/aio/_internal/managed_identity_client.py index 05503316d853..340bb5a11fa2 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_internal/managed_identity_client.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_internal/managed_identity_client.py @@ -5,7 +5,7 @@ import time from typing import TypeVar -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessTokenInfo from azure.core.pipeline import AsyncPipeline from .._internal import AsyncContextManager from ..._internal import _scopes_to_resource @@ -24,7 +24,7 @@ async def __aenter__(self: T) -> T: async def close(self) -> None: await self._pipeline.__aexit__() - async def request_token(self, *scopes: str, **kwargs) -> AccessToken: + async def request_token(self, *scopes: str, **kwargs) -> AccessTokenInfo: # pylint:disable=invalid-overridden-method resource = _scopes_to_resource(*scopes) request = self._request_factory(resource, self._identity_config) diff --git a/sdk/identity/azure-identity/setup.py b/sdk/identity/azure-identity/setup.py index e16249897f84..3b28cffbe847 100644 --- a/sdk/identity/azure-identity/setup.py +++ b/sdk/identity/azure-identity/setup.py @@ -61,7 +61,7 @@ install_requires=[ "azure-core>=1.23.0", "cryptography>=2.5", - "msal>=1.29.0", + "msal>=1.30.0", "msal-extensions>=1.2.0", "typing-extensions>=4.0.0", ], diff --git a/sdk/identity/azure-identity/tests/helpers.py b/sdk/identity/azure-identity/tests/helpers.py index 23378598869e..1a6543133962 100644 --- a/sdk/identity/azure-identity/tests/helpers.py +++ b/sdk/identity/azure-identity/tests/helpers.py @@ -6,15 +6,15 @@ import json import time from urllib.parse import urlparse +from unittest import mock -try: - from unittest import mock -except ImportError: # python < 3.3 - import mock # type: ignore +from azure.core.credentials import AccessToken, AccessTokenInfo FAKE_CLIENT_ID = "fake-client-id" INVALID_CHARACTERS = "|\\`;{&' " +ACCESS_TOKEN_CLASSES = (AccessToken, AccessTokenInfo) +GET_TOKEN_METHODS = ("get_token", "get_token_info") def build_id_token( diff --git a/sdk/identity/azure-identity/tests/helpers_async.py b/sdk/identity/azure-identity/tests/helpers_async.py index 308b7db75b18..2a66e167c325 100644 --- a/sdk/identity/azure-identity/tests/helpers_async.py +++ b/sdk/identity/azure-identity/tests/helpers_async.py @@ -10,15 +10,6 @@ from helpers import validating_transport -def await_test(fn): - @functools.wraps(fn) - def wrapper(*args, **kwargs): - loop = asyncio.get_event_loop() - loop.run_until_complete(fn(*args, **kwargs)) - - return wrapper - - def get_completed_future(result=None): future = asyncio.Future() future.set_result(result) diff --git a/sdk/identity/azure-identity/tests/test_aad_client.py b/sdk/identity/azure-identity/tests/test_aad_client.py index bd0dc721eb35..b9f8c374d341 100644 --- a/sdk/identity/azure-identity/tests/test_aad_client.py +++ b/sdk/identity/azure-identity/tests/test_aad_client.py @@ -3,6 +3,7 @@ # Licensed under the MIT License. # ------------------------------------ import functools +from unittest.mock import Mock, patch from azure.core.exceptions import ClientAuthenticationError, ServiceRequestError from azure.identity._constants import EnvironmentVariables @@ -15,11 +16,6 @@ from helpers import build_aad_response, mock_response from test_certificate_credential import PEM_CERT_PATH -try: - from unittest.mock import Mock, patch -except ImportError: # python < 3.3 - from mock import Mock, patch # type: ignore - BASE_CLASS_METHODS = [ ("_get_auth_code_request", ("code", "redirect_uri")), diff --git a/sdk/identity/azure-identity/tests/test_app_service_async.py b/sdk/identity/azure-identity/tests/test_app_service_async.py index 5d32df81c742..e79964c39a73 100644 --- a/sdk/identity/azure-identity/tests/test_app_service_async.py +++ b/sdk/identity/azure-identity/tests/test_app_service_async.py @@ -11,7 +11,6 @@ from devtools_testutils import is_live from devtools_testutils.aio import recorded_by_proxy_async -from helpers_async import await_test from recorded_test_case import RecordedTestCase from test_app_service import PLAYBACK_URL @@ -33,7 +32,7 @@ def load_settings(self): self.patch = mock.patch.dict(os.environ, env, clear=True) @pytest.mark.manual - @await_test + @pytest.mark.asyncio @recorded_by_proxy_async async def test_system_assigned(self): self.load_settings() @@ -44,7 +43,7 @@ async def test_system_assigned(self): assert isinstance(token.expires_on, int) @pytest.mark.manual - @await_test + @pytest.mark.asyncio @recorded_by_proxy_async async def test_system_assigned_tenant_id(self): with self.patch: @@ -55,7 +54,7 @@ async def test_system_assigned_tenant_id(self): @pytest.mark.manual @pytest.mark.usefixtures("user_assigned_identity_client_id") - @await_test + @pytest.mark.asyncio @recorded_by_proxy_async async def test_user_assigned(self): self.load_settings() @@ -67,7 +66,7 @@ async def test_user_assigned(self): @pytest.mark.manual @pytest.mark.usefixtures("user_assigned_identity_client_id") - @await_test + @pytest.mark.asyncio @recorded_by_proxy_async async def test_user_assigned_tenant_id(self): with self.patch: diff --git a/sdk/identity/azure-identity/tests/test_application_credential.py b/sdk/identity/azure-identity/tests/test_application_credential.py index c7a8a84f9fff..cd761d17aef8 100644 --- a/sdk/identity/azure-identity/tests/test_application_credential.py +++ b/sdk/identity/azure-identity/tests/test_application_credential.py @@ -3,23 +3,21 @@ # Licensed under the MIT License. # ------------------------------------ import os +from unittest.mock import Mock, patch +from urllib.parse import urlparse -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo from azure.identity import CredentialUnavailableError from azure.identity._credentials.application import AzureApplicationCredential from azure.identity._constants import EnvironmentVariables import pytest from urllib.parse import urlparse -try: - from unittest.mock import Mock, patch -except ImportError: # python < 3.3 - from mock import Mock, patch # type: ignore - -from helpers import build_aad_response, get_discovery_response, mock_response +from helpers import build_aad_response, get_discovery_response, mock_response, GET_TOKEN_METHODS -def test_get_token(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_get_token(get_token_method): expected_token = "***" def send(request, **kwargs): @@ -35,34 +33,42 @@ def send(request, **kwargs): with patch.dict("os.environ", {var: "..." for var in EnvironmentVariables.CLIENT_SECRET_VARS}, clear=True): credential = AzureApplicationCredential(transport=Mock(send=send)) - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == expected_token -def test_iterates_only_once(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_iterates_only_once(get_token_method): """When a credential succeeds, AzureApplicationCredential should use that credential thereafter""" - expected_token = AccessToken("***", 42) + access_token = "***" unavailable_credential = Mock( - spec_set=["get_token"], get_token=Mock(side_effect=CredentialUnavailableError(message="...")) + spec_set=["get_token", "get_token_info"], + get_token=Mock(side_effect=CredentialUnavailableError(message="...")), + get_token_info=Mock(side_effect=CredentialUnavailableError(message="...")), + ) + successful_credential = Mock( + spec_set=["get_token", "get_token_info"], + get_token=Mock(return_value=AccessToken(access_token, 42)), + get_token_info=Mock(return_value=AccessTokenInfo(access_token, 42)), ) - successful_credential = Mock(spec_set=["get_token"], get_token=Mock(return_value=expected_token)) credential = AzureApplicationCredential() - credential.credentials = [ + credential.credentials = ( unavailable_credential, successful_credential, Mock( - spec_set=["get_token"], + spec_set=["get_token", "get_token_info"], get_token=Mock(side_effect=Exception("iteration didn't stop after a credential provided a token")), + get_token_info=Mock(side_effect=Exception("iteration didn't stop after a credential provided a token")), ), - ] + ) for n in range(3): - token = credential.get_token("scope") - assert token.token == expected_token.token - assert unavailable_credential.get_token.call_count == 1 - assert successful_credential.get_token.call_count == n + 1 + token = getattr(credential, get_token_method)("scope") + assert token.token == access_token + assert getattr(unavailable_credential, get_token_method).call_count == 1 + assert getattr(successful_credential, get_token_method).call_count == n + 1 @pytest.mark.parametrize("authority", ("localhost", "https://localhost")) diff --git a/sdk/identity/azure-identity/tests/test_application_credential_async.py b/sdk/identity/azure-identity/tests/test_application_credential_async.py index a317c840308c..58a4a64edabc 100644 --- a/sdk/identity/azure-identity/tests/test_application_credential_async.py +++ b/sdk/identity/azure-identity/tests/test_application_credential_async.py @@ -5,44 +5,50 @@ import os from unittest.mock import Mock, patch -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo from azure.identity import CredentialUnavailableError from azure.identity.aio._credentials.application import AzureApplicationCredential from azure.identity._constants import EnvironmentVariables import pytest from urllib.parse import urlparse -from helpers import build_aad_response, mock_response +from helpers import build_aad_response, mock_response, GET_TOKEN_METHODS from helpers_async import get_completed_future @pytest.mark.asyncio -async def test_iterates_only_once(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_iterates_only_once(get_token_method): """When a credential succeeds, AzureApplicationCredential should use that credential thereafter""" - expected_token = AccessToken("***", 42) + access_token = "***" unavailable_credential = Mock( - spec_set=["get_token"], get_token=Mock(side_effect=CredentialUnavailableError(message="...")) + spec_set=["get_token", "get_token_info"], + get_token=Mock(side_effect=CredentialUnavailableError(message="...")), + get_token_info=Mock(side_effect=CredentialUnavailableError(message="...")), ) successful_credential = Mock( - spec_set=["get_token"], get_token=Mock(return_value=get_completed_future(expected_token)) + spec_set=["get_token", "get_token_info"], + get_token=Mock(return_value=get_completed_future(AccessToken(access_token, 42))), + get_token_info=Mock(return_value=get_completed_future(AccessTokenInfo(access_token, 42))), ) credential = AzureApplicationCredential() - credential.credentials = [ + credential.credentials = ( unavailable_credential, successful_credential, Mock( - spec_set=["get_token"], + spec_set=["get_token", "get_token_info"], get_token=Mock(side_effect=Exception("iteration didn't stop after a credential provided a token")), + get_token_info=Mock(side_effect=Exception("iteration didn't stop after a credential provided a token")), ), - ] + ) for n in range(3): - token = await credential.get_token("scope") - assert token.token == expected_token.token - assert unavailable_credential.get_token.call_count == 1 - assert successful_credential.get_token.call_count == n + 1 + token = await getattr(credential, get_token_method)("scope") + assert token.token == access_token + assert getattr(unavailable_credential, get_token_method).call_count == 1 + assert getattr(successful_credential, get_token_method).call_count == n + 1 @pytest.mark.parametrize("authority", ("localhost", "https://localhost")) @@ -83,7 +89,8 @@ def test_initialization(mock_credential, expect_argument): @pytest.mark.asyncio -async def test_get_token(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_get_token(get_token_method): expected_token = "***" async def send(request, **kwargs): @@ -95,7 +102,7 @@ async def send(request, **kwargs): with patch.dict("os.environ", {var: "..." for var in EnvironmentVariables.CLIENT_SECRET_VARS}, clear=True): credential = AzureApplicationCredential(transport=Mock(send=send)) - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == expected_token diff --git a/sdk/identity/azure-identity/tests/test_auth_code.py b/sdk/identity/azure-identity/tests/test_auth_code.py index a4ee062e84d1..0d2f5f4d8811 100644 --- a/sdk/identity/azure-identity/tests/test_auth_code.py +++ b/sdk/identity/azure-identity/tests/test_auth_code.py @@ -2,32 +2,30 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ -from azure.core.exceptions import ClientAuthenticationError +from unittest.mock import Mock, patch +from urllib.parse import urlparse + from azure.core.pipeline.policies import SansIOHTTPPolicy from azure.identity import AuthorizationCodeCredential from azure.identity._constants import EnvironmentVariables from azure.identity._internal.user_agent import USER_AGENT import msal import pytest -from urllib.parse import urlparse - -from helpers import build_aad_response, mock_response, Request, validating_transport -try: - from unittest.mock import Mock, patch -except ImportError: # python < 3.3 - from mock import Mock, patch # type: ignore +from helpers import build_aad_response, mock_response, Request, validating_transport, GET_TOKEN_METHODS -def test_no_scopes(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_no_scopes(get_token_method): """The credential should raise ValueError when get_token is called with no scopes""" credential = AuthorizationCodeCredential("tenant-id", "client-id", "auth-code", "http://localhost") with pytest.raises(ValueError): - credential.get_token() + getattr(credential, get_token_method)() -def test_policies_configurable(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_policies_configurable(get_token_method): policy = Mock(spec_set=SansIOHTTPPolicy, on_request=Mock()) def send(*_, **kwargs): @@ -40,12 +38,13 @@ def send(*_, **kwargs): "tenant-id", "client-id", "auth-code", "http://localhost", policies=[policy], transport=Mock(send=send) ) - credential.get_token("scope") + getattr(credential, get_token_method)("scope") assert policy.on_request.called -def test_user_agent(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_user_agent(get_token_method): transport = validating_transport( requests=[Request(required_headers={"User-Agent": USER_AGENT})], responses=[mock_response(json_payload=build_aad_response(access_token="**"))], @@ -55,10 +54,11 @@ def test_user_agent(): "tenant-id", "client-id", "auth-code", "http://localhost", transport=transport ) - credential.get_token("scope") + getattr(credential, get_token_method)("scope") -def test_tenant_id(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_tenant_id(get_token_method): transport = validating_transport( requests=[Request(required_headers={"User-Agent": USER_AGENT})], responses=[mock_response(json_payload=build_aad_response(access_token="**"))], @@ -73,10 +73,14 @@ def test_tenant_id(): additionally_allowed_tenants=["*"], ) - credential.get_token("scope", tenant_id="tenant_id") + kwargs = {"tenant_id": "tenant_id"} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + getattr(credential, get_token_method)("scope", **kwargs) -def test_auth_code_credential(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_auth_code_credential(get_token_method): client_id = "client id" secret = "fake-client-secret" tenant_id = "tenant" @@ -126,24 +130,25 @@ def test_auth_code_credential(): ) # first call should redeem the auth code - token = credential.get_token(expected_scope) + token = getattr(credential, get_token_method)(expected_scope) assert token.token == expected_access_token assert transport.send.call_count == 1 # no auth code -> credential should return cached token - token = credential.get_token(expected_scope) + token = getattr(credential, get_token_method)(expected_scope) assert token.token == expected_access_token assert transport.send.call_count == 1 # no auth code, no cached token -> credential should redeem refresh token cached_access_token = list(cache.search(cache.CredentialType.ACCESS_TOKEN))[0] cache.remove_at(cached_access_token) - token = credential.get_token(expected_scope) + token = getattr(credential, get_token_method)(expected_scope) assert token.token == expected_access_token assert transport.send.call_count == 2 -def test_multitenant_authentication(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_multitenant_authentication(get_token_method): first_tenant = "first-tenant" first_token = "***" second_tenant = "second-tenant" @@ -167,21 +172,28 @@ def send(request, **kwargs): transport=Mock(send=send), additionally_allowed_tenants=["*"], ) - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == first_token - token = credential.get_token("scope", tenant_id=first_tenant) + kwargs = {"tenant_id": first_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert token.token == first_token - token = credential.get_token("scope", tenant_id=second_tenant) + kwargs = {"tenant_id": second_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert token.token == second_token # should still default to the first tenant - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == first_token -def test_multitenant_authentication_not_allowed(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_multitenant_authentication_not_allowed(get_token_method): expected_tenant = "expected-tenant" expected_token = "***" @@ -203,15 +215,24 @@ def send(request, **kwargs): additionally_allowed_tenants=["*"], ) - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == expected_token - token = credential.get_token("scope", tenant_id=expected_tenant) + kwargs = {"tenant_id": expected_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert token.token == expected_token - token = credential.get_token("scope", tenant_id="un" + expected_tenant) + kwargs = {"tenant_id": "un" + expected_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert token.token == expected_token * 2 with patch.dict("os.environ", {EnvironmentVariables.AZURE_IDENTITY_DISABLE_MULTITENANTAUTH: "true"}): - token = credential.get_token("scope", tenant_id="un" + expected_tenant) + kwargs = {"tenant_id": "un" + expected_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert token.token == expected_token diff --git a/sdk/identity/azure-identity/tests/test_auth_code_async.py b/sdk/identity/azure-identity/tests/test_auth_code_async.py index 2d8c984851e2..04f86d510ca4 100644 --- a/sdk/identity/azure-identity/tests/test_auth_code_async.py +++ b/sdk/identity/azure-identity/tests/test_auth_code_async.py @@ -13,21 +13,23 @@ import msal import pytest -from helpers import build_aad_response, mock_response, Request +from helpers import build_aad_response, mock_response, Request, GET_TOKEN_METHODS from helpers_async import async_validating_transport, AsyncMockTransport pytestmark = pytest.mark.asyncio -async def test_no_scopes(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_no_scopes(get_token_method): """The credential should raise ValueError when get_token is called with no scopes""" credential = AuthorizationCodeCredential("tenant-id", "client-id", "auth-code", "http://localhost") with pytest.raises(ValueError): - await credential.get_token() + await getattr(credential, get_token_method)() -async def test_policies_configurable(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_policies_configurable(get_token_method): policy = Mock(spec_set=SansIOHTTPPolicy, on_request=Mock()) async def send(*_, **kwargs): @@ -40,7 +42,7 @@ async def send(*_, **kwargs): "tenant-id", "client-id", "auth-code", "http://localhost", policies=[policy], transport=Mock(send=send) ) - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") assert policy.on_request.called @@ -69,7 +71,8 @@ async def test_context_manager(): assert transport.__aexit__.call_count == 1 -async def test_user_agent(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_user_agent(get_token_method): transport = async_validating_transport( requests=[Request(required_headers={"User-Agent": USER_AGENT})], responses=[mock_response(json_payload=build_aad_response(access_token="**"))], @@ -79,10 +82,11 @@ async def test_user_agent(): "tenant-id", "client-id", "auth-code", "http://localhost", transport=transport ) - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") -async def test_tenant_id(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_tenant_id(get_token_method): transport = async_validating_transport( requests=[Request(required_headers={"User-Agent": USER_AGENT})], responses=[mock_response(json_payload=build_aad_response(access_token="**"))], @@ -97,10 +101,14 @@ async def test_tenant_id(): additionally_allowed_tenants=["*"], ) - await credential.get_token("scope", tenant_id="tenant_id") + kwargs = {"tenant_id": "tenant_id"} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + await getattr(credential, get_token_method)("scope", **kwargs) -async def test_auth_code_credential(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_auth_code_credential(get_token_method): client_id = "client id" secret = "fake-client-secret" tenant_id = "tenant" @@ -150,24 +158,25 @@ async def test_auth_code_credential(): ) # first call should redeem the auth code - token = await credential.get_token(expected_scope) + token = await getattr(credential, get_token_method)(expected_scope) assert token.token == expected_access_token assert transport.send.call_count == 1 # no auth code -> credential should return cached token - token = await credential.get_token(expected_scope) + token = await getattr(credential, get_token_method)(expected_scope) assert token.token == expected_access_token assert transport.send.call_count == 1 # no auth code, no cached token -> credential should redeem refresh token cached_access_token = list(cache.search(cache.CredentialType.ACCESS_TOKEN))[0] cache.remove_at(cached_access_token) - token = await credential.get_token(expected_scope) + token = await getattr(credential, get_token_method)(expected_scope) assert token.token == expected_access_token assert transport.send.call_count == 2 -async def test_multitenant_authentication(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_multitenant_authentication(get_token_method): first_tenant = "first-tenant" first_token = "***" second_tenant = "second-tenant" @@ -191,21 +200,28 @@ async def send(request, **kwargs): transport=Mock(send=send), additionally_allowed_tenants=["*"], ) - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == first_token - token = await credential.get_token("scope", tenant_id=first_tenant) + kwargs = {"tenant_id": first_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(credential, get_token_method)("scope", **kwargs) assert token.token == first_token - token = await credential.get_token("scope", tenant_id=second_tenant) + kwargs = {"tenant_id": second_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(credential, get_token_method)("scope", **kwargs) assert token.token == second_token # should still default to the first tenant - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == first_token -async def test_multitenant_authentication_not_allowed(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_multitenant_authentication_not_allowed(get_token_method): expected_tenant = "expected-tenant" expected_token = "***" @@ -227,15 +243,24 @@ async def send(request, **kwargs): additionally_allowed_tenants=["*"], ) - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == expected_token - token = await credential.get_token("scope", tenant_id=expected_tenant) + kwargs = {"tenant_id": expected_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(credential, get_token_method)("scope", **kwargs) assert token.token == expected_token - token = await credential.get_token("scope", tenant_id="un" + expected_tenant) + kwargs = {"tenant_id": "un" + expected_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(credential, get_token_method)("scope", **kwargs) assert token.token == expected_token * 2 with patch.dict("os.environ", {EnvironmentVariables.AZURE_IDENTITY_DISABLE_MULTITENANTAUTH: "true"}): - token = await credential.get_token("scope", tenant_id="un" + expected_tenant) + kwargs = {"tenant_id": "un" + expected_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(credential, get_token_method)("scope", **kwargs) assert token.token == expected_token diff --git a/sdk/identity/azure-identity/tests/test_authority.py b/sdk/identity/azure-identity/tests/test_authority.py index d6a31b627f0a..725c481ff5ca 100644 --- a/sdk/identity/azure-identity/tests/test_authority.py +++ b/sdk/identity/azure-identity/tests/test_authority.py @@ -2,10 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ -try: - from unittest.mock import Mock, patch -except ImportError: # python < 3.3 - from mock import Mock, patch # type: ignore +from unittest.mock import patch from azure.identity._constants import EnvironmentVariables, KnownAuthorities from azure.identity._internal import get_default_authority, normalize_authority diff --git a/sdk/identity/azure-identity/tests/test_azd_cli_credential.py b/sdk/identity/azure-identity/tests/test_azd_cli_credential.py index 0bceaaba3c05..44af829dfece 100644 --- a/sdk/identity/azure-identity/tests/test_azd_cli_credential.py +++ b/sdk/identity/azure-identity/tests/test_azd_cli_credential.py @@ -14,7 +14,7 @@ import subprocess import pytest -from helpers import mock, INVALID_CHARACTERS +from helpers import mock, INVALID_CHARACTERS, GET_TOKEN_METHODS CHECK_OUTPUT = AzureDeveloperCliCredential.__module__ + ".subprocess.check_output" @@ -35,14 +35,16 @@ def raise_called_process_error(return_code, output="", cmd="...", stderr=""): return mock.Mock(side_effect=error) -def test_no_scopes(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_no_scopes(get_token_method): """The credential should raise ValueError when get_token is called with no scopes""" with pytest.raises(ValueError): - AzureDeveloperCliCredential().get_token() + getattr(AzureDeveloperCliCredential(), get_token_method)() -def test_invalid_tenant_id(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_invalid_tenant_id(get_token_method): """Invalid tenant IDs should raise ValueErrors.""" for c in INVALID_CHARACTERS: @@ -50,21 +52,26 @@ def test_invalid_tenant_id(): AzureDeveloperCliCredential(tenant_id="tenant" + c) with pytest.raises(ValueError): - AzureDeveloperCliCredential().get_token("scope", tenant_id="tenant" + c) + kwargs = {"tenant_id": "tenant" + c} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + getattr(AzureDeveloperCliCredential(), get_token_method)("scope", **kwargs) -def test_invalid_scopes(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_invalid_scopes(get_token_method): """Scopes with invalid characters should raise ValueErrors.""" for c in INVALID_CHARACTERS: with pytest.raises(ValueError): - AzureDeveloperCliCredential().get_token("scope" + c) + getattr(AzureDeveloperCliCredential(), get_token_method)("scope" + c) with pytest.raises(ValueError): - AzureDeveloperCliCredential().get_token("scope", "scope2", "scope" + c) + getattr(AzureDeveloperCliCredential(), get_token_method)("scope", "scope2", "scope" + c) -def test_get_token(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_get_token(get_token_method): """The credential should parse the CLI's output to an token""" access_token = "access token" @@ -81,40 +88,44 @@ def test_get_token(): with mock.patch("shutil.which", return_value="azd"): with mock.patch(CHECK_OUTPUT, mock.Mock(return_value=successful_output)): - token = AzureDeveloperCliCredential().get_token("scope") + token = getattr(AzureDeveloperCliCredential(), get_token_method)("scope") assert token.token == access_token assert type(token.expires_on) == int assert token.expires_on == expected_expires_on -def test_cli_not_installed(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_cli_not_installed(get_token_method): """The credential should raise CredentialUnavailableError when the CLI isn't installed""" with mock.patch("shutil.which", return_value=None): with pytest.raises(CredentialUnavailableError, match=CLI_NOT_FOUND): - AzureDeveloperCliCredential().get_token("scope") + getattr(AzureDeveloperCliCredential(), get_token_method)("scope") -def test_cannot_execute_shell(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_cannot_execute_shell(get_token_method): """The credential should raise CredentialUnavailableError when the subprocess doesn't start""" with mock.patch("shutil.which", return_value="azd"): with mock.patch(CHECK_OUTPUT, mock.Mock(side_effect=OSError())): with pytest.raises(CredentialUnavailableError): - AzureDeveloperCliCredential().get_token("scope") + getattr(AzureDeveloperCliCredential(), get_token_method)("scope") -def test_not_logged_in(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_not_logged_in(get_token_method): """When the CLI isn't logged in, the credential should raise CredentialUnavailableError""" stderr = "ERROR: not logged in, run `azd auth login` to login" with mock.patch("shutil.which", return_value="azd"): with mock.patch(CHECK_OUTPUT, raise_called_process_error(1, stderr=stderr)): with pytest.raises(CredentialUnavailableError, match=NOT_LOGGED_IN): - AzureDeveloperCliCredential().get_token("scope") + getattr(AzureDeveloperCliCredential(), get_token_method)("scope") -def test_aadsts_error(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_aadsts_error(get_token_method): """When there is an AADSTS error, the credential should raise an error containing the CLI's output even if the error also contains the 'not logged in' string.""" @@ -122,46 +133,50 @@ def test_aadsts_error(): with mock.patch("shutil.which", return_value="azd"): with mock.patch(CHECK_OUTPUT, raise_called_process_error(1, stderr=stderr)): with pytest.raises(ClientAuthenticationError, match=stderr): - AzureDeveloperCliCredential().get_token("scope") + getattr(AzureDeveloperCliCredential(), get_token_method)("scope") -def test_unexpected_error(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_unexpected_error(get_token_method): """When the CLI returns an unexpected error, the credential should raise an error containing the CLI's output""" stderr = "something went wrong" with mock.patch("shutil.which", return_value="azd"): with mock.patch(CHECK_OUTPUT, raise_called_process_error(42, stderr=stderr)): with pytest.raises(ClientAuthenticationError, match=stderr): - AzureDeveloperCliCredential().get_token("scope") + getattr(AzureDeveloperCliCredential(), get_token_method)("scope") @pytest.mark.parametrize("output", TEST_ERROR_OUTPUTS) -def test_parsing_error_does_not_expose_token(output): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_parsing_error_does_not_expose_token(output, get_token_method): """Errors during CLI output parsing shouldn't expose access tokens in that output""" with mock.patch("shutil.which", return_value="azd"): with mock.patch(CHECK_OUTPUT, mock.Mock(return_value=output)): with pytest.raises(ClientAuthenticationError) as ex: - AzureDeveloperCliCredential().get_token("scope") + getattr(AzureDeveloperCliCredential(), get_token_method)("scope") assert "secret value" not in str(ex.value) assert "secret value" not in repr(ex.value) @pytest.mark.parametrize("output", TEST_ERROR_OUTPUTS) -def test_subprocess_error_does_not_expose_token(output): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_subprocess_error_does_not_expose_token(output, get_token_method): """Errors from the subprocess shouldn't expose access tokens in CLI output""" with mock.patch("shutil.which", return_value="azd"): with mock.patch(CHECK_OUTPUT, raise_called_process_error(1, output=output)): with pytest.raises(ClientAuthenticationError) as ex: - AzureDeveloperCliCredential().get_token("scope") + getattr(AzureDeveloperCliCredential(), get_token_method)("scope") assert "secret value" not in str(ex.value) assert "secret value" not in repr(ex.value) -def test_timeout(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_timeout(get_token_method): """The credential should raise CredentialUnavailableError when the subprocess times out""" from subprocess import TimeoutExpired @@ -169,7 +184,7 @@ def test_timeout(): with mock.patch("shutil.which", return_value="azd"): with mock.patch(CHECK_OUTPUT, mock.Mock(side_effect=TimeoutExpired("", 42))) as check_output_mock: with pytest.raises(CredentialUnavailableError): - AzureDeveloperCliCredential(process_timeout=42).get_token("scope") + getattr(AzureDeveloperCliCredential(process_timeout=42), get_token_method)("scope") # Ensure custom timeout is passed to subprocess _, kwargs = check_output_mock.call_args @@ -177,7 +192,8 @@ def test_timeout(): assert kwargs["timeout"] == 42 -def test_multitenant_authentication_class(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_multitenant_authentication_class(get_token_method): default_tenant = "first-tenant" first_token = "***" second_tenant = "second-tenant" @@ -199,17 +215,18 @@ def fake_check_output(command_line, **_): with mock.patch("shutil.which", return_value="azd"): with mock.patch(CHECK_OUTPUT, fake_check_output): - token = AzureDeveloperCliCredential().get_token("scope") + token = getattr(AzureDeveloperCliCredential(), get_token_method)("scope") assert token.token == first_token - token = AzureDeveloperCliCredential(tenant_id=default_tenant).get_token("scope") + token = getattr(AzureDeveloperCliCredential(tenant_id=default_tenant), get_token_method)("scope") assert token.token == first_token - token = AzureDeveloperCliCredential(tenant_id=second_tenant).get_token("scope") + token = getattr(AzureDeveloperCliCredential(tenant_id=second_tenant), get_token_method)("scope") assert token.token == second_token -def test_multitenant_authentication(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_multitenant_authentication(get_token_method): default_tenant = "first-tenant" first_token = "***" second_tenant = "second-tenant" @@ -232,21 +249,28 @@ def fake_check_output(command_line, **_): credential = AzureDeveloperCliCredential() with mock.patch("shutil.which", return_value="azd"): with mock.patch(CHECK_OUTPUT, fake_check_output): - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == first_token - token = credential.get_token("scope", tenant_id=default_tenant) + kwargs = {"tenant_id": default_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert token.token == first_token - token = credential.get_token("scope", tenant_id=second_tenant) + kwargs = {"tenant_id": second_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert token.token == second_token # should still default to the first tenant - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == first_token -def test_multitenant_authentication_not_allowed(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_multitenant_authentication_not_allowed(get_token_method): expected_tenant = "expected-tenant" expected_token = "***" @@ -266,9 +290,12 @@ def fake_check_output(command_line, **_): credential = AzureDeveloperCliCredential() with mock.patch("shutil.which", return_value="azd"): with mock.patch(CHECK_OUTPUT, fake_check_output): - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == expected_token with mock.patch.dict("os.environ", {EnvironmentVariables.AZURE_IDENTITY_DISABLE_MULTITENANTAUTH: "true"}): - token = credential.get_token("scope", tenant_id="un" + expected_tenant) + kwargs = {"tenant_id": "un" + expected_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert token.token == expected_token diff --git a/sdk/identity/azure-identity/tests/test_azd_cli_credential_async.py b/sdk/identity/azure-identity/tests/test_azd_cli_credential_async.py index 267a88ab85d6..95f7b2b2b143 100644 --- a/sdk/identity/azure-identity/tests/test_azd_cli_credential_async.py +++ b/sdk/identity/azure-identity/tests/test_azd_cli_credential_async.py @@ -16,7 +16,7 @@ from azure.core.exceptions import ClientAuthenticationError import pytest -from helpers import INVALID_CHARACTERS +from helpers import INVALID_CHARACTERS, GET_TOKEN_METHODS from helpers_async import get_completed_future from test_azd_cli_credential import TEST_ERROR_OUTPUTS @@ -33,14 +33,16 @@ async def communicate(): return mock.Mock(return_value=get_completed_future(process)) -async def test_no_scopes(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_no_scopes(get_token_method): """The credential should raise ValueError when get_token is called with no scopes""" with pytest.raises(ValueError): - await AzureDeveloperCliCredential().get_token() + await getattr(AzureDeveloperCliCredential(), get_token_method)() -async def test_invalid_tenant_id(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_invalid_tenant_id(get_token_method): """Invalid tenant IDs should raise ValueErrors.""" for c in INVALID_CHARACTERS: @@ -48,18 +50,22 @@ async def test_invalid_tenant_id(): AzureDeveloperCliCredential(tenant_id="tenant" + c) with pytest.raises(ValueError): - await AzureDeveloperCliCredential().get_token("scope", tenant_id="tenant" + c) + kwargs = {"tenant_id": "tenant" + c} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + await getattr(AzureDeveloperCliCredential(), get_token_method)("scope", **kwargs) -async def test_invalid_scopes(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_invalid_scopes(get_token_method): """Scopes with invalid characters should raise ValueErrors.""" for c in INVALID_CHARACTERS: with pytest.raises(ValueError): - await AzureDeveloperCliCredential().get_token("scope" + c) + await getattr(AzureDeveloperCliCredential(), get_token_method)("scope" + c) with pytest.raises(ValueError): - await AzureDeveloperCliCredential().get_token("scope", "scope2", "scope" + c) + await getattr(AzureDeveloperCliCredential(), get_token_method)("scope", "scope2", "scope" + c) async def test_close(): @@ -76,21 +82,25 @@ async def test_context_manager(): @pytest.mark.skipif(not sys.platform.startswith("win"), reason="tests Windows-specific behavior") -async def test_windows_fallback(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_windows_fallback(get_token_method): """The credential should fall back to the sync implementation when not using ProactorEventLoop on Windows""" sync_get_token = mock.Mock() with mock.patch("azure.identity.aio._credentials.azd_cli._SyncAzureDeveloperCliCredential") as fallback: - fallback.return_value = mock.Mock(spec_set=["get_token"], get_token=sync_get_token) + fallback.return_value = mock.Mock( + spec_set=["get_token", "get_token_info"], get_token=sync_get_token, get_token_info=sync_get_token + ) with mock.patch(AzureDeveloperCliCredential.__module__ + ".asyncio.get_event_loop"): # asyncio.get_event_loop now returns Mock, i.e. never ProactorEventLoop credential = AzureDeveloperCliCredential() - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") assert sync_get_token.call_count == 1 -async def test_get_token(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_get_token(get_token_method): """The credential should parse the CLI's output to an AccessToken""" access_token = "access token" @@ -108,33 +118,36 @@ async def test_get_token(): with mock.patch("shutil.which", return_value="azd"): with mock.patch(SUBPROCESS_EXEC, mock_exec(successful_output)): credential = AzureDeveloperCliCredential() - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == access_token assert type(token.expires_on) == int assert token.expires_on == expected_expires_on -async def test_cli_not_installed(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_cli_not_installed(get_token_method): """The credential should raise CredentialUnavailableError when the CLI isn't installed""" with mock.patch("shutil.which", return_value=None): with pytest.raises(CredentialUnavailableError, match=CLI_NOT_FOUND): credential = AzureDeveloperCliCredential() - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") -async def test_cannot_execute_shell(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_cannot_execute_shell(get_token_method): """The credential should raise CredentialUnavailableError when the subprocess doesn't start""" with mock.patch("shutil.which", return_value="azd"): with mock.patch(SUBPROCESS_EXEC, mock.Mock(side_effect=OSError())): with pytest.raises(CredentialUnavailableError): credential = AzureDeveloperCliCredential() - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") -async def test_not_logged_in(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_not_logged_in(get_token_method): """When the CLI isn't logged in, the credential should raise CredentialUnavailableError""" stderr = "ERROR: not logged in, run `azd auth login` to login" @@ -142,10 +155,11 @@ async def test_not_logged_in(): with mock.patch(SUBPROCESS_EXEC, mock_exec("", stderr, return_code=1)): with pytest.raises(CredentialUnavailableError, match=NOT_LOGGED_IN): credential = AzureDeveloperCliCredential() - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") -async def test_aadsts_error(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_aadsts_error(get_token_method): """When there is an AADSTS error, the credential should raise an error containing the CLI's output even if the error also contains the 'not logged in' string.""" @@ -154,10 +168,11 @@ async def test_aadsts_error(): with mock.patch(SUBPROCESS_EXEC, mock_exec("", stderr, return_code=1)): with pytest.raises(ClientAuthenticationError, match=stderr): credential = AzureDeveloperCliCredential() - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") -async def test_unexpected_error(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_unexpected_error(get_token_method): """When the CLI returns an unexpected error, the credential should raise an error containing the CLI's output""" stderr = "something went wrong" @@ -165,50 +180,54 @@ async def test_unexpected_error(): with mock.patch(SUBPROCESS_EXEC, mock_exec("", stderr, return_code=42)): with pytest.raises(ClientAuthenticationError, match=stderr): credential = AzureDeveloperCliCredential() - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") @pytest.mark.parametrize("output", TEST_ERROR_OUTPUTS) -async def test_parsing_error_does_not_expose_token(output): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_parsing_error_does_not_expose_token(output, get_token_method): """Errors during CLI output parsing shouldn't expose access tokens in that output""" with mock.patch("shutil.which", return_value="azd"): with mock.patch(SUBPROCESS_EXEC, mock_exec(output)): with pytest.raises(ClientAuthenticationError) as ex: credential = AzureDeveloperCliCredential() - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") assert "secret value" not in str(ex.value) assert "secret value" not in repr(ex.value) @pytest.mark.parametrize("output", TEST_ERROR_OUTPUTS) -async def test_subprocess_error_does_not_expose_token(output): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_subprocess_error_does_not_expose_token(output, get_token_method): """Errors from the subprocess shouldn't expose access tokens in CLI output""" with mock.patch("shutil.which", return_value="azd"): with mock.patch(SUBPROCESS_EXEC, mock_exec(output, return_code=1)): with pytest.raises(ClientAuthenticationError) as ex: credential = AzureDeveloperCliCredential() - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") assert "secret value" not in str(ex.value) assert "secret value" not in repr(ex.value) -async def test_timeout(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_timeout(get_token_method): """The credential should kill the subprocess after a timeout""" proc = mock.Mock(communicate=mock.Mock(side_effect=asyncio.TimeoutError), returncode=None) with mock.patch("shutil.which", return_value="azd"): with mock.patch(SUBPROCESS_EXEC, mock.Mock(return_value=get_completed_future(proc))): with pytest.raises(CredentialUnavailableError): - await AzureDeveloperCliCredential().get_token("scope") + await getattr(AzureDeveloperCliCredential(), get_token_method)("scope") assert proc.communicate.call_count == 1 -async def test_multitenant_authentication(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_multitenant_authentication(get_token_method): default_tenant = "first-tenant" first_token = "***" second_tenant = "second-tenant" @@ -232,21 +251,28 @@ async def fake_exec(*args, **_): credential = AzureDeveloperCliCredential() with mock.patch("shutil.which", return_value="azd"): with mock.patch(SUBPROCESS_EXEC, fake_exec): - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == first_token - token = await credential.get_token("scope", tenant_id=default_tenant) + kwargs = {"tenant_id": default_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(credential, get_token_method)("scope", **kwargs) assert token.token == first_token - token = await credential.get_token("scope", tenant_id=second_tenant) + kwargs = {"tenant_id": second_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(credential, get_token_method)("scope", **kwargs) assert token.token == second_token # should still default to the first tenant - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == first_token -async def test_multitenant_authentication_not_allowed(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_multitenant_authentication_not_allowed(get_token_method): expected_tenant = "expected-tenant" expected_token = "***" @@ -267,9 +293,12 @@ async def fake_exec(*args, **_): credential = AzureDeveloperCliCredential() with mock.patch("shutil.which", return_value="azd"): with mock.patch(SUBPROCESS_EXEC, fake_exec): - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == expected_token with mock.patch.dict("os.environ", {EnvironmentVariables.AZURE_IDENTITY_DISABLE_MULTITENANTAUTH: "true"}): - token = await credential.get_token("scope", tenant_id="un" + expected_tenant) + kwargs = {"tenant_id": "un" + expected_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(credential, get_token_method)("scope", **kwargs) assert token.token == expected_token diff --git a/sdk/identity/azure-identity/tests/test_azure_application.py b/sdk/identity/azure-identity/tests/test_azure_application.py index 0945a1a1823d..8772264ade76 100644 --- a/sdk/identity/azure-identity/tests/test_azure_application.py +++ b/sdk/identity/azure-identity/tests/test_azure_application.py @@ -2,10 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ -try: - from unittest.mock import patch -except ImportError: - from mock import patch # type: ignore +from unittest.mock import patch from azure.identity._credentials.application import AzureApplicationCredential from azure.identity._constants import EnvironmentVariables diff --git a/sdk/identity/azure-identity/tests/test_azure_arc.py b/sdk/identity/azure-identity/tests/test_azure_arc.py index 2f216db452f5..8234eecde7cb 100644 --- a/sdk/identity/azure-identity/tests/test_azure_arc.py +++ b/sdk/identity/azure-identity/tests/test_azure_arc.py @@ -11,8 +11,11 @@ from azure.core.exceptions import ClientAuthenticationError from azure.identity._credentials.azure_arc import AzureArcCredential +from helpers import GET_TOKEN_METHODS -def test_msal_managed_identity_error(): + +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_msal_managed_identity_error(get_token_method): scopes = ["scope1"] def mock_request_token(*args, **kwargs): @@ -22,4 +25,4 @@ def mock_request_token(*args, **kwargs): cred._msal_client.acquire_token_for_client = mock_request_token with pytest.raises(ClientAuthenticationError): - cred.get_token(*scopes) + getattr(cred, get_token_method)(*scopes) diff --git a/sdk/identity/azure-identity/tests/test_azure_pipelines_credential.py b/sdk/identity/azure-identity/tests/test_azure_pipelines_credential.py index 459f2e2251f6..1f9e27dab2d5 100644 --- a/sdk/identity/azure-identity/tests/test_azure_pipelines_credential.py +++ b/sdk/identity/azure-identity/tests/test_azure_pipelines_credential.py @@ -16,6 +16,8 @@ ) from azure.identity._credentials.azure_pipelines import SYSTEM_OIDCREQUESTURI, OIDC_API_VERSION, build_oidc_request +from helpers import GET_TOKEN_METHODS + def test_azure_pipelines_credential_initialize(): system_access_token = "token" @@ -76,7 +78,8 @@ def test_build_oidc_request(): assert request.headers["Authorization"] == f"Bearer {access_token}" -def test_azure_pipelines_credential_missing_system_env_var(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_azure_pipelines_credential_missing_system_env_var(get_token_method): credential = AzurePipelinesCredential( system_access_token="token", client_id="client-id", @@ -86,11 +89,12 @@ def test_azure_pipelines_credential_missing_system_env_var(): with patch.dict("os.environ", {}, clear=True): with pytest.raises(CredentialUnavailableError) as ex: - credential.get_token("scope") + getattr(credential, get_token_method)("scope") assert f"Missing value for the {SYSTEM_OIDCREQUESTURI} environment variable" in str(ex.value) -def test_azure_pipelines_credential_in_chain(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_azure_pipelines_credential_in_chain(get_token_method): mock_credential = MagicMock() with patch.dict("os.environ", {}, clear=True): @@ -103,12 +107,13 @@ def test_azure_pipelines_credential_in_chain(): ), mock_credential, ) - chain_credential.get_token("scope") - assert mock_credential.get_token.called + getattr(chain_credential, get_token_method)("scope") + assert getattr(mock_credential, get_token_method).called @pytest.mark.live_test_only("Requires Azure Pipelines environment with configured service connection") -def test_azure_pipelines_credential_authentication(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_azure_pipelines_credential_authentication(get_token_method): system_access_token = os.environ.get("SYSTEM_ACCESSTOKEN", "") service_connection_id = os.environ.get("AZURE_SERVICE_CONNECTION_ID", "") tenant_id = os.environ.get("AZURE_SERVICE_CONNECTION_TENANT_ID", "") @@ -126,6 +131,6 @@ def test_azure_pipelines_credential_authentication(): service_connection_id=service_connection_id, ) - token = credential.get_token(scope) + token = getattr(credential, get_token_method)(scope) assert token.token assert isinstance(token.expires_on, int) diff --git a/sdk/identity/azure-identity/tests/test_azure_pipelines_credential_async.py b/sdk/identity/azure-identity/tests/test_azure_pipelines_credential_async.py index cc3cf313a88b..471bc4f32189 100644 --- a/sdk/identity/azure-identity/tests/test_azure_pipelines_credential_async.py +++ b/sdk/identity/azure-identity/tests/test_azure_pipelines_credential_async.py @@ -11,6 +11,8 @@ from azure.identity._credentials.azure_pipelines import SYSTEM_OIDCREQUESTURI from azure.identity.aio import AzurePipelinesCredential, ChainedTokenCredential, ClientAssertionCredential +from helpers import GET_TOKEN_METHODS + def test_azure_pipelines_credential_initialize(): system_access_token = "token" @@ -57,7 +59,8 @@ async def test_azure_pipelines_credential_context_manager(): @pytest.mark.asyncio -async def test_azure_pipelines_credential_missing_system_env_var(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_azure_pipelines_credential_missing_system_env_var(get_token_method): credential = AzurePipelinesCredential( system_access_token="token", client_id="client-id", @@ -67,12 +70,13 @@ async def test_azure_pipelines_credential_missing_system_env_var(): with patch.dict("os.environ", {}, clear=True): with pytest.raises(CredentialUnavailableError) as ex: - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") assert f"Missing value for the {SYSTEM_OIDCREQUESTURI} environment variable" in str(ex.value) @pytest.mark.asyncio -async def test_azure_pipelines_credential_in_chain(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_azure_pipelines_credential_in_chain(get_token_method): mock_credential = AsyncMock() with patch.dict("os.environ", {}, clear=True): @@ -85,13 +89,14 @@ async def test_azure_pipelines_credential_in_chain(): ), mock_credential, ) - await chain_credential.get_token("scope") - assert mock_credential.get_token.called + await getattr(chain_credential, get_token_method)("scope") + assert getattr(mock_credential, get_token_method).called @pytest.mark.asyncio @pytest.mark.live_test_only("Requires Azure Pipelines environment with configured service connection") -async def test_azure_pipelines_credential_authentication(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_azure_pipelines_credential_authentication(get_token_method): system_access_token = os.environ.get("SYSTEM_ACCESSTOKEN", "") service_connection_id = os.environ.get("AZURE_SERVICE_CONNECTION_ID", "") tenant_id = os.environ.get("AZURE_SERVICE_CONNECTION_TENANT_ID", "") @@ -109,6 +114,6 @@ async def test_azure_pipelines_credential_authentication(): service_connection_id=service_connection_id, ) - token = await credential.get_token(scope) + token = await getattr(credential, get_token_method)(scope) assert token.token assert isinstance(token.expires_on, int) diff --git a/sdk/identity/azure-identity/tests/test_bearer_token_provider.py b/sdk/identity/azure-identity/tests/test_bearer_token_provider.py index f20ce7ac1d88..89cd5248645f 100644 --- a/sdk/identity/azure-identity/tests/test_bearer_token_provider.py +++ b/sdk/identity/azure-identity/tests/test_bearer_token_provider.py @@ -3,7 +3,7 @@ # Licensed under the MIT License. # ------------------------------------ -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo from azure.identity import get_bearer_token_provider @@ -14,7 +14,17 @@ def get_token(self, *scopes, **kwargs): return AccessToken("mock_token", 42) +class MockCredentialTokenInfo: + def get_token_info(self, *scopes, **kwargs): + assert len(scopes) == 1 + assert scopes[0] == "scope" + return AccessTokenInfo("mock_token_2", 42) + + def test_get_bearer_token_provider(): func = get_bearer_token_provider(MockCredential(), "scope") assert func() == "mock_token" + + func = get_bearer_token_provider(MockCredentialTokenInfo(), "scope") # type: ignore + assert func() == "mock_token_2" diff --git a/sdk/identity/azure-identity/tests/test_bearer_token_provider_async.py b/sdk/identity/azure-identity/tests/test_bearer_token_provider_async.py index 35a8db46457e..9fa6f67d5759 100644 --- a/sdk/identity/azure-identity/tests/test_bearer_token_provider_async.py +++ b/sdk/identity/azure-identity/tests/test_bearer_token_provider_async.py @@ -3,7 +3,7 @@ # Licensed under the MIT License. # ------------------------------------ -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo from azure.identity.aio import get_bearer_token_provider import pytest @@ -16,8 +16,18 @@ async def get_token(self, *scopes, **kwargs): return AccessToken("mock_token", 42) +class MockCredentialTokenInfo: + async def get_token_info(self, *scopes, **kwargs): + assert len(scopes) == 1 + assert scopes[0] == "scope" + return AccessTokenInfo("mock_token_2", 42) + + @pytest.mark.asyncio async def test_get_bearer_token_provider(): - func = get_bearer_token_provider(MockCredential(), "scope") + func = get_bearer_token_provider(MockCredential(), "scope") # type: ignore assert await func() == "mock_token" + + func = get_bearer_token_provider(MockCredentialTokenInfo(), "scope") # type: ignore + assert await func() == "mock_token_2" diff --git a/sdk/identity/azure-identity/tests/test_browser_credential.py b/sdk/identity/azure-identity/tests/test_browser_credential.py index 6cef0366c095..760af0a2a8d2 100644 --- a/sdk/identity/azure-identity/tests/test_browser_credential.py +++ b/sdk/identity/azure-identity/tests/test_browser_credential.py @@ -2,7 +2,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ -import platform import random import socket import threading @@ -18,13 +17,11 @@ from unittest.mock import ANY, Mock, patch from helpers import ( - build_aad_response, - build_id_token, get_discovery_response, - id_token_claims, mock_response, Request, validating_transport, + GET_TOKEN_METHODS, ) @@ -32,7 +29,8 @@ @pytest.mark.manual -def test_browser_credential(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_browser_credential(get_token_method): transport = Mock(wraps=RequestsTransport()) credential = InteractiveBrowserCredential(transport=transport) scope = "https://management.azure.com/.default" # N.B. this is valid only in Public Cloud @@ -45,15 +43,15 @@ def test_browser_credential(): # credential should have a cached access token for the scope used in authenticate with patch(WEBBROWSER_OPEN, Mock(side_effect=Exception("credential should authenticate silently"))): - token = credential.get_token(scope) + token = getattr(credential, get_token_method)(scope) assert token.token credential = InteractiveBrowserCredential(transport=transport) - token = credential.get_token(scope) + token = getattr(credential, get_token_method)(scope) assert token.token with patch(WEBBROWSER_OPEN, Mock(side_effect=Exception("credential should authenticate silently"))): - second_token = credential.get_token(scope) + second_token = getattr(credential, get_token_method)(scope) assert second_token.token == token.token # every request should have the correct User-Agent @@ -76,14 +74,16 @@ def test_tenant_id_validation(): InteractiveBrowserCredential(tenant_id=tenant) -def test_no_scopes(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_no_scopes(get_token_method): """The credential should raise when get_token is called with no scopes""" with pytest.raises(ValueError): - InteractiveBrowserCredential().get_token() + getattr(InteractiveBrowserCredential(), get_token_method)() -def test_policies_configurable(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_policies_configurable(get_token_method): # the policy raises an exception so this test can run without authenticating i.e. opening a browser expected_message = "test_policies_configurable" policy = Mock(spec_set=SansIOHTTPPolicy, on_request=Mock(side_effect=Exception(expected_message))) @@ -91,13 +91,14 @@ def test_policies_configurable(): credential = InteractiveBrowserCredential(policies=[policy]) with pytest.raises(ClientAuthenticationError) as ex: - credential.get_token("scope") + getattr(credential, get_token_method)("scope") assert expected_message in ex.value.message assert policy.on_request.called -def test_disable_automatic_authentication(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_disable_automatic_authentication(get_token_method): """When configured for strict silent auth, the credential should raise when silent auth fails""" transport = Mock(send=Mock(side_effect=Exception("no request should be sent"))) @@ -107,10 +108,11 @@ def test_disable_automatic_authentication(): with patch(WEBBROWSER_OPEN, Mock(side_effect=Exception("credential shouldn't try interactive authentication"))): with pytest.raises(AuthenticationRequiredError): - credential.get_token("scope") + getattr(credential, get_token_method)("scope") -def test_timeout(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_timeout(get_token_method): """get_token should raise ClientAuthenticationError when the server times out without receiving a redirect""" timeout = 0.01 @@ -133,11 +135,12 @@ def handle_request(self): with patch(WEBBROWSER_OPEN, lambda _: True): with pytest.raises(ClientAuthenticationError) as ex: - credential.get_token("scope") + getattr(credential, get_token_method)("scope") assert "timed out" in ex.value.message.lower() -def test_redirect_server(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_redirect_server(get_token_method): # binding a random port prevents races when running the test in parallel server = None hostname = "127.0.0.1" @@ -167,7 +170,8 @@ def test_redirect_server(): assert server.query_params[expected_param] == expected_value -def test_no_browser(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_no_browser(get_token_method): """The credential should raise CredentialUnavailableError when it can't open a browser""" transport = validating_transport(requests=[Request()] * 2, responses=[get_discovery_response()] * 2) @@ -176,10 +180,11 @@ def test_no_browser(): ) with patch(InteractiveBrowserCredential.__module__ + "._open_browser", lambda _: False): with pytest.raises(CredentialUnavailableError, match=r".*browser.*"): - credential.get_token("scope") + getattr(credential, get_token_method)("scope") -def test_redirect_uri(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_redirect_uri(get_token_method): """The credential should configure the redirect server to use a given redirect_uri""" expected_hostname = "localhost" @@ -192,7 +197,7 @@ def test_redirect_uri(): client_credential="client_credential", ) with pytest.raises(ClientAuthenticationError) as ex: - credential.get_token("scope") + getattr(credential, get_token_method)("scope") assert expected_message in ex.value.message server.assert_called_once_with(expected_hostname, expected_port, timeout=ANY) @@ -206,17 +211,19 @@ def test_invalid_redirect_uri(redirect_uri): InteractiveBrowserCredential(redirect_uri=redirect_uri, client_credential="client_credential") -def test_cannot_bind_port(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_cannot_bind_port(get_token_method): """get_token should raise CredentialUnavailableError when the redirect listener can't bind a port""" credential = InteractiveBrowserCredential( _server_class=Mock(side_effect=socket.error), client_credential="client_credential" ) with pytest.raises(CredentialUnavailableError): - credential.get_token("scope") + getattr(credential, get_token_method)("scope") -def test_cannot_bind_redirect_uri(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_cannot_bind_redirect_uri(get_token_method): """When a user specifies a redirect URI, the credential shouldn't attempt to bind another""" server = Mock(side_effect=socket.error) @@ -225,6 +232,6 @@ def test_cannot_bind_redirect_uri(): ) with pytest.raises(CredentialUnavailableError): - credential.get_token("scope") + getattr(credential, get_token_method)("scope") server.assert_called_once_with("localhost", 42, timeout=ANY) diff --git a/sdk/identity/azure-identity/tests/test_certificate_credential.py b/sdk/identity/azure-identity/tests/test_certificate_credential.py index a9352b85a579..a00443db4045 100644 --- a/sdk/identity/azure-identity/tests/test_certificate_credential.py +++ b/sdk/identity/azure-identity/tests/test_certificate_credential.py @@ -4,8 +4,8 @@ # ------------------------------------ import json import os +from unittest.mock import Mock, patch -from azure.core.exceptions import ClientAuthenticationError from azure.core.pipeline.policies import ContentDecodePolicy, SansIOHTTPPolicy from azure.identity import CertificateCredential, TokenCachePersistenceOptions from azure.identity._enums import RegionalAuthority @@ -17,7 +17,6 @@ from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.asymmetric import padding from msal import TokenCache -from msal_extensions import PersistedTokenCache import msal import pytest from urllib.parse import urlparse @@ -29,16 +28,11 @@ get_discovery_response, urlsafeb64_decode, mock_response, - msal_validating_transport, new_msal_validating_transport, Request, + GET_TOKEN_METHODS, ) -try: - from unittest.mock import Mock, patch -except ImportError: # python < 3.3 - from mock import Mock, patch # type: ignore - PEM_CERT_PATH = os.path.join(os.path.dirname(__file__), "certificate.pem") PEM_CERT_WITH_PASSWORD_PATH = os.path.join(os.path.dirname(__file__), "certificate-with-password.pem") PFX_CERT_PATH = os.path.join(os.path.dirname(__file__), "certificate.pfx") @@ -77,15 +71,17 @@ def test_tenant_id_validation(): CertificateCredential(tenant, "client-id", PEM_CERT_PATH) -def test_no_scopes(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_no_scopes(get_token_method): """The credential should raise ValueError when get_token is called with no scopes""" credential = CertificateCredential("tenant-id", "client-id", PEM_CERT_PATH) with pytest.raises(ValueError): - credential.get_token() + getattr(credential, get_token_method)() -def test_policies_configurable(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_policies_configurable(get_token_method): policy = Mock(spec_set=SansIOHTTPPolicy, on_request=Mock()) transport = new_msal_validating_transport( @@ -96,12 +92,13 @@ def test_policies_configurable(): "tenant-id", "client-id", PEM_CERT_PATH, policies=[ContentDecodePolicy(), policy], transport=transport ) - credential.get_token("scope") + getattr(credential, get_token_method)("scope") assert policy.on_request.called -def test_user_agent(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_user_agent(get_token_method): transport = new_msal_validating_transport( requests=[Request(required_headers={"User-Agent": USER_AGENT})], responses=[mock_response(json_payload=build_aad_response(access_token="**"))], @@ -109,10 +106,11 @@ def test_user_agent(): credential = CertificateCredential("tenant-id", "client-id", PEM_CERT_PATH, transport=transport) - credential.get_token("scope") + getattr(credential, get_token_method)("scope") -def test_tenant_id(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_tenant_id(get_token_method): transport = new_msal_validating_transport( requests=[Request(required_headers={"User-Agent": USER_AGENT})], responses=[mock_response(json_payload=build_aad_response(access_token="**"))], @@ -122,11 +120,15 @@ def test_tenant_id(): "tenant-id", "client-id", PEM_CERT_PATH, transport=transport, additionally_allowed_tenants=["*"] ) - credential.get_token("scope", tenant_id="tenant_id") + kwargs = {"tenant_id": "tenant_id"} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + getattr(credential, get_token_method)("scope", **kwargs) @pytest.mark.parametrize("authority", ("localhost", "https://localhost")) -def test_authority(authority): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_authority(authority, get_token_method): """the credential should accept an authority, with or without scheme, as an argument or environment variable""" tenant_id = "expected-tenant" @@ -141,7 +143,7 @@ def test_authority(authority): credential = CertificateCredential(tenant_id, "client-id", PEM_CERT_PATH, authority=authority) with patch("msal.ConfidentialClientApplication", mock_ctor): # must call get_token because the credential constructs the MSAL application lazily - credential.get_token("scope") + getattr(credential, get_token_method)("scope") assert mock_ctor.call_count == 1 _, kwargs = mock_ctor.call_args @@ -152,14 +154,15 @@ def test_authority(authority): with patch.dict("os.environ", {EnvironmentVariables.AZURE_AUTHORITY_HOST: authority}, clear=True): credential = CertificateCredential(tenant_id, "client-id", PEM_CERT_PATH, authority=authority) with patch("msal.ConfidentialClientApplication", mock_ctor): - credential.get_token("scope") + getattr(credential, get_token_method)("scope") assert mock_ctor.call_count == 1 _, kwargs = mock_ctor.call_args assert kwargs["authority"] == expected_authority -def test_regional_authority(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_regional_authority(get_token_method): """the credential should configure MSAL with a regional authority specified via kwarg or environment variable""" mock_confidential_client = Mock( @@ -173,7 +176,7 @@ def test_regional_authority(): with patch.dict("os.environ", {EnvironmentVariables.AZURE_REGIONAL_AUTHORITY_NAME: region.value}, clear=True): credential = CertificateCredential("tenant", "client-id", PEM_CERT_PATH) with patch("msal.ConfidentialClientApplication", mock_confidential_client): - credential.get_token("scope") + getattr(credential, get_token_method)("scope") assert mock_confidential_client.call_count == 1 _, kwargs = mock_confidential_client.call_args @@ -200,7 +203,8 @@ def test_requires_certificate(): @pytest.mark.parametrize("cert_path,cert_password", ALL_CERTS) @pytest.mark.parametrize("send_certificate_chain", (True, False)) -def test_request_body(cert_path, cert_password, send_certificate_chain): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_request_body(cert_path, cert_password, send_certificate_chain, get_token_method): access_token = "***" authority = "authority.com" client_id = "client-id" @@ -228,7 +232,7 @@ def mock_send(request, **kwargs): authority=authority, send_certificate_chain=send_certificate_chain, ) - token = cred.get_token(expected_scope) + token = getattr(cred, get_token_method)(expected_scope) assert token.token == access_token # credential should also accept the certificate as bytes @@ -244,7 +248,7 @@ def mock_send(request, **kwargs): authority=authority, send_certificate_chain=send_certificate_chain, ) - token = cred.get_token(expected_scope) + token = getattr(cred, get_token_method)(expected_scope) assert token.token == access_token @@ -294,7 +298,8 @@ def validate_jwt(request, client_id, cert_bytes, cert_password, expect_x5c=False @pytest.mark.parametrize("cert_path,cert_password", ALL_CERTS) -def test_token_cache_persistent(cert_path, cert_password): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_token_cache_persistent(cert_path, cert_password, get_token_method): """the credential should use a persistent cache if cache_persistence_options are configured""" access_token = "foo token" @@ -324,19 +329,23 @@ def send(request, **kwargs): assert credential._cache is None assert credential._cae_cache is None - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == access_token assert load_persistent_cache.call_count == 1 assert credential._cache is not None assert credential._cae_cache is None - token = credential.get_token("scope", enable_cae=True) + kwargs = {"enable_cae": True} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert load_persistent_cache.call_count == 2 assert credential._cae_cache is not None @pytest.mark.parametrize("cert_path,cert_password", ALL_CERTS) -def test_token_cache_memory(cert_path, cert_password): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_token_cache_memory(cert_path, cert_password, get_token_method): """The credential should default to in-memory cache if no persistence options are provided.""" access_token = "foo token" @@ -356,19 +365,23 @@ def send(request, **kwargs): ) assert credential._cache is None - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == access_token assert isinstance(credential._cache, TokenCache) assert credential._cae_cache is None assert not load_persistent_cache.called - token = credential.get_token("scope", enable_cae=True) + kwargs = {"enable_cae": True} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert isinstance(credential._cae_cache, TokenCache) assert not load_persistent_cache.called @pytest.mark.parametrize("cert_path,cert_password", ALL_CERTS) -def test_persistent_cache_multiple_clients(cert_path, cert_password): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_persistent_cache_multiple_clients(cert_path, cert_password, get_token_method): """the credential shouldn't use tokens issued to other service principals""" access_token_a = "token a" @@ -403,13 +416,13 @@ def test_persistent_cache_multiple_clients(cert_path, cert_password): # A caches a token scope = "scope" - token_a = credential_a.get_token(scope) + token_a = getattr(credential_a, get_token_method)(scope) assert mock_cache_loader.call_count == 1, "credential should use the persistent cache" assert token_a.token == access_token_a assert transport_a.send.call_count == 2 # one MSAL discovery request, one token request # B should get a different token for the same scope - token_b = credential_b.get_token(scope) + token_b = getattr(credential_b, get_token_method)(scope) assert mock_cache_loader.call_count == 2, "credential should load the persistent cache" assert token_b.token == access_token_b assert transport_b.send.call_count == 2 @@ -427,7 +440,8 @@ def test_certificate_arguments(): @pytest.mark.parametrize("cert_path,cert_password", ALL_CERTS) -def test_multitenant_authentication(cert_path, cert_password): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_multitenant_authentication(cert_path, cert_password, get_token_method): first_tenant = "first-tenant" first_token = "***" second_tenant = "second-tenant" @@ -454,22 +468,29 @@ def send(request, **kwargs): transport=Mock(send=send), additionally_allowed_tenants=["*"], ) - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == first_token - token = credential.get_token("scope", tenant_id=first_tenant) + kwargs = {"tenant_id": first_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert token.token == first_token - token = credential.get_token("scope", tenant_id=second_tenant) + kwargs = {"tenant_id": second_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert token.token == second_token # should still default to the first tenant - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == first_token @pytest.mark.parametrize("cert_path,cert_password", ALL_CERTS) -def test_multitenant_authentication_backcompat(cert_path, cert_password): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_multitenant_authentication_backcompat(cert_path, cert_password, get_token_method): expected_tenant = "expected-tenant" expected_token = "***" @@ -494,14 +515,20 @@ def send(request, **kwargs): additionally_allowed_tenants=["*"], ) - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == expected_token + kwargs = {"tenant_id": expected_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} # explicitly specifying the configured tenant is okay - token = credential.get_token("scope", tenant_id=expected_tenant) + token = getattr(credential, get_token_method)("scope", **kwargs) assert token.token == expected_token - token = credential.get_token("scope", tenant_id="un" + expected_tenant) + kwargs = {"tenant_id": "un" + expected_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert token.token == expected_token @@ -524,7 +551,8 @@ def test_client_capabilities(): assert kwargs["client_capabilities"] == ["CP1"] -def test_claims_challenge(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_claims_challenge(get_token_method): """get_token should pass any claims challenge to MSAL token acquisition APIs""" msal_acquire_token_result = dict( @@ -540,7 +568,10 @@ def test_claims_challenge(): msal_app.acquire_token_silent_with_error.return_value = None msal_app.acquire_token_for_client.return_value = msal_acquire_token_result - credential.get_token("scope", claims=expected_claims) + kwargs = {"claims": expected_claims} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + getattr(credential, get_token_method)("scope", **kwargs) assert msal_app.acquire_token_silent_with_error.call_count == 1 args, kwargs = msal_app.acquire_token_silent_with_error.call_args diff --git a/sdk/identity/azure-identity/tests/test_certificate_credential_async.py b/sdk/identity/azure-identity/tests/test_certificate_credential_async.py index 79786bb0216e..a55d593e8730 100644 --- a/sdk/identity/azure-identity/tests/test_certificate_credential_async.py +++ b/sdk/identity/azure-identity/tests/test_certificate_credential_async.py @@ -5,7 +5,6 @@ from unittest.mock import Mock, patch from urllib.parse import urlparse -from azure.core.exceptions import ClientAuthenticationError from azure.core.pipeline.policies import ContentDecodePolicy, SansIOHTTPPolicy from azure.identity import TokenCachePersistenceOptions from azure.identity._constants import EnvironmentVariables @@ -15,7 +14,7 @@ from msal import TokenCache import pytest -from helpers import build_aad_response, mock_response, Request +from helpers import build_aad_response, mock_response, Request, GET_TOKEN_METHODS from helpers_async import async_validating_transport, AsyncMockTransport from test_certificate_credential import ALL_CERTS, EC_CERT_PATH, PEM_CERT_PATH, validate_jwt @@ -42,12 +41,13 @@ def test_tenant_id_validation(): @pytest.mark.asyncio -async def test_no_scopes(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_no_scopes(get_token_method): """The credential should raise ValueError when get_token is called with no scopes""" credential = CertificateCredential("tenant-id", "client-id", PEM_CERT_PATH) with pytest.raises(ValueError): - await credential.get_token() + await getattr(credential, get_token_method)() @pytest.mark.asyncio @@ -73,7 +73,8 @@ async def test_context_manager(): @pytest.mark.asyncio -async def test_policies_configurable(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_policies_configurable(get_token_method): policy = Mock(spec_set=SansIOHTTPPolicy, on_request=Mock()) async def send(*_, **kwargs): @@ -86,13 +87,14 @@ async def send(*_, **kwargs): "tenant-id", "client-id", PEM_CERT_PATH, policies=[ContentDecodePolicy(), policy], transport=Mock(send=send) ) - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") assert policy.on_request.called @pytest.mark.asyncio -async def test_user_agent(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_user_agent(get_token_method): transport = async_validating_transport( requests=[Request(required_headers={"User-Agent": USER_AGENT})], responses=[mock_response(json_payload=build_aad_response(access_token="**"))], @@ -100,11 +102,12 @@ async def test_user_agent(): credential = CertificateCredential("tenant-id", "client-id", PEM_CERT_PATH, transport=transport) - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") @pytest.mark.asyncio -async def test_tenant_id(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_tenant_id(get_token_method): transport = async_validating_transport( requests=[Request(required_headers={"User-Agent": USER_AGENT})], responses=[mock_response(json_payload=build_aad_response(access_token="**"))], @@ -113,14 +116,17 @@ async def test_tenant_id(): credential = CertificateCredential( "tenant-id", "client-id", PEM_CERT_PATH, transport=transport, additionally_allowed_tenants=["*"] ) - - await credential.get_token("scope", tenant_id="tenant_id") + kwargs = {"tenant_id": "tenant_id"} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + await getattr(credential, get_token_method)("scope", **kwargs) @pytest.mark.asyncio @pytest.mark.parametrize("authority", ("localhost", "https://localhost")) @pytest.mark.parametrize("cert_path,cert_password", ALL_CERTS) -async def test_request_url(cert_path, cert_password, authority): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_request_url(cert_path, cert_password, authority, get_token_method): """the credential should accept an authority, with or without scheme, as an argument or environment variable""" tenant_id = "expected-tenant" @@ -138,7 +144,7 @@ async def mock_send(request, **kwargs): cred = CertificateCredential( tenant_id, "client-id", cert_path, password=cert_password, transport=Mock(send=mock_send), authority=authority ) - token = await cred.get_token("scope") + token = await getattr(cred, get_token_method)("scope") assert token.token == access_token # authority can be configured via environment variable @@ -146,7 +152,7 @@ async def mock_send(request, **kwargs): credential = CertificateCredential( tenant_id, "client-id", cert_path, password=cert_password, transport=Mock(send=mock_send) ) - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") assert token.token == access_token @@ -167,7 +173,8 @@ def test_requires_certificate(): @pytest.mark.asyncio @pytest.mark.parametrize("cert_path,cert_password", ALL_CERTS) -async def test_request_body(cert_path, cert_password): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_request_body(cert_path, cert_password, get_token_method): access_token = "***" authority = "authority.com" client_id = "client-id" @@ -186,7 +193,7 @@ async def mock_send(request, **kwargs): cred = CertificateCredential( tenant_id, client_id, cert_path, password=cert_password, transport=Mock(send=mock_send), authority=authority ) - token = await cred.get_token(expected_scope) + token = await getattr(cred, get_token_method)(expected_scope) assert token.token == access_token # credential should also accept the certificate as bytes @@ -201,13 +208,14 @@ async def mock_send(request, **kwargs): transport=Mock(send=mock_send), authority=authority, ) - token = await cred.get_token(expected_scope) + token = await getattr(cred, get_token_method)(expected_scope) assert token.token == access_token @pytest.mark.asyncio @pytest.mark.parametrize("cert_path,cert_password", ALL_CERTS) -async def test_token_cache_memory(cert_path, cert_password): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_token_cache_memory(cert_path, cert_password, get_token_method): """the credential should optionally use a persistent cache, and default to an in memory cache""" access_token = "token" @@ -227,13 +235,16 @@ async def test_token_cache_memory(cert_path, cert_password): assert not mock_token_cache.called assert not load_persistent_cache.called - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") assert mock_token_cache.call_count == 1 assert load_persistent_cache.call_count == 0 assert credential._client._cache is not None assert credential._client._cae_cache is None - await credential.get_token("scope", enable_cae=True) + kwargs = {"enable_cae": True} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + await getattr(credential, get_token_method)("scope", **kwargs) assert mock_token_cache.call_count == 2 assert load_persistent_cache.call_count == 0 assert credential._client._cae_cache is not None @@ -241,7 +252,8 @@ async def test_token_cache_memory(cert_path, cert_password): @pytest.mark.asyncio @pytest.mark.parametrize("cert_path,cert_password", ALL_CERTS) -async def test_token_cache_persistent(cert_path, cert_password): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_token_cache_persistent(cert_path, cert_password, get_token_method): """the credential should optionally use a persistent cache, and default to an in memory cache""" access_token = "token" @@ -267,14 +279,17 @@ async def test_token_cache_persistent(cert_path, cert_password): assert not mock_token_cache.called assert not load_persistent_cache.called - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") assert load_persistent_cache.call_count == 1 assert credential._client._cache is not None assert credential._client._cae_cache is None args, _ = load_persistent_cache.call_args assert args[1] is False - await credential.get_token("scope", enable_cae=True) + kwargs = {"enable_cae": True} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + await getattr(credential, get_token_method)("scope", **kwargs) assert load_persistent_cache.call_count == 2 assert credential._client._cae_cache is not None args, _ = load_persistent_cache.call_args @@ -284,7 +299,8 @@ async def test_token_cache_persistent(cert_path, cert_password): @pytest.mark.asyncio @pytest.mark.parametrize("cert_path,cert_password", ALL_CERTS) -async def test_persistent_cache_multiple_clients(cert_path, cert_password): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_persistent_cache_multiple_clients(cert_path, cert_password, get_token_method): """the credential shouldn't use tokens issued to other service principals""" access_token_a = "token a" @@ -321,7 +337,7 @@ async def test_persistent_cache_multiple_clients(cert_path, cert_password): # A caches a token scope = "scope" - token_a = await credential_a.get_token(scope) + token_a = await getattr(credential_a, get_token_method)(scope) assert token_a.token == access_token_a assert transport_a.send.call_count == 1 assert mock_cache_loader.call_count == 1 @@ -329,7 +345,7 @@ async def test_persistent_cache_multiple_clients(cert_path, cert_password): assert args[1] is False # not CAE # B should get a different token for the same scope - token_b = await credential_b.get_token(scope) + token_b = await getattr(credential_b, get_token_method)(scope) assert token_b.token == access_token_b assert transport_b.send.call_count == 1 assert mock_cache_loader.call_count == 2 @@ -348,7 +364,8 @@ def test_certificate_arguments(): @pytest.mark.asyncio @pytest.mark.parametrize("cert_path,cert_password", ALL_CERTS) -async def test_multitenant_authentication(cert_path, cert_password): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_multitenant_authentication(cert_path, cert_password, get_token_method): first_tenant = "first-tenant" first_token = "***" second_tenant = "second-tenant" @@ -372,23 +389,30 @@ async def send(request, **kwargs): transport=Mock(send=send), additionally_allowed_tenants=["*"], ) - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == first_token - token = await credential.get_token("scope", tenant_id=first_tenant) + kwargs = {"tenant_id": first_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(credential, get_token_method)("scope", **kwargs) assert token.token == first_token - token = await credential.get_token("scope", tenant_id=second_tenant) + kwargs = {"tenant_id": second_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(credential, get_token_method)("scope", **kwargs) assert token.token == second_token # should still default to the first tenant - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == first_token @pytest.mark.asyncio @pytest.mark.parametrize("cert_path,cert_password", ALL_CERTS) -async def test_multitenant_authentication_backcompat(cert_path, cert_password): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_multitenant_authentication_backcompat(cert_path, cert_password, get_token_method): expected_tenant = "expected-tenant" expected_token = "***" @@ -410,12 +434,18 @@ async def send(request, **kwargs): additionally_allowed_tenants=["*"], ) - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == expected_token + kwargs = {"tenant_id": expected_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} # explicitly specifying the configured tenant is okay - token = await credential.get_token("scope", tenant_id=expected_tenant) + token = await getattr(credential, get_token_method)("scope", **kwargs) assert token.token == expected_token - token = await credential.get_token("scope", tenant_id="un" + expected_tenant) + kwargs = {"tenant_id": "un" + expected_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(credential, get_token_method)("scope", **kwargs) assert token.token == expected_token * 2 diff --git a/sdk/identity/azure-identity/tests/test_chained_credential.py b/sdk/identity/azure-identity/tests/test_chained_credential.py index ea9efbbf43f3..3ccbe2e2849d 100644 --- a/sdk/identity/azure-identity/tests/test_chained_credential.py +++ b/sdk/identity/azure-identity/tests/test_chained_credential.py @@ -5,7 +5,7 @@ import time from unittest.mock import Mock, MagicMock, patch -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo from azure.core.exceptions import ClientAuthenticationError from azure.identity._credentials.imds import IMDS_TOKEN_PATH, IMDS_AUTHORITY from azure.identity._internal.user_agent import USER_AGENT @@ -17,7 +17,7 @@ ) import pytest -from helpers import validating_transport, Request, mock_response +from helpers import validating_transport, Request, mock_response, GET_TOKEN_METHODS def test_close(): @@ -51,74 +51,116 @@ def test_context_manager(): assert credential.__exit__.call_count == 1 -def test_error_message(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_error_message(get_token_method): first_error = "first_error" first_credential = Mock( - spec=ClientSecretCredential, get_token=Mock(side_effect=CredentialUnavailableError(first_error)) + spec=ClientSecretCredential, + get_token=Mock(side_effect=CredentialUnavailableError(first_error)), + get_token_info=Mock(side_effect=CredentialUnavailableError(first_error)), ) second_error = "second_error" second_credential = Mock( - name="second_credential", get_token=Mock(side_effect=ClientAuthenticationError(second_error)) + name="second_credential", + get_token=Mock(side_effect=ClientAuthenticationError(second_error)), + get_token_info=Mock(side_effect=ClientAuthenticationError(second_error)), ) with pytest.raises(ClientAuthenticationError) as ex: - ChainedTokenCredential(first_credential, second_credential).get_token("scope") + chained_cred = ChainedTokenCredential(first_credential, second_credential) + getattr(chained_cred, get_token_method)("scope") assert "ClientSecretCredential" in ex.value.message assert first_error in ex.value.message assert second_error in ex.value.message -def test_attempts_all_credentials(): - expected_token = AccessToken("expected_token", 0) +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_attempts_all_credentials(get_token_method): + expected_token = "expected_token" + expires_on = 42 credentials = [ - Mock(spec_set=["get_token"], get_token=Mock(side_effect=CredentialUnavailableError(message=""))), - Mock(spec_set=["get_token"], get_token=Mock(side_effect=CredentialUnavailableError(message=""))), - Mock(spec_set=["get_token"], get_token=Mock(return_value=expected_token)), + Mock( + spec_set=["get_token", "get_token_info"], + get_token=Mock(side_effect=CredentialUnavailableError(message="")), + get_token_info=Mock(side_effect=CredentialUnavailableError(message="")), + ), + Mock( + spec_set=["get_token", "get_token_info"], + get_token=Mock(side_effect=CredentialUnavailableError(message="")), + get_token_info=Mock(side_effect=CredentialUnavailableError(message="")), + ), + Mock( + spec_set=["get_token", "get_token_info"], + get_token=Mock(return_value=AccessToken(expected_token, expires_on)), + get_token_info=Mock(return_value=AccessTokenInfo(expected_token, expires_on)), + ), ] - token = ChainedTokenCredential(*credentials).get_token("scope") - assert token is expected_token + token = getattr(ChainedTokenCredential(*credentials), get_token_method)("scope") + assert token.token == expected_token for credential in credentials: - assert credential.get_token.call_count == 1 + assert getattr(credential, get_token_method).call_count == 1 -def test_raises_for_unexpected_error(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_raises_for_unexpected_error(get_token_method): """the chain should not continue after an unexpected error (i.e. anything but CredentialUnavailableError)""" expected_message = "it can't be done" credentials = [ - Mock(spec_set=["get_token"], get_token=Mock(side_effect=CredentialUnavailableError(message=""))), - Mock(spec_set=["get_token"], get_token=Mock(side_effect=ValueError(expected_message))), - Mock(spec_set=["get_token"], get_token=Mock(return_value=AccessToken("**", 42))), + Mock( + spec_set=["get_token", "get_token_info"], + get_token=Mock(side_effect=CredentialUnavailableError(message="")), + get_token_info=Mock(side_effect=CredentialUnavailableError(message="")), + ), + Mock( + spec_set=["get_token", "get_token_info"], + get_token=Mock(side_effect=ValueError(expected_message)), + get_token_info=Mock(side_effect=ValueError(expected_message)), + ), + Mock( + spec_set=["get_token", "get_token_info"], + get_token=Mock(return_value=AccessToken("**", 42)), + get_token_info=Mock(return_value=AccessTokenInfo("**", 42)), + ), ] with pytest.raises(ClientAuthenticationError) as ex: - ChainedTokenCredential(*credentials).get_token("scope") + getattr(ChainedTokenCredential(*credentials), get_token_method)("scope") assert expected_message in ex.value.message - assert credentials[-1].get_token.call_count == 0 + assert getattr(credentials[-1], get_token_method).call_count == 0 -def test_returns_first_token(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_returns_first_token(get_token_method): expected_token = Mock() - first_credential = Mock(spec_set=["get_token"], get_token=lambda _, **__: expected_token) - second_credential = Mock(spec_set=["get_token"], get_token=Mock()) + first_credential = Mock( + spec_set=["get_token", "get_token_info"], + get_token=lambda _, **__: expected_token, + get_token_info=lambda _, **__: expected_token, + ) + second_credential = Mock( + spec_set=["get_token", "get_token_info"], + get_token=Mock(), + get_token_info=Mock(), + ) aggregate = ChainedTokenCredential(first_credential, second_credential) - credential = aggregate.get_token("scope") + token = getattr(aggregate, get_token_method)("scope") - assert credential is expected_token - assert second_credential.get_token.call_count == 0 + assert token.token == expected_token.token + assert getattr(second_credential, get_token_method).call_count == 0 -def test_managed_identity_imds_probe(): - access_token = "****" +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_managed_identity_imds_probe(get_token_method): + expected_token = "****" expires_on = 42 - expected_token = AccessToken(access_token, expires_on) scope = "scope" transport = validating_transport( requests=[ @@ -135,7 +177,7 @@ def test_managed_identity_imds_probe(): mock_response(status_code=400, json_payload={"error": "this is an error message"}), mock_response( json_payload={ - "access_token": access_token, + "access_token": expected_token, "expires_in": 42, "expires_on": expires_on, "ext_expires_in": 42, @@ -149,28 +191,63 @@ def test_managed_identity_imds_probe(): with patch.dict("os.environ", clear=True): credentials = [ - Mock(spec_set=["get_token"], get_token=Mock(side_effect=CredentialUnavailableError(message=""))), + Mock( + spec_set=["get_token", "get_token_info"], + get_token=Mock(side_effect=CredentialUnavailableError(message="")), + get_token_info=Mock(side_effect=CredentialUnavailableError(message="")), + ), ManagedIdentityCredential(transport=transport), ] - token = ChainedTokenCredential(*credentials).get_token(scope) - assert token.token == expected_token.token + token = getattr(ChainedTokenCredential(*credentials), get_token_method)(scope) + assert token.token == expected_token -def test_managed_identity_failed_probe(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_managed_identity_failed_probe(get_token_method): mock_send = Mock(side_effect=Exception("timeout")) transport = Mock(send=mock_send) expected_token = Mock() credentials = [ - Mock(spec_set=["get_token"], get_token=Mock(side_effect=CredentialUnavailableError(message=""))), + Mock( + spec_set=["get_token", "get_token_info"], + get_token=Mock(side_effect=CredentialUnavailableError(message="")), + get_token_info=Mock(side_effect=CredentialUnavailableError(message="")), + ), ManagedIdentityCredential(transport=transport), - Mock(spec_set=["get_token"], get_token=Mock(return_value=expected_token)), + Mock( + spec_set=["get_token", "get_token_info"], + get_token=Mock(return_value=expected_token), + get_token_info=Mock(return_value=expected_token), + ), ] with patch.dict("os.environ", clear=True): - token = ChainedTokenCredential(*credentials).get_token("scope") + token = getattr(ChainedTokenCredential(*credentials), get_token_method)("scope") - assert token is expected_token + assert token.token == expected_token.token # ManagedIdentityCredential should be tried and skipped with the last credential in the chain # being used. - assert credentials[-1].get_token.call_count == 1 + assert getattr(credentials[-1], get_token_method).call_count == 1 + + +def test_credentials_with_no_get_token_info(): + """ChainedTokenCredential should work with credentials that don't implement get_token_info.""" + + access_token = "****" + credential1 = Mock( + spec_set=["get_token_info"], + get_token_info=Mock(side_effect=CredentialUnavailableError(message="")), + ) + credential2 = Mock( + spec_set=["get_token"], + get_token=Mock(return_value=AccessToken(access_token, 42)), + ) + credential3 = Mock( + spec_set=["get_token", "get_token_info"], + get_token=Mock(return_value=AccessToken("foo", 42)), + get_token_info=Mock(return_value=AccessTokenInfo("bar", 42)), + ) + chain = ChainedTokenCredential(credential1, credential2, credential3) # type: ignore + token_info = chain.get_token_info("scope") + assert token_info.token == access_token diff --git a/sdk/identity/azure-identity/tests/test_chained_token_credential_async.py b/sdk/identity/azure-identity/tests/test_chained_token_credential_async.py index b145e3f93ba8..2f541b859684 100644 --- a/sdk/identity/azure-identity/tests/test_chained_token_credential_async.py +++ b/sdk/identity/azure-identity/tests/test_chained_token_credential_async.py @@ -5,7 +5,7 @@ import time from unittest.mock import Mock, patch -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo from azure.core.exceptions import ClientAuthenticationError from azure.identity import CredentialUnavailableError, ClientSecretCredential from azure.identity.aio import ChainedTokenCredential, ManagedIdentityCredential @@ -13,7 +13,7 @@ from azure.identity._internal.user_agent import USER_AGENT import pytest -from helpers import mock_response, Request +from helpers import mock_response, Request, GET_TOKEN_METHODS from helpers_async import get_completed_future, wrap_in_future, async_validating_transport @@ -41,18 +41,23 @@ async def test_context_manager(): @pytest.mark.asyncio -async def test_credential_chain_error_message(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_credential_chain_error_message(get_token_method): first_error = "first_error" first_credential = Mock( - spec=ClientSecretCredential, get_token=Mock(side_effect=CredentialUnavailableError(first_error)) + spec=ClientSecretCredential, + get_token=Mock(side_effect=CredentialUnavailableError(first_error)), + get_token_info=Mock(side_effect=CredentialUnavailableError(first_error)), ) second_error = "second_error" second_credential = Mock( - name="second_credential", get_token=Mock(side_effect=ClientAuthenticationError(second_error)) + name="second_credential", + get_token=Mock(side_effect=ClientAuthenticationError(second_error)), + get_token_info=Mock(side_effect=ClientAuthenticationError(second_error)), ) with pytest.raises(ClientAuthenticationError) as ex: - await ChainedTokenCredential(first_credential, second_credential).get_token("scope") + await getattr(ChainedTokenCredential(first_credential, second_credential), get_token_method)("scope") assert "ClientSecretCredential" in ex.value.message assert first_error in ex.value.message @@ -60,26 +65,40 @@ async def test_credential_chain_error_message(): @pytest.mark.asyncio -async def test_chain_attempts_all_credentials(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_chain_attempts_all_credentials(get_token_method): async def credential_unavailable(message="it didn't work", **_): raise CredentialUnavailableError(message) - expected_token = AccessToken("expected_token", 0) + access_token = "expected_token" credentials = [ - Mock(spec_set=["get_token"], get_token=Mock(wraps=credential_unavailable)), - Mock(spec_set=["get_token"], get_token=Mock(wraps=credential_unavailable)), - Mock(spec_set=["get_token"], get_token=wrap_in_future(lambda _, **__: expected_token)), + Mock( + spec_set=["get_token", "get_token_info"], + get_token=Mock(wraps=credential_unavailable), + get_token_info=Mock(wraps=credential_unavailable), + ), + Mock( + spec_set=["get_token", "get_token_info"], + get_token=Mock(wraps=credential_unavailable), + get_token_info=Mock(wraps=credential_unavailable), + ), + Mock( + spec_set=["get_token", "get_token_info"], + get_token=wrap_in_future(lambda _, **__: AccessToken(access_token, 42)), + get_token_info=wrap_in_future(lambda _, **__: AccessTokenInfo(access_token, 42)), + ), ] - token = await ChainedTokenCredential(*credentials).get_token("scope") - assert token is expected_token + token = await getattr(ChainedTokenCredential(*credentials), get_token_method)("scope") + assert token.token == access_token for credential in credentials[:-1]: - assert credential.get_token.call_count == 1 + assert getattr(credential, get_token_method).call_count == 1 @pytest.mark.asyncio -async def test_chain_raises_for_unexpected_error(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_chain_raises_for_unexpected_error(get_token_method): """the chain should not continue after an unexpected error (i.e. anything but CredentialUnavailableError)""" async def credential_unavailable(message="it didn't work", **_): @@ -88,36 +107,53 @@ async def credential_unavailable(message="it didn't work", **_): expected_message = "it can't be done" credentials = [ - Mock(spec_set=["get_token"], get_token=Mock(wraps=credential_unavailable)), - Mock(spec_set=["get_token"], get_token=Mock(side_effect=ValueError(expected_message))), - Mock(spec_set=["get_token"], get_token=Mock(wraps=wrap_in_future(lambda _, **__: AccessToken("**", 42)))), + Mock( + spec_set=["get_token", "get_token_info"], + get_token=Mock(wraps=credential_unavailable), + get_token_info=Mock(wraps=credential_unavailable), + ), + Mock( + spec_set=["get_token", "get_token_info"], + get_token=Mock(side_effect=ValueError(expected_message)), + get_token_info=Mock(side_effect=ValueError(expected_message)), + ), + Mock( + spec_set=["get_token", "get_token_info"], + get_token=Mock(wraps=wrap_in_future(lambda _, **__: AccessToken("**", 42))), + get_token_info=Mock(wraps=wrap_in_future(lambda _, **__: AccessTokenInfo("**", 42))), + ), ] with pytest.raises(ClientAuthenticationError) as ex: - await ChainedTokenCredential(*credentials).get_token("scope") + await getattr(ChainedTokenCredential(*credentials), get_token_method)("scope") assert expected_message in ex.value.message - assert credentials[-1].get_token.call_count == 0 + assert getattr(credentials[-1], get_token_method).call_count == 0 @pytest.mark.asyncio -async def test_returns_first_token(): - expected_token = Mock() - first_credential = Mock(spec_set=["get_token"], get_token=wrap_in_future(lambda _, **__: expected_token)) - second_credential = Mock(spec_set=["get_token"], get_token=Mock()) +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_returns_first_token(get_token_method): + access_token = "expected_token" + first_credential = Mock( + spec_set=["get_token", "get_token_info"], + get_token=wrap_in_future(lambda _, **__: AccessToken(access_token, 42)), + get_token_info=wrap_in_future(lambda _, **__: AccessTokenInfo(access_token, 42)), + ) + second_credential = Mock(spec_set=["get_token", "get_token_info"], get_token=Mock(), get_token_info=Mock()) aggregate = ChainedTokenCredential(first_credential, second_credential) - credential = await aggregate.get_token("scope") + token = await getattr(aggregate, get_token_method)("scope") - assert credential is expected_token - assert second_credential.get_token.call_count == 0 + assert token.token == access_token + assert getattr(second_credential, get_token_method).call_count == 0 @pytest.mark.asyncio -async def test_managed_identity_imds_probe(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_managed_identity_imds_probe(get_token_method): access_token = "****" expires_on = 42 - expected_token = AccessToken(access_token, expires_on) scope = "scope" transport = async_validating_transport( requests=[ @@ -148,32 +184,71 @@ async def test_managed_identity_imds_probe(): # ensure e.g. $MSI_ENDPOINT isn't set, so we get ImdsCredential with patch.dict("os.environ", clear=True): credentials = [ - Mock(spec_set=["get_token"], get_token=Mock(side_effect=CredentialUnavailableError(message=""))), + Mock( + spec_set=["get_token", "get_token_info"], + get_token=Mock(side_effect=CredentialUnavailableError(message="")), + get_token_info=Mock(side_effect=CredentialUnavailableError(message="")), + ), ManagedIdentityCredential(transport=transport), ] - token = await ChainedTokenCredential(*credentials).get_token(scope) - assert token == expected_token + token = await getattr(ChainedTokenCredential(*credentials), get_token_method)(scope) + assert token.token == access_token @pytest.mark.asyncio -async def test_managed_identity_failed_probe(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_managed_identity_failed_probe(get_token_method): async def credential_unavailable(message="it didn't work", **_): raise CredentialUnavailableError(message) mock_send = Mock(side_effect=Exception("timeout")) transport = Mock(send=wrap_in_future(mock_send)) - expected_token = AccessToken("**", 42) + expected_token = "***" credentials = [ - Mock(spec_set=["get_token"], get_token=Mock(wraps=credential_unavailable)), + Mock( + spec_set=["get_token", "get_token_info"], + get_token=Mock(wraps=credential_unavailable), + get_token_info=Mock(wraps=credential_unavailable), + ), ManagedIdentityCredential(transport=transport), - Mock(spec_set=["get_token"], get_token=Mock(wraps=wrap_in_future(lambda _, **__: expected_token))), + Mock( + spec_set=["get_token", "get_token_info"], + get_token=Mock(wraps=wrap_in_future(lambda _, **__: AccessToken(expected_token, 42))), + get_token_info=Mock(wraps=wrap_in_future(lambda _, **__: AccessTokenInfo(expected_token, 42))), + ), ] with patch.dict("os.environ", clear=True): - token = await ChainedTokenCredential(*credentials).get_token("scope") + token = await getattr(ChainedTokenCredential(*credentials), get_token_method)("scope") - assert token is expected_token + assert token.token == expected_token # ManagedIdentityCredential should be tried and skipped with the last credential in the chain # being used. - assert credentials[-1].get_token.call_count == 1 + assert getattr(credentials[-1], get_token_method).call_count == 1 + + +@pytest.mark.asyncio +async def test_credentials_with_no_get_token_info(): + """ChainedTokenCredential should work with credentials that don't implement get_token_info.""" + + async def credential_unavailable(message="it didn't work", **_): + raise CredentialUnavailableError(message) + + access_token = "****" + credential1 = Mock( + spec_set=["get_token_info"], + get_token_info=Mock(wraps=credential_unavailable), + ) + credential2 = Mock( + spec_set=["get_token"], + get_token=Mock(wraps=wrap_in_future(lambda _, **__: AccessToken(access_token, 42))), + ) + credential3 = Mock( + spec_set=["get_token", "get_token_info"], + get_token=Mock(wraps=wrap_in_future(lambda _, **__: AccessToken("foo", 42))), + get_token_info=Mock(wraps=wrap_in_future(lambda _, **__: AccessTokenInfo("bar", 42))), + ) + chain = ChainedTokenCredential(credential1, credential2, credential3) # type: ignore + token_info = await chain.get_token_info("scope") + assert token_info.token == access_token diff --git a/sdk/identity/azure-identity/tests/test_cli_credential.py b/sdk/identity/azure-identity/tests/test_cli_credential.py index 0b688df36faf..2ed7c5657f75 100644 --- a/sdk/identity/azure-identity/tests/test_cli_credential.py +++ b/sdk/identity/azure-identity/tests/test_cli_credential.py @@ -3,6 +3,7 @@ # Licensed under the MIT License. # ------------------------------------ from datetime import datetime +from itertools import product import json import re @@ -14,7 +15,7 @@ import subprocess import pytest -from helpers import mock, INVALID_CHARACTERS +from helpers import mock, INVALID_CHARACTERS, GET_TOKEN_METHODS CHECK_OUTPUT = AzureCliCredential.__module__ + ".subprocess.check_output" @@ -35,21 +36,24 @@ def raise_called_process_error(return_code, output="", cmd="...", stderr=""): return mock.Mock(side_effect=error) -def test_no_scopes(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_no_scopes(get_token_method): """The credential should raise ValueError when get_token is called with no scopes""" with pytest.raises(ValueError): - AzureCliCredential().get_token() + getattr(AzureCliCredential(), get_token_method)() -def test_multiple_scopes(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_multiple_scopes(get_token_method): """The credential should raise ValueError when get_token is called with more than one scope""" with pytest.raises(ValueError): - AzureCliCredential().get_token("one scope", "and another") + getattr(AzureCliCredential(), get_token_method)("one scope", "and another") -def test_invalid_tenant_id(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_invalid_tenant_id(get_token_method): """Invalid tenant IDs should raise ValueErrors.""" for c in INVALID_CHARACTERS: @@ -57,18 +61,23 @@ def test_invalid_tenant_id(): AzureCliCredential(tenant_id="tenant" + c) with pytest.raises(ValueError): - AzureCliCredential().get_token("scope", tenant_id="tenant" + c) + kwargs = {"tenant_id": "tenant" + c} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + getattr(AzureCliCredential(), get_token_method)("scope", **kwargs) -def test_invalid_scopes(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_invalid_scopes(get_token_method): """Scopes with invalid characters should raise ValueErrors.""" for c in INVALID_CHARACTERS: with pytest.raises(ValueError): - AzureCliCredential().get_token("scope" + c) + getattr(AzureCliCredential(), get_token_method)("scope" + c) -def test_get_token(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_get_token(get_token_method): """The credential should parse the CLI's output to an AccessToken""" access_token = "access token" @@ -85,14 +94,15 @@ def test_get_token(): with mock.patch("shutil.which", return_value="az"): with mock.patch(CHECK_OUTPUT, mock.Mock(return_value=successful_output)): - token = AzureCliCredential().get_token("scope") + token = getattr(AzureCliCredential(), get_token_method)("scope") assert token.token == access_token assert type(token.expires_on) == int assert token.expires_on == expected_expires_on -def test_expires_on_used(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_expires_on_used(get_token_method): """Test that 'expires_on' is preferred over 'expiresOn'.""" expires_on = 1602015811 successful_output = json.dumps( @@ -108,12 +118,13 @@ def test_expires_on_used(): with mock.patch("shutil.which", return_value="az"): with mock.patch(CHECK_OUTPUT, mock.Mock(return_value=successful_output)): - token = AzureCliCredential().get_token("scope") + token = getattr(AzureCliCredential(), get_token_method)("scope") assert token.expires_on == expires_on -def test_expires_on_string(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_expires_on_string(get_token_method): """Test that 'expires_on' still works if it's a string.""" expires_on = 1602015811 successful_output = json.dumps( @@ -128,85 +139,91 @@ def test_expires_on_string(): with mock.patch("shutil.which", return_value="az"): with mock.patch(CHECK_OUTPUT, mock.Mock(return_value=successful_output)): - token = AzureCliCredential().get_token("scope") + token = getattr(AzureCliCredential(), get_token_method)("scope") assert type(token.expires_on) == int assert token.expires_on == expires_on -def test_cli_not_installed(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_cli_not_installed(get_token_method): """The credential should raise CredentialUnavailableError when the CLI isn't installed""" with mock.patch("shutil.which", return_value=None): with pytest.raises(CredentialUnavailableError, match=CLI_NOT_FOUND): - AzureCliCredential().get_token("scope") + getattr(AzureCliCredential(), get_token_method)("scope") -def test_cannot_execute_shell(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_cannot_execute_shell(get_token_method): """The credential should raise CredentialUnavailableError when the subprocess doesn't start""" with mock.patch("shutil.which", return_value="az"): with mock.patch(CHECK_OUTPUT, mock.Mock(side_effect=OSError())): with pytest.raises(CredentialUnavailableError): - AzureCliCredential().get_token("scope") + getattr(AzureCliCredential(), get_token_method)("scope") -def test_not_logged_in(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_not_logged_in(get_token_method): """When the CLI isn't logged in, the credential should raise CredentialUnavailableError""" stderr = "ERROR: Please run 'az login' to setup account." with mock.patch("shutil.which", return_value="az"): with mock.patch(CHECK_OUTPUT, raise_called_process_error(1, stderr=stderr)): with pytest.raises(CredentialUnavailableError, match=NOT_LOGGED_IN): - AzureCliCredential().get_token("scope") + getattr(AzureCliCredential(), get_token_method)("scope") -def test_aadsts_error(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_aadsts_error(get_token_method): """When the CLI isn't logged in, the credential should raise CredentialUnavailableError""" stderr = "ERROR: AADSTS70043: The refresh token has expired, Please run 'az login' to setup account." with mock.patch("shutil.which", return_value="az"): with mock.patch(CHECK_OUTPUT, raise_called_process_error(1, stderr=stderr)): with pytest.raises(ClientAuthenticationError, match=stderr): - AzureCliCredential().get_token("scope") + getattr(AzureCliCredential(), get_token_method)("scope") -def test_unexpected_error(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_unexpected_error(get_token_method): """When the CLI returns an unexpected error, the credential should raise an error containing the CLI's output""" stderr = "something went wrong" with mock.patch("shutil.which", return_value="az"): with mock.patch(CHECK_OUTPUT, raise_called_process_error(42, stderr=stderr)): with pytest.raises(ClientAuthenticationError, match=stderr): - AzureCliCredential().get_token("scope") + getattr(AzureCliCredential(), get_token_method)("scope") -@pytest.mark.parametrize("output", TEST_ERROR_OUTPUTS) -def test_parsing_error_does_not_expose_token(output): +@pytest.mark.parametrize("output,get_token_method", product(TEST_ERROR_OUTPUTS, GET_TOKEN_METHODS)) +def test_parsing_error_does_not_expose_token(output, get_token_method): """Errors during CLI output parsing shouldn't expose access tokens in that output""" with mock.patch("shutil.which", return_value="az"): with mock.patch(CHECK_OUTPUT, mock.Mock(return_value=output)): with pytest.raises(ClientAuthenticationError) as ex: - AzureCliCredential().get_token("scope") + getattr(AzureCliCredential(), get_token_method)("scope") assert "secret value" not in str(ex.value) assert "secret value" not in repr(ex.value) -@pytest.mark.parametrize("output", TEST_ERROR_OUTPUTS) -def test_subprocess_error_does_not_expose_token(output): +@pytest.mark.parametrize("output,get_token_method", product(TEST_ERROR_OUTPUTS, GET_TOKEN_METHODS)) +def test_subprocess_error_does_not_expose_token(output, get_token_method): """Errors from the subprocess shouldn't expose access tokens in CLI output""" with mock.patch("shutil.which", return_value="az"): with mock.patch(CHECK_OUTPUT, raise_called_process_error(1, output=output)): with pytest.raises(ClientAuthenticationError) as ex: - AzureCliCredential().get_token("scope") + getattr(AzureCliCredential(), get_token_method)("scope") assert "secret value" not in str(ex.value) assert "secret value" not in repr(ex.value) -def test_timeout(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_timeout(get_token_method): """The credential should raise CredentialUnavailableError when the subprocess times out""" from subprocess import TimeoutExpired @@ -214,7 +231,7 @@ def test_timeout(): with mock.patch("shutil.which", return_value="az"): with mock.patch(CHECK_OUTPUT, mock.Mock(side_effect=TimeoutExpired("", 42))) as check_output_mock: with pytest.raises(CredentialUnavailableError): - AzureCliCredential(process_timeout=42).get_token("scope") + getattr(AzureCliCredential(process_timeout=42), get_token_method)("scope") # Ensure custom timeout is passed to subprocess _, kwargs = check_output_mock.call_args @@ -222,7 +239,8 @@ def test_timeout(): assert kwargs["timeout"] == 42 -def test_multitenant_authentication_class(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_multitenant_authentication_class(get_token_method): default_tenant = "first-tenant" first_token = "***" second_tenant = "second-tenant" @@ -244,17 +262,18 @@ def fake_check_output(command_line, **_): with mock.patch("shutil.which", return_value="az"): with mock.patch(CHECK_OUTPUT, fake_check_output): - token = AzureCliCredential().get_token("scope") + token = getattr(AzureCliCredential(), get_token_method)("scope") assert token.token == first_token - token = AzureCliCredential(tenant_id=default_tenant).get_token("scope") + token = getattr(AzureCliCredential(tenant_id=default_tenant), get_token_method)("scope") assert token.token == first_token - token = AzureCliCredential(tenant_id=second_tenant).get_token("scope") + token = getattr(AzureCliCredential(tenant_id=second_tenant), get_token_method)("scope") assert token.token == second_token -def test_multitenant_authentication(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_multitenant_authentication(get_token_method): default_tenant = "first-tenant" first_token = "***" second_tenant = "second-tenant" @@ -277,21 +296,28 @@ def fake_check_output(command_line, **_): credential = AzureCliCredential() with mock.patch("shutil.which", return_value="az"): with mock.patch(CHECK_OUTPUT, fake_check_output): - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == first_token - token = credential.get_token("scope", tenant_id=default_tenant) + kwargs = {"tenant_id": default_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert token.token == first_token - token = credential.get_token("scope", tenant_id=second_tenant) + kwargs = {"tenant_id": second_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert token.token == second_token # should still default to the first tenant - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == first_token -def test_multitenant_authentication_not_allowed(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_multitenant_authentication_not_allowed(get_token_method): expected_tenant = "expected-tenant" expected_token = "***" @@ -311,9 +337,12 @@ def fake_check_output(command_line, **_): credential = AzureCliCredential() with mock.patch("shutil.which", return_value="az"): with mock.patch(CHECK_OUTPUT, fake_check_output): - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == expected_token with mock.patch.dict("os.environ", {EnvironmentVariables.AZURE_IDENTITY_DISABLE_MULTITENANTAUTH: "true"}): - token = credential.get_token("scope", tenant_id="un" + expected_tenant) + kwargs = {"tenant_id": "un" + expected_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert token.token == expected_token diff --git a/sdk/identity/azure-identity/tests/test_cli_credential_async.py b/sdk/identity/azure-identity/tests/test_cli_credential_async.py index d771dd7d77df..7865a3e6b384 100644 --- a/sdk/identity/azure-identity/tests/test_cli_credential_async.py +++ b/sdk/identity/azure-identity/tests/test_cli_credential_async.py @@ -4,6 +4,7 @@ # ------------------------------------ import asyncio from datetime import datetime +from itertools import product import json import re import sys @@ -16,7 +17,7 @@ from azure.core.exceptions import ClientAuthenticationError import pytest -from helpers import INVALID_CHARACTERS +from helpers import INVALID_CHARACTERS, GET_TOKEN_METHODS from helpers_async import get_completed_future from test_cli_credential import TEST_ERROR_OUTPUTS @@ -33,21 +34,24 @@ async def communicate(): return mock.Mock(return_value=get_completed_future(process)) -async def test_no_scopes(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_no_scopes(get_token_method): """The credential should raise ValueError when get_token is called with no scopes""" with pytest.raises(ValueError): - await AzureCliCredential().get_token() + await getattr(AzureCliCredential(), get_token_method)() -async def test_multiple_scopes(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_multiple_scopes(get_token_method): """The credential should raise ValueError when get_token is called with more than one scope""" with pytest.raises(ValueError): - await AzureCliCredential().get_token("one scope", "and another") + await getattr(AzureCliCredential(), get_token_method)("one scope", "and another") -async def test_invalid_tenant_id(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_invalid_tenant_id(get_token_method): """Invalid tenant IDs should raise ValueErrors.""" for c in INVALID_CHARACTERS: @@ -55,15 +59,19 @@ async def test_invalid_tenant_id(): AzureCliCredential(tenant_id="tenant" + c) with pytest.raises(ValueError): - await AzureCliCredential().get_token("scope", tenant_id="tenant" + c) + kwargs = {"tenant_id": "tenant" + c} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + await getattr(AzureCliCredential(), get_token_method)("scope", **kwargs) -async def test_invalid_scopes(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_invalid_scopes(get_token_method): """Scopes with invalid characters should raise ValueErrors.""" for c in INVALID_CHARACTERS: with pytest.raises(ValueError): - await AzureCliCredential().get_token("https://scope" + c) + await getattr(AzureCliCredential(), get_token_method)("https://scope" + c) async def test_close(): @@ -80,21 +88,25 @@ async def test_context_manager(): @pytest.mark.skipif(not sys.platform.startswith("win"), reason="tests Windows-specific behavior") -async def test_windows_fallback(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_windows_fallback(get_token_method): """The credential should fall back to the sync implementation when not using ProactorEventLoop on Windows""" sync_get_token = mock.Mock() with mock.patch("azure.identity.aio._credentials.azure_cli._SyncAzureCliCredential") as fallback: - fallback.return_value = mock.Mock(spec_set=["get_token"], get_token=sync_get_token) + fallback.return_value = mock.Mock( + spec_set=["get_token", "get_token_info"], get_token=sync_get_token, get_token_info=sync_get_token + ) with mock.patch(AzureCliCredential.__module__ + ".asyncio.get_event_loop"): # asyncio.get_event_loop now returns Mock, i.e. never ProactorEventLoop credential = AzureCliCredential() - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") assert sync_get_token.call_count == 1 -async def test_get_token(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_get_token(get_token_method): """The credential should parse the CLI's output to an AccessToken""" access_token = "access token" @@ -112,14 +124,15 @@ async def test_get_token(): with mock.patch("shutil.which", return_value="az"): with mock.patch(SUBPROCESS_EXEC, mock_exec(successful_output)): credential = AzureCliCredential() - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == access_token assert type(token.expires_on) == int assert token.expires_on == expected_expires_on -async def test_expires_on_used(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_expires_on_used(get_token_method): """Test that 'expires_on' is preferred over 'expiresOn'.""" expires_on = 1602015811 successful_output = json.dumps( @@ -136,12 +149,13 @@ async def test_expires_on_used(): with mock.patch("shutil.which", return_value="az"): with mock.patch(SUBPROCESS_EXEC, mock_exec(successful_output)): credential = AzureCliCredential() - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.expires_on == expires_on -async def test_expires_on_string(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_expires_on_string(get_token_method): """Test that 'expires_on' still works if it's a string.""" expires_on = 1602015811 successful_output = json.dumps( @@ -157,32 +171,35 @@ async def test_expires_on_string(): with mock.patch("shutil.which", return_value="az"): with mock.patch(SUBPROCESS_EXEC, mock_exec(successful_output)): credential = AzureCliCredential() - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert type(token.expires_on) == int assert token.expires_on == expires_on -async def test_cli_not_installed(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_cli_not_installed(get_token_method): """The credential should raise CredentialUnavailableError when the CLI isn't installed""" with mock.patch("shutil.which", return_value=None): with pytest.raises(CredentialUnavailableError, match=CLI_NOT_FOUND): credential = AzureCliCredential() - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") -async def test_cannot_execute_shell(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_cannot_execute_shell(get_token_method): """The credential should raise CredentialUnavailableError when the subprocess doesn't start""" with mock.patch("shutil.which", return_value="az"): with mock.patch(SUBPROCESS_EXEC, mock.Mock(side_effect=OSError())): with pytest.raises(CredentialUnavailableError): credential = AzureCliCredential() - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") -async def test_not_logged_in(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_not_logged_in(get_token_method): """When the CLI isn't logged in, the credential should raise CredentialUnavailableError""" stderr = "ERROR: Please run 'az login' to setup account." @@ -190,10 +207,11 @@ async def test_not_logged_in(): with mock.patch(SUBPROCESS_EXEC, mock_exec("", stderr, return_code=1)): with pytest.raises(CredentialUnavailableError, match=NOT_LOGGED_IN): credential = AzureCliCredential() - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") -async def test_aadsts_error(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_aadsts_error(get_token_method): """When the CLI isn't logged in, the credential should raise CredentialUnavailableError""" stderr = "ERROR: AADSTS70043: The refresh token has expired, Please run 'az login' to setup account." @@ -201,10 +219,11 @@ async def test_aadsts_error(): with mock.patch(SUBPROCESS_EXEC, mock_exec("", stderr, return_code=1)): with pytest.raises(ClientAuthenticationError, match=stderr): credential = AzureCliCredential() - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") -async def test_unexpected_error(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_unexpected_error(get_token_method): """When the CLI returns an unexpected error, the credential should raise an error containing the CLI's output""" stderr = "something went wrong" @@ -212,50 +231,52 @@ async def test_unexpected_error(): with mock.patch(SUBPROCESS_EXEC, mock_exec("", stderr, return_code=42)): with pytest.raises(ClientAuthenticationError, match=stderr): credential = AzureCliCredential() - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") -@pytest.mark.parametrize("output", TEST_ERROR_OUTPUTS) -async def test_parsing_error_does_not_expose_token(output): +@pytest.mark.parametrize("output,get_token_method", product(TEST_ERROR_OUTPUTS, GET_TOKEN_METHODS)) +async def test_parsing_error_does_not_expose_token(output, get_token_method): """Errors during CLI output parsing shouldn't expose access tokens in that output""" with mock.patch("shutil.which", return_value="az"): with mock.patch(SUBPROCESS_EXEC, mock_exec(output)): with pytest.raises(ClientAuthenticationError) as ex: credential = AzureCliCredential() - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") assert "secret value" not in str(ex.value) assert "secret value" not in repr(ex.value) -@pytest.mark.parametrize("output", TEST_ERROR_OUTPUTS) -async def test_subprocess_error_does_not_expose_token(output): +@pytest.mark.parametrize("output,get_token_method", product(TEST_ERROR_OUTPUTS, GET_TOKEN_METHODS)) +async def test_subprocess_error_does_not_expose_token(output, get_token_method): """Errors from the subprocess shouldn't expose access tokens in CLI output""" with mock.patch("shutil.which", return_value="az"): with mock.patch(SUBPROCESS_EXEC, mock_exec(output, return_code=1)): with pytest.raises(ClientAuthenticationError) as ex: credential = AzureCliCredential() - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") assert "secret value" not in str(ex.value) assert "secret value" not in repr(ex.value) -async def test_timeout(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_timeout(get_token_method): """The credential should kill the subprocess after a timeout""" proc = mock.Mock(communicate=mock.Mock(side_effect=asyncio.TimeoutError), returncode=None) with mock.patch("shutil.which", return_value="az"): with mock.patch(SUBPROCESS_EXEC, mock.Mock(return_value=get_completed_future(proc))): with pytest.raises(CredentialUnavailableError): - await AzureCliCredential().get_token("scope") + await getattr(AzureCliCredential(), get_token_method)("scope") assert proc.communicate.call_count == 1 -async def test_multitenant_authentication(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_multitenant_authentication(get_token_method): default_tenant = "first-tenant" first_token = "***" second_tenant = "second-tenant" @@ -279,21 +300,28 @@ async def fake_exec(*args, **_): credential = AzureCliCredential() with mock.patch("shutil.which", return_value="az"): with mock.patch(SUBPROCESS_EXEC, fake_exec): - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == first_token - token = await credential.get_token("scope", tenant_id=default_tenant) + kwargs = {"tenant_id": default_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(credential, get_token_method)("scope", **kwargs) assert token.token == first_token - token = await credential.get_token("scope", tenant_id=second_tenant) + kwargs = {"tenant_id": second_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(credential, get_token_method)("scope", **kwargs) assert token.token == second_token # should still default to the first tenant - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == first_token -async def test_multitenant_authentication_not_allowed(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_multitenant_authentication_not_allowed(get_token_method): expected_tenant = "expected-tenant" expected_token = "***" @@ -314,9 +342,12 @@ async def fake_exec(*args, **_): credential = AzureCliCredential() with mock.patch("shutil.which", return_value="az"): with mock.patch(SUBPROCESS_EXEC, fake_exec): - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == expected_token with mock.patch.dict("os.environ", {EnvironmentVariables.AZURE_IDENTITY_DISABLE_MULTITENANTAUTH: "true"}): - token = await credential.get_token("scope", tenant_id="un" + expected_tenant) + kwargs = {"tenant_id": "un" + expected_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(credential, get_token_method)("scope", **kwargs) assert token.token == expected_token diff --git a/sdk/identity/azure-identity/tests/test_client_assertion_credential.py b/sdk/identity/azure-identity/tests/test_client_assertion_credential.py index 30586a55777d..d1c1fb2beaff 100644 --- a/sdk/identity/azure-identity/tests/test_client_assertion_credential.py +++ b/sdk/identity/azure-identity/tests/test_client_assertion_credential.py @@ -7,8 +7,9 @@ from azure.identity._internal.aad_client_base import JWT_BEARER_ASSERTION from azure.identity import ClientAssertionCredential, TokenCachePersistenceOptions +import pytest -from helpers import build_aad_response, mock_response +from helpers import build_aad_response, mock_response, GET_TOKEN_METHODS def test_init_with_kwargs(): @@ -40,7 +41,8 @@ def test_context_manager(): assert transport.__exit__.called -def test_token_cache_persistence(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_token_cache_persistence(get_token_method): """The credential should use a persistent cache if cache_persistence_options are configured.""" access_token = "foo" @@ -72,12 +74,15 @@ def send(request, **kwargs): assert credential._client._cache is None assert credential._client._cae_cache is None - token = credential.get_token(scope) + token = getattr(credential, get_token_method)(scope) assert token.token == access_token assert load_persistent_cache.call_count == 1 assert credential._client._cache is not None assert credential._client._cae_cache is None - token = credential.get_token(scope, enable_cae=True) + kwargs = {"enable_cae": True} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)(scope, **kwargs) assert load_persistent_cache.call_count == 2 assert credential._client._cae_cache is not None diff --git a/sdk/identity/azure-identity/tests/test_client_assertion_credential_async.py b/sdk/identity/azure-identity/tests/test_client_assertion_credential_async.py index 18a2dcb4b57a..ddf54bf35be3 100644 --- a/sdk/identity/azure-identity/tests/test_client_assertion_credential_async.py +++ b/sdk/identity/azure-identity/tests/test_client_assertion_credential_async.py @@ -10,7 +10,7 @@ from azure.identity import TokenCachePersistenceOptions from azure.identity.aio import ClientAssertionCredential -from helpers import build_aad_response, mock_response +from helpers import build_aad_response, mock_response, GET_TOKEN_METHODS def test_init_with_kwargs(): @@ -45,7 +45,8 @@ async def test_context_manager(): @pytest.mark.asyncio -async def test_token_cache_persistence(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_token_cache_persistence(get_token_method): """The credential should use a persistent cache if cache_persistence_options are configured.""" access_token = "foo" @@ -77,12 +78,15 @@ async def send(request, **kwargs): assert credential._client._cache is None assert credential._client._cae_cache is None - token = await credential.get_token(scope) + token = await getattr(credential, get_token_method)(scope) assert token.token == access_token assert load_persistent_cache.call_count == 1 assert credential._client._cache is not None assert credential._client._cae_cache is None - token = await credential.get_token(scope, enable_cae=True) + kwargs = {"enable_cae": True} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(credential, get_token_method)(scope, **kwargs) assert load_persistent_cache.call_count == 2 assert credential._client._cae_cache is not None diff --git a/sdk/identity/azure-identity/tests/test_client_secret_credential.py b/sdk/identity/azure-identity/tests/test_client_secret_credential.py index 7b9b9a195ef4..ae778605ad24 100644 --- a/sdk/identity/azure-identity/tests/test_client_secret_credential.py +++ b/sdk/identity/azure-identity/tests/test_client_secret_credential.py @@ -2,6 +2,10 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ +from itertools import product +from urllib.parse import urlparse +from unittest.mock import Mock, patch + from azure.core.pipeline.policies import ContentDecodePolicy, SansIOHTTPPolicy from azure.identity import ClientSecretCredential, TokenCachePersistenceOptions from azure.identity._enums import RegionalAuthority @@ -10,7 +14,6 @@ from msal import TokenCache import msal import pytest -from urllib.parse import urlparse from helpers import ( build_aad_response, @@ -18,11 +21,10 @@ get_discovery_response, id_token_claims, mock_response, - msal_validating_transport, new_msal_validating_transport, Request, + GET_TOKEN_METHODS, ) -from unittest.mock import Mock, patch def test_tenant_id_validation(): @@ -38,15 +40,17 @@ def test_tenant_id_validation(): ClientSecretCredential(tenant, "client-id", "secret") -def test_no_scopes(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_no_scopes(get_token_method): """The credential should raise ValueError when get_token is called with no scopes""" credential = ClientSecretCredential("tenant-id", "client-id", "client-secret") with pytest.raises(ValueError): - credential.get_token() + getattr(credential, get_token_method)() -def test_policies_configurable(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_policies_configurable(get_token_method): policy = Mock(spec_set=SansIOHTTPPolicy, on_request=Mock()) transport = new_msal_validating_transport( @@ -57,12 +61,13 @@ def test_policies_configurable(): "tenant-id", "client-id", "client-secret", policies=[ContentDecodePolicy(), policy], transport=transport ) - credential.get_token("scope") + getattr(credential, get_token_method)("scope") assert policy.on_request.called -def test_user_agent(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_user_agent(get_token_method): transport = new_msal_validating_transport( requests=[Request(required_headers={"User-Agent": USER_AGENT})], responses=[mock_response(json_payload=build_aad_response(access_token="**"))], @@ -70,10 +75,11 @@ def test_user_agent(): credential = ClientSecretCredential("tenant-id", "client-id", "client-secret", transport=transport) - credential.get_token("scope") + getattr(credential, get_token_method)("scope") -def test_client_secret_credential(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_client_secret_credential(get_token_method): client_id = "fake-client-id" secret = "fake-client-secret" tenant_id = "fake-tenant-id" @@ -85,13 +91,15 @@ def test_client_secret_credential(): responses=[mock_response(json_payload=build_aad_response(access_token=access_token))], ) - token = ClientSecretCredential(tenant_id, client_id, secret, transport=transport).get_token("scope") + token = getattr(ClientSecretCredential(tenant_id, client_id, secret, transport=transport), get_token_method)( + "scope" + ) assert token.token == access_token -@pytest.mark.parametrize("authority", ("localhost", "https://localhost")) -def test_authority(authority): +@pytest.mark.parametrize("authority,get_token_method", product(("localhost", "https://localhost"), GET_TOKEN_METHODS)) +def test_authority(authority, get_token_method): """the credential should accept an authority, with or without scheme, as an argument or environment variable""" tenant_id = "expected-tenant" @@ -106,7 +114,7 @@ def test_authority(authority): credential = ClientSecretCredential(tenant_id, "client-id", "secret", authority=authority) with patch("msal.ConfidentialClientApplication", mock_ctor): # must call get_token because the credential constructs the MSAL application lazily - credential.get_token("scope") + getattr(credential, get_token_method)("scope") assert mock_ctor.call_count == 1 _, kwargs = mock_ctor.call_args @@ -117,14 +125,15 @@ def test_authority(authority): with patch.dict("os.environ", {EnvironmentVariables.AZURE_AUTHORITY_HOST: authority}, clear=True): credential = ClientSecretCredential(tenant_id, "client-id", "secret") with patch("msal.ConfidentialClientApplication", mock_ctor): - credential.get_token("scope") + getattr(credential, get_token_method)("scope") assert mock_ctor.call_count == 1 _, kwargs = mock_ctor.call_args assert kwargs["authority"] == expected_authority -def test_regional_authority(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_regional_authority(get_token_method): """the credential should configure MSAL with a regional authority specified via kwarg or environment variable""" mock_confidential_client = Mock( @@ -138,7 +147,7 @@ def test_regional_authority(): with patch.dict("os.environ", {EnvironmentVariables.AZURE_REGIONAL_AUTHORITY_NAME: region.value}, clear=True): credential = ClientSecretCredential("tenant", "client-id", "secret") with patch("msal.ConfidentialClientApplication", mock_confidential_client): - credential.get_token("scope") + getattr(credential, get_token_method)("scope") assert mock_confidential_client.call_count == 1 _, kwargs = mock_confidential_client.call_args @@ -148,7 +157,8 @@ def test_regional_authority(): assert kwargs["azure_region"] == region.value -def test_token_cache_persistent(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_token_cache_persistent(get_token_method): """the credential should use a persistent cache if cache_persistence_options are configured""" access_token = "foo token" @@ -176,18 +186,22 @@ def send(request, **kwargs): assert credential._cache is None assert credential._cae_cache is None - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == access_token assert load_persistent_cache.call_count == 1 assert credential._cache is not None assert credential._cae_cache is None - token = credential.get_token("scope", enable_cae=True) + kwargs = {"enable_cae": True} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert load_persistent_cache.call_count == 2 assert credential._cae_cache is not None -def test_token_cache_memory(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_token_cache_memory(get_token_method): """The credential should default to in-memory cache if no persistence options are provided.""" access_token = "foo token" @@ -205,18 +219,22 @@ def send(request, **kwargs): credential = ClientSecretCredential("tenant", "client-id", "secret", transport=Mock(send=send)) assert credential._cache is None - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == access_token assert isinstance(credential._cache, TokenCache) assert credential._cae_cache is None assert not load_persistent_cache.called - token = credential.get_token("scope", enable_cae=True) + kwargs = {"enable_cae": True} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert isinstance(credential._cae_cache, TokenCache) assert not load_persistent_cache.called -def test_cache_multiple_clients(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_cache_multiple_clients(get_token_method): """the credential shouldn't use tokens issued to other service principals""" access_token_a = "token a" @@ -249,13 +267,13 @@ def test_cache_multiple_clients(): # A caches a token scope = "scope" - token_a = credential_a.get_token(scope) + token_a = getattr(credential_a, get_token_method)(scope) assert mock_cache_loader.call_count == 1 assert token_a.token == access_token_a assert transport_a.send.call_count == 2 # one MSAL discovery request, one token request # B should get a different token for the same scope - token_b = credential_b.get_token(scope) + token_b = getattr(credential_b, get_token_method)(scope) assert mock_cache_loader.call_count == 2 assert token_b.token == access_token_b assert transport_b.send.call_count == 2 @@ -263,7 +281,8 @@ def test_cache_multiple_clients(): assert len(list(cache.search(TokenCache.CredentialType.ACCESS_TOKEN))) == 2 -def test_multitenant_authentication(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_multitenant_authentication(get_token_method): first_tenant = "first-tenant" first_token = "***" second_tenant = "second-tenant" @@ -286,21 +305,28 @@ def send(request, **kwargs): credential = ClientSecretCredential( first_tenant, "client-id", "secret", transport=Mock(send=send), additionally_allowed_tenants=["*"] ) - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == first_token - token = credential.get_token("scope", tenant_id=first_tenant) + kwargs = {"tenant_id": first_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert token.token == first_token - token = credential.get_token("scope", tenant_id=second_tenant) + kwargs = {"tenant_id": second_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert token.token == second_token # should still default to the first tenant - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == first_token -def test_live_multitenant_authentication(live_service_principal): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_live_multitenant_authentication(live_service_principal, get_token_method): # first create a credential with a non-existent tenant credential = ClientSecretCredential( "...", @@ -309,12 +335,15 @@ def test_live_multitenant_authentication(live_service_principal): additionally_allowed_tenants=["*"], ) # then get a valid token for an actual tenant - token = credential.get_token("https://vault.azure.net/.default", tenant_id=live_service_principal["tenant_id"]) + token = getattr(credential, get_token_method)( + "https://vault.azure.net/.default", tenant_id=live_service_principal["tenant_id"] + ) assert token.token assert token.expires_on -def test_multitenant_authentication_not_allowed(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_multitenant_authentication_not_allowed(get_token_method): expected_tenant = "expected-tenant" expected_token = "***" @@ -332,18 +361,25 @@ def send(request, **kwargs): credential = ClientSecretCredential(expected_tenant, "client-id", "secret", transport=Mock(send=send)) - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == expected_token - token = credential.get_token("scope", tenant_id=expected_tenant) + kwargs = {"tenant_id": expected_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert token.token == expected_token + kwargs = {"tenant_id": "un" + expected_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} with patch.dict("os.environ", {EnvironmentVariables.AZURE_IDENTITY_DISABLE_MULTITENANTAUTH: "true"}): - token = credential.get_token("scope", tenant_id="un" + expected_tenant) + token = getattr(credential, get_token_method)("scope", **kwargs) assert token.token == expected_token -def test_client_capabilities(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_client_capabilities(get_token_method): """The credential should configure MSAL for capability only if enable_cae is passed in.""" transport = Mock(send=Mock(side_effect=Exception("this test mocks MSAL, so no request should be sent"))) @@ -362,7 +398,8 @@ def test_client_capabilities(): assert kwargs["client_capabilities"] == ["CP1"] -def test_claims_challenge(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_claims_challenge(get_token_method): """get_token should pass any claims challenge to MSAL token acquisition APIs""" msal_acquire_token_result = dict( @@ -378,7 +415,10 @@ def test_claims_challenge(): msal_app.acquire_token_silent_with_error.return_value = None msal_app.acquire_token_for_client.return_value = msal_acquire_token_result - credential.get_token("scope", claims=expected_claims) + kwargs = {"claims": expected_claims} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + getattr(credential, get_token_method)("scope", **kwargs) assert msal_app.acquire_token_silent_with_error.call_count == 1 args, kwargs = msal_app.acquire_token_silent_with_error.call_args @@ -389,7 +429,8 @@ def test_claims_challenge(): assert kwargs["claims_challenge"] == expected_claims -def test_msal_kwargs_filtered(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_msal_kwargs_filtered(get_token_method): msal_acquire_token_result = dict( build_aad_response(access_token="**", id_token=build_id_token()), id_token_claims=id_token_claims("issuer", "subject", "audience", upn="upn"), @@ -402,10 +443,12 @@ def test_msal_kwargs_filtered(): msal_app.acquire_token_silent_with_error.return_value = None msal_app.acquire_token_for_client.return_value = msal_acquire_token_result - credential.get_token("scope", claims=expected_claims, correlation_id="foo", enable_cae=True) + kwargs = {"claims": expected_claims, "enable_cae": True} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + getattr(credential, get_token_method)("scope", **kwargs) assert msal_app.acquire_token_silent_with_error.call_count == 1 _, kwargs = msal_app.acquire_token_silent_with_error.call_args assert kwargs["claims_challenge"] == expected_claims - assert kwargs["correlation_id"] == "foo" assert "enable_cae" not in kwargs diff --git a/sdk/identity/azure-identity/tests/test_client_secret_credential_async.py b/sdk/identity/azure-identity/tests/test_client_secret_credential_async.py index 0b4a91ac6a11..d8044ec52695 100644 --- a/sdk/identity/azure-identity/tests/test_client_secret_credential_async.py +++ b/sdk/identity/azure-identity/tests/test_client_secret_credential_async.py @@ -3,6 +3,7 @@ # Licensed under the MIT License. # ------------------------------------ import time +from itertools import product from unittest.mock import Mock, patch from urllib.parse import urlparse @@ -15,7 +16,7 @@ from msal import TokenCache import pytest -from helpers import build_aad_response, mock_response, Request +from helpers import build_aad_response, mock_response, Request, GET_TOKEN_METHODS from helpers_async import async_validating_transport, AsyncMockTransport, wrap_in_future @@ -33,12 +34,13 @@ def test_tenant_id_validation(): @pytest.mark.asyncio -async def test_no_scopes(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_no_scopes(get_token_method): """The credential should raise ValueError when get_token is called with no scopes""" credential = ClientSecretCredential("tenant-id", "client-id", "client-secret") with pytest.raises(ValueError): - await credential.get_token() + await getattr(credential, get_token_method)() @pytest.mark.asyncio @@ -52,7 +54,8 @@ async def test_close(): @pytest.mark.asyncio -async def test_context_manager(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_context_manager(get_token_method): transport = AsyncMockTransport() credential = ClientSecretCredential("tenant-id", "client-id", "client-secret", transport=transport) @@ -64,7 +67,8 @@ async def test_context_manager(): @pytest.mark.asyncio -async def test_policies_configurable(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_policies_configurable(get_token_method): policy = Mock(spec_set=SansIOHTTPPolicy, on_request=Mock()) async def send(*_, **kwargs): @@ -77,13 +81,14 @@ async def send(*_, **kwargs): "tenant-id", "client-id", "client-secret", policies=[ContentDecodePolicy(), policy], transport=Mock(send=send) ) - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") assert policy.on_request.called @pytest.mark.asyncio -async def test_user_agent(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_user_agent(get_token_method): transport = async_validating_transport( requests=[Request(required_headers={"User-Agent": USER_AGENT})], responses=[mock_response(json_payload=build_aad_response(access_token="**"))], @@ -91,11 +96,12 @@ async def test_user_agent(): credential = ClientSecretCredential("tenant-id", "client-id", "client-secret", transport=transport) - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") @pytest.mark.asyncio -async def test_client_secret_credential(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_client_secret_credential(get_token_method): client_id = "fake-client-id" secret = "fake-client-secret" tenant_id = "fake-tenant-id" @@ -115,17 +121,18 @@ async def test_client_secret_credential(): ], ) - token = await ClientSecretCredential( - tenant_id=tenant_id, client_id=client_id, client_secret=secret, transport=transport - ).get_token("scope") + token = await getattr( + ClientSecretCredential(tenant_id=tenant_id, client_id=client_id, client_secret=secret, transport=transport), + get_token_method, + )("scope") # not validating expires_on because doing so requires monkeypatching time, and this is tested elsewhere assert token.token == access_token @pytest.mark.asyncio -@pytest.mark.parametrize("authority", ("localhost", "https://localhost")) -async def test_request_url(authority): +@pytest.mark.parametrize("authority,get_token_method", product(("localhost", "https://localhost"), GET_TOKEN_METHODS)) +async def test_request_url(authority, get_token_method): """the credential should accept an authority, with or without scheme, as an argument or environment variable""" tenant_id = "expected-tenant" @@ -143,22 +150,23 @@ async def mock_send(request, **kwargs): credential = ClientSecretCredential( tenant_id, "client-id", "secret", transport=Mock(send=mock_send), authority=authority ) - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == access_token # authority can be configured via environment variable with patch.dict("os.environ", {EnvironmentVariables.AZURE_AUTHORITY_HOST: authority}, clear=True): credential = ClientSecretCredential(tenant_id, "client-id", "secret", transport=Mock(send=mock_send)) - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") assert token.token == access_token @pytest.mark.asyncio -async def test_cache(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_cache(get_token_method): expired = "this token's expired" now = int(time.time()) expired_on = now - 3600 - expired_token = AccessToken(expired, expired_on) + expired_token = expired token_payload = { "access_token": expired, "expires_in": 0, @@ -175,22 +183,22 @@ async def test_cache(): # get_token initially returns the expired token because the credential # doesn't check whether tokens it receives from the service have expired - token = await credential.get_token(scope) - assert token == expired_token + token = await getattr(credential, get_token_method)(scope) + assert token.token == expired_token access_token = "new token" token_payload["access_token"] = access_token token_payload["expires_on"] = now + 3600 - valid_token = AccessToken(access_token, now + 3600) # second call should observe the cached token has expired, and request another - token = await credential.get_token(scope) - assert token == valid_token + token = await getattr(credential, get_token_method)(scope) + assert token.token == access_token assert mock_send.call_count == 2 @pytest.mark.asyncio -async def test_token_cache(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_token_cache(get_token_method): """the credential should default to an in memory cache, and optionally use a persistent cache""" access_token = "token" @@ -208,20 +216,24 @@ async def test_token_cache(): assert mock_token_cache.call_count == 0 assert not load_persistent_cache.called - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") assert mock_token_cache.call_count == 1 assert load_persistent_cache.call_count == 0 assert credential._client._cache is not None assert credential._client._cae_cache is None - await credential.get_token("scope", enable_cae=True) + kwargs = {"enable_cae": True} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + await getattr(credential, get_token_method)("scope", **kwargs) assert mock_token_cache.call_count == 2 assert load_persistent_cache.call_count == 0 assert credential._client._cae_cache is not None @pytest.mark.asyncio -async def test_token_cache_persistent(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_token_cache_persistent(get_token_method): """the credential should use persistent cache if passed in cache options.""" access_token = "token" @@ -241,14 +253,17 @@ async def test_token_cache_persistent(): cache_persistence_options=TokenCachePersistenceOptions(), transport=transport, ) - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") assert load_persistent_cache.call_count == 1 assert credential._client._cache is not None assert credential._client._cae_cache is None args, _ = load_persistent_cache.call_args assert args[1] is False - await credential.get_token("scope", enable_cae=True) + kwargs = {"enable_cae": True} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + await getattr(credential, get_token_method)("scope", **kwargs) assert load_persistent_cache.call_count == 2 assert credential._client._cae_cache is not None args, _ = load_persistent_cache.call_args @@ -256,7 +271,8 @@ async def test_token_cache_persistent(): @pytest.mark.asyncio -async def test_cache_multiple_clients(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_cache_multiple_clients(get_token_method): """the credential shouldn't use tokens issued to other service principals""" access_token_a = "token a" @@ -291,7 +307,7 @@ async def test_cache_multiple_clients(): # A caches a token scope = "scope" - token_a = await credential_a.get_token(scope) + token_a = await getattr(credential_a, get_token_method)(scope) assert token_a.token == access_token_a assert transport_a.send.call_count == 1 assert mock_cache_loader.call_count == 1 @@ -299,7 +315,7 @@ async def test_cache_multiple_clients(): assert args[1] is False # B should get a different token for the same scope - token_b = await credential_b.get_token(scope) + token_b = await getattr(credential_b, get_token_method)(scope) assert token_b.token == access_token_b assert transport_b.send.call_count == 1 assert mock_cache_loader.call_count == 2 @@ -308,7 +324,8 @@ async def test_cache_multiple_clients(): @pytest.mark.asyncio -async def test_multitenant_authentication(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_multitenant_authentication(get_token_method): first_tenant = "first-tenant" first_token = "***" second_tenant = "second-tenant" @@ -329,22 +346,29 @@ async def send(request, **kwargs): credential = ClientSecretCredential( first_tenant, "client-id", "secret", transport=Mock(send=send), additionally_allowed_tenants=["*"] ) - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == first_token - token = await credential.get_token("scope", tenant_id=first_tenant) + kwargs = {"tenant_id": first_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(credential, get_token_method)("scope", **kwargs) assert token.token == first_token - token = await credential.get_token("scope", tenant_id=second_tenant) + kwargs = {"tenant_id": second_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(credential, get_token_method)("scope", **kwargs) assert token.token == second_token # should still default to the first tenant - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == first_token @pytest.mark.asyncio -async def test_live_multitenant_authentication(live_service_principal): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_live_multitenant_authentication(live_service_principal, get_token_method): # first create a credential with a non-existent tenant credential = ClientSecretCredential( "...", @@ -352,16 +376,18 @@ async def test_live_multitenant_authentication(live_service_principal): live_service_principal["client_secret"], additionally_allowed_tenants=["*"], ) + kwargs = {"tenant_id": live_service_principal["tenant_id"]} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} # then get a valid token for an actual tenant - token = await credential.get_token( - "https://vault.azure.net/.default", tenant_id=live_service_principal["tenant_id"] - ) + token = await getattr(credential, get_token_method)("https://vault.azure.net/.default", **kwargs) assert token.token assert token.expires_on @pytest.mark.asyncio -async def test_multitenant_authentication_not_allowed(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_multitenant_authentication_not_allowed(get_token_method): expected_tenant = "expected-tenant" expected_token = "***" @@ -378,15 +404,21 @@ async def send(request, **kwargs): expected_tenant, "client-id", "secret", transport=Mock(send=send), additionally_allowed_tenants=["*"] ) - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == expected_token - token = await credential.get_token("scope", tenant_id=expected_tenant) + kwargs = {"tenant_id": expected_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(credential, get_token_method)("scope", **kwargs) assert token.token == expected_token - token = await credential.get_token("scope", tenant_id="un" + expected_tenant) + kwargs = {"tenant_id": "un" + expected_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(credential, get_token_method)("scope", **kwargs) assert token.token == expected_token * 2 with patch.dict("os.environ", {EnvironmentVariables.AZURE_IDENTITY_DISABLE_MULTITENANTAUTH: "true"}): - token = await credential.get_token("scope", tenant_id="un" + expected_tenant) + token = await getattr(credential, get_token_method)("scope", **kwargs) assert token.token == expected_token diff --git a/sdk/identity/azure-identity/tests/test_context_manager.py b/sdk/identity/azure-identity/tests/test_context_manager.py index e2d98c80eb07..ba97ef28fb94 100644 --- a/sdk/identity/azure-identity/tests/test_context_manager.py +++ b/sdk/identity/azure-identity/tests/test_context_manager.py @@ -2,10 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ -try: - from unittest.mock import MagicMock, patch -except ImportError: - from mock import MagicMock, patch # type: ignore +from unittest.mock import MagicMock, patch from azure.identity._credentials.application import AzureApplicationCredential from azure.identity import ( diff --git a/sdk/identity/azure-identity/tests/test_default.py b/sdk/identity/azure-identity/tests/test_default.py index bbcff2f013ec..b77bfbcbad20 100644 --- a/sdk/identity/azure-identity/tests/test_default.py +++ b/sdk/identity/azure-identity/tests/test_default.py @@ -4,7 +4,7 @@ # ------------------------------------ import os -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo from azure.identity import ( AzureCliCredential, AzureDeveloperCliCredential, @@ -22,7 +22,7 @@ import pytest from urllib.parse import urlparse -from helpers import mock_response, Request, validating_transport +from helpers import mock_response, Request, validating_transport, GET_TOKEN_METHODS from test_shared_cache_credential import build_aad_response, get_account_event, populated_cache from unittest.mock import MagicMock, Mock, patch @@ -50,28 +50,36 @@ def test_context_manager(): assert transport.__exit__.called -def test_iterates_only_once(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_iterates_only_once(get_token_method): """When a credential succeeds, DefaultAzureCredential should use that credential thereafter, ignoring the others""" unavailable_credential = Mock( - spec_set=["get_token"], get_token=Mock(side_effect=CredentialUnavailableError(message="...")) + spec_set=["get_token", "get_token_info"], + get_token=Mock(side_effect=CredentialUnavailableError(message="...")), + get_token_info=Mock(side_effect=CredentialUnavailableError(message="...")), + ) + successful_credential = Mock( + spec_set=["get_token", "get_token_info"], + get_token=Mock(return_value=AccessToken("***", 42)), + get_token_info=Mock(return_value=AccessTokenInfo("***", 42)), ) - successful_credential = Mock(spec_set=["get_token"], get_token=Mock(return_value=AccessToken("***", 42))) credential = DefaultAzureCredential() - credential.credentials = [ + credential.credentials = ( unavailable_credential, successful_credential, Mock( - spec_set=["get_token"], + spec_set=["get_token", "get_token_info"], get_token=Mock(side_effect=Exception("iteration didn't stop after a credential provided a token")), + get_token_info=Mock(side_effect=Exception("iteration didn't stop after a credential provided a token")), ), - ] + ) for n in range(3): - credential.get_token("scope") - assert unavailable_credential.get_token.call_count == 1 - assert successful_credential.get_token.call_count == n + 1 + getattr(credential, get_token_method)("scope") + assert getattr(unavailable_credential, get_token_method).call_count == 1 + assert getattr(successful_credential, get_token_method).call_count == n + 1 @pytest.mark.parametrize("authority", ("localhost", "https://localhost")) @@ -174,7 +182,8 @@ def assert_credentials_not_present(chain, *excluded_credential_classes): assert actual - default == {InteractiveBrowserCredential} -def test_shared_cache_tenant_id(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_shared_cache_tenant_id(get_token_method): expected_access_token = "expected-access-token" refresh_token_a = "refresh-token-a" refresh_token_b = "refresh-token-b" @@ -195,14 +204,14 @@ def test_shared_cache_tenant_id(): credential = get_credential_for_shared_cache_test( refresh_token_b, expected_access_token, cache, shared_cache_tenant_id=tenant_b ) - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == expected_access_token # redundantly specifying shared_cache_username makes no difference credential = get_credential_for_shared_cache_test( refresh_token_b, expected_access_token, cache, shared_cache_tenant_id=tenant_b, shared_cache_username=upn ) - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == expected_access_token # shared_cache_tenant_id should prevail over AZURE_TENANT_ID @@ -210,17 +219,18 @@ def test_shared_cache_tenant_id(): credential = get_credential_for_shared_cache_test( refresh_token_b, expected_access_token, cache, shared_cache_tenant_id=tenant_b ) - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == expected_access_token # AZURE_TENANT_ID should be used when shared_cache_tenant_id isn't specified with patch("os.environ", {EnvironmentVariables.AZURE_TENANT_ID: tenant_b}): credential = get_credential_for_shared_cache_test(refresh_token_b, expected_access_token, cache) - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == expected_access_token -def test_shared_cache_username(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_shared_cache_username(get_token_method): expected_access_token = "expected-access-token" refresh_token_a = "refresh-token-a" refresh_token_b = "refresh-token-b" @@ -240,14 +250,14 @@ def test_shared_cache_username(): credential = get_credential_for_shared_cache_test( refresh_token_a, expected_access_token, cache, shared_cache_username=upn_a ) - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == expected_access_token # redundantly specifying shared_cache_tenant_id makes no difference credential = get_credential_for_shared_cache_test( refresh_token_a, expected_access_token, cache, shared_cache_tenant_id=tenant_id, shared_cache_username=upn_a ) - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == expected_access_token # shared_cache_username should prevail over AZURE_USERNAME @@ -255,13 +265,13 @@ def test_shared_cache_username(): credential = get_credential_for_shared_cache_test( refresh_token_a, expected_access_token, cache, shared_cache_username=upn_a ) - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == expected_access_token # AZURE_USERNAME should be used when shared_cache_username isn't specified with patch("os.environ", {EnvironmentVariables.AZURE_USERNAME: upn_b}): credential = get_credential_for_shared_cache_test(refresh_token_b, expected_access_token, cache) - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == expected_access_token diff --git a/sdk/identity/azure-identity/tests/test_default_async.py b/sdk/identity/azure-identity/tests/test_default_async.py index 57b466ebb306..18a942b705a8 100644 --- a/sdk/identity/azure-identity/tests/test_default_async.py +++ b/sdk/identity/azure-identity/tests/test_default_async.py @@ -6,7 +6,7 @@ from unittest.mock import Mock, patch from urllib.parse import urlparse -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo from azure.identity import CredentialUnavailableError from azure.identity.aio import ( AzurePowerShellCredential, @@ -20,36 +20,42 @@ from azure.identity._constants import EnvironmentVariables import pytest -from helpers import mock_response, Request +from helpers import mock_response, Request, GET_TOKEN_METHODS from helpers_async import async_validating_transport, get_completed_future, wrap_in_future from test_shared_cache_credential import build_aad_response, get_account_event, populated_cache @pytest.mark.asyncio -async def test_iterates_only_once(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_iterates_only_once(get_token_method): """When a credential succeeds, DefaultAzureCredential should use that credential thereafter, ignoring the others""" unavailable_credential = Mock( - spec_set=["get_token"], get_token=Mock(side_effect=CredentialUnavailableError(message="...")) + spec_set=["get_token", "get_token_info"], + get_token=Mock(side_effect=CredentialUnavailableError(message="...")), + get_token_info=Mock(side_effect=CredentialUnavailableError(message="...")), ) successful_credential = Mock( - spec_set=["get_token"], get_token=Mock(return_value=get_completed_future(AccessToken("***", 42))) + spec_set=["get_token", "get_token_info"], + get_token=Mock(return_value=get_completed_future(AccessToken("***", 42))), + get_token_info=Mock(return_value=get_completed_future(AccessTokenInfo("***", 42))), ) credential = DefaultAzureCredential() - credential.credentials = [ + credential.credentials = ( unavailable_credential, successful_credential, Mock( - spec_set=["get_token"], + spec_set=["get_token", "get_token_info"], get_token=Mock(side_effect=Exception("iteration didn't stop after a credential provided a token")), + get_token_info=Mock(side_effect=Exception("iteration didn't stop after a credential provided a token")), ), - ] + ) for n in range(3): - await credential.get_token("scope") - assert unavailable_credential.get_token.call_count == 1 - assert successful_credential.get_token.call_count == n + 1 + await getattr(credential, get_token_method)("scope") + assert getattr(unavailable_credential, get_token_method).call_count == 1 + assert getattr(successful_credential, get_token_method).call_count == n + 1 @pytest.mark.parametrize("authority", ("localhost", "https://localhost")) @@ -145,7 +151,8 @@ def assert_credentials_not_present(chain, *credential_classes): @pytest.mark.asyncio -async def test_shared_cache_tenant_id(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_shared_cache_tenant_id(get_token_method): expected_access_token = "expected-access-token" refresh_token_a = "refresh-token-a" refresh_token_b = "refresh-token-b" @@ -166,14 +173,14 @@ async def test_shared_cache_tenant_id(): credential = get_credential_for_shared_cache_test( refresh_token_b, expected_access_token, cache, shared_cache_tenant_id=tenant_b ) - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == expected_access_token # redundantly specifying shared_cache_username makes no difference credential = get_credential_for_shared_cache_test( refresh_token_b, expected_access_token, cache, shared_cache_tenant_id=tenant_b, shared_cache_username=upn ) - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == expected_access_token # shared_cache_tenant_id should prevail over AZURE_TENANT_ID @@ -181,18 +188,19 @@ async def test_shared_cache_tenant_id(): credential = get_credential_for_shared_cache_test( refresh_token_b, expected_access_token, cache, shared_cache_tenant_id=tenant_b ) - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == expected_access_token # AZURE_TENANT_ID should be used when shared_cache_tenant_id isn't specified with patch("os.environ", {EnvironmentVariables.AZURE_TENANT_ID: tenant_b}): credential = get_credential_for_shared_cache_test(refresh_token_b, expected_access_token, cache) - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == expected_access_token @pytest.mark.asyncio -async def test_shared_cache_username(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_shared_cache_username(get_token_method): expected_access_token = "expected-access-token" refresh_token_a = "refresh-token-a" refresh_token_b = "refresh-token-b" @@ -212,7 +220,7 @@ async def test_shared_cache_username(): credential = get_credential_for_shared_cache_test( refresh_token_a, expected_access_token, cache, shared_cache_username=upn_a ) - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == expected_access_token # shared_cache_username should prevail over AZURE_USERNAME @@ -220,13 +228,13 @@ async def test_shared_cache_username(): credential = get_credential_for_shared_cache_test( refresh_token_a, expected_access_token, cache, shared_cache_username=upn_a ) - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == expected_access_token # AZURE_USERNAME should be used when shared_cache_username isn't specified with patch("os.environ", {EnvironmentVariables.AZURE_USERNAME: upn_b}): credential = get_credential_for_shared_cache_test(refresh_token_b, expected_access_token, cache) - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == expected_access_token @@ -305,7 +313,7 @@ def test_process_timeout(): assert kwargs["process_timeout"] == timeout -def test_process_timeout(): +def test_process_timeout_default(): """the credential should allow configuring a process timeout for Azure CLI and PowerShell by kwarg""" with patch(DefaultAzureCredential.__module__ + ".AzureCliCredential") as mock_cli_credential: diff --git a/sdk/identity/azure-identity/tests/test_device_code_credential.py b/sdk/identity/azure-identity/tests/test_device_code_credential.py index d8ee36d4f59d..c1bebe13de29 100644 --- a/sdk/identity/azure-identity/tests/test_device_code_credential.py +++ b/sdk/identity/azure-identity/tests/test_device_code_credential.py @@ -19,6 +19,7 @@ mock_response, Request, validating_transport, + GET_TOKEN_METHODS, ) @@ -35,15 +36,17 @@ def test_tenant_id_validation(): DeviceCodeCredential(tenant_id=tenant) -def test_no_scopes(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_no_scopes(get_token_method): """The credential should raise when get_token is called with no scopes""" credential = DeviceCodeCredential("client_id") with pytest.raises(ValueError): - credential.get_token() + getattr(credential, get_token_method)() -def test_authenticate(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_authenticate(get_token_method): client_id = "client-id" environment = "localhost" issuer = "https://" + environment @@ -92,21 +95,23 @@ def test_authenticate(): assert record.username == username # credential should have a cached access token for the scope used in authenticate - token = credential.get_token(scope) + token = getattr(credential, get_token_method)(scope) assert token.token == access_token -def test_disable_automatic_authentication(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_disable_automatic_authentication(get_token_method): """When configured for strict silent auth, the credential should raise when silent auth fails""" transport = Mock(send=Mock(side_effect=Exception("no request should be sent"))) credential = DeviceCodeCredential("client-id", disable_automatic_authentication=True, transport=transport) with pytest.raises(AuthenticationRequiredError): - credential.get_token("scope") + getattr(credential, get_token_method)("scope") -def test_policies_configurable(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_policies_configurable(get_token_method): policy = Mock(spec_set=SansIOHTTPPolicy, on_request=Mock()) client_id = "client-id" @@ -135,12 +140,13 @@ def test_policies_configurable(): client_id=client_id, prompt_callback=Mock(), policies=[policy], transport=transport ) - credential.get_token("scope") + getattr(credential, get_token_method)("scope") assert policy.on_request.called -def test_user_agent(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_user_agent(get_token_method): client_id = "client-id" transport = validating_transport( requests=[Request()] * 2 + [Request(required_headers={"User-Agent": USER_AGENT})], @@ -164,10 +170,11 @@ def test_user_agent(): credential = DeviceCodeCredential(client_id=client_id, prompt_callback=Mock(), transport=transport) - credential.get_token("scope") + getattr(credential, get_token_method)("scope") -def test_device_code_credential(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_device_code_credential(get_token_method): client_id = "client-id" expected_token = "access-token" user_code = "user-code" @@ -210,7 +217,7 @@ def test_device_code_credential(): ) now = datetime.datetime.now(datetime.timezone.utc) - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == expected_token # prompt_callback should have been called as documented @@ -226,7 +233,8 @@ def test_device_code_credential(): assert expires_on - now >= datetime.timedelta(seconds=expires_in) -def test_tenant_id(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_tenant_id(get_token_method): client_id = "client-id" expected_token = "access-token" user_code = "user-code" @@ -269,11 +277,15 @@ def test_tenant_id(): additionally_allowed_tenants=["*"], ) - token = credential.get_token("scope", tenant_id="tenant_id") + kwargs = {"tenant_id": "tenant_id"} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert token.token == expected_token -def test_timeout(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_timeout(get_token_method): flow = {"expires_in": 1800, "message": "foo"} with patch.object(DeviceCodeCredential, "_get_app") as get_app: msal_app = get_app() @@ -282,7 +294,7 @@ def test_timeout(): credential = DeviceCodeCredential(client_id="_", timeout=1, disable_instance_discovery=True) with pytest.raises(ClientAuthenticationError) as ex: - credential.get_token("scope") + getattr(credential, get_token_method)("scope") assert "timed out" in ex.value.message.lower() msal_app.acquire_token_by_device_flow.assert_called_once_with(flow, exit_condition=ANY, claims_challenge=None) @@ -306,7 +318,8 @@ def test_client_capabilities(): assert kwargs["client_capabilities"] == ["CP1"] -def test_claims_challenge(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_claims_challenge(get_token_method): """get_token and authenticate should pass any claims challenge to MSAL token acquisition APIs""" msal_acquire_token_result = dict( @@ -328,7 +341,10 @@ def test_claims_challenge(): args, kwargs = msal_app.acquire_token_by_device_flow.call_args assert kwargs["claims_challenge"] == expected_claims - credential.get_token("scope", claims=expected_claims) + kwargs = {"claims": expected_claims} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + getattr(credential, get_token_method)("scope", **kwargs) assert msal_app.acquire_token_by_device_flow.call_count == 2 args, kwargs = msal_app.acquire_token_by_device_flow.call_args @@ -336,7 +352,11 @@ def test_claims_challenge(): msal_app.get_accounts.return_value = [{"home_account_id": credential._auth_record.home_account_id}] msal_app.acquire_token_silent_with_error.return_value = msal_acquire_token_result - credential.get_token("scope", claims=expected_claims) + + kwargs = {"claims": expected_claims} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + getattr(credential, get_token_method)("scope", **kwargs) assert msal_app.acquire_token_silent_with_error.call_count == 1 args, kwargs = msal_app.acquire_token_silent_with_error.call_args diff --git a/sdk/identity/azure-identity/tests/test_environment_credential.py b/sdk/identity/azure-identity/tests/test_environment_credential.py index 1bce68eb34b5..b9e11908e632 100644 --- a/sdk/identity/azure-identity/tests/test_environment_credential.py +++ b/sdk/identity/azure-identity/tests/test_environment_credential.py @@ -9,7 +9,7 @@ from azure.identity._constants import EnvironmentVariables import pytest -from helpers import mock +from helpers import mock, GET_TOKEN_METHODS ALL_VARIABLES = { @@ -20,17 +20,18 @@ } -def test_incomplete_configuration(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_incomplete_configuration(get_token_method): """get_token should raise CredentialUnavailableError for incomplete configuration.""" with mock.patch.dict(os.environ, {}, clear=True): with pytest.raises(CredentialUnavailableError) as ex: - EnvironmentCredential().get_token("scope") + getattr(EnvironmentCredential(), get_token_method)("scope") for a, b in itertools.combinations(ALL_VARIABLES, 2): # all credentials require at least 3 variables set with mock.patch.dict(os.environ, {a: "a", b: "b"}, clear=True): with pytest.raises(CredentialUnavailableError) as ex: - EnvironmentCredential().get_token("scope") + getattr(EnvironmentCredential(), get_token_method)("scope") @pytest.mark.parametrize( diff --git a/sdk/identity/azure-identity/tests/test_environment_credential_async.py b/sdk/identity/azure-identity/tests/test_environment_credential_async.py index 480cc5ac2026..60dbfc07625a 100644 --- a/sdk/identity/azure-identity/tests/test_environment_credential_async.py +++ b/sdk/identity/azure-identity/tests/test_environment_credential_async.py @@ -10,7 +10,7 @@ from azure.identity._constants import EnvironmentVariables import pytest -from helpers import mock_response, Request +from helpers import mock_response, Request, GET_TOKEN_METHODS from helpers_async import async_validating_transport, AsyncMockTransport from test_environment_credential import ALL_VARIABLES @@ -56,17 +56,18 @@ async def test_context_manager_incomplete_configuration(): @pytest.mark.asyncio -async def test_incomplete_configuration(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_incomplete_configuration(get_token_method): """get_token should raise CredentialUnavailableError for incomplete configuration.""" with mock.patch.dict(ENVIRON, {}, clear=True): with pytest.raises(CredentialUnavailableError) as ex: - await EnvironmentCredential().get_token("scope") + await getattr(EnvironmentCredential(), get_token_method)("scope") for a, b in itertools.combinations(ALL_VARIABLES, 2): # all credentials require at least 3 variables set with mock.patch.dict(ENVIRON, {a: "a", b: "b"}, clear=True): with pytest.raises(CredentialUnavailableError) as ex: - await EnvironmentCredential().get_token("scope") + await getattr(EnvironmentCredential(), get_token_method)("scope") @pytest.mark.parametrize( @@ -169,7 +170,8 @@ def test_certificate_with_password_configuration(): @pytest.mark.asyncio -async def test_client_secret_environment_credential(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_client_secret_environment_credential(get_token_method): client_id = "fake-client-id" secret = "fake-client-secret" tenant_id = "fake-tenant-id" @@ -195,6 +197,6 @@ async def test_client_secret_environment_credential(): EnvironmentVariables.AZURE_TENANT_ID: tenant_id, } with mock.patch.dict(ENVIRON, environment, clear=True): - token = await EnvironmentCredential(transport=transport).get_token("scope") + token = await getattr(EnvironmentCredential(transport=transport), get_token_method)("scope") assert token.token == access_token diff --git a/sdk/identity/azure-identity/tests/test_get_token_mixin.py b/sdk/identity/azure-identity/tests/test_get_token_mixin.py index e3326b8f5cca..e0c877cb9a7b 100644 --- a/sdk/identity/azure-identity/tests/test_get_token_mixin.py +++ b/sdk/identity/azure-identity/tests/test_get_token_mixin.py @@ -5,15 +5,17 @@ import time from unittest import mock -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessTokenInfo import pytest from azure.identity._constants import DEFAULT_REFRESH_OFFSET from azure.identity._internal.get_token_mixin import GetTokenMixin +from helpers import GET_TOKEN_METHODS + class MockCredential(GetTokenMixin): - NEW_TOKEN = AccessToken("new token", 42) + NEW_TOKEN = AccessTokenInfo("new token", 42) def __init__(self, cached_token=None): super(MockCredential, self).__init__() @@ -29,85 +31,102 @@ def _request_token(self, *scopes, **kwargs): def get_token(self, *_, **__): return super(MockCredential, self).get_token(*_, **__) + def get_token_info(self, *_, **__): + return super(MockCredential, self).get_token_info(*_, **__) + CACHED_TOKEN = "cached token" SCOPE = "scope" -def test_no_cached_token(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_no_cached_token(get_token_method): """When it has no token cached, a credential should request one every time get_token is called""" credential = MockCredential() - token = credential.get_token(SCOPE) + token = getattr(credential, get_token_method)(SCOPE) credential.acquire_token_silently.assert_called_once_with(SCOPE, claims=None, enable_cae=False, tenant_id=None) credential.request_token.assert_called_once_with(SCOPE, claims=None, enable_cae=False, tenant_id=None) assert token.token == MockCredential.NEW_TOKEN.token -def test_tenant_id(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_tenant_id(get_token_method): credential = MockCredential() - token = credential.get_token(SCOPE, tenant_id="tenant_id") + kwargs = {"tenant_id": "tenant_id"} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)(SCOPE, **kwargs) assert token.token == MockCredential.NEW_TOKEN.token -def test_token_acquisition_failure(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_token_acquisition_failure(get_token_method): """When the credential has no token cached, every get_token call should prompt a token request""" credential = MockCredential() credential.request_token = mock.Mock(side_effect=Exception("whoops")) for i in range(4): with pytest.raises(Exception): - credential.get_token(SCOPE) + getattr(credential, get_token_method)(SCOPE) assert credential.request_token.call_count == i + 1 credential.request_token.assert_called_with(SCOPE, claims=None, enable_cae=False, tenant_id=None) -def test_expired_token(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_expired_token(get_token_method): """A credential should request a token when it has an expired token cached""" - now = time.time() - credential = MockCredential(cached_token=AccessToken(CACHED_TOKEN, now - 1)) - token = credential.get_token(SCOPE) + now = int(time.time()) + credential = MockCredential(cached_token=AccessTokenInfo(CACHED_TOKEN, now - 1)) + token = getattr(credential, get_token_method)(SCOPE) credential.acquire_token_silently.assert_called_once_with(SCOPE, claims=None, enable_cae=False, tenant_id=None) credential.request_token.assert_called_once_with(SCOPE, claims=None, enable_cae=False, tenant_id=None) assert token.token == MockCredential.NEW_TOKEN.token -def test_cached_token_outside_refresh_window(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_cached_token_outside_refresh_window(get_token_method): """A credential shouldn't request a new token when it has a cached one with sufficient validity remaining""" - credential = MockCredential(cached_token=AccessToken(CACHED_TOKEN, time.time() + DEFAULT_REFRESH_OFFSET + 1)) - token = credential.get_token(SCOPE) + credential = MockCredential( + cached_token=AccessTokenInfo(CACHED_TOKEN, int(time.time() + DEFAULT_REFRESH_OFFSET + 1)) + ) + token = getattr(credential, get_token_method)(SCOPE) credential.acquire_token_silently.assert_called_once_with(SCOPE, claims=None, enable_cae=False, tenant_id=None) assert credential.request_token.call_count == 0 assert token.token == CACHED_TOKEN -def test_cached_token_within_refresh_window(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_cached_token_within_refresh_window(get_token_method): """A credential should request a new token when its cached one is within the refresh window""" - credential = MockCredential(cached_token=AccessToken(CACHED_TOKEN, time.time() + DEFAULT_REFRESH_OFFSET - 1)) - token = credential.get_token(SCOPE) + credential = MockCredential( + cached_token=AccessTokenInfo(CACHED_TOKEN, int(time.time() + DEFAULT_REFRESH_OFFSET - 1)) + ) + token = getattr(credential, get_token_method)(SCOPE) credential.acquire_token_silently.assert_called_once_with(SCOPE, claims=None, enable_cae=False, tenant_id=None) credential.request_token.assert_called_once_with(SCOPE, claims=None, enable_cae=False, tenant_id=None) assert token.token == MockCredential.NEW_TOKEN.token -def test_retry_delay(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_retry_delay(get_token_method): """A credential should wait between requests when trying to refresh a token""" now = time.time() - credential = MockCredential(cached_token=AccessToken(CACHED_TOKEN, now + DEFAULT_REFRESH_OFFSET - 1)) + credential = MockCredential(cached_token=AccessTokenInfo(CACHED_TOKEN, int(now + DEFAULT_REFRESH_OFFSET - 1))) # the credential should swallow exceptions during proactive refresh attempts credential.request_token = mock.Mock(side_effect=Exception("whoops")) for i in range(4): - token = credential.get_token(SCOPE) + token = getattr(credential, get_token_method)(SCOPE) assert token.token == CACHED_TOKEN credential.acquire_token_silently.assert_called_with(SCOPE, claims=None, enable_cae=False, tenant_id=None) credential.request_token.assert_called_once_with(SCOPE, claims=None, enable_cae=False, tenant_id=None) diff --git a/sdk/identity/azure-identity/tests/test_get_token_mixin_async.py b/sdk/identity/azure-identity/tests/test_get_token_mixin_async.py index 3db422190714..562ac58383dc 100644 --- a/sdk/identity/azure-identity/tests/test_get_token_mixin_async.py +++ b/sdk/identity/azure-identity/tests/test_get_token_mixin_async.py @@ -5,17 +5,19 @@ import time from unittest import mock -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessTokenInfo import pytest from azure.identity._constants import DEFAULT_REFRESH_OFFSET from azure.identity.aio._internal.get_token_mixin import GetTokenMixin +from helpers import GET_TOKEN_METHODS + pytestmark = pytest.mark.asyncio class MockCredential(GetTokenMixin): - NEW_TOKEN = AccessToken("new token", 42) + NEW_TOKEN = AccessTokenInfo("new token", 42) def __init__(self, cached_token=None): super(MockCredential, self).__init__() @@ -32,85 +34,102 @@ async def _request_token(self, *scopes, **kwargs): async def get_token(self, *_, **__): return await super().get_token(*_, **__) + async def get_token_info(self, *_, **__): + return await super().get_token_info(*_, **__) + CACHED_TOKEN = "cached token" SCOPE = "scope" -async def test_no_cached_token(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_no_cached_token(get_token_method): """When it has no token cached, a credential should request one every time get_token is called""" credential = MockCredential() - token = await credential.get_token(SCOPE) + token = await getattr(credential, get_token_method)(SCOPE) credential.acquire_token_silently.assert_called_once_with(SCOPE, claims=None, enable_cae=False, tenant_id=None) credential.request_token.assert_called_once_with(SCOPE, claims=None, enable_cae=False, tenant_id=None) assert token.token == MockCredential.NEW_TOKEN.token -async def test_tenant_id(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_tenant_id(get_token_method): credential = MockCredential() - token = await credential.get_token(SCOPE, tenant_id="tenant_id") + kwargs = {"tenant_id": "tenant_id"} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(credential, get_token_method)(SCOPE, **kwargs) assert token.token == MockCredential.NEW_TOKEN.token -async def test_token_acquisition_failure(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_token_acquisition_failure(get_token_method): """When the credential has no token cached, every get_token call should prompt a token request""" credential = MockCredential() credential.request_token = mock.Mock(side_effect=Exception("whoops")) for i in range(4): with pytest.raises(Exception): - await credential.get_token(SCOPE) + await getattr(credential, get_token_method)(SCOPE) assert credential.request_token.call_count == i + 1 credential.request_token.assert_called_with(SCOPE, claims=None, enable_cae=False, tenant_id=None) -async def test_expired_token(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_expired_token(get_token_method): """A credential should request a token when it has an expired token cached""" - now = time.time() - credential = MockCredential(cached_token=AccessToken(CACHED_TOKEN, now - 1)) - token = await credential.get_token(SCOPE) + now = int(time.time()) + credential = MockCredential(cached_token=AccessTokenInfo(CACHED_TOKEN, now - 1)) + token = await getattr(credential, get_token_method)(SCOPE) credential.acquire_token_silently.assert_called_once_with(SCOPE, claims=None, enable_cae=False, tenant_id=None) credential.request_token.assert_called_once_with(SCOPE, claims=None, enable_cae=False, tenant_id=None) assert token.token == MockCredential.NEW_TOKEN.token -async def test_cached_token_outside_refresh_window(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_cached_token_outside_refresh_window(get_token_method): """A credential shouldn't request a new token when it has a cached one with sufficient validity remaining""" - credential = MockCredential(cached_token=AccessToken(CACHED_TOKEN, time.time() + DEFAULT_REFRESH_OFFSET + 1)) - token = await credential.get_token(SCOPE) + credential = MockCredential( + cached_token=AccessTokenInfo(CACHED_TOKEN, int(time.time() + DEFAULT_REFRESH_OFFSET + 1)) + ) + token = await getattr(credential, get_token_method)(SCOPE) credential.acquire_token_silently.assert_called_once_with(SCOPE, claims=None, enable_cae=False, tenant_id=None) assert credential.request_token.call_count == 0 assert token.token == CACHED_TOKEN -async def test_cached_token_within_refresh_window(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_cached_token_within_refresh_window(get_token_method): """A credential should request a new token when its cached one is within the refresh window""" - credential = MockCredential(cached_token=AccessToken(CACHED_TOKEN, time.time() + DEFAULT_REFRESH_OFFSET - 1)) - token = await credential.get_token(SCOPE) + credential = MockCredential( + cached_token=AccessTokenInfo(CACHED_TOKEN, int(time.time() + DEFAULT_REFRESH_OFFSET - 1)) + ) + token = await getattr(credential, get_token_method)(SCOPE) credential.acquire_token_silently.assert_called_once_with(SCOPE, claims=None, enable_cae=False, tenant_id=None) credential.request_token.assert_called_once_with(SCOPE, claims=None, enable_cae=False, tenant_id=None) assert token.token == MockCredential.NEW_TOKEN.token -async def test_retry_delay(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_retry_delay(get_token_method): """A credential should wait between requests when trying to refresh a token""" now = time.time() - credential = MockCredential(cached_token=AccessToken(CACHED_TOKEN, now + DEFAULT_REFRESH_OFFSET - 1)) + credential = MockCredential(cached_token=AccessTokenInfo(CACHED_TOKEN, int(now + DEFAULT_REFRESH_OFFSET - 1))) # the credential should swallow exceptions during proactive refresh attempts credential.request_token = mock.Mock(side_effect=Exception("whoops")) for i in range(4): - token = await credential.get_token(SCOPE) + token = await getattr(credential, get_token_method)(SCOPE) assert token.token == CACHED_TOKEN credential.acquire_token_silently.assert_called_with(SCOPE, claims=None, enable_cae=False, tenant_id=None) credential.request_token.assert_called_once_with(SCOPE, claims=None, enable_cae=False, tenant_id=None) diff --git a/sdk/identity/azure-identity/tests/test_imds_credential.py b/sdk/identity/azure-identity/tests/test_imds_credential.py index 3bced6054484..274225492763 100644 --- a/sdk/identity/azure-identity/tests/test_imds_credential.py +++ b/sdk/identity/azure-identity/tests/test_imds_credential.py @@ -2,50 +2,47 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ -import json +from itertools import product import time -from devtools_testutils import recorded_by_proxy -from azure.core.credentials import AccessToken -from azure.core.exceptions import ClientAuthenticationError - from azure.identity import CredentialUnavailableError -from azure.identity._constants import EnvironmentVariables -from azure.identity._credentials.imds import IMDS_TOKEN_PATH, ImdsCredential, IMDS_AUTHORITY, PIPELINE_SETTINGS -from azure.identity._internal.user_agent import USER_AGENT +from azure.identity._credentials.imds import IMDS_TOKEN_PATH, ImdsCredential, IMDS_AUTHORITY from azure.identity._internal.utils import within_credential_chain import pytest -from helpers import mock, mock_response, Request, validating_transport +from helpers import mock, mock_response, Request, validating_transport, GET_TOKEN_METHODS from recorded_test_case import RecordedTestCase -def test_no_scopes(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_no_scopes(get_token_method): """The credential should raise ValueError when get_token is called with no scopes""" credential = ImdsCredential() with pytest.raises(ValueError): - credential.get_token() + getattr(credential, get_token_method)() -def test_multiple_scopes(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_multiple_scopes(get_token_method): """The credential should raise ValueError when get_token is called with more than one scope""" credential = ImdsCredential() with pytest.raises(ValueError): - credential.get_token("one scope", "and another") + getattr(credential, get_token_method)("one scope", "and another") -def test_identity_not_available(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_identity_not_available(get_token_method): """The credential should raise CredentialUnavailableError when the endpoint responds 400 to a token request""" transport = validating_transport(requests=[Request()], responses=[mock_response(status_code=400, json_payload={})]) credential = ImdsCredential(transport=transport) with pytest.raises(CredentialUnavailableError): - credential.get_token("scope") + getattr(credential, get_token_method)("scope") -@pytest.mark.parametrize("error_ending", ("network", "host", "foo")) -def test_imds_request_failure_docker_desktop(error_ending): +@pytest.mark.parametrize("error_ending,get_token_method", product(("network", "host", "foo"), GET_TOKEN_METHODS)) +def test_imds_request_failure_docker_desktop(error_ending, get_token_method): """The credential should raise CredentialUnavailableError when a 403 with a specific message is received""" error_message = ( @@ -57,47 +54,55 @@ def test_imds_request_failure_docker_desktop(error_ending): credential = ImdsCredential(transport=transport) with pytest.raises(CredentialUnavailableError) as ex: - credential.get_token("scope") + getattr(credential, get_token_method)("scope") assert error_message in ex.value.message @pytest.mark.usefixtures("record_imds_test") class TestImds(RecordedTestCase): - @recorded_by_proxy - def test_system_assigned(self): + + @pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) + def test_system_assigned(self, recorded_test, get_token_method): credential = ImdsCredential() - token = credential.get_token(self.scope) + token = getattr(credential, get_token_method)(self.scope) assert token.token assert isinstance(token.expires_on, int) - @recorded_by_proxy - def test_system_assigned_tenant_id(self): + @pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) + def test_system_assigned_tenant_id(self, recorded_test, get_token_method): credential = ImdsCredential() - token = credential.get_token(self.scope, tenant_id="tenant_id") + kwargs = {"tenant_id": "tenant_id"} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)(self.scope, **kwargs) assert token.token assert isinstance(token.expires_on, int) @pytest.mark.usefixtures("user_assigned_identity_client_id") - @recorded_by_proxy - def test_user_assigned(self): + @pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) + def test_user_assigned(self, recorded_test, get_token_method): credential = ImdsCredential(client_id=self.user_assigned_identity_client_id) - token = credential.get_token(self.scope) + token = getattr(credential, get_token_method)(self.scope) assert token.token assert isinstance(token.expires_on, int) @pytest.mark.usefixtures("user_assigned_identity_client_id") - @recorded_by_proxy - def test_user_assigned_tenant_id(self): + @pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) + def test_user_assigned_tenant_id(self, recorded_test, get_token_method): credential = ImdsCredential(client_id=self.user_assigned_identity_client_id) - token = credential.get_token(self.scope, tenant_id="tenant_id") + kwargs = {"tenant_id": "tenant_id"} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)(self.scope, **kwargs) assert token.token assert isinstance(token.expires_on, int) - def test_managed_identity_aci_probe(self): + @pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) + def test_managed_identity_aci_probe(self, get_token_method): access_token = "****" expires_on = 42 - expected_token = AccessToken(access_token, expires_on) + expected_token = access_token scope = "scope" transport = validating_transport( requests=[ @@ -126,7 +131,7 @@ def test_managed_identity_aci_probe(self): ], ) within_credential_chain.set(True) - cred = ImdsCredential(transport=transport) - token = cred.get_token(scope) - assert token.token == expected_token.token + credential = ImdsCredential(transport=transport) + token = getattr(credential, get_token_method)(scope) + assert token.token == expected_token within_credential_chain.set(False) diff --git a/sdk/identity/azure-identity/tests/test_imds_credential_async.py b/sdk/identity/azure-identity/tests/test_imds_credential_async.py index b759335fbc12..67e169203a61 100644 --- a/sdk/identity/azure-identity/tests/test_imds_credential_async.py +++ b/sdk/identity/azure-identity/tests/test_imds_credential_async.py @@ -2,12 +2,11 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ +from itertools import product import json import time from unittest import mock -from devtools_testutils.aio import recorded_by_proxy_async -from azure.core.credentials import AccessToken from azure.core.exceptions import ClientAuthenticationError from azure.identity import CredentialUnavailableError from azure.identity._constants import EnvironmentVariables @@ -17,11 +16,10 @@ from azure.identity._internal.utils import within_credential_chain import pytest -from helpers import mock_response, Request +from helpers import mock_response, Request, GET_TOKEN_METHODS from helpers_async import ( async_validating_transport, AsyncMockTransport, - await_test, get_completed_future, wrap_in_future, ) @@ -30,18 +28,20 @@ pytestmark = pytest.mark.asyncio -async def test_no_scopes(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_no_scopes(get_token_method): """The credential should raise ValueError when get_token is called with no scopes""" credential = ImdsCredential() with pytest.raises(ValueError): - await credential.get_token() + await getattr(credential, get_token_method)() -async def test_multiple_scopes(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_multiple_scopes(get_token_method): """The credential should raise ValueError when get_token is called with more than one scope""" credential = ImdsCredential() with pytest.raises(ValueError): - await credential.get_token("one scope", "and another") + await getattr(credential, get_token_method)("one scope", "and another") async def test_imds_close(): @@ -64,7 +64,8 @@ async def test_imds_context_manager(): assert transport.__aexit__.call_count == 1 -async def test_identity_not_available(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_identity_not_available(get_token_method): """The credential should raise CredentialUnavailableError when the endpoint responds 400 to a token request""" transport = async_validating_transport( @@ -74,10 +75,11 @@ async def test_identity_not_available(): credential = ImdsCredential(transport=transport) with pytest.raises(CredentialUnavailableError): - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") -async def test_unexpected_error(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_unexpected_error(get_token_method): """The credential should raise ClientAuthenticationError when the endpoint returns an unexpected error""" error_message = "something went wrong" @@ -94,13 +96,13 @@ async def send(request, **kwargs): credential = ImdsCredential(transport=transport) with pytest.raises(ClientAuthenticationError) as ex: - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") assert error_message in ex.value.message -@pytest.mark.parametrize("error_ending", ("network", "host", "foo")) -async def test_imds_request_failure_docker_desktop(error_ending): +@pytest.mark.parametrize("error_ending,get_token_method", product(("network", "host", "foo"), GET_TOKEN_METHODS)) +async def test_imds_request_failure_docker_desktop(error_ending, get_token_method): """The credential should raise CredentialUnavailableError when a 403 with a specific message is received""" error_message = ( @@ -112,12 +114,13 @@ async def test_imds_request_failure_docker_desktop(error_ending): credential = ImdsCredential(transport=transport) with pytest.raises(CredentialUnavailableError) as ex: - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") assert error_message in ex.value.message -async def test_cache(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_cache(get_token_method): scope = "https://foo.bar" expired = "this token's expired" now = int(time.time()) @@ -140,7 +143,7 @@ async def test_cache(): mock_send = mock.Mock(return_value=mock_response) credential = ImdsCredential(transport=mock.Mock(send=wrap_in_future(mock_send))) - token = await credential.get_token(scope) + token = await getattr(credential, get_token_method)(scope) assert token.token == expired assert mock_send.call_count == 1 @@ -149,17 +152,18 @@ async def test_cache(): token_payload["expires_on"] = int(time.time()) + 3600 token_payload["expires_in"] = 3600 token_payload["access_token"] = good_for_an_hour - token = await credential.get_token(scope) + token = await getattr(credential, get_token_method)(scope) assert token.token == good_for_an_hour assert mock_send.call_count == 2 # get_token should return the cached token now - token = await credential.get_token(scope) + token = await getattr(credential, get_token_method)(scope) assert token.token == good_for_an_hour assert mock_send.call_count == 2 -async def test_retries(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_retries(get_token_method): mock_response = mock.Mock( text=lambda encoding=None: b"{}", headers={"content-type": "application/json"}, @@ -173,20 +177,24 @@ async def test_retries(): mock_send.reset_mock() mock_response.status_code = status_code try: - await ImdsCredential( - transport=mock.Mock(send=wrap_in_future(mock_send), sleep=wrap_in_future(lambda _: None)) - ).get_token("scope") + await getattr( + ImdsCredential( + transport=mock.Mock(send=wrap_in_future(mock_send), sleep=wrap_in_future(lambda _: None)) + ), + get_token_method, + )("scope") except ClientAuthenticationError: pass # credential should have then exhausted retries for each of these status codes assert mock_send.call_count == 1 + total_retries -async def test_identity_config(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_identity_config(get_token_method): param_name, param_value = "foo", "bar" access_token = "****" expires_on = 42 - expected_token = AccessToken(access_token, expires_on) + expected_token = access_token scope = "scope" client_id = "some-guid" @@ -215,12 +223,13 @@ async def test_identity_config(): ) credential = ImdsCredential(client_id=client_id, identity_config={param_name: param_value}, transport=transport) - token = await credential.get_token(scope) + token = await getattr(credential, get_token_method)(scope) - assert token == expected_token + assert token.token == expected_token -async def test_imds_authority_override(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_imds_authority_override(get_token_method): authority = "https://localhost" expected_token = "***" scope = "scope" @@ -252,52 +261,60 @@ async def test_imds_authority_override(): with mock.patch.dict("os.environ", {EnvironmentVariables.AZURE_POD_IDENTITY_AUTHORITY_HOST: authority}, clear=True): credential = ImdsCredential(transport=transport) - token = await credential.get_token(scope) + token = await getattr(credential, get_token_method)(scope) assert token.token == expected_token @pytest.mark.usefixtures("record_imds_test") class TestImdsAsync(RecordedTestCase): - @await_test - @recorded_by_proxy_async - async def test_system_assigned(self): + + @pytest.mark.asyncio + @pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) + async def test_system_assigned(self, recorded_test, get_token_method): credential = ImdsCredential() - token = await credential.get_token(self.scope) + token = await getattr(credential, get_token_method)(self.scope) assert token.token assert isinstance(token.expires_on, int) - @await_test - @recorded_by_proxy_async - async def test_system_assigned_tenant_id(self): + @pytest.mark.asyncio + @pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) + async def test_system_assigned_tenant_id(self, recorded_test, get_token_method): credential = ImdsCredential() - token = await credential.get_token(self.scope, tenant_id="tenant_id") + kwargs = {"tenant_id": "tenant_id"} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(credential, get_token_method)(self.scope, **kwargs) assert token.token assert isinstance(token.expires_on, int) @pytest.mark.usefixtures("user_assigned_identity_client_id") - @await_test - @recorded_by_proxy_async - async def test_user_assigned(self): + @pytest.mark.asyncio + @pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) + async def test_user_assigned(self, recorded_test, get_token_method): credential = ImdsCredential(client_id=self.user_assigned_identity_client_id) - token = await credential.get_token(self.scope) + token = await getattr(credential, get_token_method)(self.scope) assert token.token assert isinstance(token.expires_on, int) @pytest.mark.usefixtures("user_assigned_identity_client_id") - @await_test - @recorded_by_proxy_async - async def test_user_assigned_tenant_id(self): + @pytest.mark.asyncio + @pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) + async def test_user_assigned_tenant_id(self, recorded_test, get_token_method): credential = ImdsCredential(client_id=self.user_assigned_identity_client_id) - token = await credential.get_token(self.scope, tenant_id="tenant_id") + kwargs = {"tenant_id": "tenant_id"} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(credential, get_token_method)(self.scope, **kwargs) assert token.token assert isinstance(token.expires_on, int) @pytest.mark.asyncio - async def test_managed_identity_aci_probe(self): + @pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) + async def test_managed_identity_aci_probe(self, get_token_method): access_token = "****" expires_on = 42 - expected_token = AccessToken(access_token, expires_on) + expected_token = access_token scope = "scope" transport = async_validating_transport( requests=[ @@ -325,7 +342,7 @@ async def test_managed_identity_aci_probe(self): ], ) within_credential_chain.set(True) - cred = ImdsCredential(transport=transport) - token = await cred.get_token(scope) - assert token == expected_token + credential = ImdsCredential(transport=transport) + token = await getattr(credential, get_token_method)(scope) + assert token.token == expected_token within_credential_chain.set(False) diff --git a/sdk/identity/azure-identity/tests/test_initialization.py b/sdk/identity/azure-identity/tests/test_initialization.py new file mode 100644 index 000000000000..7679b1c41031 --- /dev/null +++ b/sdk/identity/azure-identity/tests/test_initialization.py @@ -0,0 +1,68 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ + +from azure.core.credentials import SupportsTokenInfo, TokenCredential +from azure.identity import ( + AuthorizationCodeCredential, + CertificateCredential, + ClientSecretCredential, + DeviceCodeCredential, + EnvironmentCredential, + InteractiveBrowserCredential, + ManagedIdentityCredential, + OnBehalfOfCredential, + SharedTokenCacheCredential, + UsernamePasswordCredential, + VisualStudioCodeCredential, + WorkloadIdentityCredential, + DefaultAzureCredential, + ChainedTokenCredential, + AzureCliCredential, + AzurePowerShellCredential, + AzureDeveloperCliCredential, + AzurePipelinesCredential, +) + + +def test_credential_is_token_credential(): + assert isinstance(AuthorizationCodeCredential, TokenCredential) + assert isinstance(CertificateCredential, TokenCredential) + assert isinstance(ClientSecretCredential, TokenCredential) + assert isinstance(DeviceCodeCredential, TokenCredential) + assert isinstance(EnvironmentCredential, TokenCredential) + assert isinstance(InteractiveBrowserCredential, TokenCredential) + assert isinstance(ManagedIdentityCredential, TokenCredential) + assert isinstance(OnBehalfOfCredential, TokenCredential) + assert isinstance(SharedTokenCacheCredential, TokenCredential) + assert isinstance(UsernamePasswordCredential, TokenCredential) + assert isinstance(VisualStudioCodeCredential, TokenCredential) + assert isinstance(WorkloadIdentityCredential, TokenCredential) + assert isinstance(DefaultAzureCredential, TokenCredential) + assert isinstance(ChainedTokenCredential, TokenCredential) + assert isinstance(AzureCliCredential, TokenCredential) + assert isinstance(AzurePowerShellCredential, TokenCredential) + assert isinstance(AzureDeveloperCliCredential, TokenCredential) + assert isinstance(AzurePipelinesCredential, TokenCredential) + + +def test_credential_is_supports_token_info(): + assert isinstance(AuthorizationCodeCredential, SupportsTokenInfo) + assert isinstance(CertificateCredential, SupportsTokenInfo) + assert isinstance(ClientSecretCredential, SupportsTokenInfo) + assert isinstance(DeviceCodeCredential, SupportsTokenInfo) + assert isinstance(EnvironmentCredential, SupportsTokenInfo) + assert isinstance(InteractiveBrowserCredential, SupportsTokenInfo) + assert isinstance(ManagedIdentityCredential, SupportsTokenInfo) + assert isinstance(OnBehalfOfCredential, SupportsTokenInfo) + assert isinstance(SharedTokenCacheCredential, SupportsTokenInfo) + assert isinstance(UsernamePasswordCredential, SupportsTokenInfo) + assert isinstance(VisualStudioCodeCredential, SupportsTokenInfo) + assert isinstance(WorkloadIdentityCredential, SupportsTokenInfo) + assert isinstance(DefaultAzureCredential, SupportsTokenInfo) + assert isinstance(ChainedTokenCredential, SupportsTokenInfo) + assert isinstance(AzureCliCredential, SupportsTokenInfo) + assert isinstance(AzurePowerShellCredential, SupportsTokenInfo) + assert isinstance(AzureDeveloperCliCredential, SupportsTokenInfo) + assert isinstance(AzurePipelinesCredential, SupportsTokenInfo) diff --git a/sdk/identity/azure-identity/tests/test_initialization_async.py b/sdk/identity/azure-identity/tests/test_initialization_async.py new file mode 100644 index 000000000000..41bc432b86b3 --- /dev/null +++ b/sdk/identity/azure-identity/tests/test_initialization_async.py @@ -0,0 +1,58 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +from azure.core.credentials_async import AsyncSupportsTokenInfo, AsyncTokenCredential +from azure.identity.aio import ( + AuthorizationCodeCredential, + CertificateCredential, + ClientSecretCredential, + EnvironmentCredential, + ManagedIdentityCredential, + OnBehalfOfCredential, + SharedTokenCacheCredential, + VisualStudioCodeCredential, + WorkloadIdentityCredential, + DefaultAzureCredential, + ChainedTokenCredential, + AzureCliCredential, + AzurePowerShellCredential, + AzureDeveloperCliCredential, + AzurePipelinesCredential, +) + + +def test_credential_is_async_token_credential(): + assert isinstance(AuthorizationCodeCredential, AsyncTokenCredential) + assert isinstance(CertificateCredential, AsyncTokenCredential) + assert isinstance(ClientSecretCredential, AsyncTokenCredential) + assert isinstance(EnvironmentCredential, AsyncTokenCredential) + assert isinstance(ManagedIdentityCredential, AsyncTokenCredential) + assert isinstance(OnBehalfOfCredential, AsyncTokenCredential) + assert isinstance(SharedTokenCacheCredential, AsyncTokenCredential) + assert isinstance(VisualStudioCodeCredential, AsyncTokenCredential) + assert isinstance(WorkloadIdentityCredential, AsyncTokenCredential) + assert isinstance(DefaultAzureCredential, AsyncTokenCredential) + assert isinstance(ChainedTokenCredential, AsyncTokenCredential) + assert isinstance(AzureCliCredential, AsyncTokenCredential) + assert isinstance(AzurePowerShellCredential, AsyncTokenCredential) + assert isinstance(AzureDeveloperCliCredential, AsyncTokenCredential) + assert isinstance(AzurePipelinesCredential, AsyncTokenCredential) + + +def test_credential_is_async_supports_token_info(): + assert isinstance(AuthorizationCodeCredential, AsyncSupportsTokenInfo) + assert isinstance(CertificateCredential, AsyncSupportsTokenInfo) + assert isinstance(ClientSecretCredential, AsyncSupportsTokenInfo) + assert isinstance(EnvironmentCredential, AsyncSupportsTokenInfo) + assert isinstance(ManagedIdentityCredential, AsyncSupportsTokenInfo) + assert isinstance(OnBehalfOfCredential, AsyncSupportsTokenInfo) + assert isinstance(SharedTokenCacheCredential, AsyncSupportsTokenInfo) + assert isinstance(VisualStudioCodeCredential, AsyncSupportsTokenInfo) + assert isinstance(WorkloadIdentityCredential, AsyncSupportsTokenInfo) + assert isinstance(DefaultAzureCredential, AsyncSupportsTokenInfo) + assert isinstance(ChainedTokenCredential, AsyncSupportsTokenInfo) + assert isinstance(AzureCliCredential, AsyncSupportsTokenInfo) + assert isinstance(AzurePowerShellCredential, AsyncSupportsTokenInfo) + assert isinstance(AzureDeveloperCliCredential, AsyncSupportsTokenInfo) + assert isinstance(AzurePipelinesCredential, AsyncSupportsTokenInfo) diff --git a/sdk/identity/azure-identity/tests/test_interactive_credential.py b/sdk/identity/azure-identity/tests/test_interactive_credential.py index fc1ec6ad8558..c6fcd3da9993 100644 --- a/sdk/identity/azure-identity/tests/test_interactive_credential.py +++ b/sdk/identity/azure-identity/tests/test_interactive_credential.py @@ -17,7 +17,7 @@ from urllib.parse import urlparse from unittest.mock import Mock, patch -from helpers import build_aad_response, get_discovery_response, id_token_claims +from helpers import build_aad_response, get_discovery_response, id_token_claims, GET_TOKEN_METHODS # fake object for tests which need to exercise request_token but don't care about its return value @@ -49,15 +49,17 @@ def _request_token(self, *scopes, **kwargs): return self._request_token_impl(*scopes, **kwargs) -def test_no_scopes(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_no_scopes(get_token_method): """The credential should raise when get_token is called with no scopes""" request_token = Mock(side_effect=Exception("credential shouldn't begin interactive authentication")) with pytest.raises(ValueError): - MockCredential(request_token=request_token).get_token() + getattr(MockCredential(request_token=request_token), get_token_method)() -def test_authentication_record_argument(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_authentication_record_argument(get_token_method): """The credential should initialize its msal.ClientApplication with values from a given record""" record = AuthenticationRecord("tenant-id", "client-id", "localhost", "object.tenant", "username") @@ -72,12 +74,13 @@ def validate_app_parameters(authority, client_id, **_): credential = MockCredential(authentication_record=record, disable_automatic_authentication=True) with pytest.raises(AuthenticationRequiredError): with patch("msal.PublicClientApplication", mock_client_application): - credential.get_token("scope") + getattr(credential, get_token_method)("scope") assert mock_client_application.call_count == 1, "credential didn't create an msal application" -def test_enable_support_logging(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_enable_support_logging(get_token_method): """The keyword argument for enabling PII in MSAL should be passed.""" record = AuthenticationRecord("tenant-id", "client-id", "localhost", "object.tenant", "username") @@ -95,14 +98,15 @@ def validate_app_parameters(authority, client_id, **_): ) with pytest.raises(AuthenticationRequiredError): with patch("msal.PublicClientApplication", mock_client_application): - credential.get_token("scope") + getattr(credential, get_token_method)("scope") assert mock_client_application.call_count == 1, "credential didn't create an msal application" _, kwargs = mock_client_application.call_args assert kwargs["enable_pii_log"] -def test_tenant_argument_overrides_record(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_tenant_argument_overrides_record(get_token_method): """The 'tenant_ic' keyword argument should override a given record's value""" tenant_id = "some-guid" @@ -121,10 +125,11 @@ def validate_authority(authority, **_): ) with pytest.raises(AuthenticationRequiredError): with patch("msal.PublicClientApplication", validate_authority): - credential.get_token("scope") + getattr(credential, get_token_method)("scope") -def test_disable_automatic_authentication(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_disable_automatic_authentication(get_token_method): """When silent auth fails the credential should raise, if it's configured not to authenticate automatically""" expected_details = "something went wrong" @@ -144,14 +149,18 @@ def test_disable_automatic_authentication(): expected_claims = "..." with pytest.raises(AuthenticationRequiredError) as ex: with patch("msal.PublicClientApplication", lambda *_, **__: msal_app): - credential.get_token(scope, claims=expected_claims) + kwargs = {"claims": expected_claims} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + getattr(credential, get_token_method)(scope, **kwargs) # the exception should carry the requested scopes and claims, and any error message from Microsoft Entra ID assert ex.value.scopes == (scope,) assert ex.value.claims == expected_claims -def test_scopes_round_trip(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_scopes_round_trip(get_token_method): """authenticate should accept the value of AuthenticationRequiredError.scopes""" scope = "scope" @@ -163,7 +172,7 @@ def validate_scopes(*scopes, **_): request_token = Mock(wraps=validate_scopes) credential = MockCredential(disable_automatic_authentication=True, request_token=request_token) with pytest.raises(AuthenticationRequiredError) as ex: - credential.get_token(scope) + getattr(credential, get_token_method)(scope) credential.authenticate(scopes=ex.value.scopes) @@ -191,7 +200,8 @@ def validate_scopes(*scopes, **_): assert request_token.call_count == 1 -def test_authenticate_unknown_cloud(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_authenticate_unknown_cloud(get_token_method): """authenticate should raise when given no scopes in an unknown cloud""" with pytest.raises(CredentialUnavailableError): @@ -207,7 +217,8 @@ def test_authenticate_ignores_disable_automatic_authentication(option): assert request_token.call_count == 1, "credential didn't begin interactive authentication" -def test_get_token_wraps_exceptions(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_get_token_wraps_exceptions(get_token_method): """get_token shouldn't propagate exceptions from MSAL""" class CustomException(Exception): @@ -222,13 +233,14 @@ class CustomException(Exception): credential = MockCredential(authentication_record=record) with pytest.raises(ClientAuthenticationError) as ex: with patch("msal.PublicClientApplication", lambda *_, **__: msal_app): - credential.get_token("scope") + getattr(credential, get_token_method)("scope") assert expected_message in ex.value.message assert msal_app.acquire_token_silent_with_error.call_count == 1, "credential didn't attempt silent auth" -def test_token_cache_persistent(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_token_cache_persistent(get_token_method): """the credential should default to an in memory cache, and optionally use a persistent cache""" class TestCredential(InteractiveCredential): @@ -255,12 +267,15 @@ def _request_token(self, *_, **kwargs): ) assert not load_persistent_cache.called - credential.get_token("scope") + getattr(credential, get_token_method)("scope") assert load_persistent_cache.call_count == 1 assert credential._cache is not None assert credential._cae_cache is None - credential.get_token("scope", enable_cae=True) + kwargs = {"enable_cae": True} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + getattr(credential, get_token_method)("scope", **kwargs) assert load_persistent_cache.call_count == 2 assert credential._cae_cache is not None @@ -268,15 +283,16 @@ def _request_token(self, *_, **kwargs): assert credential2._cache is None assert credential2._cae_cache is None - credential2.get_token("scope") + getattr(credential2, get_token_method)("scope") assert isinstance(credential2._cache, TokenCache) assert credential2._cae_cache is None - credential2.get_token("scope", enable_cae=True) + getattr(credential2, get_token_method)("scope", **kwargs) assert isinstance(credential2._cae_cache, TokenCache) -def test_home_account_id_client_info(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_home_account_id_client_info(get_token_method): """when MSAL returns client_info, the credential should decode it to get the home_account_id""" object_id = "object-id" @@ -302,7 +318,8 @@ def _request_token(self, *_, **__): assert record.home_account_id == "{}.{}".format(object_id, home_tenant) -def test_adfs(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_adfs(get_token_method): """the credential should be able to construct an AuthenticationRecord from an ADFS response returned by MSAL""" authority = "localhost" @@ -333,7 +350,8 @@ def _request_token(self, *_, **__): assert record.username == username -def test_multitenant_authentication(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_multitenant_authentication(get_token_method): first_tenant = "first-tenant" first_token = "***" second_tenant = "second-tenant" @@ -368,21 +386,28 @@ def send(request, **kwargs): transport=Mock(send=send), additionally_allowed_tenants=["*"], ) - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == first_token - token = credential.get_token("scope", tenant_id=first_tenant) + kwargs = {"tenant_id": first_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert token.token == first_token - token = credential.get_token("scope", tenant_id=second_tenant) + kwargs = {"tenant_id": second_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert token.token == second_token # should still default to the first tenant - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == first_token -def test_multitenant_authentication_not_allowed(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_multitenant_authentication_not_allowed(get_token_method): expected_tenant = "expected-tenant" expected_token = "***" @@ -410,12 +435,18 @@ def send(request, **kwargs): credential = MockCredential(tenant_id=expected_tenant, transport=Mock(send=send), request_token=request_token) - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == expected_token - token = credential.get_token("scope", tenant_id=expected_tenant) + kwargs = {"tenant_id": expected_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert token.token == expected_token with patch.dict("os.environ", {EnvironmentVariables.AZURE_IDENTITY_DISABLE_MULTITENANTAUTH: "true"}): - token = credential.get_token("scope", tenant_id="un" + expected_tenant) + kwargs = {"tenant_id": "un" + expected_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert token.token == expected_token diff --git a/sdk/identity/azure-identity/tests/test_live.py b/sdk/identity/azure-identity/tests/test_live.py index f57244dd843d..f475404a5dce 100644 --- a/sdk/identity/azure-identity/tests/test_live.py +++ b/sdk/identity/azure-identity/tests/test_live.py @@ -2,6 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ +from itertools import product import pytest from azure.identity import ( @@ -16,31 +17,33 @@ ) from azure.identity._constants import DEVELOPER_SIGN_ON_CLIENT_ID -from helpers import get_token_payload_contents +from helpers import get_token_payload_contents, GET_TOKEN_METHODS ARM_SCOPE = "https://management.azure.com/.default" -def get_token(credential, **kwargs): - token = credential.get_token(ARM_SCOPE, **kwargs) +def get_token(credential, method, **kwargs): + token = getattr(credential, method)(ARM_SCOPE, **kwargs) assert token assert token.token assert token.expires_on return token -@pytest.mark.parametrize("certificate_fixture", ("live_pem_certificate", "live_pfx_certificate")) -def test_certificate_credential(certificate_fixture, request): +@pytest.mark.parametrize( + "certificate_fixture,get_token_method", product(("live_pem_certificate", "live_pfx_certificate"), GET_TOKEN_METHODS) +) +def test_certificate_credential(certificate_fixture, get_token_method, request): cert = request.getfixturevalue(certificate_fixture) tenant_id = cert["tenant_id"] client_id = cert["client_id"] credential = CertificateCredential(tenant_id, client_id, cert["cert_path"]) - get_token(credential) + get_token(credential, get_token_method) credential = CertificateCredential(tenant_id, client_id, certificate_data=cert["cert_bytes"]) - token = get_token(credential, enable_cae=True) + token = get_token(credential, get_token_method, enable_cae=True) parsed_payload = get_token_payload_contents(token.token) assert "xms_cc" in parsed_payload and "CP1" in parsed_payload["xms_cc"] @@ -48,61 +51,68 @@ def test_certificate_credential(certificate_fixture, request): credential = CertificateCredential( tenant_id, client_id, cert["cert_with_password_path"], password=cert["password"] ) - get_token(credential) + get_token(credential, get_token_method) credential = CertificateCredential( tenant_id, client_id, certificate_data=cert["cert_with_password_bytes"], password=cert["password"] ) - get_token(credential) + get_token(credential, get_token_method) -def test_client_secret_credential(live_service_principal): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_client_secret_credential(live_service_principal, get_token_method): credential = ClientSecretCredential( live_service_principal["tenant_id"], live_service_principal["client_id"], live_service_principal["client_secret"], ) - token = get_token(credential, enable_cae=True) + token = get_token(credential, get_token_method, enable_cae=True) parsed_payload = get_token_payload_contents(token.token) assert "xms_cc" in parsed_payload and "CP1" in parsed_payload["xms_cc"] -def test_default_credential(live_service_principal): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_default_credential(live_service_principal, get_token_method): credential = DefaultAzureCredential() - get_token(credential) + get_token(credential, get_token_method) -def test_username_password_auth(live_user_details): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_username_password_auth(live_user_details, get_token_method): credential = UsernamePasswordCredential( client_id=live_user_details["client_id"], username=live_user_details["username"], password=live_user_details["password"], tenant_id=live_user_details["tenant"], ) - get_token(credential) + get_token(credential, get_token_method) @pytest.mark.manual -def test_cli_credential(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_cli_credential(get_token_method): credential = AzureCliCredential() - get_token(credential) + get_token(credential, get_token_method) @pytest.mark.manual -def test_dev_cli_credential(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_dev_cli_credential(get_token_method): credential = AzureDeveloperCliCredential() - get_token(credential) + get_token(credential, get_token_method) @pytest.mark.manual -def test_powershell_credential(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_powershell_credential(get_token_method): credential = AzurePowerShellCredential() - get_token(credential) + get_token(credential, get_token_method) @pytest.mark.manual @pytest.mark.prints -def test_device_code(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_device_code(get_token_method): import webbrowser def prompt(url, user_code, _): @@ -110,4 +120,4 @@ def prompt(url, user_code, _): webbrowser.open_new_tab(url) credential = DeviceCodeCredential(client_id=DEVELOPER_SIGN_ON_CLIENT_ID, prompt_callback=prompt, timeout=40) - get_token(credential) + get_token(credential, get_token_method) diff --git a/sdk/identity/azure-identity/tests/test_live_async.py b/sdk/identity/azure-identity/tests/test_live_async.py index 073ef3f65e7d..8c8ed19f64fb 100644 --- a/sdk/identity/azure-identity/tests/test_live_async.py +++ b/sdk/identity/azure-identity/tests/test_live_async.py @@ -2,6 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ +from itertools import product import pytest from azure.identity.aio import ( @@ -13,13 +14,13 @@ AzureDeveloperCliCredential, ) -from helpers import get_token_payload_contents +from helpers import get_token_payload_contents, GET_TOKEN_METHODS ARM_SCOPE = "https://management.azure.com/.default" -async def get_token(credential, **kwargs): - token = await credential.get_token(ARM_SCOPE, **kwargs) +async def get_token(credential, get_token_method, **kwargs): + token = await getattr(credential, get_token_method)(ARM_SCOPE, **kwargs) assert token assert token.token assert token.expires_on @@ -27,18 +28,20 @@ async def get_token(credential, **kwargs): @pytest.mark.asyncio -@pytest.mark.parametrize("certificate_fixture", ("live_pem_certificate", "live_pfx_certificate")) -async def test_certificate_credential(certificate_fixture, request): +@pytest.mark.parametrize( + "certificate_fixture,get_token_method", product(("live_pem_certificate", "live_pfx_certificate"), GET_TOKEN_METHODS) +) +async def test_certificate_credential(certificate_fixture, get_token_method, request): cert = request.getfixturevalue(certificate_fixture) tenant_id = cert["tenant_id"] client_id = cert["client_id"] credential = CertificateCredential(tenant_id, client_id, cert["cert_path"]) - await get_token(credential) + await get_token(credential, get_token_method) credential = CertificateCredential(tenant_id, client_id, certificate_data=cert["cert_bytes"]) - token = await get_token(credential, enable_cae=True) + token = await get_token(credential, get_token_method, enable_cae=True) parsed_payload = get_token_payload_contents(token.token) assert "xms_cc" in parsed_payload and "CP1" in parsed_payload["xms_cc"] @@ -46,48 +49,53 @@ async def test_certificate_credential(certificate_fixture, request): credential = CertificateCredential( tenant_id, client_id, cert["cert_with_password_path"], password=cert["password"] ) - await get_token(credential) + await get_token(credential, get_token_method) credential = CertificateCredential( tenant_id, client_id, certificate_data=cert["cert_with_password_bytes"], password=cert["password"] ) - await get_token(credential, enable_cae=True) + await get_token(credential, get_token_method, enable_cae=True) @pytest.mark.asyncio -async def test_client_secret_credential(live_service_principal): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_client_secret_credential(live_service_principal, get_token_method): credential = ClientSecretCredential( live_service_principal["tenant_id"], live_service_principal["client_id"], live_service_principal["client_secret"], ) - token = await get_token(credential, enable_cae=True) + token = await get_token(credential, get_token_method, enable_cae=True) parsed_payload = get_token_payload_contents(token.token) assert "xms_cc" in parsed_payload and "CP1" in parsed_payload["xms_cc"] @pytest.mark.asyncio -async def test_default_credential(live_service_principal): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_default_credential(live_service_principal, get_token_method): credential = DefaultAzureCredential() - await get_token(credential) + await get_token(credential, get_token_method) @pytest.mark.manual @pytest.mark.asyncio -async def test_cli_credential(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_cli_credential(get_token_method): credential = AzureCliCredential() - await get_token(credential) + await get_token(credential, get_token_method) @pytest.mark.manual @pytest.mark.asyncio -async def test_dev_cli_credential(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_dev_cli_credential(get_token_method): credential = AzureDeveloperCliCredential() - await get_token(credential) + await get_token(credential, get_token_method) @pytest.mark.manual @pytest.mark.asyncio -async def test_powershell_credential(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_powershell_credential(get_token_method): credential = AzurePowerShellCredential() - await get_token(credential) + await get_token(credential, get_token_method) diff --git a/sdk/identity/azure-identity/tests/test_managed_identity.py b/sdk/identity/azure-identity/tests/test_managed_identity.py index 54ff51802e7f..e14dfe724b7a 100644 --- a/sdk/identity/azure-identity/tests/test_managed_identity.py +++ b/sdk/identity/azure-identity/tests/test_managed_identity.py @@ -2,13 +2,10 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ -import os -import sys +from itertools import product import time from unittest import mock -from azure.core.credentials import AccessToken -from azure.core.exceptions import ClientAuthenticationError from azure.identity import ManagedIdentityCredential, CredentialUnavailableError from azure.identity._constants import EnvironmentVariables from azure.identity._credentials.imds import IMDS_AUTHORITY, IMDS_TOKEN_PATH @@ -16,7 +13,7 @@ from azure.identity._internal import within_credential_chain import pytest -from helpers import build_aad_response, validating_transport, mock_response, Request +from helpers import build_aad_response, validating_transport, mock_response, Request, GET_TOKEN_METHODS MANAGED_IDENTITY_ENVIRON = "azure.identity._credentials.managed_identity.os.environ" ALL_ENVIRONMENTS = ( @@ -73,8 +70,8 @@ def test_context_manager_incomplete_configuration(): pass -@pytest.mark.parametrize("environ", ALL_ENVIRONMENTS) -def test_custom_hooks(environ): +@pytest.mark.parametrize("environ,get_token_method", product(ALL_ENVIRONMENTS, GET_TOKEN_METHODS)) +def test_custom_hooks(environ, get_token_method): """The credential's pipeline should include azure-core's CustomHookPolicy""" scope = "scope" @@ -99,7 +96,7 @@ def test_custom_hooks(environ): credential = ManagedIdentityCredential( transport=transport, raw_request_hook=request_hook, raw_response_hook=response_hook ) - credential.get_token(scope) + getattr(credential, get_token_method)(scope) assert request_hook.call_count == 1 assert response_hook.call_count == 1 @@ -108,8 +105,8 @@ def test_custom_hooks(environ): assert pipeline_response.http_response == expected_response -@pytest.mark.parametrize("environ", ALL_ENVIRONMENTS) -def test_tenant_id(environ): +@pytest.mark.parametrize("environ,get_token_method", product(ALL_ENVIRONMENTS, GET_TOKEN_METHODS)) +def test_tenant_id(environ, get_token_method): scope = "scope" expected_token = "***" request_hook = mock.Mock() @@ -132,7 +129,7 @@ def test_tenant_id(environ): credential = ManagedIdentityCredential( transport=transport, raw_request_hook=request_hook, raw_response_hook=response_hook ) - credential.get_token(scope) + getattr(credential, get_token_method)(scope) assert request_hook.call_count == 1 assert response_hook.call_count == 1 @@ -141,12 +138,13 @@ def test_tenant_id(environ): assert pipeline_response.http_response == expected_response -def test_cloud_shell(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_cloud_shell(get_token_method): """Cloud Shell environment: only MSI_ENDPOINT set""" access_token = "****" expires_on = 42 - expected_token = AccessToken(access_token, expires_on) + expected_token = access_token endpoint = "http://localhost:42/token" scope = "scope" transport = validating_transport( @@ -173,14 +171,16 @@ def test_cloud_shell(): ) with mock.patch("os.environ", {EnvironmentVariables.MSI_ENDPOINT: endpoint}): - token = ManagedIdentityCredential(transport=transport).get_token(scope) - assert token == expected_token + token = getattr(ManagedIdentityCredential(transport=transport), get_token_method)(scope) + assert token.token == expected_token + assert token.expires_on == expires_on -def test_cloud_shell_tenant_id(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_cloud_shell_tenant_id(get_token_method): access_token = "****" expires_on = 42 - expected_token = AccessToken(access_token, expires_on) + expected_token = access_token endpoint = "http://localhost:42/token" scope = "scope" transport = validating_transport( @@ -207,14 +207,20 @@ def test_cloud_shell_tenant_id(): ) with mock.patch("os.environ", {EnvironmentVariables.MSI_ENDPOINT: endpoint}): - token = ManagedIdentityCredential(transport=transport).get_token(scope, tenant_id="tenant_id") - assert token == expected_token + kwargs = {"tenant_id": "tenant_id"} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(ManagedIdentityCredential(transport=transport), get_token_method)(scope, **kwargs) + assert token.token == expected_token + assert token.expires_on == expires_on -def test_azure_ml(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_azure_ml(get_token_method): """Azure ML: MSI_ENDPOINT, MSI_SECRET set (like App Service 2017-09-01 but with a different response format)""" - expected_token = AccessToken("****", int(time.time()) + 3600) + expected_token = "****" + expires_on = int(time.time()) + 3600 url = "http://localhost:42/token" secret = "expected-secret" scope = "scope" @@ -238,9 +244,9 @@ def test_azure_ml(): responses=[ mock_response( json_payload={ - "access_token": expected_token.token, + "access_token": expected_token, "expires_in": 3600, - "expires_on": expected_token.expires_on, + "expires_on": expires_on, "resource": scope, "token_type": "Bearer", } @@ -254,17 +260,19 @@ def test_azure_ml(): {EnvironmentVariables.MSI_ENDPOINT: url, EnvironmentVariables.MSI_SECRET: secret}, clear=True, ): - token = ManagedIdentityCredential(transport=transport).get_token(scope) - assert token.token == expected_token.token - assert token.expires_on == expected_token.expires_on + token = getattr(ManagedIdentityCredential(transport=transport), get_token_method)(scope) + assert token.token == expected_token + assert token.expires_on == expires_on - token = ManagedIdentityCredential(transport=transport, client_id=client_id).get_token(scope) - assert token.token == expected_token.token - assert token.expires_on == expected_token.expires_on + token = getattr(ManagedIdentityCredential(transport=transport, client_id=client_id), get_token_method)(scope) + assert token.token == expected_token + assert token.expires_on == expires_on -def test_azure_ml_tenant_id(): - expected_token = AccessToken("****", int(time.time()) + 3600) +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_azure_ml_tenant_id(get_token_method): + expected_token = "****" + expires_on = int(time.time()) + 3600 url = "http://localhost:42/token" secret = "expected-secret" scope = "scope" @@ -288,9 +296,9 @@ def test_azure_ml_tenant_id(): responses=[ mock_response( json_payload={ - "access_token": expected_token.token, + "access_token": expected_token, "expires_in": 3600, - "expires_on": expected_token.expires_on, + "expires_on": expires_on, "resource": scope, "token_type": "Bearer", } @@ -304,12 +312,16 @@ def test_azure_ml_tenant_id(): {EnvironmentVariables.MSI_ENDPOINT: url, EnvironmentVariables.MSI_SECRET: secret}, clear=True, ): - token = ManagedIdentityCredential(transport=transport).get_token(scope, tenant_id="tenant_id") - assert token.token == expected_token.token - assert token.expires_on == expected_token.expires_on + kwargs = {"tenant_id": "tenant_id"} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(ManagedIdentityCredential(transport=transport), get_token_method)(scope, **kwargs) + assert token.token == expected_token + assert token.expires_on == expires_on -def test_cloud_shell_identity_config(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_cloud_shell_identity_config(get_token_method): """Cloud Shell environment: only MSI_ENDPOINT set""" expected_token = "****" @@ -349,17 +361,18 @@ def test_cloud_shell_identity_config(): ) with mock.patch.dict(MANAGED_IDENTITY_ENVIRON, {EnvironmentVariables.MSI_ENDPOINT: endpoint}, clear=True): - token = ManagedIdentityCredential(transport=transport).get_token(scope) + token = getattr(ManagedIdentityCredential(transport=transport), get_token_method)(scope) assert token.token == expected_token assert token.expires_on == expires_on credential = ManagedIdentityCredential(transport=transport, identity_config={param_name: param_value}) - token = credential.get_token(scope) + token = getattr(credential, get_token_method)(scope) assert token.token == expected_token assert token.expires_on == expires_on -def test_prefers_app_service_2019_08_01(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_prefers_app_service_2019_08_01(get_token_method): """When the environment is configured for both App Service versions, the credential should prefer the most recent""" access_token = "****" @@ -395,12 +408,13 @@ def test_prefers_app_service_2019_08_01(): EnvironmentVariables.MSI_SECRET: secret, } with mock.patch.dict("os.environ", environ, clear=True): - token = ManagedIdentityCredential(transport=transport).get_token(scope) + token = getattr(ManagedIdentityCredential(transport=transport), get_token_method)(scope) assert token.token == access_token assert token.expires_on == expires_on -def test_app_service_2019_08_01(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_app_service_2019_08_01(get_token_method): """App Service 2019-08-01: IDENTITY_ENDPOINT, IDENTITY_HEADER set""" access_token = "****" @@ -442,12 +456,13 @@ def send(request, **kwargs): }, clear=True, ): - token = ManagedIdentityCredential(transport=mock.Mock(send=send)).get_token(scope) + token = getattr(ManagedIdentityCredential(transport=mock.Mock(send=send)), get_token_method)(scope) assert token.token == access_token assert token.expires_on == expires_on -def test_app_service_2019_08_01_tenant_id(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_app_service_2019_08_01_tenant_id(get_token_method): """App Service 2019-08-01: IDENTITY_ENDPOINT, IDENTITY_HEADER set""" access_token = "****" @@ -489,12 +504,16 @@ def send(request, **kwargs): }, clear=True, ): - token = ManagedIdentityCredential(transport=mock.Mock(send=send)).get_token(scope, tenant_id="tenant_id") + kwargs = {"tenant_id": "tenant_id"} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(ManagedIdentityCredential(transport=mock.Mock(send=send)), get_token_method)(scope, **kwargs) assert token.token == access_token assert token.expires_on == expires_on -def test_app_service_user_assigned_identity(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_app_service_user_assigned_identity(get_token_method): """App Service 2019-08-01: IDENTITY_ENDPOINT, IDENTITY_HEADER set""" expected_token = "****" @@ -541,20 +560,21 @@ def test_app_service_user_assigned_identity(): {EnvironmentVariables.IDENTITY_ENDPOINT: endpoint, EnvironmentVariables.IDENTITY_HEADER: secret}, clear=True, ): - token = ManagedIdentityCredential(client_id=client_id, transport=transport).get_token(scope) + token = getattr(ManagedIdentityCredential(client_id=client_id, transport=transport), get_token_method)(scope) assert token.token == expected_token assert token.expires_on == expires_on credential = ManagedIdentityCredential(client_id=client_id, transport=transport) - token = credential.get_token(scope) + token = getattr(credential, get_token_method)(scope) assert token.token == expected_token assert token.expires_on == expires_on -def test_imds(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_imds(get_token_method): access_token = "****" expires_on = 42 - expected_token = AccessToken(access_token, expires_on) + expected_token = access_token scope = "scope" transport = validating_transport( requests=[ @@ -582,14 +602,15 @@ def test_imds(): # ensure e.g. $MSI_ENDPOINT isn't set, so we get ImdsCredential with mock.patch.dict("os.environ", clear=True): - token = ManagedIdentityCredential(transport=transport).get_token(scope) - assert token.token == expected_token.token + token = getattr(ManagedIdentityCredential(transport=transport), get_token_method)(scope) + assert token.token == expected_token -def test_imds_tenant_id(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_imds_tenant_id(get_token_method): access_token = "****" expires_on = 42 - expected_token = AccessToken(access_token, expires_on) + expected_token = access_token scope = "scope" transport = validating_transport( requests=[ @@ -617,11 +638,15 @@ def test_imds_tenant_id(): # ensure e.g. $MSI_ENDPOINT isn't set, so we get ImdsCredential with mock.patch.dict("os.environ", clear=True): - token = ManagedIdentityCredential(transport=transport).get_token(scope, tenant_id="tenant_id") - assert token.token == expected_token.token + kwargs = {"tenant_id": "tenant_id"} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(ManagedIdentityCredential(transport=transport), get_token_method)(scope, **kwargs) + assert token.token == expected_token -def test_imds_text_response(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_imds_text_response(get_token_method): within_credential_chain.set(True) response = mock.Mock( text=lambda encoding=None: b"{This is a text response}", @@ -632,11 +657,12 @@ def test_imds_text_response(): mock_send = mock.Mock(return_value=response) credential = ManagedIdentityCredential(transport=mock.Mock(send=mock_send)) with pytest.raises(CredentialUnavailableError): - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") within_credential_chain.set(False) -def test_client_id_none(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_client_id_none(get_token_method): """the credential should ignore client_id=None""" expected_access_token = "****" @@ -655,7 +681,7 @@ def send(request, **kwargs): # IMDS credential = ManagedIdentityCredential(client_id=None, transport=mock.Mock(send=send)) - token = credential.get_token(scope) + token = getattr(credential, get_token_method)(scope) assert token.token == expected_access_token # Cloud Shell @@ -663,14 +689,15 @@ def send(request, **kwargs): MANAGED_IDENTITY_ENVIRON, {EnvironmentVariables.MSI_ENDPOINT: "https://localhost"}, clear=True ): credential = ManagedIdentityCredential(client_id=None, transport=mock.Mock(send=send)) - token = credential.get_token(scope) + token = getattr(credential, get_token_method)(scope) assert token.token == expected_access_token -def test_imds_user_assigned_identity(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_imds_user_assigned_identity(get_token_method): access_token = "****" expires_on = 42 - expected_token = AccessToken(access_token, expires_on) + expected_token = access_token endpoint = IMDS_AUTHORITY + IMDS_TOKEN_PATH scope = "scope" client_id = "some-guid" @@ -701,11 +728,12 @@ def test_imds_user_assigned_identity(): # ensure e.g. $MSI_ENDPOINT isn't set, so we get ImdsCredential with mock.patch.dict("os.environ", clear=True): - token = ManagedIdentityCredential(client_id=client_id, transport=transport).get_token(scope) - assert token.token == expected_token.token + token = getattr(ManagedIdentityCredential(client_id=client_id, transport=transport), get_token_method)(scope) + assert token.token == expected_token -def test_service_fabric(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_service_fabric(get_token_method): """Service Fabric 2019-07-01-preview""" access_token = "****" expires_on = 42 @@ -741,12 +769,13 @@ def send(request, **kwargs): EnvironmentVariables.IDENTITY_SERVER_THUMBPRINT: thumbprint, }, ): - token = ManagedIdentityCredential(transport=mock.Mock(send=send)).get_token(scope) + token = getattr(ManagedIdentityCredential(transport=mock.Mock(send=send)), get_token_method)(scope) assert token.token == access_token assert token.expires_on == expires_on -def test_service_fabric_tenant_id(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_service_fabric_tenant_id(get_token_method): access_token = "****" expires_on = 42 endpoint = "http://localhost:42/token" @@ -781,12 +810,16 @@ def send(request, **kwargs): EnvironmentVariables.IDENTITY_SERVER_THUMBPRINT: thumbprint, }, ): - token = ManagedIdentityCredential(transport=mock.Mock(send=send)).get_token(scope, tenant_id="tenant_id") + kwargs = {"tenant_id": "tenant_id"} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(ManagedIdentityCredential(transport=mock.Mock(send=send)), get_token_method)(scope, **kwargs) assert token.token == access_token assert token.expires_on == expires_on -def test_token_exchange(tmpdir): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_token_exchange(tmpdir, get_token_method): exchange_token = "exchange-token" token_file = tmpdir.join("token") token_file.write(exchange_token) @@ -833,7 +866,7 @@ def test_token_exchange(tmpdir): # credential should default to AZURE_CLIENT_ID with mock.patch.dict("os.environ", mock_environ, clear=True): credential = ManagedIdentityCredential(transport=transport) - token = credential.get_token(scope) + token = getattr(credential, get_token_method)(scope) assert token.token == access_token # client_id kwarg should override AZURE_CLIENT_ID @@ -857,7 +890,7 @@ def test_token_exchange(tmpdir): with mock.patch.dict("os.environ", mock_environ, clear=True): credential = ManagedIdentityCredential(client_id=nondefault_client_id, transport=transport) - token = credential.get_token(scope) + token = getattr(credential, get_token_method)(scope) assert token.token == access_token # AZURE_CLIENT_ID may not have a value, in which case client_id is required @@ -891,11 +924,12 @@ def test_token_exchange(tmpdir): ManagedIdentityCredential() credential = ManagedIdentityCredential(client_id=nondefault_client_id, transport=transport) - token = credential.get_token(scope) + token = getattr(credential, get_token_method)(scope) assert token.token == access_token -def test_token_exchange_tenant_id(tmpdir): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_token_exchange_tenant_id(tmpdir, get_token_method): exchange_token = "exchange-token" token_file = tmpdir.join("token") token_file.write(exchange_token) @@ -941,7 +975,10 @@ def test_token_exchange_tenant_id(tmpdir): } with mock.patch.dict("os.environ", mock_environ, clear=True): credential = ManagedIdentityCredential(transport=transport) - token = credential.get_token(scope, tenant_id="tenant_id") + kwargs = {"tenant_id": "tenant_id"} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)(scope, **kwargs) assert token.token == access_token diff --git a/sdk/identity/azure-identity/tests/test_managed_identity_async.py b/sdk/identity/azure-identity/tests/test_managed_identity_async.py index 239f074c93f1..ee988f67dbf5 100644 --- a/sdk/identity/azure-identity/tests/test_managed_identity_async.py +++ b/sdk/identity/azure-identity/tests/test_managed_identity_async.py @@ -2,11 +2,11 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ +from itertools import product import os import time from unittest import mock -from azure.core.credentials import AccessToken from azure.core.exceptions import ClientAuthenticationError from azure.identity import CredentialUnavailableError from azure.identity.aio import ManagedIdentityCredential @@ -17,7 +17,7 @@ import pytest -from helpers import build_aad_response, mock_response, Request +from helpers import build_aad_response, mock_response, Request, GET_TOKEN_METHODS from helpers_async import async_validating_transport, AsyncMockTransport from test_managed_identity import ALL_ENVIRONMENTS @@ -26,8 +26,8 @@ @pytest.mark.asyncio -@pytest.mark.parametrize("environ", ALL_ENVIRONMENTS) -async def test_custom_hooks(environ): +@pytest.mark.parametrize("environ,get_token_method", product(ALL_ENVIRONMENTS, GET_TOKEN_METHODS)) +async def test_custom_hooks(environ, get_token_method): """The credential's pipeline should include azure-core's CustomHookPolicy""" scope = "scope" @@ -52,7 +52,7 @@ async def test_custom_hooks(environ): credential = ManagedIdentityCredential( transport=transport, raw_request_hook=request_hook, raw_response_hook=response_hook ) - await credential.get_token(scope) + await getattr(credential, get_token_method)(scope) assert request_hook.call_count == 1 assert response_hook.call_count == 1 @@ -62,8 +62,8 @@ async def test_custom_hooks(environ): @pytest.mark.asyncio -@pytest.mark.parametrize("environ", ALL_ENVIRONMENTS) -async def test_tenant_id(environ): +@pytest.mark.parametrize("environ,get_token_method", product(ALL_ENVIRONMENTS, GET_TOKEN_METHODS)) +async def test_tenant_id(environ, get_token_method): scope = "scope" expected_token = "***" request_hook = mock.Mock() @@ -86,7 +86,7 @@ async def test_tenant_id(environ): credential = ManagedIdentityCredential( transport=transport, raw_request_hook=request_hook, raw_response_hook=response_hook ) - await credential.get_token(scope) + await getattr(credential, get_token_method)(scope) assert request_hook.call_count == 1 assert response_hook.call_count == 1 @@ -134,12 +134,13 @@ async def test_context_manager_incomplete_configuration(): @pytest.mark.asyncio -async def test_cloud_shell(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_cloud_shell(get_token_method): """Cloud Shell environment: only MSI_ENDPOINT set""" access_token = "****" expires_on = 42 - expected_token = AccessToken(access_token, expires_on) + expected_token = access_token endpoint = "http://localhost:42/token" scope = "scope" transport = async_validating_transport( @@ -166,17 +167,18 @@ async def test_cloud_shell(): ) with mock.patch("os.environ", {EnvironmentVariables.MSI_ENDPOINT: endpoint}): - token = await ManagedIdentityCredential(transport=transport).get_token(scope) - assert token == expected_token + token = await getattr(ManagedIdentityCredential(transport=transport), get_token_method)(scope) + assert token.token == expected_token @pytest.mark.asyncio -async def test_cloud_shell_tenant_id(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_cloud_shell_tenant_id(get_token_method): """Cloud Shell environment: only MSI_ENDPOINT set""" access_token = "****" expires_on = 42 - expected_token = AccessToken(access_token, expires_on) + expected_token = access_token endpoint = "http://localhost:42/token" scope = "scope" transport = async_validating_transport( @@ -203,15 +205,20 @@ async def test_cloud_shell_tenant_id(): ) with mock.patch("os.environ", {EnvironmentVariables.MSI_ENDPOINT: endpoint}): - token = await ManagedIdentityCredential(transport=transport).get_token(scope, tenant_id="tenant_id") - assert token == expected_token + kwargs = {"tenant_id": "tenant_id"} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(ManagedIdentityCredential(transport=transport), get_token_method)(scope, **kwargs) + assert token.token == expected_token @pytest.mark.asyncio -async def test_azure_ml(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_azure_ml(get_token_method): """Azure ML: MSI_ENDPOINT, MSI_SECRET set (like App Service 2017-09-01 but with a different response format)""" - expected_token = AccessToken("****", int(time.time()) + 3600) + expected_token = "****" + expires_on = int(time.time()) + 3600 url = "http://localhost:42/token" secret = "expected-secret" scope = "scope" @@ -235,9 +242,9 @@ async def test_azure_ml(): responses=[ mock_response( json_payload={ - "access_token": expected_token.token, + "access_token": expected_token, "expires_in": 3600, - "expires_on": expected_token.expires_on, + "expires_on": expires_on, "resource": scope, "token_type": "Bearer", } @@ -252,21 +259,23 @@ async def test_azure_ml(): clear=True, ): credential = ManagedIdentityCredential(transport=transport) - token = await credential.get_token(scope) - assert token.token == expected_token.token - assert token.expires_on == expected_token.expires_on + token = await getattr(credential, get_token_method)(scope) + assert token.token == expected_token + assert token.expires_on == expires_on credential = ManagedIdentityCredential(transport=transport, client_id=client_id) - token = await credential.get_token(scope) - assert token.token == expected_token.token - assert token.expires_on == expected_token.expires_on + token = await getattr(credential, get_token_method)(scope) + assert token.token == expected_token + assert token.expires_on == expires_on @pytest.mark.asyncio -async def test_azure_ml_tenant_id(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_azure_ml_tenant_id(get_token_method): """Azure ML: MSI_ENDPOINT, MSI_SECRET set (like App Service 2017-09-01 but with a different response format)""" - expected_token = AccessToken("****", int(time.time()) + 3600) + expected_token = "****" + expires_on = int(time.time()) + 3600 url = "http://localhost:42/token" secret = "expected-secret" scope = "scope" @@ -290,9 +299,9 @@ async def test_azure_ml_tenant_id(): responses=[ mock_response( json_payload={ - "access_token": expected_token.token, + "access_token": expected_token, "expires_in": 3600, - "expires_on": expected_token.expires_on, + "expires_on": expires_on, "resource": scope, "token_type": "Bearer", } @@ -307,13 +316,17 @@ async def test_azure_ml_tenant_id(): clear=True, ): credential = ManagedIdentityCredential(transport=transport) - token = await credential.get_token(scope, tenant_id="tenant_id") - assert token.token == expected_token.token - assert token.expires_on == expected_token.expires_on + kwargs = {"tenant_id": "tenant_id"} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(credential, get_token_method)(scope, **kwargs) + assert token.token == expected_token + assert token.expires_on == expires_on @pytest.mark.asyncio -async def test_cloud_shell_identity_config(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_cloud_shell_identity_config(get_token_method): """Cloud Shell environment: only MSI_ENDPOINT set""" expected_token = "****" @@ -354,23 +367,24 @@ async def test_cloud_shell_identity_config(): with mock.patch.dict(MANAGED_IDENTITY_ENVIRON, {EnvironmentVariables.MSI_ENDPOINT: endpoint}, clear=True): credential = ManagedIdentityCredential(transport=transport) - token = await credential.get_token(scope) + token = await getattr(credential, get_token_method)(scope) assert token.token == expected_token assert token.expires_on == expires_on credential = ManagedIdentityCredential(transport=transport, identity_config={param_name: param_value}) - token = await credential.get_token(scope) + token = await getattr(credential, get_token_method)(scope) assert token.token == expected_token assert token.expires_on == expires_on @pytest.mark.asyncio -async def test_app_service_2017_09_01(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_app_service_2017_09_01(get_token_method): """When the environment for 2019-08-01 is not configured, 2017-09-01 should be used.""" access_token = "****" expires_on = 42 - expected_token = AccessToken(access_token, expires_on) + expected_token = access_token url = "http://localhost:42/token" secret = "expected-secret" scope = "scope" @@ -414,18 +428,19 @@ async def test_app_service_2017_09_01(): clear=True, ): credential = ManagedIdentityCredential(transport=transport) - token = await credential.get_token(scope) - assert token == expected_token + token = await getattr(credential, get_token_method)(scope) + assert token.token == expected_token assert token.expires_on == expires_on credential = ManagedIdentityCredential(transport=transport) - token = await credential.get_token(scope) - assert token == expected_token + token = await getattr(credential, get_token_method)(scope) + assert token.token == expected_token assert token.expires_on == expires_on @pytest.mark.asyncio -async def test_app_service_2019_08_01(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_app_service_2019_08_01(get_token_method): """App Service 2019-08-01: IDENTITY_ENDPOINT, IDENTITY_HEADER set""" access_token = "****" @@ -467,13 +482,14 @@ async def send(request, **kwargs): }, clear=True, ): - token = await ManagedIdentityCredential(transport=mock.Mock(send=send)).get_token(scope) + token = await getattr(ManagedIdentityCredential(transport=mock.Mock(send=send)), get_token_method)(scope) assert token.token == access_token assert token.expires_on == expires_on @pytest.mark.asyncio -async def test_app_service_2019_08_01_tenant_id(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_app_service_2019_08_01_tenant_id(get_token_method): access_token = "****" expires_on = 42 endpoint = "http://localhost:42/token" @@ -513,13 +529,19 @@ async def send(request, **kwargs): }, clear=True, ): - token = await ManagedIdentityCredential(transport=mock.Mock(send=send)).get_token(scope, tenant_id="tenant_id") + kwargs = {"tenant_id": "tenant_id"} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(ManagedIdentityCredential(transport=mock.Mock(send=send)), get_token_method)( + scope, **kwargs + ) assert token.token == access_token assert token.expires_on == expires_on @pytest.mark.asyncio -async def test_app_service_user_assigned_identity(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_app_service_user_assigned_identity(get_token_method): """App Service 2019-08-01: MSI_ENDPOINT, MSI_SECRET set""" expected_token = "****" @@ -564,20 +586,21 @@ async def test_app_service_user_assigned_identity(): clear=True, ): credential = ManagedIdentityCredential(client_id=client_id, transport=transport) - token = await credential.get_token(scope) + token = await getattr(credential, get_token_method)(scope) assert token.token == expected_token assert token.expires_on == expires_on credential = ManagedIdentityCredential( client_id=client_id, transport=transport, identity_config={param_name: param_value} ) - token = await credential.get_token(scope) + token = await getattr(credential, get_token_method)(scope) assert token.token == expected_token assert token.expires_on == expires_on @pytest.mark.asyncio -async def test_client_id_none(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_client_id_none(get_token_method): """the credential should ignore client_id=None""" expected_access_token = "****" @@ -596,22 +619,23 @@ async def send(request, **kwargs): with mock.patch.dict(MANAGED_IDENTITY_ENVIRON, {}, clear=True): credential = ManagedIdentityCredential(client_id=None, transport=mock.Mock(send=send)) - token = await credential.get_token(scope) + token = await getattr(credential, get_token_method)(scope) assert token.token == expected_access_token with mock.patch.dict( MANAGED_IDENTITY_ENVIRON, {EnvironmentVariables.MSI_ENDPOINT: "https://localhost"}, clear=True ): credential = ManagedIdentityCredential(client_id=None, transport=mock.Mock(send=send)) - token = await credential.get_token(scope) + token = await getattr(credential, get_token_method)(scope) assert token.token == expected_access_token @pytest.mark.asyncio -async def test_imds(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_imds(get_token_method): access_token = "****" expires_on = 42 - expected_token = AccessToken(access_token, expires_on) + expected_token = access_token scope = "scope" transport = async_validating_transport( requests=[ @@ -639,15 +663,16 @@ async def test_imds(): # ensure e.g. $MSI_ENDPOINT isn't set, so we get ImdsCredential with mock.patch.dict("os.environ", clear=True): - token = await ManagedIdentityCredential(transport=transport).get_token(scope) - assert token == expected_token + token = await getattr(ManagedIdentityCredential(transport=transport), get_token_method)(scope) + assert token.token == expected_token @pytest.mark.asyncio -async def test_imds_tenant_id(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_imds_tenant_id(get_token_method): access_token = "****" expires_on = 42 - expected_token = AccessToken(access_token, expires_on) + expected_token = access_token scope = "scope" transport = async_validating_transport( requests=[ @@ -675,15 +700,19 @@ async def test_imds_tenant_id(): # ensure e.g. $MSI_ENDPOINT isn't set, so we get ImdsCredential with mock.patch.dict("os.environ", clear=True): - token = await ManagedIdentityCredential(transport=transport).get_token(scope, tenant_id="tenant_id") - assert token == expected_token + kwargs = {"tenant_id": "tenant_id"} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(ManagedIdentityCredential(transport=transport), get_token_method)(scope, **kwargs) + assert token.token == expected_token @pytest.mark.asyncio -async def test_imds_user_assigned_identity(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_imds_user_assigned_identity(get_token_method): access_token = "****" expires_on = 42 - expected_token = AccessToken(access_token, expires_on) + expected_token = access_token scope = "scope" client_id = "some-guid" transport = async_validating_transport( @@ -713,12 +742,15 @@ async def test_imds_user_assigned_identity(): # ensure e.g. $MSI_ENDPOINT isn't set, so we get ImdsCredential with mock.patch.dict("os.environ", clear=True): - token = await ManagedIdentityCredential(client_id=client_id, transport=transport).get_token(scope) - assert token == expected_token + token = await getattr(ManagedIdentityCredential(client_id=client_id, transport=transport), get_token_method)( + scope + ) + assert token.token == expected_token @pytest.mark.asyncio -async def test_imds_text_response(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_imds_text_response(get_token_method): async def send(request, **kwargs): response = mock.Mock( text=lambda encoding=None: b"{This is a text response}", @@ -731,12 +763,13 @@ async def send(request, **kwargs): within_credential_chain.set(True) credential = ManagedIdentityCredential(transport=mock.Mock(send=send)) with pytest.raises(CredentialUnavailableError): - token = await credential.get_token("") + token = await getattr(credential, get_token_method)("") within_credential_chain.set(False) @pytest.mark.asyncio -async def test_service_fabric(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_service_fabric(get_token_method): """Service Fabric 2019-07-01-preview""" access_token = "****" expires_on = 42 @@ -772,13 +805,14 @@ async def send(request, **kwargs): EnvironmentVariables.IDENTITY_SERVER_THUMBPRINT: thumbprint, }, ): - token = await ManagedIdentityCredential(transport=mock.Mock(send=send)).get_token(scope) + token = await getattr(ManagedIdentityCredential(transport=mock.Mock(send=send)), get_token_method)(scope) assert token.token == access_token assert token.expires_on == expires_on @pytest.mark.asyncio -async def test_service_fabric_tenant_id(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_service_fabric_tenant_id(get_token_method): access_token = "****" expires_on = 42 endpoint = "http://localhost:42/token" @@ -813,13 +847,19 @@ async def send(request, **kwargs): EnvironmentVariables.IDENTITY_SERVER_THUMBPRINT: thumbprint, }, ): - token = await ManagedIdentityCredential(transport=mock.Mock(send=send)).get_token(scope, tenant_id="tenant_id") + kwargs = {"tenant_id": "tenant_id"} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(ManagedIdentityCredential(transport=mock.Mock(send=send)), get_token_method)( + scope, **kwargs + ) assert token.token == access_token assert token.expires_on == expires_on @pytest.mark.asyncio -async def test_azure_arc(tmpdir): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_azure_arc(tmpdir, get_token_method): """Azure Arc 2020-06-01""" access_token = "****" api_version = "2020-06-01" @@ -868,13 +908,14 @@ async def test_azure_arc(tmpdir): {EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint}, ): with mock.patch("azure.identity._credentials.azure_arc._validate_key_file", lambda x: None): - token = await ManagedIdentityCredential(transport=transport).get_token(scope) + token = await getattr(ManagedIdentityCredential(transport=transport), get_token_method)(scope) assert token.token == access_token assert token.expires_on == expires_on @pytest.mark.asyncio -async def test_azure_arc_tenant_id(tmpdir): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_azure_arc_tenant_id(tmpdir, get_token_method): access_token = "****" api_version = "2020-06-01" expires_on = 42 @@ -922,13 +963,17 @@ async def test_azure_arc_tenant_id(tmpdir): {EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint}, ): with mock.patch("azure.identity._credentials.azure_arc._validate_key_file", lambda x: None): - token = await ManagedIdentityCredential(transport=transport).get_token(scope, tenant_id="tenant_id") + kwargs = {"tenant_id": "tenant_id"} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(ManagedIdentityCredential(transport=transport), get_token_method)(scope, **kwargs) assert token.token == access_token assert token.expires_on == expires_on @pytest.mark.asyncio -async def test_azure_arc_client_id(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_azure_arc_client_id(get_token_method): """Azure Arc doesn't support user-assigned managed identity""" with mock.patch( "os.environ", @@ -940,11 +985,12 @@ async def test_azure_arc_client_id(): credential = ManagedIdentityCredential(client_id="some-guid") with pytest.raises(ClientAuthenticationError): - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") @pytest.mark.asyncio -async def test_azure_arc_key_too_large(tmp_path): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_azure_arc_key_too_large(tmp_path, get_token_method): api_version = "2020-06-01" identity_endpoint = "http://localhost:42/token" imds_endpoint = "http://localhost:42" @@ -974,12 +1020,13 @@ async def test_azure_arc_key_too_large(tmp_path): ): with mock.patch("azure.identity._credentials.azure_arc._get_key_file_path", lambda: str(tmp_path)): with pytest.raises(ClientAuthenticationError) as ex: - await ManagedIdentityCredential(transport=transport).get_token(scope) + await getattr(ManagedIdentityCredential(transport=transport), get_token_method)(scope) assert "file size" in str(ex.value) @pytest.mark.asyncio -async def test_azure_arc_key_not_exist(tmp_path): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_azure_arc_key_not_exist(get_token_method): api_version = "2020-06-01" identity_endpoint = "http://localhost:42/token" imds_endpoint = "http://localhost:42" @@ -1003,12 +1050,13 @@ async def test_azure_arc_key_not_exist(tmp_path): {EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint}, ): with pytest.raises(ClientAuthenticationError) as ex: - await ManagedIdentityCredential(transport=transport).get_token(scope) + await getattr(ManagedIdentityCredential(transport=transport), get_token_method)(scope) assert "not exist" in str(ex.value) @pytest.mark.asyncio -async def test_azure_arc_key_invalid(tmp_path): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_azure_arc_key_invalid(tmp_path, get_token_method): api_version = "2020-06-01" identity_endpoint = "http://localhost:42/token" imds_endpoint = "http://localhost:42" @@ -1043,17 +1091,18 @@ async def test_azure_arc_key_invalid(tmp_path): ): with mock.patch("azure.identity._credentials.azure_arc._get_key_file_path", lambda: "/foo"): with pytest.raises(ClientAuthenticationError) as ex: - await ManagedIdentityCredential(transport=transport).get_token(scope) + await getattr(ManagedIdentityCredential(transport=transport), get_token_method)(scope) assert "Unexpected file path" in str(ex.value) with mock.patch("azure.identity._credentials.azure_arc._get_key_file_path", lambda: str(tmp_path)): with pytest.raises(ClientAuthenticationError) as ex: - await ManagedIdentityCredential(transport=transport).get_token(scope) + await getattr(ManagedIdentityCredential(transport=transport), get_token_method)(scope) assert "extension" in str(ex.value) @pytest.mark.asyncio -async def test_token_exchange(tmpdir): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_token_exchange(tmpdir, get_token_method): exchange_token = "exchange-token" token_file = tmpdir.join("token") token_file.write(exchange_token) @@ -1100,7 +1149,7 @@ async def test_token_exchange(tmpdir): # credential should default to AZURE_CLIENT_ID with mock.patch.dict("os.environ", mock_environ, clear=True): credential = ManagedIdentityCredential(transport=transport) - token = await credential.get_token(scope) + token = await getattr(credential, get_token_method)(scope) assert token.token == access_token # client_id kwarg should override AZURE_CLIENT_ID @@ -1124,7 +1173,7 @@ async def test_token_exchange(tmpdir): with mock.patch.dict("os.environ", mock_environ, clear=True): credential = ManagedIdentityCredential(client_id=nondefault_client_id, transport=transport) - token = await credential.get_token(scope) + token = await getattr(credential, get_token_method)(scope) assert token.token == access_token # AZURE_CLIENT_ID may not have a value, in which case client_id is required @@ -1158,12 +1207,13 @@ async def test_token_exchange(tmpdir): ManagedIdentityCredential() credential = ManagedIdentityCredential(client_id=nondefault_client_id, transport=transport) - token = await credential.get_token(scope) + token = await getattr(credential, get_token_method)(scope) assert token.token == access_token @pytest.mark.asyncio -async def test_token_exchange_tenant_id(tmpdir): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_token_exchange_tenant_id(tmpdir, get_token_method): exchange_token = "exchange-token" token_file = tmpdir.join("token") token_file.write(exchange_token) @@ -1210,7 +1260,10 @@ async def test_token_exchange_tenant_id(tmpdir): # credential should default to AZURE_CLIENT_ID with mock.patch.dict("os.environ", mock_environ, clear=True): credential = ManagedIdentityCredential(transport=transport) - token = await credential.get_token(scope, tenant_id="tenant_id") + kwargs = {"tenant_id": "tenant_id"} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(credential, get_token_method)(scope, **kwargs) assert token.token == access_token diff --git a/sdk/identity/azure-identity/tests/test_multi_tenant_auth.py b/sdk/identity/azure-identity/tests/test_multi_tenant_auth.py index 2c5d8c385bfe..f86df54a172b 100644 --- a/sdk/identity/azure-identity/tests/test_multi_tenant_auth.py +++ b/sdk/identity/azure-identity/tests/test_multi_tenant_auth.py @@ -11,11 +11,13 @@ from azure.core.rest import HttpRequest, HttpResponse from azure.identity import ClientSecretCredential +from helpers import GET_TOKEN_METHODS + class TestMultiTenantAuth(AzureRecordedTestCase): - def _send_request(self, credential: ClientSecretCredential) -> HttpResponse: + def _send_request(self, credential: ClientSecretCredential, get_token_method: str) -> HttpResponse: client = PipelineClient(base_url="https://graph.microsoft.com") - token = credential.get_token("https://graph.microsoft.com/.default") + token = getattr(credential, get_token_method)("https://graph.microsoft.com/.default") headers = {"Authorization": "Bearer " + token.token, "ConsistencyLevel": "eventual"} request = HttpRequest("GET", "https://graph.microsoft.com/v1.0/applications/$count", headers=headers) response = client.send_request(request) @@ -26,11 +28,12 @@ def _send_request(self, credential: ClientSecretCredential) -> HttpResponse: is_live() and not os.environ.get("AZURE_IDENTITY_MULTI_TENANT_CLIENT_ID"), reason="Multi-tenant envvars not configured.", ) - def test_multi_tenant_client_secret_graph_call(self, recorded_test, environment_variables): + @pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) + def test_multi_tenant_client_secret_graph_call(self, recorded_test, environment_variables, get_token_method): client_id = environment_variables.get("AZURE_IDENTITY_MULTI_TENANT_CLIENT_ID") tenant_id = environment_variables.get("AZURE_IDENTITY_MULTI_TENANT_TENANT_ID") client_secret = environment_variables.get("AZURE_IDENTITY_MULTI_TENANT_CLIENT_SECRET") credential = ClientSecretCredential(tenant_id, client_id, client_secret) - response = self._send_request(credential) + response = self._send_request(credential, get_token_method) assert response.status_code == 200 assert int(response.text()) > 0 diff --git a/sdk/identity/azure-identity/tests/test_multi_tenant_auth_async.py b/sdk/identity/azure-identity/tests/test_multi_tenant_auth_async.py index c6ca05477760..6c1e50a7a0d9 100644 --- a/sdk/identity/azure-identity/tests/test_multi_tenant_auth_async.py +++ b/sdk/identity/azure-identity/tests/test_multi_tenant_auth_async.py @@ -11,11 +11,13 @@ from azure.core.rest import HttpRequest, HttpResponse from azure.identity.aio import ClientSecretCredential +from helpers import GET_TOKEN_METHODS + class TestMultiTenantAuthAsync(AzureRecordedTestCase): - async def _send_request(self, credential: ClientSecretCredential) -> HttpResponse: + async def _send_request(self, credential: ClientSecretCredential, get_token_method: str) -> HttpResponse: client = AsyncPipelineClient(base_url="https://graph.microsoft.com") - token = await credential.get_token("https://graph.microsoft.com/.default") + token = await getattr(credential, get_token_method)("https://graph.microsoft.com/.default") headers = {"Authorization": "Bearer " + token.token, "ConsistencyLevel": "eventual"} request = HttpRequest("GET", "https://graph.microsoft.com/v1.0/applications/$count", headers=headers) response = await client.send_request(request, stream=False) @@ -27,12 +29,13 @@ async def _send_request(self, credential: ClientSecretCredential) -> HttpRespons is_live() and not os.environ.get("AZURE_IDENTITY_MULTI_TENANT_CLIENT_ID"), reason="Multi-tenant envvars not configured.", ) - async def test_multi_tenant_client_secret_graph_call(self, recorded_test, environment_variables): + @pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) + async def test_multi_tenant_client_secret_graph_call(self, recorded_test, environment_variables, get_token_method): client_id = environment_variables.get("AZURE_IDENTITY_MULTI_TENANT_CLIENT_ID") tenant_id = environment_variables.get("AZURE_IDENTITY_MULTI_TENANT_TENANT_ID") client_secret = environment_variables.get("AZURE_IDENTITY_MULTI_TENANT_CLIENT_SECRET") credential = ClientSecretCredential(tenant_id, client_id, client_secret) async with credential: - response = await self._send_request(credential) + response = await self._send_request(credential, get_token_method) assert response.status_code == 200 assert int(response.text()) > 0 diff --git a/sdk/identity/azure-identity/tests/test_obo.py b/sdk/identity/azure-identity/tests/test_obo.py index 3678b4fa8506..bbcd50e6c603 100644 --- a/sdk/identity/azure-identity/tests/test_obo.py +++ b/sdk/identity/azure-identity/tests/test_obo.py @@ -3,11 +3,8 @@ # Licensed under the MIT License. # ------------------------------------ import os - -try: - from unittest.mock import Mock, patch -except ImportError: - from mock import Mock, patch # type: ignore +from itertools import product +from unittest.mock import Mock, patch from azure.core.pipeline.policies import ContentDecodePolicy, SansIOHTTPPolicy from azure.identity import OnBehalfOfCredential, UsernamePasswordCredential @@ -17,7 +14,7 @@ import pytest from urllib.parse import urlparse -from helpers import build_aad_response, FAKE_CLIENT_ID, get_discovery_response, mock_response +from helpers import build_aad_response, FAKE_CLIENT_ID, get_discovery_response, mock_response, GET_TOKEN_METHODS from recorded_test_case import RecordedTestCase from test_certificate_credential import PEM_CERT_PATH from devtools_testutils import is_live, recorded_by_proxy @@ -95,7 +92,8 @@ def test_obo_cert(self): credential.get_token(self.obo_settings["scope"]) -def test_multitenant_authentication(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_multitenant_authentication(get_token_method): first_tenant = "first-tenant" first_token = "***" second_tenant = "second-tenant" @@ -124,22 +122,28 @@ def send(request, **kwargs): transport=transport, additionally_allowed_tenants=["*"], ) - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == first_token - token = credential.get_token("scope", tenant_id=first_tenant) + kwargs = {"tenant_id": first_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert token.token == first_token - token = credential.get_token("scope", tenant_id=second_tenant) + kwargs = {"tenant_id": second_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert token.token == second_token # should still default to the first tenant - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == first_token -@pytest.mark.parametrize("authority", ("localhost", "https://localhost")) -def test_authority(authority): +@pytest.mark.parametrize("authority,get_token_method", product(("localhost", "https://localhost"), GET_TOKEN_METHODS)) +def test_authority(authority, get_token_method): """the credential should accept an authority, with or without scheme, as an argument or environment variable""" tenant_id = "expected-tenant" @@ -156,7 +160,7 @@ def test_authority(authority): ) with patch("msal.ConfidentialClientApplication", mock_ctor): # must call get_token because the credential constructs the MSAL application lazily - credential.get_token("scope") + getattr(credential, get_token_method)("scope") assert mock_ctor.call_count == 1 _, kwargs = mock_ctor.call_args @@ -167,7 +171,7 @@ def test_authority(authority): with patch.dict("os.environ", {EnvironmentVariables.AZURE_AUTHORITY_HOST: authority}, clear=True): credential = OnBehalfOfCredential(tenant_id, "client-id", client_secret="secret", user_assertion="assertion") with patch("msal.ConfidentialClientApplication", mock_ctor): - credential.get_token("scope") + getattr(credential, get_token_method)("scope") assert mock_ctor.call_count == 1 _, kwargs = mock_ctor.call_args @@ -185,16 +189,18 @@ def test_tenant_id_validation(): OnBehalfOfCredential(tenant, "client-id", client_secret="secret", user_assertion="assertion") -def test_no_scopes(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_no_scopes(get_token_method): """The credential should raise ValueError when get_token is called with no scopes""" credential = OnBehalfOfCredential( "tenant-id", "client-id", client_secret="client-secret", user_assertion="assertion" ) with pytest.raises(ValueError): - credential.get_token() + getattr(credential, get_token_method)() -def test_policies_configurable(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_policies_configurable(get_token_method): policy = Mock(spec_set=SansIOHTTPPolicy, on_request=Mock(), on_exception=lambda _: False) def send(request, **kwargs): @@ -215,7 +221,7 @@ def send(request, **kwargs): policies=[ContentDecodePolicy(), policy], transport=Mock(send=send), ) - credential.get_token("scope") + getattr(credential, get_token_method)("scope") assert policy.on_request.called @@ -231,7 +237,8 @@ def test_no_client_credential(): credential = OnBehalfOfCredential("tenant-id", "client-id", user_assertion="assertion") -def test_client_assertion_func(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_client_assertion_func(get_token_method): """The credential should accept a client_assertion_func""" expected_client_assertion = "client-assertion" expected_user_assertion = "user-assertion" @@ -263,7 +270,7 @@ def send(request, **kwargs): transport=transport, ) - access_token = credential.get_token("scope") + access_token = getattr(credential, get_token_method)("scope") assert access_token.token == expected_token assert func_call_count == 1 diff --git a/sdk/identity/azure-identity/tests/test_obo_async.py b/sdk/identity/azure-identity/tests/test_obo_async.py index f9d80323e84f..5e90040a0af1 100644 --- a/sdk/identity/azure-identity/tests/test_obo_async.py +++ b/sdk/identity/azure-identity/tests/test_obo_async.py @@ -3,6 +3,7 @@ # Licensed under the MIT License. # ------------------------------------ import os +from itertools import product from urllib.parse import urlparse from unittest.mock import Mock, patch from test_certificate_credential import PEM_CERT_PATH @@ -17,7 +18,7 @@ from azure.identity.aio import OnBehalfOfCredential import pytest -from helpers import build_aad_response, get_discovery_response, mock_response, FAKE_CLIENT_ID +from helpers import build_aad_response, get_discovery_response, mock_response, FAKE_CLIENT_ID, GET_TOKEN_METHODS from helpers_async import AsyncMockTransport from recorded_test_case import RecordedTestCase @@ -123,7 +124,8 @@ async def test_context_manager(): @pytest.mark.asyncio -async def test_multitenant_authentication(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_multitenant_authentication(get_token_method): first_tenant = "first-tenant" first_token = "***" second_tenant = "second-tenant" @@ -149,27 +151,33 @@ async def send(request, **kwargs): transport=transport, additionally_allowed_tenants=["*"], ) - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == first_token assert transport.send.call_count == 1 - token = await credential.get_token("scope", tenant_id=first_tenant) + kwargs = {"tenant_id": first_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(credential, get_token_method)("scope", **kwargs) assert token.token == first_token assert transport.send.call_count == 1 # should be a cached token - token = await credential.get_token("scope", tenant_id=second_tenant) + kwargs = {"tenant_id": second_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(credential, get_token_method)("scope", **kwargs) assert token.token == second_token assert transport.send.call_count == 2 # should still default to the first tenant - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == first_token assert transport.send.call_count == 2 # should be a cached token @pytest.mark.asyncio -@pytest.mark.parametrize("authority", ("localhost", "https://localhost")) -async def test_authority(authority): +@pytest.mark.parametrize("authority,get_token_method", product(("localhost", "https://localhost"), GET_TOKEN_METHODS)) +async def test_authority(authority, get_token_method): """the credential should accept an authority, with or without scheme, as an argument or environment variable""" tenant_id = "expected-tenant" @@ -194,7 +202,7 @@ async def send(request, **kwargs): authority=authority, transport=transport, ) - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == expected_token # authority can be configured via environment variable @@ -202,12 +210,13 @@ async def send(request, **kwargs): credential = OnBehalfOfCredential( tenant_id, "client-id", client_secret="secret", user_assertion="assertion", transport=transport ) - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == expected_token @pytest.mark.asyncio -async def test_policies_configurable(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_policies_configurable(get_token_method): policy = Mock(spec_set=SansIOHTTPPolicy, on_request=Mock(), on_exception=lambda _: False) async def send(request, **kwargs): @@ -228,7 +237,7 @@ async def send(request, **kwargs): policies=[ContentDecodePolicy(), policy], transport=Mock(send=send), ) - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") assert policy.on_request.called @@ -239,7 +248,8 @@ def test_invalid_cert(): @pytest.mark.asyncio -async def test_refresh_token(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_refresh_token(get_token_method): first_token = "***" second_token = first_token * 2 refresh_token = "refresh-token" @@ -264,10 +274,10 @@ async def send(request, **kwargs): credential = OnBehalfOfCredential( "tenant-id", "client-id", client_secret="secret", user_assertion="assertion", transport=Mock(send=send) ) - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == first_token - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == second_token assert requests == 2 @@ -285,13 +295,14 @@ def test_tenant_id_validation(): @pytest.mark.asyncio -async def test_no_scopes(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_no_scopes(get_token_method): """The credential should raise ValueError when get_token is called with no scopes""" credential = OnBehalfOfCredential( "tenant-id", "client-id", client_secret="client-secret", user_assertion="assertion" ) with pytest.raises(ValueError): - await credential.get_token() + await getattr(credential, get_token_method)() @pytest.mark.asyncio @@ -309,7 +320,8 @@ async def test_no_client_credential(): @pytest.mark.asyncio -async def test_client_assertion_func(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_client_assertion_func(get_token_method): """The credential should accept a client_assertion_func""" expected_client_assertion = "client-assertion" expected_user_assertion = "user-assertion" @@ -340,7 +352,7 @@ async def send(request, **kwargs): user_assertion=expected_user_assertion, transport=transport, ) - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == expected_token assert func_call_count == 1 diff --git a/sdk/identity/azure-identity/tests/test_powershell_credential.py b/sdk/identity/azure-identity/tests/test_powershell_credential.py index 9441f1c69756..23b663a667e8 100644 --- a/sdk/identity/azure-identity/tests/test_powershell_credential.py +++ b/sdk/identity/azure-identity/tests/test_powershell_credential.py @@ -3,17 +3,14 @@ # Licensed under the MIT License. # ------------------------------------ import base64 +from itertools import product import logging from platform import python_version import re import subprocess import sys import time - -try: - from unittest.mock import Mock, patch -except ImportError: # python < 3.3 - from mock import Mock, patch # type: ignore +from unittest.mock import Mock, patch from azure.core.exceptions import ClientAuthenticationError from azure.identity import AzurePowerShellCredential, CredentialUnavailableError @@ -28,7 +25,7 @@ import pytest from credscan_ignore import POWERSHELL_INVALID_OPERATION_EXCEPTION, POWERSHELL_NOT_LOGGED_IN_ERROR -from helpers import INVALID_CHARACTERS +from helpers import INVALID_CHARACTERS, GET_TOKEN_METHODS POPEN = AzurePowerShellCredential.__module__ + ".subprocess.Popen" @@ -48,29 +45,33 @@ def get_mock_Popen(return_code=0, stdout="", stderr=""): return Mock(return_value=Mock(communicate=communicate, returncode=return_code)) -def test_no_scopes(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_no_scopes(get_token_method): """The credential should raise ValueError when get_token is called with no scopes""" with pytest.raises(ValueError): - AzurePowerShellCredential().get_token() + getattr(AzurePowerShellCredential(), get_token_method)() -def test_multiple_scopes(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_multiple_scopes(get_token_method): """The credential should raise ValueError when get_token is called with more than one scope""" with pytest.raises(ValueError): - AzurePowerShellCredential().get_token("one scope", "and another") + getattr(AzurePowerShellCredential(), get_token_method)("one scope", "and another") -def test_cannot_execute_shell(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_cannot_execute_shell(get_token_method): """The credential should raise CredentialUnavailableError when the subprocess doesn't start""" with patch(POPEN, Mock(side_effect=OSError)): with pytest.raises(CredentialUnavailableError): - AzurePowerShellCredential().get_token("scope") + getattr(AzurePowerShellCredential(), get_token_method)("scope") -def test_invalid_tenant_id(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_invalid_tenant_id(get_token_method): """Invalid tenant IDs should raise ValueErrors.""" for c in INVALID_CHARACTERS: @@ -78,19 +79,23 @@ def test_invalid_tenant_id(): AzurePowerShellCredential(tenant_id="tenant" + c) with pytest.raises(ValueError): - AzurePowerShellCredential().get_token("scope", tenant_id="tenant" + c) + kwargs = {"tenant_id": "tenant" + c} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + getattr(AzurePowerShellCredential(), get_token_method)("scope", **kwargs) -def test_invalid_scopes(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_invalid_scopes(get_token_method): """Scopes with invalid characters should raise ValueErrors.""" for c in INVALID_CHARACTERS: with pytest.raises(ValueError): - AzurePowerShellCredential().get_token("scope" + c) + getattr(AzurePowerShellCredential(), get_token_method)("scope" + c) -@pytest.mark.parametrize("stderr", ("", PREPARING_MODULES)) -def test_get_token(stderr): +@pytest.mark.parametrize("stderr,get_token_method", product(("", PREPARING_MODULES), GET_TOKEN_METHODS)) +def test_get_token(stderr, get_token_method): """The credential should parse Azure PowerShell's output to an AccessToken""" expected_access_token = "access" @@ -100,7 +105,7 @@ def test_get_token(stderr): Popen = get_mock_Popen(stdout=stdout, stderr=stderr) with patch(POPEN, Popen): - token = AzurePowerShellCredential().get_token(scope) + token = getattr(AzurePowerShellCredential(), get_token_method)(scope) assert token.token == expected_access_token assert token.expires_on == expected_expires_on @@ -123,8 +128,8 @@ def test_get_token(stderr): assert "timeout" in kwargs -@pytest.mark.parametrize("stderr", ("", PREPARING_MODULES)) -def test_get_token_tenant_id(stderr): +@pytest.mark.parametrize("stderr,get_token_method", product(("", PREPARING_MODULES), GET_TOKEN_METHODS)) +def test_get_token_tenant_id(stderr, get_token_method): expected_access_token = "access" expected_expires_on = 1617923581 scope = "scope" @@ -132,69 +137,82 @@ def test_get_token_tenant_id(stderr): Popen = get_mock_Popen(stdout=stdout, stderr=stderr) with patch(POPEN, Popen): - token = AzurePowerShellCredential().get_token(scope, tenant_id="tenant-id") + kwargs = {"tenant_id": "tenant-id"} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(AzurePowerShellCredential(), get_token_method)(scope, **kwargs) assert token.token == expected_access_token assert token.expires_on == expected_expires_on -def test_ignores_extraneous_stdout_content(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_ignores_extraneous_stdout_content(get_token_method): expected_access_token = "access" expected_expires_on = 1617923581 motd = "MOTD: Customize your experience: save your profile to $HOME/.config/PowerShell\n" Popen = get_mock_Popen(stdout=motd + "azsdk%{}%{}".format(expected_access_token, expected_expires_on)) with patch(POPEN, Popen): - token = AzurePowerShellCredential().get_token("scope") + token = getattr(AzurePowerShellCredential(), get_token_method)("scope") assert token.token == expected_access_token assert token.expires_on == expected_expires_on -def test_az_powershell_not_installed(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_az_powershell_not_installed(get_token_method): """The credential should raise CredentialUnavailableError when Azure PowerShell isn't installed""" with patch(POPEN, get_mock_Popen(stdout=NO_AZ_ACCOUNT_MODULE)): with pytest.raises(CredentialUnavailableError, match=AZ_ACCOUNT_NOT_INSTALLED): - AzurePowerShellCredential().get_token("scope") + getattr(AzurePowerShellCredential(), get_token_method)("scope") @pytest.mark.parametrize( - "stderr", - ( - "'pwsh' is not recognized as an internal or external command,\r\noperable program or batch file.", - "'powershell' is not recognized as an internal or external command,\r\noperable program or batch file.", + "stderr,get_token_method", + product( + ( + "'pwsh' is not recognized as an internal or external command,\r\noperable program or batch file.", + "'powershell' is not recognized as an internal or external command,\r\noperable program or batch file.", + ), + GET_TOKEN_METHODS, ), ) -def test_powershell_not_installed_cmd(stderr): +def test_powershell_not_installed_cmd(stderr, get_token_method): """The credential should raise CredentialUnavailableError when PowerShell isn't installed""" Popen = get_mock_Popen(return_code=1, stderr=stderr) with patch(POPEN, Popen): with pytest.raises(CredentialUnavailableError, match=POWERSHELL_NOT_INSTALLED): - AzurePowerShellCredential().get_token("scope") + getattr(AzurePowerShellCredential(), get_token_method)("scope") -def test_powershell_not_installed_sh(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_powershell_not_installed_sh(get_token_method): """The credential should raise CredentialUnavailableError when PowerShell isn't installed""" Popen = get_mock_Popen(return_code=127, stderr="/bin/sh: 0: Can't open pwsh") with patch(POPEN, Popen): with pytest.raises(CredentialUnavailableError, match=POWERSHELL_NOT_INSTALLED): - AzurePowerShellCredential().get_token("scope") + getattr(AzurePowerShellCredential(), get_token_method)("scope") -@pytest.mark.parametrize("stderr", (POWERSHELL_INVALID_OPERATION_EXCEPTION, POWERSHELL_NOT_LOGGED_IN_ERROR)) -def test_not_logged_in(stderr): +@pytest.mark.parametrize( + "stderr,get_token_method", + product((POWERSHELL_INVALID_OPERATION_EXCEPTION, POWERSHELL_NOT_LOGGED_IN_ERROR), GET_TOKEN_METHODS), +) +def test_not_logged_in(stderr, get_token_method): """The credential should raise CredentialUnavailableError when a user isn't logged in to Azure PowerShell""" Popen = get_mock_Popen(return_code=1, stderr=stderr) with patch(POPEN, Popen): with pytest.raises(CredentialUnavailableError, match=RUN_CONNECT_AZ_ACCOUNT): - AzurePowerShellCredential().get_token("scope") + getattr(AzurePowerShellCredential(), get_token_method)("scope") -def test_blocked_by_execution_policy(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_blocked_by_execution_policy(get_token_method): """The credential should raise CredentialUnavailableError when execution policy blocks Get-AzAccessToken""" stderr = r"""#< CLIXML @@ -202,11 +220,11 @@ def test_blocked_by_execution_policy(): Popen = get_mock_Popen(return_code=1, stderr=stderr) with patch(POPEN, Popen): with pytest.raises(CredentialUnavailableError, match=BLOCKED_BY_EXECUTION_POLICY): - AzurePowerShellCredential().get_token("scope") + getattr(AzurePowerShellCredential(), get_token_method)("scope") -@pytest.mark.skipif(sys.version_info < (3, 3), reason="Python 3.3 added timeout support to Popen") -def test_timeout(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_timeout(get_token_method): """The credential should kill the subprocess after a timeout""" from subprocess import TimeoutExpired @@ -214,7 +232,7 @@ def test_timeout(): proc = Mock(communicate=Mock(side_effect=TimeoutExpired("", 42)), returncode=None) with patch(POPEN, Mock(return_value=proc)): with pytest.raises(CredentialUnavailableError): - AzurePowerShellCredential(process_timeout=42).get_token("scope") + getattr(AzurePowerShellCredential(process_timeout=42), get_token_method)("scope") assert proc.communicate.call_count == 1 # Ensure custom timeout is passed to subprocess @@ -223,7 +241,8 @@ def test_timeout(): assert kwargs["timeout"] == 42 -def test_unexpected_error(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_unexpected_error(get_token_method): """The credential should log stderr when Get-AzAccessToken returns an unexpected error""" class MockHandler(logging.Handler): @@ -243,7 +262,7 @@ def emit(self, record): Popen = get_mock_Popen(return_code=42, stderr=expected_output) with patch(POPEN, Popen): with pytest.raises(ClientAuthenticationError): - AzurePowerShellCredential().get_token("scope") + getattr(AzurePowerShellCredential(), get_token_method)("scope") for message in mock_handler.messages: if message.levelname == "DEBUG" and expected_output in message.message: @@ -253,13 +272,16 @@ def emit(self, record): @pytest.mark.parametrize( - "error_message", - ( - "'pwsh' is not recognized as an internal or external command,\r\noperable program or batch file.", - "some other message", + "error_message,get_token_method", + product( + ( + "'pwsh' is not recognized as an internal or external command,\r\noperable program or batch file.", + "some other message", + ), + GET_TOKEN_METHODS, ), ) -def test_windows_powershell_fallback(error_message): +def test_windows_powershell_fallback(error_message, get_token_method): """On Windows, the credential should fall back to powershell.exe when pwsh.exe isn't on the path""" class Fake: @@ -285,12 +307,13 @@ def Popen(args, **kwargs): with patch.dict("os.environ", {"SYSTEMROOT": "foo"}): with patch(POPEN, Popen): with pytest.raises(CredentialUnavailableError, match=AZ_ACCOUNT_NOT_INSTALLED): - AzurePowerShellCredential().get_token("scope") + getattr(AzurePowerShellCredential(), get_token_method)("scope") assert Fake.calls == 2 -def test_multitenant_authentication(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_multitenant_authentication(get_token_method): first_token = "***" second_tenant = "12345" second_token = first_token * 2 @@ -314,18 +337,22 @@ def fake_Popen(command, **_): credential = AzurePowerShellCredential() with patch(POPEN, fake_Popen): - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == first_token - token = credential.get_token("scope", tenant_id=second_tenant) + kwargs = {"tenant_id": second_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert token.token == second_token # should still default to the first tenant - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == first_token -def test_multitenant_authentication_not_allowed(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_multitenant_authentication_not_allowed(get_token_method): expected_token = "***" def fake_Popen(command, **_): @@ -346,9 +373,12 @@ def fake_Popen(command, **_): credential = AzurePowerShellCredential() with patch(POPEN, fake_Popen): - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == expected_token with patch.dict("os.environ", {EnvironmentVariables.AZURE_IDENTITY_DISABLE_MULTITENANTAUTH: "true"}): - token = credential.get_token("scope", tenant_id="12345") + kwargs = {"tenant_id": "12345"} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert token.token == expected_token diff --git a/sdk/identity/azure-identity/tests/test_powershell_credential_async.py b/sdk/identity/azure-identity/tests/test_powershell_credential_async.py index bc694b7ef01a..2764575655e8 100644 --- a/sdk/identity/azure-identity/tests/test_powershell_credential_async.py +++ b/sdk/identity/azure-identity/tests/test_powershell_credential_async.py @@ -4,6 +4,7 @@ # ------------------------------------ import asyncio import base64 +from itertools import product import logging import re import sys @@ -24,7 +25,7 @@ import pytest from credscan_ignore import POWERSHELL_INVALID_OPERATION_EXCEPTION, POWERSHELL_NOT_LOGGED_IN_ERROR -from helpers import INVALID_CHARACTERS +from helpers import INVALID_CHARACTERS, GET_TOKEN_METHODS from helpers_async import get_completed_future from test_powershell_credential import PREPARING_MODULES @@ -38,29 +39,33 @@ def get_mock_exec(return_code=0, stdout="", stderr=""): return Mock(return_value=get_completed_future(Mock(communicate=communicate, returncode=return_code))) -async def test_no_scopes(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_no_scopes(get_token_method): """The credential should raise ValueError when get_token is called with no scopes""" with pytest.raises(ValueError): - await AzurePowerShellCredential().get_token() + await getattr(AzurePowerShellCredential(), get_token_method)() -async def test_multiple_scopes(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_multiple_scopes(get_token_method): """The credential should raise ValueError when get_token is called with more than one scope""" with pytest.raises(ValueError): - await AzurePowerShellCredential().get_token("one scope", "and another") + await getattr(AzurePowerShellCredential(), get_token_method)("one scope", "and another") -async def test_cannot_execute_shell(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_cannot_execute_shell(get_token_method): """The credential should raise CredentialUnavailableError when the subprocess doesn't start""" with patch(CREATE_SUBPROCESS_EXEC, Mock(side_effect=OSError)): with pytest.raises(CredentialUnavailableError): - await AzurePowerShellCredential().get_token("scope") + await getattr(AzurePowerShellCredential(), get_token_method)("scope") -async def test_invalid_tenant_id(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_invalid_tenant_id(get_token_method): """Invalid tenant IDs should raise ValueErrors.""" for c in INVALID_CHARACTERS: @@ -68,19 +73,23 @@ async def test_invalid_tenant_id(): AzurePowerShellCredential(tenant_id="tenant" + c) with pytest.raises(ValueError): - await AzurePowerShellCredential().get_token("scope", tenant_id="tenant" + c) + kwargs = {"tenant_id": "tenant" + c} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + await getattr(AzurePowerShellCredential(), get_token_method)("scope", **kwargs) -async def test_invalid_scopes(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_invalid_scopes(get_token_method): """Scopes with invalid characters should raise ValueErrors.""" for c in INVALID_CHARACTERS: with pytest.raises(ValueError): - await AzurePowerShellCredential().get_token("scope" + c) + await getattr(AzurePowerShellCredential(), get_token_method)("scope" + c) -@pytest.mark.parametrize("stderr", ("", PREPARING_MODULES)) -async def test_get_token(stderr): +@pytest.mark.parametrize("stderr,get_token_method", product(("", PREPARING_MODULES), GET_TOKEN_METHODS)) +async def test_get_token(stderr, get_token_method): """The credential should parse Azure PowerShell's output to an AccessToken""" expected_access_token = "access" @@ -90,7 +99,7 @@ async def test_get_token(stderr): mock_exec = get_mock_exec(stdout=stdout, stderr=stderr) with patch(CREATE_SUBPROCESS_EXEC, mock_exec): - token = await AzurePowerShellCredential().get_token(scope) + token = await getattr(AzurePowerShellCredential(), get_token_method)(scope) assert token.token == expected_access_token assert token.expires_on == expected_expires_on @@ -110,8 +119,8 @@ async def test_get_token(stderr): assert mock_exec().result().communicate.call_count == 1 -@pytest.mark.parametrize("stderr", ("", PREPARING_MODULES)) -async def test_get_token_tenant_id(stderr): +@pytest.mark.parametrize("stderr,get_token_method", product(("", PREPARING_MODULES), GET_TOKEN_METHODS)) +async def test_get_token_tenant_id(stderr, get_token_method): expected_access_token = "access" expected_expires_on = 1617923581 scope = "scope" @@ -119,69 +128,82 @@ async def test_get_token_tenant_id(stderr): mock_exec = get_mock_exec(stdout=stdout, stderr=stderr) with patch(CREATE_SUBPROCESS_EXEC, mock_exec): - token = await AzurePowerShellCredential().get_token(scope, tenant_id="tenant-id") + kwargs = {"tenant_id": "tenant-id"} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(AzurePowerShellCredential(), get_token_method)(scope, **kwargs) assert token.token == expected_access_token assert token.expires_on == expected_expires_on -async def test_ignores_extraneous_stdout_content(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_ignores_extraneous_stdout_content(get_token_method): expected_access_token = "access" expected_expires_on = 1617923581 motd = "MOTD: Customize your experience: save your profile to $HOME/.config/PowerShell\n" mock_exec = get_mock_exec(stdout=motd + "azsdk%{}%{}".format(expected_access_token, expected_expires_on)) with patch(CREATE_SUBPROCESS_EXEC, mock_exec): - token = await AzurePowerShellCredential().get_token("scope") + token = await getattr(AzurePowerShellCredential(), get_token_method)("scope") assert token.token == expected_access_token assert token.expires_on == expected_expires_on -async def test_az_powershell_not_installed(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_az_powershell_not_installed(get_token_method): """The credential should raise CredentialUnavailableError when Azure PowerShell isn't installed""" with patch(CREATE_SUBPROCESS_EXEC, get_mock_exec(stdout=NO_AZ_ACCOUNT_MODULE)): with pytest.raises(CredentialUnavailableError, match=AZ_ACCOUNT_NOT_INSTALLED): - await AzurePowerShellCredential().get_token("scope") + await getattr(AzurePowerShellCredential(), get_token_method)("scope") @pytest.mark.parametrize( - "stderr", - ( - "'pwsh' is not recognized as an internal or external command,\r\noperable program or batch file.", - "'powershell' is not recognized as an internal or external command,\r\noperable program or batch file.", + "stderr,get_token_method", + product( + ( + "'pwsh' is not recognized as an internal or external command,\r\noperable program or batch file.", + "'powershell' is not recognized as an internal or external command,\r\noperable program or batch file.", + ), + GET_TOKEN_METHODS, ), ) -async def test_powershell_not_installed_cmd(stderr): +async def test_powershell_not_installed_cmd(stderr, get_token_method): """The credential should raise CredentialUnavailableError when PowerShell isn't installed""" mock_exec = get_mock_exec(return_code=1, stderr=stderr) with patch(CREATE_SUBPROCESS_EXEC, mock_exec): with pytest.raises(CredentialUnavailableError, match=POWERSHELL_NOT_INSTALLED): - await AzurePowerShellCredential().get_token("scope") + await getattr(AzurePowerShellCredential(), get_token_method)("scope") -async def test_powershell_not_installed_sh(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_powershell_not_installed_sh(get_token_method): """The credential should raise CredentialUnavailableError when PowerShell isn't installed""" mock_exec = get_mock_exec(return_code=127, stderr="/bin/sh: 0: Can't open pwsh") with patch(CREATE_SUBPROCESS_EXEC, mock_exec): with pytest.raises(CredentialUnavailableError, match=POWERSHELL_NOT_INSTALLED): - await AzurePowerShellCredential().get_token("scope") + await getattr(AzurePowerShellCredential(), get_token_method)("scope") -@pytest.mark.parametrize("stderr", (POWERSHELL_INVALID_OPERATION_EXCEPTION, POWERSHELL_NOT_LOGGED_IN_ERROR)) -async def test_not_logged_in(stderr): +@pytest.mark.parametrize( + "stderr,get_token_method", + product((POWERSHELL_INVALID_OPERATION_EXCEPTION, POWERSHELL_NOT_LOGGED_IN_ERROR), GET_TOKEN_METHODS), +) +async def test_not_logged_in(stderr, get_token_method): """The credential should raise CredentialUnavailableError when a user isn't logged in to Azure PowerShell""" mock_exec = get_mock_exec(return_code=1, stderr=stderr) with patch(CREATE_SUBPROCESS_EXEC, mock_exec): with pytest.raises(CredentialUnavailableError, match=RUN_CONNECT_AZ_ACCOUNT): - await AzurePowerShellCredential().get_token("scope") + await getattr(AzurePowerShellCredential(), get_token_method)("scope") -async def test_blocked_by_execution_policy(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_blocked_by_execution_policy(get_token_method): """The credential should raise CredentialUnavailableError when execution policy blocks Get-AzAccessToken""" stderr = r"""#< CLIXML @@ -189,21 +211,23 @@ async def test_blocked_by_execution_policy(): mock_exec = get_mock_exec(return_code=1, stderr=stderr) with patch(CREATE_SUBPROCESS_EXEC, mock_exec): with pytest.raises(CredentialUnavailableError, match=BLOCKED_BY_EXECUTION_POLICY): - await AzurePowerShellCredential().get_token("scope") + await getattr(AzurePowerShellCredential(), get_token_method)("scope") -async def test_timeout(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_timeout(get_token_method): """The credential should kill the subprocess after a timeout""" proc = Mock(communicate=Mock(side_effect=asyncio.TimeoutError), returncode=None) with patch(CREATE_SUBPROCESS_EXEC, Mock(return_value=get_completed_future(proc))): with pytest.raises(CredentialUnavailableError): - await AzurePowerShellCredential().get_token("scope") + await getattr(AzurePowerShellCredential(), get_token_method)("scope") assert proc.communicate.call_count == 1 -async def test_unexpected_error(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_unexpected_error(get_token_method): """The credential should log stderr when Get-AzAccessToken returns an unexpected error""" class MockHandler(logging.Handler): @@ -223,7 +247,7 @@ def emit(self, record): mock_exec = get_mock_exec(return_code=42, stderr=expected_output) with patch(CREATE_SUBPROCESS_EXEC, mock_exec): with pytest.raises(ClientAuthenticationError): - await AzurePowerShellCredential().get_token("scope") + await getattr(AzurePowerShellCredential(), get_token_method)("scope") for message in mock_handler.messages: if message.levelname == "DEBUG" and expected_output in message.message: @@ -233,17 +257,22 @@ def emit(self, record): @pytest.mark.skipif(not sys.platform.startswith("win"), reason="tests Windows-specific behavior") -async def test_windows_event_loop(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_windows_event_loop(get_token_method): """The credential should fall back to the sync implementation when not using ProactorEventLoop on Windows""" sync_get_token = Mock() credential = AzurePowerShellCredential() with patch(AzurePowerShellCredential.__module__ + "._SyncCredential") as fallback: - fallback.return_value = Mock(spec_set=["get_token"], get_token=sync_get_token) + fallback.return_value = Mock( + spec_set=["get_token", "get_token_info"], + get_token=sync_get_token, + get_token_info=sync_get_token, + ) with patch(AzurePowerShellCredential.__module__ + ".asyncio.get_event_loop"): # asyncio.get_event_loop now returns Mock, i.e. never ProactorEventLoop - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") assert sync_get_token.call_count == 1 @@ -256,7 +285,8 @@ async def test_windows_event_loop(): "some other message", ), ) -async def test_windows_powershell_fallback(error_message): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_windows_powershell_fallback(error_message, get_token_method): """On Windows, the credential should fall back to powershell.exe when pwsh.exe isn't on the path""" calls = 0 @@ -282,12 +312,13 @@ async def mock_exec(*args, **kwargs): credential = AzurePowerShellCredential() with pytest.raises(CredentialUnavailableError, match=AZ_ACCOUNT_NOT_INSTALLED): with patch(CREATE_SUBPROCESS_EXEC, mock_exec): - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") assert calls == 2 -async def test_multitenant_authentication(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_multitenant_authentication(get_token_method): first_token = "***" second_tenant = "12345" second_token = first_token * 2 @@ -312,18 +343,22 @@ async def fake_exec(*args, **_): credential = AzurePowerShellCredential() with patch(CREATE_SUBPROCESS_EXEC, fake_exec): - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == first_token - token = await credential.get_token("scope", tenant_id=second_tenant) + kwargs = {"tenant_id": second_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(credential, get_token_method)("scope", **kwargs) assert token.token == second_token # should still default to the first tenant - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == first_token -async def test_multitenant_authentication_not_allowed(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_multitenant_authentication_not_allowed(get_token_method): expected_token = "***" async def fake_exec(*args, **_): @@ -344,9 +379,12 @@ async def fake_exec(*args, **_): credential = AzurePowerShellCredential() with patch(CREATE_SUBPROCESS_EXEC, fake_exec): - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == expected_token with patch.dict("os.environ", {EnvironmentVariables.AZURE_IDENTITY_DISABLE_MULTITENANTAUTH: "true"}): - token = await credential.get_token("scope", tenant_id="12345") + kwargs = {"tenant_id": "12345"} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(credential, get_token_method)("scope", **kwargs) assert token.token == expected_token diff --git a/sdk/identity/azure-identity/tests/test_shared_cache_credential.py b/sdk/identity/azure-identity/tests/test_shared_cache_credential.py index ba1052ac52c3..66709b61f0d5 100644 --- a/sdk/identity/azure-identity/tests/test_shared_cache_credential.py +++ b/sdk/identity/azure-identity/tests/test_shared_cache_credential.py @@ -35,14 +35,16 @@ msal_validating_transport, Request, validating_transport, + GET_TOKEN_METHODS, ) -def test_close(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_close(get_token_method): transport = MagicMock() credential = SharedTokenCacheCredential(transport=transport, _cache=TokenCache()) with pytest.raises(CredentialUnavailableError): - credential.get_token("scope") + getattr(credential, get_token_method)("scope") assert not transport.__enter__.called assert not transport.__exit__.called @@ -52,11 +54,12 @@ def test_close(): assert transport.__exit__.call_count == 1 -def test_context_manager(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_context_manager(get_token_method): transport = MagicMock() credential = SharedTokenCacheCredential(transport=transport, _cache=TokenCache()) with pytest.raises(CredentialUnavailableError): - credential.get_token("scope") + getattr(credential, get_token_method)("scope") assert not transport.__enter__.called assert not transport.__exit__.called @@ -92,15 +95,17 @@ def test_supported(): assert SharedTokenCacheCredential.supported() -def test_no_scopes(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_no_scopes(get_token_method): """The credential should raise when get_token is called with no scopes""" credential = SharedTokenCacheCredential(_cache=TokenCache()) with pytest.raises(ValueError): - credential.get_token() + getattr(credential, get_token_method)() -def test_policies_configurable(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_policies_configurable(get_token_method): policy = Mock(spec_set=SansIOHTTPPolicy, on_request=Mock()) def send(*_, **kwargs): @@ -115,12 +120,13 @@ def send(*_, **kwargs): transport=Mock(send=send), ) - credential.get_token("scope") + getattr(credential, get_token_method)("scope") assert policy.on_request.called -def test_user_agent(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_user_agent(get_token_method): transport = validating_transport( requests=[Request(required_headers={"User-Agent": USER_AGENT})], responses=[mock_response(json_payload=build_aad_response(access_token="**"))], @@ -130,10 +136,11 @@ def test_user_agent(): _cache=populated_cache(get_account_event("test@user", "uid", "utid")), transport=transport ) - credential.get_token("scope") + getattr(credential, get_token_method)("scope") -def test_tenant_id(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_tenant_id(get_token_method): transport = validating_transport( requests=[Request(required_headers={"User-Agent": USER_AGENT})], responses=[mock_response(json_payload=build_aad_response(access_token="**"))], @@ -145,7 +152,11 @@ def test_tenant_id(): additionally_allowed_tenants=["*"], ) - credential.get_token("scope", tenant_id="tenant_id") + kwargs = {"tenant_id": "tenant_id"} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + + getattr(credential, get_token_method)("scope", **kwargs) @pytest.mark.parametrize("authority", ("localhost", "https://localhost")) @@ -168,20 +179,25 @@ def _get_auth_client(self, authority=None, **kwargs): MockCredential(_cache=TokenCache(), authority=authority, transport=transport) -def test_empty_cache(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_empty_cache(get_token_method): """the credential should raise CredentialUnavailableError when the cache is empty""" with pytest.raises(CredentialUnavailableError, match=NO_ACCOUNTS): - SharedTokenCacheCredential(_cache=TokenCache()).get_token("scope") + getattr(SharedTokenCacheCredential(_cache=TokenCache()), get_token_method)("scope") with pytest.raises(CredentialUnavailableError, match=NO_ACCOUNTS): - SharedTokenCacheCredential(_cache=TokenCache(), username="not@cache").get_token("scope") + getattr(SharedTokenCacheCredential(_cache=TokenCache(), username="not@cache"), get_token_method)("scope") with pytest.raises(CredentialUnavailableError, match=NO_ACCOUNTS): - SharedTokenCacheCredential(_cache=TokenCache(), tenant_id="not-cached").get_token("scope") + getattr(SharedTokenCacheCredential(_cache=TokenCache(), tenant_id="not-cached"), get_token_method)("scope") with pytest.raises(CredentialUnavailableError, match=NO_ACCOUNTS): - SharedTokenCacheCredential(_cache=TokenCache(), tenant_id="not-cached", username="not@cache").get_token("scope") + getattr( + SharedTokenCacheCredential(_cache=TokenCache(), tenant_id="not-cached", username="not@cache"), + get_token_method, + )("scope") -def test_no_matching_account_for_username(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_no_matching_account_for_username(get_token_method): """one cached account, username specified, username doesn't match -> credential should raise""" upn = "spam@eggs" @@ -190,13 +206,14 @@ def test_no_matching_account_for_username(): cache = populated_cache(account) with pytest.raises(CredentialUnavailableError) as ex: - SharedTokenCacheCredential(_cache=cache, username="not" + upn).get_token("scope") + getattr(SharedTokenCacheCredential(_cache=cache, username="not" + upn), get_token_method)("scope") assert ex.value.message.startswith(NO_MATCHING_ACCOUNTS[: NO_MATCHING_ACCOUNTS.index("{")]) assert "not" + upn in ex.value.message -def test_no_matching_account_for_tenant(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_no_matching_account_for_tenant(get_token_method): """one cached account, tenant specified, tenant doesn't match -> credential should raise""" upn = "spam@eggs" @@ -205,13 +222,14 @@ def test_no_matching_account_for_tenant(): cache = populated_cache(account) with pytest.raises(CredentialUnavailableError) as ex: - SharedTokenCacheCredential(_cache=cache, tenant_id="not-" + tenant).get_token("scope") + getattr(SharedTokenCacheCredential(_cache=cache, tenant_id="not-" + tenant), get_token_method)("scope") assert ex.value.message.startswith(NO_MATCHING_ACCOUNTS[: NO_MATCHING_ACCOUNTS.index("{")]) assert "not-" + tenant in ex.value.message -def test_no_matching_account_for_tenant_and_username(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_no_matching_account_for_tenant_and_username(get_token_method): """one cached account, tenant and username specified, neither match -> credential should raise""" upn = "spam@eggs" @@ -220,13 +238,16 @@ def test_no_matching_account_for_tenant_and_username(): cache = populated_cache(account) with pytest.raises(CredentialUnavailableError) as ex: - SharedTokenCacheCredential(_cache=cache, tenant_id="not-" + tenant, username="not" + upn).get_token("scope") + getattr( + SharedTokenCacheCredential(_cache=cache, tenant_id="not-" + tenant, username="not" + upn), get_token_method + )("scope") assert ex.value.message.startswith(NO_MATCHING_ACCOUNTS[: NO_MATCHING_ACCOUNTS.index("{")]) assert "not" + upn in ex.value.message and "not-" + tenant in ex.value.message -def test_no_matching_account_for_tenant_or_username(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_no_matching_account_for_tenant_or_username(get_token_method): """two cached accounts, username and tenant specified, one account matches each -> credential should raise""" refresh_token_a = "refresh-token-a" @@ -243,18 +264,19 @@ def test_no_matching_account_for_tenant_or_username(): credential = SharedTokenCacheCredential(username=upn_a, tenant_id=tenant_b, _cache=cache, transport=transport) with pytest.raises(CredentialUnavailableError) as ex: - credential.get_token("scope") + getattr(credential, get_token_method)("scope") assert ex.value.message.startswith(NO_MATCHING_ACCOUNTS[: NO_MATCHING_ACCOUNTS.index("{")]) assert upn_a in ex.value.message and tenant_b in ex.value.message credential = SharedTokenCacheCredential(username=upn_b, tenant_id=tenant_a, _cache=cache, transport=transport) with pytest.raises(CredentialUnavailableError) as ex: - credential.get_token("scope") + getattr(credential, get_token_method)("scope") assert ex.value.message.startswith(NO_MATCHING_ACCOUNTS[: NO_MATCHING_ACCOUNTS.index("{")]) assert upn_b in ex.value.message and tenant_a in ex.value.message -def test_single_account_matching_username(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_single_account_matching_username(get_token_method): """one cached account, username specified, username matches -> credential should auth that account""" upn = "spam@eggs" @@ -269,11 +291,12 @@ def test_single_account_matching_username(): responses=[mock_response(json_payload=build_aad_response(access_token=expected_token))], ) credential = SharedTokenCacheCredential(_cache=cache, transport=transport, username=upn) - token = credential.get_token(scope) + token = getattr(credential, get_token_method)(scope) assert token.token == expected_token -def test_single_account_matching_tenant(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_single_account_matching_tenant(get_token_method): """one cached account, tenant specified, tenant matches -> credential should auth that account""" tenant_id = "tenant-id" @@ -288,11 +311,12 @@ def test_single_account_matching_tenant(): responses=[mock_response(json_payload=build_aad_response(access_token=expected_token))], ) credential = SharedTokenCacheCredential(_cache=cache, transport=transport, tenant_id=tenant_id) - token = credential.get_token(scope) + token = getattr(credential, get_token_method)(scope) assert token.token == expected_token -def test_single_account_matching_tenant_and_username(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_single_account_matching_tenant_and_username(get_token_method): """one cached account, tenant and username specified, both match -> credential should auth that account""" upn = "spam@eggs" @@ -308,11 +332,12 @@ def test_single_account_matching_tenant_and_username(): responses=[mock_response(json_payload=build_aad_response(access_token=expected_token))], ) credential = SharedTokenCacheCredential(_cache=cache, transport=transport, tenant_id=tenant_id, username=upn) - token = credential.get_token(scope) + token = getattr(credential, get_token_method)(scope) assert token.token == expected_token -def test_single_account(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_single_account(get_token_method): """one cached account, no username specified -> credential should auth that account""" refresh_token = "refresh-token" @@ -327,11 +352,12 @@ def test_single_account(): ) credential = SharedTokenCacheCredential(_cache=cache, transport=transport) - token = credential.get_token(scope) + token = getattr(credential, get_token_method)(scope) assert token.token == expected_token -def test_no_refresh_token(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_no_refresh_token(get_token_method): """one cached account, account has no refresh token -> credential should raise""" account = get_account_event(uid="uid_a", utid="utid", username="spam@eggs", refresh_token=None) @@ -341,14 +367,15 @@ def test_no_refresh_token(): credential = SharedTokenCacheCredential(_cache=cache, transport=transport) with pytest.raises(CredentialUnavailableError, match=NO_ACCOUNTS): - credential.get_token("scope") + getattr(credential, get_token_method)("scope") credential = SharedTokenCacheCredential(_cache=cache, transport=transport, username="not@cache") with pytest.raises(CredentialUnavailableError, match=NO_ACCOUNTS): - credential.get_token("scope") + getattr(credential, get_token_method)("scope") -def test_two_accounts_no_username_or_tenant(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_two_accounts_no_username_or_tenant(get_token_method): """two cached accounts, no username or tenant specified -> credential should raise""" upn_a = "a@foo" @@ -363,10 +390,11 @@ def test_two_accounts_no_username_or_tenant(): # two users in the cache, no username specified -> CredentialUnavailableError credential = SharedTokenCacheCredential(_cache=cache, transport=transport) with pytest.raises(ClientAuthenticationError, match=MULTIPLE_ACCOUNTS) as ex: - credential.get_token("scope") + getattr(credential, get_token_method)("scope") -def test_two_accounts_username_specified(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_two_accounts_username_specified(get_token_method): """two cached accounts, username specified, one account matches -> credential should auth that account""" scope = "scope" @@ -383,11 +411,12 @@ def test_two_accounts_username_specified(): responses=[mock_response(json_payload=build_aad_response(access_token=expected_token))], ) credential = SharedTokenCacheCredential(username=upn_a, _cache=cache, transport=transport) - token = credential.get_token(scope) + token = getattr(credential, get_token_method)(scope) assert token.token == expected_token -def test_two_accounts_tenant_specified(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_two_accounts_tenant_specified(get_token_method): """two cached accounts, tenant specified, one account matches -> credential should auth that account""" scope = "scope" @@ -405,11 +434,12 @@ def test_two_accounts_tenant_specified(): responses=[mock_response(json_payload=build_aad_response(access_token=expected_token))], ) credential = SharedTokenCacheCredential(tenant_id=tenant_id, _cache=cache, transport=transport) - token = credential.get_token(scope) + token = getattr(credential, get_token_method)(scope) assert token.token == expected_token -def test_two_accounts_tenant_and_username_specified(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_two_accounts_tenant_and_username_specified(get_token_method): """two cached accounts, tenant and username specified, one account matches both -> credential should auth that account""" scope = "scope" @@ -427,11 +457,12 @@ def test_two_accounts_tenant_and_username_specified(): responses=[mock_response(json_payload=build_aad_response(access_token=expected_token))], ) credential = SharedTokenCacheCredential(tenant_id=tenant_id, username=upn_a, _cache=cache, transport=transport) - token = credential.get_token(scope) + token = getattr(credential, get_token_method)(scope) assert token.token == expected_token -def test_same_username_different_tenants(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_same_username_different_tenants(get_token_method): """two cached accounts, same username, different tenants""" access_token_a = "access-token-a" @@ -450,7 +481,7 @@ def test_same_username_different_tenants(): transport = Mock(side_effect=Exception()) # (so it shouldn't use the network) credential = SharedTokenCacheCredential(username=upn, _cache=cache, transport=transport) with pytest.raises(CredentialUnavailableError) as ex: - credential.get_token("scope") + getattr(credential, get_token_method)("scope") assert ex.value.message.startswith(MULTIPLE_MATCHING_ACCOUNTS[: MULTIPLE_MATCHING_ACCOUNTS.index("{")]) assert upn in ex.value.message @@ -462,7 +493,7 @@ def test_same_username_different_tenants(): responses=[mock_response(json_payload=build_aad_response(access_token=access_token_a))], ) credential = SharedTokenCacheCredential(tenant_id=tenant_a, _cache=cache, transport=transport) - token = credential.get_token(scope) + token = getattr(credential, get_token_method)(scope) assert token.token == access_token_a transport = validating_transport( @@ -470,11 +501,12 @@ def test_same_username_different_tenants(): responses=[mock_response(json_payload=build_aad_response(access_token=access_token_b))], ) credential = SharedTokenCacheCredential(tenant_id=tenant_b, _cache=cache, transport=transport) - token = credential.get_token(scope) + token = getattr(credential, get_token_method)(scope) assert token.token == access_token_b -def test_same_tenant_different_usernames(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_same_tenant_different_usernames(get_token_method): """two cached accounts, same tenant, different usernames""" access_token_a = "access-token-a" @@ -493,7 +525,7 @@ def test_same_tenant_different_usernames(): transport = Mock(side_effect=Exception()) # (so it shouldn't use the network) credential = SharedTokenCacheCredential(tenant_id=tenant_id, _cache=cache, transport=transport) with pytest.raises(CredentialUnavailableError) as ex: - credential.get_token("scope") + getattr(credential, get_token_method)("scope") assert ex.value.message.startswith(MULTIPLE_MATCHING_ACCOUNTS[: MULTIPLE_MATCHING_ACCOUNTS.index("{")]) assert tenant_id in ex.value.message @@ -505,7 +537,7 @@ def test_same_tenant_different_usernames(): responses=[mock_response(json_payload=build_aad_response(access_token=access_token_a))], ) credential = SharedTokenCacheCredential(username=upn_b, _cache=cache, transport=transport) - token = credential.get_token(scope) + token = getattr(credential, get_token_method)(scope) assert token.token == access_token_a transport = validating_transport( @@ -513,11 +545,12 @@ def test_same_tenant_different_usernames(): responses=[mock_response(json_payload=build_aad_response(access_token=access_token_a))], ) credential = SharedTokenCacheCredential(username=upn_a, _cache=cache, transport=transport) - token = credential.get_token(scope) + token = getattr(credential, get_token_method)(scope) assert token.token == access_token_a -def test_authority_aliases(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_authority_aliases(get_token_method): """the credential should use a refresh token valid for any known alias of its authority""" expected_access_token = "access-token" @@ -536,7 +569,7 @@ def test_authority_aliases(): responses=[mock_response(json_payload=build_aad_response(access_token=expected_access_token))], ) credential = SharedTokenCacheCredential(authority=authority, _cache=cache, transport=transport) - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == expected_access_token # it should be acceptable for every known alias of this authority @@ -546,11 +579,12 @@ def test_authority_aliases(): responses=[mock_response(json_payload=build_aad_response(access_token=expected_access_token))], ) credential = SharedTokenCacheCredential(authority=alias, _cache=cache, transport=transport) - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == expected_access_token -def test_authority_with_no_known_alias(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_authority_with_no_known_alias(get_token_method): """given an appropriate token, an authority with no known aliases should work""" authority = "unknown.authority" @@ -563,11 +597,12 @@ def test_authority_with_no_known_alias(): responses=[mock_response(json_payload=build_aad_response(access_token=expected_access_token))], ) credential = SharedTokenCacheCredential(authority=authority, _cache=cache, transport=transport) - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == expected_access_token -def test_authority_environment_variable(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_authority_environment_variable(get_token_method): """the credential should accept an authority by environment variable when none is otherwise specified""" authority = "localhost" @@ -581,11 +616,12 @@ def test_authority_environment_variable(): ) with patch.dict("os.environ", {EnvironmentVariables.AZURE_AUTHORITY_HOST: authority}, clear=True): credential = SharedTokenCacheCredential(transport=transport, _cache=cache) - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == expected_access_token -def test_authentication_record_empty_cache(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_authentication_record_empty_cache(get_token_method): record = AuthenticationRecord("tenant-id", "client_id", "authority", "home_account_id", "username") def send(request, **kwargs): @@ -601,10 +637,11 @@ def send(request, **kwargs): ) with pytest.raises(CredentialUnavailableError): - credential.get_token("scope") + getattr(credential, get_token_method)("scope") -def test_authentication_record_no_match(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_authentication_record_no_match(get_token_method): tenant_id = "tenant-id" client_id = "client-id" authority = "localhost" @@ -632,10 +669,11 @@ def send(request, **kwargs): credential = SharedTokenCacheCredential(authentication_record=record, transport=Mock(send=send), _cache=cache) with pytest.raises(CredentialUnavailableError): - credential.get_token("scope") + getattr(credential, get_token_method)("scope") -def test_authentication_record(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_authentication_record(get_token_method): tenant_id = "tenant-id" client_id = "client-id" authority = "localhost" @@ -658,11 +696,12 @@ def test_authentication_record(): ) credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=cache) - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == expected_access_token -def test_auth_record_multiple_accounts_for_username(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_auth_record_multiple_accounts_for_username(get_token_method): tenant_id = "tenant-id" client_id = "client-id" authority = "localhost" @@ -695,11 +734,12 @@ def test_auth_record_multiple_accounts_for_username(): ) credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=cache) - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == expected_access_token -def test_writes_to_cache(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_writes_to_cache(get_token_method): """the credential should write tokens it acquires to the cache""" scope = "scope" @@ -731,14 +771,14 @@ def test_writes_to_cache(): ], ) credential = SharedTokenCacheCredential(_cache=cache, transport=transport) - token = credential.get_token(scope) + token = getattr(credential, get_token_method)(scope) assert token.token == expected_access_token # access token should be in the cache, and another instance should retrieve it credential = SharedTokenCacheCredential( _cache=cache, transport=Mock(send=Mock(side_effect=Exception("the credential should return a cached token"))) ) - token = credential.get_token(scope) + token = getattr(credential, get_token_method)(scope) assert token.token == expected_access_token # and the credential should have updated the cached refresh token @@ -748,14 +788,15 @@ def test_writes_to_cache(): responses=[mock_response(json_payload=build_aad_response(access_token=second_access_token))], ) credential = SharedTokenCacheCredential(_cache=cache, transport=transport) - token = credential.get_token("some other " + scope) + token = getattr(credential, get_token_method)("some other " + scope) assert token.token == second_access_token # verify the credential didn't add a new cache entry assert len(list(cache.search(TokenCache.CredentialType.REFRESH_TOKEN))) == 1 -def test_initialization(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_initialization(get_token_method): """the credential should attempt to load the cache when it's needed and no cache has been established.""" with patch("azure.identity._internal.shared_token_cache._load_persistent_cache") as mock_cache_loader: @@ -765,15 +806,16 @@ def test_initialization(): assert mock_cache_loader.call_count == 0 with pytest.raises(CredentialUnavailableError, match="Shared token cache unavailable"): - credential.get_token("scope") + getattr(credential, get_token_method)("scope") assert mock_cache_loader.call_count == 1 with pytest.raises(CredentialUnavailableError, match="Shared token cache unavailable"): - credential.get_token("scope") + getattr(credential, get_token_method)("scope") assert mock_cache_loader.call_count == 2 -def test_initialization_with_cache_options(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_initialization_with_cache_options(get_token_method): """the credential should use user-supplied persistence options""" with patch("azure.identity._internal.shared_token_cache._load_persistent_cache") as mock_cache_loader: @@ -781,21 +823,25 @@ def test_initialization_with_cache_options(): credential = SharedTokenCacheCredential(cache_persistence_options=options) with pytest.raises(CredentialUnavailableError): - credential.get_token("scope") + getattr(credential, get_token_method)("scope") assert mock_cache_loader.call_count == 1 args, _ = mock_cache_loader.call_args assert args[0] == options assert args[1] is False # is_cae is False. with pytest.raises(CredentialUnavailableError): - credential.get_token("scope", enable_cae=True) + kwargs = {"enable_cae": True} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + getattr(credential, get_token_method)("scope", **kwargs) assert mock_cache_loader.call_count == 2 args, _ = mock_cache_loader.call_args assert args[0] == options assert args[1] is True # is_cae is True. -def test_authentication_record_authenticating_tenant(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_authentication_record_authenticating_tenant(get_token_method): """when given a record and 'tenant_id', the credential should authenticate in the latter""" expected_tenant_id = "tenant-id" @@ -812,12 +858,13 @@ def mock_send(request, **_): authentication_record=record, _cache=TokenCache(), tenant_id=expected_tenant_id, transport=transport ) with pytest.raises(CredentialUnavailableError): - credential.get_token("scope") # this raises because the cache is empty + getattr(credential, get_token_method)("scope") # this raises because the cache is empty assert transport.send.called -def test_client_capabilities(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_client_capabilities(get_token_method): """the credential should configure MSAL for capability CP1 only if enable_cae is passed.""" def send(request, **kwargs): @@ -834,20 +881,24 @@ def send(request, **kwargs): with patch("azure.identity._credentials.silent.PublicClientApplication") as PublicClientApplication: with pytest.raises(ClientAuthenticationError): # (cache is empty) - credential.get_token("scope") + getattr(credential, get_token_method)("scope") assert PublicClientApplication.call_count == 1 _, kwargs = PublicClientApplication.call_args assert kwargs["client_capabilities"] is None with pytest.raises(ClientAuthenticationError): - credential.get_token("scope", enable_cae=True) + kwargs = {"enable_cae": True} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + getattr(credential, get_token_method)("scope", **kwargs) assert PublicClientApplication.call_count == 2 _, kwargs = PublicClientApplication.call_args assert kwargs["client_capabilities"] == ["CP1"] -def test_within_dac_error(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_within_dac_error(get_token_method): def send(request, **kwargs): # ensure the `claims` and `tenant_id` keywords from credential's `get_token` method don't make it to transport assert "claims" not in kwargs @@ -862,11 +913,12 @@ def send(request, **kwargs): within_dac.set(True) with patch("azure.identity._credentials.silent.PublicClientApplication") as PublicClientApplication: with pytest.raises(CredentialUnavailableError): # (cache is empty) - credential.get_token("scope") + getattr(credential, get_token_method)("scope") within_dac.set(False) -def test_claims_challenge(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_claims_challenge(get_token_method): """get_token should pass any claims challenge to MSAL token acquisition APIs""" expected_claims = '{"access_token": {"essential": "true"}' @@ -882,14 +934,18 @@ def test_claims_challenge(): transport = Mock(send=Mock(side_effect=Exception("this test mocks MSAL, so no request should be sent"))) credential = SharedTokenCacheCredential(transport=transport, authentication_record=record, _cache=TokenCache()) with patch("azure.identity._credentials.silent.PublicClientApplication", lambda *_, **__: msal_app): - credential.get_token("scope", claims=expected_claims) + kwargs = {"claims": expected_claims} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + getattr(credential, get_token_method)("scope", **kwargs) assert msal_app.acquire_token_silent_with_error.call_count == 1 args, kwargs = msal_app.acquire_token_silent_with_error.call_args assert kwargs["claims_challenge"] == expected_claims -def test_multitenant_authentication(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_multitenant_authentication(get_token_method): default_tenant = "organizations" first_token = "***" second_tenant = "second-tenant" @@ -918,21 +974,28 @@ def send(request, **kwargs): credential = SharedTokenCacheCredential( authority=authority, transport=Mock(send=send), _cache=cache, additionally_allowed_tenants=["*"] ) - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == first_token - token = credential.get_token("scope", tenant_id=default_tenant) + kwargs = {"tenant_id": default_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert token.token == first_token - token = credential.get_token("scope", tenant_id=second_tenant) + kwargs = {"tenant_id": second_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert token.token == second_token # should still default to the first tenant - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == first_token -def test_multitenant_authentication_auth_record(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_multitenant_authentication_auth_record(get_token_method): default_tenant = "organizations" first_token = "***" second_tenant = "second-tenant" @@ -972,17 +1035,23 @@ def send(request, **kwargs): _cache=cache, additionally_allowed_tenants=["*"], ) - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == first_token - token = credential.get_token("scope", tenant_id=default_tenant) + kwargs = {"tenant_id": default_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert token.token == first_token - token = credential.get_token("scope", tenant_id=second_tenant) + kwargs = {"tenant_id": second_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert token.token == second_token # should still default to the first tenant - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == first_token @@ -1017,7 +1086,8 @@ def populated_cache(*accounts): return cache -def test_multitenant_authentication_not_allowed(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_multitenant_authentication_not_allowed(get_token_method): default_tenant = "organizations" expected_token = "***" @@ -1048,12 +1118,18 @@ def send(request, **kwargs): credential = SharedTokenCacheCredential(authority=authority, transport=Mock(send=send), _cache=cache) - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == expected_token - token = credential.get_token("scope", tenant_id=default_tenant) + kwargs = {"tenant_id": default_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert token.token == expected_token with patch.dict("os.environ", {EnvironmentVariables.AZURE_IDENTITY_DISABLE_MULTITENANTAUTH: "true"}): - token = credential.get_token("scope", tenant_id="some tenant") + kwargs = {"tenant_id": "some_tenant"} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert token.token == expected_token diff --git a/sdk/identity/azure-identity/tests/test_shared_cache_credential_async.py b/sdk/identity/azure-identity/tests/test_shared_cache_credential_async.py index 1982b98ccf7e..012e965d6272 100644 --- a/sdk/identity/azure-identity/tests/test_shared_cache_credential_async.py +++ b/sdk/identity/azure-identity/tests/test_shared_cache_credential_async.py @@ -21,7 +21,7 @@ from msal import TokenCache import pytest -from helpers import build_aad_response, id_token_claims, mock_response, Request +from helpers import build_aad_response, id_token_claims, mock_response, Request, GET_TOKEN_METHODS from helpers_async import async_validating_transport, AsyncMockTransport from test_shared_cache_credential import get_account_event, populated_cache @@ -32,16 +32,18 @@ def test_supported(): @pytest.mark.asyncio -async def test_no_scopes(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_no_scopes(get_token_method): """The credential should raise when get_token is called with no scopes""" credential = SharedTokenCacheCredential(_cache=TokenCache()) with pytest.raises(ValueError): - await credential.get_token() + await getattr(credential, get_token_method)() @pytest.mark.asyncio -async def test_close(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_close(get_token_method): async def send(*_, **kwargs): # ensure the `claims` and `tenant_id` keywords from credential's `get_token` method don't make it to transport assert "claims" not in kwargs @@ -54,7 +56,7 @@ async def send(*_, **kwargs): ) # the credential doesn't open a transport session before one is needed, so we send a request - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") await credential.close() @@ -62,7 +64,8 @@ async def send(*_, **kwargs): @pytest.mark.asyncio -async def test_context_manager(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_context_manager(get_token_method): async def send(*_, **kwargs): # ensure the `claims` and `tenant_id` keywords from credential's `get_token` method don't make it to transport assert "claims" not in kwargs @@ -76,14 +79,14 @@ async def send(*_, **kwargs): # async with before initialization: credential should call __aexit__ but not __aenter__ async with credential: - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") assert transport.__aenter__.call_count == 0 assert transport.__aexit__.call_count == 1 # async with after initialization: credential should call __aenter__ and __aexit__ async with credential: - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") assert transport.__aenter__.call_count == 1 assert transport.__aexit__.call_count == 2 @@ -105,7 +108,8 @@ async def test_context_manager_no_cache(): @pytest.mark.asyncio -async def test_policies_configurable(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_policies_configurable(get_token_method): policy = Mock(spec_set=SansIOHTTPPolicy, on_request=Mock()) async def send(*_, **kwargs): @@ -120,13 +124,14 @@ async def send(*_, **kwargs): transport=Mock(send=send), ) - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") assert policy.on_request.called @pytest.mark.asyncio -async def test_user_agent(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_user_agent(get_token_method): transport = async_validating_transport( requests=[Request(required_headers={"User-Agent": USER_AGENT})], responses=[mock_response(json_payload=build_aad_response(access_token="**"))], @@ -136,11 +141,12 @@ async def test_user_agent(): _cache=populated_cache(get_account_event("test@user", "uid", "utid")), transport=transport ) - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") @pytest.mark.asyncio -async def test_tenant_id(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_tenant_id(get_token_method): transport = async_validating_transport( requests=[Request(required_headers={"User-Agent": USER_AGENT})], responses=[mock_response(json_payload=build_aad_response(access_token="**"))], @@ -152,7 +158,10 @@ async def test_tenant_id(): additionally_allowed_tenants=["*"], ) - await credential.get_token("scope", tenant_id="tenant_id") + kwargs = {"tenant_id": "tenant_id"} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + await getattr(credential, get_token_method)("scope", **kwargs) @pytest.mark.parametrize("authority", ("localhost", "https://localhost")) @@ -176,22 +185,26 @@ def _get_auth_client(self, authority=None, **kwargs): @pytest.mark.asyncio -async def test_empty_cache(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_empty_cache(get_token_method): """the credential should raise CredentialUnavailableError when the cache is empty""" with pytest.raises(CredentialUnavailableError, match=NO_ACCOUNTS): - await SharedTokenCacheCredential(_cache=TokenCache()).get_token("scope") + await getattr(SharedTokenCacheCredential(_cache=TokenCache()), get_token_method)("scope") with pytest.raises(CredentialUnavailableError, match=NO_ACCOUNTS): - await SharedTokenCacheCredential(_cache=TokenCache(), username="not@cache").get_token("scope") + await getattr(SharedTokenCacheCredential(_cache=TokenCache(), username="not@cache"), get_token_method)("scope") with pytest.raises(CredentialUnavailableError, match=NO_ACCOUNTS): - await SharedTokenCacheCredential(_cache=TokenCache(), tenant_id="not-cached").get_token("scope") + await getattr(SharedTokenCacheCredential(_cache=TokenCache(), tenant_id="not-cached"), get_token_method)( + "scope" + ) with pytest.raises(CredentialUnavailableError, match=NO_ACCOUNTS): credential = SharedTokenCacheCredential(_cache=TokenCache(), tenant_id="not-cached", username="not@cache") - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") @pytest.mark.asyncio -async def test_no_matching_account_for_username(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_no_matching_account_for_username(get_token_method): """one cached account, username specified, username doesn't match -> credential should raise""" upn = "spam@eggs" @@ -200,14 +213,15 @@ async def test_no_matching_account_for_username(): cache = populated_cache(account) with pytest.raises(CredentialUnavailableError) as ex: - await SharedTokenCacheCredential(_cache=cache, username="not" + upn).get_token("scope") + await getattr(SharedTokenCacheCredential(_cache=cache, username="not" + upn), get_token_method)("scope") assert ex.value.message.startswith(NO_MATCHING_ACCOUNTS[: NO_MATCHING_ACCOUNTS.index("{")]) assert "not" + upn in ex.value.message @pytest.mark.asyncio -async def test_no_matching_account_for_tenant(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_no_matching_account_for_tenant(get_token_method): """one cached account, tenant specified, tenant doesn't match -> credential should raise""" upn = "spam@eggs" @@ -216,14 +230,15 @@ async def test_no_matching_account_for_tenant(): cache = populated_cache(account) with pytest.raises(CredentialUnavailableError) as ex: - await SharedTokenCacheCredential(_cache=cache, tenant_id="not-" + tenant).get_token("scope") + await getattr(SharedTokenCacheCredential(_cache=cache, tenant_id="not-" + tenant), get_token_method)("scope") assert ex.value.message.startswith(NO_MATCHING_ACCOUNTS[: NO_MATCHING_ACCOUNTS.index("{")]) assert "not-" + tenant in ex.value.message @pytest.mark.asyncio -async def test_no_matching_account_for_tenant_and_username(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_no_matching_account_for_tenant_and_username(get_token_method): """one cached account, tenant and username specified, neither match -> credential should raise""" upn = "spam@eggs" @@ -232,16 +247,18 @@ async def test_no_matching_account_for_tenant_and_username(): cache = populated_cache(account) with pytest.raises(CredentialUnavailableError) as ex: - await SharedTokenCacheCredential(_cache=cache, tenant_id="not-" + tenant, username="not" + upn).get_token( - "scope" - ) + await getattr( + SharedTokenCacheCredential(_cache=cache, tenant_id="not-" + tenant, username="not" + upn), + get_token_method, + )("scope") assert ex.value.message.startswith(NO_MATCHING_ACCOUNTS[: NO_MATCHING_ACCOUNTS.index("{")]) assert "not" + upn in ex.value.message and "not-" + tenant in ex.value.message @pytest.mark.asyncio -async def test_no_matching_account_for_tenant_or_username(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_no_matching_account_for_tenant_or_username(get_token_method): """two cached accounts, username and tenant specified, one account matches each -> credential should raise""" refresh_token_a = "refresh-token-a" @@ -258,19 +275,20 @@ async def test_no_matching_account_for_tenant_or_username(): credential = SharedTokenCacheCredential(username=upn_a, tenant_id=tenant_b, _cache=cache, transport=transport) with pytest.raises(CredentialUnavailableError) as ex: - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") assert ex.value.message.startswith(NO_MATCHING_ACCOUNTS[: NO_MATCHING_ACCOUNTS.index("{")]) assert upn_a in ex.value.message and tenant_b in ex.value.message credential = SharedTokenCacheCredential(username=upn_b, tenant_id=tenant_a, _cache=cache, transport=transport) with pytest.raises(CredentialUnavailableError) as ex: - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") assert ex.value.message.startswith(NO_MATCHING_ACCOUNTS[: NO_MATCHING_ACCOUNTS.index("{")]) assert upn_b in ex.value.message and tenant_a in ex.value.message @pytest.mark.asyncio -async def test_single_account_matching_username(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_single_account_matching_username(get_token_method): """one cached account, username specified, username matches -> credential should auth that account""" upn = "spam@eggs" @@ -285,12 +303,13 @@ async def test_single_account_matching_username(): responses=[mock_response(json_payload=build_aad_response(access_token=expected_token))], ) credential = SharedTokenCacheCredential(_cache=cache, transport=transport, username=upn) - token = await credential.get_token(scope) + token = await getattr(credential, get_token_method)(scope) assert token.token == expected_token @pytest.mark.asyncio -async def test_single_account_matching_tenant(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_single_account_matching_tenant(get_token_method): """one cached account, tenant specified, tenant matches -> credential should auth that account""" tenant_id = "tenant-id" @@ -305,12 +324,13 @@ async def test_single_account_matching_tenant(): responses=[mock_response(json_payload=build_aad_response(access_token=expected_token))], ) credential = SharedTokenCacheCredential(_cache=cache, transport=transport, tenant_id=tenant_id) - token = await credential.get_token(scope) + token = await getattr(credential, get_token_method)(scope) assert token.token == expected_token @pytest.mark.asyncio -async def test_single_account_matching_tenant_and_username(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_single_account_matching_tenant_and_username(get_token_method): """one cached account, tenant and username specified, both match -> credential should auth that account""" upn = "spam@eggs" @@ -326,12 +346,13 @@ async def test_single_account_matching_tenant_and_username(): responses=[mock_response(json_payload=build_aad_response(access_token=expected_token))], ) credential = SharedTokenCacheCredential(_cache=cache, transport=transport, tenant_id=tenant_id, username=upn) - token = await credential.get_token(scope) + token = await getattr(credential, get_token_method)(scope) assert token.token == expected_token @pytest.mark.asyncio -async def test_single_account(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_single_account(get_token_method): """one cached account, no username specified -> credential should auth that account""" refresh_token = "refresh-token" @@ -346,12 +367,13 @@ async def test_single_account(): ) credential = SharedTokenCacheCredential(_cache=cache, transport=transport) - token = await credential.get_token(scope) + token = await getattr(credential, get_token_method)(scope) assert token.token == expected_token @pytest.mark.asyncio -async def test_no_refresh_token(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_no_refresh_token(get_token_method): """one cached account, account has no refresh token -> credential should raise""" account = get_account_event(uid="uid_a", utid="utid", username="spam@eggs", refresh_token=None) @@ -361,15 +383,16 @@ async def test_no_refresh_token(): credential = SharedTokenCacheCredential(_cache=cache, transport=transport) with pytest.raises(CredentialUnavailableError, match=NO_ACCOUNTS): - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") credential = SharedTokenCacheCredential(_cache=cache, transport=transport, username="not@cache") with pytest.raises(CredentialUnavailableError, match=NO_ACCOUNTS): - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") @pytest.mark.asyncio -async def test_two_accounts_no_username_or_tenant(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_two_accounts_no_username_or_tenant(get_token_method): """two cached accounts, no username or tenant specified -> credential should raise""" upn_a = "a@foo" @@ -384,11 +407,12 @@ async def test_two_accounts_no_username_or_tenant(): # two users in the cache, no username specified -> CredentialUnavailableError credential = SharedTokenCacheCredential(_cache=cache, transport=transport) with pytest.raises(ClientAuthenticationError, match=MULTIPLE_ACCOUNTS) as ex: - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") @pytest.mark.asyncio -async def test_two_accounts_username_specified(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_two_accounts_username_specified(get_token_method): """two cached accounts, username specified, one account matches -> credential should auth that account""" scope = "scope" @@ -405,12 +429,13 @@ async def test_two_accounts_username_specified(): responses=[mock_response(json_payload=build_aad_response(access_token=expected_token))], ) credential = SharedTokenCacheCredential(username=upn_a, _cache=cache, transport=transport) - token = await credential.get_token(scope) + token = await getattr(credential, get_token_method)(scope) assert token.token == expected_token @pytest.mark.asyncio -async def test_two_accounts_tenant_specified(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_two_accounts_tenant_specified(get_token_method): """two cached accounts, tenant specified, one account matches -> credential should auth that account""" scope = "scope" @@ -428,12 +453,13 @@ async def test_two_accounts_tenant_specified(): responses=[mock_response(json_payload=build_aad_response(access_token=expected_token))], ) credential = SharedTokenCacheCredential(tenant_id=tenant_id, _cache=cache, transport=transport) - token = await credential.get_token(scope) + token = await getattr(credential, get_token_method)(scope) assert token.token == expected_token @pytest.mark.asyncio -async def test_two_accounts_tenant_and_username_specified(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_two_accounts_tenant_and_username_specified(get_token_method): """two cached accounts, tenant and username specified, one account matches both -> credential should auth that account""" scope = "scope" @@ -451,12 +477,13 @@ async def test_two_accounts_tenant_and_username_specified(): responses=[mock_response(json_payload=build_aad_response(access_token=expected_token))], ) credential = SharedTokenCacheCredential(tenant_id=tenant_id, username=upn_a, _cache=cache, transport=transport) - token = await credential.get_token(scope) + token = await getattr(credential, get_token_method)(scope) assert token.token == expected_token @pytest.mark.asyncio -async def test_same_username_different_tenants(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_same_username_different_tenants(get_token_method): """two cached accounts, same username, different tenants""" access_token_a = "access-token-a" @@ -475,7 +502,7 @@ async def test_same_username_different_tenants(): transport = Mock(side_effect=Exception()) # (so it shouldn't use the network) credential = SharedTokenCacheCredential(username=upn, _cache=cache, transport=transport) with pytest.raises(CredentialUnavailableError) as ex: - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") assert ex.value.message.startswith(MULTIPLE_MATCHING_ACCOUNTS[: MULTIPLE_MATCHING_ACCOUNTS.index("{")]) assert upn in ex.value.message @@ -487,7 +514,7 @@ async def test_same_username_different_tenants(): responses=[mock_response(json_payload=build_aad_response(access_token=access_token_a))], ) credential = SharedTokenCacheCredential(tenant_id=tenant_a, _cache=cache, transport=transport) - token = await credential.get_token(scope) + token = await getattr(credential, get_token_method)(scope) assert token.token == access_token_a transport = async_validating_transport( @@ -495,12 +522,13 @@ async def test_same_username_different_tenants(): responses=[mock_response(json_payload=build_aad_response(access_token=access_token_b))], ) credential = SharedTokenCacheCredential(tenant_id=tenant_b, _cache=cache, transport=transport) - token = await credential.get_token(scope) + token = await getattr(credential, get_token_method)(scope) assert token.token == access_token_b @pytest.mark.asyncio -async def test_same_tenant_different_usernames(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_same_tenant_different_usernames(get_token_method): """two cached accounts, same tenant, different usernames""" access_token_a = "access-token-a" @@ -519,7 +547,7 @@ async def test_same_tenant_different_usernames(): transport = Mock(side_effect=Exception()) # (so it shouldn't use the network) credential = SharedTokenCacheCredential(tenant_id=tenant_id, _cache=cache, transport=transport) with pytest.raises(CredentialUnavailableError) as ex: - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") assert ex.value.message.startswith(MULTIPLE_MATCHING_ACCOUNTS[: MULTIPLE_MATCHING_ACCOUNTS.index("{")]) assert tenant_id in ex.value.message @@ -531,7 +559,7 @@ async def test_same_tenant_different_usernames(): responses=[mock_response(json_payload=build_aad_response(access_token=access_token_a))], ) credential = SharedTokenCacheCredential(username=upn_b, _cache=cache, transport=transport) - token = await credential.get_token(scope) + token = await getattr(credential, get_token_method)(scope) assert token.token == access_token_a transport = async_validating_transport( @@ -539,12 +567,13 @@ async def test_same_tenant_different_usernames(): responses=[mock_response(json_payload=build_aad_response(access_token=access_token_a))], ) credential = SharedTokenCacheCredential(username=upn_a, _cache=cache, transport=transport) - token = await credential.get_token(scope) + token = await getattr(credential, get_token_method)(scope) assert token.token == access_token_a @pytest.mark.asyncio -async def test_authority_aliases(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_authority_aliases(get_token_method): """the credential should use a refresh token valid for any known alias of its authority""" expected_access_token = "access-token" @@ -563,7 +592,7 @@ async def test_authority_aliases(): responses=[mock_response(json_payload=build_aad_response(access_token=expected_access_token))], ) credential = SharedTokenCacheCredential(authority=authority, _cache=cache, transport=transport) - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == expected_access_token # it should also be acceptable for every known alias of this authority @@ -573,12 +602,13 @@ async def test_authority_aliases(): responses=[mock_response(json_payload=build_aad_response(access_token=expected_access_token))], ) credential = SharedTokenCacheCredential(authority=alias, _cache=cache, transport=transport) - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == expected_access_token @pytest.mark.asyncio -async def test_authority_with_no_known_alias(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_authority_with_no_known_alias(get_token_method): """given an appropriate token, an authority with no known aliases should work""" authority = "unknown.authority" @@ -591,12 +621,13 @@ async def test_authority_with_no_known_alias(): responses=[mock_response(json_payload=build_aad_response(access_token=expected_access_token))], ) credential = SharedTokenCacheCredential(authority=authority, _cache=cache, transport=transport) - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == expected_access_token @pytest.mark.asyncio -async def test_authority_environment_variable(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_authority_environment_variable(get_token_method): """the credential should accept an authority by environment variable when none is otherwise specified""" authority = "localhost" @@ -610,12 +641,13 @@ async def test_authority_environment_variable(): ) with patch.dict("os.environ", {EnvironmentVariables.AZURE_AUTHORITY_HOST: authority}, clear=True): credential = SharedTokenCacheCredential(transport=transport, _cache=cache) - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == expected_access_token @pytest.mark.asyncio -async def test_initialization(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_initialization(get_token_method): """the credential should attempt to load the cache when it's needed and no cache has been established.""" with patch("azure.identity._persistent_cache._get_persistence") as mock_cache_loader: @@ -625,16 +657,17 @@ async def test_initialization(): assert mock_cache_loader.call_count == 0 with pytest.raises(CredentialUnavailableError, match="Shared token cache unavailable"): - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") assert mock_cache_loader.call_count == 1 with pytest.raises(CredentialUnavailableError, match="Shared token cache unavailable"): - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") assert mock_cache_loader.call_count == 2 @pytest.mark.asyncio -async def test_initialization_with_cache_options(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_initialization_with_cache_options(get_token_method): """the credential should use user-supplied persistence options""" with patch("azure.identity._internal.shared_token_cache._load_persistent_cache") as mock_cache_loader: @@ -642,12 +675,13 @@ async def test_initialization_with_cache_options(): credential = SharedTokenCacheCredential(cache_persistence_options=options) with pytest.raises(CredentialUnavailableError): - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") assert mock_cache_loader.call_count == 1 @pytest.mark.asyncio -async def test_multitenant_authentication(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_multitenant_authentication(get_token_method): first_token = "***" second_tenant = "second-tenant" second_token = first_token * 2 @@ -674,22 +708,29 @@ async def send(request, **kwargs): credential = SharedTokenCacheCredential( authority=authority, transport=Mock(send=send), _cache=cache, additionally_allowed_tenants=["*"] ) - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == first_token - token = await credential.get_token("scope", tenant_id="organizations") + kwargs = {"tenant_id": "organizations"} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(credential, get_token_method)("scope", **kwargs) assert token.token == first_token - token = await credential.get_token("scope", tenant_id=second_tenant) + kwargs = {"tenant_id": second_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(credential, get_token_method)("scope", **kwargs) assert token.token == second_token # should still default to the first tenant - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == first_token @pytest.mark.asyncio -async def test_multitenant_authentication_not_allowed(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_multitenant_authentication_not_allowed(get_token_method): default_tenant = "organizations" expected_token = "***" @@ -715,12 +756,18 @@ async def send(request, **kwargs): credential = SharedTokenCacheCredential(authority=authority, transport=Mock(send=send), _cache=cache) - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == expected_token - token = await credential.get_token("scope", tenant_id=default_tenant) + kwargs = {"tenant_id": default_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(credential, get_token_method)("scope", **kwargs) assert token.token == expected_token + kwargs = {"tenant_id": "some_tenant"} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} with patch.dict("os.environ", {EnvironmentVariables.AZURE_IDENTITY_DISABLE_MULTITENANTAUTH: "true"}): - token = await credential.get_token("scope", tenant_id="some tenant") + token = await getattr(credential, get_token_method)("scope", **kwargs) assert token.token == expected_token diff --git a/sdk/identity/azure-identity/tests/test_username_password_credential.py b/sdk/identity/azure-identity/tests/test_username_password_credential.py index ec86c1174249..3ce95488fc3b 100644 --- a/sdk/identity/azure-identity/tests/test_username_password_credential.py +++ b/sdk/identity/azure-identity/tests/test_username_password_credential.py @@ -2,6 +2,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ +from unittest.mock import Mock, patch + from azure.core.pipeline.policies import SansIOHTTPPolicy from azure.identity import UsernamePasswordCredential from azure.identity._internal.user_agent import USER_AGENT @@ -15,13 +17,9 @@ mock_response, Request, validating_transport, + GET_TOKEN_METHODS, ) -try: - from unittest.mock import Mock, patch -except ImportError: # python < 3.3 - from mock import Mock, patch # type: ignore - def test_tenant_id_validation(): """The credential should raise ValueError when given an invalid tenant_id""" @@ -36,15 +34,17 @@ def test_tenant_id_validation(): UsernamePasswordCredential("client-id", "username", "password", tenant_id=tenant) -def test_no_scopes(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_no_scopes(get_token_method): """The credential should raise when get_token is called with no scopes""" credential = UsernamePasswordCredential("client-id", "username", "password") with pytest.raises(ValueError): - credential.get_token() + getattr(credential, get_token_method)() -def test_policies_configurable(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_policies_configurable(get_token_method): policy = Mock(spec_set=SansIOHTTPPolicy, on_request=Mock()) transport = validating_transport( @@ -54,12 +54,13 @@ def test_policies_configurable(): ) credential = UsernamePasswordCredential("client-id", "username", "password", policies=[policy], transport=transport) - credential.get_token("scope") + getattr(credential, get_token_method)("scope") assert policy.on_request.called -def test_user_agent(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_user_agent(get_token_method): transport = validating_transport( requests=[Request()] * 2 + [Request(required_headers={"User-Agent": USER_AGENT})], responses=[get_discovery_response()] * 2 @@ -68,10 +69,11 @@ def test_user_agent(): credential = UsernamePasswordCredential("client-id", "username", "password", transport=transport) - credential.get_token("scope") + getattr(credential, get_token_method)("scope") -def test_tenant_id(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_tenant_id(get_token_method): transport = validating_transport( requests=[Request()] * 2 + [Request(required_headers={"User-Agent": USER_AGENT})], responses=[get_discovery_response()] * 2 @@ -82,10 +84,14 @@ def test_tenant_id(): "client-id", "username", "password", transport=transport, additionally_allowed_tenants=["*"] ) - credential.get_token("scope", tenant_id="tenant_id") + kwargs = {"tenant_id": "tenant_id"} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + getattr(credential, get_token_method)("scope", **kwargs) -def test_username_password_credential(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_username_password_credential(get_token_method): expected_token = "access-token" client_id = "client-id" transport = validating_transport( @@ -110,11 +116,12 @@ def test_username_password_credential(): disable_instance_discovery=True, # kwargs are passed to MSAL; this one prevents a Microsoft Entra verification request ) - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == expected_token -def test_authenticate(): +@pytest.mark.parametrize("get_token_method", ["get_token_info"]) +def test_authenticate(get_token_method): client_id = "client-id" environment = "localhost" issuer = "https://" + environment @@ -158,7 +165,7 @@ def test_authenticate(): assert record.username == username # credential should have a cached access token for the scope passed to authenticate - token = credential.get_token(scope) + token = getattr(credential, get_token_method)(scope) assert token.token == access_token @@ -182,7 +189,8 @@ def test_client_capabilities(): assert kwargs["client_capabilities"] == ["CP1"] -def test_claims_challenge(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_claims_challenge(get_token_method): """get_token should and authenticate pass any claims challenge to MSAL token acquisition APIs""" msal_acquire_token_result = dict( @@ -202,7 +210,10 @@ def test_claims_challenge(): args, kwargs = msal_app.acquire_token_by_username_password.call_args assert kwargs["claims_challenge"] == expected_claims - credential.get_token("scope", claims=expected_claims) + kwargs = {"claims": expected_claims} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + getattr(credential, get_token_method)("scope", **kwargs) assert msal_app.acquire_token_by_username_password.call_count == 2 args, kwargs = msal_app.acquire_token_by_username_password.call_args @@ -210,7 +221,10 @@ def test_claims_challenge(): msal_app.get_accounts.return_value = [{"home_account_id": credential._auth_record.home_account_id}] msal_app.acquire_token_silent_with_error.return_value = msal_acquire_token_result - credential.get_token("scope", claims=expected_claims) + kwargs = {"claims": expected_claims} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + getattr(credential, get_token_method)("scope", **kwargs) assert msal_app.acquire_token_silent_with_error.call_count == 1 args, kwargs = msal_app.acquire_token_silent_with_error.call_args diff --git a/sdk/identity/azure-identity/tests/test_vscode_credential.py b/sdk/identity/azure-identity/tests/test_vscode_credential.py index a2e5b37e5925..becfe82af14b 100644 --- a/sdk/identity/azure-identity/tests/test_vscode_credential.py +++ b/sdk/identity/azure-identity/tests/test_vscode_credential.py @@ -4,6 +4,8 @@ # ------------------------------------ import sys import time +from unittest import mock +from urllib.parse import urlparse from azure.core.credentials import AccessToken from azure.core.exceptions import ClientAuthenticationError @@ -12,14 +14,9 @@ from azure.identity._constants import EnvironmentVariables from azure.identity._internal.user_agent import USER_AGENT import pytest -from urllib.parse import urlparse from helpers import build_aad_response, mock_response, Request, validating_transport -try: - from unittest import mock -except ImportError: # python < 3.3 - import mock GET_REFRESH_TOKEN = VisualStudioCodeCredential.__module__ + ".get_refresh_token" GET_USER_SETTINGS = VisualStudioCodeCredential.__module__ + ".get_user_settings" diff --git a/sdk/identity/azure-identity/tests/test_workload_identity_credential.py b/sdk/identity/azure-identity/tests/test_workload_identity_credential.py index a744c7e36b7b..1db0874a77ce 100644 --- a/sdk/identity/azure-identity/tests/test_workload_identity_credential.py +++ b/sdk/identity/azure-identity/tests/test_workload_identity_credential.py @@ -4,9 +4,10 @@ # ------------------------------------ from unittest.mock import mock_open, MagicMock, patch +import pytest from azure.identity import WorkloadIdentityCredential -from helpers import mock_response, build_aad_response +from helpers import mock_response, build_aad_response, GET_TOKEN_METHODS def test_workload_identity_credential_initialize(): @@ -18,7 +19,8 @@ def test_workload_identity_credential_initialize(): ) -def test_workload_identity_credential_get_token(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_workload_identity_credential_get_token(get_token_method): tenant_id = "tenant-id" client_id = "client-id" access_token = "foo" @@ -38,7 +40,7 @@ def send(request, **kwargs): open_mock = mock_open(read_data=assertion) with patch("builtins.open", open_mock): - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == access_token open_mock.assert_called_once_with(token_file_path, encoding="utf-8") diff --git a/sdk/identity/azure-identity/tests/test_workload_identity_credential_async.py b/sdk/identity/azure-identity/tests/test_workload_identity_credential_async.py index cbda554886a7..fc6f3e8c6cb5 100644 --- a/sdk/identity/azure-identity/tests/test_workload_identity_credential_async.py +++ b/sdk/identity/azure-identity/tests/test_workload_identity_credential_async.py @@ -7,7 +7,7 @@ import pytest from azure.identity.aio import WorkloadIdentityCredential -from helpers import mock_response, build_aad_response +from helpers import mock_response, build_aad_response, GET_TOKEN_METHODS def test_workload_identity_credential_initialize(): @@ -20,7 +20,8 @@ def test_workload_identity_credential_initialize(): @pytest.mark.asyncio -async def test_workload_identity_credential_get_token(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_workload_identity_credential_get_token(get_token_method): tenant_id = "tenant-id" client_id = "client-id" access_token = "foo" @@ -40,7 +41,7 @@ async def send(request, **kwargs): open_mock = mock_open(read_data=assertion) with patch("builtins.open", open_mock): - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == access_token open_mock.assert_called_once_with(token_file_path, encoding="utf-8") From c4dfbbceb3bcbbd706706de409759dc18e0d658a Mon Sep 17 00:00:00 2001 From: Paul Van Eck Date: Thu, 12 Sep 2024 01:38:27 +0000 Subject: [PATCH 2/8] pytest skips Signed-off-by: Paul Van Eck --- .../azure-identity/tests/test_initialization.py | 6 ++++++ .../azure-identity/tests/test_initialization_async.py | 11 +++++++++++ 2 files changed, 17 insertions(+) diff --git a/sdk/identity/azure-identity/tests/test_initialization.py b/sdk/identity/azure-identity/tests/test_initialization.py index 7679b1c41031..e44fe487a3db 100644 --- a/sdk/identity/azure-identity/tests/test_initialization.py +++ b/sdk/identity/azure-identity/tests/test_initialization.py @@ -2,6 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ +import sys from azure.core.credentials import SupportsTokenInfo, TokenCredential from azure.identity import ( @@ -24,6 +25,7 @@ AzureDeveloperCliCredential, AzurePipelinesCredential, ) +import pytest def test_credential_is_token_credential(): @@ -47,6 +49,10 @@ def test_credential_is_token_credential(): assert isinstance(AzurePipelinesCredential, TokenCredential) +@pytest.mark.skipif( + sys.version_info < (3, 9), + reason="isinstance check doesn't seem to work when the Protocol subclasses ContextManager in Python <=3.8", +) def test_credential_is_supports_token_info(): assert isinstance(AuthorizationCodeCredential, SupportsTokenInfo) assert isinstance(CertificateCredential, SupportsTokenInfo) diff --git a/sdk/identity/azure-identity/tests/test_initialization_async.py b/sdk/identity/azure-identity/tests/test_initialization_async.py index 41bc432b86b3..56a5f9175740 100644 --- a/sdk/identity/azure-identity/tests/test_initialization_async.py +++ b/sdk/identity/azure-identity/tests/test_initialization_async.py @@ -2,6 +2,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ +import sys + from azure.core.credentials_async import AsyncSupportsTokenInfo, AsyncTokenCredential from azure.identity.aio import ( AuthorizationCodeCredential, @@ -20,8 +22,13 @@ AzureDeveloperCliCredential, AzurePipelinesCredential, ) +import pytest +@pytest.mark.skipif( + sys.version_info < (3, 9), + reason="isinstance check doesn't seem to work when the Protocol subclasses AsyncContextManager in Python <=3.8", +) def test_credential_is_async_token_credential(): assert isinstance(AuthorizationCodeCredential, AsyncTokenCredential) assert isinstance(CertificateCredential, AsyncTokenCredential) @@ -40,6 +47,10 @@ def test_credential_is_async_token_credential(): assert isinstance(AzurePipelinesCredential, AsyncTokenCredential) +@pytest.mark.skipif( + sys.version_info < (3, 9), + reason="isinstance check doesn't seem to work when the Protocol subclasses AsyncContextManager in Python <=3.8", +) def test_credential_is_async_supports_token_info(): assert isinstance(AuthorizationCodeCredential, AsyncSupportsTokenInfo) assert isinstance(CertificateCredential, AsyncSupportsTokenInfo) From cf0e44ce4533350fede74898a2c0b51a691952d1 Mon Sep 17 00:00:00 2001 From: Paul Van Eck Date: Thu, 12 Sep 2024 23:03:58 +0000 Subject: [PATCH 3/8] Bump core Signed-off-by: Paul Van Eck --- sdk/identity/azure-identity/CHANGELOG.md | 1 + sdk/identity/azure-identity/setup.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/sdk/identity/azure-identity/CHANGELOG.md b/sdk/identity/azure-identity/CHANGELOG.md index d9c6788bf622..6fae76863593 100644 --- a/sdk/identity/azure-identity/CHANGELOG.md +++ b/sdk/identity/azure-identity/CHANGELOG.md @@ -15,6 +15,7 @@ - Added identity config validation to `ManagedIdentityCredential` to avoid non-deterministic states (e.g. both `resource_id` and `object_id` are specified). ([#36950](https://github.com/Azure/azure-sdk-for-python/pull/36950)) - Additional validation was added for `ManagedIdentityCredential` in Azure Cloud Shell environments. ([#36438](https://github.com/Azure/azure-sdk-for-python/issues/36438)) +- Bumped minimum dependency on `azure-core` to `>=1.31.0`. ## 1.18.0b2 (2024-08-09) diff --git a/sdk/identity/azure-identity/setup.py b/sdk/identity/azure-identity/setup.py index 3b28cffbe847..b56ebc2abaa9 100644 --- a/sdk/identity/azure-identity/setup.py +++ b/sdk/identity/azure-identity/setup.py @@ -59,7 +59,7 @@ ), python_requires=">=3.8", install_requires=[ - "azure-core>=1.23.0", + "azure-core>=1.31.0", "cryptography>=2.5", "msal>=1.30.0", "msal-extensions>=1.2.0", From a32601da370f1bdcb1c10eaca639c2d2d1754c21 Mon Sep 17 00:00:00 2001 From: Paul Van Eck Date: Fri, 13 Sep 2024 18:44:09 +0000 Subject: [PATCH 4/8] update default base method name Signed-off-by: Paul Van Eck --- .../azure-identity/azure/identity/_credentials/shared_cache.py | 2 +- .../azure-identity/azure/identity/_credentials/silent.py | 2 +- .../azure-identity/azure/identity/_internal/get_token_mixin.py | 2 +- .../azure-identity/azure/identity/_internal/interactive.py | 2 +- .../azure/identity/_internal/msal_managed_identity_client.py | 2 +- .../azure/identity/aio/_credentials/shared_cache.py | 2 +- .../azure/identity/aio/_internal/get_token_mixin.py | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py b/sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py index c23dfc0485b9..31870a6401db 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py @@ -157,7 +157,7 @@ def _get_token_base( self, *scopes: str, options: Optional[TokenRequestOptions] = None, - base_method_name: str = "get_token", + base_method_name: str = "get_token_info", **kwargs: Any, ) -> AccessTokenInfo: if not scopes: diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/silent.py b/sdk/identity/azure-identity/azure/identity/_credentials/silent.py index f4add4b726c8..19b2b49a01a7 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/silent.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/silent.py @@ -83,7 +83,7 @@ def _get_token_base( self, *scopes: str, options: Optional[TokenRequestOptions] = None, - base_method_name: str = "get_token", + base_method_name: str = "get_token_info", **kwargs: Any, ) -> AccessTokenInfo: diff --git a/sdk/identity/azure-identity/azure/identity/_internal/get_token_mixin.py b/sdk/identity/azure-identity/azure/identity/_internal/get_token_mixin.py index a8fcb2195851..16fd1ea9f9c4 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/get_token_mixin.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/get_token_mixin.py @@ -117,7 +117,7 @@ def _get_token_base( self, *scopes: str, options: Optional[TokenRequestOptions] = None, - base_method_name: str = "get_token", + base_method_name: str = "get_token_info", **kwargs: Any, ) -> AccessTokenInfo: if not scopes: diff --git a/sdk/identity/azure-identity/azure/identity/_internal/interactive.py b/sdk/identity/azure-identity/azure/identity/_internal/interactive.py index 54a7474fc1b4..d1b8ab869719 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/interactive.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/interactive.py @@ -177,7 +177,7 @@ def _get_token_base( self, *scopes: str, options: Optional[TokenRequestOptions] = None, - base_method_name: str = "get_token", + base_method_name: str = "get_token_info", **kwargs: Any, ) -> AccessTokenInfo: if not scopes: diff --git a/sdk/identity/azure-identity/azure/identity/_internal/msal_managed_identity_client.py b/sdk/identity/azure-identity/azure/identity/_internal/msal_managed_identity_client.py index a2d46074d19c..1c768ec09245 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/msal_managed_identity_client.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/msal_managed_identity_client.py @@ -140,7 +140,7 @@ def _get_token_base( self, *scopes: str, options: Optional[TokenRequestOptions] = None, - base_method_name: str = "get_token", + base_method_name: str = "get_token_info", **kwargs: Any, ) -> AccessTokenInfo: if not scopes: diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/shared_cache.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/shared_cache.py index e710013ed1b4..6d5f38d497f1 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/shared_cache.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/shared_cache.py @@ -111,7 +111,7 @@ async def _get_token_base( self, *scopes: str, options: Optional[TokenRequestOptions] = None, - base_method_name: str = "get_token", + base_method_name: str = "get_token_info", **kwargs: Any, ) -> AccessTokenInfo: if not scopes: diff --git a/sdk/identity/azure-identity/azure/identity/aio/_internal/get_token_mixin.py b/sdk/identity/azure-identity/azure/identity/aio/_internal/get_token_mixin.py index 24bf36ef8811..61ad01eaab75 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_internal/get_token_mixin.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_internal/get_token_mixin.py @@ -117,7 +117,7 @@ async def _get_token_base( self, *scopes: str, options: Optional[TokenRequestOptions] = None, - base_method_name: str = "get_token", + base_method_name: str = "get_token_info", **kwargs: Any, ) -> AccessTokenInfo: if not scopes: From 76598c1271c3062a21418581a71a013d04bb6736 Mon Sep 17 00:00:00 2001 From: Paul Van Eck Date: Sat, 14 Sep 2024 00:35:18 +0000 Subject: [PATCH 5/8] Handle chained pop scenario Signed-off-by: Paul Van Eck --- .../azure/identity/_credentials/chained.py | 6 ++++- .../identity/aio/_credentials/chained.py | 6 ++++- .../tests/test_chained_credential.py | 22 ++++++++++++++++ .../test_chained_token_credential_async.py | 26 +++++++++++++++++++ 4 files changed, 58 insertions(+), 2 deletions(-) diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/chained.py b/sdk/identity/azure-identity/azure/identity/_credentials/chained.py index df888cb2ef70..be161cb483ed 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/chained.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/chained.py @@ -147,13 +147,17 @@ def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = """ within_credential_chain.set(True) history = [] + options = options or {} for credential in self.credentials: try: # A custom credential in the chain may not implement get_token_info if hasattr(credential, "get_token_info"): token_info = cast(SupportsTokenInfo, credential).get_token_info(*scopes, options=options) else: - options = options or {} + if options.get("pop"): + raise CredentialUnavailableError( + "Proof of possession arguments are not supported for this credential." + ) token = credential.get_token(*scopes, **options) token_info = AccessTokenInfo(token=token.token, expires_on=token.expires_on) _LOGGER.info("%s acquired a token from %s", self.__class__.__name__, credential.__class__.__name__) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/chained.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/chained.py index 80d25746ad67..3221cfca42ed 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/chained.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/chained.py @@ -128,13 +128,17 @@ async def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptio """ within_credential_chain.set(True) history = [] + options = options or {} for credential in self.credentials: try: # A custom credential in the chain may not implement get_token_info if hasattr(credential, "get_token_info"): token_info = await cast(AsyncSupportsTokenInfo, credential).get_token_info(*scopes, options=options) else: - options = options or {} + if options.get("pop"): + raise CredentialUnavailableError( + "Proof of possession arguments are not supported for this credential." + ) token = await credential.get_token(*scopes, **options) token_info = AccessTokenInfo(token=token.token, expires_on=token.expires_on) _LOGGER.info("%s acquired a token from %s", self.__class__.__name__, credential.__class__.__name__) diff --git a/sdk/identity/azure-identity/tests/test_chained_credential.py b/sdk/identity/azure-identity/tests/test_chained_credential.py index 3ccbe2e2849d..9ab770991fe4 100644 --- a/sdk/identity/azure-identity/tests/test_chained_credential.py +++ b/sdk/identity/azure-identity/tests/test_chained_credential.py @@ -251,3 +251,25 @@ def test_credentials_with_no_get_token_info(): chain = ChainedTokenCredential(credential1, credential2, credential3) # type: ignore token_info = chain.get_token_info("scope") assert token_info.token == access_token + + +def test_credentials_with_pop_option(): + """ChainedTokenCredential should skip credentials that don't support get_token_info and the pop option is set.""" + + access_token = "****" + credential1 = Mock( + spec_set=["get_token_info"], + get_token_info=Mock(side_effect=CredentialUnavailableError(message="")), + ) + credential2 = Mock( + spec_set=["get_token"], + get_token=Mock(return_value=AccessToken("foo", 42)), + ) + credential3 = Mock( + spec_set=["get_token", "get_token_info"], + get_token=Mock(return_value=AccessToken("bar", 42)), + get_token_info=Mock(return_value=AccessTokenInfo(access_token, 42)), + ) + chain = ChainedTokenCredential(credential1, credential2, credential3) # type: ignore + token_info = chain.get_token_info("scope", options={"pop": True}) # type: ignore + assert token_info.token == access_token diff --git a/sdk/identity/azure-identity/tests/test_chained_token_credential_async.py b/sdk/identity/azure-identity/tests/test_chained_token_credential_async.py index 2f541b859684..aa105db967bd 100644 --- a/sdk/identity/azure-identity/tests/test_chained_token_credential_async.py +++ b/sdk/identity/azure-identity/tests/test_chained_token_credential_async.py @@ -252,3 +252,29 @@ async def credential_unavailable(message="it didn't work", **_): chain = ChainedTokenCredential(credential1, credential2, credential3) # type: ignore token_info = await chain.get_token_info("scope") assert token_info.token == access_token + + +@pytest.mark.asyncio +async def test_credentials_with_pop_option(): + """ChainedTokenCredential should skip credentials that don't support get_token_info and the pop option is set.""" + + async def credential_unavailable(message="it didn't work", **_): + raise CredentialUnavailableError(message) + + access_token = "****" + credential1 = Mock( + spec_set=["get_token_info"], + get_token_info=Mock(wraps=credential_unavailable), + ) + credential2 = Mock( + spec_set=["get_token"], + get_token=Mock(wraps=wrap_in_future(lambda _, **__: AccessToken("foo", 42))), + ) + credential3 = Mock( + spec_set=["get_token", "get_token_info"], + get_token=Mock(wraps=wrap_in_future(lambda _, **__: AccessToken("bar", 42))), + get_token_info=Mock(wraps=wrap_in_future(lambda _, **__: AccessTokenInfo(access_token, 42))), + ) + chain = ChainedTokenCredential(credential1, credential2, credential3) # type: ignore + token_info = await chain.get_token_info("scope", options={"pop": True}) # type: ignore + assert token_info.token == access_token From 3e5958110862079b932ae51ae968ecd80d43996c Mon Sep 17 00:00:00 2001 From: Paul Van Eck Date: Mon, 16 Sep 2024 20:42:38 +0000 Subject: [PATCH 6/8] chained updates Signed-off-by: Paul Van Eck --- .../identity/_credentials/application.py | 6 ++- .../azure/identity/_credentials/chained.py | 52 ++++++++++++++----- .../azure/identity/_credentials/default.py | 11 ++-- .../azure/identity/_internal/interactive.py | 6 +++ .../identity/aio/_credentials/application.py | 6 ++- .../identity/aio/_credentials/chained.py | 42 ++++++++++----- .../identity/aio/_credentials/default.py | 4 +- .../tests/test_aad_client_async.py | 2 +- .../tests/test_chained_credential.py | 22 ++++++++ .../test_chained_token_credential_async.py | 26 ++++++++++ .../tests/test_interactive_credential.py | 16 ++++++ 11 files changed, 156 insertions(+), 37 deletions(-) diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/application.py b/sdk/identity/azure-identity/azure/identity/_credentials/application.py index 1d900f0d1fcb..81a06fd989d1 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/application.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/application.py @@ -6,7 +6,7 @@ import os from typing import Any, Optional, cast -from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions, SupportsTokenInfo +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions, SupportsTokenInfo, TokenCredential from .chained import ChainedTokenCredential from .environment import EnvironmentCredential from .managed_identity import ManagedIdentityCredential @@ -83,7 +83,9 @@ def get_token( `message` attribute listing each authentication attempt and its error message. """ if self._successful_credential: - token = self._successful_credential.get_token(*scopes, claims=claims, tenant_id=tenant_id, **kwargs) + token = cast(TokenCredential, self._successful_credential).get_token( + *scopes, claims=claims, tenant_id=tenant_id, **kwargs + ) _LOGGER.info( "%s acquired a token from %s", self.__class__.__name__, self._successful_credential.__class__.__name__ ) diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/chained.py b/sdk/identity/azure-identity/azure/identity/_credentials/chained.py index be161cb483ed..31f38937f1d7 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/chained.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/chained.py @@ -3,16 +3,20 @@ # Licensed under the MIT License. # ------------------------------------ import logging -from typing import Any, Optional, TYPE_CHECKING, cast +from typing import Any, Optional, cast from azure.core.exceptions import ClientAuthenticationError -from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions, SupportsTokenInfo +from azure.core.credentials import ( + AccessToken, + AccessTokenInfo, + TokenRequestOptions, + SupportsTokenInfo, + TokenCredential, + TokenProvider, +) from .. import CredentialUnavailableError from .._internal import within_credential_chain -if TYPE_CHECKING: - from azure.core.credentials import TokenCredential - _LOGGER = logging.getLogger(__name__) @@ -48,12 +52,11 @@ class ChainedTokenCredential: :caption: Create a ChainedTokenCredential. """ - def __init__(self, *credentials): - # type: (*TokenCredential) -> None + def __init__(self, *credentials: TokenProvider) -> None: if not credentials: raise ValueError("at least one credential is required") - self._successful_credential = None # type: Optional[TokenCredential] + self._successful_credential: Optional[TokenProvider] = None self.credentials = credentials def __enter__(self) -> "ChainedTokenCredential": @@ -70,7 +73,12 @@ def close(self) -> None: self.__exit__() def get_token( - self, *scopes: str, claims: Optional[str] = None, tenant_id: Optional[str] = None, **kwargs: Any + self, + *scopes: str, + claims: Optional[str] = None, + tenant_id: Optional[str] = None, + enable_cae: bool = False, + **kwargs: Any, ) -> AccessToken: """Request a token from each chained credential, in order, returning the first token received. @@ -85,20 +93,38 @@ def get_token( :keyword str claims: additional claims required in the token, such as those returned in a resource provider's claims challenge following an authorization failure. :keyword str tenant_id: optional tenant to include in the token request. + :keyword bool enable_cae: indicates whether to enable Continuous Access Evaluation (CAE) for the requested + token. Defaults to False. :return: An access token with the desired scopes. :rtype: ~azure.core.credentials.AccessToken :raises ~azure.core.exceptions.ClientAuthenticationError: no credential in the chain provided a token """ + within_credential_chain.set(True) history = [] for credential in self.credentials: try: - token = credential.get_token(*scopes, claims=claims, tenant_id=tenant_id, **kwargs) + # Prioritize "get_token". Fall back to "get_token_info" if not available. + if hasattr(credential, "get_token"): + token = cast(TokenCredential, credential).get_token( + *scopes, claims=claims, tenant_id=tenant_id, enable_cae=enable_cae, **kwargs + ) + else: + options: TokenRequestOptions = {} + if claims: + options["claims"] = claims + if tenant_id: + options["tenant_id"] = tenant_id + options["enable_cae"] = enable_cae + token_info = cast(SupportsTokenInfo, credential).get_token_info(*scopes, options=options) + token = AccessToken(token_info.token, token_info.expires_on) + _LOGGER.info("%s acquired a token from %s", self.__class__.__name__, credential.__class__.__name__) self._successful_credential = credential within_credential_chain.set(False) return token + except CredentialUnavailableError as ex: # credential didn't attempt authentication because it lacks required data or state -> continue history.append((credential, ex.message)) @@ -113,7 +139,6 @@ def get_token( exc_info=True, ) break - within_credential_chain.set(False) attempts = _get_error_message(history) message = ( @@ -150,7 +175,7 @@ def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = options = options or {} for credential in self.credentials: try: - # A custom credential in the chain may not implement get_token_info + # Prioritize "get_token_info". Fall back to "get_token" if not available. if hasattr(credential, "get_token_info"): token_info = cast(SupportsTokenInfo, credential).get_token_info(*scopes, options=options) else: @@ -158,8 +183,9 @@ def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = raise CredentialUnavailableError( "Proof of possession arguments are not supported for this credential." ) - token = credential.get_token(*scopes, **options) + token = cast(TokenCredential, credential).get_token(*scopes, **options) token_info = AccessTokenInfo(token=token.token, expires_on=token.expires_on) + _LOGGER.info("%s acquired a token from %s", self.__class__.__name__, credential.__class__.__name__) self._successful_credential = credential within_credential_chain.set(False) diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/default.py b/sdk/identity/azure-identity/azure/identity/_credentials/default.py index b1503cb65dff..b19916275b8f 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/default.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/default.py @@ -4,9 +4,9 @@ # ------------------------------------ import logging import os -from typing import List, TYPE_CHECKING, Any, Optional, cast +from typing import List, Any, Optional, cast -from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions, SupportsTokenInfo +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions, SupportsTokenInfo, TokenCredential from .._constants import EnvironmentVariables from .._internal import get_default_authority, normalize_authority, within_dac from .azure_powershell import AzurePowerShellCredential @@ -20,9 +20,6 @@ from .vscode import VisualStudioCodeCredential from .workload_identity import WorkloadIdentityCredential -if TYPE_CHECKING: - from azure.core.credentials import TokenCredential - _LOGGER = logging.getLogger(__name__) @@ -217,7 +214,9 @@ def get_token( `message` attribute listing each authentication attempt and its error message. """ if self._successful_credential: - token = self._successful_credential.get_token(*scopes, claims=claims, tenant_id=tenant_id, **kwargs) + token = cast(TokenCredential, self._successful_credential).get_token( + *scopes, claims=claims, tenant_id=tenant_id, **kwargs + ) _LOGGER.info( "%s acquired a token from %s", self.__class__.__name__, self._successful_credential.__class__.__name__ ) diff --git a/sdk/identity/azure-identity/azure/identity/_internal/interactive.py b/sdk/identity/azure-identity/azure/identity/_internal/interactive.py index d1b8ab869719..e6397e600418 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/interactive.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/interactive.py @@ -190,6 +190,12 @@ def _get_token_base( claims = options.get("claims") tenant_id = options.get("tenant_id") enable_cae = options.get("enable_cae", False) + + # Check for arbitrary additional options to enable intermediary support for PoP tokens. + for key in options: + if key not in TokenRequestOptions.__annotations__: + kwargs.setdefault(key, options[key]) + try: token = self._acquire_token_silent( *scopes, claims=claims, tenant_id=tenant_id, enable_cae=enable_cae, **kwargs diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/application.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/application.py index fa16e36a3609..980171ef1e6c 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/application.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/application.py @@ -7,7 +7,7 @@ from typing import Optional, Any, cast from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions -from azure.core.credentials_async import AsyncSupportsTokenInfo +from azure.core.credentials_async import AsyncSupportsTokenInfo, AsyncTokenCredential from .chained import ChainedTokenCredential from .environment import EnvironmentCredential from .managed_identity import ManagedIdentityCredential @@ -83,7 +83,9 @@ async def get_token( `message` attribute listing each authentication attempt and its error message. """ if self._successful_credential: - token = await self._successful_credential.get_token(*scopes, claims=claims, tenant_id=tenant_id, **kwargs) + token = await cast(AsyncTokenCredential, self._successful_credential).get_token( + *scopes, claims=claims, tenant_id=tenant_id, **kwargs + ) _LOGGER.info( "%s acquired a token from %s", self.__class__.__name__, self._successful_credential.__class__.__name__ ) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/chained.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/chained.py index 3221cfca42ed..16ce23709dcc 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/chained.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/chained.py @@ -4,19 +4,16 @@ # ------------------------------------ import asyncio import logging -from typing import Any, Optional, TYPE_CHECKING, cast +from typing import Any, Optional, cast from azure.core.exceptions import ClientAuthenticationError from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions -from azure.core.credentials_async import AsyncSupportsTokenInfo +from azure.core.credentials_async import AsyncSupportsTokenInfo, AsyncTokenCredential, AsyncTokenProvider from .._internal import AsyncContextManager from ... import CredentialUnavailableError from ..._credentials.chained import _get_error_message from ..._internal import within_credential_chain -if TYPE_CHECKING: - from azure.core.credentials_async import AsyncTokenCredential - _LOGGER = logging.getLogger(__name__) @@ -39,11 +36,11 @@ class ChainedTokenCredential(AsyncContextManager): :caption: Create a ChainedTokenCredential. """ - def __init__(self, *credentials: "AsyncTokenCredential") -> None: + def __init__(self, *credentials: AsyncTokenProvider) -> None: if not credentials: raise ValueError("at least one credential is required") - self._successful_credential: Optional[AsyncTokenCredential] = None + self._successful_credential: Optional[AsyncTokenProvider] = None self.credentials = credentials async def close(self) -> None: @@ -52,7 +49,12 @@ async def close(self) -> None: await asyncio.gather(*(credential.close() for credential in self.credentials)) async def get_token( - self, *scopes: str, claims: Optional[str] = None, tenant_id: Optional[str] = None, **kwargs: Any + self, + *scopes: str, + claims: Optional[str] = None, + tenant_id: Optional[str] = None, + enable_cae: bool = False, + **kwargs: Any, ) -> AccessToken: """Asynchronously request a token from each credential, in order, returning the first token received. @@ -67,6 +69,8 @@ async def get_token( :keyword str claims: additional claims required in the token, such as those returned in a resource provider's claims challenge following an authorization failure. :keyword str tenant_id: optional tenant to include in the token request. + :keyword bool enable_cae: indicates whether to enable Continuous Access Evaluation (CAE) for the requested + token. Defaults to False. :return: An access token with the desired scopes. :rtype: ~azure.core.credentials.AccessToken @@ -76,7 +80,21 @@ async def get_token( history = [] for credential in self.credentials: try: - token = await credential.get_token(*scopes, claims=claims, tenant_id=tenant_id, **kwargs) + # Prioritize "get_token". Fall back to "get_token_info" if not available. + if hasattr(credential, "get_token"): + token = await cast(AsyncTokenCredential, credential).get_token( + *scopes, claims=claims, tenant_id=tenant_id, enable_cae=enable_cae, **kwargs + ) + else: + options: TokenRequestOptions = {} + if claims: + options["claims"] = claims + if tenant_id: + options["tenant_id"] = tenant_id + options["enable_cae"] = enable_cae + token_info = await cast(AsyncSupportsTokenInfo, credential).get_token_info(*scopes, **kwargs) + token = AccessToken(token_info.token, token_info.expires_on) + _LOGGER.info("%s acquired a token from %s", self.__class__.__name__, credential.__class__.__name__) self._successful_credential = credential within_credential_chain.set(False) @@ -95,7 +113,6 @@ async def get_token( exc_info=True, ) break - within_credential_chain.set(False) attempts = _get_error_message(history) message = ( @@ -131,7 +148,7 @@ async def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptio options = options or {} for credential in self.credentials: try: - # A custom credential in the chain may not implement get_token_info + # Prioritize "get_token_info". Fall back to "get_token" if not available. if hasattr(credential, "get_token_info"): token_info = await cast(AsyncSupportsTokenInfo, credential).get_token_info(*scopes, options=options) else: @@ -139,8 +156,9 @@ async def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptio raise CredentialUnavailableError( "Proof of possession arguments are not supported for this credential." ) - token = await credential.get_token(*scopes, **options) + token = await cast(AsyncTokenCredential, credential).get_token(*scopes, **options) token_info = AccessTokenInfo(token=token.token, expires_on=token.expires_on) + _LOGGER.info("%s acquired a token from %s", self.__class__.__name__, credential.__class__.__name__) self._successful_credential = credential within_credential_chain.set(False) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/default.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/default.py index ded986453685..a5d60715d57b 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/default.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/default.py @@ -196,7 +196,9 @@ async def get_token( `message` attribute listing each authentication attempt and its error message. """ if self._successful_credential: - token = await self._successful_credential.get_token(*scopes, claims=claims, tenant_id=tenant_id, **kwargs) + token = await cast(AsyncTokenCredential, self._successful_credential).get_token( + *scopes, claims=claims, tenant_id=tenant_id, **kwargs + ) _LOGGER.info( "%s acquired a token from %s", self.__class__.__name__, self._successful_credential.__class__.__name__ ) diff --git a/sdk/identity/azure-identity/tests/test_aad_client_async.py b/sdk/identity/azure-identity/tests/test_aad_client_async.py index f618669fba35..56a2f1486a8c 100644 --- a/sdk/identity/azure-identity/tests/test_aad_client_async.py +++ b/sdk/identity/azure-identity/tests/test_aad_client_async.py @@ -180,7 +180,7 @@ async def send(request, **_): await client.obtain_token_by_refresh_token("scope", "refresh token") # obtain_token_by_refresh_token is client_secret safe - client.obtain_token_by_refresh_token("scope", "refresh token", client_secret="secret") + await client.obtain_token_by_refresh_token("scope", "refresh token", client_secret="secret") # authority can be configured via environment variable with patch.dict("os.environ", {EnvironmentVariables.AZURE_AUTHORITY_HOST: authority}, clear=True): diff --git a/sdk/identity/azure-identity/tests/test_chained_credential.py b/sdk/identity/azure-identity/tests/test_chained_credential.py index 9ab770991fe4..0099802d0c68 100644 --- a/sdk/identity/azure-identity/tests/test_chained_credential.py +++ b/sdk/identity/azure-identity/tests/test_chained_credential.py @@ -253,6 +253,28 @@ def test_credentials_with_no_get_token_info(): assert token_info.token == access_token +def test_credentials_with_no_get_token(): + """ChainedTokenCredential should work with credentials that only implement get_token_info.""" + + access_token = "****" + credential1 = Mock( + spec_set=["get_token"], + get_token=Mock(side_effect=CredentialUnavailableError(message="")), + ) + credential2 = Mock( + spec_set=["get_token_info"], + get_token_info=Mock(return_value=AccessTokenInfo(access_token, 42)), + ) + credential3 = Mock( + spec_set=["get_token", "get_token_info"], + get_token=Mock(return_value=AccessToken("foo", 42)), + get_token_info=Mock(return_value=AccessTokenInfo("bar", 42)), + ) + chain = ChainedTokenCredential(credential1, credential2, credential3) # type: ignore + token_info = chain.get_token("scope") + assert token_info.token == access_token + + def test_credentials_with_pop_option(): """ChainedTokenCredential should skip credentials that don't support get_token_info and the pop option is set.""" diff --git a/sdk/identity/azure-identity/tests/test_chained_token_credential_async.py b/sdk/identity/azure-identity/tests/test_chained_token_credential_async.py index aa105db967bd..fb6686856690 100644 --- a/sdk/identity/azure-identity/tests/test_chained_token_credential_async.py +++ b/sdk/identity/azure-identity/tests/test_chained_token_credential_async.py @@ -254,6 +254,32 @@ async def credential_unavailable(message="it didn't work", **_): assert token_info.token == access_token +@pytest.mark.asyncio +async def test_credentials_with_no_get_token(): + """ChainedTokenCredential should work with credentials that only implement get_token_info.""" + + async def credential_unavailable(message="it didn't work", **_): + raise CredentialUnavailableError(message) + + access_token = "****" + credential1 = Mock( + spec_set=["get_token"], + get_token=Mock(wraps=credential_unavailable), + ) + credential2 = Mock( + spec_set=["get_token_info"], + get_token_info=Mock(wraps=wrap_in_future(lambda _, **__: AccessTokenInfo(access_token, 42))), + ) + credential3 = Mock( + spec_set=["get_token", "get_token_info"], + get_token=Mock(wraps=wrap_in_future(lambda _, **__: AccessToken("foo", 42))), + get_token_info=Mock(wraps=wrap_in_future(lambda _, **__: AccessTokenInfo("bar", 42))), + ) + chain = ChainedTokenCredential(credential1, credential2, credential3) # type: ignore + token_info = await chain.get_token("scope") + assert token_info.token == access_token + + @pytest.mark.asyncio async def test_credentials_with_pop_option(): """ChainedTokenCredential should skip credentials that don't support get_token_info and the pop option is set.""" diff --git a/sdk/identity/azure-identity/tests/test_interactive_credential.py b/sdk/identity/azure-identity/tests/test_interactive_credential.py index c6fcd3da9993..d623f8b80048 100644 --- a/sdk/identity/azure-identity/tests/test_interactive_credential.py +++ b/sdk/identity/azure-identity/tests/test_interactive_credential.py @@ -450,3 +450,19 @@ def send(request, **kwargs): kwargs = {"options": kwargs} token = getattr(credential, get_token_method)("scope", **kwargs) assert token.token == expected_token + + +def test_arbitrary_kwargs_propagated_get_token_info(): + """For intermediary testing of PoP support.""" + + class TestCredential(InteractiveCredential): + def __init__(self, **kwargs): + super(TestCredential, self).__init__(client_id="...", **kwargs) + + def _request_token(self, *_, **kwargs): + assert "foo" in kwargs + raise ValueError("Raising here since keyword arg was propagated") + + credential = TestCredential() + with pytest.raises(ValueError): + credential.get_token_info("scope", options={"foo": "bar"}) # type: ignore From c25e1eb4b1e57fb3e63130cc6e44e3f2be585365 Mon Sep 17 00:00:00 2001 From: Paul Van Eck Date: Tue, 17 Sep 2024 08:03:32 +0000 Subject: [PATCH 7/8] add some ignores Signed-off-by: Paul Van Eck --- .../azure-identity/azure/identity/_internal/interactive.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sdk/identity/azure-identity/azure/identity/_internal/interactive.py b/sdk/identity/azure-identity/azure/identity/_internal/interactive.py index e6397e600418..9667132888c6 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/interactive.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/interactive.py @@ -193,8 +193,8 @@ def _get_token_base( # Check for arbitrary additional options to enable intermediary support for PoP tokens. for key in options: - if key not in TokenRequestOptions.__annotations__: - kwargs.setdefault(key, options[key]) + if key not in TokenRequestOptions.__annotations__: # pylint:disable=no-member + kwargs.setdefault(key, options[key]) # type: ignore try: token = self._acquire_token_silent( From 56273c93ffa84a067713ed77cca7989c8eecee49 Mon Sep 17 00:00:00 2001 From: Paul Van Eck Date: Tue, 17 Sep 2024 22:31:50 +0000 Subject: [PATCH 8/8] Changes based on feedback - Updated some typing - Propagated refresh_on in additional places - Propagated "token_type" for InteractiveCredential Signed-off-by: Paul Van Eck --- .../azure/identity/_bearer_token_provider.py | 4 ++-- .../azure/identity/_credentials/default.py | 2 +- .../identity/_credentials/managed_identity.py | 10 ++++------ .../azure/identity/_credentials/on_behalf_of.py | 8 ++++++-- .../azure/identity/_credentials/shared_cache.py | 16 +++++++++------- .../azure/identity/_credentials/silent.py | 3 +++ .../identity/_internal/client_credential_base.py | 2 +- .../azure/identity/_internal/interactive.py | 8 +++++++- .../azure/identity/aio/_credentials/default.py | 2 +- .../aio/_credentials/managed_identity.py | 12 ++++++------ 10 files changed, 40 insertions(+), 27 deletions(-) diff --git a/sdk/identity/azure-identity/azure/identity/_bearer_token_provider.py b/sdk/identity/azure-identity/azure/identity/_bearer_token_provider.py index 209f46d46ef7..3617f56eab33 100644 --- a/sdk/identity/azure-identity/azure/identity/_bearer_token_provider.py +++ b/sdk/identity/azure-identity/azure/identity/_bearer_token_provider.py @@ -4,7 +4,7 @@ # ------------------------------------ from typing import Callable -from azure.core.credentials import TokenCredential +from azure.core.credentials import TokenProvider from azure.core.pipeline.policies import BearerTokenCredentialPolicy from azure.core.pipeline import PipelineRequest, PipelineContext from azure.core.rest import HttpRequest @@ -14,7 +14,7 @@ def _make_request() -> PipelineRequest[HttpRequest]: return PipelineRequest(HttpRequest("CredentialWrapper", "https://fakeurl"), PipelineContext(None)) -def get_bearer_token_provider(credential: TokenCredential, *scopes: str) -> Callable[[], str]: +def get_bearer_token_provider(credential: TokenProvider, *scopes: str) -> Callable[[], str]: """Returns a callable that provides a bearer token. It can be used for instance to write code like: diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/default.py b/sdk/identity/azure-identity/azure/identity/_credentials/default.py index b19916275b8f..035efb52bd39 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/default.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/default.py @@ -141,7 +141,7 @@ def __init__(self, **kwargs: Any) -> None: # pylint: disable=too-many-statement exclude_interactive_browser_credential = kwargs.pop("exclude_interactive_browser_credential", True) exclude_powershell_credential = kwargs.pop("exclude_powershell_credential", False) - credentials: List["TokenCredential"] = [] + credentials: List[SupportsTokenInfo] = [] within_dac.set(True) if not exclude_environment_credential: credentials.append(EnvironmentCredential(authority=authority, _within_dac=True, **kwargs)) diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/managed_identity.py b/sdk/identity/azure-identity/azure/identity/_credentials/managed_identity.py index c9d9aef9408b..db8667b1d519 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/managed_identity.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/managed_identity.py @@ -4,15 +4,13 @@ # ------------------------------------ import logging import os -from typing import Optional, TYPE_CHECKING, Any, Mapping, cast +from typing import Optional, Any, Mapping, cast -from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions, SupportsTokenInfo +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions, TokenCredential, SupportsTokenInfo from .. import CredentialUnavailableError from .._constants import EnvironmentVariables from .._internal.decorators import log_get_token -if TYPE_CHECKING: - from azure.core.credentials import TokenCredential _LOGGER = logging.getLogger(__name__) @@ -62,7 +60,7 @@ def __init__( self, *, client_id: Optional[str] = None, identity_config: Optional[Mapping[str, str]] = None, **kwargs: Any ) -> None: validate_identity_config(client_id, identity_config) - self._credential: Optional[TokenCredential] = None + self._credential: Optional[SupportsTokenInfo] = None exclude_workload_identity = kwargs.pop("_exclude_workload_identity_credential", False) if os.environ.get(EnvironmentVariables.IDENTITY_ENDPOINT): if os.environ.get(EnvironmentVariables.IDENTITY_HEADER): @@ -159,7 +157,7 @@ def get_token( "Visit https://aka.ms/azsdk/python/identity/managedidentitycredential/troubleshoot to " "troubleshoot this issue." ) - return self._credential.get_token(*scopes, claims=claims, tenant_id=tenant_id, **kwargs) + return cast(TokenCredential, self._credential).get_token(*scopes, claims=claims, tenant_id=tenant_id, **kwargs) @log_get_token def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/on_behalf_of.py b/sdk/identity/azure-identity/azure/identity/_credentials/on_behalf_of.py index 464014cf20b8..93675adb84e7 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/on_behalf_of.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/on_behalf_of.py @@ -134,7 +134,10 @@ def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[Acces now = int(time.time()) result = app.acquire_token_silent_with_error(list(scopes), account=account, claims_challenge=claims) if result and "access_token" in result and "expires_in" in result: - return AccessTokenInfo(result["access_token"], now + int(result["expires_in"])) + refresh_on = int(result["refresh_on"]) if "refresh_on" in result else None + return AccessTokenInfo( + result["access_token"], now + int(result["expires_in"]), refresh_on=refresh_on + ) return None @@ -153,4 +156,5 @@ def _request_token(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo: except ClientAuthenticationError: pass # non-fatal; we'll use the assertion again next time instead of a refresh token - return AccessTokenInfo(result["access_token"], request_time + int(result["expires_in"])) + refresh_on = int(result["refresh_on"]) if "refresh_on" in result else None + return AccessTokenInfo(result["access_token"], request_time + int(result["expires_in"]), refresh_on=refresh_on) diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py b/sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py index 31870a6401db..39e895c6997b 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py @@ -2,8 +2,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ -from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast -from azure.core.credentials import AccessToken, TokenRequestOptions, AccessTokenInfo, SupportsTokenInfo +from typing import Any, Optional, TypeVar, cast +from azure.core.credentials import AccessToken, TokenRequestOptions, AccessTokenInfo, SupportsTokenInfo, TokenCredential from .silent import SilentAuthenticationCredential from .. import CredentialUnavailableError @@ -12,9 +12,6 @@ from .._internal.decorators import log_get_token from .._internal.shared_token_cache import NO_TOKEN, SharedTokenCacheBase -if TYPE_CHECKING: - from azure.core.credentials import TokenCredential - T = TypeVar("T", bound="_SharedTokenCacheCredential") @@ -39,7 +36,7 @@ class SharedTokenCacheCredential: def __init__(self, username: Optional[str] = None, **kwargs: Any) -> None: if "authentication_record" in kwargs: - self._credential: TokenCredential = SilentAuthenticationCredential(**kwargs) + self._credential: SupportsTokenInfo = SilentAuthenticationCredential(**kwargs) else: self._credential = _SharedTokenCacheCredential(username=username, **kwargs) @@ -85,7 +82,9 @@ def get_token( :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message`` attribute gives a reason. """ - return self._credential.get_token(*scopes, claims=claims, tenant_id=tenant_id, enable_cae=enable_cae, **kwargs) + return cast(TokenCredential, self._credential).get_token( + *scopes, claims=claims, tenant_id=tenant_id, enable_cae=enable_cae, **kwargs + ) @log_get_token def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: @@ -132,6 +131,9 @@ def __exit__(self, *args): if self._client: self._client.__exit__(*args) + def close(self) -> None: + self.__exit__() + def get_token( self, *scopes: str, diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/silent.py b/sdk/identity/azure-identity/azure/identity/_credentials/silent.py index 19b2b49a01a7..80d170a629d7 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/silent.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/silent.py @@ -58,6 +58,9 @@ def __enter__(self) -> "SilentAuthenticationCredential": def __exit__(self, *args): self._client.__exit__(*args) + def close(self) -> None: + self.__exit__() + def get_token( self, *scopes: str, diff --git a/sdk/identity/azure-identity/azure/identity/_internal/client_credential_base.py b/sdk/identity/azure-identity/azure/identity/_internal/client_credential_base.py index 16e4b75928f7..7685e8522c39 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/client_credential_base.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/client_credential_base.py @@ -39,7 +39,7 @@ def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[Acces return None @wrap_exceptions - def _request_token(self, *scopes: str, **kwargs: Any) -> Optional[AccessTokenInfo]: + def _request_token(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo: app = self._get_app(**kwargs) request_time = int(time.time()) result = app.acquire_token_for_client(list(scopes), claims_challenge=kwargs.pop("claims", None)) diff --git a/sdk/identity/azure-identity/azure/identity/_internal/interactive.py b/sdk/identity/azure-identity/azure/identity/_internal/interactive.py index 9667132888c6..c2665ee15932 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/interactive.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/interactive.py @@ -236,7 +236,13 @@ def _get_token_base( raise _LOGGER.info("%s.%s succeeded", self.__class__.__name__, base_method_name) - return AccessTokenInfo(result["access_token"], now + int(result["expires_in"])) + refresh_on = int(result["refresh_on"]) if "refresh_on" in result else None + return AccessTokenInfo( + result["access_token"], + now + int(result["expires_in"]), + token_type=result.get("token_type", "Bearer"), + refresh_on=refresh_on, + ) def authenticate( self, *, scopes: Optional[Iterable[str]] = None, claims: Optional[str] = None, **kwargs: Any diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/default.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/default.py index a5d60715d57b..fe1bc03b6084 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/default.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/default.py @@ -133,7 +133,7 @@ def __init__(self, **kwargs: Any) -> None: # pylint: disable=too-many-statement exclude_shared_token_cache_credential = kwargs.pop("exclude_shared_token_cache_credential", False) exclude_powershell_credential = kwargs.pop("exclude_powershell_credential", False) - credentials: List[AsyncTokenCredential] = [] + credentials: List[AsyncSupportsTokenInfo] = [] within_dac.set(True) if not exclude_environment_credential: credentials.append(EnvironmentCredential(authority=authority, _within_dac=True, **kwargs)) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/managed_identity.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/managed_identity.py index 6b8ced00219b..22ab3f4a269a 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/managed_identity.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/managed_identity.py @@ -4,18 +4,16 @@ # ------------------------------------ import logging import os -from typing import TYPE_CHECKING, Optional, Any, Mapping, cast +from typing import Optional, Any, Mapping, cast from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions -from azure.core.credentials_async import AsyncSupportsTokenInfo +from azure.core.credentials_async import AsyncTokenCredential, AsyncSupportsTokenInfo from .._internal import AsyncContextManager from .._internal.decorators import log_get_token_async from ... import CredentialUnavailableError from ..._constants import EnvironmentVariables from ..._credentials.managed_identity import validate_identity_config -if TYPE_CHECKING: - from azure.core.credentials_async import AsyncTokenCredential _LOGGER = logging.getLogger(__name__) @@ -49,7 +47,7 @@ def __init__( self, *, client_id: Optional[str] = None, identity_config: Optional[Mapping[str, str]] = None, **kwargs: Any ) -> None: validate_identity_config(client_id, identity_config) - self._credential: Optional[AsyncTokenCredential] = None + self._credential: Optional[AsyncSupportsTokenInfo] = None exclude_workload_identity = kwargs.pop("_exclude_workload_identity_credential", False) if os.environ.get(EnvironmentVariables.IDENTITY_ENDPOINT): @@ -142,7 +140,9 @@ async def get_token( "Visit https://aka.ms/azsdk/python/identity/managedidentitycredential/troubleshoot to " "troubleshoot this issue." ) - return await self._credential.get_token(*scopes, claims=claims, tenant_id=tenant_id, **kwargs) + return await cast(AsyncTokenCredential, self._credential).get_token( + *scopes, claims=claims, tenant_id=tenant_id, **kwargs + ) @log_get_token_async async def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: