Skip to content

Commit

Permalink
Consolidate import and usage of itertools (#33479)
Browse files Browse the repository at this point in the history
  • Loading branch information
eumiro authored Aug 21, 2023
1 parent 1cdd823 commit 95a930b
Show file tree
Hide file tree
Showing 24 changed files with 66 additions and 70 deletions.
6 changes: 3 additions & 3 deletions airflow/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import datetime
import functools
import io
import itertools as it
import itertools
import json
import logging
import multiprocessing
Expand Down Expand Up @@ -473,7 +473,7 @@ def get_sections_including_defaults(self) -> list[str]:
:return: list of section names
"""
return list(dict.fromkeys(it.chain(self.configuration_description, self.sections())))
return list(dict.fromkeys(itertools.chain(self.configuration_description, self.sections())))

def get_options_including_defaults(self, section: str) -> list[str]:
"""
Expand All @@ -485,7 +485,7 @@ def get_options_including_defaults(self, section: str) -> list[str]:
"""
my_own_options = self.options(section) if self.has_section(section) else []
all_options_from_defaults = self.configuration_description.get(section, {}).get("options", {})
return list(dict.fromkeys(it.chain(all_options_from_defaults, my_own_options)))
return list(dict.fromkeys(itertools.chain(all_options_from_defaults, my_own_options)))

def optionxform(self, optionstr: str) -> str:
"""
Expand Down
4 changes: 2 additions & 2 deletions airflow/decorators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
from __future__ import annotations

import inspect
import itertools
import warnings
from functools import cached_property
from itertools import chain
from textwrap import dedent
from typing import (
Any,
Expand Down Expand Up @@ -226,7 +226,7 @@ def __init__(
def execute(self, context: Context):
# todo make this more generic (move to prepare_lineage) so it deals with non taskflow operators
# as well
for arg in chain(self.op_args, self.op_kwargs.values()):
for arg in itertools.chain(self.op_args, self.op_kwargs.values()):
if isinstance(arg, Dataset):
self.inlets.append(arg)
return_value = super().execute(context)
Expand Down
3 changes: 1 addition & 2 deletions airflow/lineage/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
"""Provides lineage support functions."""
from __future__ import annotations

import itertools
import logging
from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, TypeVar, cast
Expand Down Expand Up @@ -142,7 +141,7 @@ def wrapper(self, context, *args, **kwargs):
_inlets = self.xcom_pull(
context, task_ids=task_ids, dag_id=self.dag_id, key=PIPELINE_OUTLETS, session=session
)
self.inlets.extend(itertools.chain.from_iterable(_inlets))
self.inlets.extend(i for it in _inlets for i in it)

elif self.inlets:
raise AttributeError("inlets is not a list, operator, string or attr annotated object")
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/amazon/aws/hooks/batch_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
"""
from __future__ import annotations

import itertools as it
import itertools
from random import uniform
from time import sleep
from typing import Callable
Expand Down Expand Up @@ -488,7 +488,7 @@ def get_job_all_awslogs_info(self, job_id: str) -> list[dict[str, str]]:

# cross stream names with options (i.e. attempts X nodes) to generate all log infos
result = []
for stream, option in it.product(stream_names, log_options):
for stream, option in itertools.product(stream_names, log_options):
result.append(
{
"awslogs_stream_name": stream,
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/amazon/aws/triggers/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from __future__ import annotations

import asyncio
import itertools as it
import itertools
from functools import cached_property
from typing import Any

Expand Down Expand Up @@ -162,7 +162,7 @@ async def run(self):
"""
async with self.hook.async_conn as client:
waiter = self.hook.get_waiter("batch_job_complete", deferrable=True, client=client)
for attempt in it.count(1):
for attempt in itertools.count(1):
try:
await waiter.wait(
jobs=[self.job_id],
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/cncf/kubernetes/utils/pod_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from __future__ import annotations

import enum
import itertools as it
import itertools
import json
import logging
import math
Expand Down Expand Up @@ -628,7 +628,7 @@ def read_pod(self, pod: V1Pod) -> V1Pod:

def await_xcom_sidecar_container_start(self, pod: V1Pod) -> None:
self.log.info("Checking if xcom sidecar container is started.")
for attempt in it.count():
for attempt in itertools.count():
if self.container_is_running(pod, PodDefaults.SIDECAR_CONTAINER_NAME):
self.log.info("The xcom sidecar container is started.")
break
Expand Down
6 changes: 3 additions & 3 deletions airflow/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@
from __future__ import annotations

import copy
import itertools
import re
import signal
import warnings
from datetime import datetime
from functools import reduce
from itertools import filterfalse, tee
from typing import TYPE_CHECKING, Any, Callable, Generator, Iterable, Mapping, MutableMapping, TypeVar, cast

from lazy_object_proxy import Proxy
Expand Down Expand Up @@ -216,8 +216,8 @@ def merge_dicts(dict1: dict, dict2: dict) -> dict:

def partition(pred: Callable[[T], bool], iterable: Iterable[T]) -> tuple[Iterable[T], Iterable[T]]:
"""Use a predicate to partition entries into false entries and true entries."""
iter_1, iter_2 = tee(iterable)
return filterfalse(pred, iter_1), filter(pred, iter_2)
iter_1, iter_2 = itertools.tee(iterable)
return itertools.filterfalse(pred, iter_1), filter(pred, iter_2)


def chain(*args, **kwargs):
Expand Down
6 changes: 3 additions & 3 deletions airflow/www/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@

import functools
import gzip
import itertools
import json
import logging
from io import BytesIO as IO
from itertools import chain
from typing import Callable, TypeVar, cast

import pendulum
Expand Down Expand Up @@ -94,15 +94,15 @@ def wrapper(*args, **kwargs):
fields_skip_logging = {"csrf_token", "_csrf_token"}
extra_fields = [
(k, secrets_masker.redact(v, k))
for k, v in chain(request.values.items(multi=True), request.view_args.items())
for k, v in itertools.chain(request.values.items(multi=True), request.view_args.items())
if k not in fields_skip_logging
]
if event and event.startswith("variable."):
extra_fields = _mask_variable_fields(extra_fields)
if event and event.startswith("connection."):
extra_fields = _mask_connection_fields(extra_fields)

params = {k: v for k, v in chain(request.values.items(), request.view_args.items())}
params = {k: v for k, v in itertools.chain(request.values.items(), request.view_args.items())}

log = Log(
event=event or f.__name__,
Expand Down
4 changes: 2 additions & 2 deletions dev/check_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
# under the License.
from __future__ import annotations

import itertools
import os
import re
from itertools import product

import rich_click as click
from rich import print
Expand Down Expand Up @@ -141,7 +141,7 @@ def check_release(files: list[str], version: str):


def expand_name_variations(files):
return sorted(base + suffix for base, suffix in product(files, ["", ".asc", ".sha512"]))
return sorted(base + suffix for base, suffix in itertools.product(files, ["", ".asc", ".sha512"]))


def check_upgrade_check(files: list[str], version: str):
Expand Down
6 changes: 3 additions & 3 deletions docs/build_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@
from __future__ import annotations

import argparse
import itertools
import multiprocessing
import os
import sys
from collections import defaultdict
from itertools import filterfalse, tee
from typing import Callable, Iterable, NamedTuple, TypeVar

from rich.console import Console
Expand Down Expand Up @@ -74,8 +74,8 @@

def partition(pred: Callable[[T], bool], iterable: Iterable[T]) -> tuple[Iterable[T], Iterable[T]]:
"""Use a predicate to partition entries into false entries and true entries"""
iter_1, iter_2 = tee(iterable)
return filterfalse(pred, iter_1), filter(pred, iter_2)
iter_1, iter_2 = itertools.tee(iterable)
return itertools.filterfalse(pred, iter_1), filter(pred, iter_2)


def _promote_new_flags():
Expand Down
4 changes: 2 additions & 2 deletions docs/exts/docs_build/fetch_inventories.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
import concurrent
import concurrent.futures
import datetime
import itertools
import os
import shutil
import sys
import traceback
from itertools import repeat
from tempfile import NamedTemporaryFile
from typing import Iterator

Expand Down Expand Up @@ -142,7 +142,7 @@ def fetch_inventories():
with requests.Session() as session, concurrent.futures.ThreadPoolExecutor(DEFAULT_POOLSIZE) as pool:
download_results: Iterator[tuple[str, bool]] = pool.map(
_fetch_file,
repeat(session, len(to_download)),
itertools.repeat(session, len(to_download)),
(pkg_name for pkg_name, _, _ in to_download),
(url for _, url, _ in to_download),
(path for _, _, path in to_download),
Expand Down
6 changes: 3 additions & 3 deletions docs/exts/docs_build/lint_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
from __future__ import annotations

import ast
import itertools
import os
import re
from glob import glob
from itertools import chain
from typing import Iterable

from docs.exts.docs_build.docs_builder import ALL_PROVIDER_YAMLS
Expand Down Expand Up @@ -87,7 +87,7 @@ def check_guide_links_in_operator_descriptions() -> list[DocBuildError]:
operator_names=find_existing_guide_operator_names(
f"{DOCS_DIR}/apache-airflow/howto/operator/**/*.rst"
),
python_module_paths=chain(
python_module_paths=itertools.chain(
glob(f"{ROOT_PACKAGE_DIR}/operators/*.py"),
glob(f"{ROOT_PACKAGE_DIR}/sensors/*.py"),
),
Expand All @@ -101,7 +101,7 @@ def check_guide_links_in_operator_descriptions() -> list[DocBuildError]:
}

# Extract all potential python modules that can contain operators
python_module_paths = chain(
python_module_paths = itertools.chain(
glob(f"{provider['package-dir']}/**/operators/*.py", recursive=True),
glob(f"{provider['package-dir']}/**/sensors/*.py", recursive=True),
glob(f"{provider['package-dir']}/**/transfers/*.py", recursive=True),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def iter_check_deferrable_default_errors(module_filename: str) -> Iterator[str]:
args = node.args
arguments = reversed([*args.args, *args.kwonlyargs])
defaults = reversed([*args.defaults, *args.kw_defaults])
for argument, default in itertools.zip_longest(arguments, defaults, fillvalue=None):
for argument, default in zip(arguments, defaults):
if argument is None or default is None:
continue
if argument.arg != "deferrable" or _is_valid_deferrable_default(default):
Expand Down
3 changes: 1 addition & 2 deletions scripts/ci/pre_commit/pre_commit_sort_installed_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
# under the License.
from __future__ import annotations

import itertools
from pathlib import Path

if __name__ not in ("__main__", "__mp_main__"):
Expand All @@ -35,7 +34,7 @@ def stable_sort(x):


def sort_uniq(sequence):
return (x[0] for x in itertools.groupby(sorted(sequence, key=stable_sort)))
return sorted(set(sequence), key=stable_sort)


if __name__ == "__main__":
Expand Down
3 changes: 1 addition & 2 deletions scripts/ci/pre_commit/pre_commit_sort_spelling_wordlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
# under the License.
from __future__ import annotations

import itertools
from pathlib import Path

if __name__ not in ("__main__", "__mp_main__"):
Expand All @@ -35,7 +34,7 @@ def stable_sort(x):


def sort_uniq(sequence):
return (x[0] for x in itertools.groupby(sorted(sequence, key=stable_sort)))
return sorted(set(sequence), key=stable_sort)


if __name__ == "__main__":
Expand Down
14 changes: 7 additions & 7 deletions scripts/in_container/run_provider_yaml_files_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import importlib
import inspect
import itertools
import json
import os
import pathlib
Expand All @@ -27,7 +28,6 @@
import textwrap
from collections import Counter
from enum import Enum
from itertools import chain, product
from typing import Any, Iterable

import jsonschema
Expand Down Expand Up @@ -219,7 +219,7 @@ def check_if_objects_exist_and_belong_to_package(
def parse_module_data(provider_data, resource_type, yaml_file_path):
package_dir = ROOT_DIR.joinpath(yaml_file_path).parent
provider_package = pathlib.Path(yaml_file_path).parent.as_posix().replace("/", ".")
py_files = chain(
py_files = itertools.chain(
package_dir.glob(f"**/{resource_type}/*.py"),
package_dir.glob(f"{resource_type}/*.py"),
package_dir.glob(f"**/{resource_type}/**/*.py"),
Expand All @@ -233,7 +233,7 @@ def parse_module_data(provider_data, resource_type, yaml_file_path):
def check_correctness_of_list_of_sensors_operators_hook_modules(yaml_files: dict[str, dict]):
print("Checking completeness of list of {sensors, hooks, operators, triggers}")
print(" -- {sensors, hooks, operators, triggers} - Expected modules (left) : Current modules (right)")
for (yaml_file_path, provider_data), resource_type in product(
for (yaml_file_path, provider_data), resource_type in itertools.product(
yaml_files.items(), ["sensors", "operators", "hooks", "triggers"]
):
expected_modules, provider_package, resource_data = parse_module_data(
Expand All @@ -257,7 +257,7 @@ def check_correctness_of_list_of_sensors_operators_hook_modules(yaml_files: dict

def check_duplicates_in_integrations_names_of_hooks_sensors_operators(yaml_files: dict[str, dict]):
print("Checking for duplicates in list of {sensors, hooks, operators, triggers}")
for (yaml_file_path, provider_data), resource_type in product(
for (yaml_file_path, provider_data), resource_type in itertools.product(
yaml_files.items(), ["sensors", "operators", "hooks", "triggers"]
):
resource_data = provider_data.get(resource_type, [])
Expand Down Expand Up @@ -362,7 +362,7 @@ def check_invalid_integration(yaml_files: dict[str, dict]):
print("Detect unregistered integrations")
all_integration_names = set(get_all_integration_names(yaml_files))

for (yaml_file_path, provider_data), resource_type in product(
for (yaml_file_path, provider_data), resource_type in itertools.product(
yaml_files.items(), ["sensors", "operators", "hooks", "triggers"]
):
resource_data = provider_data.get(resource_type, [])
Expand All @@ -374,7 +374,7 @@ def check_invalid_integration(yaml_files: dict[str, dict]):
f"Invalid values: {invalid_names}"
)

for (yaml_file_path, provider_data), key in product(
for (yaml_file_path, provider_data), key in itertools.product(
yaml_files.items(), ["source-integration-name", "target-integration-name"]
):
resource_data = provider_data.get("transfers", [])
Expand Down Expand Up @@ -409,7 +409,7 @@ def check_doc_files(yaml_files: dict[str, dict]):
console.print("[yellow]Suspended providers:[/]")
console.print(suspended_providers)

expected_doc_files = chain(
expected_doc_files = itertools.chain(
DOCS_DIR.glob("apache-airflow-providers-*/operators/**/*.rst"),
DOCS_DIR.glob("apache-airflow-providers-*/transfer/**/*.rst"),
)
Expand Down
10 changes: 4 additions & 6 deletions tests/always/test_project_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,12 +455,10 @@ class TestDockerProviderProjectStructure(ExampleCoverageTest):
class TestOperatorsHooks:
def test_no_illegal_suffixes(self):
illegal_suffixes = ["_operator.py", "_hook.py", "_sensor.py"]
files = itertools.chain(
*(
glob.glob(f"{ROOT_FOLDER}/{part}/providers/**/{resource_type}/*.py", recursive=True)
for resource_type in ["operators", "hooks", "sensors", "example_dags"]
for part in ["airflow", "tests"]
)
files = itertools.chain.from_iterable(
glob.glob(f"{ROOT_FOLDER}/{part}/providers/**/{resource_type}/*.py", recursive=True)
for resource_type in ["operators", "hooks", "sensors", "example_dags"]
for part in ["airflow", "tests"]
)

invalid_files = [f for f in files if f.endswith(tuple(illegal_suffixes))]
Expand Down
Loading

0 comments on commit 95a930b

Please sign in to comment.