From 64c1e09719c1506620da4d5f4b8341d8b2a3386a Mon Sep 17 00:00:00 2001 From: Dan Redding <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 30 Oct 2024 13:12:39 +0000 Subject: [PATCH] feat: Support `datetime.(date|datetime)` in `Expression`(s) (#3654) --- altair/expr/core.py | 48 ++++++++++++++++++++++++++++++++++++++--- tests/expr/test_expr.py | 41 +++++++++++++++++++++++++++++++++++ 2 files changed, 86 insertions(+), 3 deletions(-) diff --git a/altair/expr/core.py b/altair/expr/core.py index 429766b8a..cb89391f4 100644 --- a/altair/expr/core.py +++ b/altair/expr/core.py @@ -1,10 +1,20 @@ from __future__ import annotations -from typing import Any, Union -from typing_extensions import TypeAlias +import datetime as dt +from typing import TYPE_CHECKING, Any, Literal, Union from altair.utils import SchemaBase +if TYPE_CHECKING: + import sys + + from altair.vegalite.v5.schema._typing import Map, PrimitiveValue_T + + if sys.version_info >= (3, 10): + from typing import TypeAlias + else: + from typing_extensions import TypeAlias + class DatumType: """An object to assist in building Vega-Lite Expressions.""" @@ -38,10 +48,40 @@ def _js_repr(val) -> str: return "null" elif isinstance(val, OperatorMixin): return val._to_expr() + elif isinstance(val, dt.date): + return _from_date_datetime(val) else: return repr(val) +def _from_date_datetime(obj: dt.date | dt.datetime, /) -> str: + """ + Parse native `datetime.(date|datetime)` into a `datetime expression`_ string. + + **Month is 0-based** + + .. _datetime expression: + https://vega.github.io/vega/docs/expressions/#datetime + """ + fn_name: Literal["datetime", "utc"] = "datetime" + args: tuple[int, ...] = obj.year, obj.month - 1, obj.day + if isinstance(obj, dt.datetime): + if tzinfo := obj.tzinfo: + if tzinfo is dt.timezone.utc: + fn_name = "utc" + else: + msg = ( + f"Unsupported timezone {tzinfo!r}.\n" + "Only `'UTC'` or naive (local) datetimes are permitted.\n" + "See https://altair-viz.github.io/user_guide/generated/core/altair.DateTime.html" + ) + raise TypeError(msg) + us = obj.microsecond + ms = us if us == 0 else us // 1_000 + args = *args, obj.hour, obj.minute, obj.second, ms + return FunctionExpression(fn_name, args)._to_expr() + + # Designed to work with Expression and VariableParameter class OperatorMixin: def _to_expr(self) -> str: @@ -237,4 +277,6 @@ def __repr__(self) -> str: return f"{self.group}[{self.name!r}]" -IntoExpression: TypeAlias = Union[bool, None, str, float, OperatorMixin, dict[str, Any]] +IntoExpression: TypeAlias = Union[ + "PrimitiveValue_T", dt.date, dt.datetime, OperatorMixin, "Map" +] diff --git a/tests/expr/test_expr.py b/tests/expr/test_expr.py index 15757b09f..b2200184c 100644 --- a/tests/expr/test_expr.py +++ b/tests/expr/test_expr.py @@ -1,5 +1,6 @@ from __future__ import annotations +import datetime as dt import operator import sys from inspect import classify_class_attrs, getmembers, signature @@ -188,3 +189,43 @@ def test_expression_function_nostring(): with pytest.raises(ValidationError): expr(["foo", "bah"]) # pyright: ignore + + +@pytest.mark.parametrize( + ("value", "expected"), + [ + (dt.date(2000, 1, 1), "datetime(2000,0,1)"), + (dt.datetime(2000, 1, 1), "datetime(2000,0,1,0,0,0,0)"), + (dt.datetime(2001, 1, 1, 9, 30, 0, 2999), "datetime(2001,0,1,9,30,0,2)"), + ( + dt.datetime(2003, 5, 1, 1, 30, tzinfo=dt.timezone.utc), + "utc(2003,4,1,1,30,0,0)", + ), + ], + ids=["date", "datetime (no time)", "datetime (microseconds)", "datetime (UTC)"], +) +def test_expr_datetime(value: Any, expected: str) -> None: + r_datum = datum.date >= value + assert isinstance(r_datum, Expression) + assert repr(r_datum) == f"(datum.date >= {expected})" + + +@pytest.mark.parametrize( + "tzinfo", + [ + dt.timezone(dt.timedelta(hours=2), "UTC+2"), + dt.timezone(dt.timedelta(hours=1), "BST"), + dt.timezone(dt.timedelta(hours=-7), "pdt"), + dt.timezone(dt.timedelta(hours=-3), "BRT"), + dt.timezone(dt.timedelta(hours=9), "UTC"), + dt.timezone(dt.timedelta(minutes=60), "utc"), + ], +) +def test_expr_datetime_unsupported_timezone(tzinfo: dt.timezone) -> None: + datetime = dt.datetime(2003, 5, 1, 1, 30) + + result = datum.date == datetime + assert repr(result) == "(datum.date === datetime(2003,4,1,1,30,0,0))" + + with pytest.raises(TypeError, match=r"Unsupported timezone.+\n.+UTC.+local"): + datum.date == datetime.replace(tzinfo=tzinfo) # noqa: B015