Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement synchronous device code credential #6464

Merged
merged 8 commits into from
Aug 1, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sdk/identity/azure-identity/azure/identity/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
CertificateCredential,
ChainedTokenCredential,
ClientSecretCredential,
DeviceCodeCredential,
EnvironmentCredential,
ManagedIdentityCredential,
UsernamePasswordCredential,
Expand Down Expand Up @@ -35,6 +36,7 @@ def __init__(self, **kwargs):
"ChainedTokenCredential",
"ClientSecretCredential",
"DefaultAzureCredential",
"DeviceCodeCredential",
"EnvironmentCredential",
"InteractiveBrowserCredential",
"ManagedIdentityCredential",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@
# Licensed under the MIT License.
# ------------------------------------
from .auth_code_redirect_handler import AuthCodeRedirectServer
from .exception_wrapper import wrap_exceptions
from .msal_credentials import ConfidentialClientCredential, PublicClientCredential
from .msal_transport_adapter import MsalTransportAdapter, MsalTransportResponse
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import functools

from six import raise_from

from azure.core.exceptions import ClientAuthenticationError


def wrap_exceptions(fn):
"""Prevents leaking exceptions defined outside azure-core by raising ClientAuthenticationError from them."""

@functools.wraps(fn)
def wrapper(*args, **kwargs):
try:
return fn(*args, **kwargs)
except ClientAuthenticationError:
raise
except Exception as ex:
auth_error = ClientAuthenticationError(message="Authentication failed: {}".format(ex))
raise_from(auth_error, ex)

return wrapper
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from azure.core.credentials import AccessToken
from azure.core.exceptions import ClientAuthenticationError

from .exception_wrapper import wrap_exceptions
from .msal_transport_adapter import MsalTransportAdapter

try:
Expand Down Expand Up @@ -75,6 +76,7 @@ def _create_app(self, cls):
class ConfidentialClientCredential(MsalCredential):
"""Wraps an MSAL ConfidentialClientApplication with the TokenCredential API"""

@wrap_exceptions
def get_token(self, *scopes):
# type: (str) -> AccessToken

Expand Down
3 changes: 2 additions & 1 deletion sdk/identity/azure-identity/azure/identity/browser_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from azure.core.credentials import AccessToken
from azure.core.exceptions import ClientAuthenticationError

from ._internal import AuthCodeRedirectServer, ConfidentialClientCredential
from ._internal import AuthCodeRedirectServer, ConfidentialClientCredential, wrap_exceptions


class InteractiveBrowserCredential(ConfidentialClientCredential):
Expand Down Expand Up @@ -48,6 +48,7 @@ def __init__(self, client_id, client_secret, **kwargs):
client_id=client_id, client_credential=client_secret, authority=authority, **kwargs
)

@wrap_exceptions
def get_token(self, *scopes):
# type: (str) -> AccessToken
"""
Expand Down
91 changes: 87 additions & 4 deletions sdk/identity/azure-identity/azure/identity/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from ._authn_client import AuthnClient
from ._base import ClientSecretCredentialBase, CertificateCredentialBase
from ._internal import PublicClientCredential
from ._internal import PublicClientCredential, wrap_exceptions
from ._managed_identity import ImdsCredential, MsiCredential
from .constants import Endpoints, EnvironmentVariables

Expand All @@ -26,8 +26,9 @@

if TYPE_CHECKING:
# pylint:disable=unused-import
from typing import Any, Dict, Mapping, Optional, Union
from typing import Any, Callable, Dict, Mapping, Optional, Union
from azure.core.credentials import TokenCredential

EnvironmentCredentialTypes = Union["CertificateCredential", "ClientSecretCredential", "UsernamePasswordCredential"]

# pylint:disable=too-few-public-methods
Expand Down Expand Up @@ -249,6 +250,86 @@ def _get_error_message(history):
return "No valid token received. {}".format(". ".join(attempts))


class DeviceCodeCredential(PublicClientCredential):
"""
Authenticates users through the device code flow. When ``get_token`` is called, this credential acquires a
verification URL and code from Azure Active Directory. A user must browse to the URL, enter the code, and
authenticate with Directory. If the user authenticates successfully, the credential receives an access token.
This credential doesn't cache tokens--each ``get_token`` call begins a new authentication flow.
For more information about the device code flow, see Azure Active Directory documentation:
https://docs.microsoft.com/en-us/azure/active-directory/develop/v2-oauth2-device-code
:param str client_id: the application's ID
:param prompt_callback: (optional) A callback enabling control of how authentication instructions are presented.
If not provided, the credential will print instructions to stdout.
:type prompt_callback: A callable accepting arguments (``verification_uri``, ``user_code``, ``expires_in``):
- ``verification_uri`` (str) the URL the user must visit
- ``user_code`` (str) the code the user must enter there
- ``expires_in`` (int) the number of seconds the code will be valid
**Keyword arguments:**
- *tenant (str)* - tenant ID or a domain associated with a tenant. If not provided, the credential defaults to the
'organizations' tenant, which supports only Azure Active Directory work or school accounts.
- *timeout (int)* - seconds to wait for the user to authenticate. Defaults to the validity period of the device code
as set by Azure Active Directory, which also prevails when ``timeout`` is longer.
"""

def __init__(self, client_id, prompt_callback=None, **kwargs):
# type: (str, Optional[Callable[[str, str], None]], Any) -> None
self._timeout = kwargs.pop("timeout", None) # type: Optional[int]
self._prompt_callback = prompt_callback
super(DeviceCodeCredential, self).__init__(client_id=client_id, **kwargs)

@wrap_exceptions
def get_token(self, *scopes):
# type (*str) -> AccessToken
"""
Request an access token for `scopes`. This credential won't cache the token. Each call begins a new
authentication flow.
:param str scopes: desired scopes for the token
:rtype: :class:`azure.core.credentials.AccessToken`
:raises: :class:`azure.core.exceptions.ClientAuthenticationError`
"""

# MSAL requires scopes be a list
scopes = list(scopes) # type: ignore
now = int(time.time())

app = self._get_app()
flow = app.initiate_device_flow(scopes)
if "error" in flow:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if "error" in flow: [](start = 8, length = 19)

Is this the best way to check for failure?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Failure of AAD, yes, I think so; flow here is the AAD response payload, which has a value for error when AAD can't start the flow.

raise ClientAuthenticationError(
message="Couldn't begin authentication: {}".format(flow.get("error_description") or flow.get("error"))
)

if self._prompt_callback:
self._prompt_callback(flow["verification_uri"], flow["user_code"], flow["expires_in"])
else:
print(flow["message"])

if self._timeout is not None and self._timeout < flow["expires_in"]:
deadline = now + self._timeout
result = app.acquire_token_by_device_flow(flow, exit_condition=lambda flow: time.time() > deadline)
else:
result = app.acquire_token_by_device_flow(flow)

if "access_token" not in result:
if result.get("error") == "authorization_pending":
message = "Timed out waiting for user to authenticate"
else:
message = "Authentication failed: {}".format(result.get("error_description") or result.get("error"))
raise ClientAuthenticationError(message=message)

token = AccessToken(result["access_token"], now + int(result["expires_in"]))
return token


class UsernamePasswordCredential(PublicClientCredential):
"""
Authenticates a user with a username and password. In general, Microsoft doesn't recommend this kind of
Expand All @@ -267,8 +348,9 @@ class UsernamePasswordCredential(PublicClientCredential):
**Keyword arguments:**
*tenant (str)* - a tenant ID or a domain associated with a tenant. If not provided, the credential defaults to the
'organizations' tenant.
- **tenant (str)** - a tenant ID or a domain associated with a tenant. If not provided, defaults to the
'organizations' tenant.
"""

def __init__(self, client_id, username, password, **kwargs):
Expand All @@ -277,6 +359,7 @@ def __init__(self, client_id, username, password, **kwargs):
self._username = username
self._password = password

@wrap_exceptions
def get_token(self, *scopes):
# type (*str) -> AccessToken
"""
Expand Down
2 changes: 1 addition & 1 deletion sdk/identity/azure-identity/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,6 @@
"azure",
]
),
install_requires=["azure-core<2.0.0,>=1.0.0b1", "cryptography>=2.1.4", "msal~=0.4.1"],
install_requires=["azure-core<2.0.0,>=1.0.0b1", "cryptography>=2.1.4", "msal~=0.4.1", "six>=1.6"],
extras_require={":python_version<'3.0'": ["azure-nspkg"], ":python_version<'3.5'": ["typing"]},
)
69 changes: 62 additions & 7 deletions sdk/identity/azure-identity/tests/test_identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,21 @@
except ImportError: # python < 3.3
from mock import Mock, patch

