From ee78e5d29a73d821977947a3aeacaa10ee9235f3 Mon Sep 17 00:00:00 2001 From: Charles Lowell Date: Fri, 29 May 2020 16:40:53 -0700 Subject: [PATCH] Reimplement AadClient without msal.oauth2cli (#11466) --- sdk/identity/azure-identity/CHANGELOG.md | 4 + .../_credentials/authorization_code.py | 10 +- .../identity/_credentials/shared_cache.py | 2 +- .../_credentials/vscode_credential.py | 2 +- .../azure/identity/_internal/aad_client.py | 75 +++++++-- .../identity/_internal/aad_client_base.py | 115 ++++++++------ .../identity/_internal/shared_token_cache.py | 5 +- .../aio/_credentials/authorization_code.py | 17 +-- .../identity/aio/_credentials/shared_cache.py | 2 +- .../aio/_credentials/vscode_credential.py | 2 +- .../identity/aio/_internal/aad_client.py | 101 +++++++++---- .../azure-identity/tests/helpers_async.py | 2 +- .../azure-identity/tests/test_aad_client.py | 121 +++++++++------ .../tests/test_aad_client_async.py | 143 ++++++++++++++---- .../azure-identity/tests/test_auth_code.py | 67 +++++--- .../tests/test_auth_code_async.py | 126 ++++++--------- .../tests/test_vscode_credential.py | 5 +- .../tests/test_vscode_credential_async.py | 5 +- 18 files changed, 515 insertions(+), 289 deletions(-) diff --git a/sdk/identity/azure-identity/CHANGELOG.md b/sdk/identity/azure-identity/CHANGELOG.md index d2a4b255b6c9..ea5b7fca0727 100644 --- a/sdk/identity/azure-identity/CHANGELOG.md +++ b/sdk/identity/azure-identity/CHANGELOG.md @@ -1,6 +1,10 @@ # Release History ## 1.4.0b4 (Unreleased) +- `azure.identity.aio.AuthorizationCodeCredential.get_token()` no longer accepts + optional keyword arguments `executor` or `loop`. Prior versions of the method + didn't use these correctly, provoking exceptions, and internal changes in this + version have made them obsolete. - `InteractiveBrowserCredential` raises `CredentialUnavailableError` when it can't start an HTTP server on `localhost`. ([#11665](https://github.com/Azure/azure-sdk-for-python/pull/11665)) diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/authorization_code.py b/sdk/identity/azure-identity/azure/identity/_credentials/authorization_code.py index 108d0cb865ef..3568f8c921ce 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/authorization_code.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/authorization_code.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: # pylint:disable=unused-import,ungrouped-imports - from typing import Any, Iterable, Optional + from typing import Any, Optional, Sequence from azure.core.credentials import AccessToken @@ -59,7 +59,7 @@ def get_token(self, *scopes, **kwargs): if self._authorization_code: token = self._client.obtain_token_by_authorization_code( - code=self._authorization_code, redirect_uri=self._redirect_uri, scopes=scopes, **kwargs + scopes=scopes, code=self._authorization_code, redirect_uri=self._redirect_uri, **kwargs ) self._authorization_code = None # auth codes are single-use return token @@ -73,9 +73,11 @@ def get_token(self, *scopes, **kwargs): return token def _redeem_refresh_token(self, scopes, **kwargs): - # type: (Iterable[str], **Any) -> Optional[AccessToken] + # type: (Sequence[str], **Any) -> Optional[AccessToken] for refresh_token in self._client.get_cached_refresh_tokens(scopes): - token = self._client.obtain_token_by_refresh_token(refresh_token, scopes, **kwargs) + if "secret" not in refresh_token: + continue + token = self._client.obtain_token_by_refresh_token(scopes, refresh_token["secret"], **kwargs) if token: return token return None diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py b/sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py index 00354d194d70..b748b700e5bf 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py @@ -60,7 +60,7 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument # try each refresh token, returning the first access token acquired for refresh_token in self._get_refresh_tokens(account): - token = self._client.obtain_token_by_refresh_token(refresh_token, scopes) + token = self._client.obtain_token_by_refresh_token(scopes, refresh_token) return token raise CredentialUnavailableError(message=NO_TOKEN.format(account.get("username"))) diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/vscode_credential.py b/sdk/identity/azure-identity/azure/identity/_credentials/vscode_credential.py index 808686c4b2e5..c40636c24e96 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/vscode_credential.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/vscode_credential.py @@ -56,5 +56,5 @@ def get_token(self, *scopes, **kwargs): if not self._refresh_token: raise CredentialUnavailableError(message="No Azure user is logged in to Visual Studio Code.") - token = self._client.obtain_token_by_refresh_token(self._refresh_token, scopes, **kwargs) + token = self._client.obtain_token_by_refresh_token(scopes, self._refresh_token, **kwargs) return token diff --git a/sdk/identity/azure-identity/azure/identity/_internal/aad_client.py b/sdk/identity/azure-identity/azure/identity/_internal/aad_client.py index aa25b4a4b1a2..75de3d05cf41 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/aad_client.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/aad_client.py @@ -2,29 +2,78 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ -"""A thin wrapper around MSAL's token cache and OAuth 2 client""" - import time from typing import TYPE_CHECKING -from azure.core.credentials import AccessToken +from azure.core.configuration import Configuration +from azure.core.pipeline import Pipeline +from azure.core.pipeline.policies import ( + NetworkTraceLoggingPolicy, + RetryPolicy, + ProxyPolicy, + UserAgentPolicy, + ContentDecodePolicy, + DistributedTracingPolicy, + HttpLoggingPolicy, +) from .aad_client_base import AadClientBase -from .msal_transport_adapter import MsalTransportAdapter -from .exception_wrapper import wrap_exceptions +from .user_agent import USER_AGENT if TYPE_CHECKING: # pylint:disable=unused-import,ungrouped-imports - from typing import Any, Callable, Iterable + from typing import Any, List, Optional, Sequence, Union + from azure.core.credentials import AccessToken + from azure.core.pipeline.policies import HTTPPolicy, SansIOHTTPPolicy + from azure.core.pipeline.transport import HttpTransport + + Policy = Union[HTTPPolicy, SansIOHTTPPolicy] class AadClient(AadClientBase): - def _get_client_session(self, **kwargs): - return MsalTransportAdapter(**kwargs) + def obtain_token_by_authorization_code(self, scopes, code, redirect_uri, client_secret=None, **kwargs): + # type: (str, str, Sequence[str], Optional[str], **Any) -> AccessToken + request = self._get_auth_code_request( + scopes=scopes, code=code, redirect_uri=redirect_uri, client_secret=client_secret + ) + now = int(time.time()) + response = self._pipeline.run(request, stream=False, **kwargs) + content = ContentDecodePolicy.deserialize_from_http_generics(response.http_response) + return self._process_response(response=content, scopes=scopes, now=now) - @wrap_exceptions - def _obtain_token(self, scopes, fn, **kwargs): # pylint:disable=unused-argument - # type: (Iterable[str], Callable, **Any) -> AccessToken + def obtain_token_by_refresh_token(self, scopes, refresh_token, **kwargs): + # type: (str, Sequence[str], **Any) -> AccessToken + request = self._get_refresh_token_request(scopes, refresh_token) now = int(time.time()) - response = fn() - return self._process_response(response=response, scopes=scopes, now=now) + response = self._pipeline.run(request, stream=False, **kwargs) + content = ContentDecodePolicy.deserialize_from_http_generics(response.http_response) + return self._process_response(response=content, scopes=scopes, now=now) + + # pylint:disable=no-self-use + def _build_pipeline(self, config=None, policies=None, transport=None, **kwargs): + # type: (Optional[Configuration], Optional[List[Policy]], Optional[HttpTransport], **Any) -> Pipeline + config = config or _create_config(**kwargs) + policies = policies or [ + config.user_agent_policy, + config.proxy_policy, + config.retry_policy, + config.logging_policy, + DistributedTracingPolicy(**kwargs), + HttpLoggingPolicy(**kwargs), + ] + if not transport: + from azure.core.pipeline.transport import RequestsTransport + + transport = RequestsTransport(**kwargs) + + return Pipeline(transport=transport, policies=policies) + + +def _create_config(**kwargs): + # type: (**Any) -> Configuration + config = Configuration(**kwargs) + config.logging_policy = NetworkTraceLoggingPolicy(**kwargs) + config.retry_policy = RetryPolicy(**kwargs) + config.proxy_policy = ProxyPolicy(**kwargs) + config.user_agent_policy = UserAgentPolicy(base_user_agent=USER_AGENT, **kwargs) + return config diff --git a/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py b/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py index 10f2a0ef8ea8..bf659087a2cb 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py @@ -4,21 +4,20 @@ # ------------------------------------ import abc import copy -import functools import time -try: - from typing import TYPE_CHECKING -except ImportError: - TYPE_CHECKING = False - from msal import TokenCache -from msal.oauth2cli.oauth2 import Client +from azure.core.pipeline.transport import HttpRequest from azure.core.credentials import AccessToken from azure.core.exceptions import ClientAuthenticationError from . import get_default_authority, normalize_authority +try: + from typing import TYPE_CHECKING +except ImportError: + TYPE_CHECKING = False + try: ABC = abc.ABC except AttributeError: # Python 2.7, abc exists, but not ABC @@ -26,28 +25,27 @@ if TYPE_CHECKING: # pylint:disable=unused-import,ungrouped-imports - from typing import Any, Callable, Iterable, Optional + from typing import Any, Optional, Sequence, Union + from azure.core.pipeline import AsyncPipeline, Pipeline + from azure.core.pipeline.policies import AsyncHTTPPolicy, HTTPPolicy, SansIOHTTPPolicy + from azure.core.pipeline.transport import AsyncHttpTransport, HttpTransport + PipelineType = Union[AsyncPipeline, Pipeline] + PolicyType = Union[AsyncHTTPPolicy, HTTPPolicy, SansIOHTTPPolicy] + TransportType = Union[AsyncHttpTransport, HttpTransport] -class AadClientBase(ABC): - """Sans I/O methods for AAD clients wrapping MSAL's OAuth client""" - def __init__(self, tenant_id, client_id, cache=None, **kwargs): - # type: (str, str, Optional[TokenCache], **Any) -> None - authority = kwargs.pop("authority", None) +class AadClientBase(ABC): + def __init__(self, tenant_id, client_id, authority=None, cache=None, **kwargs): + # type: (str, str, Optional[str], Optional[TokenCache], **Any) -> None authority = normalize_authority(authority) if authority else get_default_authority() - - token_endpoint = "/".join((authority, tenant_id, "oauth2/v2.0/token")) - config = {"token_endpoint": token_endpoint} - + self._token_endpoint = "/".join((authority, tenant_id, "oauth2/v2.0/token")) self._cache = cache or TokenCache() - - self._client = Client(server_configuration=config, client_id=client_id) - self._client.session.close() - self._client.session = self._get_client_session(**kwargs) + self._client_id = client_id + self._pipeline = self._build_pipeline(**kwargs) def get_cached_access_token(self, scopes): - # type: (Iterable[str]) -> Optional[AccessToken] + # type: (Sequence[str]) -> Optional[AccessToken] tokens = self._cache.find(TokenCache.CredentialType.ACCESS_TOKEN, target=list(scopes)) for token in tokens: expires_on = int(token["expires_on"]) @@ -56,35 +54,30 @@ def get_cached_access_token(self, scopes): return None def get_cached_refresh_tokens(self, scopes): + # type: (Sequence[str]) -> Sequence[dict] """Assumes all cached refresh tokens belong to the same user""" return self._cache.find(TokenCache.CredentialType.REFRESH_TOKEN, target=list(scopes)) - def obtain_token_by_authorization_code(self, code, redirect_uri, scopes, **kwargs): - # type: (str, str, Iterable[str], **Any) -> AccessToken - fn = functools.partial( - self._client.obtain_token_by_authorization_code, code=code, redirect_uri=redirect_uri, **kwargs - ) - return self._obtain_token(scopes, fn, **kwargs) - - def obtain_token_by_refresh_token(self, refresh_token, scopes, **kwargs): - # type: (str, Iterable[str], **Any) -> AccessToken - fn = functools.partial( - self._client.obtain_token_by_refresh_token, - token_item=refresh_token, - scope=scopes, - rt_getter=lambda token: token["secret"], - **kwargs - ) - return self._obtain_token(scopes, fn, **kwargs) + @abc.abstractmethod + def obtain_token_by_authorization_code(self, scopes, code, redirect_uri, client_secret=None, **kwargs): + pass + + @abc.abstractmethod + def obtain_token_by_refresh_token(self, scopes, refresh_token, **kwargs): + pass + + @abc.abstractmethod + def _build_pipeline(self, config=None, policies=None, transport=None, **kwargs): + pass def _process_response(self, response, scopes, now): - # type: (dict, Iterable[str], int) -> AccessToken + # type: (dict, Sequence[str], int) -> AccessToken _raise_for_error(response) # TokenCache.add mutates the response. In particular, it removes tokens. response_copy = copy.deepcopy(response) - self._cache.add(event={"response": response, "scope": scopes}, now=now) + self._cache.add(event={"response": response, "scope": scopes, "client_id": self._client_id}, now=now) if "expires_on" in response_copy: expires_on = int(response_copy["expires_on"]) elif "expires_in" in response_copy: @@ -96,17 +89,41 @@ def _process_response(self, response, scopes, now): ) return AccessToken(response_copy["access_token"], expires_on) - @abc.abstractmethod - def _get_client_session(self, **kwargs): - pass - - @abc.abstractmethod - def _obtain_token(self, scopes, fn, **kwargs): - # type: (Iterable[str], Callable, **Any) -> AccessToken - pass + def _get_auth_code_request(self, scopes, code, redirect_uri, client_secret=None): + # type: (str, str, Sequence[str], Optional[str]) -> HttpRequest + + data = { + "client_id": self._client_id, + "code": code, + "grant_type": "authorization_code", + "redirect_uri": redirect_uri, + "scope": " ".join(scopes), + } + if client_secret: + data["client_secret"] = client_secret + + request = HttpRequest( + "POST", self._token_endpoint, headers={"Content-Type": "application/x-www-form-urlencoded"}, data=data + ) + return request + + def _get_refresh_token_request(self, scopes, refresh_token): + # type: (str, Sequence[str]) -> HttpRequest + + data = { + "grant_type": "refresh_token", + "refresh_token": refresh_token, + "scope": " ".join(scopes), + "client_id": self._client_id, + } + request = HttpRequest( + "POST", self._token_endpoint, headers={"Content-Type": "application/x-www-form-urlencoded"}, data=data + ) + return request def _scrub_secrets(response): + # type: (dict) -> None for secret in ("access_token", "refresh_token"): if secret in response: response[secret] = "***" diff --git a/sdk/identity/azure-identity/azure/identity/_internal/shared_token_cache.py b/sdk/identity/azure-identity/azure/identity/_internal/shared_token_cache.py index c21ecc27947f..6fba6c09b986 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/shared_token_cache.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/shared_token_cache.py @@ -26,8 +26,6 @@ if TYPE_CHECKING: # pylint:disable=unused-import,ungrouped-imports from typing import Any, Iterable, List, Mapping, Optional - import msal_extensions - from azure.core.credentials import AccessToken from .._internal import AadClientBase CacheItem = Mapping[str, str] @@ -182,9 +180,10 @@ def _get_account(self, username=None, tenant_id=None): raise CredentialUnavailableError(message=message) def _get_refresh_tokens(self, account): - return self._cache.find( + cache_entries = self._cache.find( TokenCache.CredentialType.REFRESH_TOKEN, query={"home_account_id": account.get("home_account_id")} ) + return (token["secret"] for token in cache_entries if "secret" in token) @staticmethod def supported(): diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/authorization_code.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/authorization_code.py index 65134dcf3e82..0b5fbb53dc33 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/authorization_code.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/authorization_code.py @@ -2,7 +2,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ -import asyncio from typing import TYPE_CHECKING from azure.core.exceptions import ClientAuthenticationError @@ -11,7 +10,7 @@ if TYPE_CHECKING: # pylint:disable=unused-import,ungrouped-imports - from typing import Any, Iterable, Optional + from typing import Any, Optional, Sequence from azure.core.credentials import AccessToken @@ -66,18 +65,15 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message`` attribute gives a reason. Any error response from Azure Active Directory is available as the error's ``response`` attribute. - :keyword ~concurrent.futures.Executor executor: An Executor instance used to execute asynchronous calls - :keyword loop: An event loop on which to schedule network I/O. If not provided, the currently running - loop will be used. """ if not scopes: raise ValueError("'get_token' requires at least one scope") if self._authorization_code: - loop = kwargs.pop("loop", None) or asyncio.get_event_loop() token = await self._client.obtain_token_by_authorization_code( - code=self._authorization_code, redirect_uri=self._redirect_uri, scopes=scopes, loop=loop, **kwargs + scopes=scopes, code=self._authorization_code, redirect_uri=self._redirect_uri, **kwargs ) + self._authorization_code = None # auth codes are single-use return token @@ -92,10 +88,11 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": return token - async def _redeem_refresh_token(self, scopes: "Iterable[str]", **kwargs: "Any") -> "Optional[AccessToken]": - loop = kwargs.pop("loop", None) or asyncio.get_event_loop() + async def _redeem_refresh_token(self, scopes: "Sequence[str]", **kwargs: "Any") -> "Optional[AccessToken]": for refresh_token in self._client.get_cached_refresh_tokens(scopes): - token = await self._client.obtain_token_by_refresh_token(refresh_token, scopes, loop=loop, **kwargs) + if "secret" not in refresh_token: + continue + token = await self._client.obtain_token_by_refresh_token(scopes, refresh_token["secret"], **kwargs) if token: return token return None diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/shared_cache.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/shared_cache.py index 81d0f6843f2f..ec021eba460d 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/shared_cache.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/shared_cache.py @@ -68,7 +68,7 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # py # try each refresh token, returning the first access token acquired for refresh_token in self._get_refresh_tokens(account): - token = await self._client.obtain_token_by_refresh_token(refresh_token, scopes) + token = await self._client.obtain_token_by_refresh_token(scopes, refresh_token) return token raise CredentialUnavailableError(message=NO_TOKEN.format(account.get("username"))) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/vscode_credential.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/vscode_credential.py index 337460611b51..fcf392421294 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/vscode_credential.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/vscode_credential.py @@ -60,5 +60,5 @@ async def get_token(self, *scopes, **kwargs): if not self._refresh_token: raise CredentialUnavailableError(message="No Azure user is logged in to Visual Studio Code.") - token = await self._client.obtain_token_by_refresh_token(self._refresh_token, scopes, **kwargs) + token = await self._client.obtain_token_by_refresh_token(scopes, self._refresh_token, **kwargs) return token diff --git a/sdk/identity/azure-identity/azure/identity/aio/_internal/aad_client.py b/sdk/identity/azure-identity/azure/identity/aio/_internal/aad_client.py index ceb96ba87d21..98075b9aeb66 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_internal/aad_client.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_internal/aad_client.py @@ -2,25 +2,36 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ -"""A thin wrapper around MSAL's token cache and OAuth 2 client""" - -import asyncio import time from typing import TYPE_CHECKING -from azure.identity._internal import AadClientBase -from .msal_transport_adapter import MsalTransportAdapter -from .exception_wrapper import wrap_exceptions +from azure.core.configuration import Configuration +from azure.core.pipeline import AsyncPipeline +from azure.core.pipeline.policies import ( + ContentDecodePolicy, + ProxyPolicy, + NetworkTraceLoggingPolicy, + AsyncRetryPolicy, + UserAgentPolicy, + DistributedTracingPolicy, + HttpLoggingPolicy, +) +from ..._internal import AadClientBase +from ..._internal.user_agent import USER_AGENT if TYPE_CHECKING: # pylint:disable=unused-import,ungrouped-imports - from typing import Any, Callable, Iterable + from typing import Any, List, Optional, Sequence, Union from azure.core.credentials import AccessToken + from azure.core.pipeline.policies import AsyncHTTPPolicy, SansIOHTTPPolicy + from azure.core.pipeline.transport import AsyncHttpTransport + + Policy = Union[AsyncHTTPPolicy, SansIOHTTPPolicy] class AadClient(AadClientBase): async def __aenter__(self): - await self._client.session.__aenter__() + await self._pipeline.__aenter__() return self async def __aexit__(self, *args): @@ -29,28 +40,62 @@ async def __aexit__(self, *args): async def close(self) -> None: """Close the client's transport session.""" - await self._client.session.__aexit__() + await self._pipeline.__aexit__() - # pylint:disable=arguments-differ - def obtain_token_by_authorization_code( - self, *args: "Any", loop: "asyncio.AbstractEventLoop" = None, **kwargs: "Any" + async def obtain_token_by_authorization_code( + self, + scopes: "Sequence[str]", + code: str, + redirect_uri: str, + client_secret: "Optional[str]" = None, + **kwargs: "Any" ) -> "AccessToken": - # 'loop' will reach the transport adapter as a kwarg, so here we ensure it's passed - loop = loop or asyncio.get_event_loop() - return super().obtain_token_by_authorization_code(*args, loop=loop, **kwargs) - - def obtain_token_by_refresh_token(self, *args, loop: "asyncio.AbstractEventLoop" = None, **kwargs) -> "AccessToken": - # 'loop' will reach the transport adapter as a kwarg, so here we ensure it's passed - loop = loop or asyncio.get_event_loop() - return super().obtain_token_by_refresh_token(*args, loop=loop, **kwargs) - - def _get_client_session(self, **kwargs): - return MsalTransportAdapter(**kwargs) + request = self._get_auth_code_request( + scopes=scopes, code=code, redirect_uri=redirect_uri, client_secret=client_secret + ) + now = int(time.time()) + response = await self._pipeline.run(request, **kwargs) + content = ContentDecodePolicy.deserialize_from_http_generics(response.http_response) + return self._process_response(response=content, scopes=scopes, now=now) - @wrap_exceptions - async def _obtain_token( - self, scopes: "Iterable[str]", fn: "Callable", loop: "asyncio.AbstractEventLoop", executor=None, **kwargs: "Any" + async def obtain_token_by_refresh_token( + self, scopes: "Sequence[str]", refresh_token: str, **kwargs: "Any" ) -> "AccessToken": + request = self._get_refresh_token_request(scopes, refresh_token) now = int(time.time()) - response = await loop.run_in_executor(executor, fn) - return self._process_response(response=response, scopes=scopes, now=now) + response = await self._pipeline.run(request, **kwargs) + content = ContentDecodePolicy.deserialize_from_http_generics(response.http_response) + return self._process_response(response=content, scopes=scopes, now=now) + + # pylint:disable=no-self-use + def _build_pipeline( + self, + config: Configuration = None, + policies: "Optional[List[Policy]]" = None, + transport: "Optional[AsyncHttpTransport]" = None, + **kwargs: "Any" + ) -> AsyncPipeline: + config = config or _create_config(**kwargs) + policies = policies or [ + config.user_agent_policy, + config.proxy_policy, + config.retry_policy, + config.logging_policy, + DistributedTracingPolicy(**kwargs), + HttpLoggingPolicy(**kwargs), + ] + if not transport: + from azure.core.pipeline.transport import AioHttpTransport + + transport = AioHttpTransport(configuration=config) + + return AsyncPipeline(transport=transport, policies=policies) + + +def _create_config(**kwargs: "Any") -> Configuration: + config = Configuration(**kwargs) + config.proxy_policy = ProxyPolicy(**kwargs) + config.logging_policy = NetworkTraceLoggingPolicy(**kwargs) + config.retry_policy = AsyncRetryPolicy(**kwargs) + config.user_agent_policy = UserAgentPolicy(base_user_agent=USER_AGENT, **kwargs) + return config diff --git a/sdk/identity/azure-identity/tests/helpers_async.py b/sdk/identity/azure-identity/tests/helpers_async.py index 38fac1d36395..2a66e167c325 100644 --- a/sdk/identity/azure-identity/tests/helpers_async.py +++ b/sdk/identity/azure-identity/tests/helpers_async.py @@ -46,4 +46,4 @@ def __init__(self, *args, **kwargs): def async_validating_transport(requests, responses): sync_transport = validating_transport(requests, responses) - return AsyncMockTransport(send=wrap_in_future(sync_transport.send)) + return AsyncMockTransport(send=mock.Mock(wraps=wrap_in_future(sync_transport.send))) diff --git a/sdk/identity/azure-identity/tests/test_aad_client.py b/sdk/identity/azure-identity/tests/test_aad_client.py index 41612cba5853..98a74fa4ae68 100644 --- a/sdk/identity/azure-identity/tests/test_aad_client.py +++ b/sdk/identity/azure-identity/tests/test_aad_client.py @@ -3,7 +3,6 @@ # Licensed under the MIT License. # ------------------------------------ import functools -import json from azure.core.exceptions import ClientAuthenticationError from azure.identity._constants import EnvironmentVariables @@ -19,49 +18,18 @@ from mock import Mock, patch # type: ignore -class MockClient(AadClient): - def __init__(self, *args, **kwargs): - self.session = kwargs.pop("session") - super(MockClient, self).__init__(*args, **kwargs) - - def _get_client_session(self, **kwargs): - return self.session - - -def test_uses_msal_correctly(): - session = Mock() - transport = Mock() - session.get = session.post = transport - - client = MockClient("tenant id", "client id", session=session) - - # MSAL will raise on each call because the mock transport returns nothing useful. - # That's okay because we only want to verify the transport was called, i.e. that - # the client used the MSAL API correctly, such that MSAL tried to send a request. - with pytest.raises(ClientAuthenticationError): - client.obtain_token_by_authorization_code("code", "redirect uri", "scope") - assert transport.call_count == 1 - - transport.reset_mock() - - with pytest.raises(ClientAuthenticationError): - client.obtain_token_by_refresh_token("refresh token", "scope") - assert transport.call_count == 1 - - def test_error_reporting(): error_name = "everything's sideways" error_description = "something went wrong" error_response = {"error": error_name, "error_description": error_description} - response = Mock(status_code=403, json=lambda: error_response, text=json.dumps(error_response)) - transport = Mock(return_value=response) - session = Mock(get=transport, post=transport) - client = MockClient("tenant id", "client id", session=session) + response = mock_response(status_code=403, json_payload=error_response) + transport = Mock(send=Mock(return_value=response)) + client = AadClient("tenant id", "client id", transport=transport) fns = [ - functools.partial(client.obtain_token_by_authorization_code, "code", "uri", "scope"), - functools.partial(client.obtain_token_by_refresh_token, {"secret": "refresh token"}, "scope"), + functools.partial(client.obtain_token_by_authorization_code, ("scope",), "code", "uri"), + functools.partial(client.obtain_token_by_refresh_token, ("scope",), "refresh token"), ] # exceptions raised for AAD errors should contain AAD's error description @@ -70,27 +38,30 @@ def test_error_reporting(): fn() message = str(ex.value) assert error_name in message and error_description in message + assert transport.send.call_count == 1 + transport.send.reset_mock() def test_exceptions_do_not_expose_secrets(): secret = "secret" body = {"error": "bad thing", "access_token": secret, "refresh_token": secret} - response = Mock(status_code=403, json=lambda: body) - transport = Mock(return_value=response) - session = Mock(get=transport, post=transport) - client = MockClient("tenant id", "client id", session=session) + response = mock_response(status_code=403, json_payload=body) + transport = Mock(send=Mock(return_value=response)) + client = AadClient("tenant id", "client id", transport=transport) fns = [ functools.partial(client.obtain_token_by_authorization_code, "code", "uri", "scope"), - functools.partial(client.obtain_token_by_refresh_token, {"secret": "refresh token"}, "scope"), + functools.partial(client.obtain_token_by_refresh_token, "refresh token", ("scope"),), ] def assert_secrets_not_exposed(): for fn in fns: with pytest.raises(ClientAuthenticationError) as ex: fn() - assert secret not in str(ex.value) - assert secret not in repr(ex.value) + assert secret not in str(ex.value) + assert secret not in repr(ex.value) + assert transport.send.call_count == 1 + transport.send.reset_mock() # AAD errors shouldn't provoke exceptions exposing secrets assert_secrets_not_exposed() @@ -115,11 +86,65 @@ def send(request, **_): client = AadClient(tenant_id, "client id", transport=Mock(send=send), authority=authority) - client.obtain_token_by_authorization_code("code", "uri", "scope") - client.obtain_token_by_refresh_token("refresh token", "scope") + client.obtain_token_by_authorization_code("scope", "code", "uri") + client.obtain_token_by_refresh_token("scope", "refresh token") # authority can be configured via environment variable with patch.dict("os.environ", {EnvironmentVariables.AZURE_AUTHORITY_HOST: authority}, clear=True): client = AadClient(tenant_id=tenant_id, client_id="client id", transport=Mock(send=send)) - client.obtain_token_by_authorization_code("code", "uri", "scope") - client.obtain_token_by_refresh_token("refresh token", "scope") + client.obtain_token_by_authorization_code("scope", "code", "uri") + client.obtain_token_by_refresh_token("scope", "refresh token") + + +@pytest.mark.parametrize("secret", (None, "client secret")) +def test_authorization_code(secret): + tenant_id = "tenant-id" + client_id = "client-id" + auth_code = "code" + scope = "scope" + redirect_uri = "https://localhost" + access_token = "***" + + def send(request, **_): + assert request.data["client_id"] == client_id + assert request.data["code"] == auth_code + assert request.data["grant_type"] == "authorization_code" + assert request.data["redirect_uri"] == redirect_uri + assert request.data["scope"] == scope + assert request.data.get("client_secret") == secret + + return mock_response(json_payload={"access_token": access_token, "expires_in": 42}) + + transport = Mock(send=Mock(wraps=send)) + + client = AadClient(tenant_id, client_id, transport=transport) + token = client.obtain_token_by_authorization_code( + scopes=(scope,), code=auth_code, redirect_uri=redirect_uri, client_secret=secret + ) + + assert token.token == access_token + assert transport.send.call_count == 1 + + +def test_refresh_token(): + tenant_id = "tenant-id" + client_id = "client-id" + scope = "scope" + refresh_token = "refresh-token" + access_token = "***" + + def send(request, **_): + assert request.data["client_id"] == client_id + assert request.data["grant_type"] == "refresh_token" + assert request.data["refresh_token"] == refresh_token + assert request.data["scope"] == scope + + return mock_response(json_payload={"access_token": access_token, "expires_in": 42}) + + transport = Mock(send=Mock(wraps=send)) + + client = AadClient(tenant_id, client_id, transport=transport) + token = client.obtain_token_by_refresh_token(scopes=(scope,), refresh_token=refresh_token) + + assert token.token == access_token + assert transport.send.call_count == 1 diff --git a/sdk/identity/azure-identity/tests/test_aad_client_async.py b/sdk/identity/azure-identity/tests/test_aad_client_async.py index de77ab384736..db7ce99c3e30 100644 --- a/sdk/identity/azure-identity/tests/test_aad_client_async.py +++ b/sdk/identity/azure-identity/tests/test_aad_client_async.py @@ -2,47 +2,136 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ +import functools from unittest.mock import Mock, patch from urllib.parse import urlparse +from azure.core.exceptions import ClientAuthenticationError from azure.identity._constants import EnvironmentVariables from azure.identity.aio._internal.aad_client import AadClient import pytest -from helpers import mock_response +from helpers import build_aad_response, mock_response +pytestmark = pytest.mark.asyncio -class MockClient(AadClient): - def __init__(self, *args, **kwargs): - self.session = kwargs.pop("session") - super(MockClient, self).__init__(*args, **kwargs) - def _get_client_session(self, **kwargs): - return self.session +async def test_error_reporting(): + error_name = "everything's sideways" + error_description = "something went wrong" + error_response = {"error": error_name, "error_description": error_description} + response = mock_response(status_code=403, json_payload=error_response) -@pytest.mark.asyncio -async def test_uses_msal_correctly(): - transport = Mock() - session = Mock(get=transport, post=transport) + async def send(*_, **__): + return response - client = MockClient("tenant id", "client id", session=session) + transport = Mock(send=Mock(wraps=send)) + client = AadClient("tenant id", "client id", transport=transport) - # MSAL will raise on each call because the mock transport returns nothing useful. - # That's okay because we only want to verify the transport was called, i.e. that - # the client used the MSAL API correctly, such that MSAL tried to send a request. - with pytest.raises(Exception): - await client.obtain_token_by_authorization_code("code", "redirect uri", "scope") - assert transport.call_count == 1 + fns = [ + functools.partial(client.obtain_token_by_authorization_code, ("scope",), "code", "uri"), + functools.partial(client.obtain_token_by_refresh_token, ("scope",), "refresh token"), + ] - transport.reset_mock() + # exceptions raised for AAD errors should contain AAD's error description + for fn in fns: + with pytest.raises(ClientAuthenticationError) as ex: + await fn() + message = str(ex.value) + assert error_name in message and error_description in message + assert transport.send.call_count == 1 + transport.send.reset_mock() - with pytest.raises(Exception): - await client.obtain_token_by_refresh_token("refresh token", "scope") - assert transport.call_count == 1 + +async def test_exceptions_do_not_expose_secrets(): + secret = "secret" + body = {"error": "bad thing", "access_token": secret, "refresh_token": secret} + response = mock_response(status_code=403, json_payload=body) + + async def send(*_, **__): + return response + + transport = Mock(send=Mock(wraps=send)) + + client = AadClient("tenant id", "client id", transport=transport) + + fns = [ + functools.partial(client.obtain_token_by_authorization_code, "code", "uri", ("scope",)), + functools.partial(client.obtain_token_by_refresh_token, "refresh token", ("scope",)), + ] + + async def assert_secrets_not_exposed(): + for fn in fns: + with pytest.raises(ClientAuthenticationError) as ex: + await fn() + assert secret not in str(ex.value) + assert secret not in repr(ex.value) + assert transport.send.call_count == 1 + transport.send.reset_mock() + + # AAD errors shouldn't provoke exceptions exposing secrets + await assert_secrets_not_exposed() + + # neither should unexpected AAD responses + del body["error"] + await assert_secrets_not_exposed() + + +@pytest.mark.parametrize("secret", (None, "client secret")) +async def test_authorization_code(secret): + tenant_id = "tenant-id" + client_id = "client-id" + auth_code = "code" + scope = "scope" + redirect_uri = "https://localhost" + access_token = "***" + + async def send(request, **_): + assert request.data["client_id"] == client_id + assert request.data["code"] == auth_code + assert request.data["grant_type"] == "authorization_code" + assert request.data["redirect_uri"] == redirect_uri + assert request.data["scope"] == scope + assert request.data.get("client_secret") == secret + + return mock_response(json_payload={"access_token": access_token, "expires_in": 42}) + + transport = Mock(send=Mock(wraps=send)) + + client = AadClient(tenant_id, client_id, transport=transport) + token = await client.obtain_token_by_authorization_code( + scopes=(scope,), code=auth_code, redirect_uri=redirect_uri, client_secret=secret + ) + + assert token.token == access_token + assert transport.send.call_count == 1 + + +async def test_refresh_token(): + tenant_id = "tenant-id" + client_id = "client-id" + scope = "scope" + refresh_token = "refresh-token" + access_token = "***" + + async def send(request, **_): + assert request.data["client_id"] == client_id + assert request.data["grant_type"] == "refresh_token" + assert request.data["refresh_token"] == refresh_token + assert request.data["scope"] == scope + + return mock_response(json_payload={"access_token": access_token, "expires_in": 42}) + + transport = Mock(send=Mock(wraps=send)) + + client = AadClient(tenant_id, client_id, transport=transport) + token = await client.obtain_token_by_refresh_token(scopes=(scope,), refresh_token=refresh_token) + + assert token.token == access_token + assert transport.send.call_count == 1 -@pytest.mark.asyncio @pytest.mark.parametrize("authority", ("localhost", "https://localhost")) async def test_request_url(authority): tenant_id = "expected_tenant" @@ -58,11 +147,11 @@ async def send(request, **_): client = AadClient(tenant_id, "client id", transport=Mock(send=send), authority=authority) - await client.obtain_token_by_authorization_code("code", "uri", "scope") - await client.obtain_token_by_refresh_token("refresh token", "scope") + await client.obtain_token_by_authorization_code("scope", "code", "uri") + await client.obtain_token_by_refresh_token("scope", "refresh token") # authority can be configured via environment variable with patch.dict("os.environ", {EnvironmentVariables.AZURE_AUTHORITY_HOST: authority}, clear=True): client = AadClient(tenant_id=tenant_id, client_id="client id", transport=Mock(send=send)) - await client.obtain_token_by_authorization_code("code", "uri", "scope") - await client.obtain_token_by_refresh_token("refresh token", "scope") + await client.obtain_token_by_authorization_code("scope", "code", "uri") + await client.obtain_token_by_refresh_token("scope", "refresh token") diff --git a/sdk/identity/azure-identity/tests/test_auth_code.py b/sdk/identity/azure-identity/tests/test_auth_code.py index 7093e48e7171..7b6ce76a75ed 100644 --- a/sdk/identity/azure-identity/tests/test_auth_code.py +++ b/sdk/identity/azure-identity/tests/test_auth_code.py @@ -6,6 +6,7 @@ from azure.core.pipeline.policies import SansIOHTTPPolicy from azure.identity import AuthorizationCodeCredential from azure.identity._internal.user_agent import USER_AGENT +import msal import pytest from helpers import build_aad_response, mock_response, Request, validating_transport @@ -56,36 +57,60 @@ def test_auth_code_credential(): client_id = "client id" tenant_id = "tenant" expected_code = "auth code" - redirect_uri = "https://foo.bar" - expected_token = AccessToken("token", 42) + redirect_uri = "https://localhost" + expected_access_token = "access" + expected_refresh_token = "refresh" + expected_scope = "scope" - mock_client = Mock(spec=object) - mock_client.obtain_token_by_authorization_code = Mock(return_value=expected_token) + auth_response = build_aad_response(access_token=expected_access_token, refresh_token=expected_refresh_token) + transport = validating_transport( + requests=[ + Request( # first call should redeem the auth code + url_substring=tenant_id, + required_data={ + "client_id": client_id, + "code": expected_code, + "grant_type": "authorization_code", + "redirect_uri": redirect_uri, + "scope": expected_scope, + }, + ), + Request( # third call should redeem the refresh token + url_substring=tenant_id, + required_data={ + "client_id": client_id, + "grant_type": "refresh_token", + "refresh_token": expected_refresh_token, + "scope": expected_scope, + }, + ), + ], + responses=[mock_response(json_payload=auth_response)] * 2, + ) + cache = msal.TokenCache() credential = AuthorizationCodeCredential( client_id=client_id, tenant_id=tenant_id, authorization_code=expected_code, redirect_uri=redirect_uri, - client=mock_client, + transport=transport, + cache=cache, ) # first call should redeem the auth code - token = credential.get_token("scope") - assert token is expected_token - assert mock_client.obtain_token_by_authorization_code.call_count == 1 - _, kwargs = mock_client.obtain_token_by_authorization_code.call_args - assert kwargs["code"] == expected_code + token = credential.get_token(expected_scope) + assert token.token == expected_access_token + assert transport.send.call_count == 1 # no auth code -> credential should return cached token - mock_client.obtain_token_by_authorization_code = None # raise if credential calls this again - mock_client.get_cached_access_token = lambda *_: expected_token - token = credential.get_token("scope") - assert token is expected_token - - # no auth code, no cached token -> credential should use refresh token - mock_client.get_cached_access_token = lambda *_: None - mock_client.get_cached_refresh_tokens = lambda *_: ["this is a refresh token"] - mock_client.obtain_token_by_refresh_token = lambda *_, **__: expected_token - token = credential.get_token("scope") - assert token is expected_token + token = credential.get_token(expected_scope) + assert token.token == expected_access_token + assert transport.send.call_count == 1 + + # no auth code, no cached token -> credential should redeem refresh token + cached_access_token = cache.find(cache.CredentialType.ACCESS_TOKEN)[0] + cache.remove_at(cached_access_token) + token = credential.get_token(expected_scope) + assert token.token == expected_access_token + assert transport.send.call_count == 2 diff --git a/sdk/identity/azure-identity/tests/test_auth_code_async.py b/sdk/identity/azure-identity/tests/test_auth_code_async.py index bb284a2ff4dc..3b754b40a6f4 100644 --- a/sdk/identity/azure-identity/tests/test_auth_code_async.py +++ b/sdk/identity/azure-identity/tests/test_auth_code_async.py @@ -2,21 +2,20 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ -import asyncio from unittest.mock import Mock -from azure.core.credentials import AccessToken -from azure.core.exceptions import ClientAuthenticationError from azure.core.pipeline.policies import SansIOHTTPPolicy from azure.identity._internal.user_agent import USER_AGENT from azure.identity.aio import AuthorizationCodeCredential +import msal import pytest from helpers import build_aad_response, mock_response, Request -from helpers_async import async_validating_transport, AsyncMockTransport, wrap_in_future +from helpers_async import async_validating_transport, AsyncMockTransport + +pytestmark = pytest.mark.asyncio -@pytest.mark.asyncio async def test_no_scopes(): """The credential should raise ValueError when get_token is called with no scopes""" @@ -25,7 +24,6 @@ async def test_no_scopes(): await credential.get_token() -@pytest.mark.asyncio async def test_policies_configurable(): policy = Mock(spec_set=SansIOHTTPPolicy, on_request=Mock()) @@ -41,7 +39,6 @@ async def send(*_, **__): assert policy.on_request.called -@pytest.mark.asyncio async def test_close(): transport = AsyncMockTransport() credential = AuthorizationCodeCredential( @@ -53,7 +50,6 @@ async def test_close(): assert transport.__aexit__.call_count == 1 -@pytest.mark.asyncio async def test_context_manager(): transport = AsyncMockTransport() credential = AuthorizationCodeCredential( @@ -67,7 +63,6 @@ async def test_context_manager(): assert transport.__aexit__.call_count == 1 -@pytest.mark.asyncio async def test_user_agent(): transport = async_validating_transport( requests=[Request(required_headers={"User-Agent": USER_AGENT})], @@ -81,87 +76,64 @@ async def test_user_agent(): await credential.get_token("scope") -@pytest.mark.asyncio async def test_auth_code_credential(): client_id = "client id" tenant_id = "tenant" expected_code = "auth code" - redirect_uri = "https://foo.bar" - expected_token = AccessToken("token", 42) + redirect_uri = "https://localhost" + expected_access_token = "access" + expected_refresh_token = "refresh" + expected_scope = "scope" - mock_client = Mock(spec=object) - obtain_by_auth_code = Mock(return_value=expected_token) - mock_client.obtain_token_by_authorization_code = wrap_in_future(obtain_by_auth_code) + auth_response = build_aad_response(access_token=expected_access_token, refresh_token=expected_refresh_token) + transport = async_validating_transport( + requests=[ + Request( # first call should redeem the auth code + url_substring=tenant_id, + required_data={ + "client_id": client_id, + "code": expected_code, + "grant_type": "authorization_code", + "redirect_uri": redirect_uri, + "scope": expected_scope, + }, + ), + Request( # third call should redeem the refresh token + url_substring=tenant_id, + required_data={ + "client_id": client_id, + "grant_type": "refresh_token", + "refresh_token": expected_refresh_token, + "scope": expected_scope, + }, + ), + ], + responses=[mock_response(json_payload=auth_response)] * 2, + ) + cache = msal.TokenCache() credential = AuthorizationCodeCredential( client_id=client_id, tenant_id=tenant_id, authorization_code=expected_code, redirect_uri=redirect_uri, - client=mock_client, + transport=transport, + cache=cache, ) # first call should redeem the auth code - token = await credential.get_token("scope") - assert token is expected_token - assert obtain_by_auth_code.call_count == 1 - _, kwargs = obtain_by_auth_code.call_args - assert kwargs["code"] == expected_code + token = await credential.get_token(expected_scope) + assert token.token == expected_access_token + assert transport.send.call_count == 1 # no auth code -> credential should return cached token - mock_client.obtain_token_by_authorization_code = None # raise if credential calls this again - mock_client.get_cached_access_token = lambda *_: expected_token - token = await credential.get_token("scope") - assert token is expected_token - - # no auth code, no cached token -> credential should use refresh token - mock_client.get_cached_access_token = lambda *_: None - mock_client.get_cached_refresh_tokens = lambda *_: ["this is a refresh token"] - mock_client.obtain_token_by_refresh_token = wrap_in_future(lambda *_, **__: expected_token) - token = await credential.get_token("scope") - assert token is expected_token - - -@pytest.mark.asyncio -async def test_custom_executor_used(): - credential = AuthorizationCodeCredential( - client_id="client id", tenant_id="tenant id", authorization_code="auth code", redirect_uri="https://foo.bar" - ) - - executor = Mock() - - with pytest.raises(ClientAuthenticationError): - await credential.get_token("scope", executor=executor) - - assert executor.submit.call_count == 1 - - -@pytest.mark.asyncio -async def test_custom_loop_used(): - credential = AuthorizationCodeCredential( - client_id="client id", tenant_id="tenant id", authorization_code="auth code", redirect_uri="https://foo.bar" - ) - - loop = Mock() - - with pytest.raises(ClientAuthenticationError): - await credential.get_token("scope", loop=loop) - - assert loop.run_in_executor.call_count == 1 - - -@pytest.mark.asyncio -async def test_custom_loop_and_executor_used(): - credential = AuthorizationCodeCredential( - client_id="client id", tenant_id="tenant id", authorization_code="auth code", redirect_uri="https://foo.bar" - ) - - executor = Mock() - loop = Mock() - - with pytest.raises(ClientAuthenticationError): - await credential.get_token("scope", executor=executor, loop=loop) - - assert loop.run_in_executor.call_count == 1 - executor_arg, _ = loop.run_in_executor.call_args[0] - assert executor_arg is executor + token = await credential.get_token(expected_scope) + assert token.token == expected_access_token + assert transport.send.call_count == 1 + + # no auth code, no cached token -> credential should redeem refresh token + cached_access_token = cache.find(cache.CredentialType.ACCESS_TOKEN)[0] + cache.remove_at(cached_access_token) + token = await credential.get_token(expected_scope) + assert token.token == expected_access_token + assert transport.send.call_count == 2 diff --git a/sdk/identity/azure-identity/tests/test_vscode_credential.py b/sdk/identity/azure-identity/tests/test_vscode_credential.py index fdfa919fc366..d4e0afe2b338 100644 --- a/sdk/identity/azure-identity/tests/test_vscode_credential.py +++ b/sdk/identity/azure-identity/tests/test_vscode_credential.py @@ -57,16 +57,17 @@ def test_credential_unavailable_error(): def test_redeem_token(): expected_token = AccessToken("token", 42) + expected_value = "value" mock_client = mock.Mock(spec=object) mock_client.obtain_token_by_refresh_token = mock.Mock(return_value=expected_token) mock_client.get_cached_access_token = mock.Mock(return_value=None) - with mock.patch(VSCodeCredential.__module__ + ".get_credentials", return_value="VALUE"): + with mock.patch(VSCodeCredential.__module__ + ".get_credentials", return_value=expected_value): credential = VSCodeCredential(_client=mock_client) token = credential.get_token("scope") assert token is expected_token - mock_client.obtain_token_by_refresh_token.assert_called_with("VALUE", ("scope",)) + mock_client.obtain_token_by_refresh_token.assert_called_with(("scope",), expected_value) assert mock_client.obtain_token_by_refresh_token.call_count == 1 diff --git a/sdk/identity/azure-identity/tests/test_vscode_credential_async.py b/sdk/identity/azure-identity/tests/test_vscode_credential_async.py index 3ab4059bfc56..89f27a78d4f5 100644 --- a/sdk/identity/azure-identity/tests/test_vscode_credential_async.py +++ b/sdk/identity/azure-identity/tests/test_vscode_credential_async.py @@ -58,17 +58,18 @@ async def test_credential_unavailable_error(): @pytest.mark.asyncio async def test_redeem_token(): expected_token = AccessToken("token", 42) + expected_value = "value" mock_client = mock.Mock(spec=object) token_by_refresh_token = mock.Mock(return_value=expected_token) mock_client.obtain_token_by_refresh_token = wrap_in_future(token_by_refresh_token) mock_client.get_cached_access_token = mock.Mock(return_value=None) - with mock.patch(VSCodeCredential.__module__ + ".get_credentials", return_value="VALUE"): + with mock.patch(VSCodeCredential.__module__ + ".get_credentials", return_value=expected_value): credential = VSCodeCredential(_client=mock_client) token = await credential.get_token("scope") assert token is expected_token - token_by_refresh_token.assert_called_with("VALUE", ("scope",)) + token_by_refresh_token.assert_called_with(("scope",), expected_value) @pytest.mark.asyncio