From 0728cfa5c26a279b4a393af7fc749df4b1c65431 Mon Sep 17 00:00:00 2001 From: atrbgithub <14765982+atrbgithub@users.noreply.github.com> Date: Fri, 8 Dec 2023 18:56:23 +0000 Subject: [PATCH] Only force loading if delimiter is specified When delimiter is passed into list_blobs, prefixes is not populated, loading is forced with next(blobs, None). Do not attempt to iterate again over blobs if blobs.prefixes is empty and delimiter has been passed in. This prevents the "iterator has already started" error message. --- airflow/providers/google/cloud/hooks/gcs.py | 10 ++++++---- tests/providers/google/cloud/hooks/test_gcs.py | 5 +++++ 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/airflow/providers/google/cloud/hooks/gcs.py b/airflow/providers/google/cloud/hooks/gcs.py index 45a202124d942..9ce4cb288cf10 100644 --- a/airflow/providers/google/cloud/hooks/gcs.py +++ b/airflow/providers/google/cloud/hooks/gcs.py @@ -821,11 +821,12 @@ def _list( delimiter=delimiter, versions=versions, ) - list(blobs) + if delimiter: + next(blobs, None) if blobs.prefixes: ids.extend(blobs.prefixes) - else: + elif not delimiter or (match_glob and delimiter == "/"): ids.extend(blob.name for blob in blobs) page_token = blobs.next_page_token @@ -933,11 +934,12 @@ def list_by_timespan( delimiter=delimiter, versions=versions, ) - list(blobs) + if delimiter: + next(blobs, None) if blobs.prefixes: ids.extend(blobs.prefixes) - else: + elif not delimiter or (match_glob and delimiter == "/"): ids.extend( blob.name for blob in blobs diff --git a/tests/providers/google/cloud/hooks/test_gcs.py b/tests/providers/google/cloud/hooks/test_gcs.py index 33df98e37b002..bd70dd7aa0d90 100644 --- a/tests/providers/google/cloud/hooks/test_gcs.py +++ b/tests/providers/google/cloud/hooks/test_gcs.py @@ -821,6 +821,7 @@ def test_provide_file_upload(self, mock_upload, mock_temp_file): @mock.patch(GCS_STRING.format("GCSHook.get_conn")) def test_list__delimiter(self, mock_service, prefix, result): mock_service.return_value.bucket.return_value.list_blobs.return_value.next_page_token = None + mock_service.return_value.bucket.return_value.list_blobs.return_value.prefixes = None with pytest.deprecated_call(): self.gcs_hook.list( bucket_name="test_bucket", @@ -828,6 +829,10 @@ def test_list__delimiter(self, mock_service, prefix, result): delimiter=",", ) assert mock_service.return_value.bucket.return_value.list_blobs.call_args_list == result + assert ( + mock_service.return_value.bucket.return_value.list_blobs.return_value.__next__.call_count + == len(result) + ) @mock.patch(GCS_STRING.format("GCSHook.get_conn")) @mock.patch("airflow.providers.google.cloud.hooks.gcs.functools")