Skip to content

Commit

Permalink
Rewrite DAG run retrieval in task command (#20737)
Browse files Browse the repository at this point in the history
  • Loading branch information
uranusjr authored Jan 8, 2022
1 parent 384fa4a commit 7947b72
Showing 1 changed file with 63 additions and 24 deletions.
87 changes: 63 additions & 24 deletions airflow/cli/commands/task_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,27 @@
# specific language governing permissions and limitations
# under the License.
"""Task sub-commands"""
import datetime
import importlib
import json
import logging
import os
import textwrap
from contextlib import contextmanager, redirect_stderr, redirect_stdout, suppress
from contextlib import contextmanager, redirect_stderr, redirect_stdout
from typing import List, Optional

from pendulum.parsing.exceptions import ParserError
from sqlalchemy.orm.exc import NoResultFound
from sqlalchemy.orm.session import Session

from airflow import settings
from airflow.cli.simple_table import AirflowConsole
from airflow.configuration import conf
from airflow.exceptions import AirflowException, DagRunNotFound
from airflow.exceptions import AirflowException, DagRunNotFound, TaskInstanceNotFound
from airflow.executors.executor_loader import ExecutorLoader
from airflow.jobs.local_task_job import LocalTaskJob
from airflow.models import DagPickle, TaskInstance
from airflow.models.baseoperator import BaseOperator
from airflow.models.dag import DAG
from airflow.models.dagrun import DagRun
from airflow.models.xcom import IN_MEMORY_DAGRUN_ID
Expand All @@ -50,46 +53,82 @@
from airflow.utils.dates import timezone
from airflow.utils.log.logging_mixin import StreamLogWriter
from airflow.utils.net import get_hostname
from airflow.utils.session import create_session, provide_session


def _get_dag_run(dag, exec_date_or_run_id, create_if_necessary, session):
from airflow.utils.session import NEW_SESSION, create_session, provide_session


def _get_dag_run(
*,
dag: DAG,
exec_date_or_run_id: str,
create_if_necessary: bool,
session: Session,
) -> DagRun:
"""Try to retrieve a DAG run from a string representing either a run ID or logical date.
This checks DAG runs like this:
1. If the input ``exec_date_or_run_id`` matches a DAG run ID, return the run.
2. Try to parse the input as a date. If that works, and the resulting
date matches a DAG run's logical date, return the run.
3. If ``create_if_necessary`` is *False* and the input works for neither of
the above, raise ``DagRunNotFound``.
4. Try to create a new DAG run. If the input looks like a date, use it as
the logical date; otherwise use it as a run ID and set the logical date
to the current time.
"""
dag_run = dag.get_dagrun(run_id=exec_date_or_run_id, session=session)
if dag_run:
return dag_run

execution_date = None
with suppress(ParserError, TypeError):
execution_date = timezone.parse(exec_date_or_run_id)
try:
execution_date: Optional[datetime.datetime] = timezone.parse(exec_date_or_run_id)
except (ParserError, TypeError):
execution_date = None

if create_if_necessary and not execution_date:
return DagRun(dag_id=dag.dag_id, run_id=exec_date_or_run_id)
try:
return (
session.query(DagRun)
.filter(
DagRun.dag_id == dag.dag_id,
DagRun.execution_date == execution_date,
)
.filter(DagRun.dag_id == dag.dag_id, DagRun.execution_date == execution_date)
.one()
)
except NoResultFound:
if create_if_necessary:
return DagRun(dag.dag_id, run_id=IN_MEMORY_DAGRUN_ID, execution_date=execution_date)
raise DagRunNotFound(
f"DagRun for {dag.dag_id} with run_id or execution_date of {exec_date_or_run_id!r} not found"
) from None
if not create_if_necessary:
raise DagRunNotFound(
f"DagRun for {dag.dag_id} with run_id or execution_date of {exec_date_or_run_id!r} not found"
) from None

if execution_date is not None:
return DagRun(dag.dag_id, run_id=IN_MEMORY_DAGRUN_ID, execution_date=execution_date)
return DagRun(dag.dag_id, run_id=exec_date_or_run_id, execution_date=timezone.utcnow())


@provide_session
def _get_ti(task, exec_date_or_run_id, create_if_necessary=False, session=None):
def _get_ti(
task: BaseOperator,
exec_date_or_run_id: str,
*,
create_if_necessary: bool = False,
session: Session = NEW_SESSION,
) -> TaskInstance:
"""Get the task instance through DagRun.run_id, if that fails, get the TI the old way"""
dag_run = _get_dag_run(task.dag, exec_date_or_run_id, create_if_necessary, session)
dag_run = _get_dag_run(
dag=task.dag,
exec_date_or_run_id=exec_date_or_run_id,
create_if_necessary=create_if_necessary,
session=session,
)

ti = dag_run.get_task_instance(task.task_id)
if not ti and create_if_necessary:
ti_or_none = dag_run.get_task_instance(task.task_id)
if ti_or_none is None:
if not create_if_necessary:
raise TaskInstanceNotFound(
f"TaskInstance for {task.dag.dag_id}, {task.task_id} with "
f"run_id or execution_date of {exec_date_or_run_id!r} not found"
)
ti = TaskInstance(task, run_id=dag_run.run_id)
ti.dag_run = dag_run
else:
ti = ti_or_none
ti.refresh_from_task(task)
return ti

Expand Down

0 comments on commit 7947b72

Please sign in to comment.