diff --git a/airflow/providers/google/cloud/example_dags/example_tasks.py b/airflow/providers/google/cloud/example_dags/example_tasks.py index cfa35622c4567..24874820b8c1b 100644 --- a/airflow/providers/google/cloud/example_dags/example_tasks.py +++ b/airflow/providers/google/cloud/example_dags/example_tasks.py @@ -39,7 +39,7 @@ timestamp = timestamp_pb2.Timestamp() timestamp.FromDatetime(datetime.now() + timedelta(hours=12)) # pylint: disable=no-member -LOCATION = "asia-east2" +LOCATION = "europe-west1" QUEUE_ID = "cloud-tasks-queue" TASK_NAME = "task-to-run" diff --git a/airflow/providers/google/cloud/operators/tasks.py b/airflow/providers/google/cloud/operators/tasks.py index 8c81c07184b43..5dc7dcf16645f 100644 --- a/airflow/providers/google/cloud/operators/tasks.py +++ b/airflow/providers/google/cloud/operators/tasks.py @@ -23,9 +23,11 @@ """ from typing import Dict, Optional, Sequence, Tuple, Union +from google.api_core.exceptions import AlreadyExists from google.api_core.retry import Retry from google.cloud.tasks_v2 import enums from google.cloud.tasks_v2.types import FieldMask, Queue, Task +from google.protobuf.json_format import MessageToDict from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.tasks import CloudTasksHook @@ -98,15 +100,27 @@ def __init__( def execute(self, context): hook = CloudTasksHook(gcp_conn_id=self.gcp_conn_id) - return hook.create_queue( - location=self.location, - task_queue=self.task_queue, - project_id=self.project_id, - queue_name=self.queue_name, - retry=self.retry, - timeout=self.timeout, - metadata=self.metadata, - ) + try: + queue = hook.create_queue( + location=self.location, + task_queue=self.task_queue, + project_id=self.project_id, + queue_name=self.queue_name, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + except AlreadyExists: + queue = hook.get_queue( + location=self.location, + project_id=self.project_id, + queue_name=self.queue_name, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + return MessageToDict(queue) class CloudTasksQueueUpdateOperator(BaseOperator): @@ -181,7 +195,7 @@ def __init__( def execute(self, context): hook = CloudTasksHook(gcp_conn_id=self.gcp_conn_id) - return hook.update_queue( + queue = hook.update_queue( task_queue=self.task_queue, project_id=self.project_id, location=self.location, @@ -191,6 +205,7 @@ def execute(self, context): timeout=self.timeout, metadata=self.metadata, ) + return MessageToDict(queue) class CloudTasksQueueGetOperator(BaseOperator): @@ -244,7 +259,7 @@ def __init__( def execute(self, context): hook = CloudTasksHook(gcp_conn_id=self.gcp_conn_id) - return hook.get_queue( + queue = hook.get_queue( location=self.location, queue_name=self.queue_name, project_id=self.project_id, @@ -252,6 +267,7 @@ def execute(self, context): timeout=self.timeout, metadata=self.metadata, ) + return MessageToDict(queue) class CloudTasksQueuesListOperator(BaseOperator): @@ -311,7 +327,7 @@ def __init__( def execute(self, context): hook = CloudTasksHook(gcp_conn_id=self.gcp_conn_id) - return hook.list_queues( + queues = hook.list_queues( location=self.location, project_id=self.project_id, results_filter=self.results_filter, @@ -320,6 +336,7 @@ def execute(self, context): timeout=self.timeout, metadata=self.metadata, ) + return [MessageToDict(q) for q in queues] class CloudTasksQueueDeleteOperator(BaseOperator): @@ -433,7 +450,7 @@ def __init__( def execute(self, context): hook = CloudTasksHook(gcp_conn_id=self.gcp_conn_id) - return hook.purge_queue( + queue = hook.purge_queue( location=self.location, queue_name=self.queue_name, project_id=self.project_id, @@ -441,6 +458,7 @@ def execute(self, context): timeout=self.timeout, metadata=self.metadata, ) + return MessageToDict(queue) class CloudTasksQueuePauseOperator(BaseOperator): @@ -494,7 +512,7 @@ def __init__( def execute(self, context): hook = CloudTasksHook(gcp_conn_id=self.gcp_conn_id) - return hook.pause_queue( + queues = hook.pause_queue( location=self.location, queue_name=self.queue_name, project_id=self.project_id, @@ -502,6 +520,7 @@ def execute(self, context): timeout=self.timeout, metadata=self.metadata, ) + return [MessageToDict(q) for q in queues] class CloudTasksQueueResumeOperator(BaseOperator): @@ -555,7 +574,7 @@ def __init__( def execute(self, context): hook = CloudTasksHook(gcp_conn_id=self.gcp_conn_id) - return hook.resume_queue( + queue = hook.resume_queue( location=self.location, queue_name=self.queue_name, project_id=self.project_id, @@ -563,6 +582,7 @@ def execute(self, context): timeout=self.timeout, metadata=self.metadata, ) + return MessageToDict(queue) class CloudTasksTaskCreateOperator(BaseOperator): @@ -638,7 +658,7 @@ def __init__( # pylint: disable=too-many-arguments def execute(self, context): hook = CloudTasksHook(gcp_conn_id=self.gcp_conn_id) - return hook.create_task( + task = hook.create_task( location=self.location, queue_name=self.queue_name, task=self.task, @@ -649,6 +669,7 @@ def execute(self, context): timeout=self.timeout, metadata=self.metadata, ) + return MessageToDict(task) class CloudTasksTaskGetOperator(BaseOperator): @@ -717,7 +738,7 @@ def __init__( def execute(self, context): hook = CloudTasksHook(gcp_conn_id=self.gcp_conn_id) - return hook.get_task( + task = hook.get_task( location=self.location, queue_name=self.queue_name, task_name=self.task_name, @@ -727,6 +748,7 @@ def execute(self, context): timeout=self.timeout, metadata=self.metadata, ) + return MessageToDict(task) class CloudTasksTasksListOperator(BaseOperator): @@ -790,7 +812,7 @@ def __init__( def execute(self, context): hook = CloudTasksHook(gcp_conn_id=self.gcp_conn_id) - return hook.list_tasks( + tasks = hook.list_tasks( location=self.location, queue_name=self.queue_name, project_id=self.project_id, @@ -800,6 +822,7 @@ def execute(self, context): timeout=self.timeout, metadata=self.metadata, ) + return [MessageToDict(t) for t in tasks] class CloudTasksTaskDeleteOperator(BaseOperator): @@ -939,7 +962,7 @@ def __init__( def execute(self, context): hook = CloudTasksHook(gcp_conn_id=self.gcp_conn_id) - return hook.run_task( + task = hook.run_task( location=self.location, queue_name=self.queue_name, task_name=self.task_name, @@ -949,3 +972,4 @@ def execute(self, context): timeout=self.timeout, metadata=self.metadata, ) + return MessageToDict(task) diff --git a/tests/providers/google/cloud/operators/test_tasks.py b/tests/providers/google/cloud/operators/test_tasks.py index fde9da1a65cf2..24634f8899a1d 100644 --- a/tests/providers/google/cloud/operators/test_tasks.py +++ b/tests/providers/google/cloud/operators/test_tasks.py @@ -44,7 +44,7 @@ class TestCloudTasksQueueCreate(unittest.TestCase): @mock.patch("airflow.providers.google.cloud.operators.tasks.CloudTasksHook") def test_create_queue(self, mock_hook): - mock_hook.return_value.create_queue.return_value = {} + mock_hook.return_value.create_queue.return_value = mock.MagicMock() operator = CloudTasksQueueCreateOperator( location=LOCATION, task_queue=Queue(), task_id="id" ) @@ -64,7 +64,7 @@ def test_create_queue(self, mock_hook): class TestCloudTasksQueueUpdate(unittest.TestCase): @mock.patch("airflow.providers.google.cloud.operators.tasks.CloudTasksHook") def test_update_queue(self, mock_hook): - mock_hook.return_value.update_queue.return_value = {} + mock_hook.return_value.update_queue.return_value = mock.MagicMock() operator = CloudTasksQueueUpdateOperator( task_queue=Queue(name=FULL_QUEUE_PATH), task_id="id" ) @@ -85,7 +85,7 @@ def test_update_queue(self, mock_hook): class TestCloudTasksQueueGet(unittest.TestCase): @mock.patch("airflow.providers.google.cloud.operators.tasks.CloudTasksHook") def test_get_queue(self, mock_hook): - mock_hook.return_value.get_queue.return_value = {} + mock_hook.return_value.get_queue.return_value = mock.MagicMock() operator = CloudTasksQueueGetOperator( location=LOCATION, queue_name=QUEUE_ID, task_id="id" ) @@ -104,7 +104,7 @@ def test_get_queue(self, mock_hook): class TestCloudTasksQueuesList(unittest.TestCase): @mock.patch("airflow.providers.google.cloud.operators.tasks.CloudTasksHook") def test_list_queues(self, mock_hook): - mock_hook.return_value.list_queues.return_value = {} + mock_hook.return_value.list_queues.return_value = mock.MagicMock() operator = CloudTasksQueuesListOperator(location=LOCATION, task_id="id") operator.execute(context=None) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID) @@ -122,7 +122,7 @@ def test_list_queues(self, mock_hook): class TestCloudTasksQueueDelete(unittest.TestCase): @mock.patch("airflow.providers.google.cloud.operators.tasks.CloudTasksHook") def test_delete_queue(self, mock_hook): - mock_hook.return_value.delete_queue.return_value = {} + mock_hook.return_value.delete_queue.return_value = mock.MagicMock() operator = CloudTasksQueueDeleteOperator( location=LOCATION, queue_name=QUEUE_ID, task_id="id" ) @@ -141,7 +141,7 @@ def test_delete_queue(self, mock_hook): class TestCloudTasksQueuePurge(unittest.TestCase): @mock.patch("airflow.providers.google.cloud.operators.tasks.CloudTasksHook") def test_delete_queue(self, mock_hook): - mock_hook.return_value.purge_queue.return_value = {} + mock_hook.return_value.purge_queue.return_value = mock.MagicMock() operator = CloudTasksQueuePurgeOperator( location=LOCATION, queue_name=QUEUE_ID, task_id="id" ) @@ -160,7 +160,7 @@ def test_delete_queue(self, mock_hook): class TestCloudTasksQueuePause(unittest.TestCase): @mock.patch("airflow.providers.google.cloud.operators.tasks.CloudTasksHook") def test_pause_queue(self, mock_hook): - mock_hook.return_value.pause_queue.return_value = {} + mock_hook.return_value.pause_queue.return_value = mock.MagicMock() operator = CloudTasksQueuePauseOperator( location=LOCATION, queue_name=QUEUE_ID, task_id="id" ) @@ -179,7 +179,7 @@ def test_pause_queue(self, mock_hook): class TestCloudTasksQueueResume(unittest.TestCase): @mock.patch("airflow.providers.google.cloud.operators.tasks.CloudTasksHook") def test_resume_queue(self, mock_hook): - mock_hook.return_value.resume_queue.return_value = {} + mock_hook.return_value.resume_queue.return_value = mock.MagicMock() operator = CloudTasksQueueResumeOperator( location=LOCATION, queue_name=QUEUE_ID, task_id="id" ) @@ -198,7 +198,7 @@ def test_resume_queue(self, mock_hook): class TestCloudTasksTaskCreate(unittest.TestCase): @mock.patch("airflow.providers.google.cloud.operators.tasks.CloudTasksHook") def test_create_task(self, mock_hook): - mock_hook.return_value.create_task.return_value = {} + mock_hook.return_value.create_task.return_value = mock.MagicMock() operator = CloudTasksTaskCreateOperator( location=LOCATION, queue_name=QUEUE_ID, task=Task(), task_id="id" ) @@ -220,7 +220,7 @@ def test_create_task(self, mock_hook): class TestCloudTasksTaskGet(unittest.TestCase): @mock.patch("airflow.providers.google.cloud.operators.tasks.CloudTasksHook") def test_get_task(self, mock_hook): - mock_hook.return_value.get_task.return_value = {} + mock_hook.return_value.get_task.return_value = mock.MagicMock() operator = CloudTasksTaskGetOperator( location=LOCATION, queue_name=QUEUE_ID, task_name=TASK_NAME, task_id="id" ) @@ -241,7 +241,7 @@ def test_get_task(self, mock_hook): class TestCloudTasksTasksList(unittest.TestCase): @mock.patch("airflow.providers.google.cloud.operators.tasks.CloudTasksHook") def test_list_tasks(self, mock_hook): - mock_hook.return_value.list_tasks.return_value = {} + mock_hook.return_value.list_tasks.return_value = mock.MagicMock() operator = CloudTasksTasksListOperator( location=LOCATION, queue_name=QUEUE_ID, task_id="id" ) @@ -262,7 +262,7 @@ def test_list_tasks(self, mock_hook): class TestCloudTasksTaskDelete(unittest.TestCase): @mock.patch("airflow.providers.google.cloud.operators.tasks.CloudTasksHook") def test_delete_task(self, mock_hook): - mock_hook.return_value.delete_task.return_value = {} + mock_hook.return_value.delete_task.return_value = mock.MagicMock() operator = CloudTasksTaskDeleteOperator( location=LOCATION, queue_name=QUEUE_ID, task_name=TASK_NAME, task_id="id" ) @@ -282,7 +282,7 @@ def test_delete_task(self, mock_hook): class TestCloudTasksTaskRun(unittest.TestCase): @mock.patch("airflow.providers.google.cloud.operators.tasks.CloudTasksHook") def test_run_task(self, mock_hook): - mock_hook.return_value.run_task.return_value = {} + mock_hook.return_value.run_task.return_value = mock.MagicMock() operator = CloudTasksTaskRunOperator( location=LOCATION, queue_name=QUEUE_ID, task_name=TASK_NAME, task_id="id" )