Skip to content

Commit

Permalink
Make DefaultAzureCredential in AzureBaseHook configuration (#35051)
Browse files Browse the repository at this point in the history
* feat(provider/microsoft): make managed_identity configurable int base azure hook

* test(providers/microsoft): add test case for verifying calling DefaultAzureCredential with user provided identity in AzureIdentityCredentialAdapter

* docs(microsoft/azure): update azure base hook doc for manged identity args

* refactor(provider/microsoft): extract get_default_azure_credential function
  • Loading branch information
Lee-W authored Oct 30, 2023
1 parent 1aa91a4 commit 2b011b2
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 6 deletions.
8 changes: 8 additions & 0 deletions airflow/providers/microsoft/azure/hooks/base_azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ def get_connection_form_widgets() -> dict[str, Any]:
return {
"tenantId": StringField(lazy_gettext("Azure Tenant ID"), widget=BS3TextFieldWidget()),
"subscriptionId": StringField(lazy_gettext("Azure Subscription ID"), widget=BS3TextFieldWidget()),
"managed_identity_client_id": StringField(
lazy_gettext("Managed Identity Client ID"), widget=BS3TextFieldWidget()
),
"workload_identity_tenant_id": StringField(
lazy_gettext("Workload Identity Tenant ID"), widget=BS3TextFieldWidget()
),
}

@staticmethod
Expand All @@ -79,6 +85,8 @@ def get_ui_field_behaviour() -> dict[str, Any]:
"password": "secret (token credentials auth)",
"tenantId": "tenantId (token credentials auth)",
"subscriptionId": "subscriptionId (token credentials auth)",
"managed_identity_client_id": "Managed Identity Client ID",
"workload_identity_tenant_id": "Workload Identity Tenant ID",
},
}

Expand Down
39 changes: 35 additions & 4 deletions airflow/providers/microsoft/azure/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,24 @@ def get_field(*, conn_id: str, conn_type: str, extras: dict, field_name: str):
return ret


def get_default_azure_credential(
managed_identity_client_id: str | None, workload_identity_tenant_id: str | None
) -> DefaultAzureCredential:
"""Get DefaultAzureCredential based on provided arguments.
If managed_identity_client_id and workload_identity_tenant_id are provided, this function returns
DefaultAzureCredential with managed identity.
"""
if managed_identity_client_id and workload_identity_tenant_id:
return DefaultAzureCredential(
managed_identity_client_id=managed_identity_client_id,
workload_identity_tenant_id=workload_identity_tenant_id,
additionally_allowed_tenants=[workload_identity_tenant_id],
)
else:
return DefaultAzureCredential()


class AzureIdentityCredentialAdapter(BasicTokenAuthentication):
"""Adapt azure-identity credentials for backward compatibility.
Expand All @@ -60,15 +78,28 @@ class AzureIdentityCredentialAdapter(BasicTokenAuthentication):
Check https://stackoverflow.com/questions/63384092/exception-attributeerror-defaultazurecredential-object-has-no-attribute-sig
"""

def __init__(self, credential=None, resource_id="https://management.azure.com/.default", **kwargs):
def __init__(
self,
credential=None,
resource_id="https://management.azure.com/.default",
*,
managed_identity_client_id: str | None = None,
workload_identity_tenant_id: str | None = None,
**kwargs,
):
"""Adapt azure-identity credentials for backward compatibility.
:param credential: Any azure-identity credential (DefaultAzureCredential by default)
:param str resource_id: The scope to use to get the token (default ARM)
:param resource_id: The scope to use to get the token (default ARM)
:param managed_identity_client_id: The client ID of a user-assigned managed identity.
If provided with `workload_identity_tenant_id`, they'll pass to ``DefaultAzureCredential``.
:param workload_identity_tenant_id: ID of the application's Microsoft Entra tenant.
Also called its "directory" ID.
If provided with `managed_identity_client_id`, they'll pass to ``DefaultAzureCredential``.
"""
super().__init__(None)
super().__init__(None) # type: ignore[arg-type]
if credential is None:
credential = DefaultAzureCredential()
credential = get_default_azure_credential(managed_identity_client_id, workload_identity_tenant_id)
self._policy = BearerTokenCredentialPolicy(credential, resource_id, **kwargs)

