diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index c6b90bfd..09840ba6 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -8,6 +8,9 @@ on: pull_request: branches: [ dev ] + # This guards against unknown PR until a community member vet it and label it. + types: [ labeled ] + jobs: ci: env: diff --git a/.gitignore b/.gitignore index ff05e560..e776c10e 100644 --- a/.gitignore +++ b/.gitignore @@ -45,7 +45,8 @@ src/build # Virtual Environments /env* - +.venv/ +docs/_build/ # Visual Studio Files /.vs/* /tests/.vs/* diff --git a/docs/conf.py b/docs/conf.py index 251cf948..810dfc02 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -12,6 +12,7 @@ # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. # +from datetime import date import os import sys sys.path.insert(0, os.path.abspath('..')) @@ -20,7 +21,7 @@ # -- Project information ----------------------------------------------------- project = u'MSAL Python' -copyright = u'2018, Microsoft' +copyright = u'{0}, Microsoft'.format(date.today().year) author = u'Microsoft' # The short X.Y version @@ -77,13 +78,18 @@ # a list of builtin themes. # # html_theme = 'alabaster' -html_theme = 'sphinx_rtd_theme' +html_theme = 'furo' # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. # -# html_theme_options = {} +html_theme_options = { + "light_css_variables": { + "font-stack": "'Segoe UI', SegoeUI, 'Helvetica Neue', Helvetica, Arial, sans-serif", + "font-stack--monospace": "SFMono-Regular, Consolas, 'Liberation Mono', Menlo, Courier, monospace", + }, +} # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, @@ -176,4 +182,4 @@ epub_exclude_files = ['search.html'] -# -- Extension configuration ------------------------------------------------- +# -- Extension configuration ------------------------------------------------- \ No newline at end of file diff --git a/docs/index.rst b/docs/index.rst index baad12fd..439ca0ee 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,17 +1,13 @@ -.. MSAL Python documentation master file, created by - sphinx-quickstart on Tue Dec 18 10:53:22 2018. - You can adapt this file completely to your liking, but it should at least - contain the root `toctree` directive. - -.. This file is also inspired by - https://pythonhosted.org/an_example_pypi_project/sphinx.html#full-code-example - -Welcome to MSAL Python's documentation! -======================================= +MSAL Python documentation +========================= .. toctree:: :maxdepth: 2 :caption: Contents: + :hidden: + + MSAL Documentation + GitHub Repository You can find high level conceptual documentations in the project `README `_ @@ -22,9 +18,8 @@ and The documentation hosted here is for API Reference. - -PublicClientApplication and ConfidentialClientApplication -========================================================= +API +=== MSAL proposes a clean separation between `public client applications and confidential client applications @@ -35,31 +30,22 @@ with different methods for different authentication scenarios. PublicClientApplication ----------------------- + .. autoclass:: msal.PublicClientApplication :members: + :inherited-members: ConfidentialClientApplication ----------------------------- -.. autoclass:: msal.ConfidentialClientApplication - :members: - -Shared Methods --------------- -Both PublicClientApplication and ConfidentialClientApplication -have following methods inherited from their base class. -You typically do not need to initiate this base class, though. - -.. autoclass:: msal.ClientApplication +.. autoclass:: msal.ConfidentialClientApplication :members: - - .. automethod:: __init__ - + :inherited-members: TokenCache -========== +---------- -One of the parameter accepted by +One of the parameters accepted by both `PublicClientApplication` and `ConfidentialClientApplication` is the `TokenCache`. @@ -71,11 +57,3 @@ See `SerializableTokenCache` for example. .. autoclass:: msal.SerializableTokenCache :members: - - -Indices and tables -================== - -* :ref:`genindex` -* :ref:`search` - diff --git a/docs/requirements.txt b/docs/requirements.txt new file mode 100644 index 00000000..d5de57fe --- /dev/null +++ b/docs/requirements.txt @@ -0,0 +1,2 @@ +furo +-r ../requirements.txt \ No newline at end of file diff --git a/msal/application.py b/msal/application.py index a1f50038..c4a46b1f 100644 --- a/msal/application.py +++ b/msal/application.py @@ -21,7 +21,7 @@ # The __init__.py will import this. Not the other way around. -__version__ = "1.9.0" +__version__ = "1.10.0" logger = logging.getLogger(__name__) @@ -100,6 +100,12 @@ def _str2bytes(raw): return raw +def _clean_up(result): + if isinstance(result, dict): + result.pop("refresh_in", None) # MSAL handled refresh_in, customers need not + return result + + class ClientApplication(object): ACQUIRE_TOKEN_SILENT_ID = "84" @@ -507,7 +513,7 @@ def authorize(): # A controller in a web app return redirect(url_for("index")) """ self._validate_ssh_cert_input_data(kwargs.get("data", {})) - return self.client.obtain_token_by_auth_code_flow( + return _clean_up(self.client.obtain_token_by_auth_code_flow( auth_code_flow, auth_response, scope=decorate_scope(scopes, self.client_id) if scopes else None, @@ -521,7 +527,7 @@ def authorize(): # A controller in a web app claims=_merge_claims_challenge_and_capabilities( self._client_capabilities, auth_code_flow.pop("claims_challenge", None))), - **kwargs) + **kwargs)) def acquire_token_by_authorization_code( self, @@ -580,7 +586,7 @@ def acquire_token_by_authorization_code( "Change your acquire_token_by_authorization_code() " "to acquire_token_by_auth_code_flow()", DeprecationWarning) with warnings.catch_warnings(record=True): - return self.client.obtain_token_by_authorization_code( + return _clean_up(self.client.obtain_token_by_authorization_code( code, redirect_uri=redirect_uri, scope=decorate_scope(scopes, self.client_id), headers={ @@ -593,7 +599,7 @@ def acquire_token_by_authorization_code( claims=_merge_claims_challenge_and_capabilities( self._client_capabilities, claims_challenge)), nonce=nonce, - **kwargs) + **kwargs)) def get_accounts(self, username=None): """Get a list of accounts which previously signed in, i.e. exists in cache. @@ -822,6 +828,7 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it( force_refresh=False, # type: Optional[boolean] claims_challenge=None, **kwargs): + access_token_from_cache = None if not (force_refresh or claims_challenge): # Bypass AT when desired or using claims query={ "client_id": self.client_id, @@ -839,17 +846,27 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it( now = time.time() for entry in matches: expires_in = int(entry["expires_on"]) - now - if expires_in < 5*60: + if expires_in < 5*60: # Then consider it expired continue # Removal is not necessary, it will be overwritten logger.debug("Cache hit an AT") - return { # Mimic a real response + access_token_from_cache = { # Mimic a real response "access_token": entry["secret"], "token_type": entry.get("token_type", "Bearer"), "expires_in": int(expires_in), # OAuth2 specs defines it as int } - return self._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family( + if "refresh_on" in entry and int(entry["refresh_on"]) < now: # aging + break # With a fallback in hand, we break here to go refresh + return access_token_from_cache # It is still good as new + try: + result = self._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family( authority, decorate_scope(scopes, self.client_id), account, force_refresh=force_refresh, claims_challenge=claims_challenge, **kwargs) + result = _clean_up(result) + if (result and "error" not in result) or (not access_token_from_cache): + return result + except: # The exact HTTP exception is transportation-layer dependent + logger.exception("Refresh token failed") # Potential AAD outage? + return access_token_from_cache def _acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family( self, authority, scopes, account, **kwargs): @@ -907,11 +924,17 @@ def _acquire_token_silent_by_finding_specific_refresh_token( client = self._build_client(self.client_credential, authority) response = None # A distinguishable value to mean cache is empty - for entry in matches: + for entry in sorted( # Since unfit RTs would not be aggressively removed, + # we start from newer RTs which are more likely fit. + matches, + key=lambda e: int(e.get("last_modification_time", "0")), + reverse=True): logger.debug("Cache attempts an RT") response = client.obtain_token_by_refresh_token( entry, rt_getter=lambda token_item: token_item["secret"], - on_removing_rt=rt_remover or self.token_cache.remove_rt, + on_removing_rt=lambda rt_item: None, # Disable RT removal, + # because an invalid_grant could be caused by new MFA policy, + # the RT could still be useful for other MFA-less scope or tenant on_obtaining_tokens=lambda event: self.token_cache.add(dict( event, environment=authority.instance, @@ -976,7 +999,7 @@ def acquire_token_by_refresh_token(self, refresh_token, scopes, **kwargs): * A dict contains no "error" key means migration was successful. """ self._validate_ssh_cert_input_data(kwargs.get("data", {})) - return self.client.obtain_token_by_refresh_token( + return _clean_up(self.client.obtain_token_by_refresh_token( refresh_token, scope=decorate_scope(scopes, self.client_id), headers={ @@ -987,7 +1010,7 @@ def acquire_token_by_refresh_token(self, refresh_token, scopes, **kwargs): rt_getter=lambda rt: rt, on_updating_rt=False, on_removing_rt=lambda rt_item: None, # No OP - **kwargs) + **kwargs)) class PublicClientApplication(ClientApplication): # browser app or mobile app @@ -1013,6 +1036,9 @@ def acquire_token_interactive( **kwargs): """Acquire token interactively i.e. via a local browser. + Prerequisite: In Azure Portal, configure the Redirect URI of your + "Mobile and Desktop application" as ``http://localhost``. + :param list scope: It is a list of case-sensitive strings. :param str prompt: @@ -1061,7 +1087,7 @@ def acquire_token_interactive( self._validate_ssh_cert_input_data(kwargs.get("data", {})) claims = _merge_claims_challenge_and_capabilities( self._client_capabilities, claims_challenge) - return self.client.obtain_token_by_browser( + return _clean_up(self.client.obtain_token_by_browser( scope=decorate_scope(scopes, self.client_id) if scopes else None, extra_scope_to_consent=extra_scopes_to_consent, redirect_uri="http://localhost:{port}".format( @@ -1080,7 +1106,7 @@ def acquire_token_interactive( CLIENT_CURRENT_TELEMETRY: _build_current_telemetry_request_header( self.ACQUIRE_TOKEN_INTERACTIVE), }, - **kwargs) + **kwargs)) def initiate_device_flow(self, scopes=None, **kwargs): """Initiate a Device Flow instance, @@ -1123,7 +1149,7 @@ def acquire_token_by_device_flow(self, flow, claims_challenge=None, **kwargs): - A successful response would contain "access_token" key, - an error response would contain "error" and usually "error_description". """ - return self.client.obtain_token_by_device_flow( + return _clean_up(self.client.obtain_token_by_device_flow( flow, data=dict( kwargs.pop("data", {}), @@ -1139,7 +1165,7 @@ def acquire_token_by_device_flow(self, flow, claims_challenge=None, **kwargs): CLIENT_CURRENT_TELEMETRY: _build_current_telemetry_request_header( self.ACQUIRE_TOKEN_BY_DEVICE_FLOW_ID), }, - **kwargs) + **kwargs)) def acquire_token_by_username_password( self, username, password, scopes, claims_challenge=None, **kwargs): @@ -1177,15 +1203,15 @@ def acquire_token_by_username_password( user_realm_result = self.authority.user_realm_discovery( username, correlation_id=headers[CLIENT_REQUEST_ID]) if user_realm_result.get("account_type") == "Federated": - return self._acquire_token_by_username_password_federated( + return _clean_up(self._acquire_token_by_username_password_federated( user_realm_result, username, password, scopes=scopes, data=data, - headers=headers, **kwargs) - return self.client.obtain_token_by_username_password( + headers=headers, **kwargs)) + return _clean_up(self.client.obtain_token_by_username_password( username, password, scope=scopes, headers=headers, data=data, - **kwargs) + **kwargs)) def _acquire_token_by_username_password_federated( self, user_realm_result, username, password, scopes=None, **kwargs): @@ -1245,7 +1271,7 @@ def acquire_token_for_client(self, scopes, claims_challenge=None, **kwargs): """ # TBD: force_refresh behavior self._validate_ssh_cert_input_data(kwargs.get("data", {})) - return self.client.obtain_token_for_client( + return _clean_up(self.client.obtain_token_for_client( scope=scopes, # This grant flow requires no scope decoration headers={ CLIENT_REQUEST_ID: _get_new_correlation_id(), @@ -1256,7 +1282,7 @@ def acquire_token_for_client(self, scopes, claims_challenge=None, **kwargs): kwargs.pop("data", {}), claims=_merge_claims_challenge_and_capabilities( self._client_capabilities, claims_challenge)), - **kwargs) + **kwargs)) def acquire_token_on_behalf_of(self, user_assertion, scopes, claims_challenge=None, **kwargs): """Acquires token using on-behalf-of (OBO) flow. @@ -1286,7 +1312,7 @@ def acquire_token_on_behalf_of(self, user_assertion, scopes, claims_challenge=No """ # The implementation is NOT based on Token Exchange # https://tools.ietf.org/html/draft-ietf-oauth-token-exchange-16 - return self.client.obtain_token_by_assertion( # bases on assertion RFC 7521 + return _clean_up(self.client.obtain_token_by_assertion( # bases on assertion RFC 7521 user_assertion, self.client.GRANT_TYPE_JWT, # IDTs and AAD ATs are all JWTs scope=decorate_scope(scopes, self.client_id), # Decoration is used for: @@ -1305,4 +1331,4 @@ def acquire_token_on_behalf_of(self, user_assertion, scopes, claims_challenge=No CLIENT_CURRENT_TELEMETRY: _build_current_telemetry_request_header( self.ACQUIRE_TOKEN_ON_BEHALF_OF_ID), }, - **kwargs) + **kwargs)) diff --git a/msal/token_cache.py b/msal/token_cache.py index 34eff37c..edc7dcb6 100644 --- a/msal/token_cache.py +++ b/msal/token_cache.py @@ -148,9 +148,9 @@ def __add(self, event, now=None): target = ' '.join(event.get("scope", [])) # Per schema, we don't sort it with self._lock: + now = int(time.time() if now is None else now) if access_token: - now = int(time.time() if now is None else now) expires_in = int( # AADv1-like endpoint returns a string response.get("expires_in", 3599)) ext_expires_in = int( # AADv1-like endpoint returns a string @@ -170,6 +170,9 @@ def __add(self, event, now=None): } if data.get("key_id"): # It happens in SSH-cert or POP scenario at["key_id"] = data.get("key_id") + if "refresh_in" in response: + refresh_in = response["refresh_in"] # It is an integer + at["refresh_on"] = str(now + refresh_in) # Schema wants a string self.modify(self.CredentialType.ACCESS_TOKEN, at, at) if client_info and not event.get("skip_account_creation"): @@ -209,6 +212,7 @@ def __add(self, event, now=None): "environment": environment, "client_id": event.get("client_id"), "target": target, # Optional per schema though + "last_modification_time": str(now), # Optional. Schema defines it as a string. } if "foci" in response: rt["family_id"] = response["foci"] @@ -246,8 +250,10 @@ def remove_rt(self, rt_item): def update_rt(self, rt_item, new_rt): assert rt_item.get("credential_type") == self.CredentialType.REFRESH_TOKEN - return self.modify( - self.CredentialType.REFRESH_TOKEN, rt_item, {"secret": new_rt}) + return self.modify(self.CredentialType.REFRESH_TOKEN, rt_item, { + "secret": new_rt, + "last_modification_time": str(int(time.time())), # Optional. Schema defines it as a string. + }) def remove_at(self, at_item): assert at_item.get("credential_type") == self.CredentialType.ACCESS_TOKEN diff --git a/sample/confidential_client_certificate_sample.py b/sample/confidential_client_certificate_sample.py index e3b1bf86..7e5d8069 100644 --- a/sample/confidential_client_certificate_sample.py +++ b/sample/confidential_client_certificate_sample.py @@ -48,7 +48,7 @@ client_credential={"thumbprint": config["thumbprint"], "private_key": open(config['private_key_file']).read()}, # token_cache=... # Default cache is in memory only. # You can learn how to use SerializableTokenCache from - # https://msal-python.rtfd.io/en/latest/#msal.SerializableTokenCache + # https://msal-python.readthedocs.io/en/latest/#msal.SerializableTokenCache ) # The pattern to acquire a token looks like this. diff --git a/sample/confidential_client_secret_sample.py b/sample/confidential_client_secret_sample.py index c7bc7374..d4c06e20 100644 --- a/sample/confidential_client_secret_sample.py +++ b/sample/confidential_client_secret_sample.py @@ -47,7 +47,7 @@ client_credential=config["secret"], # token_cache=... # Default cache is in memory only. # You can learn how to use SerializableTokenCache from - # https://msal-python.rtfd.io/en/latest/#msal.SerializableTokenCache + # https://msal-python.readthedocs.io/en/latest/#msal.SerializableTokenCache ) # The pattern to acquire a token looks like this. diff --git a/sample/device_flow_sample.py b/sample/device_flow_sample.py index 51667ce7..48f8e7f4 100644 --- a/sample/device_flow_sample.py +++ b/sample/device_flow_sample.py @@ -36,7 +36,7 @@ config["client_id"], authority=config["authority"], # token_cache=... # Default cache is in memory only. # You can learn how to use SerializableTokenCache from - # https://msal-python.rtfd.io/en/latest/#msal.SerializableTokenCache + # https://msal-python.readthedocs.io/en/latest/#msal.SerializableTokenCache ) # The pattern to acquire a token looks like this. diff --git a/sample/interactive_sample.py b/sample/interactive_sample.py index 2e5b1cf6..6aafd160 100644 --- a/sample/interactive_sample.py +++ b/sample/interactive_sample.py @@ -32,7 +32,7 @@ config["client_id"], authority=config["authority"], # token_cache=... # Default cache is in memory only. # You can learn how to use SerializableTokenCache from - # https://msal-python.rtfd.io/en/latest/#msal.SerializableTokenCache + # https://msal-python.readthedocs.io/en/latest/#msal.SerializableTokenCache ) # The pattern to acquire a token looks like this. diff --git a/sample/migrate_rt.py b/sample/migrate_rt.py index eb623733..ed0011ed 100644 --- a/sample/migrate_rt.py +++ b/sample/migrate_rt.py @@ -50,7 +50,7 @@ def get_preexisting_rt_and_their_scopes_from_elsewhere(): config["client_id"], authority=config["authority"], # token_cache=... # Default cache is in memory only. # You can learn how to use SerializableTokenCache from - # https://msal-python.rtfd.io/en/latest/#msal.SerializableTokenCache + # https://msal-python.readthedocs.io/en/latest/#msal.SerializableTokenCache ) # We choose a migration strategy of migrating all RTs in one loop diff --git a/sample/username_password_sample.py b/sample/username_password_sample.py index 9c9b3c06..bcc8b7d5 100644 --- a/sample/username_password_sample.py +++ b/sample/username_password_sample.py @@ -38,7 +38,7 @@ config["client_id"], authority=config["authority"], # token_cache=... # Default cache is in memory only. # You can learn how to use SerializableTokenCache from - # https://msal-python.rtfd.io/en/latest/#msal.SerializableTokenCache + # https://msal-python.readthedocs.io/en/latest/#msal.SerializableTokenCache ) # The pattern to acquire a token looks like this. diff --git a/tests/test_application.py b/tests/test_application.py index 8d48a0ac..93b3d002 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -73,8 +73,7 @@ def setUp(self): self.client_id, authority=self.authority_url, token_cache=self.cache) def test_cache_empty_will_be_returned_as_None(self): - self.assertEqual( - None, self.app.acquire_token_silent(['cache_miss'], self.account)) + self.app.token_cache = msal.SerializableTokenCache() # Reset it to empty self.assertEqual( None, self.app.acquire_token_silent_with_error(['cache_miss'], self.account)) @@ -319,3 +318,91 @@ def test_only_client_capabilities_no_claims_merge(self): def test_both_claims_and_capabilities_none(self): self.assertEqual(_merge_claims_challenge_and_capabilities(None, None), None) + + +class TestApplicationForRefreshInBehaviors(unittest.TestCase): + """The following test cases were based on design doc here + https://identitydivision.visualstudio.com/DevEx/_git/AuthLibrariesApiReview?path=%2FRefreshAtExpirationPercentage%2Foverview.md&version=GBdev&_a=preview&anchor=scenarios + """ + def setUp(self): + self.authority_url = "https://login.microsoftonline.com/common" + self.authority = msal.authority.Authority( + self.authority_url, MinimalHttpClient()) + self.scopes = ["s1", "s2"] + self.uid = "my_uid" + self.utid = "my_utid" + self.account = {"home_account_id": "{}.{}".format(self.uid, self.utid)} + self.rt = "this is a rt" + self.cache = msal.SerializableTokenCache() + self.client_id = "my_app" + self.app = ClientApplication( + self.client_id, authority=self.authority_url, token_cache=self.cache) + + def populate_cache(self, access_token="at", expires_in=86400, refresh_in=43200): + self.cache.add({ + "client_id": self.client_id, + "scope": self.scopes, + "token_endpoint": "{}/oauth2/v2.0/token".format(self.authority_url), + "response": TokenCacheTestCase.build_response( + access_token=access_token, + expires_in=expires_in, refresh_in=refresh_in, + uid=self.uid, utid=self.utid, refresh_token=self.rt), + }) + + def test_fresh_token_should_be_returned_from_cache(self): + # a.k.a. Return unexpired token that is not above token refresh expiration threshold + access_token = "An access token prepopulated into cache" + self.populate_cache(access_token=access_token, expires_in=900, refresh_in=450) + result = self.app.acquire_token_silent(['s1'], self.account) + self.assertEqual(access_token, result.get("access_token")) + self.assertNotIn("refresh_in", result, "Customers need not know refresh_in") + + def test_aging_token_and_available_aad_should_return_new_token(self): + # a.k.a. Attempt to refresh unexpired token when AAD available + self.populate_cache(access_token="old AT", expires_in=3599, refresh_in=-1) + new_access_token = "new AT" + def mock_post(*args, **kwargs): + return MinimalResponse(status_code=200, text=json.dumps({ + "access_token": new_access_token, + "refresh_in": 123, + })) + self.app.http_client.post = mock_post + result = self.app.acquire_token_silent(['s1'], self.account) + self.assertEqual(new_access_token, result.get("access_token")) + self.assertNotIn("refresh_in", result, "Customers need not know refresh_in") + + def test_aging_token_and_unavailable_aad_should_return_old_token(self): + # a.k.a. Attempt refresh unexpired token when AAD unavailable + old_at = "old AT" + self.populate_cache(access_token=old_at, expires_in=3599, refresh_in=-1) + self.app._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family = ( + lambda *args, **kwargs: {"error": "sth went wrong"}) + self.assertEqual( + old_at, + self.app.acquire_token_silent(['s1'], self.account).get("access_token")) + + def test_expired_token_and_unavailable_aad_should_return_error(self): + # a.k.a. Attempt refresh expired token when AAD unavailable + self.populate_cache(access_token="expired at", expires_in=-1, refresh_in=-900) + error = "something went wrong" + self.app._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family = ( + lambda *args, **kwargs: {"error": error}) + self.assertEqual( + error, + self.app.acquire_token_silent_with_error( # This variant preserves error + ['s1'], self.account).get("error")) + + def test_expired_token_and_available_aad_should_return_new_token(self): + # a.k.a. Attempt refresh expired token when AAD available + self.populate_cache(access_token="expired at", expires_in=-1, refresh_in=-900) + new_access_token = "new AT" + def mock_post(*args, **kwargs): + return MinimalResponse(status_code=200, text=json.dumps({ + "access_token": new_access_token, + "refresh_in": 123, + })) + self.app.http_client.post = mock_post + result = self.app.acquire_token_silent(['s1'], self.account) + self.assertEqual(new_access_token, result.get("access_token")) + self.assertNotIn("refresh_in", result, "Customers need not know refresh_in") + diff --git a/tests/test_e2e.py b/tests/test_e2e.py index 94e8e17b..f57a3a48 100644 --- a/tests/test_e2e.py +++ b/tests/test_e2e.py @@ -23,8 +23,14 @@ def _get_app_and_auth_code( scopes=["https://graph.microsoft.com/.default"], # Microsoft Graph **kwargs): from msal.oauth2cli.authcode import obtain_auth_code - app = msal.ClientApplication( - client_id, client_secret, authority=authority, http_client=MinimalHttpClient()) + if client_secret: + app = msal.ConfidentialClientApplication( + client_id, + client_credential=client_secret, + authority=authority, http_client=MinimalHttpClient()) + else: + app = msal.PublicClientApplication( + client_id, authority=authority, http_client=MinimalHttpClient()) redirect_uri = "http://localhost:%d" % port ac = obtain_auth_code(port, auth_uri=app.get_authorization_request_url( scopes, redirect_uri=redirect_uri, **kwargs)) diff --git a/tests/test_token_cache.py b/tests/test_token_cache.py index c846883d..3cce0c82 100644 --- a/tests/test_token_cache.py +++ b/tests/test_token_cache.py @@ -29,30 +29,20 @@ def build_id_token( def build_response( # simulate a response from AAD uid=None, utid=None, # If present, they will form client_info access_token=None, expires_in=3600, token_type="some type", - refresh_token=None, - foci=None, - id_token=None, # or something generated by build_id_token() - error=None, + **kwargs # Pass-through: refresh_token, foci, id_token, error, refresh_in, ... ): response = {} if uid and utid: # Mimic the AAD behavior for "client_info=1" request response["client_info"] = base64.b64encode(json.dumps({ "uid": uid, "utid": utid, }).encode()).decode('utf-8') - if error: - response["error"] = error if access_token: response.update({ "access_token": access_token, "expires_in": expires_in, "token_type": token_type, }) - if refresh_token: - response["refresh_token"] = refresh_token - if id_token: - response["id_token"] = id_token - if foci: - response["foci"] = foci + response.update(kwargs) # Pass-through key-value pairs as top-level fields return response def setUp(self): @@ -94,6 +84,7 @@ def testAddByAad(self): 'credential_type': 'RefreshToken', 'environment': 'login.example.com', 'home_account_id': "uid.utid", + 'last_modification_time': '1000', 'secret': 'a refresh token', 'target': 's2 s1 s3', }, @@ -167,6 +158,7 @@ def testAddByAdfs(self): 'credential_type': 'RefreshToken', 'environment': 'fs.msidlab8.com', 'home_account_id': "subject", + 'last_modification_time': "1000", 'secret': 'a refresh token', 'target': 's2 s1 s3', }, @@ -222,6 +214,21 @@ def test_key_id_is_also_recorded(self): {}).get("key_id") self.assertEqual(my_key_id, cached_key_id, "AT should be bound to the key") + def test_refresh_in_should_be_recorded_as_refresh_on(self): # Sounds weird. Yep. + self.cache.add({ + "client_id": "my_client_id", + "scope": ["s2", "s1", "s3"], # Not in particular order + "token_endpoint": "https://login.example.com/contoso/v2/token", + "response": self.build_response( + uid="uid", utid="utid", # client_info + expires_in=3600, refresh_in=1800, access_token="an access token", + ), #refresh_token="a refresh token"), + }, now=1000) + refresh_on = self.cache._cache["AccessToken"].get( + 'uid.utid-login.example.com-accesstoken-my_client_id-contoso-s2 s1 s3', + {}).get("refresh_on") + self.assertEqual("2800", refresh_on, "Should save refresh_on") + def test_old_rt_data_with_wrong_key_should_still_be_salvaged_into_new_rt(self): sample = { 'client_id': 'my_client_id', @@ -241,6 +248,7 @@ def test_old_rt_data_with_wrong_key_should_still_be_salvaged_into_new_rt(self): 'uid.utid-login.example.com-refreshtoken-my_client_id--s2 s1 s3') ) + class SerializableTokenCacheTestCase(TokenCacheTestCase): # Run all inherited test methods, and have extra check in tearDown()