Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add use_regex argument for allowing S3KeySensor to check s3 keys with regular expression #36578

Merged
merged 5 commits into from
Jan 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions airflow/providers/amazon/aws/hooks/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,15 +462,17 @@ async def list_prefixes_async(
return prefixes

@provide_bucket_name_async
async def get_file_metadata_async(self, client: AioBaseClient, bucket_name: str, key: str) -> list[Any]:
async def get_file_metadata_async(
self, client: AioBaseClient, bucket_name: str, key: str | None = None
) -> list[Any]:
"""
Get a list of files that a key matching a wildcard expression exists in a bucket asynchronously.

:param client: aiobotocore client
:param bucket_name: the name of the bucket
:param key: the path to the key
"""
prefix = re.split(r"[\[*?]", key, 1)[0]
prefix = re.split(r"[\[\*\?]", key, 1)[0] if key else ""
delimiter = ""
paginator = client.get_paginator("list_objects_v2")
response = paginator.paginate(Bucket=bucket_name, Prefix=prefix, Delimiter=delimiter)
Expand All @@ -486,6 +488,7 @@ async def _check_key_async(
bucket_val: str,
wildcard_match: bool,
key: str,
use_regex: bool = False,
) -> bool:
"""
Get a list of files that a key matching a wildcard expression or get the head object.
Expand All @@ -498,13 +501,19 @@ async def _check_key_async(
:param bucket_val: the name of the bucket
:param key: S3 keys that will point to the file
:param wildcard_match: the path to the key
:param use_regex: whether to use regex to check bucket
"""
bucket_name, key = self.get_s3_bucket_key(bucket_val, key, "bucket_name", "bucket_key")
if wildcard_match:
keys = await self.get_file_metadata_async(client, bucket_name, key)
key_matches = [k for k in keys if fnmatch.fnmatch(k["Key"], key)]
if not key_matches:
return False
elif use_regex:
keys = await self.get_file_metadata_async(client, bucket_name)
key_matches = [k for k in keys if re.match(pattern=key, string=k["Key"])]
if not key_matches:
return False
else:
obj = await self.get_head_object_async(client, key, bucket_name)
if obj is None:
Expand All @@ -518,6 +527,7 @@ async def check_key_async(
bucket: str,
bucket_keys: str | list[str],
wildcard_match: bool,
use_regex: bool = False,
) -> bool:
"""
Get a list of files that a key matching a wildcard expression or get the head object.
Expand All @@ -530,14 +540,18 @@ async def check_key_async(
:param bucket: the name of the bucket
:param bucket_keys: S3 keys that will point to the file
:param wildcard_match: the path to the key
:param use_regex: whether to use regex to check bucket
"""
if isinstance(bucket_keys, list):
return all(
await asyncio.gather(
*(self._check_key_async(client, bucket, wildcard_match, key) for key in bucket_keys)
*(
self._check_key_async(client, bucket, wildcard_match, key, use_regex)
for key in bucket_keys
)
)
)
return await self._check_key_async(client, bucket, wildcard_match, bucket_keys)
return await self._check_key_async(client, bucket, wildcard_match, bucket_keys, use_regex)

async def check_for_prefix_async(
self, client: AioBaseClient, prefix: str, delimiter: str, bucket_name: str | None = None
Expand Down
21 changes: 13 additions & 8 deletions airflow/providers/amazon/aws/sensors/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ class S3KeySensor(BaseSensorOperator):
def check_fn(files: List) -> bool:
return any(f.get('Size', 0) > 1048576 for f in files)
:param aws_conn_id: a reference to the s3 connection
:param deferrable: Run operator in the deferrable mode
:param verify: Whether to verify SSL certificates for S3 connection.
By default, SSL certificates are verified.
You can provide the following values:
Expand All @@ -76,6 +75,8 @@ def check_fn(files: List) -> bool:
- ``path/to/cert/bundle.pem``: A filename of the CA cert bundle to uses.
You can specify this argument if you want to use a different
CA cert bundle than the one used by botocore.
:param deferrable: Run operator in the deferrable mode
:param use_regex: whether to use regex to check bucket
"""

template_fields: Sequence[str] = ("bucket_key", "bucket_name")
Expand All @@ -90,6 +91,7 @@ def __init__(
aws_conn_id: str = "aws_default",
verify: str | bool | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
use_regex: bool = False,
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -100,6 +102,7 @@ def __init__(
self.aws_conn_id = aws_conn_id
self.verify = verify
self.deferrable = deferrable
self.use_regex = use_regex

def _check_key(self, key):
bucket_name, key = S3Hook.get_s3_bucket_key(self.bucket_name, key, "bucket_name", "bucket_key")
Expand All @@ -121,6 +124,11 @@ def _check_key(self, key):

# Reduce the set of metadata to size only
files = [{"Size": f["Size"]} for f in key_matches]
elif self.use_regex:
keys = self.hook.get_file_metadata("", bucket_name)
key_matches = [k for k in keys if re.match(pattern=key, string=k["Key"])]
if not key_matches:
return False
o-nikolas marked this conversation as resolved.
Show resolved Hide resolved
else:
obj = self.hook.head_object(key, bucket_name)
if obj is None:
Expand Down Expand Up @@ -158,29 +166,26 @@ def _defer(self) -> None:
verify=self.verify,
poke_interval=self.poke_interval,
should_check_fn=bool(self.check_fn),
use_regex=self.use_regex,
),
method_name="execute_complete",
)

def execute_complete(self, context: Context, event: dict[str, Any]) -> bool | None:
def execute_complete(self, context: Context, event: dict[str, Any]) -> None:
"""
Execute when the trigger fires - returns immediately.

Relies on trigger to throw an exception, otherwise it assumes execution was successful.
"""
if event["status"] == "running":
found_keys = self.check_fn(event["files"]) # type: ignore[misc]
if found_keys:
return None
else:
if not found_keys:
self._defer()

if event["status"] == "error":
elif event["status"] == "error":
# TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
if self.soft_fail:
raise AirflowSkipException(event["message"])
raise AirflowException(event["message"])
return None

@deprecated(reason="use `hook` property instead.", category=AirflowProviderDeprecationWarning)
def get_hook(self) -> S3Hook:
Expand Down
13 changes: 9 additions & 4 deletions airflow/providers/amazon/aws/triggers/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class S3KeyTrigger(BaseTrigger):
:param wildcard_match: whether the bucket_key should be interpreted as a
Unix wildcard pattern
:param aws_conn_id: reference to the s3 connection
:param use_regex: whether to use regex to check bucket
:param hook_params: params for hook its optional
"""

Expand All @@ -50,6 +51,7 @@ def __init__(
aws_conn_id: str = "aws_default",
poke_interval: float = 5.0,
should_check_fn: bool = False,
use_regex: bool = False,
**hook_params: Any,
):
super().__init__()
Expand All @@ -60,6 +62,7 @@ def __init__(
self.hook_params = hook_params
self.poke_interval = poke_interval
self.should_check_fn = should_check_fn
self.use_regex = use_regex

def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serialize S3KeyTrigger arguments and classpath."""
Expand All @@ -73,6 +76,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"hook_params": self.hook_params,
"poke_interval": self.poke_interval,
"should_check_fn": self.should_check_fn,
"use_regex": self.use_regex,
},
)

Expand All @@ -86,18 +90,19 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
async with self.hook.async_conn as client:
while True:
if await self.hook.check_key_async(
client, self.bucket_name, self.bucket_key, self.wildcard_match
client, self.bucket_name, self.bucket_key, self.wildcard_match, self.use_regex
):
if self.should_check_fn:
s3_objects = await self.hook.get_files_async(
client, self.bucket_name, self.bucket_key, self.wildcard_match
)
await asyncio.sleep(self.poke_interval)
yield TriggerEvent({"status": "running", "files": s3_objects})
else:
yield TriggerEvent({"status": "success"})
await asyncio.sleep(self.poke_interval)

yield TriggerEvent({"status": "success"})

self.log.info("Sleeping for %s seconds", self.poke_interval)
await asyncio.sleep(self.poke_interval)
except Exception as e:
yield TriggerEvent({"status": "error", "message": str(e)})

Expand Down
16 changes: 16 additions & 0 deletions docs/apache-airflow-providers-amazon/operators/s3/s3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,14 @@ To check multiple files:
:start-after: [START howto_sensor_s3_key_multiple_keys]
:end-before: [END howto_sensor_s3_key_multiple_keys]

To check a file with regular expression:

.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_s3.py
:language: python
:dedent: 4
:start-after: [START howto_sensor_s3_key_regex]
:end-before: [END howto_sensor_s3_key_regex]

To check with an additional custom check you can define a function which receives a list of matched S3 object
attributes and returns a boolean:

Expand Down Expand Up @@ -268,6 +276,14 @@ To check multiple files:
:start-after: [START howto_sensor_s3_key_multiple_keys_deferrable]
:end-before: [END howto_sensor_s3_key_multiple_keys_deferrable]

To check a file with regular expression:

.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_s3.py
:language: python
:dedent: 4
:start-after: [START howto_sensor_s3_key_regex_deferrable]
:end-before: [END howto_sensor_s3_key_regex_deferrable]

.. _howto/sensor:S3KeysUnchangedSensor:

Wait on Amazon S3 prefix changes
Expand Down
47 changes: 41 additions & 6 deletions tests/providers/amazon/aws/hooks/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,10 +635,10 @@ async def test_s3_prefix_sensor_hook_check_for_prefix_async(
@async_mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_s3_bucket_key")
@async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.get_head_object_async")
@async_mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.async_conn")
async def test_s3__check_key_without_wild_card_async(
async def test__check_key_async_without_wildcard_match(
self, mock_client, mock_head_object, mock_get_bucket_key
):
"""Test _check_key function"""
"""Test _check_key_async function without using wildcard_match"""
mock_get_bucket_key.return_value = "test_bucket", "test.txt"
mock_head_object.return_value = {"ContentLength": 0}
s3_hook_async = S3Hook(client_type="S3", resource_type="S3")
Expand All @@ -651,10 +651,10 @@ async def test_s3__check_key_without_wild_card_async(
@async_mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_s3_bucket_key")
@async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.get_head_object_async")
@async_mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.async_conn")
async def test_s3__check_key_none_without_wild_card_async(
async def test_s3__check_key_async_without_wildcard_match_and_get_none(
self, mock_client, mock_head_object, mock_get_bucket_key
):
"""Test _check_key function when get head object returns none"""
"""Test _check_key_async function when get head object returns none"""
mock_get_bucket_key.return_value = "test_bucket", "test.txt"
mock_head_object.return_value = None
s3_hook_async = S3Hook(client_type="S3", resource_type="S3")
Expand All @@ -667,10 +667,10 @@ async def test_s3__check_key_none_without_wild_card_async(
@async_mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_s3_bucket_key")
@async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.get_file_metadata_async")
@async_mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.async_conn")
async def test_s3__check_key_with_wild_card_async(
async def test_s3__check_key_async_with_wildcard_match(
self, mock_client, mock_get_file_metadata, mock_get_bucket_key
):
"""Test _check_key function"""
"""Test _check_key_async function"""
mock_get_bucket_key.return_value = "test_bucket", "test"
mock_get_file_metadata.return_value = [
{
Expand All @@ -692,6 +692,41 @@ async def test_s3__check_key_with_wild_card_async(
)
assert response is False

@pytest.mark.parametrize(
"key, pattern, expected",
[
("test.csv", r"[a-z]+\.csv", True),
("test.txt", r"test/[a-z]+\.csv", False),
("test/test.csv", r"test/[a-z]+\.csv", True),
],
)
@pytest.mark.asyncio
@async_mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_s3_bucket_key")
@async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.get_file_metadata_async")
@async_mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.async_conn")
async def test__check_key_async_with_use_regex(
self, mock_client, mock_get_file_metadata, mock_get_bucket_key, key, pattern, expected
):
"""Match AWS S3 key with regex expression"""
mock_get_bucket_key.return_value = "test_bucket", pattern
mock_get_file_metadata.return_value = [
{
"Key": key,
"ETag": "etag1",
"LastModified": datetime(2020, 8, 14, 17, 19, 34),
"Size": 0,
},
]
s3_hook_async = S3Hook(client_type="S3", resource_type="S3")
response = await s3_hook_async._check_key_async(
client=mock_client.return_value,
bucket_val="test_bucket",
wildcard_match=False,
key=pattern,
use_regex=True,
)
assert response is expected

@pytest.mark.asyncio
@async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.async_conn")
@async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook._list_keys_async")
Expand Down
19 changes: 19 additions & 0 deletions tests/providers/amazon/aws/sensors/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,25 @@ def check_fn(files: list) -> bool:
mock_head_object.return_value = {"ContentLength": 1}
assert op.poke(None) is True

@pytest.mark.parametrize(
"key, pattern, expected",
[
("test.csv", r"[a-z]+\.csv", True),
("test.txt", r"test/[a-z]+\.csv", False),
("test/test.csv", r"test/[a-z]+\.csv", True),
],
)
@mock.patch("airflow.providers.amazon.aws.sensors.s3.S3Hook.get_file_metadata")
def test_poke_with_use_regex(self, mock_get_file_metadata, key, pattern, expected):
op = S3KeySensor(
task_id="s3_key_sensor_async",
bucket_key=pattern,
bucket_name="test_bucket",
use_regex=True,
)
mock_get_file_metadata.return_value = [{"Key": key, "Size": 0}]
assert op.poke(None) is expected

@mock.patch("airflow.providers.amazon.aws.sensors.s3.S3KeySensor.poke", return_value=False)
def test_s3_key_sensor_execute_complete_success_with_keys(self, mock_poke):
"""
Expand Down
1 change: 1 addition & 0 deletions tests/providers/amazon/aws/triggers/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def test_serialization(self):
"hook_params": {},
"poke_interval": 5.0,
"should_check_fn": False,
"use_regex": False,
}

@pytest.mark.asyncio
Expand Down
Loading