diff --git a/airflow/models.py b/airflow/models.py index 46cc21928d941..fa336098522d9 100755 --- a/airflow/models.py +++ b/airflow/models.py @@ -1686,14 +1686,6 @@ def signal_handler(signum, frame): self.handle_failure(e, test_mode, context) raise - # Recording SUCCESS - self.end_date = timezone.utcnow() - self.set_duration() - if not test_mode: - session.add(Log(self.state, self)) - session.merge(self) - session.commit() - # Success callback try: if task.on_success_callback: @@ -1702,6 +1694,12 @@ def signal_handler(signum, frame): self.log.error("Failed when executing success callback") self.log.exception(e3) + # Recording SUCCESS + self.end_date = timezone.utcnow() + self.set_duration() + if not test_mode: + session.add(Log(self.state, self)) + session.merge(self) session.commit() @provide_session diff --git a/tests/models.py b/tests/models.py index c14b42c86e605..b16055b380abb 100644 --- a/tests/models.py +++ b/tests/models.py @@ -2558,6 +2558,40 @@ def test_set_duration_empty_dates(self): ti.set_duration() self.assertIsNone(ti.duration) + def test_success_callbak_no_race_condition(self): + class CallbackWrapper(object): + def wrap_task_instance(self, ti): + self.task_id = ti.task_id + self.dag_id = ti.dag_id + self.execution_date = ti.execution_date + self.task_state_in_callback = "" + self.callback_ran = False + + def success_handler(self, context): + self.callback_ran = True + session = settings.Session() + temp_instance = session.query(TI).filter( + TI.task_id == self.task_id).filter( + TI.dag_id == self.dag_id).filter( + TI.execution_date == self.execution_date).one() + self.task_state_in_callback = temp_instance.state + cw = CallbackWrapper() + dag = DAG('test_success_callbak_no_race_condition', start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE + datetime.timedelta(days=10)) + task = DummyOperator(task_id='op', email='test@test.test', + on_success_callback=cw.success_handler, dag=dag) + ti = TI(task=task, execution_date=datetime.datetime.now()) + ti.state = State.RUNNING + session = settings.Session() + session.merge(ti) + session.commit() + cw.wrap_task_instance(ti) + ti._run_raw_task() + self.assertTrue(cw.callback_ran) + self.assertEqual(cw.task_state_in_callback, State.RUNNING) + ti.refresh_from_db() + self.assertEqual(ti.state, State.SUCCESS) + class ClearTasksTest(unittest.TestCase):