diff --git a/airflow/providers/google/cloud/operators/dataproc.py b/airflow/providers/google/cloud/operators/dataproc.py index 647f485ee135c..a5d185a967746 100644 --- a/airflow/providers/google/cloud/operators/dataproc.py +++ b/airflow/providers/google/cloud/operators/dataproc.py @@ -2377,7 +2377,7 @@ def __init__( region: str, project_id: str | None = None, request_id: str | None = None, - retry: Retry | _MethodDefault = DEFAULT, + retry: AsyncRetry | _MethodDefault = DEFAULT, timeout: float | None = None, metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", @@ -2431,7 +2431,7 @@ def execute(self, context: Context): ) if not self.deferrable: self.log.info("Template instantiated. Workflow Id : %s", workflow_id) - operation.result() + hook.wait_for_operation(timeout=self.timeout, result_retry=self.retry, operation=operation) self.log.info("Workflow %s completed successfully", workflow_id) else: self.defer( diff --git a/tests/providers/google/cloud/operators/test_dataproc.py b/tests/providers/google/cloud/operators/test_dataproc.py index 68795357e7a5b..663aa2ac454af 100644 --- a/tests/providers/google/cloud/operators/test_dataproc.py +++ b/tests/providers/google/cloud/operators/test_dataproc.py @@ -23,6 +23,7 @@ import pytest from google.api_core.exceptions import AlreadyExists, NotFound from google.api_core.retry import Retry +from google.api_core.retry_async import AsyncRetry from google.cloud import dataproc from google.cloud.dataproc_v1 import Batch, Cluster, JobStatus @@ -2068,6 +2069,7 @@ def test_execute_call_defer_method(self, mock_trigger_hook, mock_hook): mock_hook.return_value.instantiate_inline_workflow_template.assert_called_once() + mock_hook.return_value.wait_for_operation.assert_not_called() assert isinstance(exc.value.trigger, DataprocOperationTrigger) assert exc.value.method_name == GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME @@ -2102,6 +2104,35 @@ def test_on_kill(self, mock_hook): name=operation_name ) + @mock.patch(DATAPROC_PATH.format("DataprocHook")) + def test_wait_for_operation_on_execute(self, mock_hook): + template = {} + + custom_timeout = 10800 + custom_retry = mock.MagicMock(AsyncRetry) + op = DataprocInstantiateInlineWorkflowTemplateOperator( + task_id=TASK_ID, + template=template, + region=GCP_REGION, + project_id=GCP_PROJECT, + request_id=REQUEST_ID, + retry=custom_retry, + timeout=custom_timeout, + metadata=METADATA, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + + mock_op = MagicMock() + mock_hook.return_value.instantiate_inline_workflow_template.return_value = mock_op + + op.execute(context=MagicMock()) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) + mock_hook.return_value.wait_for_operation.assert_called_once_with( + timeout=custom_timeout, result_retry=custom_retry, operation=mock_op + ) + mock_op.return_value.result.assert_not_called() + @pytest.mark.db_test @pytest.mark.need_serialized_dag