From 57d5d41b5ed772584da60664b849db28dc98606e Mon Sep 17 00:00:00 2001 From: Charles Lowell Date: Tue, 16 Jul 2019 09:57:03 -0700 Subject: [PATCH 1/8] device code credential --- .../azure-identity/azure/identity/__init__.py | 2 + .../azure/identity/credentials.py | 72 ++++++++++++++++++- 2 files changed, 73 insertions(+), 1 deletion(-) diff --git a/sdk/identity/azure-identity/azure/identity/__init__.py b/sdk/identity/azure-identity/azure/identity/__init__.py index 1df14dd617f5..5a9ccc7aad36 100644 --- a/sdk/identity/azure-identity/azure/identity/__init__.py +++ b/sdk/identity/azure-identity/azure/identity/__init__.py @@ -7,6 +7,7 @@ CertificateCredential, ChainedTokenCredential, ClientSecretCredential, + DeviceCodeCredential, EnvironmentCredential, ManagedIdentityCredential, UsernamePasswordCredential, @@ -35,6 +36,7 @@ def __init__(self, **kwargs): "ChainedTokenCredential", "ClientSecretCredential", "DefaultAzureCredential", + "DeviceCodeCredential", "EnvironmentCredential", "InteractiveBrowserCredential", "ManagedIdentityCredential", diff --git a/sdk/identity/azure-identity/azure/identity/credentials.py b/sdk/identity/azure-identity/azure/identity/credentials.py index 2e09a306aa71..17a85e163e1f 100644 --- a/sdk/identity/azure-identity/azure/identity/credentials.py +++ b/sdk/identity/azure-identity/azure/identity/credentials.py @@ -26,7 +26,7 @@ if TYPE_CHECKING: # pylint:disable=unused-import - from typing import Any, Dict, Mapping, Optional, Union + from typing import Any, Callable, Dict, Mapping, Optional, Union from azure.core.credentials import TokenCredential EnvironmentCredentialTypes = Union["CertificateCredential", "ClientSecretCredential", "UsernamePasswordCredential"] @@ -309,3 +309,73 @@ def get_token(self, *scopes): raise ClientAuthenticationError(message="authentication failed: {}".format(result.get("error_description"))) return AccessToken(result["access_token"], now + int(result["expires_in"])) + + +class DeviceCodeCredential(PublicClientCredential): + """ + Authenticates users through the device code flow. When ``get_token`` is called, this credential acquires a + verification URL and code from Azure Active Directory. A user must browse to the URL, enter the code, and + authenticate with Directory. If the user authenticates successfully, the credential receives an access token. + + This credential doesn't cache tokens--each ``get_token`` call begins a new authentication flow. + + For more information about the device code flow, see Azure Active Directory documentation: + https://docs.microsoft.com/en-us/azure/active-directory/develop/v2-oauth2-device-code + + :param str client_id: the application's ID + :param prompt_callback: (optional) A callable with string parameters (``verification_uri``, ``user_code``). + ``verification_uri`` is the URL the user must visit. ``user_code`` is the code the user must enter there. + Provide this callback if you want to control how authentication instructions are presented. Otherwise, the + credential will print them to stdout. + + **Keyword arguments:** + + *tenant (str)* - a tenant ID or a domain associated with a tenant. If not provided, the credential defaults to the + 'organizations' tenant, which supports only Azure Active Directory work or school accounts. + *timeout (int)* - seconds to wait for the user to complete authentication. Defaults to 300 (5 minutes). + """ + + def __init__(self, client_id, prompt_callback=None, **kwargs): + # type: (str, Optional[Callable[[str, str], None]], Any) -> None + self._timeout = kwargs.pop("timeout", 300) + self._prompt_callback = prompt_callback + super(DeviceCodeCredential, self).__init__(client_id=client_id, **kwargs) + + def get_token(self, *scopes): + # type (*str) -> AccessToken + """ + Request an access token for `scopes`. This credential won't cache the token. Each call begins a new + authentication flow. + + :param str scopes: desired scopes for the token + :rtype: :class:`azure.core.credentials.AccessToken` + :raises: :class:`azure.core.exceptions.ClientAuthenticationError` + """ + + # MSAL requires scopes be a list + scopes = list(scopes) # type: ignore + now = int(time.time()) + + flow = self._app.initiate_device_flow(scopes) + if "error" in flow: + raise ClientAuthenticationError( + message="Couldn't begin authentication: {}".format(flow.get("error_description") or flow.get("error")) + ) + + if self._prompt_callback: + self._prompt_callback(flow["verification_uri"], flow["user_code"]) + else: + print(flow["message"]) + + deadline = now + self._timeout + result = self._app.acquire_token_by_device_flow(flow, exit_condition=lambda flow: time.time() >= deadline) + + if "access_token" not in result: + if result.get("error") == "authorization_pending": + message = "Timed out after waiting {} seconds for user to authenticate".format(self._timeout) + else: + message = "Authentication failed: {}".format(result.get("error_description")) + raise ClientAuthenticationError(message=message) + + token = AccessToken(result["access_token"], now + int(result["expires_in"])) + return token From b0e32b0bb935c354136a50a6ca2fe4d20ea0686b Mon Sep 17 00:00:00 2001 From: Charles Lowell Date: Tue, 23 Jul 2019 14:06:32 -0700 Subject: [PATCH 2/8] offline tests --- .../azure-identity/tests/test_identity.py | 63 ++++++++++++++++++- 1 file changed, 61 insertions(+), 2 deletions(-) diff --git a/sdk/identity/azure-identity/tests/test_identity.py b/sdk/identity/azure-identity/tests/test_identity.py index 3aaa54a1f28f..618d81501321 100644 --- a/sdk/identity/azure-identity/tests/test_identity.py +++ b/sdk/identity/azure-identity/tests/test_identity.py @@ -13,20 +13,21 @@ except ImportError: # python < 3.3 from mock import Mock, patch -import pytest from azure.core.credentials import AccessToken from azure.core.exceptions import ClientAuthenticationError from azure.identity import ( + ChainedTokenCredential, ClientSecretCredential, DefaultAzureCredential, + DeviceCodeCredential, EnvironmentCredential, ManagedIdentityCredential, - ChainedTokenCredential, InteractiveBrowserCredential, UsernamePasswordCredential, ) from azure.identity._managed_identity import ImdsCredential from azure.identity.constants import EnvironmentVariables +import pytest from helpers import mock_response, Request, validating_transport @@ -244,6 +245,64 @@ def test_default_credential(): DefaultAzureCredential() +def test_device_code_credential(): + expected_token = "access-token" + user_code = "user-code" + verification_uri = "verification-uri" + + transport = validating_transport( + requests=[Request()] * 3, # not validating requests because they're formed by MSAL + responses=[ + # expected requests: discover tenant, start device code flow, poll for completion + mock_response(json_payload={"authorization_endpoint": "https://a/b", "token_endpoint": "https://a/b"}), + mock_response( + json_payload={"device_code": "_", "user_code": user_code, "verification_uri": verification_uri} + ), + mock_response( + json_payload={ + "access_token": expected_token, + "expires_in": 42, + "scope": "scope", + "token_type": "Bearer", + "refresh_token": "_", + } + ), + ], + ) + + callback = Mock() + credential = DeviceCodeCredential( + client_id="_", prompt_callback=callback, transport=transport, instance_discovery=False + ) + + token = credential.get_token("scope") + assert token.token == expected_token + + # prompt_callback should have been called as documented + assert callback.call_count == 1 + assert callback.call_args[0] == (verification_uri, user_code) + + +def test_device_code_credential_timeout(): + transport = validating_transport( + requests=[Request()] * 3, # not validating requests because they're formed by MSAL + responses=[ + # expected requests: discover tenant, start device code flow, poll for completion + mock_response(json_payload={"authorization_endpoint": "https://a/b", "token_endpoint": "https://a/b"}), + mock_response(json_payload={"device_code": "_", "user_code": "_", "verification_uri": "_"}), + mock_response(json_payload={"error": "authorization_pending"}), + ], + ) + + credential = DeviceCodeCredential( + client_id="_", prompt_callback=Mock(), transport=transport, timeout=0.1, instance_discovery=False + ) + + with pytest.raises(ClientAuthenticationError) as ex: + credential.get_token("scope") + assert "timed out" in ex.value.message.lower() + + @patch("azure.identity.browser_auth.webbrowser.open", lambda _: None) # prevent the credential opening a browser def test_interactive_credential(): oauth_state = "state" From 84d67dd1f54b5a30607876646a1e67a3b552a501 Mon Sep 17 00:00:00 2001 From: Charles Lowell Date: Fri, 26 Jul 2019 12:13:22 -0700 Subject: [PATCH 3/8] AAD timeout trumps user's --- .../azure-identity/azure/identity/credentials.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/sdk/identity/azure-identity/azure/identity/credentials.py b/sdk/identity/azure-identity/azure/identity/credentials.py index 17a85e163e1f..50bea0ba63ce 100644 --- a/sdk/identity/azure-identity/azure/identity/credentials.py +++ b/sdk/identity/azure-identity/azure/identity/credentials.py @@ -332,12 +332,13 @@ class DeviceCodeCredential(PublicClientCredential): *tenant (str)* - a tenant ID or a domain associated with a tenant. If not provided, the credential defaults to the 'organizations' tenant, which supports only Azure Active Directory work or school accounts. - *timeout (int)* - seconds to wait for the user to complete authentication. Defaults to 300 (5 minutes). + *timeout (int)* - seconds to wait for the user to authenticate. Defaults to the validity period of the device code + as set by Azure Active Directory, which also prevails when ``timeout`` is longer. """ def __init__(self, client_id, prompt_callback=None, **kwargs): # type: (str, Optional[Callable[[str, str], None]], Any) -> None - self._timeout = kwargs.pop("timeout", 300) + self._timeout = kwargs.pop("timeout", None) # type: Optional[int] self._prompt_callback = prompt_callback super(DeviceCodeCredential, self).__init__(client_id=client_id, **kwargs) @@ -367,14 +368,17 @@ def get_token(self, *scopes): else: print(flow["message"]) - deadline = now + self._timeout - result = self._app.acquire_token_by_device_flow(flow, exit_condition=lambda flow: time.time() >= deadline) + if self._timeout is not None and self._timeout < flow["expires_in"]: + deadline = now + self._timeout + result = app.acquire_token_by_device_flow(flow, exit_condition=lambda flow: time.time() > deadline) + else: + result = app.acquire_token_by_device_flow(flow) if "access_token" not in result: if result.get("error") == "authorization_pending": - message = "Timed out after waiting {} seconds for user to authenticate".format(self._timeout) + message = "Timed out waiting for user to authenticate" else: - message = "Authentication failed: {}".format(result.get("error_description")) + message = "Authentication failed: {}".format(result.get("error_description") or result.get("error")) raise ClientAuthenticationError(message=message) token = AccessToken(result["access_token"], now + int(result["expires_in"])) From 923c9a33634ddb1b4c47ba2ed1b9aabc1002813d Mon Sep 17 00:00:00 2001 From: Charles Lowell Date: Fri, 26 Jul 2019 12:14:00 -0700 Subject: [PATCH 4/8] include code lifetime in callback --- .../azure-identity/azure/identity/credentials.py | 12 +++++++----- sdk/identity/azure-identity/tests/test_identity.py | 7 ++++--- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/sdk/identity/azure-identity/azure/identity/credentials.py b/sdk/identity/azure-identity/azure/identity/credentials.py index 50bea0ba63ce..beb6981fe3ef 100644 --- a/sdk/identity/azure-identity/azure/identity/credentials.py +++ b/sdk/identity/azure-identity/azure/identity/credentials.py @@ -323,10 +323,12 @@ class DeviceCodeCredential(PublicClientCredential): https://docs.microsoft.com/en-us/azure/active-directory/develop/v2-oauth2-device-code :param str client_id: the application's ID - :param prompt_callback: (optional) A callable with string parameters (``verification_uri``, ``user_code``). - ``verification_uri`` is the URL the user must visit. ``user_code`` is the code the user must enter there. - Provide this callback if you want to control how authentication instructions are presented. Otherwise, the - credential will print them to stdout. + :param prompt_callback: (optional) A callback enabling control of how authentication instructions are presented. + If not provided, the credential will print instructions to stdout. + :type prompt_callback: A callable accepting arguments (``verification_uri``, ``user_code``, ``expires_in``): + - ``verification_uri`` (str) the URL the user must visit + - ``user_code`` (str) the code the user must enter there + - ``expires_in`` (int) the number of seconds the code will be valid **Keyword arguments:** @@ -364,7 +366,7 @@ def get_token(self, *scopes): ) if self._prompt_callback: - self._prompt_callback(flow["verification_uri"], flow["user_code"]) + self._prompt_callback(flow["verification_uri"], flow["user_code"], flow["expires_in"]) else: print(flow["message"]) diff --git a/sdk/identity/azure-identity/tests/test_identity.py b/sdk/identity/azure-identity/tests/test_identity.py index 618d81501321..b984c38959c7 100644 --- a/sdk/identity/azure-identity/tests/test_identity.py +++ b/sdk/identity/azure-identity/tests/test_identity.py @@ -249,6 +249,7 @@ def test_device_code_credential(): expected_token = "access-token" user_code = "user-code" verification_uri = "verification-uri" + expires_in = 42 transport = validating_transport( requests=[Request()] * 3, # not validating requests because they're formed by MSAL @@ -256,12 +257,12 @@ def test_device_code_credential(): # expected requests: discover tenant, start device code flow, poll for completion mock_response(json_payload={"authorization_endpoint": "https://a/b", "token_endpoint": "https://a/b"}), mock_response( - json_payload={"device_code": "_", "user_code": user_code, "verification_uri": verification_uri} + json_payload={"device_code": "_", "user_code": user_code, "verification_uri": verification_uri, "expires_in": expires_in} ), mock_response( json_payload={ "access_token": expected_token, - "expires_in": 42, + "expires_in": expires_in, "scope": "scope", "token_type": "Bearer", "refresh_token": "_", @@ -280,7 +281,7 @@ def test_device_code_credential(): # prompt_callback should have been called as documented assert callback.call_count == 1 - assert callback.call_args[0] == (verification_uri, user_code) + assert callback.call_args[0] == (verification_uri, user_code, expires_in) def test_device_code_credential_timeout(): From 2ee0f3dd9847f0497d7e7413f879d17fd2e6f584 Mon Sep 17 00:00:00 2001 From: Charles Lowell Date: Fri, 26 Jul 2019 12:19:58 -0700 Subject: [PATCH 5/8] remove low-value test --- sdk/identity/azure-identity/tests/test_identity.py | 5 ----- sdk/identity/azure-identity/tests/test_identity_async.py | 6 ------ 2 files changed, 11 deletions(-) diff --git a/sdk/identity/azure-identity/tests/test_identity.py b/sdk/identity/azure-identity/tests/test_identity.py index b984c38959c7..e3cfc64544c0 100644 --- a/sdk/identity/azure-identity/tests/test_identity.py +++ b/sdk/identity/azure-identity/tests/test_identity.py @@ -124,11 +124,6 @@ def test_client_secret_environment_credential(monkeypatch): assert token.token == access_token -def test_environment_credential_error(): - with pytest.raises(ClientAuthenticationError): - EnvironmentCredential().get_token("scope") - - def test_credential_chain_error_message(): def raise_authn_error(message): raise ClientAuthenticationError(message) diff --git a/sdk/identity/azure-identity/tests/test_identity_async.py b/sdk/identity/azure-identity/tests/test_identity_async.py index 03b53db9eccd..06254dc898ba 100644 --- a/sdk/identity/azure-identity/tests/test_identity_async.py +++ b/sdk/identity/azure-identity/tests/test_identity_async.py @@ -121,12 +121,6 @@ async def test_client_secret_environment_credential(monkeypatch): assert token.token == access_token -@pytest.mark.asyncio -async def test_environment_credential_error(): - with pytest.raises(ClientAuthenticationError): - await EnvironmentCredential().get_token("scope") - - @pytest.mark.asyncio async def test_credential_chain_error_message(): def raise_authn_error(message): From 4ad311caf1b7d06154cdbf084caa934f6ac5edc9 Mon Sep 17 00:00:00 2001 From: Charles Lowell Date: Tue, 30 Jul 2019 09:10:24 -0700 Subject: [PATCH 6/8] reorganize --- .../azure/identity/credentials.py | 128 +++++++++--------- 1 file changed, 65 insertions(+), 63 deletions(-) diff --git a/sdk/identity/azure-identity/azure/identity/credentials.py b/sdk/identity/azure-identity/azure/identity/credentials.py index beb6981fe3ef..7968c868527a 100644 --- a/sdk/identity/azure-identity/azure/identity/credentials.py +++ b/sdk/identity/azure-identity/azure/identity/credentials.py @@ -28,6 +28,7 @@ # pylint:disable=unused-import from typing import Any, Callable, Dict, Mapping, Optional, Union from azure.core.credentials import TokenCredential + EnvironmentCredentialTypes = Union["CertificateCredential", "ClientSecretCredential", "UsernamePasswordCredential"] # pylint:disable=too-few-public-methods @@ -249,68 +250,6 @@ def _get_error_message(history): return "No valid token received. {}".format(". ".join(attempts)) -class UsernamePasswordCredential(PublicClientCredential): - """ - Authenticates a user with a username and password. In general, Microsoft doesn't recommend this kind of - authentication, because it's less secure than other authentication flows. - - Authentication with this credential is not interactive, so it is **not compatible with any form of - multi-factor authentication or consent prompting**. The application must already have the user's consent. - - This credential can only authenticate work and school accounts; Microsoft accounts are not supported. - See this document for more information about account types: - https://docs.microsoft.com/en-us/azure/active-directory/fundamentals/sign-up-organization - - :param str client_id: the application's client ID - :param str username: the user's username (usually an email address) - :param str password: the user's password - - **Keyword arguments:** - - *tenant (str)* - a tenant ID or a domain associated with a tenant. If not provided, the credential defaults to the - 'organizations' tenant. - """ - - def __init__(self, client_id, username, password, **kwargs): - # type: (str, str, str, Any) -> None - super(UsernamePasswordCredential, self).__init__(client_id=client_id, **kwargs) - self._username = username - self._password = password - - def get_token(self, *scopes): - # type (*str) -> AccessToken - """ - Request an access token for `scopes`. - - :param str scopes: desired scopes for the token - :rtype: :class:`azure.core.credentials.AccessToken` - :raises: :class:`azure.core.exceptions.ClientAuthenticationError` - """ - - # MSAL requires scopes be a list - scopes = list(scopes) # type: ignore - now = int(time.time()) - - app = self._get_app() - accounts = app.get_accounts(username=self._username) - result = None - for account in accounts: - result = app.acquire_token_silent(scopes, account=account) - if result: - break - - if not result: - # cache miss -> request a new token - result = app.acquire_token_by_username_password( - username=self._username, password=self._password, scopes=scopes - ) - - if "access_token" not in result: - raise ClientAuthenticationError(message="authentication failed: {}".format(result.get("error_description"))) - - return AccessToken(result["access_token"], now + int(result["expires_in"])) - - class DeviceCodeCredential(PublicClientCredential): """ Authenticates users through the device code flow. When ``get_token`` is called, this credential acquires a @@ -359,7 +298,8 @@ def get_token(self, *scopes): scopes = list(scopes) # type: ignore now = int(time.time()) - flow = self._app.initiate_device_flow(scopes) + app = self._get_app() + flow = app.initiate_device_flow(scopes) if "error" in flow: raise ClientAuthenticationError( message="Couldn't begin authentication: {}".format(flow.get("error_description") or flow.get("error")) @@ -385,3 +325,65 @@ def get_token(self, *scopes): token = AccessToken(result["access_token"], now + int(result["expires_in"])) return token + + +class UsernamePasswordCredential(PublicClientCredential): + """ + Authenticates a user with a username and password. In general, Microsoft doesn't recommend this kind of + authentication, because it's less secure than other authentication flows. + + Authentication with this credential is not interactive, so it is **not compatible with any form of + multi-factor authentication or consent prompting**. The application must already have the user's consent. + + This credential can only authenticate work and school accounts; Microsoft accounts are not supported. + See this document for more information about account types: + https://docs.microsoft.com/en-us/azure/active-directory/fundamentals/sign-up-organization + + :param str client_id: the application's client ID + :param str username: the user's username (usually an email address) + :param str password: the user's password + + **Keyword arguments:** + + *tenant (str)* - a tenant ID or a domain associated with a tenant. If not provided, the credential defaults to the + 'organizations' tenant. + """ + + def __init__(self, client_id, username, password, **kwargs): + # type: (str, str, str, Any) -> None + super(UsernamePasswordCredential, self).__init__(client_id=client_id, **kwargs) + self._username = username + self._password = password + + def get_token(self, *scopes): + # type (*str) -> AccessToken + """ + Request an access token for `scopes`. + + :param str scopes: desired scopes for the token + :rtype: :class:`azure.core.credentials.AccessToken` + :raises: :class:`azure.core.exceptions.ClientAuthenticationError` + """ + + # MSAL requires scopes be a list + scopes = list(scopes) # type: ignore + now = int(time.time()) + + app = self._get_app() + accounts = app.get_accounts(username=self._username) + result = None + for account in accounts: + result = app.acquire_token_silent(scopes, account=account) + if result: + break + + if not result: + # cache miss -> request a new token + result = app.acquire_token_by_username_password( + username=self._username, password=self._password, scopes=scopes + ) + + if "access_token" not in result: + raise ClientAuthenticationError(message="authentication failed: {}".format(result.get("error_description"))) + + return AccessToken(result["access_token"], now + int(result["expires_in"])) From 39ea2a76f3d0bdba39dcca7da5904ccf49176cdb Mon Sep 17 00:00:00 2001 From: Charles Lowell Date: Tue, 30 Jul 2019 13:59:07 -0700 Subject: [PATCH 7/8] fix docstring formatting --- .../azure-identity/azure/identity/credentials.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/sdk/identity/azure-identity/azure/identity/credentials.py b/sdk/identity/azure-identity/azure/identity/credentials.py index 7968c868527a..575c69df6312 100644 --- a/sdk/identity/azure-identity/azure/identity/credentials.py +++ b/sdk/identity/azure-identity/azure/identity/credentials.py @@ -271,10 +271,12 @@ class DeviceCodeCredential(PublicClientCredential): **Keyword arguments:** - *tenant (str)* - a tenant ID or a domain associated with a tenant. If not provided, the credential defaults to the - 'organizations' tenant, which supports only Azure Active Directory work or school accounts. - *timeout (int)* - seconds to wait for the user to authenticate. Defaults to the validity period of the device code - as set by Azure Active Directory, which also prevails when ``timeout`` is longer. + - *tenant (str)* - tenant ID or a domain associated with a tenant. If not provided, the credential defaults to the + 'organizations' tenant, which supports only Azure Active Directory work or school accounts. + + - *timeout (int)* - seconds to wait for the user to authenticate. Defaults to the validity period of the device code + as set by Azure Active Directory, which also prevails when ``timeout`` is longer. + """ def __init__(self, client_id, prompt_callback=None, **kwargs): @@ -345,8 +347,9 @@ class UsernamePasswordCredential(PublicClientCredential): **Keyword arguments:** - *tenant (str)* - a tenant ID or a domain associated with a tenant. If not provided, the credential defaults to the - 'organizations' tenant. + - **tenant (str)** - a tenant ID or a domain associated with a tenant. If not provided, defaults to the + 'organizations' tenant. + """ def __init__(self, client_id, username, password, **kwargs): From 18a2a9508e299fc7ea67fdad628b02a358b8a496 Mon Sep 17 00:00:00 2001 From: Charles Lowell Date: Wed, 31 Jul 2019 09:27:08 -0700 Subject: [PATCH 8/8] prevent leaking MSAL exceptions --- .../azure/identity/_internal/__init__.py | 1 + .../identity/_internal/exception_wrapper.py | 25 +++++++++++++++++++ .../identity/_internal/msal_credentials.py | 2 ++ .../azure/identity/browser_auth.py | 3 ++- .../azure/identity/credentials.py | 4 ++- sdk/identity/azure-identity/setup.py | 2 +- 6 files changed, 34 insertions(+), 3 deletions(-) create mode 100644 sdk/identity/azure-identity/azure/identity/_internal/exception_wrapper.py diff --git a/sdk/identity/azure-identity/azure/identity/_internal/__init__.py b/sdk/identity/azure-identity/azure/identity/_internal/__init__.py index e7583139395b..dccfb9cf7e42 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/__init__.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/__init__.py @@ -3,5 +3,6 @@ # Licensed under the MIT License. # ------------------------------------ from .auth_code_redirect_handler import AuthCodeRedirectServer +from .exception_wrapper import wrap_exceptions from .msal_credentials import ConfidentialClientCredential, PublicClientCredential from .msal_transport_adapter import MsalTransportAdapter, MsalTransportResponse diff --git a/sdk/identity/azure-identity/azure/identity/_internal/exception_wrapper.py b/sdk/identity/azure-identity/azure/identity/_internal/exception_wrapper.py new file mode 100644 index 000000000000..2eca1e0a25ad --- /dev/null +++ b/sdk/identity/azure-identity/azure/identity/_internal/exception_wrapper.py @@ -0,0 +1,25 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +import functools + +from six import raise_from + +from azure.core.exceptions import ClientAuthenticationError + + +def wrap_exceptions(fn): + """Prevents leaking exceptions defined outside azure-core by raising ClientAuthenticationError from them.""" + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + try: + return fn(*args, **kwargs) + except ClientAuthenticationError: + raise + except Exception as ex: + auth_error = ClientAuthenticationError(message="Authentication failed: {}".format(ex)) + raise_from(auth_error, ex) + + return wrapper diff --git a/sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py b/sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py index 83906bf71c2e..165d7328d800 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py @@ -12,6 +12,7 @@ from azure.core.credentials import AccessToken from azure.core.exceptions import ClientAuthenticationError +from .exception_wrapper import wrap_exceptions from .msal_transport_adapter import MsalTransportAdapter try: @@ -75,6 +76,7 @@ def _create_app(self, cls): class ConfidentialClientCredential(MsalCredential): """Wraps an MSAL ConfidentialClientApplication with the TokenCredential API""" + @wrap_exceptions def get_token(self, *scopes): # type: (str) -> AccessToken diff --git a/sdk/identity/azure-identity/azure/identity/browser_auth.py b/sdk/identity/azure-identity/azure/identity/browser_auth.py index 82460513092d..921d70e635f2 100644 --- a/sdk/identity/azure-identity/azure/identity/browser_auth.py +++ b/sdk/identity/azure-identity/azure/identity/browser_auth.py @@ -18,7 +18,7 @@ from azure.core.credentials import AccessToken from azure.core.exceptions import ClientAuthenticationError -from ._internal import AuthCodeRedirectServer, ConfidentialClientCredential +from ._internal import AuthCodeRedirectServer, ConfidentialClientCredential, wrap_exceptions class InteractiveBrowserCredential(ConfidentialClientCredential): @@ -48,6 +48,7 @@ def __init__(self, client_id, client_secret, **kwargs): client_id=client_id, client_credential=client_secret, authority=authority, **kwargs ) + @wrap_exceptions def get_token(self, *scopes): # type: (str) -> AccessToken """ diff --git a/sdk/identity/azure-identity/azure/identity/credentials.py b/sdk/identity/azure-identity/azure/identity/credentials.py index 575c69df6312..5c96c0bf4dbc 100644 --- a/sdk/identity/azure-identity/azure/identity/credentials.py +++ b/sdk/identity/azure-identity/azure/identity/credentials.py @@ -15,7 +15,7 @@ from ._authn_client import AuthnClient from ._base import ClientSecretCredentialBase, CertificateCredentialBase -from ._internal import PublicClientCredential +from ._internal import PublicClientCredential, wrap_exceptions from ._managed_identity import ImdsCredential, MsiCredential from .constants import Endpoints, EnvironmentVariables @@ -285,6 +285,7 @@ def __init__(self, client_id, prompt_callback=None, **kwargs): self._prompt_callback = prompt_callback super(DeviceCodeCredential, self).__init__(client_id=client_id, **kwargs) + @wrap_exceptions def get_token(self, *scopes): # type (*str) -> AccessToken """ @@ -358,6 +359,7 @@ def __init__(self, client_id, username, password, **kwargs): self._username = username self._password = password + @wrap_exceptions def get_token(self, *scopes): # type (*str) -> AccessToken """ diff --git a/sdk/identity/azure-identity/setup.py b/sdk/identity/azure-identity/setup.py index ac671fd6e283..783ef669f72a 100644 --- a/sdk/identity/azure-identity/setup.py +++ b/sdk/identity/azure-identity/setup.py @@ -69,6 +69,6 @@ "azure", ] ), - install_requires=["azure-core<2.0.0,>=1.0.0b1", "cryptography>=2.1.4", "msal~=0.4.1"], + install_requires=["azure-core<2.0.0,>=1.0.0b1", "cryptography>=2.1.4", "msal~=0.4.1", "six>=1.6"], extras_require={":python_version<'3.0'": ["azure-nspkg"], ":python_version<'3.5'": ["typing"]}, )