Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate S3PrefixSensor and S3KeySizeSensor in favor of S3KeySensor #22737

Merged
merged 11 commits into from
Apr 12, 2022
103 changes: 38 additions & 65 deletions airflow/providers/amazon/aws/sensors/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import os
import re
import sys
import warnings
from datetime import datetime
from typing import TYPE_CHECKING, Callable, List, Optional, Sequence, Set, Union
from urllib.parse import urlparse
Expand All @@ -40,15 +41,16 @@

class S3KeySensor(BaseSensorOperator):
"""
Waits for a key (a file-like instance on S3) to be present in a S3 bucket.
Waits for one or multiple keys (a file-like instance on S3) to be present in a S3 bucket.
eladkal marked this conversation as resolved.
Show resolved Hide resolved
S3 being a key/value it does not support folders. The path is just a key
a resource.

:param bucket_key: The key being waited on. Supports full s3:// style url
:param bucket_key: The key(s) being waited on. Supports full s3:// style url
or relative path from root level. When it's specified as a full s3://
url, please leave bucket_name as `None`.
url, please leave bucket_name as `None`
:param bucket_name: Name of the S3 bucket. Only needed when ``bucket_key``
is not provided as a full s3:// url.
is not provided as a full s3:// url. When specified, all the keys passed to ``bucket_key``
refers to this bucket
:param wildcard_match: whether the bucket_key should be interpreted as a
Unix wildcard pattern
:param aws_conn_id: a reference to the s3 connection
Expand All @@ -69,7 +71,7 @@ class S3KeySensor(BaseSensorOperator):
def __init__(
self,
*,
bucket_key: str,
bucket_key: Union[str, List[str]],
bucket_name: Optional[str] = None,
wildcard_match: bool = False,
aws_conn_id: str = 'aws_default',
Expand All @@ -78,27 +80,32 @@ def __init__(
):
super().__init__(**kwargs)
self.bucket_name = bucket_name
self.bucket_key = bucket_key
self.bucket_key = [bucket_key] if isinstance(bucket_key, str) else bucket_key
self.wildcard_match = wildcard_match
self.aws_conn_id = aws_conn_id
self.verify = verify
self.hook: Optional[S3Hook] = None

def _resolve_bucket_and_key(self):
def _resolve_bucket_and_key(self, key):
"""If key is URI, parse bucket"""
if self.bucket_name is None:
self.bucket_name, self.bucket_key = S3Hook.parse_s3_url(self.bucket_key)
return S3Hook.parse_s3_url(key)
else:
parsed_url = urlparse(self.bucket_key)
parsed_url = urlparse(key)
if parsed_url.scheme != '' or parsed_url.netloc != '':
raise AirflowException('If bucket_name provided, bucket_key must be relative path, not URI.')
return self.bucket_name, key

def poke(self, context: 'Context'):
self._resolve_bucket_and_key()
self.log.info('Poking for key : s3://%s/%s', self.bucket_name, self.bucket_key)
def _key_exists(self, key):
bucket_name, key = self._resolve_bucket_and_key(key)
self.log.info('Poking for key : s3://%s/%s', bucket_name, key)
if self.wildcard_match:
return self.get_hook().check_for_wildcard_key(self.bucket_key, self.bucket_name)
return self.get_hook().check_for_key(self.bucket_key, self.bucket_name)
return self.get_hook().check_for_wildcard_key(key, bucket_name)

return self.get_hook().check_for_key(key, bucket_name)

def poke(self, context: 'Context'):
return all(self._key_exists(key) for key in self.bucket_key)
eladkal marked this conversation as resolved.
Show resolved Hide resolved

def get_hook(self) -> S3Hook:
"""Create and return an S3Hook"""
Expand Down Expand Up @@ -166,13 +173,13 @@ def poke(self, context: 'Context'):

def get_files(self, s3_hook: S3Hook, delimiter: Optional[str] = '/') -> List:
"""Gets a list of files in the bucket"""
prefix = self.bucket_key
prefix = self.bucket_key[0]
eladkal marked this conversation as resolved.
Show resolved Hide resolved
config = {
'PageSize': None,
'MaxItems': None,
}
if self.wildcard_match:
prefix = re.split(r'[\[\*\?]', self.bucket_key, 1)[0]
prefix = re.split(r'[\[\*\?]', self.bucket_key[0], 1)[0]

paginator = s3_hook.get_conn().get_paginator('list_objects_v2')
response = paginator.paginate(
Expand Down Expand Up @@ -332,66 +339,32 @@ def poke(self, context: 'Context'):
return self.is_keys_unchanged(set(self.hook.list_keys(self.bucket_name, prefix=self.prefix)))


class S3PrefixSensor(BaseSensorOperator):
class S3PrefixSensor(S3KeySensor):
"""
Waits for a prefix or all prefixes to exist. A prefix is the first part of a key,
thus enabling checking of constructs similar to glob ``airfl*`` or
SQL LIKE ``'airfl%'``. There is the possibility to precise a delimiter to
indicate the hierarchy or keys, meaning that the match will stop at that
delimiter. Current code accepts sane delimiters, i.e. characters that
are NOT special characters in the Python regex engine.

