Skip to content

Commit

Permalink
Merge pull request #312 from AzureAD/refresh-in
Browse files Browse the repository at this point in the history
Implement refresh_in behavior, and some test cases
  • Loading branch information
rayluo authored Feb 23, 2021
2 parents 963505f + 36365ac commit 78e9ccf
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 15 deletions.
17 changes: 14 additions & 3 deletions msal/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,6 +822,7 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
force_refresh=False, # type: Optional[boolean]
claims_challenge=None,
**kwargs):
access_token_from_cache = None
if not (force_refresh or claims_challenge): # Bypass AT when desired or using claims
query={
"client_id": self.client_id,
Expand All @@ -839,17 +840,27 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
now = time.time()
for entry in matches:
expires_in = int(entry["expires_on"]) - now
if expires_in < 5*60:
if expires_in < 5*60: # Then consider it expired
continue # Removal is not necessary, it will be overwritten
logger.debug("Cache hit an AT")
return { # Mimic a real response
access_token_from_cache = { # Mimic a real response
"access_token": entry["secret"],
"token_type": entry.get("token_type", "Bearer"),
"expires_in": int(expires_in), # OAuth2 specs defines it as int
}
return self._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family(
if "refresh_on" in entry and int(entry["refresh_on"]) < now: # aging
break # With a fallback in hand, we break here to go refresh
return access_token_from_cache # It is still good as new
try:
result = self._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family(
authority, decorate_scope(scopes, self.client_id), account,
force_refresh=force_refresh, claims_challenge=claims_challenge, **kwargs)
if (result and "error" not in result) or (not access_token_from_cache):
return result
except: # The exact HTTP exception is transportation-layer dependent
logger.exception("Refresh token failed") # Potential AAD outage?
return access_token_from_cache


def _acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family(
self, authority, scopes, account, **kwargs):
Expand Down
3 changes: 3 additions & 0 deletions msal/token_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,9 @@ def __add(self, event, now=None):
}
if data.get("key_id"): # It happens in SSH-cert or POP scenario
at["key_id"] = data.get("key_id")
if "refresh_in" in response:
refresh_in = response["refresh_in"] # It is an integer
at["refresh_on"] = str(now + refresh_in) # Schema wants a string
self.modify(self.CredentialType.ACCESS_TOKEN, at, at)

if client_info and not event.get("skip_account_creation"):
Expand Down
80 changes: 80 additions & 0 deletions tests/test_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,3 +319,83 @@ def test_only_client_capabilities_no_claims_merge(self):

def test_both_claims_and_capabilities_none(self):
self.assertEqual(_merge_claims_challenge_and_capabilities(None, None), None)


class TestApplicationForRefreshInBehaviors(unittest.TestCase):
"""The following test cases were based on design doc here
https://identitydivision.visualstudio.com/DevEx/_git/AuthLibrariesApiReview?path=%2FRefreshAtExpirationPercentage%2Foverview.md&version=GBdev&_a=preview&anchor=scenarios
"""
def setUp(self):
self.authority_url = "https://login.microsoftonline.com/common"
self.authority = msal.authority.Authority(
self.authority_url, MinimalHttpClient())
self.scopes = ["s1", "s2"]
self.uid = "my_uid"
self.utid = "my_utid"
self.account = {"home_account_id": "{}.{}".format(self.uid, self.utid)}
self.rt = "this is a rt"
self.cache = msal.SerializableTokenCache()
self.client_id = "my_app"
self.app = ClientApplication(
self.client_id, authority=self.authority_url, token_cache=self.cache)

def populate_cache(self, access_token="at", expires_in=86400, refresh_in=43200):
self.cache.add({
"client_id": self.client_id,
"scope": self.scopes,
"token_endpoint": "{}/oauth2/v2.0/token".format(self.authority_url),
"response": TokenCacheTestCase.build_response(
access_token=access_token,
expires_in=expires_in, refresh_in=refresh_in,
uid=self.uid, utid=self.utid, refresh_token=self.rt),
})

def test_fresh_token_should_be_returned_from_cache(self):
# a.k.a. Return unexpired token that is not above token refresh expiration threshold
access_token = "An access token prepopulated into cache"
self.populate_cache(access_token=access_token, expires_in=900, refresh_in=450)
self.assertEqual(
access_token,
self.app.acquire_token_silent(['s1'], self.account).get("access_token"))

def test_aging_token_and_available_aad_should_return_new_token(self):
# a.k.a. Attempt to refresh unexpired token when AAD available
self.populate_cache(access_token="old AT", expires_in=3599, refresh_in=-1)
new_access_token = "new AT"
self.app._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family = (
lambda *args, **kwargs: {"access_token": new_access_token})
self.assertEqual(
new_access_token,
self.app.acquire_token_silent(['s1'], self.account).get("access_token"))

def test_aging_token_and_unavailable_aad_should_return_old_token(self):
# a.k.a. Attempt refresh unexpired token when AAD unavailable
old_at = "old AT"
self.populate_cache(access_token=old_at, expires_in=3599, refresh_in=-1)
self.app._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family = (
lambda *args, **kwargs: {"error": "sth went wrong"})
self.assertEqual(
old_at,
self.app.acquire_token_silent(['s1'], self.account).get("access_token"))

def test_expired_token_and_unavailable_aad_should_return_error(self):
# a.k.a. Attempt refresh expired token when AAD unavailable
self.populate_cache(access_token="expired at", expires_in=-1, refresh_in=-900)
error = "something went wrong"
self.app._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family = (
lambda *args, **kwargs: {"error": error})
self.assertEqual(
error,
self.app.acquire_token_silent_with_error( # This variant preserves error
['s1'], self.account).get("error"))

def test_expired_token_and_available_aad_should_return_new_token(self):
# a.k.a. Attempt refresh expired token when AAD available
self.populate_cache(access_token="expired at", expires_in=-1, refresh_in=-900)
new_access_token = "new AT"
self.app._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family = (
lambda *args, **kwargs: {"access_token": new_access_token})
self.assertEqual(
new_access_token,
self.app.acquire_token_silent(['s1'], self.account).get("access_token"))

30 changes: 18 additions & 12 deletions tests/test_token_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,30 +29,20 @@ def build_id_token(
def build_response( # simulate a response from AAD
uid=None, utid=None, # If present, they will form client_info
access_token=None, expires_in=3600, token_type="some type",
refresh_token=None,
foci=None,
id_token=None, # or something generated by build_id_token()
error=None,
**kwargs # Pass-through: refresh_token, foci, id_token, error, refresh_in, ...
):
response = {}
if uid and utid: # Mimic the AAD behavior for "client_info=1" request
response["client_info"] = base64.b64encode(json.dumps({
"uid": uid, "utid": utid,
}).encode()).decode('utf-8')
if error:
response["error"] = error
if access_token:
response.update({
"access_token": access_token,
"expires_in": expires_in,
"token_type": token_type,
})
if refresh_token:
response["refresh_token"] = refresh_token
if id_token:
response["id_token"] = id_token
if foci:
response["foci"] = foci
response.update(kwargs) # Pass-through key-value pairs as top-level fields
return response

def setUp(self):
Expand Down Expand Up @@ -222,6 +212,21 @@ def test_key_id_is_also_recorded(self):
{}).get("key_id")
self.assertEqual(my_key_id, cached_key_id, "AT should be bound to the key")

def test_refresh_in_should_be_recorded_as_refresh_on(self): # Sounds weird. Yep.
self.cache.add({
"client_id": "my_client_id",
"scope": ["s2", "s1", "s3"], # Not in particular order
"token_endpoint": "https://login.example.com/contoso/v2/token",
"response": self.build_response(
uid="uid", utid="utid", # client_info
expires_in=3600, refresh_in=1800, access_token="an access token",
), #refresh_token="a refresh token"),
}, now=1000)
refresh_on = self.cache._cache["AccessToken"].get(
'uid.utid-login.example.com-accesstoken-my_client_id-contoso-s2 s1 s3',
{}).get("refresh_on")
self.assertEqual("2800", refresh_on, "Should save refresh_on")

def test_old_rt_data_with_wrong_key_should_still_be_salvaged_into_new_rt(self):
sample = {
'client_id': 'my_client_id',
Expand All @@ -241,6 +246,7 @@ def test_old_rt_data_with_wrong_key_should_still_be_salvaged_into_new_rt(self):
'uid.utid-login.example.com-refreshtoken-my_client_id--s2 s1 s3')
)


class SerializableTokenCacheTestCase(TokenCacheTestCase):
# Run all inherited test methods, and have extra check in tearDown()

Expand Down

0 comments on commit 78e9ccf

Please sign in to comment.