Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added logical_date parameter #39285

Merged
merged 3 commits into from
Apr 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 37 additions & 25 deletions airflow/operators/trigger_dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@
import datetime
import json
import time
import warnings
from typing import TYPE_CHECKING, Any, Sequence, cast

from sqlalchemy import select
from sqlalchemy.orm.exc import NoResultFound

from airflow.api.common.trigger_dag import trigger_dag
from airflow.configuration import conf
from airflow.exceptions import AirflowException, DagNotFound, DagRunAlreadyExists
from airflow.exceptions import AirflowException, DagNotFound, DagRunAlreadyExists, RemovedInAirflow3Warning
from airflow.models.baseoperator import BaseOperator
from airflow.models.baseoperatorlink import BaseOperatorLink
from airflow.models.dag import DagModel
Expand All @@ -41,7 +42,7 @@
from airflow.utils.state import DagRunState
from airflow.utils.types import DagRunType

XCOM_EXECUTION_DATE_ISO = "trigger_execution_date_iso"
XCOM_LOGICAL_DATE_ISO = "trigger_logical_date_iso"
XCOM_RUN_ID = "trigger_run_id"


Expand All @@ -64,7 +65,7 @@ class TriggerDagRunLink(BaseOperatorLink):
def get_link(self, operator: BaseOperator, *, ti_key: TaskInstanceKey) -> str:
# Fetch the correct execution date for the triggerED dag which is
# stored in xcom during execution of the triggerING task.
when = XCom.get_value(ti_key=ti_key, key=XCOM_EXECUTION_DATE_ISO)
when = XCom.get_value(ti_key=ti_key, key=XCOM_LOGICAL_DATE_ISO)
query = {"dag_id": cast(TriggerDagRunOperator, operator).trigger_dag_id, "base_date": when}
return build_airflow_url_with_query(query)

