Skip to content

Commit

Permalink
AIP45 Remove dag parsing in airflow run local (#21877)
Browse files Browse the repository at this point in the history
  • Loading branch information
pingzh authored May 12, 2022
1 parent 75c6092 commit 3138604
Show file tree
Hide file tree
Showing 25 changed files with 668 additions and 817 deletions.
2 changes: 0 additions & 2 deletions airflow/cli/cli_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,6 @@ def string_lower_type(val):
("--ship-dag",), help="Pickles (serializes) the DAG and ships it to the worker", action="store_true"
)
ARG_PICKLE = Arg(("-p", "--pickle"), help="Serialized pickle object of the entire dag (used internally)")
ARG_ERROR_FILE = Arg(("--error-file",), help="File to store task failure error")
ARG_JOB_ID = Arg(("-j", "--job-id"), help=argparse.SUPPRESS)
ARG_CFG_PATH = Arg(("--cfg-path",), help="Path to config file to use instead of airflow.cfg")
ARG_MAP_INDEX = Arg(('--map-index',), type=int, default=-1, help="Mapped task index")
Expand Down Expand Up @@ -1264,7 +1263,6 @@ class GroupCommand(NamedTuple):
ARG_PICKLE,
ARG_JOB_ID,
ARG_INTERACTIVE,
ARG_ERROR_FILE,
ARG_SHUT_DOWN_LOGGING,
ARG_MAP_INDEX,
),
Expand Down
7 changes: 5 additions & 2 deletions airflow/cli/commands/task_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from airflow.utils import cli as cli_utils
from airflow.utils.cli import (
get_dag,
get_dag_by_deserialization,
get_dag_by_file_location,
get_dag_by_pickle,
get_dags,
Expand Down Expand Up @@ -258,7 +259,6 @@ def _run_raw_task(args, ti: TaskInstance) -> None:
mark_success=args.mark_success,
job_id=args.job_id,
pool=args.pool,
error_file=args.error_file,
)


Expand Down Expand Up @@ -357,7 +357,10 @@ def task_run(args, dag=None):
print(f'Loading pickle id: {args.pickle}')
dag = get_dag_by_pickle(args.pickle)
elif not dag:
dag = get_dag(args.subdir, args.dag_id)
if args.local:
dag = get_dag_by_deserialization(args.dag_id)
else:
dag = get_dag(args.subdir, args.dag_id)
else:
# Use DAG from parameter
pass
Expand Down
2 changes: 1 addition & 1 deletion airflow/dag_processing/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,7 @@ def _execute_task_callbacks(self, dagbag: DagBag, request: TaskCallbackRequest):
ti = TI(task, run_id=simple_ti.run_id, map_index=simple_ti.map_index)
# TODO: Use simple_ti to improve performance here in the future
ti.refresh_from_db()
ti.handle_failure_with_callback(error=request.msg, test_mode=self.UNIT_TEST_MODE)
ti.handle_failure(error=request.msg, test_mode=self.UNIT_TEST_MODE)
self.log.info('Executed failure callback for %s in state %s', ti, ti.state)

@provide_session
Expand Down
3 changes: 0 additions & 3 deletions airflow/executors/debug_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ def sync(self) -> None:
self.log.info("Executor is terminated! Stopping %s to %s", ti.key, State.FAILED)
ti.set_state(State.FAILED)
self.change_state(ti.key, State.FAILED)
ti._run_finished_callback()
continue

task_succeeded = self._run_task(ti)
Expand All @@ -78,12 +77,10 @@ def _run_task(self, ti: TaskInstance) -> bool:
params = self.tasks_params.pop(ti.key, {})
ti._run_raw_task(job_id=ti.job_id, **params)
self.change_state(key, State.SUCCESS)
ti._run_finished_callback()
return True
except Exception as e:
ti.set_state(State.FAILED)
self.change_state(key, State.FAILED)
ti._run_finished_callback(error=e)
self.log.exception("Failed to execute task: %s.", str(e))
return False

