Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AIRFLOW-3607] Only query DB once per DAG run for TriggerRuleDep #11010

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 4 additions & 21 deletions airflow/jobs/scheduler_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from airflow.jobs.base_job import BaseJob
from airflow.models import DagRun, SlaMiss, errors
from airflow.settings import Stats
from airflow.ti_deps.dep_context import DepContext, SCHEDULEABLE_STATES, SCHEDULED_DEPS
from airflow.ti_deps.dep_context import DepContext, SCHEDULED_DEPS
from airflow.operators.dummy_operator import DummyOperator
from airflow.ti_deps.deps.pool_slots_available_dep import STATES_TO_COUNT_AS_RUNNING
from airflow.utils import asciiart, helpers, timezone
Expand Down Expand Up @@ -790,27 +790,10 @@ def _process_task_instances(self, dag, task_instances_list, session=None):
run.dag = dag
# todo: preferably the integrity check happens at dag collection time
run.verify_integrity(session=session)
run.update_state(session=session)
ready_tis = run.update_state(session=session)
if run.state == State.RUNNING:
make_transient(run)
active_dag_runs.append(run)

for run in active_dag_runs:
self.log.debug("Examining active DAG run: %s", run)
tis = run.get_task_instances(state=SCHEDULEABLE_STATES)

# this loop is quite slow as it uses are_dependencies_met for
# every task (in ti.is_runnable). This is also called in
# update_state above which has already checked these tasks
for ti in tis:
task = dag.get_task(ti.task_id)

# fixme: ti.task is transient but needs to be set
ti.task = task

if ti.are_dependencies_met(
dep_context=DepContext(flag_upstream_failed=True),
session=session):
self.log.debug("Examining active DAG run: %s", run)
for ti in ready_tis:
self.log.debug('Queuing task: %s', ti)
task_instances_list.append(ti.key)

Expand Down
85 changes: 50 additions & 35 deletions airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from airflow.exceptions import AirflowException
from airflow.models.base import ID_LEN, Base
from airflow.settings import Stats, task_instance_mutation_hook
from airflow.ti_deps.dep_context import DepContext
from airflow.ti_deps.dep_context import SCHEDULEABLE_STATES, DepContext
from airflow.utils import timezone
from airflow.utils.db import provide_session
from airflow.utils.log.logging_mixin import LoggingMixin
Expand Down Expand Up @@ -187,7 +187,6 @@ def get_task_instances(self, state=None, session=None):

if self.dag and self.dag.partial:
tis = tis.filter(TaskInstance.task_id.in_(self.dag.task_ids))

return tis.all()

@provide_session
Expand Down Expand Up @@ -255,48 +254,36 @@ def update_state(self, session=None):
Determines the overall state of the DagRun based on the state
of its TaskInstances.

