Skip to content

Commit

Permalink
Bug fix GCSToS3Operator: avoid ValueError when replace=False with…
Browse files Browse the repository at this point in the history
… files already in S3 (#32322)
  • Loading branch information
Adaverse authored Jul 4, 2023
1 parent d162518 commit 575bf2f
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 17 deletions.
5 changes: 5 additions & 0 deletions airflow/providers/amazon/aws/transfers/gcs_to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,11 @@ def execute(self, context: Context) -> list[str]:
# and only keep those files which are present in
# Google Cloud Storage and not in S3
bucket_name, prefix = S3Hook.parse_s3_url(self.dest_s3_key)
# if prefix is empty, do not add "/" at end since it would
# filter all the objects (return empty list) instead of empty
# prefix returning all the objects
if prefix:
prefix = prefix if prefix.endswith("/") else f"{prefix}/"
# look for the bucket and the prefix to avoid look into
# parent directories/keys
existing_files = s3_hook.list_keys(bucket_name, prefix=prefix)
Expand Down
69 changes: 52 additions & 17 deletions tests/providers/amazon/aws/transfers/test_gcs_to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
S3_BUCKET = "s3://bucket/"
MOCK_FILES = ["TEST1.csv", "TEST2.csv", "TEST3.csv"]
S3_ACL_POLICY = "private-read"
deprecated_call_match = "Usage of 'delimiter' is deprecated, please use 'match_glob' instead"


def _create_test_bucket():
Expand All @@ -47,8 +48,6 @@ def _create_test_bucket():

@mock_s3
class TestGCSToS3Operator:

# Test0: match_glob
@mock.patch("airflow.providers.amazon.aws.transfers.gcs_to_s3.GCSHook")
def test_execute__match_glob(self, mock_hook):
mock_hook.return_value.list.return_value = MOCK_FILES
Expand All @@ -73,15 +72,14 @@ def test_execute__match_glob(self, mock_hook):
bucket_name=GCS_BUCKET, delimiter=None, match_glob=f"**/*{DELIMITER}", prefix=PREFIX
)

# Test1: incremental behaviour (just some files missing)
@mock.patch("airflow.providers.amazon.aws.transfers.gcs_to_s3.GCSHook")
def test_execute_incremental(self, mock_hook):
mock_hook.return_value.list.return_value = MOCK_FILES
with NamedTemporaryFile() as f:
gcs_provide_file = mock_hook.return_value.provide_file
gcs_provide_file.return_value.__enter__.return_value.name = f.name

