diff --git a/sdk/identity/azure-identity/tests/test_context_manager.py b/sdk/identity/azure-identity/tests/test_context_manager.py index eab3ade98bcf..e5f6e59a8013 100644 --- a/sdk/identity/azure-identity/tests/test_context_manager.py +++ b/sdk/identity/azure-identity/tests/test_context_manager.py @@ -30,13 +30,14 @@ class CredentialFixture: - def __init__(self, cls, default_kwargs=None, ctor_patch=None): + def __init__(self, cls, default_kwargs=None, ctor_patch_factory=None): self.cls = cls self._default_kwargs = default_kwargs or {} - self._ctor_patch = ctor_patch or MagicMock() + self._ctor_patch_factory = ctor_patch_factory or MagicMock def get_credential(self, **kwargs): - with self._ctor_patch: + patch = self._ctor_patch_factory() + with patch: return self.cls(**dict(self._default_kwargs, **kwargs)) @@ -50,18 +51,20 @@ def get_credential(self, **kwargs): CredentialFixture(DeviceCodeCredential), CredentialFixture( EnvironmentCredential, - ctor_patch=patch.dict( + ctor_patch_factory=lambda: patch.dict( EnvironmentCredential.__module__ + ".os.environ", {var: "..." for var in EnvironmentVariables.CLIENT_SECRET_VARS}, ), ), CredentialFixture(InteractiveBrowserCredential), CredentialFixture(UsernamePasswordCredential, {"client_id": "...", "username": "...", "password": "..."}), - CredentialFixture(VisualStudioCodeCredential, ctor_patch=patch(GET_USER_SETTINGS, lambda: {})), + CredentialFixture(VisualStudioCodeCredential, ctor_patch_factory=lambda: patch(GET_USER_SETTINGS, lambda: {})), ) +all_fixtures = pytest.mark.parametrize("fixture", FIXTURES, ids=lambda fixture: fixture.cls.__name__) -@pytest.mark.parametrize("fixture", FIXTURES, ids=lambda fixture: fixture.cls.__name__) + +@all_fixtures def test_close(fixture): transport = MagicMock() credential = fixture.get_credential(transport=transport) @@ -73,7 +76,7 @@ def test_close(fixture): assert transport.__exit__.call_count == 1 -@pytest.mark.parametrize("fixture", FIXTURES, ids=lambda fixture: fixture.cls.__name__) +@all_fixtures def test_context_manager(fixture): transport = MagicMock() credential = fixture.get_credential(transport=transport) @@ -86,7 +89,7 @@ def test_context_manager(fixture): assert transport.__exit__.call_count == 1 -@pytest.mark.parametrize("fixture", FIXTURES, ids=lambda fixture: fixture.cls.__name__) +@all_fixtures def test_exit_args(fixture): transport = MagicMock() credential = fixture.get_credential(transport=transport)