-
Notifications
You must be signed in to change notification settings - Fork 14.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Yandex Query support from Yandex.Cloud (#37458)
* initial commit * support web link * move jwt logic out of base hook * use http client in hook * add yq_results * add token_prefix, format exception message * use YQResults inside client * add tests, fix provider.yaml * fix oauth token usage, add tests for complex results * add tests for YQ operator * fix test name * linting * restyling * improve tests, fix close(), add link to YQ service * trim spaces * add docstrings, remove query description, move privates to bottom of the file * fix last newline * restyling * restyling * refactor, restyling * revert version * change text to trigger CI checks * fixes for linters * rework * restyling * fix CI tests, add yq link tests * add doc strings * fix link style tests * rename files, add deps, fix doc string * replace SQLExecuteQueryOperator with BaseOperator * fix static checks * fight with static checks * remove http client, use py package * fix static checks
- Loading branch information
Showing
12 changed files
with
614 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
from __future__ import annotations | ||
|
||
import time | ||
from datetime import timedelta | ||
from typing import Any | ||
|
||
import jwt | ||
import requests | ||
from urllib3.util.retry import Retry | ||
from yandex_query_client import YQHttpClient, YQHttpClientConfig | ||
|
||
from airflow.exceptions import AirflowException | ||
from airflow.providers.yandex.hooks.yandex import YandexCloudBaseHook | ||
from airflow.providers.yandex.utils.user_agent import provider_user_agent | ||
|
||
|
||
class YQHook(YandexCloudBaseHook): | ||
"""A hook for Yandex Query.""" | ||
|
||
def __init__(self, *args, **kwargs) -> None: | ||
super().__init__(*args, **kwargs) | ||
|
||
config = YQHttpClientConfig( | ||
token=self._get_iam_token(), project=self.default_folder_id, user_agent=provider_user_agent() | ||
) | ||
|
||
self.client: YQHttpClient = YQHttpClient(config=config) | ||
|
||
def close(self): | ||
"""Release all resources.""" | ||
self.client.close() | ||
|
||
def create_query(self, query_text: str | None, name: str | None = None) -> str: | ||
"""Create and run query. | ||
:param query_text: SQL text. | ||
:param name: name for the query | ||
""" | ||
return self.client.create_query( | ||
name=name, | ||
query_text=query_text, | ||
) | ||
|
||
def wait_results(self, query_id: str, execution_timeout: timedelta = timedelta(minutes=30)) -> Any: | ||
"""Wait for query complete and get results. | ||
:param query_id: ID of query. | ||
:param execution_timeout: how long to wait for the query to complete. | ||
""" | ||
result_set_count = self.client.wait_query_to_succeed( | ||
query_id, execution_timeout=execution_timeout, stop_on_timeout=True | ||
) | ||
|
||
return self.client.get_query_all_result_sets(query_id=query_id, result_set_count=result_set_count) | ||
|
||
def stop_query(self, query_id: str) -> None: | ||
"""Stop the query. | ||
:param query_id: ID of the query. | ||
""" | ||
self.client.stop_query(query_id) | ||
|
||
def get_query(self, query_id: str) -> Any: | ||
"""Get query info. | ||
:param query_id: ID of the query. | ||
""" | ||
return self.client.get_query(query_id) | ||
|
||
def get_query_status(self, query_id: str) -> str: | ||
"""Get status fo the query. | ||
:param query_id: ID of query. | ||
""" | ||
return self.client.get_query_status(query_id) | ||
|
||
def compose_query_web_link(self, query_id: str): | ||
"""Compose web link to query in Yandex Query UI. | ||
:param query_id: ID of query. | ||
""" | ||
return self.client.compose_query_web_link(query_id) | ||
|
||
def _get_iam_token(self) -> str: | ||
if "token" in self.credentials: | ||
return self.credentials["token"] | ||
if "service_account_key" in self.credentials: | ||
return YQHook._resolve_service_account_key(self.credentials["service_account_key"]) | ||
raise AirflowException(f"Unknown credentials type, available keys {self.credentials.keys()}") | ||
|
||
@staticmethod | ||
def _resolve_service_account_key(sa_info: dict) -> str: | ||
with YQHook._create_session() as session: | ||
api = "https://iam.api.cloud.yandex.net/iam/v1/tokens" | ||
now = int(time.time()) | ||
payload = {"aud": api, "iss": sa_info["service_account_id"], "iat": now, "exp": now + 360} | ||
|
||
encoded_token = jwt.encode( | ||
payload, sa_info["private_key"], algorithm="PS256", headers={"kid": sa_info["id"]} | ||
) | ||
|
||
data = {"jwt": encoded_token} | ||
iam_response = session.post(api, json=data) | ||
iam_response.raise_for_status() | ||
|
||
return iam_response.json()["iamToken"] | ||
|
||
@staticmethod | ||
def _create_session() -> requests.Session: | ||
session = requests.Session() | ||
session.verify = False | ||
retry = Retry(backoff_factor=0.3, total=10) | ||
session.mount("http://", requests.adapters.HTTPAdapter(max_retries=retry)) | ||
session.mount("https://", requests.adapters.HTTPAdapter(max_retries=retry)) | ||
|
||
return session |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
from __future__ import annotations | ||
|
||
from typing import TYPE_CHECKING | ||
|
||
from airflow.models import BaseOperatorLink, XCom | ||
|
||
if TYPE_CHECKING: | ||
from airflow.models import BaseOperator | ||
from airflow.models.taskinstancekey import TaskInstanceKey | ||
from airflow.utils.context import Context | ||
|
||
XCOM_WEBLINK_KEY = "web_link" | ||
|
||
|
||
class YQLink(BaseOperatorLink): | ||
"""Web link to query in Yandex Query UI.""" | ||
|
||
name = "Yandex Query" | ||
|
||
def get_link(self, operator: BaseOperator, *, ti_key: TaskInstanceKey): | ||
return XCom.get_value(key=XCOM_WEBLINK_KEY, ti_key=ti_key) or "https://yq.cloud.yandex.ru" | ||
|
||
@staticmethod | ||
def persist(context: Context, task_instance: BaseOperator, web_link: str) -> None: | ||
task_instance.xcom_push(context, key=XCOM_WEBLINK_KEY, value=web_link) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
from __future__ import annotations | ||
|
||
from functools import cached_property | ||
from typing import TYPE_CHECKING, Any, Sequence | ||
|
||
from airflow.models import BaseOperator | ||
from airflow.providers.yandex.hooks.yq import YQHook | ||
from airflow.providers.yandex.links.yq import YQLink | ||
|
||
if TYPE_CHECKING: | ||
from airflow.utils.context import Context | ||
|
||
|
||
class YQExecuteQueryOperator(BaseOperator): | ||
""" | ||
Executes sql code using Yandex Query service. | ||
:param sql: the SQL code to be executed as a single string | ||
:param name: name of the query in YandexQuery | ||
:param folder_id: cloud folder id where to create query | ||
:param yandex_conn_id: Airflow connection ID to get parameters from | ||
""" | ||
|
||
operator_extra_links = (YQLink(),) | ||
template_fields: Sequence[str] = ("sql",) | ||
template_fields_renderers = {"sql": "sql"} | ||
template_ext: Sequence[str] = (".sql",) | ||
ui_color = "#ededed" | ||
|
||
def __init__( | ||
self, | ||
*, | ||
name: str | None = None, | ||
folder_id: str | None = None, | ||
yandex_conn_id: str | None = None, | ||
public_ssh_key: str | None = None, | ||
service_account_id: str | None = None, | ||
sql: str, | ||
**kwargs, | ||
) -> None: | ||
super().__init__(**kwargs) | ||
self.name = name | ||
self.folder_id = folder_id | ||
self.yandex_conn_id = yandex_conn_id | ||
self.public_ssh_key = public_ssh_key | ||
self.service_account_id = service_account_id | ||
self.sql = sql | ||
|
||
self.query_id: str | None = None | ||
|
||
@cached_property | ||
def hook(self) -> YQHook: | ||
"""Get valid hook.""" | ||
return YQHook( | ||
yandex_conn_id=self.yandex_conn_id, | ||
default_folder_id=self.folder_id, | ||
default_public_ssh_key=self.public_ssh_key, | ||
default_service_account_id=self.service_account_id, | ||
) | ||
|
||
def execute(self, context: Context) -> Any: | ||
self.query_id = self.hook.create_query(query_text=self.sql, name=self.name) | ||
|
||
# pass to YQLink | ||
web_link = self.hook.compose_query_web_link(self.query_id) | ||
YQLink.persist(context, self, web_link) | ||
|
||
results = self.hook.wait_results(self.query_id) | ||
# forget query to avoid 'stop_query' in on_kill | ||
self.query_id = None | ||
return results | ||
|
||
def on_kill(self) -> None: | ||
if self.hook is not None and self.query_id is not None: | ||
self.hook.stop_query(self.query_id) | ||
self.hook.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.