From b297fc78a9a33234f7708642f62641cc20c8e322 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Fri, 10 Dec 2021 04:36:06 +0800 Subject: [PATCH] Mypy fixes to DagRun, TaskInstance, and db utils (#20163) --- airflow/models/dagrun.py | 8 +++++--- airflow/models/taskinstance.py | 24 +++++++++++++----------- airflow/utils/db.py | 17 +++++++++-------- 3 files changed, 27 insertions(+), 22 deletions(-) diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index f9ced3a1f0b50..e0210eb615aee 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -141,7 +141,7 @@ def __init__( self, dag_id: Optional[str] = None, run_id: Optional[str] = None, - queued_at: Union[datetime, None, ArgNotSet] = NOTSET, # type: ignore + queued_at: Union[datetime, None, ArgNotSet] = NOTSET, execution_date: Optional[datetime] = None, start_date: Optional[datetime] = None, external_trigger: Optional[bool] = None, @@ -398,7 +398,9 @@ def generate_run_id(run_type: DagRunType, execution_date: datetime) -> str: @provide_session def get_task_instances( - self, state: Optional[Iterable[TaskInstanceState]] = None, session=None + self, + state: Optional[Iterable[Optional[TaskInstanceState]]] = None, + session: Session = NEW_SESSION, ) -> Iterable[TI]: """Returns the task instances for this dag run""" tis = ( @@ -805,7 +807,7 @@ def verify_integrity(self, session: Session = NEW_SESSION): if task.task_id not in task_ids: Stats.incr(f"task_instance_created-{task.task_type}", 1, 1) - ti = TI(task, execution_date=None, run_id=self.run_id) + ti = TI(task, run_id=self.run_id) task_instance_mutation_hook(ti) session.add(ti) diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index b02a588b64595..7691f4cd71efc 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -161,7 +161,7 @@ def load_error_file(fd: IO[bytes]) -> Optional[Union[str, Exception]]: return "Failed to load task run error" -def set_error_file(error_file: str, error: Union[str, Exception]) -> None: +def set_error_file(error_file: str, error: Union[str, BaseException]) -> None: """Write error into error file by path""" with open(error_file, "wb") as fd: try: @@ -877,7 +877,7 @@ def are_dependents_done(self, session=NEW_SESSION): @provide_session def get_previous_dagrun( self, - state: Optional[str] = None, + state: Optional[DagRunState] = None, session: Optional[Session] = None, ) -> Optional["DagRun"]: """The DagRun that ran before this task instance's DagRun. @@ -909,7 +909,9 @@ def get_previous_dagrun( @provide_session def get_previous_ti( - self, state: Optional[str] = None, session: Session = NEW_SESSION + self, + state: Optional[DagRunState] = None, + session: Session = NEW_SESSION, ) -> Optional['TaskInstance']: """ The task instance for the task that ran before this task instance. @@ -952,12 +954,12 @@ def previous_ti_success(self) -> Optional['TaskInstance']: DeprecationWarning, stacklevel=2, ) - return self.get_previous_ti(state=State.SUCCESS) + return self.get_previous_ti(state=DagRunState.SUCCESS) @provide_session def get_previous_execution_date( self, - state: Optional[str] = None, + state: Optional[DagRunState] = None, session: Session = NEW_SESSION, ) -> Optional[pendulum.DateTime]: """ @@ -972,7 +974,7 @@ def get_previous_execution_date( @provide_session def get_previous_start_date( - self, state: Optional[str] = None, session: Session = NEW_SESSION + self, state: Optional[DagRunState] = None, session: Session = NEW_SESSION ) -> Optional[pendulum.DateTime]: """ The start date from property previous_ti_success. @@ -999,7 +1001,7 @@ def previous_start_date_success(self) -> Optional[pendulum.DateTime]: DeprecationWarning, stacklevel=2, ) - return self.get_previous_start_date(state=State.SUCCESS) + return self.get_previous_start_date(state=DagRunState.SUCCESS) @provide_session def are_dependencies_met(self, dep_context=None, session=NEW_SESSION, verbose=False): @@ -1689,7 +1691,7 @@ def _handle_reschedule( @provide_session def handle_failure( self, - error: Union[str, Exception], + error: Union[str, BaseException], test_mode: Optional[bool] = None, force_fail: bool = False, error_file: Optional[str] = None, @@ -1700,7 +1702,7 @@ def handle_failure( test_mode = self.test_mode if error: - if isinstance(error, Exception): + if isinstance(error, BaseException): self.log.error("Task failed with exception", exc_info=error) else: self.log.error("%s", error) @@ -1814,7 +1816,7 @@ def get_template_context( @cache # Prevent multiple database access. def _get_previous_dagrun_success() -> Optional["DagRun"]: - return self.get_previous_dagrun(state=State.SUCCESS, session=session) + return self.get_previous_dagrun(state=DagRunState.SUCCESS, session=session) def _get_previous_dagrun_data_interval_success() -> Optional["DataInterval"]: dagrun = _get_previous_dagrun_success() @@ -1926,7 +1928,7 @@ def get_prev_ds_nodash() -> Optional[str]: 'prev_ds_nodash': get_prev_ds_nodash(), 'prev_execution_date': get_prev_execution_date(), 'prev_execution_date_success': self.get_previous_execution_date( - state=State.SUCCESS, + state=DagRunState.SUCCESS, session=session, ), 'prev_start_date_success': get_prev_start_date_success(), diff --git a/airflow/utils/db.py b/airflow/utils/db.py index 24d100bebd742..bd620f5eabdbd 100644 --- a/airflow/utils/db.py +++ b/airflow/utils/db.py @@ -22,9 +22,10 @@ import sys import time from tempfile import gettempdir -from typing import Iterable +from typing import Any, Iterable, List from sqlalchemy import Table, exc, func, inspect, or_, text +from sqlalchemy.orm.session import Session from airflow import settings from airflow.configuration import conf @@ -57,7 +58,7 @@ from airflow.utils import helpers # TODO: remove create_session once we decide to break backward compatibility -from airflow.utils.session import create_session, provide_session # noqa: F401 +from airflow.utils.session import NEW_SESSION, create_session, provide_session # noqa: F401 from airflow.version import version log = logging.getLogger(__name__) @@ -719,7 +720,7 @@ def check_and_run_migrations(): sys.exit(1) -def check_conn_id_duplicates(session=None) -> Iterable[str]: +def check_conn_id_duplicates(session: Session) -> Iterable[str]: """ Check unique conn_id in connection table @@ -742,7 +743,7 @@ def check_conn_id_duplicates(session=None) -> Iterable[str]: ) -def check_conn_type_null(session=None) -> Iterable[str]: +def check_conn_type_null(session: Session) -> Iterable[str]: """ Check nullable conn_type column in Connection table @@ -840,7 +841,7 @@ def _move_dangling_table(session, source_table: "Table", target_table_name: str, ) -def check_run_id_null(session) -> Iterable[str]: +def check_run_id_null(session: Session) -> Iterable[str]: import sqlalchemy.schema metadata = sqlalchemy.schema.MetaData(session.bind) @@ -882,12 +883,12 @@ def _move_dangling_task_data_to_new_table(session, source_table: "Table", target _move_dangling_table(session, source_table, target_table_name, where_clause) -def check_task_tables_without_matching_dagruns(session) -> Iterable[str]: +def check_task_tables_without_matching_dagruns(session: Session) -> Iterable[str]: import sqlalchemy.schema from sqlalchemy import and_, outerjoin metadata = sqlalchemy.schema.MetaData(session.bind) - models_to_dagrun = [TaskInstance, TaskReschedule] + models_to_dagrun: List[Any] = [TaskInstance, TaskReschedule] for model in models_to_dagrun + [DagRun]: try: metadata.reflect(only=[model.__tablename__], extend_existing=True, resolve_fks=False) @@ -950,7 +951,7 @@ def check_task_tables_without_matching_dagruns(session) -> Iterable[str]: @provide_session -def _check_migration_errors(session=None) -> Iterable[str]: +def _check_migration_errors(session: Session = NEW_SESSION) -> Iterable[str]: """ :session: session of the sqlalchemy :rtype: list[str]