Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ESQL: Categorize grouping function testing improvements #118013

Merged
merged 7 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,18 @@
package org.elasticsearch.xpack.esql.core.expression;

public enum Nullability {
TRUE, // Whether the expression can become null
FALSE, // The expression can never become null
UNKNOWN // Cannot determine if the expression supports possible null folding
/**
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

++

* Whether the expression can become null
*/
TRUE,

/**
* The expression can never become null
*/
FALSE,

/**
* Cannot determine if the expression supports possible null folding
*/
UNKNOWN
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,11 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import static org.elasticsearch.compute.operator.OperatorTestCase.runDriver;
import static org.hamcrest.Matchers.arrayWithSize;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;

Expand Down Expand Up @@ -95,41 +95,114 @@ public void testCategorizeRaw() {
page = new Page(builder.build());
}

try (BlockHash hash = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.INITIAL, analysisRegistry)) {
hash.add(page, new GroupingAggregatorFunction.AddInput() {
@Override
public void add(int positionOffset, IntBlock groupIds) {
assertEquals(groupIds.getPositionCount(), positions);
try (var hash = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.SINGLE, analysisRegistry)) {
for (int i = randomInt(2); i < 3; i++) {
hash.add(page, new GroupingAggregatorFunction.AddInput() {
@Override
public void add(int positionOffset, IntBlock groupIds) {
assertEquals(groupIds.getPositionCount(), positions);

assertEquals(1, groupIds.getInt(0));
assertEquals(2, groupIds.getInt(1));
assertEquals(2, groupIds.getInt(2));
assertEquals(2, groupIds.getInt(3));
assertEquals(3, groupIds.getInt(4));
assertEquals(1, groupIds.getInt(5));
assertEquals(1, groupIds.getInt(6));
if (withNull) {
assertEquals(0, groupIds.getInt(7));
}
}

assertEquals(1, groupIds.getInt(0));
assertEquals(2, groupIds.getInt(1));
assertEquals(2, groupIds.getInt(2));
assertEquals(2, groupIds.getInt(3));
assertEquals(3, groupIds.getInt(4));
assertEquals(1, groupIds.getInt(5));
assertEquals(1, groupIds.getInt(6));
if (withNull) {
assertEquals(0, groupIds.getInt(7));
@Override
public void add(int positionOffset, IntVector groupIds) {
add(positionOffset, groupIds.asBlock());
}
}

@Override
public void add(int positionOffset, IntVector groupIds) {
add(positionOffset, groupIds.asBlock());
}
@Override
public void close() {
fail("hashes should not close AddInput");
}
});

@Override
public void close() {
fail("hashes should not close AddInput");
}
});
assertHashState(hash, withNull, ".*?Connected.+?to.*?", ".*?Connection.+?error.*?", ".*?Disconnected.*?");
}
} finally {
page.releaseBlocks();
}

// TODO: randomize and try multiple pages.
// TODO: assert the state of the BlockHash after adding pages. Including the categorizer state.
// TODO: also test the lookup method and other stuff.
// TODO: randomize values? May give wrong results
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could see randomizing the values in a separate test and asserting that we get expected results/don't crash out.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But later. This is fine now.

// TODO: assert the categorizer state after adding pages.
}

