From 36365ac43c412ea54ef3637e560b24a1ad424786 Mon Sep 17 00:00:00 2001 From: Ray Luo Date: Thu, 11 Feb 2021 20:19:37 -0800 Subject: [PATCH] Implement refresh_in behavior, and some test cases --- msal/application.py | 17 +++++++-- msal/token_cache.py | 3 ++ tests/test_application.py | 80 +++++++++++++++++++++++++++++++++++++++ tests/test_token_cache.py | 30 +++++++++------ 4 files changed, 115 insertions(+), 15 deletions(-) diff --git a/msal/application.py b/msal/application.py index a1f50038..72bbecf3 100644 --- a/msal/application.py +++ b/msal/application.py @@ -822,6 +822,7 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it( force_refresh=False, # type: Optional[boolean] claims_challenge=None, **kwargs): + access_token_from_cache = None if not (force_refresh or claims_challenge): # Bypass AT when desired or using claims query={ "client_id": self.client_id, @@ -839,17 +840,27 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it( now = time.time() for entry in matches: expires_in = int(entry["expires_on"]) - now - if expires_in < 5*60: + if expires_in < 5*60: # Then consider it expired continue # Removal is not necessary, it will be overwritten logger.debug("Cache hit an AT") - return { # Mimic a real response + access_token_from_cache = { # Mimic a real response "access_token": entry["secret"], "token_type": entry.get("token_type", "Bearer"), "expires_in": int(expires_in), # OAuth2 specs defines it as int } - return self._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family( + if "refresh_on" in entry and int(entry["refresh_on"]) < now: # aging + break # With a fallback in hand, we break here to go refresh + return access_token_from_cache # It is still good as new + try: + result = self._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family( authority, decorate_scope(scopes, self.client_id), account, force_refresh=force_refresh, claims_challenge=claims_challenge, **kwargs) + if (result and "error" not in result) or (not access_token_from_cache): + return result + except: # The exact HTTP exception is transportation-layer dependent + logger.exception("Refresh token failed") # Potential AAD outage? + return access_token_from_cache + def _acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family( self, authority, scopes, account, **kwargs): diff --git a/msal/token_cache.py b/msal/token_cache.py index 34eff37c..028635b5 100644 --- a/msal/token_cache.py +++ b/msal/token_cache.py @@ -170,6 +170,9 @@ def __add(self, event, now=None): } if data.get("key_id"): # It happens in SSH-cert or POP scenario at["key_id"] = data.get("key_id") + if "refresh_in" in response: + refresh_in = response["refresh_in"] # It is an integer + at["refresh_on"] = str(now + refresh_in) # Schema wants a string self.modify(self.CredentialType.ACCESS_TOKEN, at, at) if client_info and not event.get("skip_account_creation"): diff --git a/tests/test_application.py b/tests/test_application.py index 8d48a0ac..3c3b4644 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -319,3 +319,83 @@ def test_only_client_capabilities_no_claims_merge(self): def test_both_claims_and_capabilities_none(self): self.assertEqual(_merge_claims_challenge_and_capabilities(None, None), None) + + +class TestApplicationForRefreshInBehaviors(unittest.TestCase): + """The following test cases were based on design doc here + https://identitydivision.visualstudio.com/DevEx/_git/AuthLibrariesApiReview?path=%2FRefreshAtExpirationPercentage%2Foverview.md&version=GBdev&_a=preview&anchor=scenarios + """ + def setUp(self): + self.authority_url = "https://login.microsoftonline.com/common" + self.authority = msal.authority.Authority( + self.authority_url, MinimalHttpClient()) + self.scopes = ["s1", "s2"] + self.uid = "my_uid" + self.utid = "my_utid" + self.account = {"home_account_id": "{}.{}".format(self.uid, self.utid)} + self.rt = "this is a rt" + self.cache = msal.SerializableTokenCache() + self.client_id = "my_app" + self.app = ClientApplication( + self.client_id, authority=self.authority_url, token_cache=self.cache) + + def populate_cache(self, access_token="at", expires_in=86400, refresh_in=43200): + self.cache.add({ + "client_id": self.client_id, + "scope": self.scopes, + "token_endpoint": "{}/oauth2/v2.0/token".format(self.authority_url), + "response": TokenCacheTestCase.build_response( + access_token=access_token, + expires_in=expires_in, refresh_in=refresh_in, + uid=self.uid, utid=self.utid, refresh_token=self.rt), + }) + + def test_fresh_token_should_be_returned_from_cache(self): + # a.k.a. Return unexpired token that is not above token refresh expiration threshold + access_token = "An access token prepopulated into cache" + self.populate_cache(access_token=access_token, expires_in=900, refresh_in=450) + self.assertEqual( + access_token, + self.app.acquire_token_silent(['s1'], self.account).get("access_token")) + + def test_aging_token_and_available_aad_should_return_new_token(self): + # a.k.a. Attempt to refresh unexpired token when AAD available + self.populate_cache(access_token="old AT", expires_in=3599, refresh_in=-1) + new_access_token = "new AT" + self.app._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family = ( + lambda *args, **kwargs: {"access_token": new_access_token}) + self.assertEqual( + new_access_token, + self.app.acquire_token_silent(['s1'], self.account).get("access_token")) + + def test_aging_token_and_unavailable_aad_should_return_old_token(self): + # a.k.a. Attempt refresh unexpired token when AAD unavailable + old_at = "old AT" + self.populate_cache(access_token=old_at, expires_in=3599, refresh_in=-1) + self.app._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family = ( + lambda *args, **kwargs: {"error": "sth went wrong"}) + self.assertEqual( + old_at, + self.app.acquire_token_silent(['s1'], self.account).get("access_token")) + + def test_expired_token_and_unavailable_aad_should_return_error(self): + # a.k.a. Attempt refresh expired token when AAD unavailable + self.populate_cache(access_token="expired at", expires_in=-1, refresh_in=-900) + error = "something went wrong" + self.app._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family = ( + lambda *args, **kwargs: {"error": error}) + self.assertEqual( + error, + self.app.acquire_token_silent_with_error( # This variant preserves error + ['s1'], self.account).get("error")) + + def test_expired_token_and_available_aad_should_return_new_token(self): + # a.k.a. Attempt refresh expired token when AAD available + self.populate_cache(access_token="expired at", expires_in=-1, refresh_in=-900) + new_access_token = "new AT" + self.app._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family = ( + lambda *args, **kwargs: {"access_token": new_access_token}) + self.assertEqual( + new_access_token, + self.app.acquire_token_silent(['s1'], self.account).get("access_token")) + diff --git a/tests/test_token_cache.py b/tests/test_token_cache.py index c846883d..92ab7c33 100644 --- a/tests/test_token_cache.py +++ b/tests/test_token_cache.py @@ -29,30 +29,20 @@ def build_id_token( def build_response( # simulate a response from AAD uid=None, utid=None, # If present, they will form client_info access_token=None, expires_in=3600, token_type="some type", - refresh_token=None, - foci=None, - id_token=None, # or something generated by build_id_token() - error=None, + **kwargs # Pass-through: refresh_token, foci, id_token, error, refresh_in, ... ): response = {} if uid and utid: # Mimic the AAD behavior for "client_info=1" request response["client_info"] = base64.b64encode(json.dumps({ "uid": uid, "utid": utid, }).encode()).decode('utf-8') - if error: - response["error"] = error if access_token: response.update({ "access_token": access_token, "expires_in": expires_in, "token_type": token_type, }) - if refresh_token: - response["refresh_token"] = refresh_token - if id_token: - response["id_token"] = id_token - if foci: - response["foci"] = foci + response.update(kwargs) # Pass-through key-value pairs as top-level fields return response def setUp(self): @@ -222,6 +212,21 @@ def test_key_id_is_also_recorded(self): {}).get("key_id") self.assertEqual(my_key_id, cached_key_id, "AT should be bound to the key") + def test_refresh_in_should_be_recorded_as_refresh_on(self): # Sounds weird. Yep. + self.cache.add({ + "client_id": "my_client_id", + "scope": ["s2", "s1", "s3"], # Not in particular order + "token_endpoint": "https://login.example.com/contoso/v2/token", + "response": self.build_response( + uid="uid", utid="utid", # client_info + expires_in=3600, refresh_in=1800, access_token="an access token", + ), #refresh_token="a refresh token"), + }, now=1000) + refresh_on = self.cache._cache["AccessToken"].get( + 'uid.utid-login.example.com-accesstoken-my_client_id-contoso-s2 s1 s3', + {}).get("refresh_on") + self.assertEqual("2800", refresh_on, "Should save refresh_on") + def test_old_rt_data_with_wrong_key_should_still_be_salvaged_into_new_rt(self): sample = { 'client_id': 'my_client_id', @@ -241,6 +246,7 @@ def test_old_rt_data_with_wrong_key_should_still_be_salvaged_into_new_rt(self): 'uid.utid-login.example.com-refreshtoken-my_client_id--s2 s1 s3') ) + class SerializableTokenCacheTestCase(TokenCacheTestCase): # Run all inherited test methods, and have extra check in tearDown()