diff --git a/eng/tox/install_depend_packages.py b/eng/tox/install_depend_packages.py index c460c13307e2..644e240cec85 100644 --- a/eng/tox/install_depend_packages.py +++ b/eng/tox/install_depend_packages.py @@ -41,6 +41,7 @@ "requests": "2.19.0", "six": "1.12.0", "cryptography": "3.3.2", + "msal": "1.23.0", } # this array contains overrides ONLY IF the package being processed the key of each item diff --git a/sdk/identity/azure-identity/tests/helpers.py b/sdk/identity/azure-identity/tests/helpers.py index 77ba6110f476..ab8e0e4a451b 100644 --- a/sdk/identity/azure-identity/tests/helpers.py +++ b/sdk/identity/azure-identity/tests/helpers.py @@ -192,6 +192,12 @@ def msal_validating_transport(requests, responses, **kwargs): return validating_transport([Request()] * 2 + requests, [get_discovery_response(**kwargs)] * 2 + responses) +def new_msal_validating_transport(requests, responses, **kwargs): + """a transport with default responses to MSAL's discovery requests without validation""" + """msal made some optimizations to make less calls to discovery endpoint""" + return validating_transport([Request()] + requests, [get_discovery_response(**kwargs)] + responses) + + def urlsafeb64_decode(s): if isinstance(s, str): s = s.encode("ascii") diff --git a/sdk/identity/azure-identity/tests/test_certificate_credential.py b/sdk/identity/azure-identity/tests/test_certificate_credential.py index 44ce40c4aa80..b02b0a7653d8 100644 --- a/sdk/identity/azure-identity/tests/test_certificate_credential.py +++ b/sdk/identity/azure-identity/tests/test_certificate_credential.py @@ -29,6 +29,7 @@ urlsafeb64_decode, mock_response, msal_validating_transport, + new_msal_validating_transport, Request, ) @@ -86,7 +87,7 @@ def test_no_scopes(): def test_policies_configurable(): policy = Mock(spec_set=SansIOHTTPPolicy, on_request=Mock()) - transport = msal_validating_transport( + transport = new_msal_validating_transport( requests=[Request()], responses=[mock_response(json_payload=build_aad_response(access_token="**"))] ) @@ -100,7 +101,7 @@ def test_policies_configurable(): def test_user_agent(): - transport = msal_validating_transport( + transport = new_msal_validating_transport( requests=[Request(required_headers={"User-Agent": USER_AGENT})], responses=[mock_response(json_payload=build_aad_response(access_token="**"))], ) @@ -111,7 +112,7 @@ def test_user_agent(): def test_tenant_id(): - transport = msal_validating_transport( + transport = new_msal_validating_transport( requests=[Request(required_headers={"User-Agent": USER_AGENT})], responses=[mock_response(json_payload=build_aad_response(access_token="**"))], ) @@ -316,10 +317,10 @@ def test_persistent_cache_multiple_clients(cert_path, cert_password): access_token_a = "token a" access_token_b = "not " + access_token_a - transport_a = msal_validating_transport( + transport_a = new_msal_validating_transport( requests=[Request()], responses=[mock_response(json_payload=build_aad_response(access_token=access_token_a))] ) - transport_b = msal_validating_transport( + transport_b = new_msal_validating_transport( requests=[Request()], responses=[mock_response(json_payload=build_aad_response(access_token=access_token_b))] ) @@ -350,12 +351,12 @@ def test_persistent_cache_multiple_clients(cert_path, cert_password): scope = "scope" token_a = credential_a.get_token(scope) assert token_a.token == access_token_a - assert transport_a.send.call_count == 3 # two MSAL discovery requests, one token request + assert transport_a.send.call_count == 2 # two MSAL discovery requests, one token request # B should get a different token for the same scope token_b = credential_b.get_token(scope) assert token_b.token == access_token_b - assert transport_b.send.call_count == 3 + assert transport_b.send.call_count == 2 assert len(cache.find(TokenCache.CredentialType.ACCESS_TOKEN)) == 2 diff --git a/sdk/identity/azure-identity/tests/test_client_secret_credential.py b/sdk/identity/azure-identity/tests/test_client_secret_credential.py index efaf7fc5759a..c91d291790b7 100644 --- a/sdk/identity/azure-identity/tests/test_client_secret_credential.py +++ b/sdk/identity/azure-identity/tests/test_client_secret_credential.py @@ -19,13 +19,10 @@ id_token_claims, mock_response, msal_validating_transport, + new_msal_validating_transport, Request, ) - -try: - from unittest.mock import Mock, patch -except ImportError: # python < 3.3 - from mock import Mock, patch # type: ignore +from unittest.mock import Mock, patch def test_tenant_id_validation(): @@ -52,7 +49,7 @@ def test_no_scopes(): def test_policies_configurable(): policy = Mock(spec_set=SansIOHTTPPolicy, on_request=Mock()) - transport = msal_validating_transport( + transport = new_msal_validating_transport( requests=[Request()], responses=[mock_response(json_payload=build_aad_response(access_token="**"))] ) @@ -66,7 +63,7 @@ def test_policies_configurable(): def test_user_agent(): - transport = msal_validating_transport( + transport = new_msal_validating_transport( requests=[Request(required_headers={"User-Agent": USER_AGENT})], responses=[mock_response(json_payload=build_aad_response(access_token="**"))], ) @@ -82,7 +79,7 @@ def test_client_secret_credential(): tenant_id = "fake-tenant-id" access_token = "***" - transport = msal_validating_transport( + transport = new_msal_validating_transport( endpoint="https://localhost/" + tenant_id, requests=[Request(url_substring=tenant_id, required_data={"client_id": client_id, "client_secret": secret})], responses=[mock_response(json_payload=build_aad_response(access_token=access_token))], @@ -170,10 +167,10 @@ def test_cache_multiple_clients(): access_token_a = "token a" access_token_b = "not " + access_token_a - transport_a = msal_validating_transport( + transport_a = new_msal_validating_transport( requests=[Request()], responses=[mock_response(json_payload=build_aad_response(access_token=access_token_a))] ) - transport_b = msal_validating_transport( + transport_b = new_msal_validating_transport( requests=[Request()], responses=[mock_response(json_payload=build_aad_response(access_token=access_token_b))] ) @@ -202,12 +199,12 @@ def test_cache_multiple_clients(): scope = "scope" token_a = credential_a.get_token(scope) assert token_a.token == access_token_a - assert transport_a.send.call_count == 3 # two MSAL discovery requests, one token request + assert transport_a.send.call_count == 2 # two MSAL discovery requests, one token request # B should get a different token for the same scope token_b = credential_b.get_token(scope) assert token_b.token == access_token_b - assert transport_b.send.call_count == 3 + assert transport_b.send.call_count == 2 assert len(cache.find(TokenCache.CredentialType.ACCESS_TOKEN)) == 2