Skip to content

Commit

Permalink
ESQL: Fix filtering all elements in aggs (elastic#113804)
Browse files Browse the repository at this point in the history
This adds a test to *every* agg for when it's entirely filtered away and
another when filtering is enabled but unused. I'll follow up with
another test later for partial filtering.

That test caught a bug where some aggs would think they'd been `seen`
when they hadn't. This fixes that too.
  • Loading branch information
nik9000 authored Oct 2, 2024
1 parent eb9b897 commit a18b331
Show file tree
Hide file tree
Showing 78 changed files with 804 additions and 299 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.elasticsearch.compute.aggregation.CountAggregatorFunction;
import org.elasticsearch.compute.aggregation.CountDistinctDoubleAggregatorFunctionSupplier;
import org.elasticsearch.compute.aggregation.CountDistinctLongAggregatorFunctionSupplier;
import org.elasticsearch.compute.aggregation.FilteredAggregatorFunctionSupplier;
import org.elasticsearch.compute.aggregation.MaxDoubleAggregatorFunctionSupplier;
import org.elasticsearch.compute.aggregation.MaxLongAggregatorFunctionSupplier;
import org.elasticsearch.compute.aggregation.MinDoubleAggregatorFunctionSupplier;
Expand All @@ -27,6 +28,7 @@
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.BooleanBlock;
import org.elasticsearch.compute.data.BooleanVector;
import org.elasticsearch.compute.data.BytesRefBlock;
import org.elasticsearch.compute.data.DoubleBlock;
import org.elasticsearch.compute.data.ElementType;
Expand All @@ -35,6 +37,7 @@
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.AggregationOperator;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.EvalOperator;
import org.elasticsearch.compute.operator.HashAggregationOperator;
import org.elasticsearch.compute.operator.Operator;
import org.openjdk.jmh.annotations.Benchmark;
Expand Down Expand Up @@ -94,13 +97,20 @@ public class AggregatorBenchmark {

private static final String NONE = "none";

private static final String CONSTANT_TRUE = "constant_true";
private static final String ALL_TRUE = "all_true";
private static final String HALF_TRUE = "half_true";
private static final String CONSTANT_FALSE = "constant_false";

static {
// Smoke test all the expected values and force loading subclasses more like prod
try {
for (String grouping : AggregatorBenchmark.class.getField("grouping").getAnnotationsByType(Param.class)[0].value()) {
for (String op : AggregatorBenchmark.class.getField("op").getAnnotationsByType(Param.class)[0].value()) {
for (String blockType : AggregatorBenchmark.class.getField("blockType").getAnnotationsByType(Param.class)[0].value()) {
run(grouping, op, blockType, 50);
for (String filter : AggregatorBenchmark.class.getField("filter").getAnnotationsByType(Param.class)[0].value()) {
run(grouping, op, blockType, filter, 10);
}
}
}
}
Expand All @@ -118,10 +128,14 @@ public class AggregatorBenchmark {
@Param({ VECTOR_LONGS, HALF_NULL_LONGS, VECTOR_DOUBLES, HALF_NULL_DOUBLES })
public String blockType;

private static Operator operator(DriverContext driverContext, String grouping, String op, String dataType) {
@Param({ NONE, CONSTANT_TRUE, ALL_TRUE, HALF_TRUE, CONSTANT_FALSE })
public String filter;

private static Operator operator(DriverContext driverContext, String grouping, String op, String dataType, String filter) {

if (grouping.equals("none")) {
return new AggregationOperator(
List.of(supplier(op, dataType, 0).aggregatorFactory(AggregatorMode.SINGLE).apply(driverContext)),
List.of(supplier(op, dataType, filter, 0).aggregatorFactory(AggregatorMode.SINGLE).apply(driverContext)),
driverContext
);
}
Expand All @@ -144,14 +158,14 @@ private static Operator operator(DriverContext driverContext, String grouping, S
default -> throw new IllegalArgumentException("unsupported grouping [" + grouping + "]");
};
return new HashAggregationOperator(
List.of(supplier(op, dataType, groups.size()).groupingAggregatorFactory(AggregatorMode.SINGLE)),
List.of(supplier(op, dataType, filter, groups.size()).groupingAggregatorFactory(AggregatorMode.SINGLE)),
() -> BlockHash.build(groups, driverContext.blockFactory(), 16 * 1024, false),
driverContext
);
}

private static AggregatorFunctionSupplier supplier(String op, String dataType, int dataChannel) {
return switch (op) {
private static AggregatorFunctionSupplier supplier(String op, String dataType, String filter, int dataChannel) {
return filtered(switch (op) {
case COUNT -> CountAggregatorFunction.supplier(List.of(dataChannel));
case COUNT_DISTINCT -> switch (dataType) {
case LONGS -> new CountDistinctLongAggregatorFunctionSupplier(List.of(dataChannel), 3000);
Expand All @@ -174,10 +188,22 @@ private static AggregatorFunctionSupplier supplier(String op, String dataType, i
default -> throw new IllegalArgumentException("unsupported data type [" + dataType + "]");
};
default -> throw new IllegalArgumentException("unsupported op [" + op + "]");
};
}, filter);
}

private static void checkExpected(String grouping, String op, String blockType, String dataType, Page page, int opCount) {
private static void checkExpected(
String grouping,
String op,
String blockType,
String filter,
String dataType,
Page page,
int opCount
) {
if (filter.equals(CONSTANT_FALSE) || filter.equals(HALF_TRUE)) {
// We don't verify these because it's hard to get the right answer.
return;
}
String prefix = String.format("[%s][%s][%s] ", grouping, op, blockType);
if (grouping.equals("none")) {
checkUngrouped(prefix, op, dataType, page, opCount);
Expand Down Expand Up @@ -559,27 +585,73 @@ private static BytesRef bytesGroup(int group) {
});
}

private static AggregatorFunctionSupplier filtered(AggregatorFunctionSupplier agg, String filter) {
if (filter.equals("none")) {
return agg;
}
BooleanBlock mask = mask(filter).asBlock();
return new FilteredAggregatorFunctionSupplier(agg, context -> new EvalOperator.ExpressionEvaluator() {
@Override
public Block eval(Page page) {
mask.incRef();
return mask;
}

@Override
public void close() {
mask.close();
}
});
}

private static BooleanVector mask(String filter) {
// Usually BLOCK_LENGTH is the count of positions, but sometimes the blocks are longer
int positionCount = BLOCK_LENGTH * 10;
return switch (filter) {
case CONSTANT_TRUE -> blockFactory.newConstantBooleanVector(true, positionCount);
case ALL_TRUE -> {
try (BooleanVector.Builder builder = blockFactory.newBooleanVectorFixedBuilder(positionCount)) {
for (int i = 0; i < positionCount; i++) {
builder.appendBoolean(true);
}
yield builder.build();
}
}
case HALF_TRUE -> {
try (BooleanVector.Builder builder = blockFactory.newBooleanVectorFixedBuilder(positionCount)) {
for (int i = 0; i < positionCount; i++) {
builder.appendBoolean(i % 2 == 0);
}
yield builder.build();
}
}
case CONSTANT_FALSE -> blockFactory.newConstantBooleanVector(false, positionCount);
default -> throw new IllegalArgumentException("unsupported filter [" + filter + "]");
};
}

@Benchmark
@OperationsPerInvocation(OP_COUNT * BLOCK_LENGTH)
public void run() {
run(grouping, op, blockType, OP_COUNT);
run(grouping, op, blockType, filter, OP_COUNT);
}

private static void run(String grouping, String op, String blockType, int opCount) {
private static void run(String grouping, String op, String blockType, String filter, int opCount) {
// System.err.printf("[%s][%s][%s][%s][%s]\n", grouping, op, blockType, filter, opCount);
String dataType = switch (blockType) {
case VECTOR_LONGS, HALF_NULL_LONGS -> LONGS;
case VECTOR_DOUBLES, HALF_NULL_DOUBLES -> DOUBLES;
default -> throw new IllegalArgumentException();
};

DriverContext driverContext = driverContext();
try (Operator operator = operator(driverContext, grouping, op, dataType)) {
try (Operator operator = operator(driverContext, grouping, op, dataType, filter)) {
Page page = page(driverContext.blockFactory(), grouping, blockType);
for (int i = 0; i < opCount; i++) {
operator.addInput(page.shallowCopy());
}
operator.finish();
checkExpected(grouping, op, blockType, dataType, operator.getOutput(), opCount);
checkExpected(grouping, op, blockType, filter, dataType, operator.getOutput(), opCount);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -353,14 +353,14 @@ private MethodSpec addRawInput() {
builder.addStatement("return");
builder.endControlFlow();
}
builder.beginControlFlow("if (mask.isConstant())");
builder.beginControlFlow("if (mask.allFalse())");
{
builder.addComment("Entire page masked away");
builder.addStatement("return");
}
builder.endControlFlow();
builder.beginControlFlow("if (mask.allTrue())");
{
builder.beginControlFlow("if (mask.getBoolean(0) == false)");
{
builder.addComment("Entire page masked away");
builder.addStatement("return");
}
builder.endControlFlow();
builder.addComment("No masking");
builder.addStatement("$T block = page.getBlock(channels.get(0))", valueBlockType(init, combine));
builder.addStatement("$T vector = block.asVector()", valueVectorType(init, combine));
Expand All @@ -372,6 +372,7 @@ private MethodSpec addRawInput() {
builder.addStatement("return");
}
builder.endControlFlow();

builder.addComment("Some positions masked away, others kept");
builder.addStatement("$T block = page.getBlock(channels.get(0))", valueBlockType(init, combine));
builder.addStatement("$T vector = block.asVector()", valueVectorType(init, combine));
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit a18b331

Please sign in to comment.