Skip to content

Commit

Permalink
Deprecate TimeDelta serialization_type parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
ddelange committed Dec 23, 2024
1 parent 0755fe1 commit 3470680
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 77 deletions.
76 changes: 35 additions & 41 deletions src/marshmallow/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -1435,26 +1435,18 @@ def _make_object_from_format(value, data_format):


class TimeDelta(Field):
"""A field that (de)serializes a :class:`datetime.timedelta` object to an
integer or float. The integer or float can represent any time unit that the
:class:`datetime.timedelta` constructor supports.
"""A field that (de)serializes a :class:`datetime.timedelta` object to a `float`.
The `float` can represent any time unit that the :class:`datetime.timedelta` constructor
supports.
:param precision: The time unit used for (de)serialization. Must be one of 'weeks',
'days', 'hours', 'minutes', 'seconds', 'milliseconds' or 'microseconds'.
:param serialization_type: Whether to serialize to an `int` or `float`.
Ignored during deserialization: both `int` and `float` inputs are supported.
:param kwargs: The same keyword arguments that :class:`Field` receives.
Integer Caveats
---------------
When serializing using ``serialization_type=int`` and depending on the ``precision``
used, any fractional parts might be truncated (downcast to integer).
Float Caveats
-------------
Precision loss may occur when serializing a highly precise :class:`datetime.timedelta`
object using ``serialization_type=float`` and a big ``precision`` unit due to floating
point arithmetics.
object using a big ``precision`` unit due to floating point arithmetics.
When necessary, the :class:`datetime.timedelta` constructor rounds `float` inputs
to whole microseconds during initialization of the object. As a result, deserializing
Expand All @@ -1466,7 +1458,9 @@ class TimeDelta(Field):
Add `precision` parameter.
.. versionchanged:: 3.17.0
Allow serialization to `float` through use of a new `serialization_type` parameter.
Defaults to `int` for backwards compatibility.
Defaults to `int` for backwards compatibility. Also affects deserialization.
.. versionchanged:: 4.0.0
Deprecate `serialization_type` parameter, always serialize to float.
"""

WEEKS = "weeks"
Expand All @@ -1477,6 +1471,17 @@ class TimeDelta(Field):
MILLISECONDS = "milliseconds"
MICROSECONDS = "microseconds"

# cache this mapping on class level for performance
_unit_to_microseconds_mapping = {
WEEKS: 1000000 * 60 * 60 * 24 * 7,
DAYS: 1000000 * 60 * 60 * 24,
HOURS: 1000000 * 60 * 60,
MINUTES: 1000000 * 60,
SECONDS: 1000000,
MILLISECONDS: 1000,
MICROSECONDS: 1,
}

#: Default error messages.
default_error_messages = {
"invalid": "Not a valid period of time.",
Expand All @@ -1486,47 +1491,36 @@ class TimeDelta(Field):
def __init__(
self,
precision: str = SECONDS,
serialization_type: type[int | float] = int,
serialization_type: typing.Any = missing_,
**kwargs,
):
) -> None:
precision = precision.lower()
units = (
self.DAYS,
self.SECONDS,
self.MICROSECONDS,
self.MILLISECONDS,
self.MINUTES,
self.HOURS,
self.WEEKS,
)

if precision not in units:
msg = 'The precision must be {} or "{}".'.format(
", ".join([f'"{each}"' for each in units[:-1]]), units[-1]
)
if precision not in self._unit_to_microseconds_mapping:
units = ", ".join(self._unit_to_microseconds_mapping)
msg = f"The precision must be one of: {units}."
raise ValueError(msg)

if serialization_type not in (int, float):
raise ValueError("The serialization type must be one of int or float")
if serialization_type is not missing_:
warnings.warn(
"The 'serialization_type' argument to TimeDelta is deprecated.",
RemovedInMarshmallow4Warning,
stacklevel=2,
)

self.precision = precision
self.serialization_type = serialization_type
super().__init__(**kwargs)

def _serialize(self, value, attr, obj, **kwargs):
def _serialize(self, value, attr, obj, **kwargs) -> float | None:
if value is None:
return None

base_unit = dt.timedelta(**{self.precision: 1})

if self.serialization_type is int:
delta = utils.timedelta_to_microseconds(value)
unit = utils.timedelta_to_microseconds(base_unit)
return delta // unit
assert self.serialization_type is float
return value.total_seconds() / base_unit.total_seconds()
# limit float arithmetics to a single division to minimize precision loss
microseconds: int = utils.timedelta_to_microseconds(value)
microseconds_per_unit: int = self._unit_to_microseconds_mapping[self.precision]
return microseconds / microseconds_per_unit

def _deserialize(self, value, attr, data, **kwargs):
def _deserialize(self, value, attr, data, **kwargs) -> dt.timedelta:
try:
value = float(value)
except (TypeError, ValueError) as error:
Expand Down
4 changes: 2 additions & 2 deletions src/marshmallow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,9 +363,9 @@ def resolve_field_instance(cls_or_instance):


def timedelta_to_microseconds(value: dt.timedelta) -> int:
"""Compute the total microseconds of a timedelta
"""Compute the total microseconds of a timedelta.
https://github.com/python/cpython/blob/bb3e0c240bc60fe08d332ff5955d54197f79751c/Lib/datetime.py#L665-L667 # noqa: B950
https://github.com/python/cpython/blob/v3.13.1/Lib/_pydatetime.py#L805-L807
"""
return (value.days * (24 * 3600) + value.seconds) * 1000000 + value.microseconds

