Skip to content

Commit

Permalink
feat: Support datetime.(date|datetime) in Expression(s) (vega#3654)
Browse files Browse the repository at this point in the history
  • Loading branch information
dangotbanned authored Oct 30, 2024
1 parent 74d58f0 commit 64c1e09
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 3 deletions.
48 changes: 45 additions & 3 deletions altair/expr/core.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
from __future__ import annotations

from typing import Any, Union
from typing_extensions import TypeAlias
import datetime as dt
from typing import TYPE_CHECKING, Any, Literal, Union

from altair.utils import SchemaBase

if TYPE_CHECKING:
import sys

from altair.vegalite.v5.schema._typing import Map, PrimitiveValue_T

if sys.version_info >= (3, 10):
from typing import TypeAlias
else:
from typing_extensions import TypeAlias


class DatumType:
"""An object to assist in building Vega-Lite Expressions."""
Expand Down Expand Up @@ -38,10 +48,40 @@ def _js_repr(val) -> str:
return "null"
elif isinstance(val, OperatorMixin):
return val._to_expr()
elif isinstance(val, dt.date):
return _from_date_datetime(val)
else:
return repr(val)


def _from_date_datetime(obj: dt.date | dt.datetime, /) -> str:
"""
Parse native `datetime.(date|datetime)` into a `datetime expression`_ string.
**Month is 0-based**
.. _datetime expression:
https://vega.github.io/vega/docs/expressions/#datetime
"""
fn_name: Literal["datetime", "utc"] = "datetime"
args: tuple[int, ...] = obj.year, obj.month - 1, obj.day
if isinstance(obj, dt.datetime):
if tzinfo := obj.tzinfo:
if tzinfo is dt.timezone.utc:
fn_name = "utc"
else:
msg = (
f"Unsupported timezone {tzinfo!r}.\n"
"Only `'UTC'` or naive (local) datetimes are permitted.\n"
"See https://altair-viz.github.io/user_guide/generated/core/altair.DateTime.html"
)
raise TypeError(msg)
us = obj.microsecond
ms = us if us == 0 else us // 1_000
args = *args, obj.hour, obj.minute, obj.second, ms
return FunctionExpression(fn_name, args)._to_expr()


# Designed to work with Expression and VariableParameter
class OperatorMixin:
def _to_expr(self) -> str:
Expand Down Expand Up @@ -237,4 +277,6 @@ def __repr__(self) -> str:
return f"{self.group}[{self.name!r}]"


IntoExpression: TypeAlias = Union[bool, None, str, float, OperatorMixin, dict[str, Any]]
IntoExpression: TypeAlias = Union[
"PrimitiveValue_T", dt.date, dt.datetime, OperatorMixin, "Map"
]
41 changes: 41 additions & 0 deletions tests/expr/test_expr.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import datetime as dt
import operator
import sys
from inspect import classify_class_attrs, getmembers, signature
Expand Down Expand Up @@ -188,3 +189,43 @@ def test_expression_function_nostring():

with pytest.raises(ValidationError):
expr(["foo", "bah"]) # pyright: ignore


@pytest.mark.parametrize(
("value", "expected"),
[
(dt.date(2000, 1, 1), "datetime(2000,0,1)"),
(dt.datetime(2000, 1, 1), "datetime(2000,0,1,0,0,0,0)"),
(dt.datetime(2001, 1, 1, 9, 30, 0, 2999), "datetime(2001,0,1,9,30,0,2)"),
(
dt.datetime(2003, 5, 1, 1, 30, tzinfo=dt.timezone.utc),
"utc(2003,4,1,1,30,0,0)",
),
],
ids=["date", "datetime (no time)", "datetime (microseconds)", "datetime (UTC)"],
)
def test_expr_datetime(value: Any, expected: str) -> None:
r_datum = datum.date >= value
assert isinstance(r_datum, Expression)
assert repr(r_datum) == f"(datum.date >= {expected})"


@pytest.mark.parametrize(
"tzinfo",
[
dt.timezone(dt.timedelta(hours=2), "UTC+2"),
dt.timezone(dt.timedelta(hours=1), "BST"),
dt.timezone(dt.timedelta(hours=-7), "pdt"),
dt.timezone(dt.timedelta(hours=-3), "BRT"),
dt.timezone(dt.timedelta(hours=9), "UTC"),
dt.timezone(dt.timedelta(minutes=60), "utc"),
],
)
def test_expr_datetime_unsupported_timezone(tzinfo: dt.timezone) -> None:
datetime = dt.datetime(2003, 5, 1, 1, 30)

result = datum.date == datetime
assert repr(result) == "(datum.date === datetime(2003,4,1,1,30,0,0))"

with pytest.raises(TypeError, match=r"Unsupported timezone.+\n.+UTC.+local"):
datum.date == datetime.replace(tzinfo=tzinfo) # noqa: B015

0 comments on commit 64c1e09

Please sign in to comment.