diff --git a/msal/application.py b/msal/application.py index 9ff16514..cf4a1a3a 100644 --- a/msal/application.py +++ b/msal/application.py @@ -8,7 +8,7 @@ import logging import sys import warnings -import uuid +from threading import Lock import requests @@ -18,6 +18,7 @@ from .wstrust_request import send_request as wst_send_request from .wstrust_response import * from .token_cache import TokenCache +import msal.telemetry # The __init__.py will import this. Not the other way around. @@ -52,18 +53,6 @@ def decorate_scope( decorated = scope_set | reserved_scope return list(decorated) -CLIENT_REQUEST_ID = 'client-request-id' -CLIENT_CURRENT_TELEMETRY = 'x-client-current-telemetry' - -def _get_new_correlation_id(): - correlation_id = str(uuid.uuid4()) - logger.debug("Generates correlation_id: %s", correlation_id) - return correlation_id - - -def _build_current_telemetry_request_header(public_api_id, force_refresh=False): - return "1|{},{}|".format(public_api_id, "1" if force_refresh else "0") - def extract_certs(public_cert_content): # Parses raw public certificate file contents and returns a list of strings @@ -257,6 +246,14 @@ def __init__( self.token_cache = token_cache or TokenCache() self.client = self._build_client(client_credential, self.authority) self.authority_groups = None + self._telemetry_buffer = {} + self._telemetry_lock = Lock() + + def _build_telemetry_context( + self, api_id, correlation_id=None, refresh_reason=None): + return msal.telemetry._TelemetryContext( + self._telemetry_buffer, self._telemetry_lock, api_id, + correlation_id=correlation_id, refresh_reason=refresh_reason) def _build_client(self, client_credential, authority): client_assertion = None @@ -520,21 +517,21 @@ def authorize(): # A controller in a web app return redirect(url_for("index")) """ self._validate_ssh_cert_input_data(kwargs.get("data", {})) - return _clean_up(self.client.obtain_token_by_auth_code_flow( + telemetry_context = self._build_telemetry_context( + self.ACQUIRE_TOKEN_BY_AUTHORIZATION_CODE_ID) + response =_clean_up(self.client.obtain_token_by_auth_code_flow( auth_code_flow, auth_response, scope=decorate_scope(scopes, self.client_id) if scopes else None, - headers={ - CLIENT_REQUEST_ID: _get_new_correlation_id(), - CLIENT_CURRENT_TELEMETRY: _build_current_telemetry_request_header( - self.ACQUIRE_TOKEN_BY_AUTHORIZATION_CODE_ID), - }, + headers=telemetry_context.generate_headers(), data=dict( kwargs.pop("data", {}), claims=_merge_claims_challenge_and_capabilities( self._client_capabilities, auth_code_flow.pop("claims_challenge", None))), **kwargs)) + telemetry_context.update_telemetry(response) + return response def acquire_token_by_authorization_code( self, @@ -593,20 +590,20 @@ def acquire_token_by_authorization_code( "Change your acquire_token_by_authorization_code() " "to acquire_token_by_auth_code_flow()", DeprecationWarning) with warnings.catch_warnings(record=True): - return _clean_up(self.client.obtain_token_by_authorization_code( + telemetry_context = self._build_telemetry_context( + self.ACQUIRE_TOKEN_BY_AUTHORIZATION_CODE_ID) + response = _clean_up(self.client.obtain_token_by_authorization_code( code, redirect_uri=redirect_uri, scope=decorate_scope(scopes, self.client_id), - headers={ - CLIENT_REQUEST_ID: _get_new_correlation_id(), - CLIENT_CURRENT_TELEMETRY: _build_current_telemetry_request_header( - self.ACQUIRE_TOKEN_BY_AUTHORIZATION_CODE_ID), - }, + headers=telemetry_context.generate_headers(), data=dict( kwargs.pop("data", {}), claims=_merge_claims_challenge_and_capabilities( self._client_capabilities, claims_challenge)), nonce=nonce, **kwargs)) + telemetry_context.update_telemetry(resposne) + return response def get_accounts(self, username=None): """Get a list of accounts which previously signed in, i.e. exists in cache. @@ -735,7 +732,7 @@ def acquire_token_silent( - None when cache lookup does not yield a token. """ result = self.acquire_token_silent_with_error( - scopes, account, authority, force_refresh, + scopes, account, authority=authority, force_refresh=force_refresh, claims_challenge=claims_challenge, **kwargs) return result if result and "error" not in result else None @@ -780,7 +777,7 @@ def acquire_token_silent_with_error( """ assert isinstance(scopes, list), "Invalid parameter type" self._validate_ssh_cert_input_data(kwargs.get("data", {})) - correlation_id = _get_new_correlation_id() + correlation_id = msal.telemetry._get_new_correlation_id() if authority: warnings.warn("We haven't decided how/if this method will accept authority parameter") # the_authority = Authority( @@ -851,9 +848,11 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it( target=scopes, query=query) now = time.time() + refresh_reason = msal.telemetry.AT_ABSENT for entry in matches: expires_in = int(entry["expires_on"]) - now if expires_in < 5*60: # Then consider it expired + refresh_reason = msal.telemetry.AT_EXPIRED continue # Removal is not necessary, it will be overwritten logger.debug("Cache hit an AT") access_token_from_cache = { # Mimic a real response @@ -862,13 +861,18 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it( "expires_in": int(expires_in), # OAuth2 specs defines it as int } if "refresh_on" in entry and int(entry["refresh_on"]) < now: # aging + refresh_reason = msal.telemetry.AT_AGING break # With a fallback in hand, we break here to go refresh + self._build_telemetry_context(-1).hit_an_access_token() return access_token_from_cache # It is still good as new + else: + refresh_reason = msal.telemetry.FORCE_REFRESH # TODO: It could also mean claims_challenge + assert refresh_reason, "It should have been established at this point" try: - result = self._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family( + result = _clean_up(self._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family( authority, decorate_scope(scopes, self.client_id), account, - force_refresh=force_refresh, claims_challenge=claims_challenge, **kwargs) - result = _clean_up(result) + refresh_reason=refresh_reason, claims_challenge=claims_challenge, + **kwargs)) if (result and "error" not in result) or (not access_token_from_cache): return result except: # The exact HTTP exception is transportation-layer dependent @@ -922,7 +926,8 @@ def _get_app_metadata(self, environment): def _acquire_token_silent_by_finding_specific_refresh_token( self, authority, scopes, query, rt_remover=None, break_condition=lambda response: False, - force_refresh=False, correlation_id=None, claims_challenge=None, **kwargs): + refresh_reason=None, correlation_id=None, claims_challenge=None, + **kwargs): matches = self.token_cache.find( self.token_cache.CredentialType.REFRESH_TOKEN, # target=scopes, # AAD RTs are scope-independent @@ -931,6 +936,9 @@ def _acquire_token_silent_by_finding_specific_refresh_token( client = self._build_client(self.client_credential, authority) response = None # A distinguishable value to mean cache is empty + telemetry_context = self._build_telemetry_context( + self.ACQUIRE_TOKEN_SILENT_ID, + correlation_id=correlation_id, refresh_reason=refresh_reason) for entry in sorted( # Since unfit RTs would not be aggressively removed, # we start from newer RTs which are more likely fit. matches, @@ -948,16 +956,13 @@ def _acquire_token_silent_by_finding_specific_refresh_token( skip_account_creation=True, # To honor a concurrent remove_account() )), scope=scopes, - headers={ - CLIENT_REQUEST_ID: correlation_id or _get_new_correlation_id(), - CLIENT_CURRENT_TELEMETRY: _build_current_telemetry_request_header( - self.ACQUIRE_TOKEN_SILENT_ID, force_refresh=force_refresh), - }, + headers=telemetry_context.generate_headers(), data=dict( kwargs.pop("data", {}), claims=_merge_claims_challenge_and_capabilities( self._client_capabilities, claims_challenge)), **kwargs) + telemetry_context.update_telemetry(response) if "error" not in response: return response logger.debug("Refresh failed. {error}: {error_description}".format( @@ -1006,18 +1011,19 @@ def acquire_token_by_refresh_token(self, refresh_token, scopes, **kwargs): * A dict contains no "error" key means migration was successful. """ self._validate_ssh_cert_input_data(kwargs.get("data", {})) - return _clean_up(self.client.obtain_token_by_refresh_token( + telemetry_context = self._build_telemetry_context( + self.ACQUIRE_TOKEN_BY_REFRESH_TOKEN, + refresh_reason=msal.telemetry.FORCE_REFRESH) + response = _clean_up(self.client.obtain_token_by_refresh_token( refresh_token, scope=decorate_scope(scopes, self.client_id), - headers={ - CLIENT_REQUEST_ID: _get_new_correlation_id(), - CLIENT_CURRENT_TELEMETRY: _build_current_telemetry_request_header( - self.ACQUIRE_TOKEN_BY_REFRESH_TOKEN), - }, + headers=telemetry_context.generate_headers(), rt_getter=lambda rt: rt, on_updating_rt=False, on_removing_rt=lambda rt_item: None, # No OP **kwargs)) + telemetry_context.update_telemetry(response) + return response class PublicClientApplication(ClientApplication): # browser app or mobile app @@ -1093,7 +1099,9 @@ def acquire_token_interactive( self._validate_ssh_cert_input_data(kwargs.get("data", {})) claims = _merge_claims_challenge_and_capabilities( self._client_capabilities, claims_challenge) - return _clean_up(self.client.obtain_token_by_browser( + telemetry_context = self._build_telemetry_context( + self.ACQUIRE_TOKEN_INTERACTIVE) + response = _clean_up(self.client.obtain_token_by_browser( scope=decorate_scope(scopes, self.client_id) if scopes else None, extra_scope_to_consent=extra_scopes_to_consent, redirect_uri="http://localhost:{port}".format( @@ -1107,12 +1115,10 @@ def acquire_token_interactive( "domain_hint": domain_hint, }, data=dict(kwargs.pop("data", {}), claims=claims), - headers={ - CLIENT_REQUEST_ID: _get_new_correlation_id(), - CLIENT_CURRENT_TELEMETRY: _build_current_telemetry_request_header( - self.ACQUIRE_TOKEN_INTERACTIVE), - }, + headers=telemetry_context.generate_headers(), **kwargs)) + telemetry_context.update_telemetry(response) + return response def initiate_device_flow(self, scopes=None, **kwargs): """Initiate a Device Flow instance, @@ -1125,13 +1131,10 @@ def initiate_device_flow(self, scopes=None, **kwargs): - A successful response would contain "user_code" key, among others - an error response would contain some other readable key/value pairs. """ - correlation_id = _get_new_correlation_id() + correlation_id = msal.telemetry._get_new_correlation_id() flow = self.client.initiate_device_flow( scope=decorate_scope(scopes or [], self.client_id), - headers={ - CLIENT_REQUEST_ID: correlation_id, - # CLIENT_CURRENT_TELEMETRY is not currently required - }, + headers={msal.telemetry.CLIENT_REQUEST_ID: correlation_id}, **kwargs) flow[self.DEVICE_FLOW_CORRELATION_ID] = correlation_id return flow @@ -1155,7 +1158,10 @@ def acquire_token_by_device_flow(self, flow, claims_challenge=None, **kwargs): - A successful response would contain "access_token" key, - an error response would contain "error" and usually "error_description". """ - return _clean_up(self.client.obtain_token_by_device_flow( + telemetry_context = self._build_telemetry_context( + self.ACQUIRE_TOKEN_BY_DEVICE_FLOW_ID, + correlation_id=flow.get(self.DEVICE_FLOW_CORRELATION_ID)) + response = _clean_up(self.client.obtain_token_by_device_flow( flow, data=dict( kwargs.pop("data", {}), @@ -1165,13 +1171,10 @@ def acquire_token_by_device_flow(self, flow, claims_challenge=None, **kwargs): claims=_merge_claims_challenge_and_capabilities( self._client_capabilities, claims_challenge), ), - headers={ - CLIENT_REQUEST_ID: - flow.get(self.DEVICE_FLOW_CORRELATION_ID) or _get_new_correlation_id(), - CLIENT_CURRENT_TELEMETRY: _build_current_telemetry_request_header( - self.ACQUIRE_TOKEN_BY_DEVICE_FLOW_ID), - }, + headers=telemetry_context.generate_headers(), **kwargs)) + telemetry_context.update_telemetry(response) + return response def acquire_token_by_username_password( self, username, password, scopes, claims_challenge=None, **kwargs): @@ -1196,28 +1199,30 @@ def acquire_token_by_username_password( - an error response would contain "error" and usually "error_description". """ scopes = decorate_scope(scopes, self.client_id) - headers = { - CLIENT_REQUEST_ID: _get_new_correlation_id(), - CLIENT_CURRENT_TELEMETRY: _build_current_telemetry_request_header( - self.ACQUIRE_TOKEN_BY_USERNAME_PASSWORD_ID), - } + telemetry_context = self._build_telemetry_context( + self.ACQUIRE_TOKEN_BY_USERNAME_PASSWORD_ID) + headers = telemetry_context.generate_headers() data = dict( kwargs.pop("data", {}), claims=_merge_claims_challenge_and_capabilities( self._client_capabilities, claims_challenge)) if not self.authority.is_adfs: user_realm_result = self.authority.user_realm_discovery( - username, correlation_id=headers[CLIENT_REQUEST_ID]) + username, correlation_id=headers[msal.telemetry.CLIENT_REQUEST_ID]) if user_realm_result.get("account_type") == "Federated": - return _clean_up(self._acquire_token_by_username_password_federated( + response = _clean_up(self._acquire_token_by_username_password_federated( user_realm_result, username, password, scopes=scopes, data=data, headers=headers, **kwargs)) - return _clean_up(self.client.obtain_token_by_username_password( + telemetry_context.update_telemetry(response) + return response + response = _clean_up(self.client.obtain_token_by_username_password( username, password, scope=scopes, headers=headers, data=data, **kwargs)) + telemetry_context.update_telemetry(response) + return response def _acquire_token_by_username_password_federated( self, user_realm_result, username, password, scopes=None, **kwargs): @@ -1277,18 +1282,18 @@ def acquire_token_for_client(self, scopes, claims_challenge=None, **kwargs): """ # TBD: force_refresh behavior self._validate_ssh_cert_input_data(kwargs.get("data", {})) - return _clean_up(self.client.obtain_token_for_client( + telemetry_context = self._build_telemetry_context( + self.ACQUIRE_TOKEN_FOR_CLIENT_ID) + response = _clean_up(self.client.obtain_token_for_client( scope=scopes, # This grant flow requires no scope decoration - headers={ - CLIENT_REQUEST_ID: _get_new_correlation_id(), - CLIENT_CURRENT_TELEMETRY: _build_current_telemetry_request_header( - self.ACQUIRE_TOKEN_FOR_CLIENT_ID), - }, + headers=telemetry_context.generate_headers(), data=dict( kwargs.pop("data", {}), claims=_merge_claims_challenge_and_capabilities( self._client_capabilities, claims_challenge)), **kwargs)) + telemetry_context.update_telemetry(response) + return response def acquire_token_on_behalf_of(self, user_assertion, scopes, claims_challenge=None, **kwargs): """Acquires token using on-behalf-of (OBO) flow. @@ -1316,9 +1321,11 @@ def acquire_token_on_behalf_of(self, user_assertion, scopes, claims_challenge=No - A successful response would contain "access_token" key, - an error response would contain "error" and usually "error_description". """ + telemetry_context = self._build_telemetry_context( + self.ACQUIRE_TOKEN_ON_BEHALF_OF_ID) # The implementation is NOT based on Token Exchange # https://tools.ietf.org/html/draft-ietf-oauth-token-exchange-16 - return _clean_up(self.client.obtain_token_by_assertion( # bases on assertion RFC 7521 + response = _clean_up(self.client.obtain_token_by_assertion( # bases on assertion RFC 7521 user_assertion, self.client.GRANT_TYPE_JWT, # IDTs and AAD ATs are all JWTs scope=decorate_scope(scopes, self.client_id), # Decoration is used for: @@ -1332,9 +1339,8 @@ def acquire_token_on_behalf_of(self, user_assertion, scopes, claims_challenge=No requested_token_use="on_behalf_of", claims=_merge_claims_challenge_and_capabilities( self._client_capabilities, claims_challenge)), - headers={ - CLIENT_REQUEST_ID: _get_new_correlation_id(), - CLIENT_CURRENT_TELEMETRY: _build_current_telemetry_request_header( - self.ACQUIRE_TOKEN_ON_BEHALF_OF_ID), - }, + headers=telemetry_context.generate_headers(), **kwargs)) + telemetry_context.update_telemetry(response) + return response + diff --git a/msal/telemetry.py b/msal/telemetry.py new file mode 100644 index 00000000..b07ab3ed --- /dev/null +++ b/msal/telemetry.py @@ -0,0 +1,78 @@ +import uuid +import logging + + +logger = logging.getLogger(__name__) + +CLIENT_REQUEST_ID = 'client-request-id' +CLIENT_CURRENT_TELEMETRY = "x-client-current-telemetry" +CLIENT_LAST_TELEMETRY = "x-client-last-telemetry" +NON_SILENT_CALL = 0 +FORCE_REFRESH = 1 +AT_ABSENT = 2 +AT_EXPIRED = 3 +AT_AGING = 4 +RESERVED = 5 + + +def _get_new_correlation_id(): + return str(uuid.uuid4()) + + +class _TelemetryContext(object): + """It is used for handling the telemetry context for current OAuth2 "exchange".""" + # https://identitydivision.visualstudio.com/DevEx/_git/AuthLibrariesApiReview?path=%2FTelemetry%2FMSALServerSideTelemetry.md&_a=preview + _SUCCEEDED = "succeeded" + _FAILED = "failed" + _FAILURE_SIZE = "failure_size" + _CURRENT_HEADER_SIZE_LIMIT = 100 + _LAST_HEADER_SIZE_LIMIT = 350 + + def __init__(self, buffer, lock, api_id, correlation_id=None, refresh_reason=None): + self._buffer = buffer + self._lock = lock + self._api_id = api_id + self._correlation_id = correlation_id or _get_new_correlation_id() + self._refresh_reason = refresh_reason or NON_SILENT_CALL + logger.debug("Generate or reuse correlation_id: %s", self._correlation_id) + + def generate_headers(self): + with self._lock: + current = "4|{api_id},{cache_refresh}|".format( + api_id=self._api_id, cache_refresh=self._refresh_reason) + if len(current) > self._CURRENT_HEADER_SIZE_LIMIT: + logger.warning( + "Telemetry header greater than {} will be truncated by AAD".format( + self._CURRENT_HEADER_SIZE_LIMIT)) + failures = self._buffer.get(self._FAILED, []) + return { + CLIENT_REQUEST_ID: self._correlation_id, + CLIENT_CURRENT_TELEMETRY: current, + CLIENT_LAST_TELEMETRY: "4|{succeeded}|{failed_requests}|{errors}|".format( + succeeded=self._buffer.get(self._SUCCEEDED, 0), + failed_requests=",".join("{a},{c}".format(**f) for f in failures), + errors=",".join(f["e"] for f in failures), + ) + } + + def hit_an_access_token(self): + with self._lock: + self._buffer[self._SUCCEEDED] = self._buffer.get(self._SUCCEEDED, 0) + 1 + + def update_telemetry(self, auth_result): + if auth_result: + with self._lock: + if "error" in auth_result: + self._record_failure(auth_result["error"]) + else: # Telemetry sent successfully. Reset buffer + self._buffer.clear() # This won't work: self._buffer = {} + + def _record_failure(self, error): + simulation = len(",{api_id},{correlation_id},{error}".format( + api_id=self._api_id, correlation_id=self._correlation_id, error=error)) + if self._buffer.get(self._FAILURE_SIZE, 0) + simulation < self._LAST_HEADER_SIZE_LIMIT: + self._buffer[self._FAILURE_SIZE] = self._buffer.get( + self._FAILURE_SIZE, 0) + simulation + self._buffer.setdefault(self._FAILED, []).append({ + "a": self._api_id, "c": self._correlation_id, "e": error}) + diff --git a/msal/token_cache.py b/msal/token_cache.py index edc7dcb6..b0731278 100644 --- a/msal/token_cache.py +++ b/msal/token_cache.py @@ -145,7 +145,7 @@ def __add(self, event, now=None): client_info["uid"] = id_token_claims.get("sub") home_account_id = id_token_claims.get("sub") - target = ' '.join(event.get("scope", [])) # Per schema, we don't sort it + target = ' '.join(event.get("scope") or []) # Per schema, we don't sort it with self._lock: now = int(time.time() if now is None else now) diff --git a/tests/test_application.py b/tests/test_application.py index 93b3d002..f4787e2c 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -7,6 +7,7 @@ from tests import unittest from tests.test_token_cache import TokenCacheTestCase from tests.http_client import MinimalHttpClient, MinimalResponse +from msal.telemetry import CLIENT_CURRENT_TELEMETRY, CLIENT_LAST_TELEMETRY logger = logging.getLogger(__name__) @@ -282,7 +283,7 @@ class TestApplicationForClientCapabilities(unittest.TestCase): def test_capabilities_and_id_token_claims_merge(self): client_capabilities = ["foo", "bar"] claims_challenge = '''{"id_token": {"auth_time": {"essential": true}}}''' - merged_claims = '''{"id_token": {"auth_time": {"essential": true}}, + merged_claims = '''{"id_token": {"auth_time": {"essential": true}}, "access_token": {"xms_cc": {"values": ["foo", "bar"]}}}''' # Comparing dictionaries as JSON object order differs based on python version self.assertEqual( @@ -292,7 +293,7 @@ def test_capabilities_and_id_token_claims_merge(self): def test_capabilities_and_id_token_claims_and_access_token_claims_merge(self): client_capabilities = ["foo", "bar"] - claims_challenge = '''{"id_token": {"auth_time": {"essential": true}}, + claims_challenge = '''{"id_token": {"auth_time": {"essential": true}}, "access_token": {"nbf":{"essential":true, "value":"1563308371"}}}''' merged_claims = '''{"id_token": {"auth_time": {"essential": true}}, "access_token": {"nbf": {"essential": true, "value": "1563308371"}, @@ -324,19 +325,17 @@ class TestApplicationForRefreshInBehaviors(unittest.TestCase): """The following test cases were based on design doc here https://identitydivision.visualstudio.com/DevEx/_git/AuthLibrariesApiReview?path=%2FRefreshAtExpirationPercentage%2Foverview.md&version=GBdev&_a=preview&anchor=scenarios """ + authority_url = "https://login.microsoftonline.com/common" + scopes = ["s1", "s2"] + uid = "my_uid" + utid = "my_utid" + account = {"home_account_id": "{}.{}".format(uid, utid)} + rt = "this is a rt" + client_id = "my_app" + app = ClientApplication(client_id, authority=authority_url) + def setUp(self): - self.authority_url = "https://login.microsoftonline.com/common" - self.authority = msal.authority.Authority( - self.authority_url, MinimalHttpClient()) - self.scopes = ["s1", "s2"] - self.uid = "my_uid" - self.utid = "my_utid" - self.account = {"home_account_id": "{}.{}".format(self.uid, self.utid)} - self.rt = "this is a rt" - self.cache = msal.SerializableTokenCache() - self.client_id = "my_app" - self.app = ClientApplication( - self.client_id, authority=self.authority_url, token_cache=self.cache) + self.app.token_cache = self.cache = msal.SerializableTokenCache() def populate_cache(self, access_token="at", expires_in=86400, refresh_in=43200): self.cache.add({ @@ -353,7 +352,11 @@ def test_fresh_token_should_be_returned_from_cache(self): # a.k.a. Return unexpired token that is not above token refresh expiration threshold access_token = "An access token prepopulated into cache" self.populate_cache(access_token=access_token, expires_in=900, refresh_in=450) - result = self.app.acquire_token_silent(['s1'], self.account) + result = self.app.acquire_token_silent( + ['s1'], self.account, + post=lambda url, *args, **kwargs: # Utilize the undocumented test feature + self.fail("I/O shouldn't happen in cache hit AT scenario") + ) self.assertEqual(access_token, result.get("access_token")) self.assertNotIn("refresh_in", result, "Customers need not know refresh_in") @@ -361,13 +364,13 @@ def test_aging_token_and_available_aad_should_return_new_token(self): # a.k.a. Attempt to refresh unexpired token when AAD available self.populate_cache(access_token="old AT", expires_in=3599, refresh_in=-1) new_access_token = "new AT" - def mock_post(*args, **kwargs): + def mock_post(url, headers=None, *args, **kwargs): + self.assertEqual("4|84,4|", (headers or {}).get(CLIENT_CURRENT_TELEMETRY)) return MinimalResponse(status_code=200, text=json.dumps({ "access_token": new_access_token, "refresh_in": 123, })) - self.app.http_client.post = mock_post - result = self.app.acquire_token_silent(['s1'], self.account) + result = self.app.acquire_token_silent(['s1'], self.account, post=mock_post) self.assertEqual(new_access_token, result.get("access_token")) self.assertNotIn("refresh_in", result, "Customers need not know refresh_in") @@ -375,34 +378,180 @@ def test_aging_token_and_unavailable_aad_should_return_old_token(self): # a.k.a. Attempt refresh unexpired token when AAD unavailable old_at = "old AT" self.populate_cache(access_token=old_at, expires_in=3599, refresh_in=-1) - self.app._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family = ( - lambda *args, **kwargs: {"error": "sth went wrong"}) - self.assertEqual( - old_at, - self.app.acquire_token_silent(['s1'], self.account).get("access_token")) + def mock_post(url, headers=None, *args, **kwargs): + self.assertEqual("4|84,2|", (headers or {}).get(CLIENT_CURRENT_TELEMETRY)) + return MinimalResponse(status_code=400, text=json.dumps({"error": error})) + result = self.app.acquire_token_silent(['s1'], self.account, post=mock_post) + self.assertEqual(old_at, result.get("access_token")) def test_expired_token_and_unavailable_aad_should_return_error(self): # a.k.a. Attempt refresh expired token when AAD unavailable self.populate_cache(access_token="expired at", expires_in=-1, refresh_in=-900) error = "something went wrong" - self.app._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family = ( - lambda *args, **kwargs: {"error": error}) - self.assertEqual( - error, - self.app.acquire_token_silent_with_error( # This variant preserves error - ['s1'], self.account).get("error")) + def mock_post(url, headers=None, *args, **kwargs): + self.assertEqual("4|84,3|", (headers or {}).get(CLIENT_CURRENT_TELEMETRY)) + return MinimalResponse(status_code=400, text=json.dumps({"error": error})) + result = self.app.acquire_token_silent_with_error( + ['s1'], self.account, post=mock_post) + self.assertEqual(error, result.get("error"), "Error should be returned") def test_expired_token_and_available_aad_should_return_new_token(self): # a.k.a. Attempt refresh expired token when AAD available self.populate_cache(access_token="expired at", expires_in=-1, refresh_in=-900) new_access_token = "new AT" - def mock_post(*args, **kwargs): + def mock_post(url, headers=None, *args, **kwargs): + self.assertEqual("4|84,3|", (headers or {}).get(CLIENT_CURRENT_TELEMETRY)) return MinimalResponse(status_code=200, text=json.dumps({ "access_token": new_access_token, "refresh_in": 123, })) - self.app.http_client.post = mock_post - result = self.app.acquire_token_silent(['s1'], self.account) + result = self.app.acquire_token_silent(['s1'], self.account, post=mock_post) self.assertEqual(new_access_token, result.get("access_token")) self.assertNotIn("refresh_in", result, "Customers need not know refresh_in") + +class TestTelemetryMaintainingOfflineState(unittest.TestCase): + authority_url = "https://login.microsoftonline.com/common" + scopes = ["s1", "s2"] + uid = "my_uid" + utid = "my_utid" + account = {"home_account_id": "{}.{}".format(uid, utid)} + rt = "this is a rt" + client_id = "my_app" + + def populate_cache(self, cache, access_token="at"): + cache.add({ + "client_id": self.client_id, + "scope": self.scopes, + "token_endpoint": "{}/oauth2/v2.0/token".format(self.authority_url), + "response": TokenCacheTestCase.build_response( + access_token=access_token, + uid=self.uid, utid=self.utid, refresh_token=self.rt), + }) + + def test_maintaining_offline_state_and_sending_them(self): + app = PublicClientApplication( + self.client_id, + authority=self.authority_url, token_cache=msal.SerializableTokenCache()) + cached_access_token = "cached_at" + self.populate_cache(app.token_cache, access_token=cached_access_token) + + result = app.acquire_token_silent( + self.scopes, self.account, + post=lambda url, *args, **kwargs: # Utilize the undocumented test feature + self.fail("I/O shouldn't happen in cache hit AT scenario") + ) + self.assertEqual(cached_access_token, result.get("access_token")) + + error1 = "error_1" + def mock_post(url, headers=None, *args, **kwargs): + self.assertEqual("4|622,0|", (headers or {}).get(CLIENT_CURRENT_TELEMETRY)) + self.assertEqual("4|1|||", (headers or {}).get(CLIENT_LAST_TELEMETRY), + "The previous cache hit should result in success counter value as 1") + return MinimalResponse(status_code=400, text=json.dumps({"error": error1})) + result = app.acquire_token_by_device_flow({ # It allows customizing correlation_id + "device_code": "123", + PublicClientApplication.DEVICE_FLOW_CORRELATION_ID: "id_1", + }, post=mock_post) + self.assertEqual(error1, result.get("error")) + + error2 = "error_2" + def mock_post(url, headers=None, *args, **kwargs): + self.assertEqual("4|622,0|", (headers or {}).get(CLIENT_CURRENT_TELEMETRY)) + self.assertEqual("4|1|622,id_1|error_1|", (headers or {}).get(CLIENT_LAST_TELEMETRY), + "The previous error should result in same success counter plus latest error info") + return MinimalResponse(status_code=400, text=json.dumps({"error": error2})) + result = app.acquire_token_by_device_flow({ + "device_code": "123", + PublicClientApplication.DEVICE_FLOW_CORRELATION_ID: "id_2", + }, post=mock_post) + self.assertEqual(error2, result.get("error")) + + at = "ensures the successful path (which includes the mock) been used" + def mock_post(url, headers=None, *args, **kwargs): + self.assertEqual("4|622,0|", (headers or {}).get(CLIENT_CURRENT_TELEMETRY)) + self.assertEqual("4|1|622,id_1,622,id_2|error_1,error_2|", (headers or {}).get(CLIENT_LAST_TELEMETRY), + "The previous error should result in same success counter plus latest error info") + return MinimalResponse(status_code=200, text=json.dumps({"access_token": at})) + result = app.acquire_token_by_device_flow({"device_code": "123"}, post=mock_post) + self.assertEqual(at, result.get("access_token")) + + def mock_post(url, headers=None, *args, **kwargs): + self.assertEqual("4|622,0|", (headers or {}).get(CLIENT_CURRENT_TELEMETRY)) + self.assertEqual("4|0|||", (headers or {}).get(CLIENT_LAST_TELEMETRY), + "The previous success should reset all offline telemetry counters") + return MinimalResponse(status_code=200, text=json.dumps({"access_token": at})) + result = app.acquire_token_by_device_flow({"device_code": "123"}, post=mock_post) + self.assertEqual(at, result.get("access_token")) + + +class TestTelemetryOnClientApplication(unittest.TestCase): + app = ClientApplication( + "client_id", authority="https://login.microsoftonline.com/common") + + def test_acquire_token_by_auth_code_flow(self): + at = "this is an access token" + def mock_post(url, headers=None, *args, **kwargs): + self.assertEqual("4|832,0|", (headers or {}).get(CLIENT_CURRENT_TELEMETRY)) + return MinimalResponse(status_code=200, text=json.dumps({"access_token": at})) + state = "foo" + result = self.app.acquire_token_by_auth_code_flow( + {"state": state, "code_verifier": "bar"}, {"state": state, "code": "012"}, + post=mock_post) + self.assertEqual(at, result.get("access_token")) + + def test_acquire_token_by_refresh_token(self): + at = "this is an access token" + def mock_post(url, headers=None, *args, **kwargs): + self.assertEqual("4|85,1|", (headers or {}).get(CLIENT_CURRENT_TELEMETRY)) + return MinimalResponse(status_code=200, text=json.dumps({"access_token": at})) + result = self.app.acquire_token_by_refresh_token("rt", ["s"], post=mock_post) + self.assertEqual(at, result.get("access_token")) + + +class TestTelemetryOnPublicClientApplication(unittest.TestCase): + app = PublicClientApplication( + "client_id", authority="https://login.microsoftonline.com/common") + + # For now, acquire_token_interactive() is verified by code review. + + def test_acquire_token_by_device_flow(self): + at = "this is an access token" + def mock_post(url, headers=None, *args, **kwargs): + self.assertEqual("4|622,0|", (headers or {}).get(CLIENT_CURRENT_TELEMETRY)) + return MinimalResponse(status_code=200, text=json.dumps({"access_token": at})) + result = self.app.acquire_token_by_device_flow( + {"device_code": "123"}, post=mock_post) + self.assertEqual(at, result.get("access_token")) + + def test_acquire_token_by_username_password(self): + at = "this is an access token" + def mock_post(url, headers=None, *args, **kwargs): + self.assertEqual("4|301,0|", (headers or {}).get(CLIENT_CURRENT_TELEMETRY)) + return MinimalResponse(status_code=200, text=json.dumps({"access_token": at})) + result = self.app.acquire_token_by_username_password( + "username", "password", ["scope"], post=mock_post) + self.assertEqual(at, result.get("access_token")) + + +class TestTelemetryOnConfidentialClientApplication(unittest.TestCase): + app = ConfidentialClientApplication( + "client_id", client_credential="secret", + authority="https://login.microsoftonline.com/common") + + def test_acquire_token_for_client(self): + at = "this is an access token" + def mock_post(url, headers=None, *args, **kwargs): + self.assertEqual("4|730,0|", (headers or {}).get(CLIENT_CURRENT_TELEMETRY)) + return MinimalResponse(status_code=200, text=json.dumps({"access_token": at})) + result = self.app.acquire_token_for_client(["scope"], post=mock_post) + self.assertEqual(at, result.get("access_token")) + + def test_acquire_token_on_behalf_of(self): + at = "this is an access token" + def mock_post(url, headers=None, *args, **kwargs): + self.assertEqual("4|523,0|", (headers or {}).get(CLIENT_CURRENT_TELEMETRY)) + return MinimalResponse(status_code=200, text=json.dumps({"access_token": at})) + result = self.app.acquire_token_on_behalf_of("assertion", ["s"], post=mock_post) + self.assertEqual(at, result.get("access_token")) +