:return: State
:return: ready_tis: the tis that can be scheduled in the current loop
:rtype ready_tis: list[airflow.models.TaskInstance]
"""

dag = self.get_dag()

tis = self.get_task_instances(session=session)
self.log.debug("Updating state for %s considering %s task(s)", self, len(tis))

ready_tis = []
tis = [ti for ti in self.get_task_instances(session=session,
state=State.task_states + (State.SHUTDOWN,))]
self.log.debug("number of tis tasks for %s: %s task(s)", self, len(tis))
for ti in list(tis):
# skip in db?
if ti.state == State.REMOVED:
tis.remove(ti)
else:
ti.task = dag.get_task(ti.task_id)
ti.task = dag.get_task(ti.task_id)

# pre-calculate
# db is faster
start_dttm = timezone.utcnow()
unfinished_tasks = self.get_task_instances(
state=State.unfinished(),
session=session
)
unfinished_tasks = [t for t in tis if t.state in State.unfinished()]
finished_tasks = [t for t in tis if t.state in State.finished() + [State.UPSTREAM_FAILED]]
none_depends_on_past = all(not t.task.depends_on_past for t in unfinished_tasks)
none_task_concurrency = all(t.task.task_concurrency is None
for t in unfinished_tasks)
# small speed up
if unfinished_tasks and none_depends_on_past and none_task_concurrency:
# todo: this can actually get pretty slow: one task costs between 0.01-015s
no_dependencies_met = True
for ut in unfinished_tasks:
# We need to flag upstream and check for changes because upstream
# failures/re-schedules can result in deadlock false positives
old_state = ut.state
deps_met = ut.are_dependencies_met(
dep_context=DepContext(
flag_upstream_failed=True,
ignore_in_retry_period=True,
ignore_in_reschedule_period=True),
session=session)
if deps_met or old_state != ut.current_state(session=session):
no_dependencies_met = False
break
scheduleable_tasks = [ut for ut in unfinished_tasks if ut.state in SCHEDULEABLE_STATES]
self.log.debug(
"number of scheduleable tasks for %s: %s task(s)",
self, len(scheduleable_tasks))
ready_tis, changed_tis = self._get_ready_tis(scheduleable_tasks, finished_tasks, session)
self.log.debug("ready tis length for %s: %s task(s)", self, len(ready_tis))
if none_depends_on_past and none_task_concurrency:
# small speed up
are_runnable_tasks = ready_tis or self._are_premature_tis(
unfinished_tasks, finished_tasks, session) or changed_tis

duration = (timezone.utcnow() - start_dttm).total_seconds() * 1000
Stats.timing("dagrun.dependency-check.{}".format(self.dag_id), duration)
Expand All @@ -323,7 +310,7 @@ def update_state(self, session=None):

# if *all tasks* are deadlocked, the run failed
elif (unfinished_tasks and none_depends_on_past and
none_task_concurrency and no_dependencies_met):
none_task_concurrency and not are_runnable_tasks):
self.log.info('Deadlock; marking run %s failed', self)
self.set_state(State.FAILED)
dag.handle_callback(self, success=False, reason='all_tasks_deadlocked',
Expand All @@ -339,7 +326,35 @@ def update_state(self, session=None):
session.merge(self)
session.commit()

return self.state
return ready_tis

def _get_ready_tis(self, scheduleable_tasks, finished_tasks, session):
ready_tis = []
changed_tis = False
for st in scheduleable_tasks:
st_old_state = st.state
if st.are_dependencies_met(
dep_context=DepContext(
flag_upstream_failed=True,
finished_tasks=finished_tasks),
session=session):
ready_tis.append(st)
elif st_old_state != st.current_state(session=session):
changed_tis = True
return ready_tis, changed_tis

def _are_premature_tis(self, unfinished_tasks, finished_tasks, session):
# there might be runnable tasks that are up for retry and from some reason(retry delay, etc) are
# not ready yet so we set the flags to count them in
for ut in unfinished_tasks:
if ut.are_dependencies_met(
dep_context=DepContext(
flag_upstream_failed=True,
ignore_in_retry_period=True,
ignore_in_reschedule_period=True,
finished_tasks=finished_tasks),
session=session):
return True

def _emit_duration_stats_for_finished_state(self):
if self.state == State.RUNNING:
Expand Down
58 changes: 28 additions & 30 deletions airflow/ti_deps/deps/trigger_rule_dep.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# specific language governing permissions and limitations
# under the License.

from sqlalchemy import case, func
from collections import Counter

import airflow
from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
Expand All @@ -34,11 +34,32 @@ class TriggerRuleDep(BaseTIDep):
IGNOREABLE = True
IS_TASK_DEP = True

@staticmethod
@provide_session
def _get_states_count_upstream_ti(ti, finished_tasks, session):
"""
This function returns the states of the upstream tis for a specific ti in order to determine
whether this ti can run in this iteration