:param bucket_name: Name of the S3 bucket
:param prefix: The prefix being waited on. Relative path from bucket root level.
:param delimiter: The delimiter intended to show hierarchy.
Defaults to '/'.
:param aws_conn_id: a reference to the s3 connection
:param verify: Whether or not to verify SSL certificates for S3 connection.
By default SSL certificates are verified.
You can provide the following values:

- ``False``: do not validate SSL certificates. SSL will still be used
(unless use_ssl is False), but SSL certificates will not be
verified.
- ``path/to/cert/bundle.pem``: A filename of the CA cert bundle to uses.
You can specify this argument if you want to use a different
CA cert bundle than the one used by botocore.
This class is deprecated.
Please use `airflow.providers.amazon.aws.sensors.s3.S3KeySensor`.
vincbeck marked this conversation as resolved.
Show resolved Hide resolved
"""

template_fields: Sequence[str] = ('prefix', 'bucket_name')

def __init__(
self,
*,
bucket_name: str,
prefix: Union[str, List[str]],
delimiter: str = '/',
aws_conn_id: str = 'aws_default',
verify: Optional[Union[str, bool]] = None,
**kwargs,
):
super().__init__(**kwargs)
# Parse
self.bucket_name = bucket_name
self.prefix = [prefix] if isinstance(prefix, str) else prefix
self.delimiter = delimiter
self.aws_conn_id = aws_conn_id
self.verify = verify
self.hook: Optional[S3Hook] = None

def poke(self, context: 'Context'):
self.log.info('Poking for prefix : %s in bucket s3://%s', self.prefix, self.bucket_name)
return all(self._check_for_prefix(prefix) for prefix in self.prefix)

def get_hook(self) -> S3Hook:
"""Create and return an S3Hook"""
if self.hook:
return self.hook
warnings.warn(
"""
S3PrefixSensor is deprecated.
Please use `airflow.providers.amazon.aws.sensors.s3.S3KeySensor`.
""",
DeprecationWarning,
stacklevel=2,
)

self.hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
return self.hook
self.prefix = prefix
prefixes = [self.prefix] if isinstance(self.prefix, str) else self.prefix
keys = [pref if pref.endswith(delimiter) else pref + delimiter for pref in prefixes]

def _check_for_prefix(self, prefix: str) -> bool:
return self.get_hook().check_for_prefix(
prefix=prefix, delimiter=self.delimiter, bucket_name=self.bucket_name
)
super().__init__(bucket_key=keys, **kwargs)
4 changes: 2 additions & 2 deletions airflow/sensors/s3_prefix_sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.sensors.s3_prefix`."""
"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.sensors.s3`."""

import warnings

from airflow.providers.amazon.aws.sensors.s3_prefix import S3PrefixSensor # noqa

warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.s3_prefix`.",
"This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.s3`.",
DeprecationWarning,
stacklevel=2,
)
1 change: 0 additions & 1 deletion docs/apache-airflow-providers-amazon/operators/s3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ Airflow to Amazon Simple Storage Service (S3) integration provides several opera
- :class:`~airflow.providers.amazon.aws.sensors.s3.S3KeySensor`
- :class:`~airflow.providers.amazon.aws.sensors.s3.S3KeySizeSensor`
- :class:`~airflow.providers.amazon.aws.sensors.s3.S3KeysUnchangedSensor`
- :class:`~airflow.providers.amazon.aws.sensors.s3.S3PrefixSensor`
- :class:`~airflow.providers.amazon.aws.operators.s3.S3CreateBucketOperator`
- :class:`~airflow.providers.amazon.aws.operators.s3.S3DeleteBucketOperator`
- :class:`~airflow.providers.amazon.aws.operators.s3.S3DeleteBucketTaggingOperator`
Expand Down
98 changes: 78 additions & 20 deletions tests/providers/amazon/aws/sensors/test_s3_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,18 @@ def test_bucket_name_none_and_bucket_key_as_relative_path(self):
with pytest.raises(AirflowException):
op.poke(None)

