Skip to content

Commit

Permalink
ESQL: Syntax support and operator for count all (#99602)
Browse files Browse the repository at this point in the history
Introduce physical plan for representing query stats
Use internal aggs when pushing down count
Add support for count all outside Lucene
  • Loading branch information
costin authored Sep 30, 2023
1 parent ccc896d commit f883dd9
Show file tree
Hide file tree
Showing 26 changed files with 1,724 additions and 804 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ public static List<IntermediateStateDesc> intermediateStateDesc() {

private final LongState state;
private final List<Integer> channels;
private final boolean countAll;

public static CountAggregatorFunction create(List<Integer> inputChannels) {
return new CountAggregatorFunction(inputChannels, new LongState());
Expand All @@ -57,24 +58,32 @@ public static CountAggregatorFunction create(List<Integer> inputChannels) {
private CountAggregatorFunction(List<Integer> channels, LongState state) {
this.channels = channels;
this.state = state;
// no channels specified means count-all/count(*)
this.countAll = channels.isEmpty();
}

@Override
public int intermediateBlockCount() {
return intermediateStateDesc().size();
}

private int blockIndex() {
return countAll ? 0 : channels.get(0);
}

@Override
public void addRawInput(Page page) {
Block block = page.getBlock(channels.get(0));
Block block = page.getBlock(blockIndex());
LongState state = this.state;
state.longValue(state.longValue() + block.getTotalValueCount());
int count = countAll ? block.getPositionCount() : block.getTotalValueCount();
state.longValue(state.longValue() + count);
}

@Override
public void addIntermediateInput(Page page) {
assert channels.size() == intermediateBlockCount();
assert page.getBlockCount() >= channels.get(0) + intermediateStateDesc().size();
var blockIndex = blockIndex();
assert page.getBlockCount() >= blockIndex + intermediateStateDesc().size();
LongVector count = page.<LongBlock>getBlock(channels.get(0)).asVector();
BooleanVector seen = page.<BooleanBlock>getBlock(channels.get(1)).asVector();
assert count.getPositionCount() == 1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ public class CountGroupingAggregatorFunction implements GroupingAggregatorFuncti

private final LongArrayState state;
private final List<Integer> channels;
private final boolean countAll;

public static CountGroupingAggregatorFunction create(BigArrays bigArrays, List<Integer> inputChannels) {
return new CountGroupingAggregatorFunction(inputChannels, new LongArrayState(bigArrays, 0));
Expand All @@ -42,6 +43,11 @@ public static List<IntermediateStateDesc> intermediateStateDesc() {
private CountGroupingAggregatorFunction(List<Integer> channels, LongArrayState state) {
this.channels = channels;
this.state = state;
this.countAll = channels.isEmpty();
}

private int blockIndex() {
return countAll ? 0 : channels.get(0);
}

@Override
Expand All @@ -51,33 +57,35 @@ public int intermediateBlockCount() {

@Override
public AddInput prepareProcessPage(SeenGroupIds seenGroupIds, Page page) {
Block valuesBlock = page.getBlock(channels.get(0));
if (valuesBlock.areAllValuesNull()) {
state.enableGroupIdTracking(seenGroupIds);
return new AddInput() { // TODO return null meaning "don't collect me" and skip those
@Override
public void add(int positionOffset, IntBlock groupIds) {}

@Override
public void add(int positionOffset, IntVector groupIds) {}
};
}
Vector valuesVector = valuesBlock.asVector();
if (valuesVector == null) {
if (valuesBlock.mayHaveNulls()) {
Block valuesBlock = page.getBlock(blockIndex());
if (countAll == false) {
if (valuesBlock.areAllValuesNull()) {
state.enableGroupIdTracking(seenGroupIds);
}
return new AddInput() {
@Override
public void add(int positionOffset, IntBlock groupIds) {
addRawInput(positionOffset, groupIds, valuesBlock);
}
return new AddInput() { // TODO return null meaning "don't collect me" and skip those
@Override
public void add(int positionOffset, IntBlock groupIds) {}

@Override
public void add(int positionOffset, IntVector groupIds) {
addRawInput(positionOffset, groupIds, valuesBlock);
@Override
public void add(int positionOffset, IntVector groupIds) {}
};
}
Vector valuesVector = valuesBlock.asVector();
if (valuesVector == null) {
if (valuesBlock.mayHaveNulls()) {
state.enableGroupIdTracking(seenGroupIds);
}
};
return new AddInput() {
@Override
public void add(int positionOffset, IntBlock groupIds) {
addRawInput(positionOffset, groupIds, valuesBlock);
}

@Override
public void add(int positionOffset, IntVector groupIds) {
addRawInput(positionOffset, groupIds, valuesBlock);
}
};
}
}
return new AddInput() {
@Override
Expand Down Expand Up @@ -121,13 +129,19 @@ private void addRawInput(int positionOffset, IntBlock groups, Block values) {
}
}

/**
* This method is called for count all.
*/
private void addRawInput(IntVector groups) {
for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
int groupId = Math.toIntExact(groups.getInt(groupPosition));
state.increment(groupId, 1);
}
}

/**
* This method is called for count all.
*/
private void addRawInput(IntBlock groups) {
for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
// TODO remove the check one we don't emit null anymore
Expand All @@ -146,7 +160,7 @@ private void addRawInput(IntBlock groups) {
@Override
public void addIntermediateInput(int positionOffset, IntVector groups, Page page) {
assert channels.size() == intermediateBlockCount();
assert page.getBlockCount() >= channels.get(0) + intermediateStateDesc().size();
assert page.getBlockCount() >= blockIndex() + intermediateStateDesc().size();
state.enableGroupIdTracking(new SeenGroupIds.Empty());
LongVector count = page.<LongBlock>getBlock(channels.get(0)).asVector();
BooleanVector seen = page.<BooleanBlock>getBlock(channels.get(1)).asVector();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.compute.lucene;

import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.Scorable;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Weight;
import org.elasticsearch.compute.data.BooleanBlock;
import org.elasticsearch.compute.data.LongBlock;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.SourceOperator;
import org.elasticsearch.search.internal.SearchContext;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.List;
import java.util.function.Function;

/**
* Source operator that incrementally counts the results in Lucene searches
* Returns always one entry that mimics the Count aggregation internal state:
* 1. the count as a long (0 if no doc is seen)
* 2. a bool flag (seen) that's always true meaning that the group (all items) always exists
*/
public class LuceneCountOperator extends LuceneOperator {

private static final int PAGE_SIZE = 1;

private int totalHits = 0;
private int remainingDocs;

private final LeafCollector leafCollector;

public static class Factory implements LuceneOperator.Factory {
private final DataPartitioning dataPartitioning;
private final int taskConcurrency;
private final int limit;
private final LuceneSliceQueue sliceQueue;

public Factory(
List<SearchContext> searchContexts,
Function<SearchContext, Query> queryFunction,
DataPartitioning dataPartitioning,
int taskConcurrency,
int limit
) {
this.limit = limit;
this.dataPartitioning = dataPartitioning;
var weightFunction = weightFunction(queryFunction, ScoreMode.COMPLETE_NO_SCORES);
this.sliceQueue = LuceneSliceQueue.create(searchContexts, weightFunction, dataPartitioning, taskConcurrency);
this.taskConcurrency = Math.min(sliceQueue.totalSlices(), taskConcurrency);
}

@Override
public SourceOperator get(DriverContext driverContext) {
return new LuceneCountOperator(sliceQueue, limit);
}

@Override
public int taskConcurrency() {
return taskConcurrency;
}

public int limit() {
return limit;
}

@Override
public String describe() {
return "LuceneCountOperator[dataPartitioning = " + dataPartitioning + ", limit = " + limit + "]";
}
}

public LuceneCountOperator(LuceneSliceQueue sliceQueue, int limit) {
super(PAGE_SIZE, sliceQueue);
this.remainingDocs = limit;
this.leafCollector = new LeafCollector() {
@Override
public void setScorer(Scorable scorer) {}

@Override
public void collect(int doc) {
if (remainingDocs > 0) {
remainingDocs--;
totalHits++;
}
}
};
}

@Override
public boolean isFinished() {
return doneCollecting || remainingDocs == 0;
}

@Override
public void finish() {
doneCollecting = true;
}

@Override
public Page getOutput() {
if (isFinished()) {
assert remainingDocs <= 0 : remainingDocs;
return null;
}
try {
final LuceneScorer scorer = getCurrentOrLoadNextScorer();
// no scorer means no more docs
if (scorer == null) {
remainingDocs = 0;
} else {
Weight weight = scorer.weight();
var leafReaderContext = scorer.leafReaderContext();
// see org.apache.lucene.search.TotalHitCountCollector
int leafCount = weight == null ? -1 : weight.count(leafReaderContext);
if (leafCount != -1) {
// make sure to NOT multi count as the count _shortcut_ (which is segment wide)
// handle doc partitioning where the same leaf can be seen multiple times
// since the count is global, consider it only for the first partition and skip the rest
// SHARD, SEGMENT and the first DOC_ reader in data partitioning contain the first doc (position 0)
if (scorer.position() == 0) {
// check to not count over the desired number of docs/limit
var count = Math.min(leafCount, remainingDocs);
totalHits += count;
remainingDocs -= count;
scorer.markAsDone();
}
} else {
// could not apply shortcut, trigger the search
scorer.scoreNextRange(leafCollector, leafReaderContext.reader().getLiveDocs(), remainingDocs);
}
}

Page page = null;
// emit only one page
if (remainingDocs <= 0 && pagesEmitted == 0) {
pagesEmitted++;
page = new Page(
PAGE_SIZE,
LongBlock.newConstantBlockWith(totalHits, PAGE_SIZE),
BooleanBlock.newConstantBlockWith(true, PAGE_SIZE)
);
}
return page;
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}

@Override
protected void describe(StringBuilder sb) {
sb.append(", remainingDocs=").append(remainingDocs);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,7 @@ public interface Factory extends SourceOperator.SourceOperatorFactory {
}

@Override
public void close() {

}
public void close() {}

LuceneScorer getCurrentOrLoadNextScorer() {
while (currentScorer == null || currentScorer.isDone()) {
Expand Down Expand Up @@ -150,6 +148,14 @@ int shardIndex() {
SearchContext searchContext() {
return searchContext;
}

Weight weight() {
return weight;
}

int position() {
return position;
}
}

@Override
Expand Down
Loading

0 comments on commit f883dd9

Please sign in to comment.