Skip to content

Commit

Permalink
Identity new msal (#31267)
Browse files Browse the repository at this point in the history
* update tests to accommodate msal change

* update msal mindep

* update
  • Loading branch information
xiangyan99 authored Jul 24, 2023
1 parent 17c3695 commit 5714623
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 19 deletions.
1 change: 1 addition & 0 deletions eng/tox/install_depend_packages.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
"requests": "2.19.0",
"six": "1.12.0",
"cryptography": "3.3.2",
"msal": "1.23.0",
}

# this array contains overrides ONLY IF the package being processed the key of each item
Expand Down
6 changes: 6 additions & 0 deletions sdk/identity/azure-identity/tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,12 @@ def msal_validating_transport(requests, responses, **kwargs):
return validating_transport([Request()] * 2 + requests, [get_discovery_response(**kwargs)] * 2 + responses)


def new_msal_validating_transport(requests, responses, **kwargs):
"""a transport with default responses to MSAL's discovery requests without validation"""
"""msal made some optimizations to make less calls to discovery endpoint"""
return validating_transport([Request()] + requests, [get_discovery_response(**kwargs)] + responses)


def urlsafeb64_decode(s):
if isinstance(s, str):
s = s.encode("ascii")
Expand Down
15 changes: 8 additions & 7 deletions sdk/identity/azure-identity/tests/test_certificate_credential.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
urlsafeb64_decode,
mock_response,
msal_validating_transport,
new_msal_validating_transport,
Request,
)

Expand Down Expand Up @@ -86,7 +87,7 @@ def test_no_scopes():
def test_policies_configurable():
policy = Mock(spec_set=SansIOHTTPPolicy, on_request=Mock())

transport = msal_validating_transport(
transport = new_msal_validating_transport(
requests=[Request()], responses=[mock_response(json_payload=build_aad_response(access_token="**"))]
)

Expand All @@ -100,7 +101,7 @@ def test_policies_configurable():


def test_user_agent():
transport = msal_validating_transport(
transport = new_msal_validating_transport(
requests=[Request(required_headers={"User-Agent": USER_AGENT})],
responses=[mock_response(json_payload=build_aad_response(access_token="**"))],
)
Expand All @@ -111,7 +112,7 @@ def test_user_agent():


def test_tenant_id():
transport = msal_validating_transport(
transport = new_msal_validating_transport(
requests=[Request(required_headers={"User-Agent": USER_AGENT})],
responses=[mock_response(json_payload=build_aad_response(access_token="**"))],
)
Expand Down Expand Up @@ -316,10 +317,10 @@ def test_persistent_cache_multiple_clients(cert_path, cert_password):

access_token_a = "token a"
access_token_b = "not " + access_token_a
transport_a = msal_validating_transport(
transport_a = new_msal_validating_transport(
requests=[Request()], responses=[mock_response(json_payload=build_aad_response(access_token=access_token_a))]
)
transport_b = msal_validating_transport(
transport_b = new_msal_validating_transport(
requests=[Request()], responses=[mock_response(json_payload=build_aad_response(access_token=access_token_b))]
)

Expand Down Expand Up @@ -350,12 +351,12 @@ def test_persistent_cache_multiple_clients(cert_path, cert_password):
scope = "scope"
token_a = credential_a.get_token(scope)
assert token_a.token == access_token_a
assert transport_a.send.call_count == 3 # two MSAL discovery requests, one token request
assert transport_a.send.call_count == 2 # two MSAL discovery requests, one token request

# B should get a different token for the same scope
token_b = credential_b.get_token(scope)
assert token_b.token == access_token_b
assert transport_b.send.call_count == 3
assert transport_b.send.call_count == 2

assert len(cache.find(TokenCache.CredentialType.ACCESS_TOKEN)) == 2

Expand Down
21 changes: 9 additions & 12 deletions sdk/identity/azure-identity/tests/test_client_secret_credential.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,10 @@
id_token_claims,
mock_response,
msal_validating_transport,
new_msal_validating_transport,
Request,
)

try:
from unittest.mock import Mock, patch
except ImportError: # python < 3.3
from mock import Mock, patch # type: ignore
from unittest.mock import Mock, patch


def test_tenant_id_validation():
Expand All @@ -52,7 +49,7 @@ def test_no_scopes():
def test_policies_configurable():
policy = Mock(spec_set=SansIOHTTPPolicy, on_request=Mock())

transport = msal_validating_transport(
transport = new_msal_validating_transport(
requests=[Request()], responses=[mock_response(json_payload=build_aad_response(access_token="**"))]
)

Expand All @@ -66,7 +63,7 @@ def test_policies_configurable():


def test_user_agent():
transport = msal_validating_transport(
transport = new_msal_validating_transport(
requests=[Request(required_headers={"User-Agent": USER_AGENT})],
responses=[mock_response(json_payload=build_aad_response(access_token="**"))],
)
Expand All @@ -82,7 +79,7 @@ def test_client_secret_credential():
tenant_id = "fake-tenant-id"
access_token = "***"

transport = msal_validating_transport(
transport = new_msal_validating_transport(
endpoint="https://localhost/" + tenant_id,
requests=[Request(url_substring=tenant_id, required_data={"client_id": client_id, "client_secret": secret})],
responses=[mock_response(json_payload=build_aad_response(access_token=access_token))],
Expand Down Expand Up @@ -170,10 +167,10 @@ def test_cache_multiple_clients():

access_token_a = "token a"
access_token_b = "not " + access_token_a
transport_a = msal_validating_transport(
transport_a = new_msal_validating_transport(
requests=[Request()], responses=[mock_response(json_payload=build_aad_response(access_token=access_token_a))]
)
transport_b = msal_validating_transport(
transport_b = new_msal_validating_transport(
requests=[Request()], responses=[mock_response(json_payload=build_aad_response(access_token=access_token_b))]
)

Expand Down Expand Up @@ -202,12 +199,12 @@ def test_cache_multiple_clients():
scope = "scope"
token_a = credential_a.get_token(scope)
assert token_a.token == access_token_a
assert transport_a.send.call_count == 3 # two MSAL discovery requests, one token request
assert transport_a.send.call_count == 2 # two MSAL discovery requests, one token request

# B should get a different token for the same scope
token_b = credential_b.get_token(scope)
assert token_b.token == access_token_b
assert transport_b.send.call_count == 3
assert transport_b.send.call_count == 2

assert len(cache.find(TokenCache.CredentialType.ACCESS_TOKEN)) == 2

Expand Down

0 comments on commit 5714623

Please sign in to comment.