Expand All @@ -77,7 +78,7 @@ class TriggerDagRunOperator(BaseOperator):
:param trigger_run_id: The run ID to use for the triggered DAG run (templated).
If not provided, a run ID will be automatically generated.
:param conf: Configuration for the DAG run (templated).
:param execution_date: Execution date for the dag (templated).
:param logical_date: Logical date for the dag (templated).
:param reset_dag_run: Whether clear existing dag run if already exists.
This is useful when backfill or rerun an existing dag run.
This only resets (not recreates) the dag run.
Expand All @@ -91,12 +92,13 @@ class TriggerDagRunOperator(BaseOperator):
:param failed_states: List of failed or dis-allowed states, default is ``None``.
:param deferrable: If waiting for completion, whether or not to defer the task until done,
default is ``False``.
:param execution_date: Deprecated parameter; same as ``logical_date``.
"""

template_fields: Sequence[str] = (
"trigger_dag_id",
"trigger_run_id",
"execution_date",
"logical_date",
"conf",
"wait_for_completion",
)
Expand All @@ -110,13 +112,14 @@ def __init__(
trigger_dag_id: str,
trigger_run_id: str | None = None,
conf: dict | None = None,
execution_date: str | datetime.datetime | None = None,
logical_date: str | datetime.datetime | None = None,
reset_dag_run: bool = False,
wait_for_completion: bool = False,
poke_interval: int = 60,
allowed_states: list[str] | None = None,
failed_states: list[str] | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
execution_date: str | datetime.datetime | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -136,20 +139,29 @@ def __init__(
self.failed_states = [DagRunState.FAILED]
self._defer = deferrable

if execution_date is not None and not isinstance(execution_date, (str, datetime.datetime)):
if execution_date is not None:
warnings.warn(
"Parameter 'execution_date' is deprecated. Use 'logical_date' instead.",
RemovedInAirflow3Warning,
stacklevel=2,
)
logical_date = execution_date

if logical_date is not None and not isinstance(logical_date, (str, datetime.datetime)):
type_name = type(logical_date).__name__
raise TypeError(
f"Expected str or datetime.datetime type for execution_date.Got {type(execution_date)}"
f"Expected str or datetime.datetime type for parameter 'logical_date'. Got {type_name}"
)

self.execution_date = execution_date
self.logical_date = logical_date

def execute(self, context: Context):
if isinstance(self.execution_date, datetime.datetime):
parsed_execution_date = self.execution_date
elif isinstance(self.execution_date, str):
parsed_execution_date = timezone.parse(self.execution_date)
if isinstance(self.logical_date, datetime.datetime):
parsed_logical_date = self.logical_date
elif isinstance(self.logical_date, str):
parsed_logical_date = timezone.parse(self.logical_date)
else:
parsed_execution_date = timezone.utcnow()
parsed_logical_date = timezone.utcnow()

try:
json.dumps(self.conf)
Expand All @@ -159,20 +171,20 @@ def execute(self, context: Context):
if self.trigger_run_id:
run_id = str(self.trigger_run_id)
else:
run_id = DagRun.generate_run_id(DagRunType.MANUAL, parsed_execution_date)
run_id = DagRun.generate_run_id(DagRunType.MANUAL, parsed_logical_date)

try:
dag_run = trigger_dag(
dag_id=self.trigger_dag_id,
run_id=run_id,
conf=self.conf,
execution_date=parsed_execution_date,
execution_date=parsed_logical_date,
replace_microseconds=False,
)

except DagRunAlreadyExists as e:
if self.reset_dag_run:
self.log.info("Clearing %s on %s", self.trigger_dag_id, parsed_execution_date)
self.log.info("Clearing %s on %s", self.trigger_dag_id, parsed_logical_date)

# Get target dag object and call clear()
dag_model = DagModel.get_current(self.trigger_dag_id)
Expand All @@ -182,15 +194,15 @@ def execute(self, context: Context):
dag_bag = DagBag(dag_folder=dag_model.fileloc, read_dags_from_db=True)
dag = dag_bag.get_dag(self.trigger_dag_id)
dag_run = e.dag_run
dag.clear(start_date=dag_run.execution_date, end_date=dag_run.execution_date)
dag.clear(start_date=dag_run.logical_date, end_date=dag_run.logical_date)
else:
raise e
if dag_run is None:
raise RuntimeError("The dag_run should be set here!")
# Store the execution date from the dag run (either created or found above) to
# be used when creating the extra link on the webserver.
ti = context["task_instance"]
ti.xcom_push(key=XCOM_EXECUTION_DATE_ISO, value=dag_run.execution_date.isoformat())
ti.xcom_push(key=XCOM_LOGICAL_DATE_ISO, value=dag_run.logical_date.isoformat())
ti.xcom_push(key=XCOM_RUN_ID, value=dag_run.run_id)

if self.wait_for_completion:
Expand All @@ -200,7 +212,7 @@ def execute(self, context: Context):
trigger=DagStateTrigger(
dag_id=self.trigger_dag_id,
states=self.allowed_states + self.failed_states,
execution_dates=[parsed_execution_date],
execution_dates=[parsed_logical_date],
poll_interval=self.poke_interval,
),
method_name="execute_complete",
Expand All @@ -210,7 +222,7 @@ def execute(self, context: Context):
self.log.info(
"Waiting for %s on %s to become allowed state %s ...",
self.trigger_dag_id,
dag_run.execution_date,
dag_run.logical_date,
self.allowed_states,
)
time.sleep(self.poke_interval)
Expand All @@ -225,17 +237,17 @@ def execute(self, context: Context):

@provide_session
def execute_complete(self, context: Context, session: Session, event: tuple[str, dict[str, Any]]):
# This execution date is parsed from the return trigger event
provided_execution_date = event[1]["execution_dates"][0]
# This logical_date is parsed from the return trigger event
provided_logical_date = event[1]["execution_dates"][0]
try:
dag_run = session.execute(
select(DagRun).where(
DagRun.dag_id == self.trigger_dag_id, DagRun.execution_date == provided_execution_date
DagRun.dag_id == self.trigger_dag_id, DagRun.execution_date == provided_logical_date
)
).scalar_one()
except NoResultFound:
raise AirflowException(
f"No DAG run found for DAG {self.trigger_dag_id} and execution date {self.execution_date}"
f"No DAG run found for DAG {self.trigger_dag_id} and logical date {self.logical_date}"
)

state = dag_run.state
Expand Down
Loading