Skip to content

Commit

Permalink
Refactor DagRun.verify_integrity (#24114)
Browse files Browse the repository at this point in the history
This refactoring became necessary as there's a necessity to add additional code
to the already exisiting code to handle mapped task immutability during run. The additional
code would make this method difficult to read. Refactoring the code will aid understanding and
help in debugging.

(cherry picked from commit 12638d2)
  • Loading branch information
ephraimbuddy committed Jun 29, 2022
1 parent d7b58db commit 5e174a1
Showing 1 changed file with 88 additions and 14 deletions.
102 changes: 88 additions & 14 deletions airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,15 @@
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Generator,
Iterable,
List,
NamedTuple,
Optional,
Sequence,
Set,
Tuple,
Union,
cast,
Expand Down Expand Up @@ -818,13 +820,50 @@ def verify_integrity(self, session: Session = NEW_SESSION):
"""
from airflow.settings import task_instance_mutation_hook

# Set for the empty default in airflow.settings -- if it's not set this means it has been changed
hook_is_noop = getattr(task_instance_mutation_hook, 'is_noop', False)

dag = self.get_dag()
task_ids = self._check_for_removed_or_restored_tasks(
dag, task_instance_mutation_hook, session=session
)

def task_filter(task: "Operator") -> bool:
return task.task_id not in task_ids and (
self.is_backfill
or task.start_date <= self.execution_date
and (task.end_date is None or self.execution_date <= task.end_date)
)

created_counts: Dict[str, int] = defaultdict(int)

# Get task creator function
task_creator = self._get_task_creator(created_counts, task_instance_mutation_hook, hook_is_noop)

# Create the missing tasks, including mapped tasks
tasks = self._create_missing_tasks(dag, task_creator, task_filter, session=session)

self._create_task_instances(dag.dag_id, tasks, created_counts, hook_is_noop, session=session)

def _check_for_removed_or_restored_tasks(
self, dag: "DAG", ti_mutation_hook, *, session: Session
) -> Set[str]:
"""
Check for removed tasks/restored tasks.
:param dag: DAG object corresponding to the dagrun
:param ti_mutation_hook: task_instance_mutation_hook function
:param session: Sqlalchemy ORM Session
:return: List of task_ids in the dagrun
"""
tis = self.get_task_instances(session=session)

# check for removed or restored tasks
task_ids = set()
for ti in tis:
task_instance_mutation_hook(ti)
ti_mutation_hook(ti)
task_ids.add(ti.task_id)
task = None
try:
Expand Down Expand Up @@ -885,19 +924,21 @@ def verify_integrity(self, session: Session = NEW_SESSION):
)
ti.state = State.REMOVED
...
return task_ids

def task_filter(task: "Operator") -> bool:
return task.task_id not in task_ids and (
self.is_backfill
or task.start_date <= self.execution_date
and (task.end_date is None or self.execution_date <= task.end_date)
)
def _get_task_creator(
self, created_counts: Dict[str, int], ti_mutation_hook: Callable, hook_is_noop: bool
) -> Callable:
"""
Get the task creator function.
created_counts: Dict[str, int] = defaultdict(int)
This function also updates the created_counts dictionary with the number of tasks created.
# Set for the empty default in airflow.settings -- if it's not set this means it has been changed
hook_is_noop = getattr(task_instance_mutation_hook, 'is_noop', False)
:param created_counts: Dictionary of task_type -> count of created TIs
:param ti_mutation_hook: task_instance_mutation_hook function
:param hook_is_noop: Whether the task_instance_mutation_hook is a noop
"""
if hook_is_noop:

def create_ti_mapping(task: "Operator", indexes: Tuple[int, ...]) -> Generator:
Expand All @@ -912,13 +953,25 @@ def create_ti_mapping(task: "Operator", indexes: Tuple[int, ...]) -> Generator:
def create_ti(task: "Operator", indexes: Tuple[int, ...]) -> Generator:
for map_index in indexes:
ti = TI(task, run_id=self.run_id, map_index=map_index)
task_instance_mutation_hook(ti)
ti_mutation_hook(ti)
created_counts[ti.operator] += 1
yield ti

creator = create_ti
return creator

def _create_missing_tasks(
self, dag: "DAG", task_creator: Callable, task_filter: Callable, *, session: Session
) -> Iterable["Operator"]:
"""
Create missing tasks -- and expand any MappedOperator that _only_ have literals as input
:param dag: DAG object corresponding to the dagrun
:param task_creator: a function that creates tasks
:param task_filter: a function that filters tasks to create
:param session: the session to use
"""

# Create missing tasks -- and expand any MappedOperator that _only_ have literals as input
def expand_mapped_literals(task: "Operator") -> Tuple["Operator", Sequence[int]]:
if not task.is_mapped:
return (task, (-1,))
Expand All @@ -931,8 +984,29 @@ def expand_mapped_literals(task: "Operator") -> Tuple["Operator", Sequence[int]]
return (task, range(count))

tasks_and_map_idxs = map(expand_mapped_literals, filter(task_filter, dag.task_dict.values()))
tasks = itertools.chain.from_iterable(itertools.starmap(creator, tasks_and_map_idxs))

tasks = itertools.chain.from_iterable(itertools.starmap(task_creator, tasks_and_map_idxs))
return tasks

def _create_task_instances(
self,
dag_id: str,
tasks: Iterable["Operator"],
created_counts: Dict[str, int],
hook_is_noop: bool,
*,
session: Session,
) -> None:
"""
Create the necessary task instances from the given tasks.
:param dag_id: DAG ID associated with the dagrun
:param tasks: the tasks to create the task instances from
:param created_counts: a dictionary of number of tasks -> total ti created by the task creator
:param hook_is_noop: whether the task_instance_mutation_hook is noop
:param session: the session to use
"""
try:
if hook_is_noop:
session.bulk_insert_mappings(TI, tasks)
Expand All @@ -945,7 +1019,7 @@ def expand_mapped_literals(task: "Operator") -> Tuple["Operator", Sequence[int]]
except IntegrityError:
self.log.info(
'Hit IntegrityError while creating the TIs for %s- %s',
dag.dag_id,
dag_id,
self.run_id,
exc_info=True,
)
Expand Down

0 comments on commit 5e174a1

Please sign in to comment.