Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support serialization as float in TimeDelta field #1998

Merged
merged 4 commits into from
Jun 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AUTHORS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,4 @@ Contributors (chronological)
- Kevin Kirsche `@kkirsche <https://github.com/kkirsche>`_
- Isira Seneviratne `@Isira-Seneviratne <https://github.com/Isira-Seneviratne>`_
- Karthikeyan Singaravelan `@tirkarthi <https://github.com/tirkarthi>`_
- Marco Satti `@marcosatti <https://github.com/marcosatti>`_
49 changes: 41 additions & 8 deletions src/marshmallow/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -1434,17 +1434,35 @@ def _make_object_from_format(value, data_format):

class TimeDelta(Field):
"""A field that (de)serializes a :class:`datetime.timedelta` object to an
integer and vice versa. The integer can represent the number of days,
seconds or microseconds.
integer or float and vice versa. The integer or float can represent the
number of days, seconds or microseconds.

:param precision: Influences how the integer is interpreted during
:param precision: Influences how the integer or float is interpreted during
(de)serialization. Must be 'days', 'seconds', 'microseconds',
'milliseconds', 'minutes', 'hours' or 'weeks'.
:param serialization_type: Whether to (de)serialize to a `int` or `float`.
:param kwargs: The same keyword arguments that :class:`Field` receives.

Integer Caveats
---------------
Any fractional parts (which depends on the precision used) will be truncated
when serializing using `int`.

Float Caveats
-------------
Use of `float` when (de)serializing may result in data precision loss due
to the way machines handle floating point values.

Regardless of the precision chosen, the fractional part when using `float`
will always be truncated to microseconds.
For example, `1.12345` interpreted as microseconds will result in `timedelta(microseconds=1)`.

.. versionchanged:: 2.0.0
Always serializes to an integer value to avoid rounding errors.
Add `precision` parameter.
.. versionchanged:: 3.17.0
Allow (de)serialization to `float` through use of a new `serialization_type` parameter.
`int` is the default to retain previous behaviour.
marcosatti marked this conversation as resolved.
Show resolved Hide resolved
"""

DAYS = "days"
Expand All @@ -1461,7 +1479,12 @@ class TimeDelta(Field):
"format": "{input!r} cannot be formatted as a timedelta.",
}

def __init__(self, precision: str = SECONDS, **kwargs):
def __init__(
self,
precision: str = SECONDS,
serialization_type: type[int | float] = int,
**kwargs,
):
precision = precision.lower()
units = (
self.DAYS,
Expand All @@ -1479,20 +1502,30 @@ def __init__(self, precision: str = SECONDS, **kwargs):
)
raise ValueError(msg)

if serialization_type not in (int, float):
raise ValueError("The serialization type must be one of int or float")

self.precision = precision
self.serialization_type = serialization_type
super().__init__(**kwargs)

def _serialize(self, value, attr, obj, **kwargs):
if value is None:
return None

base_unit = dt.timedelta(**{self.precision: 1})
delta = utils.timedelta_to_microseconds(value)
unit = utils.timedelta_to_microseconds(base_unit)
return delta // unit

if self.serialization_type is int:
delta = utils.timedelta_to_microseconds(value)
unit = utils.timedelta_to_microseconds(base_unit)
return delta // unit
else:
assert self.serialization_type is float
return value.total_seconds() / base_unit.total_seconds()

def _deserialize(self, value, attr, data, **kwargs):
try:
value = int(value)
value = self.serialization_type(value)
except (TypeError, ValueError) as error:
raise self.make_error("invalid") from error

Expand Down
61 changes: 61 additions & 0 deletions tests/test_deserialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,67 @@ def test_timedelta_field_deserialization(self):
assert result.seconds == 123
assert result.microseconds == 456000

total_microseconds_value = 322.0
field = fields.TimeDelta(fields.TimeDelta.MICROSECONDS, float)
result = field.deserialize(total_microseconds_value)
assert isinstance(result, dt.timedelta)
unit_value = dt.timedelta(microseconds=1).total_seconds()
assert math.isclose(
result.total_seconds() / unit_value, total_microseconds_value
)

total_microseconds_value = 322.12345
field = fields.TimeDelta(fields.TimeDelta.MICROSECONDS, float)
result = field.deserialize(total_microseconds_value)
assert isinstance(result, dt.timedelta)
unit_value = dt.timedelta(microseconds=1).total_seconds()
assert math.isclose(
result.total_seconds() / unit_value, math.floor(total_microseconds_value)
)

