Skip to content

Commit

Permalink
fix: handle nulls in file-level stats
Browse files Browse the repository at this point in the history
  • Loading branch information
wjones127 committed Jul 9, 2023
1 parent 5172dd6 commit f901757
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 16 deletions.
64 changes: 48 additions & 16 deletions python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,14 @@ fn json_value_to_py(value: &serde_json::Value, py: Python) -> PyObject {
///
/// PyArrow uses this expression to determine which Dataset fragments may be
/// skipped during a scan.
///
/// Partition values are translated to equality expressions (if they are valid)
/// or is_null expression otherwise. For example, if the partition is
/// {"date": "2021-01-01", "x": null}, then the expression is:
/// field(date) = "2021-01-01" AND x IS NULL
///
/// Statistics are translated into inequalities. If there are null values, then
/// they must be OR'd with is_null.
fn filestats_to_expression<'py>(
py: Python<'py>,
schema: &PyArrowType<ArrowSchema>,
Expand Down Expand Up @@ -616,10 +624,27 @@ fn filestats_to_expression<'py>(
.call1((column,))?
.call_method1("__eq__", (converted_value,)),
);
} else {
expressions.push(field.call1((column,))?.call_method0("is_null"));
}
}

if let Some(stats) = stats {
let mut has_nulls_set: HashSet<String> = HashSet::new();

for (col_name, null_count) in stats.null_count.iter().filter_map(|(k, v)| match v {
ColumnCountStat::Value(val) => Some((k, val)),
_ => None,
}) {
if *null_count == 0 {
expressions.push(field.call1((col_name,))?.call_method0("is_valid"));
} else if *null_count == stats.num_records {
expressions.push(field.call1((col_name,))?.call_method0("is_null"));
} else {
has_nulls_set.insert(col_name.clone());
}
}

for (col_name, minimum) in stats.min_values.iter().filter_map(|(k, v)| match v {
ColumnValueStat::Value(val) => Some((k.clone(), json_value_to_py(val, py))),
// TODO(wjones127): Handle nested field statistics.
Expand All @@ -628,7 +653,17 @@ fn filestats_to_expression<'py>(
}) {
let maybe_minimum = cast_to_type(&col_name, minimum, &schema.0);
if let Ok(minimum) = maybe_minimum {
expressions.push(field.call1((col_name,))?.call_method1("__ge__", (minimum,)));
let field_expr = field.call1((&col_name,))?;
let expr = field_expr.call_method1("__ge__", (minimum,));
let expr = if has_nulls_set.contains(&col_name) {
// col >= min_value OR col is null
let is_null_expr = field_expr.call_method0("is_null");
expr?.call_method1("__or__", (is_null_expr?,))
} else {
// col >= min_value
expr
};
expressions.push(expr);
}
}

Expand All @@ -638,20 +673,17 @@ fn filestats_to_expression<'py>(
}) {
let maybe_maximum = cast_to_type(&col_name, maximum, &schema.0);
if let Ok(maximum) = maybe_maximum {
expressions.push(field.call1((col_name,))?.call_method1("__le__", (maximum,)));
}
}

for (col_name, null_count) in stats.null_count.iter().filter_map(|(k, v)| match v {
ColumnCountStat::Value(val) => Some((k, val)),
_ => None,
}) {
if *null_count == stats.num_records {
expressions.push(field.call1((col_name.clone(),))?.call_method0("is_null"));
}

if *null_count == 0 {
expressions.push(field.call1((col_name.clone(),))?.call_method0("is_valid"));
let field_expr = field.call1((&col_name,))?;
let expr = field_expr.call_method1("__le__", (maximum,));
let expr = if has_nulls_set.contains(&col_name) {
// col <= max_value OR col is null
let is_null_expr = field_expr.call_method0("is_null");
expr?.call_method1("__or__", (is_null_expr?,))
} else {
// col <= max_value
expr
};
expressions.push(expr);
}
}
}
Expand All @@ -661,7 +693,7 @@ fn filestats_to_expression<'py>(
} else {
expressions
.into_iter()
.reduce(|accum, item| accum?.getattr("__and__")?.call1((item?,)))
.reduce(|accum, item| accum?.call_method1("__and__", (item?,)))
.transpose()
}
}
Expand Down
73 changes: 73 additions & 0 deletions python/tests/test_table_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from deltalake.exceptions import DeltaProtocolError
from deltalake.table import ProtocolVersions
from deltalake.writer import write_deltalake

try:
import pandas as pd
Expand Down Expand Up @@ -543,3 +544,75 @@ def read_table():
t.start()
for t in threads:
t.join()


def assert_num_fragments(table, predicate, count):
frags = table.to_pyarrow_dataset().get_fragments(filter=predicate)
assert len(list(frags)) == count


def test_filter_nulls(tmp_path: Path):
def assert_scan_equals(table, predicate, expected):
data = table.to_pyarrow_dataset().to_table(filter=predicate).sort_by("part")
assert data == expected

# 1 all-valid part, 1 all-null part, and 1 mixed part.
data = pa.table(
{"part": ["a", "a", "b", "b", "c", "c"], "value": [1, 1, None, None, 2, None]}
)

write_deltalake(tmp_path, data, partition_by="part")

table = DeltaTable(tmp_path)

# Note: we assert number of fragments returned because that verifies
# that file skipping is working properly.

# is valid predicate
predicate = ds.field("value").is_valid()
assert_num_fragments(table, predicate, 2)
expected = pa.table({"part": ["a", "a", "c"], "value": [1, 1, 2]})
assert_scan_equals(table, predicate, expected)

# is null predicate
predicate = ds.field("value").is_null()
assert_num_fragments(table, predicate, 2)
expected = pa.table(
{"part": ["b", "b", "c"], "value": pa.array([None, None, None], pa.int64())}
)
assert_scan_equals(table, predicate, expected)

# inequality predicate
predicate = ds.field("value") > 1
assert_num_fragments(table, predicate, 1)
expected = pa.table({"part": ["c"], "value": pa.array([2], pa.int64())})
assert_scan_equals(table, predicate, expected)

# also test nulls in partition values
data = pa.table({"part": pa.array([None], pa.string()), "value": [3]})
write_deltalake(
table,
data,
mode="append",
partition_by="part",
)

# null predicate
predicate = ds.field("part").is_null()
assert_num_fragments(table, predicate, 1)
expected = pa.table({"part": pa.array([None], pa.string()), "value": [3]})
assert_scan_equals(table, predicate, expected)

# valid predicate
predicate = ds.field("part").is_valid()
assert_num_fragments(table, predicate, 3)
expected = pa.table(
{"part": ["a", "a", "b", "b", "c", "c"], "value": [1, 1, None, None, 2, None]}
)
assert_scan_equals(table, predicate, expected)

# inequality predicate
predicate = ds.field("part") < "c"
assert_num_fragments(table, predicate, 2)
expected = pa.table({"part": ["a", "a", "b", "b"], "value": [1, 1, None, None]})
assert_scan_equals(table, predicate, expected)

0 comments on commit f901757

Please sign in to comment.