Skip to content

Commit

Permalink
Make _run_raw_task AIP-44 compatible (#38992)
Browse files Browse the repository at this point in the history
Migrate _run_raw_task to work with AIP-44. Also, _handle_reschedule, defer_task, xcom_pull and xcom_push.
  • Loading branch information
dstandish authored May 24, 2024
1 parent 6fb9fb7 commit bca2930
Show file tree
Hide file tree
Showing 8 changed files with 592 additions and 316 deletions.
21 changes: 18 additions & 3 deletions airflow/api_internal/endpoints/rpc_api_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from flask import Response

from airflow.jobs.job import Job, most_recent_job
from airflow.models.taskinstance import _get_template_context, _update_rtif
from airflow.sensors.base import _orig_start_date
from airflow.serialization.serialized_objects import BaseSerialization
from airflow.utils.session import create_session
Expand All @@ -42,22 +41,37 @@ def _initialize_map() -> dict[str, Callable]:
from airflow.cli.commands.task_command import _get_ti_db_access
from airflow.dag_processing.manager import DagFileProcessorManager
from airflow.dag_processing.processor import DagFileProcessor
from airflow.datasets.manager import DatasetManager
from airflow.models import Trigger, Variable, XCom
from airflow.models.dag import DAG, DagModel
from airflow.models.dagrun import DagRun
from airflow.models.dagwarning import DagWarning
from airflow.models.serialized_dag import SerializedDagModel
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskinstance import (
TaskInstance,
_add_log,
_defer_task,
_get_template_context,
_handle_failure,
_handle_reschedule,
_update_rtif,
_xcom_pull,
)
from airflow.secrets.metastore import MetastoreBackend
from airflow.utils.cli_action_loggers import _default_action_log_internal
from airflow.utils.log.file_task_handler import FileTaskHandler

functions: list[Callable] = [
_default_action_log_internal,
_defer_task,
_get_template_context,
_get_ti_db_access,
_update_rtif,
_orig_start_date,
_handle_failure,
_handle_reschedule,
_add_log,
_xcom_pull,
DagFileProcessor.update_import_errors,
DagFileProcessor.manage_slas,
DagFileProcessorManager.deactivate_stale_dags,
Expand All @@ -66,6 +80,7 @@ def _initialize_map() -> dict[str, Callable]:
DagModel.get_current,
DagFileProcessorManager.clear_nonexistent_import_errors,
DagWarning.purge_inactive_dag_warnings,
DatasetManager.register_dataset_change,
FileTaskHandler._render_filename_db_access,
Job._add_to_db,
Job._fetch_from_db,
Expand All @@ -79,6 +94,7 @@ def _initialize_map() -> dict[str, Callable]:
XCom.get_one,
XCom.get_many,
XCom.clear,
XCom.set,
Variable.set,
Variable.update,
Variable.delete,
Expand All @@ -94,7 +110,6 @@ def _initialize_map() -> dict[str, Callable]:
TaskInstance.get_task_instance,
TaskInstance._get_dagrun,
TaskInstance._set_state,
TaskInstance.fetch_handle_failure_context,
TaskInstance.save_to_db,
TaskInstance._schedule_downstream_tasks,
TaskInstance._clear_xcom_data,
Expand Down
2 changes: 1 addition & 1 deletion airflow/cli/commands/task_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,7 +673,7 @@ def task_test(args, dag: DAG | None = None, session: Session = NEW_SESSION) -> N
else:
ti.run(ignore_task_deps=True, ignore_ti_state=True, test_mode=True, raise_on_defer=True)
except TaskDeferred as defer:
ti.defer_task(defer=defer, session=session)
ti.defer_task(exception=defer, session=session)
log.info("[TASK TEST] running trigger in line")

event = _run_inline_trigger(defer.trigger)
Expand Down
36 changes: 23 additions & 13 deletions airflow/datasets/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@
from sqlalchemy import exc, select
from sqlalchemy.orm import joinedload

from airflow.api_internal.internal_api_call import internal_api_call
from airflow.configuration import conf
from airflow.datasets import Dataset
from airflow.listeners.listener import get_listener_manager
from airflow.models.dataset import DagScheduleDatasetReference, DatasetDagRunQueue, DatasetEvent, DatasetModel
from airflow.stats import Stats
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import NEW_SESSION, provide_session

if TYPE_CHECKING:
from sqlalchemy.orm.session import Session
Expand Down Expand Up @@ -56,13 +58,16 @@ def create_datasets(self, dataset_models: list[DatasetModel], session: Session)
for dataset_model in dataset_models:
self.notify_dataset_created(dataset=Dataset(uri=dataset_model.uri, extra=dataset_model.extra))

@classmethod
@internal_api_call
@provide_session
def register_dataset_change(
self,
cls,
*,
task_instance: TaskInstance | None = None,
dataset: Dataset,
extra=None,
session: Session,
session: Session = NEW_SESSION,
**kwargs,
) -> DatasetEvent | None:
"""
Expand All @@ -71,13 +76,14 @@ def register_dataset_change(
For local datasets, look them up, record the dataset event, queue dagruns, and broadcast
the dataset event
"""
# todo: add test so that all usages of internal_api_call are added to rpc endpoint
dataset_model = session.scalar(
select(DatasetModel)
.where(DatasetModel.uri == dataset.uri)
.options(joinedload(DatasetModel.consuming_dags).joinedload(DagScheduleDatasetReference.dag))
)
if not dataset_model:
self.log.warning("DatasetModel %s not found", dataset)
cls.logger().warning("DatasetModel %s not found", dataset)
return None

event_kwargs = {
Expand All @@ -97,22 +103,24 @@ def register_dataset_change(
session.add(dataset_event)
session.flush()

self.notify_dataset_changed(dataset=dataset)
cls.notify_dataset_changed(dataset=dataset)

Stats.incr("dataset.updates")
self._queue_dagruns(dataset_model, session)
cls._queue_dagruns(dataset_model, session)
session.flush()
return dataset_event

def notify_dataset_created(self, dataset: Dataset):
"""Run applicable notification actions when a dataset is created."""
get_listener_manager().hook.on_dataset_created(dataset=dataset)

def notify_dataset_changed(self, dataset: Dataset):
@classmethod
def notify_dataset_changed(cls, dataset: Dataset):
"""Run applicable notification actions when a dataset is changed."""
get_listener_manager().hook.on_dataset_changed(dataset=dataset)

def _queue_dagruns(self, dataset: DatasetModel, session: Session) -> None:
@classmethod
def _queue_dagruns(cls, dataset: DatasetModel, session: Session) -> None:
# Possible race condition: if multiple dags or multiple (usually
# mapped) tasks update the same dataset, this can fail with a unique
# constraint violation.
Expand All @@ -123,10 +131,11 @@ def _queue_dagruns(self, dataset: DatasetModel, session: Session) -> None:
# where `ti.state` is changed.

if session.bind.dialect.name == "postgresql":
return self._postgres_queue_dagruns(dataset, session)
return self._slow_path_queue_dagruns(dataset, session)
return cls._postgres_queue_dagruns(dataset, session)
return cls._slow_path_queue_dagruns(dataset, session)

def _slow_path_queue_dagruns(self, dataset: DatasetModel, session: Session) -> None:
@classmethod
def _slow_path_queue_dagruns(cls, dataset: DatasetModel, session: Session) -> None:
def _queue_dagrun_if_needed(dag: DagModel) -> str | None:
if not dag.is_active or dag.is_paused:
return None
Expand All @@ -137,14 +146,15 @@ def _queue_dagrun_if_needed(dag: DagModel) -> str | None:
with session.begin_nested():
session.merge(item)
except exc.IntegrityError:
self.log.debug("Skipping record %s", item, exc_info=True)
cls.logger().debug("Skipping record %s", item, exc_info=True)
return dag.dag_id

queued_results = (_queue_dagrun_if_needed(ref.dag) for ref in dataset.consuming_dags)
if queued_dag_ids := [r for r in queued_results if r is not None]:
self.log.debug("consuming dag ids %s", queued_dag_ids)
cls.logger().debug("consuming dag ids %s", queued_dag_ids)

def _postgres_queue_dagruns(self, dataset: DatasetModel, session: Session) -> None:
@classmethod
def _postgres_queue_dagruns(cls, dataset: DatasetModel, session: Session) -> None:
from sqlalchemy.dialects.postgresql import insert

values = [
Expand Down
2 changes: 1 addition & 1 deletion airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -1548,7 +1548,7 @@ def schedule_tis(
if ti.state != TaskInstanceState.UP_FOR_RESCHEDULE:
ti.try_number += 1
ti.defer_task(
defer=TaskDeferred(trigger=ti.task.start_trigger, method_name=ti.task.next_method),
exception=TaskDeferred(trigger=ti.task.start_trigger, method_name=ti.task.next_method),
session=session,
)
else:
Expand Down
Loading

0 comments on commit bca2930

Please sign in to comment.