From 3470680cdfea01ddfdd5a1257b5553cf5153b1eb Mon Sep 17 00:00:00 2001 From: ddelange <14880945+ddelange@users.noreply.github.com> Date: Mon, 23 Dec 2024 16:24:10 +0100 Subject: [PATCH] Deprecate TimeDelta serialization_type parameter --- src/marshmallow/fields.py | 76 ++++++++++++++++------------------- src/marshmallow/utils.py | 4 +- tests/test_deserialization.py | 2 +- tests/test_serialization.py | 55 ++++++++++--------------- 4 files changed, 60 insertions(+), 77 deletions(-) diff --git a/src/marshmallow/fields.py b/src/marshmallow/fields.py index b564aee9b..44ddc3692 100644 --- a/src/marshmallow/fields.py +++ b/src/marshmallow/fields.py @@ -1435,26 +1435,18 @@ def _make_object_from_format(value, data_format): class TimeDelta(Field): - """A field that (de)serializes a :class:`datetime.timedelta` object to an - integer or float. The integer or float can represent any time unit that the - :class:`datetime.timedelta` constructor supports. + """A field that (de)serializes a :class:`datetime.timedelta` object to a `float`. + The `float` can represent any time unit that the :class:`datetime.timedelta` constructor + supports. :param precision: The time unit used for (de)serialization. Must be one of 'weeks', 'days', 'hours', 'minutes', 'seconds', 'milliseconds' or 'microseconds'. - :param serialization_type: Whether to serialize to an `int` or `float`. - Ignored during deserialization: both `int` and `float` inputs are supported. :param kwargs: The same keyword arguments that :class:`Field` receives. - Integer Caveats - --------------- - When serializing using ``serialization_type=int`` and depending on the ``precision`` - used, any fractional parts might be truncated (downcast to integer). - Float Caveats ------------- Precision loss may occur when serializing a highly precise :class:`datetime.timedelta` - object using ``serialization_type=float`` and a big ``precision`` unit due to floating - point arithmetics. + object using a big ``precision`` unit due to floating point arithmetics. When necessary, the :class:`datetime.timedelta` constructor rounds `float` inputs to whole microseconds during initialization of the object. As a result, deserializing @@ -1466,7 +1458,9 @@ class TimeDelta(Field): Add `precision` parameter. .. versionchanged:: 3.17.0 Allow serialization to `float` through use of a new `serialization_type` parameter. - Defaults to `int` for backwards compatibility. + Defaults to `int` for backwards compatibility. Also affects deserialization. + .. versionchanged:: 4.0.0 + Deprecate `serialization_type` parameter, always serialize to float. """ WEEKS = "weeks" @@ -1477,6 +1471,17 @@ class TimeDelta(Field): MILLISECONDS = "milliseconds" MICROSECONDS = "microseconds" + # cache this mapping on class level for performance + _unit_to_microseconds_mapping = { + WEEKS: 1000000 * 60 * 60 * 24 * 7, + DAYS: 1000000 * 60 * 60 * 24, + HOURS: 1000000 * 60 * 60, + MINUTES: 1000000 * 60, + SECONDS: 1000000, + MILLISECONDS: 1000, + MICROSECONDS: 1, + } + #: Default error messages. default_error_messages = { "invalid": "Not a valid period of time.", @@ -1486,47 +1491,36 @@ class TimeDelta(Field): def __init__( self, precision: str = SECONDS, - serialization_type: type[int | float] = int, + serialization_type: typing.Any = missing_, **kwargs, - ): + ) -> None: precision = precision.lower() - units = ( - self.DAYS, - self.SECONDS, - self.MICROSECONDS, - self.MILLISECONDS, - self.MINUTES, - self.HOURS, - self.WEEKS, - ) - if precision not in units: - msg = 'The precision must be {} or "{}".'.format( - ", ".join([f'"{each}"' for each in units[:-1]]), units[-1] - ) + if precision not in self._unit_to_microseconds_mapping: + units = ", ".join(self._unit_to_microseconds_mapping) + msg = f"The precision must be one of: {units}." raise ValueError(msg) - if serialization_type not in (int, float): - raise ValueError("The serialization type must be one of int or float") + if serialization_type is not missing_: + warnings.warn( + "The 'serialization_type' argument to TimeDelta is deprecated.", + RemovedInMarshmallow4Warning, + stacklevel=2, + ) self.precision = precision - self.serialization_type = serialization_type super().__init__(**kwargs) - def _serialize(self, value, attr, obj, **kwargs): + def _serialize(self, value, attr, obj, **kwargs) -> float | None: if value is None: return None - base_unit = dt.timedelta(**{self.precision: 1}) - - if self.serialization_type is int: - delta = utils.timedelta_to_microseconds(value) - unit = utils.timedelta_to_microseconds(base_unit) - return delta // unit - assert self.serialization_type is float - return value.total_seconds() / base_unit.total_seconds() + # limit float arithmetics to a single division to minimize precision loss + microseconds: int = utils.timedelta_to_microseconds(value) + microseconds_per_unit: int = self._unit_to_microseconds_mapping[self.precision] + return microseconds / microseconds_per_unit - def _deserialize(self, value, attr, data, **kwargs): + def _deserialize(self, value, attr, data, **kwargs) -> dt.timedelta: try: value = float(value) except (TypeError, ValueError) as error: diff --git a/src/marshmallow/utils.py b/src/marshmallow/utils.py index a5fe72624..59769dd4d 100644 --- a/src/marshmallow/utils.py +++ b/src/marshmallow/utils.py @@ -363,9 +363,9 @@ def resolve_field_instance(cls_or_instance): def timedelta_to_microseconds(value: dt.timedelta) -> int: - """Compute the total microseconds of a timedelta + """Compute the total microseconds of a timedelta. - https://github.com/python/cpython/blob/bb3e0c240bc60fe08d332ff5955d54197f79751c/Lib/datetime.py#L665-L667 # noqa: B950 + https://github.com/python/cpython/blob/v3.13.1/Lib/_pydatetime.py#L805-L807 """ return (value.days * (24 * 3600) + value.seconds) * 1000000 + value.microseconds diff --git a/tests/test_deserialization.py b/tests/test_deserialization.py index e4a03d5db..ff20bbf6b 100644 --- a/tests/test_deserialization.py +++ b/tests/test_deserialization.py @@ -698,7 +698,7 @@ def test_iso_time_field_deserialization(self, fmt, value, expected): assert field.deserialize(value) == expected def test_invalid_timedelta_precision(self): - with pytest.raises(ValueError, match='The precision must be "days",'): + with pytest.raises(ValueError, match="The precision must be one of: weeks,"): fields.TimeDelta("invalid") def test_timedelta_field_deserialization(self): diff --git a/tests/test_serialization.py b/tests/test_serialization.py index df778ce82..d43defb97 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -704,25 +704,25 @@ def test_timedelta_field(self, user): ) field = fields.TimeDelta(fields.TimeDelta.DAYS) - assert field.serialize("d1", user) == 1 + assert field.serialize("d1", user) == 1.0000115740856481 field = fields.TimeDelta(fields.TimeDelta.SECONDS) - assert field.serialize("d1", user) == 86401 + assert field.serialize("d1", user) == 86401.000001 field = fields.TimeDelta(fields.TimeDelta.MICROSECONDS) assert field.serialize("d1", user) == 86401000001 field = fields.TimeDelta(fields.TimeDelta.HOURS) - assert field.serialize("d1", user) == 24 + assert field.serialize("d1", user) == 24.000277778055555 field = fields.TimeDelta(fields.TimeDelta.DAYS) - assert field.serialize("d2", user) == 1 + assert field.serialize("d2", user) == 1.0000115740856481 field = fields.TimeDelta(fields.TimeDelta.SECONDS) - assert field.serialize("d2", user) == 86401 + assert field.serialize("d2", user) == 86401.000001 field = fields.TimeDelta(fields.TimeDelta.MICROSECONDS) assert field.serialize("d2", user) == 86401000001 field = fields.TimeDelta(fields.TimeDelta.DAYS) - assert field.serialize("d3", user) == 1 + assert field.serialize("d3", user) == 1.0000115740856481 field = fields.TimeDelta(fields.TimeDelta.SECONDS) - assert field.serialize("d3", user) == 86401 + assert field.serialize("d3", user) == 86401.000001 field = fields.TimeDelta(fields.TimeDelta.MICROSECONDS) assert field.serialize("d3", user) == 86401000001 @@ -741,38 +741,30 @@ def test_timedelta_field(self, user): assert field.serialize("d5", user) == -86400000000 field = fields.TimeDelta(fields.TimeDelta.WEEKS) - assert field.serialize("d6", user) == 1 + assert field.serialize("d6", user) == 1.1489103852529763 field = fields.TimeDelta(fields.TimeDelta.DAYS) - assert field.serialize("d6", user) == 7 + 1 + assert field.serialize("d6", user) == 8.042372696770833 field = fields.TimeDelta(fields.TimeDelta.HOURS) - assert field.serialize("d6", user) == 7 * 24 + 24 + 1 + assert field.serialize("d6", user) == 193.0169447225 field = fields.TimeDelta(fields.TimeDelta.MINUTES) - assert field.serialize("d6", user) == 7 * 24 * 60 + 24 * 60 + 60 + 1 - d6_seconds = ( - 7 * 24 * 60 * 60 - + 24 * 60 * 60 # 1 week - + 60 * 60 # 1 day - + 60 # 1 hour - + 1 # 1 minute - ) + assert field.serialize("d6", user) == 11581.01668335 field = fields.TimeDelta(fields.TimeDelta.SECONDS) - assert field.serialize("d6", user) == d6_seconds + assert field.serialize("d6", user) == 694861.001001 field = fields.TimeDelta(fields.TimeDelta.MILLISECONDS) - assert field.serialize("d6", user) == d6_seconds * 1000 + 1 + assert field.serialize("d6", user) == 694861001.001 field = fields.TimeDelta(fields.TimeDelta.MICROSECONDS) - assert field.serialize("d6", user) == d6_seconds * 10**6 + 1000 + 1 + assert field.serialize("d6", user) == 694861001001 user.d7 = None assert field.serialize("d7", user) is None - # https://github.com/marshmallow-code/marshmallow/issues/1856 user.d8 = dt.timedelta(milliseconds=345) field = fields.TimeDelta(fields.TimeDelta.MILLISECONDS) assert field.serialize("d8", user) == 345 user.d9 = dt.timedelta(milliseconds=1999) field = fields.TimeDelta(fields.TimeDelta.SECONDS) - assert field.serialize("d9", user) == 1 + assert field.serialize("d9", user) == 1.999 user.d10 = dt.timedelta( weeks=1, @@ -784,48 +776,45 @@ def test_timedelta_field(self, user): microseconds=742, ) - field = fields.TimeDelta(fields.TimeDelta.MICROSECONDS, float) + field = fields.TimeDelta(fields.TimeDelta.MICROSECONDS) unit_value = dt.timedelta(microseconds=1).total_seconds() assert math.isclose( field.serialize("d10", user), user.d10.total_seconds() / unit_value ) - field = fields.TimeDelta(fields.TimeDelta.MILLISECONDS, float) + field = fields.TimeDelta(fields.TimeDelta.MILLISECONDS) unit_value = dt.timedelta(milliseconds=1).total_seconds() assert math.isclose( field.serialize("d10", user), user.d10.total_seconds() / unit_value ) - field = fields.TimeDelta(fields.TimeDelta.SECONDS, float) + field = fields.TimeDelta(fields.TimeDelta.SECONDS) assert math.isclose(field.serialize("d10", user), user.d10.total_seconds()) - field = fields.TimeDelta(fields.TimeDelta.MINUTES, float) + field = fields.TimeDelta(fields.TimeDelta.MINUTES) unit_value = dt.timedelta(minutes=1).total_seconds() assert math.isclose( field.serialize("d10", user), user.d10.total_seconds() / unit_value ) - field = fields.TimeDelta(fields.TimeDelta.HOURS, float) + field = fields.TimeDelta(fields.TimeDelta.HOURS) unit_value = dt.timedelta(hours=1).total_seconds() assert math.isclose( field.serialize("d10", user), user.d10.total_seconds() / unit_value ) - field = fields.TimeDelta(fields.TimeDelta.DAYS, float) + field = fields.TimeDelta(fields.TimeDelta.DAYS) unit_value = dt.timedelta(days=1).total_seconds() assert math.isclose( field.serialize("d10", user), user.d10.total_seconds() / unit_value ) - field = fields.TimeDelta(fields.TimeDelta.WEEKS, float) + field = fields.TimeDelta(fields.TimeDelta.WEEKS) unit_value = dt.timedelta(weeks=1).total_seconds() assert math.isclose( field.serialize("d10", user), user.d10.total_seconds() / unit_value ) - with pytest.raises(ValueError): - fields.TimeDelta(fields.TimeDelta.SECONDS, str) - def test_datetime_list_field(self): obj = DateTimeList([dt.datetime.now(dt.timezone.utc), dt.datetime.now()]) field = fields.List(fields.DateTime)