From 30817a5c6df1be3ec080ff1c542899092679768f Mon Sep 17 00:00:00 2001 From: uzhastik Date: Sat, 23 Mar 2024 00:30:33 +0300 Subject: [PATCH] support iam token from metadata, simplify code (#38411) --- airflow/providers/yandex/hooks/yq.py | 54 ++++++++----------------- airflow/providers/yandex/provider.yaml | 3 -- generated/provider_dependencies.json | 2 - pyproject.toml | 2 - tests/providers/yandex/hooks/test_yq.py | 49 ++++++++++++++++++---- 5 files changed, 59 insertions(+), 51 deletions(-) diff --git a/airflow/providers/yandex/hooks/yq.py b/airflow/providers/yandex/hooks/yq.py index 963709d89b66c..37f7550df623a 100644 --- a/airflow/providers/yandex/hooks/yq.py +++ b/airflow/providers/yandex/hooks/yq.py @@ -16,16 +16,14 @@ # under the License. from __future__ import annotations -import time from datetime import timedelta from typing import Any -import jwt -import requests -from urllib3.util.retry import Retry +import yandexcloud +import yandexcloud._auth_fabric as auth_fabric +from yandex.cloud.iam.v1.iam_token_service_pb2_grpc import IamTokenServiceStub from yandex_query_client import YQHttpClient, YQHttpClientConfig -from airflow.exceptions import AirflowException from airflow.providers.yandex.hooks.yandex import YandexCloudBaseHook from airflow.providers.yandex.utils.user_agent import provider_user_agent @@ -98,35 +96,17 @@ def compose_query_web_link(self, query_id: str): return self.client.compose_query_web_link(query_id) def _get_iam_token(self) -> str: - if "token" in self.credentials: - return self.credentials["token"] - if "service_account_key" in self.credentials: - return YQHook._resolve_service_account_key(self.credentials["service_account_key"]) - raise AirflowException(f"Unknown credentials type, available keys {self.credentials.keys()}") - - @staticmethod - def _resolve_service_account_key(sa_info: dict) -> str: - with YQHook._create_session() as session: - api = "https://iam.api.cloud.yandex.net/iam/v1/tokens" - now = int(time.time()) - payload = {"aud": api, "iss": sa_info["service_account_id"], "iat": now, "exp": now + 360} - - encoded_token = jwt.encode( - payload, sa_info["private_key"], algorithm="PS256", headers={"kid": sa_info["id"]} - ) - - data = {"jwt": encoded_token} - iam_response = session.post(api, json=data) - iam_response.raise_for_status() - - return iam_response.json()["iamToken"] - - @staticmethod - def _create_session() -> requests.Session: - session = requests.Session() - session.verify = False - retry = Retry(backoff_factor=0.3, total=10) - session.mount("http://", requests.adapters.HTTPAdapter(max_retries=retry)) - session.mount("https://", requests.adapters.HTTPAdapter(max_retries=retry)) - - return session + iam_token = self.credentials.get("token") + if iam_token is not None: + return iam_token + + service_account_key = self.credentials.get("service_account_key") + # if service_account_key is None metadata server will be used + token_requester = auth_fabric.get_auth_token_requester(service_account_key=service_account_key) + + if service_account_key is None: + return token_requester.get_token() + + sdk = yandexcloud.SDK() + client = sdk.client(IamTokenServiceStub) + return client.Create(token_requester.get_token_request()).iam_token diff --git a/airflow/providers/yandex/provider.yaml b/airflow/providers/yandex/provider.yaml index 0135ac3fb4835..df700127c4aea 100644 --- a/airflow/providers/yandex/provider.yaml +++ b/airflow/providers/yandex/provider.yaml @@ -50,9 +50,6 @@ dependencies: - apache-airflow>=2.6.0 - yandexcloud>=0.228.0 - yandex-query-client>=0.1.2 - - python-dateutil>=2.8.0 - # Requests 3 if it will be released, will be heavily breaking. - - requests>=2.27.0,<3 integrations: - integration-name: Yandex.Cloud diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index fe9848069dc08..4f110f99182f1 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -1180,8 +1180,6 @@ "yandex": { "deps": [ "apache-airflow>=2.6.0", - "python-dateutil>=2.8.0", - "requests>=2.27.0,<3", "yandex-query-client>=0.1.2", "yandexcloud>=0.228.0" ], diff --git a/pyproject.toml b/pyproject.toml index 89d7496c4afa9..ace1b0800a519 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -975,8 +975,6 @@ weaviate = [ # source: airflow/providers/weaviate/provider.yaml "weaviate-client>=3.24.2", ] yandex = [ # source: airflow/providers/yandex/provider.yaml - "python-dateutil>=2.8.0", - "requests>=2.27.0,<3", "yandex-query-client>=0.1.2", "yandexcloud>=0.228.0", ] diff --git a/tests/providers/yandex/hooks/test_yq.py b/tests/providers/yandex/hooks/test_yq.py index 3b3db91dd1eab..c378c65347f87 100644 --- a/tests/providers/yandex/hooks/test_yq.py +++ b/tests/providers/yandex/hooks/test_yq.py @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations +import json from datetime import timedelta from unittest import mock @@ -26,6 +27,7 @@ from airflow.providers.yandex.hooks.yq import YQHook OAUTH_TOKEN = "my_oauth_token" +IAM_TOKEN = "my_iam_token" SERVICE_ACCOUNT_AUTH_KEY_JSON = """{"id":"my_id", "service_account_id":"my_sa1", "private_key":"my_pk"}""" @@ -34,6 +36,18 @@ def __init__(self) -> None: self.client = None +class DummyTokenRequester: + def get_token(self) -> str: + return IAM_TOKEN + + def get_token_request(self) -> str: + return "my_dummy_request" + + +class DummyCreateTokenResponse: + iam_token = "zzz" + + class TestYandexCloudYqHook: def _init_hook(self): with mock.patch("airflow.hooks.base.BaseHook.get_connection") as mock_get_connection: @@ -68,18 +82,33 @@ def test_oauth_token_usage(self): m.assert_called_once_with("query1") @responses.activate() - @mock.patch("yandexcloud.SDK") - @mock.patch("jwt.encode") - def test_select_results(self, mock_jwt, mock_sdk): + @mock.patch("yandexcloud._auth_fabric.get_auth_token_requester", return_value=DummyTokenRequester()) + def test_metadata_token_usage(self, mock_get_auth_token_requester): responses.post( - "https://iam.api.cloud.yandex.net/iam/v1/tokens", - json={"iamToken": "super_token"}, + "https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries", + match=[ + matchers.header_matcher( + {"Content-Type": "application/json", "Authorization": f"Bearer {IAM_TOKEN}"} + ), + matchers.query_param_matcher({"project": "my_folder_id"}), + ], + json={"id": "query1"}, status=200, ) - mock_jwt.return_value = "zzzz" - mock_sdk.return_value = DummySDK() + self.connection = Connection(extra={}) + self._init_hook() + query_id = self.hook.create_query(query_text="select 777", name="my query") + assert query_id == "query1" + @mock.patch( + "yandex.cloud.iam.v1.iam_token_service_pb2_grpc.IamTokenServiceStub.Create", + create=True, + new_callable=mock.PropertyMock, + ) + @mock.patch("yandexcloud._auth_fabric.__validate_service_account_key") + @mock.patch("yandexcloud._auth_fabric.get_auth_token_requester", return_value=DummyTokenRequester()) + def test_select_results(self, mock_get_auth_token_requester, mock_validate, mock_create_token): with mock.patch.multiple( "yandex_query_client.YQHttpClient", create_query=mock.DEFAULT, @@ -90,6 +119,12 @@ def test_select_results(self, mock_jwt, mock_sdk): stop_query=mock.DEFAULT, ) as mocks: self._init_hook() + mock_validate.assert_called() + mock_create_token.assert_called() + mock_get_auth_token_requester.assert_called_once_with( + service_account_key=json.loads(SERVICE_ACCOUNT_AUTH_KEY_JSON) + ) + mocks["create_query"].return_value = "query1" mocks["wait_query_to_succeed"].return_value = 2 mocks["get_query_all_result_sets"].return_value = {"x": 765}