diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java index dde7bc09ac615..8b45735458326 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java @@ -78,6 +78,7 @@ import org.elasticsearch.xpack.esql.plan.logical.Project; import org.elasticsearch.xpack.esql.plan.logical.Rename; import org.elasticsearch.xpack.esql.plan.logical.UnresolvedRelation; +import org.elasticsearch.xpack.esql.plan.logical.inference.Completion; import org.elasticsearch.xpack.esql.plan.logical.join.Join; import org.elasticsearch.xpack.esql.plan.logical.join.JoinConfig; import org.elasticsearch.xpack.esql.plan.logical.join.JoinType; @@ -465,9 +466,33 @@ protected LogicalPlan doRule(LogicalPlan plan) { return resolveLookupJoin(j); } + if (plan instanceof Completion p) { + return resolveCompletion(p, childrenOutput); + } + return plan.transformExpressionsOnly(UnresolvedAttribute.class, ua -> maybeResolveAttribute(ua, childrenOutput)); } + private LogicalPlan resolveCompletion(Completion p, List childrenOutput) { + Holder changed = new Holder<>(false); + + Expression resolvedPrompt = p.prompt().transformUp(UnresolvedAttribute.class, ua -> maybeResolveAttribute(ua, childrenOutput)); + if (resolvedPrompt != p.prompt()) { + changed.set(true); + } + + Expression resolvedInferenceId = p.inferenceId().transformUp(UnresolvedAttribute.class, ua -> maybeResolveAttribute(ua, childrenOutput)); + if (resolvedInferenceId != p.inferenceId()) { + changed.set(true); + } + + if (changed.get() == false) { + return p; + } + + return new Completion(p.source(), p.child(), p.target(), resolvedPrompt, resolvedInferenceId); + } + private Aggregate resolveAggregate(Aggregate aggregate, List childrenOutput) { // if the grouping is resolved but the aggs are not, use the former to resolve the latter // e.g. STATS a ... GROUP BY a = x + 1 diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/CompletionInferenceOperator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/CompletionInferenceOperator.java new file mode 100644 index 0000000000000..c2709c11f7cdd --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/CompletionInferenceOperator.java @@ -0,0 +1,58 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.inference; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.AsyncOperator; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.EvalOperator; +import org.elasticsearch.compute.operator.Operator; + +public class CompletionInferenceOperator extends AsyncOperator { + + // Move to a setting. + private static final int MAX_INFERENCE_WORKER = 1; + + public record Factory(EvalOperator.ExpressionEvaluator.Factory promptEvaluatorFactory) implements OperatorFactory { + public String describe() { + return "CompletionInferenceOperator[]"; + } + + + @Override + public Operator get(DriverContext driverContext) { + return new CompletionInferenceOperator(driverContext, promptEvaluatorFactory.get(driverContext)); + } + } + + private final EvalOperator.ExpressionEvaluator promptEvaluator; + + public CompletionInferenceOperator(DriverContext driverContext, EvalOperator.ExpressionEvaluator promptEvaluator) { + super(driverContext, MAX_INFERENCE_WORKER); + this.promptEvaluator = promptEvaluator; + } + + + @Override + protected void performAsync(Page inputPage, ActionListener listener) { + Block promptBlock = promptEvaluator.eval(inputPage); + listener.onResponse(inputPage.appendBlock(promptBlock)); + } + + @Override + protected void doClose() { + + } + + @Override + public String toString() { + return "CompletionInferenceOperator[]"; + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LogicalPlanBuilder.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LogicalPlanBuilder.java index 87292057b1504..6be316b2eff81 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LogicalPlanBuilder.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LogicalPlanBuilder.java @@ -25,6 +25,7 @@ import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute; import org.elasticsearch.xpack.esql.core.expression.NamedExpression; +import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute; import org.elasticsearch.xpack.esql.core.expression.UnresolvedAttribute; import org.elasticsearch.xpack.esql.core.expression.UnresolvedStar; import org.elasticsearch.xpack.esql.core.tree.Source; @@ -189,7 +190,7 @@ private void validateGrokPattern(Source source, Grok.Parser grokParser, String p public PlanFactory visitCompletionCommand(EsqlBaseParser.CompletionCommandContext ctx) { return p -> { Source source = source(ctx); - UnresolvedAttribute target = visitQualifiedName(ctx.target); + ReferenceAttribute target = new ReferenceAttribute(source(ctx.target), ctx.target.getText(), DataType.TEXT); return new Completion(source, p, target, expression(ctx.prompt), expression(ctx.inferenceId)); }; } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/PlanWritables.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/PlanWritables.java index b3c273cbfa1bb..24daca201a394 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/PlanWritables.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/PlanWritables.java @@ -22,6 +22,7 @@ import org.elasticsearch.xpack.esql.plan.logical.OrderBy; import org.elasticsearch.xpack.esql.plan.logical.Project; import org.elasticsearch.xpack.esql.plan.logical.TopN; +import org.elasticsearch.xpack.esql.plan.logical.inference.Completion; import org.elasticsearch.xpack.esql.plan.logical.join.InlineJoin; import org.elasticsearch.xpack.esql.plan.logical.join.Join; import org.elasticsearch.xpack.esql.plan.logical.local.EsqlProject; @@ -48,6 +49,7 @@ import org.elasticsearch.xpack.esql.plan.physical.ShowExec; import org.elasticsearch.xpack.esql.plan.physical.SubqueryExec; import org.elasticsearch.xpack.esql.plan.physical.TopNExec; +import org.elasticsearch.xpack.esql.plan.physical.inference.CompletionExec; import java.util.ArrayList; import java.util.List; @@ -65,6 +67,7 @@ public static List logical() { return List.of( Aggregate.ENTRY, Dissect.ENTRY, + Completion.ENTRY, Enrich.ENTRY, EsRelation.ENTRY, EsqlProject.ENTRY, @@ -88,6 +91,7 @@ public static List phsyical() { return List.of( AggregateExec.ENTRY, DissectExec.ENTRY, + CompletionExec.ENTRY, EnrichExec.ENTRY, EsQueryExec.ENTRY, EsSourceExec.ENTRY, diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/Completion.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/Completion.java index 701309d6d2556..0516d6369003d 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/Completion.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/Completion.java @@ -11,12 +11,11 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.xpack.esql.core.expression.Attribute; +import org.elasticsearch.xpack.esql.core.expression.AttributeSet; import org.elasticsearch.xpack.esql.core.expression.Expression; -import org.elasticsearch.xpack.esql.core.expression.NamedExpression; import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; -import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.expression.NamedExpressions; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; @@ -30,13 +29,11 @@ public class Completion extends UnaryPlan { public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(LogicalPlan.class, "Completion", Completion::new); - private final NamedExpression target; - + private final ReferenceAttribute target; private final Expression prompt; - private final Expression inferenceId; - public Completion(Source source, LogicalPlan child, NamedExpression target, Expression prompt, Expression inferenceId) { + public Completion(Source source, LogicalPlan child, ReferenceAttribute target, Expression prompt, Expression inferenceId) { super(source, child); this.target = target; this.prompt = prompt; @@ -47,7 +44,7 @@ public Completion(StreamInput in) throws IOException { this( Source.readFrom((PlanStreamInput) in), in.readNamedWriteable(LogicalPlan.class), - in.readNamedWriteable(NamedExpression.class), + in.readNamedWriteable(ReferenceAttribute.class), in.readNamedWriteable(Expression.class), in.readNamedWriteable(Expression.class) ); @@ -62,7 +59,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeNamedWriteable(inferenceId()); } - public NamedExpression target() { + public ReferenceAttribute target() { return target; } @@ -79,11 +76,6 @@ public String getWriteableName() { return ENTRY.name; } - @Override - public int hashCode() { - return Objects.hash(super.hashCode(), target, prompt, inferenceId); - } - @Override public String commandName() { return "COMPLETION"; @@ -91,7 +83,12 @@ public String commandName() { @Override public boolean expressionsResolved() { - return target.resolved() && prompt.resolved() && inferenceId.resolved(); + return prompt.resolved() && inferenceId.resolved(); + } + + @Override + protected AttributeSet computeReferences() { + return prompt.references(); } @Override @@ -101,7 +98,7 @@ public UnaryPlan replaceChild(LogicalPlan newChild) { @Override public List output() { - return NamedExpressions.mergeOutputAttributes(List.of( new ReferenceAttribute(Source.EMPTY, target.name(), DataType.TEXT)), child().output()); + return NamedExpressions.mergeOutputAttributes(List.of(target), child().output()); } @Override @@ -109,6 +106,11 @@ protected NodeInfo info() { return NodeInfo.create(this, Completion::new, child(), target, prompt, inferenceId); } + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), target, prompt, inferenceId); + } + @Override public boolean equals(Object obj) { if (false == super.equals(obj)) { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/CompletionExec.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/CompletionExec.java new file mode 100644 index 0000000000000..47acc82d76ae1 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/CompletionExec.java @@ -0,0 +1,110 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.plan.physical.inference; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.xpack.esql.core.expression.Attribute; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.NamedExpression; +import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute; +import org.elasticsearch.xpack.esql.core.tree.NodeInfo; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.expression.NamedExpressions; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; +import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan; +import org.elasticsearch.xpack.esql.plan.physical.UnaryExec; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +public class CompletionExec extends UnaryExec { + + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( + PhysicalPlan.class, + "CompileExec", + CompletionExec::new + ); + private final ReferenceAttribute target; + private final Expression prompt; + private final Expression inferenceId; + + public CompletionExec(Source source, PhysicalPlan child, ReferenceAttribute target, Expression prompt, Expression inferenceId) { + super(source, child); + this.target = target; + this.prompt = prompt; + this.inferenceId = inferenceId; + } + + public CompletionExec(StreamInput in) throws IOException { + this( + Source.readFrom((PlanStreamInput) in), + in.readNamedWriteable(PhysicalPlan.class), + in.readNamedWriteable(ReferenceAttribute.class), + in.readNamedWriteable(Expression.class), + in.readNamedWriteable(Expression.class) + ); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + Source.EMPTY.writeTo(out); + out.writeNamedWriteable(child()); + out.writeNamedWriteable(target()); + out.writeNamedWriteable(prompt()); + out.writeNamedWriteable(inferenceId()); + } + + public NamedExpression target() { + return target; + } + + public Expression prompt() { + return prompt; + } + + public Expression inferenceId() { + return inferenceId; + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + + @Override + public List output() { + return NamedExpressions.mergeOutputAttributes(List.of(target), child().output()); + } + + @Override + public UnaryExec replaceChild(PhysicalPlan newChild) { + return new CompletionExec(source(), newChild, target, prompt, inferenceId); + } + + @Override + protected NodeInfo info() { + return NodeInfo.create(this, CompletionExec::new, child(), target, prompt, inferenceId); + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), target, prompt, inferenceId); + } + + @Override + public boolean equals(Object obj) { + if (false == super.equals(obj)) { + return false; + } + CompletionExec other = ((CompletionExec) obj); + return Objects.equals(target, other.target) && Objects.equals(prompt, other.prompt) && Objects.equals(inferenceId, other.inferenceId); + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java index 1ffc652e54337..de7f2fd4abc85 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java @@ -68,6 +68,7 @@ import org.elasticsearch.xpack.esql.evaluator.EvalMapper; import org.elasticsearch.xpack.esql.evaluator.command.GrokEvaluatorExtracter; import org.elasticsearch.xpack.esql.expression.Order; +import org.elasticsearch.xpack.esql.inference.CompletionInferenceOperator; import org.elasticsearch.xpack.esql.plan.physical.AggregateExec; import org.elasticsearch.xpack.esql.plan.physical.DissectExec; import org.elasticsearch.xpack.esql.plan.physical.EnrichExec; @@ -90,6 +91,7 @@ import org.elasticsearch.xpack.esql.plan.physical.ProjectExec; import org.elasticsearch.xpack.esql.plan.physical.ShowExec; import org.elasticsearch.xpack.esql.plan.physical.TopNExec; +import org.elasticsearch.xpack.esql.plan.physical.inference.CompletionExec; import org.elasticsearch.xpack.esql.plugin.QueryPragmas; import org.elasticsearch.xpack.esql.session.Configuration; @@ -233,6 +235,10 @@ else if (node instanceof EnrichExec enrich) { } else if (node instanceof LookupJoinExec join) { return planLookupJoin(join, context); } + // inference + else if (node instanceof CompletionExec completion) { + return planCompletion(completion, context); + } // output else if (node instanceof OutputExec outputExec) { return planOutput(outputExec, context); @@ -419,6 +425,16 @@ private PhysicalOperation planEval(EvalExec eval, LocalExecutionPlannerContext c return source; } + private PhysicalOperation planCompletion(CompletionExec completion, LocalExecutionPlannerContext context) { + PhysicalOperation source = plan(completion.child(), context); + ExpressionEvaluator.Factory promptEvaluatorSupplier = EvalMapper.toEvaluator(completion.prompt(), source.layout); + + Layout.Builder layout = source.layout.builder(); + layout.append(completion.target()); + + return source.with(new CompletionInferenceOperator.Factory(promptEvaluatorSupplier), layout.build()); + } + private PhysicalOperation planDissect(DissectExec dissect, LocalExecutionPlannerContext context) { PhysicalOperation source = plan(dissect.child(), context); Layout.Builder layoutBuilder = source.layout.builder(); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/MapperUtils.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/MapperUtils.java index e881eabb38c43..8f1e0006af24e 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/MapperUtils.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/MapperUtils.java @@ -25,6 +25,7 @@ import org.elasticsearch.xpack.esql.plan.logical.MvExpand; import org.elasticsearch.xpack.esql.plan.logical.Project; import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan; +import org.elasticsearch.xpack.esql.plan.logical.inference.Completion; import org.elasticsearch.xpack.esql.plan.logical.local.LocalRelation; import org.elasticsearch.xpack.esql.plan.logical.show.ShowInfo; import org.elasticsearch.xpack.esql.plan.physical.AggregateExec; @@ -39,6 +40,7 @@ import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan; import org.elasticsearch.xpack.esql.plan.physical.ProjectExec; import org.elasticsearch.xpack.esql.plan.physical.ShowExec; +import org.elasticsearch.xpack.esql.plan.physical.inference.CompletionExec; import org.elasticsearch.xpack.esql.planner.AbstractPhysicalOperationProviders; import java.util.List; @@ -83,6 +85,11 @@ static PhysicalPlan mapUnary(UnaryPlan p, PhysicalPlan child) { return new GrokExec(grok.source(), child, grok.input(), grok.parser(), grok.extractedFields()); } + if (p instanceof Completion completion) { + return new CompletionExec(completion.source(), child, completion.target(), completion.prompt(), completion.inferenceId()); + } + + if (p instanceof Enrich enrich) { return new EnrichExec( enrich.source(),