Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

👽 Implement tzinfo.fromutc for TzInfo #864

Merged
merged 7 commits into from
Aug 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/pydantic_core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
SchemaSerializer,
SchemaValidator,
Some,
TzInfo,
Url,
ValidationError,
__version__,
Expand Down Expand Up @@ -59,6 +60,7 @@
'PydanticUseDefault',
'PydanticSerializationError',
'PydanticSerializationUnexpectedValue',
'TzInfo',
'to_json',
'to_jsonable_python',
]
Expand Down
1 change: 1 addition & 0 deletions python/pydantic_core/_pydantic_core.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -798,4 +798,5 @@ class TzInfo(datetime.tzinfo):
def tzname(self, _dt: datetime.datetime | None) -> str | None: ...
def utcoffset(self, _dt: datetime.datetime | None) -> datetime.timedelta: ...
def dst(self, _dt: datetime.datetime | None) -> datetime.timedelta: ...
def fromutc(self, dt: datetime.datetime) -> datetime.datetime: ...
def __deepcopy__(self, _memo: dict[Any, Any]) -> 'TzInfo': ...
60 changes: 52 additions & 8 deletions src/input/datetime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@ use pyo3::intern;
use pyo3::prelude::*;

use pyo3::exceptions::PyValueError;
use pyo3::pyclass::CompareOp;
use pyo3::types::{PyDate, PyDateTime, PyDelta, PyDeltaAccess, PyDict, PyTime, PyTzInfo};
use speedate::MicrosecondsPrecisionOverflowBehavior;
use speedate::{Date, DateTime, Duration, ParseError, Time, TimeConfig};
use std::borrow::Cow;
use std::collections::hash_map::DefaultHasher;
use std::hash::Hash;
use std::hash::Hasher;

use strum::EnumMessage;

Expand Down Expand Up @@ -222,7 +226,7 @@ impl<'a> EitherTime<'a> {
fn time_as_tzinfo<'py>(py: Python<'py>, time: &Time) -> PyResult<Option<&'py PyTzInfo>> {
match time.tz_offset {
Some(offset) => {
let tz_info = TzInfo::new(offset);
let tz_info: TzInfo = offset.try_into()?;
let py_tz_info = Py::new(py, tz_info)?.to_object(py).into_ref(py);
Ok(Some(py_tz_info.extract()?))
}
Expand Down Expand Up @@ -508,11 +512,11 @@ pub struct TzInfo {
#[pymethods]
impl TzInfo {
#[new]
fn new(seconds: i32) -> Self {
Self { seconds }
fn py_new(seconds: f32) -> PyResult<Self> {
Copy link
Contributor Author

@lig lig Aug 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the change from i32 to f32 is to accept the output of timedelta.total_seconds() directly

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to add

impl TryFrom<i32> for TzInfo {
    type Error = PyErr;

    fn try_from(seconds: i32) -> PyResult<Self> {
        if seconds.abs() >= 86400 {
            Err(PyValueError::new_err(format!(
                "TzInfo offset must be strictly between -86400 and 86400 (24 hours) seconds, got {seconds}"
            )))
        } else {
            Ok(Self { seconds })
        }
    }
}

Then use (seconds.trunc() as i32).try_into() here, then you can use let tz_info: TzInfo = offset.try_into()?; above and avoid going i32 -> f32 -> i32.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice

Self::try_from(seconds.trunc() as i32)
}

fn utcoffset<'p>(&self, py: Python<'p>, _dt: &PyAny) -> PyResult<&'p PyDelta> {
fn utcoffset<'py>(&self, py: Python<'py>, _dt: &PyAny) -> PyResult<&'py PyDelta> {
PyDelta::new(py, 0, self.seconds, 0, true)
}

Expand All @@ -524,17 +528,43 @@ impl TzInfo {
None
}

fn fromutc<'py>(&self, dt: &'py PyDateTime) -> PyResult<&'py PyAny> {
let py = dt.py();
dt.call_method1("__add__", (self.utcoffset(py, py.None().as_ref(py))?,))
}

fn __repr__(&self) -> String {
format!("TzInfo({})", self.__str__())
}

fn __str__(&self) -> String {
if self.seconds == 0 {
"UTC".to_string()
} else {
let mins = self.seconds / 60;
format!("{:+03}:{:02}", mins / 60, (mins % 60).abs())
return "UTC".to_string();
}

let (mins, seconds) = (self.seconds / 60, self.seconds % 60);
let mut result = format!(
"{}{:02}:{:02}",
if self.seconds.signum() >= 0 { "+" } else { "-" },
(mins / 60).abs(),
(mins % 60).abs()
);

if seconds != 0 {
result.push_str(&format!(":{:02}", seconds.abs()));
}

result
}

