From d2b766f8cc779fa1b22a6200c52782dad453e02b Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Mon, 30 May 2022 03:24:08 -0400 Subject: [PATCH] Ensure @contextmanager decorates generator func (#23103) (cherry picked from commit e58985598f202395098e15b686aec33645a906ff) --- airflow/cli/commands/task_command.py | 4 ++-- airflow/models/taskinstance.py | 3 +-- airflow/providers/google/cloud/hooks/gcs.py | 19 ++++++++++++++++--- .../cloud/utils/credentials_provider.py | 9 ++++++--- .../google/common/hooks/base_google.py | 10 +++++----- .../providers/microsoft/psrp/hooks/psrp.py | 4 ++-- airflow/utils/db.py | 11 ++++++++--- airflow/utils/process_utils.py | 4 ++-- airflow/utils/session.py | 4 ++-- .../src/airflow_breeze/utils/run_utils.py | 4 ++-- .../prepare_provider_packages.py | 4 ++-- 11 files changed, 48 insertions(+), 28 deletions(-) diff --git a/airflow/cli/commands/task_command.py b/airflow/cli/commands/task_command.py index ea20ebb64615e..2b743b91fe039 100644 --- a/airflow/cli/commands/task_command.py +++ b/airflow/cli/commands/task_command.py @@ -23,7 +23,7 @@ import os import textwrap from contextlib import contextmanager, redirect_stderr, redirect_stdout -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, Generator, List, Optional, Tuple, Union from pendulum.parsing.exceptions import ParserError from sqlalchemy.orm.exc import NoResultFound @@ -269,7 +269,7 @@ def _extract_external_executor_id(args) -> Optional[str]: @contextmanager -def _capture_task_logs(ti): +def _capture_task_logs(ti: TaskInstance) -> Generator[None, None, None]: """Manage logging context for a task run - Replace the root logger configuration with the airflow.task configuration diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 2885d56b54ff8..5cd582ce3e34d 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -40,7 +40,6 @@ Dict, Generator, Iterable, - Iterator, List, NamedTuple, Optional, @@ -142,7 +141,7 @@ @contextlib.contextmanager -def set_current_context(context: Context) -> Iterator[Context]: +def set_current_context(context: Context) -> Generator[Context, None, None]: """ Sets the current execution context to the provided context object. This method should be called once per Task execution, before calling operator.execute. diff --git a/airflow/providers/google/cloud/hooks/gcs.py b/airflow/providers/google/cloud/hooks/gcs.py index 29ad6ac4386e7..93717e00e9a5a 100644 --- a/airflow/providers/google/cloud/hooks/gcs.py +++ b/airflow/providers/google/cloud/hooks/gcs.py @@ -29,7 +29,20 @@ from io import BytesIO from os import path from tempfile import NamedTemporaryFile -from typing import Callable, List, Optional, Sequence, Set, Tuple, TypeVar, Union, cast, overload +from typing import ( + IO, + Callable, + Generator, + List, + Optional, + Sequence, + Set, + Tuple, + TypeVar, + Union, + cast, + overload, +) from urllib.parse import urlparse from google.api_core.exceptions import NotFound @@ -385,7 +398,7 @@ def provide_file( object_name: Optional[str] = None, object_url: Optional[str] = None, dir: Optional[str] = None, - ): + ) -> Generator[IO[bytes], None, None]: """ Downloads the file to a temporary directory and returns a file handle @@ -413,7 +426,7 @@ def provide_file_and_upload( bucket_name: str = PROVIDE_BUCKET, object_name: Optional[str] = None, object_url: Optional[str] = None, - ): + ) -> Generator[IO[bytes], None, None]: """ Creates temporary file, returns a file handle and uploads the files content on close. diff --git a/airflow/providers/google/cloud/utils/credentials_provider.py b/airflow/providers/google/cloud/utils/credentials_provider.py index 0a8143ceae782..1cf33ea70b056 100644 --- a/airflow/providers/google/cloud/utils/credentials_provider.py +++ b/airflow/providers/google/cloud/utils/credentials_provider.py @@ -74,7 +74,10 @@ def build_gcp_conn( @contextmanager -def provide_gcp_credentials(key_file_path: Optional[str] = None, key_file_dict: Optional[Dict] = None): +def provide_gcp_credentials( + key_file_path: Optional[str] = None, + key_file_dict: Optional[Dict] = None, +) -> Generator[None, None, None]: """ Context manager that provides a Google Cloud credentials for application supporting `Application Default Credentials (ADC) strategy`__. @@ -111,7 +114,7 @@ def provide_gcp_connection( key_file_path: Optional[str] = None, scopes: Optional[Sequence] = None, project_id: Optional[str] = None, -) -> Generator: +) -> Generator[None, None, None]: """ Context manager that provides a temporary value of :envvar:`AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT` connection. It build a new connection that includes path to provided service json, @@ -135,7 +138,7 @@ def provide_gcp_conn_and_credentials( key_file_path: Optional[str] = None, scopes: Optional[Sequence] = None, project_id: Optional[str] = None, -) -> Generator: +) -> Generator[None, None, None]: """ Context manager that provides both: diff --git a/airflow/providers/google/common/hooks/base_google.py b/airflow/providers/google/common/hooks/base_google.py index f2c0d5157a2e6..d9fe5daba5443 100644 --- a/airflow/providers/google/common/hooks/base_google.py +++ b/airflow/providers/google/common/hooks/base_google.py @@ -25,7 +25,7 @@ import warnings from contextlib import ExitStack, contextmanager from subprocess import check_output -from typing import Any, Callable, Dict, Optional, Sequence, Tuple, TypeVar, Union, cast +from typing import Any, Callable, Dict, Generator, Optional, Sequence, Tuple, TypeVar, Union, cast import google.auth import google.auth.credentials @@ -459,7 +459,7 @@ def wrapper(self: GoogleBaseHook, *args, **kwargs): return cast(T, wrapper) @contextmanager - def provide_gcp_credential_file_as_context(self): + def provide_gcp_credential_file_as_context(self) -> Generator[Optional[str], None, None]: """ Context manager that provides a Google Cloud credentials for application supporting `Application Default Credentials (ADC) strategy `__. @@ -467,8 +467,8 @@ def provide_gcp_credential_file_as_context(self): It can be used to provide credentials for external programs (e.g. gcloud) that expect authorization file in ``GOOGLE_APPLICATION_CREDENTIALS`` environment variable. """ - key_path = self._get_field('key_path', None) # type: Optional[str] # - keyfile_dict = self._get_field('keyfile_dict', None) # type: Optional[Dict] + key_path: Optional[str] = self._get_field('key_path', None) + keyfile_dict: Optional[str] = self._get_field('keyfile_dict', None) if key_path and keyfile_dict: raise AirflowException( "The `keyfile_dict` and `key_path` fields are mutually exclusive. " @@ -490,7 +490,7 @@ def provide_gcp_credential_file_as_context(self): yield None @contextmanager - def provide_authorized_gcloud(self): + def provide_authorized_gcloud(self) -> Generator[None, None, None]: """ Provides a separate gcloud configuration with current credentials. diff --git a/airflow/providers/microsoft/psrp/hooks/psrp.py b/airflow/providers/microsoft/psrp/hooks/psrp.py index 005f1e215d74d..0aebe63d0319e 100644 --- a/airflow/providers/microsoft/psrp/hooks/psrp.py +++ b/airflow/providers/microsoft/psrp/hooks/psrp.py @@ -19,7 +19,7 @@ from contextlib import contextmanager from copy import copy from logging import DEBUG, ERROR, INFO, WARNING -from typing import Any, Callable, Dict, Iterator, Optional +from typing import Any, Callable, Dict, Generator, Optional from weakref import WeakKeyDictionary from pypsrp.host import PSHost @@ -155,7 +155,7 @@ def apply_extra(d, keys): return pool @contextmanager - def invoke(self) -> Iterator[PowerShell]: + def invoke(self) -> Generator[PowerShell, None, None]: """ Context manager that yields a PowerShell object to which commands can be added. Upon exit, the commands will be invoked. diff --git a/airflow/utils/db.py b/airflow/utils/db.py index 03c0848e6c722..f606dfc332813 100644 --- a/airflow/utils/db.py +++ b/airflow/utils/db.py @@ -24,7 +24,7 @@ import warnings from dataclasses import dataclass from tempfile import gettempdir -from typing import TYPE_CHECKING, Callable, Iterable, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Callable, Generator, Iterable, List, Optional, Tuple, Union from sqlalchemy import Table, and_, column, exc, func, inspect, or_, select, table, text, tuple_ from sqlalchemy.orm.session import Session @@ -68,6 +68,7 @@ from airflow.version import version if TYPE_CHECKING: + from alembic.runtime.environment import EnvironmentContext from alembic.script import ScriptDirectory from sqlalchemy.orm import Query @@ -709,7 +710,7 @@ def check_migrations(timeout): @contextlib.contextmanager -def _configured_alembic_environment(): +def _configured_alembic_environment() -> Generator["EnvironmentContext", None, None]: from alembic.runtime.environment import EnvironmentContext config = _get_alembic_config() @@ -1606,7 +1607,11 @@ def __str__(self): @contextlib.contextmanager -def create_global_lock(session: Session, lock: DBLocks, lock_timeout=1800): +def create_global_lock( + session: Session, + lock: DBLocks, + lock_timeout: int = 1800, +) -> Generator[None, None, None]: """Contextmanager that will create and teardown a global db lock.""" conn = session.get_bind().connect() dialect = conn.dialect diff --git a/airflow/utils/process_utils.py b/airflow/utils/process_utils.py index fd63f3e959abd..2ec782df665af 100644 --- a/airflow/utils/process_utils.py +++ b/airflow/utils/process_utils.py @@ -34,7 +34,7 @@ import pty from contextlib import contextmanager -from typing import Dict, List, Optional +from typing import Dict, Generator, List, Optional import psutil from lockfile.pidlockfile import PIDLockFile @@ -258,7 +258,7 @@ def kill_child_processes_by_pids(pids_to_kill: List[int], timeout: int = 5) -> N @contextmanager -def patch_environ(new_env_variables: Dict[str, str]): +def patch_environ(new_env_variables: Dict[str, str]) -> Generator[None, None, None]: """ Sets environment variables in context. After leaving the context, it restores its original state. diff --git a/airflow/utils/session.py b/airflow/utils/session.py index 3565e216a2a0c..377ff55cbf00e 100644 --- a/airflow/utils/session.py +++ b/airflow/utils/session.py @@ -17,13 +17,13 @@ import contextlib from functools import wraps from inspect import signature -from typing import Callable, Iterator, TypeVar, cast +from typing import Callable, Generator, TypeVar, cast from airflow import settings @contextlib.contextmanager -def create_session() -> Iterator[settings.SASession]: +def create_session() -> Generator[settings.SASession, None, None]: """Contextmanager that will create and teardown a session.""" if not settings.Session: raise RuntimeError("Session must be set before!") diff --git a/dev/breeze/src/airflow_breeze/utils/run_utils.py b/dev/breeze/src/airflow_breeze/utils/run_utils.py index 86b84be4c076f..03f3c0532d205 100644 --- a/dev/breeze/src/airflow_breeze/utils/run_utils.py +++ b/dev/breeze/src/airflow_breeze/utils/run_utils.py @@ -25,7 +25,7 @@ from functools import lru_cache from pathlib import Path from re import match -from typing import Dict, List, Mapping, Optional, Union +from typing import Dict, Generator, List, Mapping, Optional, Union from airflow_breeze.branch_defaults import AIRFLOW_BRANCH from airflow_breeze.params._common_build_params import _CommonBuildParams @@ -213,7 +213,7 @@ def instruct_build_image(python: str): @contextlib.contextmanager -def working_directory(source_path: Path): +def working_directory(source_path: Path) -> Generator[None, None, None]: """ # Equivalent of pushd and popd in bash script. # https://stackoverflow.com/a/42441759/3101838 diff --git a/dev/provider_packages/prepare_provider_packages.py b/dev/provider_packages/prepare_provider_packages.py index 0091dee5c6adc..5ecf9a4850c3d 100755 --- a/dev/provider_packages/prepare_provider_packages.py +++ b/dev/provider_packages/prepare_provider_packages.py @@ -38,7 +38,7 @@ from os.path import dirname, relpath from pathlib import Path from shutil import copyfile -from typing import Any, Dict, Iterable, List, NamedTuple, Optional, Set, Tuple, Union +from typing import Any, Dict, Generator, Iterable, List, NamedTuple, Optional, Set, Tuple, Union import jsonschema import rich_click as click @@ -195,7 +195,7 @@ def cli(): @contextmanager -def with_group(title): +def with_group(title: str) -> Generator[None, None, None]: """ If used in GitHub Action, creates an expandable group in the GitHub Action log. Otherwise, display simple text groups.