Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ESQL: Add DriverContext to the construction of Evaluators #99518

Merged
merged 8 commits into from
Sep 13, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -15,6 +15,7 @@
import org.elasticsearch.compute.data.LongBlock;
import org.elasticsearch.compute.data.LongVector;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.EvalOperator;
import org.elasticsearch.compute.operator.Operator;
import org.elasticsearch.core.TimeValue;
@@ -79,14 +80,14 @@ private static EvalOperator.ExpressionEvaluator evaluator(String operation) {
return switch (operation) {
case "abs" -> {
FieldAttribute longField = longField();
yield EvalMapper.toEvaluator(new Abs(Source.EMPTY, longField), layout(longField)).get();
yield EvalMapper.toEvaluator(new Abs(Source.EMPTY, longField), layout(longField)).get(new DriverContext());
}
case "add" -> {
FieldAttribute longField = longField();
yield EvalMapper.toEvaluator(
new Add(Source.EMPTY, longField, new Literal(Source.EMPTY, 1L, DataTypes.LONG)),
layout(longField)
).get();
).get(new DriverContext());
}
case "date_trunc" -> {
FieldAttribute timestamp = new FieldAttribute(
@@ -97,28 +98,28 @@ private static EvalOperator.ExpressionEvaluator evaluator(String operation) {
yield EvalMapper.toEvaluator(
new DateTrunc(Source.EMPTY, new Literal(Source.EMPTY, Duration.ofHours(24), EsqlDataTypes.TIME_DURATION), timestamp),
layout(timestamp)
).get();
).get(new DriverContext());
}
case "equal_to_const" -> {
FieldAttribute longField = longField();
yield EvalMapper.toEvaluator(
new Equals(Source.EMPTY, longField, new Literal(Source.EMPTY, 100_000L, DataTypes.LONG)),
layout(longField)
).get();
).get(new DriverContext());
}
case "long_equal_to_long" -> {
FieldAttribute lhs = longField();
FieldAttribute rhs = longField();
yield EvalMapper.toEvaluator(new Equals(Source.EMPTY, lhs, rhs), layout(lhs, rhs)).get();
yield EvalMapper.toEvaluator(new Equals(Source.EMPTY, lhs, rhs), layout(lhs, rhs)).get(new DriverContext());
}
case "long_equal_to_int" -> {
FieldAttribute lhs = longField();
FieldAttribute rhs = intField();
yield EvalMapper.toEvaluator(new Equals(Source.EMPTY, lhs, rhs), layout(lhs, rhs)).get();
yield EvalMapper.toEvaluator(new Equals(Source.EMPTY, lhs, rhs), layout(lhs, rhs)).get(new DriverContext());
}
case "mv_min", "mv_min_ascending" -> {
FieldAttribute longField = longField();
yield EvalMapper.toEvaluator(new MvMin(Source.EMPTY, longField), layout(longField)).get();
yield EvalMapper.toEvaluator(new MvMin(Source.EMPTY, longField), layout(longField)).get(new DriverContext());
}
default -> throw new UnsupportedOperationException();
};
Original file line number Diff line number Diff line change
@@ -34,6 +34,7 @@
import static org.elasticsearch.compute.gen.Methods.getMethod;
import static org.elasticsearch.compute.gen.Types.BLOCK;
import static org.elasticsearch.compute.gen.Types.BYTES_REF;
import static org.elasticsearch.compute.gen.Types.DRIVER_CONTEXT;
import static org.elasticsearch.compute.gen.Types.EXPRESSION_EVALUATOR;
import static org.elasticsearch.compute.gen.Types.PAGE;
import static org.elasticsearch.compute.gen.Types.SOURCE;
@@ -77,6 +78,7 @@ private TypeSpec type() {
builder.addField(WARNINGS, "warnings", Modifier.PRIVATE, Modifier.FINAL);
}
processFunction.args.stream().forEach(a -> a.declareField(builder));
builder.addField(DRIVER_CONTEXT, "driverContext", Modifier.PRIVATE, Modifier.FINAL);

builder.addMethod(ctor());
builder.addMethod(eval());
@@ -95,6 +97,8 @@ private MethodSpec ctor() {
builder.addStatement("this.warnings = new Warnings(source)");
}
processFunction.args.stream().forEach(a -> a.implementCtor(builder));
builder.addParameter(DRIVER_CONTEXT, "driverContext");
builder.addStatement("this.driverContext = driverContext");
return builder.build();
}