def _make_request(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
.. _howto/connection:azure:


Microsoft Azure Connection
==========================

Expand All @@ -27,14 +28,15 @@ The Microsoft Azure connection type enables the Azure Integrations.
Authenticating to Azure
-----------------------

There are four ways to connect to Azure using Airflow.
There are five ways to connect to Azure using Airflow.

1. Use `token credentials`_
i.e. add specific credentials (client_id, secret, tenant) and subscription id to the Airflow connection.
2. Use a `JSON file`_
3. Use a `JSON dictionary`_
i.e. add a key config directly into the Airflow connection.
4. Fallback on `DefaultAzureCredential`_.
4. Use managed identity through providing ``managed_identity_client_id`` and ``workload_identity_tenant_id``.
5. Fallback on `DefaultAzureCredential`_.
This includes a mechanism to try different options to authenticate: Managed System Identity, environment variables, authentication through Azure CLI and etc.
``subscriptionId`` is required in this authentication mechanism.

Expand Down Expand Up @@ -71,6 +73,8 @@ Extra (optional)
It specifies the path to the json file that contains the authentication information.
* ``key_json``: If set, it uses the *JSON dictionary* authentication mechanism.
It specifies the json that contains the authentication information.
* ``managed_identity_client_id``: The client ID of a user-assigned managed identity. If provided with `workload_identity_tenant_id`, they'll pass to ``DefaultAzureCredential``.
* ``workload_identity_tenant_id``: ID of the application's Microsoft Entra tenant. Also called its "directory" ID. If provided with `managed_identity_client_id`, they'll pass to ``DefaultAzureCredential``.

The entire extra column can be left out to fall back on DefaultAzureCredential_.

Expand All @@ -90,3 +94,7 @@ For example:
.. _JSON file: https://docs.microsoft.com/en-us/azure/developer/python/azure-sdk-authenticate?tabs=cmd#authenticate-with-a-json-file
.. _JSON dictionary: https://docs.microsoft.com/en-us/azure/developer/python/azure-sdk-authenticate?tabs=cmd#authenticate-with-a-json-dictionary>
.. _DefaultAzureCredential: https://docs.microsoft.com/en-us/python/api/overview/azure/identity-readme?view=azure-python#defaultazurecredential

.. spelling:word-list::
Entra
21 changes: 21 additions & 0 deletions tests/providers/microsoft/azure/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,24 @@ def test_signed_session(self, mock_default_azure_credential, mock_policy, mock_r

adapter.signed_session()
assert adapter.token == {"access_token": "token"}

@mock.patch(f"{MODULE}.PipelineRequest")
@mock.patch(f"{MODULE}.BearerTokenCredentialPolicy")
@mock.patch(f"{MODULE}.DefaultAzureCredential")
def test_init_with_identity(self, mock_default_azure_credential, mock_policy, mock_request):
mock_request.return_value.http_request.headers = {"Authorization": "Bearer token"}

adapter = AzureIdentityCredentialAdapter(
managed_identity_client_id="managed_identity_client_id",
workload_identity_tenant_id="workload_identity_tenant_id",
additionally_allowed_tenants=["workload_identity_tenant_id"],
)
mock_default_azure_credential.assert_called_once_with(
managed_identity_client_id="managed_identity_client_id",
workload_identity_tenant_id="workload_identity_tenant_id",
additionally_allowed_tenants=["workload_identity_tenant_id"],
)
mock_policy.assert_called_once()

adapter.signed_session()
assert adapter.token == {"access_token": "token"}

0 comments on commit 2b011b2

Please sign in to comment.