Skip to content

Commit

Permalink
Fix Serialization error in TaskCallbackRequest (#25471)
Browse files Browse the repository at this point in the history
How we serialize `SimpleTaskInstance `in `TaskCallbackRequest` class leads to JSON serialization error when there's start_date or end_date in the task instance. Since there's always a start_date on tis, this would always fail.
This PR aims to fix this through a new method on the SimpleTaskInstance that looks for start_date/end_date and converts them to isoformat for serialization.

(cherry picked from commit d7e14ba)
  • Loading branch information
ephraimbuddy committed Aug 15, 2022
1 parent 605126c commit 385f04b
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 5 deletions.
2 changes: 1 addition & 1 deletion airflow/callbacks/callback_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def __init__(

def to_json(self) -> str:
dict_obj = self.__dict__.copy()
dict_obj["simple_task_instance"] = dict_obj["simple_task_instance"].__dict__
dict_obj["simple_task_instance"] = self.simple_task_instance.as_dict()
return json.dumps(dict_obj)

@classmethod
Expand Down
10 changes: 10 additions & 0 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2650,6 +2650,16 @@ def __eq__(self, other):
return self.__dict__ == other.__dict__
return NotImplemented

def as_dict(self):
new_dict = dict(self.__dict__)
for key in new_dict:
if key in ['start_date', 'end_date']:
val = new_dict[key]
if not val or isinstance(val, str):
continue
new_dict.update({key: val.isoformat()})
return new_dict

@classmethod
def from_ti(cls, ti: TaskInstance):
return cls(
Expand Down
21 changes: 17 additions & 4 deletions tests/callbacks/test_callback_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.

import unittest
from datetime import datetime

from parameterized import parameterized
Expand All @@ -29,6 +28,7 @@
from airflow.models.dag import DAG
from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance
from airflow.operators.bash import BashOperator
from airflow.utils import timezone
from airflow.utils.state import State

TI = TaskInstance(
Expand All @@ -38,7 +38,7 @@
)


class TestCallbackRequest(unittest.TestCase):
class TestCallbackRequest:
@parameterized.expand(
[
(CallbackRequest(full_filepath="filepath", msg="task_failure"), CallbackRequest),
Expand All @@ -64,7 +64,20 @@ class TestCallbackRequest(unittest.TestCase):
)
def test_from_json(self, input, request_class):
json_str = input.to_json()

result = request_class.from_json(json_str=json_str)
assert result == input

self.assertEqual(result, input)
def test_taskcallback_to_json_with_start_date_and_end_date(self, session, create_task_instance):
ti = create_task_instance()
ti.start_date = timezone.utcnow()
ti.end_date = timezone.utcnow()
session.merge(ti)
session.flush()
input = TaskCallbackRequest(
full_filepath="filepath",
simple_task_instance=SimpleTaskInstance.from_ti(ti),
is_failure_callback=True,
)
json_str = input.to_json()
result = TaskCallbackRequest.from_json(json_str)
assert input == result

0 comments on commit 385f04b

Please sign in to comment.