Skip to content

Commit

Permalink
fix(python, rust): expr parsing date/timestamp (#2357)
Browse files Browse the repository at this point in the history
# Description
We weren't parsing all scalar values yet, parses date32/64 and
timestampmicros now as well.

# Related Issue(s)
- fixes #2344
  • Loading branch information
ion-elgreco authored Apr 2, 2024
1 parent c223bb6 commit 6f81b80
Show file tree
Hide file tree
Showing 2 changed files with 164 additions and 1 deletion.
73 changes: 72 additions & 1 deletion crates/core/src/delta_datafusion/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@
//! Utility functions for Datafusion's Expressions

use std::{
fmt::{self, Display, Formatter, Write},
fmt::{self, format, Display, Error, Formatter, Write},
sync::Arc,
};

use arrow_schema::DataType;
use chrono::{Date, NaiveDate, NaiveDateTime, TimeZone};
use datafusion::execution::context::SessionState;
use datafusion_common::Result as DFResult;
use datafusion_common::{config::ConfigOptions, DFSchema, Result, ScalarValue, TableReference};
Expand Down Expand Up @@ -326,6 +327,9 @@ macro_rules! format_option {
}};
}

/// Epoch days from ce calander until 1970-01-01
pub const EPOCH_DAYS_FROM_CE: i32 = 719_163;

