Skip to content

Commit

Permalink
AWS DataSync cancel task on exception (#11011) (#16589)
Browse files Browse the repository at this point in the history
Small improvements to DataSync operator. 

Most notable is the ability of the operator to cancel an in progress task execution, eg if the Airflow task times out or is killed. This avoids a zombie issue when the AWS DataSync service can have a zombie task running even if Airflow's task has failed. 

Also made some small changes to polling values. DataSync is a batch-based uploading service, it takes several minutes to operate so I changed the polling intervals from 5 seconds to 30 seconds and adjusted max_iterations to what I think is a more reasonable default.

closes: #11011
  • Loading branch information
baolsen authored Jun 24, 2021
1 parent 962c5f4 commit 2543c74
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 8 deletions.
7 changes: 7 additions & 0 deletions airflow/providers/amazon/CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@
Changelog
---------

2.1.0
.....

Bug Fixes
~~~~~~~~~
* ``AWS DataSync default polling adjusted from 5s to 30s (#11011)``

2.0.0
.....

Expand Down
6 changes: 3 additions & 3 deletions airflow/providers/amazon/aws/hooks/datasync.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class AWSDataSyncHook(AwsBaseHook):
:class:`~airflow.providers.amazon.aws.operators.datasync.AWSDataSyncOperator`
:param wait_interval_seconds: Time to wait between two
consecutive calls to check TaskExecution status. Defaults to 5 seconds.
consecutive calls to check TaskExecution status. Defaults to 30 seconds.
:type wait_interval_seconds: Optional[int]
:raises ValueError: If wait_interval_seconds is not between 0 and 15*60 seconds.
"""
Expand All @@ -52,7 +52,7 @@ class AWSDataSyncHook(AwsBaseHook):
TASK_EXECUTION_FAILURE_STATES = ("ERROR",)
TASK_EXECUTION_SUCCESS_STATES = ("SUCCESS",)

def __init__(self, wait_interval_seconds: int = 5, *args, **kwargs) -> None:
def __init__(self, wait_interval_seconds: int = 30, *args, **kwargs) -> None:
super().__init__(client_type='datasync', *args, **kwargs) # type: ignore[misc]
self.locations: list = []
self.tasks: list = []
Expand Down Expand Up @@ -279,7 +279,7 @@ def get_current_task_execution_arn(self, task_arn: str) -> Optional[str]:
return task_description["CurrentTaskExecutionArn"]
return None

def wait_for_task_execution(self, task_execution_arn: str, max_iterations: int = 2 * 180) -> bool:
def wait_for_task_execution(self, task_execution_arn: str, max_iterations: int = 60) -> bool:
"""
Wait for Task Execution status to be complete (SUCCESS/ERROR).
The ``task_execution_arn`` must exist, or a boto3 ClientError will be raised.
Expand Down
23 changes: 19 additions & 4 deletions airflow/providers/amazon/aws/operators/datasync.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import random
from typing import List, Optional

from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, AirflowTaskTimeout
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.datasync import AWSDataSyncHook

Expand All @@ -46,6 +46,9 @@ class AWSDataSyncOperator(BaseOperator):
:param wait_interval_seconds: Time to wait between two
consecutive calls to check TaskExecution status.
:type wait_interval_seconds: int
:param max_iterations: Maximum number of
consecutive calls to check TaskExecution status.
:type max_iterations: int
:param task_arn: AWS DataSync TaskArn to use. If None, then this operator will
attempt to either search for an existing Task or attempt to create a new Task.
:type task_arn: str
Expand Down Expand Up @@ -128,7 +131,8 @@ def __init__(
self,
*,
aws_conn_id: str = "aws_default",
wait_interval_seconds: int = 5,
wait_interval_seconds: int = 30,
max_iterations: int = 60,
task_arn: Optional[str] = None,
source_location_uri: Optional[str] = None,
destination_location_uri: Optional[str] = None,
Expand All @@ -147,6 +151,7 @@ def __init__(
# Assignments
self.aws_conn_id = aws_conn_id
self.wait_interval_seconds = wait_interval_seconds
self.max_iterations = max_iterations

self.task_arn = task_arn

Expand Down Expand Up @@ -355,8 +360,14 @@ def _execute_datasync_task(self) -> None:

# Wait for task execution to complete
self.log.info("Waiting for TaskExecutionArn %s", self.task_execution_arn)
result = hook.wait_for_task_execution(self.task_execution_arn)
try:
result = hook.wait_for_task_execution(self.task_execution_arn, max_iterations=self.max_iterations)
except (AirflowTaskTimeout, AirflowException) as e:
self.log.error('Cancelling TaskExecution after Exception: %s', e)
self._cancel_datasync_task_execution()
raise
self.log.info("Completed TaskExecutionArn %s", self.task_execution_arn)

task_execution_description = hook.describe_task_execution(task_execution_arn=self.task_execution_arn)
self.log.info("task_execution_description=%s", task_execution_description)

Expand All @@ -371,14 +382,18 @@ def _execute_datasync_task(self) -> None:
if not result:
raise AirflowException(f"Failed TaskExecutionArn {self.task_execution_arn}")

def on_kill(self) -> None:
def _cancel_datasync_task_execution(self):
"""Cancel the submitted DataSync task."""
hook = self.get_hook()
if self.task_execution_arn:
self.log.info("Cancelling TaskExecutionArn %s", self.task_execution_arn)
hook.cancel_task_execution(task_execution_arn=self.task_execution_arn)
self.log.info("Cancelled TaskExecutionArn %s", self.task_execution_arn)

def on_kill(self):
self.log.error('Cancelling TaskExecution after task was killed')
self._cancel_datasync_task_execution()

def _delete_datasync_task(self) -> None:
"""Deletes an AWS DataSync Task."""
if not self.task_arn:
Expand Down
2 changes: 1 addition & 1 deletion tests/providers/amazon/aws/operators/test_datasync.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,7 +710,7 @@ def test_killed_task(self, mock_wait, mock_get_conn):
# ### Begin tests:

# Kill the task when doing wait_for_task_execution
def kill_task(*args):
def kill_task(*args, **kwargs):
self.datasync.on_kill()
return True

Expand Down

0 comments on commit 2543c74

Please sign in to comment.