diff --git a/python/pydantic_core/__init__.py b/python/pydantic_core/__init__.py index 7953efcbd..a916f43d8 100644 --- a/python/pydantic_core/__init__.py +++ b/python/pydantic_core/__init__.py @@ -18,6 +18,7 @@ SchemaSerializer, SchemaValidator, Some, + TzInfo, Url, ValidationError, __version__, @@ -59,6 +60,7 @@ 'PydanticUseDefault', 'PydanticSerializationError', 'PydanticSerializationUnexpectedValue', + 'TzInfo', 'to_json', 'to_jsonable_python', ] diff --git a/python/pydantic_core/_pydantic_core.pyi b/python/pydantic_core/_pydantic_core.pyi index 145d4aa0c..5aae6df8a 100644 --- a/python/pydantic_core/_pydantic_core.pyi +++ b/python/pydantic_core/_pydantic_core.pyi @@ -798,4 +798,5 @@ class TzInfo(datetime.tzinfo): def tzname(self, _dt: datetime.datetime | None) -> str | None: ... def utcoffset(self, _dt: datetime.datetime | None) -> datetime.timedelta: ... def dst(self, _dt: datetime.datetime | None) -> datetime.timedelta: ... + def fromutc(self, dt: datetime.datetime) -> datetime.datetime: ... def __deepcopy__(self, _memo: dict[Any, Any]) -> 'TzInfo': ... diff --git a/src/input/datetime.rs b/src/input/datetime.rs index 1faad049d..b89b76c1e 100644 --- a/src/input/datetime.rs +++ b/src/input/datetime.rs @@ -2,10 +2,14 @@ use pyo3::intern; use pyo3::prelude::*; use pyo3::exceptions::PyValueError; +use pyo3::pyclass::CompareOp; use pyo3::types::{PyDate, PyDateTime, PyDelta, PyDeltaAccess, PyDict, PyTime, PyTzInfo}; use speedate::MicrosecondsPrecisionOverflowBehavior; use speedate::{Date, DateTime, Duration, ParseError, Time, TimeConfig}; use std::borrow::Cow; +use std::collections::hash_map::DefaultHasher; +use std::hash::Hash; +use std::hash::Hasher; use strum::EnumMessage; @@ -222,7 +226,7 @@ impl<'a> EitherTime<'a> { fn time_as_tzinfo<'py>(py: Python<'py>, time: &Time) -> PyResult> { match time.tz_offset { Some(offset) => { - let tz_info = TzInfo::new(offset); + let tz_info: TzInfo = offset.try_into()?; let py_tz_info = Py::new(py, tz_info)?.to_object(py).into_ref(py); Ok(Some(py_tz_info.extract()?)) } @@ -508,11 +512,11 @@ pub struct TzInfo { #[pymethods] impl TzInfo { #[new] - fn new(seconds: i32) -> Self { - Self { seconds } + fn py_new(seconds: f32) -> PyResult { + Self::try_from(seconds.trunc() as i32) } - fn utcoffset<'p>(&self, py: Python<'p>, _dt: &PyAny) -> PyResult<&'p PyDelta> { + fn utcoffset<'py>(&self, py: Python<'py>, _dt: &PyAny) -> PyResult<&'py PyDelta> { PyDelta::new(py, 0, self.seconds, 0, true) } @@ -524,17 +528,43 @@ impl TzInfo { None } + fn fromutc<'py>(&self, dt: &'py PyDateTime) -> PyResult<&'py PyAny> { + let py = dt.py(); + dt.call_method1("__add__", (self.utcoffset(py, py.None().as_ref(py))?,)) + } + fn __repr__(&self) -> String { format!("TzInfo({})", self.__str__()) } fn __str__(&self) -> String { if self.seconds == 0 { - "UTC".to_string() - } else { - let mins = self.seconds / 60; - format!("{:+03}:{:02}", mins / 60, (mins % 60).abs()) + return "UTC".to_string(); + } + + let (mins, seconds) = (self.seconds / 60, self.seconds % 60); + let mut result = format!( + "{}{:02}:{:02}", + if self.seconds.signum() >= 0 { "+" } else { "-" }, + (mins / 60).abs(), + (mins % 60).abs() + ); + + if seconds != 0 { + result.push_str(&format!(":{:02}", seconds.abs())); } + + result + } + + fn __hash__(&self) -> u64 { + let mut hasher = DefaultHasher::new(); + self.seconds.hash(&mut hasher); + hasher.finish() + } + + fn __richcmp__(&self, other: &Self, op: CompareOp) -> bool { + op.matches(self.seconds.cmp(&other.seconds)) } fn __deepcopy__(&self, py: Python, _memo: &PyDict) -> PyResult> { @@ -547,3 +577,17 @@ impl TzInfo { Ok((cls, args).into_py(py)) } } + +impl TryFrom for TzInfo { + type Error = PyErr; + + fn try_from(seconds: i32) -> PyResult { + if seconds.abs() >= 86400 { + Err(PyValueError::new_err(format!( + "TzInfo offset must be strictly between -86400 and 86400 (24 hours) seconds, got {seconds}" + ))) + } else { + Ok(Self { seconds }) + } + } +} diff --git a/tests/test_tzinfo.py b/tests/test_tzinfo.py new file mode 100644 index 000000000..cb67b737e --- /dev/null +++ b/tests/test_tzinfo.py @@ -0,0 +1,212 @@ +import copy +import functools +import pickle +import unittest +from datetime import datetime, timedelta, timezone, tzinfo + +from pydantic_core import SchemaValidator, TzInfo, core_schema + + +class _ALWAYS_EQ: + """ + Object that is equal to anything. + """ + + def __eq__(self, other): + return True + + def __ne__(self, other): + return False + + +ALWAYS_EQ = _ALWAYS_EQ() + + +@functools.total_ordering +class _LARGEST: + """ + Object that is greater than anything (except itself). + """ + + def __eq__(self, other): + return isinstance(other, _LARGEST) + + def __lt__(self, other): + return False + + +LARGEST = _LARGEST() + + +@functools.total_ordering +class _SMALLEST: + """ + Object that is less than anything (except itself). + """ + + def __eq__(self, other): + return isinstance(other, _SMALLEST) + + def __gt__(self, other): + return False + + +SMALLEST = _SMALLEST() + + +pickle_choices = [(pickle, pickle, proto) for proto in range(pickle.HIGHEST_PROTOCOL + 1)] + +HOUR = timedelta(hours=1).total_seconds() +ZERO = timedelta(0).total_seconds() + + +def first_sunday_on_or_after(dt): + days_to_go = 6 - dt.weekday() + if days_to_go: + dt += timedelta(days_to_go) + return dt + + +DSTSTART = datetime(1, 4, 1, 2) +DSTEND = datetime(1, 10, 25, 1) + + +class TestTzInfo(unittest.TestCase): + """Adapted from CPython `timezone` tests + + Original tests are located here https://github.com/python/cpython/blob/a0bb4a39d1ca10e4a75f50a9fbe90cc9db28d29e/Lib/test/datetimetester.py#L256 + """ + + def setUp(self): + self.ACDT = TzInfo(timedelta(hours=9.5).total_seconds()) + self.EST = TzInfo(-timedelta(hours=5).total_seconds()) + self.DT = datetime(2010, 1, 1) + + def test_str(self): + for tz in [self.ACDT, self.EST]: + self.assertEqual(str(tz), tz.tzname(None)) + + def test_constructor(self): + for subminute in [timedelta(microseconds=1), timedelta(seconds=1)]: + tz = TzInfo(subminute.total_seconds()) + self.assertNotEqual(tz.utcoffset(None) % timedelta(minutes=1), 0) + # invalid offsets + for invalid in [timedelta(1, 1), timedelta(1)]: + self.assertRaises(ValueError, TzInfo, invalid.total_seconds()) + self.assertRaises(ValueError, TzInfo, -invalid.total_seconds()) + + with self.assertRaises(TypeError): + TzInfo(None) + with self.assertRaises(TypeError): + TzInfo(timedelta(seconds=42)) + with self.assertRaises(TypeError): + TzInfo(ZERO, None) + with self.assertRaises(TypeError): + TzInfo(ZERO, 42) + with self.assertRaises(TypeError): + TzInfo(ZERO, 'ABC', 'extra') + + def test_inheritance(self): + self.assertIsInstance(self.EST, tzinfo) + + def test_utcoffset(self): + dummy = self.DT + for h in [0, 1.5, 12]: + offset = h * HOUR + self.assertEqual(timedelta(seconds=offset), TzInfo(offset).utcoffset(dummy)) + self.assertEqual(timedelta(seconds=-offset), TzInfo(-offset).utcoffset(dummy)) + + self.assertEqual(self.EST.utcoffset(''), timedelta(hours=-5)) + self.assertEqual(self.EST.utcoffset(5), timedelta(hours=-5)) + + def test_dst(self): + self.EST.dst('') is None + self.EST.dst(5) is None + + def test_tzname(self): + self.assertEqual('-05:00', TzInfo(-5 * HOUR).tzname(None)) + self.assertEqual('+09:30', TzInfo(9.5 * HOUR).tzname(None)) + self.assertEqual('-00:01', TzInfo(timedelta(minutes=-1).total_seconds()).tzname(None)) + # Sub-minute offsets: + self.assertEqual('+01:06:40', TzInfo(timedelta(0, 4000).total_seconds()).tzname(None)) + self.assertEqual('-01:06:40', TzInfo(-timedelta(0, 4000).total_seconds()).tzname(None)) + self.assertEqual('+01:06:40', TzInfo(timedelta(0, 4000, 1).total_seconds()).tzname(None)) + self.assertEqual('-01:06:40', TzInfo(-timedelta(0, 4000, 1).total_seconds()).tzname(None)) + + self.assertEqual(self.EST.tzname(''), '-05:00') + self.assertEqual(self.EST.tzname(5), '-05:00') + + def test_fromutc(self): + for tz in [self.EST, self.ACDT]: + utctime = self.DT.replace(tzinfo=tz) + local = tz.fromutc(utctime) + self.assertEqual(local - utctime, tz.utcoffset(local)) + self.assertEqual(local, self.DT.replace(tzinfo=timezone.utc)) + + def test_comparison(self): + self.assertNotEqual(TzInfo(ZERO), TzInfo(HOUR)) + self.assertEqual(TzInfo(HOUR), TzInfo(HOUR)) + self.assertFalse(TzInfo(ZERO) < TzInfo(ZERO)) + self.assertIn(TzInfo(ZERO), {TzInfo(ZERO)}) + self.assertTrue(TzInfo(ZERO) is not None) + self.assertFalse(TzInfo(ZERO) is None) + + tz = TzInfo(ZERO) + self.assertTrue(tz == ALWAYS_EQ) + self.assertFalse(tz != ALWAYS_EQ) + self.assertTrue(tz < LARGEST) + self.assertFalse(tz > LARGEST) + self.assertTrue(tz <= LARGEST) + self.assertFalse(tz >= LARGEST) + self.assertFalse(tz < SMALLEST) + self.assertTrue(tz > SMALLEST) + self.assertFalse(tz <= SMALLEST) + self.assertTrue(tz >= SMALLEST) + + def test_copy(self): + for tz in self.ACDT, self.EST: + tz_copy = copy.copy(tz) + self.assertEqual(tz_copy, tz) + + def test_deepcopy(self): + for tz in self.ACDT, self.EST: + tz_copy = copy.deepcopy(tz) + self.assertEqual(tz_copy, tz) + + def test_offset_boundaries(self): + # Test timedeltas close to the boundaries + time_deltas = [timedelta(hours=23, minutes=59), timedelta(hours=23, minutes=59, seconds=59)] + time_deltas.extend([-delta for delta in time_deltas]) + + for delta in time_deltas: + with self.subTest(test_type='good', delta=delta): + print(delta.total_seconds()) + TzInfo(delta.total_seconds()) + + # Test timedeltas on and outside the boundaries + bad_time_deltas = [timedelta(hours=24), timedelta(hours=24, microseconds=1)] + bad_time_deltas.extend([-delta for delta in bad_time_deltas]) + + for delta in bad_time_deltas: + with self.subTest(test_type='bad', delta=delta): + with self.assertRaises(ValueError): + TzInfo(delta.total_seconds()) + + +def test_tzinfo_could_be_reused(): + class Model: + value: datetime + + v = SchemaValidator( + core_schema.model_schema( + Model, core_schema.model_fields_schema({'value': core_schema.model_field(core_schema.datetime_schema())}) + ) + ) + + m = v.validate_python({'value': '2015-10-21T15:28:00.000000+01:00'}) + + target = datetime(1955, 11, 12, 14, 38, tzinfo=m.value.tzinfo) + assert target == datetime(1955, 11, 12, 14, 38, tzinfo=timezone(timedelta(hours=1))) + + now = datetime.now(tz=m.value.tzinfo) + assert isinstance(now, datetime)