Skip to content

Commit

Permalink
(Async)SupportsTokenInfo support/tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mccoyp committed Oct 8, 2024
1 parent 8e65726 commit 93c7eaa
Show file tree
Hide file tree
Showing 10 changed files with 623 additions and 275 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,18 @@

from copy import deepcopy
import time
from typing import Any, Awaitable, Callable, Optional, overload, TypeVar, Union
from typing import Any, Awaitable, Callable, cast, Optional, overload, TypeVar, Union
from urllib.parse import urlparse

from typing_extensions import ParamSpec

from azure.core.credentials import AccessToken
from azure.core.credentials_async import AsyncTokenCredential
from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions
from azure.core.credentials_async import AsyncSupportsTokenInfo, AsyncTokenCredential, AsyncTokenProvider
from azure.core.pipeline import PipelineRequest, PipelineResponse
from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy
from azure.core.rest import AsyncHttpResponse, HttpRequest

from .http_challenge import HttpChallenge
from . import http_challenge_cache as ChallengeCache
from .challenge_auth_policy import _enforce_tls, _has_claims, _update_challenge

Expand Down Expand Up @@ -64,14 +65,14 @@ class AsyncChallengeAuthPolicy(AsyncBearerTokenCredentialPolicy):
:param credential: An object which can provide an access token for the vault, such as a credential from
:mod:`azure.identity.aio`
:type credential: ~azure.core.credentials_async.AsyncTokenCredential
:type credential: ~azure.core.credentials_async.AsyncTokenProvider
"""

def __init__(self, credential: AsyncTokenCredential, *scopes: str, **kwargs: Any) -> None:
def __init__(self, credential: AsyncTokenProvider, *scopes: str, **kwargs: Any) -> None:
# Pass `enable_cae` so `enable_cae=True` is always passed through self.authorize_request
super().__init__(credential, *scopes, enable_cae=True, **kwargs)
self._credential: AsyncTokenCredential = credential
self._token: Optional[AccessToken] = None
self._credential: AsyncTokenProvider = credential
self._token: Optional[Union["AccessToken", "AccessTokenInfo"]] = None
self._verify_challenge_resource = kwargs.pop("verify_challenge_resource", True)
self._request_copy: Optional[HttpRequest] = None

Expand Down Expand Up @@ -157,16 +158,10 @@ async def on_request(self, request: PipelineRequest) -> None:
if self._need_new_token():
# azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource
scope = challenge.get_scope() or challenge.get_resource() + "/.default"
# Exclude tenant for AD FS authentication
if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"):
self._token = await self._credential.get_token(scope, enable_cae=True)
else:
self._token = await self._credential.get_token(
scope, tenant_id=challenge.tenant_id, enable_cae=True
)

# ignore mypy's warning -- although self._token is Optional, get_token raises when it fails to get a token
request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" # type: ignore
await self._request_kv_token(scope, challenge)

bearer_token = cast(Union[AccessToken, AccessTokenInfo], self._token).token
request.http_request.headers["Authorization"] = f"Bearer {bearer_token}"
return

# else: discover authentication information by eliciting a challenge from Key Vault. Remove any request data,
Expand Down Expand Up @@ -233,4 +228,28 @@ async def on_challenge(self, request: PipelineRequest, response: PipelineRespons
return True

def _need_new_token(self) -> bool:
return not self._token or self._token.expires_on - time.time() < 300
now = time.time()
refresh_on = getattr(self._token, "refresh_on", None)
return not self._token or (refresh_on and refresh_on <= now) or self._token.expires_on - now < 300

async def _request_kv_token(self, scope: str, challenge: HttpChallenge, **kwargs: Any) -> None:
"""Implementation of BearerTokenCredentialPolicy's _request_token method, but specific to Key Vault.
:param str scope: The scope for which to request a token.
:param challenge: The challenge for the request being made.
"""
# Exclude tenant for AD FS authentication
exclude_tenant = challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs")
# The AsyncSupportsTokenInfo protocol needs TokenRequestOptions for token requests instead of kwargs
if hasattr(self._credential, "get_token_info"):
options: TokenRequestOptions = {"enable_cae": True}
if challenge.tenant_id and not exclude_tenant:
options["tenant_id"] = challenge.tenant_id
self._token = await cast(AsyncSupportsTokenInfo, self._credential).get_token_info(scope, options=options)
else:
if exclude_tenant:
self._token = await self._credential.get_token(scope, enable_cae=True)
else:
self._token = await cast(AsyncTokenCredential, self._credential).get_token(
scope, tenant_id=challenge.tenant_id, enable_cae=True
)
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,17 @@

from copy import deepcopy
import time
from typing import Any, Optional
from typing import Any, cast, Optional, Union
from urllib.parse import urlparse

from azure.core.credentials import AccessToken, TokenCredential
from azure.core.credentials import (
AccessToken,
AccessTokenInfo,
TokenCredential,
TokenProvider,
TokenRequestOptions,
SupportsTokenInfo,
)
from azure.core.exceptions import ServiceRequestError
from azure.core.pipeline import PipelineRequest, PipelineResponse
from azure.core.pipeline.policies import BearerTokenCredentialPolicy
Expand Down Expand Up @@ -76,14 +83,15 @@ class ChallengeAuthPolicy(BearerTokenCredentialPolicy):
:param credential: An object which can provide an access token for the vault, such as a credential from
:mod:`azure.identity`
:type credential: ~azure.core.credentials.TokenCredential
:type credential: ~azure.core.credentials.TokenProvider
:param str scopes: Lets you specify the type of access needed.
"""

