diff --git a/airflow/providers/hashicorp/hooks/vault.py b/airflow/providers/hashicorp/hooks/vault.py index a17eceda31946..b298e3b990205 100644 --- a/airflow/providers/hashicorp/hooks/vault.py +++ b/airflow/providers/hashicorp/hooks/vault.py @@ -17,6 +17,7 @@ """Hook for HashiCorp Vault""" import json +import warnings from typing import Optional, Tuple import hvac @@ -63,7 +64,7 @@ class VaultHook(BaseHook): Login/Password are used as credentials: - * approle: password -> secret_id + * approle: login -> role_id, password -> secret_id * github: password -> token * token: password -> token * aws_iam: login -> key_id, password -> secret_id @@ -83,7 +84,7 @@ class VaultHook(BaseHook): :param kv_engine_version: Select the version of the engine to run (``1`` or ``2``). Defaults to version defined in connection or ``2`` if not defined in connection. :type kv_engine_version: int - :param role_id: Role ID for Authentication (for ``approle``, ``aws_iam`` auth_types) + :param role_id: Role ID for ``aws_iam`` Authentication. :type role_id: str :param kubernetes_role: Role for Authentication (for ``kubernetes`` auth_type) :type kubernetes_role: str @@ -148,7 +149,26 @@ def __init__( except ValueError: raise VaultError(f"The version is not an int: {conn_version}. ") - if auth_type in ["approle", "aws_iam"]: + if auth_type == "approle": + if role_id: + warnings.warn( + """The usage of role_id for AppRole authentication has been deprecated. + Please use connection login.""", + DeprecationWarning, + stacklevel=2, + ) + elif self.connection.extra_dejson.get('role_id'): + role_id = self.connection.extra_dejson.get('role_id') + warnings.warn( + """The usage of role_id in connection extra for AppRole authentication has been + deprecated. Please use connection login.""", + DeprecationWarning, + stacklevel=2, + ) + elif self.connection.login: + role_id = self.connection.login + + if auth_type == "aws_iam": if not role_id: role_id = self.connection.extra_dejson.get('role_id') diff --git a/tests/providers/hashicorp/hooks/test_vault.py b/tests/providers/hashicorp/hooks/test_vault.py index 394bd0b43a784..fdc3cb51852e4 100644 --- a/tests/providers/hashicorp/hooks/test_vault.py +++ b/tests/providers/hashicorp/hooks/test_vault.py @@ -176,7 +176,6 @@ def test_protocol(self, protocol, expected_url, mock_hvac, mock_get_connection): kwargs = { "vault_conn_id": "vault_conn_id", "auth_type": "approle", - "role_id": "role", "kv_engine_version": 2, } @@ -184,7 +183,7 @@ def test_protocol(self, protocol, expected_url, mock_hvac, mock_get_connection): mock_get_connection.assert_called_with("vault_conn_id") test_client = test_hook.get_conn() mock_hvac.Client.assert_called_with(url=expected_url) - test_client.auth.approle.login.assert_called_with(role_id="role", secret_id="pass") + test_client.auth.approle.login.assert_called_with(role_id="user", secret_id="pass") test_client.is_authenticated.assert_called_with() assert 2 == test_hook.vault_client.kv_engine_version @@ -202,7 +201,6 @@ def test_approle_init_params(self, mock_hvac, mock_get_connection): kwargs = { "vault_conn_id": "vault_conn_id", "auth_type": "approle", - "role_id": "role", "kv_engine_version": 2, } @@ -210,7 +208,7 @@ def test_approle_init_params(self, mock_hvac, mock_get_connection): mock_get_connection.assert_called_with("vault_conn_id") test_client = test_hook.get_conn() mock_hvac.Client.assert_called_with(url='http://localhost:8180') - test_client.auth.approle.login.assert_called_with(role_id="role", secret_id="pass") + test_client.auth.approle.login.assert_called_with(role_id="user", secret_id="pass") test_client.is_authenticated.assert_called_with() assert 2 == test_hook.vault_client.kv_engine_version @@ -222,10 +220,7 @@ def test_approle_dejson(self, mock_hvac, mock_get_connection): mock_connection = self.get_mock_connection() mock_get_connection.return_value = mock_connection - connection_dict = { - "auth_type": "approle", - 'role_id': "role", - } + connection_dict = {"auth_type": "approle"} mock_connection.extra_dejson.get.side_effect = connection_dict.get kwargs = { @@ -236,7 +231,20 @@ def test_approle_dejson(self, mock_hvac, mock_get_connection): mock_get_connection.assert_called_with("vault_conn_id") test_client = test_hook.get_conn() mock_hvac.Client.assert_called_with(url='http://localhost:8180') - test_client.auth.approle.login.assert_called_with(role_id="role", secret_id="pass") + test_client.auth.approle.login.assert_called_with(role_id="user", secret_id="pass") + test_client.is_authenticated.assert_called_with() + assert 2 == test_hook.vault_client.kv_engine_version + + @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") + @mock.patch.dict( + 'os.environ', + AIRFLOW_CONN_VAULT_CONN_ID='https://role:secret@vault.example.com?auth_type=approle', + ) + def test_approle_uri(self, mock_hvac): + test_hook = VaultHook(vault_conn_id='vault_conn_id') + test_client = test_hook.get_conn() + mock_hvac.Client.assert_called_with(url='https://vault.example.com') + test_client.auth.approle.login.assert_called_with(role_id="role", secret_id="secret") test_client.is_authenticated.assert_called_with() assert 2 == test_hook.vault_client.kv_engine_version @@ -290,6 +298,23 @@ def test_aws_iam_dejson(self, mock_hvac, mock_get_connection): role="role", ) + @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") + @mock.patch.dict( + 'os.environ', + AIRFLOW_CONN_VAULT_CONN_ID='https://login:pass@vault.example.com?auth_type=aws_iam&role_id=role', + ) + def test_aws_uri(self, mock_hvac): + test_hook = VaultHook(vault_conn_id='vault_conn_id') + test_client = test_hook.get_conn() + mock_hvac.Client.assert_called_with(url='https://vault.example.com') + test_client.auth_aws_iam.assert_called_with( + access_key='login', + secret_key='pass', + role="role", + ) + test_client.is_authenticated.assert_called_with() + assert 2 == test_hook.vault_client.kv_engine_version + @mock.patch("airflow.providers.hashicorp.hooks.vault.VaultHook.get_connection") @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") def test_azure_init_params(self, mock_hvac, mock_get_connection):