From 155be623beb889a7a6fa79098483b1ee45efecb8 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Sun, 9 Jul 2023 08:47:19 -0700 Subject: [PATCH] fix: handle nulls in file-level stats (#1520) # Description Fixes issue where predicate pushdown isn't working for null values. This adds tests for both columns and partition columns. # Related Issue(s) - closes #1496 # Documentation --- .github/workflows/python_build.yml | 4 +- python/pyproject.toml | 2 +- python/src/lib.rs | 64 +++++++++++++++++++------- python/tests/test_table_read.py | 73 ++++++++++++++++++++++++++++++ 4 files changed, 124 insertions(+), 19 deletions(-) diff --git a/.github/workflows/python_build.yml b/.github/workflows/python_build.yml index cd7315a072..45f865dfdf 100644 --- a/.github/workflows/python_build.yml +++ b/.github/workflows/python_build.yml @@ -36,7 +36,7 @@ jobs: run: make check-rust test-minimal: - name: Python Build (Python 3.7 PyArrow 7.0.0) + name: Python Build (Python 3.7 PyArrow 8.0.0) runs-on: ubuntu-latest env: RUSTFLAGS: "-C debuginfo=0" @@ -70,7 +70,7 @@ jobs: source venv/bin/activate make setup # Install minimum PyArrow version - pip install -e .[pandas,devel] pyarrow==7.0.0 + pip install -e .[pandas,devel] pyarrow==8.0.0 env: RUSTFLAGS: "-C debuginfo=0" diff --git a/python/pyproject.toml b/python/pyproject.toml index 35fa589f76..dc5062a573 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -15,7 +15,7 @@ classifiers = [ "Programming Language :: Python :: 3 :: Only" ] dependencies = [ - "pyarrow>=7", + "pyarrow>=8", 'typing-extensions;python_version<"3.8"', ] diff --git a/python/src/lib.rs b/python/src/lib.rs index efd074b64f..a78978cff1 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -581,6 +581,14 @@ fn json_value_to_py(value: &serde_json::Value, py: Python) -> PyObject { /// /// PyArrow uses this expression to determine which Dataset fragments may be /// skipped during a scan. +/// +/// Partition values are translated to equality expressions (if they are valid) +/// or is_null expression otherwise. For example, if the partition is +/// {"date": "2021-01-01", "x": null}, then the expression is: +/// field(date) = "2021-01-01" AND x IS NULL +/// +/// Statistics are translated into inequalities. If there are null values, then +/// they must be OR'd with is_null. fn filestats_to_expression<'py>( py: Python<'py>, schema: &PyArrowType, @@ -616,10 +624,27 @@ fn filestats_to_expression<'py>( .call1((column,))? .call_method1("__eq__", (converted_value,)), ); + } else { + expressions.push(field.call1((column,))?.call_method0("is_null")); } } if let Some(stats) = stats { + let mut has_nulls_set: HashSet = HashSet::new(); + + for (col_name, null_count) in stats.null_count.iter().filter_map(|(k, v)| match v { + ColumnCountStat::Value(val) => Some((k, val)), + _ => None, + }) { + if *null_count == 0 { + expressions.push(field.call1((col_name,))?.call_method0("is_valid")); + } else if *null_count == stats.num_records { + expressions.push(field.call1((col_name,))?.call_method0("is_null")); + } else { + has_nulls_set.insert(col_name.clone()); + } + } + for (col_name, minimum) in stats.min_values.iter().filter_map(|(k, v)| match v { ColumnValueStat::Value(val) => Some((k.clone(), json_value_to_py(val, py))), // TODO(wjones127): Handle nested field statistics. @@ -628,7 +653,17 @@ fn filestats_to_expression<'py>( }) { let maybe_minimum = cast_to_type(&col_name, minimum, &schema.0); if let Ok(minimum) = maybe_minimum { - expressions.push(field.call1((col_name,))?.call_method1("__ge__", (minimum,))); + let field_expr = field.call1((&col_name,))?; + let expr = field_expr.call_method1("__ge__", (minimum,)); + let expr = if has_nulls_set.contains(&col_name) { + // col >= min_value OR col is null + let is_null_expr = field_expr.call_method0("is_null"); + expr?.call_method1("__or__", (is_null_expr?,)) + } else { + // col >= min_value + expr + }; + expressions.push(expr); } } @@ -638,20 +673,17 @@ fn filestats_to_expression<'py>( }) { let maybe_maximum = cast_to_type(&col_name, maximum, &schema.0); if let Ok(maximum) = maybe_maximum { - expressions.push(field.call1((col_name,))?.call_method1("__le__", (maximum,))); - } - } - - for (col_name, null_count) in stats.null_count.iter().filter_map(|(k, v)| match v { - ColumnCountStat::Value(val) => Some((k, val)), - _ => None, - }) { - if *null_count == stats.num_records { - expressions.push(field.call1((col_name.clone(),))?.call_method0("is_null")); - } - - if *null_count == 0 { - expressions.push(field.call1((col_name.clone(),))?.call_method0("is_valid")); + let field_expr = field.call1((&col_name,))?; + let expr = field_expr.call_method1("__le__", (maximum,)); + let expr = if has_nulls_set.contains(&col_name) { + // col <= max_value OR col is null + let is_null_expr = field_expr.call_method0("is_null"); + expr?.call_method1("__or__", (is_null_expr?,)) + } else { + // col <= max_value + expr + }; + expressions.push(expr); } } } @@ -661,7 +693,7 @@ fn filestats_to_expression<'py>( } else { expressions .into_iter() - .reduce(|accum, item| accum?.getattr("__and__")?.call1((item?,))) + .reduce(|accum, item| accum?.call_method1("__and__", (item?,))) .transpose() } } diff --git a/python/tests/test_table_read.py b/python/tests/test_table_read.py index 99986763cd..0096ac2cc8 100644 --- a/python/tests/test_table_read.py +++ b/python/tests/test_table_read.py @@ -8,6 +8,7 @@ from deltalake.exceptions import DeltaProtocolError from deltalake.table import ProtocolVersions +from deltalake.writer import write_deltalake try: import pandas as pd @@ -543,3 +544,75 @@ def read_table(): t.start() for t in threads: t.join() + + +def assert_num_fragments(table, predicate, count): + frags = table.to_pyarrow_dataset().get_fragments(filter=predicate) + assert len(list(frags)) == count + + +def test_filter_nulls(tmp_path: Path): + def assert_scan_equals(table, predicate, expected): + data = table.to_pyarrow_dataset().to_table(filter=predicate).sort_by("part") + assert data == expected + + # 1 all-valid part, 1 all-null part, and 1 mixed part. + data = pa.table( + {"part": ["a", "a", "b", "b", "c", "c"], "value": [1, 1, None, None, 2, None]} + ) + + write_deltalake(tmp_path, data, partition_by="part") + + table = DeltaTable(tmp_path) + + # Note: we assert number of fragments returned because that verifies + # that file skipping is working properly. + + # is valid predicate + predicate = ds.field("value").is_valid() + assert_num_fragments(table, predicate, 2) + expected = pa.table({"part": ["a", "a", "c"], "value": [1, 1, 2]}) + assert_scan_equals(table, predicate, expected) + + # is null predicate + predicate = ds.field("value").is_null() + assert_num_fragments(table, predicate, 2) + expected = pa.table( + {"part": ["b", "b", "c"], "value": pa.array([None, None, None], pa.int64())} + ) + assert_scan_equals(table, predicate, expected) + + # inequality predicate + predicate = ds.field("value") > 1 + assert_num_fragments(table, predicate, 1) + expected = pa.table({"part": ["c"], "value": pa.array([2], pa.int64())}) + assert_scan_equals(table, predicate, expected) + + # also test nulls in partition values + data = pa.table({"part": pa.array([None], pa.string()), "value": [3]}) + write_deltalake( + table, + data, + mode="append", + partition_by="part", + ) + + # null predicate + predicate = ds.field("part").is_null() + assert_num_fragments(table, predicate, 1) + expected = pa.table({"part": pa.array([None], pa.string()), "value": [3]}) + assert_scan_equals(table, predicate, expected) + + # valid predicate + predicate = ds.field("part").is_valid() + assert_num_fragments(table, predicate, 3) + expected = pa.table( + {"part": ["a", "a", "b", "b", "c", "c"], "value": [1, 1, None, None, 2, None]} + ) + assert_scan_equals(table, predicate, expected) + + # inequality predicate + predicate = ds.field("part") < "c" + assert_num_fragments(table, predicate, 2) + expected = pa.table({"part": ["a", "a", "b", "b"], "value": [1, 1, None, None]}) + assert_scan_equals(table, predicate, expected)