From 5066013417573475dd0fcbd1b49a3a1145b3a470 Mon Sep 17 00:00:00 2001 From: Jared Deckard Date: Fri, 20 Aug 2021 09:35:26 -0500 Subject: [PATCH] Use microsecond integers to fix timedelta precision errors --- src/marshmallow/fields.py | 4 +++- src/marshmallow/utils.py | 8 ++++++++ tests/test_serialization.py | 9 +++++++++ 3 files changed, 20 insertions(+), 1 deletion(-) diff --git a/src/marshmallow/fields.py b/src/marshmallow/fields.py index 88c1bc719..ea5d08cff 100644 --- a/src/marshmallow/fields.py +++ b/src/marshmallow/fields.py @@ -1471,7 +1471,9 @@ def _serialize(self, value, attr, obj, **kwargs): if value is None: return None base_unit = dt.timedelta(**{self.precision: 1}) - return int(value.total_seconds() / base_unit.total_seconds()) + delta = utils.timedelta_to_microseconds(value) + unit = utils.timedelta_to_microseconds(base_unit) + return delta // unit def _deserialize(self, value, attr, data, **kwargs): try: diff --git a/src/marshmallow/utils.py b/src/marshmallow/utils.py index cb5e5002d..74ee67392 100644 --- a/src/marshmallow/utils.py +++ b/src/marshmallow/utils.py @@ -323,3 +323,11 @@ def resolve_field_instance(cls_or_instance): if not isinstance(cls_or_instance, FieldABC): raise FieldInstanceResolutionError return cls_or_instance + + +def timedelta_to_microseconds(value: dt.timedelta) -> int: + """Compute the total microseconds of a timedelta + + https://github.com/python/cpython/blob/bb3e0c240bc60fe08d332ff5955d54197f79751c/Lib/datetime.py#L665-L667 # noqa: B950 + """ + return (value.days * (24 * 3600) + value.seconds) * 1000000 + value.microseconds diff --git a/tests/test_serialization.py b/tests/test_serialization.py index ca6fd0b9a..d191dff37 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -712,6 +712,15 @@ def test_timedelta_field(self, user): user.d7 = None assert field.serialize("d7", user) is None + # https://github.com/marshmallow-code/marshmallow/issues/1856 + user.d8 = dt.timedelta(milliseconds=345) + field = fields.TimeDelta(fields.TimeDelta.MILLISECONDS) + assert field.serialize("d8", user) == 345 + + user.d9 = dt.timedelta(milliseconds=1999) + field = fields.TimeDelta(fields.TimeDelta.SECONDS) + assert field.serialize("d9", user) == 1 + def test_datetime_list_field(self): obj = DateTimeList([dt.datetime.utcnow(), dt.datetime.now()]) field = fields.List(fields.DateTime)