public void testCategorizeRawMultivalue() {
final Page page;
boolean withNull = randomBoolean();
final int positions = 3 + (withNull ? 1 : 0);
try (BytesRefBlock.Builder builder = blockFactory.newBytesRefBlockBuilder(positions)) {
builder.beginPositionEntry();
builder.appendBytesRef(new BytesRef("Connected to 10.1.0.1"));
builder.appendBytesRef(new BytesRef("Connection error"));
builder.appendBytesRef(new BytesRef("Connection error"));
builder.appendBytesRef(new BytesRef("Connection error"));
builder.endPositionEntry();
builder.appendBytesRef(new BytesRef("Disconnected"));
builder.beginPositionEntry();
builder.appendBytesRef(new BytesRef("Connected to 10.1.0.2"));
builder.appendBytesRef(new BytesRef("Connected to 10.1.0.3"));
builder.endPositionEntry();
if (withNull) {
if (randomBoolean()) {
builder.appendNull();
} else {
builder.appendBytesRef(new BytesRef(""));
}
}
page = new Page(builder.build());
}

try (var hash = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.SINGLE, analysisRegistry)) {
for (int i = randomInt(2); i < 3; i++) {
hash.add(page, new GroupingAggregatorFunction.AddInput() {
@Override
public void add(int positionOffset, IntBlock groupIds) {
assertEquals(groupIds.getPositionCount(), positions);

assertThat(groupIds.getFirstValueIndex(0), equalTo(0));
assertThat(groupIds.getValueCount(0), equalTo(4));
assertThat(groupIds.getFirstValueIndex(1), equalTo(4));
assertThat(groupIds.getValueCount(1), equalTo(1));
assertThat(groupIds.getFirstValueIndex(2), equalTo(5));
assertThat(groupIds.getValueCount(2), equalTo(2));

assertEquals(1, groupIds.getInt(0));
assertEquals(2, groupIds.getInt(1));
assertEquals(2, groupIds.getInt(2));
assertEquals(2, groupIds.getInt(3));
assertEquals(3, groupIds.getInt(4));
assertEquals(1, groupIds.getInt(5));
assertEquals(1, groupIds.getInt(6));
if (withNull) {
assertEquals(0, groupIds.getInt(7));
}
}

@Override
public void add(int positionOffset, IntVector groupIds) {
add(positionOffset, groupIds.asBlock());
}

@Override
public void close() {
fail("hashes should not close AddInput");
}
});

assertHashState(hash, withNull, ".*?Connected.+?to.*?", ".*?Connection.+?error.*?", ".*?Disconnected.*?");
}
} finally {
page.releaseBlocks();
}
}

