Skip to content

Commit

Permalink
[Identity] Allow use of client assertion in OBO cred
Browse files Browse the repository at this point in the history
Signed-off-by: Paul Van Eck <[email protected]>
  • Loading branch information
pvaneck committed Jun 5, 2024
1 parent edccdfa commit f7e11e5
Show file tree
Hide file tree
Showing 9 changed files with 196 additions and 26 deletions.
2 changes: 2 additions & 0 deletions sdk/identity/azure-identity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

### Features Added

- `OnBehalfOfCredential` now supports client assertion callbacks through the `client_assertion_func` keyword argument. This enables authenticating with client assertions such as federated credentials. ([#35812](https://github.com/Azure/azure-sdk-for-python/pull/35812))

### Breaking Changes

### Bugs Fixed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Licensed under the MIT License.
# ------------------------------------
import time
from typing import Any, Optional
from typing import Any, Optional, Callable, Union, Dict

import msal

Expand All @@ -30,10 +30,14 @@ class OnBehalfOfCredential(MsalCredential, GetTokenMixin):
:param str tenant_id: ID of the service principal's tenant. Also called its "directory" ID.
:param str client_id: The service principal's client ID
:keyword str client_secret: Optional. A client secret to authenticate the service principal.
Either **client_secret** or **client_certificate** must be provided.
One of **client_secret**, **client_certificate**, or **client_assertion_func** must be provided.
:keyword bytes client_certificate: Optional. The bytes of a certificate in PEM or PKCS12 format including
the private key to authenticate the service principal. Either **client_secret** or **client_certificate** must
be provided.
the private key to authenticate the service principal. One of **client_secret**, **client_certificate**,
or **client_assertion_func** must be provided.
:keyword client_assertion_func: Optional. Function that returns client assertions that authenticate the
application to Microsoft Entra ID. This function is called each time the credential requests a token. It must
return a valid assertion for the target resource.
:paramtype client_assertion_func: Callable[[], str]
:keyword str user_assertion: Required. The access token the credential will use as the user assertion when
requesting on-behalf-of tokens
Expand Down Expand Up @@ -65,14 +69,30 @@ class OnBehalfOfCredential(MsalCredential, GetTokenMixin):
:caption: Create an OnBehalfOfCredential.
"""

def __init__(self, tenant_id: str, client_id: str, **kwargs: Any) -> None:
self._assertion = kwargs.pop("user_assertion", None)
def __init__(
self,
tenant_id: str,
client_id: str,
*,
client_certificate: Optional[bytes] = None,
client_secret: Optional[str] = None,
client_assertion_func: Optional[Callable[[], str]] = None,
user_assertion: str,
**kwargs: Any
) -> None:
self._assertion = user_assertion
if not self._assertion:
raise TypeError('"user_assertion" is required.')
client_certificate = kwargs.pop("client_certificate", None)
client_secret = kwargs.pop("client_secret", None)
raise TypeError('"user_assertion" must not be empty.')

if client_certificate:
if client_assertion_func:
if client_certificate or client_secret:
raise ValueError(
'Specifying both "client_assertion_func" and "client_certificate" or "client_secret" is not valid.'
)
credential: Union[str, Dict[str, Any]] = {
"client_assertion": client_assertion_func,
}
elif client_certificate:
if client_secret:
raise ValueError('Specifying both "client_certificate" and "client_secret" is not valid.')
try:
Expand All @@ -86,7 +106,7 @@ def __init__(self, tenant_id: str, client_id: str, **kwargs: Any) -> None:
elif client_secret:
credential = client_secret
else:
raise TypeError('Either "client_certificate" or "client_secret" must be provided')
raise TypeError('Either "client_certificate", "client_secret", or "client_assertion_func" must be provided')

super(OnBehalfOfCredential, self).__init__(
client_id=client_id, client_credential=credential, tenant_id=tenant_id, **kwargs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def _get_client_secret_request(self, scopes: Iterable[str], secret: str, **kwarg
def _get_on_behalf_of_request(
self,
scopes: Iterable[str],
client_credential: Union[str, AadClientCertificate],
client_credential: Union[str, AadClientCertificate, Dict[str, Any]],
user_assertion: str,
**kwargs: Any
) -> HttpRequest:
Expand All @@ -288,6 +288,10 @@ def _get_on_behalf_of_request(
if isinstance(client_credential, AadClientCertificate):
data["client_assertion"] = self._get_client_certificate_assertion(client_credential)
data["client_assertion_type"] = JWT_BEARER_ASSERTION
elif isinstance(client_credential, dict):
func = client_credential["client_assertion"]
data["client_assertion"] = func()
data["client_assertion_type"] = JWT_BEARER_ASSERTION
else:
data["client_secret"] = client_credential

Expand Down Expand Up @@ -318,7 +322,7 @@ def _get_refresh_token_request(self, scopes: Iterable[str], refresh_token: str,
def _get_refresh_token_on_behalf_of_request(
self,
scopes: Iterable[str],
client_credential: Union[str, AadClientCertificate],
client_credential: Union[str, AadClientCertificate, Dict[str, Any]],
refresh_token: str,
**kwargs: Any
) -> HttpRequest:
Expand All @@ -338,6 +342,10 @@ def _get_refresh_token_on_behalf_of_request(
if isinstance(client_credential, AadClientCertificate):
data["client_assertion"] = self._get_client_certificate_assertion(client_credential)
data["client_assertion_type"] = JWT_BEARER_ASSERTION
elif isinstance(client_credential, dict):
func = client_credential["client_assertion"]
data["client_assertion"] = func()
data["client_assertion_type"] = JWT_BEARER_ASSERTION
else:
data["client_secret"] = client_credential
request = self._post(data, **kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class MsalCredential: # pylint: disable=too-many-instance-attributes
def __init__(
self,
client_id: str,
client_credential: Optional[Union[str, Dict[str, str]]] = None,
client_credential: Optional[Union[str, Dict[str, Any]]] = None,
*,
additionally_allowed_tenants: Optional[List[str]] = None,
authority: Optional[str] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Licensed under the MIT License.
# ------------------------------------
import logging
from typing import Optional, Union, Any
from typing import Optional, Union, Any, Dict, Callable

from azure.core.exceptions import ClientAuthenticationError
from azure.core.credentials import AccessToken
Expand All @@ -27,10 +27,14 @@ class OnBehalfOfCredential(AsyncContextManager, GetTokenMixin):
:param str tenant_id: ID of the service principal's tenant. Also called its "directory" ID.
:param str client_id: The service principal's client ID
:keyword str client_secret: Optional. A client secret to authenticate the service principal.
Either **client_secret** or **client_certificate** must be provided.
One of **client_secret**, **client_certificate**, or **client_assertion_func** must be provided.
:keyword bytes client_certificate: Optional. The bytes of a certificate in PEM or PKCS12 format including
the private key to authenticate the service principal. Either **client_secret** or **client_certificate** must
be provided.
the private key to authenticate the service principal. One of **client_secret**, **client_certificate**,
or **client_assertion_func** must be provided.
:keyword client_assertion_func: Optional. Function that returns client assertions that authenticate the
application to Microsoft Entra ID. This function is called each time the credential requests a token. It must
return a valid assertion for the target resource.
:paramtype client_assertion_func: Callable[[], str]
:keyword str user_assertion: Required. The access token the credential will use as the user assertion when
requesting on-behalf-of tokens
Expand Down Expand Up @@ -62,29 +66,38 @@ def __init__(
*,
client_certificate: Optional[bytes] = None,
client_secret: Optional[str] = None,
client_assertion_func: Optional[Callable[[], str]] = None,
user_assertion: str,
**kwargs: Any
) -> None:
super().__init__()
validate_tenant_id(tenant_id)

self._assertion = user_assertion
if not self._assertion:
raise TypeError('"user_assertion" must not be empty.')

if client_certificate:
if client_assertion_func:
if client_certificate or client_secret:
raise ValueError(
'Specifying both "client_assertion_func" and "client_certificate" or "client_secret" is not valid.'
)
self._client_credential: Union[str, AadClientCertificate, Dict[str, Any]] = {
"client_assertion": client_assertion_func,
}
elif client_certificate:
if client_secret:
raise ValueError('Specifying both "client_certificate" and "client_secret" is not valid.')
try:
cert = get_client_credential(None, kwargs.pop("password", None), client_certificate)
except ValueError as ex:
message = '"client_certificate" is not a valid certificate in PEM or PKCS12 format'
raise ValueError(message) from ex
self._client_credential: Union[str, AadClientCertificate] = AadClientCertificate(
cert["private_key"], password=cert.get("passphrase")
)
self._client_credential = AadClientCertificate(cert["private_key"], password=cert.get("passphrase"))
elif client_secret:
self._client_credential = client_secret
else:
raise TypeError('Either "client_certificate" or "client_secret" must be provided')
raise TypeError('Either "client_certificate", "client_secret", or "client_assertion_func" must be provided')

# note AadClient handles "authority" and any pipeline kwargs
self._client = AadClient(tenant_id, client_id, **kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Licensed under the MIT License.
# ------------------------------------
import time
from typing import Iterable, Optional, Union
from typing import Iterable, Optional, Union, Dict, Any

from azure.core.credentials import AccessToken
from azure.core.pipeline import AsyncPipeline
Expand Down Expand Up @@ -57,15 +57,23 @@ async def obtain_token_by_refresh_token(self, scopes: Iterable[str], refresh_tok
return await self._run_pipeline(request, **kwargs)

async def obtain_token_by_refresh_token_on_behalf_of( # pylint: disable=name-too-long
self, scopes: Iterable[str], client_credential: Union[str, AadClientCertificate], refresh_token: str, **kwargs
self,
scopes: Iterable[str],
client_credential: Union[str, AadClientCertificate, Dict[str, Any]],
refresh_token: str,
**kwargs
) -> AccessToken:
request = self._get_refresh_token_on_behalf_of_request(
scopes, client_credential=client_credential, refresh_token=refresh_token, **kwargs
)
return await self._run_pipeline(request, **kwargs)

async def obtain_token_on_behalf_of(
self, scopes: Iterable[str], client_credential: Union[str, AadClientCertificate], user_assertion: str, **kwargs
self,
scopes: Iterable[str],
client_credential: Union[str, AadClientCertificate, Dict[str, Any]],
user_assertion: str,
**kwargs
) -> AccessToken:
request = self._get_on_behalf_of_request(
scopes=scopes, client_credential=client_credential, user_assertion=user_assertion, **kwargs
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
"""
FILE: on_behalf_of_client_assertion.py
DESCRIPTION:
This sample demonstrates the use of OnBehalfOfCredential to authenticate the Key Vault SecretClient using a managed
identity as the client assertion. More information about the On-Behalf-Of flow can be found here:
https://learn.microsoft.com/entra/identity-platform/v2-oauth2-on-behalf-of-flow.
USAGE:
python on_behalf_of_client_assertion.py
**Note** - This sample requires the `azure-keyvault-secrets` package.
"""
# [START obo_client_assertion]
from azure.identity import OnBehalfOfCredential, ManagedIdentityCredential
from azure.keyvault.secrets import SecretClient


# Replace the following variables with your own values.
tenant_id = "<tenant_id>"
client_id = "<client_id>"
user_assertion = "<user_assertion>"

managed_identity_credential = ManagedIdentityCredential()


def get_managed_identity_token() -> str:
# This function should return an access token obtained from a managed identity.
access_token = managed_identity_credential.get_token("api://AzureADTokenExchange")
return access_token.token


credential = OnBehalfOfCredential(
tenant_id=tenant_id,
client_id=client_id,
user_assertion=user_assertion,
client_assertion_func=get_managed_identity_token,
)

client = SecretClient(vault_url="https://<your-key-vault-name>.vault.azure.net/", credential=credential)
# [END obo_client_assertion]
38 changes: 38 additions & 0 deletions sdk/identity/azure-identity/tests/test_obo.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from azure.core.pipeline.policies import ContentDecodePolicy, SansIOHTTPPolicy
from azure.identity import OnBehalfOfCredential, UsernamePasswordCredential
from azure.identity._constants import EnvironmentVariables
from azure.identity._internal.aad_client_base import JWT_BEARER_ASSERTION
from azure.identity._internal.user_agent import USER_AGENT
import pytest
from urllib.parse import urlparse
Expand Down Expand Up @@ -228,3 +229,40 @@ def test_no_client_credential():
"""The credential should raise ValueError when ctoring with no client_secret or client_certificate"""
with pytest.raises(TypeError):
credential = OnBehalfOfCredential("tenant-id", "client-id", user_assertion="assertion")


def test_client_assertion_func():
"""The credential should accept a client_assertion_func"""
expected_client_assertion = "client-assertion"
expected_user_assertion = "user-assertion"
expected_token = "***"
func_call_count = 0

def client_assertion_func():
nonlocal func_call_count
func_call_count += 1
return expected_client_assertion

def send(request, **kwargs):
parsed = urlparse(request.url)
tenant = parsed.path.split("/")[1]
if "/oauth2/v2.0/token" not in parsed.path:
return get_discovery_response("https://{}/{}".format(parsed.netloc, tenant))

assert request.data.get("client_assertion") == expected_client_assertion
assert request.data.get("client_assertion_type") == JWT_BEARER_ASSERTION
assert request.data.get("assertion") == expected_user_assertion
return mock_response(json_payload=build_aad_response(access_token=expected_token))

transport = Mock(send=Mock(wraps=send))
credential = OnBehalfOfCredential(
"tenant-id",
"client-id",
client_assertion_func=client_assertion_func,
user_assertion=expected_user_assertion,
transport=transport,
)

access_token = credential.get_token("scope")
assert access_token.token == expected_token
assert func_call_count == 1
38 changes: 38 additions & 0 deletions sdk/identity/azure-identity/tests/test_obo_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from azure.core.pipeline.policies import ContentDecodePolicy, SansIOHTTPPolicy
from azure.identity import UsernamePasswordCredential
from azure.identity._constants import EnvironmentVariables
from azure.identity._internal.aad_client_base import JWT_BEARER_ASSERTION
from azure.identity._internal.user_agent import USER_AGENT
from azure.identity.aio import OnBehalfOfCredential
import pytest
Expand Down Expand Up @@ -305,3 +306,40 @@ async def test_no_client_credential():
"""The credential should raise ValueError when ctoring with no client_secret or client_certificate"""
with pytest.raises(TypeError):
credential = OnBehalfOfCredential("tenant-id", "client-id", user_assertion="assertion")


@pytest.mark.asyncio
async def test_client_assertion_func():
"""The credential should accept a client_assertion_func"""
expected_client_assertion = "client-assertion"
expected_user_assertion = "user-assertion"
expected_token = "***"
func_call_count = 0

def client_assertion_func():
nonlocal func_call_count
func_call_count += 1
return expected_client_assertion

async def send(request, **kwargs):
parsed = urlparse(request.url)
tenant = parsed.path.split("/")[1]
if "/oauth2/v2.0/token" not in parsed.path:
return get_discovery_response("https://{}/{}".format(parsed.netloc, tenant))

assert request.data.get("client_assertion") == expected_client_assertion
assert request.data.get("client_assertion_type") == JWT_BEARER_ASSERTION
assert request.data.get("assertion") == expected_user_assertion
return mock_response(json_payload=build_aad_response(access_token=expected_token))

transport = Mock(send=send)
credential = OnBehalfOfCredential(
"tenant-id",
"client-id",
client_assertion_func=client_assertion_func,
user_assertion=expected_user_assertion,
transport=transport,
)
token = await credential.get_token("scope")
assert token.token == expected_token
assert func_call_count == 1

0 comments on commit f7e11e5

Please sign in to comment.