Skip to content

Commit

Permalink
ESQL: Categorize grouping function testing improvements (#118013)
Browse files Browse the repository at this point in the history
Added some extra tests on the CategorizeBlockHash.

Added NullFold rule comments, and forced nullable() to TRUE on Categorize.
  • Loading branch information
ivancea authored Dec 9, 2024
1 parent 64e0902 commit 63ee866
Show file tree
Hide file tree
Showing 5 changed files with 229 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,18 @@
package org.elasticsearch.xpack.esql.core.expression;

public enum Nullability {
TRUE, // Whether the expression can become null
FALSE, // The expression can never become null
UNKNOWN // Cannot determine if the expression supports possible null folding
/**
* Whether the expression can become null
*/
TRUE,

/**
* The expression can never become null
*/
FALSE,

/**
* Cannot determine if the expression supports possible null folding
*/
UNKNOWN
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,11 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import static org.elasticsearch.compute.operator.OperatorTestCase.runDriver;
import static org.hamcrest.Matchers.arrayWithSize;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;

Expand Down Expand Up @@ -95,41 +95,114 @@ public void testCategorizeRaw() {
page = new Page(builder.build());
}

try (BlockHash hash = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.INITIAL, analysisRegistry)) {
hash.add(page, new GroupingAggregatorFunction.AddInput() {
@Override
public void add(int positionOffset, IntBlock groupIds) {
assertEquals(groupIds.getPositionCount(), positions);
try (var hash = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.SINGLE, analysisRegistry)) {
for (int i = randomInt(2); i < 3; i++) {
hash.add(page, new GroupingAggregatorFunction.AddInput() {
@Override
public void add(int positionOffset, IntBlock groupIds) {
assertEquals(groupIds.getPositionCount(), positions);

assertEquals(1, groupIds.getInt(0));
assertEquals(2, groupIds.getInt(1));
assertEquals(2, groupIds.getInt(2));
assertEquals(2, groupIds.getInt(3));
assertEquals(3, groupIds.getInt(4));
assertEquals(1, groupIds.getInt(5));
assertEquals(1, groupIds.getInt(6));
if (withNull) {
assertEquals(0, groupIds.getInt(7));
}
}

assertEquals(1, groupIds.getInt(0));
assertEquals(2, groupIds.getInt(1));
assertEquals(2, groupIds.getInt(2));
assertEquals(2, groupIds.getInt(3));
assertEquals(3, groupIds.getInt(4));
assertEquals(1, groupIds.getInt(5));
assertEquals(1, groupIds.getInt(6));
if (withNull) {
assertEquals(0, groupIds.getInt(7));
@Override
public void add(int positionOffset, IntVector groupIds) {
add(positionOffset, groupIds.asBlock());
}
}

@Override
public void add(int positionOffset, IntVector groupIds) {
add(positionOffset, groupIds.asBlock());
}
@Override
public void close() {
fail("hashes should not close AddInput");
}
});

@Override
public void close() {
fail("hashes should not close AddInput");
}
});
assertHashState(hash, withNull, ".*?Connected.+?to.*?", ".*?Connection.+?error.*?", ".*?Disconnected.*?");
}
} finally {
page.releaseBlocks();
}

// TODO: randomize and try multiple pages.
// TODO: assert the state of the BlockHash after adding pages. Including the categorizer state.
// TODO: also test the lookup method and other stuff.
// TODO: randomize values? May give wrong results
// TODO: assert the categorizer state after adding pages.
}

public void testCategorizeRawMultivalue() {
final Page page;
boolean withNull = randomBoolean();
final int positions = 3 + (withNull ? 1 : 0);
try (BytesRefBlock.Builder builder = blockFactory.newBytesRefBlockBuilder(positions)) {
builder.beginPositionEntry();
builder.appendBytesRef(new BytesRef("Connected to 10.1.0.1"));
builder.appendBytesRef(new BytesRef("Connection error"));
builder.appendBytesRef(new BytesRef("Connection error"));
builder.appendBytesRef(new BytesRef("Connection error"));
builder.endPositionEntry();
builder.appendBytesRef(new BytesRef("Disconnected"));
builder.beginPositionEntry();
builder.appendBytesRef(new BytesRef("Connected to 10.1.0.2"));
builder.appendBytesRef(new BytesRef("Connected to 10.1.0.3"));
builder.endPositionEntry();
if (withNull) {
if (randomBoolean()) {
builder.appendNull();
} else {
builder.appendBytesRef(new BytesRef(""));
}
}
page = new Page(builder.build());
}

try (var hash = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.SINGLE, analysisRegistry)) {
for (int i = randomInt(2); i < 3; i++) {
hash.add(page, new GroupingAggregatorFunction.AddInput() {
@Override
public void add(int positionOffset, IntBlock groupIds) {
assertEquals(groupIds.getPositionCount(), positions);

assertThat(groupIds.getFirstValueIndex(0), equalTo(0));
assertThat(groupIds.getValueCount(0), equalTo(4));
assertThat(groupIds.getFirstValueIndex(1), equalTo(4));
assertThat(groupIds.getValueCount(1), equalTo(1));
assertThat(groupIds.getFirstValueIndex(2), equalTo(5));
assertThat(groupIds.getValueCount(2), equalTo(2));

assertEquals(1, groupIds.getInt(0));
assertEquals(2, groupIds.getInt(1));
assertEquals(2, groupIds.getInt(2));
assertEquals(2, groupIds.getInt(3));
assertEquals(3, groupIds.getInt(4));
assertEquals(1, groupIds.getInt(5));
assertEquals(1, groupIds.getInt(6));
if (withNull) {
assertEquals(0, groupIds.getInt(7));
}
}

@Override
public void add(int positionOffset, IntVector groupIds) {
add(positionOffset, groupIds.asBlock());
}

@Override
public void close() {
fail("hashes should not close AddInput");
}
});

assertHashState(hash, withNull, ".*?Connected.+?to.*?", ".*?Connection.+?error.*?", ".*?Disconnected.*?");
}
} finally {
page.releaseBlocks();
}
}