public void testCategorizeIntermediate() {
Expand Down Expand Up @@ -226,18 +299,18 @@ public void close() {
page2.releaseBlocks();
}

try (BlockHash intermediateHash = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.INTERMEDIATE, null)) {
try (var intermediateHash = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.FINAL, null)) {
intermediateHash.add(intermediatePage1, new GroupingAggregatorFunction.AddInput() {
@Override
public void add(int positionOffset, IntBlock groupIds) {
Set<Integer> values = IntStream.range(0, groupIds.getPositionCount())
List<Integer> values = IntStream.range(0, groupIds.getPositionCount())
.map(groupIds::getInt)
.boxed()
.collect(Collectors.toSet());
.collect(Collectors.toList());
if (withNull) {
assertEquals(Set.of(0, 1, 2), values);
assertEquals(List.of(0, 1, 2), values);
} else {
assertEquals(Set.of(1, 2), values);
assertEquals(List.of(1, 2), values);
}
}

Expand All @@ -252,28 +325,39 @@ public void close() {
}
});

intermediateHash.add(intermediatePage2, new GroupingAggregatorFunction.AddInput() {
@Override
public void add(int positionOffset, IntBlock groupIds) {
Set<Integer> values = IntStream.range(0, groupIds.getPositionCount())
.map(groupIds::getInt)
.boxed()
.collect(Collectors.toSet());
// The category IDs {0, 1, 2} should map to groups {0, 2, 3}, because
// 0 matches an existing category (Connected to ...), and the others are new.
assertEquals(Set.of(1, 3, 4), values);
}
for (int i = randomInt(2); i < 3; i++) {
intermediateHash.add(intermediatePage2, new GroupingAggregatorFunction.AddInput() {
@Override
public void add(int positionOffset, IntBlock groupIds) {
List<Integer> values = IntStream.range(0, groupIds.getPositionCount())
.map(groupIds::getInt)
.boxed()
.collect(Collectors.toList());
// The category IDs {1, 2, 3} should map to groups {1, 3, 4}, because
// 1 matches an existing category (Connected to ...), and the others are new.
assertEquals(List.of(3, 1, 4), values);
}

@Override
public void add(int positionOffset, IntVector groupIds) {
add(positionOffset, groupIds.asBlock());
}
@Override
public void add(int positionOffset, IntVector groupIds) {
add(positionOffset, groupIds.asBlock());
}

@Override
public void close() {
fail("hashes should not close AddInput");
}
});
@Override
public void close() {
fail("hashes should not close AddInput");
}
});

assertHashState(
intermediateHash,
withNull,
".*?Connected.+?to.*?",
".*?Connection.+?error.*?",
".*?Disconnected.*?",
".*?System.+?shutdown.*?"
);
}
} finally {
intermediatePage1.releaseBlocks();
intermediatePage2.releaseBlocks();
Expand Down Expand Up @@ -457,4 +541,49 @@ public void testCategorize_withDriver() {
private BlockHash.GroupSpec makeGroupSpec() {
return new BlockHash.GroupSpec(0, ElementType.BYTES_REF, true);
}

private void assertHashState(CategorizeBlockHash hash, boolean withNull, String... expectedKeys) {
// Check the keys
Block[] blocks = null;
try {
blocks = hash.getKeys();
assertThat(blocks, arrayWithSize(1));

var keysBlock = (BytesRefBlock) blocks[0];
assertThat(keysBlock.getPositionCount(), equalTo(expectedKeys.length + (withNull ? 1 : 0)));

if (withNull) {
assertTrue(keysBlock.isNull(0));
}

for (int i = 0; i < expectedKeys.length; i++) {
int position = i + (withNull ? 1 : 0);
String key = keysBlock.getBytesRef(position, new BytesRef()).utf8ToString();
assertThat(key, equalTo(expectedKeys[i]));
}
} finally {
if (blocks != null) {
Releasables.close(blocks);
}
}

// Check the nonEmpty() result
try (IntVector nonEmptyKeys = hash.nonEmpty()) {
int oneIfNull = withNull ? 1 : 0;
assertThat(nonEmptyKeys.getPositionCount(), equalTo(expectedKeys.length + oneIfNull));

for (int i = 0; i < expectedKeys.length + oneIfNull; i++) {
assertThat(nonEmptyKeys.getInt(i), equalTo(i + 1 - oneIfNull));
}
}

// Check seenGroupIds()
try (var seenGroupIds = hash.seenGroupIds(blockFactory.bigArrays())) {
assertThat(seenGroupIds.get(0), equalTo(withNull));

for (int i = 1; i <= expectedKeys.length; i++) {
assertThat(seenGroupIds.get(i), equalTo(true));
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,29 @@ COUNT():long | category:keyword
7 | null
;

on const null
required_capability: categorize_v5

FROM sample_data
| STATS COUNT(), SUM(event_duration) BY category=CATEGORIZE(null)
| SORT category
;

COUNT():long | SUM(event_duration):long | category:keyword
7 | 23231327 | null
;

on null row
required_capability: categorize_v3
ivancea marked this conversation as resolved.
Show resolved Hide resolved

ROW message = null, str = ["a", "b", "c"]
| STATS COUNT(), VALUES(str) BY category=CATEGORIZE(message)
;

COUNT():long | VALUES(str):keyword | category:keyword
1 | [a, b, c] | null
;

filtering out all data
required_capability: categorize_v5

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;
import org.elasticsearch.xpack.esql.capabilities.Validatable;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.Nullability;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
Expand Down Expand Up @@ -92,6 +93,12 @@ public boolean foldable() {
return false;
}

@Override
public Nullability nullable() {
// Both nulls and empty strings result in null values
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

++ for the comment, thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be precisely: any string for which the analyzer returns zero tokens results in a null category. The analyzer discards tokens like numbers, hex.numbers, and stopwords like Jan, Feb, Mon, Tue, ...

return Nullability.TRUE;
}

@Override
public ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) {
throw new UnsupportedOperationException("CATEGORIZE is only evaluated during aggregations");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@ public Expression rule(Expression e) {
if (Expressions.isGuaranteedNull(in.value())) {
return Literal.of(in, null);
}
} else if (e instanceof Alias == false
&& e.nullable() == Nullability.TRUE
} else if (e instanceof Alias == false && e.nullable() == Nullability.TRUE
// Categorize function stays as a STATS grouping (It isn't moved to an early EVAL like other groupings),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The story of null folding is not complete (PropagateEvalFoldables has missing pieces), but I appreciate adding this comment here. I was kind of confused and surprised to see here this out-of-nowhere instanceof Categorize.

// so folding it to null would currently break the plan, as we don't create an attribute/channel for that null value.
&& e instanceof Categorize == false
&& Expressions.anyMatch(e.children(), Expressions::isGuaranteedNull)) {
return Literal.of(e, null);
Expand Down