Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ESQL: Syntax support and operator for count all #99602

Merged
merged 15 commits into from
Sep 30, 2023
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ public static List<IntermediateStateDesc> intermediateStateDesc() {

private final LongState state;
private final List<Integer> channels;
private final boolean countAll;

public static CountAggregatorFunction create(List<Integer> inputChannels) {
return new CountAggregatorFunction(inputChannels, new LongState());
Expand All @@ -57,24 +58,31 @@ public static CountAggregatorFunction create(List<Integer> inputChannels) {
private CountAggregatorFunction(List<Integer> channels, LongState state) {
this.channels = channels;
this.state = state;
this.countAll = channels.isEmpty();
costin marked this conversation as resolved.
Show resolved Hide resolved
}

@Override
public int intermediateBlockCount() {
return intermediateStateDesc().size();
}

private int blockIndex() {
return countAll ? 0 : channels.get(0);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nik9000 not sure if this is the best way to count things in a block but it seems to be working.

}

@Override
public void addRawInput(Page page) {
Block block = page.getBlock(channels.get(0));
Block block = page.getBlock(blockIndex());
LongState state = this.state;
state.longValue(state.longValue() + block.getTotalValueCount());
int count = countAll ? block.getPositionCount() : block.getTotalValueCount();
state.longValue(state.longValue() + count);
}

@Override
public void addIntermediateInput(Page page) {
assert channels.size() == intermediateBlockCount();
assert page.getBlockCount() >= channels.get(0) + intermediateStateDesc().size();
var blockIndex = blockIndex();
assert page.getBlockCount() >= blockIndex + intermediateStateDesc().size();
LongVector count = page.<LongBlock>getBlock(channels.get(0)).asVector();
BooleanVector seen = page.<BooleanBlock>getBlock(channels.get(1)).asVector();
assert count.getPositionCount() == 1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ public class CountGroupingAggregatorFunction implements GroupingAggregatorFuncti

private final LongArrayState state;
private final List<Integer> channels;
private final boolean countAll;

public static CountGroupingAggregatorFunction create(BigArrays bigArrays, List<Integer> inputChannels) {
return new CountGroupingAggregatorFunction(inputChannels, new LongArrayState(bigArrays, 0));
Expand All @@ -42,6 +43,11 @@ public static List<IntermediateStateDesc> intermediateStateDesc() {
private CountGroupingAggregatorFunction(List<Integer> channels, LongArrayState state) {
this.channels = channels;
this.state = state;
this.countAll = channels.isEmpty();
}

private int blockIndex() {
return countAll ? 0 : channels.get(0);
}

@Override
Expand All @@ -51,33 +57,35 @@ public int intermediateBlockCount() {

@Override
public AddInput prepareProcessPage(SeenGroupIds seenGroupIds, Page page) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nik9000 likewise here

Block valuesBlock = page.getBlock(channels.get(0));
if (valuesBlock.areAllValuesNull()) {
state.enableGroupIdTracking(seenGroupIds);
return new AddInput() { // TODO return null meaning "don't collect me" and skip those
@Override
public void add(int positionOffset, IntBlock groupIds) {}

@Override
public void add(int positionOffset, IntVector groupIds) {}
};
}
Vector valuesVector = valuesBlock.asVector();
if (valuesVector == null) {
if (valuesBlock.mayHaveNulls()) {
Block valuesBlock = page.getBlock(blockIndex());
if (countAll == false) {
if (valuesBlock.areAllValuesNull()) {
state.enableGroupIdTracking(seenGroupIds);
}
return new AddInput() {
@Override
public void add(int positionOffset, IntBlock groupIds) {
addRawInput(positionOffset, groupIds, valuesBlock);
}
return new AddInput() { // TODO return null meaning "don't collect me" and skip those
@Override
public void add(int positionOffset, IntBlock groupIds) {}

@Override
public void add(int positionOffset, IntVector groupIds) {
addRawInput(positionOffset, groupIds, valuesBlock);
@Override
public void add(int positionOffset, IntVector groupIds) {}
};
}
Vector valuesVector = valuesBlock.asVector();
if (valuesVector == null) {
if (valuesBlock.mayHaveNulls()) {
state.enableGroupIdTracking(seenGroupIds);
}
};
return new AddInput() {
@Override
public void add(int positionOffset, IntBlock groupIds) {
addRawInput(positionOffset, groupIds, valuesBlock);
}

@Override
public void add(int positionOffset, IntVector groupIds) {
addRawInput(positionOffset, groupIds, valuesBlock);
}
};
}
}
return new AddInput() {
@Override
Expand Down Expand Up @@ -121,13 +129,19 @@ private void addRawInput(int positionOffset, IntBlock groups, Block values) {
}
}

/**
* This method is called for count all.
*/
private void addRawInput(IntVector groups) {
for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
int groupId = Math.toIntExact(groups.getInt(groupPosition));
state.increment(groupId, 1);
}
}

/**
* This method is called for count all.
*/
private void addRawInput(IntBlock groups) {
for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
// TODO remove the check one we don't emit null anymore
Expand All @@ -146,7 +160,7 @@ private void addRawInput(IntBlock groups) {
@Override
public void addIntermediateInput(int positionOffset, IntVector groups, Page page) {
assert channels.size() == intermediateBlockCount();
assert page.getBlockCount() >= channels.get(0) + intermediateStateDesc().size();
assert page.getBlockCount() >= blockIndex() + intermediateStateDesc().size();
state.enableGroupIdTracking(new SeenGroupIds.Empty());
LongVector count = page.<LongBlock>getBlock(channels.get(0)).asVector();
BooleanVector seen = page.<BooleanBlock>getBlock(channels.get(1)).asVector();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.compute.lucene;

import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.Scorable;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Weight;
import org.elasticsearch.compute.data.BooleanBlock;
import org.elasticsearch.compute.data.LongBlock;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.SourceOperator;
import org.elasticsearch.search.internal.SearchContext;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.List;
import java.util.function.Function;

/**
* Source operator that incrementally counts the results in Lucene searches
* Returns always one entry that mimics the Count aggregation internal state:
* 1. the count as a long (0 if no doc is seen)
* 2. a bool flag (seen) that's always true meaning that the group (all items) always exists
*/
public class LuceneCountOperator extends LuceneOperator {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class is very similar to LuceneSourceOperator - I tried combining the two but there's not much code reuse and the inner state and semantics are fairly different.


private static final int PAGE_SIZE = 1;

private int totalHits = 0;
private int remainingDocs;

private final LeafCollector leafCollector;

public static class Factory implements LuceneOperator.Factory {
private final DataPartitioning dataPartitioning;
private final int taskConcurrency;
private final int limit;
private final LuceneSliceQueue sliceQueue;

public Factory(
List<SearchContext> searchContexts,
Function<SearchContext, Query> queryFunction,
DataPartitioning dataPartitioning,
int taskConcurrency,
int limit
) {
this.limit = limit;
costin marked this conversation as resolved.
Show resolved Hide resolved
this.dataPartitioning = dataPartitioning;
var weightFunction = weightFunction(queryFunction, ScoreMode.COMPLETE_NO_SCORES);
this.sliceQueue = LuceneSliceQueue.create(searchContexts, weightFunction, dataPartitioning, taskConcurrency);
this.taskConcurrency = Math.min(sliceQueue.totalSlices(), taskConcurrency);
}

@Override
public SourceOperator get(DriverContext driverContext) {
return new LuceneCountOperator(sliceQueue, limit);
}

@Override
public int taskConcurrency() {
return taskConcurrency;
}

public int limit() {
return limit;
}

@Override
public String describe() {
return "LuceneCountOperator[dataPartitioning = " + dataPartitioning + ", limit = " + limit + "]";
}
}

public LuceneCountOperator(LuceneSliceQueue sliceQueue, int limit) {
super(PAGE_SIZE, sliceQueue);
this.remainingDocs = limit;
this.leafCollector = new LeafCollector() {
@Override
public void setScorer(Scorable scorer) {}

@Override
public void collect(int doc) {
if (remainingDocs > 0) {
remainingDocs--;
totalHits++;
}
}
};
}

@Override
public boolean isFinished() {
return doneCollecting || remainingDocs == 0;
}

@Override
public void finish() {
doneCollecting = true;
}

@Override
public Page getOutput() {
if (isFinished()) {
assert remainingDocs <= 0 : remainingDocs;
return null;
}
try {
final LuceneScorer scorer = getCurrentOrLoadNextScorer();
// no scorer means no more docs
if (scorer == null) {
remainingDocs = 0;
} else {
Weight weight = scorer.weight();
var leafReaderContext = scorer.leafReaderContext();
// see org.apache.lucene.search.TotalHitCountCollector
int leafCount = weight == null ? -1 : weight.count(leafReaderContext);
if (leafCount != -1) {
// make sure to NOT multi count as the count _shortcut_ (which is segment wide)
// handle doc partitioning where the same leaf can be seen multiple times
// since the count is global, consider it only for the first partition and skip the rest
// SHARD, SEGMENT and the first DOC_ reader in data partitioning contain the first doc (position 0)
if (scorer.position() == 0) {
// check to not count over the desired number of docs/limit
var count = Math.min(leafCount, remainingDocs);
totalHits += count;
remainingDocs -= count;
scorer.markAsDone();
}
} else {
// could not apply shortcut, trigger the search
scorer.scoreNextRange(leafCollector, leafReaderContext.reader().getLiveDocs(), remainingDocs);
}
}

Page page = null;
// emit only one page
if (remainingDocs <= 0 && pagesEmitted == 0) {
pagesEmitted++;
page = new Page(
PAGE_SIZE,
LongBlock.newConstantBlockWith(totalHits, PAGE_SIZE),
BooleanBlock.newConstantBlockWith(true, PAGE_SIZE)
);
}
return page;
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}

@Override
protected void describe(StringBuilder sb) {
sb.append(", remainingDocs=").append(remainingDocs);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,7 @@ public interface Factory extends SourceOperator.SourceOperatorFactory {
}

@Override
public void close() {

}
public void close() {}

LuceneScorer getCurrentOrLoadNextScorer() {
while (currentScorer == null || currentScorer.isDone()) {
Expand Down Expand Up @@ -146,6 +144,14 @@ int shardIndex() {
SearchContext searchContext() {
return searchContext;
}

Weight weight() {
return weight;
}

int position() {
return position;
}
Comment on lines +152 to +158
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made these accessible for the count source.

}

@Override
Expand Down
Loading