From 7c86a22a4e9cd8a3c0549f5fb1eaeef5eeb6e527 Mon Sep 17 00:00:00 2001 From: Aurelien FOUCRET Date: Thu, 24 Oct 2024 14:08:03 +0200 Subject: [PATCH] Adding text extraction unit tests. --- .../xpack/kql/parser/ParserUtils.java | 87 ++++---- .../parser/KqlParserFieldlessQueryTests.java | 4 +- .../xpack/kql/parser/ParserUtilsTests.java | 211 ++++++++++++++++++ 3 files changed, 253 insertions(+), 49 deletions(-) create mode 100644 x-pack/plugin/kql/src/test/java/org/elasticsearch/xpack/kql/parser/ParserUtilsTests.java diff --git a/x-pack/plugin/kql/src/main/java/org/elasticsearch/xpack/kql/parser/ParserUtils.java b/x-pack/plugin/kql/src/main/java/org/elasticsearch/xpack/kql/parser/ParserUtils.java index 92ed45b972a49..7b07bd9188ad0 100644 --- a/x-pack/plugin/kql/src/main/java/org/elasticsearch/xpack/kql/parser/ParserUtils.java +++ b/x-pack/plugin/kql/src/main/java/org/elasticsearch/xpack/kql/parser/ParserUtils.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.kql.parser; import org.antlr.v4.runtime.ParserRuleContext; +import org.antlr.v4.runtime.Token; import org.antlr.v4.runtime.tree.ParseTree; import org.antlr.v4.runtime.tree.ParseTreeVisitor; import org.antlr.v4.runtime.tree.TerminalNode; @@ -55,7 +56,22 @@ public static String extractText(ParserRuleContext ctx) { return String.join(UNQUOTED_LITERAL_TERM_DELIMITER, extractTextTokems(ctx)); } - public static List extractTextTokems(ParserRuleContext ctx) { + public static boolean hasWildcard(ParserRuleContext ctx) { + return ctx.children.stream().anyMatch(childNode -> { + if (childNode instanceof TerminalNode terminalNode) { + Token token = terminalNode.getSymbol(); + return switch (token.getType()) { + case KqlBaseParser.WILDCARD -> true; + case KqlBaseParser.UNQUOTED_LITERAL -> token.getText().matches("[^\\\\]*[*].*"); + default -> false; + }; + } + + return false; + }); + } + + private static List extractTextTokems(ParserRuleContext ctx) { assert ctx.children != null; List textTokens = new ArrayList<>(ctx.children.size()); @@ -71,7 +87,7 @@ public static List extractTextTokems(ParserRuleContext ctx) { return textTokens; } - public static String extractText(TerminalNode node) { + private static String extractText(TerminalNode node) { if (node.getSymbol().getType() == KqlBaseParser.QUOTED_STRING) { return unescapeQuotedString(node); } else if (node.getSymbol().getType() == KqlBaseParser.UNQUOTED_LITERAL) { @@ -81,45 +97,27 @@ public static String extractText(TerminalNode node) { return node.getText(); } - public static boolean hasWildcard(ParserRuleContext ctx) { - return ctx.children.stream().anyMatch(childNode -> { - if (childNode instanceof TerminalNode terminalNode) { - return switch (terminalNode.getSymbol().getType()) { - case KqlBaseParser.WILDCARD -> true; - case KqlBaseParser.UNQUOTED_LITERAL -> terminalNode.getText().matches("[^\\\\]*[*].*"); - default -> false; - }; - } - - return false; - }); - } - private static String unescapeQuotedString(TerminalNode ctx) { String inputText = ctx.getText(); assert inputText.length() >= 2 && inputText.charAt(0) == QUOTE_CHAR && inputText.charAt(inputText.length() - 1) == QUOTE_CHAR; StringBuilder sb = new StringBuilder(); - for (int i = 1; i < inputText.length() - 1;) { + for (int i = 1; i < inputText.length() - 1; i++) { if (inputText.charAt(i) == ESCAPE_CHAR && i + 1 < inputText.length()) { switch (inputText.charAt(++i)) { case 't' -> sb.append('\t'); - case 'b' -> sb.append('\b'); - case 'f' -> sb.append('\f'); case 'n' -> sb.append('\n'); case 'r' -> sb.append('\r'); case '"' -> sb.append('\"'); - case '\'' -> sb.append('\''); - case 'u' -> i = handleUnicodePoints(ctx, sb, inputText, ++i); + case 'u' -> i = handleUnicodePoints(ctx, sb, inputText, i); case '\\' -> sb.append('\\'); - default -> { + default -> // For quoted strings, unknown escape sequences are passed through as-is - sb.append(ESCAPE_CHAR).append(inputText.charAt(i++)); - } + sb.append(ESCAPE_CHAR).append(inputText.charAt(i)); } } else { - sb.append(inputText.charAt(i++)); + sb.append(inputText.charAt(i)); } } @@ -138,25 +136,20 @@ private static String unescapeUnquotedLiteral(TerminalNode ctx) { char currentChar = inputText.charAt(i); if (currentChar == '\\' && i + 1 < inputText.length()) { - switch (inputText.charAt(++i)) { - case 't' -> sb.append('\t'); - case 'b' -> sb.append('\b'); - case 'f' -> sb.append('\f'); - case 'n' -> sb.append('\n'); - case 'r' -> sb.append('\r'); - case '"' -> sb.append('\"'); - case '\'' -> sb.append('\''); - case 'u' -> i = handleUnicodePoints(ctx, sb, inputText, ++i); - case '\\' -> sb.append('\\'); - case '(', ')', ':', '<', '>', '*', '{', '}' -> sb.append(inputText.charAt(i++)); - default -> { - if (isEscapedKeywordSequence(inputText, i)) { - String sequence = handleKeywordSequence(inputText, i); - sb.append(sequence); - i += sequence.length(); - } else { - sb.append('\\').append(inputText.charAt(i++)); - } + if (isEscapedKeywordSequence(inputText, ++i)) { + String sequence = handleKeywordSequence(inputText, i); + sb.append(sequence); + i += sequence.length(); + } else { + switch (currentChar = inputText.charAt(i++)) { + case 't' -> sb.append('\t'); + case 'n' -> sb.append('\n'); + case 'r' -> sb.append('\r'); + case '"' -> sb.append('\"'); + case 'u' -> i = handleUnicodePoints(ctx, sb, inputText, i); + case '\\' -> sb.append('\\'); + case '(', ')', ':', '<', '>', '*', '{', '}' -> sb.append(currentChar); + default -> sb.append('\\').append(currentChar); } } } else { @@ -177,9 +170,9 @@ private static boolean isEscapedKeywordSequence(String input, int startIndex) { private static String handleKeywordSequence(String input, int startIndex) { String remaining = input.substring(startIndex); - if (Strings.toRootLowerCase(remaining).startsWith("and")) return remaining.substring(0, 2); - if (Strings.toRootLowerCase(remaining).startsWith("or")) return remaining.substring(0, 1); - if (Strings.toRootLowerCase(remaining).startsWith("not")) return remaining.substring(0, 2); + if (Strings.toRootLowerCase(remaining).startsWith("and")) return remaining.substring(0, 3); + if (Strings.toRootLowerCase(remaining).startsWith("or")) return remaining.substring(0, 2); + if (Strings.toRootLowerCase(remaining).startsWith("not")) return remaining.substring(0, 3); return ""; } diff --git a/x-pack/plugin/kql/src/test/java/org/elasticsearch/xpack/kql/parser/KqlParserFieldlessQueryTests.java b/x-pack/plugin/kql/src/test/java/org/elasticsearch/xpack/kql/parser/KqlParserFieldlessQueryTests.java index 2eb91ac1a05b0..1fe0c6effa89a 100644 --- a/x-pack/plugin/kql/src/test/java/org/elasticsearch/xpack/kql/parser/KqlParserFieldlessQueryTests.java +++ b/x-pack/plugin/kql/src/test/java/org/elasticsearch/xpack/kql/parser/KqlParserFieldlessQueryTests.java @@ -31,9 +31,9 @@ public void testMatchPhraseQuery() { assertMultiMatchQuery(parseKqlQuery("\"foo\""), "foo", MultiMatchQueryBuilder.Type.PHRASE); // Multiple words assertMultiMatchQuery(parseKqlQuery("\"foo bar\""), "foo bar", MultiMatchQueryBuilder.Type.PHRASE); - // Containing unescaped language reserved keyword + // Containing unescaped KQL reserved keyword assertMultiMatchQuery(parseKqlQuery("\"not foo and bar or baz\""), "not foo and bar or baz", MultiMatchQueryBuilder.Type.PHRASE); - // Containing unescaped language reserved characters + // Containing unescaped KQL reserved characters assertMultiMatchQuery(parseKqlQuery("\"foo*: {(})\""), "foo*: {(})", MultiMatchQueryBuilder.Type.PHRASE); } diff --git a/x-pack/plugin/kql/src/test/java/org/elasticsearch/xpack/kql/parser/ParserUtilsTests.java b/x-pack/plugin/kql/src/test/java/org/elasticsearch/xpack/kql/parser/ParserUtilsTests.java new file mode 100644 index 0000000000000..f5f68c393cb41 --- /dev/null +++ b/x-pack/plugin/kql/src/test/java/org/elasticsearch/xpack/kql/parser/ParserUtilsTests.java @@ -0,0 +1,211 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.kql.parser; + +import org.antlr.v4.runtime.ParserRuleContext; +import org.antlr.v4.runtime.Token; +import org.antlr.v4.runtime.tree.ParseTree; +import org.antlr.v4.runtime.tree.TerminalNode; +import org.antlr.v4.runtime.tree.TerminalNodeImpl; +import org.elasticsearch.test.ESTestCase; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.stream.Stream; + +import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.xpack.kql.parser.KqlBaseParser.QUOTED_STRING; +import static org.elasticsearch.xpack.kql.parser.KqlBaseParser.UNQUOTED_LITERAL; +import static org.elasticsearch.xpack.kql.parser.KqlBaseParser.WILDCARD; +import static org.elasticsearch.xpack.kql.parser.ParserUtils.extractText; +import static org.elasticsearch.xpack.kql.parser.ParserUtils.hasWildcard; +import static org.hamcrest.Matchers.equalTo; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class ParserUtilsTests extends ESTestCase { + + public void testExtractTestWithQuotedString() { + // General case + assertThat(extractText(parserRuleContext(quotedStringNode("foo"))), equalTo("foo")); + + // Whitespaces are preserved + assertThat(extractText(parserRuleContext(quotedStringNode(" foo bar "))), equalTo(" foo bar ")); + + // Quoted string does not need escaping for KQL keywords (and, or, ...) + assertThat(extractText(parserRuleContext(quotedStringNode("not foo and bar or baz"))), equalTo("not foo and bar or baz")); + + // Quoted string does not need escaping for KQL special chars (e.g: '{', ':', ...) + assertThat(extractText(parserRuleContext(quotedStringNode("foo*:'\u3000{(})"))), equalTo("foo*:'\u3000{(})")); + + // Escaped characters handling + assertThat(extractText(parserRuleContext(quotedStringNode("\\\\"))), equalTo("\\")); + assertThat(extractText(parserRuleContext(quotedStringNode("foo\\\\bar"))), equalTo("foo\\bar")); + assertThat(extractText(parserRuleContext(quotedStringNode("foo\\\\"))), equalTo("foo\\")); + assertThat(extractText(parserRuleContext(quotedStringNode("\\\\foo"))), equalTo("\\foo")); + + assertThat(extractText(parserRuleContext(quotedStringNode("\\\""))), equalTo("\"")); + assertThat(extractText(parserRuleContext(quotedStringNode("foo\\\"bar"))), equalTo("foo\"bar")); + assertThat(extractText(parserRuleContext(quotedStringNode("foo\\\""))), equalTo("foo\"")); + assertThat(extractText(parserRuleContext(quotedStringNode("\\\"foo"))), equalTo("\"foo")); + + assertThat(extractText(parserRuleContext(quotedStringNode("\\t"))), equalTo("\t")); + assertThat(extractText(parserRuleContext(quotedStringNode("foo\\tbar"))), equalTo("foo\tbar")); + assertThat(extractText(parserRuleContext(quotedStringNode("foo\\t"))), equalTo("foo\t")); + assertThat(extractText(parserRuleContext(quotedStringNode("\\tfoo"))), equalTo("\tfoo")); + + assertThat(extractText(parserRuleContext(quotedStringNode("\\n"))), equalTo("\n")); + assertThat(extractText(parserRuleContext(quotedStringNode("foo\\nbar"))), equalTo("foo\nbar")); + assertThat(extractText(parserRuleContext(quotedStringNode("foo\\n"))), equalTo("foo\n")); + assertThat(extractText(parserRuleContext(quotedStringNode("\\nfoo"))), equalTo("\nfoo")); + + assertThat(extractText(parserRuleContext(quotedStringNode("\\r"))), equalTo("\r")); + assertThat(extractText(parserRuleContext(quotedStringNode("foo\\rbar"))), equalTo("foo\rbar")); + assertThat(extractText(parserRuleContext(quotedStringNode("foo\\r"))), equalTo("foo\r")); + assertThat(extractText(parserRuleContext(quotedStringNode("\\rfoo"))), equalTo("\rfoo")); + + // Unicode characters handling + assertThat(extractText(parserRuleContext(quotedStringNode(format("\u0041")))), equalTo("A")); + assertThat(extractText(parserRuleContext(quotedStringNode(format("foo\u0041bar")))), equalTo("fooAbar")); + assertThat(extractText(parserRuleContext(quotedStringNode(format("foo\u0041")))), equalTo("fooA")); + assertThat(extractText(parserRuleContext(quotedStringNode(format("\u0041foo")))), equalTo("Afoo")); + } + + public void testExtractTestWithUnquotedLiteral() { + // General case + assertThat(extractText(parserRuleContext(literalNode("foo"))), equalTo("foo")); + + // KQL keywords unescaping + assertThat(extractText(parserRuleContext(literalNode("\\not foo \\and bar \\or baz"))), equalTo("not foo and bar or baz")); + assertThat( + extractText(parserRuleContext(literalNode("\\\\not foo \\\\and bar \\\\or baz"))), + equalTo("\\not foo \\and bar \\or baz") + ); + + // Escaped characters handling + assertThat(extractText(parserRuleContext(literalNode("\\\\"))), equalTo("\\")); + assertThat(extractText(parserRuleContext(literalNode("foo\\\\bar"))), equalTo("foo\\bar")); + assertThat(extractText(parserRuleContext(literalNode("foo\\\\"))), equalTo("foo\\")); + assertThat(extractText(parserRuleContext(literalNode("\\\\foo"))), equalTo("\\foo")); + + assertThat(extractText(parserRuleContext(literalNode("\\\""))), equalTo("\"")); + assertThat(extractText(parserRuleContext(literalNode("foo\\\"bar"))), equalTo("foo\"bar")); + assertThat(extractText(parserRuleContext(literalNode("foo\\\""))), equalTo("foo\"")); + assertThat(extractText(parserRuleContext(literalNode("\\\"foo"))), equalTo("\"foo")); + + assertThat(extractText(parserRuleContext(literalNode("\\t"))), equalTo("\t")); + assertThat(extractText(parserRuleContext(literalNode("foo\\tbar"))), equalTo("foo\tbar")); + assertThat(extractText(parserRuleContext(literalNode("foo\\t"))), equalTo("foo\t")); + assertThat(extractText(parserRuleContext(literalNode("\\tfoo"))), equalTo("\tfoo")); + + assertThat(extractText(parserRuleContext(literalNode("\\n"))), equalTo("\n")); + assertThat(extractText(parserRuleContext(literalNode("foo\\nbar"))), equalTo("foo\nbar")); + assertThat(extractText(parserRuleContext(literalNode("foo\\n"))), equalTo("foo\n")); + assertThat(extractText(parserRuleContext(literalNode("\\nfoo"))), equalTo("\nfoo")); + + assertThat(extractText(parserRuleContext(literalNode("\\r"))), equalTo("\r")); + assertThat(extractText(parserRuleContext(literalNode("foo\\rbar"))), equalTo("foo\rbar")); + assertThat(extractText(parserRuleContext(literalNode("foo\\r"))), equalTo("foo\r")); + assertThat(extractText(parserRuleContext(literalNode("\\rfoo"))), equalTo("\rfoo")); + + for (String escapedChar : List.of("(", ")", ":", "<", ">", "*", "{", "}")) { + assertThat(extractText(parserRuleContext(literalNode(format("\\%s", escapedChar)))), equalTo(escapedChar)); + assertThat( + extractText(parserRuleContext(literalNode(format("foo\\%sbar", escapedChar)))), + equalTo(format("foo%sbar", escapedChar)) + ); + assertThat(extractText(parserRuleContext(literalNode(format("foo\\%s", escapedChar)))), equalTo(format("foo%s", escapedChar))); + assertThat(extractText(parserRuleContext(literalNode(format("\\%sfoo", escapedChar)))), equalTo(format("%sfoo", escapedChar))); + } + + // Unicode characters handling + assertThat(extractText(parserRuleContext(literalNode(format("\u0041")))), equalTo("A")); + assertThat(extractText(parserRuleContext(literalNode(format("foo\u0041bar")))), equalTo("fooAbar")); + assertThat(extractText(parserRuleContext(literalNode(format("foo\u0041")))), equalTo("fooA")); + assertThat(extractText(parserRuleContext(literalNode(format("\u0041foo")))), equalTo("Afoo")); + } + + public void testHasWildcard() { + { + // No children + assertFalse(hasWildcard(parserRuleContext(List.of()))); + + // Lone wildcard + assertTrue(hasWildcard(parserRuleContext(wildcardNode()))); + assertTrue(hasWildcard(parserRuleContext(randomTextNodeListWithNode(wildcardNode())))); + + // All children are literals + assertFalse(hasWildcard(parserRuleContext(randomList(1, randomInt(100), this::randomLiteralNode)))); + + // Quoted string + assertFalse(hasWildcard(parserRuleContext(randomQuotedStringNode()))); + + // Literal node containing the wildcard character + assertTrue(hasWildcard(parserRuleContext(terminalNode(UNQUOTED_LITERAL, "f*oo")))); + assertTrue(hasWildcard(parserRuleContext(terminalNode(UNQUOTED_LITERAL, "*foo")))); + assertTrue(hasWildcard(parserRuleContext(terminalNode(UNQUOTED_LITERAL, "foo*")))); + + // Literal node containing the wildcard characters (escaped) + assertFalse(hasWildcard(parserRuleContext(terminalNode(UNQUOTED_LITERAL, "f\\*oo")))); + assertFalse(hasWildcard(parserRuleContext(terminalNode(UNQUOTED_LITERAL, "\\*foo")))); + assertFalse(hasWildcard(parserRuleContext(terminalNode(UNQUOTED_LITERAL, "foo\\*")))); + } + } + + private ParserRuleContext parserRuleContext(ParseTree child) { + return parserRuleContext(List.of(child)); + } + + private ParserRuleContext parserRuleContext(List children) { + ParserRuleContext ctx = new ParserRuleContext(null, randomInt()); + ctx.children = children; + return ctx; + } + + private TerminalNode terminalNode(int type, String text) { + Token symbol = mock(Token.class); + when(symbol.getType()).thenReturn(type); + when(symbol.getText()).thenReturn(text); + return new TerminalNodeImpl(symbol); + } + + private List randomTextNodeListWithNode(TerminalNode node) { + List nodes = new ArrayList<>(Stream.concat(Stream.generate(this::randomTextNode).limit(100), Stream.of(node)).toList()); + Collections.shuffle(nodes, random()); + return nodes; + } + + private TerminalNode randomTextNode() { + return switch (randomInt() % 3) { + case 0 -> wildcardNode(); + case 1 -> randomQuotedStringNode(); + default -> randomLiteralNode(); + }; + } + + private TerminalNode quotedStringNode(String quotedStringText) { + return terminalNode(QUOTED_STRING, "\"" + quotedStringText + "\""); + } + + private TerminalNode randomQuotedStringNode() { + return quotedStringNode(randomIdentifier()); + } + + private TerminalNode literalNode(String literalText) { + return terminalNode(UNQUOTED_LITERAL, literalText); + } + + private TerminalNode randomLiteralNode() { + return terminalNode(UNQUOTED_LITERAL, randomIdentifier()); + } + + private TerminalNode wildcardNode() { + return terminalNode(WILDCARD, "*"); + } +}