Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add context manager API to azure.identity credentials #19746

Merged
merged 15 commits into from
Aug 16, 2021
5 changes: 5 additions & 0 deletions sdk/identity/azure-identity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@
### Bugs Fixed

### Other Changes
- Added context manager methods and `close()` to credentials in the
`azure.identity` namespace. At the end of a `with` block, or when `close()`
is called, these credentials close their underlying transport sessions.
([#18798](https://github.com/Azure/azure-sdk-for-python/issues/18798))


## 1.7.0b3 (2021-08-10)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,45 +6,27 @@
import os
from typing import TYPE_CHECKING

from azure.core.credentials import AccessToken
from azure.core.pipeline.transport import HttpRequest

from .. import CredentialUnavailableError
from .._constants import EnvironmentVariables
from .._internal.managed_identity_base import ManagedIdentityBase
from .._internal.managed_identity_client import ManagedIdentityClient
from .._internal.get_token_mixin import GetTokenMixin

if TYPE_CHECKING:
from typing import Any, Optional


class AppServiceCredential(GetTokenMixin):
def __init__(self, **kwargs):
# type: (**Any) -> None
super(AppServiceCredential, self).__init__()

class AppServiceCredential(ManagedIdentityBase):
def get_client(self, **kwargs):
# type: (**Any) -> Optional[ManagedIdentityClient]
client_args = _get_client_args(**kwargs)
if client_args:
self._available = True
self._client = ManagedIdentityClient(**client_args)
else:
self._available = False

def get_token(self, *scopes, **kwargs):
# type: (*str, **Any) -> AccessToken
if not self._available:
raise CredentialUnavailableError(
message="App Service managed identity configuration not found in environment"
)
return super(AppServiceCredential, self).get_token(*scopes, **kwargs)

def _acquire_token_silently(self, *scopes, **kwargs):
# type: (*str, **Any) -> Optional[AccessToken]
return self._client.get_cached_token(*scopes)

def _request_token(self, *scopes, **kwargs):
# type: (*str, **Any) -> AccessToken
return self._client.request_token(*scopes, **kwargs)
return ManagedIdentityClient(**client_args)
return None

def get_unavailable_message(self):
# type: () -> str
return "App Service managed identity configuration not found in environment"


def _get_client_args(**kwargs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,18 @@ def __init__(self, tenant_id, client_id, authorization_code, redirect_uri, **kwa
self._redirect_uri = redirect_uri
super(AuthorizationCodeCredential, self).__init__()

def __enter__(self):
self._client.__enter__()
return self

def __exit__(self, *args):
self._client.__exit__(*args)

def close(self):
# type: () -> None
"""Close the credential's transport session."""
self.__exit__()

def get_token(self, *scopes, **kwargs):
# type: (*str, **Any) -> AccessToken
"""Request an access token for `scopes`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,51 +10,42 @@
from azure.core.pipeline.transport import HttpRequest
from azure.core.pipeline.policies import HTTPPolicy

from .. import CredentialUnavailableError
from .._constants import EnvironmentVariables
from .._internal.managed_identity_base import ManagedIdentityBase
from .._internal.managed_identity_client import ManagedIdentityClient
from .._internal.get_token_mixin import GetTokenMixin

if TYPE_CHECKING:
# pylint:disable=unused-import,ungrouped-imports
from typing import Any, Optional, Union
from azure.core.credentials import AccessToken
from typing import Any, Optional
from azure.core.pipeline import PipelineRequest, PipelineResponse
from azure.core.pipeline.policies import SansIOHTTPPolicy

PolicyType = Union[HTTPPolicy, SansIOHTTPPolicy]


class AzureArcCredential(GetTokenMixin):
def __init__(self, **kwargs):
# type: (**Any) -> None
super(AzureArcCredential, self).__init__()

class AzureArcCredential(ManagedIdentityBase):
def get_client(self, **kwargs):
# type: (**Any) -> Optional[ManagedIdentityClient]
url = os.environ.get(EnvironmentVariables.IDENTITY_ENDPOINT)
imds = os.environ.get(EnvironmentVariables.IMDS_ENDPOINT)
self._available = url and imds
if self._available:
self._client = ManagedIdentityClient(
if url and imds:
return ManagedIdentityClient(
_per_retry_policies=[ArcChallengeAuthPolicy()],
request_factory=functools.partial(_get_request, url),
**kwargs
)
return None

def get_token(self, *scopes, **kwargs):
# type: (*str, **Any) -> AccessToken
if not self._available:
raise CredentialUnavailableError(
message="Azure Arc managed identity configuration not found in environment"
)
return super(AzureArcCredential, self).get_token(*scopes, **kwargs)
def __enter__(self):
self._client.__enter__()
return self

def __exit__(self, *args):
self._client.__exit__(*args)

def _acquire_token_silently(self, *scopes, **kwargs):
# type: (*str, **Any) -> Optional[AccessToken]
return self._client.get_cached_token(*scopes)
def close(self):
self.__exit__()

def _request_token(self, *scopes, **kwargs):
# type: (*str, **Any) -> AccessToken
return self._client.request_token(*scopes, **kwargs)
def get_unavailable_message(self):
# type: () -> str
return "Azure Arc managed identity configuration not found in environment"


def _get_request(url, scope, identity_config):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,16 @@ class AzureCliCredential(object):
def __init__(self, **kwargs):
self._allow_multitenant = kwargs.get("allow_multitenant_authentication", False)

def __enter__(self):
return self

def __exit__(self, *args):
pass

def close(self):
# type: () -> None
"""Calling this method is unnecessary."""

@log_get_token("AzureCliCredential")
def get_token(self, *scopes, **kwargs):
# type: (*str, **Any) -> AccessToken
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,16 @@ def __init__(self, **kwargs):
# type: (**Any) -> None
self._allow_multitenant = kwargs.get("allow_multitenant_authentication", False)

def __enter__(self):
return self

def __exit__(self, *args):
pass

def close(self):
# type: () -> None
"""Calling this method is unnecessary."""

@log_get_token("AzurePowerShellCredential")
def get_token(self, *scopes, **kwargs):
# type: (*str, **Any) -> AccessToken
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,20 @@ def __init__(self, *credentials):
self._successful_credential = None # type: Optional[TokenCredential]
self.credentials = credentials

def __enter__(self):
for credential in self.credentials:
credential.__enter__()
return self

def __exit__(self, *args):
for credential in self.credentials:
credential.__exit__(*args)

def close(self):
# type: () -> None
"""Close the transport session of each credential in the chain."""
self.__exit__()

def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
# type: (*str, **Any) -> AccessToken
"""Request a token from each chained credential, in order, returning the first token received.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,17 @@ def __init__(self, tenant_id, client_id, get_assertion, **kwargs):
self._client = AadClient(tenant_id, client_id, **kwargs)
super(ClientAssertionCredential, self).__init__(**kwargs)

def __enter__(self):
self._client.__enter__()
return self

def __exit__(self, *args):
self._client.__exit__(*args)

def close(self):
# type: () -> None
self.__exit__()

def _acquire_token_silently(self, *scopes, **kwargs):
# type: (*str, **Any) -> Optional[AccessToken]
return self._client.get_cached_access_token(scopes, **kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,46 +8,27 @@

from azure.core.pipeline.transport import HttpRequest

from .. import CredentialUnavailableError
from .._constants import EnvironmentVariables
from .._internal.managed_identity_client import ManagedIdentityClient
from .._internal.get_token_mixin import GetTokenMixin
from .._internal.managed_identity_base import ManagedIdentityBase

if TYPE_CHECKING:
from typing import Any, Optional
from azure.core.credentials import AccessToken


class CloudShellCredential(GetTokenMixin):
def __init__(self, **kwargs):
# type: (**Any) -> None
super(CloudShellCredential, self).__init__()
class CloudShellCredential(ManagedIdentityBase):
def get_client(self, **kwargs):
# type: (**Any) -> Optional[ManagedIdentityClient]
url = os.environ.get(EnvironmentVariables.MSI_ENDPOINT)
if url:
self._available = True
self._client = ManagedIdentityClient(
request_factory=functools.partial(_get_request, url),
base_headers={"Metadata": "true"},
**kwargs
return ManagedIdentityClient(
request_factory=functools.partial(_get_request, url), base_headers={"Metadata": "true"}, **kwargs
)
else:
self._available = False

def get_token(self, *scopes, **kwargs):
# type: (*str, **Any) -> AccessToken
if not self._available:
raise CredentialUnavailableError(
message="Cloud Shell managed identity configuration not found in environment"
)
return super(CloudShellCredential, self).get_token(*scopes, **kwargs)

def _acquire_token_silently(self, *scopes, **kwargs):
# type: (*str, **Any) -> Optional[AccessToken]
return self._client.get_cached_token(*scopes)
return None

def _request_token(self, *scopes, **kwargs):
# type: (*str, **Any) -> AccessToken
return self._client.request_token(*scopes, **kwargs)
def get_unavailable_message(self):
# type: () -> str
return "Cloud Shell managed identity configuration not found in environment"


def _get_request(url, scope, identity_config):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,20 @@ def __init__(self, **kwargs):
else:
_LOGGER.info("No environment configuration found.")

def __enter__(self):
if self._credential:
self._credential.__enter__()
return self

def __exit__(self, *args):
if self._credential:
self._credential.__exit__(*args)

def close(self):
# type: () -> None
"""Close the credential's transport session."""
self.__exit__()

@log_get_token("EnvironmentCredential")
def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
# type: (*str, **Any) -> AccessToken
Expand Down
10 changes: 10 additions & 0 deletions sdk/identity/azure-identity/azure/identity/_credentials/imds.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,16 @@ def __init__(self, **kwargs):
self._error_message = None # type: Optional[str]
self._user_assigned_identity = "client_id" in kwargs or "identity_config" in kwargs

def __enter__(self):
self._client.__enter__()
return self

def __exit__(self, *args):
self._client.__exit__(*args)

def close(self):
self.__exit__()

def _acquire_token_silently(self, *scopes, **kwargs):
# type: (*str, **Any) -> Optional[AccessToken]
return self._client.get_cached_token(*scopes)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,18 @@ def __init__(self, **kwargs):
_LOGGER.info("%s will use IMDS", self.__class__.__name__)
self._credential = ImdsCredential(**kwargs)

def __enter__(self):
self._credential.__enter__()
return self

def __exit__(self, *args):
self._credential.__exit__(*args)

def close(self):
# type: () -> None
"""Close the credential's transport session."""
self.__exit__()

@log_get_token("ManagedIdentityCredential")
def get_token(self, *scopes, **kwargs):
# type: (*str, **Any) -> AccessToken
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,43 +8,25 @@

from azure.core.pipeline.transport import HttpRequest

from .. import CredentialUnavailableError
from .._constants import EnvironmentVariables
from .._internal.managed_identity_base import ManagedIdentityBase
from .._internal.managed_identity_client import ManagedIdentityClient
from .._internal.get_token_mixin import GetTokenMixin

if TYPE_CHECKING:
from typing import Any, Optional
from azure.core.credentials import AccessToken


class ServiceFabricCredential(GetTokenMixin):
def __init__(self, **kwargs):
# type: (**Any) -> None
super(ServiceFabricCredential, self).__init__()

class ServiceFabricCredential(ManagedIdentityBase):
def get_client(self, **kwargs):
# type: (**Any) -> Optional[ManagedIdentityClient]
client_args = _get_client_args(**kwargs)
if client_args:
self._available = True
self._client = ManagedIdentityClient(**client_args)
else:
self._available = False

def get_token(self, *scopes, **kwargs):
# type: (*str, **Any) -> AccessToken
if not self._available:
raise CredentialUnavailableError(
message="Service Fabric managed identity configuration not found in environment"
)
return super(ServiceFabricCredential, self).get_token(*scopes, **kwargs)

def _acquire_token_silently(self, *scopes, **kwargs):
# type: (*str, **Any) -> Optional[AccessToken]
return self._client.get_cached_token(*scopes)
return ManagedIdentityClient(**client_args)
return None

def _request_token(self, *scopes, **kwargs):
# type: (*str, **Any) -> AccessToken
return self._client.request_token(*scopes, **kwargs)
def get_unavailable_message(self):
# type: () -> str
return "Service Fabric managed identity configuration not found in environment"


def _get_client_args(**kwargs):
Expand Down
Loading