struct ScalarValueFormat<'a> {
scalar: &'a ScalarValue,
}
Expand All @@ -344,6 +348,46 @@ impl<'a> fmt::Display for ScalarValueFormat<'a> {
ScalarValue::UInt16(e) => format_option!(f, e)?,
ScalarValue::UInt32(e) => format_option!(f, e)?,
ScalarValue::UInt64(e) => format_option!(f, e)?,
ScalarValue::Date32(e) => match e {
Some(e) => write!(
f,
"{}",
NaiveDate::from_num_days_from_ce_opt((EPOCH_DAYS_FROM_CE + (*e)).into())
.ok_or(Error::default())?
)?,
None => write!(f, "NULL")?,
},
ScalarValue::Date64(e) => match e {
Some(e) => write!(
f,
"'{}'::date",
NaiveDateTime::from_timestamp_millis((*e).into())
.ok_or(Error::default())?
.date()
.format("%Y-%m-%d")
)?,
None => write!(f, "NULL")?,
},
ScalarValue::TimestampMicrosecond(e, tz) => match e {
Some(e) => match tz {
Some(tz) => write!(
f,
"arrow_cast('{}', 'Timestamp(Microsecond, Some(\"UTC\"))')",
NaiveDateTime::from_timestamp_micros(*e)
.ok_or(Error::default())?
.and_utc()
.format("%Y-%m-%dT%H:%M:%S%.6f")
)?,
None => write!(
f,
"arrow_cast('{}', 'Timestamp(Microsecond, None)')",
NaiveDateTime::from_timestamp_micros(*e)
.ok_or(Error::default())?
.format("%Y-%m-%dT%H:%M:%S%.6f")
)?,
},
None => write!(f, "NULL")?,
},
ScalarValue::Utf8(e) | ScalarValue::LargeUtf8(e) => match e {
Some(e) => write!(f, "'{}'", escape_quoted_string(e, '\''))?,
None => write!(f, "NULL")?,
Expand Down Expand Up @@ -445,6 +489,11 @@ mod test {
DataType::Primitive(PrimitiveType::Timestamp),
true,
),
StructField::new(
"_timestamp_ntz".to_string(),
DataType::Primitive(PrimitiveType::TimestampNtz),
true,
),
StructField::new(
"_binary".to_string(),
DataType::Primitive(PrimitiveType::Binary),
Expand Down Expand Up @@ -610,6 +659,28 @@ mod test {
cardinality(col("_list").range(col("value"), lit(10_i64))),
"cardinality(_list[value:10:1])".to_string()
),
ParseTest {
expr: col("_timestamp_ntz").gt(lit(ScalarValue::TimestampMicrosecond(Some(1262304000000000), None))),
expected: "_timestamp_ntz > arrow_cast('2010-01-01T00:00:00.000000', 'Timestamp(Microsecond, None)')".to_string(),
override_expected_expr: Some(col("_timestamp_ntz").gt(
datafusion_expr::Expr::Cast( Cast {
expr: Box::new(lit(ScalarValue::Utf8(Some("2010-01-01T00:00:00.000000".into())))),
data_type:ArrowDataType::Timestamp(arrow_schema::TimeUnit::Microsecond, None)
}
))),
},
ParseTest {
expr: col("_timestamp").gt(lit(ScalarValue::TimestampMicrosecond(
Some(1262304000000000),
Some("UTC".into())
))),
expected: "_timestamp > arrow_cast('2010-01-01T00:00:00.000000', 'Timestamp(Microsecond, Some(\"UTC\"))')".to_string(),
override_expected_expr: Some(col("_timestamp").gt(
datafusion_expr::Expr::Cast( Cast {
expr: Box::new(lit(ScalarValue::Utf8(Some("2010-01-01T00:00:00.000000".into())))),
data_type:ArrowDataType::Timestamp(arrow_schema::TimeUnit::Microsecond, Some("UTC".into()))
}))),
},
];

let session: SessionContext = DeltaSessionContext::default().into();
Expand Down
92 changes: 92 additions & 0 deletions python/tests/test_merge.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pathlib

import pyarrow as pa
import pytest

from deltalake import DeltaTable, write_deltalake

Expand Down Expand Up @@ -763,3 +764,94 @@ def test_merge_multiple_when_not_matched_by_source_update_wo_predicate(

assert last_action["operation"] == "MERGE"
assert result == expected


def test_merge_date_partitioned_2344(tmp_path: pathlib.Path):
from datetime import date

schema = pa.schema(
[
("date", pa.date32()),
("foo", pa.string()),
("bar", pa.string()),
]
)

dt = DeltaTable.create(
tmp_path, schema=schema, partition_by=["date"], mode="overwrite"
)

data = pa.table(
{
"date": pa.array([date(2022, 2, 1)]),
"foo": pa.array(["hello"]),
"bar": pa.array(["world"]),
}
)

dt.merge(
data,
predicate="s.date = t.date",
source_alias="s",
target_alias="t",
).when_matched_update_all().when_not_matched_insert_all().execute()

result = dt.to_pyarrow_table()
last_action = dt.history(1)[0]

assert last_action["operation"] == "MERGE"
assert result == data
assert last_action["operationParameters"].get("predicate") == "2022-02-01 = date"


@pytest.mark.parametrize(
"timezone,predicate",
[
(
None,
"arrow_cast('2022-02-01T00:00:00.000000', 'Timestamp(Microsecond, None)') = datetime",
),
(
"UTC",
"arrow_cast('2022-02-01T00:00:00.000000', 'Timestamp(Microsecond, Some(\"UTC\"))') = datetime",
),
],
)
def test_merge_timestamps_partitioned_2344(tmp_path: pathlib.Path, timezone, predicate):
from datetime import datetime

schema = pa.schema(
[
("datetime", pa.timestamp("us", tz=timezone)),
("foo", pa.string()),
("bar", pa.string()),
]
)

dt = DeltaTable.create(
tmp_path, schema=schema, partition_by=["datetime"], mode="overwrite"
)

data = pa.table(
{
"datetime": pa.array(
[datetime(2022, 2, 1)], pa.timestamp("us", tz=timezone)
),
"foo": pa.array(["hello"]),
"bar": pa.array(["world"]),
}
)

dt.merge(
data,
predicate="s.datetime = t.datetime",
source_alias="s",
target_alias="t",
).when_matched_update_all().when_not_matched_insert_all().execute()

result = dt.to_pyarrow_table()
last_action = dt.history(1)[0]

assert last_action["operation"] == "MERGE"
assert result == data
assert last_action["operationParameters"].get("predicate") == predicate

0 comments on commit 6f81b80

Please sign in to comment.