diff --git a/src/marshmallow/fields.py b/src/marshmallow/fields.py index e83ba198e..44ddc3692 100644 --- a/src/marshmallow/fields.py +++ b/src/marshmallow/fields.py @@ -1435,45 +1435,52 @@ def _make_object_from_format(value, data_format): class TimeDelta(Field): - """A field that (de)serializes a :class:`datetime.timedelta` object to an - integer or float and vice versa. The integer or float can represent the - number of days, seconds or microseconds. - - :param precision: Influences how the integer or float is interpreted during - (de)serialization. Must be 'days', 'seconds', 'microseconds', - 'milliseconds', 'minutes', 'hours' or 'weeks'. - :param serialization_type: Whether to (de)serialize to a `int` or `float`. - :param kwargs: The same keyword arguments that :class:`Field` receives. + """A field that (de)serializes a :class:`datetime.timedelta` object to a `float`. + The `float` can represent any time unit that the :class:`datetime.timedelta` constructor + supports. - Integer Caveats - --------------- - Any fractional parts (which depends on the precision used) will be truncated - when serializing using `int`. + :param precision: The time unit used for (de)serialization. Must be one of 'weeks', + 'days', 'hours', 'minutes', 'seconds', 'milliseconds' or 'microseconds'. + :param kwargs: The same keyword arguments that :class:`Field` receives. Float Caveats ------------- - Use of `float` when (de)serializing may result in data precision loss due - to the way machines handle floating point values. + Precision loss may occur when serializing a highly precise :class:`datetime.timedelta` + object using a big ``precision`` unit due to floating point arithmetics. - Regardless of the precision chosen, the fractional part when using `float` - will always be truncated to microseconds. - For example, `1.12345` interpreted as microseconds will result in `timedelta(microseconds=1)`. + When necessary, the :class:`datetime.timedelta` constructor rounds `float` inputs + to whole microseconds during initialization of the object. As a result, deserializing + a `float` might be subject to rounding, regardless of `precision`. For example, + ``TimeDelta().deserialize("1.1234567") == timedelta(seconds=1, microseconds=123457)``. .. versionchanged:: 2.0.0 Always serializes to an integer value to avoid rounding errors. Add `precision` parameter. .. versionchanged:: 3.17.0 - Allow (de)serialization to `float` through use of a new `serialization_type` parameter. - `int` is the default to retain previous behaviour. + Allow serialization to `float` through use of a new `serialization_type` parameter. + Defaults to `int` for backwards compatibility. Also affects deserialization. + .. versionchanged:: 4.0.0 + Deprecate `serialization_type` parameter, always serialize to float. """ + WEEKS = "weeks" DAYS = "days" + HOURS = "hours" + MINUTES = "minutes" SECONDS = "seconds" - MICROSECONDS = "microseconds" MILLISECONDS = "milliseconds" - MINUTES = "minutes" - HOURS = "hours" - WEEKS = "weeks" + MICROSECONDS = "microseconds" + + # cache this mapping on class level for performance + _unit_to_microseconds_mapping = { + WEEKS: 1000000 * 60 * 60 * 24 * 7, + DAYS: 1000000 * 60 * 60 * 24, + HOURS: 1000000 * 60 * 60, + MINUTES: 1000000 * 60, + SECONDS: 1000000, + MILLISECONDS: 1000, + MICROSECONDS: 1, + } #: Default error messages. default_error_messages = { @@ -1484,49 +1491,38 @@ class TimeDelta(Field): def __init__( self, precision: str = SECONDS, - serialization_type: type[int | float] = int, + serialization_type: typing.Any = missing_, **kwargs, - ): + ) -> None: precision = precision.lower() - units = ( - self.DAYS, - self.SECONDS, - self.MICROSECONDS, - self.MILLISECONDS, - self.MINUTES, - self.HOURS, - self.WEEKS, - ) - if precision not in units: - msg = 'The precision must be {} or "{}".'.format( - ", ".join([f'"{each}"' for each in units[:-1]]), units[-1] - ) + if precision not in self._unit_to_microseconds_mapping: + units = ", ".join(self._unit_to_microseconds_mapping) + msg = f"The precision must be one of: {units}." raise ValueError(msg) - if serialization_type not in (int, float): - raise ValueError("The serialization type must be one of int or float") + if serialization_type is not missing_: + warnings.warn( + "The 'serialization_type' argument to TimeDelta is deprecated.", + RemovedInMarshmallow4Warning, + stacklevel=2, + ) self.precision = precision - self.serialization_type = serialization_type super().__init__(**kwargs) - def _serialize(self, value, attr, obj, **kwargs): + def _serialize(self, value, attr, obj, **kwargs) -> float | None: if value is None: return None - base_unit = dt.timedelta(**{self.precision: 1}) + # limit float arithmetics to a single division to minimize precision loss + microseconds: int = utils.timedelta_to_microseconds(value) + microseconds_per_unit: int = self._unit_to_microseconds_mapping[self.precision] + return microseconds / microseconds_per_unit - if self.serialization_type is int: - delta = utils.timedelta_to_microseconds(value) - unit = utils.timedelta_to_microseconds(base_unit) - return delta // unit - assert self.serialization_type is float - return value.total_seconds() / base_unit.total_seconds() - - def _deserialize(self, value, attr, data, **kwargs): + def _deserialize(self, value, attr, data, **kwargs) -> dt.timedelta: try: - value = self.serialization_type(value) + value = float(value) except (TypeError, ValueError) as error: raise self.make_error("invalid") from error diff --git a/src/marshmallow/utils.py b/src/marshmallow/utils.py index a5fe72624..59769dd4d 100644 --- a/src/marshmallow/utils.py +++ b/src/marshmallow/utils.py @@ -363,9 +363,9 @@ def resolve_field_instance(cls_or_instance): def timedelta_to_microseconds(value: dt.timedelta) -> int: - """Compute the total microseconds of a timedelta + """Compute the total microseconds of a timedelta. - https://github.com/python/cpython/blob/bb3e0c240bc60fe08d332ff5955d54197f79751c/Lib/datetime.py#L665-L667 # noqa: B950 + https://github.com/python/cpython/blob/v3.13.1/Lib/_pydatetime.py#L805-L807 """ return (value.days * (24 * 3600) + value.seconds) * 1000000 + value.microseconds diff --git a/tests/test_deserialization.py b/tests/test_deserialization.py index 80d44272f..ff20bbf6b 100644 --- a/tests/test_deserialization.py +++ b/tests/test_deserialization.py @@ -698,7 +698,7 @@ def test_iso_time_field_deserialization(self, fmt, value, expected): assert field.deserialize(value) == expected def test_invalid_timedelta_precision(self): - with pytest.raises(ValueError, match='The precision must be "days",'): + with pytest.raises(ValueError, match="The precision must be one of: weeks,"): fields.TimeDelta("invalid") def test_timedelta_field_deserialization(self): @@ -709,6 +709,13 @@ def test_timedelta_field_deserialization(self): assert result.seconds == 42 assert result.microseconds == 0 + field = fields.TimeDelta() + result = field.deserialize("42.9") + assert isinstance(result, dt.timedelta) + assert result.days == 0 + assert result.seconds == 42 + assert result.microseconds == 900000 + field = fields.TimeDelta(fields.TimeDelta.SECONDS) result = field.deserialize(100000) assert result.days == 1 @@ -741,7 +748,7 @@ def test_timedelta_field_deserialization(self): assert isinstance(result, dt.timedelta) assert result.days == 0 assert result.seconds == 12 - assert result.microseconds == 0 + assert result.microseconds == 900000 field = fields.TimeDelta(fields.TimeDelta.WEEKS) result = field.deserialize(1) @@ -772,7 +779,7 @@ def test_timedelta_field_deserialization(self): assert result.microseconds == 456000 total_microseconds_value = 322.0 - field = fields.TimeDelta(fields.TimeDelta.MICROSECONDS, float) + field = fields.TimeDelta(fields.TimeDelta.MICROSECONDS) result = field.deserialize(total_microseconds_value) assert isinstance(result, dt.timedelta) unit_value = dt.timedelta(microseconds=1).total_seconds() @@ -781,7 +788,7 @@ def test_timedelta_field_deserialization(self): ) total_microseconds_value = 322.12345 - field = fields.TimeDelta(fields.TimeDelta.MICROSECONDS, float) + field = fields.TimeDelta(fields.TimeDelta.MICROSECONDS) result = field.deserialize(total_microseconds_value) assert isinstance(result, dt.timedelta) unit_value = dt.timedelta(microseconds=1).total_seconds() @@ -790,7 +797,7 @@ def test_timedelta_field_deserialization(self): ) total_milliseconds_value = 322.223 - field = fields.TimeDelta(fields.TimeDelta.MILLISECONDS, float) + field = fields.TimeDelta(fields.TimeDelta.MILLISECONDS) result = field.deserialize(total_milliseconds_value) assert isinstance(result, dt.timedelta) unit_value = dt.timedelta(milliseconds=1).total_seconds() @@ -799,34 +806,34 @@ def test_timedelta_field_deserialization(self): ) total_seconds_value = 322.223 - field = fields.TimeDelta(fields.TimeDelta.SECONDS, float) + field = fields.TimeDelta(fields.TimeDelta.SECONDS) result = field.deserialize(total_seconds_value) assert isinstance(result, dt.timedelta) assert math.isclose(result.total_seconds(), total_seconds_value) total_minutes_value = 322.223 - field = fields.TimeDelta(fields.TimeDelta.MINUTES, float) + field = fields.TimeDelta(fields.TimeDelta.MINUTES) result = field.deserialize(total_minutes_value) assert isinstance(result, dt.timedelta) unit_value = dt.timedelta(minutes=1).total_seconds() assert math.isclose(result.total_seconds() / unit_value, total_minutes_value) total_hours_value = 322.223 - field = fields.TimeDelta(fields.TimeDelta.HOURS, float) + field = fields.TimeDelta(fields.TimeDelta.HOURS) result = field.deserialize(total_hours_value) assert isinstance(result, dt.timedelta) unit_value = dt.timedelta(hours=1).total_seconds() assert math.isclose(result.total_seconds() / unit_value, total_hours_value) total_days_value = 322.223 - field = fields.TimeDelta(fields.TimeDelta.DAYS, float) + field = fields.TimeDelta(fields.TimeDelta.DAYS) result = field.deserialize(total_days_value) assert isinstance(result, dt.timedelta) unit_value = dt.timedelta(days=1).total_seconds() assert math.isclose(result.total_seconds() / unit_value, total_days_value) total_weeks_value = 322.223 - field = fields.TimeDelta(fields.TimeDelta.WEEKS, float) + field = fields.TimeDelta(fields.TimeDelta.WEEKS) result = field.deserialize(total_weeks_value) assert isinstance(result, dt.timedelta) unit_value = dt.timedelta(weeks=1).total_seconds() diff --git a/tests/test_serialization.py b/tests/test_serialization.py index df778ce82..d43defb97 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -704,25 +704,25 @@ def test_timedelta_field(self, user): ) field = fields.TimeDelta(fields.TimeDelta.DAYS) - assert field.serialize("d1", user) == 1 + assert field.serialize("d1", user) == 1.0000115740856481 field = fields.TimeDelta(fields.TimeDelta.SECONDS) - assert field.serialize("d1", user) == 86401 + assert field.serialize("d1", user) == 86401.000001 field = fields.TimeDelta(fields.TimeDelta.MICROSECONDS) assert field.serialize("d1", user) == 86401000001 field = fields.TimeDelta(fields.TimeDelta.HOURS) - assert field.serialize("d1", user) == 24 + assert field.serialize("d1", user) == 24.000277778055555 field = fields.TimeDelta(fields.TimeDelta.DAYS) - assert field.serialize("d2", user) == 1 + assert field.serialize("d2", user) == 1.0000115740856481 field = fields.TimeDelta(fields.TimeDelta.SECONDS) - assert field.serialize("d2", user) == 86401 + assert field.serialize("d2", user) == 86401.000001 field = fields.TimeDelta(fields.TimeDelta.MICROSECONDS) assert field.serialize("d2", user) == 86401000001 field = fields.TimeDelta(fields.TimeDelta.DAYS) - assert field.serialize("d3", user) == 1 + assert field.serialize("d3", user) == 1.0000115740856481 field = fields.TimeDelta(fields.TimeDelta.SECONDS) - assert field.serialize("d3", user) == 86401 + assert field.serialize("d3", user) == 86401.000001 field = fields.TimeDelta(fields.TimeDelta.MICROSECONDS) assert field.serialize("d3", user) == 86401000001 @@ -741,38 +741,30 @@ def test_timedelta_field(self, user): assert field.serialize("d5", user) == -86400000000 field = fields.TimeDelta(fields.TimeDelta.WEEKS) - assert field.serialize("d6", user) == 1 + assert field.serialize("d6", user) == 1.1489103852529763 field = fields.TimeDelta(fields.TimeDelta.DAYS) - assert field.serialize("d6", user) == 7 + 1 + assert field.serialize("d6", user) == 8.042372696770833 field = fields.TimeDelta(fields.TimeDelta.HOURS) - assert field.serialize("d6", user) == 7 * 24 + 24 + 1 + assert field.serialize("d6", user) == 193.0169447225 field = fields.TimeDelta(fields.TimeDelta.MINUTES) - assert field.serialize("d6", user) == 7 * 24 * 60 + 24 * 60 + 60 + 1 - d6_seconds = ( - 7 * 24 * 60 * 60 - + 24 * 60 * 60 # 1 week - + 60 * 60 # 1 day - + 60 # 1 hour - + 1 # 1 minute - ) + assert field.serialize("d6", user) == 11581.01668335 field = fields.TimeDelta(fields.TimeDelta.SECONDS) - assert field.serialize("d6", user) == d6_seconds + assert field.serialize("d6", user) == 694861.001001 field = fields.TimeDelta(fields.TimeDelta.MILLISECONDS) - assert field.serialize("d6", user) == d6_seconds * 1000 + 1 + assert field.serialize("d6", user) == 694861001.001 field = fields.TimeDelta(fields.TimeDelta.MICROSECONDS) - assert field.serialize("d6", user) == d6_seconds * 10**6 + 1000 + 1 + assert field.serialize("d6", user) == 694861001001 user.d7 = None assert field.serialize("d7", user) is None - # https://github.com/marshmallow-code/marshmallow/issues/1856 user.d8 = dt.timedelta(milliseconds=345) field = fields.TimeDelta(fields.TimeDelta.MILLISECONDS) assert field.serialize("d8", user) == 345 user.d9 = dt.timedelta(milliseconds=1999) field = fields.TimeDelta(fields.TimeDelta.SECONDS) - assert field.serialize("d9", user) == 1 + assert field.serialize("d9", user) == 1.999 user.d10 = dt.timedelta( weeks=1, @@ -784,48 +776,45 @@ def test_timedelta_field(self, user): microseconds=742, ) - field = fields.TimeDelta(fields.TimeDelta.MICROSECONDS, float) + field = fields.TimeDelta(fields.TimeDelta.MICROSECONDS) unit_value = dt.timedelta(microseconds=1).total_seconds() assert math.isclose( field.serialize("d10", user), user.d10.total_seconds() / unit_value ) - field = fields.TimeDelta(fields.TimeDelta.MILLISECONDS, float) + field = fields.TimeDelta(fields.TimeDelta.MILLISECONDS) unit_value = dt.timedelta(milliseconds=1).total_seconds() assert math.isclose( field.serialize("d10", user), user.d10.total_seconds() / unit_value ) - field = fields.TimeDelta(fields.TimeDelta.SECONDS, float) + field = fields.TimeDelta(fields.TimeDelta.SECONDS) assert math.isclose(field.serialize("d10", user), user.d10.total_seconds()) - field = fields.TimeDelta(fields.TimeDelta.MINUTES, float) + field = fields.TimeDelta(fields.TimeDelta.MINUTES) unit_value = dt.timedelta(minutes=1).total_seconds() assert math.isclose( field.serialize("d10", user), user.d10.total_seconds() / unit_value ) - field = fields.TimeDelta(fields.TimeDelta.HOURS, float) + field = fields.TimeDelta(fields.TimeDelta.HOURS) unit_value = dt.timedelta(hours=1).total_seconds() assert math.isclose( field.serialize("d10", user), user.d10.total_seconds() / unit_value ) - field = fields.TimeDelta(fields.TimeDelta.DAYS, float) + field = fields.TimeDelta(fields.TimeDelta.DAYS) unit_value = dt.timedelta(days=1).total_seconds() assert math.isclose( field.serialize("d10", user), user.d10.total_seconds() / unit_value ) - field = fields.TimeDelta(fields.TimeDelta.WEEKS, float) + field = fields.TimeDelta(fields.TimeDelta.WEEKS) unit_value = dt.timedelta(weeks=1).total_seconds() assert math.isclose( field.serialize("d10", user), user.d10.total_seconds() / unit_value ) - with pytest.raises(ValueError): - fields.TimeDelta(fields.TimeDelta.SECONDS, str) - def test_datetime_list_field(self): obj = DateTimeList([dt.datetime.now(dt.timezone.utc), dt.datetime.now()]) field = fields.List(fields.DateTime)