total_milliseconds_value = 322.223
field = fields.TimeDelta(fields.TimeDelta.MILLISECONDS, float)
result = field.deserialize(total_milliseconds_value)
assert isinstance(result, dt.timedelta)
unit_value = dt.timedelta(milliseconds=1).total_seconds()
assert math.isclose(
result.total_seconds() / unit_value, total_milliseconds_value
)

total_seconds_value = 322.223
field = fields.TimeDelta(fields.TimeDelta.SECONDS, float)
result = field.deserialize(total_seconds_value)
assert isinstance(result, dt.timedelta)
assert math.isclose(result.total_seconds(), total_seconds_value)

total_minutes_value = 322.223
field = fields.TimeDelta(fields.TimeDelta.MINUTES, float)
result = field.deserialize(total_minutes_value)
assert isinstance(result, dt.timedelta)
unit_value = dt.timedelta(minutes=1).total_seconds()
assert math.isclose(result.total_seconds() / unit_value, total_minutes_value)

total_hours_value = 322.223
field = fields.TimeDelta(fields.TimeDelta.HOURS, float)
result = field.deserialize(total_hours_value)
assert isinstance(result, dt.timedelta)
unit_value = dt.timedelta(hours=1).total_seconds()
assert math.isclose(result.total_seconds() / unit_value, total_hours_value)

total_days_value = 322.223
field = fields.TimeDelta(fields.TimeDelta.DAYS, float)
result = field.deserialize(total_days_value)
assert isinstance(result, dt.timedelta)
unit_value = dt.timedelta(days=1).total_seconds()
assert math.isclose(result.total_seconds() / unit_value, total_days_value)

total_weeks_value = 322.223
field = fields.TimeDelta(fields.TimeDelta.WEEKS, float)
result = field.deserialize(total_weeks_value)
assert isinstance(result, dt.timedelta)
unit_value = dt.timedelta(weeks=1).total_seconds()
assert math.isclose(result.total_seconds() / unit_value, total_weeks_value)

@pytest.mark.parametrize("in_value", ["", "badvalue", [], 9999999999])
def test_invalid_timedelta_field_deserialization(self, in_value):
field = fields.TimeDelta(fields.TimeDelta.DAYS)
Expand Down
53 changes: 53 additions & 0 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import decimal
import uuid
import ipaddress
import math

import pytest

Expand Down Expand Up @@ -721,6 +722,58 @@ def test_timedelta_field(self, user):
field = fields.TimeDelta(fields.TimeDelta.SECONDS)
assert field.serialize("d9", user) == 1

user.d10 = dt.timedelta(
weeks=1,
days=6,
hours=2,
minutes=5,
seconds=51,
milliseconds=10,
microseconds=742,
)

field = fields.TimeDelta(fields.TimeDelta.MICROSECONDS, float)
unit_value = dt.timedelta(microseconds=1).total_seconds()
assert math.isclose(
field.serialize("d10", user), user.d10.total_seconds() / unit_value
)

field = fields.TimeDelta(fields.TimeDelta.MILLISECONDS, float)
unit_value = dt.timedelta(milliseconds=1).total_seconds()
assert math.isclose(
field.serialize("d10", user), user.d10.total_seconds() / unit_value
)

field = fields.TimeDelta(fields.TimeDelta.SECONDS, float)
assert math.isclose(field.serialize("d10", user), user.d10.total_seconds())

field = fields.TimeDelta(fields.TimeDelta.MINUTES, float)
unit_value = dt.timedelta(minutes=1).total_seconds()
assert math.isclose(
field.serialize("d10", user), user.d10.total_seconds() / unit_value
)

field = fields.TimeDelta(fields.TimeDelta.HOURS, float)
unit_value = dt.timedelta(hours=1).total_seconds()
assert math.isclose(
field.serialize("d10", user), user.d10.total_seconds() / unit_value
)

field = fields.TimeDelta(fields.TimeDelta.DAYS, float)
unit_value = dt.timedelta(days=1).total_seconds()
assert math.isclose(
field.serialize("d10", user), user.d10.total_seconds() / unit_value
)

field = fields.TimeDelta(fields.TimeDelta.WEEKS, float)
unit_value = dt.timedelta(weeks=1).total_seconds()
assert math.isclose(
field.serialize("d10", user), user.d10.total_seconds() / unit_value
)

with pytest.raises(ValueError):
fields.TimeDelta(fields.TimeDelta.SECONDS, str)

def test_datetime_list_field(self):
obj = DateTimeList([dt.datetime.utcnow(), dt.datetime.now()])
field = fields.List(fields.DateTime)
Expand Down