From 63ee866ed6f7e27ddd3364660c5606ea20ab73e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Mon, 9 Dec 2024 14:05:48 +0100 Subject: [PATCH] ESQL: Categorize grouping function testing improvements (#118013) Added some extra tests on the CategorizeBlockHash. Added NullFold rule comments, and forced nullable() to TRUE on Categorize. --- .../esql/core/expression/Nullability.java | 17 +- .../blockhash/CategorizeBlockHashTests.java | 235 ++++++++++++++---- .../src/main/resources/categorize.csv-spec | 23 ++ .../function/grouping/Categorize.java | 7 + .../optimizer/rules/logical/FoldNull.java | 5 +- 5 files changed, 229 insertions(+), 58 deletions(-) diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Nullability.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Nullability.java index b08024a707774..d9f136a357208 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Nullability.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Nullability.java @@ -7,7 +7,18 @@ package org.elasticsearch.xpack.esql.core.expression; public enum Nullability { - TRUE, // Whether the expression can become null - FALSE, // The expression can never become null - UNKNOWN // Cannot determine if the expression supports possible null folding + /** + * Whether the expression can become null + */ + TRUE, + + /** + * The expression can never become null + */ + FALSE, + + /** + * Cannot determine if the expression supports possible null folding + */ + UNKNOWN } diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHashTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHashTests.java index 3c47e85a4a9c8..f8428b7c33568 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHashTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHashTests.java @@ -50,11 +50,11 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Set; import java.util.stream.Collectors; import java.util.stream.IntStream; import static org.elasticsearch.compute.operator.OperatorTestCase.runDriver; +import static org.hamcrest.Matchers.arrayWithSize; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; @@ -95,41 +95,114 @@ public void testCategorizeRaw() { page = new Page(builder.build()); } - try (BlockHash hash = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.INITIAL, analysisRegistry)) { - hash.add(page, new GroupingAggregatorFunction.AddInput() { - @Override - public void add(int positionOffset, IntBlock groupIds) { - assertEquals(groupIds.getPositionCount(), positions); + try (var hash = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.SINGLE, analysisRegistry)) { + for (int i = randomInt(2); i < 3; i++) { + hash.add(page, new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntBlock groupIds) { + assertEquals(groupIds.getPositionCount(), positions); + + assertEquals(1, groupIds.getInt(0)); + assertEquals(2, groupIds.getInt(1)); + assertEquals(2, groupIds.getInt(2)); + assertEquals(2, groupIds.getInt(3)); + assertEquals(3, groupIds.getInt(4)); + assertEquals(1, groupIds.getInt(5)); + assertEquals(1, groupIds.getInt(6)); + if (withNull) { + assertEquals(0, groupIds.getInt(7)); + } + } - assertEquals(1, groupIds.getInt(0)); - assertEquals(2, groupIds.getInt(1)); - assertEquals(2, groupIds.getInt(2)); - assertEquals(2, groupIds.getInt(3)); - assertEquals(3, groupIds.getInt(4)); - assertEquals(1, groupIds.getInt(5)); - assertEquals(1, groupIds.getInt(6)); - if (withNull) { - assertEquals(0, groupIds.getInt(7)); + @Override + public void add(int positionOffset, IntVector groupIds) { + add(positionOffset, groupIds.asBlock()); } - } - @Override - public void add(int positionOffset, IntVector groupIds) { - add(positionOffset, groupIds.asBlock()); - } + @Override + public void close() { + fail("hashes should not close AddInput"); + } + }); - @Override - public void close() { - fail("hashes should not close AddInput"); - } - }); + assertHashState(hash, withNull, ".*?Connected.+?to.*?", ".*?Connection.+?error.*?", ".*?Disconnected.*?"); + } } finally { page.releaseBlocks(); } - // TODO: randomize and try multiple pages. - // TODO: assert the state of the BlockHash after adding pages. Including the categorizer state. - // TODO: also test the lookup method and other stuff. + // TODO: randomize values? May give wrong results + // TODO: assert the categorizer state after adding pages. + } + + public void testCategorizeRawMultivalue() { + final Page page; + boolean withNull = randomBoolean(); + final int positions = 3 + (withNull ? 1 : 0); + try (BytesRefBlock.Builder builder = blockFactory.newBytesRefBlockBuilder(positions)) { + builder.beginPositionEntry(); + builder.appendBytesRef(new BytesRef("Connected to 10.1.0.1")); + builder.appendBytesRef(new BytesRef("Connection error")); + builder.appendBytesRef(new BytesRef("Connection error")); + builder.appendBytesRef(new BytesRef("Connection error")); + builder.endPositionEntry(); + builder.appendBytesRef(new BytesRef("Disconnected")); + builder.beginPositionEntry(); + builder.appendBytesRef(new BytesRef("Connected to 10.1.0.2")); + builder.appendBytesRef(new BytesRef("Connected to 10.1.0.3")); + builder.endPositionEntry(); + if (withNull) { + if (randomBoolean()) { + builder.appendNull(); + } else { + builder.appendBytesRef(new BytesRef("")); + } + } + page = new Page(builder.build()); + } + + try (var hash = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.SINGLE, analysisRegistry)) { + for (int i = randomInt(2); i < 3; i++) { + hash.add(page, new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntBlock groupIds) { + assertEquals(groupIds.getPositionCount(), positions); + + assertThat(groupIds.getFirstValueIndex(0), equalTo(0)); + assertThat(groupIds.getValueCount(0), equalTo(4)); + assertThat(groupIds.getFirstValueIndex(1), equalTo(4)); + assertThat(groupIds.getValueCount(1), equalTo(1)); + assertThat(groupIds.getFirstValueIndex(2), equalTo(5)); + assertThat(groupIds.getValueCount(2), equalTo(2)); + + assertEquals(1, groupIds.getInt(0)); + assertEquals(2, groupIds.getInt(1)); + assertEquals(2, groupIds.getInt(2)); + assertEquals(2, groupIds.getInt(3)); + assertEquals(3, groupIds.getInt(4)); + assertEquals(1, groupIds.getInt(5)); + assertEquals(1, groupIds.getInt(6)); + if (withNull) { + assertEquals(0, groupIds.getInt(7)); + } + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + add(positionOffset, groupIds.asBlock()); + } + + @Override + public void close() { + fail("hashes should not close AddInput"); + } + }); + + assertHashState(hash, withNull, ".*?Connected.+?to.*?", ".*?Connection.+?error.*?", ".*?Disconnected.*?"); + } + } finally { + page.releaseBlocks(); + } } public void testCategorizeIntermediate() { @@ -226,18 +299,18 @@ public void close() { page2.releaseBlocks(); } - try (BlockHash intermediateHash = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.INTERMEDIATE, null)) { + try (var intermediateHash = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.FINAL, null)) { intermediateHash.add(intermediatePage1, new GroupingAggregatorFunction.AddInput() { @Override public void add(int positionOffset, IntBlock groupIds) { - Set values = IntStream.range(0, groupIds.getPositionCount()) + List values = IntStream.range(0, groupIds.getPositionCount()) .map(groupIds::getInt) .boxed() - .collect(Collectors.toSet()); + .collect(Collectors.toList()); if (withNull) { - assertEquals(Set.of(0, 1, 2), values); + assertEquals(List.of(0, 1, 2), values); } else { - assertEquals(Set.of(1, 2), values); + assertEquals(List.of(1, 2), values); } } @@ -252,28 +325,39 @@ public void close() { } }); - intermediateHash.add(intermediatePage2, new GroupingAggregatorFunction.AddInput() { - @Override - public void add(int positionOffset, IntBlock groupIds) { - Set values = IntStream.range(0, groupIds.getPositionCount()) - .map(groupIds::getInt) - .boxed() - .collect(Collectors.toSet()); - // The category IDs {0, 1, 2} should map to groups {0, 2, 3}, because - // 0 matches an existing category (Connected to ...), and the others are new. - assertEquals(Set.of(1, 3, 4), values); - } + for (int i = randomInt(2); i < 3; i++) { + intermediateHash.add(intermediatePage2, new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntBlock groupIds) { + List values = IntStream.range(0, groupIds.getPositionCount()) + .map(groupIds::getInt) + .boxed() + .collect(Collectors.toList()); + // The category IDs {1, 2, 3} should map to groups {1, 3, 4}, because + // 1 matches an existing category (Connected to ...), and the others are new. + assertEquals(List.of(3, 1, 4), values); + } - @Override - public void add(int positionOffset, IntVector groupIds) { - add(positionOffset, groupIds.asBlock()); - } + @Override + public void add(int positionOffset, IntVector groupIds) { + add(positionOffset, groupIds.asBlock()); + } - @Override - public void close() { - fail("hashes should not close AddInput"); - } - }); + @Override + public void close() { + fail("hashes should not close AddInput"); + } + }); + + assertHashState( + intermediateHash, + withNull, + ".*?Connected.+?to.*?", + ".*?Connection.+?error.*?", + ".*?Disconnected.*?", + ".*?System.+?shutdown.*?" + ); + } } finally { intermediatePage1.releaseBlocks(); intermediatePage2.releaseBlocks(); @@ -457,4 +541,49 @@ public void testCategorize_withDriver() { private BlockHash.GroupSpec makeGroupSpec() { return new BlockHash.GroupSpec(0, ElementType.BYTES_REF, true); } + + private void assertHashState(CategorizeBlockHash hash, boolean withNull, String... expectedKeys) { + // Check the keys + Block[] blocks = null; + try { + blocks = hash.getKeys(); + assertThat(blocks, arrayWithSize(1)); + + var keysBlock = (BytesRefBlock) blocks[0]; + assertThat(keysBlock.getPositionCount(), equalTo(expectedKeys.length + (withNull ? 1 : 0))); + + if (withNull) { + assertTrue(keysBlock.isNull(0)); + } + + for (int i = 0; i < expectedKeys.length; i++) { + int position = i + (withNull ? 1 : 0); + String key = keysBlock.getBytesRef(position, new BytesRef()).utf8ToString(); + assertThat(key, equalTo(expectedKeys[i])); + } + } finally { + if (blocks != null) { + Releasables.close(blocks); + } + } + + // Check the nonEmpty() result + try (IntVector nonEmptyKeys = hash.nonEmpty()) { + int oneIfNull = withNull ? 1 : 0; + assertThat(nonEmptyKeys.getPositionCount(), equalTo(expectedKeys.length + oneIfNull)); + + for (int i = 0; i < expectedKeys.length + oneIfNull; i++) { + assertThat(nonEmptyKeys.getInt(i), equalTo(i + 1 - oneIfNull)); + } + } + + // Check seenGroupIds() + try (var seenGroupIds = hash.seenGroupIds(blockFactory.bigArrays())) { + assertThat(seenGroupIds.get(0), equalTo(withNull)); + + for (int i = 1; i <= expectedKeys.length; i++) { + assertThat(seenGroupIds.get(i), equalTo(true)); + } + } + } } diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/categorize.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/categorize.csv-spec index 804c1c56a1eb5..4ce43961a7077 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/categorize.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/categorize.csv-spec @@ -374,6 +374,29 @@ COUNT():long | category:keyword 7 | null ; +on const null +required_capability: categorize_v5 + +FROM sample_data + | STATS COUNT(), SUM(event_duration) BY category=CATEGORIZE(null) + | SORT category +; + +COUNT():long | SUM(event_duration):long | category:keyword + 7 | 23231327 | null +; + +on null row +required_capability: categorize_v5 + +ROW message = null, str = ["a", "b", "c"] +| STATS COUNT(), VALUES(str) BY category=CATEGORIZE(message) +; + +COUNT():long | VALUES(str):keyword | category:keyword + 1 | [a, b, c] | null +; + filtering out all data required_capability: categorize_v5 diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Categorize.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Categorize.java index e2c04ecb15b59..ded913a78bdf1 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Categorize.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Categorize.java @@ -13,6 +13,7 @@ import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator; import org.elasticsearch.xpack.esql.capabilities.Validatable; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.Nullability; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; @@ -92,6 +93,12 @@ public boolean foldable() { return false; } + @Override + public Nullability nullable() { + // Both nulls and empty strings result in null values + return Nullability.TRUE; + } + @Override public ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) { throw new UnsupportedOperationException("CATEGORIZE is only evaluated during aggregations"); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/FoldNull.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/FoldNull.java index 4f97bf60bd863..747864625e65c 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/FoldNull.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/FoldNull.java @@ -41,8 +41,9 @@ public Expression rule(Expression e) { if (Expressions.isGuaranteedNull(in.value())) { return Literal.of(in, null); } - } else if (e instanceof Alias == false - && e.nullable() == Nullability.TRUE + } else if (e instanceof Alias == false && e.nullable() == Nullability.TRUE + // Categorize function stays as a STATS grouping (It isn't moved to an early EVAL like other groupings), + // so folding it to null would currently break the plan, as we don't create an attribute/channel for that null value. && e instanceof Categorize == false && Expressions.anyMatch(e.children(), Expressions::isGuaranteedNull)) { return Literal.of(e, null);