From 2543c74c1927b751e7492df81d762e61d2a4d5f6 Mon Sep 17 00:00:00 2001 From: Bjorn Olsen Date: Thu, 24 Jun 2021 14:59:27 +0200 Subject: [PATCH] AWS DataSync cancel task on exception (#11011) (#16589) Small improvements to DataSync operator. Most notable is the ability of the operator to cancel an in progress task execution, eg if the Airflow task times out or is killed. This avoids a zombie issue when the AWS DataSync service can have a zombie task running even if Airflow's task has failed. Also made some small changes to polling values. DataSync is a batch-based uploading service, it takes several minutes to operate so I changed the polling intervals from 5 seconds to 30 seconds and adjusted max_iterations to what I think is a more reasonable default. closes: #11011 --- airflow/providers/amazon/CHANGELOG.rst | 7 ++++++ .../providers/amazon/aws/hooks/datasync.py | 6 ++--- .../amazon/aws/operators/datasync.py | 23 +++++++++++++++---- .../amazon/aws/operators/test_datasync.py | 2 +- 4 files changed, 30 insertions(+), 8 deletions(-) diff --git a/airflow/providers/amazon/CHANGELOG.rst b/airflow/providers/amazon/CHANGELOG.rst index 6535e974519b4..792b42e7498aa 100644 --- a/airflow/providers/amazon/CHANGELOG.rst +++ b/airflow/providers/amazon/CHANGELOG.rst @@ -19,6 +19,13 @@ Changelog --------- +2.1.0 +..... + +Bug Fixes +~~~~~~~~~ +* ``AWS DataSync default polling adjusted from 5s to 30s (#11011)`` + 2.0.0 ..... diff --git a/airflow/providers/amazon/aws/hooks/datasync.py b/airflow/providers/amazon/aws/hooks/datasync.py index ec92bd739d0fa..d157f92f016e8 100644 --- a/airflow/providers/amazon/aws/hooks/datasync.py +++ b/airflow/providers/amazon/aws/hooks/datasync.py @@ -36,7 +36,7 @@ class AWSDataSyncHook(AwsBaseHook): :class:`~airflow.providers.amazon.aws.operators.datasync.AWSDataSyncOperator` :param wait_interval_seconds: Time to wait between two - consecutive calls to check TaskExecution status. Defaults to 5 seconds. + consecutive calls to check TaskExecution status. Defaults to 30 seconds. :type wait_interval_seconds: Optional[int] :raises ValueError: If wait_interval_seconds is not between 0 and 15*60 seconds. """ @@ -52,7 +52,7 @@ class AWSDataSyncHook(AwsBaseHook): TASK_EXECUTION_FAILURE_STATES = ("ERROR",) TASK_EXECUTION_SUCCESS_STATES = ("SUCCESS",) - def __init__(self, wait_interval_seconds: int = 5, *args, **kwargs) -> None: + def __init__(self, wait_interval_seconds: int = 30, *args, **kwargs) -> None: super().__init__(client_type='datasync', *args, **kwargs) # type: ignore[misc] self.locations: list = [] self.tasks: list = [] @@ -279,7 +279,7 @@ def get_current_task_execution_arn(self, task_arn: str) -> Optional[str]: return task_description["CurrentTaskExecutionArn"] return None - def wait_for_task_execution(self, task_execution_arn: str, max_iterations: int = 2 * 180) -> bool: + def wait_for_task_execution(self, task_execution_arn: str, max_iterations: int = 60) -> bool: """ Wait for Task Execution status to be complete (SUCCESS/ERROR). The ``task_execution_arn`` must exist, or a boto3 ClientError will be raised. diff --git a/airflow/providers/amazon/aws/operators/datasync.py b/airflow/providers/amazon/aws/operators/datasync.py index 6c88eb1c339cd..88381a85edc40 100644 --- a/airflow/providers/amazon/aws/operators/datasync.py +++ b/airflow/providers/amazon/aws/operators/datasync.py @@ -21,7 +21,7 @@ import random from typing import List, Optional -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, AirflowTaskTimeout from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.datasync import AWSDataSyncHook @@ -46,6 +46,9 @@ class AWSDataSyncOperator(BaseOperator): :param wait_interval_seconds: Time to wait between two consecutive calls to check TaskExecution status. :type wait_interval_seconds: int + :param max_iterations: Maximum number of + consecutive calls to check TaskExecution status. + :type max_iterations: int :param task_arn: AWS DataSync TaskArn to use. If None, then this operator will attempt to either search for an existing Task or attempt to create a new Task. :type task_arn: str @@ -128,7 +131,8 @@ def __init__( self, *, aws_conn_id: str = "aws_default", - wait_interval_seconds: int = 5, + wait_interval_seconds: int = 30, + max_iterations: int = 60, task_arn: Optional[str] = None, source_location_uri: Optional[str] = None, destination_location_uri: Optional[str] = None, @@ -147,6 +151,7 @@ def __init__( # Assignments self.aws_conn_id = aws_conn_id self.wait_interval_seconds = wait_interval_seconds + self.max_iterations = max_iterations self.task_arn = task_arn @@ -355,8 +360,14 @@ def _execute_datasync_task(self) -> None: # Wait for task execution to complete self.log.info("Waiting for TaskExecutionArn %s", self.task_execution_arn) - result = hook.wait_for_task_execution(self.task_execution_arn) + try: + result = hook.wait_for_task_execution(self.task_execution_arn, max_iterations=self.max_iterations) + except (AirflowTaskTimeout, AirflowException) as e: + self.log.error('Cancelling TaskExecution after Exception: %s', e) + self._cancel_datasync_task_execution() + raise self.log.info("Completed TaskExecutionArn %s", self.task_execution_arn) + task_execution_description = hook.describe_task_execution(task_execution_arn=self.task_execution_arn) self.log.info("task_execution_description=%s", task_execution_description) @@ -371,7 +382,7 @@ def _execute_datasync_task(self) -> None: if not result: raise AirflowException(f"Failed TaskExecutionArn {self.task_execution_arn}") - def on_kill(self) -> None: + def _cancel_datasync_task_execution(self): """Cancel the submitted DataSync task.""" hook = self.get_hook() if self.task_execution_arn: @@ -379,6 +390,10 @@ def on_kill(self) -> None: hook.cancel_task_execution(task_execution_arn=self.task_execution_arn) self.log.info("Cancelled TaskExecutionArn %s", self.task_execution_arn) + def on_kill(self): + self.log.error('Cancelling TaskExecution after task was killed') + self._cancel_datasync_task_execution() + def _delete_datasync_task(self) -> None: """Deletes an AWS DataSync Task.""" if not self.task_arn: diff --git a/tests/providers/amazon/aws/operators/test_datasync.py b/tests/providers/amazon/aws/operators/test_datasync.py index 587196cba1c13..6a4f9111ca7ae 100644 --- a/tests/providers/amazon/aws/operators/test_datasync.py +++ b/tests/providers/amazon/aws/operators/test_datasync.py @@ -710,7 +710,7 @@ def test_killed_task(self, mock_wait, mock_get_conn): # ### Begin tests: # Kill the task when doing wait_for_task_execution - def kill_task(*args): + def kill_task(*args, **kwargs): self.datasync.on_kill() return True