Skip to content

Commit

Permalink
Add context manager API to azure.identity credentials (#19746)
Browse files Browse the repository at this point in the history
  • Loading branch information
chlowell authored and iscai-msft committed Sep 29, 2021
1 parent 872f513 commit 94ad583
Show file tree
Hide file tree
Showing 36 changed files with 718 additions and 260 deletions.
5 changes: 5 additions & 0 deletions sdk/identity/azure-identity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@
### Bugs Fixed

### Other Changes
- Added context manager methods and `close()` to credentials in the
`azure.identity` namespace. At the end of a `with` block, or when `close()`
is called, these credentials close their underlying transport sessions.
([#18798](https://github.com/Azure/azure-sdk-for-python/issues/18798))


## 1.7.0b3 (2021-08-10)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,45 +6,27 @@
import os
from typing import TYPE_CHECKING

from azure.core.credentials import AccessToken
from azure.core.pipeline.transport import HttpRequest

from .. import CredentialUnavailableError
from .._constants import EnvironmentVariables
from .._internal.managed_identity_base import ManagedIdentityBase
from .._internal.managed_identity_client import ManagedIdentityClient
from .._internal.get_token_mixin import GetTokenMixin

if TYPE_CHECKING:
from typing import Any, Optional


class AppServiceCredential(GetTokenMixin):
def __init__(self, **kwargs):
# type: (**Any) -> None
super(AppServiceCredential, self).__init__()

class AppServiceCredential(ManagedIdentityBase):
def get_client(self, **kwargs):
# type: (**Any) -> Optional[ManagedIdentityClient]
client_args = _get_client_args(**kwargs)
if client_args:
self._available = True
self._client = ManagedIdentityClient(**client_args)
else:
self._available = False

def get_token(self, *scopes, **kwargs):
# type: (*str, **Any) -> AccessToken
if not self._available:
raise CredentialUnavailableError(
message="App Service managed identity configuration not found in environment"
)
return super(AppServiceCredential, self).get_token(*scopes, **kwargs)

def _acquire_token_silently(self, *scopes, **kwargs):
# type: (*str, **Any) -> Optional[AccessToken]
return self._client.get_cached_token(*scopes)

def _request_token(self, *scopes, **kwargs):
# type: (*str, **Any) -> AccessToken
return self._client.request_token(*scopes, **kwargs)
return ManagedIdentityClient(**client_args)
return None

def get_unavailable_message(self):
# type: () -> str
return "App Service managed identity configuration not found in environment"


def _get_client_args(**kwargs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,18 @@ def __init__(self, tenant_id, client_id, authorization_code, redirect_uri, **kwa
self._redirect_uri = redirect_uri
super(AuthorizationCodeCredential, self).__init__()

def __enter__(self):
self._client.__enter__()
return self

def __exit__(self, *args):
self._client.__exit__(*args)

def close(self):
# type: () -> None
"""Close the credential's transport session."""
self.__exit__()

def get_token(self, *scopes, **kwargs):
# type: (*str, **Any) -> AccessToken
"""Request an access token for `scopes`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,51 +10,42 @@
from azure.core.pipeline.transport import HttpRequest
from azure.core.pipeline.policies import HTTPPolicy

from .. import CredentialUnavailableError
from .._constants import EnvironmentVariables
from .._internal.managed_identity_base import ManagedIdentityBase
from .._internal.managed_identity_client import ManagedIdentityClient
from .._internal.get_token_mixin import GetTokenMixin

if TYPE_CHECKING:
# pylint:disable=unused-import,ungrouped-imports
from typing import Any, Optional, Union
from azure.core.credentials import AccessToken
from typing import Any, Optional
from azure.core.pipeline import PipelineRequest, PipelineResponse
from azure.core.pipeline.policies import SansIOHTTPPolicy

PolicyType = Union[HTTPPolicy, SansIOHTTPPolicy]


class AzureArcCredential(GetTokenMixin):
def __init__(self, **kwargs):
# type: (**Any) -> None
super(AzureArcCredential, self).__init__()

class AzureArcCredential(ManagedIdentityBase):
def get_client(self, **kwargs):
# type: (**Any) -> Optional[ManagedIdentityClient]
url = os.environ.get(EnvironmentVariables.IDENTITY_ENDPOINT)
imds = os.environ.get(EnvironmentVariables.IMDS_ENDPOINT)
self._available = url and imds
if self._available:
self._client = ManagedIdentityClient(
if url and imds:
return ManagedIdentityClient(
_per_retry_policies=[ArcChallengeAuthPolicy()],
request_factory=functools.partial(_get_request, url),
**kwargs
)
return None

def get_token(self, *scopes, **kwargs):
# type: (*str, **Any) -> AccessToken
if not self._available:
raise CredentialUnavailableError(
message="Azure Arc managed identity configuration not found in environment"
)
return super(AzureArcCredential, self).get_token(*scopes, **kwargs)
def __enter__(self):
self._client.__enter__()
return self

def __exit__(self, *args):
self._client.__exit__(*args)

def _acquire_token_silently(self, *scopes, **kwargs):
# type: (*str, **Any) -> Optional[AccessToken]
return self._client.get_cached_token(*scopes)
def close(self):
self.__exit__()

def _request_token(self, *scopes, **kwargs):
# type: (*str, **Any) -> AccessToken
return self._client.request_token(*scopes, **kwargs)
def get_unavailable_message(self):
# type: () -> str
return "Azure Arc managed identity configuration not found in environment"


def _get_request(url, scope, identity_config):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,16 @@ class AzureCliCredential(object):
def __init__(self, **kwargs):
self._allow_multitenant = kwargs.get("allow_multitenant_authentication", False)

def __enter__(self):
return self

def __exit__(self, *args):
pass

def close(self):
# type: () -> None
"""Calling this method is unnecessary."""

@log_get_token("AzureCliCredential")
def get_token(self, *scopes, **kwargs):
# type: (*str, **Any) -> AccessToken
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,16 @@ def __init__(self, **kwargs):
# type: (**Any) -> None
self._allow_multitenant = kwargs.get("allow_multitenant_authentication", False)

def __enter__(self):
return self

def __exit__(self, *args):
pass

def close(self):
# type: () -> None
"""Calling this method is unnecessary."""

@log_get_token("AzurePowerShellCredential")
def get_token(self, *scopes, **kwargs):
# type: (*str, **Any) -> AccessToken
Expand Down
14 changes: 14 additions & 0 deletions sdk/identity/azure-identity/azure/identity/_credentials/chained.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,20 @@ def __init__(self, *credentials):
self._successful_credential = None # type: Optional[TokenCredential]
self.credentials = credentials

def __enter__(self):
for credential in self.credentials:
credential.__enter__()
return self

def __exit__(self, *args):
for credential in self.credentials:
credential.__exit__(*args)

def close(self):
# type: () -> None
"""Close the transport session of each credential in the chain."""
self.__exit__()

def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
# type: (*str, **Any) -> AccessToken
"""Request a token from each chained credential, in order, returning the first token received.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,17 @@ def __init__(self, tenant_id, client_id, get_assertion, **kwargs):
self._client = AadClient(tenant_id, client_id, **kwargs)
super(ClientAssertionCredential, self).__init__(**kwargs)

def __enter__(self):
self._client.__enter__()
return self

def __exit__(self, *args):
self._client.__exit__(*args)

def close(self):
# type: () -> None
self.__exit__()

def _acquire_token_silently(self, *scopes, **kwargs):
# type: (*str, **Any) -> Optional[AccessToken]
return self._client.get_cached_access_token(scopes, **kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,46 +8,27 @@

from azure.core.pipeline.transport import HttpRequest

from .. import CredentialUnavailableError
from .._constants import EnvironmentVariables
from .._internal.managed_identity_client import ManagedIdentityClient
from .._internal.get_token_mixin import GetTokenMixin
from .._internal.managed_identity_base import ManagedIdentityBase

if TYPE_CHECKING:
from typing import Any, Optional
from azure.core.credentials import AccessToken


class CloudShellCredential(GetTokenMixin):
def __init__(self, **kwargs):
# type: (**Any) -> None
super(CloudShellCredential, self).__init__()
class CloudShellCredential(ManagedIdentityBase):
def get_client(self, **kwargs):
# type: (**Any) -> Optional[ManagedIdentityClient]
url = os.environ.get(EnvironmentVariables.MSI_ENDPOINT)
if url:
self._available = True
self._client = ManagedIdentityClient(
request_factory=functools.partial(_get_request, url),
base_headers={"Metadata": "true"},
**kwargs
return ManagedIdentityClient(
request_factory=functools.partial(_get_request, url), base_headers={"Metadata": "true"}, **kwargs
)
else:
self._available = False

def get_token(self, *scopes, **kwargs):
# type: (*str, **Any) -> AccessToken
if not self._available:
raise CredentialUnavailableError(
message="Cloud Shell managed identity configuration not found in environment"
)
return super(CloudShellCredential, self).get_token(*scopes, **kwargs)

def _acquire_token_silently(self, *scopes, **kwargs):
# type: (*str, **Any) -> Optional[AccessToken]
return self._client.get_cached_token(*scopes)
return None

def _request_token(self, *scopes, **kwargs):
# type: (*str, **Any) -> AccessToken
return self._client.request_token(*scopes, **kwargs)
def get_unavailable_message(self):
# type: () -> str
return "Cloud Shell managed identity configuration not found in environment"


def _get_request(url, scope, identity_config):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,20 @@ def __init__(self, **kwargs):
else:
_LOGGER.info("No environment configuration found.")

def __enter__(self):
if self._credential:
self._credential.__enter__()
return self

def __exit__(self, *args):
if self._credential:
self._credential.__exit__(*args)

def close(self):
# type: () -> None
"""Close the credential's transport session."""
self.__exit__()

@log_get_token("EnvironmentCredential")
def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
# type: (*str, **Any) -> AccessToken
Expand Down
10 changes: 10 additions & 0 deletions sdk/identity/azure-identity/azure/identity/_credentials/imds.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,16 @@ def __init__(self, **kwargs):
self._error_message = None # type: Optional[str]
self._user_assigned_identity = "client_id" in kwargs or "identity_config" in kwargs

def __enter__(self):
self._client.__enter__()
return self

def __exit__(self, *args):
self._client.__exit__(*args)

def close(self):
self.__exit__()

def _acquire_token_silently(self, *scopes, **kwargs):
# type: (*str, **Any) -> Optional[AccessToken]
return self._client.get_cached_token(*scopes)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,18 @@ def __init__(self, **kwargs):
_LOGGER.info("%s will use IMDS", self.__class__.__name__)
self._credential = ImdsCredential(**kwargs)

def __enter__(self):
self._credential.__enter__()
return self

def __exit__(self, *args):
self._credential.__exit__(*args)

def close(self):
# type: () -> None
"""Close the credential's transport session."""
self.__exit__()

@log_get_token("ManagedIdentityCredential")
def get_token(self, *scopes, **kwargs):
# type: (*str, **Any) -> AccessToken
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,43 +8,25 @@

from azure.core.pipeline.transport import HttpRequest

from .. import CredentialUnavailableError
from .._constants import EnvironmentVariables
from .._internal.managed_identity_base import ManagedIdentityBase
from .._internal.managed_identity_client import ManagedIdentityClient
from .._internal.get_token_mixin import GetTokenMixin

if TYPE_CHECKING:
from typing import Any, Optional
from azure.core.credentials import AccessToken


class ServiceFabricCredential(GetTokenMixin):
def __init__(self, **kwargs):
# type: (**Any) -> None
super(ServiceFabricCredential, self).__init__()

class ServiceFabricCredential(ManagedIdentityBase):
def get_client(self, **kwargs):
# type: (**Any) -> Optional[ManagedIdentityClient]
client_args = _get_client_args(**kwargs)
if client_args:
self._available = True
self._client = ManagedIdentityClient(**client_args)
else:
self._available = False

def get_token(self, *scopes, **kwargs):
# type: (*str, **Any) -> AccessToken
if not self._available:
raise CredentialUnavailableError(
message="Service Fabric managed identity configuration not found in environment"
)
return super(ServiceFabricCredential, self).get_token(*scopes, **kwargs)

def _acquire_token_silently(self, *scopes, **kwargs):
# type: (*str, **Any) -> Optional[AccessToken]
return self._client.get_cached_token(*scopes)
return ManagedIdentityClient(**client_args)
return None

def _request_token(self, *scopes, **kwargs):
# type: (*str, **Any) -> AccessToken
return self._client.request_token(*scopes, **kwargs)
def get_unavailable_message(self):
# type: () -> str
return "Service Fabric managed identity configuration not found in environment"


def _get_client_args(**kwargs):
Expand Down
Loading

0 comments on commit 94ad583

Please sign in to comment.