Skip to content

Commit

Permalink
Set end_date and duration for triggers completed with end_from_trigge…
Browse files Browse the repository at this point in the history
…r as True. (#41834)

Co-authored-by: Karthikeyan Singaravelan <[email protected]>
  • Loading branch information
2 people authored and utkarsharma2 committed Sep 2, 2024
1 parent 1bcf94b commit 5699007
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
2 changes: 1 addition & 1 deletion airflow/triggers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def handle_submit(self, *, task_instance: TaskInstance, session: Session = NEW_S
"""
# Mark the task with terminal state and prevent it from resuming on worker
task_instance.trigger_id = None
task_instance.state = self.task_instance_state
task_instance.set_state(self.task_instance_state, session=session)
self._submit_callback_if_necessary(task_instance=task_instance, session=session)
self._push_xcoms_if_necessary(task_instance=task_instance)

Expand Down
10 changes: 9 additions & 1 deletion tests/models/test_trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
import datetime
import json
from typing import Any, AsyncIterator
from unittest.mock import patch

import pendulum
import pytest
import pytz
from cryptography.fernet import Fernet
Expand Down Expand Up @@ -161,11 +163,15 @@ def test_submit_failure(session, create_task_instance):
(TaskSkippedEvent, "skipped"),
],
)
def test_submit_event_task_end(session, create_task_instance, event_cls, expected):
@patch("airflow.utils.timezone.utcnow")
def test_submit_event_task_end(mock_utcnow, session, create_task_instance, event_cls, expected):
"""
Tests that events inheriting BaseTaskEndEvent *don't* re-wake their dependent
but mark them in the appropriate terminal state and send xcom
"""
now = pendulum.now("UTC")
mock_utcnow.return_value = now

# Make a trigger
trigger = Trigger(classpath="does.not.matter", kwargs={})
trigger.id = 1
Expand Down Expand Up @@ -199,6 +205,8 @@ def get_xcoms(ti):
ti = session.query(TaskInstance).one()
assert ti.state == expected
assert ti.next_kwargs is None
assert ti.end_date == now
assert ti.duration is not None
actual_xcoms = {x.key: x.value for x in get_xcoms(ti)}
assert actual_xcoms == {"return_value": "xcomret", "a": "b", "c": "d"}

Expand Down

0 comments on commit 5699007

Please sign in to comment.