public void testCategorizeIntermediate() {
Expand Down Expand Up @@ -226,18 +299,18 @@ public void close() {
page2.releaseBlocks();
}

try (BlockHash intermediateHash = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.INTERMEDIATE, null)) {
try (var intermediateHash = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.FINAL, null)) {
intermediateHash.add(intermediatePage1, new GroupingAggregatorFunction.AddInput() {
@Override
public void add(int positionOffset, IntBlock groupIds) {
Set<Integer> values = IntStream.range(0, groupIds.getPositionCount())
List<Integer> values = IntStream.range(0, groupIds.getPositionCount())
.map(groupIds::getInt)
.boxed()
.collect(Collectors.toSet());
.collect(Collectors.toList());
if (withNull) {
assertEquals(Set.of(0, 1, 2), values);
assertEquals(List.of(0, 1, 2), values);
} else {
assertEquals(Set.of(1, 2), values);
assertEquals(List.of(1, 2), values);
}
}

Expand All @@ -252,28 +325,39 @@ public void close() {
}
});

intermediateHash.add(intermediatePage2, new GroupingAggregatorFunction.AddInput() {
@Override
public void add(int positionOffset, IntBlock groupIds) {
Set<Integer> values = IntStream.range(0, groupIds.getPositionCount())
.map(groupIds::getInt)
.boxed()
.collect(Collectors.toSet());
// The category IDs {0, 1, 2} should map to groups {0, 2, 3}, because
// 0 matches an existing category (Connected to ...), and the others are new.
assertEquals(Set.of(1, 3, 4), values);
}
for (int i = randomInt(2); i < 3; i++) {
intermediateHash.add(intermediatePage2, new GroupingAggregatorFunction.AddInput() {
@Override
public void add(int positionOffset, IntBlock groupIds) {
List<Integer> values = IntStream.range(0, groupIds.getPositionCount())
.map(groupIds::getInt)
.boxed()
.collect(Collectors.toList());
// The category IDs {1, 2, 3} should map to groups {1, 3, 4}, because
// 1 matches an existing category (Connected to ...), and the others are new.
assertEquals(List.of(3, 1, 4), values);
}

@Override
public void add(int positionOffset, IntVector groupIds) {
add(positionOffset, groupIds.asBlock());
}
@Override
public void add(int positionOffset, IntVector groupIds) {
add(positionOffset, groupIds.asBlock());
}

@Override
public void close() {
fail("hashes should not close AddInput");
}
});
@Override
public void close() {
fail("hashes should not close AddInput");
}
});

assertHashState(
intermediateHash,
withNull,
".*?Connected.+?to.*?",
".*?Connection.+?error.*?",
".*?Disconnected.*?",
".*?System.+?shutdown.*?"
);
}
} finally {
intermediatePage1.releaseBlocks();
intermediatePage2.releaseBlocks();
Expand Down Expand Up @@ -457,4 +541,49 @@ public void testCategorize_withDriver() {
private BlockHash.GroupSpec makeGroupSpec() {
return new BlockHash.GroupSpec(0, ElementType.BYTES_REF, true);
}

private void assertHashState(CategorizeBlockHash hash, boolean withNull, String... expectedKeys) {
// Check the keys
Block[] blocks = null;
try {
blocks = hash.getKeys();
assertThat(blocks, arrayWithSize(1));

var keysBlock = (BytesRefBlock) blocks[0];
assertThat(keysBlock.getPositionCount(), equalTo(expectedKeys.length + (withNull ? 1 : 0)));

if (withNull) {
assertTrue(keysBlock.isNull(0));
}

for (int i = 0; i < expectedKeys.length; i++) {
int position = i + (withNull ? 1 : 0);
String key = keysBlock.getBytesRef(position, new BytesRef()).utf8ToString();
assertThat(key, equalTo(expectedKeys[i]));
}
} finally {
if (blocks != null) {
Releasables.close(blocks);
}
}

// Check the nonEmpty() result
try (IntVector nonEmptyKeys = hash.nonEmpty()) {
int oneIfNull = withNull ? 1 : 0;
assertThat(nonEmptyKeys.getPositionCount(), equalTo(expectedKeys.length + oneIfNull));

for (int i = 0; i < expectedKeys.length + oneIfNull; i++) {
assertThat(nonEmptyKeys.getInt(i), equalTo(i + 1 - oneIfNull));
}
}

// Check seenGroupIds()
try (var seenGroupIds = hash.seenGroupIds(blockFactory.bigArrays())) {
assertThat(seenGroupIds.get(0), equalTo(withNull));

for (int i = 1; i <= expectedKeys.length; i++) {
assertThat(seenGroupIds.get(i), equalTo(true));
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,29 @@ COUNT():long | category:keyword
7 | null
;

on const null
required_capability: categorize_v5

FROM sample_data
| STATS COUNT(), SUM(event_duration) BY category=CATEGORIZE(null)
| SORT category
;

COUNT():long | SUM(event_duration):long | category:keyword
7 | 23231327 | null
;

on null row
required_capability: categorize_v5

ROW message = null, str = ["a", "b", "c"]
| STATS COUNT(), VALUES(str) BY category=CATEGORIZE(message)
;

COUNT():long | VALUES(str):keyword | category:keyword
1 | [a, b, c] | null
;

filtering out all data
required_capability: categorize_v5

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;
import org.elasticsearch.xpack.esql.capabilities.Validatable;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.Nullability;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
Expand Down Expand Up @@ -92,6 +93,12 @@ public boolean foldable() {
return false;
}

@Override
public Nullability nullable() {
// Both nulls and empty strings result in null values
return Nullability.TRUE;
}

@Override
public ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) {
throw new UnsupportedOperationException("CATEGORIZE is only evaluated during aggregations");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@ public Expression rule(Expression e) {
if (Expressions.isGuaranteedNull(in.value())) {
return Literal.of(in, null);
}
} else if (e instanceof Alias == false
&& e.nullable() == Nullability.TRUE
} else if (e instanceof Alias == false && e.nullable() == Nullability.TRUE
// Categorize function stays as a STATS grouping (It isn't moved to an early EVAL like other groupings),
// so folding it to null would currently break the plan, as we don't create an attribute/channel for that null value.
&& e instanceof Categorize == false
&& Expressions.anyMatch(e.children(), Expressions::isGuaranteedNull)) {
return Literal.of(e, null);
Expand Down

0 comments on commit 63ee866

Please sign in to comment.