with pytest.deprecated_call():
with pytest.deprecated_call(match=deprecated_call_match):
operator = GCSToS3Operator(
task_id=TASK_ID,
bucket=GCS_BUCKET,
Expand All @@ -100,15 +98,17 @@ def test_execute_incremental(self, mock_hook):
assert sorted(MOCK_FILES[1:]) == sorted(uploaded_files)
assert sorted(MOCK_FILES) == sorted(hook.list_keys("bucket", delimiter="/"))

# Test2: All the files are already in origin and destination without replace
@mock.patch("airflow.providers.amazon.aws.transfers.gcs_to_s3.GCSHook")
def test_execute_without_replace(self, mock_hook):
"""
Tests scenario where all the files are already in origin and destination without replace
"""
mock_hook.return_value.list.return_value = MOCK_FILES
with NamedTemporaryFile() as f:
gcs_provide_file = mock_hook.return_value.provide_file
gcs_provide_file.return_value.__enter__.return_value.name = f.name

with pytest.deprecated_call():
with pytest.deprecated_call(match=deprecated_call_match):
operator = GCSToS3Operator(
task_id=TASK_ID,
bucket=GCS_BUCKET,
Expand All @@ -128,15 +128,53 @@ def test_execute_without_replace(self, mock_hook):
assert [] == uploaded_files
assert sorted(MOCK_FILES) == sorted(hook.list_keys("bucket", delimiter="/"))

# Test3: There are no files in destination bucket
@pytest.mark.parametrize(
argnames="dest_s3_url",
argvalues=[f"{S3_BUCKET}/test/", f"{S3_BUCKET}/test"],
)
@mock.patch("airflow.providers.amazon.aws.transfers.gcs_to_s3.GCSHook")
def test_execute_without_replace_with_folder_structure(self, mock_hook, dest_s3_url):
mock_files_gcs = [f"test{idx}/{mock_file}" for idx, mock_file in enumerate(MOCK_FILES)]
mock_files_s3 = [f"test/test{idx}/{mock_file}" for idx, mock_file in enumerate(MOCK_FILES)]
mock_hook.return_value.list.return_value = mock_files_gcs

hook, bucket = _create_test_bucket()
for mock_file_s3 in mock_files_s3:
bucket.put_object(Key=mock_file_s3, Body=b"testing")

with NamedTemporaryFile() as f:
gcs_provide_file = mock_hook.return_value.provide_file
gcs_provide_file.return_value.__enter__.return_value.name = f.name

with pytest.deprecated_call(match=deprecated_call_match):
operator = GCSToS3Operator(
task_id=TASK_ID,
bucket=GCS_BUCKET,
prefix=PREFIX,
delimiter=DELIMITER,
dest_aws_conn_id="aws_default",
dest_s3_key=dest_s3_url,
replace=False,
)

# we expect nothing to be uploaded
# and all the MOCK_FILES to be present at the S3 bucket
uploaded_files = operator.execute(None)

assert [] == uploaded_files
assert sorted(mock_files_s3) == sorted(hook.list_keys("bucket", prefix="test/"))

@mock.patch("airflow.providers.amazon.aws.transfers.gcs_to_s3.GCSHook")
def test_execute(self, mock_hook):
"""
Tests the scenario where there are no files in destination bucket
"""
mock_hook.return_value.list.return_value = MOCK_FILES
with NamedTemporaryFile() as f:
gcs_provide_file = mock_hook.return_value.provide_file
gcs_provide_file.return_value.__enter__.return_value.name = f.name

with pytest.deprecated_call():
with pytest.deprecated_call(match=deprecated_call_match):
operator = GCSToS3Operator(
task_id=TASK_ID,
bucket=GCS_BUCKET,
Expand All @@ -154,15 +192,14 @@ def test_execute(self, mock_hook):
assert sorted(MOCK_FILES) == sorted(uploaded_files)
assert sorted(MOCK_FILES) == sorted(hook.list_keys("bucket", delimiter="/"))

# Test4: Destination and Origin are in sync but replace all files in destination
@mock.patch("airflow.providers.amazon.aws.transfers.gcs_to_s3.GCSHook")
def test_execute_with_replace(self, mock_hook):
mock_hook.return_value.list.return_value = MOCK_FILES
with NamedTemporaryFile() as f:
gcs_provide_file = mock_hook.return_value.provide_file
gcs_provide_file.return_value.__enter__.return_value.name = f.name

with pytest.deprecated_call():
with pytest.deprecated_call(match=deprecated_call_match):
operator = GCSToS3Operator(
task_id=TASK_ID,
bucket=GCS_BUCKET,
Expand All @@ -182,15 +219,14 @@ def test_execute_with_replace(self, mock_hook):
assert sorted(MOCK_FILES) == sorted(uploaded_files)
assert sorted(MOCK_FILES) == sorted(hook.list_keys("bucket", delimiter="/"))

# Test5: Incremental sync with replace
@mock.patch("airflow.providers.amazon.aws.transfers.gcs_to_s3.GCSHook")
def test_execute_incremental_with_replace(self, mock_hook):
mock_hook.return_value.list.return_value = MOCK_FILES
with NamedTemporaryFile() as f:
gcs_provide_file = mock_hook.return_value.provide_file
gcs_provide_file.return_value.__enter__.return_value.name = f.name

with pytest.deprecated_call():
with pytest.deprecated_call(match=deprecated_call_match):
operator = GCSToS3Operator(
task_id=TASK_ID,
bucket=GCS_BUCKET,
Expand Down Expand Up @@ -218,7 +254,7 @@ def test_execute_should_handle_with_default_dest_s3_extra_args(self, s3_mock_hoo
s3_mock_hook.return_value = mock.Mock()
s3_mock_hook.parse_s3_url.return_value = mock.Mock()

with pytest.deprecated_call():
with pytest.deprecated_call(match=deprecated_call_match):
operator = GCSToS3Operator(
task_id=TASK_ID,
bucket=GCS_BUCKET,
Expand All @@ -241,7 +277,7 @@ def test_execute_should_pass_dest_s3_extra_args_to_s3_hook(self, s3_mock_hook, m
s3_mock_hook.return_value = mock.Mock()
s3_mock_hook.parse_s3_url.return_value = mock.Mock()

with pytest.deprecated_call():
with pytest.deprecated_call(match=deprecated_call_match):
operator = GCSToS3Operator(
task_id=TASK_ID,
bucket=GCS_BUCKET,
Expand All @@ -259,7 +295,6 @@ def test_execute_should_pass_dest_s3_extra_args_to_s3_hook(self, s3_mock_hook, m
aws_conn_id="aws_default", extra_args={"ContentLanguage": "value"}, verify=None
)

# Test6: s3_acl_policy parameter is set
@mock.patch("airflow.providers.amazon.aws.transfers.gcs_to_s3.GCSHook")
@mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.load_file")
def test_execute_with_s3_acl_policy(self, mock_load_file, mock_gcs_hook):
Expand All @@ -268,7 +303,7 @@ def test_execute_with_s3_acl_policy(self, mock_load_file, mock_gcs_hook):
gcs_provide_file = mock_gcs_hook.return_value.provide_file
gcs_provide_file.return_value.__enter__.return_value.name = f.name

with pytest.deprecated_call():
with pytest.deprecated_call(match=deprecated_call_match):
operator = GCSToS3Operator(
task_id=TASK_ID,
bucket=GCS_BUCKET,
Expand All @@ -293,7 +328,7 @@ def test_execute_without_keep_director_structure(self, mock_hook):
gcs_provide_file = mock_hook.return_value.provide_file
gcs_provide_file.return_value.__enter__.return_value.name = f.name

with pytest.deprecated_call():
with pytest.deprecated_call(match=deprecated_call_match):
operator = GCSToS3Operator(
task_id=TASK_ID,
bucket=GCS_BUCKET,
Expand Down

0 comments on commit 575bf2f

Please sign in to comment.