From e9264acdacab839e4a0ccc1e98af78a9bbc91adc Mon Sep 17 00:00:00 2001 From: coufon Date: Sat, 23 Dec 2023 06:14:26 +0000 Subject: [PATCH] Add falsifiable filter on index manifest files --- python/pyproject.toml | 15 +- .../core/manifests/falsifiable_filters.py | 168 +++++++++++++++ python/src/space/core/manifests/index.py | 26 +-- python/src/space/core/ops/append.py | 7 +- python/src/space/core/schema/arrow.py | 21 +- python/src/space/core/schema/constants.py | 5 + python/src/space/core/schema/utils.py | 43 ++++ .../manifests/test_falsifiable_filters.py | 194 ++++++++++++++++++ python/tests/core/schema/test_arrow.py | 23 ++- 9 files changed, 446 insertions(+), 56 deletions(-) create mode 100644 python/src/space/core/manifests/falsifiable_filters.py create mode 100644 python/src/space/core/schema/utils.py create mode 100644 python/tests/core/manifests/test_falsifiable_filters.py diff --git a/python/pyproject.toml b/python/pyproject.toml index 5ffe16e..64e597a 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -14,6 +14,7 @@ classifiers = [ ] requires-python = ">=3.8" dependencies = [ + "absl-py", "array-record", "numpy", "protobuf", @@ -39,14 +40,14 @@ pythonpath = ["src"] [tool.pylint.format] max-line-length = 80 -indent-string = ' ' -disable = ['fixme'] +indent-string = " " +disable = ["fixme", "no-else-return"] [tool.pylint.MAIN] -ignore = 'space/core/proto' +ignore = "space/core/proto" ignored-modules = [ - 'space.core.proto', - 'google.protobuf', - 'substrait', - 'array_record', + "space.core.proto", + "google.protobuf", + "substrait", + "array_record", ] diff --git a/python/src/space/core/manifests/falsifiable_filters.py b/python/src/space/core/manifests/falsifiable_filters.py new file mode 100644 index 0000000..c0a0164 --- /dev/null +++ b/python/src/space/core/manifests/falsifiable_filters.py @@ -0,0 +1,168 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Implementation of falsifiable filters for Substrait expressions. + +Falsifiable filters are obtained by converting filters on data to filters on +index manifest files" column statistics (e.g., min and max) to prune away data +files that are impossible to contain the data. See +https://vldb.org/pvldb/vol14/p3083-edara.pdf. +""" + +from typing import Dict, List, Optional +from functools import partial + +from absl import logging # type: ignore[import-untyped] +import pyarrow as pa +import pyarrow.compute as pc +import pyarrow.substrait as ps +from substrait.algebra_pb2 import Expression +from substrait.extensions.extensions_pb2 import SimpleExtensionDeclaration +from substrait.extended_expression_pb2 import ExtendedExpression +from substrait.type_pb2 import NamedStruct + +from space.core.schema import utils as schema_utils +from space.core.schema import constants + + +def substrait_expr(schema: pa.Schema, + arrow_expr: pc.Expression) -> ExtendedExpression: + """Convert an expression from Arrow to Substrait format. + + PyArrow does not expose enough methods for processing expressions, thus we + convert it to Substrait format for processing. + """ + buf = ps.serialize_expressions( # type: ignore[attr-defined] + [arrow_expr], ['expr'], schema) + + expr = ExtendedExpression() + expr.ParseFromString(buf.to_pybytes()) + return expr + + +class ExpressionException(Exception): + """Raise for exceptions in expressions.""" + + +def falsifiable_filter( + filter_: ExtendedExpression, + field_name_to_id_dict: Dict[str, int]) -> Optional[pc.Expression]: + """Build a falsifiable filter. + + Args: + filter_: a filter on data fields. + field_name_to_id_dict: a dict of field names to IDs mapping. + + Returns: + Falsifiable filter, or None if not convertable. + """ + if len(filter_.referred_expr) != 1: + logging.warning( + f"Expect 1 referred expr, found: {len(filter_.referred_expr)}; " + "Falsifiable filter is not used.") + return None + + return _falsifiable_filter( + filter_.extensions, # type: ignore[arg-type] + filter_.base_schema, + field_name_to_id_dict, + filter_.referred_expr[0].expression.scalar_function) + + +# pylint: disable=too-many-locals,too-many-return-statements +def _falsifiable_filter( + extensions: List[SimpleExtensionDeclaration], base_schema: NamedStruct, + field_name_to_id_dict: Dict[str, int], + root: Expression.ScalarFunction) -> Optional[pc.Expression]: + if len(root.arguments) != 2: + logging.warning(f"Invalid number of arguments: {root.arguments}; " + "Falsifiable filter is not used.") + return None + + fn = extensions[root.function_reference].extension_function.name + lhs = root.arguments[0].value + rhs = root.arguments[1].value + + falsifiable_filter_fn = partial(_falsifiable_filter, extensions, base_schema, + field_name_to_id_dict) + + if _has_scalar_function(lhs) and _has_scalar_function(rhs): + lhs_fn = lhs.scalar_function + rhs_fn = rhs.scalar_function + + # TODO: to support more functions. + if fn == "and": + return falsifiable_filter_fn(lhs_fn) | falsifiable_filter_fn( + rhs_fn) # type: ignore[operator] + elif fn == "or": + return falsifiable_filter_fn(lhs_fn) & falsifiable_filter_fn( + rhs_fn) # type: ignore[operator] + else: + logging.warning(f"Unsupported fn: {fn}; Falsifiable filter is not used.") + return None + + if _has_selection(lhs) and _has_selection(rhs): + logging.warning(f"Both args are fields: {root.arguments}; " + "Falsifiable filter is not used.") + return None + + if _has_literal(lhs) and _has_literal(rhs): + logging.warning(f"Both args are constants: {root.arguments}; " + "Falsifiable filter is not used.") + return None + + # Move literal to rhs. + if _has_selection(rhs): + tmp, lhs = lhs, rhs + rhs = tmp + + field_index = lhs.selection.direct_reference.struct_field.field + field_name = base_schema.names[field_index] + field_id = field_name_to_id_dict[field_name] + field_min, field_max = _stats_field_min(field_id), _stats_field_max(field_id) + value = pc.scalar( + getattr( + rhs.literal, + rhs.literal.WhichOneof("literal_type"))) # type: ignore[arg-type] + + # TODO: to support more functions. + if fn == "gt": + return field_max <= value + elif fn == "lt": + return field_min >= value + elif fn == "equal": + return (field_min > value) | (field_max < value) + + logging.warning(f"Unsupported fn: {fn}; Falsifiable filter is not used.") + return None + + +def _stats_field_min(field_id: int) -> pc.Expression: + return pc.field(schema_utils.stats_field_name(field_id), constants.MIN_FIELD) + + +def _stats_field_max(field_id: int) -> pc.Expression: + return pc.field(schema_utils.stats_field_name(field_id), constants.MAX_FIELD) + + +def _has_scalar_function(msg: Expression) -> bool: + return msg.HasField("scalar_function") + + +def _has_selection(msg: Expression) -> bool: + return msg.HasField("selection") + + +def _has_literal(msg: Expression) -> bool: + return msg.HasField("literal") diff --git a/python/src/space/core/manifests/index.py b/python/src/space/core/manifests/index.py index 46bfac0..74d3503 100644 --- a/python/src/space/core/manifests/index.py +++ b/python/src/space/core/manifests/index.py @@ -22,6 +22,7 @@ from space.core.manifests.utils import write_parquet_file import space.core.proto.metadata_pb2 as meta from space.core.schema import constants +from space.core.schema import utils as schema_utils from space.core.schema.arrow import field_id, field_id_to_column_id_dict from space.core.utils import paths @@ -29,26 +30,13 @@ _INDEX_COMPRESSED_BYTES_FIELD = '_INDEX_COMPRESSED_BYTES' _INDEX_UNCOMPRESSED_BYTES_FIELD = '_INDEX_UNCOMPRESSED_BYTES' -# Constants for building column statistics field name. -_STATS_FIELD = "_STATS" -_MIN_FIELD = "_MIN" -_MAX_FIELD = "_MAX" - - -def _stats_field_name(field_id_: int) -> str: - """Column stats struct field name. - - It uses field ID instead of name. Manifest file has all Parquet files and it - is not tied with one Parquet schema, we can't do table field name to file - field name projection. Using field ID ensures that we can always uniquely - identifies a field. - """ - return f"{_STATS_FIELD}_f{field_id_}" - def _stats_subfields(type_: pa.DataType) -> List[pa.Field]: """Column stats struct field sub-fields.""" - return [pa.field(_MIN_FIELD, type_), pa.field(_MAX_FIELD, type_)] + return [ + pa.field(constants.MIN_FIELD, type_), + pa.field(constants.MAX_FIELD, type_) + ] def _manifest_schema( @@ -70,8 +58,8 @@ def _manifest_schema( continue field_id_ = field_id(f) - fields.append( - (_stats_field_name(field_id_), pa.struct(_stats_subfields(f.type)))) + fields.append((schema_utils.stats_field_name(field_id_), + pa.struct(_stats_subfields(f.type)))) stats_fields.append((field_id_, f.type)) return pa.schema(fields), stats_fields # type: ignore[arg-type] diff --git a/python/src/space/core/ops/append.py b/python/src/space/core/ops/append.py index 0f4e531..541736e 100644 --- a/python/src/space/core/ops/append.py +++ b/python/src/space/core/ops/append.py @@ -29,6 +29,7 @@ from space.core.proto import metadata_pb2 as meta from space.core.proto import runtime_pb2 as runtime from space.core.schema import arrow +from space.core.schema import utils as schema_utils from space.core.utils import paths from space.core.utils.lazy_imports_utils import array_record_module as ar from space.core.utils.paths import StoragePaths @@ -161,7 +162,7 @@ def _append_arrow(self, data: pa.Table) -> None: index_data = data self._maybe_create_index_writer() - index_data = data.select(arrow.field_names(self._index_fields)) + index_data = data.select(schema_utils.field_names(self._index_fields)) # Write record fields into files. # TODO: to parallelize it. @@ -220,7 +221,7 @@ def _finish_index_writer(self) -> None: self._cached_index_file_bytes = 0 def _write_record_column( - self, field: arrow.Field, + self, field: schema_utils.Field, column: pa.ChunkedArray) -> Tuple[str, pa.StructArray]: """Write record field into files. @@ -259,7 +260,7 @@ def _write_record_column( return field_name, address_column - def _finish_record_writer(self, field: arrow.Field, + def _finish_record_writer(self, field: schema_utils.Field, writer_info: _RecordWriterInfo) -> None: """Materialize a new record file (ArrayRecord), update metadata and stats. diff --git a/python/src/space/core/schema/arrow.py b/python/src/space/core/schema/arrow.py index 8f7961b..0740e02 100644 --- a/python/src/space/core/schema/arrow.py +++ b/python/src/space/core/schema/arrow.py @@ -22,6 +22,7 @@ from space.core.schema import constants from space.core.schema.types import TfFeatures +from space.core.schema import utils from space.core.utils.constants import UTF_8 _PARQUET_FIELD_ID_KEY = b"PARQUET:field_id" @@ -153,18 +154,11 @@ def field_id_to_column_id_dict(schema: pa.Schema) -> Dict[int, int]: } -@dataclass -class Field: - """Information of a field.""" - name: str - field_id: int - - def classify_fields( schema: pa.Schema, record_fields: Set[str], selected_fields: Optional[Set[str]] = None -) -> Tuple[List[Field], List[Field]]: +) -> Tuple[List[utils.Field], List[utils.Field]]: """Classify fields into indexes and records. Args: @@ -175,14 +169,14 @@ def classify_fields( Returns: A tuple (index_fields, record_fields). """ - index_fields: List[Field] = [] - record_fields_: List[Field] = [] + index_fields: List[utils.Field] = [] + record_fields_: List[utils.Field] = [] for f in schema: if selected_fields is not None and f.name not in selected_fields: continue - field = Field(f.name, field_id(f)) + field = utils.Field(f.name, field_id(f)) if f.name in record_fields: record_fields_.append(field) else: @@ -191,11 +185,6 @@ def classify_fields( return index_fields, record_fields_ -def field_names(fields: List[Field]) -> List[str]: - """Extract field names from a list of fields.""" - return list(map(lambda f: f.name, fields)) - - def record_address_types() -> List[Tuple[str, pa.DataType]]: """Returns Arrow fields of record addresses.""" return [(constants.FILE_PATH_FIELD, pa.string()), diff --git a/python/src/space/core/schema/constants.py b/python/src/space/core/schema/constants.py index 91daa6a..51ec864 100644 --- a/python/src/space/core/schema/constants.py +++ b/python/src/space/core/schema/constants.py @@ -23,3 +23,8 @@ NUM_ROWS_FIELD = "_NUM_ROWS" UNCOMPRESSED_BYTES_FIELD = "_UNCOMPRESSED_BYTES" + +# Constants for building column statistics field name. +STATS_FIELD = "_STATS" +MIN_FIELD = "_MIN" +MAX_FIELD = "_MAX" diff --git a/python/src/space/core/schema/utils.py b/python/src/space/core/schema/utils.py new file mode 100644 index 0000000..a6b132c --- /dev/null +++ b/python/src/space/core/schema/utils.py @@ -0,0 +1,43 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Common utilities for schemas.""" + +from dataclasses import dataclass +from typing import List + +from space.core.schema import constants + + +@dataclass +class Field: + """Information of a field.""" + name: str + field_id: int + + +def field_names(fields: List[Field]) -> List[str]: + """Extract field names from a list of fields.""" + return list(map(lambda f: f.name, fields)) + + +def stats_field_name(field_id_: int) -> str: + """Column stats struct field name. + + It uses field ID instead of name. Manifest file has all Parquet files and it + is not tied with one Parquet schema, we can't do table field name to file + field name projection. Using field ID ensures that we can always uniquely + identifies a field. + """ + return f"{constants.STATS_FIELD}_f{field_id_}" diff --git a/python/tests/core/manifests/test_falsifiable_filters.py b/python/tests/core/manifests/test_falsifiable_filters.py new file mode 100644 index 0000000..262a4b1 --- /dev/null +++ b/python/tests/core/manifests/test_falsifiable_filters.py @@ -0,0 +1,194 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.protobuf import text_format +import pyarrow as pa +import pyarrow.compute as pc +import pytest +from substrait.extended_expression_pb2 import ExtendedExpression + +from space.core.manifests import falsifiable_filters as ff + +_SAMPLE_SUBSTRAIT_EXPR = """ +extension_uris { + uri: "https://github.com/substrait-io/substrait/blob/main/extensions/functions_comparison.yaml" +} +extension_uris { + extension_uri_anchor: 1 + uri: "https://github.com/substrait-io/substrait/blob/main/extensions/functions_comparison.yaml" +} +extension_uris { + extension_uri_anchor: 2 + uri: "https://github.com/substrait-io/substrait/blob/main/extensions/functions_boolean.yaml" +} +extension_uris { + extension_uri_anchor: 3 + uri: "https://github.com/substrait-io/substrait/blob/main/extensions/functions_comparison.yaml" +} +extension_uris { + extension_uri_anchor: 4 + uri: "https://github.com/substrait-io/substrait/blob/main/extensions/functions_comparison.yaml" +} +extensions { + extension_function { + extension_uri_reference: 4 + name: "gt" + } +} +extensions { + extension_function { + extension_uri_reference: 4 + function_anchor: 1 + name: "equal" + } +} +extensions { + extension_function { + extension_uri_reference: 2 + function_anchor: 2 + name: "and" + } +} +referred_expr { + expression { + scalar_function { + function_reference: 2 + output_type { + bool { + nullability: NULLABILITY_NULLABLE + } + } + arguments { + value { + scalar_function { + output_type { + bool { + nullability: NULLABILITY_NULLABLE + } + } + arguments { + value { + selection { + direct_reference { + struct_field { + } + } + root_reference { + } + } + } + } + arguments { + value { + literal { + i64: 10 + } + } + } + } + } + } + arguments { + value { + scalar_function { + function_reference: 1 + output_type { + bool { + nullability: NULLABILITY_NULLABLE + } + } + arguments { + value { + selection { + direct_reference { + struct_field { + field: 1 + } + } + root_reference { + } + } + } + } + arguments { + value { + literal { + fp64: 1.0 + } + } + } + } + } + } + } + } + output_names: "expr" +} +base_schema { + names: "a" + names: "b" + struct { + types { + i64 { + nullability: NULLABILITY_NULLABLE + } + } + types { + fp64 { + nullability: NULLABILITY_NULLABLE + } + } + } +} +""" + + +def test_substrait_expr(): + arrow_schema = pa.schema([("a", pa.int64()), ("b", pa.float64())]) # pylint: disable=too-few-public-methods + arrow_expr = (pc.field("a") > 10) & (pc.field("b") == 1) + + substrait_expr = ff.substrait_expr(arrow_schema, arrow_expr) + substrait_expr.ClearField("version") + + expected_expr = text_format.Parse(_SAMPLE_SUBSTRAIT_EXPR, + ExtendedExpression()) + assert substrait_expr == expected_expr + + +@pytest.mark.parametrize("filter_,falsifiable_filter", + [((pc.field("a") < 10) | (pc.field("b") > 1), + (pc.field("_STATS_f0", "_MIN") >= 10) & + (pc.field("_STATS_f1", "_MIN") <= 10)), + ((pc.field("a") > 10) & (pc.field("b") == 1), + (pc.field("_STATS_f0", "_MAX") <= 10) | + ((pc.field("_STATS_f1", "_MIN") > 1) | + (pc.field("_STATS_f1", "_MAX") < 1)))]) +def test_falsifiable_filter(filter_, falsifiable_filter): + arrow_schema = pa.schema([("a", pa.int64()), ("b", pa.float64())]) # pylint: disable=too-few-public-methods + field_name_to_id_dict = {"a": 0, "b": 1} + substrait_expr = ff.substrait_expr(arrow_schema, filter_) + + falsifiable_filter = ff.falsifiable_filter(substrait_expr, + field_name_to_id_dict) + assert str(falsifiable_filter) == str(falsifiable_filter) + + +@pytest.mark.parametrize("filter_", [(pc.field("a") != 10), + (~(pc.field("a") > 10))]) +def test_falsifiable_filter_not_supported_return_none(filter_): + arrow_schema = pa.schema([("a", pa.int64()), ("b", pa.float64())]) # pylint: disable=too-few-public-methods + field_name_to_id_dict = {"a": 0, "b": 1} + substrait_expr = ff.substrait_expr(arrow_schema, filter_) + + assert ff.falsifiable_filter(substrait_expr, field_name_to_id_dict) is None diff --git a/python/tests/core/schema/test_arrow.py b/python/tests/core/schema/test_arrow.py index 9d7c409..543ce80 100644 --- a/python/tests/core/schema/test_arrow.py +++ b/python/tests/core/schema/test_arrow.py @@ -15,6 +15,7 @@ import pyarrow as pa from space.core.schema import arrow +from space.core.schema import utils from space.core.schema.arrow import field_metadata @@ -82,13 +83,13 @@ def test_classify_fields(sample_arrow_schema): ["float32", "list"]) assert index_fields == [ - arrow.Field("struct", 150), - arrow.Field("list_struct", 220), - arrow.Field("struct_list", 260) + utils.Field("struct", 150), + utils.Field("list_struct", 220), + utils.Field("struct_list", 260) ] assert record_fields == [ - arrow.Field("float32", 100), - arrow.Field("list", 120) + utils.Field("float32", 100), + utils.Field("list", 120) ] @@ -97,13 +98,13 @@ def test_classify_fields_with_selected_fields(sample_arrow_schema): ["float32", "list"], ["list", "struct"]) - assert index_fields == [arrow.Field("struct", 150)] - assert record_fields == [arrow.Field("list", 120)] + assert index_fields == [utils.Field("struct", 150)] + assert record_fields == [utils.Field("list", 120)] def test_field_names(): - assert arrow.field_names([ - arrow.Field("struct", 150), - arrow.Field("list_struct", 220), - arrow.Field("struct_list", 260) + assert utils.field_names([ + utils.Field("struct", 150), + utils.Field("list_struct", 220), + utils.Field("struct_list", 260) ]) == ["struct", "list_struct", "struct_list"]