Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

YandexCloud provider: Support new Yandex SDK features for DataProc #25158

Merged
merged 5 commits into from
Jul 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions airflow/providers/yandex/hooks/yandex.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import json
import warnings
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, Optional

import yandexcloud

Expand Down Expand Up @@ -107,7 +107,7 @@ def __init__(
# Connection id is deprecated. Use yandex_conn_id instead
connection_id: Optional[str] = None,
yandex_conn_id: Optional[str] = None,
default_folder_id: Union[dict, bool, None] = None,
default_folder_id: Optional[str] = None,
default_public_ssh_key: Optional[str] = None,
) -> None:
super().__init__()
Expand Down
190 changes: 109 additions & 81 deletions airflow/providers/yandex/operators/yandexcloud_dataproc.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion airflow/providers/yandex/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ versions:

dependencies:
- apache-airflow>=2.2.0
- yandexcloud>=0.146.0
- yandexcloud>=0.173.0

integrations:
- integration-name: Yandex.Cloud
Expand Down
2 changes: 2 additions & 0 deletions generated/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
NOTE! The files in this folder are generated by pre-commit based on airflow sources. They are not
supposed to be manually modified.

You can read more about pre-commit hooks [here](../STATIC_CODE_CHECKS.rst#pre-commit-hooks).

* `provider_dependencies.json` - is generated based on `provider.yaml` files in `airflow/providers` and
based on the imports in the provider code. If you want to add new dependency to a provider, you
need to modify the corresponding `provider.yaml` file
2 changes: 1 addition & 1 deletion generated/provider_dependencies.json
Original file line number Diff line number Diff line change
Expand Up @@ -721,7 +721,7 @@
"yandex": {
"deps": [
"apache-airflow>=2.2.0",
"yandexcloud>=0.146.0"
"yandexcloud>=0.173.0"
],
"cross-providers-deps": []
},
Expand Down
18 changes: 15 additions & 3 deletions tests/providers/yandex/hooks/test_yandex.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,11 @@ def test_client_created_without_exceptions(self, get_credentials_mock, get_conne
)
get_credentials_mock.return_value = {"token": 122323}

hook = YandexCloudBaseHook(None, default_folder_id, default_public_ssh_key)
hook = YandexCloudBaseHook(
yandex_conn_id=None,
default_folder_id=default_folder_id,
default_public_ssh_key=default_public_ssh_key,
)
assert hook.client is not None

@mock.patch('airflow.hooks.base.BaseHook.get_connection')
Expand All @@ -63,7 +67,11 @@ def test_get_credentials_raise_exception(self, get_connection_mock):
)

with pytest.raises(AirflowException):
YandexCloudBaseHook(None, default_folder_id, default_public_ssh_key)
YandexCloudBaseHook(
yandex_conn_id=None,
default_folder_id=default_folder_id,
default_public_ssh_key=default_public_ssh_key,
)

@mock.patch('airflow.hooks.base.BaseHook.get_connection')
@mock.patch('airflow.providers.yandex.hooks.yandex.YandexCloudBaseHook._get_credentials')
Expand All @@ -80,6 +88,10 @@ def test_get_field(self, get_credentials_mock, get_connection_mock):
)
get_credentials_mock.return_value = {"token": 122323}

hook = YandexCloudBaseHook(None, default_folder_id, default_public_ssh_key)
hook = YandexCloudBaseHook(
yandex_conn_id=None,
default_folder_id=default_folder_id,
default_public_ssh_key=default_public_ssh_key,
)

assert hook._get_field('one') == 'value_one'
5 changes: 5 additions & 0 deletions tests/providers/yandex/operators/test_yandexcloud_dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,11 @@ def test_create_cluster(self, create_cluster_mock, *_):
subnet_id='my_subnet_id',
zone='ru-central1-c',
log_group_id=LOG_GROUP_ID,
properties=None,
enable_ui_proxy=False,
host_group_ids=None,
security_group_ids=None,
initialization_actions=None,
)
context['task_instance'].xcom_push.assert_has_calls(
[
Expand Down
197 changes: 197 additions & 0 deletions tests/system/providers/yandex/example_yandexcloud.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import os
from datetime import datetime
from typing import Optional

import yandex.cloud.dataproc.v1.cluster_pb2 as cluster_pb
import yandex.cloud.dataproc.v1.cluster_service_pb2 as cluster_service_pb
import yandex.cloud.dataproc.v1.cluster_service_pb2_grpc as cluster_service_grpc_pb
import yandex.cloud.dataproc.v1.common_pb2 as common_pb
import yandex.cloud.dataproc.v1.job_pb2 as job_pb
import yandex.cloud.dataproc.v1.job_service_pb2 as job_service_pb
import yandex.cloud.dataproc.v1.job_service_pb2_grpc as job_service_grpc_pb
import yandex.cloud.dataproc.v1.subcluster_pb2 as subcluster_pb
from google.protobuf.json_format import MessageToDict

from airflow import DAG
from airflow.decorators import task
from airflow.providers.yandex.hooks.yandex import YandexCloudBaseHook

ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID")
DAG_ID = 'example_yandexcloud_hook'

# Fill it with your identifiers
YC_S3_BUCKET_NAME = '' # Fill to use S3 instead of HFDS
YC_FOLDER_ID = None # Fill to override default YC folder from connection data
YC_ZONE_NAME = 'ru-central1-b'
YC_SUBNET_ID = None # Fill if you have more than one VPC subnet in given folder and zone
YC_SERVICE_ACCOUNT_ID = None # Fill if you have more than one YC service account in given folder


def create_cluster_request(
folder_id: str,
cluster_name: str,
cluster_desc: str,
zone: str,
subnet_id: str,
service_account_id: str,
ssh_public_key: str,
resources: common_pb.Resources,
):
return cluster_service_pb.CreateClusterRequest(
folder_id=folder_id,
name=cluster_name,
description=cluster_desc,
bucket=YC_S3_BUCKET_NAME,
config_spec=cluster_service_pb.CreateClusterConfigSpec(
hadoop=cluster_pb.HadoopConfig(
services=('SPARK', 'YARN'),
ssh_public_keys=[ssh_public_key],
),
subclusters_spec=[
cluster_service_pb.CreateSubclusterConfigSpec(
name='master',
role=subcluster_pb.Role.MASTERNODE,
resources=resources,
subnet_id=subnet_id,
hosts_count=1,
),
cluster_service_pb.CreateSubclusterConfigSpec(
name='compute',
role=subcluster_pb.Role.COMPUTENODE,
resources=resources,
subnet_id=subnet_id,
hosts_count=1,
),
],
),
zone_id=zone,
service_account_id=service_account_id,
)


@task
def create_cluster(
yandex_conn_id: Optional[str] = None,
folder_id: Optional[str] = None,
network_id: Optional[str] = None,
subnet_id: Optional[str] = None,
zone: str = YC_ZONE_NAME,
service_account_id: Optional[str] = None,
ssh_public_key: Optional[str] = None,
*,
dag: Optional[DAG] = None,
ts_nodash: Optional[str] = None,
) -> str:
hook = YandexCloudBaseHook(yandex_conn_id=yandex_conn_id)
folder_id = folder_id or hook.default_folder_id
if subnet_id is None:
network_id = network_id or hook.sdk.helpers.find_network_id(folder_id)
subnet_id = hook.sdk.helpers.find_subnet_id(folder_id=folder_id, zone_id=zone, network_id=network_id)
service_account_id = service_account_id or hook.sdk.helpers.find_service_account_id()
ssh_public_key = ssh_public_key or hook.default_public_ssh_key

dag_id = dag and dag.dag_id or 'dag'

request = create_cluster_request(
folder_id=folder_id,
subnet_id=subnet_id,
zone=zone,
cluster_name=f'airflow_{dag_id}_{ts_nodash}'[:62],
cluster_desc='Created via Airflow custom hook task',
service_account_id=service_account_id,
ssh_public_key=ssh_public_key,
resources=common_pb.Resources(
resource_preset_id='s2.micro',
disk_type_id='network-ssd',
),
)
operation = hook.sdk.client(cluster_service_grpc_pb.ClusterServiceStub).Create(request)
operation_result = hook.sdk.wait_operation_and_get_result(
operation, response_type=cluster_pb.Cluster, meta_type=cluster_service_pb.CreateClusterMetadata
)
return operation_result.response.id


@task
def run_spark_job(
cluster_id: str,
yandex_conn_id: Optional[str] = None,
):
hook = YandexCloudBaseHook(yandex_conn_id=yandex_conn_id)

request = job_service_pb.CreateJobRequest(
cluster_id=cluster_id,
name='Spark job: Find total urban population in distribution by country',
spark_job=job_pb.SparkJob(
main_jar_file_uri='file:///usr/lib/spark/examples/jars/spark-examples.jar',
main_class='org.apache.spark.examples.SparkPi',
args=['1000'],
),
)
operation = hook.sdk.client(job_service_grpc_pb.JobServiceStub).Create(request)
operation_result = hook.sdk.wait_operation_and_get_result(
operation, response_type=job_pb.Job, meta_type=job_service_pb.CreateJobMetadata
)
return MessageToDict(operation_result.response)


@task(trigger_rule='all_done')
def delete_cluster(
cluster_id: str,
yandex_conn_id: Optional[str] = None,
):
hook = YandexCloudBaseHook(yandex_conn_id=yandex_conn_id)

operation = hook.sdk.client(cluster_service_grpc_pb.ClusterServiceStub).Delete(
cluster_service_pb.DeleteClusterRequest(cluster_id=cluster_id)
)
hook.sdk.wait_operation_and_get_result(
operation,
meta_type=cluster_service_pb.DeleteClusterMetadata,
)


with DAG(
dag_id=DAG_ID,
schedule_interval=None,
start_date=datetime(2021, 1, 1),
tags=['example'],
) as dag:
cluster_id = create_cluster(
folder_id=YC_FOLDER_ID,
subnet_id=YC_SUBNET_ID,
zone=YC_ZONE_NAME,
service_account_id=YC_SERVICE_ACCOUNT_ID,
)
spark_job = run_spark_job(cluster_id=cluster_id)
delete_task = delete_cluster(cluster_id=cluster_id)

spark_job >> delete_task

from tests.system.utils.watcher import watcher

# This test needs watcher in order to properly mark success/failure
# when "teardown" task with trigger rule is part of the DAG
list(dag.tasks) >> watcher()


from tests.system.utils import get_test_run # noqa: E402

# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest)
test_run = get_test_run(dag)
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import os
from datetime import datetime

from airflow import DAG
from airflow.providers.yandex.operators.yandexcloud_dataproc import (
DataprocCreateClusterOperator,
DataprocCreateSparkJobOperator,
DataprocDeleteClusterOperator,
)

# Name of the datacenter where Dataproc cluster will be created
from airflow.utils.trigger_rule import TriggerRule

# should be filled with appropriate ids


AVAILABILITY_ZONE_ID = 'ru-central1-c'

# Dataproc cluster will use this bucket as distributed storage
S3_BUCKET_NAME = ''

ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID")
DAG_ID = 'example_yandexcloud_dataproc_lightweight'

with DAG(
DAG_ID,
schedule_interval=None,
start_date=datetime(2021, 1, 1),
tags=['example'],
) as dag:
create_cluster = DataprocCreateClusterOperator(
task_id='create_cluster',
zone=AVAILABILITY_ZONE_ID,
s3_bucket=S3_BUCKET_NAME,
computenode_count=1,
datanode_count=0,
services=('SPARK', 'YARN'),
)

create_spark_job = DataprocCreateSparkJobOperator(
cluster_id=create_cluster.cluster_id,
task_id='create_spark_job',
main_jar_file_uri='file:///usr/lib/spark/examples/jars/spark-examples.jar',
main_class='org.apache.spark.examples.SparkPi',
args=['1000'],
)

delete_cluster = DataprocDeleteClusterOperator(
cluster_id=create_cluster.cluster_id,
task_id='delete_cluster',
trigger_rule=TriggerRule.ALL_DONE,
)
create_spark_job >> delete_cluster

from tests.system.utils.watcher import watcher

# This test needs watcher in order to properly mark success/failure
# when "teardown" task with trigger rule is part of the DAG
list(dag.tasks) >> watcher()

from tests.system.utils import get_test_run # noqa: E402

# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest)
test_run = get_test_run(dag)