Skip to content

Commit

Permalink
Reimplement AadClient without msal.oauth2cli (#11466)
Browse files Browse the repository at this point in the history
  • Loading branch information
chlowell authored May 29, 2020
1 parent 3a05ac7 commit ee78e5d
Show file tree
Hide file tree
Showing 18 changed files with 515 additions and 289 deletions.
4 changes: 4 additions & 0 deletions sdk/identity/azure-identity/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# Release History

## 1.4.0b4 (Unreleased)
- `azure.identity.aio.AuthorizationCodeCredential.get_token()` no longer accepts
optional keyword arguments `executor` or `loop`. Prior versions of the method
didn't use these correctly, provoking exceptions, and internal changes in this
version have made them obsolete.
- `InteractiveBrowserCredential` raises `CredentialUnavailableError` when it
can't start an HTTP server on `localhost`.
([#11665](https://github.com/Azure/azure-sdk-for-python/pull/11665))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

if TYPE_CHECKING:
# pylint:disable=unused-import,ungrouped-imports
from typing import Any, Iterable, Optional
from typing import Any, Optional, Sequence
from azure.core.credentials import AccessToken


Expand Down Expand Up @@ -59,7 +59,7 @@ def get_token(self, *scopes, **kwargs):

if self._authorization_code:
token = self._client.obtain_token_by_authorization_code(
code=self._authorization_code, redirect_uri=self._redirect_uri, scopes=scopes, **kwargs
scopes=scopes, code=self._authorization_code, redirect_uri=self._redirect_uri, **kwargs
)
self._authorization_code = None # auth codes are single-use
return token
Expand All @@ -73,9 +73,11 @@ def get_token(self, *scopes, **kwargs):
return token

def _redeem_refresh_token(self, scopes, **kwargs):
# type: (Iterable[str], **Any) -> Optional[AccessToken]
# type: (Sequence[str], **Any) -> Optional[AccessToken]
for refresh_token in self._client.get_cached_refresh_tokens(scopes):
token = self._client.obtain_token_by_refresh_token(refresh_token, scopes, **kwargs)
if "secret" not in refresh_token:
continue
token = self._client.obtain_token_by_refresh_token(scopes, refresh_token["secret"], **kwargs)
if token:
return token
return None
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument

# try each refresh token, returning the first access token acquired
for refresh_token in self._get_refresh_tokens(account):
token = self._client.obtain_token_by_refresh_token(refresh_token, scopes)
token = self._client.obtain_token_by_refresh_token(scopes, refresh_token)
return token

raise CredentialUnavailableError(message=NO_TOKEN.format(account.get("username")))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,5 +56,5 @@ def get_token(self, *scopes, **kwargs):
if not self._refresh_token:
raise CredentialUnavailableError(message="No Azure user is logged in to Visual Studio Code.")

token = self._client.obtain_token_by_refresh_token(self._refresh_token, scopes, **kwargs)
token = self._client.obtain_token_by_refresh_token(scopes, self._refresh_token, **kwargs)
return token
75 changes: 62 additions & 13 deletions sdk/identity/azure-identity/azure/identity/_internal/aad_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,78 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
"""A thin wrapper around MSAL's token cache and OAuth 2 client"""

import time
from typing import TYPE_CHECKING

from azure.core.credentials import AccessToken
from azure.core.configuration import Configuration
from azure.core.pipeline import Pipeline
from azure.core.pipeline.policies import (
NetworkTraceLoggingPolicy,
RetryPolicy,
ProxyPolicy,
UserAgentPolicy,
ContentDecodePolicy,
DistributedTracingPolicy,
HttpLoggingPolicy,
)

from .aad_client_base import AadClientBase
from .msal_transport_adapter import MsalTransportAdapter
from .exception_wrapper import wrap_exceptions
from .user_agent import USER_AGENT

if TYPE_CHECKING:
# pylint:disable=unused-import,ungrouped-imports
from typing import Any, Callable, Iterable
from typing import Any, List, Optional, Sequence, Union
from azure.core.credentials import AccessToken
from azure.core.pipeline.policies import HTTPPolicy, SansIOHTTPPolicy
from azure.core.pipeline.transport import HttpTransport

Policy = Union[HTTPPolicy, SansIOHTTPPolicy]


class AadClient(AadClientBase):
def _get_client_session(self, **kwargs):
return MsalTransportAdapter(**kwargs)
def obtain_token_by_authorization_code(self, scopes, code, redirect_uri, client_secret=None, **kwargs):
# type: (str, str, Sequence[str], Optional[str], **Any) -> AccessToken
request = self._get_auth_code_request(
scopes=scopes, code=code, redirect_uri=redirect_uri, client_secret=client_secret
)
now = int(time.time())
response = self._pipeline.run(request, stream=False, **kwargs)
content = ContentDecodePolicy.deserialize_from_http_generics(response.http_response)
return self._process_response(response=content, scopes=scopes, now=now)

@wrap_exceptions
def _obtain_token(self, scopes, fn, **kwargs): # pylint:disable=unused-argument
# type: (Iterable[str], Callable, **Any) -> AccessToken
def obtain_token_by_refresh_token(self, scopes, refresh_token, **kwargs):
# type: (str, Sequence[str], **Any) -> AccessToken
request = self._get_refresh_token_request(scopes, refresh_token)
now = int(time.time())
response = fn()
return self._process_response(response=response, scopes=scopes, now=now)
response = self._pipeline.run(request, stream=False, **kwargs)
content = ContentDecodePolicy.deserialize_from_http_generics(response.http_response)
return self._process_response(response=content, scopes=scopes, now=now)

# pylint:disable=no-self-use
def _build_pipeline(self, config=None, policies=None, transport=None, **kwargs):
# type: (Optional[Configuration], Optional[List[Policy]], Optional[HttpTransport], **Any) -> Pipeline
config = config or _create_config(**kwargs)
policies = policies or [
config.user_agent_policy,
config.proxy_policy,
config.retry_policy,
config.logging_policy,
DistributedTracingPolicy(**kwargs),
HttpLoggingPolicy(**kwargs),
]
if not transport:
from azure.core.pipeline.transport import RequestsTransport

transport = RequestsTransport(**kwargs)

return Pipeline(transport=transport, policies=policies)


def _create_config(**kwargs):
# type: (**Any) -> Configuration
config = Configuration(**kwargs)
config.logging_policy = NetworkTraceLoggingPolicy(**kwargs)
config.retry_policy = RetryPolicy(**kwargs)
config.proxy_policy = ProxyPolicy(**kwargs)
config.user_agent_policy = UserAgentPolicy(base_user_agent=USER_AGENT, **kwargs)
return config
Original file line number Diff line number Diff line change
Expand Up @@ -4,50 +4,48 @@
# ------------------------------------
import abc
import copy
import functools
import time

try:
from typing import TYPE_CHECKING
except ImportError:
TYPE_CHECKING = False

from msal import TokenCache
from msal.oauth2cli.oauth2 import Client

from azure.core.pipeline.transport import HttpRequest
from azure.core.credentials import AccessToken
from azure.core.exceptions import ClientAuthenticationError
from . import get_default_authority, normalize_authority

try:
from typing import TYPE_CHECKING
except ImportError:
TYPE_CHECKING = False

try:
ABC = abc.ABC
except AttributeError: # Python 2.7, abc exists, but not ABC
ABC = abc.ABCMeta("ABC", (object,), {"__slots__": ()}) # type: ignore

if TYPE_CHECKING:
# pylint:disable=unused-import,ungrouped-imports
from typing import Any, Callable, Iterable, Optional
from typing import Any, Optional, Sequence, Union
from azure.core.pipeline import AsyncPipeline, Pipeline
from azure.core.pipeline.policies import AsyncHTTPPolicy, HTTPPolicy, SansIOHTTPPolicy
from azure.core.pipeline.transport import AsyncHttpTransport, HttpTransport

PipelineType = Union[AsyncPipeline, Pipeline]
PolicyType = Union[AsyncHTTPPolicy, HTTPPolicy, SansIOHTTPPolicy]
TransportType = Union[AsyncHttpTransport, HttpTransport]

class AadClientBase(ABC):
"""Sans I/O methods for AAD clients wrapping MSAL's OAuth client"""

def __init__(self, tenant_id, client_id, cache=None, **kwargs):
# type: (str, str, Optional[TokenCache], **Any) -> None
authority = kwargs.pop("authority", None)
class AadClientBase(ABC):
def __init__(self, tenant_id, client_id, authority=None, cache=None, **kwargs):
# type: (str, str, Optional[str], Optional[TokenCache], **Any) -> None
authority = normalize_authority(authority) if authority else get_default_authority()

token_endpoint = "/".join((authority, tenant_id, "oauth2/v2.0/token"))
config = {"token_endpoint": token_endpoint}

self._token_endpoint = "/".join((authority, tenant_id, "oauth2/v2.0/token"))
self._cache = cache or TokenCache()

self._client = Client(server_configuration=config, client_id=client_id)
self._client.session.close()
self._client.session = self._get_client_session(**kwargs)
self._client_id = client_id
self._pipeline = self._build_pipeline(**kwargs)

def get_cached_access_token(self, scopes):
# type: (Iterable[str]) -> Optional[AccessToken]
# type: (Sequence[str]) -> Optional[AccessToken]
tokens = self._cache.find(TokenCache.CredentialType.ACCESS_TOKEN, target=list(scopes))
for token in tokens:
expires_on = int(token["expires_on"])
Expand All @@ -56,35 +54,30 @@ def get_cached_access_token(self, scopes):
return None

def get_cached_refresh_tokens(self, scopes):
# type: (Sequence[str]) -> Sequence[dict]
"""Assumes all cached refresh tokens belong to the same user"""
return self._cache.find(TokenCache.CredentialType.REFRESH_TOKEN, target=list(scopes))

def obtain_token_by_authorization_code(self, code, redirect_uri, scopes, **kwargs):
# type: (str, str, Iterable[str], **Any) -> AccessToken
fn = functools.partial(
self._client.obtain_token_by_authorization_code, code=code, redirect_uri=redirect_uri, **kwargs
)
return self._obtain_token(scopes, fn, **kwargs)

def obtain_token_by_refresh_token(self, refresh_token, scopes, **kwargs):
# type: (str, Iterable[str], **Any) -> AccessToken
fn = functools.partial(
self._client.obtain_token_by_refresh_token,
token_item=refresh_token,
scope=scopes,
rt_getter=lambda token: token["secret"],
**kwargs
)
return self._obtain_token(scopes, fn, **kwargs)
@abc.abstractmethod
def obtain_token_by_authorization_code(self, scopes, code, redirect_uri, client_secret=None, **kwargs):
pass

@abc.abstractmethod
def obtain_token_by_refresh_token(self, scopes, refresh_token, **kwargs):
pass

@abc.abstractmethod
def _build_pipeline(self, config=None, policies=None, transport=None, **kwargs):
pass

def _process_response(self, response, scopes, now):
# type: (dict, Iterable[str], int) -> AccessToken
# type: (dict, Sequence[str], int) -> AccessToken
_raise_for_error(response)

# TokenCache.add mutates the response. In particular, it removes tokens.
response_copy = copy.deepcopy(response)

self._cache.add(event={"response": response, "scope": scopes}, now=now)
self._cache.add(event={"response": response, "scope": scopes, "client_id": self._client_id}, now=now)
if "expires_on" in response_copy:
expires_on = int(response_copy["expires_on"])
elif "expires_in" in response_copy:
Expand All @@ -96,17 +89,41 @@ def _process_response(self, response, scopes, now):
)
return AccessToken(response_copy["access_token"], expires_on)

@abc.abstractmethod
def _get_client_session(self, **kwargs):
pass

@abc.abstractmethod
def _obtain_token(self, scopes, fn, **kwargs):
# type: (Iterable[str], Callable, **Any) -> AccessToken
pass
def _get_auth_code_request(self, scopes, code, redirect_uri, client_secret=None):
# type: (str, str, Sequence[str], Optional[str]) -> HttpRequest

data = {
"client_id": self._client_id,
"code": code,
"grant_type": "authorization_code",
"redirect_uri": redirect_uri,
"scope": " ".join(scopes),
}
if client_secret:
data["client_secret"] = client_secret

request = HttpRequest(
"POST", self._token_endpoint, headers={"Content-Type": "application/x-www-form-urlencoded"}, data=data
)
return request

def _get_refresh_token_request(self, scopes, refresh_token):
# type: (str, Sequence[str]) -> HttpRequest

data = {
"grant_type": "refresh_token",
"refresh_token": refresh_token,
"scope": " ".join(scopes),
"client_id": self._client_id,
}
request = HttpRequest(
"POST", self._token_endpoint, headers={"Content-Type": "application/x-www-form-urlencoded"}, data=data
)
return request


def _scrub_secrets(response):
# type: (dict) -> None
for secret in ("access_token", "refresh_token"):
if secret in response:
response[secret] = "***"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@
if TYPE_CHECKING:
# pylint:disable=unused-import,ungrouped-imports
from typing import Any, Iterable, List, Mapping, Optional
import msal_extensions
from azure.core.credentials import AccessToken
from .._internal import AadClientBase

CacheItem = Mapping[str, str]
Expand Down Expand Up @@ -182,9 +180,10 @@ def _get_account(self, username=None, tenant_id=None):
raise CredentialUnavailableError(message=message)

def _get_refresh_tokens(self, account):
return self._cache.find(
cache_entries = self._cache.find(
TokenCache.CredentialType.REFRESH_TOKEN, query={"home_account_id": account.get("home_account_id")}
)
return (token["secret"] for token in cache_entries if "secret" in token)

@staticmethod
def supported():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import asyncio
from typing import TYPE_CHECKING

from azure.core.exceptions import ClientAuthenticationError
Expand All @@ -11,7 +10,7 @@

if TYPE_CHECKING:
# pylint:disable=unused-import,ungrouped-imports
from typing import Any, Iterable, Optional
from typing import Any, Optional, Sequence
from azure.core.credentials import AccessToken


Expand Down Expand Up @@ -66,18 +65,15 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken":
:raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message``
attribute gives a reason. Any error response from Azure Active Directory is available as the error's
``response`` attribute.
:keyword ~concurrent.futures.Executor executor: An Executor instance used to execute asynchronous calls
:keyword loop: An event loop on which to schedule network I/O. If not provided, the currently running
loop will be used.
"""
if not scopes:
raise ValueError("'get_token' requires at least one scope")

if self._authorization_code:
loop = kwargs.pop("loop", None) or asyncio.get_event_loop()
token = await self._client.obtain_token_by_authorization_code(
code=self._authorization_code, redirect_uri=self._redirect_uri, scopes=scopes, loop=loop, **kwargs
scopes=scopes, code=self._authorization_code, redirect_uri=self._redirect_uri, **kwargs
)

self._authorization_code = None # auth codes are single-use
return token

Expand All @@ -92,10 +88,11 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken":

return token

async def _redeem_refresh_token(self, scopes: "Iterable[str]", **kwargs: "Any") -> "Optional[AccessToken]":
loop = kwargs.pop("loop", None) or asyncio.get_event_loop()
async def _redeem_refresh_token(self, scopes: "Sequence[str]", **kwargs: "Any") -> "Optional[AccessToken]":
for refresh_token in self._client.get_cached_refresh_tokens(scopes):
token = await self._client.obtain_token_by_refresh_token(refresh_token, scopes, loop=loop, **kwargs)
if "secret" not in refresh_token:
continue
token = await self._client.obtain_token_by_refresh_token(scopes, refresh_token["secret"], **kwargs)
if token:
return token
return None
Loading

0 comments on commit ee78e5d

Please sign in to comment.