diff --git a/airflow/providers/amazon/aws/operators/datasync.py b/airflow/providers/amazon/aws/operators/datasync.py index f990ffbb708a7..c280b531028d5 100644 --- a/airflow/providers/amazon/aws/operators/datasync.py +++ b/airflow/providers/amazon/aws/operators/datasync.py @@ -19,20 +19,20 @@ import logging import random -from functools import cached_property -from typing import TYPE_CHECKING, Sequence +from typing import TYPE_CHECKING, Any, Sequence from deprecated.classic import deprecated from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowTaskTimeout -from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.datasync import DataSyncHook +from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator +from airflow.providers.amazon.aws.utils.mixins import aws_template_fields if TYPE_CHECKING: from airflow.utils.context import Context -class DataSyncOperator(BaseOperator): +class DataSyncOperator(AwsBaseOperator[DataSyncHook]): """Find, Create, Update, Execute and Delete AWS DataSync Tasks. If ``do_xcom_push`` is True, then the DataSync TaskArn and TaskExecutionArn @@ -46,7 +46,6 @@ class DataSyncOperator(BaseOperator): environment. The default behavior is to create a new Task if there are 0, or execute the Task if there was 1 Task, or fail if there were many Tasks. - :param aws_conn_id: AWS connection to use. :param wait_interval_seconds: Time to wait between two consecutive calls to check TaskExecution status. :param max_iterations: Maximum number of @@ -91,6 +90,16 @@ class DataSyncOperator(BaseOperator): ``boto3.start_task_execution(TaskArn=task_arn, **task_execution_kwargs)`` :param delete_task_after_execution: If True then the TaskArn which was executed will be deleted from AWS DataSync on successful completion. + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html :raises AirflowException: If ``task_arn`` was not specified, or if either ``source_location_uri`` or ``destination_location_uri`` were not specified. @@ -100,7 +109,8 @@ class DataSyncOperator(BaseOperator): :raises AirflowException: If Task creation, update, execution or delete fails. """ - template_fields: Sequence[str] = ( + aws_hook_class = DataSyncHook + template_fields: Sequence[str] = aws_template_fields( "task_arn", "source_location_uri", "destination_location_uri", @@ -122,7 +132,6 @@ class DataSyncOperator(BaseOperator): def __init__( self, *, - aws_conn_id: str = "aws_default", wait_interval_seconds: int = 30, max_iterations: int = 60, wait_for_completion: bool = True, @@ -142,7 +151,6 @@ def __init__( super().__init__(**kwargs) # Assignments - self.aws_conn_id = aws_conn_id self.wait_interval_seconds = wait_interval_seconds self.max_iterations = max_iterations self.wait_for_completion = wait_for_completion @@ -185,16 +193,9 @@ def __init__( self.destination_location_arn: str | None = None self.task_execution_arn: str | None = None - @cached_property - def hook(self) -> DataSyncHook: - """Create and return DataSyncHook. - - :return DataSyncHook: An DataSyncHook instance. - """ - return DataSyncHook( - aws_conn_id=self.aws_conn_id, - wait_interval_seconds=self.wait_interval_seconds, - ) + @property + def _hook_parameters(self) -> dict[str, Any]: + return {**super()._hook_parameters, "wait_interval_seconds": self.wait_interval_seconds} @deprecated(reason="use `hook` property instead.", category=AirflowProviderDeprecationWarning) def get_hook(self) -> DataSyncHook: diff --git a/docs/apache-airflow-providers-amazon/operators/datasync.rst b/docs/apache-airflow-providers-amazon/operators/datasync.rst index aca6d5d755145..16e65db42dbdc 100644 --- a/docs/apache-airflow-providers-amazon/operators/datasync.rst +++ b/docs/apache-airflow-providers-amazon/operators/datasync.rst @@ -28,6 +28,11 @@ Prerequisite Tasks .. include:: ../_partials/prerequisite_tasks.rst +Generic Parameters +------------------ + +.. include:: ../_partials/generic_parameters.rst + Operators --------- diff --git a/tests/providers/amazon/aws/operators/test_datasync.py b/tests/providers/amazon/aws/operators/test_datasync.py index 829dca7082226..fa666dd4766bb 100644 --- a/tests/providers/amazon/aws/operators/test_datasync.py +++ b/tests/providers/amazon/aws/operators/test_datasync.py @@ -108,6 +108,42 @@ def teardown_method(self, method): self.client = None +def test_generic_params(): + op = DataSyncOperator( + task_id="generic-task", + task_arn="arn:fake", + source_location_uri="fake://source", + destination_location_uri="fake://destination", + aws_conn_id="fake-conn-id", + region_name="cn-north-1", + verify=False, + botocore_config={"read_timeout": 42}, + # Non-generic hook params + wait_interval_seconds=42, + ) + + assert op.hook.client_type == "datasync" + assert op.hook.resource_type is None + assert op.hook.aws_conn_id == "fake-conn-id" + assert op.hook._region_name == "cn-north-1" + assert op.hook._verify is False + assert op.hook._config is not None + assert op.hook._config.read_timeout == 42 + assert op.hook.wait_interval_seconds == 42 + + op = DataSyncOperator( + task_id="generic-task", + task_arn="arn:fake", + source_location_uri="fake://source", + destination_location_uri="fake://destination", + ) + assert op.hook.aws_conn_id == "aws_default" + assert op.hook._region_name is None + assert op.hook._verify is None + assert op.hook._config is None + assert op.hook.wait_interval_seconds is not None + + @mock_datasync @mock.patch.object(DataSyncHook, "get_conn") class TestDataSyncOperatorCreate(DataSyncTestCaseBase):