Expand Down
2 changes: 1 addition & 1 deletion tests/test_deserialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,7 +698,7 @@ def test_iso_time_field_deserialization(self, fmt, value, expected):
assert field.deserialize(value) == expected

def test_invalid_timedelta_precision(self):
with pytest.raises(ValueError, match='The precision must be "days",'):
with pytest.raises(ValueError, match="The precision must be one of: weeks,"):
fields.TimeDelta("invalid")

def test_timedelta_field_deserialization(self):
Expand Down
55 changes: 22 additions & 33 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,25 +704,25 @@ def test_timedelta_field(self, user):
)

field = fields.TimeDelta(fields.TimeDelta.DAYS)
assert field.serialize("d1", user) == 1
assert field.serialize("d1", user) == 1.0000115740856481
field = fields.TimeDelta(fields.TimeDelta.SECONDS)
assert field.serialize("d1", user) == 86401
assert field.serialize("d1", user) == 86401.000001
field = fields.TimeDelta(fields.TimeDelta.MICROSECONDS)
assert field.serialize("d1", user) == 86401000001
field = fields.TimeDelta(fields.TimeDelta.HOURS)
assert field.serialize("d1", user) == 24
assert field.serialize("d1", user) == 24.000277778055555

field = fields.TimeDelta(fields.TimeDelta.DAYS)
assert field.serialize("d2", user) == 1
assert field.serialize("d2", user) == 1.0000115740856481
field = fields.TimeDelta(fields.TimeDelta.SECONDS)
assert field.serialize("d2", user) == 86401
assert field.serialize("d2", user) == 86401.000001
field = fields.TimeDelta(fields.TimeDelta.MICROSECONDS)
assert field.serialize("d2", user) == 86401000001

field = fields.TimeDelta(fields.TimeDelta.DAYS)
assert field.serialize("d3", user) == 1
assert field.serialize("d3", user) == 1.0000115740856481
field = fields.TimeDelta(fields.TimeDelta.SECONDS)
assert field.serialize("d3", user) == 86401
assert field.serialize("d3", user) == 86401.000001
field = fields.TimeDelta(fields.TimeDelta.MICROSECONDS)
assert field.serialize("d3", user) == 86401000001

Expand All @@ -741,38 +741,30 @@ def test_timedelta_field(self, user):
assert field.serialize("d5", user) == -86400000000

