Skip to content

Commit

Permalink
Use microsecond integers to fix timedelta precision errors
Browse files Browse the repository at this point in the history
  • Loading branch information
deckar01 committed Aug 24, 2021
1 parent e17b7d5 commit 5066013
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/marshmallow/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -1471,7 +1471,9 @@ def _serialize(self, value, attr, obj, **kwargs):
if value is None:
return None
base_unit = dt.timedelta(**{self.precision: 1})
return int(value.total_seconds() / base_unit.total_seconds())
delta = utils.timedelta_to_microseconds(value)
unit = utils.timedelta_to_microseconds(base_unit)
return delta // unit

def _deserialize(self, value, attr, data, **kwargs):
try:
Expand Down
8 changes: 8 additions & 0 deletions src/marshmallow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,3 +323,11 @@ def resolve_field_instance(cls_or_instance):
if not isinstance(cls_or_instance, FieldABC):
raise FieldInstanceResolutionError
return cls_or_instance


def timedelta_to_microseconds(value: dt.timedelta) -> int:
"""Compute the total microseconds of a timedelta
https://github.com/python/cpython/blob/bb3e0c240bc60fe08d332ff5955d54197f79751c/Lib/datetime.py#L665-L667 # noqa: B950
"""
return (value.days * (24 * 3600) + value.seconds) * 1000000 + value.microseconds
9 changes: 9 additions & 0 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,6 +712,15 @@ def test_timedelta_field(self, user):
user.d7 = None
assert field.serialize("d7", user) is None

# https://github.com/marshmallow-code/marshmallow/issues/1856
user.d8 = dt.timedelta(milliseconds=345)
field = fields.TimeDelta(fields.TimeDelta.MILLISECONDS)
assert field.serialize("d8", user) == 345

user.d9 = dt.timedelta(milliseconds=1999)
field = fields.TimeDelta(fields.TimeDelta.SECONDS)
assert field.serialize("d9", user) == 1

def test_datetime_list_field(self):
obj = DateTimeList([dt.datetime.utcnow(), dt.datetime.now()])
field = fields.List(fields.DateTime)
Expand Down

0 comments on commit 5066013

Please sign in to comment.