diff --git a/sdk/identity/azure-identity/azure/identity/__init__.py b/sdk/identity/azure-identity/azure/identity/__init__.py index 1df14dd617f5..5a9ccc7aad36 100644 --- a/sdk/identity/azure-identity/azure/identity/__init__.py +++ b/sdk/identity/azure-identity/azure/identity/__init__.py @@ -7,6 +7,7 @@ CertificateCredential, ChainedTokenCredential, ClientSecretCredential, + DeviceCodeCredential, EnvironmentCredential, ManagedIdentityCredential, UsernamePasswordCredential, @@ -35,6 +36,7 @@ def __init__(self, **kwargs): "ChainedTokenCredential", "ClientSecretCredential", "DefaultAzureCredential", + "DeviceCodeCredential", "EnvironmentCredential", "InteractiveBrowserCredential", "ManagedIdentityCredential", diff --git a/sdk/identity/azure-identity/azure/identity/_internal/__init__.py b/sdk/identity/azure-identity/azure/identity/_internal/__init__.py index e7583139395b..dccfb9cf7e42 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/__init__.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/__init__.py @@ -3,5 +3,6 @@ # Licensed under the MIT License. # ------------------------------------ from .auth_code_redirect_handler import AuthCodeRedirectServer +from .exception_wrapper import wrap_exceptions from .msal_credentials import ConfidentialClientCredential, PublicClientCredential from .msal_transport_adapter import MsalTransportAdapter, MsalTransportResponse diff --git a/sdk/identity/azure-identity/azure/identity/_internal/exception_wrapper.py b/sdk/identity/azure-identity/azure/identity/_internal/exception_wrapper.py new file mode 100644 index 000000000000..2eca1e0a25ad --- /dev/null +++ b/sdk/identity/azure-identity/azure/identity/_internal/exception_wrapper.py @@ -0,0 +1,25 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +import functools + +from six import raise_from + +from azure.core.exceptions import ClientAuthenticationError + + +def wrap_exceptions(fn): + """Prevents leaking exceptions defined outside azure-core by raising ClientAuthenticationError from them.""" + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + try: + return fn(*args, **kwargs) + except ClientAuthenticationError: + raise + except Exception as ex: + auth_error = ClientAuthenticationError(message="Authentication failed: {}".format(ex)) + raise_from(auth_error, ex) + + return wrapper diff --git a/sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py b/sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py index 83906bf71c2e..165d7328d800 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py @@ -12,6 +12,7 @@ from azure.core.credentials import AccessToken from azure.core.exceptions import ClientAuthenticationError +from .exception_wrapper import wrap_exceptions from .msal_transport_adapter import MsalTransportAdapter try: @@ -75,6 +76,7 @@ def _create_app(self, cls): class ConfidentialClientCredential(MsalCredential): """Wraps an MSAL ConfidentialClientApplication with the TokenCredential API""" + @wrap_exceptions def get_token(self, *scopes): # type: (str) -> AccessToken diff --git a/sdk/identity/azure-identity/azure/identity/browser_auth.py b/sdk/identity/azure-identity/azure/identity/browser_auth.py index 82460513092d..921d70e635f2 100644 --- a/sdk/identity/azure-identity/azure/identity/browser_auth.py +++ b/sdk/identity/azure-identity/azure/identity/browser_auth.py @@ -18,7 +18,7 @@ from azure.core.credentials import AccessToken from azure.core.exceptions import ClientAuthenticationError -from ._internal import AuthCodeRedirectServer, ConfidentialClientCredential +from ._internal import AuthCodeRedirectServer, ConfidentialClientCredential, wrap_exceptions class InteractiveBrowserCredential(ConfidentialClientCredential): @@ -48,6 +48,7 @@ def __init__(self, client_id, client_secret, **kwargs): client_id=client_id, client_credential=client_secret, authority=authority, **kwargs ) + @wrap_exceptions def get_token(self, *scopes): # type: (str) -> AccessToken """ diff --git a/sdk/identity/azure-identity/azure/identity/credentials.py b/sdk/identity/azure-identity/azure/identity/credentials.py index 2e09a306aa71..5c96c0bf4dbc 100644 --- a/sdk/identity/azure-identity/azure/identity/credentials.py +++ b/sdk/identity/azure-identity/azure/identity/credentials.py @@ -15,7 +15,7 @@ from ._authn_client import AuthnClient from ._base import ClientSecretCredentialBase, CertificateCredentialBase -from ._internal import PublicClientCredential +from ._internal import PublicClientCredential, wrap_exceptions from ._managed_identity import ImdsCredential, MsiCredential from .constants import Endpoints, EnvironmentVariables @@ -26,8 +26,9 @@ if TYPE_CHECKING: # pylint:disable=unused-import - from typing import Any, Dict, Mapping, Optional, Union + from typing import Any, Callable, Dict, Mapping, Optional, Union from azure.core.credentials import TokenCredential + EnvironmentCredentialTypes = Union["CertificateCredential", "ClientSecretCredential", "UsernamePasswordCredential"] # pylint:disable=too-few-public-methods @@ -249,6 +250,86 @@ def _get_error_message(history): return "No valid token received. {}".format(". ".join(attempts)) +class DeviceCodeCredential(PublicClientCredential): + """ + Authenticates users through the device code flow. When ``get_token`` is called, this credential acquires a + verification URL and code from Azure Active Directory. A user must browse to the URL, enter the code, and + authenticate with Directory. If the user authenticates successfully, the credential receives an access token. + + This credential doesn't cache tokens--each ``get_token`` call begins a new authentication flow. + + For more information about the device code flow, see Azure Active Directory documentation: + https://docs.microsoft.com/en-us/azure/active-directory/develop/v2-oauth2-device-code + + :param str client_id: the application's ID + :param prompt_callback: (optional) A callback enabling control of how authentication instructions are presented. + If not provided, the credential will print instructions to stdout. + :type prompt_callback: A callable accepting arguments (``verification_uri``, ``user_code``, ``expires_in``): + - ``verification_uri`` (str) the URL the user must visit + - ``user_code`` (str) the code the user must enter there + - ``expires_in`` (int) the number of seconds the code will be valid + + **Keyword arguments:** + + - *tenant (str)* - tenant ID or a domain associated with a tenant. If not provided, the credential defaults to the + 'organizations' tenant, which supports only Azure Active Directory work or school accounts. + + - *timeout (int)* - seconds to wait for the user to authenticate. Defaults to the validity period of the device code + as set by Azure Active Directory, which also prevails when ``timeout`` is longer. + + """ + + def __init__(self, client_id, prompt_callback=None, **kwargs): + # type: (str, Optional[Callable[[str, str], None]], Any) -> None + self._timeout = kwargs.pop("timeout", None) # type: Optional[int] + self._prompt_callback = prompt_callback + super(DeviceCodeCredential, self).__init__(client_id=client_id, **kwargs) + + @wrap_exceptions + def get_token(self, *scopes): + # type (*str) -> AccessToken + """ + Request an access token for `scopes`. This credential won't cache the token. Each call begins a new + authentication flow. + + :param str scopes: desired scopes for the token + :rtype: :class:`azure.core.credentials.AccessToken` + :raises: :class:`azure.core.exceptions.ClientAuthenticationError` + """ + + # MSAL requires scopes be a list + scopes = list(scopes) # type: ignore + now = int(time.time()) + + app = self._get_app() + flow = app.initiate_device_flow(scopes) + if "error" in flow: + raise ClientAuthenticationError( + message="Couldn't begin authentication: {}".format(flow.get("error_description") or flow.get("error")) + ) + + if self._prompt_callback: + self._prompt_callback(flow["verification_uri"], flow["user_code"], flow["expires_in"]) + else: + print(flow["message"]) + + if self._timeout is not None and self._timeout < flow["expires_in"]: + deadline = now + self._timeout + result = app.acquire_token_by_device_flow(flow, exit_condition=lambda flow: time.time() > deadline) + else: + result = app.acquire_token_by_device_flow(flow) + + if "access_token" not in result: + if result.get("error") == "authorization_pending": + message = "Timed out waiting for user to authenticate" + else: + message = "Authentication failed: {}".format(result.get("error_description") or result.get("error")) + raise ClientAuthenticationError(message=message) + + token = AccessToken(result["access_token"], now + int(result["expires_in"])) + return token + + class UsernamePasswordCredential(PublicClientCredential): """ Authenticates a user with a username and password. In general, Microsoft doesn't recommend this kind of @@ -267,8 +348,9 @@ class UsernamePasswordCredential(PublicClientCredential): **Keyword arguments:** - *tenant (str)* - a tenant ID or a domain associated with a tenant. If not provided, the credential defaults to the - 'organizations' tenant. + - **tenant (str)** - a tenant ID or a domain associated with a tenant. If not provided, defaults to the + 'organizations' tenant. + """ def __init__(self, client_id, username, password, **kwargs): @@ -277,6 +359,7 @@ def __init__(self, client_id, username, password, **kwargs): self._username = username self._password = password + @wrap_exceptions def get_token(self, *scopes): # type (*str) -> AccessToken """ diff --git a/sdk/identity/azure-identity/setup.py b/sdk/identity/azure-identity/setup.py index ac671fd6e283..783ef669f72a 100644 --- a/sdk/identity/azure-identity/setup.py +++ b/sdk/identity/azure-identity/setup.py @@ -69,6 +69,6 @@ "azure", ] ), - install_requires=["azure-core<2.0.0,>=1.0.0b1", "cryptography>=2.1.4", "msal~=0.4.1"], + install_requires=["azure-core<2.0.0,>=1.0.0b1", "cryptography>=2.1.4", "msal~=0.4.1", "six>=1.6"], extras_require={":python_version<'3.0'": ["azure-nspkg"], ":python_version<'3.5'": ["typing"]}, ) diff --git a/sdk/identity/azure-identity/tests/test_identity.py b/sdk/identity/azure-identity/tests/test_identity.py index 3aaa54a1f28f..e3cfc64544c0 100644 --- a/sdk/identity/azure-identity/tests/test_identity.py +++ b/sdk/identity/azure-identity/tests/test_identity.py @@ -13,20 +13,21 @@ except ImportError: # python < 3.3 from mock import Mock, patch -import pytest from azure.core.credentials import AccessToken from azure.core.exceptions import ClientAuthenticationError from azure.identity import ( + ChainedTokenCredential, ClientSecretCredential, DefaultAzureCredential, + DeviceCodeCredential, EnvironmentCredential, ManagedIdentityCredential, - ChainedTokenCredential, InteractiveBrowserCredential, UsernamePasswordCredential, ) from azure.identity._managed_identity import ImdsCredential from azure.identity.constants import EnvironmentVariables +import pytest from helpers import mock_response, Request, validating_transport @@ -123,11 +124,6 @@ def test_client_secret_environment_credential(monkeypatch): assert token.token == access_token -def test_environment_credential_error(): - with pytest.raises(ClientAuthenticationError): - EnvironmentCredential().get_token("scope") - - def test_credential_chain_error_message(): def raise_authn_error(message): raise ClientAuthenticationError(message) @@ -244,6 +240,65 @@ def test_default_credential(): DefaultAzureCredential() +def test_device_code_credential(): + expected_token = "access-token" + user_code = "user-code" + verification_uri = "verification-uri" + expires_in = 42 + + transport = validating_transport( + requests=[Request()] * 3, # not validating requests because they're formed by MSAL + responses=[ + # expected requests: discover tenant, start device code flow, poll for completion + mock_response(json_payload={"authorization_endpoint": "https://a/b", "token_endpoint": "https://a/b"}), + mock_response( + json_payload={"device_code": "_", "user_code": user_code, "verification_uri": verification_uri, "expires_in": expires_in} + ), + mock_response( + json_payload={ + "access_token": expected_token, + "expires_in": expires_in, + "scope": "scope", + "token_type": "Bearer", + "refresh_token": "_", + } + ), + ], + ) + + callback = Mock() + credential = DeviceCodeCredential( + client_id="_", prompt_callback=callback, transport=transport, instance_discovery=False + ) + + token = credential.get_token("scope") + assert token.token == expected_token + + # prompt_callback should have been called as documented + assert callback.call_count == 1 + assert callback.call_args[0] == (verification_uri, user_code, expires_in) + + +def test_device_code_credential_timeout(): + transport = validating_transport( + requests=[Request()] * 3, # not validating requests because they're formed by MSAL + responses=[ + # expected requests: discover tenant, start device code flow, poll for completion + mock_response(json_payload={"authorization_endpoint": "https://a/b", "token_endpoint": "https://a/b"}), + mock_response(json_payload={"device_code": "_", "user_code": "_", "verification_uri": "_"}), + mock_response(json_payload={"error": "authorization_pending"}), + ], + ) + + credential = DeviceCodeCredential( + client_id="_", prompt_callback=Mock(), transport=transport, timeout=0.1, instance_discovery=False + ) + + with pytest.raises(ClientAuthenticationError) as ex: + credential.get_token("scope") + assert "timed out" in ex.value.message.lower() + + @patch("azure.identity.browser_auth.webbrowser.open", lambda _: None) # prevent the credential opening a browser def test_interactive_credential(): oauth_state = "state" diff --git a/sdk/identity/azure-identity/tests/test_identity_async.py b/sdk/identity/azure-identity/tests/test_identity_async.py index 03b53db9eccd..06254dc898ba 100644 --- a/sdk/identity/azure-identity/tests/test_identity_async.py +++ b/sdk/identity/azure-identity/tests/test_identity_async.py @@ -121,12 +121,6 @@ async def test_client_secret_environment_credential(monkeypatch): assert token.token == access_token -@pytest.mark.asyncio -async def test_environment_credential_error(): - with pytest.raises(ClientAuthenticationError): - await EnvironmentCredential().get_token("scope") - - @pytest.mark.asyncio async def test_credential_chain_error_message(): def raise_authn_error(message):