diff --git a/airflow/providers/amazon/aws/hooks/s3.py b/airflow/providers/amazon/aws/hooks/s3.py index 9d4ecaa38d464..1a961e583cd52 100644 --- a/airflow/providers/amazon/aws/hooks/s3.py +++ b/airflow/providers/amazon/aws/hooks/s3.py @@ -462,7 +462,9 @@ async def list_prefixes_async( return prefixes @provide_bucket_name_async - async def get_file_metadata_async(self, client: AioBaseClient, bucket_name: str, key: str) -> list[Any]: + async def get_file_metadata_async( + self, client: AioBaseClient, bucket_name: str, key: str | None = None + ) -> list[Any]: """ Get a list of files that a key matching a wildcard expression exists in a bucket asynchronously. @@ -470,7 +472,7 @@ async def get_file_metadata_async(self, client: AioBaseClient, bucket_name: str, :param bucket_name: the name of the bucket :param key: the path to the key """ - prefix = re.split(r"[\[*?]", key, 1)[0] + prefix = re.split(r"[\[\*\?]", key, 1)[0] if key else "" delimiter = "" paginator = client.get_paginator("list_objects_v2") response = paginator.paginate(Bucket=bucket_name, Prefix=prefix, Delimiter=delimiter) @@ -486,6 +488,7 @@ async def _check_key_async( bucket_val: str, wildcard_match: bool, key: str, + use_regex: bool = False, ) -> bool: """ Get a list of files that a key matching a wildcard expression or get the head object. @@ -498,6 +501,7 @@ async def _check_key_async( :param bucket_val: the name of the bucket :param key: S3 keys that will point to the file :param wildcard_match: the path to the key + :param use_regex: whether to use regex to check bucket """ bucket_name, key = self.get_s3_bucket_key(bucket_val, key, "bucket_name", "bucket_key") if wildcard_match: @@ -505,6 +509,11 @@ async def _check_key_async( key_matches = [k for k in keys if fnmatch.fnmatch(k["Key"], key)] if not key_matches: return False + elif use_regex: + keys = await self.get_file_metadata_async(client, bucket_name) + key_matches = [k for k in keys if re.match(pattern=key, string=k["Key"])] + if not key_matches: + return False else: obj = await self.get_head_object_async(client, key, bucket_name) if obj is None: @@ -518,6 +527,7 @@ async def check_key_async( bucket: str, bucket_keys: str | list[str], wildcard_match: bool, + use_regex: bool = False, ) -> bool: """ Get a list of files that a key matching a wildcard expression or get the head object. @@ -530,14 +540,18 @@ async def check_key_async( :param bucket: the name of the bucket :param bucket_keys: S3 keys that will point to the file :param wildcard_match: the path to the key + :param use_regex: whether to use regex to check bucket """ if isinstance(bucket_keys, list): return all( await asyncio.gather( - *(self._check_key_async(client, bucket, wildcard_match, key) for key in bucket_keys) + *( + self._check_key_async(client, bucket, wildcard_match, key, use_regex) + for key in bucket_keys + ) ) ) - return await self._check_key_async(client, bucket, wildcard_match, bucket_keys) + return await self._check_key_async(client, bucket, wildcard_match, bucket_keys, use_regex) async def check_for_prefix_async( self, client: AioBaseClient, prefix: str, delimiter: str, bucket_name: str | None = None diff --git a/airflow/providers/amazon/aws/sensors/s3.py b/airflow/providers/amazon/aws/sensors/s3.py index 519aa49d6f557..6d55a724af6da 100644 --- a/airflow/providers/amazon/aws/sensors/s3.py +++ b/airflow/providers/amazon/aws/sensors/s3.py @@ -65,7 +65,6 @@ class S3KeySensor(BaseSensorOperator): def check_fn(files: List) -> bool: return any(f.get('Size', 0) > 1048576 for f in files) :param aws_conn_id: a reference to the s3 connection - :param deferrable: Run operator in the deferrable mode :param verify: Whether to verify SSL certificates for S3 connection. By default, SSL certificates are verified. You can provide the following values: @@ -76,6 +75,8 @@ def check_fn(files: List) -> bool: - ``path/to/cert/bundle.pem``: A filename of the CA cert bundle to uses. You can specify this argument if you want to use a different CA cert bundle than the one used by botocore. + :param deferrable: Run operator in the deferrable mode + :param use_regex: whether to use regex to check bucket """ template_fields: Sequence[str] = ("bucket_key", "bucket_name") @@ -90,6 +91,7 @@ def __init__( aws_conn_id: str = "aws_default", verify: str | bool | None = None, deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + use_regex: bool = False, **kwargs, ): super().__init__(**kwargs) @@ -100,6 +102,7 @@ def __init__( self.aws_conn_id = aws_conn_id self.verify = verify self.deferrable = deferrable + self.use_regex = use_regex def _check_key(self, key): bucket_name, key = S3Hook.get_s3_bucket_key(self.bucket_name, key, "bucket_name", "bucket_key") @@ -121,6 +124,11 @@ def _check_key(self, key): # Reduce the set of metadata to size only files = [{"Size": f["Size"]} for f in key_matches] + elif self.use_regex: + keys = self.hook.get_file_metadata("", bucket_name) + key_matches = [k for k in keys if re.match(pattern=key, string=k["Key"])] + if not key_matches: + return False else: obj = self.hook.head_object(key, bucket_name) if obj is None: @@ -158,11 +166,12 @@ def _defer(self) -> None: verify=self.verify, poke_interval=self.poke_interval, should_check_fn=bool(self.check_fn), + use_regex=self.use_regex, ), method_name="execute_complete", ) - def execute_complete(self, context: Context, event: dict[str, Any]) -> bool | None: + def execute_complete(self, context: Context, event: dict[str, Any]) -> None: """ Execute when the trigger fires - returns immediately. @@ -170,17 +179,13 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> bool | No """ if event["status"] == "running": found_keys = self.check_fn(event["files"]) # type: ignore[misc] - if found_keys: - return None - else: + if not found_keys: self._defer() - - if event["status"] == "error": + elif event["status"] == "error": # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 if self.soft_fail: raise AirflowSkipException(event["message"]) raise AirflowException(event["message"]) - return None @deprecated(reason="use `hook` property instead.", category=AirflowProviderDeprecationWarning) def get_hook(self) -> S3Hook: diff --git a/airflow/providers/amazon/aws/triggers/s3.py b/airflow/providers/amazon/aws/triggers/s3.py index 864e1fe14c0e2..df079ff0f8952 100644 --- a/airflow/providers/amazon/aws/triggers/s3.py +++ b/airflow/providers/amazon/aws/triggers/s3.py @@ -39,6 +39,7 @@ class S3KeyTrigger(BaseTrigger): :param wildcard_match: whether the bucket_key should be interpreted as a Unix wildcard pattern :param aws_conn_id: reference to the s3 connection + :param use_regex: whether to use regex to check bucket :param hook_params: params for hook its optional """ @@ -50,6 +51,7 @@ def __init__( aws_conn_id: str = "aws_default", poke_interval: float = 5.0, should_check_fn: bool = False, + use_regex: bool = False, **hook_params: Any, ): super().__init__() @@ -60,6 +62,7 @@ def __init__( self.hook_params = hook_params self.poke_interval = poke_interval self.should_check_fn = should_check_fn + self.use_regex = use_regex def serialize(self) -> tuple[str, dict[str, Any]]: """Serialize S3KeyTrigger arguments and classpath.""" @@ -73,6 +76,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]: "hook_params": self.hook_params, "poke_interval": self.poke_interval, "should_check_fn": self.should_check_fn, + "use_regex": self.use_regex, }, ) @@ -86,7 +90,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]: async with self.hook.async_conn as client: while True: if await self.hook.check_key_async( - client, self.bucket_name, self.bucket_key, self.wildcard_match + client, self.bucket_name, self.bucket_key, self.wildcard_match, self.use_regex ): if self.should_check_fn: s3_objects = await self.hook.get_files_async( @@ -94,10 +98,11 @@ async def run(self) -> AsyncIterator[TriggerEvent]: ) await asyncio.sleep(self.poke_interval) yield TriggerEvent({"status": "running", "files": s3_objects}) - else: - yield TriggerEvent({"status": "success"}) - await asyncio.sleep(self.poke_interval) + yield TriggerEvent({"status": "success"}) + + self.log.info("Sleeping for %s seconds", self.poke_interval) + await asyncio.sleep(self.poke_interval) except Exception as e: yield TriggerEvent({"status": "error", "message": str(e)}) diff --git a/docs/apache-airflow-providers-amazon/operators/s3/s3.rst b/docs/apache-airflow-providers-amazon/operators/s3/s3.rst index 516b1712446e0..41e5c7149bb18 100644 --- a/docs/apache-airflow-providers-amazon/operators/s3/s3.rst +++ b/docs/apache-airflow-providers-amazon/operators/s3/s3.rst @@ -222,6 +222,14 @@ To check multiple files: :start-after: [START howto_sensor_s3_key_multiple_keys] :end-before: [END howto_sensor_s3_key_multiple_keys] +To check a file with regular expression: + +.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_s3.py + :language: python + :dedent: 4 + :start-after: [START howto_sensor_s3_key_regex] + :end-before: [END howto_sensor_s3_key_regex] + To check with an additional custom check you can define a function which receives a list of matched S3 object attributes and returns a boolean: @@ -268,6 +276,14 @@ To check multiple files: :start-after: [START howto_sensor_s3_key_multiple_keys_deferrable] :end-before: [END howto_sensor_s3_key_multiple_keys_deferrable] +To check a file with regular expression: + +.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_s3.py + :language: python + :dedent: 4 + :start-after: [START howto_sensor_s3_key_regex_deferrable] + :end-before: [END howto_sensor_s3_key_regex_deferrable] + .. _howto/sensor:S3KeysUnchangedSensor: Wait on Amazon S3 prefix changes diff --git a/tests/providers/amazon/aws/hooks/test_s3.py b/tests/providers/amazon/aws/hooks/test_s3.py index b1b52b35576b9..21bcdcabd71fb 100644 --- a/tests/providers/amazon/aws/hooks/test_s3.py +++ b/tests/providers/amazon/aws/hooks/test_s3.py @@ -635,10 +635,10 @@ async def test_s3_prefix_sensor_hook_check_for_prefix_async( @async_mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_s3_bucket_key") @async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.get_head_object_async") @async_mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.async_conn") - async def test_s3__check_key_without_wild_card_async( + async def test__check_key_async_without_wildcard_match( self, mock_client, mock_head_object, mock_get_bucket_key ): - """Test _check_key function""" + """Test _check_key_async function without using wildcard_match""" mock_get_bucket_key.return_value = "test_bucket", "test.txt" mock_head_object.return_value = {"ContentLength": 0} s3_hook_async = S3Hook(client_type="S3", resource_type="S3") @@ -651,10 +651,10 @@ async def test_s3__check_key_without_wild_card_async( @async_mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_s3_bucket_key") @async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.get_head_object_async") @async_mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.async_conn") - async def test_s3__check_key_none_without_wild_card_async( + async def test_s3__check_key_async_without_wildcard_match_and_get_none( self, mock_client, mock_head_object, mock_get_bucket_key ): - """Test _check_key function when get head object returns none""" + """Test _check_key_async function when get head object returns none""" mock_get_bucket_key.return_value = "test_bucket", "test.txt" mock_head_object.return_value = None s3_hook_async = S3Hook(client_type="S3", resource_type="S3") @@ -667,10 +667,10 @@ async def test_s3__check_key_none_without_wild_card_async( @async_mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_s3_bucket_key") @async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.get_file_metadata_async") @async_mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.async_conn") - async def test_s3__check_key_with_wild_card_async( + async def test_s3__check_key_async_with_wildcard_match( self, mock_client, mock_get_file_metadata, mock_get_bucket_key ): - """Test _check_key function""" + """Test _check_key_async function""" mock_get_bucket_key.return_value = "test_bucket", "test" mock_get_file_metadata.return_value = [ { @@ -692,6 +692,41 @@ async def test_s3__check_key_with_wild_card_async( ) assert response is False + @pytest.mark.parametrize( + "key, pattern, expected", + [ + ("test.csv", r"[a-z]+\.csv", True), + ("test.txt", r"test/[a-z]+\.csv", False), + ("test/test.csv", r"test/[a-z]+\.csv", True), + ], + ) + @pytest.mark.asyncio + @async_mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_s3_bucket_key") + @async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.get_file_metadata_async") + @async_mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.async_conn") + async def test__check_key_async_with_use_regex( + self, mock_client, mock_get_file_metadata, mock_get_bucket_key, key, pattern, expected + ): + """Match AWS S3 key with regex expression""" + mock_get_bucket_key.return_value = "test_bucket", pattern + mock_get_file_metadata.return_value = [ + { + "Key": key, + "ETag": "etag1", + "LastModified": datetime(2020, 8, 14, 17, 19, 34), + "Size": 0, + }, + ] + s3_hook_async = S3Hook(client_type="S3", resource_type="S3") + response = await s3_hook_async._check_key_async( + client=mock_client.return_value, + bucket_val="test_bucket", + wildcard_match=False, + key=pattern, + use_regex=True, + ) + assert response is expected + @pytest.mark.asyncio @async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.async_conn") @async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook._list_keys_async") diff --git a/tests/providers/amazon/aws/sensors/test_s3.py b/tests/providers/amazon/aws/sensors/test_s3.py index b326a2d0894d1..2fa2e458a9840 100644 --- a/tests/providers/amazon/aws/sensors/test_s3.py +++ b/tests/providers/amazon/aws/sensors/test_s3.py @@ -232,6 +232,25 @@ def check_fn(files: list) -> bool: mock_head_object.return_value = {"ContentLength": 1} assert op.poke(None) is True + @pytest.mark.parametrize( + "key, pattern, expected", + [ + ("test.csv", r"[a-z]+\.csv", True), + ("test.txt", r"test/[a-z]+\.csv", False), + ("test/test.csv", r"test/[a-z]+\.csv", True), + ], + ) + @mock.patch("airflow.providers.amazon.aws.sensors.s3.S3Hook.get_file_metadata") + def test_poke_with_use_regex(self, mock_get_file_metadata, key, pattern, expected): + op = S3KeySensor( + task_id="s3_key_sensor_async", + bucket_key=pattern, + bucket_name="test_bucket", + use_regex=True, + ) + mock_get_file_metadata.return_value = [{"Key": key, "Size": 0}] + assert op.poke(None) is expected + @mock.patch("airflow.providers.amazon.aws.sensors.s3.S3KeySensor.poke", return_value=False) def test_s3_key_sensor_execute_complete_success_with_keys(self, mock_poke): """ diff --git a/tests/providers/amazon/aws/triggers/test_s3.py b/tests/providers/amazon/aws/triggers/test_s3.py index a73223f15ce89..01533d298875b 100644 --- a/tests/providers/amazon/aws/triggers/test_s3.py +++ b/tests/providers/amazon/aws/triggers/test_s3.py @@ -45,6 +45,7 @@ def test_serialization(self): "hook_params": {}, "poke_interval": 5.0, "should_check_fn": False, + "use_regex": False, } @pytest.mark.asyncio diff --git a/tests/system/providers/amazon/aws/example_s3.py b/tests/system/providers/amazon/aws/example_s3.py index 87806a0bf0a5a..3b4a4bd38bf6f 100644 --- a/tests/system/providers/amazon/aws/example_s3.py +++ b/tests/system/providers/amazon/aws/example_s3.py @@ -71,6 +71,8 @@ key = f"{env_id}-key" key_2 = f"{env_id}-key2" + key_regex_pattern = ".*-key" + # [START howto_sensor_s3_key_function_definition] def check_fn(files: list) -> bool: """ @@ -191,7 +193,7 @@ def check_fn(files: list) -> bool: ) # [END howto_sensor_s3_key_multiple_keys_deferrable] - # [START howto_sensor_s3_key_function] + # [START howto_sensor_s3_key_function_deferrable] # Check if a file exists and match a certain pattern defined in check_fn sensor_key_with_function_deferrable = S3KeySensor( task_id="sensor_key_with_function_deferrable", @@ -200,7 +202,18 @@ def check_fn(files: list) -> bool: check_fn=check_fn, deferrable=True, ) - # [END howto_sensor_s3_key_function] + # [END howto_sensor_s3_key_function_deferrable] + + # [START howto_sensor_s3_key_regex_deferrable] + # Check if a file exists and match a certain regular expression pattern + sensor_key_with_regex_deferrable = S3KeySensor( + task_id="sensor_key_with_regex_deferrable", + bucket_name=bucket_name, + bucket_key=key_regex_pattern, + use_regex=True, + deferrable=True, + ) + # [END howto_sensor_s3_key_regex_deferrable] # [START howto_sensor_s3_key_function] # Check if a file exists and match a certain pattern defined in check_fn @@ -212,6 +225,13 @@ def check_fn(files: list) -> bool: ) # [END howto_sensor_s3_key_function] + # [START howto_sensor_s3_key_regex] + # Check if a file exists and match a certain regular expression pattern + sensor_key_with_regex = S3KeySensor( + task_id="sensor_key_with_regex", bucket_name=bucket_name, bucket_key=key_regex_pattern, use_regex=True + ) + # [END howto_sensor_s3_key_regex] + # [START howto_operator_s3_copy_object] copy_object = S3CopyObjectOperator( task_id="copy_object", @@ -286,8 +306,13 @@ def check_fn(files: list) -> bool: create_object_2, list_prefixes, list_keys, - [sensor_one_key, sensor_two_keys, sensor_key_with_function], - [sensor_one_key_deferrable, sensor_two_keys_deferrable, sensor_key_with_function_deferrable], + [sensor_one_key, sensor_two_keys, sensor_key_with_function, sensor_key_with_regex], + [ + sensor_one_key_deferrable, + sensor_two_keys_deferrable, + sensor_key_with_function_deferrable, + sensor_key_with_regex_deferrable, + ], copy_object, file_transform, branching,