Expand Down
2 changes: 1 addition & 1 deletion airflow/jobs/backfill_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def _manage_executor_state(
f"{ti.state}. Was the task killed externally? Info: {info}"
)
self.log.error(msg)
ti.handle_failure_with_callback(error=msg)
ti.handle_failure(error=msg)
continue
if ti.state not in self.STATES_COUNT_AS_RUNNING:
# Don't use ti.task; if this task is mapped, that attribute
Expand Down
73 changes: 26 additions & 47 deletions airflow/jobs/local_task_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ def __init__(
# terminate multiple times
self.terminating = False

self._state_change_checks = 0

super().__init__(*args, **kwargs)

def _execute(self):
Expand All @@ -84,7 +86,6 @@ def signal_handler(signum, frame):
self.log.error("Received SIGTERM. Terminating subprocesses")
self.task_runner.terminate()
self.handle_task_exit(128 + signum)
return

signal.signal(signal.SIGTERM, signal_handler)

Expand All @@ -106,13 +107,15 @@ def signal_handler(signum, frame):

heartbeat_time_limit = conf.getint('scheduler', 'scheduler_zombie_task_threshold')

# task callback invocation happens either here or in
# self.heartbeat() instead of taskinstance._run_raw_task to
# avoid race conditions
#
# When self.terminating is set to True by heartbeat_callback, this
# loop should not be restarted. Otherwise self.handle_task_exit
# will be invoked and we will end up with duplicated callbacks
# LocalTaskJob should not run callbacks, which are handled by TaskInstance._run_raw_task
# 1, LocalTaskJob does not parse DAG, thus cannot run callbacks
# 2, The run_as_user of LocalTaskJob is likely not same as the TaskInstance._run_raw_task.
# When run_as_user is specified, the process owner of the LocalTaskJob must be sudoable.
# It is not secure to run callbacks with sudoable users.

# If _run_raw_task receives SIGKILL, scheduler will mark it as zombie and invoke callbacks
# If LocalTaskJob receives SIGTERM, LocalTaskJob passes SIGTERM to _run_raw_task
# If the state of task_instance is changed, LocalTaskJob sends SIGTERM to _run_raw_task
while not self.terminating:
# Monitor the task to see if it's done. Wait in a syscall
# (`os.wait`) for as long as possible so we notice the
Expand Down Expand Up @@ -150,26 +153,18 @@ def signal_handler(signum, frame):
self.on_kill()

def handle_task_exit(self, return_code: int) -> None:
"""Handle case where self.task_runner exits by itself or is externally killed"""
"""
Handle case where self.task_runner exits by itself or is externally killed
Dont run any callbacks
"""
# Without setting this, heartbeat may get us
self.terminating = True
self.log.info("Task exited with return code %s", return_code)
self.task_instance.refresh_from_db()

if self.task_instance.state == State.RUNNING:
# This is for a case where the task received a SIGKILL
# while running or the task runner received a sigterm
self.task_instance.handle_failure(error=None)
# We need to check for error file
# in case it failed due to runtime exception/error
error = None
if self.task_instance.state != State.SUCCESS:
error = self.task_runner.deserialize_run_error()
self.task_instance._run_finished_callback(error=error)
if not self.task_instance.test_mode:
if conf.getboolean('scheduler', 'schedule_after_task_execution', fallback=True):
self._run_mini_scheduler_on_child_tasks()
self._update_dagrun_state_for_paused_dag()

def on_kill(self):
self.task_runner.terminate()
Expand Down Expand Up @@ -217,19 +212,16 @@ def heartbeat_callback(self, session=None):
dagrun_timeout = ti.task.dag.dagrun_timeout
if dagrun_timeout and execution_time > dagrun_timeout:
self.log.warning("DagRun timed out after %s.", str(execution_time))
self.log.warning(
"State of this instance has been externally set to %s. Terminating instance.", ti.state
)
self.task_runner.terminate()
if ti.state == State.SUCCESS:
error = None
else:
# if ti.state is not set by taskinstance.handle_failure, then
# error file will not be populated and it must be updated by
# external source such as web UI
error = self.task_runner.deserialize_run_error() or "task marked as failed externally"
ti._run_finished_callback(error=error)
self.terminating = True

# potential race condition, the _run_raw_task commits `success` or other state
# but task_runner does not exit right away due to slow process shutdown or any other reasons
# let's do a throttle here, if the above case is true, the handle_task_exit will handle it
if self._state_change_checks >= 1: # defer to next round of heartbeat
self.log.warning(
"State of this instance has been externally set to %s. Terminating instance.", ti.state
)
self.terminating = True
self._state_change_checks += 1

@provide_session
@Sentry.enrich_errors
Expand Down Expand Up @@ -282,19 +274,6 @@ def _run_mini_scheduler_on_child_tasks(self, session=None) -> None:
)
session.rollback()

