Skip to content

Commit

Permalink
Add StrictMetricsEvaluator (apache#518)
Browse files Browse the repository at this point in the history
* Add StrictMetricsEvaluator

This will enable use to delete whole datafiles by evaluating
the metrics, and not needing to open the Parquet files.

* Split out field check
  • Loading branch information
Fokko authored Mar 15, 2024
1 parent 781096e commit b447461
Show file tree
Hide file tree
Showing 2 changed files with 847 additions and 34 deletions.
352 changes: 319 additions & 33 deletions pyiceberg/expressions/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
DoubleType,
FloatType,
IcebergType,
NestedField,
PrimitiveType,
StructType,
TimestampType,
Expand Down Expand Up @@ -534,7 +535,9 @@ def visit_or(self, left_result: bool, right_result: bool) -> bool:


ROWS_MIGHT_MATCH = True
ROWS_MUST_MATCH = True
ROWS_CANNOT_MATCH = False
ROWS_MIGHT_NOT_MATCH = False
IN_PREDICATE_LIMIT = 200


Expand Down Expand Up @@ -1089,16 +1092,52 @@ def expression_to_plain_format(
return [visit(expression, visitor) for expression in expressions]


class _InclusiveMetricsEvaluator(BoundBooleanExpressionVisitor[bool]):
struct: StructType
expr: BooleanExpression

class _MetricsEvaluator(BoundBooleanExpressionVisitor[bool], ABC):
value_counts: Dict[int, int]
null_counts: Dict[int, int]
nan_counts: Dict[int, int]
lower_bounds: Dict[int, bytes]
upper_bounds: Dict[int, bytes]

def visit_true(self) -> bool:
# all rows match
return ROWS_MIGHT_MATCH

def visit_false(self) -> bool:
# all rows fail
return ROWS_CANNOT_MATCH

def visit_not(self, child_result: bool) -> bool:
raise ValueError(f"NOT should be rewritten: {child_result}")

def visit_and(self, left_result: bool, right_result: bool) -> bool:
return left_result and right_result

def visit_or(self, left_result: bool, right_result: bool) -> bool:
return left_result or right_result

def _contains_nulls_only(self, field_id: int) -> bool:
if (value_count := self.value_counts.get(field_id)) and (null_count := self.null_counts.get(field_id)):
return value_count == null_count
return False

def _contains_nans_only(self, field_id: int) -> bool:
if (nan_count := self.nan_counts.get(field_id)) and (value_count := self.value_counts.get(field_id)):
return nan_count == value_count
return False

def _is_nan(self, val: Any) -> bool:
try:
return math.isnan(val)
except TypeError:
# In the case of None or other non-numeric types
return False


class _InclusiveMetricsEvaluator(_MetricsEvaluator):
struct: StructType
expr: BooleanExpression

def __init__(
self, schema: Schema, expr: BooleanExpression, case_sensitive: bool = True, include_empty_files: bool = False
) -> None:
Expand Down Expand Up @@ -1128,40 +1167,11 @@ def eval(self, file: DataFile) -> bool:
def _may_contain_null(self, field_id: int) -> bool:
return self.null_counts is None or (field_id in self.null_counts and self.null_counts.get(field_id) is not None)

def _contains_nulls_only(self, field_id: int) -> bool:
if (value_count := self.value_counts.get(field_id)) and (null_count := self.null_counts.get(field_id)):
return value_count == null_count
return False

def _contains_nans_only(self, field_id: int) -> bool:
if (nan_count := self.nan_counts.get(field_id)) and (value_count := self.value_counts.get(field_id)):
return nan_count == value_count
return False

def _is_nan(self, val: Any) -> bool:
try:
return math.isnan(val)
except TypeError:
# In the case of None or other non-numeric types
return False

def visit_true(self) -> bool:
# all rows match
return ROWS_MIGHT_MATCH

def visit_false(self) -> bool:
# all rows fail
return ROWS_CANNOT_MATCH

def visit_not(self, child_result: bool) -> bool:
raise ValueError(f"NOT should be rewritten: {child_result}")

def visit_and(self, left_result: bool, right_result: bool) -> bool:
return left_result and right_result

def visit_or(self, left_result: bool, right_result: bool) -> bool:
return left_result or right_result

def visit_is_null(self, term: BoundTerm[L]) -> bool:
field_id = term.ref().field.field_id

Expand Down Expand Up @@ -1421,3 +1431,279 @@ def visit_not_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> bool
return ROWS_CANNOT_MATCH

return ROWS_MIGHT_MATCH


class _StrictMetricsEvaluator(_MetricsEvaluator):
struct: StructType
expr: BooleanExpression

def __init__(
self, schema: Schema, expr: BooleanExpression, case_sensitive: bool = True, include_empty_files: bool = False
) -> None:
self.struct = schema.as_struct()
self.include_empty_files = include_empty_files
self.expr = bind(schema, rewrite_not(expr), case_sensitive)

def eval(self, file: DataFile) -> bool:
"""Test whether all records within the file match the expression.
Args:
file: A data file
Returns: false if the file may contain any row that doesn't match
the expression, true otherwise.
"""
if file.record_count <= 0:
# Older version don't correctly implement record count from avro file and thus
# set record count -1 when importing avro tables to iceberg tables. This should
# be updated once we implemented and set correct record count.
return ROWS_MUST_MATCH

self.value_counts = file.value_counts or EMPTY_DICT
self.null_counts = file.null_value_counts or EMPTY_DICT
self.nan_counts = file.nan_value_counts or EMPTY_DICT
self.lower_bounds = file.lower_bounds or EMPTY_DICT
self.upper_bounds = file.upper_bounds or EMPTY_DICT

return visit(self.expr, self)

def visit_is_null(self, term: BoundTerm[L]) -> bool:
# no need to check whether the field is required because binding evaluates that case
# if the column has any non-null values, the expression does not match
field_id = term.ref().field.field_id

if self._contains_nulls_only(field_id):
return ROWS_MUST_MATCH
else:
return ROWS_MIGHT_NOT_MATCH

def visit_not_null(self, term: BoundTerm[L]) -> bool:
# no need to check whether the field is required because binding evaluates that case
# if the column has any non-null values, the expression does not match
field_id = term.ref().field.field_id

if (null_count := self.null_counts.get(field_id)) is not None and null_count == 0:
return ROWS_MUST_MATCH
else:
return ROWS_MIGHT_NOT_MATCH

def visit_is_nan(self, term: BoundTerm[L]) -> bool:
field_id = term.ref().field.field_id

if self._contains_nans_only(field_id):
return ROWS_MUST_MATCH
else:
return ROWS_MIGHT_NOT_MATCH

def visit_not_nan(self, term: BoundTerm[L]) -> bool:
field_id = term.ref().field.field_id

if (nan_count := self.nan_counts.get(field_id)) is not None and nan_count == 0:
return ROWS_MUST_MATCH

if self._contains_nulls_only(field_id):
return ROWS_MUST_MATCH

return ROWS_MIGHT_NOT_MATCH

def visit_less_than(self, term: BoundTerm[L], literal: Literal[L]) -> bool:
# Rows must match when: <----------Min----Max---X------->

field_id = term.ref().field.field_id

if self._can_contain_nulls(field_id) or self._can_contain_nans(field_id):
return ROWS_MIGHT_NOT_MATCH

if upper_bytes := self.upper_bounds.get(field_id):
field = self._get_field(field_id)
upper = _from_byte_buffer(field.field_type, upper_bytes)

if upper < literal.value:
return ROWS_MUST_MATCH

return ROWS_MIGHT_NOT_MATCH

def visit_less_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool:
# Rows must match when: <----------Min----Max---X------->

field_id = term.ref().field.field_id

if self._can_contain_nulls(field_id) or self._can_contain_nans(field_id):
return ROWS_MIGHT_NOT_MATCH

if upper_bytes := self.upper_bounds.get(field_id):
field = self._get_field(field_id)
upper = _from_byte_buffer(field.field_type, upper_bytes)

if upper <= literal.value:
return ROWS_MUST_MATCH

return ROWS_MIGHT_NOT_MATCH

def visit_greater_than(self, term: BoundTerm[L], literal: Literal[L]) -> bool:
# Rows must match when: <-------X---Min----Max---------->

field_id = term.ref().field.field_id

if self._can_contain_nulls(field_id) or self._can_contain_nans(field_id):
return ROWS_MIGHT_NOT_MATCH

if lower_bytes := self.lower_bounds.get(field_id):
field = self._get_field(field_id)
lower = _from_byte_buffer(field.field_type, lower_bytes)

if self._is_nan(lower):
# NaN indicates unreliable bounds.
# See the _StrictMetricsEvaluator docs for more.
return ROWS_MIGHT_NOT_MATCH

if lower > literal.value:
return ROWS_MUST_MATCH

return ROWS_MIGHT_NOT_MATCH

def visit_greater_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool:
# Rows must match when: <-------X---Min----Max---------->
field_id = term.ref().field.field_id

if self._can_contain_nulls(field_id) or self._can_contain_nans(field_id):
return ROWS_MIGHT_NOT_MATCH

if lower_bytes := self.lower_bounds.get(field_id):
field = self._get_field(field_id)
lower = _from_byte_buffer(field.field_type, lower_bytes)

if self._is_nan(lower):
# NaN indicates unreliable bounds.
# See the _StrictMetricsEvaluator docs for more.
return ROWS_MIGHT_NOT_MATCH

if lower >= literal.value:
return ROWS_MUST_MATCH

return ROWS_MIGHT_NOT_MATCH

def visit_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool:
# Rows must match when Min == X == Max
field_id = term.ref().field.field_id

if self._can_contain_nulls(field_id) or self._can_contain_nans(field_id):
return ROWS_MIGHT_NOT_MATCH

if (lower_bytes := self.lower_bounds.get(field_id)) and (upper_bytes := self.upper_bounds.get(field_id)):
field = self._get_field(field_id)
lower = _from_byte_buffer(field.field_type, lower_bytes)
upper = _from_byte_buffer(field.field_type, upper_bytes)

if lower != literal.value or upper != literal.value:
return ROWS_MIGHT_NOT_MATCH
else:
return ROWS_MUST_MATCH

return ROWS_MIGHT_NOT_MATCH

def visit_not_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool:
# Rows must match when X < Min or Max < X because it is not in the range
field_id = term.ref().field.field_id

if self._can_contain_nulls(field_id) or self._can_contain_nans(field_id):
return ROWS_MUST_MATCH

field = self._get_field(field_id)

if lower_bytes := self.lower_bounds.get(field_id):
lower = _from_byte_buffer(field.field_type, lower_bytes)

if self._is_nan(lower):
# NaN indicates unreliable bounds.
# See the _StrictMetricsEvaluator docs for more.
return ROWS_MIGHT_NOT_MATCH

if lower > literal.value:
return ROWS_MUST_MATCH

if upper_bytes := self.upper_bounds.get(field_id):
upper = _from_byte_buffer(field.field_type, upper_bytes)

if upper < literal.value:
return ROWS_MUST_MATCH

return ROWS_MIGHT_NOT_MATCH

def visit_in(self, term: BoundTerm[L], literals: Set[L]) -> bool:
field_id = term.ref().field.field_id

if self._can_contain_nulls(field_id) or self._can_contain_nans(field_id):
return ROWS_MIGHT_NOT_MATCH

field = self._get_field(field_id)

if (lower_bytes := self.lower_bounds.get(field_id)) and (upper_bytes := self.upper_bounds.get(field_id)):
# similar to the implementation in eq, first check if the lower bound is in the set
lower = _from_byte_buffer(field.field_type, lower_bytes)
if lower not in literals:
return ROWS_MIGHT_NOT_MATCH

# check if the upper bound is in the set
upper = _from_byte_buffer(field.field_type, upper_bytes)
if upper not in literals:
return ROWS_MIGHT_NOT_MATCH

# finally check if the lower bound and the upper bound are equal
if lower != upper:
return ROWS_MIGHT_NOT_MATCH

# All values must be in the set if the lower bound and the upper bound are
# in the set and are equal.
return ROWS_MUST_MATCH

return ROWS_MIGHT_NOT_MATCH

def visit_not_in(self, term: BoundTerm[L], literals: Set[L]) -> bool:
field_id = term.ref().field.field_id

if self._can_contain_nulls(field_id) or self._can_contain_nans(field_id):
return ROWS_MUST_MATCH

field = self._get_field(field_id)

if lower_bytes := self.lower_bounds.get(field_id):
lower = _from_byte_buffer(field.field_type, lower_bytes)

if self._is_nan(lower):
# NaN indicates unreliable bounds.
# See the StrictMetricsEvaluator docs for more.
return ROWS_MIGHT_NOT_MATCH

literals = {val for val in literals if lower <= val}
if len(literals) == 0:
return ROWS_MUST_MATCH

if upper_bytes := self.upper_bounds.get(field_id):
upper = _from_byte_buffer(field.field_type, upper_bytes)

literals = {val for val in literals if upper >= val}

if len(literals) == 0:
return ROWS_MUST_MATCH

return ROWS_MIGHT_NOT_MATCH

def visit_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> bool:
return ROWS_MIGHT_NOT_MATCH

def visit_not_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> bool:
return ROWS_MIGHT_NOT_MATCH

def _get_field(self, field_id: int) -> NestedField:
field = self.struct.field(field_id=field_id)
if field is None:
raise ValueError(f"Cannot find field, might be nested or missing: {field_id}")

return field

def _can_contain_nulls(self, field_id: int) -> bool:
return (null_count := self.null_counts.get(field_id)) is not None and null_count > 0

def _can_contain_nans(self, field_id: int) -> bool:
return (nan_count := self.nan_counts.get(field_id)) is not None and nan_count > 0
Loading

0 comments on commit b447461

Please sign in to comment.