def __init__(self, credential: TokenCredential, *scopes: str, **kwargs: Any) -> None:
def __init__(self, credential: TokenProvider, *scopes: str, **kwargs: Any) -> None:
# Pass `enable_cae` so `enable_cae=True` is always passed through self.authorize_request
super(ChallengeAuthPolicy, self).__init__(credential, *scopes, enable_cae=True, **kwargs)
self._credential: TokenCredential = credential
self._token: Optional[AccessToken] = None
self._credential: TokenProvider = credential
self._token: Optional[Union["AccessToken", "AccessTokenInfo"]] = None
self._verify_challenge_resource = kwargs.pop("verify_challenge_resource", True)
self._request_copy: Optional[HttpRequest] = None

Expand Down Expand Up @@ -166,14 +174,10 @@ def on_request(self, request: PipelineRequest) -> None:
if self._need_new_token:
# azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource
scope = challenge.get_scope() or challenge.get_resource() + "/.default"
# Exclude tenant for AD FS authentication
if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"):
self._token = self._credential.get_token(scope, enable_cae=True)
else:
self._token = self._credential.get_token(scope, tenant_id=challenge.tenant_id, enable_cae=True)

# ignore mypy's warning -- although self._token is Optional, get_token raises when it fails to get a token
request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" # type: ignore
self._request_kv_token(scope, challenge)

bearer_token = cast(Union["AccessToken", "AccessTokenInfo"], self._token).token
request.http_request.headers["Authorization"] = f"Bearer {bearer_token}"
return

# else: discover authentication information by eliciting a challenge from Key Vault. Remove any request data,
Expand Down Expand Up @@ -238,4 +242,28 @@ def on_challenge(self, request: PipelineRequest, response: PipelineResponse) ->

@property
def _need_new_token(self) -> bool:
return not self._token or self._token.expires_on - time.time() < 300
now = time.time()
refresh_on = getattr(self._token, "refresh_on", None)
return not self._token or (refresh_on and refresh_on <= now) or self._token.expires_on - now < 300

def _request_kv_token(self, scope: str, challenge: HttpChallenge, **kwargs: Any) -> None:
"""Implementation of BearerTokenCredentialPolicy's _request_token method, but specific to Key Vault.
:param str scope: The scope for which to request a token.
:param challenge: The challenge for the request being made.
"""
# Exclude tenant for AD FS authentication
exclude_tenant = challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs")
# The SupportsTokenInfo protocol needs TokenRequestOptions for token requests instead of kwargs
if hasattr(self._credential, "get_token_info"):
options: TokenRequestOptions = {"enable_cae": True}
if challenge.tenant_id and not exclude_tenant:
options["tenant_id"] = challenge.tenant_id
self._token = cast(SupportsTokenInfo, self._credential).get_token_info(scope, options=options)
else:
if exclude_tenant:
self._token = self._credential.get_token(scope, enable_cae=True)
else:
self._token = cast(TokenCredential, self._credential).get_token(
scope, tenant_id=challenge.tenant_id, enable_cae=True
)
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,18 @@

