Skip to content

Commit

Permalink
Mypy fixes to DagRun, TaskInstance, and db utils (#20163)
Browse files Browse the repository at this point in the history
  • Loading branch information
uranusjr authored Dec 9, 2021
1 parent 985bb06 commit b297fc7
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 22 deletions.
8 changes: 5 additions & 3 deletions airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def __init__(
self,
dag_id: Optional[str] = None,
run_id: Optional[str] = None,
queued_at: Union[datetime, None, ArgNotSet] = NOTSET, # type: ignore
queued_at: Union[datetime, None, ArgNotSet] = NOTSET,
execution_date: Optional[datetime] = None,
start_date: Optional[datetime] = None,
external_trigger: Optional[bool] = None,
Expand Down Expand Up @@ -398,7 +398,9 @@ def generate_run_id(run_type: DagRunType, execution_date: datetime) -> str:

@provide_session
def get_task_instances(
self, state: Optional[Iterable[TaskInstanceState]] = None, session=None
self,
state: Optional[Iterable[Optional[TaskInstanceState]]] = None,
session: Session = NEW_SESSION,
) -> Iterable[TI]:
"""Returns the task instances for this dag run"""
tis = (
Expand Down Expand Up @@ -805,7 +807,7 @@ def verify_integrity(self, session: Session = NEW_SESSION):

if task.task_id not in task_ids:
Stats.incr(f"task_instance_created-{task.task_type}", 1, 1)
ti = TI(task, execution_date=None, run_id=self.run_id)
ti = TI(task, run_id=self.run_id)
task_instance_mutation_hook(ti)
session.add(ti)

Expand Down
24 changes: 13 additions & 11 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def load_error_file(fd: IO[bytes]) -> Optional[Union[str, Exception]]:
return "Failed to load task run error"


def set_error_file(error_file: str, error: Union[str, Exception]) -> None:
def set_error_file(error_file: str, error: Union[str, BaseException]) -> None:
"""Write error into error file by path"""
with open(error_file, "wb") as fd:
try:
Expand Down Expand Up @@ -877,7 +877,7 @@ def are_dependents_done(self, session=NEW_SESSION):
@provide_session
def get_previous_dagrun(
self,
state: Optional[str] = None,
state: Optional[DagRunState] = None,
session: Optional[Session] = None,
) -> Optional["DagRun"]:
"""The DagRun that ran before this task instance's DagRun.
Expand Down Expand Up @@ -909,7 +909,9 @@ def get_previous_dagrun(

@provide_session
def get_previous_ti(
self, state: Optional[str] = None, session: Session = NEW_SESSION
self,
state: Optional[DagRunState] = None,
session: Session = NEW_SESSION,
) -> Optional['TaskInstance']:
"""
The task instance for the task that ran before this task instance.
Expand Down Expand Up @@ -952,12 +954,12 @@ def previous_ti_success(self) -> Optional['TaskInstance']:
DeprecationWarning,
stacklevel=2,
)
return self.get_previous_ti(state=State.SUCCESS)
return self.get_previous_ti(state=DagRunState.SUCCESS)

@provide_session
def get_previous_execution_date(
self,
state: Optional[str] = None,
state: Optional[DagRunState] = None,
session: Session = NEW_SESSION,
) -> Optional[pendulum.DateTime]:
"""
Expand All @@ -972,7 +974,7 @@ def get_previous_execution_date(

@provide_session
def get_previous_start_date(
self, state: Optional[str] = None, session: Session = NEW_SESSION
self, state: Optional[DagRunState] = None, session: Session = NEW_SESSION
) -> Optional[pendulum.DateTime]:
"""
The start date from property previous_ti_success.
Expand All @@ -999,7 +1001,7 @@ def previous_start_date_success(self) -> Optional[pendulum.DateTime]:
DeprecationWarning,
stacklevel=2,
)
return self.get_previous_start_date(state=State.SUCCESS)
return self.get_previous_start_date(state=DagRunState.SUCCESS)

@provide_session
def are_dependencies_met(self, dep_context=None, session=NEW_SESSION, verbose=False):
Expand Down Expand Up @@ -1689,7 +1691,7 @@ def _handle_reschedule(
@provide_session
def handle_failure(
self,
error: Union[str, Exception],
error: Union[str, BaseException],
test_mode: Optional[bool] = None,
force_fail: bool = False,
error_file: Optional[str] = None,
Expand All @@ -1700,7 +1702,7 @@ def handle_failure(
test_mode = self.test_mode

if error:
if isinstance(error, Exception):
if isinstance(error, BaseException):
self.log.error("Task failed with exception", exc_info=error)
else:
self.log.error("%s", error)
Expand Down Expand Up @@ -1814,7 +1816,7 @@ def get_template_context(

@cache # Prevent multiple database access.
def _get_previous_dagrun_success() -> Optional["DagRun"]:
return self.get_previous_dagrun(state=State.SUCCESS, session=session)
return self.get_previous_dagrun(state=DagRunState.SUCCESS, session=session)

def _get_previous_dagrun_data_interval_success() -> Optional["DataInterval"]:
dagrun = _get_previous_dagrun_success()
Expand Down Expand Up @@ -1926,7 +1928,7 @@ def get_prev_ds_nodash() -> Optional[str]:
'prev_ds_nodash': get_prev_ds_nodash(),
'prev_execution_date': get_prev_execution_date(),
'prev_execution_date_success': self.get_previous_execution_date(
state=State.SUCCESS,
state=DagRunState.SUCCESS,
session=session,
),
'prev_start_date_success': get_prev_start_date_success(),
Expand Down
17 changes: 9 additions & 8 deletions airflow/utils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@
import sys
import time
from tempfile import gettempdir
from typing import Iterable
from typing import Any, Iterable, List

from sqlalchemy import Table, exc, func, inspect, or_, text
from sqlalchemy.orm.session import Session

from airflow import settings
from airflow.configuration import conf
Expand Down Expand Up @@ -57,7 +58,7 @@
from airflow.utils import helpers

# TODO: remove create_session once we decide to break backward compatibility
from airflow.utils.session import create_session, provide_session # noqa: F401
from airflow.utils.session import NEW_SESSION, create_session, provide_session # noqa: F401
from airflow.version import version

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -719,7 +720,7 @@ def check_and_run_migrations():
sys.exit(1)


def check_conn_id_duplicates(session=None) -> Iterable[str]:
def check_conn_id_duplicates(session: Session) -> Iterable[str]:
"""
Check unique conn_id in connection table
Expand All @@ -742,7 +743,7 @@ def check_conn_id_duplicates(session=None) -> Iterable[str]:
)


def check_conn_type_null(session=None) -> Iterable[str]:
def check_conn_type_null(session: Session) -> Iterable[str]:
"""
Check nullable conn_type column in Connection table
Expand Down Expand Up @@ -840,7 +841,7 @@ def _move_dangling_table(session, source_table: "Table", target_table_name: str,
)


def check_run_id_null(session) -> Iterable[str]:
def check_run_id_null(session: Session) -> Iterable[str]:
import sqlalchemy.schema

metadata = sqlalchemy.schema.MetaData(session.bind)
Expand Down Expand Up @@ -882,12 +883,12 @@ def _move_dangling_task_data_to_new_table(session, source_table: "Table", target
_move_dangling_table(session, source_table, target_table_name, where_clause)


def check_task_tables_without_matching_dagruns(session) -> Iterable[str]:
def check_task_tables_without_matching_dagruns(session: Session) -> Iterable[str]:
import sqlalchemy.schema
from sqlalchemy import and_, outerjoin

metadata = sqlalchemy.schema.MetaData(session.bind)
models_to_dagrun = [TaskInstance, TaskReschedule]
models_to_dagrun: List[Any] = [TaskInstance, TaskReschedule]
for model in models_to_dagrun + [DagRun]:
try:
metadata.reflect(only=[model.__tablename__], extend_existing=True, resolve_fks=False)
Expand Down Expand Up @@ -950,7 +951,7 @@ def check_task_tables_without_matching_dagruns(session) -> Iterable[str]:


@provide_session
def _check_migration_errors(session=None) -> Iterable[str]:
def _check_migration_errors(session: Session = NEW_SESSION) -> Iterable[str]:
"""
:session: session of the sqlalchemy
:rtype: list[str]
Expand Down

0 comments on commit b297fc7

Please sign in to comment.