Skip to content

Commit

Permalink
PARQUET-34: implement size() fiter for repeated columns
Browse files Browse the repository at this point in the history
  • Loading branch information
clairemcginty committed Dec 5, 2024
1 parent aec7bc6 commit 51eabd2
Show file tree
Hide file tree
Showing 21 changed files with 696 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.apache.parquet.filter2.predicate.Operators.NotEq;
import org.apache.parquet.filter2.predicate.Operators.NotIn;
import org.apache.parquet.filter2.predicate.Operators.Or;
import org.apache.parquet.filter2.predicate.Operators.Size;
import org.apache.parquet.filter2.predicate.Operators.UserDefined;

/**
Expand Down Expand Up @@ -97,6 +98,11 @@ public <T extends Comparable<T>> FilterPredicate visit(Contains<T> contains) {
return contains;
}

@Override
public FilterPredicate visit(Size size) {
return size;
}

@Override
public FilterPredicate visit(And and) {
final FilterPredicate left;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import org.apache.parquet.filter2.predicate.Operators.NotIn;
import org.apache.parquet.filter2.predicate.Operators.Or;
import org.apache.parquet.filter2.predicate.Operators.SingleColumnFilterPredicate;
import org.apache.parquet.filter2.predicate.Operators.Size;
import org.apache.parquet.filter2.predicate.Operators.SupportsEqNotEq;
import org.apache.parquet.filter2.predicate.Operators.SupportsLtGt;
import org.apache.parquet.filter2.predicate.Operators.UserDefined;
Expand Down Expand Up @@ -263,6 +264,10 @@ public static <T extends Comparable<T>, P extends SingleColumnFilterPredicate<T>
return Contains.of(pred);
}

public static Size size(Column<?> column, Size.Operator operator, int value) {
return new Size(column, operator, value);
}

/**
* Keeps records that pass the provided {@link UserDefinedPredicate}
* <p>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.apache.parquet.filter2.predicate.Operators.NotEq;
import org.apache.parquet.filter2.predicate.Operators.NotIn;
import org.apache.parquet.filter2.predicate.Operators.Or;
import org.apache.parquet.filter2.predicate.Operators.Size;
import org.apache.parquet.filter2.predicate.Operators.UserDefined;

/**
Expand Down Expand Up @@ -89,6 +90,10 @@ default <T extends Comparable<T>> R visit(Contains<T> contains) {
throw new UnsupportedOperationException("visit Contains is not supported.");
}

default R visit(Size size) {
throw new UnsupportedOperationException("visit Size is not supported.");
}

R visit(And and);

R visit(Or or);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.apache.parquet.filter2.predicate.Operators.NotEq;
import org.apache.parquet.filter2.predicate.Operators.NotIn;
import org.apache.parquet.filter2.predicate.Operators.Or;
import org.apache.parquet.filter2.predicate.Operators.Size;
import org.apache.parquet.filter2.predicate.Operators.UserDefined;

/**
Expand Down Expand Up @@ -104,6 +105,11 @@ public <T extends Comparable<T>> FilterPredicate visit(Contains<T> contains) {
return contains;
}

@Override
public FilterPredicate visit(Size size) {
return size;
}

@Override
public FilterPredicate visit(And and) {
return and(and.getLeft().accept(this), and.getRight().accept(this));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.apache.parquet.filter2.predicate.Operators.NotEq;
import org.apache.parquet.filter2.predicate.Operators.NotIn;
import org.apache.parquet.filter2.predicate.Operators.Or;
import org.apache.parquet.filter2.predicate.Operators.Size;
import org.apache.parquet.filter2.predicate.Operators.UserDefined;

/**
Expand Down Expand Up @@ -98,6 +99,19 @@ public <T extends Comparable<T>> FilterPredicate visit(Contains<T> contains) {
return contains.not();
}

@Override
public FilterPredicate visit(Size size) {
final long value = size.getValue();
final Operators.Column<?> column = size.getColumn();

return size.filter(
(eq) -> new Or(new Size(column, Size.Operator.LT, value), new Size(column, Size.Operator.GT, value)),
(lt) -> new Size(column, Size.Operator.GTE, value),
(ltEq) -> new Size(column, Size.Operator.GT, value),
(gt) -> new Size(column, Size.Operator.LTE, value),
(gtEq) -> new Size(column, Size.Operator.LT, value));
}

@Override
public FilterPredicate visit(And and) {
return new Or(and.getLeft().accept(this), and.getRight().accept(this));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,82 @@ public <R> R filter(
}
}

public static final class Size implements FilterPredicate, Serializable {
public enum Operator {
EQ,
LT,
LTE,
GT,
GTE
}

private final Column<?> column;
private final Operator operator;
private final long value;

Size(Column<?> column, Operator operator, long value) {
this.column = column;
this.operator = operator;
if (value < 0) {
throw new IllegalArgumentException("Argument to size() operator cannot be negative: " + value);
}
this.value = value;
}

@Override
public <R> R accept(Visitor<R> visitor) {
return visitor.visit(this);
}

public long getValue() {
return value;
}

public Column<?> getColumn() {
return column;
}

public <R> R filter(
Function<Long, R> onEq,
Function<Long, R> onLt,
Function<Long, R> onLtEq,
Function<Long, R> onGt,
Function<Long, R> onGtEq) {
if (operator == Operator.EQ) {
return onEq.apply(value);
} else if (operator == Operator.LT) {
return onLt.apply(value);
} else if (operator == Operator.LTE) {
return onLtEq.apply(value);
} else if (operator == Operator.GT) {
return onGt.apply(value);
} else if (operator == Operator.GTE) {
return onGtEq.apply(value);
} else {
throw new UnsupportedOperationException("Operator " + operator + " cannot be used with size() filter");
}
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;

return column.equals(((Size) o).column) && operator == ((Size) o).operator && value == ((Size) o).value;
}

@Override
public int hashCode() {
return Objects.hash(column, operator, value);
}

@Override
public String toString() {
String name = Size.class.getSimpleName().toLowerCase(Locale.ENGLISH);
return name + "(" + operator.toString().toLowerCase() + " " + value + ")";
}
}

public static final class NotIn<T extends Comparable<T>> extends SetColumnFilterPredicate<T> {

NotIn(Column<T> column, Set<T> values) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,11 @@
import org.apache.parquet.filter2.predicate.Operators.NotIn;
import org.apache.parquet.filter2.predicate.Operators.Or;
import org.apache.parquet.filter2.predicate.Operators.SetColumnFilterPredicate;
import org.apache.parquet.filter2.predicate.Operators.Size;
import org.apache.parquet.filter2.predicate.Operators.UserDefined;
import org.apache.parquet.hadoop.metadata.ColumnPath;
import org.apache.parquet.schema.MessageType;
import org.apache.parquet.schema.Type;

/**
* Inspects the column types found in the provided {@link FilterPredicate} and compares them
Expand Down Expand Up @@ -135,6 +137,12 @@ public <T extends Comparable<T>> Void visit(Contains<T> pred) {
return null;
}

@Override
public Void visit(Size size) {
validateColumn(size.getColumn(), true, true);
return null;
}

@Override
public Void visit(And and) {
and.getLeft().accept(this);
Expand Down Expand Up @@ -175,14 +183,15 @@ private <T extends Comparable<T>> void validateColumnFilterPredicate(SetColumnFi
}

private <T extends Comparable<T>> void validateColumnFilterPredicate(Contains<T> pred) {
validateColumn(pred.getColumn(), true);
validateColumn(pred.getColumn(), true, false);
}

private <T extends Comparable<T>> void validateColumn(Column<T> column) {
validateColumn(column, false);
validateColumn(column, false, false);
}

private <T extends Comparable<T>> void validateColumn(Column<T> column, boolean shouldBeRepeated) {
private <T extends Comparable<T>> void validateColumn(
Column<T> column, boolean isRepeatedColumn, boolean mustBeRequired) {
ColumnPath path = column.getColumnPath();

Class<?> alreadySeen = columnTypesEncountered.get(path);
Expand All @@ -204,15 +213,21 @@ private <T extends Comparable<T>> void validateColumn(Column<T> column, boolean
return;
}

if (shouldBeRepeated && descriptor.getMaxRepetitionLevel() == 0) {
if (isRepeatedColumn && descriptor.getMaxRepetitionLevel() == 0) {
throw new IllegalArgumentException(
"FilterPredicate for column " + path.toDotString() + " requires a repeated "
+ "schema, but found max repetition level " + descriptor.getMaxRepetitionLevel());
} else if (!shouldBeRepeated && descriptor.getMaxRepetitionLevel() > 0) {
} else if (!isRepeatedColumn && descriptor.getMaxRepetitionLevel() > 0) {
throw new IllegalArgumentException("FilterPredicates do not currently support repeated columns. "
+ "Column " + path.toDotString() + " is repeated.");
}

if (mustBeRequired && descriptor.getPrimitiveType().isRepetition(Type.Repetition.OPTIONAL)) {
throw new IllegalArgumentException("FilterPredicate for column " + path.toDotString()
+ " requires schema to have repetition REQUIRED, but found "
+ descriptor.getPrimitiveType().getRepetition() + ".");
}

ValidTypeMap.assertTypeValid(column, descriptor.getType());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import java.util.Arrays;
import java.util.Objects;
import java.util.function.Function;
import org.apache.parquet.io.api.Binary;

/**
Expand Down Expand Up @@ -223,6 +224,74 @@ public void reset() {
}
}

class CountingValueInspector extends ValueInspector {
private long observedValueCount;
private final ValueInspector delegate;
private final Function<Long, Boolean> shouldUpdateDelegate;

public CountingValueInspector(ValueInspector delegate, Function<Long, Boolean> shouldUpdateDelegate) {
this.observedValueCount = 0;
this.delegate = delegate;
this.shouldUpdateDelegate = shouldUpdateDelegate;
}

@Override
public void updateNull() {
delegate.update(observedValueCount);
if (!delegate.isKnown()) {
delegate.updateNull();
}
setResult(delegate.getResult());
}

@Override
public void update(int value) {
incrementCount();
}

@Override
public void update(long value) {
incrementCount();
}

@Override
public void update(double value) {
incrementCount();
}

@Override
public void update(float value) {
incrementCount();
}

@Override
public void update(boolean value) {
incrementCount();
}

@Override
public void update(Binary value) {
incrementCount();
}

@Override
public void reset() {
super.reset();
delegate.reset();
observedValueCount = 0;
}

private void incrementCount() {
observedValueCount++;
if (!delegate.isKnown() && shouldUpdateDelegate.apply(observedValueCount)) {
delegate.update(observedValueCount);
if (delegate.isKnown()) {
setResult(delegate.getResult());
}
}
}
}

// base class for and / or
abstract static class BinaryLogical implements IncrementallyUpdatedFilterPredicate {
private final IncrementallyUpdatedFilterPredicate left;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,19 @@
import java.util.List;
import java.util.Map;
import org.apache.parquet.column.ColumnDescriptor;
import org.apache.parquet.filter2.predicate.FilterApi;
import org.apache.parquet.filter2.predicate.FilterPredicate;
import org.apache.parquet.filter2.predicate.FilterPredicate.Visitor;
import org.apache.parquet.filter2.predicate.Operators;
import org.apache.parquet.filter2.predicate.Operators.And;
import org.apache.parquet.filter2.predicate.Operators.Not;
import org.apache.parquet.filter2.predicate.Operators.Or;
import org.apache.parquet.filter2.recordlevel.IncrementallyUpdatedFilterPredicate.ValueInspector;
import org.apache.parquet.hadoop.metadata.ColumnPath;
import org.apache.parquet.io.PrimitiveColumnIO;
import org.apache.parquet.schema.PrimitiveComparator;
import org.apache.parquet.schema.PrimitiveType;
import org.apache.parquet.schema.Type;

/**
* The implementation of this abstract class is auto-generated by
Expand All @@ -56,6 +60,8 @@
*/
public abstract class IncrementallyUpdatedFilterPredicateBuilderBase
implements Visitor<IncrementallyUpdatedFilterPredicate> {
static final Operators.LongColumn SIZE_PSUEDOCOLUMN = FilterApi.longColumn("SIZE");

private boolean built = false;
private final Map<ColumnPath, List<ValueInspector>> valueInspectorsByColumn = new HashMap<>();
private final Map<ColumnPath, PrimitiveComparator<?>> comparatorsByColumn = new HashMap<>();
Expand All @@ -70,6 +76,13 @@ public IncrementallyUpdatedFilterPredicateBuilderBase(List<PrimitiveColumnIO> le
PrimitiveComparator<?> comparator = descriptor.getPrimitiveType().comparator();
comparatorsByColumn.put(path, comparator);
}
comparatorsByColumn.put(
SIZE_PSUEDOCOLUMN.getColumnPath(),
new PrimitiveType(
Type.Repetition.REQUIRED,
PrimitiveType.PrimitiveTypeName.INT64,
SIZE_PSUEDOCOLUMN.getColumnPath().toDotString())
.comparator());
}

public final IncrementallyUpdatedFilterPredicate build(FilterPredicate pred) {
Expand All @@ -80,6 +93,11 @@ public final IncrementallyUpdatedFilterPredicate build(FilterPredicate pred) {
}

protected final void addValueInspector(ColumnPath columnPath, ValueInspector valueInspector) {
if (columnPath.equals(SIZE_PSUEDOCOLUMN.getColumnPath())) {
// do not add psuedocolumn to list of value inspectors
return;
}

List<ValueInspector> valueInspectors = valueInspectorsByColumn.get(columnPath);
if (valueInspectors == null) {
valueInspectors = new ArrayList<>();
Expand Down
Loading

0 comments on commit 51eabd2

Please sign in to comment.