From f883dd98566c1f8ffa34779c9949eaeb27596014 Mon Sep 17 00:00:00 2001
From: Costin Leau
Date: Sat, 30 Sep 2023 21:09:14 +0300
Subject: [PATCH] ESQL: Syntax support and operator for count all (#99602)
Introduce physical plan for representing query stats
Use internal aggs when pushing down count
Add support for count all outside Lucene
---
.../aggregation/CountAggregatorFunction.java | 15 +-
.../CountGroupingAggregatorFunction.java | 64 +-
.../compute/lucene/LuceneCountOperator.java | 163 ++
.../compute/lucene/LuceneOperator.java | 12 +-
.../lucene/LuceneCountOperatorTests.java | 155 ++
.../xpack/esql/qa/rest/EsqlSpecTestCase.java | 37 +-
.../xpack/esql/CsvTestUtils.java | 2 +-
.../src/main/resources/stats.csv-spec | 36 +
.../xpack/esql/action/EsqlActionIT.java | 43 +-
.../esql/src/main/antlr/EsqlBaseParser.g4 | 6 +-
.../function/EsqlFunctionRegistry.java | 4 +
.../optimizer/LocalPhysicalPlanOptimizer.java | 75 +
.../xpack/esql/parser/EsqlBaseLexer.java | 2 +-
.../xpack/esql/parser/EsqlBaseParser.interp | 3 +-
.../xpack/esql/parser/EsqlBaseParser.java | 1509 +++++++++--------
.../parser/EsqlBaseParserBaseListener.java | 12 +
.../parser/EsqlBaseParserBaseVisitor.java | 7 +
.../esql/parser/EsqlBaseParserListener.java | 18 +-
.../esql/parser/EsqlBaseParserVisitor.java | 10 +-
.../xpack/esql/parser/ExpressionBuilder.java | 17 +-
.../esql/plan/physical/EsStatsQueryExec.java | 128 ++
.../AbstractPhysicalOperationProviders.java | 34 +-
.../planner/EsPhysicalOperationProviders.java | 22 +-
.../esql/planner/LocalExecutionPlanner.java | 55 +-
.../optimizer/PhysicalPlanOptimizerTests.java | 93 +-
.../esql/tree/EsqlNodeSubclassTests.java | 6 +
26 files changed, 1724 insertions(+), 804 deletions(-)
create mode 100644 x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneCountOperator.java
create mode 100644 x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/LuceneCountOperatorTests.java
create mode 100644 x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/EsStatsQueryExec.java
diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/CountAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/CountAggregatorFunction.java
index 25ff4a2a3ab6..c9374b78ba5a 100644
--- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/CountAggregatorFunction.java
+++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/CountAggregatorFunction.java
@@ -49,6 +49,7 @@ public static List intermediateStateDesc() {
private final LongState state;
private final List channels;
+ private final boolean countAll;
public static CountAggregatorFunction create(List inputChannels) {
return new CountAggregatorFunction(inputChannels, new LongState());
@@ -57,6 +58,8 @@ public static CountAggregatorFunction create(List inputChannels) {
private CountAggregatorFunction(List channels, LongState state) {
this.channels = channels;
this.state = state;
+ // no channels specified means count-all/count(*)
+ this.countAll = channels.isEmpty();
}
@Override
@@ -64,17 +67,23 @@ public int intermediateBlockCount() {
return intermediateStateDesc().size();
}
+ private int blockIndex() {
+ return countAll ? 0 : channels.get(0);
+ }
+
@Override
public void addRawInput(Page page) {
- Block block = page.getBlock(channels.get(0));
+ Block block = page.getBlock(blockIndex());
LongState state = this.state;
- state.longValue(state.longValue() + block.getTotalValueCount());
+ int count = countAll ? block.getPositionCount() : block.getTotalValueCount();
+ state.longValue(state.longValue() + count);
}
@Override
public void addIntermediateInput(Page page) {
assert channels.size() == intermediateBlockCount();
- assert page.getBlockCount() >= channels.get(0) + intermediateStateDesc().size();
+ var blockIndex = blockIndex();
+ assert page.getBlockCount() >= blockIndex + intermediateStateDesc().size();
LongVector count = page.getBlock(channels.get(0)).asVector();
BooleanVector seen = page.getBlock(channels.get(1)).asVector();
assert count.getPositionCount() == 1;
diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/CountGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/CountGroupingAggregatorFunction.java
index 078e0cff99da..cc33a8de8bf6 100644
--- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/CountGroupingAggregatorFunction.java
+++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/CountGroupingAggregatorFunction.java
@@ -30,6 +30,7 @@ public class CountGroupingAggregatorFunction implements GroupingAggregatorFuncti
private final LongArrayState state;
private final List channels;
+ private final boolean countAll;
public static CountGroupingAggregatorFunction create(BigArrays bigArrays, List inputChannels) {
return new CountGroupingAggregatorFunction(inputChannels, new LongArrayState(bigArrays, 0));
@@ -42,6 +43,11 @@ public static List intermediateStateDesc() {
private CountGroupingAggregatorFunction(List channels, LongArrayState state) {
this.channels = channels;
this.state = state;
+ this.countAll = channels.isEmpty();
+ }
+
+ private int blockIndex() {
+ return countAll ? 0 : channels.get(0);
}
@Override
@@ -51,33 +57,35 @@ public int intermediateBlockCount() {
@Override
public AddInput prepareProcessPage(SeenGroupIds seenGroupIds, Page page) {
- Block valuesBlock = page.getBlock(channels.get(0));
- if (valuesBlock.areAllValuesNull()) {
- state.enableGroupIdTracking(seenGroupIds);
- return new AddInput() { // TODO return null meaning "don't collect me" and skip those
- @Override
- public void add(int positionOffset, IntBlock groupIds) {}
-
- @Override
- public void add(int positionOffset, IntVector groupIds) {}
- };
- }
- Vector valuesVector = valuesBlock.asVector();
- if (valuesVector == null) {
- if (valuesBlock.mayHaveNulls()) {
+ Block valuesBlock = page.getBlock(blockIndex());
+ if (countAll == false) {
+ if (valuesBlock.areAllValuesNull()) {
state.enableGroupIdTracking(seenGroupIds);
- }
- return new AddInput() {
- @Override
- public void add(int positionOffset, IntBlock groupIds) {
- addRawInput(positionOffset, groupIds, valuesBlock);
- }
+ return new AddInput() { // TODO return null meaning "don't collect me" and skip those
+ @Override
+ public void add(int positionOffset, IntBlock groupIds) {}
- @Override
- public void add(int positionOffset, IntVector groupIds) {
- addRawInput(positionOffset, groupIds, valuesBlock);
+ @Override
+ public void add(int positionOffset, IntVector groupIds) {}
+ };
+ }
+ Vector valuesVector = valuesBlock.asVector();
+ if (valuesVector == null) {
+ if (valuesBlock.mayHaveNulls()) {
+ state.enableGroupIdTracking(seenGroupIds);
}
- };
+ return new AddInput() {
+ @Override
+ public void add(int positionOffset, IntBlock groupIds) {
+ addRawInput(positionOffset, groupIds, valuesBlock);
+ }
+
+ @Override
+ public void add(int positionOffset, IntVector groupIds) {
+ addRawInput(positionOffset, groupIds, valuesBlock);
+ }
+ };
+ }
}
return new AddInput() {
@Override
@@ -121,6 +129,9 @@ private void addRawInput(int positionOffset, IntBlock groups, Block values) {
}
}
+ /**
+ * This method is called for count all.
+ */
private void addRawInput(IntVector groups) {
for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
int groupId = Math.toIntExact(groups.getInt(groupPosition));
@@ -128,6 +139,9 @@ private void addRawInput(IntVector groups) {
}
}
+ /**
+ * This method is called for count all.
+ */
private void addRawInput(IntBlock groups) {
for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
// TODO remove the check one we don't emit null anymore
@@ -146,7 +160,7 @@ private void addRawInput(IntBlock groups) {
@Override
public void addIntermediateInput(int positionOffset, IntVector groups, Page page) {
assert channels.size() == intermediateBlockCount();
- assert page.getBlockCount() >= channels.get(0) + intermediateStateDesc().size();
+ assert page.getBlockCount() >= blockIndex() + intermediateStateDesc().size();
state.enableGroupIdTracking(new SeenGroupIds.Empty());
LongVector count = page.getBlock(channels.get(0)).asVector();
BooleanVector seen = page.getBlock(channels.get(1)).asVector();
diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneCountOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneCountOperator.java
new file mode 100644
index 000000000000..e1e5b11c5b8c
--- /dev/null
+++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneCountOperator.java
@@ -0,0 +1,163 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.compute.lucene;
+
+import org.apache.lucene.search.LeafCollector;
+import org.apache.lucene.search.Query;
+import org.apache.lucene.search.Scorable;
+import org.apache.lucene.search.ScoreMode;
+import org.apache.lucene.search.Weight;
+import org.elasticsearch.compute.data.BooleanBlock;
+import org.elasticsearch.compute.data.LongBlock;
+import org.elasticsearch.compute.data.Page;
+import org.elasticsearch.compute.operator.DriverContext;
+import org.elasticsearch.compute.operator.SourceOperator;
+import org.elasticsearch.search.internal.SearchContext;
+
+import java.io.IOException;
+import java.io.UncheckedIOException;
+import java.util.List;
+import java.util.function.Function;
+
+/**
+ * Source operator that incrementally counts the results in Lucene searches
+ * Returns always one entry that mimics the Count aggregation internal state:
+ * 1. the count as a long (0 if no doc is seen)
+ * 2. a bool flag (seen) that's always true meaning that the group (all items) always exists
+ */
+public class LuceneCountOperator extends LuceneOperator {
+
+ private static final int PAGE_SIZE = 1;
+
+ private int totalHits = 0;
+ private int remainingDocs;
+
+ private final LeafCollector leafCollector;
+
+ public static class Factory implements LuceneOperator.Factory {
+ private final DataPartitioning dataPartitioning;
+ private final int taskConcurrency;
+ private final int limit;
+ private final LuceneSliceQueue sliceQueue;
+
+ public Factory(
+ List searchContexts,
+ Function queryFunction,
+ DataPartitioning dataPartitioning,
+ int taskConcurrency,
+ int limit
+ ) {
+ this.limit = limit;
+ this.dataPartitioning = dataPartitioning;
+ var weightFunction = weightFunction(queryFunction, ScoreMode.COMPLETE_NO_SCORES);
+ this.sliceQueue = LuceneSliceQueue.create(searchContexts, weightFunction, dataPartitioning, taskConcurrency);
+ this.taskConcurrency = Math.min(sliceQueue.totalSlices(), taskConcurrency);
+ }
+
+ @Override
+ public SourceOperator get(DriverContext driverContext) {
+ return new LuceneCountOperator(sliceQueue, limit);
+ }
+
+ @Override
+ public int taskConcurrency() {
+ return taskConcurrency;
+ }
+
+ public int limit() {
+ return limit;
+ }
+
+ @Override
+ public String describe() {
+ return "LuceneCountOperator[dataPartitioning = " + dataPartitioning + ", limit = " + limit + "]";
+ }
+ }
+
+ public LuceneCountOperator(LuceneSliceQueue sliceQueue, int limit) {
+ super(PAGE_SIZE, sliceQueue);
+ this.remainingDocs = limit;
+ this.leafCollector = new LeafCollector() {
+ @Override
+ public void setScorer(Scorable scorer) {}
+
+ @Override
+ public void collect(int doc) {
+ if (remainingDocs > 0) {
+ remainingDocs--;
+ totalHits++;
+ }
+ }
+ };
+ }
+
+ @Override
+ public boolean isFinished() {
+ return doneCollecting || remainingDocs == 0;
+ }
+
+ @Override
+ public void finish() {
+ doneCollecting = true;
+ }
+
+ @Override
+ public Page getOutput() {
+ if (isFinished()) {
+ assert remainingDocs <= 0 : remainingDocs;
+ return null;
+ }
+ try {
+ final LuceneScorer scorer = getCurrentOrLoadNextScorer();
+ // no scorer means no more docs
+ if (scorer == null) {
+ remainingDocs = 0;
+ } else {
+ Weight weight = scorer.weight();
+ var leafReaderContext = scorer.leafReaderContext();
+ // see org.apache.lucene.search.TotalHitCountCollector
+ int leafCount = weight == null ? -1 : weight.count(leafReaderContext);
+ if (leafCount != -1) {
+ // make sure to NOT multi count as the count _shortcut_ (which is segment wide)
+ // handle doc partitioning where the same leaf can be seen multiple times
+ // since the count is global, consider it only for the first partition and skip the rest
+ // SHARD, SEGMENT and the first DOC_ reader in data partitioning contain the first doc (position 0)
+ if (scorer.position() == 0) {
+ // check to not count over the desired number of docs/limit
+ var count = Math.min(leafCount, remainingDocs);
+ totalHits += count;
+ remainingDocs -= count;
+ scorer.markAsDone();
+ }
+ } else {
+ // could not apply shortcut, trigger the search
+ scorer.scoreNextRange(leafCollector, leafReaderContext.reader().getLiveDocs(), remainingDocs);
+ }
+ }
+
+ Page page = null;
+ // emit only one page
+ if (remainingDocs <= 0 && pagesEmitted == 0) {
+ pagesEmitted++;
+ page = new Page(
+ PAGE_SIZE,
+ LongBlock.newConstantBlockWith(totalHits, PAGE_SIZE),
+ BooleanBlock.newConstantBlockWith(true, PAGE_SIZE)
+ );
+ }
+ return page;
+ } catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+
+ @Override
+ protected void describe(StringBuilder sb) {
+ sb.append(", remainingDocs=").append(remainingDocs);
+ }
+}
diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneOperator.java
index e7ba2f0d5589..74baecf154fe 100644
--- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneOperator.java
+++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneOperator.java
@@ -59,9 +59,7 @@ public interface Factory extends SourceOperator.SourceOperatorFactory {
}
@Override
- public void close() {
-
- }
+ public void close() {}
LuceneScorer getCurrentOrLoadNextScorer() {
while (currentScorer == null || currentScorer.isDone()) {
@@ -150,6 +148,14 @@ int shardIndex() {
SearchContext searchContext() {
return searchContext;
}
+
+ Weight weight() {
+ return weight;
+ }
+
+ int position() {
+ return position;
+ }
}
@Override
diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/LuceneCountOperatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/LuceneCountOperatorTests.java
new file mode 100644
index 000000000000..9893cd2b2a02
--- /dev/null
+++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/LuceneCountOperatorTests.java
@@ -0,0 +1,155 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.compute.lucene;
+
+import org.apache.lucene.document.SortedNumericDocValuesField;
+import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.index.IndexableField;
+import org.apache.lucene.index.NoMergePolicy;
+import org.apache.lucene.search.IndexSearcher;
+import org.apache.lucene.search.MatchAllDocsQuery;
+import org.apache.lucene.search.Query;
+import org.apache.lucene.store.Directory;
+import org.apache.lucene.tests.index.RandomIndexWriter;
+import org.elasticsearch.common.util.BigArrays;
+import org.elasticsearch.compute.data.BooleanBlock;
+import org.elasticsearch.compute.data.LongBlock;
+import org.elasticsearch.compute.data.Page;
+import org.elasticsearch.compute.operator.AnyOperatorTestCase;
+import org.elasticsearch.compute.operator.Driver;
+import org.elasticsearch.compute.operator.DriverContext;
+import org.elasticsearch.compute.operator.OperatorTestCase;
+import org.elasticsearch.compute.operator.PageConsumerOperator;
+import org.elasticsearch.core.IOUtils;
+import org.elasticsearch.index.cache.query.TrivialQueryCachingPolicy;
+import org.elasticsearch.index.query.SearchExecutionContext;
+import org.elasticsearch.search.internal.ContextIndexSearcher;
+import org.elasticsearch.search.internal.SearchContext;
+import org.junit.After;
+
+import java.io.IOException;
+import java.io.UncheckedIOException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.function.Function;
+
+import static org.hamcrest.Matchers.hasSize;
+import static org.hamcrest.Matchers.is;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+public class LuceneCountOperatorTests extends AnyOperatorTestCase {
+ private Directory directory = newDirectory();
+ private IndexReader reader;
+
+ @After
+ public void closeIndex() throws IOException {
+ IOUtils.close(reader, directory);
+ }
+
+ @Override
+ protected LuceneCountOperator.Factory simple(BigArrays bigArrays) {
+ return simple(bigArrays, randomFrom(DataPartitioning.values()), between(1, 10_000), 100);
+ }
+
+ private LuceneCountOperator.Factory simple(BigArrays bigArrays, DataPartitioning dataPartitioning, int numDocs, int limit) {
+ int commitEvery = Math.max(1, numDocs / 10);
+ try (
+ RandomIndexWriter writer = new RandomIndexWriter(
+ random(),
+ directory,
+ newIndexWriterConfig().setMergePolicy(NoMergePolicy.INSTANCE)
+ )
+ ) {
+ for (int d = 0; d < numDocs; d++) {
+ List doc = new ArrayList<>();
+ doc.add(new SortedNumericDocValuesField("s", d));
+ writer.addDocument(doc);
+ if (d % commitEvery == 0) {
+ writer.commit();
+ }
+ }
+ reader = writer.getReader();
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+
+ SearchContext ctx = mockSearchContext(reader);
+ SearchExecutionContext ectx = mock(SearchExecutionContext.class);
+ when(ctx.getSearchExecutionContext()).thenReturn(ectx);
+ when(ectx.getIndexReader()).thenReturn(reader);
+ Function queryFunction = c -> new MatchAllDocsQuery();
+ return new LuceneCountOperator.Factory(List.of(ctx), queryFunction, dataPartitioning, 1, limit);
+ }
+
+ @Override
+ protected String expectedToStringOfSimple() {
+ assumeFalse("can't support variable maxPageSize", true); // TODO allow testing this
+ return "LuceneCountOperator[shardId=0, maxPageSize=**random**]";
+ }
+
+ @Override
+ protected String expectedDescriptionOfSimple() {
+ assumeFalse("can't support variable maxPageSize", true); // TODO allow testing this
+ return """
+ LuceneCountOperator[dataPartitioning = SHARD, maxPageSize = **random**, limit = 100, sorts = [{"s":{"order":"asc"}}]]""";
+ }
+
+ // TODO tests for the other data partitioning configurations
+
+ public void testShardDataPartitioning() {
+ int size = between(1_000, 20_000);
+ int limit = between(10, size);
+ testCount(size, limit);
+ }
+
+ public void testEmpty() {
+ testCount(0, between(10, 10_000));
+ }
+
+ private void testCount(int size, int limit) {
+ DriverContext ctx = driverContext();
+ LuceneCountOperator.Factory factory = simple(nonBreakingBigArrays(), DataPartitioning.SHARD, size, limit);
+
+ List results = new ArrayList<>();
+ OperatorTestCase.runDriver(new Driver(ctx, factory.get(ctx), List.of(), new PageConsumerOperator(results::add), () -> {}));
+ OperatorTestCase.assertDriverContext(ctx);
+
+ assertThat(results, hasSize(1));
+ Page page = results.get(0);
+
+ assertThat(page.getPositionCount(), is(1));
+ assertThat(page.getBlockCount(), is(2));
+ LongBlock lb = page.getBlock(0);
+ assertThat(lb.getPositionCount(), is(1));
+ assertThat(lb.getLong(0), is((long) Math.min(size, limit)));
+ BooleanBlock bb = page.getBlock(1);
+ assertThat(bb.getBoolean(1), is(true));
+ }
+
+ /**
+ * Creates a mock search context with the given index reader.
+ * The returned mock search context can be used to test with {@link LuceneOperator}.
+ */
+ public static SearchContext mockSearchContext(IndexReader reader) {
+ try {
+ ContextIndexSearcher searcher = new ContextIndexSearcher(
+ reader,
+ IndexSearcher.getDefaultSimilarity(),
+ IndexSearcher.getDefaultQueryCache(),
+ TrivialQueryCachingPolicy.NEVER,
+ true
+ );
+ SearchContext searchContext = mock(SearchContext.class);
+ when(searchContext.searcher()).thenReturn(searcher);
+ return searchContext;
+ } catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+}
diff --git a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/EsqlSpecTestCase.java b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/EsqlSpecTestCase.java
index 4fc89e11b7b6..776a2e732e5e 100644
--- a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/EsqlSpecTestCase.java
+++ b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/EsqlSpecTestCase.java
@@ -32,6 +32,7 @@
import static org.elasticsearch.test.MapMatcher.matchesMap;
import static org.elasticsearch.xpack.esql.CsvAssert.assertData;
import static org.elasticsearch.xpack.esql.CsvAssert.assertMetadata;
+import static org.elasticsearch.xpack.esql.CsvTestUtils.ExpectedResults;
import static org.elasticsearch.xpack.esql.CsvTestUtils.isEnabled;
import static org.elasticsearch.xpack.esql.CsvTestUtils.loadCsvSpecValues;
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.CSV_DATASET_MAP;
@@ -83,6 +84,10 @@ public static void wipeTestData() throws IOException {
}
}
+ public boolean logResults() {
+ return false;
+ }
+
public final void test() throws Throwable {
try {
assumeTrue("Test " + testName + " is not enabled", isEnabled(testName));
@@ -97,21 +102,29 @@ protected final void doTest() throws Throwable {
Map answer = runEsql(builder.query(testCase.query).build(), testCase.expectedWarnings);
var expectedColumnsWithValues = loadCsvSpecValues(testCase.expectedResults);
- assertNotNull(answer.get("columns"));
+ var metadata = answer.get("columns");
+ assertNotNull(metadata);
@SuppressWarnings("unchecked")
- var actualColumns = (List
*/
@Override public T visitDereference(EsqlBaseParser.DereferenceContext ctx) { return visitChildren(ctx); }
+ /**
+ * {@inheritDoc}
+ *
+ * The default implementation returns the result of calling
+ * {@link #visitChildren} on {@code ctx}.
+ */
+ @Override public T visitFunction(EsqlBaseParser.FunctionContext ctx) { return visitChildren(ctx); }
/**
* {@inheritDoc}
*
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserListener.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserListener.java
index 04f0d6da3dbe..dd6cdaacddbe 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserListener.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserListener.java
@@ -237,6 +237,18 @@ public interface EsqlBaseParserListener extends ParseTreeListener {
* @param ctx the parse tree
*/
void exitDereference(EsqlBaseParser.DereferenceContext ctx);
+ /**
+ * Enter a parse tree produced by the {@code function}
+ * labeled alternative in {@link EsqlBaseParser#primaryExpression}.
+ * @param ctx the parse tree
+ */
+ void enterFunction(EsqlBaseParser.FunctionContext ctx);
+ /**
+ * Exit a parse tree produced by the {@code function}
+ * labeled alternative in {@link EsqlBaseParser#primaryExpression}.
+ * @param ctx the parse tree
+ */
+ void exitFunction(EsqlBaseParser.FunctionContext ctx);
/**
* Enter a parse tree produced by the {@code parenthesizedExpression}
* labeled alternative in {@link EsqlBaseParser#primaryExpression}.
@@ -250,14 +262,12 @@ public interface EsqlBaseParserListener extends ParseTreeListener {
*/
void exitParenthesizedExpression(EsqlBaseParser.ParenthesizedExpressionContext ctx);
/**
- * Enter a parse tree produced by the {@code functionExpression}
- * labeled alternative in {@link EsqlBaseParser#primaryExpression}.
+ * Enter a parse tree produced by {@link EsqlBaseParser#functionExpression}.
* @param ctx the parse tree
*/
void enterFunctionExpression(EsqlBaseParser.FunctionExpressionContext ctx);
/**
- * Exit a parse tree produced by the {@code functionExpression}
- * labeled alternative in {@link EsqlBaseParser#primaryExpression}.
+ * Exit a parse tree produced by {@link EsqlBaseParser#functionExpression}.
* @param ctx the parse tree
*/
void exitFunctionExpression(EsqlBaseParser.FunctionExpressionContext ctx);
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserVisitor.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserVisitor.java
index 681de2590d57..35297f3d4f33 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserVisitor.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserVisitor.java
@@ -145,6 +145,13 @@ public interface EsqlBaseParserVisitor extends ParseTreeVisitor {
* @return the visitor result
*/
T visitDereference(EsqlBaseParser.DereferenceContext ctx);
+ /**
+ * Visit a parse tree produced by the {@code function}
+ * labeled alternative in {@link EsqlBaseParser#primaryExpression}.
+ * @param ctx the parse tree
+ * @return the visitor result
+ */
+ T visitFunction(EsqlBaseParser.FunctionContext ctx);
/**
* Visit a parse tree produced by the {@code parenthesizedExpression}
* labeled alternative in {@link EsqlBaseParser#primaryExpression}.
@@ -153,8 +160,7 @@ public interface EsqlBaseParserVisitor extends ParseTreeVisitor {
*/
T visitParenthesizedExpression(EsqlBaseParser.ParenthesizedExpressionContext ctx);
/**
- * Visit a parse tree produced by the {@code functionExpression}
- * labeled alternative in {@link EsqlBaseParser#primaryExpression}.
+ * Visit a parse tree produced by {@link EsqlBaseParser#functionExpression}.
* @param ctx the parse tree
* @return the visitor result
*/
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/ExpressionBuilder.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/ExpressionBuilder.java
index aa653d36d141..a7c8d6dd49cc 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/ExpressionBuilder.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/ExpressionBuilder.java
@@ -20,6 +20,7 @@
import org.elasticsearch.xpack.esql.evaluator.predicate.operator.regex.RLike;
import org.elasticsearch.xpack.esql.evaluator.predicate.operator.regex.WildcardLike;
import org.elasticsearch.xpack.esql.expression.Order;
+import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry;
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Add;
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Div;
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Mod;
@@ -62,6 +63,7 @@
import java.util.function.BiFunction;
import static java.util.Collections.emptyList;
+import static java.util.Collections.singletonList;
import static org.elasticsearch.xpack.esql.type.EsqlDataTypes.DATE_PERIOD;
import static org.elasticsearch.xpack.esql.type.EsqlDataTypes.TIME_DURATION;
import static org.elasticsearch.xpack.ql.parser.ParserUtils.source;
@@ -312,12 +314,15 @@ public UnresolvedAttribute visitDereference(EsqlBaseParser.DereferenceContext ct
@Override
public Expression visitFunctionExpression(EsqlBaseParser.FunctionExpressionContext ctx) {
- return new UnresolvedFunction(
- source(ctx),
- visitIdentifier(ctx.identifier()),
- FunctionResolutionStrategy.DEFAULT,
- ctx.booleanExpression().stream().map(this::expression).toList()
- );
+ String name = visitIdentifier(ctx.identifier());
+ List args = expressions(ctx.booleanExpression());
+ if ("count".equals(EsqlFunctionRegistry.normalizeName(name))) {
+ // to simplify the registration, handle in the parser the special count cases
+ if (args.isEmpty() || ctx.ASTERISK() != null) {
+ args = singletonList(new Literal(source(ctx), "*", DataTypes.KEYWORD));
+ }
+ }
+ return new UnresolvedFunction(source(ctx), name, FunctionResolutionStrategy.DEFAULT, args);
}
@Override
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/EsStatsQueryExec.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/EsStatsQueryExec.java
new file mode 100644
index 000000000000..8e65e66e3045
--- /dev/null
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/EsStatsQueryExec.java
@@ -0,0 +1,128 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.plan.physical;
+
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.index.query.QueryBuilder;
+import org.elasticsearch.xpack.ql.expression.Attribute;
+import org.elasticsearch.xpack.ql.expression.Expression;
+import org.elasticsearch.xpack.ql.index.EsIndex;
+import org.elasticsearch.xpack.ql.tree.NodeInfo;
+import org.elasticsearch.xpack.ql.tree.NodeUtils;
+import org.elasticsearch.xpack.ql.tree.Source;
+
+import java.util.List;
+import java.util.Objects;
+
+/**
+ * Specialized query class for retrieving statistics about the underlying data and not the actual documents.
+ * For that see {@link EsQueryExec}
+ */
+public class EsStatsQueryExec extends LeafExec implements EstimatesRowSize {
+
+ public enum StatsType {
+ COUNT,
+ MIN,
+ MAX,
+ EXISTS;
+ }
+
+ public record Stat(String name, StatsType type) {};
+
+ private final EsIndex index;
+ private final QueryBuilder query;
+ private final Expression limit;
+ private final List attrs;
+ private final List stats;
+
+ public EsStatsQueryExec(
+ Source source,
+ EsIndex index,
+ QueryBuilder query,
+ Expression limit,
+ List attributes,
+ List stats
+ ) {
+ super(source);
+ this.index = index;
+ this.query = query;
+ this.limit = limit;
+ this.attrs = attributes;
+ this.stats = stats;
+ }
+
+ @Override
+ protected NodeInfo info() {
+ return NodeInfo.create(this, EsStatsQueryExec::new, index, query, limit, attrs, stats);
+ }
+
+ public EsIndex index() {
+ return index;
+ }
+
+ public QueryBuilder query() {
+ return query;
+ }
+
+ @Override
+ public List output() {
+ return attrs;
+ }
+
+ public Expression limit() {
+ return limit;
+ }
+
+ @Override
+ // TODO - get the estimation outside the plan so it doesn't touch the plan
+ public PhysicalPlan estimateRowSize(State state) {
+ int size;
+ state.add(false, attrs);
+ size = state.consumeAllFields(false);
+ return this;
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(index, query, limit, attrs, stats);
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (this == obj) {
+ return true;
+ }
+
+ if (obj == null || getClass() != obj.getClass()) {
+ return false;
+ }
+
+ EsStatsQueryExec other = (EsStatsQueryExec) obj;
+ return Objects.equals(index, other.index)
+ && Objects.equals(attrs, other.attrs)
+ && Objects.equals(query, other.query)
+ && Objects.equals(limit, other.limit)
+ && Objects.equals(stats, other.stats);
+ }
+
+ @Override
+ public String nodeString() {
+ return nodeName()
+ + "["
+ + index
+ + "], stats"
+ + stats
+ + "], query["
+ + (query != null ? Strings.toString(query, false, true) : "")
+ + "]"
+ + NodeUtils.limitedToString(attrs)
+ + ", limit["
+ + (limit != null ? limit.toString() : "")
+ + "], ";
+ }
+}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java
index 0e984b3b85b0..113e4b91232a 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java
@@ -18,6 +18,7 @@
import org.elasticsearch.compute.operator.HashAggregationOperator.HashAggregationOperatorFactory;
import org.elasticsearch.compute.operator.Operator;
import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
+import org.elasticsearch.xpack.esql.expression.function.aggregate.Count;
import org.elasticsearch.xpack.esql.plan.physical.AggregateExec;
import org.elasticsearch.xpack.esql.planner.LocalExecutionPlanner.LocalExecutionPlannerContext;
import org.elasticsearch.xpack.esql.planner.LocalExecutionPlanner.PhysicalOperation;
@@ -35,7 +36,9 @@
import java.util.Set;
import java.util.function.Consumer;
-abstract class AbstractPhysicalOperationProviders implements PhysicalOperationProviders {
+import static java.util.Collections.emptyList;
+
+public abstract class AbstractPhysicalOperationProviders implements PhysicalOperationProviders {
private final AggregateMapper aggregateMapper = new AggregateMapper();
@@ -235,7 +238,30 @@ private void aggregatesToFactory(
if (mode == AggregateExec.Mode.PARTIAL) {
aggMode = AggregatorMode.INITIAL;
// TODO: this needs to be made more reliable - use casting to blow up when dealing with expressions (e+1)
- sourceAttr = List.of(Expressions.attribute(aggregateFunction.field()));
+ Expression field = aggregateFunction.field();
+ // Only count can now support literals - all the other aggs should be optimized away
+ if (field.foldable()) {
+ if (aggregateFunction instanceof Count count) {
+ sourceAttr = emptyList();
+ } else {
+ throw new EsqlIllegalArgumentException(
+ "Does not support yet aggregations over constants - [{}]",
+ aggregateFunction.sourceText()
+ );
+ }
+ } else {
+ Attribute attr = Expressions.attribute(field);
+ // cannot determine attribute
+ if (attr == null) {
+ throw new EsqlIllegalArgumentException(
+ "Cannot work with target field [{}] for agg [{}]",
+ field.sourceText(),
+ aggregateFunction.sourceText()
+ );
+ }
+ sourceAttr = List.of(attr);
+ }
+
} else if (mode == AggregateExec.Mode.FINAL) {
aggMode = AggregatorMode.FINAL;
if (grouping) {
@@ -253,7 +279,9 @@ private void aggregatesToFactory(
}
List inputChannels = sourceAttr.stream().map(attr -> layout.get(attr.id()).channel()).toList();
- assert inputChannels != null && inputChannels.size() > 0 && inputChannels.stream().allMatch(i -> i >= 0);
+ if (inputChannels.size() > 0) {
+ assert inputChannels.size() > 0 && inputChannels.stream().allMatch(i -> i >= 0);
+ }
if (aggregateFunction instanceof ToAggregator agg) {
consumer.accept(new AggFunctionSupplierContext(agg.supplier(bigArrays, inputChannels), aggMode));
} else {
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsPhysicalOperationProviders.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsPhysicalOperationProviders.java
index ac62c45d4d1f..ce5e277deaad 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsPhysicalOperationProviders.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsPhysicalOperationProviders.java
@@ -20,6 +20,8 @@
import org.elasticsearch.compute.operator.Operator;
import org.elasticsearch.compute.operator.OrdinalsGroupingOperator;
import org.elasticsearch.index.mapper.NestedLookup;
+import org.elasticsearch.index.query.QueryBuilder;
+import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.index.search.NestedHelper;
import org.elasticsearch.logging.LogManager;
@@ -54,6 +56,10 @@ public EsPhysicalOperationProviders(List searchContexts) {
this.searchContexts = searchContexts;
}
+ public List searchContexts() {
+ return searchContexts;
+ }
+
@Override
public final PhysicalOperation fieldExtractPhysicalOperation(FieldExtractExec fieldExtractExec, PhysicalOperation source) {
Layout.Builder layout = source.layout.builder();
@@ -85,12 +91,12 @@ public final PhysicalOperation fieldExtractPhysicalOperation(FieldExtractExec fi
return op;
}
- @Override
- public final PhysicalOperation sourcePhysicalOperation(EsQueryExec esQueryExec, LocalExecutionPlannerContext context) {
- final LuceneOperator.Factory luceneFactory;
- Function querySupplier = searchContext -> {
+ public static Function querySupplier(QueryBuilder queryBuilder) {
+ final QueryBuilder qb = queryBuilder == null ? QueryBuilders.matchAllQuery() : queryBuilder;
+
+ return searchContext -> {
SearchExecutionContext ctx = searchContext.getSearchExecutionContext();
- Query query = ctx.toQuery(esQueryExec.query()).query();
+ Query query = ctx.toQuery(qb).query();
NestedLookup nestedLookup = ctx.nestedLookup();
if (nestedLookup != NestedLookup.EMPTY) {
NestedHelper nestedHelper = new NestedHelper(nestedLookup, ctx::isFieldMapped);
@@ -110,6 +116,12 @@ public final PhysicalOperation sourcePhysicalOperation(EsQueryExec esQueryExec,
}
return query;
};
+ }
+
+ @Override
+ public final PhysicalOperation sourcePhysicalOperation(EsQueryExec esQueryExec, LocalExecutionPlannerContext context) {
+ Function querySupplier = querySupplier(esQueryExec.query());
+ final LuceneOperator.Factory luceneFactory;
List sorts = esQueryExec.sorts();
List> fieldSorts = null;
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java
index 18fad8cecb01..156b93e1551c 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java
@@ -7,6 +7,7 @@
package org.elasticsearch.xpack.esql.planner;
+import org.apache.lucene.search.Query;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.iterable.Iterables;
import org.elasticsearch.compute.Describable;
@@ -15,6 +16,8 @@
import org.elasticsearch.compute.data.ElementType;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.lucene.DataPartitioning;
+import org.elasticsearch.compute.lucene.LuceneCountOperator;
+import org.elasticsearch.compute.lucene.LuceneOperator;
import org.elasticsearch.compute.operator.ColumnExtractOperator;
import org.elasticsearch.compute.operator.Driver;
import org.elasticsearch.compute.operator.DriverContext;
@@ -43,7 +46,7 @@
import org.elasticsearch.compute.operator.topn.TopNOperator.TopNOperatorFactory;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.core.TimeValue;
-import org.elasticsearch.index.query.MatchAllQueryBuilder;
+import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
import org.elasticsearch.xpack.esql.enrich.EnrichLookupOperator;
@@ -54,6 +57,7 @@
import org.elasticsearch.xpack.esql.plan.physical.DissectExec;
import org.elasticsearch.xpack.esql.plan.physical.EnrichExec;
import org.elasticsearch.xpack.esql.plan.physical.EsQueryExec;
+import org.elasticsearch.xpack.esql.plan.physical.EsStatsQueryExec;
import org.elasticsearch.xpack.esql.plan.physical.EvalExec;
import org.elasticsearch.xpack.esql.plan.physical.ExchangeExec;
import org.elasticsearch.xpack.esql.plan.physical.ExchangeSinkExec;
@@ -96,6 +100,7 @@
import java.util.stream.Stream;
import static java.util.stream.Collectors.joining;
+import static org.elasticsearch.compute.lucene.LuceneOperator.NO_LIMIT;
import static org.elasticsearch.compute.operator.LimitOperator.Factory;
import static org.elasticsearch.compute.operator.ProjectOperator.ProjectOperatorFactory;
@@ -196,6 +201,8 @@ private PhysicalOperation plan(PhysicalPlan node, LocalExecutionPlannerContext c
// source nodes
else if (node instanceof EsQueryExec esQuery) {
return planEsQueryNode(esQuery, context);
+ } else if (node instanceof EsStatsQueryExec statsQuery) {
+ return planEsStats(statsQuery, context);
} else if (node instanceof RowExec row) {
return planRow(row, context);
} else if (node instanceof LocalSourceExec localSource) {
@@ -224,19 +231,33 @@ private PhysicalOperation planAggregation(AggregateExec aggregate, LocalExecutio
return physicalOperationProviders.groupingPhysicalOperation(aggregate, source, context);
}
- private PhysicalOperation planEsQueryNode(EsQueryExec esQuery, LocalExecutionPlannerContext context) {
- if (esQuery.query() == null) {
- esQuery = new EsQueryExec(
- esQuery.source(),
- esQuery.index(),
- esQuery.output(),
- new MatchAllQueryBuilder(),
- esQuery.limit(),
- esQuery.sorts(),
- esQuery.estimatedRowSize()
- );
+ private PhysicalOperation planEsQueryNode(EsQueryExec esQueryExec, LocalExecutionPlannerContext context) {
+ return physicalOperationProviders.sourcePhysicalOperation(esQueryExec, context);
+ }
+
+ private PhysicalOperation planEsStats(EsStatsQueryExec statsQuery, LocalExecutionPlannerContext context) {
+ if (physicalOperationProviders instanceof EsPhysicalOperationProviders == false) {
+ throw new EsqlIllegalArgumentException("EsStatsQuery should only occur against a Lucene backend");
}
- return physicalOperationProviders.sourcePhysicalOperation(esQuery, context);
+ EsPhysicalOperationProviders esProvider = (EsPhysicalOperationProviders) physicalOperationProviders;
+
+ Function querySupplier = EsPhysicalOperationProviders.querySupplier(statsQuery.query());
+
+ Expression limitExp = statsQuery.limit();
+ int limit = limitExp != null ? (Integer) limitExp.fold() : NO_LIMIT;
+ final LuceneOperator.Factory luceneFactory = new LuceneCountOperator.Factory(
+ esProvider.searchContexts(),
+ querySupplier,
+ context.dataPartitioning(),
+ context.taskConcurrency(),
+ limit
+ );
+
+ Layout.Builder layout = new Layout.Builder();
+ layout.append(statsQuery.outputSet());
+ int instanceCount = Math.max(1, luceneFactory.taskConcurrency());
+ context.driverParallelism(new DriverParallelism(DriverParallelism.Type.DATA_PARALLELISM, instanceCount));
+ return PhysicalOperation.fromSource(luceneFactory, layout.build());
}
private PhysicalOperation planFieldExtractNode(LocalExecutionPlannerContext context, FieldExtractExec fieldExtractExec) {
@@ -318,11 +339,11 @@ private PhysicalOperation planExchange(ExchangeExec exchangeExec, LocalExecution
private PhysicalOperation planExchangeSink(ExchangeSinkExec exchangeSink, LocalExecutionPlannerContext context) {
Objects.requireNonNull(exchangeSinkHandler, "ExchangeSinkHandler wasn't provided");
- PhysicalOperation source = plan(exchangeSink.child(), context);
+ var child = exchangeSink.child();
+ PhysicalOperation source = plan(child, context);
- Function transformer = exchangeSink.child() instanceof AggregateExec
- ? Function.identity()
- : alignPageToAttributes(exchangeSink.output(), source.layout);
+ boolean isAgg = child instanceof AggregateExec || child instanceof EsStatsQueryExec;
+ Function transformer = isAgg ? Function.identity() : alignPageToAttributes(exchangeSink.output(), source.layout);
return source.withSink(new ExchangeSinkOperatorFactory(exchangeSinkHandler::createExchangeSink, transformer), source.layout);
}
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizerTests.java
index 0bb43539dba7..e20ba72b82e5 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizerTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizerTests.java
@@ -14,12 +14,14 @@
import org.elasticsearch.core.Tuple;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
+import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.index.query.RangeQueryBuilder;
import org.elasticsearch.index.query.RegexpQueryBuilder;
import org.elasticsearch.index.query.TermQueryBuilder;
import org.elasticsearch.index.query.TermsQueryBuilder;
import org.elasticsearch.index.query.WildcardQueryBuilder;
import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.test.junit.annotations.TestLogging;
import org.elasticsearch.xpack.core.enrich.EnrichPolicy;
import org.elasticsearch.xpack.esql.EsqlTestUtils;
import org.elasticsearch.xpack.esql.analysis.Analyzer;
@@ -42,6 +44,7 @@
import org.elasticsearch.xpack.esql.plan.physical.EsQueryExec;
import org.elasticsearch.xpack.esql.plan.physical.EsQueryExec.FieldSort;
import org.elasticsearch.xpack.esql.plan.physical.EsSourceExec;
+import org.elasticsearch.xpack.esql.plan.physical.EsStatsQueryExec;
import org.elasticsearch.xpack.esql.plan.physical.EstimatesRowSize;
import org.elasticsearch.xpack.esql.plan.physical.EvalExec;
import org.elasticsearch.xpack.esql.plan.physical.ExchangeExec;
@@ -54,6 +57,7 @@
import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan;
import org.elasticsearch.xpack.esql.plan.physical.ProjectExec;
import org.elasticsearch.xpack.esql.plan.physical.TopNExec;
+import org.elasticsearch.xpack.esql.planner.FilterTests;
import org.elasticsearch.xpack.esql.planner.Mapper;
import org.elasticsearch.xpack.esql.planner.PhysicalVerificationException;
import org.elasticsearch.xpack.esql.planner.PlannerUtils;
@@ -91,6 +95,7 @@
import static org.elasticsearch.xpack.esql.EsqlTestUtils.loadMapping;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.statsForMissingField;
import static org.elasticsearch.xpack.esql.SerializationTestUtils.assertSerialization;
+import static org.elasticsearch.xpack.esql.plan.physical.AggregateExec.Mode.FINAL;
import static org.elasticsearch.xpack.ql.expression.Expressions.name;
import static org.elasticsearch.xpack.ql.expression.Expressions.names;
import static org.elasticsearch.xpack.ql.expression.Order.OrderDirection.ASC;
@@ -103,7 +108,7 @@
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.nullValue;
-//@TestLogging(value = "org.elasticsearch.xpack.esql.optimizer.LocalLogicalPlanOptimizer:TRACE", reason = "debug")
+@TestLogging(value = "org.elasticsearch.xpack.esql.optimizer.LocalLogicalPlanOptimizer:TRACE", reason = "debug")
public class PhysicalPlanOptimizerTests extends ESTestCase {
private static final String PARAM_FORMATTING = "%1$s";
@@ -1844,7 +1849,7 @@ public void testAvgSurrogateFunctionAfterRenameAndLimit() {
assertThat(limit.limit(), instanceOf(Literal.class));
assertThat(limit.limit().fold(), equalTo(10000));
var aggFinal = as(limit.child(), AggregateExec.class);
- assertThat(aggFinal.getMode(), equalTo(AggregateExec.Mode.FINAL));
+ assertThat(aggFinal.getMode(), equalTo(FINAL));
var aggPartial = as(aggFinal.child(), AggregateExec.class);
assertThat(aggPartial.getMode(), equalTo(AggregateExec.Mode.PARTIAL));
limit = as(aggPartial.child(), LimitExec.class);
@@ -1861,6 +1866,86 @@ public void testAvgSurrogateFunctionAfterRenameAndLimit() {
assertThat(source.limit().fold(), equalTo(10));
}
+ // optimized doesn't know yet how to push down count over field
+ public void testCountOneFieldWithFilter() {
+ var plan = optimizedPlan(physicalPlan("""
+ from test
+ | where salary > 1000
+ | stats c = count(salary)
+ """));
+ assertThat(plan.anyMatch(EsQueryExec.class::isInstance), is(true));
+ }
+
+ // optimized doesn't know yet how to push down count over field
+ public void testCountOneFieldWithFilterAndLimit() {
+ var plan = optimizedPlan(physicalPlan("""
+ from test
+ | where salary > 1000
+ | limit 10
+ | stats c = count(salary)
+ """));
+ assertThat(plan.anyMatch(EsQueryExec.class::isInstance), is(true));
+ }
+
+ // optimized doesn't know yet how to break down different multi count
+ public void testCountMultipleFieldsWithFilter() {
+ var plan = optimizedPlan(physicalPlan("""
+ from test
+ | where salary > 1000 and emp_no > 10010
+ | stats cs = count(salary), ce = count(emp_no)
+ """));
+ assertThat(plan.anyMatch(EsQueryExec.class::isInstance), is(true));
+ }
+
+ public void testCountAllWithFilter() {
+ var plan = optimizedPlan(physicalPlan("""
+ from test
+ | where emp_no > 10010
+ | stats c = count()
+ """));
+
+ var limit = as(plan, LimitExec.class);
+ var agg = as(limit.child(), AggregateExec.class);
+ assertThat(agg.getMode(), is(FINAL));
+ assertThat(Expressions.names(agg.aggregates()), contains("c"));
+ var exchange = as(agg.child(), ExchangeExec.class);
+ var esStatsQuery = as(exchange.child(), EsStatsQueryExec.class);
+ assertThat(esStatsQuery.limit(), is(nullValue()));
+ assertThat(Expressions.names(esStatsQuery.output()), contains("count", "seen"));
+ var expected = wrapWithSingleQuery(QueryBuilders.rangeQuery("emp_no").gt(10010), "emp_no");
+ assertThat(expected.toString(), is(esStatsQuery.query().toString()));
+ }
+
+ @AwaitsFix(bugUrl = "intermediateAgg does proper reduction but the agg itself does not - the optimizer needs to improve")
+ public void testMultiCountAllWithFilter() {
+ var plan = optimizedPlan(physicalPlan("""
+ from test
+ | where emp_no > 10010
+ | stats c = count(), call = count(*), c_literal = count(1)
+ """));
+
+ var limit = as(plan, LimitExec.class);
+ var agg = as(limit.child(), AggregateExec.class);
+ assertThat(agg.getMode(), is(FINAL));
+ assertThat(Expressions.names(agg.aggregates()), contains("c", "call", "c_literal"));
+ var exchange = as(agg.child(), ExchangeExec.class);
+ var esStatsQuery = as(exchange.child(), EsStatsQueryExec.class);
+ assertThat(esStatsQuery.limit(), is(nullValue()));
+ assertThat(Expressions.names(esStatsQuery.output()), contains("count", "seen"));
+ var expected = wrapWithSingleQuery(QueryBuilders.rangeQuery("emp_no").gt(10010), "emp_no");
+ assertThat(expected.toString(), is(esStatsQuery.query().toString()));
+ }
+
+ // optimized doesn't know yet how to break down different multi count
+ public void testCountFieldsAndAllWithFilter() {
+ var plan = optimizedPlan(physicalPlan("""
+ from test
+ | where emp_no > 10010
+ | stats c = count(), cs = count(salary), ce = count(emp_no)
+ """));
+ assertThat(plan.anyMatch(EsQueryExec.class::isInstance), is(true));
+ }
+
private static EsQueryExec source(PhysicalPlan plan) {
if (plan instanceof ExchangeExec exchange) {
plan = exchange.child();
@@ -1915,4 +2000,8 @@ private QueryBuilder sv(QueryBuilder builder, String fieldName) {
assertThat(sv.field(), equalTo(fieldName));
return sv.next();
}
+
+ private QueryBuilder wrapWithSingleQuery(QueryBuilder inner, String fieldName) {
+ return FilterTests.singleValueQuery(inner, fieldName);
+ }
}
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/tree/EsqlNodeSubclassTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/tree/EsqlNodeSubclassTests.java
index 937488d2ed54..640dd410d857 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/tree/EsqlNodeSubclassTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/tree/EsqlNodeSubclassTests.java
@@ -19,6 +19,8 @@
import org.elasticsearch.xpack.esql.plan.logical.Dissect;
import org.elasticsearch.xpack.esql.plan.logical.Grok;
import org.elasticsearch.xpack.esql.plan.physical.EsQueryExec;
+import org.elasticsearch.xpack.esql.plan.physical.EsStatsQueryExec.Stat;
+import org.elasticsearch.xpack.esql.plan.physical.EsStatsQueryExec.StatsType;
import org.elasticsearch.xpack.esql.plan.physical.OutputExec;
import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan;
import org.elasticsearch.xpack.esql.type.EsqlDataTypes;
@@ -97,6 +99,10 @@ protected Object pluggableMakeArg(Class extends Node>> toBuildClass, Class
),
IndexResolution.invalid(randomAlphaOfLength(5))
);
+
+ } else if (argClass == Stat.class) {
+ // record field
+ return new Stat(randomRealisticUnicodeOfLength(10), randomFrom(StatsType.values()));
} else if (argClass == Integer.class) {
return randomInt();
}