fn __hash__(&self) -> u64 {
samuelcolvin marked this conversation as resolved.
Show resolved Hide resolved
let mut hasher = DefaultHasher::new();
self.seconds.hash(&mut hasher);
hasher.finish()
}

fn __richcmp__(&self, other: &Self, op: CompareOp) -> bool {
op.matches(self.seconds.cmp(&other.seconds))
}

fn __deepcopy__(&self, py: Python, _memo: &PyDict) -> PyResult<Py<Self>> {
Expand All @@ -547,3 +577,17 @@ impl TzInfo {
Ok((cls, args).into_py(py))
}
}

impl TryFrom<i32> for TzInfo {
type Error = PyErr;

fn try_from(seconds: i32) -> PyResult<Self> {
if seconds.abs() >= 86400 {
Err(PyValueError::new_err(format!(
"TzInfo offset must be strictly between -86400 and 86400 (24 hours) seconds, got {seconds}"
)))
} else {
Ok(Self { seconds })
}
}
}
212 changes: 212 additions & 0 deletions tests/test_tzinfo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
import copy
import functools
import pickle
import unittest
from datetime import datetime, timedelta, timezone, tzinfo

from pydantic_core import SchemaValidator, TzInfo, core_schema


class _ALWAYS_EQ:
"""
Object that is equal to anything.
"""

def __eq__(self, other):
return True

def __ne__(self, other):
return False


ALWAYS_EQ = _ALWAYS_EQ()


@functools.total_ordering
class _LARGEST:
"""
Object that is greater than anything (except itself).
"""

def __eq__(self, other):
return isinstance(other, _LARGEST)

def __lt__(self, other):
return False


LARGEST = _LARGEST()


@functools.total_ordering
class _SMALLEST:
"""
Object that is less than anything (except itself).
"""

def __eq__(self, other):
return isinstance(other, _SMALLEST)

def __gt__(self, other):
return False


SMALLEST = _SMALLEST()


pickle_choices = [(pickle, pickle, proto) for proto in range(pickle.HIGHEST_PROTOCOL + 1)]

HOUR = timedelta(hours=1).total_seconds()
ZERO = timedelta(0).total_seconds()


def first_sunday_on_or_after(dt):
days_to_go = 6 - dt.weekday()
if days_to_go:
dt += timedelta(days_to_go)
return dt


DSTSTART = datetime(1, 4, 1, 2)
DSTEND = datetime(1, 10, 25, 1)


class TestTzInfo(unittest.TestCase):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assuming this is copied from cpython somewhere, please include a link to the original in a docstring.

"""Adapted from CPython `timezone` tests

Original tests are located here https://github.com/python/cpython/blob/a0bb4a39d1ca10e4a75f50a9fbe90cc9db28d29e/Lib/test/datetimetester.py#L256
"""

def setUp(self):
self.ACDT = TzInfo(timedelta(hours=9.5).total_seconds())
self.EST = TzInfo(-timedelta(hours=5).total_seconds())
self.DT = datetime(2010, 1, 1)

