From f6f676657ac9187c196a2c1d74732e7f30fde819 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Wed, 8 Sep 2021 21:39:51 +0800 Subject: [PATCH] Make next_dagrun_info take a data interval This modifies the Timetable.next_dagrun_info() to take the previous run's data interval instead of logical date. DAG.next_dagrun_info() is modified to accept 'datetime | DataInterval | None' for compatibility, but the datetime form is deprecated. When receiving a datetime, a compatibility method DAG.infer_automated_data_interval() is called to reverse-infer the data interval from the logical date, which should work for all DAGs prior to AIP-39 implementation. All code paths that need DAG.infer_automated_data_interval() become deprecated: * DAG.following_schedule() * DAG.next_dagrun_info() with a datetime argument * DAG.create_dagrun() without passing in a data interval * DagModel.calculate_dagrun_date_fields() with a datetime argument All existing usages of the newly deprecated code paths are rewritten to avoid triggering the deprecation warning, the most significant one being DAG.following_schedule(). Most of its usages are to calculate the end of the interval, replaced by data_interval.end instead (except those already deprecated, which continue to use DAG.following_schedule() to avoid any unintended change). For clarity, Timetable.infer_data_interval() is renamed to Timetable.infer_manual_data_interval(). --- .../endpoints/dag_run_endpoint.py | 4 +- airflow/cli/commands/dag_command.py | 47 ++- airflow/dag_processing/processor.py | 25 +- airflow/jobs/scheduler_job.py | 25 +- airflow/models/dag.py | 248 ++++++++++++--- airflow/models/dagrun.py | 11 +- airflow/models/taskinstance.py | 40 ++- airflow/operators/latest_only.py | 18 +- .../providers/google/cloud/operators/gcs.py | 21 +- airflow/providers/google/cloud/sensors/gcs.py | 10 +- airflow/sensors/time_delta.py | 15 +- airflow/timetables/base.py | 6 +- airflow/timetables/interval.py | 10 +- airflow/timetables/simple.py | 8 +- airflow/utils/session.py | 6 +- airflow/www/views.py | 7 +- docs/apache-airflow/templates-ref.rst | 4 +- .../endpoints/test_log_endpoint.py | 14 +- tests/cli/commands/test_dag_command.py | 5 +- tests/conftest.py | 59 +++- tests/jobs/test_local_task_job.py | 6 +- tests/jobs/test_scheduler_job.py | 300 ++++++++---------- tests/models/test_dag.py | 31 +- tests/models/test_dagrun.py | 24 +- tests/models/test_taskinstance.py | 9 +- .../apache/druid/operators/test_druid.py | 5 +- .../apache/kylin/operators/test_kylin_cube.py | 2 +- .../spark/operators/test_spark_submit.py | 2 +- tests/sensors/test_external_task_sensor.py | 3 +- .../perf/scheduler_dag_execution_timing.py | 19 +- .../deps/test_runnable_exec_date_dep.py | 3 + tests/utils/log/test_log_reader.py | 2 + tests/utils/test_log_handlers.py | 1 + tests/www/views/test_views_log.py | 4 +- 34 files changed, 607 insertions(+), 387 deletions(-) diff --git a/airflow/api_connexion/endpoints/dag_run_endpoint.py b/airflow/api_connexion/endpoints/dag_run_endpoint.py index e816aac27e157..ab7805e306463 100644 --- a/airflow/api_connexion/endpoints/dag_run_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_run_endpoint.py @@ -257,10 +257,12 @@ def post_dag_run(dag_id, session): .first() ) if not dagrun_instance: - dag_run = current_app.dag_bag.get_dag(dag_id).create_dagrun( + dag = current_app.dag_bag.get_dag(dag_id) + dag_run = dag.create_dagrun( run_type=DagRunType.MANUAL, run_id=run_id, execution_date=logical_date, + data_interval=dag.timetable.infer_manual_data_interval(run_after=logical_date), state=State.QUEUED, conf=post_body.get("conf"), external_trigger=True, diff --git a/airflow/cli/commands/dag_command.py b/airflow/cli/commands/dag_command.py index fffcf68d57541..75503e780bcf0 100644 --- a/airflow/cli/commands/dag_command.py +++ b/airflow/cli/commands/dag_command.py @@ -23,8 +23,10 @@ import signal import subprocess import sys +from typing import Optional from graphviz.dot import Dot +from sqlalchemy.sql.functions import func from airflow import settings from airflow.api.client import get_current_api_client @@ -255,26 +257,37 @@ def dag_next_execution(args): if dag.get_is_paused(): print("[INFO] Please be reminded this DAG is PAUSED now.", file=sys.stderr) - latest_execution_date = dag.get_latest_execution_date() - if latest_execution_date: - next_execution_dttm = dag.following_schedule(latest_execution_date) + with create_session() as session: + max_date_subq = ( + session.query(func.max(DagRun.execution_date).label("max_date")) + .filter(DagRun.dag_id == dag.dag_id) + .subquery() + ) + max_date_run: Optional[DagRun] = ( + session.query(DagRun) + .filter(DagRun.dag_id == dag.dag_id, DagRun.execution_date == max_date_subq.c.max_date) + .one_or_none() + ) - if next_execution_dttm is None: - print( - "[WARN] No following schedule can be found. " - + "This DAG may have schedule interval '@once' or `None`.", - file=sys.stderr, - ) + if max_date_run is None: + print("[WARN] Only applicable when there is execution record found for the DAG.", file=sys.stderr) print(None) - else: - print(next_execution_dttm.isoformat()) - - for _ in range(1, args.num_executions): - next_execution_dttm = dag.following_schedule(next_execution_dttm) - print(next_execution_dttm.isoformat()) - else: - print("[WARN] Only applicable when there is execution record found for the DAG.", file=sys.stderr) + return + + next_info = dag.next_dagrun_info(dag.get_run_data_interval(max_date_run), restricted=False) + if next_info is None: + print( + "[WARN] No following schedule can be found. " + "This DAG may have schedule interval '@once' or `None`.", + file=sys.stderr, + ) print(None) + return + + print(next_info.logical_date.isoformat()) + for _ in range(1, args.num_executions): + next_info = dag.next_dagrun_info(next_info.data_interval, restricted=False) + print(next_info.logical_date.isoformat()) @cli_utils.action_logging diff --git a/airflow/dag_processing/processor.py b/airflow/dag_processing/processor.py index ce6b9553ec462..53338da0f6daa 100644 --- a/airflow/dag_processing/processor.py +++ b/airflow/dag_processing/processor.py @@ -28,12 +28,14 @@ from setproctitle import setproctitle from sqlalchemy import func, or_ +from sqlalchemy.orm import eagerload from sqlalchemy.orm.session import Session from airflow import models, settings from airflow.configuration import conf from airflow.exceptions import AirflowException, TaskNotFound -from airflow.models import DAG, DagModel, SlaMiss, errors +from airflow.models import SlaMiss, errors +from airflow.models.dag import DAG, DagModel from airflow.models.dagbag import DagBag from airflow.stats import Stats from airflow.utils import timezone @@ -391,6 +393,7 @@ def manage_slas(self, dag: DAG, session: Session = None) -> None: max_tis: Iterator[TI] = ( session.query(TI) + .options(eagerload(TI.dag_run)) .join(TI.dag_run) .filter( TI.dag_id == dag.dag_id, @@ -411,14 +414,20 @@ def manage_slas(self, dag: DAG, session: Session = None) -> None: f"{type(task.sla)} in {task.dag_id}:{task.task_id}" ) - dttm = dag.following_schedule(ti.execution_date) - while dttm < ts: - following_schedule = dag.following_schedule(dttm) - if following_schedule + task.sla < ts: - session.merge( - SlaMiss(task_id=ti.task_id, dag_id=ti.dag_id, execution_date=dttm, timestamp=ts) + sla_misses = [] + next_info = dag.next_dagrun_info(dag.get_run_data_interval(ti.dag_run), restricted=False) + while next_info.logical_date < ts: + next_info = dag.next_dagrun_info(next_info.data_interval, restricted=False) + if next_info.logical_date + task.sla < ts: + sla_miss = SlaMiss( + task_id=ti.task_id, + dag_id=ti.dag_id, + execution_date=next_info.logical_date, + timestamp=ts, ) - dttm = dag.following_schedule(dttm) + sla_misses.append(sla_miss) + if sla_misses: + session.add_all(sla_misses) session.commit() slas: List[SlaMiss] = ( diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py index ec10c054f1d32..cd5103479bd14 100644 --- a/airflow/jobs/scheduler_job.py +++ b/airflow/jobs/scheduler_job.py @@ -849,6 +849,8 @@ def _create_dag_runs(self, dag_models: Iterable[DagModel], session: Session) -> self.log.exception("DAG '%s' not found in serialized_dag table", dag_model.dag_id) continue dag_hash = self.dagbag.dags_hash.get(dag.dag_id) + + data_interval = dag.get_next_data_interval(dag_model) # Explicitly check if the DagRun already exists. This is an edge case # where a Dag Run is created but `DagModel.next_dagrun` and `DagModel.next_dagrun_create_after` # are not updated. @@ -858,19 +860,18 @@ def _create_dag_runs(self, dag_models: Iterable[DagModel], session: Session) -> # create a new one. This is so that in the next Scheduling loop we try to create new runs # instead of falling in a loop of Integrity Error. if (dag.dag_id, dag_model.next_dagrun) not in existing_dagruns: - dag.create_dagrun( run_type=DagRunType.SCHEDULED, execution_date=dag_model.next_dagrun, state=State.QUEUED, - data_interval=dag_model.next_dagrun_data_interval, + data_interval=data_interval, external_trigger=False, session=session, dag_hash=dag_hash, creating_job_id=self.id, ) queued_runs_of_dags[dag_model.dag_id] += 1 - dag_model.calculate_dagrun_date_fields(dag, dag_model.next_dagrun) + dag_model.calculate_dagrun_date_fields(dag, data_interval) # TODO[HA]: Should we do a session.flush() so we don't have to keep lots of state/object in # memory for larger dags? or expunge_all() @@ -894,16 +895,18 @@ def _start_queued_dagruns( .all(), ) - def _update_state(dag_run): + def _update_state(dag: DAG, dag_run: DagRun): dag_run.state = State.RUNNING dag_run.start_date = timezone.utcnow() - expected_start_date = dag.following_schedule(dag_run.execution_date) - if expected_start_date: + if dag.timetable.periodic: + # TODO: Logically, this should be DagRunInfo.run_after, but the + # information is not stored on a DagRun, only before the actual + # execution on DagModel.next_dagrun_create_after. We should add + # a field on DagRun for this instead of relying on the run + # always happening immediately after the data interval. + expected_start_date = dag.get_run_data_interval(dag_run).end schedule_delay = dag_run.start_date - expected_start_date - Stats.timing( - f'dagrun.schedule_delay.{dag.dag_id}', - schedule_delay, - ) + Stats.timing(f'dagrun.schedule_delay.{dag.dag_id}', schedule_delay) for dag_run in dag_runs: @@ -923,7 +926,7 @@ def _update_state(dag_run): ) else: active_runs_of_dags[dag_run.dag_id] += 1 - _update_state(dag_run) + _update_state(dag, dag_run) def _schedule_dag_run( self, diff --git a/airflow/models/dag.py b/airflow/models/dag.py index dc15d217e88fd..510396d776999 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -31,6 +31,7 @@ from inspect import signature from typing import ( TYPE_CHECKING, + Any, Callable, Collection, Dict, @@ -70,7 +71,7 @@ from airflow.models.taskinstance import Context, TaskInstance, TaskInstanceKey, clear_task_instances from airflow.security import permissions from airflow.stats import Stats -from airflow.timetables.base import DagRunInfo, TimeRestriction, Timetable +from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction, Timetable from airflow.timetables.interval import CronDataIntervalTimetable, DeltaDataIntervalTimetable from airflow.timetables.simple import NullTimetable, OnceTimetable from airflow.typing_compat import Literal, RePatternType @@ -106,6 +107,44 @@ DEFAULT_SCHEDULE_INTERVAL = timedelta(days=1) +class InconsistentDataInterval(AirflowException): + """Exception raised when a model populates data interval fields incorrectly. + + The data interval fields should either both be None (for runs scheduled + prior to AIP-39), or both be datetime (for runs scheduled after AIP-39 is + implemented). This is raised if exactly one of the fields is None. + """ + + _template = ( + "Inconsistent {cls}: {start[0]}={start[1]!r}, {end[0]}={end[1]!r}, " + "they must be either both None or both datetime" + ) + + def __init__(self, instance: Any, start_field_name: str, end_field_name: str) -> None: + self._class_name = type(instance).__name__ + self._start_field = (start_field_name, getattr(instance, start_field_name)) + self._end_field = (end_field_name, getattr(instance, end_field_name)) + + def __str__(self) -> str: + return self._template.format(cls=self._class_name, start=self._start_field, end=self._end_field) + + +def _get_model_data_interval( + instance: Any, + start_field_name: str, + end_field_name: str, +) -> Optional[DataInterval]: + start = timezone.coerce_datetime(getattr(instance, start_field_name)) + end = timezone.coerce_datetime(getattr(instance, end_field_name)) + if start is None: + if end is not None: + raise InconsistentDataInterval(instance, start_field_name, end_field_name) + return None + elif end is None: + raise InconsistentDataInterval(instance, start_field_name, end_field_name) + return DataInterval(start, end) + + def create_timetable(interval: ScheduleIntervalArg, timezone: tzinfo) -> Timetable: """Create a Timetable instance from a ``schedule_interval`` argument.""" if interval is ScheduleIntervalArgNotSet: @@ -539,10 +578,13 @@ def following_schedule(self, dttm): :param dttm: utc datetime :return: utc datetime """ - next_info = self.timetable.next_dagrun_info( - last_automated_dagrun=pendulum.instance(dttm), - restriction=TimeRestriction(earliest=None, latest=None, catchup=True), + warnings.warn( + "`DAG.following_schedule()` is deprecated. Use `DAG.next_dagrun_info(restricted=False)` instead.", + category=DeprecationWarning, + stacklevel=2, ) + data_interval = self.infer_automated_data_interval(timezone.coerce_datetime(dttm)) + next_info = self.next_dagrun_info(data_interval, restricted=False) if next_info is None: return None return next_info.data_interval.start @@ -557,11 +599,77 @@ def previous_schedule(self, dttm): ) if not isinstance(self.timetable, _DataIntervalTimetable): return None - return self.timetable._get_prev(pendulum.instance(dttm)) + return self.timetable._get_prev(timezone.coerce_datetime(dttm)) + + def get_next_data_interval(self, dag_model: "DagModel") -> DataInterval: + """Get the data interval of the next scheduled run. + + For compatibility, this method infers the data interval from the DAG's + schedule if the run does not have an explicit one set, which is possible for + runs created prior to AIP-39. + + This function is private to Airflow core and should not be depended as a + part of the Python API. + + :meta private: + """ + if self.dag_id != dag_model.dag_id: + raise ValueError(f"Arguments refer to different DAGs: {self.dag_id} != {dag_model.dag_id}") + data_interval = dag_model.next_dagrun_data_interval + if data_interval is not None: + return data_interval + # Compatibility: runs scheduled before AIP-39 implementation don't have an + # explicit data interval. Try to infer from the logical date. + return self.infer_automated_data_interval(dag_model.next_dagrun) + + def get_run_data_interval(self, run: DagRun) -> DataInterval: + """Get the data interval of this run. + + For compatibility, this method infers the data interval from the DAG's + schedule if the run does not have an explicit one set, which is possible for + runs created prior to AIP-39. + + This function is private to Airflow core and should not be depended as a + part of the Python API. + + :meta private: + """ + if run.dag_id is not None and run.dag_id != self.dag_id: + raise ValueError(f"Arguments refer to different DAGs: {self.dag_id} != {run.dag_id}") + data_interval = _get_model_data_interval(run, "data_interval_start", "data_interval_end") + if data_interval is not None: + return data_interval + # Compatibility: runs created before AIP-39 implementation don't have an + # explicit data interval. Try to infer from the logical date. + return self.infer_automated_data_interval(run.execution_date) + + def infer_automated_data_interval(self, logical_date: datetime) -> DataInterval: + """Infer a data interval for a run against this DAG. + + This method is used to bridge runs created prior to AIP-39 + implementation, which do not have an explicit data interval. Therefore, + this method only considers ``schedule_interval`` values valid prior to + Airflow 2.2. + + DO NOT use this method is there is a known data interval. + """ + timetable_type = type(self.timetable) + if issubclass(timetable_type, (NullTimetable, OnceTimetable)): + return DataInterval.exact(timezone.coerce_datetime(logical_date)) + start = timezone.coerce_datetime(logical_date) + if issubclass(timetable_type, CronDataIntervalTimetable): + end = cast(CronDataIntervalTimetable, self.timetable)._get_next(start) + elif issubclass(timetable_type, DeltaDataIntervalTimetable): + end = cast(DeltaDataIntervalTimetable, self.timetable)._get_next(start) + else: + raise ValueError(f"Not a valid timetable: {self.timetable!r}") + return DataInterval(start, end) def next_dagrun_info( self, - date_last_automated_dagrun: Optional[pendulum.DateTime], + last_automated_dagrun: Union[None, datetime, DataInterval], + *, + restricted: bool = True, ) -> Optional[DagRunInfo]: """Get information about the next DagRun of this dag after ``date_last_automated_dagrun``. @@ -575,19 +683,33 @@ def next_dagrun_info( :param date_last_automated_dagrun: The ``max(execution_date)`` of existing "automated" DagRuns for this dag (scheduled or backfill, but not manual). + :param restricted: If set to *False* (default is *True*), ignore + ``start_date``, ``end_date``, and ``catchup`` specified on the DAG + or tasks. :return: DagRunInfo of the next dagrun, or None if a dagrun is not going to be scheduled. """ # Never schedule a subdag. It will be scheduled by its parent dag. if self.is_subdag: return None - # XXX: The timezone.coerce_datetime calls should not be necessary since - # the function annotation suggests it only accepts pendulum.DateTime, - # and someone is passing datetime.datetime into this function. We should - # fix whatever is doing that. + if isinstance(last_automated_dagrun, datetime): + warnings.warn( + "Passing a datetime to DAG.next_dagrun_info is deprecated. Use a DataInterval instead.", + DeprecationWarning, + stacklevel=2, + ) + data_interval = self.infer_automated_data_interval( + timezone.coerce_datetime(last_automated_dagrun) + ) + else: + data_interval = last_automated_dagrun + if restricted: + restriction = self._time_restriction + else: + restriction = TimeRestriction(earliest=None, latest=None, catchup=True) return self.timetable.next_dagrun_info( - last_automated_dagrun=timezone.coerce_datetime(date_last_automated_dagrun), - restriction=self._time_restriction, + last_automated_data_interval=data_interval, + restriction=restriction, ) def next_dagrun_after_date(self, date_last_automated_dagrun: Optional[pendulum.DateTime]): @@ -596,7 +718,11 @@ def next_dagrun_after_date(self, date_last_automated_dagrun: Optional[pendulum.D category=DeprecationWarning, stacklevel=2, ) - info = self.next_dagrun_info(date_last_automated_dagrun) + if date_last_automated_dagrun is None: + data_interval = None + else: + data_interval = self.infer_automated_data_interval(date_last_automated_dagrun) + info = self.next_dagrun_info(data_interval) if info is None: return None return info.run_after @@ -658,7 +784,7 @@ def iter_dagrun_infos_between( if self.is_subdag: align = False - info = self.timetable.next_dagrun_info(last_automated_dagrun=None, restriction=restriction) + info = self.timetable.next_dagrun_info(last_automated_data_interval=None, restriction=restriction) if info is None: # No runs to be scheduled between the user-supplied timeframe. But # if align=False, "invent" a data interval for the timeframe itself. @@ -675,7 +801,7 @@ def iter_dagrun_infos_between( while info is not None: yield info info = self.timetable.next_dagrun_info( - last_automated_dagrun=info.logical_date, + last_automated_data_interval=info.data_interval, restriction=restriction, ) @@ -700,7 +826,7 @@ def get_run_dates(self, start_date, end_date=None): if end_date is None: latest = pendulum.now(timezone.utc) else: - latest = pendulum.instance(end_date) + latest = timezone.coerce_datetime(end_date) return [info.logical_date for info in self.iter_dagrun_infos_between(earliest, latest)] def normalize_schedule(self, dttm): @@ -709,14 +835,16 @@ def normalize_schedule(self, dttm): category=DeprecationWarning, stacklevel=2, ) - following = self.following_schedule(dttm) - - # in case of @once - if not following: + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + following = self.following_schedule(dttm) + if not following: # in case of @once return dttm - if self.previous_schedule(following) != dttm: + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + previous_of_following = self.previous_schedule(following) + if previous_of_following != dttm: return following - return dttm @provide_session @@ -2150,15 +2278,22 @@ def create_dagrun( "Creating DagRun needs either `run_id` or both `run_type` and `execution_date`" ) - if run_type == DagRunType.MANUAL and data_interval is None and execution_date is not None: - data_interval = self.timetable.infer_data_interval( - run_after=timezone.coerce_datetime(execution_date), + logical_date = timezone.coerce_datetime(execution_date) + if data_interval is None and logical_date is not None: + warnings.warn( + "Calling `DAG.create_dagrun()` without an explicit data interval is deprecated", + DeprecationWarning, + stacklevel=3, ) + if run_type == DagRunType.MANUAL: + data_interval = self.timetable.infer_manual_data_interval(run_after=logical_date) + else: + data_interval = self.infer_automated_data_interval(logical_date) run = DagRun( dag_id=self.dag_id, run_id=run_id, - execution_date=execution_date, + execution_date=logical_date, start_date=start_date, external_trigger=external_trigger, conf=conf, @@ -2214,7 +2349,7 @@ def bulk_write_to_db(cls, dags: Collection["DAG"], session=None): .options(joinedload(DagModel.tags, innerjoin=False)) .filter(DagModel.dag_id.in_(dag_ids)) ) - orm_dags = with_row_locks(query, of=DagModel, session=session).all() + orm_dags: List[DagModel] = with_row_locks(query, of=DagModel, session=session).all() existing_dag_ids = {orm_dag.dag_id for orm_dag in orm_dags} missing_dag_ids = dag_ids.difference(existing_dag_ids) @@ -2230,18 +2365,20 @@ def bulk_write_to_db(cls, dags: Collection["DAG"], session=None): orm_dags.append(orm_dag) # Get the latest dag run for each existing dag as a single query (avoid n+1 query) - most_recent_dag_runs = dict( - session.query(DagRun.dag_id, func.max_(DagRun.execution_date)) + most_recent_subq = ( + session.query(DagRun.dag_id, func.max(DagRun.execution_date).label("max_execution_date")) .filter( DagRun.dag_id.in_(existing_dag_ids), - or_( - DagRun.run_type == DagRunType.BACKFILL_JOB, - DagRun.run_type == DagRunType.SCHEDULED, - ), + or_(DagRun.run_type == DagRunType.BACKFILL_JOB, DagRun.run_type == DagRunType.SCHEDULED), ) .group_by(DagRun.dag_id) - .all() + .subquery() ) + most_recent_runs_iter = session.query(DagRun).filter( + DagRun.dag_id == most_recent_subq.c.dag_id, + DagRun.execution_date == most_recent_subq.c.max_execution_date, + ) + most_recent_runs = {run.dag_id: run for run in most_recent_runs_iter} filelocs = [] @@ -2266,10 +2403,12 @@ def bulk_write_to_db(cls, dags: Collection["DAG"], session=None): orm_dag.max_active_runs = dag.max_active_runs orm_dag.has_task_concurrency_limits = any(t.max_active_tis_per_dag is not None for t in dag.tasks) - orm_dag.calculate_dagrun_date_fields( - dag, - most_recent_dag_runs.get(dag.dag_id), - ) + run: Optional[DagRun] = most_recent_runs.get(dag.dag_id) + if run is None: + data_interval = None + else: + data_interval = dag.get_run_data_interval(run) + orm_dag.calculate_dagrun_date_fields(dag, data_interval) for orm_tag in list(orm_dag.tags): if orm_tag.name not in orm_dag.tags: @@ -2541,17 +2680,12 @@ def __repr__(self): return f"" @property - def next_dagrun_data_interval(self) -> Optional[Tuple[datetime, datetime]]: - if self.next_dagrun_data_interval_start is None: - if self.next_dagrun_data_interval_end is not None: - raise AirflowException( - f"Inconsistent DagModel: " - f"next_dagrun_data_interval_start={self.next_dagrun_data_interval_start!r}, " - f"next_dagrun_data_interval_end={self.next_dagrun_data_interval_end!r}; " - f"they must be either both None or both datetime" - ) - return None - return (self.next_dagrun_data_interval_start, self.next_dagrun_data_interval_end) + def next_dagrun_data_interval(self) -> Optional[DataInterval]: + return _get_model_data_interval( + self, + "next_dagrun_data_interval_start", + "next_dagrun_data_interval_end", + ) @next_dagrun_data_interval.setter def next_dagrun_data_interval(self, value: Optional[Tuple[datetime, datetime]]) -> None: @@ -2694,7 +2828,9 @@ def dags_needing_dagruns(cls, session: Session): return with_row_locks(query, of=cls, session=session, **skip_locked(session=session)) def calculate_dagrun_date_fields( - self, dag: DAG, most_recent_dag_run: Optional[pendulum.DateTime] + self, + dag: DAG, + most_recent_dag_run: Union[None, datetime, DataInterval], ) -> None: """ Calculate ``next_dagrun`` and `next_dagrun_create_after`` @@ -2702,7 +2838,17 @@ def calculate_dagrun_date_fields( :param dag: The DAG object :param most_recent_dag_run: DateTime of most recent run of this dag, or none if not yet scheduled. """ - next_dagrun_info = dag.next_dagrun_info(most_recent_dag_run) + if isinstance(most_recent_dag_run, datetime): + warnings.warn( + "Passing a datetime to `DagModel.calculate_dagrun_date_fields` is deprecated. " + "Provide a data interval instead.", + DeprecationWarning, + stacklevel=2, + ) + most_recent_data_interval = dag.infer_automated_data_interval(most_recent_dag_run) + else: + most_recent_data_interval = most_recent_dag_run + next_dagrun_info = dag.next_dagrun_info(most_recent_data_interval) if next_dagrun_info is None: self.next_dagrun_data_interval = self.next_dagrun = self.next_dagrun_create_after = None else: diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index 59b01a7f99501..1dacf5213438b 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -649,10 +649,13 @@ def _emit_true_scheduling_delay_stats_for_finished_state(self, finished_tis): ordered_tis_by_start_date.sort(key=lambda ti: ti.start_date, reverse=False) first_start_date = ordered_tis_by_start_date[0].start_date if first_start_date: - # dag.following_schedule calculates the expected start datetime for a scheduled dagrun - # i.e. a daily flow for execution date 1/1/20 actually runs on 1/2/20 hh:mm:ss, - # and ti.start_date will be 1/2/20 hh:mm:ss so the following schedule is comparison - true_delay = first_start_date - dag.following_schedule(self.execution_date) + # TODO: Logically, this should be DagRunInfo.run_after, but the + # information is not stored on a DagRun, only before the actual + # execution on DagModel.next_dagrun_create_after. We should add + # a field on DagRun for this instead of relying on the run + # always happening immediately after the data interval. + data_interval_end = dag.get_run_data_interval(self).end + true_delay = first_start_date - data_interval_end if true_delay.total_seconds() > 0: Stats.timing(f'dagrun.{dag.dag_id}.first_task_scheduling_delay', true_delay) except Exception as e: diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index e7aaea1af0d73..2a74bde0807c0 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -84,6 +84,7 @@ from airflow.stats import Stats from airflow.ti_deps.dep_context import DepContext from airflow.ti_deps.dependencies_deps import REQUEUEABLE_DEPS, RUNNING_DEPS +from airflow.timetables.base import DataInterval from airflow.typing_compat import Literal from airflow.utils import timezone from airflow.utils.email import send_email @@ -1748,6 +1749,7 @@ def get_template_context(self, session: Session = None) -> Context: if not session: session = settings.Session() task = self.task + dag: DAG = task.dag from airflow import macros integrate_macros_plugins() @@ -1757,36 +1759,40 @@ def get_template_context(self, session: Session = None) -> Context: params = {} # type: Dict[str, Any] with contextlib.suppress(AttributeError): - params.update(task.dag.params) + params.update(dag.params) if task.params: params.update(task.params) if conf.getboolean('core', 'dag_run_conf_overrides_params'): self.overwrite_params_with_dag_run_conf(params=params, dag_run=dag_run) - # DagRuns scheduled prior to Airflow 2.2 and by tests don't always have - # a data interval, and we default to execution_date for compatibility. - compat_interval_start = timezone.coerce_datetime(dag_run.data_interval_start or self.execution_date) - ds = compat_interval_start.strftime('%Y-%m-%d') + interval_start = dag.get_run_data_interval(dag_run).start + ds = interval_start.strftime('%Y-%m-%d') ds_nodash = ds.replace('-', '') - ts = compat_interval_start.isoformat() - ts_nodash = compat_interval_start.strftime('%Y%m%dT%H%M%S') + ts = interval_start.isoformat() + ts_nodash = interval_start.strftime('%Y%m%dT%H%M%S') ts_nodash_with_tz = ts.replace('-', '').replace(':', '') @cache # Prevent multiple database access. def _get_previous_dagrun_success() -> Optional["DagRun"]: return self.get_previous_dagrun(state=State.SUCCESS, session=session) - def get_prev_data_interval_start_success() -> Optional[pendulum.DateTime]: + def _get_previous_dagrun_data_interval_success() -> Optional["DataInterval"]: dagrun = _get_previous_dagrun_success() if dagrun is None: return None - return timezone.coerce_datetime(dagrun.data_interval_start) + return dag.get_run_data_interval(dagrun) + + def get_prev_data_interval_start_success() -> Optional[pendulum.DateTime]: + data_interval = _get_previous_dagrun_data_interval_success() + if data_interval is None: + return None + return data_interval.start def get_prev_data_interval_end_success() -> Optional[pendulum.DateTime]: - dagrun = _get_previous_dagrun_success() - if dagrun is None: + data_interval = _get_previous_dagrun_data_interval_success() + if data_interval is None: return None - return timezone.coerce_datetime(dagrun.data_interval_end) + return data_interval.end def get_prev_start_date_success() -> Optional[pendulum.DateTime]: dagrun = _get_previous_dagrun_success() @@ -1912,9 +1918,11 @@ def get_next_execution_date() -> Optional[pendulum.DateTime]: # to execution date for consistency with how execution_date is set # for manually triggered tasks, i.e. triggered_date == execution_date. if dag_run.external_trigger: - next_execution_date = self.execution_date + next_execution_date = dag_run.execution_date else: - next_execution_date = task.dag.following_schedule(self.execution_date) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + next_execution_date = dag.following_schedule(self.execution_date) if next_execution_date is None: return None return timezone.coerce_datetime(next_execution_date) @@ -1937,7 +1945,7 @@ def get_prev_execution_date(): return timezone.coerce_datetime(self.execution_date) with warnings.catch_warnings(): warnings.simplefilter("ignore", DeprecationWarning) - return task.dag.previous_schedule(self.execution_date) + return dag.previous_schedule(self.execution_date) @cache def get_prev_ds() -> Optional[str]: @@ -1954,7 +1962,7 @@ def get_prev_ds_nodash() -> Optional[str]: return { 'conf': conf, - 'dag': task.dag, + 'dag': dag, 'dag_run': dag_run, 'data_interval_end': timezone.coerce_datetime(dag_run.data_interval_end), 'data_interval_start': timezone.coerce_datetime(dag_run.data_interval_start), diff --git a/airflow/operators/latest_only.py b/airflow/operators/latest_only.py index 1d44a53ffea8d..fb563e3230b63 100644 --- a/airflow/operators/latest_only.py +++ b/airflow/operators/latest_only.py @@ -19,12 +19,15 @@ This module contains an operator to run downstream tasks only for the latest scheduled DagRun """ -from typing import Dict, Iterable, Union +from typing import TYPE_CHECKING, Dict, Iterable, Union import pendulum from airflow.operators.branch import BaseBranchOperator +if TYPE_CHECKING: + from airflow.models import DAG, DagRun + class LatestOnlyOperator(BaseBranchOperator): """ @@ -43,13 +46,20 @@ class LatestOnlyOperator(BaseBranchOperator): def choose_branch(self, context: Dict) -> Union[str, Iterable[str]]: # If the DAG Run is externally triggered, then return without # skipping downstream tasks - if context['dag_run'] and context['dag_run'].external_trigger: + dag_run: "DagRun" = context["dag_run"] + if dag_run.external_trigger: self.log.info("Externally triggered DAG_Run: allowing execution to proceed.") return list(context['task'].get_direct_relative_ids(upstream=False)) + dag: "DAG" = context["dag"] + next_info = dag.next_dagrun_info(dag.get_run_data_interval(dag_run), restricted=False) now = pendulum.now('UTC') - left_window = context['dag'].following_schedule(context['execution_date']) - right_window = context['dag'].following_schedule(left_window) + + if next_info is None: + self.log.info("Last scheduled execution: allowing execution to proceed.") + return list(context['task'].get_direct_relative_ids(upstream=False)) + + left_window, right_window = next_info.data_interval self.log.info( 'Checking latest only with left_window: %s right_window: %s now: %s', left_window, diff --git a/airflow/providers/google/cloud/operators/gcs.py b/airflow/providers/google/cloud/operators/gcs.py index f38bea9237f69..e2937c5bbfcca 100644 --- a/airflow/providers/google/cloud/operators/gcs.py +++ b/airflow/providers/google/cloud/operators/gcs.py @@ -26,6 +26,7 @@ from google.api_core.exceptions import Conflict from google.cloud.exceptions import GoogleCloudError +from pendulum.datetime import DateTime from airflow.exceptions import AirflowException from airflow.models import BaseOperator @@ -793,14 +794,22 @@ def __init__( def execute(self, context: dict) -> None: # Define intervals and prefixes. - timespan_start = context["execution_date"] - timespan_end = context["dag"].following_schedule(timespan_start) - if timespan_end is None: + try: + timespan_start = context["data_interval_start"] + timespan_end = context["data_interval_end"] + except KeyError: # Data interval context variables are only available in Airflow 2.2+ + timespan_start = timezone.coerce_datetime(context["execution_date"]) + timespan_end = timezone.coerce_datetime(context["dag"].following_schedule(timespan_start)) + + if timespan_end is None: # Only possible in Airflow before 2.2. self.log.warning("No following schedule found, setting timespan end to max %s", timespan_end) - timespan_end = datetime.datetime.max + timespan_end = DateTime.max + elif timespan_start >= timespan_end: # Airflow 2.2 sets start == end for non-perodic schedules. + self.log.warning("DAG schedule not periodic, setting timespan end to max %s", timespan_end) + timespan_end = DateTime.max - timespan_start = timespan_start.replace(tzinfo=timezone.utc) - timespan_end = timespan_end.replace(tzinfo=timezone.utc) + timespan_start = timespan_start.in_timezone(timezone.utc) + timespan_end = timespan_end.in_timezone(timezone.utc) source_prefix_interp = GCSTimeSpanFileTransformOperator.interpolate_prefix( self.source_prefix, diff --git a/airflow/providers/google/cloud/sensors/gcs.py b/airflow/providers/google/cloud/sensors/gcs.py index b715022b32dc8..4b30f0fb92ab6 100644 --- a/airflow/providers/google/cloud/sensors/gcs.py +++ b/airflow/providers/google/cloud/sensors/gcs.py @@ -92,10 +92,14 @@ def poke(self, context: dict) -> bool: def ts_function(context): """ Default callback for the GoogleCloudStorageObjectUpdatedSensor. The default - behaviour is check for the object being updated after execution_date + - schedule_interval. + behaviour is check for the object being updated after the data interval's + end, or execution_date + interval on Airflow versions prior to 2.2 (before + AIP-39 implementation). """ - return context['dag'].following_schedule(context['execution_date']) + try: + return context["data_interval_end"] + except KeyError: + return context["dag"].following_schedule(context["execution_date"]) class GCSObjectUpdateSensor(BaseSensorOperator): diff --git a/airflow/sensors/time_delta.py b/airflow/sensors/time_delta.py index 8e08c532e6aff..8c53ff8202ec2 100644 --- a/airflow/sensors/time_delta.py +++ b/airflow/sensors/time_delta.py @@ -23,12 +23,9 @@ class TimeDeltaSensor(BaseSensorOperator): """ - Waits for a timedelta after the task's execution_date + schedule_interval. - In Airflow, the daily task stamped with ``execution_date`` - 2016-01-01 can only start running on 2016-01-02. The timedelta here - represents the time after the execution period has closed. + Waits for a timedelta after the run's data interval. - :param delta: time length to wait after execution_date before succeeding + :param delta: time length to wait after the data interval before succeeding. :type delta: datetime.timedelta """ @@ -37,8 +34,7 @@ def __init__(self, *, delta, **kwargs): self.delta = delta def poke(self, context): - dag = context['dag'] - target_dttm = dag.following_schedule(context['execution_date']) + target_dttm = context['data_interval_end'] target_dttm += self.delta self.log.info('Checking if the time (%s) has come', target_dttm) return timezone.utcnow() > target_dttm @@ -49,13 +45,12 @@ class TimeDeltaSensorAsync(TimeDeltaSensor): A drop-in replacement for TimeDeltaSensor that defers itself to avoid taking up a worker slot while it is waiting. - :param delta: time length to wait after execution_date before succeeding + :param delta: time length to wait after the data interval before succeeding. :type delta: datetime.timedelta """ def execute(self, context): - dag = context['dag'] - target_dttm = dag.following_schedule(context['execution_date']) + target_dttm = context['data_interval_end'] target_dttm += self.delta self.defer(trigger=DateTimeTrigger(moment=target_dttm), method_name="execute_complete") diff --git a/airflow/timetables/base.py b/airflow/timetables/base.py index 2f1d0980f6416..1ef62eaa468c4 100644 --- a/airflow/timetables/base.py +++ b/airflow/timetables/base.py @@ -148,7 +148,7 @@ def summary(self) -> str: """ return type(self).__name__ - def infer_data_interval(self, *, run_after: DateTime) -> DataInterval: + def infer_manual_data_interval(self, *, run_after: DateTime) -> DataInterval: """When a DAG run is manually triggered, infer a data interval for it. This is used for e.g. manually-triggered runs, where ``run_after`` would @@ -160,14 +160,14 @@ def infer_data_interval(self, *, run_after: DateTime) -> DataInterval: def next_dagrun_info( self, *, - last_automated_dagrun: Optional[DateTime], + last_automated_data_interval: Optional[DataInterval], restriction: TimeRestriction, ) -> Optional[DagRunInfo]: """Provide information to schedule the next DagRun. The default implementation raises ``NotImplementedError``. - :param last_automated_dagrun: The ``execution_date`` of the associated + :param last_automated_data_interval: The data interval of the associated DAG's last scheduled or backfilled run (manual runs not considered). :param restriction: Restriction to apply when scheduling the DAG run. See documentation of :class:`TimeRestriction` for details. diff --git a/airflow/timetables/interval.py b/airflow/timetables/interval.py index de8a566a6966c..8f6132ae170bd 100644 --- a/airflow/timetables/interval.py +++ b/airflow/timetables/interval.py @@ -68,13 +68,13 @@ def _get_prev(self, current: DateTime) -> DateTime: def next_dagrun_info( self, *, - last_automated_dagrun: Optional[DateTime], + last_automated_data_interval: Optional[DataInterval], restriction: TimeRestriction, ) -> Optional[DagRunInfo]: earliest = restriction.earliest if not restriction.catchup: earliest = self._skip_to_latest(earliest) - if last_automated_dagrun is None: + if last_automated_data_interval is None: # First run; schedule the run at the first available time matching # the schedule, and retrospectively create a data interval for it. if earliest is None: @@ -83,7 +83,7 @@ def next_dagrun_info( else: # There's a previous run. Create a data interval starting from when # the end of the previous interval. - start = self._get_next(last_automated_dagrun) + start = last_automated_data_interval.end if restriction.latest is not None and start > restriction.latest: return None end = self._get_next(start) @@ -214,7 +214,7 @@ def _skip_to_latest(self, earliest: Optional[DateTime]) -> DateTime: return new_start return max(new_start, earliest) - def infer_data_interval(self, *, run_after: DateTime) -> DataInterval: + def infer_manual_data_interval(self, *, run_after: DateTime) -> DataInterval: # Get the last complete period before run_after, e.g. if a DAG run is # scheduled at each midnight, the data interval of a manually triggered # run at 1am 25th is between 0am 24th and 0am 25th. @@ -290,5 +290,5 @@ def _skip_to_latest(self, earliest: Optional[DateTime]) -> DateTime: return new_start return max(new_start, earliest) - def infer_data_interval(self, run_after: DateTime) -> DataInterval: + def infer_manual_data_interval(self, run_after: DateTime) -> DataInterval: return DataInterval(start=self._get_prev(run_after), end=run_after) diff --git a/airflow/timetables/simple.py b/airflow/timetables/simple.py index a71755d40d6d1..17e1223ad69bb 100644 --- a/airflow/timetables/simple.py +++ b/airflow/timetables/simple.py @@ -44,7 +44,7 @@ def __eq__(self, other: Any) -> bool: def serialize(self) -> Dict[str, Any]: return {} - def infer_data_interval(self, *, run_after: DateTime) -> DataInterval: + def infer_manual_data_interval(self, *, run_after: DateTime) -> DataInterval: return DataInterval.exact(run_after) @@ -61,7 +61,7 @@ def summary(self) -> str: def next_dagrun_info( self, *, - last_automated_dagrun: Optional[DateTime], + last_automated_data_interval: Optional[DataInterval], restriction: TimeRestriction, ) -> Optional[DagRunInfo]: return None @@ -80,10 +80,10 @@ def summary(self) -> str: def next_dagrun_info( self, *, - last_automated_dagrun: Optional[DateTime], + last_automated_data_interval: Optional[DataInterval], restriction: TimeRestriction, ) -> Optional[DagRunInfo]: - if last_automated_dagrun is not None: + if last_automated_data_interval is not None: return None # Already run, no more scheduling. if restriction.earliest is None: # No start date, won't run. return None diff --git a/airflow/utils/session.py b/airflow/utils/session.py index adde7477c2c5c..9636fc401e6cc 100644 --- a/airflow/utils/session.py +++ b/airflow/utils/session.py @@ -18,15 +18,15 @@ import contextlib from functools import wraps from inspect import signature -from typing import Callable, TypeVar +from typing import Callable, Iterator, TypeVar from airflow import settings @contextlib.contextmanager -def create_session(): +def create_session() -> Iterator[settings.SASession]: """Contextmanager that will create and teardown a session.""" - session = settings.Session() + session: settings.SASession = settings.Session() try: yield session session.commit() diff --git a/airflow/www/views.py b/airflow/www/views.py index 7b4d42a82cb9e..f603df85c6ee0 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -1694,6 +1694,7 @@ def trigger(self, session=None): dag.create_dagrun( run_type=DagRunType.MANUAL, execution_date=execution_date, + data_interval=dag.timetable.infer_manual_data_interval(run_after=execution_date), state=State.QUEUED, conf=run_conf, external_trigger=True, @@ -2685,7 +2686,7 @@ def landing_times(self, session=None): """Shows landing times.""" default_dag_run = conf.getint('webserver', 'default_dag_run_display_number') dag_id = request.args.get('dag_id') - dag = current_app.dag_bag.get_dag(dag_id) + dag: DAG = current_app.dag_bag.get_dag(dag_id) dag_model = DagModel.get_dagmodel(dag_id) base_date = request.args.get('base_date') num_runs = request.args.get('num_runs', default=default_dag_run, type=int) @@ -2714,9 +2715,7 @@ def landing_times(self, session=None): y_points[task_id] = [] x_points[task_id] = [] for ti in tis: - ts = ti.execution_date - if dag.following_schedule(ts): - ts = dag.following_schedule(ts) + ts = dag.get_run_data_interval(ti.dag_run).end if ti.end_date: dttm = wwwutils.epoch(ti.execution_date) secs = (ti.end_date - ts).total_seconds() diff --git a/docs/apache-airflow/templates-ref.rst b/docs/apache-airflow/templates-ref.rst index b5ce0b6191614..aa1f4434b45b1 100644 --- a/docs/apache-airflow/templates-ref.rst +++ b/docs/apache-airflow/templates-ref.rst @@ -36,8 +36,8 @@ in all templates ========================================== ==================================== Variable Description ========================================== ==================================== -``{{ data_interval_start }}`` Start of the data interval (`pendulum.Pendulum`_ or ``None``). -``{{ data_interval_end }}`` End of the data interval (`pendulum.Pendulum`_ or ``None``). +``{{ data_interval_start }}`` Start of the data interval (`pendulum.Pendulum`_). +``{{ data_interval_end }}`` End of the data interval (`pendulum.Pendulum`_). ``{{ ds }}`` Start of the data interval as ``YYYY-MM-DD``. Same as ``{{ data_interval_start | ds }}``. ``{{ ds_nodash }}`` Start of the data interval as ``YYYYMMDD``. diff --git a/tests/api_connexion/endpoints/test_log_endpoint.py b/tests/api_connexion/endpoints/test_log_endpoint.py index 87963c0b1ea6a..9b7972f29a5c2 100644 --- a/tests/api_connexion/endpoints/test_log_endpoint.py +++ b/tests/api_connexion/endpoints/test_log_endpoint.py @@ -74,7 +74,7 @@ def setup_attrs(self, configured_app, configure_loggers, dag_maker, session) -> DummyOperator(task_id=self.TASK_ID) dr = dag_maker.create_dagrun( run_id='TEST_DAG_RUN_ID', - run_type=DagRunType.MANUAL, + run_type=DagRunType.SCHEDULED, execution_date=timezone.parse(self.default_time), start_date=timezone.parse(self.default_time), ) @@ -111,7 +111,7 @@ def configure_loggers(self, tmp_path): def teardown_method(self): clear_db_runs() - def test_should_respond_200_json(self, session): + def test_should_respond_200_json(self): key = self.app.config["SECRET_KEY"] serializer = URLSafeSerializer(key) token = serializer.dumps({"download_logs": False}) @@ -132,7 +132,7 @@ def test_should_respond_200_json(self, session): assert info == {'end_of_log': True} assert 200 == response.status_code - def test_should_respond_200_text_plain(self, session): + def test_should_respond_200_text_plain(self): key = self.app.config["SECRET_KEY"] serializer = URLSafeSerializer(key) token = serializer.dumps({"download_logs": True}) @@ -152,7 +152,7 @@ def test_should_respond_200_text_plain(self, session): == f"\n*** Reading local file: {expected_filename}\nLog for testing.\n" ) - def test_get_logs_of_removed_task(self, session): + def test_get_logs_of_removed_task(self): # Recreate DAG without tasks dagbag = self.app.dag_bag dag = DAG(self.DAG_ID, start_date=timezone.parse(self.default_time)) @@ -178,7 +178,7 @@ def test_get_logs_of_removed_task(self, session): == f"\n*** Reading local file: {expected_filename}\nLog for testing.\n" ) - def test_get_logs_response_with_ti_equal_to_none(self, session): + def test_get_logs_response_with_ti_equal_to_none(self): key = self.app.config["SECRET_KEY"] serializer = URLSafeSerializer(key) token = serializer.dumps({"download_logs": True}) @@ -196,7 +196,7 @@ def test_get_logs_response_with_ti_equal_to_none(self, session): 'type': EXCEPTIONS_LINK_MAP[404], } - def test_get_logs_with_metadata_as_download_large_file(self, session): + def test_get_logs_with_metadata_as_download_large_file(self): with mock.patch("airflow.utils.log.file_task_handler.FileTaskHandler.read") as read_mock: first_return = ([[('', '1st line')]], [{}]) second_return = ([[('', '2nd line')]], [{'end_of_log': False}]) @@ -234,7 +234,7 @@ def test_get_logs_for_handler_without_read_method(self, mock_log_reader): assert 400 == response.status_code assert 'Task log handler does not support read logs.' in response.data.decode('utf-8') - def test_bad_signature_raises(self, session): + def test_bad_signature_raises(self): token = {"download_logs": False} response = self.client.get( diff --git a/tests/cli/commands/test_dag_command.py b/tests/cli/commands/test_dag_command.py index 4bd7013f93e9a..c67dd6b4e1c6d 100644 --- a/tests/cli/commands/test_dag_command.py +++ b/tests/cli/commands/test_dag_command.py @@ -315,7 +315,10 @@ def test_next_execution(self): dag = self.dagbag.dags[dag_id] # Create a DagRun for each DAG, to prepare for next step dag.create_dagrun( - run_type=DagRunType.MANUAL, execution_date=now, start_date=now, state=State.FAILED + run_type=DagRunType.SCHEDULED, + execution_date=now, + start_date=now, + state=State.FAILED, ) # Test num-executions = 1 (default) diff --git a/tests/conftest.py b/tests/conftest.py index b18a472dc03c6..50c2bb8679c89 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -521,13 +521,13 @@ def __exit__(self, type, value, traceback): self.dagbag.bag_dag(self.dag, self.dag) def create_dagrun(self, **kwargs): - from airflow.timetables.base import DataInterval + from airflow.utils import timezone from airflow.utils.state import State + from airflow.utils.types import DagRunType dag = self.dag kwargs = { "state": State.RUNNING, - "execution_date": self.start_date, "start_date": self.start_date, "session": self.session, **kwargs, @@ -536,15 +536,37 @@ def create_dagrun(self, **kwargs): # explicitly, or pass run_type for inference in dag.create_dagrun(). if "run_id" not in kwargs and "run_type" not in kwargs: kwargs["run_id"] = "test" - # Fill data_interval is not provided. - if not kwargs.get("data_interval"): - kwargs["data_interval"] = DataInterval.exact(kwargs["execution_date"]) + + if "run_type" not in kwargs: + kwargs["run_type"] = DagRunType.from_run_id(kwargs["run_id"]) + if "execution_date" not in kwargs: + if kwargs["run_type"] == DagRunType.MANUAL: + kwargs["execution_date"] = self.start_date + else: + kwargs["execution_date"] = dag.next_dagrun_info(None).logical_date + if "data_interval" not in kwargs: + logical_date = timezone.coerce_datetime(kwargs["execution_date"]) + if kwargs["run_type"] == DagRunType.MANUAL: + data_interval = dag.timetable.infer_manual_data_interval(run_after=logical_date) + else: + data_interval = dag.infer_automated_data_interval(logical_date) + kwargs["data_interval"] = data_interval self.dag_run = dag.create_dagrun(**kwargs) for ti in self.dag_run.task_instances: ti.refresh_from_task(dag.get_task(ti.task_id)) return self.dag_run + def create_dagrun_after(self, dagrun, **kwargs): + next_info = self.dag.next_dagrun_info(self.dag.get_run_data_interval(dagrun)) + if next_info is None: + raise ValueError(f"cannot create run after {dagrun}") + return self.create_dagrun( + execution_date=next_info.logical_date, + data_interval=next_info.data_interval, + **kwargs, + ) + def __call__( self, dag_id='test_dag', serialized=want_serialized, fileloc=None, session=None, **kwargs ): @@ -640,7 +662,7 @@ def create_dag( on_failure_callback=None, on_retry_callback=None, email=None, - with_dagrun=True, + with_dagrun_type=DagRunType.SCHEDULED, **kwargs, ): with dag_maker(dag_id, **kwargs) as dag: @@ -656,8 +678,8 @@ def create_dag( pool=pool, trigger_rule=trigger_rule, ) - if with_dagrun: - dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED) + if with_dagrun_type is not None: + dag_maker.create_dagrun(run_type=with_dagrun_type) return dag, op return create_dag @@ -671,15 +693,20 @@ def create_task_instance(dag_maker, create_dummy_dag): Uses ``create_dummy_dag`` to create the dag structure. """ - def maker(execution_date=None, dagrun_state=None, state=None, run_id='test', **kwargs): + def maker(execution_date=None, dagrun_state=None, state=None, run_id=None, run_type=None, **kwargs): if execution_date is None: from airflow.utils import timezone execution_date = timezone.utcnow() - create_dummy_dag(with_dagrun=False, **kwargs) - - dr = dag_maker.create_dagrun(execution_date=execution_date, state=dagrun_state, run_id=run_id) - ti = dr.task_instances[0] + create_dummy_dag(with_dagrun_type=None, **kwargs) + + dagrun_kwargs = {"execution_date": execution_date, "state": dagrun_state} + if run_id is not None: + dagrun_kwargs["run_id"] = run_id + if run_type is not None: + dagrun_kwargs["run_type"] = run_type + dagrun = dag_maker.create_dagrun(**dagrun_kwargs) + (ti,) = dagrun.task_instances ti.state = state return ti @@ -699,7 +726,11 @@ def _create_task_instance( ): with dag_maker(dag_id=dag_id, session=session): operator_class(**operator_kwargs) - (ti,) = dag_maker.create_dagrun(execution_date=execution_date).task_instances + if execution_date is None: + dagrun_kwargs = {} + else: + dagrun_kwargs = {"execution_date": execution_date} + (ti,) = dag_maker.create_dagrun(**dagrun_kwargs).task_instances return ti return _create_task_instance diff --git a/tests/jobs/test_local_task_job.py b/tests/jobs/test_local_task_job.py index 1574fd6e4a9b2..75124db6915e4 100644 --- a/tests/jobs/test_local_task_job.py +++ b/tests/jobs/test_local_task_job.py @@ -16,6 +16,7 @@ # specific language governing permissions and limitations # under the License. # +import datetime import os import signal import time @@ -782,12 +783,13 @@ def task_function(ti): def test_task_exit_should_update_state_of_finished_dagruns_with_dag_paused(self, dag_maker): """Test that with DAG paused, DagRun state will update when the tasks finishes the run""" - with dag_maker(dag_id='test_dags') as dag: + schedule_interval = datetime.timedelta(days=1) + with dag_maker(dag_id='test_dags', schedule_interval=schedule_interval) as dag: op1 = PythonOperator(task_id='dummy', python_callable=lambda: True) session = settings.Session() dagmodel = dag_maker.dag_model - dagmodel.next_dagrun_create_after = dag.following_schedule(DEFAULT_DATE) + dagmodel.next_dagrun_create_after = DEFAULT_DATE + schedule_interval dagmodel.is_paused = True session.merge(dagmodel) session.flush() diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index fe443417fd9f4..cf4ba026a9be3 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -309,18 +309,14 @@ def test_execute_task_instances_backfill_tasks_wont_execute(self, dag_maker): def test_find_executable_task_instances_backfill(self, dag_maker): dag_id = 'SchedulerJobTest.test_find_executable_task_instances_backfill' task_id_1 = 'dummy' - with dag_maker(dag_id=dag_id, max_active_tasks=16) as dag: + with dag_maker(dag_id=dag_id, max_active_tasks=16): task1 = DummyOperator(task_id=task_id_1) self.scheduler_job = SchedulerJob(subdir=os.devnull) session = settings.Session() - dr1 = dag_maker.create_dagrun() - dr2 = dag.create_dagrun( - run_type=DagRunType.BACKFILL_JOB, - execution_date=dag.following_schedule(dr1.execution_date), - state=State.RUNNING, - ) + dr1 = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED) + dr2 = dag_maker.create_dagrun_after(dr1, run_type=DagRunType.BACKFILL_JOB, state=State.RUNNING) ti_backfill = dr2.get_task_instance(task1.task_id) ti_with_dagrun = dr1.get_task_instance(task1.task_id) @@ -344,18 +340,14 @@ def test_find_executable_task_instances_pool(self, dag_maker): task_id_1 = 'dummy' task_id_2 = 'dummydummy' session = settings.Session() - with dag_maker(dag_id=dag_id, max_active_tasks=16, session=session) as dag: + with dag_maker(dag_id=dag_id, max_active_tasks=16, session=session): DummyOperator(task_id=task_id_1, pool='a') DummyOperator(task_id=task_id_2, pool='b') self.scheduler_job = SchedulerJob(subdir=os.devnull) - dr1 = dag_maker.create_dagrun() - dr2 = dag_maker.create_dagrun( - run_type=DagRunType.SCHEDULED, - execution_date=dag.following_schedule(dr1.execution_date), - state=State.RUNNING, - ) + dr1 = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED) + dr2 = dag_maker.create_dagrun_after(dr1, run_type=DagRunType.SCHEDULED) tis = dr1.task_instances + dr2.task_instances for ti in tis: @@ -470,7 +462,7 @@ def test_find_executable_task_instances_in_default_pool(self, dag_maker): set_default_pool_slots(1) dag_id = 'SchedulerJobTest.test_find_executable_task_instances_in_default_pool' - with dag_maker(dag_id=dag_id) as dag: + with dag_maker(dag_id=dag_id): op1 = DummyOperator(task_id='dummy1') op2 = DummyOperator(task_id='dummy2') @@ -478,12 +470,8 @@ def test_find_executable_task_instances_in_default_pool(self, dag_maker): self.scheduler_job = SchedulerJob(executor=executor) session = settings.Session() - dr1 = dag_maker.create_dagrun() - dr2 = dag.create_dagrun( - run_type=DagRunType.SCHEDULED, - execution_date=dag.following_schedule(dr1.execution_date), - state=State.RUNNING, - ) + dr1 = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED) + dr2 = dag_maker.create_dagrun_after(dr1, run_type=DagRunType.SCHEDULED, state=State.RUNNING) ti1 = dr1.get_task_instance(op1.task_id, session) ti2 = dr2.get_task_instance(op2.task_id, session) @@ -566,16 +554,14 @@ def test_tis_for_queued_dagruns_are_not_run(self, dag_maker): dag_id = "test_tis_for_queued_dagruns_are_not_run" task_id_1 = 'dummy' - with dag_maker(dag_id) as dag: + with dag_maker(dag_id): task1 = DummyOperator(task_id=task_id_1) - dr1 = dag_maker.create_dagrun(state=State.QUEUED) - dr2 = dag_maker.create_dagrun( - run_id='test2', execution_date=dag.following_schedule(dr1.execution_date) - ) + dr1 = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED, state=State.QUEUED) + dr2 = dag_maker.create_dagrun_after(dr1, run_type=DagRunType.SCHEDULED) self.scheduler_job = SchedulerJob(subdir=os.devnull) session = settings.Session() - ti1 = TaskInstance(task1, dr1.execution_date) - ti2 = TaskInstance(task1, dr2.execution_date) + ti1 = TaskInstance(task1, run_id=dr1.run_id) + ti2 = TaskInstance(task1, run_id=dr2.run_id) ti1.state = State.SCHEDULED ti2.state = State.SCHEDULED session.merge(ti1) @@ -593,22 +579,14 @@ def test_tis_for_queued_dagruns_are_not_run(self, dag_maker): def test_find_executable_task_instances_concurrency(self, dag_maker): dag_id = 'SchedulerJobTest.test_find_executable_task_instances_concurrency' session = settings.Session() - with dag_maker(dag_id=dag_id, max_active_tasks=2, session=session) as dag: + with dag_maker(dag_id=dag_id, max_active_tasks=2, session=session): DummyOperator(task_id='dummy') self.scheduler_job = SchedulerJob(subdir=os.devnull) - dr1 = dag_maker.create_dagrun() - dr2 = dag_maker.create_dagrun( - run_type=DagRunType.SCHEDULED, - execution_date=dag.following_schedule(dr1.execution_date), - state=State.RUNNING, - ) - dr3 = dag_maker.create_dagrun( - run_type=DagRunType.SCHEDULED, - execution_date=dag.following_schedule(dr2.execution_date), - state=State.RUNNING, - ) + dr1 = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED) + dr2 = dag_maker.create_dagrun_after(dr1, run_type=DagRunType.SCHEDULED) + dr3 = dag_maker.create_dagrun_after(dr2, run_type=DagRunType.SCHEDULED) ti1 = dr1.task_instances[0] ti2 = dr2.task_instances[0] @@ -673,7 +651,7 @@ def test_find_executable_task_instances_max_active_tis_per_dag(self, dag_maker): dag_id = 'SchedulerJobTest.test_find_executable_task_instances_max_active_tis_per_dag' task_id_1 = 'dummy' task_id_2 = 'dummy2' - with dag_maker(dag_id=dag_id, max_active_tasks=16) as dag: + with dag_maker(dag_id=dag_id, max_active_tasks=16): task1 = DummyOperator(task_id=task_id_1, max_active_tis_per_dag=2) task2 = DummyOperator(task_id=task_id_2) @@ -681,17 +659,9 @@ def test_find_executable_task_instances_max_active_tis_per_dag(self, dag_maker): self.scheduler_job = SchedulerJob(executor=executor) session = settings.Session() - dr1 = dag_maker.create_dagrun() - dr2 = dag.create_dagrun( - run_type=DagRunType.SCHEDULED, - execution_date=dag.following_schedule(dr1.execution_date), - state=State.RUNNING, - ) - dr3 = dag.create_dagrun( - run_type=DagRunType.SCHEDULED, - execution_date=dag.following_schedule(dr2.execution_date), - state=State.RUNNING, - ) + dr1 = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED) + dr2 = dag_maker.create_dagrun_after(dr1, run_type=DagRunType.SCHEDULED) + dr3 = dag_maker.create_dagrun_after(dr2, run_type=DagRunType.SCHEDULED) ti1_1 = dr1.get_task_instance(task1.task_id) ti2 = dr1.get_task_instance(task2.task_id) @@ -758,26 +728,15 @@ def test_find_executable_task_instances_max_active_tis_per_dag(self, dag_maker): def test_change_state_for_executable_task_instances_no_tis_with_state(self, dag_maker): dag_id = 'SchedulerJobTest.test_change_state_for__no_tis_with_state' task_id_1 = 'dummy' - with dag_maker(dag_id=dag_id, max_active_tasks=2) as dag: + with dag_maker(dag_id=dag_id, max_active_tasks=2): task1 = DummyOperator(task_id=task_id_1) self.scheduler_job = SchedulerJob(subdir=os.devnull) session = settings.Session() - date = DEFAULT_DATE - dr1 = dag_maker.create_dagrun() - date = dag.following_schedule(date) - dr2 = dag.create_dagrun( - run_type=DagRunType.SCHEDULED, - execution_date=date, - state=State.RUNNING, - ) - date = dag.following_schedule(date) - dr3 = dag.create_dagrun( - run_type=DagRunType.SCHEDULED, - execution_date=date, - state=State.RUNNING, - ) + dr1 = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED) + dr2 = dag_maker.create_dagrun_after(dr1, run_type=DagRunType.SCHEDULED) + dr3 = dag_maker.create_dagrun_after(dr2, run_type=DagRunType.SCHEDULED) ti1 = dr1.get_task_instance(task1.task_id) ti2 = dr2.get_task_instance(task1.task_id) @@ -831,7 +790,7 @@ def test_critical_section_execute_task_instances(self, dag_maker): # create first dag run with 1 running and 1 queued - dr1 = dag_maker.create_dagrun() + dr1 = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED) ti1 = dr1.get_task_instance(task1.task_id, session) ti2 = dr1.get_task_instance(task2.task_id, session) @@ -843,11 +802,7 @@ def test_critical_section_execute_task_instances(self, dag_maker): assert 2 == DAG.get_num_task_instances(dag_id, dag.task_ids, states=[State.RUNNING], session=session) # create second dag run - dr2 = dag_maker.create_dagrun( - run_type=DagRunType.SCHEDULED, - execution_date=dag.following_schedule(dr1.execution_date), - state=State.RUNNING, - ) + dr2 = dag_maker.create_dagrun_after(dr1, run_type=DagRunType.SCHEDULED) ti3 = dr2.get_task_instance(task1.task_id, session) ti4 = dr2.get_task_instance(task2.task_id, session) # manually set to scheduled so we can pick them up @@ -881,21 +836,25 @@ def test_execute_task_instances_limit(self, dag_maker): # because before scheduler._execute_task_instances would only # check the num tasks once so if max_active_tasks was 3, # we could execute arbitrarily many tasks in the second run - with dag_maker(dag_id=dag_id, max_active_tasks=16, session=session) as dag: + with dag_maker(dag_id=dag_id, max_active_tasks=16, session=session): task1 = DummyOperator(task_id=task_id_1) task2 = DummyOperator(task_id=task_id_2) self.scheduler_job = SchedulerJob(subdir=os.devnull) - date = dag.start_date + def _create_dagruns(): + dagrun = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED, state=State.RUNNING) + yield dagrun + for _ in range(0, 3): + dagrun = dag_maker.create_dagrun_after( + dagrun, + run_type=DagRunType.SCHEDULED, + state=State.RUNNING, + ) + yield dagrun + tis = [] - for _ in range(0, 4): - date = dag.following_schedule(date) - dr = dag_maker.create_dagrun( - run_type=DagRunType.SCHEDULED, - execution_date=date, - state=State.RUNNING, - ) + for dr in _create_dagruns(): ti1 = dr.get_task_instance(task1.task_id, session) ti2 = dr.get_task_instance(task2.task_id, session) ti1.state = State.SCHEDULED @@ -928,21 +887,24 @@ def test_execute_task_instances_unlimited(self, dag_maker): task_id_2 = 'dummy_task_2' session = settings.Session() - with dag_maker(dag_id=dag_id, max_active_tasks=1024, session=session) as dag: + with dag_maker(dag_id=dag_id, max_active_tasks=1024, session=session): task1 = DummyOperator(task_id=task_id_1) task2 = DummyOperator(task_id=task_id_2) self.scheduler_job = SchedulerJob(subdir=os.devnull) - date = dag.start_date - for _ in range(0, 20): - date = dag.following_schedule(date) - dr = dag_maker.create_dagrun( - run_type=DagRunType.SCHEDULED, - execution_date=date, - state=State.RUNNING, - ) - date = dag.following_schedule(date) + def _create_dagruns(): + dagrun = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED, state=State.RUNNING) + yield dagrun + for _ in range(0, 19): + dagrun = dag_maker.create_dagrun_after( + dagrun, + run_type=DagRunType.SCHEDULED, + state=State.RUNNING, + ) + yield dagrun + + for dr in _create_dagruns(): ti1 = dr.get_task_instance(task1.task_id, session) ti2 = dr.get_task_instance(task2.task_id, session) ti1.state = State.SCHEDULED @@ -1284,7 +1246,11 @@ def test_dagrun_callbacks_are_not_added_when_callbacks_are_not_defined(self, sta session.close() def test_do_not_schedule_removed_task(self, dag_maker): - with dag_maker(dag_id='test_scheduler_do_not_schedule_removed_task') as dag: + schedule_interval = datetime.timedelta(days=1) + with dag_maker( + dag_id='test_scheduler_do_not_schedule_removed_task', + schedule_interval=schedule_interval, + ): DummyOperator(task_id='dummy') session = settings.Session() @@ -1297,7 +1263,8 @@ def test_do_not_schedule_removed_task(self, dag_maker): session.query(DagModel).delete() with dag_maker( dag_id='test_scheduler_do_not_schedule_removed_task', - start_date=dag.following_schedule(DEFAULT_DATE), + schedule_interval=schedule_interval, + start_date=DEFAULT_DATE + schedule_interval, ): pass @@ -1333,14 +1300,16 @@ def evaluate_dagrun( run_type=DagRunType.SCHEDULED, execution_date=dagrun_info.logical_date, state=State.RUNNING, + session=session, ) if advance_execution_date: # run a second time to schedule a dagrun after the start_date dr = dag.create_dagrun( run_type=DagRunType.SCHEDULED, - execution_date=dag.following_schedule(dr.execution_date), + execution_date=dr.data_interval_end, state=State.RUNNING, + session=session, ) ex_date = dr.execution_date @@ -1358,7 +1327,7 @@ def evaluate_dagrun( pass # load dagrun - dr = DagRun.find(dag_id=dag_id, execution_date=ex_date) + dr = DagRun.find(dag_id=dag_id, execution_date=ex_date, session=session) dr = dr[0] dr.dag = dag @@ -1620,7 +1589,7 @@ def test_scheduler_verify_pool_full(self, dag_maker): """ Test task instances not queued when pool is full """ - with dag_maker(dag_id='test_scheduler_verify_pool_full') as dag: + with dag_maker(dag_id='test_scheduler_verify_pool_full'): BashOperator( task_id='dummy', pool='test_scheduler_verify_pool_full', @@ -1640,11 +1609,7 @@ def test_scheduler_verify_pool_full(self, dag_maker): run_type=DagRunType.SCHEDULED, ) self.scheduler_job._schedule_dag_run(dr, session) - dr = dag.create_dagrun( - run_type=DagRunType.SCHEDULED, - execution_date=dag.following_schedule(dr.execution_date), - state=State.RUNNING, - ) + dr = dag_maker.create_dagrun_after(dr, run_type=DagRunType.SCHEDULED, state=State.RUNNING) self.scheduler_job._schedule_dag_run(dr, session) task_instances_list = self.scheduler_job._executable_task_instances_to_queued( max_tis=32, session=session @@ -1659,7 +1624,11 @@ def test_scheduler_verify_pool_full_2_slots_per_task(self, dag_maker, session): Variation with non-default pool_slots """ - with dag_maker(dag_id='test_scheduler_verify_pool_full_2_slots_per_task', session=session) as dag: + with dag_maker( + dag_id='test_scheduler_verify_pool_full_2_slots_per_task', + start_date=DEFAULT_DATE, + session=session, + ): BashOperator( task_id='dummy', pool='test_scheduler_verify_pool_full_2_slots_per_task', @@ -1675,15 +1644,14 @@ def test_scheduler_verify_pool_full_2_slots_per_task(self, dag_maker, session): self.scheduler_job.processor_agent = mock.MagicMock() # Create 5 dagruns, which will create 5 task instances. - date = DEFAULT_DATE - for _ in range(5): - date = dag.following_schedule(date) - dr = dag.create_dagrun( - run_type=DagRunType.SCHEDULED, - execution_date=date, - state=State.RUNNING, - session=session, - ) + def _create_dagruns(): + dr = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED) + yield dr + for _ in range(4): + dr = dag_maker.create_dagrun_after(dr, run_type=DagRunType.SCHEDULED) + yield dr + + for dr in _create_dagruns(): self.scheduler_job._schedule_dag_run(dr, session) task_instances_list = self.scheduler_job._executable_task_instances_to_queued( @@ -1698,14 +1666,20 @@ def test_scheduler_keeps_scheduling_pool_full(self, dag_maker): """ Test task instances in a pool that isn't full keep getting scheduled even when a pool is full. """ - with dag_maker(dag_id='test_scheduler_keeps_scheduling_pool_full_d1') as dag_d1: + with dag_maker( + dag_id='test_scheduler_keeps_scheduling_pool_full_d1', + start_date=DEFAULT_DATE, + ) as dag_d1: BashOperator( task_id='test_scheduler_keeps_scheduling_pool_full_t1', pool='test_scheduler_keeps_scheduling_pool_full_p1', bash_command='echo hi', ) - with dag_maker(dag_id='test_scheduler_keeps_scheduling_pool_full_d2') as dag_d2: + with dag_maker( + dag_id='test_scheduler_keeps_scheduling_pool_full_d2', + start_date=DEFAULT_DATE, + ) as dag_d2: BashOperator( task_id='test_scheduler_keeps_scheduling_pool_full_t2', pool='test_scheduler_keeps_scheduling_pool_full_p2', @@ -1722,28 +1696,24 @@ def test_scheduler_keeps_scheduling_pool_full(self, dag_maker): scheduler = SchedulerJob(executor=self.null_exec) scheduler.processor_agent = mock.MagicMock() + def _create_dagruns(dag: DAG): + next_info = dag.next_dagrun_info(None) + for _ in range(5): + yield dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=next_info.logical_date, + data_interval=next_info.data_interval, + state=State.RUNNING, + ) + next_info = dag.next_dagrun_info(next_info.data_interval) + # Create 5 dagruns for each DAG. # To increase the chances the TIs from the "full" pool will get retrieved first, we schedule all # TIs from the first dag first. - date = DEFAULT_DATE - for _ in range(5): - dr = dag_d1.create_dagrun( - run_type=DagRunType.SCHEDULED, - execution_date=date, - state=State.RUNNING, - ) + for dr in _create_dagruns(dag_d1): scheduler._schedule_dag_run(dr, session) - date = dag_d1.following_schedule(date) - - date = DEFAULT_DATE - for _ in range(5): - dr = dag_d2.create_dagrun( - run_type=DagRunType.SCHEDULED, - execution_date=date, - state=State.RUNNING, - ) + for dr in _create_dagruns(dag_d2): scheduler._schedule_dag_run(dr, session) - date = dag_d2.following_schedule(date) scheduler._executable_task_instances_to_queued(max_tis=2, session=session) task_instances_list2 = scheduler._executable_task_instances_to_queued(max_tis=2, session=session) @@ -2679,20 +2649,28 @@ def test_do_schedule_max_active_runs_and_manual_trigger(self, dag_maker): def test_max_active_runs_in_a_dag_doesnt_stop_running_dagruns_in_otherdags(self, dag_maker): session = settings.Session() - with dag_maker('test_dag1', max_active_runs=1) as dag: + with dag_maker( + 'test_dag1', + start_date=DEFAULT_DATE, + schedule_interval=timedelta(hours=1), + max_active_runs=1, + ): DummyOperator(task_id='mytask') - date = dag.following_schedule(DEFAULT_DATE) - for _ in range(30): - dr = dag.create_dagrun(run_type=DagRunType.SCHEDULED, state=State.QUEUED, execution_date=date) - date = dr.execution_date + timedelta(hours=1) - date = timezone.datetime(2020, 1, 1) - with dag_maker('test_dag2', start_date=date) as dag2: + dr = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED, state=State.QUEUED) + for _ in range(29): + dr = dag_maker.create_dagrun_after(dr, run_type=DagRunType.SCHEDULED, state=State.QUEUED) + + with dag_maker( + 'test_dag2', + start_date=timezone.datetime(2020, 1, 1), + schedule_interval=timedelta(hours=1), + ): DummyOperator(task_id='mytask') - for _ in range(10): - dr = dag2.create_dagrun(run_type=DagRunType.SCHEDULED, state=State.QUEUED, execution_date=date) - date = dr.execution_date + timedelta(hours=1) + dr = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED, state=State.QUEUED) + for _ in range(9): + dr = dag_maker.create_dagrun_after(dr, run_type=DagRunType.SCHEDULED, state=State.QUEUED) self.scheduler_job = SchedulerJob(subdir=os.devnull) self.scheduler_job.executor = MockExecutor(do_update=False) @@ -2703,16 +2681,14 @@ def test_max_active_runs_in_a_dag_doesnt_stop_running_dagruns_in_otherdags(self, self.scheduler_job._start_queued_dagruns(session) session.flush() - assert ( - len( - session.query(DagRun) - .filter(DagRun.dag_id == dag.dag_id) - .filter(DagRun.state == State.RUNNING) - .all() - ) - == 1 + dag1_running_count = ( + session.query(func.count(DagRun.id)) + .filter(DagRun.dag_id == "test_dag1", DagRun.state == State.RUNNING) + .scalar() ) - assert len(session.query(DagRun).filter(DagRun.state == State.RUNNING).all()) == 11 + running_count = session.query(func.count(DagRun.id)).filter(DagRun.state == State.RUNNING).scalar() + assert dag1_running_count == 1 + assert running_count == 11 def test_start_queued_dagruns_do_follow_execution_date_order(self, dag_maker): session = settings.Session() @@ -2999,7 +2975,9 @@ def test_runs_respected_after_clear(self, dag_maker): Test dag after dag.clear, max_active_runs is respected """ with dag_maker( - dag_id='test_scheduler_max_active_runs_respected_after_clear', max_active_runs=1 + dag_id='test_scheduler_max_active_runs_respected_after_clear', + start_date=DEFAULT_DATE, + max_active_runs=1, ) as dag: BashOperator(task_id='dummy', bash_command='echo Hi') @@ -3007,23 +2985,9 @@ def test_runs_respected_after_clear(self, dag_maker): self.scheduler_job.processor_agent = mock.MagicMock() session = settings.Session() - date = DEFAULT_DATE - dag_maker.create_dagrun( - run_type=DagRunType.SCHEDULED, - state=State.QUEUED, - ) - date = dag.following_schedule(date) - dag.create_dagrun( - run_type=DagRunType.SCHEDULED, - execution_date=date, - state=State.QUEUED, - ) - date = dag.following_schedule(date) - dag.create_dagrun( - run_type=DagRunType.SCHEDULED, - execution_date=date, - state=State.QUEUED, - ) + dr = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED, state=State.QUEUED) + dr = dag_maker.create_dagrun_after(dr, run_type=DagRunType.SCHEDULED, state=State.QUEUED) + dag_maker.create_dagrun_after(dr, run_type=DagRunType.SCHEDULED, state=State.QUEUED) dag.clear() assert len(DagRun.find(dag_id=dag.dag_id, state=State.QUEUED, session=session)) == 3 diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index 61cbceedf061e..0e1c5492087b4 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -772,9 +772,8 @@ def test_bulk_write_to_db_max_active_runs(self): model = session.query(DagModel).get((dag.dag_id,)) - period_end = dag.following_schedule(DEFAULT_DATE) assert model.next_dagrun == DEFAULT_DATE - assert model.next_dagrun_create_after == period_end + assert model.next_dagrun_create_after == DEFAULT_DATE + timedelta(days=1) dr = dag.create_dagrun( state=State.RUNNING, @@ -786,7 +785,7 @@ def test_bulk_write_to_db_max_active_runs(self): DAG.bulk_write_to_db([dag]) model = session.query(DagModel).get((dag.dag_id,)) - assert model.next_dagrun == period_end + assert model.next_dagrun == DEFAULT_DATE + timedelta(days=1) # Next dagrun after is not None because the dagrun would be in queued state assert model.next_dagrun_create_after is not None @@ -1148,7 +1147,7 @@ def test_next_dagrun_after_fake_scheduled_previous(self): # Even though there is a run for this date already, it is marked as manual/external, so we should # create a scheduled one anyway! assert model.next_dagrun == DEFAULT_DATE - assert model.next_dagrun_create_after == dag.following_schedule(DEFAULT_DATE) + assert model.next_dagrun_create_after == DEFAULT_DATE + delta self._clean_up(dag_id) @@ -1503,7 +1502,7 @@ def test_next_dagrun_info_once(self): next_info = dag.next_dagrun_info(None) assert next_info and next_info.logical_date == timezone.datetime(2015, 1, 1) - next_info = dag.next_dagrun_info(next_info.logical_date) + next_info = dag.next_dagrun_info(next_info.data_interval) assert next_info is None def test_next_dagrun_info_start_end_dates(self): @@ -1521,18 +1520,18 @@ def test_next_dagrun_info_start_end_dates(self): # Create and schedule the dag runs dates = [] - date = None + interval = None for _ in range(runs): - next_info = dag.next_dagrun_info(date) + next_info = dag.next_dagrun_info(interval) if next_info is None: dates.append(None) else: - date = next_info.logical_date - dates.append(date) + interval = next_info.data_interval + dates.append(interval.start) assert all(date is not None for date in dates) assert dates[-1] == end_date - assert dag.next_dagrun_info(date) is None + assert dag.next_dagrun_info(interval.start) is None def test_next_dagrun_info_catchup(self): """ @@ -1618,7 +1617,7 @@ def test_next_dagrun_info_timedelta_schedule_and_catchup_false(self): assert next_info and next_info.logical_date == timezone.datetime(2020, 1, 4) # The date to create is in the future, this is handled by "DagModel.dags_needing_dagruns" - next_info = dag.next_dagrun_info(next_info.logical_date) + next_info = dag.next_dagrun_info(next_info.data_interval) assert next_info and next_info.logical_date == timezone.datetime(2020, 1, 5) @freeze_time(timezone.datetime(2020, 5, 4)) @@ -1637,14 +1636,14 @@ def test_next_dagrun_info_timedelta_schedule_and_catchup_true(self): next_info = dag.next_dagrun_info(None) assert next_info and next_info.logical_date == timezone.datetime(2020, 5, 1) - next_info = dag.next_dagrun_info(next_info.logical_date) + next_info = dag.next_dagrun_info(next_info.data_interval) assert next_info and next_info.logical_date == timezone.datetime(2020, 5, 2) - next_info = dag.next_dagrun_info(next_info.logical_date) + next_info = dag.next_dagrun_info(next_info.data_interval) assert next_info and next_info.logical_date == timezone.datetime(2020, 5, 3) # The date to create is in the future, this is handled by "DagModel.dags_needing_dagruns" - next_info = dag.next_dagrun_info(next_info.logical_date) + next_info = dag.next_dagrun_info(next_info.data_interval) assert next_info and next_info.logical_date == timezone.datetime(2020, 5, 4) def test_next_dagrun_after_auto_align(self): @@ -1790,8 +1789,8 @@ def test_dags_needing_dagruns_only_unpaused(self): orm_dag = DagModel( dag_id=dag.dag_id, has_task_concurrency_limits=False, - next_dagrun=dag.start_date, - next_dagrun_create_after=dag.following_schedule(DEFAULT_DATE), + next_dagrun=DEFAULT_DATE, + next_dagrun_create_after=DEFAULT_DATE + timedelta(days=1), is_active=True, ) session.add(orm_dag) diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py index c9b9bcae35955..0da62cb649bfa 100644 --- a/tests/models/test_dagrun.py +++ b/tests/models/test_dagrun.py @@ -709,8 +709,8 @@ def test_next_dagruns_to_examine_only_unpaused(self, state): orm_dag = DagModel( dag_id=dag.dag_id, has_task_concurrency_limits=False, - next_dagrun=dag.start_date, - next_dagrun_create_after=dag.following_schedule(DEFAULT_DATE), + next_dagrun=DEFAULT_DATE, + next_dagrun_create_after=DEFAULT_DATE + datetime.timedelta(days=1), is_active=True, ) session.add(orm_dag) @@ -767,13 +767,17 @@ def test_emit_scheduling_delay(self, schedule_interval, expected): session = settings.Session() try: - orm_dag = DagModel( - dag_id=dag.dag_id, - has_task_concurrency_limits=False, - next_dagrun=dag.start_date, - next_dagrun_create_after=dag.following_schedule(dag.start_date), - is_active=True, - ) + info = dag.next_dagrun_info(None) + orm_dag_kwargs = {"dag_id": dag.dag_id, "has_task_concurrency_limits": False, "is_active": True} + if info is not None: + orm_dag_kwargs.update( + { + "next_dagrun": info.logical_date, + "next_dagrun_data_interval": info.data_interval, + "next_dagrun_create_after": info.run_after, + }, + ) + orm_dag = DagModel(**orm_dag_kwargs) session.add(orm_dag) session.flush() dag_run = dag.create_dagrun( @@ -793,7 +797,7 @@ def test_emit_scheduling_delay(self, schedule_interval, expected): metric_name = f'dagrun.{dag.dag_id}.first_task_scheduling_delay' if expected: - true_delay = ti.start_date - dag.following_schedule(dag_run.execution_date) + true_delay = ti.start_date - dag_run.data_interval_end sched_delay_stat_call = call(metric_name, true_delay) assert sched_delay_stat_call in stats_mock.mock_calls else: diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index a2f2ec5d60a24..e22a44748e602 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -1588,11 +1588,18 @@ def test_handle_failure(self, create_dummy_dag, session=None): schedule_interval=None, start_date=start_date, task_id="test_handle_failure_on_failure", + with_dagrun_type=DagRunType.MANUAL, on_failure_callback=mock_on_failure_1, on_retry_callback=mock_on_retry_1, session=session, ) - dr = dag.create_dagrun(run_id="test2", execution_date=timezone.utcnow(), state=None, session=session) + dr = dag.create_dagrun( + run_id="test2", + run_type=DagRunType.MANUAL, + execution_date=timezone.utcnow(), + state=None, + session=session, + ) ti1 = dr.get_task_instance(task1.task_id, session=session) ti1.task = task1 diff --git a/tests/providers/apache/druid/operators/test_druid.py b/tests/providers/apache/druid/operators/test_druid.py index 55867d5a34fa3..2195f7007be09 100644 --- a/tests/providers/apache/druid/operators/test_druid.py +++ b/tests/providers/apache/druid/operators/test_druid.py @@ -20,6 +20,7 @@ from airflow.providers.apache.druid.operators.druid import DruidOperator from airflow.utils import timezone +from airflow.utils.types import DagRunType DEFAULT_DATE = timezone.datetime(2017, 1, 1) @@ -52,7 +53,7 @@ def test_render_template(dag_maker): params={"index_type": "index_hadoop", "datasource": "datasource_prd"}, ) - dag_maker.create_dagrun().task_instances[0].render_templates() + dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED).task_instances[0].render_templates() assert RENDERED_INDEX == json.loads(operator.json_index_file) @@ -71,5 +72,5 @@ def test_render_template_from_file(tmp_path, dag_maker): params={"index_type": "index_hadoop", "datasource": "datasource_prd"}, ) - dag_maker.create_dagrun().task_instances[0].render_templates() + dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED).task_instances[0].render_templates() assert RENDERED_INDEX == json.loads(operator.json_index_file) diff --git a/tests/providers/apache/kylin/operators/test_kylin_cube.py b/tests/providers/apache/kylin/operators/test_kylin_cube.py index 7d306ad84e9d2..bb17a40567671 100644 --- a/tests/providers/apache/kylin/operators/test_kylin_cube.py +++ b/tests/providers/apache/kylin/operators/test_kylin_cube.py @@ -167,7 +167,7 @@ def test_render_template(self): }, ) ti = TaskInstance(operator, run_id="kylin_test") - ti.dag_run = DagRun(run_id="kylin_test", execution_date=DEFAULT_DATE) + ti.dag_run = DagRun(dag_id=self.dag.dag_id, run_id="kylin_test", execution_date=DEFAULT_DATE) ti.render_templates() assert 'learn_kylin' == getattr(operator, 'project') assert 'kylin_sales_cube' == getattr(operator, 'cube') diff --git a/tests/providers/apache/spark/operators/test_spark_submit.py b/tests/providers/apache/spark/operators/test_spark_submit.py index 746bcd5b6c628..6cbdd0cf512d2 100644 --- a/tests/providers/apache/spark/operators/test_spark_submit.py +++ b/tests/providers/apache/spark/operators/test_spark_submit.py @@ -148,7 +148,7 @@ def test_render_template(self): # Given operator = SparkSubmitOperator(task_id='spark_submit_job', dag=self.dag, **self._config) ti = TaskInstance(operator, run_id="spark_test") - ti.dag_run = DagRun(run_id="spark_test", execution_date=DEFAULT_DATE) + ti.dag_run = DagRun(dag_id=self.dag.dag_id, run_id="spark_test", execution_date=DEFAULT_DATE) # When ti.render_templates() diff --git a/tests/sensors/test_external_task_sensor.py b/tests/sensors/test_external_task_sensor.py index 7387e1ce8e3d1..97179f925e725 100644 --- a/tests/sensors/test_external_task_sensor.py +++ b/tests/sensors/test_external_task_sensor.py @@ -415,7 +415,8 @@ def test_external_task_sensor_templated(dag_maker): external_task_id='task_{{ ds }}', ) - (instance,) = dag_maker.create_dagrun(execution_date=DEFAULT_DATE).task_instances + dagrun = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED, execution_date=DEFAULT_DATE) + (instance,) = dagrun.task_instances instance.render_templates() assert instance.task.external_dag_id == f"dag_{DEFAULT_DATE.date()}" diff --git a/tests/test_utils/perf/scheduler_dag_execution_timing.py b/tests/test_utils/perf/scheduler_dag_execution_timing.py index 2adbd5f20246f..bb5acc6cb8b03 100755 --- a/tests/test_utils/perf/scheduler_dag_execution_timing.py +++ b/tests/test_utils/perf/scheduler_dag_execution_timing.py @@ -163,18 +163,19 @@ def create_dag_runs(dag, num_runs, session): id_prefix = DagRun.ID_PREFIX - last_dagrun_at = None + last_dagrun_data_interval = None for _ in range(num_runs): - next_info = dag.next_dagrun_info(last_dagrun_at) - last_dagrun_at = next_info.logical_date + next_info = dag.next_dagrun_info(last_dagrun_data_interval) + logical_date = next_info.logical_date dag.create_dagrun( - run_id=f"{id_prefix}{last_dagrun_at.isoformat()}", - execution_date=last_dagrun_at, + run_id=f"{id_prefix}{logical_date.isoformat()}", + execution_date=logical_date, start_date=timezone.utcnow(), state=State.RUNNING, external_trigger=False, session=session, ) + last_dagrun_data_interval = next_info.data_interval @click.command() @@ -253,17 +254,17 @@ def main(num_runs, repeat, pre_create_dag_runs, executor_class, dag_ids): dags.append(dag) reset_dag(dag, session) - next_run_date = dag.normalize_schedule(dag.start_date or min(t.start_date for t in dag.tasks)) + next_info = dag.next_dagrun_info(None) for _ in range(num_runs - 1): - next_run_date = dag.following_schedule(next_run_date) + next_info = dag.next_dagrun_info(next_info.data_interval) end_date = dag.end_date or dag.default_args.get('end_date') - if end_date != next_run_date: + if end_date != next_info.logical_date: message = ( f"DAG {dag_id} has incorrect end_date ({end_date}) for number of runs! " f"It should be " - f" {next_run_date}" + f" {next_info.logical_date}" ) sys.exit(message) diff --git a/tests/ti_deps/deps/test_runnable_exec_date_dep.py b/tests/ti_deps/deps/test_runnable_exec_date_dep.py index 5a5ca3b669050..0577e67e500e8 100644 --- a/tests/ti_deps/deps/test_runnable_exec_date_dep.py +++ b/tests/ti_deps/deps/test_runnable_exec_date_dep.py @@ -26,6 +26,7 @@ from airflow.models import DagRun, TaskInstance from airflow.ti_deps.deps.runnable_exec_date_dep import RunnableExecDateDep from airflow.utils.timezone import datetime +from airflow.utils.types import DagRunType @pytest.fixture(autouse=True, scope="function") @@ -66,6 +67,7 @@ def test_exec_date_dep( start_date=datetime(2015, 1, 1), end_date=datetime(2016, 11, 5), schedule_interval=schedule_interval, + with_dagrun_type=DagRunType.MANUAL, session=session, ) (ti,) = dag_maker.create_dagrun(execution_date=execution_date).task_instances @@ -82,6 +84,7 @@ def test_exec_date_after_end_date(session, dag_maker, create_dummy_dag): start_date=datetime(2015, 1, 1), end_date=datetime(2016, 11, 5), schedule_interval=None, + with_dagrun_type=DagRunType.MANUAL, session=session, ) (ti,) = dag_maker.create_dagrun(execution_date=datetime(2016, 11, 2)).task_instances diff --git a/tests/utils/log/test_log_reader.py b/tests/utils/log/test_log_reader.py index 67e7d5a162521..f891b2fdbdd31 100644 --- a/tests/utils/log/test_log_reader.py +++ b/tests/utils/log/test_log_reader.py @@ -30,6 +30,7 @@ from airflow.utils.log.log_reader import TaskLogReader from airflow.utils.log.logging_mixin import ExternalLoggingMixin from airflow.utils.state import TaskInstanceState +from airflow.utils.types import DagRunType from tests.test_utils.config import conf_vars from tests.test_utils.db import clear_db_dags, clear_db_runs @@ -91,6 +92,7 @@ def prepare_db(self, session, create_task_instance): dag_id=self.DAG_ID, task_id=self.TASK_ID, start_date=self.DEFAULT_DATE, + run_type=DagRunType.SCHEDULED, execution_date=self.DEFAULT_DATE, state=TaskInstanceState.RUNNING, ) diff --git a/tests/utils/test_log_handlers.py b/tests/utils/test_log_handlers.py index 6d5403c3d9f5f..4503dd80303e3 100644 --- a/tests/utils/test_log_handlers.py +++ b/tests/utils/test_log_handlers.py @@ -223,6 +223,7 @@ def filename_rendering_ti(session, create_task_instance): return create_task_instance( dag_id='dag_for_testing_filename_rendering', task_id='task_for_testing_filename_rendering', + run_type=DagRunType.SCHEDULED, execution_date=DEFAULT_DATE, session=session, ) diff --git a/tests/www/views/test_views_log.py b/tests/www/views/test_views_log.py index af3d452e415cd..2b42033d0dbc3 100644 --- a/tests/www/views/test_views_log.py +++ b/tests/www/views/test_views_log.py @@ -108,14 +108,14 @@ def dags(log_app, create_dummy_dag, session): dag_id=DAG_ID, task_id=TASK_ID, start_date=DEFAULT_DATE, - with_dagrun=False, + with_dagrun_type=None, session=session, ) dag_removed, _ = create_dummy_dag( dag_id=DAG_ID_REMOVED, task_id=TASK_ID, start_date=DEFAULT_DATE, - with_dagrun=False, + with_dagrun_type=None, session=session, )