From a61e0c1df7cd8a25ac67fdfc778350e148510743 Mon Sep 17 00:00:00 2001 From: Peter Reznikov Date: Fri, 29 Jul 2022 12:56:36 +0300 Subject: [PATCH] YandexCloud provider: Support new Yandex SDK features for DataProc (#25158) --- airflow/providers/yandex/hooks/yandex.py | 4 +- .../yandex/operators/yandexcloud_dataproc.py | 190 ++++++++++------- airflow/providers/yandex/provider.yaml | 2 +- generated/README.md | 2 + generated/provider_dependencies.json | 2 +- tests/providers/yandex/hooks/test_yandex.py | 18 +- .../operators/test_yandexcloud_dataproc.py | 5 + .../providers/yandex/example_yandexcloud.py | 197 ++++++++++++++++++ ...xample_yandexcloud_dataproc_lightweight.py | 80 +++++++ 9 files changed, 412 insertions(+), 88 deletions(-) create mode 100644 tests/system/providers/yandex/example_yandexcloud.py create mode 100644 tests/system/providers/yandex/example_yandexcloud_dataproc_lightweight.py diff --git a/airflow/providers/yandex/hooks/yandex.py b/airflow/providers/yandex/hooks/yandex.py index a337954496241..deeac7b0fcc35 100644 --- a/airflow/providers/yandex/hooks/yandex.py +++ b/airflow/providers/yandex/hooks/yandex.py @@ -17,7 +17,7 @@ import json import warnings -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional import yandexcloud @@ -107,7 +107,7 @@ def __init__( # Connection id is deprecated. Use yandex_conn_id instead connection_id: Optional[str] = None, yandex_conn_id: Optional[str] = None, - default_folder_id: Union[dict, bool, None] = None, + default_folder_id: Optional[str] = None, default_public_ssh_key: Optional[str] = None, ) -> None: super().__init__() diff --git a/airflow/providers/yandex/operators/yandexcloud_dataproc.py b/airflow/providers/yandex/operators/yandexcloud_dataproc.py index 1a9dd1acf05bc..ec6d8d684953d 100644 --- a/airflow/providers/yandex/operators/yandexcloud_dataproc.py +++ b/airflow/providers/yandex/operators/yandexcloud_dataproc.py @@ -14,7 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +import warnings +from dataclasses import dataclass from typing import TYPE_CHECKING, Dict, Iterable, Optional, Sequence, Union from airflow.models import BaseOperator @@ -24,6 +25,15 @@ from airflow.utils.context import Context +@dataclass +class InitializationAction: + """Data for initialization action to be run at start of DataProc cluster.""" + + uri: str # Uri of the executable file + args: Sequence[str] # Arguments to the initialization action + timeout: int # Execution timeout + + class DataprocCreateClusterOperator(BaseOperator): """Creates Yandex.Cloud Data Proc cluster. @@ -69,9 +79,20 @@ class DataprocCreateClusterOperator(BaseOperator): in percents. 10-100. By default is not set and default autoscaling strategy is used. :param computenode_decommission_timeout: Timeout to gracefully decommission nodes during downscaling. - In seconds. + In seconds + :param properties: Properties passed to main node software. + Docs: https://cloud.yandex.com/docs/data-proc/concepts/settings-list + :param enable_ui_proxy: Enable UI Proxy feature for forwarding Hadoop components web interfaces + Docs: https://cloud.yandex.com/docs/data-proc/concepts/ui-proxy + :param host_group_ids: Dedicated host groups to place VMs of cluster on. + Docs: https://cloud.yandex.com/docs/compute/concepts/dedicated-host + :param security_group_ids: User security groups. + Docs: https://cloud.yandex.com/docs/data-proc/concepts/network#security-groups :param log_group_id: Id of log group to write logs. By default logs will be sent to default log group. To disable cloud log sending set cluster property dataproc:disable_cloud_logging = true + Docs: https://cloud.yandex.com/docs/data-proc/concepts/logs + :param initialization_actions: Set of init-actions to run when cluster starts. + Docs: https://cloud.yandex.com/docs/data-proc/concepts/init-action """ def __init__( @@ -106,7 +127,12 @@ def __init__( computenode_cpu_utilization_target: Optional[int] = None, computenode_decommission_timeout: Optional[int] = None, connection_id: Optional[str] = None, + properties: Optional[Dict[str, str]] = None, + enable_ui_proxy: bool = False, + host_group_ids: Optional[Iterable[str]] = None, + security_group_ids: Optional[Iterable[str]] = None, log_group_id: Optional[str] = None, + initialization_actions: Optional[Iterable[InitializationAction]] = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -139,11 +165,16 @@ def __init__( self.computenode_preemptible = computenode_preemptible self.computenode_cpu_utilization_target = computenode_cpu_utilization_target self.computenode_decommission_timeout = computenode_decommission_timeout + self.properties = properties + self.enable_ui_proxy = enable_ui_proxy + self.host_group_ids = host_group_ids + self.security_group_ids = security_group_ids self.log_group_id = log_group_id + self.initialization_actions = initialization_actions self.hook: Optional[DataprocHook] = None - def execute(self, context: 'Context') -> None: + def execute(self, context: 'Context') -> dict: self.hook = DataprocHook( yandex_conn_id=self.yandex_conn_id, ) @@ -176,14 +207,35 @@ def execute(self, context: 'Context') -> None: computenode_preemptible=self.computenode_preemptible, computenode_cpu_utilization_target=self.computenode_cpu_utilization_target, computenode_decommission_timeout=self.computenode_decommission_timeout, + properties=self.properties, + enable_ui_proxy=self.enable_ui_proxy, + host_group_ids=self.host_group_ids, + security_group_ids=self.security_group_ids, log_group_id=self.log_group_id, + initialization_actions=self.initialization_actions + and [ + self.hook.sdk.wrappers.InitializationAction( + uri=init_action.uri, + args=init_action.args, + timeout=init_action.timeout, + ) + for init_action in self.initialization_actions + ], ) - context['task_instance'].xcom_push(key='cluster_id', value=operation_result.response.id) + cluster_id = operation_result.response.id + + context['task_instance'].xcom_push(key='cluster_id', value=cluster_id) + # Deprecated context['task_instance'].xcom_push(key='yandexcloud_connection_id', value=self.yandex_conn_id) + return cluster_id + @property + def cluster_id(self): + return self.output -class DataprocDeleteClusterOperator(BaseOperator): - """Deletes Yandex.Cloud Data Proc cluster. + +class DataprocBaseOperator(BaseOperator): + """Base class for DataProc operators working with given cluster. :param connection_id: ID of the Yandex.Cloud Airflow connection. :param cluster_id: ID of the cluster to remove. (templated) @@ -192,25 +244,45 @@ class DataprocDeleteClusterOperator(BaseOperator): template_fields: Sequence[str] = ('cluster_id',) def __init__( - self, *, connection_id: Optional[str] = None, cluster_id: Optional[str] = None, **kwargs + self, *, yandex_conn_id: Optional[str] = None, cluster_id: Optional[str] = None, **kwargs ) -> None: super().__init__(**kwargs) - self.yandex_conn_id = connection_id self.cluster_id = cluster_id - self.hook: Optional[DataprocHook] = None + self.yandex_conn_id = yandex_conn_id + + def _setup(self, context: 'Context') -> DataprocHook: + if self.cluster_id is None: + self.cluster_id = context['task_instance'].xcom_pull(key='cluster_id') + if self.yandex_conn_id is None: + xcom_yandex_conn_id = context['task_instance'].xcom_pull(key='yandexcloud_connection_id') + if xcom_yandex_conn_id: + warnings.warn('Implicit pass of `yandex_conn_id` is deprecated, please pass it explicitly') + self.yandex_conn_id = xcom_yandex_conn_id + + return DataprocHook(yandex_conn_id=self.yandex_conn_id) + + def execute(self, context: 'Context'): + raise NotImplementedError() + + +class DataprocDeleteClusterOperator(DataprocBaseOperator): + """Deletes Yandex.Cloud Data Proc cluster. + + :param connection_id: ID of the Yandex.Cloud Airflow connection. + :param cluster_id: ID of the cluster to remove. (templated) + """ + + def __init__( + self, *, connection_id: Optional[str] = None, cluster_id: Optional[str] = None, **kwargs + ) -> None: + super().__init__(yandex_conn_id=connection_id, cluster_id=cluster_id, **kwargs) def execute(self, context: 'Context') -> None: - cluster_id = self.cluster_id or context['task_instance'].xcom_pull(key='cluster_id') - yandex_conn_id = self.yandex_conn_id or context['task_instance'].xcom_pull( - key='yandexcloud_connection_id' - ) - self.hook = DataprocHook( - yandex_conn_id=yandex_conn_id, - ) - self.hook.client.delete_cluster(cluster_id) + hook = self._setup(context) + hook.client.delete_cluster(self.cluster_id) -class DataprocCreateHiveJobOperator(BaseOperator): +class DataprocCreateHiveJobOperator(DataprocBaseOperator): """Runs Hive job in Data Proc cluster. :param query: Hive query. @@ -224,8 +296,6 @@ class DataprocCreateHiveJobOperator(BaseOperator): :param connection_id: ID of the Yandex.Cloud Airflow connection. """ - template_fields: Sequence[str] = ('cluster_id',) - def __init__( self, *, @@ -239,37 +309,28 @@ def __init__( connection_id: Optional[str] = None, **kwargs, ) -> None: - super().__init__(**kwargs) + super().__init__(yandex_conn_id=connection_id, cluster_id=cluster_id, **kwargs) self.query = query self.query_file_uri = query_file_uri self.script_variables = script_variables self.continue_on_failure = continue_on_failure self.properties = properties self.name = name - self.cluster_id = cluster_id - self.connection_id = connection_id - self.hook: Optional[DataprocHook] = None def execute(self, context: 'Context') -> None: - cluster_id = self.cluster_id or context['task_instance'].xcom_pull(key='cluster_id') - yandex_conn_id = self.connection_id or context['task_instance'].xcom_pull( - key='yandexcloud_connection_id' - ) - self.hook = DataprocHook( - yandex_conn_id=yandex_conn_id, - ) - self.hook.client.create_hive_job( + hook = self._setup(context) + hook.client.create_hive_job( query=self.query, query_file_uri=self.query_file_uri, script_variables=self.script_variables, continue_on_failure=self.continue_on_failure, properties=self.properties, name=self.name, - cluster_id=cluster_id, + cluster_id=self.cluster_id, ) -class DataprocCreateMapReduceJobOperator(BaseOperator): +class DataprocCreateMapReduceJobOperator(DataprocBaseOperator): """Runs Mapreduce job in Data Proc cluster. :param main_jar_file_uri: URI of jar file with job. @@ -286,8 +347,6 @@ class DataprocCreateMapReduceJobOperator(BaseOperator): :param connection_id: ID of the Yandex.Cloud Airflow connection. """ - template_fields: Sequence[str] = ('cluster_id',) - def __init__( self, *, @@ -303,7 +362,7 @@ def __init__( connection_id: Optional[str] = None, **kwargs, ) -> None: - super().__init__(**kwargs) + super().__init__(yandex_conn_id=connection_id, cluster_id=cluster_id, **kwargs) self.main_class = main_class self.main_jar_file_uri = main_jar_file_uri self.jar_file_uris = jar_file_uris @@ -312,19 +371,10 @@ def __init__( self.args = args self.properties = properties self.name = name - self.cluster_id = cluster_id - self.connection_id = connection_id - self.hook: Optional[DataprocHook] = None def execute(self, context: 'Context') -> None: - cluster_id = self.cluster_id or context['task_instance'].xcom_pull(key='cluster_id') - yandex_conn_id = self.connection_id or context['task_instance'].xcom_pull( - key='yandexcloud_connection_id' - ) - self.hook = DataprocHook( - yandex_conn_id=yandex_conn_id, - ) - self.hook.client.create_mapreduce_job( + hook = self._setup(context) + hook.client.create_mapreduce_job( main_class=self.main_class, main_jar_file_uri=self.main_jar_file_uri, jar_file_uris=self.jar_file_uris, @@ -333,11 +383,11 @@ def execute(self, context: 'Context') -> None: args=self.args, properties=self.properties, name=self.name, - cluster_id=cluster_id, + cluster_id=self.cluster_id, ) -class DataprocCreateSparkJobOperator(BaseOperator): +class DataprocCreateSparkJobOperator(DataprocBaseOperator): """Runs Spark job in Data Proc cluster. :param main_jar_file_uri: URI of jar file with job. Can be placed in HDFS or S3. @@ -358,8 +408,6 @@ class DataprocCreateSparkJobOperator(BaseOperator): provided in --packages to avoid dependency conflicts. """ - template_fields: Sequence[str] = ('cluster_id',) - def __init__( self, *, @@ -378,7 +426,7 @@ def __init__( exclude_packages: Optional[Iterable[str]] = None, **kwargs, ) -> None: - super().__init__(**kwargs) + super().__init__(yandex_conn_id=connection_id, cluster_id=cluster_id, **kwargs) self.main_class = main_class self.main_jar_file_uri = main_jar_file_uri self.jar_file_uris = jar_file_uris @@ -387,22 +435,13 @@ def __init__( self.args = args self.properties = properties self.name = name - self.cluster_id = cluster_id - self.connection_id = connection_id self.packages = packages self.repositories = repositories self.exclude_packages = exclude_packages - self.hook: Optional[DataprocHook] = None def execute(self, context: 'Context') -> None: - cluster_id = self.cluster_id or context['task_instance'].xcom_pull(key='cluster_id') - yandex_conn_id = self.connection_id or context['task_instance'].xcom_pull( - key='yandexcloud_connection_id' - ) - self.hook = DataprocHook( - yandex_conn_id=yandex_conn_id, - ) - self.hook.client.create_spark_job( + hook = self._setup(context) + hook.client.create_spark_job( main_class=self.main_class, main_jar_file_uri=self.main_jar_file_uri, jar_file_uris=self.jar_file_uris, @@ -414,11 +453,11 @@ def execute(self, context: 'Context') -> None: repositories=self.repositories, exclude_packages=self.exclude_packages, name=self.name, - cluster_id=cluster_id, + cluster_id=self.cluster_id, ) -class DataprocCreatePysparkJobOperator(BaseOperator): +class DataprocCreatePysparkJobOperator(DataprocBaseOperator): """Runs Pyspark job in Data Proc cluster. :param main_python_file_uri: URI of python file with job. Can be placed in HDFS or S3. @@ -439,8 +478,6 @@ class DataprocCreatePysparkJobOperator(BaseOperator): provided in --packages to avoid dependency conflicts. """ - template_fields: Sequence[str] = ('cluster_id',) - def __init__( self, *, @@ -459,7 +496,7 @@ def __init__( exclude_packages: Optional[Iterable[str]] = None, **kwargs, ) -> None: - super().__init__(**kwargs) + super().__init__(yandex_conn_id=connection_id, cluster_id=cluster_id, **kwargs) self.main_python_file_uri = main_python_file_uri self.python_file_uris = python_file_uris self.jar_file_uris = jar_file_uris @@ -468,22 +505,13 @@ def __init__( self.args = args self.properties = properties self.name = name - self.cluster_id = cluster_id - self.connection_id = connection_id self.packages = packages self.repositories = repositories self.exclude_packages = exclude_packages - self.hook: Optional[DataprocHook] = None def execute(self, context: 'Context') -> None: - cluster_id = self.cluster_id or context['task_instance'].xcom_pull(key='cluster_id') - yandex_conn_id = self.connection_id or context['task_instance'].xcom_pull( - key='yandexcloud_connection_id' - ) - self.hook = DataprocHook( - yandex_conn_id=yandex_conn_id, - ) - self.hook.client.create_pyspark_job( + hook = self._setup(context) + hook.client.create_pyspark_job( main_python_file_uri=self.main_python_file_uri, python_file_uris=self.python_file_uris, jar_file_uris=self.jar_file_uris, @@ -495,5 +523,5 @@ def execute(self, context: 'Context') -> None: repositories=self.repositories, exclude_packages=self.exclude_packages, name=self.name, - cluster_id=cluster_id, + cluster_id=self.cluster_id, ) diff --git a/airflow/providers/yandex/provider.yaml b/airflow/providers/yandex/provider.yaml index c066a2f8ed837..90f6cc1c9a321 100644 --- a/airflow/providers/yandex/provider.yaml +++ b/airflow/providers/yandex/provider.yaml @@ -34,7 +34,7 @@ versions: dependencies: - apache-airflow>=2.2.0 - - yandexcloud>=0.146.0 + - yandexcloud>=0.173.0 integrations: - integration-name: Yandex.Cloud diff --git a/generated/README.md b/generated/README.md index f87a767da44c1..d1dcc1f78307e 100644 --- a/generated/README.md +++ b/generated/README.md @@ -20,6 +20,8 @@ NOTE! The files in this folder are generated by pre-commit based on airflow sources. They are not supposed to be manually modified. +You can read more about pre-commit hooks [here](../STATIC_CODE_CHECKS.rst#pre-commit-hooks). + * `provider_dependencies.json` - is generated based on `provider.yaml` files in `airflow/providers` and based on the imports in the provider code. If you want to add new dependency to a provider, you need to modify the corresponding `provider.yaml` file diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index 743a73d0da786..51e3ea26f88bd 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -721,7 +721,7 @@ "yandex": { "deps": [ "apache-airflow>=2.2.0", - "yandexcloud>=0.146.0" + "yandexcloud>=0.173.0" ], "cross-providers-deps": [] }, diff --git a/tests/providers/yandex/hooks/test_yandex.py b/tests/providers/yandex/hooks/test_yandex.py index b4ddf0e121ab0..a1ada7aefa78c 100644 --- a/tests/providers/yandex/hooks/test_yandex.py +++ b/tests/providers/yandex/hooks/test_yandex.py @@ -43,7 +43,11 @@ def test_client_created_without_exceptions(self, get_credentials_mock, get_conne ) get_credentials_mock.return_value = {"token": 122323} - hook = YandexCloudBaseHook(None, default_folder_id, default_public_ssh_key) + hook = YandexCloudBaseHook( + yandex_conn_id=None, + default_folder_id=default_folder_id, + default_public_ssh_key=default_public_ssh_key, + ) assert hook.client is not None @mock.patch('airflow.hooks.base.BaseHook.get_connection') @@ -63,7 +67,11 @@ def test_get_credentials_raise_exception(self, get_connection_mock): ) with pytest.raises(AirflowException): - YandexCloudBaseHook(None, default_folder_id, default_public_ssh_key) + YandexCloudBaseHook( + yandex_conn_id=None, + default_folder_id=default_folder_id, + default_public_ssh_key=default_public_ssh_key, + ) @mock.patch('airflow.hooks.base.BaseHook.get_connection') @mock.patch('airflow.providers.yandex.hooks.yandex.YandexCloudBaseHook._get_credentials') @@ -80,6 +88,10 @@ def test_get_field(self, get_credentials_mock, get_connection_mock): ) get_credentials_mock.return_value = {"token": 122323} - hook = YandexCloudBaseHook(None, default_folder_id, default_public_ssh_key) + hook = YandexCloudBaseHook( + yandex_conn_id=None, + default_folder_id=default_folder_id, + default_public_ssh_key=default_public_ssh_key, + ) assert hook._get_field('one') == 'value_one' diff --git a/tests/providers/yandex/operators/test_yandexcloud_dataproc.py b/tests/providers/yandex/operators/test_yandexcloud_dataproc.py index f54087c742af9..23cda00e4a887 100644 --- a/tests/providers/yandex/operators/test_yandexcloud_dataproc.py +++ b/tests/providers/yandex/operators/test_yandexcloud_dataproc.py @@ -127,6 +127,11 @@ def test_create_cluster(self, create_cluster_mock, *_): subnet_id='my_subnet_id', zone='ru-central1-c', log_group_id=LOG_GROUP_ID, + properties=None, + enable_ui_proxy=False, + host_group_ids=None, + security_group_ids=None, + initialization_actions=None, ) context['task_instance'].xcom_push.assert_has_calls( [ diff --git a/tests/system/providers/yandex/example_yandexcloud.py b/tests/system/providers/yandex/example_yandexcloud.py new file mode 100644 index 0000000000000..708a6049a75ad --- /dev/null +++ b/tests/system/providers/yandex/example_yandexcloud.py @@ -0,0 +1,197 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import os +from datetime import datetime +from typing import Optional + +import yandex.cloud.dataproc.v1.cluster_pb2 as cluster_pb +import yandex.cloud.dataproc.v1.cluster_service_pb2 as cluster_service_pb +import yandex.cloud.dataproc.v1.cluster_service_pb2_grpc as cluster_service_grpc_pb +import yandex.cloud.dataproc.v1.common_pb2 as common_pb +import yandex.cloud.dataproc.v1.job_pb2 as job_pb +import yandex.cloud.dataproc.v1.job_service_pb2 as job_service_pb +import yandex.cloud.dataproc.v1.job_service_pb2_grpc as job_service_grpc_pb +import yandex.cloud.dataproc.v1.subcluster_pb2 as subcluster_pb +from google.protobuf.json_format import MessageToDict + +from airflow import DAG +from airflow.decorators import task +from airflow.providers.yandex.hooks.yandex import YandexCloudBaseHook + +ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID") +DAG_ID = 'example_yandexcloud_hook' + +# Fill it with your identifiers +YC_S3_BUCKET_NAME = '' # Fill to use S3 instead of HFDS +YC_FOLDER_ID = None # Fill to override default YC folder from connection data +YC_ZONE_NAME = 'ru-central1-b' +YC_SUBNET_ID = None # Fill if you have more than one VPC subnet in given folder and zone +YC_SERVICE_ACCOUNT_ID = None # Fill if you have more than one YC service account in given folder + + +def create_cluster_request( + folder_id: str, + cluster_name: str, + cluster_desc: str, + zone: str, + subnet_id: str, + service_account_id: str, + ssh_public_key: str, + resources: common_pb.Resources, +): + return cluster_service_pb.CreateClusterRequest( + folder_id=folder_id, + name=cluster_name, + description=cluster_desc, + bucket=YC_S3_BUCKET_NAME, + config_spec=cluster_service_pb.CreateClusterConfigSpec( + hadoop=cluster_pb.HadoopConfig( + services=('SPARK', 'YARN'), + ssh_public_keys=[ssh_public_key], + ), + subclusters_spec=[ + cluster_service_pb.CreateSubclusterConfigSpec( + name='master', + role=subcluster_pb.Role.MASTERNODE, + resources=resources, + subnet_id=subnet_id, + hosts_count=1, + ), + cluster_service_pb.CreateSubclusterConfigSpec( + name='compute', + role=subcluster_pb.Role.COMPUTENODE, + resources=resources, + subnet_id=subnet_id, + hosts_count=1, + ), + ], + ), + zone_id=zone, + service_account_id=service_account_id, + ) + + +@task +def create_cluster( + yandex_conn_id: Optional[str] = None, + folder_id: Optional[str] = None, + network_id: Optional[str] = None, + subnet_id: Optional[str] = None, + zone: str = YC_ZONE_NAME, + service_account_id: Optional[str] = None, + ssh_public_key: Optional[str] = None, + *, + dag: Optional[DAG] = None, + ts_nodash: Optional[str] = None, +) -> str: + hook = YandexCloudBaseHook(yandex_conn_id=yandex_conn_id) + folder_id = folder_id or hook.default_folder_id + if subnet_id is None: + network_id = network_id or hook.sdk.helpers.find_network_id(folder_id) + subnet_id = hook.sdk.helpers.find_subnet_id(folder_id=folder_id, zone_id=zone, network_id=network_id) + service_account_id = service_account_id or hook.sdk.helpers.find_service_account_id() + ssh_public_key = ssh_public_key or hook.default_public_ssh_key + + dag_id = dag and dag.dag_id or 'dag' + + request = create_cluster_request( + folder_id=folder_id, + subnet_id=subnet_id, + zone=zone, + cluster_name=f'airflow_{dag_id}_{ts_nodash}'[:62], + cluster_desc='Created via Airflow custom hook task', + service_account_id=service_account_id, + ssh_public_key=ssh_public_key, + resources=common_pb.Resources( + resource_preset_id='s2.micro', + disk_type_id='network-ssd', + ), + ) + operation = hook.sdk.client(cluster_service_grpc_pb.ClusterServiceStub).Create(request) + operation_result = hook.sdk.wait_operation_and_get_result( + operation, response_type=cluster_pb.Cluster, meta_type=cluster_service_pb.CreateClusterMetadata + ) + return operation_result.response.id + + +@task +def run_spark_job( + cluster_id: str, + yandex_conn_id: Optional[str] = None, +): + hook = YandexCloudBaseHook(yandex_conn_id=yandex_conn_id) + + request = job_service_pb.CreateJobRequest( + cluster_id=cluster_id, + name='Spark job: Find total urban population in distribution by country', + spark_job=job_pb.SparkJob( + main_jar_file_uri='file:///usr/lib/spark/examples/jars/spark-examples.jar', + main_class='org.apache.spark.examples.SparkPi', + args=['1000'], + ), + ) + operation = hook.sdk.client(job_service_grpc_pb.JobServiceStub).Create(request) + operation_result = hook.sdk.wait_operation_and_get_result( + operation, response_type=job_pb.Job, meta_type=job_service_pb.CreateJobMetadata + ) + return MessageToDict(operation_result.response) + + +@task(trigger_rule='all_done') +def delete_cluster( + cluster_id: str, + yandex_conn_id: Optional[str] = None, +): + hook = YandexCloudBaseHook(yandex_conn_id=yandex_conn_id) + + operation = hook.sdk.client(cluster_service_grpc_pb.ClusterServiceStub).Delete( + cluster_service_pb.DeleteClusterRequest(cluster_id=cluster_id) + ) + hook.sdk.wait_operation_and_get_result( + operation, + meta_type=cluster_service_pb.DeleteClusterMetadata, + ) + + +with DAG( + dag_id=DAG_ID, + schedule_interval=None, + start_date=datetime(2021, 1, 1), + tags=['example'], +) as dag: + cluster_id = create_cluster( + folder_id=YC_FOLDER_ID, + subnet_id=YC_SUBNET_ID, + zone=YC_ZONE_NAME, + service_account_id=YC_SERVICE_ACCOUNT_ID, + ) + spark_job = run_spark_job(cluster_id=cluster_id) + delete_task = delete_cluster(cluster_id=cluster_id) + + spark_job >> delete_task + + from tests.system.utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "teardown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() + + +from tests.system.utils import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag) diff --git a/tests/system/providers/yandex/example_yandexcloud_dataproc_lightweight.py b/tests/system/providers/yandex/example_yandexcloud_dataproc_lightweight.py new file mode 100644 index 0000000000000..d5faa0865e947 --- /dev/null +++ b/tests/system/providers/yandex/example_yandexcloud_dataproc_lightweight.py @@ -0,0 +1,80 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import os +from datetime import datetime + +from airflow import DAG +from airflow.providers.yandex.operators.yandexcloud_dataproc import ( + DataprocCreateClusterOperator, + DataprocCreateSparkJobOperator, + DataprocDeleteClusterOperator, +) + +# Name of the datacenter where Dataproc cluster will be created +from airflow.utils.trigger_rule import TriggerRule + +# should be filled with appropriate ids + + +AVAILABILITY_ZONE_ID = 'ru-central1-c' + +# Dataproc cluster will use this bucket as distributed storage +S3_BUCKET_NAME = '' + +ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID") +DAG_ID = 'example_yandexcloud_dataproc_lightweight' + +with DAG( + DAG_ID, + schedule_interval=None, + start_date=datetime(2021, 1, 1), + tags=['example'], +) as dag: + create_cluster = DataprocCreateClusterOperator( + task_id='create_cluster', + zone=AVAILABILITY_ZONE_ID, + s3_bucket=S3_BUCKET_NAME, + computenode_count=1, + datanode_count=0, + services=('SPARK', 'YARN'), + ) + + create_spark_job = DataprocCreateSparkJobOperator( + cluster_id=create_cluster.cluster_id, + task_id='create_spark_job', + main_jar_file_uri='file:///usr/lib/spark/examples/jars/spark-examples.jar', + main_class='org.apache.spark.examples.SparkPi', + args=['1000'], + ) + + delete_cluster = DataprocDeleteClusterOperator( + cluster_id=create_cluster.cluster_id, + task_id='delete_cluster', + trigger_rule=TriggerRule.ALL_DONE, + ) + create_spark_job >> delete_cluster + + from tests.system.utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "teardown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() + +from tests.system.utils import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag)