diff --git a/airflow/providers/google/cloud/example_dags/example_datafusion.py b/airflow/providers/google/cloud/example_dags/example_datafusion.py index 93f238a22305a..70fc60a987698 100644 --- a/airflow/providers/google/cloud/example_dags/example_datafusion.py +++ b/airflow/providers/google/cloud/example_dags/example_datafusion.py @@ -34,6 +34,7 @@ CloudDataFusionStopPipelineOperator, CloudDataFusionUpdateInstanceOperator, ) +from airflow.providers.google.cloud.sensors.datafusion import CloudDataFusionPipelineStateSensor from airflow.utils import dates from airflow.utils.state import State @@ -205,6 +206,28 @@ ) # [END howto_cloud_data_fusion_start_pipeline] + # [START howto_cloud_data_fusion_start_pipeline_async] + start_pipeline_async = CloudDataFusionStartPipelineOperator( + location=LOCATION, + pipeline_name=PIPELINE_NAME, + instance_name=INSTANCE_NAME, + asynchronous=True, + task_id="start_pipeline_async", + ) + + # [END howto_cloud_data_fusion_start_pipeline_async] + + # [START howto_cloud_data_fusion_start_pipeline_sensor] + start_pipeline_sensor = CloudDataFusionPipelineStateSensor( + task_id="pipeline_state_sensor", + pipeline_name=PIPELINE_NAME, + pipeline_id=start_pipeline_async.output, + expected_statuses=["COMPLETED"], + instance_name=INSTANCE_NAME, + location=LOCATION, + ) + # [END howto_cloud_data_fusion_start_pipeline_sensor] + # [START howto_cloud_data_fusion_stop_pipeline] stop_pipeline = CloudDataFusionStopPipelineOperator( location=LOCATION, @@ -233,7 +256,16 @@ sleep = BashOperator(task_id="sleep", bash_command="sleep 60") create_instance >> get_instance >> restart_instance >> update_instance >> sleep - sleep >> create_pipeline >> list_pipelines >> start_pipeline >> stop_pipeline >> delete_pipeline + ( + sleep + >> create_pipeline + >> list_pipelines + >> start_pipeline_async + >> start_pipeline_sensor + >> start_pipeline + >> stop_pipeline + >> delete_pipeline + ) delete_pipeline >> delete_instance if __name__ == "__main__": diff --git a/airflow/providers/google/cloud/hooks/datafusion.py b/airflow/providers/google/cloud/hooks/datafusion.py index 00582239c7578..fefd5ce6d84ed 100644 --- a/airflow/providers/google/cloud/hooks/datafusion.py +++ b/airflow/providers/google/cloud/hooks/datafusion.py @@ -102,12 +102,13 @@ def wait_for_pipeline_state( current_state = None while monotonic() - start_time < timeout: try: - current_state = self._get_workflow_state( + workflow = self.get_pipeline_workflow( pipeline_name=pipeline_name, pipeline_id=pipeline_id, instance_url=instance_url, namespace=namespace, ) + current_state = workflow["status"] except AirflowException: pass # Because the pipeline may not be visible in system yet if current_state in success_states: @@ -398,7 +399,7 @@ def list_pipelines( raise AirflowException(f"Listing pipelines failed with code {response.status}") return json.loads(response.data) - def _get_workflow_state( + def get_pipeline_workflow( self, pipeline_name: str, instance_url: str, @@ -417,7 +418,7 @@ def _get_workflow_state( if response.status != 200: raise AirflowException(f"Retrieving a pipeline state failed with code {response.status}") workflow = json.loads(response.data) - return workflow["status"] + return workflow def start_pipeline( self, diff --git a/airflow/providers/google/cloud/operators/datafusion.py b/airflow/providers/google/cloud/operators/datafusion.py index b115437c73453..8bc29ff65a69a 100644 --- a/airflow/providers/google/cloud/operators/datafusion.py +++ b/airflow/providers/google/cloud/operators/datafusion.py @@ -780,6 +780,10 @@ class CloudDataFusionStartPipelineOperator(BaseOperator): Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] + :param asynchronous: Flag to return after submitting the pipeline Id to the Data Fusion API. + This is useful for submitting long running pipelines and + waiting on them asynchronously using the CloudDataFusionPipelineStateSensor + :type asynchronous: bool """ template_fields = ( @@ -804,6 +808,7 @@ def __init__( gcp_conn_id: str = "google_cloud_default", delegate_to: Optional[str] = None, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + asynchronous=False, **kwargs, ) -> None: super().__init__(**kwargs) @@ -817,6 +822,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain + self.asynchronous = asynchronous if success_states: self.success_states = success_states @@ -825,7 +831,7 @@ def __init__( self.success_states = SUCCESS_STATES + [PipelineStates.RUNNING] self.pipeline_timeout = 5 * 60 - def execute(self, context: dict) -> None: + def execute(self, context: dict) -> str: hook = DataFusionHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -845,15 +851,20 @@ def execute(self, context: dict) -> None: namespace=self.namespace, runtime_args=self.runtime_args, ) - hook.wait_for_pipeline_state( - success_states=self.success_states, - pipeline_id=pipeline_id, - pipeline_name=self.pipeline_name, - namespace=self.namespace, - instance_url=api_url, - timeout=self.pipeline_timeout, - ) - self.log.info("Pipeline started") + self.log.info("Pipeline %s submitted successfully.", pipeline_id) + + if not self.asynchronous: + self.log.info("Waiting when pipeline %s will be in one of the success states", pipeline_id) + hook.wait_for_pipeline_state( + success_states=self.success_states, + pipeline_id=pipeline_id, + pipeline_name=self.pipeline_name, + namespace=self.namespace, + instance_url=api_url, + timeout=self.pipeline_timeout, + ) + self.log.info("Job %s discover success state.", pipeline_id) + return pipeline_id class CloudDataFusionStopPipelineOperator(BaseOperator): diff --git a/airflow/providers/google/cloud/sensors/datafusion.py b/airflow/providers/google/cloud/sensors/datafusion.py new file mode 100644 index 0000000000000..c57d3f0c35a03 --- /dev/null +++ b/airflow/providers/google/cloud/sensors/datafusion.py @@ -0,0 +1,125 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains a Google Cloud Data Fusion sensors.""" +from typing import Optional, Sequence, Set, Union + +from airflow.exceptions import AirflowException +from airflow.providers.google.cloud.hooks.datafusion import DataFusionHook +from airflow.sensors.base import BaseSensorOperator + + +class CloudDataFusionPipelineStateSensor(BaseSensorOperator): + """ + Check the status of the pipeline in the Google Cloud Data Fusion + + :param pipeline_name: Your pipeline name. + :type pipeline_name: str + :param pipeline_id: Your pipeline ID. + :type pipeline_name: str + :param expected_statuses: State that is expected + :type expected_statuses: set[str] + :param instance_name: The name of the instance. + :type instance_name: str + :param location: The Cloud Data Fusion location in which to handle the request. + :type location: str + :param project_id: The ID of the Google Cloud project that the instance belongs to. + :type project_id: str + :param namespace: If your pipeline belongs to a Basic edition instance, the namespace ID + is always default. If your pipeline belongs to an Enterprise edition instance, you + can create a namespace. + :type namespace: str + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + + """ + + template_fields = ['pipeline_id'] + + def __init__( + self, + pipeline_name: str, + pipeline_id: str, + expected_statuses: Set[str], + instance_name: str, + location: str, + project_id: Optional[str] = None, + namespace: str = "default", + gcp_conn_id: str = 'google_cloud_default', + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.pipeline_name = pipeline_name + self.pipeline_id = pipeline_id + self.expected_statuses = expected_statuses + self.instance_name = instance_name + self.location = location + self.project_id = project_id + self.namespace = namespace + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def poke(self, context: dict) -> bool: + self.log.info( + "Waiting for pipeline %s to be in one of the states: %s.", + self.pipeline_id, + ", ".join(self.expected_statuses), + ) + hook = DataFusionHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + + instance = hook.get_instance( + instance_name=self.instance_name, + location=self.location, + project_id=self.project_id, + ) + api_url = instance["apiEndpoint"] + pipeline_status = None + try: + pipeline_workflow = hook.get_pipeline_workflow( + pipeline_name=self.pipeline_name, + instance_url=api_url, + pipeline_id=self.pipeline_id, + namespace=self.namespace, + ) + pipeline_status = pipeline_workflow["status"] + except AirflowException: + pass # Because the pipeline may not be visible in system yet + + self.log.debug( + "Current status of the pipeline workflow for %s: %s.", self.pipeline_id, pipeline_status + ) + return pipeline_status in self.expected_statuses diff --git a/docs/apache-airflow-providers-google/operators/cloud/datafusion.rst b/docs/apache-airflow-providers-google/operators/cloud/datafusion.rst index b32f8a3d0c093..5f6cf27369f8b 100644 --- a/docs/apache-airflow-providers-google/operators/cloud/datafusion.rst +++ b/docs/apache-airflow-providers-google/operators/cloud/datafusion.rst @@ -158,7 +158,7 @@ The result is saved to :ref:`XCom `, which allows it to be used b Start a DataFusion pipeline ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -To start Data Fusion pipeline use: +To start Data Fusion pipeline using synchronous mode: :class:`~airflow.providers.google.cloud.operators.datafusion.CloudDataFusionStartPipelineOperator`. .. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_datafusion.py @@ -167,6 +167,15 @@ To start Data Fusion pipeline use: :start-after: [START howto_cloud_data_fusion_start_pipeline] :end-before: [END howto_cloud_data_fusion_start_pipeline] +To start Data Fusion pipeline using asynchronous mode: +:class:`~airflow.providers.google.cloud.operators.datafusion.CloudDataFusionStartPipelineOperator`. + +.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_datafusion.py + :language: python + :dedent: 4 + :start-after: [START howto_cloud_data_fusion_start_pipeline_async] + :end-before: [END howto_cloud_data_fusion_start_pipeline_async] + You can use :ref:`Jinja templating ` with :template-fields:`airflow.providers.google.cloud.operators.datafusion.CloudDataFusionStartPipelineOperator` parameters which allows you to dynamically determine values. @@ -229,3 +238,18 @@ You can use :ref:`Jinja templating ` with :template-fields:`airflow.providers.google.cloud.operators.datafusion.CloudDataFusionListPipelinesOperator` parameters which allows you to dynamically determine values. The result is saved to :ref:`XCom `, which allows it to be used by other operators. + +Sensors +^^^^^^^ + +When start pipeline is triggered asynchronously sensors may be used to run checks and verify that the pipeline in in correct state. + +:class:`~airflow.providers.google.cloud.sensors.datafusion.CloudDataFusionPipelineStateSensor`. + +.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_datafusion.py + :language: python + :dedent: 4 + :start-after: [START howto_cloud_data_fusion_start_pipeline_sensor] + :end-before: [END howto_cloud_data_fusion_start_pipeline_sensor] + +:class:`~airflow.providers.google.cloud.sensors.datafusion.CloudDataFusionPipelineStateSensor`. diff --git a/tests/providers/google/cloud/operators/test_datafusion.py b/tests/providers/google/cloud/operators/test_datafusion.py index 466f67061d4d5..5ff72bdd24244 100644 --- a/tests/providers/google/cloud/operators/test_datafusion.py +++ b/tests/providers/google/cloud/operators/test_datafusion.py @@ -231,6 +231,38 @@ def test_execute(self, mock_hook): timeout=300, ) + @mock.patch(HOOK_STR) + def test_execute_async(self, mock_hook): + PIPELINE_ID = "test_pipeline_id" + mock_hook.return_value.get_instance.return_value = {"apiEndpoint": INSTANCE_URL} + mock_hook.return_value.start_pipeline.return_value = PIPELINE_ID + + op = CloudDataFusionStartPipelineOperator( + task_id="test_task", + pipeline_name=PIPELINE_NAME, + instance_name=INSTANCE_NAME, + namespace=NAMESPACE, + location=LOCATION, + project_id=PROJECT_ID, + runtime_args=RUNTIME_ARGS, + asynchronous=True, + ) + op.dag = mock.MagicMock(spec=DAG, task_dict={}, dag_id="test") + + op.execute({}) + mock_hook.return_value.get_instance.assert_called_once_with( + instance_name=INSTANCE_NAME, location=LOCATION, project_id=PROJECT_ID + ) + + mock_hook.return_value.start_pipeline.assert_called_once_with( + instance_url=INSTANCE_URL, + pipeline_name=PIPELINE_NAME, + namespace=NAMESPACE, + runtime_args=RUNTIME_ARGS, + ) + + mock_hook.return_value.wait_for_pipeline_state.assert_not_called() + class TestCloudDataFusionStopPipelineOperator: @mock.patch(HOOK_STR) diff --git a/tests/providers/google/cloud/sensors/test_datafusion.py b/tests/providers/google/cloud/sensors/test_datafusion.py new file mode 100644 index 0000000000000..aeff6a57dbced --- /dev/null +++ b/tests/providers/google/cloud/sensors/test_datafusion.py @@ -0,0 +1,75 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import unittest +from unittest import mock + +from parameterized.parameterized import parameterized + +from airflow.providers.google.cloud.hooks.datafusion import PipelineStates +from airflow.providers.google.cloud.sensors.datafusion import CloudDataFusionPipelineStateSensor + +LOCATION = "test-location" +INSTANCE_NAME = "airflow-test-instance" +INSTANCE_URL = "http://datafusion.instance.com" +PIPELINE_NAME = "shrubberyPipeline" +PIPELINE_ID = "test_pipeline_id" +PROJECT_ID = "test_project_id" +GCP_CONN_ID = "test_conn_id" +DELEGATE_TO = "test_delegate_to" +IMPERSONATION_CHAIN = ["ACCOUNT_1", "ACCOUNT_2", "ACCOUNT_3"] + + +class TestCloudDataFusionPipelineStateSensor(unittest.TestCase): + @parameterized.expand( + [ + (PipelineStates.COMPLETED, PipelineStates.COMPLETED, True), + (PipelineStates.COMPLETED, PipelineStates.RUNNING, False), + ] + ) + @mock.patch("airflow.providers.google.cloud.sensors.datafusion.DataFusionHook") + def test_poke(self, expected_status, current_status, sensor_return, mock_hook): + mock_hook.return_value.get_instance.return_value = {"apiEndpoint": INSTANCE_URL} + + task = CloudDataFusionPipelineStateSensor( + task_id="test_task_id", + pipeline_name=PIPELINE_NAME, + pipeline_id=PIPELINE_ID, + project_id=PROJECT_ID, + expected_statuses=[expected_status], + instance_name=INSTANCE_NAME, + location=LOCATION, + gcp_conn_id=GCP_CONN_ID, + delegate_to=DELEGATE_TO, + impersonation_chain=IMPERSONATION_CHAIN, + ) + + mock_hook.return_value.get_pipeline_workflow.return_value = {"status": current_status} + result = task.poke(mock.MagicMock()) + + assert sensor_return == result + + mock_hook.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, + delegate_to=DELEGATE_TO, + impersonation_chain=IMPERSONATION_CHAIN, + ) + + mock_hook.return_value.get_instance.assert_called_once_with( + instance_name=INSTANCE_NAME, location=LOCATION, project_id=PROJECT_ID + )