From 7947b72eee61a4596c5d8667f8442d32dcbf3f6d Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Sat, 8 Jan 2022 17:09:02 +0800 Subject: [PATCH] Rewrite DAG run retrieval in task command (#20737) --- airflow/cli/commands/task_command.py | 87 ++++++++++++++++++++-------- 1 file changed, 63 insertions(+), 24 deletions(-) diff --git a/airflow/cli/commands/task_command.py b/airflow/cli/commands/task_command.py index 9ea4f4daf5bbe..9458d05aed53d 100644 --- a/airflow/cli/commands/task_command.py +++ b/airflow/cli/commands/task_command.py @@ -16,24 +16,27 @@ # specific language governing permissions and limitations # under the License. """Task sub-commands""" +import datetime import importlib import json import logging import os import textwrap -from contextlib import contextmanager, redirect_stderr, redirect_stdout, suppress +from contextlib import contextmanager, redirect_stderr, redirect_stdout from typing import List, Optional from pendulum.parsing.exceptions import ParserError from sqlalchemy.orm.exc import NoResultFound +from sqlalchemy.orm.session import Session from airflow import settings from airflow.cli.simple_table import AirflowConsole from airflow.configuration import conf -from airflow.exceptions import AirflowException, DagRunNotFound +from airflow.exceptions import AirflowException, DagRunNotFound, TaskInstanceNotFound from airflow.executors.executor_loader import ExecutorLoader from airflow.jobs.local_task_job import LocalTaskJob from airflow.models import DagPickle, TaskInstance +from airflow.models.baseoperator import BaseOperator from airflow.models.dag import DAG from airflow.models.dagrun import DagRun from airflow.models.xcom import IN_MEMORY_DAGRUN_ID @@ -50,46 +53,82 @@ from airflow.utils.dates import timezone from airflow.utils.log.logging_mixin import StreamLogWriter from airflow.utils.net import get_hostname -from airflow.utils.session import create_session, provide_session - - -def _get_dag_run(dag, exec_date_or_run_id, create_if_necessary, session): +from airflow.utils.session import NEW_SESSION, create_session, provide_session + + +def _get_dag_run( + *, + dag: DAG, + exec_date_or_run_id: str, + create_if_necessary: bool, + session: Session, +) -> DagRun: + """Try to retrieve a DAG run from a string representing either a run ID or logical date. + + This checks DAG runs like this: + + 1. If the input ``exec_date_or_run_id`` matches a DAG run ID, return the run. + 2. Try to parse the input as a date. If that works, and the resulting + date matches a DAG run's logical date, return the run. + 3. If ``create_if_necessary`` is *False* and the input works for neither of + the above, raise ``DagRunNotFound``. + 4. Try to create a new DAG run. If the input looks like a date, use it as + the logical date; otherwise use it as a run ID and set the logical date + to the current time. + """ dag_run = dag.get_dagrun(run_id=exec_date_or_run_id, session=session) if dag_run: return dag_run - execution_date = None - with suppress(ParserError, TypeError): - execution_date = timezone.parse(exec_date_or_run_id) + try: + execution_date: Optional[datetime.datetime] = timezone.parse(exec_date_or_run_id) + except (ParserError, TypeError): + execution_date = None - if create_if_necessary and not execution_date: - return DagRun(dag_id=dag.dag_id, run_id=exec_date_or_run_id) try: return ( session.query(DagRun) - .filter( - DagRun.dag_id == dag.dag_id, - DagRun.execution_date == execution_date, - ) + .filter(DagRun.dag_id == dag.dag_id, DagRun.execution_date == execution_date) .one() ) except NoResultFound: - if create_if_necessary: - return DagRun(dag.dag_id, run_id=IN_MEMORY_DAGRUN_ID, execution_date=execution_date) - raise DagRunNotFound( - f"DagRun for {dag.dag_id} with run_id or execution_date of {exec_date_or_run_id!r} not found" - ) from None + if not create_if_necessary: + raise DagRunNotFound( + f"DagRun for {dag.dag_id} with run_id or execution_date of {exec_date_or_run_id!r} not found" + ) from None + + if execution_date is not None: + return DagRun(dag.dag_id, run_id=IN_MEMORY_DAGRUN_ID, execution_date=execution_date) + return DagRun(dag.dag_id, run_id=exec_date_or_run_id, execution_date=timezone.utcnow()) @provide_session -def _get_ti(task, exec_date_or_run_id, create_if_necessary=False, session=None): +def _get_ti( + task: BaseOperator, + exec_date_or_run_id: str, + *, + create_if_necessary: bool = False, + session: Session = NEW_SESSION, +) -> TaskInstance: """Get the task instance through DagRun.run_id, if that fails, get the TI the old way""" - dag_run = _get_dag_run(task.dag, exec_date_or_run_id, create_if_necessary, session) + dag_run = _get_dag_run( + dag=task.dag, + exec_date_or_run_id=exec_date_or_run_id, + create_if_necessary=create_if_necessary, + session=session, + ) - ti = dag_run.get_task_instance(task.task_id) - if not ti and create_if_necessary: + ti_or_none = dag_run.get_task_instance(task.task_id) + if ti_or_none is None: + if not create_if_necessary: + raise TaskInstanceNotFound( + f"TaskInstance for {task.dag.dag_id}, {task.task_id} with " + f"run_id or execution_date of {exec_date_or_run_id!r} not found" + ) ti = TaskInstance(task, run_id=dag_run.run_id) ti.dag_run = dag_run + else: + ti = ti_or_none ti.refresh_from_task(task) return ti