From 95a930bc0a720c5548e4fa2e1f74e25f12e9ae1d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miroslav=20=C5=A0ediv=C3=BD?= <6774676+eumiro@users.noreply.github.com> Date: Mon, 21 Aug 2023 05:36:29 +0000 Subject: [PATCH] Consolidate import and usage of itertools (#33479) --- airflow/configuration.py | 6 +++--- airflow/decorators/base.py | 4 ++-- airflow/lineage/__init__.py | 3 +-- airflow/providers/amazon/aws/hooks/batch_client.py | 4 ++-- airflow/providers/amazon/aws/triggers/batch.py | 4 ++-- .../providers/cncf/kubernetes/utils/pod_manager.py | 4 ++-- airflow/utils/helpers.py | 6 +++--- airflow/www/decorators.py | 6 +++--- dev/check_files.py | 4 ++-- docs/build_docs.py | 6 +++--- docs/exts/docs_build/fetch_inventories.py | 4 ++-- docs/exts/docs_build/lint_checks.py | 6 +++--- .../pre_commit_check_deferrable_default.py | 2 +- .../pre_commit_sort_installed_providers.py | 3 +-- .../pre_commit_sort_spelling_wordlist.py | 3 +-- .../in_container/run_provider_yaml_files_check.py | 14 +++++++------- tests/always/test_project_structure.py | 10 ++++------ tests/models/test_taskmixin.py | 10 +++++++--- .../amazon/aws/hooks/test_batch_waiters.py | 8 +++----- .../apache/hive/transfers/test_s3_to_hive.py | 6 +++--- .../providers/apache/spark/hooks/test_spark_sql.py | 4 ++-- .../test_cloud_storage_transfer_service.py | 9 ++++----- tests/system/conftest.py | 4 ++-- tests/utils/test_helpers.py | 6 +++--- 24 files changed, 66 insertions(+), 70 deletions(-) diff --git a/airflow/configuration.py b/airflow/configuration.py index 8f6f703ed6a4..3c9c6975eabf 100644 --- a/airflow/configuration.py +++ b/airflow/configuration.py @@ -19,7 +19,7 @@ import datetime import functools import io -import itertools as it +import itertools import json import logging import multiprocessing @@ -473,7 +473,7 @@ def get_sections_including_defaults(self) -> list[str]: :return: list of section names """ - return list(dict.fromkeys(it.chain(self.configuration_description, self.sections()))) + return list(dict.fromkeys(itertools.chain(self.configuration_description, self.sections()))) def get_options_including_defaults(self, section: str) -> list[str]: """ @@ -485,7 +485,7 @@ def get_options_including_defaults(self, section: str) -> list[str]: """ my_own_options = self.options(section) if self.has_section(section) else [] all_options_from_defaults = self.configuration_description.get(section, {}).get("options", {}) - return list(dict.fromkeys(it.chain(all_options_from_defaults, my_own_options))) + return list(dict.fromkeys(itertools.chain(all_options_from_defaults, my_own_options))) def optionxform(self, optionstr: str) -> str: """ diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py index c4a4b0ed6143..750e1fa1e760 100644 --- a/airflow/decorators/base.py +++ b/airflow/decorators/base.py @@ -17,9 +17,9 @@ from __future__ import annotations import inspect +import itertools import warnings from functools import cached_property -from itertools import chain from textwrap import dedent from typing import ( Any, @@ -226,7 +226,7 @@ def __init__( def execute(self, context: Context): # todo make this more generic (move to prepare_lineage) so it deals with non taskflow operators # as well - for arg in chain(self.op_args, self.op_kwargs.values()): + for arg in itertools.chain(self.op_args, self.op_kwargs.values()): if isinstance(arg, Dataset): self.inlets.append(arg) return_value = super().execute(context) diff --git a/airflow/lineage/__init__.py b/airflow/lineage/__init__.py index a2fcdf4ed5cd..4843da12fcb7 100644 --- a/airflow/lineage/__init__.py +++ b/airflow/lineage/__init__.py @@ -18,7 +18,6 @@ """Provides lineage support functions.""" from __future__ import annotations -import itertools import logging from functools import wraps from typing import TYPE_CHECKING, Any, Callable, TypeVar, cast @@ -142,7 +141,7 @@ def wrapper(self, context, *args, **kwargs): _inlets = self.xcom_pull( context, task_ids=task_ids, dag_id=self.dag_id, key=PIPELINE_OUTLETS, session=session ) - self.inlets.extend(itertools.chain.from_iterable(_inlets)) + self.inlets.extend(i for it in _inlets for i in it) elif self.inlets: raise AttributeError("inlets is not a list, operator, string or attr annotated object") diff --git a/airflow/providers/amazon/aws/hooks/batch_client.py b/airflow/providers/amazon/aws/hooks/batch_client.py index 74dbef1eac3d..26304ed3673e 100644 --- a/airflow/providers/amazon/aws/hooks/batch_client.py +++ b/airflow/providers/amazon/aws/hooks/batch_client.py @@ -26,7 +26,7 @@ """ from __future__ import annotations -import itertools as it +import itertools from random import uniform from time import sleep from typing import Callable @@ -488,7 +488,7 @@ def get_job_all_awslogs_info(self, job_id: str) -> list[dict[str, str]]: # cross stream names with options (i.e. attempts X nodes) to generate all log infos result = [] - for stream, option in it.product(stream_names, log_options): + for stream, option in itertools.product(stream_names, log_options): result.append( { "awslogs_stream_name": stream, diff --git a/airflow/providers/amazon/aws/triggers/batch.py b/airflow/providers/amazon/aws/triggers/batch.py index 774ce3c4bfa1..900040afa231 100644 --- a/airflow/providers/amazon/aws/triggers/batch.py +++ b/airflow/providers/amazon/aws/triggers/batch.py @@ -17,7 +17,7 @@ from __future__ import annotations import asyncio -import itertools as it +import itertools from functools import cached_property from typing import Any @@ -162,7 +162,7 @@ async def run(self): """ async with self.hook.async_conn as client: waiter = self.hook.get_waiter("batch_job_complete", deferrable=True, client=client) - for attempt in it.count(1): + for attempt in itertools.count(1): try: await waiter.wait( jobs=[self.job_id], diff --git a/airflow/providers/cncf/kubernetes/utils/pod_manager.py b/airflow/providers/cncf/kubernetes/utils/pod_manager.py index 139befdbff6a..1c2e0ab59719 100644 --- a/airflow/providers/cncf/kubernetes/utils/pod_manager.py +++ b/airflow/providers/cncf/kubernetes/utils/pod_manager.py @@ -18,7 +18,7 @@ from __future__ import annotations import enum -import itertools as it +import itertools import json import logging import math @@ -628,7 +628,7 @@ def read_pod(self, pod: V1Pod) -> V1Pod: def await_xcom_sidecar_container_start(self, pod: V1Pod) -> None: self.log.info("Checking if xcom sidecar container is started.") - for attempt in it.count(): + for attempt in itertools.count(): if self.container_is_running(pod, PodDefaults.SIDECAR_CONTAINER_NAME): self.log.info("The xcom sidecar container is started.") break diff --git a/airflow/utils/helpers.py b/airflow/utils/helpers.py index e07608030d42..e55a8e00442e 100644 --- a/airflow/utils/helpers.py +++ b/airflow/utils/helpers.py @@ -18,12 +18,12 @@ from __future__ import annotations import copy +import itertools import re import signal import warnings from datetime import datetime from functools import reduce -from itertools import filterfalse, tee from typing import TYPE_CHECKING, Any, Callable, Generator, Iterable, Mapping, MutableMapping, TypeVar, cast from lazy_object_proxy import Proxy @@ -216,8 +216,8 @@ def merge_dicts(dict1: dict, dict2: dict) -> dict: def partition(pred: Callable[[T], bool], iterable: Iterable[T]) -> tuple[Iterable[T], Iterable[T]]: """Use a predicate to partition entries into false entries and true entries.""" - iter_1, iter_2 = tee(iterable) - return filterfalse(pred, iter_1), filter(pred, iter_2) + iter_1, iter_2 = itertools.tee(iterable) + return itertools.filterfalse(pred, iter_1), filter(pred, iter_2) def chain(*args, **kwargs): diff --git a/airflow/www/decorators.py b/airflow/www/decorators.py index 975910fe5077..94c1c34921ce 100644 --- a/airflow/www/decorators.py +++ b/airflow/www/decorators.py @@ -19,10 +19,10 @@ import functools import gzip +import itertools import json import logging from io import BytesIO as IO -from itertools import chain from typing import Callable, TypeVar, cast import pendulum @@ -94,7 +94,7 @@ def wrapper(*args, **kwargs): fields_skip_logging = {"csrf_token", "_csrf_token"} extra_fields = [ (k, secrets_masker.redact(v, k)) - for k, v in chain(request.values.items(multi=True), request.view_args.items()) + for k, v in itertools.chain(request.values.items(multi=True), request.view_args.items()) if k not in fields_skip_logging ] if event and event.startswith("variable."): @@ -102,7 +102,7 @@ def wrapper(*args, **kwargs): if event and event.startswith("connection."): extra_fields = _mask_connection_fields(extra_fields) - params = {k: v for k, v in chain(request.values.items(), request.view_args.items())} + params = {k: v for k, v in itertools.chain(request.values.items(), request.view_args.items())} log = Log( event=event or f.__name__, diff --git a/dev/check_files.py b/dev/check_files.py index 52260c1da293..f50875cc015c 100644 --- a/dev/check_files.py +++ b/dev/check_files.py @@ -16,9 +16,9 @@ # under the License. from __future__ import annotations +import itertools import os import re -from itertools import product import rich_click as click from rich import print @@ -141,7 +141,7 @@ def check_release(files: list[str], version: str): def expand_name_variations(files): - return sorted(base + suffix for base, suffix in product(files, ["", ".asc", ".sha512"])) + return sorted(base + suffix for base, suffix in itertools.product(files, ["", ".asc", ".sha512"])) def check_upgrade_check(files: list[str], version: str): diff --git a/docs/build_docs.py b/docs/build_docs.py index 84ddb4d22cac..417854c8d188 100755 --- a/docs/build_docs.py +++ b/docs/build_docs.py @@ -23,11 +23,11 @@ from __future__ import annotations import argparse +import itertools import multiprocessing import os import sys from collections import defaultdict -from itertools import filterfalse, tee from typing import Callable, Iterable, NamedTuple, TypeVar from rich.console import Console @@ -74,8 +74,8 @@ def partition(pred: Callable[[T], bool], iterable: Iterable[T]) -> tuple[Iterable[T], Iterable[T]]: """Use a predicate to partition entries into false entries and true entries""" - iter_1, iter_2 = tee(iterable) - return filterfalse(pred, iter_1), filter(pred, iter_2) + iter_1, iter_2 = itertools.tee(iterable) + return itertools.filterfalse(pred, iter_1), filter(pred, iter_2) def _promote_new_flags(): diff --git a/docs/exts/docs_build/fetch_inventories.py b/docs/exts/docs_build/fetch_inventories.py index 6db368f0d0b5..9576a82b32c4 100644 --- a/docs/exts/docs_build/fetch_inventories.py +++ b/docs/exts/docs_build/fetch_inventories.py @@ -19,11 +19,11 @@ import concurrent import concurrent.futures import datetime +import itertools import os import shutil import sys import traceback -from itertools import repeat from tempfile import NamedTemporaryFile from typing import Iterator @@ -142,7 +142,7 @@ def fetch_inventories(): with requests.Session() as session, concurrent.futures.ThreadPoolExecutor(DEFAULT_POOLSIZE) as pool: download_results: Iterator[tuple[str, bool]] = pool.map( _fetch_file, - repeat(session, len(to_download)), + itertools.repeat(session, len(to_download)), (pkg_name for pkg_name, _, _ in to_download), (url for _, url, _ in to_download), (path for _, _, path in to_download), diff --git a/docs/exts/docs_build/lint_checks.py b/docs/exts/docs_build/lint_checks.py index ef254e097625..e536feb68088 100644 --- a/docs/exts/docs_build/lint_checks.py +++ b/docs/exts/docs_build/lint_checks.py @@ -17,10 +17,10 @@ from __future__ import annotations import ast +import itertools import os import re from glob import glob -from itertools import chain from typing import Iterable from docs.exts.docs_build.docs_builder import ALL_PROVIDER_YAMLS @@ -87,7 +87,7 @@ def check_guide_links_in_operator_descriptions() -> list[DocBuildError]: operator_names=find_existing_guide_operator_names( f"{DOCS_DIR}/apache-airflow/howto/operator/**/*.rst" ), - python_module_paths=chain( + python_module_paths=itertools.chain( glob(f"{ROOT_PACKAGE_DIR}/operators/*.py"), glob(f"{ROOT_PACKAGE_DIR}/sensors/*.py"), ), @@ -101,7 +101,7 @@ def check_guide_links_in_operator_descriptions() -> list[DocBuildError]: } # Extract all potential python modules that can contain operators - python_module_paths = chain( + python_module_paths = itertools.chain( glob(f"{provider['package-dir']}/**/operators/*.py", recursive=True), glob(f"{provider['package-dir']}/**/sensors/*.py", recursive=True), glob(f"{provider['package-dir']}/**/transfers/*.py", recursive=True), diff --git a/scripts/ci/pre_commit/pre_commit_check_deferrable_default.py b/scripts/ci/pre_commit/pre_commit_check_deferrable_default.py index 784d25d5221e..8373385f0d7f 100755 --- a/scripts/ci/pre_commit/pre_commit_check_deferrable_default.py +++ b/scripts/ci/pre_commit/pre_commit_check_deferrable_default.py @@ -76,7 +76,7 @@ def iter_check_deferrable_default_errors(module_filename: str) -> Iterator[str]: args = node.args arguments = reversed([*args.args, *args.kwonlyargs]) defaults = reversed([*args.defaults, *args.kw_defaults]) - for argument, default in itertools.zip_longest(arguments, defaults, fillvalue=None): + for argument, default in zip(arguments, defaults): if argument is None or default is None: continue if argument.arg != "deferrable" or _is_valid_deferrable_default(default): diff --git a/scripts/ci/pre_commit/pre_commit_sort_installed_providers.py b/scripts/ci/pre_commit/pre_commit_sort_installed_providers.py index e2bd0e2921cb..7ab17c5dd80e 100755 --- a/scripts/ci/pre_commit/pre_commit_sort_installed_providers.py +++ b/scripts/ci/pre_commit/pre_commit_sort_installed_providers.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import itertools from pathlib import Path if __name__ not in ("__main__", "__mp_main__"): @@ -35,7 +34,7 @@ def stable_sort(x): def sort_uniq(sequence): - return (x[0] for x in itertools.groupby(sorted(sequence, key=stable_sort))) + return sorted(set(sequence), key=stable_sort) if __name__ == "__main__": diff --git a/scripts/ci/pre_commit/pre_commit_sort_spelling_wordlist.py b/scripts/ci/pre_commit/pre_commit_sort_spelling_wordlist.py index f9eb8b4a06ed..41d7a3ce428d 100755 --- a/scripts/ci/pre_commit/pre_commit_sort_spelling_wordlist.py +++ b/scripts/ci/pre_commit/pre_commit_sort_spelling_wordlist.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import itertools from pathlib import Path if __name__ not in ("__main__", "__mp_main__"): @@ -35,7 +34,7 @@ def stable_sort(x): def sort_uniq(sequence): - return (x[0] for x in itertools.groupby(sorted(sequence, key=stable_sort))) + return sorted(set(sequence), key=stable_sort) if __name__ == "__main__": diff --git a/scripts/in_container/run_provider_yaml_files_check.py b/scripts/in_container/run_provider_yaml_files_check.py index ae523eb4babc..3b5b3211352c 100755 --- a/scripts/in_container/run_provider_yaml_files_check.py +++ b/scripts/in_container/run_provider_yaml_files_check.py @@ -19,6 +19,7 @@ import importlib import inspect +import itertools import json import os import pathlib @@ -27,7 +28,6 @@ import textwrap from collections import Counter from enum import Enum -from itertools import chain, product from typing import Any, Iterable import jsonschema @@ -219,7 +219,7 @@ def check_if_objects_exist_and_belong_to_package( def parse_module_data(provider_data, resource_type, yaml_file_path): package_dir = ROOT_DIR.joinpath(yaml_file_path).parent provider_package = pathlib.Path(yaml_file_path).parent.as_posix().replace("/", ".") - py_files = chain( + py_files = itertools.chain( package_dir.glob(f"**/{resource_type}/*.py"), package_dir.glob(f"{resource_type}/*.py"), package_dir.glob(f"**/{resource_type}/**/*.py"), @@ -233,7 +233,7 @@ def parse_module_data(provider_data, resource_type, yaml_file_path): def check_correctness_of_list_of_sensors_operators_hook_modules(yaml_files: dict[str, dict]): print("Checking completeness of list of {sensors, hooks, operators, triggers}") print(" -- {sensors, hooks, operators, triggers} - Expected modules (left) : Current modules (right)") - for (yaml_file_path, provider_data), resource_type in product( + for (yaml_file_path, provider_data), resource_type in itertools.product( yaml_files.items(), ["sensors", "operators", "hooks", "triggers"] ): expected_modules, provider_package, resource_data = parse_module_data( @@ -257,7 +257,7 @@ def check_correctness_of_list_of_sensors_operators_hook_modules(yaml_files: dict def check_duplicates_in_integrations_names_of_hooks_sensors_operators(yaml_files: dict[str, dict]): print("Checking for duplicates in list of {sensors, hooks, operators, triggers}") - for (yaml_file_path, provider_data), resource_type in product( + for (yaml_file_path, provider_data), resource_type in itertools.product( yaml_files.items(), ["sensors", "operators", "hooks", "triggers"] ): resource_data = provider_data.get(resource_type, []) @@ -362,7 +362,7 @@ def check_invalid_integration(yaml_files: dict[str, dict]): print("Detect unregistered integrations") all_integration_names = set(get_all_integration_names(yaml_files)) - for (yaml_file_path, provider_data), resource_type in product( + for (yaml_file_path, provider_data), resource_type in itertools.product( yaml_files.items(), ["sensors", "operators", "hooks", "triggers"] ): resource_data = provider_data.get(resource_type, []) @@ -374,7 +374,7 @@ def check_invalid_integration(yaml_files: dict[str, dict]): f"Invalid values: {invalid_names}" ) - for (yaml_file_path, provider_data), key in product( + for (yaml_file_path, provider_data), key in itertools.product( yaml_files.items(), ["source-integration-name", "target-integration-name"] ): resource_data = provider_data.get("transfers", []) @@ -409,7 +409,7 @@ def check_doc_files(yaml_files: dict[str, dict]): console.print("[yellow]Suspended providers:[/]") console.print(suspended_providers) - expected_doc_files = chain( + expected_doc_files = itertools.chain( DOCS_DIR.glob("apache-airflow-providers-*/operators/**/*.rst"), DOCS_DIR.glob("apache-airflow-providers-*/transfer/**/*.rst"), ) diff --git a/tests/always/test_project_structure.py b/tests/always/test_project_structure.py index 116f8f99d2d5..a518c9f3d2be 100644 --- a/tests/always/test_project_structure.py +++ b/tests/always/test_project_structure.py @@ -455,12 +455,10 @@ class TestDockerProviderProjectStructure(ExampleCoverageTest): class TestOperatorsHooks: def test_no_illegal_suffixes(self): illegal_suffixes = ["_operator.py", "_hook.py", "_sensor.py"] - files = itertools.chain( - *( - glob.glob(f"{ROOT_FOLDER}/{part}/providers/**/{resource_type}/*.py", recursive=True) - for resource_type in ["operators", "hooks", "sensors", "example_dags"] - for part in ["airflow", "tests"] - ) + files = itertools.chain.from_iterable( + glob.glob(f"{ROOT_FOLDER}/{part}/providers/**/{resource_type}/*.py", recursive=True) + for resource_type in ["operators", "hooks", "sensors", "example_dags"] + for part in ["airflow", "tests"] ) invalid_files = [f for f in files if f.endswith(tuple(illegal_suffixes))] diff --git a/tests/models/test_taskmixin.py b/tests/models/test_taskmixin.py index 95aefd0faaa1..2435d6711a76 100644 --- a/tests/models/test_taskmixin.py +++ b/tests/models/test_taskmixin.py @@ -17,7 +17,7 @@ from __future__ import annotations -from itertools import product +import itertools import pytest @@ -67,7 +67,9 @@ def my_task(): return my_task.override(task_id=name)() -@pytest.mark.parametrize("setup_type, work_type, teardown_type", product(*3 * [["classic", "taskflow"]])) +@pytest.mark.parametrize( + "setup_type, work_type, teardown_type", itertools.product(["classic", "taskflow"], repeat=3) +) def test_as_teardown(dag_maker, setup_type, work_type, teardown_type): """ Check that as_teardown works properly as implemented in PlainXComArg @@ -98,7 +100,9 @@ def test_as_teardown(dag_maker, setup_type, work_type, teardown_type): assert get_task_attr(t1, "upstream_task_ids") == {"w1", "s1"} -@pytest.mark.parametrize("setup_type, work_type, teardown_type", product(*3 * [["classic", "taskflow"]])) +@pytest.mark.parametrize( + "setup_type, work_type, teardown_type", itertools.product(["classic", "taskflow"], repeat=3) +) def test_as_teardown_oneline(dag_maker, setup_type, work_type, teardown_type): """ Check that as_teardown implementations work properly. Tests all combinations of taskflow and classic. diff --git a/tests/providers/amazon/aws/hooks/test_batch_waiters.py b/tests/providers/amazon/aws/hooks/test_batch_waiters.py index cdf581417c4f..206ce6885709 100644 --- a/tests/providers/amazon/aws/hooks/test_batch_waiters.py +++ b/tests/providers/amazon/aws/hooks/test_batch_waiters.py @@ -274,11 +274,9 @@ def test_job_running_waiter_change_to_waited_state(self, status): self.mock_describe_jobs.side_effect = [ # Emulate change job status before one of expected states. # SUBMITTED -> PENDING -> RUNNABLE -> STARTING - *itertools.chain( - *[ - itertools.repeat(self.describe_jobs_response(job_id=job_id, status=inter_status), 3) - for inter_status in INTERMEDIATE_STATES - ] + *itertools.chain.from_iterable( + itertools.repeat(self.describe_jobs_response(job_id=job_id, status=inter_status), 3) + for inter_status in INTERMEDIATE_STATES ), # Expected status self.describe_jobs_response(job_id=job_id, status=status), diff --git a/tests/providers/apache/hive/transfers/test_s3_to_hive.py b/tests/providers/apache/hive/transfers/test_s3_to_hive.py index c84a78828ea8..3f674ec3fa54 100644 --- a/tests/providers/apache/hive/transfers/test_s3_to_hive.py +++ b/tests/providers/apache/hive/transfers/test_s3_to_hive.py @@ -20,10 +20,10 @@ import bz2 import errno import filecmp +import itertools import logging import shutil from gzip import GzipFile -from itertools import product from tempfile import NamedTemporaryFile, mkdtemp from unittest import mock @@ -204,7 +204,7 @@ def test_execute(self, mock_hiveclihook): ) # Testing txt, zip, bz2 files with and without header row - for (ext, has_header) in product([".txt", ".gz", ".bz2", ".GZ"], [True, False]): + for ext, has_header in itertools.product([".txt", ".gz", ".bz2", ".GZ"], [True, False]): self.kwargs["headers"] = has_header self.kwargs["check_headers"] = has_header logging.info("Testing %s format %s header", ext, "with" if has_header else "without") @@ -242,7 +242,7 @@ def test_execute_with_select_expression(self, mock_hiveclihook): # Only testing S3ToHiveTransfer calls S3Hook.select_key with # the right parameters and its execute method succeeds here, # since Moto doesn't support select_object_content as of 1.3.2. - for (ext, has_header) in product([".txt", ".gz", ".GZ"], [True, False]): + for ext, has_header in itertools.product([".txt", ".gz", ".GZ"], [True, False]): input_compressed = ext.lower() != ".txt" key = self.s3_key + ext diff --git a/tests/providers/apache/spark/hooks/test_spark_sql.py b/tests/providers/apache/spark/hooks/test_spark_sql.py index 1666c51946e6..9bd46e6ce3ae 100644 --- a/tests/providers/apache/spark/hooks/test_spark_sql.py +++ b/tests/providers/apache/spark/hooks/test_spark_sql.py @@ -18,7 +18,7 @@ from __future__ import annotations import io -from itertools import dropwhile +import itertools from unittest.mock import call, patch import pytest @@ -32,7 +32,7 @@ def get_after(sentinel, iterable): """Get the value after `sentinel` in an `iterable`""" - truncated = dropwhile(lambda el: el != sentinel, iterable) + truncated = itertools.dropwhile(lambda el: el != sentinel, iterable) next(truncated) return next(truncated) diff --git a/tests/providers/google/cloud/operators/test_cloud_storage_transfer_service.py b/tests/providers/google/cloud/operators/test_cloud_storage_transfer_service.py index 7684d8476cf8..3b6325151020 100644 --- a/tests/providers/google/cloud/operators/test_cloud_storage_transfer_service.py +++ b/tests/providers/google/cloud/operators/test_cloud_storage_transfer_service.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import itertools from copy import deepcopy from datetime import date, time from unittest import mock @@ -220,10 +219,10 @@ def test_should_raise_exception_when_body_empty(self): @pytest.mark.parametrize( "transfer_spec", [ - dict(itertools.chain(SOURCE_AWS.items(), SOURCE_GCS.items(), SOURCE_HTTP.items())), - dict(itertools.chain(SOURCE_AWS.items(), SOURCE_GCS.items())), - dict(itertools.chain(SOURCE_AWS.items(), SOURCE_HTTP.items())), - dict(itertools.chain(SOURCE_GCS.items(), SOURCE_HTTP.items())), + {**SOURCE_AWS, **SOURCE_GCS, **SOURCE_HTTP}, + {**SOURCE_AWS, **SOURCE_GCS}, + {**SOURCE_AWS, **SOURCE_HTTP}, + {**SOURCE_GCS, **SOURCE_HTTP}, ], ) def test_verify_data_source(self, transfer_spec): diff --git a/tests/system/conftest.py b/tests/system/conftest.py index 154e7c208f81..58eca1287c32 100644 --- a/tests/system/conftest.py +++ b/tests/system/conftest.py @@ -16,9 +16,9 @@ # under the License. from __future__ import annotations +import itertools import os import re -from itertools import chain from pathlib import Path from unittest import mock @@ -41,7 +41,7 @@ def provider_env_vars(): @pytest.fixture(autouse=True) def skip_if_env_var_not_set(provider_env_vars): - for env in chain(REQUIRED_ENV_VARS, provider_env_vars): + for env in itertools.chain(REQUIRED_ENV_VARS, provider_env_vars): if env not in os.environ: pytest.skip(f"Missing required environment variable {env}") return diff --git a/tests/utils/test_helpers.py b/tests/utils/test_helpers.py index 9d6020874c86..c3c370060a66 100644 --- a/tests/utils/test_helpers.py +++ b/tests/utils/test_helpers.py @@ -17,8 +17,8 @@ # under the License. from __future__ import annotations +import itertools import re -from itertools import product import pytest @@ -264,7 +264,7 @@ def assert_exactly_one(true=0, truthy=0, false=0, falsy=0): expected = True if true + truthy == 1 else False assert exactly_one(*sample) is expected - for row in product(range(4), range(4), range(4), range(4)): + for row in itertools.product(range(4), repeat=4): assert_exactly_one(*row) def test_exactly_one_should_fail(self): @@ -295,7 +295,7 @@ def assert_at_most_one(true=0, truthy=0, false=0, falsy=0, notset=0): expected = True if true + truthy in (0, 1) else False assert at_most_one(*sample) is expected - for row in product(range(4), range(4), range(4), range(4), range(4)): + for row in itertools.product(range(4), repeat=4): print(row) assert_at_most_one(*row)