def test_str(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can probably use parameterize for these and thereby make any error easier to read.

Ideally these should be rewritten in pytest style, it shouldn't take very long, but I know you've already spend lots of time on this, so don't worry if you can't do it quickly.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kinda like the idea of keeping them as unittest. @lig can you add a link / note to permalink to the standard library for future selfs?

for tz in [self.ACDT, self.EST]:
self.assertEqual(str(tz), tz.tzname(None))

def test_constructor(self):
for subminute in [timedelta(microseconds=1), timedelta(seconds=1)]:
tz = TzInfo(subminute.total_seconds())
self.assertNotEqual(tz.utcoffset(None) % timedelta(minutes=1), 0)
# invalid offsets
for invalid in [timedelta(1, 1), timedelta(1)]:
self.assertRaises(ValueError, TzInfo, invalid.total_seconds())
self.assertRaises(ValueError, TzInfo, -invalid.total_seconds())

with self.assertRaises(TypeError):
TzInfo(None)
with self.assertRaises(TypeError):
TzInfo(timedelta(seconds=42))
with self.assertRaises(TypeError):
TzInfo(ZERO, None)
with self.assertRaises(TypeError):
TzInfo(ZERO, 42)
with self.assertRaises(TypeError):
TzInfo(ZERO, 'ABC', 'extra')

def test_inheritance(self):
self.assertIsInstance(self.EST, tzinfo)

def test_utcoffset(self):
dummy = self.DT
for h in [0, 1.5, 12]:
offset = h * HOUR
self.assertEqual(timedelta(seconds=offset), TzInfo(offset).utcoffset(dummy))
self.assertEqual(timedelta(seconds=-offset), TzInfo(-offset).utcoffset(dummy))

self.assertEqual(self.EST.utcoffset(''), timedelta(hours=-5))
self.assertEqual(self.EST.utcoffset(5), timedelta(hours=-5))

def test_dst(self):
self.EST.dst('') is None
self.EST.dst(5) is None

def test_tzname(self):
self.assertEqual('-05:00', TzInfo(-5 * HOUR).tzname(None))
self.assertEqual('+09:30', TzInfo(9.5 * HOUR).tzname(None))
self.assertEqual('-00:01', TzInfo(timedelta(minutes=-1).total_seconds()).tzname(None))
# Sub-minute offsets:
self.assertEqual('+01:06:40', TzInfo(timedelta(0, 4000).total_seconds()).tzname(None))
self.assertEqual('-01:06:40', TzInfo(-timedelta(0, 4000).total_seconds()).tzname(None))
self.assertEqual('+01:06:40', TzInfo(timedelta(0, 4000, 1).total_seconds()).tzname(None))
self.assertEqual('-01:06:40', TzInfo(-timedelta(0, 4000, 1).total_seconds()).tzname(None))

self.assertEqual(self.EST.tzname(''), '-05:00')
self.assertEqual(self.EST.tzname(5), '-05:00')

def test_fromutc(self):
for tz in [self.EST, self.ACDT]:
utctime = self.DT.replace(tzinfo=tz)
local = tz.fromutc(utctime)
self.assertEqual(local - utctime, tz.utcoffset(local))
self.assertEqual(local, self.DT.replace(tzinfo=timezone.utc))

def test_comparison(self):
self.assertNotEqual(TzInfo(ZERO), TzInfo(HOUR))
self.assertEqual(TzInfo(HOUR), TzInfo(HOUR))
self.assertFalse(TzInfo(ZERO) < TzInfo(ZERO))
self.assertIn(TzInfo(ZERO), {TzInfo(ZERO)})
self.assertTrue(TzInfo(ZERO) is not None)
self.assertFalse(TzInfo(ZERO) is None)

tz = TzInfo(ZERO)
self.assertTrue(tz == ALWAYS_EQ)
self.assertFalse(tz != ALWAYS_EQ)
self.assertTrue(tz < LARGEST)
self.assertFalse(tz > LARGEST)
self.assertTrue(tz <= LARGEST)
self.assertFalse(tz >= LARGEST)
self.assertFalse(tz < SMALLEST)
self.assertTrue(tz > SMALLEST)
self.assertFalse(tz <= SMALLEST)
self.assertTrue(tz >= SMALLEST)

def test_copy(self):
for tz in self.ACDT, self.EST:
tz_copy = copy.copy(tz)
self.assertEqual(tz_copy, tz)

def test_deepcopy(self):
for tz in self.ACDT, self.EST:
tz_copy = copy.deepcopy(tz)
self.assertEqual(tz_copy, tz)

def test_offset_boundaries(self):
# Test timedeltas close to the boundaries
time_deltas = [timedelta(hours=23, minutes=59), timedelta(hours=23, minutes=59, seconds=59)]
time_deltas.extend([-delta for delta in time_deltas])

for delta in time_deltas:
with self.subTest(test_type='good', delta=delta):
print(delta.total_seconds())
TzInfo(delta.total_seconds())

# Test timedeltas on and outside the boundaries
bad_time_deltas = [timedelta(hours=24), timedelta(hours=24, microseconds=1)]
bad_time_deltas.extend([-delta for delta in bad_time_deltas])

for delta in bad_time_deltas:
with self.subTest(test_type='bad', delta=delta):
with self.assertRaises(ValueError):
TzInfo(delta.total_seconds())


def test_tzinfo_could_be_reused():
class Model:
value: datetime

v = SchemaValidator(
core_schema.model_schema(
Model, core_schema.model_fields_schema({'value': core_schema.model_field(core_schema.datetime_schema())})
)
)

m = v.validate_python({'value': '2015-10-21T15:28:00.000000+01:00'})

target = datetime(1955, 11, 12, 14, 38, tzinfo=m.value.tzinfo)
assert target == datetime(1955, 11, 12, 14, 38, tzinfo=timezone(timedelta(hours=1)))

now = datetime.now(tz=m.value.tzinfo)
assert isinstance(now, datetime)