Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(python, rust): expr parsing date/timestamp #2357

Merged
73 changes: 72 additions & 1 deletion crates/core/src/delta_datafusion/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@
//! Utility functions for Datafusion's Expressions

use std::{
fmt::{self, Display, Formatter, Write},
fmt::{self, format, Display, Error, Formatter, Write},
sync::Arc,
};

use arrow_schema::DataType;
use chrono::{Date, NaiveDate, NaiveDateTime, TimeZone};
use datafusion::execution::context::SessionState;
use datafusion_common::Result as DFResult;
use datafusion_common::{config::ConfigOptions, DFSchema, Result, ScalarValue, TableReference};
Expand Down Expand Up @@ -326,6 +327,9 @@ macro_rules! format_option {
}};
}

/// Epoch days from ce calander until 1970-01-01
pub const EPOCH_DAYS_FROM_CE: i32 = 719_163;

struct ScalarValueFormat<'a> {
scalar: &'a ScalarValue,
}
Expand All @@ -344,6 +348,46 @@ impl<'a> fmt::Display for ScalarValueFormat<'a> {
ScalarValue::UInt16(e) => format_option!(f, e)?,
ScalarValue::UInt32(e) => format_option!(f, e)?,
ScalarValue::UInt64(e) => format_option!(f, e)?,
ScalarValue::Date32(e) => match e {
Some(e) => write!(
f,
"{}",
NaiveDate::from_num_days_from_ce_opt((EPOCH_DAYS_FROM_CE + (*e)).into())
.ok_or(Error::default())?
)?,
None => write!(f, "NULL")?,
},
ScalarValue::Date64(e) => match e {
Some(e) => write!(
f,
"'{}'::date",
NaiveDateTime::from_timestamp_millis((*e).into())
.ok_or(Error::default())?
.date()
.format("%Y-%m-%d")
)?,
None => write!(f, "NULL")?,
},
ScalarValue::TimestampMicrosecond(e, tz) => match e {
Some(e) => match tz {
Some(tz) => write!(
f,
"arrow_cast('{}', 'Timestamp(Microsecond, Some(\"UTC\"))')",
NaiveDateTime::from_timestamp_micros(*e)
.ok_or(Error::default())?
.and_utc()
.format("%Y-%m-%dT%H:%M:%S%.6f")
)?,
None => write!(
f,
"arrow_cast('{}', 'Timestamp(Microsecond, None)')",
NaiveDateTime::from_timestamp_micros(*e)
.ok_or(Error::default())?
.format("%Y-%m-%dT%H:%M:%S%.6f")
)?,
},
None => write!(f, "NULL")?,
},
ScalarValue::Utf8(e) | ScalarValue::LargeUtf8(e) => match e {
Some(e) => write!(f, "'{}'", escape_quoted_string(e, '\''))?,
None => write!(f, "NULL")?,
Expand Down Expand Up @@ -445,6 +489,11 @@ mod test {
DataType::Primitive(PrimitiveType::Timestamp),
true,
),
StructField::new(
"_timestamp_ntz".to_string(),
DataType::Primitive(PrimitiveType::TimestampNtz),
true,
),
StructField::new(
"_binary".to_string(),
DataType::Primitive(PrimitiveType::Binary),
Expand Down Expand Up @@ -610,6 +659,28 @@ mod test {
cardinality(col("_list").range(col("value"), lit(10_i64))),
"cardinality(_list[value:10:1])".to_string()
),
ParseTest {
expr: col("_timestamp_ntz").gt(lit(ScalarValue::TimestampMicrosecond(Some(1262304000000000), None))),
expected: "_timestamp_ntz > arrow_cast('2010-01-01T00:00:00.000000', 'Timestamp(Microsecond, None)')".to_string(),
override_expected_expr: Some(col("_timestamp_ntz").gt(
datafusion_expr::Expr::Cast( Cast {
expr: Box::new(lit(ScalarValue::Utf8(Some("2010-01-01T00:00:00.000000".into())))),
data_type:ArrowDataType::Timestamp(arrow_schema::TimeUnit::Microsecond, None)
}
))),
},
ParseTest {
expr: col("_timestamp").gt(lit(ScalarValue::TimestampMicrosecond(
Some(1262304000000000),
Some("UTC".into())
))),
expected: "_timestamp > arrow_cast('2010-01-01T00:00:00.000000', 'Timestamp(Microsecond, Some(\"UTC\"))')".to_string(),
override_expected_expr: Some(col("_timestamp").gt(
datafusion_expr::Expr::Cast( Cast {
expr: Box::new(lit(ScalarValue::Utf8(Some("2010-01-01T00:00:00.000000".into())))),
data_type:ArrowDataType::Timestamp(arrow_schema::TimeUnit::Microsecond, Some("UTC".into()))
}))),
},
];

let session: SessionContext = DeltaSessionContext::default().into();
Expand Down
92 changes: 92 additions & 0 deletions python/tests/test_merge.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pathlib

import pyarrow as pa
import pytest

from deltalake import DeltaTable, write_deltalake

Expand Down Expand Up @@ -763,3 +764,94 @@ def test_merge_multiple_when_not_matched_by_source_update_wo_predicate(

assert last_action["operation"] == "MERGE"
assert result == expected


def test_merge_date_partitioned_2344(tmp_path: pathlib.Path):
from datetime import date

schema = pa.schema(
[
("date", pa.date32()),
("foo", pa.string()),
("bar", pa.string()),
]
)

dt = DeltaTable.create(
tmp_path, schema=schema, partition_by=["date"], mode="overwrite"
)

data = pa.table(
{
"date": pa.array([date(2022, 2, 1)]),
"foo": pa.array(["hello"]),
"bar": pa.array(["world"]),
}
)

dt.merge(
data,
predicate="s.date = t.date",
source_alias="s",
target_alias="t",
).when_matched_update_all().when_not_matched_insert_all().execute()

result = dt.to_pyarrow_table()
last_action = dt.history(1)[0]

assert last_action["operation"] == "MERGE"
assert result == data
assert last_action["operationParameters"].get("predicate") == "2022-02-01 = date"


@pytest.mark.parametrize(
"timezone,predicate",
[
(
None,
"arrow_cast('2022-02-01T00:00:00.000000', 'Timestamp(Microsecond, None)') = datetime",
),
(
"UTC",
"arrow_cast('2022-02-01T00:00:00.000000', 'Timestamp(Microsecond, Some(\"UTC\"))') = datetime",
),
],
)
def test_merge_timestamps_partitioned_2344(tmp_path: pathlib.Path, timezone, predicate):
from datetime import datetime

schema = pa.schema(
[
("datetime", pa.timestamp("us", tz=timezone)),
("foo", pa.string()),
("bar", pa.string()),
]
)

dt = DeltaTable.create(
tmp_path, schema=schema, partition_by=["datetime"], mode="overwrite"
)

data = pa.table(
{
"datetime": pa.array(
[datetime(2022, 2, 1)], pa.timestamp("us", tz=timezone)
),
"foo": pa.array(["hello"]),
"bar": pa.array(["world"]),
}
)

dt.merge(
data,
predicate="s.datetime = t.datetime",
source_alias="s",
target_alias="t",
).when_matched_update_all().when_not_matched_insert_all().execute()

result = dt.to_pyarrow_table()
last_action = dt.history(1)[0]

assert last_action["operation"] == "MERGE"
assert result == data
assert last_action["operationParameters"].get("predicate") == predicate
Loading