Skip to content

Commit

Permalink
iterate through blobs before checking prefixes (#36202)
Browse files Browse the repository at this point in the history
* fix(providers/google): iterate through blobs before checking prefixes

According to https://github.com/googleapis/python-storage/blob/v2.14.0/google/cloud/storage/client.py#L1213-L1217, the prefixes are not returned until the blobs are consumed

* test(providers/google): add test cases to check gcs.list result
  • Loading branch information
Lee-W authored Dec 14, 2023
1 parent 1431535 commit e83a986
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 13 deletions.
18 changes: 10 additions & 8 deletions airflow/providers/google/cloud/hooks/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,12 +821,13 @@ def _list(
delimiter=delimiter,
versions=versions,
)
list(blobs)

blob_names = [blob.name for blob in blobs]

if blobs.prefixes:
ids.extend(blobs.prefixes)
else:
ids.extend(blob.name for blob in blobs)
ids.extend(blob_names)

page_token = blobs.next_page_token
if page_token is None:
Expand Down Expand Up @@ -933,16 +934,17 @@ def list_by_timespan(
delimiter=delimiter,
versions=versions,
)
list(blobs)

blob_names = [
blob.name
for blob in blobs
if timespan_start <= blob.updated.replace(tzinfo=timezone.utc) < timespan_end
]

if blobs.prefixes:
ids.extend(blobs.prefixes)
else:
ids.extend(
blob.name
for blob in blobs
if timespan_start <= blob.updated.replace(tzinfo=timezone.utc) < timespan_end
)
ids.extend(blob_names)

page_token = blobs.next_page_token
if page_token is None:
Expand Down
42 changes: 37 additions & 5 deletions tests/providers/google/cloud/hooks/test_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import logging
import os
import re
from collections import namedtuple
from datetime import datetime, timedelta
from io import BytesIO
from unittest import mock
Expand Down Expand Up @@ -799,14 +800,26 @@ def test_provide_file_upload(self, mock_upload, mock_temp_file):
)

@pytest.mark.parametrize(
"prefix, result",
"prefix, blob_names, returned_prefixes, call_args, result",
(
(
"prefix",
["prefix"],
None,
[mock.call(delimiter=",", prefix="prefix", versions=None, max_results=None, page_token=None)],
["prefix"],
),
(
"prefix",
["prefix"],
{"prefix,"},
[mock.call(delimiter=",", prefix="prefix", versions=None, max_results=None, page_token=None)],
["prefix,"],
),
(
["prefix", "prefix_2"],
["prefix", "prefix2"],
None,
[
mock.call(
delimiter=",", prefix="prefix", versions=None, max_results=None, page_token=None
Expand All @@ -815,19 +828,38 @@ def test_provide_file_upload(self, mock_upload, mock_temp_file):
delimiter=",", prefix="prefix_2", versions=None, max_results=None, page_token=None
),
],
["prefix", "prefix2"],
),
),
)
@mock.patch(GCS_STRING.format("GCSHook.get_conn"))
def test_list__delimiter(self, mock_service, prefix, result):
mock_service.return_value.bucket.return_value.list_blobs.return_value.next_page_token = None
def test_list__delimiter(self, mock_service, prefix, blob_names, returned_prefixes, call_args, result):
Blob = namedtuple("Blob", ["name"])

class BlobsIterator:
def __init__(self):
self._item_iter = (Blob(name=name) for name in blob_names)

def __iter__(self):
return self

def __next__(self):
try:
return next(self._item_iter)
except StopIteration:
self.prefixes = returned_prefixes
self.next_page_token = None
raise

mock_service.return_value.bucket.return_value.list_blobs.return_value = BlobsIterator()
with pytest.deprecated_call():
self.gcs_hook.list(
blobs = self.gcs_hook.list(
bucket_name="test_bucket",
prefix=prefix,
delimiter=",",
)
assert mock_service.return_value.bucket.return_value.list_blobs.call_args_list == result
assert mock_service.return_value.bucket.return_value.list_blobs.call_args_list == call_args
assert blobs == result

@mock.patch(GCS_STRING.format("GCSHook.get_conn"))
@mock.patch("airflow.providers.google.cloud.hooks.gcs.functools")
Expand Down

0 comments on commit e83a986

Please sign in to comment.