@mock.patch('airflow.providers.amazon.aws.sensors.s3.S3Hook.check_for_key')
def test_bucket_name_none_and_bucket_key_is_list_and_contain_relative_path(self, mock_check):
"""
Test if exception is raised when bucket_name is None
and bucket_key is provided with one of the two keys as relative path rather than s3:// url.
:return:
"""
mock_check.return_value = True
op = S3KeySensor(task_id='s3_key_sensor', bucket_key=["s3://test_bucket/file", "file_in_bucket"])
with pytest.raises(AirflowException):
op.poke(None)

def test_bucket_name_provided_and_bucket_key_is_s3_url(self):
"""
Test if exception is raised when bucket_name is provided
Expand All @@ -52,15 +64,31 @@ def test_bucket_name_provided_and_bucket_key_is_s3_url(self):
with pytest.raises(AirflowException):
op.poke(None)

@mock.patch('airflow.providers.amazon.aws.sensors.s3.S3Hook.check_for_key')
def test_bucket_name_provided_and_bucket_key_is_list_and_contains_s3_url(self, mock_check):
"""
Test if exception is raised when bucket_name is provided
while bucket_key contains a full s3:// url.
:return:
"""
mock_check.return_value = True
op = S3KeySensor(
task_id='s3_key_sensor',
bucket_key=["test_bucket", "s3://test_bucket/file"],
bucket_name='test_bucket',
)
with pytest.raises(AirflowException):
op.poke(None)

@parameterized.expand(
[
['s3://bucket/key', None, 'key', 'bucket'],
['key', 'bucket', 'key', 'bucket'],
]
)
@mock.patch('airflow.providers.amazon.aws.sensors.s3.S3Hook.check_for_key')
def test_parse_bucket_key(self, key, bucket, parsed_key, parsed_bucket, mock_check):
mock_check.return_value = False
def test_parse_bucket_key(self, key, bucket, parsed_key, parsed_bucket, mock_check_for_key):
mock_check_for_key.return_value = False

op = S3KeySensor(
task_id='s3_key_sensor',
Expand All @@ -70,12 +98,11 @@ def test_parse_bucket_key(self, key, bucket, parsed_key, parsed_bucket, mock_che

op.poke(None)

assert op.bucket_key == parsed_key
assert op.bucket_name == parsed_bucket
mock_check_for_key.assert_called_once_with(parsed_key, parsed_bucket)

@mock.patch('airflow.providers.amazon.aws.sensors.s3.S3Hook.check_for_key')
def test_parse_bucket_key_from_jinja(self, mock_check):
mock_check.return_value = False
def test_parse_bucket_key_from_jinja(self, mock_check_for_key):
mock_check_for_key.return_value = False

Variable.set("test_bucket_key", "s3://bucket/key")

Expand All @@ -96,45 +123,76 @@ def test_parse_bucket_key_from_jinja(self, mock_check):
ti.render_templates(context)
op.poke(None)

assert op.bucket_key == "key"
assert op.bucket_name == "bucket"
mock_check_for_key.assert_called_once_with("key", "bucket")

@mock.patch('airflow.providers.amazon.aws.sensors.s3.S3Hook.check_for_key')
def test_poke(self, mock_check):
def test_poke(self, mock_check_for_key):
op = S3KeySensor(task_id='s3_key_sensor', bucket_key='s3://test_bucket/file')

mock_check.return_value = False
mock_check_for_key.return_value = False
assert op.poke(None) is False
mock_check.assert_called_once_with(op.bucket_key, op.bucket_name)
mock_check_for_key.assert_called_once_with("file", "test_bucket")

mock_check.return_value = True
mock_check_for_key.return_value = True
assert op.poke(None) is True

@mock.patch('airflow.providers.amazon.aws.sensors.s3.S3Hook.check_for_key')
def test_poke_multiple_files(self, mock_check_for_key):
op = S3KeySensor(
task_id='s3_key_sensor', bucket_key=['s3://test_bucket/file1', 's3://test_bucket/file2']
)

mock_check_for_key.side_effect = [True, False]
assert op.poke(None) is False

mock_check_for_key.side_effect = [True, True]
assert op.poke(None) is True

mock_check_for_key.assert_any_call("file1", "test_bucket")
mock_check_for_key.assert_any_call("file2", "test_bucket")

@mock.patch('airflow.providers.amazon.aws.sensors.s3.S3Hook.check_for_wildcard_key')
def test_poke_wildcard(self, mock_check):
def test_poke_wildcard(self, mock_check_for_key):
op = S3KeySensor(task_id='s3_key_sensor', bucket_key='s3://test_bucket/file', wildcard_match=True)

mock_check.return_value = False
mock_check_for_key.return_value = False
assert op.poke(None) is False
mock_check.assert_called_once_with(op.bucket_key, op.bucket_name)
mock_check_for_key.assert_called_once_with("file", "test_bucket")

mock_check.return_value = True
mock_check_for_key.return_value = True
assert op.poke(None) is True

@mock.patch('airflow.providers.amazon.aws.sensors.s3.S3Hook.check_for_wildcard_key')
def test_poke_wildcard_multiple_files(self, mock_check_for_wildcard_key):
op = S3KeySensor(
task_id='s3_key_sensor',
bucket_key=['s3://test_bucket/file1', 's3://test_bucket/file2'],
wildcard_match=True,
)

mock_check_for_wildcard_key.side_effect = [True, False]
assert op.poke(None) is False

mock_check_for_wildcard_key.side_effect = [True, True]
assert op.poke(None) is True

mock_check_for_wildcard_key.assert_any_call("file1", "test_bucket")
mock_check_for_wildcard_key.assert_any_call("file2", "test_bucket")


class TestS3KeySizeSensor(unittest.TestCase):
@mock.patch('airflow.providers.amazon.aws.sensors.s3.S3Hook.check_for_key', return_value=False)
def test_poke_check_for_key_false(self, mock_check_for_key):
op = S3KeySizeSensor(task_id='s3_key_sensor', bucket_key='s3://test_bucket/file')
assert op.poke(None) is False
mock_check_for_key.assert_called_once_with(op.bucket_key, op.bucket_name)
mock_check_for_key.assert_called_once_with("file", "test_bucket")

@mock.patch('airflow.providers.amazon.aws.sensors.s3.S3KeySizeSensor.get_files', return_value=[])
@mock.patch('airflow.providers.amazon.aws.sensors.s3.S3Hook.check_for_key', return_value=True)
def test_poke_get_files_false(self, mock_check_for_key, mock_get_files):
op = S3KeySizeSensor(task_id='s3_key_sensor', bucket_key='s3://test_bucket/file')
assert op.poke(None) is False
mock_check_for_key.assert_called_once_with(op.bucket_key, op.bucket_name)
mock_check_for_key.assert_called_once_with("file", "test_bucket")
mock_get_files.assert_called_once_with(s3_hook=op.get_hook())

@parameterized.expand(
Expand All @@ -159,7 +217,7 @@ def test_poke(self, paginate_return_value, poke_return_value, mock_check, mock_g
mock_get_conn.return_value.get_paginator.return_value = mock_paginator
mock_paginator.paginate.return_value = [paginate_return_value]
assert op.poke(None) is poke_return_value
mock_check.assert_called_once_with(op.bucket_key, op.bucket_name)
mock_check.assert_called_once_with("file", "test_bucket")

@mock.patch('airflow.providers.amazon.aws.sensors.s3.S3KeySizeSensor.get_files', return_value=[])
@mock.patch('airflow.providers.amazon.aws.sensors.s3.S3Hook.check_for_wildcard_key')
Expand All @@ -168,4 +226,4 @@ def test_poke_wildcard(self, mock_check, mock_get_files):

mock_check.return_value = False
assert op.poke(None) is False
mock_check.assert_called_once_with(op.bucket_key, op.bucket_name)
mock_check.assert_called_once_with("file", "test_bucket")
Loading