Skip to content

Commit

Permalink
Allow to plan the completion operation as an asynchronous operation.
Browse files Browse the repository at this point in the history
  • Loading branch information
afoucret committed Dec 2, 2024
1 parent 55aa259 commit e1f8e3f
Show file tree
Hide file tree
Showing 8 changed files with 239 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
import org.elasticsearch.xpack.esql.plan.logical.Project;
import org.elasticsearch.xpack.esql.plan.logical.Rename;
import org.elasticsearch.xpack.esql.plan.logical.UnresolvedRelation;
import org.elasticsearch.xpack.esql.plan.logical.inference.Completion;
import org.elasticsearch.xpack.esql.plan.logical.join.Join;
import org.elasticsearch.xpack.esql.plan.logical.join.JoinConfig;
import org.elasticsearch.xpack.esql.plan.logical.join.JoinType;
Expand Down Expand Up @@ -465,9 +466,33 @@ protected LogicalPlan doRule(LogicalPlan plan) {
return resolveLookupJoin(j);
}

if (plan instanceof Completion p) {
return resolveCompletion(p, childrenOutput);
}

return plan.transformExpressionsOnly(UnresolvedAttribute.class, ua -> maybeResolveAttribute(ua, childrenOutput));
}

private LogicalPlan resolveCompletion(Completion p, List<Attribute> childrenOutput) {
Holder<Boolean> changed = new Holder<>(false);

Expression resolvedPrompt = p.prompt().transformUp(UnresolvedAttribute.class, ua -> maybeResolveAttribute(ua, childrenOutput));
if (resolvedPrompt != p.prompt()) {
changed.set(true);
}

Expression resolvedInferenceId = p.inferenceId().transformUp(UnresolvedAttribute.class, ua -> maybeResolveAttribute(ua, childrenOutput));
if (resolvedInferenceId != p.inferenceId()) {
changed.set(true);
}

if (changed.get() == false) {
return p;
}

return new Completion(p.source(), p.child(), p.target(), resolvedPrompt, resolvedInferenceId);
}

