Skip to content

Commit

Permalink
Use base aws classes in AWS Datasync Operators (#36766)
Browse files Browse the repository at this point in the history
  • Loading branch information
Taragolis authored Jan 14, 2024
1 parent 270b112 commit c7f518f
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 18 deletions.
37 changes: 19 additions & 18 deletions airflow/providers/amazon/aws/operators/datasync.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,20 @@

import logging
import random
from functools import cached_property
from typing import TYPE_CHECKING, Sequence
from typing import TYPE_CHECKING, Any, Sequence

from deprecated.classic import deprecated

from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowTaskTimeout
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.datasync import DataSyncHook
from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields

if TYPE_CHECKING:
from airflow.utils.context import Context


class DataSyncOperator(BaseOperator):
class DataSyncOperator(AwsBaseOperator[DataSyncHook]):
"""Find, Create, Update, Execute and Delete AWS DataSync Tasks.
If ``do_xcom_push`` is True, then the DataSync TaskArn and TaskExecutionArn
Expand All @@ -46,7 +46,6 @@ class DataSyncOperator(BaseOperator):
environment. The default behavior is to create a new Task if there are 0, or
execute the Task if there was 1 Task, or fail if there were many Tasks.
:param aws_conn_id: AWS connection to use.
:param wait_interval_seconds: Time to wait between two
consecutive calls to check TaskExecution status.
:param max_iterations: Maximum number of
Expand Down Expand Up @@ -91,6 +90,16 @@ class DataSyncOperator(BaseOperator):
``boto3.start_task_execution(TaskArn=task_arn, **task_execution_kwargs)``
:param delete_task_after_execution: If True then the TaskArn which was executed
will be deleted from AWS DataSync on successful completion.
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is ``None`` or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param verify: Whether or not to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
:raises AirflowException: If ``task_arn`` was not specified, or if
either ``source_location_uri`` or ``destination_location_uri`` were
not specified.
Expand All @@ -100,7 +109,8 @@ class DataSyncOperator(BaseOperator):
:raises AirflowException: If Task creation, update, execution or delete fails.
"""

template_fields: Sequence[str] = (
aws_hook_class = DataSyncHook
template_fields: Sequence[str] = aws_template_fields(
"task_arn",
"source_location_uri",
"destination_location_uri",
Expand All @@ -122,7 +132,6 @@ class DataSyncOperator(BaseOperator):
def __init__(
self,
*,
aws_conn_id: str = "aws_default",
wait_interval_seconds: int = 30,
max_iterations: int = 60,
wait_for_completion: bool = True,
Expand All @@ -142,7 +151,6 @@ def __init__(
super().__init__(**kwargs)

# Assignments
self.aws_conn_id = aws_conn_id
self.wait_interval_seconds = wait_interval_seconds
self.max_iterations = max_iterations
self.wait_for_completion = wait_for_completion
Expand Down Expand Up @@ -185,16 +193,9 @@ def __init__(
self.destination_location_arn: str | None = None
self.task_execution_arn: str | None = None

@cached_property
def hook(self) -> DataSyncHook:
"""Create and return DataSyncHook.
:return DataSyncHook: An DataSyncHook instance.
"""
return DataSyncHook(
aws_conn_id=self.aws_conn_id,
wait_interval_seconds=self.wait_interval_seconds,
)
@property
def _hook_parameters(self) -> dict[str, Any]:
return {**super()._hook_parameters, "wait_interval_seconds": self.wait_interval_seconds}

@deprecated(reason="use `hook` property instead.", category=AirflowProviderDeprecationWarning)
def get_hook(self) -> DataSyncHook:
Expand Down
5 changes: 5 additions & 0 deletions docs/apache-airflow-providers-amazon/operators/datasync.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ Prerequisite Tasks

.. include:: ../_partials/prerequisite_tasks.rst

Generic Parameters
------------------

.. include:: ../_partials/generic_parameters.rst

Operators
---------

Expand Down
36 changes: 36 additions & 0 deletions tests/providers/amazon/aws/operators/test_datasync.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,42 @@ def teardown_method(self, method):
self.client = None


def test_generic_params():
op = DataSyncOperator(
task_id="generic-task",
task_arn="arn:fake",
source_location_uri="fake://source",
destination_location_uri="fake://destination",
aws_conn_id="fake-conn-id",
region_name="cn-north-1",
verify=False,
botocore_config={"read_timeout": 42},
# Non-generic hook params
wait_interval_seconds=42,
)

assert op.hook.client_type == "datasync"
assert op.hook.resource_type is None
assert op.hook.aws_conn_id == "fake-conn-id"
assert op.hook._region_name == "cn-north-1"
assert op.hook._verify is False
assert op.hook._config is not None
assert op.hook._config.read_timeout == 42
assert op.hook.wait_interval_seconds == 42

op = DataSyncOperator(
task_id="generic-task",
task_arn="arn:fake",
source_location_uri="fake://source",
destination_location_uri="fake://destination",
)
assert op.hook.aws_conn_id == "aws_default"
assert op.hook._region_name is None
assert op.hook._verify is None
assert op.hook._config is None
assert op.hook.wait_interval_seconds is not None


@mock_datasync
@mock.patch.object(DataSyncHook, "get_conn")
class TestDataSyncOperatorCreate(DataSyncTestCaseBase):
Expand Down

0 comments on commit c7f518f

Please sign in to comment.