from copy import deepcopy
import time
from typing import Any, Awaitable, Callable, Optional, overload, TypeVar, Union
from typing import Any, Awaitable, Callable, cast, Optional, overload, TypeVar, Union
from urllib.parse import urlparse

from typing_extensions import ParamSpec

from azure.core.credentials import AccessToken
from azure.core.credentials_async import AsyncTokenCredential
from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions
from azure.core.credentials_async import AsyncSupportsTokenInfo, AsyncTokenCredential, AsyncTokenProvider
from azure.core.pipeline import PipelineRequest, PipelineResponse
from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy
from azure.core.rest import AsyncHttpResponse, HttpRequest

from .http_challenge import HttpChallenge
from . import http_challenge_cache as ChallengeCache
from .challenge_auth_policy import _enforce_tls, _has_claims, _update_challenge

Expand Down Expand Up @@ -64,14 +65,14 @@ class AsyncChallengeAuthPolicy(AsyncBearerTokenCredentialPolicy):
:param credential: An object which can provide an access token for the vault, such as a credential from
:mod:`azure.identity.aio`
:type credential: ~azure.core.credentials_async.AsyncTokenCredential
:type credential: ~azure.core.credentials_async.AsyncTokenProvider
"""

def __init__(self, credential: AsyncTokenCredential, *scopes: str, **kwargs: Any) -> None:
def __init__(self, credential: AsyncTokenProvider, *scopes: str, **kwargs: Any) -> None:
# Pass `enable_cae` so `enable_cae=True` is always passed through self.authorize_request
super().__init__(credential, *scopes, enable_cae=True, **kwargs)
self._credential: AsyncTokenCredential = credential
self._token: Optional[AccessToken] = None
self._credential: AsyncTokenProvider = credential
self._token: Optional[Union["AccessToken", "AccessTokenInfo"]] = None
self._verify_challenge_resource = kwargs.pop("verify_challenge_resource", True)
self._request_copy: Optional[HttpRequest] = None

Expand Down Expand Up @@ -157,16 +158,10 @@ async def on_request(self, request: PipelineRequest) -> None:
if self._need_new_token():
# azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource
scope = challenge.get_scope() or challenge.get_resource() + "/.default"
# Exclude tenant for AD FS authentication
if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"):
self._token = await self._credential.get_token(scope, enable_cae=True)
else:
self._token = await self._credential.get_token(
scope, tenant_id=challenge.tenant_id, enable_cae=True
)

# ignore mypy's warning -- although self._token is Optional, get_token raises when it fails to get a token
request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" # type: ignore
await self._request_kv_token(scope, challenge)

bearer_token = cast(Union[AccessToken, AccessTokenInfo], self._token).token
request.http_request.headers["Authorization"] = f"Bearer {bearer_token}"
return

# else: discover authentication information by eliciting a challenge from Key Vault. Remove any request data,
Expand Down Expand Up @@ -233,4 +228,28 @@ async def on_challenge(self, request: PipelineRequest, response: PipelineRespons
return True

def _need_new_token(self) -> bool:
return not self._token or self._token.expires_on - time.time() < 300
now = time.time()
refresh_on = getattr(self._token, "refresh_on", None)
return not self._token or (refresh_on and refresh_on <= now) or self._token.expires_on - now < 300

async def _request_kv_token(self, scope: str, challenge: HttpChallenge, **kwargs: Any) -> None:
"""Implementation of BearerTokenCredentialPolicy's _request_token method, but specific to Key Vault.
:param str scope: The scope for which to request a token.
:param challenge: The challenge for the request being made.
"""
# Exclude tenant for AD FS authentication
exclude_tenant = challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs")
# The AsyncSupportsTokenInfo protocol needs TokenRequestOptions for token requests instead of kwargs
if hasattr(self._credential, "get_token_info"):
options: TokenRequestOptions = {"enable_cae": True}
if challenge.tenant_id and not exclude_tenant:
options["tenant_id"] = challenge.tenant_id
self._token = await cast(AsyncSupportsTokenInfo, self._credential).get_token_info(scope, options=options)
else:
if exclude_tenant:
self._token = await self._credential.get_token(scope, enable_cae=True)
else:
self._token = await cast(AsyncTokenCredential, self._credential).get_token(
scope, tenant_id=challenge.tenant_id, enable_cae=True
)
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,17 @@

from copy import deepcopy
import time
from typing import Any, Optional
from typing import Any, cast, Optional, Union
from urllib.parse import urlparse

from azure.core.credentials import AccessToken, TokenCredential
from azure.core.credentials import (
AccessToken,
AccessTokenInfo,
TokenCredential,
TokenProvider,
TokenRequestOptions,
SupportsTokenInfo,
)
from azure.core.exceptions import ServiceRequestError
from azure.core.pipeline import PipelineRequest, PipelineResponse
from azure.core.pipeline.policies import BearerTokenCredentialPolicy
Expand Down Expand Up @@ -76,14 +83,15 @@ class ChallengeAuthPolicy(BearerTokenCredentialPolicy):
:param credential: An object which can provide an access token for the vault, such as a credential from
:mod:`azure.identity`
:type credential: ~azure.core.credentials.TokenCredential
:type credential: ~azure.core.credentials.TokenProvider
:param str scopes: Lets you specify the type of access needed.
"""