private Aggregate resolveAggregate(Aggregate aggregate, List<Attribute> childrenOutput) {
// if the grouping is resolved but the aggs are not, use the former to resolve the latter
// e.g. STATS a ... GROUP BY a = x + 1
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.esql.inference;

import org.elasticsearch.action.ActionListener;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.AsyncOperator;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.EvalOperator;
import org.elasticsearch.compute.operator.Operator;

public class CompletionInferenceOperator extends AsyncOperator {

// Move to a setting.
private static final int MAX_INFERENCE_WORKER = 1;

public record Factory(EvalOperator.ExpressionEvaluator.Factory promptEvaluatorFactory) implements OperatorFactory {
public String describe() {
return "CompletionInferenceOperator[]";
}


@Override
public Operator get(DriverContext driverContext) {
return new CompletionInferenceOperator(driverContext, promptEvaluatorFactory.get(driverContext));
}
}

private final EvalOperator.ExpressionEvaluator promptEvaluator;

public CompletionInferenceOperator(DriverContext driverContext, EvalOperator.ExpressionEvaluator promptEvaluator) {
super(driverContext, MAX_INFERENCE_WORKER);
this.promptEvaluator = promptEvaluator;
}


@Override
protected void performAsync(Page inputPage, ActionListener<Page> listener) {
Block promptBlock = promptEvaluator.eval(inputPage);
listener.onResponse(inputPage.appendBlock(promptBlock));
}

@Override
protected void doClose() {

}

@Override
public String toString() {
return "CompletionInferenceOperator[]";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute;
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute;
import org.elasticsearch.xpack.esql.core.expression.UnresolvedAttribute;
import org.elasticsearch.xpack.esql.core.expression.UnresolvedStar;
import org.elasticsearch.xpack.esql.core.tree.Source;
Expand Down Expand Up @@ -189,7 +190,7 @@ private void validateGrokPattern(Source source, Grok.Parser grokParser, String p
public PlanFactory visitCompletionCommand(EsqlBaseParser.CompletionCommandContext ctx) {
return p -> {
Source source = source(ctx);
UnresolvedAttribute target = visitQualifiedName(ctx.target);
ReferenceAttribute target = new ReferenceAttribute(source(ctx.target), ctx.target.getText(), DataType.TEXT);
return new Completion(source, p, target, expression(ctx.prompt), expression(ctx.inferenceId));
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.elasticsearch.xpack.esql.plan.logical.OrderBy;
import org.elasticsearch.xpack.esql.plan.logical.Project;
import org.elasticsearch.xpack.esql.plan.logical.TopN;
import org.elasticsearch.xpack.esql.plan.logical.inference.Completion;
import org.elasticsearch.xpack.esql.plan.logical.join.InlineJoin;
import org.elasticsearch.xpack.esql.plan.logical.join.Join;
import org.elasticsearch.xpack.esql.plan.logical.local.EsqlProject;
Expand All @@ -48,6 +49,7 @@
import org.elasticsearch.xpack.esql.plan.physical.ShowExec;
import org.elasticsearch.xpack.esql.plan.physical.SubqueryExec;
import org.elasticsearch.xpack.esql.plan.physical.TopNExec;
import org.elasticsearch.xpack.esql.plan.physical.inference.CompletionExec;

import java.util.ArrayList;
import java.util.List;
Expand All @@ -65,6 +67,7 @@ public static List<NamedWriteableRegistry.Entry> logical() {
return List.of(
Aggregate.ENTRY,
Dissect.ENTRY,
Completion.ENTRY,
Enrich.ENTRY,
EsRelation.ENTRY,
EsqlProject.ENTRY,
Expand All @@ -88,6 +91,7 @@ public static List<NamedWriteableRegistry.Entry> phsyical() {
return List.of(
AggregateExec.ENTRY,
DissectExec.ENTRY,
CompletionExec.ENTRY,
EnrichExec.ENTRY,
EsQueryExec.ENTRY,
EsSourceExec.ENTRY,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,11 @@
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.AttributeSet;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.expression.NamedExpressions;
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
Expand All @@ -30,13 +29,11 @@ public class Completion extends UnaryPlan {

public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(LogicalPlan.class, "Completion", Completion::new);

private final NamedExpression target;

private final ReferenceAttribute target;
private final Expression prompt;

private final Expression inferenceId;

public Completion(Source source, LogicalPlan child, NamedExpression target, Expression prompt, Expression inferenceId) {
public Completion(Source source, LogicalPlan child, ReferenceAttribute target, Expression prompt, Expression inferenceId) {
super(source, child);
this.target = target;
this.prompt = prompt;
Expand All @@ -47,7 +44,7 @@ public Completion(StreamInput in) throws IOException {
this(
Source.readFrom((PlanStreamInput) in),
in.readNamedWriteable(LogicalPlan.class),
in.readNamedWriteable(NamedExpression.class),
in.readNamedWriteable(ReferenceAttribute.class),
in.readNamedWriteable(Expression.class),
in.readNamedWriteable(Expression.class)
);
Expand All @@ -62,7 +59,7 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeNamedWriteable(inferenceId());
}

public NamedExpression target() {
public ReferenceAttribute target() {
return target;
}

Expand All @@ -79,19 +76,19 @@ public String getWriteableName() {
return ENTRY.name;
}

@Override
public int hashCode() {
return Objects.hash(super.hashCode(), target, prompt, inferenceId);
}

@Override
public String commandName() {
return "COMPLETION";
}

@Override
public boolean expressionsResolved() {
return target.resolved() && prompt.resolved() && inferenceId.resolved();
return prompt.resolved() && inferenceId.resolved();
}

@Override
protected AttributeSet computeReferences() {
return prompt.references();
}

@Override
Expand All @@ -101,14 +98,19 @@ public UnaryPlan replaceChild(LogicalPlan newChild) {

@Override
public List<Attribute> output() {
return NamedExpressions.mergeOutputAttributes(List.of( new ReferenceAttribute(Source.EMPTY, target.name(), DataType.TEXT)), child().output());
return NamedExpressions.mergeOutputAttributes(List.of(target), child().output());
}

@Override
protected NodeInfo<? extends LogicalPlan> info() {
return NodeInfo.create(this, Completion::new, child(), target, prompt, inferenceId);
}

@Override
public int hashCode() {
return Objects.hash(super.hashCode(), target, prompt, inferenceId);
}

@Override
public boolean equals(Object obj) {
if (false == super.equals(obj)) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.esql.plan.physical.inference;

import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.expression.NamedExpressions;
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan;
import org.elasticsearch.xpack.esql.plan.physical.UnaryExec;

import java.io.IOException;
import java.util.List;
import java.util.Objects;

public class CompletionExec extends UnaryExec {

public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(
PhysicalPlan.class,
"CompileExec",
CompletionExec::new
);
private final ReferenceAttribute target;
private final Expression prompt;
private final Expression inferenceId;

public CompletionExec(Source source, PhysicalPlan child, ReferenceAttribute target, Expression prompt, Expression inferenceId) {
super(source, child);
this.target = target;
this.prompt = prompt;
this.inferenceId = inferenceId;
}

public CompletionExec(StreamInput in) throws IOException {
this(
Source.readFrom((PlanStreamInput) in),
in.readNamedWriteable(PhysicalPlan.class),
in.readNamedWriteable(ReferenceAttribute.class),
in.readNamedWriteable(Expression.class),
in.readNamedWriteable(Expression.class)
);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
Source.EMPTY.writeTo(out);
out.writeNamedWriteable(child());
out.writeNamedWriteable(target());
out.writeNamedWriteable(prompt());
out.writeNamedWriteable(inferenceId());
}

public NamedExpression target() {
return target;
}

public Expression prompt() {
return prompt;
}

public Expression inferenceId() {
return inferenceId;
}

@Override
public String getWriteableName() {
return ENTRY.name;
}

@Override
public List<Attribute> output() {
return NamedExpressions.mergeOutputAttributes(List.of(target), child().output());
}

@Override
public UnaryExec replaceChild(PhysicalPlan newChild) {
return new CompletionExec(source(), newChild, target, prompt, inferenceId);
}

@Override
protected NodeInfo<? extends PhysicalPlan> info() {
return NodeInfo.create(this, CompletionExec::new, child(), target, prompt, inferenceId);
}

@Override
public int hashCode() {
return Objects.hash(super.hashCode(), target, prompt, inferenceId);
}

@Override
public boolean equals(Object obj) {
if (false == super.equals(obj)) {
return false;
}
CompletionExec other = ((CompletionExec) obj);
return Objects.equals(target, other.target) && Objects.equals(prompt, other.prompt) && Objects.equals(inferenceId, other.inferenceId);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
import org.elasticsearch.xpack.esql.evaluator.EvalMapper;
import org.elasticsearch.xpack.esql.evaluator.command.GrokEvaluatorExtracter;
import org.elasticsearch.xpack.esql.expression.Order;
import org.elasticsearch.xpack.esql.inference.CompletionInferenceOperator;
import org.elasticsearch.xpack.esql.plan.physical.AggregateExec;
import org.elasticsearch.xpack.esql.plan.physical.DissectExec;
import org.elasticsearch.xpack.esql.plan.physical.EnrichExec;
Expand All @@ -90,6 +91,7 @@
import org.elasticsearch.xpack.esql.plan.physical.ProjectExec;
import org.elasticsearch.xpack.esql.plan.physical.ShowExec;
import org.elasticsearch.xpack.esql.plan.physical.TopNExec;
import org.elasticsearch.xpack.esql.plan.physical.inference.CompletionExec;
import org.elasticsearch.xpack.esql.plugin.QueryPragmas;
import org.elasticsearch.xpack.esql.session.Configuration;

Expand Down Expand Up @@ -233,6 +235,10 @@ else if (node instanceof EnrichExec enrich) {
} else if (node instanceof LookupJoinExec join) {
return planLookupJoin(join, context);
}
// inference
else if (node instanceof CompletionExec completion) {
return planCompletion(completion, context);
}
// output
else if (node instanceof OutputExec outputExec) {
return planOutput(outputExec, context);
Expand Down Expand Up @@ -419,6 +425,16 @@ private PhysicalOperation planEval(EvalExec eval, LocalExecutionPlannerContext c
return source;
}

private PhysicalOperation planCompletion(CompletionExec completion, LocalExecutionPlannerContext context) {
PhysicalOperation source = plan(completion.child(), context);
ExpressionEvaluator.Factory promptEvaluatorSupplier = EvalMapper.toEvaluator(completion.prompt(), source.layout);

Layout.Builder layout = source.layout.builder();
layout.append(completion.target());

return source.with(new CompletionInferenceOperator.Factory(promptEvaluatorSupplier), layout.build());
}

private PhysicalOperation planDissect(DissectExec dissect, LocalExecutionPlannerContext context) {
PhysicalOperation source = plan(dissect.child(), context);
Layout.Builder layoutBuilder = source.layout.builder();
Expand Down
Loading

0 comments on commit e1f8e3f

Please sign in to comment.