import pytest
from azure.core.credentials import AccessToken
from azure.core.exceptions import ClientAuthenticationError
from azure.identity import (
ChainedTokenCredential,
ClientSecretCredential,
DefaultAzureCredential,
DeviceCodeCredential,
EnvironmentCredential,
ManagedIdentityCredential,
ChainedTokenCredential,
InteractiveBrowserCredential,
UsernamePasswordCredential,
)
from azure.identity._managed_identity import ImdsCredential
from azure.identity.constants import EnvironmentVariables
import pytest

from helpers import mock_response, Request, validating_transport

Expand Down Expand Up @@ -123,11 +124,6 @@ def test_client_secret_environment_credential(monkeypatch):
assert token.token == access_token


def test_environment_credential_error():
with pytest.raises(ClientAuthenticationError):
EnvironmentCredential().get_token("scope")


def test_credential_chain_error_message():
def raise_authn_error(message):
raise ClientAuthenticationError(message)
Expand Down Expand Up @@ -244,6 +240,65 @@ def test_default_credential():
DefaultAzureCredential()


def test_device_code_credential():
expected_token = "access-token"
user_code = "user-code"
verification_uri = "verification-uri"
expires_in = 42

transport = validating_transport(
requests=[Request()] * 3, # not validating requests because they're formed by MSAL
responses=[
# expected requests: discover tenant, start device code flow, poll for completion
mock_response(json_payload={"authorization_endpoint": "https://a/b", "token_endpoint": "https://a/b"}),
mock_response(
json_payload={"device_code": "_", "user_code": user_code, "verification_uri": verification_uri, "expires_in": expires_in}
),
mock_response(
json_payload={
"access_token": expected_token,
"expires_in": expires_in,
"scope": "scope",
"token_type": "Bearer",
"refresh_token": "_",
}
),
],
)

callback = Mock()
credential = DeviceCodeCredential(
client_id="_", prompt_callback=callback, transport=transport, instance_discovery=False
)

token = credential.get_token("scope")
assert token.token == expected_token

# prompt_callback should have been called as documented
assert callback.call_count == 1
assert callback.call_args[0] == (verification_uri, user_code, expires_in)


def test_device_code_credential_timeout():
transport = validating_transport(
requests=[Request()] * 3, # not validating requests because they're formed by MSAL
responses=[
# expected requests: discover tenant, start device code flow, poll for completion
mock_response(json_payload={"authorization_endpoint": "https://a/b", "token_endpoint": "https://a/b"}),
mock_response(json_payload={"device_code": "_", "user_code": "_", "verification_uri": "_"}),
mock_response(json_payload={"error": "authorization_pending"}),
],
)

credential = DeviceCodeCredential(
client_id="_", prompt_callback=Mock(), transport=transport, timeout=0.1, instance_discovery=False
)

with pytest.raises(ClientAuthenticationError) as ex:
credential.get_token("scope")
assert "timed out" in ex.value.message.lower()


@patch("azure.identity.browser_auth.webbrowser.open", lambda _: None) # prevent the credential opening a browser
def test_interactive_credential():
oauth_state = "state"
Expand Down
6 changes: 0 additions & 6 deletions sdk/identity/azure-identity/tests/test_identity_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,6 @@ async def test_client_secret_environment_credential(monkeypatch):
assert token.token == access_token


@pytest.mark.asyncio
async def test_environment_credential_error():
with pytest.raises(ClientAuthenticationError):
await EnvironmentCredential().get_token("scope")


@pytest.mark.asyncio
async def test_credential_chain_error_message():
def raise_authn_error(message):
Expand Down