field = fields.TimeDelta(fields.TimeDelta.WEEKS)
assert field.serialize("d6", user) == 1
assert field.serialize("d6", user) == 1.1489103852529763
field = fields.TimeDelta(fields.TimeDelta.DAYS)
assert field.serialize("d6", user) == 7 + 1
assert field.serialize("d6", user) == 8.042372696770833
field = fields.TimeDelta(fields.TimeDelta.HOURS)
assert field.serialize("d6", user) == 7 * 24 + 24 + 1
assert field.serialize("d6", user) == 193.0169447225
field = fields.TimeDelta(fields.TimeDelta.MINUTES)
assert field.serialize("d6", user) == 7 * 24 * 60 + 24 * 60 + 60 + 1
d6_seconds = (
7 * 24 * 60 * 60
+ 24 * 60 * 60 # 1 week
+ 60 * 60 # 1 day
+ 60 # 1 hour
+ 1 # 1 minute
)
assert field.serialize("d6", user) == 11581.01668335
field = fields.TimeDelta(fields.TimeDelta.SECONDS)
assert field.serialize("d6", user) == d6_seconds
assert field.serialize("d6", user) == 694861.001001
field = fields.TimeDelta(fields.TimeDelta.MILLISECONDS)
assert field.serialize("d6", user) == d6_seconds * 1000 + 1
assert field.serialize("d6", user) == 694861001.001
field = fields.TimeDelta(fields.TimeDelta.MICROSECONDS)
assert field.serialize("d6", user) == d6_seconds * 10**6 + 1000 + 1
assert field.serialize("d6", user) == 694861001001

user.d7 = None
assert field.serialize("d7", user) is None

# https://github.com/marshmallow-code/marshmallow/issues/1856
user.d8 = dt.timedelta(milliseconds=345)
field = fields.TimeDelta(fields.TimeDelta.MILLISECONDS)
assert field.serialize("d8", user) == 345

user.d9 = dt.timedelta(milliseconds=1999)
field = fields.TimeDelta(fields.TimeDelta.SECONDS)
assert field.serialize("d9", user) == 1
assert field.serialize("d9", user) == 1.999

user.d10 = dt.timedelta(
weeks=1,
Expand All @@ -784,48 +776,45 @@ def test_timedelta_field(self, user):
microseconds=742,
)

field = fields.TimeDelta(fields.TimeDelta.MICROSECONDS, float)
field = fields.TimeDelta(fields.TimeDelta.MICROSECONDS)
unit_value = dt.timedelta(microseconds=1).total_seconds()
assert math.isclose(
field.serialize("d10", user), user.d10.total_seconds() / unit_value
)

field = fields.TimeDelta(fields.TimeDelta.MILLISECONDS, float)
field = fields.TimeDelta(fields.TimeDelta.MILLISECONDS)
unit_value = dt.timedelta(milliseconds=1).total_seconds()
assert math.isclose(
field.serialize("d10", user), user.d10.total_seconds() / unit_value
)

field = fields.TimeDelta(fields.TimeDelta.SECONDS, float)
field = fields.TimeDelta(fields.TimeDelta.SECONDS)
assert math.isclose(field.serialize("d10", user), user.d10.total_seconds())

field = fields.TimeDelta(fields.TimeDelta.MINUTES, float)
field = fields.TimeDelta(fields.TimeDelta.MINUTES)
unit_value = dt.timedelta(minutes=1).total_seconds()
assert math.isclose(
field.serialize("d10", user), user.d10.total_seconds() / unit_value
)

field = fields.TimeDelta(fields.TimeDelta.HOURS, float)
field = fields.TimeDelta(fields.TimeDelta.HOURS)
unit_value = dt.timedelta(hours=1).total_seconds()
assert math.isclose(
field.serialize("d10", user), user.d10.total_seconds() / unit_value
)

field = fields.TimeDelta(fields.TimeDelta.DAYS, float)
field = fields.TimeDelta(fields.TimeDelta.DAYS)
unit_value = dt.timedelta(days=1).total_seconds()
assert math.isclose(
field.serialize("d10", user), user.d10.total_seconds() / unit_value
)

field = fields.TimeDelta(fields.TimeDelta.WEEKS, float)
field = fields.TimeDelta(fields.TimeDelta.WEEKS)
unit_value = dt.timedelta(weeks=1).total_seconds()
assert math.isclose(
field.serialize("d10", user), user.d10.total_seconds() / unit_value
)

with pytest.raises(ValueError):
fields.TimeDelta(fields.TimeDelta.SECONDS, str)

def test_datetime_list_field(self):
obj = DateTimeList([dt.datetime.now(dt.timezone.utc), dt.datetime.now()])
field = fields.List(fields.DateTime)
Expand Down

0 comments on commit 3470680

Please sign in to comment.