Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reimplement AadClient without msal.oauth2cli #11466

Merged
merged 10 commits into from
May 29, 2020
4 changes: 4 additions & 0 deletions sdk/identity/azure-identity/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# Release History

## 1.4.0b4 (Unreleased)
- `azure.identity.aio.AuthorizationCodeCredential.get_token()` no longer accepts
optional keyword arguments `executor` or `loop`. Prior versions of the method
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

executor & loop were already in 1.3.1?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, they've been around since 1.0.0b4.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So do we really want 1.4.0 to break 1.3.1?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I wrote at the top of this PR, these arguments never worked. Trying to use them just raises exceptions.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the behavior is
Raise exception -> silently ignored?

We should add it into Breaking Change section

didn't use these correctly, provoking exceptions, and internal changes in this
version have made them obsolete.
- `InteractiveBrowserCredential` raises `CredentialUnavailableError` when it
can't start an HTTP server on `localhost`.
([#11665](https://github.com/Azure/azure-sdk-for-python/pull/11665))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

if TYPE_CHECKING:
# pylint:disable=unused-import,ungrouped-imports
from typing import Any, Iterable, Optional
from typing import Any, Optional, Sequence
from azure.core.credentials import AccessToken


Expand Down Expand Up @@ -59,7 +59,7 @@ def get_token(self, *scopes, **kwargs):

if self._authorization_code:
token = self._client.obtain_token_by_authorization_code(
code=self._authorization_code, redirect_uri=self._redirect_uri, scopes=scopes, **kwargs
scopes=scopes, code=self._authorization_code, redirect_uri=self._redirect_uri, **kwargs
xiangyan99 marked this conversation as resolved.
Show resolved Hide resolved
)
self._authorization_code = None # auth codes are single-use
return token
Expand All @@ -73,9 +73,11 @@ def get_token(self, *scopes, **kwargs):
return token

def _redeem_refresh_token(self, scopes, **kwargs):
# type: (Iterable[str], **Any) -> Optional[AccessToken]
# type: (Sequence[str], **Any) -> Optional[AccessToken]
for refresh_token in self._client.get_cached_refresh_tokens(scopes):
token = self._client.obtain_token_by_refresh_token(refresh_token, scopes, **kwargs)
if "secret" not in refresh_token:
continue
token = self._client.obtain_token_by_refresh_token(scopes, refresh_token["secret"], **kwargs)
xiangyan99 marked this conversation as resolved.
Show resolved Hide resolved
if token:
return token
return None
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument

# try each refresh token, returning the first access token acquired
for refresh_token in self._get_refresh_tokens(account):
token = self._client.obtain_token_by_refresh_token(refresh_token, scopes)
token = self._client.obtain_token_by_refresh_token(scopes, refresh_token)
return token
xiangyan99 marked this conversation as resolved.
Show resolved Hide resolved

raise CredentialUnavailableError(message=NO_TOKEN.format(account.get("username")))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,5 +56,5 @@ def get_token(self, *scopes, **kwargs):
if not self._refresh_token:
raise CredentialUnavailableError(message="No Azure user is logged in to Visual Studio Code.")

token = self._client.obtain_token_by_refresh_token(self._refresh_token, scopes, **kwargs)
token = self._client.obtain_token_by_refresh_token(scopes, self._refresh_token, **kwargs)
return token
75 changes: 62 additions & 13 deletions sdk/identity/azure-identity/azure/identity/_internal/aad_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,78 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
"""A thin wrapper around MSAL's token cache and OAuth 2 client"""

import time
from typing import TYPE_CHECKING

from azure.core.credentials import AccessToken
from azure.core.configuration import Configuration
from azure.core.pipeline import Pipeline
from azure.core.pipeline.policies import (
NetworkTraceLoggingPolicy,
RetryPolicy,
ProxyPolicy,
UserAgentPolicy,
ContentDecodePolicy,
DistributedTracingPolicy,
HttpLoggingPolicy,
)

from .aad_client_base import AadClientBase
from .msal_transport_adapter import MsalTransportAdapter
from .exception_wrapper import wrap_exceptions
from .user_agent import USER_AGENT

if TYPE_CHECKING:
# pylint:disable=unused-import,ungrouped-imports
from typing import Any, Callable, Iterable
from typing import Any, List, Optional, Sequence, Union
from azure.core.credentials import AccessToken
from azure.core.pipeline.policies import HTTPPolicy, SansIOHTTPPolicy
from azure.core.pipeline.transport import HttpTransport

Policy = Union[HTTPPolicy, SansIOHTTPPolicy]


class AadClient(AadClientBase):
def _get_client_session(self, **kwargs):
return MsalTransportAdapter(**kwargs)
def obtain_token_by_authorization_code(self, scopes, code, redirect_uri, client_secret=None, **kwargs):
# type: (str, str, Sequence[str], Optional[str], **Any) -> AccessToken
request = self._get_auth_code_request(
scopes=scopes, code=code, redirect_uri=redirect_uri, client_secret=client_secret
)
now = int(time.time())
response = self._pipeline.run(request, stream=False, **kwargs)
content = ContentDecodePolicy.deserialize_from_http_generics(response.http_response)
return self._process_response(response=content, scopes=scopes, now=now)

@wrap_exceptions
def _obtain_token(self, scopes, fn, **kwargs): # pylint:disable=unused-argument
# type: (Iterable[str], Callable, **Any) -> AccessToken
def obtain_token_by_refresh_token(self, scopes, refresh_token, **kwargs):
# type: (str, Sequence[str], **Any) -> AccessToken
request = self._get_refresh_token_request(scopes, refresh_token)
now = int(time.time())
response = fn()
return self._process_response(response=response, scopes=scopes, now=now)
response = self._pipeline.run(request, stream=False, **kwargs)
xiangyan99 marked this conversation as resolved.
Show resolved Hide resolved
content = ContentDecodePolicy.deserialize_from_http_generics(response.http_response)
return self._process_response(response=content, scopes=scopes, now=now)

# pylint:disable=no-self-use
def _build_pipeline(self, config=None, policies=None, transport=None, **kwargs):
# type: (Optional[Configuration], Optional[List[Policy]], Optional[HttpTransport], **Any) -> Pipeline
config = config or _create_config(**kwargs)
policies = policies or [
config.user_agent_policy,
config.proxy_policy,
config.retry_policy,
config.logging_policy,
DistributedTracingPolicy(**kwargs),
HttpLoggingPolicy(**kwargs),
]
if not transport:
xiangyan99 marked this conversation as resolved.
Show resolved Hide resolved
from azure.core.pipeline.transport import RequestsTransport

transport = RequestsTransport(**kwargs)

return Pipeline(transport=transport, policies=policies)


def _create_config(**kwargs):
# type: (**Any) -> Configuration
config = Configuration(**kwargs)
config.logging_policy = NetworkTraceLoggingPolicy(**kwargs)
config.retry_policy = RetryPolicy(**kwargs)
config.proxy_policy = ProxyPolicy(**kwargs)
config.user_agent_policy = UserAgentPolicy(base_user_agent=USER_AGENT, **kwargs)
return config
Original file line number Diff line number Diff line change
Expand Up @@ -4,50 +4,48 @@
# ------------------------------------
import abc
import copy
import functools
import time

try:
from typing import TYPE_CHECKING
except ImportError:
TYPE_CHECKING = False

from msal import TokenCache
from msal.oauth2cli.oauth2 import Client

from azure.core.pipeline.transport import HttpRequest
from azure.core.credentials import AccessToken
from azure.core.exceptions import ClientAuthenticationError
from . import get_default_authority, normalize_authority

try:
from typing import TYPE_CHECKING
except ImportError:
TYPE_CHECKING = False

try:
ABC = abc.ABC
except AttributeError: # Python 2.7, abc exists, but not ABC
ABC = abc.ABCMeta("ABC", (object,), {"__slots__": ()}) # type: ignore

if TYPE_CHECKING:
# pylint:disable=unused-import,ungrouped-imports
from typing import Any, Callable, Iterable, Optional
from typing import Any, Optional, Sequence, Union
from azure.core.pipeline import AsyncPipeline, Pipeline
from azure.core.pipeline.policies import AsyncHTTPPolicy, HTTPPolicy, SansIOHTTPPolicy
from azure.core.pipeline.transport import AsyncHttpTransport, HttpTransport

PipelineType = Union[AsyncPipeline, Pipeline]
PolicyType = Union[AsyncHTTPPolicy, HTTPPolicy, SansIOHTTPPolicy]
TransportType = Union[AsyncHttpTransport, HttpTransport]

class AadClientBase(ABC):
"""Sans I/O methods for AAD clients wrapping MSAL's OAuth client"""

def __init__(self, tenant_id, client_id, cache=None, **kwargs):
# type: (str, str, Optional[TokenCache], **Any) -> None
authority = kwargs.pop("authority", None)
class AadClientBase(ABC):
def __init__(self, tenant_id, client_id, authority=None, cache=None, **kwargs):
# type: (str, str, Optional[str], Optional[TokenCache], **Any) -> None
authority = normalize_authority(authority) if authority else get_default_authority()

token_endpoint = "/".join((authority, tenant_id, "oauth2/v2.0/token"))
config = {"token_endpoint": token_endpoint}

self._token_endpoint = "/".join((authority, tenant_id, "oauth2/v2.0/token"))
self._cache = cache or TokenCache()

self._client = Client(server_configuration=config, client_id=client_id)
self._client.session.close()
self._client.session = self._get_client_session(**kwargs)
self._client_id = client_id
self._pipeline = self._build_pipeline(**kwargs)

def get_cached_access_token(self, scopes):
# type: (Iterable[str]) -> Optional[AccessToken]
# type: (Sequence[str]) -> Optional[AccessToken]
tokens = self._cache.find(TokenCache.CredentialType.ACCESS_TOKEN, target=list(scopes))
for token in tokens:
expires_on = int(token["expires_on"])
Expand All @@ -56,35 +54,30 @@ def get_cached_access_token(self, scopes):
return None

def get_cached_refresh_tokens(self, scopes):
# type: (Sequence[str]) -> Sequence[dict]
"""Assumes all cached refresh tokens belong to the same user"""
return self._cache.find(TokenCache.CredentialType.REFRESH_TOKEN, target=list(scopes))

def obtain_token_by_authorization_code(self, code, redirect_uri, scopes, **kwargs):
xiangyan99 marked this conversation as resolved.
Show resolved Hide resolved
# type: (str, str, Iterable[str], **Any) -> AccessToken
fn = functools.partial(
self._client.obtain_token_by_authorization_code, code=code, redirect_uri=redirect_uri, **kwargs
)
return self._obtain_token(scopes, fn, **kwargs)

def obtain_token_by_refresh_token(self, refresh_token, scopes, **kwargs):
# type: (str, Iterable[str], **Any) -> AccessToken
fn = functools.partial(
self._client.obtain_token_by_refresh_token,
token_item=refresh_token,
scope=scopes,
rt_getter=lambda token: token["secret"],
**kwargs
)
return self._obtain_token(scopes, fn, **kwargs)
@abc.abstractmethod
def obtain_token_by_authorization_code(self, scopes, code, redirect_uri, client_secret=None, **kwargs):
pass

@abc.abstractmethod
def obtain_token_by_refresh_token(self, scopes, refresh_token, **kwargs):
pass

@abc.abstractmethod
def _build_pipeline(self, config=None, policies=None, transport=None, **kwargs):
pass

def _process_response(self, response, scopes, now):
# type: (dict, Iterable[str], int) -> AccessToken
# type: (dict, Sequence[str], int) -> AccessToken
_raise_for_error(response)

# TokenCache.add mutates the response. In particular, it removes tokens.
response_copy = copy.deepcopy(response)

self._cache.add(event={"response": response, "scope": scopes}, now=now)
self._cache.add(event={"response": response, "scope": scopes, "client_id": self._client_id}, now=now)
if "expires_on" in response_copy:
expires_on = int(response_copy["expires_on"])
elif "expires_in" in response_copy:
Expand All @@ -96,17 +89,41 @@ def _process_response(self, response, scopes, now):
)
return AccessToken(response_copy["access_token"], expires_on)

@abc.abstractmethod
def _get_client_session(self, **kwargs):
pass

@abc.abstractmethod
def _obtain_token(self, scopes, fn, **kwargs):
# type: (Iterable[str], Callable, **Any) -> AccessToken
pass
def _get_auth_code_request(self, scopes, code, redirect_uri, client_secret=None):
# type: (str, str, Sequence[str], Optional[str]) -> HttpRequest

data = {
"client_id": self._client_id,
"code": code,
"grant_type": "authorization_code",
"redirect_uri": redirect_uri,
"scope": " ".join(scopes),
}
if client_secret:
data["client_secret"] = client_secret

request = HttpRequest(
"POST", self._token_endpoint, headers={"Content-Type": "application/x-www-form-urlencoded"}, data=data
)
return request

def _get_refresh_token_request(self, scopes, refresh_token):
# type: (str, Sequence[str]) -> HttpRequest

data = {
xiangyan99 marked this conversation as resolved.
Show resolved Hide resolved
"grant_type": "refresh_token",
"refresh_token": refresh_token,
"scope": " ".join(scopes),
"client_id": self._client_id,
}
request = HttpRequest(
"POST", self._token_endpoint, headers={"Content-Type": "application/x-www-form-urlencoded"}, data=data
)
return request


def _scrub_secrets(response):
# type: (dict) -> None
for secret in ("access_token", "refresh_token"):
if secret in response:
response[secret] = "***"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@
if TYPE_CHECKING:
# pylint:disable=unused-import,ungrouped-imports
from typing import Any, Iterable, List, Mapping, Optional
import msal_extensions
from azure.core.credentials import AccessToken
from .._internal import AadClientBase

CacheItem = Mapping[str, str]
Expand Down Expand Up @@ -182,9 +180,10 @@ def _get_account(self, username=None, tenant_id=None):
raise CredentialUnavailableError(message=message)

def _get_refresh_tokens(self, account):
return self._cache.find(
cache_entries = self._cache.find(
TokenCache.CredentialType.REFRESH_TOKEN, query={"home_account_id": account.get("home_account_id")}
)
return (token["secret"] for token in cache_entries if "secret" in token)
xiangyan99 marked this conversation as resolved.
Show resolved Hide resolved

@staticmethod
def supported():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import asyncio
from typing import TYPE_CHECKING

from azure.core.exceptions import ClientAuthenticationError
Expand All @@ -11,7 +10,7 @@

if TYPE_CHECKING:
# pylint:disable=unused-import,ungrouped-imports
from typing import Any, Iterable, Optional
from typing import Any, Optional, Sequence
from azure.core.credentials import AccessToken


Expand Down Expand Up @@ -66,18 +65,15 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken":
:raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message``
attribute gives a reason. Any error response from Azure Active Directory is available as the error's
``response`` attribute.
:keyword ~concurrent.futures.Executor executor: An Executor instance used to execute asynchronous calls
:keyword loop: An event loop on which to schedule network I/O. If not provided, the currently running
loop will be used.
"""
if not scopes:
raise ValueError("'get_token' requires at least one scope")

if self._authorization_code:
loop = kwargs.pop("loop", None) or asyncio.get_event_loop()
token = await self._client.obtain_token_by_authorization_code(
code=self._authorization_code, redirect_uri=self._redirect_uri, scopes=scopes, loop=loop, **kwargs
scopes=scopes, code=self._authorization_code, redirect_uri=self._redirect_uri, **kwargs
)

self._authorization_code = None # auth codes are single-use
return token

Expand All @@ -92,10 +88,11 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken":

return token

async def _redeem_refresh_token(self, scopes: "Iterable[str]", **kwargs: "Any") -> "Optional[AccessToken]":
loop = kwargs.pop("loop", None) or asyncio.get_event_loop()
async def _redeem_refresh_token(self, scopes: "Sequence[str]", **kwargs: "Any") -> "Optional[AccessToken]":
for refresh_token in self._client.get_cached_refresh_tokens(scopes):
token = await self._client.obtain_token_by_refresh_token(refresh_token, scopes, loop=loop, **kwargs)
if "secret" not in refresh_token:
continue
token = await self._client.obtain_token_by_refresh_token(scopes, refresh_token["secret"], **kwargs)
if token:
return token
return None
Loading