Original file line number Diff line number Diff line change
@@ -35,6 +35,7 @@
import static org.elasticsearch.compute.gen.Types.BLOCK;
import static org.elasticsearch.compute.gen.Types.BYTES_REF;
import static org.elasticsearch.compute.gen.Types.BYTES_REF_ARRAY;
import static org.elasticsearch.compute.gen.Types.DRIVER_CONTEXT;
import static org.elasticsearch.compute.gen.Types.EXPRESSION_EVALUATOR;
import static org.elasticsearch.compute.gen.Types.SOURCE;
import static org.elasticsearch.compute.gen.Types.VECTOR;
@@ -129,6 +130,7 @@ private TypeSpec type() {

builder.addField(WARNINGS, "warnings", Modifier.PRIVATE, Modifier.FINAL);
}
builder.addField(DRIVER_CONTEXT, "driverContext", Modifier.PRIVATE, Modifier.FINAL);

builder.addMethod(ctor());
builder.addMethod(name());
@@ -159,6 +161,8 @@ private MethodSpec ctor() {
if (warnExceptions.isEmpty() == false) {
builder.addStatement("this.warnings = new Warnings(source)");
}
builder.addParameter(DRIVER_CONTEXT, "driverContext");
builder.addStatement("this.driverContext = driverContext");
return builder.build();
}

