Skip to content

Commit

Permalink
Only force loading if delimiter is specified
Browse files Browse the repository at this point in the history
When delimiter is passed into list_blobs, prefixes is not populated, loading is forced with next(blobs, None).

Do not attempt to iterate again over blobs if blobs.prefixes is empty and delimiter has been passed in.

This prevents the "iterator has already started" error message.
  • Loading branch information
atrbgithub committed Dec 12, 2023
1 parent 67453a6 commit 0728cfa
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
10 changes: 6 additions & 4 deletions airflow/providers/google/cloud/hooks/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,11 +821,12 @@ def _list(
delimiter=delimiter,
versions=versions,
)
list(blobs)
if delimiter:
next(blobs, None)

if blobs.prefixes:
ids.extend(blobs.prefixes)
else:
elif not delimiter or (match_glob and delimiter == "/"):
ids.extend(blob.name for blob in blobs)

page_token = blobs.next_page_token
Expand Down Expand Up @@ -933,11 +934,12 @@ def list_by_timespan(
delimiter=delimiter,
versions=versions,
)
list(blobs)
if delimiter:
next(blobs, None)

if blobs.prefixes:
ids.extend(blobs.prefixes)
else:
elif not delimiter or (match_glob and delimiter == "/"):
ids.extend(
blob.name
for blob in blobs
Expand Down
5 changes: 5 additions & 0 deletions tests/providers/google/cloud/hooks/test_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,13 +821,18 @@ def test_provide_file_upload(self, mock_upload, mock_temp_file):
@mock.patch(GCS_STRING.format("GCSHook.get_conn"))
def test_list__delimiter(self, mock_service, prefix, result):
mock_service.return_value.bucket.return_value.list_blobs.return_value.next_page_token = None
mock_service.return_value.bucket.return_value.list_blobs.return_value.prefixes = None
with pytest.deprecated_call():
self.gcs_hook.list(
bucket_name="test_bucket",
prefix=prefix,
delimiter=",",
)
assert mock_service.return_value.bucket.return_value.list_blobs.call_args_list == result
assert (
mock_service.return_value.bucket.return_value.list_blobs.return_value.__next__.call_count
== len(result)
)

@mock.patch(GCS_STRING.format("GCSHook.get_conn"))
@mock.patch("airflow.providers.google.cloud.hooks.gcs.functools")
Expand Down

0 comments on commit 0728cfa

Please sign in to comment.