@provide_session
def _update_dagrun_state_for_paused_dag(self, session=None):
"""
Checks for paused dags with DagRuns in the running state and
update the DagRun state if possible
"""
dag = self.task_instance.task.dag
if dag.get_is_paused():
dag_run = self.task_instance.get_dagrun(session=session)
if dag_run:
dag_run.dag = dag
dag_run.update_state(session=session, execute_callbacks=True)

@staticmethod
def _enable_task_listeners():
"""
Expand Down
22 changes: 22 additions & 0 deletions airflow/jobs/scheduler_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def __init__(
self.processor_agent: Optional[DagFileProcessorAgent] = None

self.dagbag = DagBag(dag_folder=self.subdir, read_dags_from_db=True, load_op_links=False)
self._paused_dag_without_running_dagruns: Set = set()

if conf.getboolean('smart_sensor', 'use_smart_sensor'):
compatible_sensors = set(
Expand Down Expand Up @@ -764,6 +765,26 @@ def _execute(self) -> None:
self.log.exception("Exception when executing DagFileProcessorAgent.end")
self.log.info("Exited execute loop")

def _update_dag_run_state_for_paused_dags(self):
try:
paused_dag_ids = DagModel.get_all_paused_dag_ids()
for dag_id in paused_dag_ids:
if dag_id in self._paused_dag_without_running_dagruns:
continue

dag = SerializedDagModel.get_dag(dag_id)
if dag is None:
continue
dag_runs = DagRun.find(dag_id=dag_id, state=State.RUNNING)
for dag_run in dag_runs:
dag_run.dag = dag
_, callback_to_run = dag_run.update_state(execute_callbacks=False)
if callback_to_run:
self._send_dag_callbacks_to_processor(dag, callback_to_run)
self._paused_dag_without_running_dagruns.add(dag_id)
except Exception as e: # should not fail the scheduler
self.log.exception('Failed to update dag run state for paused dags due to %s', str(e))

def _run_scheduler_loop(self) -> None:
"""
The actual scheduler loop. The main steps in the loop are:
Expand Down Expand Up @@ -809,6 +830,7 @@ def _run_scheduler_loop(self) -> None:
conf.getfloat('scheduler', 'zombie_detection_interval', fallback=10.0),
self._find_zombies,
)
timers.call_regular_interval(60.0, self._update_dag_run_state_for_paused_dags)

for loop_count in itertools.count(start=1):
with Stats.timer() as timer:
Expand Down
9 changes: 9 additions & 0 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -2770,6 +2770,15 @@ def get_dagmodel(dag_id, session=NEW_SESSION):
def get_current(cls, dag_id, session=NEW_SESSION):
return session.query(cls).filter(cls.dag_id == dag_id).first()

@staticmethod
@provide_session
def get_all_paused_dag_ids(session: Session = NEW_SESSION) -> Set[str]:
"""Get a set of paused DAG ids"""
paused_dag_ids = session.query(DagModel.dag_id).filter(DagModel.is_paused == expression.true()).all()

paused_dag_ids = {paused_dag_id for paused_dag_id, in paused_dag_ids}
return paused_dag_ids

@provide_session
def get_last_dagrun(self, session=NEW_SESSION, include_externally_triggered=False):
return get_last_dagrun(
Expand Down
8 changes: 8 additions & 0 deletions airflow/models/serialized_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,14 @@ def has_dag(cls, dag_id: str, session: Session = None) -> bool:
"""
return session.query(literal(True)).filter(cls.dag_id == dag_id).first() is not None

@classmethod
@provide_session
def get_dag(cls, dag_id: str, session: Session = None) -> Optional['SerializedDAG']:
row = cls.get(dag_id, session=session)
if row:
return row.dag
return None

@classmethod
@provide_session
def get(cls, dag_id: str, session: Session = None) -> Optional['SerializedDagModel']:
Expand Down
Loading

0 comments on commit 3138604

Please sign in to comment.