Skip to content

Commit

Permalink
[FEAT] Enable Comparison between timestamp / dates (#1689)
Browse files Browse the repository at this point in the history
* Enables comparisons between Timestamps of no / same / different
timezones and units
* Enables comparisons between dates and timestamps
* Adds Date / Timestamps to expression fixtures
* However we do not allow compare a TimeStamp with no tz with one that
has one.
  • Loading branch information
samster25 authored Dec 1, 2023
1 parent f0ddb8c commit 2d499c4
Show file tree
Hide file tree
Showing 7 changed files with 155 additions and 34 deletions.
8 changes: 6 additions & 2 deletions daft/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,13 +194,17 @@ def date(cls) -> DataType:
return cls._from_pydatatype(PyDataType.date())

@classmethod
def timestamp(cls, timeunit: TimeUnit, timezone: str | None = None) -> DataType:
def timestamp(cls, timeunit: TimeUnit | str, timezone: str | None = None) -> DataType:
"""Timestamp DataType."""
if isinstance(timeunit, str):
timeunit = TimeUnit.from_str(timeunit)
return cls._from_pydatatype(PyDataType.timestamp(timeunit._timeunit, timezone))

@classmethod
def duration(cls, timeunit: TimeUnit) -> DataType:
def duration(cls, timeunit: TimeUnit | str) -> DataType:
"""Duration DataType."""
if isinstance(timeunit, str):
timeunit = TimeUnit.from_str(timeunit)
return cls._from_pydatatype(PyDataType.duration(timeunit._timeunit))

@classmethod
Expand Down
42 changes: 26 additions & 16 deletions src/daft-core/src/datatypes/binary_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::ops::{Add, Div, Mul, Rem, Sub};

use common_error::{DaftError, DaftResult};

use crate::impl_binary_trait_by_reference;
use crate::{impl_binary_trait_by_reference, utils::supertype::try_get_supertype};

use super::DataType;

Expand All @@ -24,27 +24,37 @@ impl DataType {
))
})
}
pub fn comparison_op(&self, other: &Self) -> DaftResult<(DataType, DataType)> {
pub fn comparison_op(
&self,
other: &Self,
) -> DaftResult<(DataType, Option<DataType>, DataType)> {
// Whether a comparison op is supported between the two types.
// Returns:
// - the output type,
// - an optional intermediate type
// - the type at which the comparison should be performed.
use DataType::*;
match (self, other) {
(s, o) if s == o => Ok(s.to_physical()),
(s, o) if s.is_physical() && o.is_physical() => {
try_physical_supertype(s, o).map_err(|_| ())
}
// To maintain existing behaviour. TODO: cleanup
(Date, o) | (o, Date) if o.is_physical() && o.clone() != Boolean => {
try_physical_supertype(&Date.to_physical(), o).map_err(|_| ())
let evaluator = || {
use DataType::*;
match (self, other) {
(s, o) if s == o => Ok((Boolean, None, s.to_physical())),
(s, o) if s.is_physical() && o.is_physical() => {
Ok((Boolean, None, try_physical_supertype(s, o)?))
}
(Timestamp(..) | Date, Timestamp(..) | Date) => {
let intermediate_type = try_get_supertype(self, other)?;
let pt = intermediate_type.to_physical();
Ok((Boolean, Some(intermediate_type), pt))
}
_ => Err(DaftError::TypeError(format!(
"Cannot perform comparison on types: {}, {}",
self, other
))),
}
_ => Err(()),
}
.map(|comp_type| (Boolean, comp_type))
.map_err(|()| {
};

evaluator().map_err(|err| {
DaftError::TypeError(format!(
"Cannot perform comparison on types: {}, {}",
"Cannot perform comparison on types: {}, {}\nDetails:\n{err}",
self, other
))
})
Expand Down
13 changes: 10 additions & 3 deletions src/daft-core/src/series/array_impl/binary_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,17 +114,24 @@ macro_rules! physical_logic_op {

macro_rules! physical_compare_op {
($self:expr, $rhs:expr, $op:ident, $pyop:expr) => {{
let (output_type, comp_type) = ($self.data_type().comparison_op($rhs.data_type()))?;
let (output_type, intermediate, comp_type) =
($self.data_type().comparison_op($rhs.data_type()))?;
let lhs = $self.into_series();
let (lhs, rhs) = if let Some(ref it) = intermediate {
(lhs.cast(it)?, $rhs.cast(it)?)
} else {
(lhs, $rhs.clone())
};

use DataType::*;
if let Boolean = output_type {
match comp_type {
#[cfg(feature = "python")]
Python => py_binary_op_bool!(lhs, $rhs, $pyop)
Python => py_binary_op_bool!(lhs, rhs, $pyop)
.downcast::<BooleanArray>()
.cloned(),
_ => with_match_comparable_daft_types!(comp_type, |$T| {
cast_downcast_op!(lhs, $rhs, &comp_type, <$T as DaftDataType>::ArrayType, $op)
cast_downcast_op!(lhs, rhs, &comp_type, <$T as DaftDataType>::ArrayType, $op)
}),
}
} else {
Expand Down
12 changes: 6 additions & 6 deletions src/daft-core/src/utils/supertype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,20 +168,20 @@ pub fn get_supertype(l: &DataType, r: &DataType) -> Option<DataType> {
(Duration(_), Date) | (Date, Duration(_)) => Some(Date),
(Duration(lu), Duration(ru)) => Some(Duration(get_time_units(lu, ru))),

// None and Some("") timezones
// Some() timezones that are non equal
// we cast from more precision to higher precision as that always fits with occasional loss of precision
(Timestamp(tu_l, tz_l), Timestamp(tu_r, tz_r))
if (tz_l.is_none() || tz_l.as_deref() == Some(""))
&& (tz_r.is_none() || tz_r.as_deref() == Some("")) =>
(Timestamp(tu_l, Some(tz_l)), Timestamp(tu_r, Some(tz_r)))
if !tz_l.is_empty()
&& !tz_r.is_empty() && tz_l != tz_r =>
{
let tu = get_time_units(tu_l, tu_r);
Some(Timestamp(tu, None))
Some(Timestamp(tu, Some("UTC".to_string())))
}
// None and Some("<tz>") timezones
// we cast from more precision to higher precision as that always fits with occasional loss of precision
(Timestamp(tu_l, tz_l), Timestamp(tu_r, tz_r)) if
// both are none
tz_l.is_none() && tz_r.is_some()
tz_l.is_none() && tz_r.is_none()
// both have the same time zone
|| (tz_l.is_some() && (tz_l == tz_r)) => {
let tu = get_time_units(tu_l, tu_r);
Expand Down
2 changes: 1 addition & 1 deletion src/daft-dsl/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ impl Expr {
| Operator::NotEq
| Operator::LtEq
| Operator::GtEq => {
let (result_type, _comp_type) =
let (result_type, _intermediate, _comp_type) =
left_field.dtype.comparison_op(&right_field.dtype)?;
Ok(Field::new(left_field.name.as_str(), result_type))
}
Expand Down
62 changes: 56 additions & 6 deletions tests/expressions/typing/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import itertools
import sys

import pytz

if sys.version_info < (3, 8):
pass
else:
Expand Down Expand Up @@ -33,17 +35,65 @@
(DataType.bool(), pa.array([True, False, None], type=pa.bool_())),
(DataType.null(), pa.array([None, None, None], type=pa.null())),
(DataType.binary(), pa.array([b"1", b"2", None], type=pa.binary())),
(DataType.date(), pa.array([datetime.date(2021, 1, 1), datetime.date(2021, 1, 2), None], type=pa.date32())),
# TODO(jay): Some of the fixtures are broken/become very complicated when testing against timestamps
# (
# DataType.timestamp(TimeUnit.ms()),
# pa.array([datetime.datetime(2021, 1, 1), datetime.datetime(2021, 1, 2), None], type=pa.timestamp("ms")),
# ),
]

ALL_DATATYPES_BINARY_PAIRS = list(itertools.product(ALL_DTYPES, repeat=2))


ALL_TEMPORAL_DTYPES = [
(DataType.date(), pa.array([datetime.date(2021, 1, 1), datetime.date(2021, 1, 2), None], type=pa.date32())),
*[
(
DataType.timestamp(unit),
pa.array([datetime.datetime(2021, 1, 1), datetime.datetime(2021, 1, 2), None], type=pa.timestamp(unit)),
)
for unit in ["ns", "us", "ms"]
],
*[
(
DataType.timestamp(unit, "US/Eastern"),
pa.array(
[
datetime.datetime(2021, 1, 1).astimezone(pytz.timezone("US/Eastern")),
datetime.datetime(2021, 1, 2).astimezone(pytz.timezone("US/Eastern")),
None,
],
type=pa.timestamp(unit, "US/Eastern"),
),
)
for unit in ["ns", "us", "ms"]
],
*[
(
DataType.timestamp(unit, "Africa/Accra"),
pa.array(
[
datetime.datetime(2021, 1, 1).astimezone(pytz.timezone("Africa/Accra")),
datetime.datetime(2021, 1, 2).astimezone(pytz.timezone("Africa/Accra")),
None,
],
type=pa.timestamp(unit, "Africa/Accra"),
),
)
for unit in ["ns", "us", "ms"]
],
]

ALL_DTYPES += ALL_TEMPORAL_DTYPES

ALL_TEMPORAL_DATATYPES_BINARY_PAIRS = [
((dt1, data1), (dt2, data2))
for (dt1, data1), (dt2, data2) in itertools.product(ALL_TEMPORAL_DTYPES, repeat=2)
if not (
pa.types.is_timestamp(data1.type)
and pa.types.is_timestamp(data2.type)
and (data1.type.tz is None) ^ (data2.type.tz is None)
)
]

ALL_DATATYPES_BINARY_PAIRS += ALL_TEMPORAL_DATATYPES_BINARY_PAIRS


@pytest.fixture(
scope="module",
params=ALL_DATATYPES_BINARY_PAIRS,
Expand Down
50 changes: 50 additions & 0 deletions tests/series/test_comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

import itertools
import operator
from datetime import date, datetime

import pyarrow as pa
import pytest
import pytz

from daft import DataType, Series

Expand Down Expand Up @@ -682,3 +684,51 @@ def test_logicalops_pyobjects(op, expected, expected_self) -> None:
assert op(custom_falses, values).datatype() == DataType.bool()
assert op(custom_falses, values).to_pylist() == expected
assert op(custom_falses, custom_falses).to_pylist() == expected_self


@pytest.mark.parametrize("tu1, tu2", itertools.product(["ns", "us", "ms"], repeat=2))
def test_compare_timestamps_no_tz(tu1, tu2):
tz1 = Series.from_pylist([datetime(2022, 1, 1)])
assert (tz1.cast(DataType.timestamp(tu1)) == tz1.cast(DataType.timestamp(tu2))).to_pylist() == [True]


def test_compare_timestamps_no_tz_date():
tz1 = Series.from_pylist([datetime(2022, 1, 1)])
Series.from_pylist([date(2022, 1, 1)])
assert (tz1 == tz1).to_pylist() == [True]


def test_compare_timestamps_one_tz():
tz1 = Series.from_pylist([datetime(2022, 1, 1)])
tz2 = Series.from_pylist([datetime(2022, 1, 1, tzinfo=pytz.utc)])
with pytest.raises(ValueError, match="Cannot perform comparison on types"):
assert (tz1 == tz2).to_pylist() == [True]


def test_compare_timestamps_and_int():
tz1 = Series.from_pylist([datetime(2022, 1, 1)])
tz2 = Series.from_pylist([5])
with pytest.raises(ValueError, match="Cannot perform comparison on types"):
assert (tz1 == tz2).to_pylist() == [True]


def test_compare_timestamps_tz_date():
tz1 = Series.from_pylist([datetime(2022, 1, 1, tzinfo=pytz.utc)])
Series.from_pylist([date(2022, 1, 1)])
assert (tz1 == tz1).to_pylist() == [True]


@pytest.mark.parametrize("tu1, tu2", itertools.product(["ns", "us", "ms"], repeat=2))
def test_compare_timestamps_same_tz(tu1, tu2):
tz1 = Series.from_pylist([datetime(2022, 1, 1, tzinfo=pytz.utc)]).cast(DataType.timestamp(tu1, "UTC"))
tz2 = Series.from_pylist([datetime(2022, 1, 1, tzinfo=pytz.utc)]).cast(DataType.timestamp(tu2, "UTC"))
assert (tz1 == tz2).to_pylist() == [True]


@pytest.mark.parametrize("tu1, tu2", itertools.product(["ns", "us", "ms"], repeat=2))
def test_compare_timestamps_diff_tz(tu1, tu2):
utc = datetime(2022, 1, 1, tzinfo=pytz.utc)
eastern = utc.astimezone(pytz.timezone("US/Eastern"))
tz1 = Series.from_pylist([utc]).cast(DataType.timestamp(tu1, "UTC"))
tz2 = Series.from_pylist([eastern]).cast(DataType.timestamp(tu1, "US/Eastern"))
assert (tz1 == tz2).to_pylist() == [True]

0 comments on commit 2d499c4

Please sign in to comment.