def __init__(self, credential: TokenCredential, *scopes: str, **kwargs: Any) -> None:
def __init__(self, credential: TokenProvider, *scopes: str, **kwargs: Any) -> None:
# Pass `enable_cae` so `enable_cae=True` is always passed through self.authorize_request
super(ChallengeAuthPolicy, self).__init__(credential, *scopes, enable_cae=True, **kwargs)
self._credential: TokenCredential = credential
self._token: Optional[AccessToken] = None
self._credential: TokenProvider = credential
self._token: Optional[Union["AccessToken", "AccessTokenInfo"]] = None
self._verify_challenge_resource = kwargs.pop("verify_challenge_resource", True)
self._request_copy: Optional[HttpRequest] = None

Expand Down Expand Up @@ -166,14 +174,10 @@ def on_request(self, request: PipelineRequest) -> None:
if self._need_new_token:
# azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource
scope = challenge.get_scope() or challenge.get_resource() + "/.default"
# Exclude tenant for AD FS authentication
if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"):
self._token = self._credential.get_token(scope, enable_cae=True)
else:
self._token = self._credential.get_token(scope, tenant_id=challenge.tenant_id, enable_cae=True)

# ignore mypy's warning -- although self._token is Optional, get_token raises when it fails to get a token
request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" # type: ignore
self._request_kv_token(scope, challenge)

bearer_token = cast(Union["AccessToken", "AccessTokenInfo"], self._token).token
request.http_request.headers["Authorization"] = f"Bearer {bearer_token}"
return

# else: discover authentication information by eliciting a challenge from Key Vault. Remove any request data,
Expand Down Expand Up @@ -238,4 +242,28 @@ def on_challenge(self, request: PipelineRequest, response: PipelineResponse) ->

@property
def _need_new_token(self) -> bool:
return not self._token or self._token.expires_on - time.time() < 300
now = time.time()
refresh_on = getattr(self._token, "refresh_on", None)
return not self._token or (refresh_on and refresh_on <= now) or self._token.expires_on - now < 300

def _request_kv_token(self, scope: str, challenge: HttpChallenge, **kwargs: Any) -> None:
"""Implementation of BearerTokenCredentialPolicy's _request_token method, but specific to Key Vault.
:param str scope: The scope for which to request a token.
:param challenge: The challenge for the request being made.
"""
# Exclude tenant for AD FS authentication
exclude_tenant = challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs")
# The SupportsTokenInfo protocol needs TokenRequestOptions for token requests instead of kwargs
if hasattr(self._credential, "get_token_info"):
options: TokenRequestOptions = {"enable_cae": True}
if challenge.tenant_id and not exclude_tenant:
options["tenant_id"] = challenge.tenant_id
self._token = cast(SupportsTokenInfo, self._credential).get_token_info(scope, options=options)
else:
if exclude_tenant:
self._token = self._credential.get_token(scope, enable_cae=True)
else:
self._token = cast(TokenCredential, self._credential).get_token(
scope, tenant_id=challenge.tenant_id, enable_cae=True
)
Loading

0 comments on commit 93c7eaa

Please sign in to comment.