Original file line number Diff line number Diff line change
@@ -88,6 +88,8 @@ public class Types {
static final ClassName INTERMEDIATE_STATE_DESC = ClassName.get(AGGREGATION_PACKAGE, "IntermediateStateDesc");
static final TypeName LIST_AGG_FUNC_DESC = ParameterizedTypeName.get(ClassName.get(List.class), INTERMEDIATE_STATE_DESC);

static final ClassName DRIVER_CONTEXT = ClassName.get(OPERATOR_PACKAGE, "DriverContext");

static final ClassName EXPRESSION_EVALUATOR = ClassName.get(OPERATOR_PACKAGE, "EvalOperator", "ExpressionEvaluator");
static final ClassName ABSTRACT_MULTIVALUE_FUNCTION_EVALUATOR = ClassName.get(
"org.elasticsearch.xpack.esql.expression.function.scalar.multivalue",
Original file line number Diff line number Diff line change
@@ -12,20 +12,21 @@
import org.elasticsearch.compute.data.BytesRefBlock;
import org.elasticsearch.compute.data.ElementType;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;

import java.util.function.Supplier;

public class ColumnExtractOperator extends AbstractPageMappingOperator {

public record Factory(
ElementType[] types,
Supplier<EvalOperator.ExpressionEvaluator> inputEvalSupplier,
ExpressionEvaluator.Factory inputEvalSupplier,
Supplier<ColumnExtractOperator.Evaluator> evaluatorSupplier
) implements OperatorFactory {

@Override
public Operator get(DriverContext driverContext) {
return new ColumnExtractOperator(types, inputEvalSupplier.get(), evaluatorSupplier.get());
return new ColumnExtractOperator(types, inputEvalSupplier.get(driverContext), evaluatorSupplier.get());
}

@Override
Original file line number Diff line number Diff line change
@@ -7,10 +7,12 @@

package org.elasticsearch.compute.operator;

import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.core.Releasable;

import java.util.Collections;
import java.util.IdentityHashMap;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.atomic.AtomicReference;

@@ -33,11 +35,30 @@
*/
public class DriverContext {

/** A default driver context. The returned bigArrays is non recycling. */
public static DriverContext DEFAULT = new DriverContext(BigArrays.NON_RECYCLING_INSTANCE);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not really a "default" - it's more of a "don't use this in production" kind of thing.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or, like, I guess a "temporary" kind of thing.


// Working set. Only the thread executing the driver will update this set.
Set<Releasable> workingSet = Collections.newSetFromMap(new IdentityHashMap<>());

private final AtomicReference<Snapshot> snapshot = new AtomicReference<>();

private final BigArrays bigArrays;

// For testing
public DriverContext() {
this(BigArrays.NON_RECYCLING_INSTANCE);
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know it's a pain, but I'd prefer not to have this in our production code. It'd be super ok on some test class, but it's too tempting.


public DriverContext(BigArrays bigArrays) {
Objects.requireNonNull(bigArrays);
this.bigArrays = bigArrays;
}

public BigArrays bigArrays() {
return bigArrays;
}

/** A snapshot of the driver context. */
public record Snapshot(Set<Releasable> releasables) {}

Original file line number Diff line number Diff line change
@@ -10,24 +10,22 @@
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.Page;

import java.util.function.Supplier;

/**
* Evaluates a tree of functions for every position in the block, resulting in a
* new block which is appended to the page.
*/
public class EvalOperator extends AbstractPageMappingOperator {

public record EvalOperatorFactory(Supplier<ExpressionEvaluator> evaluator) implements OperatorFactory {
public record EvalOperatorFactory(ExpressionEvaluator.Factory evaluator) implements OperatorFactory {

@Override
public Operator get(DriverContext driverContext) {
return new EvalOperator(evaluator.get());
return new EvalOperator(evaluator.get(driverContext));
}

@Override
public String describe() {
return "EvalOperator[evaluator=" + evaluator.get() + "]";
return "EvalOperator[evaluator=" + evaluator.get(DriverContext.DEFAULT) + "]";
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're generating factory could we generate the toString in it so we don't have to pass the driver context? I just want to make sure we don't accidentally do stuff with the context.

}
}

@@ -48,6 +46,12 @@ public String toString() {
}

public interface ExpressionEvaluator {

/** A Factory for creating ExpressionEvaluators. */
interface Factory {
ExpressionEvaluator get(DriverContext driverContext);
ChrisHegarty marked this conversation as resolved.
Show resolved Hide resolved
}

Block eval(Page page);
}

Original file line number Diff line number Diff line change
@@ -10,24 +10,24 @@
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BooleanBlock;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;

import java.util.Arrays;
import java.util.function.Supplier;

public class FilterOperator extends AbstractPageMappingOperator {

private final EvalOperator.ExpressionEvaluator evaluator;

public record FilterOperatorFactory(Supplier<EvalOperator.ExpressionEvaluator> evaluatorSupplier) implements OperatorFactory {
public record FilterOperatorFactory(ExpressionEvaluator.Factory evaluatorSupplier) implements OperatorFactory {

@Override
public Operator get(DriverContext driverContext) {
return new FilterOperator(evaluatorSupplier.get());
return new FilterOperator(evaluatorSupplier.get(driverContext));
}

@Override
public String describe() {
return "FilterOperator[evaluator=" + evaluatorSupplier.get() + "]";
return "FilterOperator[evaluator=" + evaluatorSupplier.get(DriverContext.DEFAULT) + "]";
}
}

Original file line number Diff line number Diff line change
@@ -15,8 +15,7 @@
import org.elasticsearch.compute.data.IntBlock;
import org.elasticsearch.compute.data.LongBlock;
import org.elasticsearch.compute.data.Page;

import java.util.function.Supplier;
import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;

/**
* Utilities to remove duplicates from multivalued fields.
@@ -77,42 +76,39 @@ public static Block dedupeToBlockUsingCopyAndSort(Block block) {
* Build and {@link EvalOperator.ExpressionEvaluator} that deduplicates values
* using an adaptive algorithm based on the size of the input list.
*/
public static Supplier<EvalOperator.ExpressionEvaluator> evaluator(
ElementType elementType,
Supplier<EvalOperator.ExpressionEvaluator> nextSupplier
) {
public static ExpressionEvaluator.Factory evaluator(ElementType elementType, ExpressionEvaluator.Factory nextSupplier) {
return switch (elementType) {
case BOOLEAN -> () -> new MvDedupeEvaluator(nextSupplier.get()) {
case BOOLEAN -> dvrCtx -> new MvDedupeEvaluator(nextSupplier.get(dvrCtx)) {
@Override
public Block eval(Page page) {
return new MultivalueDedupeBoolean((BooleanBlock) field.eval(page)).dedupeToBlock();
}
};
case BYTES_REF -> () -> new MvDedupeEvaluator(nextSupplier.get()) {
case BYTES_REF -> dvrCtx -> new MvDedupeEvaluator(nextSupplier.get(dvrCtx)) {
@Override
public Block eval(Page page) {
return new MultivalueDedupeBytesRef((BytesRefBlock) field.eval(page)).dedupeToBlockAdaptive();
}
};
case INT -> () -> new MvDedupeEvaluator(nextSupplier.get()) {
case INT -> dvrCtx -> new MvDedupeEvaluator(nextSupplier.get(dvrCtx)) {
@Override
public Block eval(Page page) {
return new MultivalueDedupeInt((IntBlock) field.eval(page)).dedupeToBlockAdaptive();
}
};
case LONG -> () -> new MvDedupeEvaluator(nextSupplier.get()) {
case LONG -> dvrCtx -> new MvDedupeEvaluator(nextSupplier.get(dvrCtx)) {
@Override
public Block eval(Page page) {
return new MultivalueDedupeLong((LongBlock) field.eval(page)).dedupeToBlockAdaptive();
}
};
case DOUBLE -> () -> new MvDedupeEvaluator(nextSupplier.get()) {
case DOUBLE -> dvrCtx -> new MvDedupeEvaluator(nextSupplier.get(dvrCtx)) {
@Override
public Block eval(Page page) {
return new MultivalueDedupeDouble((DoubleBlock) field.eval(page)).dedupeToBlockAdaptive();
}
};
case NULL -> () -> new MvDedupeEvaluator(nextSupplier.get()) {
case NULL -> dvrCtx -> new MvDedupeEvaluator(nextSupplier.get(dvrCtx)) {
@Override
public Block eval(Page page) {
return field.eval(page); // The page is all nulls and when you dedupe that it's still all nulls
Original file line number Diff line number Diff line change
@@ -13,6 +13,7 @@
import org.elasticsearch.compute.data.BytesRefBlock;
import org.elasticsearch.compute.data.ElementType;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;

import java.util.Arrays;
import java.util.Map;
@@ -24,13 +25,13 @@ public class StringExtractOperator extends AbstractPageMappingOperator {

public record StringExtractOperatorFactory(
String[] fieldNames,
Supplier<EvalOperator.ExpressionEvaluator> expressionEvaluator,
ExpressionEvaluator.Factory expressionEvaluator,
Supplier<Function<String, Map<String, String>>> parserSupplier
) implements OperatorFactory {

@Override
public Operator get(DriverContext driverContext) {
return new StringExtractOperator(fieldNames, expressionEvaluator.get(), parserSupplier.get());
return new StringExtractOperator(fieldNames, expressionEvaluator.get(driverContext), parserSupplier.get());
}

@Override
Original file line number Diff line number Diff line change
@@ -47,7 +47,7 @@ public String toString() {
@Override
protected Operator.OperatorFactory simple(BigArrays bigArrays) {
Supplier<ColumnExtractOperator.Evaluator> expEval = () -> new FirstWord(0);
return new ColumnExtractOperator.Factory(new ElementType[] { ElementType.BYTES_REF }, () -> page -> page.getBlock(0), expEval);
return new ColumnExtractOperator.Factory(new ElementType[] { ElementType.BYTES_REF }, dvrCtx -> page -> page.getBlock(0), expEval);
}

@Override
Original file line number Diff line number Diff line change
@@ -40,7 +40,7 @@ public Block eval(Page page) {

@Override
protected Operator.OperatorFactory simple(BigArrays bigArrays) {
return new EvalOperator.EvalOperatorFactory(() -> new Addition(0, 1));
return new EvalOperator.EvalOperatorFactory(dvrCtx -> new Addition(0, 1));
}

@Override
Original file line number Diff line number Diff line change
@@ -42,7 +42,7 @@ public Block eval(Page page) {

@Override
protected Operator.OperatorFactory simple(BigArrays bigArrays) {
return new FilterOperator.FilterOperatorFactory(() -> new SameLastDigit(0, 1));
return new FilterOperator.FilterOperatorFactory(dvrCtx -> new SameLastDigit(0, 1));
}

@Override
Original file line number Diff line number Diff line change
@@ -42,7 +42,7 @@ public Map<String, String> apply(String s) {
@Override
protected Operator.OperatorFactory simple(BigArrays bigArrays) {
Supplier<Function<String, Map<String, String>>> expEval = () -> new FirstWord("test");
return new StringExtractOperator.StringExtractOperatorFactory(new String[] { "test" }, () -> page -> page.getBlock(0), expEval);
return new StringExtractOperator.StringExtractOperatorFactory(new String[] { "test" }, dvrCtx -> page -> page.getBlock(0), expEval);
}

@Override
Original file line number Diff line number Diff line change
@@ -10,6 +10,7 @@
import org.elasticsearch.compute.data.BooleanBlock;
import org.elasticsearch.compute.data.BooleanVector;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.EvalOperator;

/**
@@ -21,10 +22,13 @@ public final class EqualsBoolsEvaluator implements EvalOperator.ExpressionEvalua

private final EvalOperator.ExpressionEvaluator rhs;

private final DriverContext driverContext;

public EqualsBoolsEvaluator(EvalOperator.ExpressionEvaluator lhs,
EvalOperator.ExpressionEvaluator rhs) {
EvalOperator.ExpressionEvaluator rhs, DriverContext driverContext) {
this.lhs = lhs;
this.rhs = rhs;
this.driverContext = driverContext;
}

@Override
Loading