:param ti: the ti that we want to calculate deps for
:type ti: airflow.models.TaskInstance
:param finished_tasks: all the finished tasks of the dag_run
:type finished_tasks: list[airflow.models.TaskInstance]
"""
if finished_tasks is None:
# this is for the strange feature of running tasks without dag_run
finished_tasks = ti.task.dag.get_task_instances(
start_date=ti.execution_date,
end_date=ti.execution_date,
state=State.finished() + [State.UPSTREAM_FAILED],
session=session)
counter = Counter(task.state for task in finished_tasks if task.task_id in ti.task.upstream_task_ids)
return counter.get(State.SUCCESS, 0), counter.get(State.SKIPPED, 0), counter.get(State.FAILED, 0), \
counter.get(State.UPSTREAM_FAILED, 0), sum(counter.values())

@provide_session
def _get_dep_statuses(self, ti, session, dep_context):
TI = airflow.models.TaskInstance
TR = airflow.utils.trigger_rule.TriggerRule

# Checking that all upstream dependencies have succeeded
if not ti.task.upstream_list:
yield self._passing_status(
Expand All @@ -48,34 +69,11 @@ def _get_dep_statuses(self, ti, session, dep_context):
if ti.task.trigger_rule == TR.DUMMY:
yield self._passing_status(reason="The task had a dummy trigger rule set.")
return
# see if the task name is in the task upstream for our task
successes, skipped, failed, upstream_failed, done = self._get_states_count_upstream_ti(
ti=ti,
finished_tasks=dep_context.finished_tasks)

# TODO(unknown): this query becomes quite expensive with dags that have many
# tasks. It should be refactored to let the task report to the dag run and get the
# aggregates from there.
qry = (
session
.query(
func.coalesce(func.sum(
case([(TI.state == State.SUCCESS, 1)], else_=0)), 0),
func.coalesce(func.sum(
case([(TI.state == State.SKIPPED, 1)], else_=0)), 0),
func.coalesce(func.sum(
case([(TI.state == State.FAILED, 1)], else_=0)), 0),
func.coalesce(func.sum(
case([(TI.state == State.UPSTREAM_FAILED, 1)], else_=0)), 0),
func.count(TI.task_id),
)
.filter(
TI.dag_id == ti.dag_id,
TI.task_id.in_(ti.task.upstream_task_ids),
TI.execution_date == ti.execution_date,
TI.state.in_([
State.SUCCESS, State.FAILED,
State.UPSTREAM_FAILED, State.SKIPPED]),
)
)

successes, skipped, failed, upstream_failed, done = qry.first()
for dep_status in self._evaluate_trigger_rule(
ti=ti,
successes=successes,
Expand Down
1 change: 1 addition & 0 deletions tests/jobs/test_backfill_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -1300,6 +1300,7 @@ def test_backfill_execute_subdag_with_removed_task(self):

session = settings.Session()
session.merge(removed_task_ti)
session.commit()

with timeout(seconds=30):
job.run()
Expand Down
4 changes: 2 additions & 2 deletions tests/jobs/test_scheduler_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -1220,8 +1220,8 @@ def test_dagrun_root_fail_unfinished(self):
ti = dr.get_task_instance('test_dagrun_unfinished', session=session)
ti.state = State.NONE
session.commit()
dr_state = dr.update_state()
self.assertEqual(dr_state, State.RUNNING)
dr.update_state()
self.assertEqual(dr.state, State.RUNNING)

def test_dagrun_root_after_dagrun_unfinished(self):
"""
Expand Down
20 changes: 10 additions & 10 deletions tests/models/test_dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,8 @@ def test_dagrun_success_when_all_skipped(self):
dag_run = self.create_dag_run(dag=dag,
state=State.RUNNING,
task_states=initial_task_states)
updated_dag_state = dag_run.update_state()
self.assertEqual(State.SUCCESS, updated_dag_state)
dag_run.update_state()
self.assertEqual(State.SUCCESS, dag_run.state)

def test_dagrun_success_conditions(self):
session = settings.Session()
Expand Down Expand Up @@ -202,15 +202,15 @@ def test_dagrun_success_conditions(self):
ti_op4 = dr.get_task_instance(task_id=op4.task_id)

# root is successful, but unfinished tasks
state = dr.update_state()
self.assertEqual(State.RUNNING, state)
dr.update_state()
self.assertEqual(State.RUNNING, dr.state)

# one has failed, but root is successful
ti_op2.set_state(state=State.FAILED, session=session)
ti_op3.set_state(state=State.SUCCESS, session=session)
ti_op4.set_state(state=State.SUCCESS, session=session)
state = dr.update_state()
self.assertEqual(State.SUCCESS, state)
dr.update_state()
self.assertEqual(State.SUCCESS, dr.state)

def test_dagrun_deadlock(self):
session = settings.Session()
Expand Down Expand Up @@ -325,8 +325,8 @@ def on_success_callable(context):
dag_run = self.create_dag_run(dag=dag,
state=State.RUNNING,
task_states=initial_task_states)
updated_dag_state = dag_run.update_state()
self.assertEqual(State.SUCCESS, updated_dag_state)
dag_run.update_state()
self.assertEqual(State.SUCCESS, dag_run.state)

def test_dagrun_failure_callback(self):
def on_failure_callable(context):
Expand Down Expand Up @@ -356,8 +356,8 @@ def on_failure_callable(context):
dag_run = self.create_dag_run(dag=dag,
state=State.RUNNING,
task_states=initial_task_states)
updated_dag_state = dag_run.update_state()
self.assertEqual(State.FAILED, updated_dag_state)
dag_run.update_state()
self.assertEqual(State.FAILED, dag_run.state)

def test_dagrun_set_state_end_date(self):
session = settings.Session()
Expand Down
Loading