From 74e1f87b09b8347217dad30b29f967b3a1eb619a Mon Sep 17 00:00:00 2001 From: Aurelien FOUCRET Date: Fri, 25 Oct 2024 15:05:08 +0200 Subject: [PATCH] Implements test parsing for the KQL parser. --- .../xpack/kql/parser/ParserUtils.java | 211 +++++++++++++ .../xpack/kql/parser/ParserUtilsTests.java | 277 ++++++++++++++++++ 2 files changed, 488 insertions(+) create mode 100644 x-pack/plugin/kql/src/main/java/org/elasticsearch/xpack/kql/parser/ParserUtils.java create mode 100644 x-pack/plugin/kql/src/test/java/org/elasticsearch/xpack/kql/parser/ParserUtilsTests.java diff --git a/x-pack/plugin/kql/src/main/java/org/elasticsearch/xpack/kql/parser/ParserUtils.java b/x-pack/plugin/kql/src/main/java/org/elasticsearch/xpack/kql/parser/ParserUtils.java new file mode 100644 index 0000000000000..c1ddf18deb572 --- /dev/null +++ b/x-pack/plugin/kql/src/main/java/org/elasticsearch/xpack/kql/parser/ParserUtils.java @@ -0,0 +1,211 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.kql.parser; + +import org.antlr.v4.runtime.ParserRuleContext; +import org.antlr.v4.runtime.Token; +import org.antlr.v4.runtime.tree.ParseTree; +import org.antlr.v4.runtime.tree.ParseTreeVisitor; +import org.antlr.v4.runtime.tree.TerminalNode; +import org.apache.logging.log4j.util.Strings; +import org.apache.lucene.queryparser.classic.QueryParser; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public final class ParserUtils { + + private static final String UNQUOTED_LITERAL_TERM_DELIMITER = " "; + private static final char ESCAPE_CHAR = '\\'; + private static final char QUOTE_CHAR = '"'; + + private ParserUtils() { + + } + + @SuppressWarnings("unchecked") + public static T typedParsing(ParseTreeVisitor visitor, ParserRuleContext ctx, Class type) { + Object result = ctx.accept(visitor); + + if (type.isInstance(result)) { + return (T) result; + } + + throw new KqlParsingException( + "Invalid query '{}'[{}] given; expected {} but found {}", + ctx.start.getLine(), + ctx.start.getCharPositionInLine(), + ctx.getText(), + ctx.getClass().getSimpleName(), + type.getSimpleName(), + (result != null ? result.getClass().getSimpleName() : "null") + ); + } + + public static String extractText(ParserRuleContext ctx) { + return String.join(UNQUOTED_LITERAL_TERM_DELIMITER, extractTextTokems(ctx)); + } + + public static boolean hasWildcard(ParserRuleContext ctx) { + return ctx.children.stream().anyMatch(childNode -> { + if (childNode instanceof TerminalNode terminalNode) { + Token token = terminalNode.getSymbol(); + return switch (token.getType()) { + case KqlBaseParser.WILDCARD -> true; + case KqlBaseParser.UNQUOTED_LITERAL -> token.getText().matches("[^\\\\]*[*].*"); + default -> false; + }; + } + + return false; + }); + } + + public static String escapeQueryString(String queryText, boolean preseveWildcards) { + if (preseveWildcards) { + return Stream.of(queryText.split("[*]]")).map(QueryParser::escape).collect(Collectors.joining("*")); + } + + return QueryParser.escape(queryText); + } + + private static List extractTextTokems(ParserRuleContext ctx) { + assert ctx.children != null; + List textTokens = new ArrayList<>(ctx.children.size()); + + for (ParseTree currentNode : ctx.children) { + if (currentNode instanceof TerminalNode terminalNode) { + textTokens.add(extractText(terminalNode)); + } else { + throw new KqlParsingException("Unable to extract text from ctx", ctx.start.getLine(), ctx.start.getCharPositionInLine()); + } + } + + return textTokens; + } + + private static String extractText(TerminalNode node) { + if (node.getSymbol().getType() == KqlBaseParser.QUOTED_STRING) { + return unescapeQuotedString(node); + } else if (node.getSymbol().getType() == KqlBaseParser.UNQUOTED_LITERAL) { + return unescapeUnquotedLiteral(node); + } + + return node.getText(); + } + + private static String unescapeQuotedString(TerminalNode ctx) { + String inputText = ctx.getText(); + + assert inputText.length() >= 2 && inputText.charAt(0) == QUOTE_CHAR && inputText.charAt(inputText.length() - 1) == QUOTE_CHAR; + StringBuilder sb = new StringBuilder(); + + for (int i = 1; i < inputText.length() - 1;) { + char currentChar = inputText.charAt(i++); + if (currentChar == ESCAPE_CHAR && i + 1 < inputText.length()) { + currentChar = inputText.charAt(i++); + switch (currentChar) { + case 't' -> sb.append('\t'); + case 'n' -> sb.append('\n'); + case 'r' -> sb.append('\r'); + case 'u' -> i = handleUnicodeSequemce(ctx, sb, inputText, i); + case QUOTE_CHAR -> sb.append('\"'); + case ESCAPE_CHAR -> sb.append(ESCAPE_CHAR); + default -> sb.append(ESCAPE_CHAR).append(currentChar); + } + } else { + sb.append(currentChar); + } + } + + return sb.toString(); + } + + private static String unescapeUnquotedLiteral(TerminalNode ctx) { + String inputText = ctx.getText(); + + if (inputText == null || inputText.isEmpty()) { + return inputText; + } + StringBuilder sb = new StringBuilder(inputText.length()); + + for (int i = 0; i < inputText.length();) { + char currentChar = inputText.charAt(i++); + if (currentChar == ESCAPE_CHAR && i < inputText.length()) { + if (isEscapedKeywordSequence(inputText, i)) { + String sequence = handleKeywordSequence(inputText, i); + sb.append(sequence); + i += sequence.length(); + } else { + currentChar = inputText.charAt(i++); + switch (currentChar) { + case 't' -> sb.append('\t'); + case 'n' -> sb.append('\n'); + case 'r' -> sb.append('\r'); + case 'u' -> i = handleUnicodeSequemce(ctx, sb, inputText, i); + case QUOTE_CHAR -> sb.append('\"'); + case ESCAPE_CHAR -> sb.append(ESCAPE_CHAR); + case '(', ')', ':', '<', '>', '*', '{', '}' -> sb.append(currentChar); + default -> sb.append(ESCAPE_CHAR).append(currentChar); + } + } + } else { + sb.append(currentChar); + } + } + + return sb.toString(); + } + + private static boolean isEscapedKeywordSequence(String input, int startIndex) { + if (startIndex + 1 >= input.length()) { + return false; + } + String remaining = Strings.toRootLowerCase(input.substring(startIndex)); + return remaining.startsWith("and") || remaining.startsWith("or") || remaining.startsWith("not"); + } + + private static String handleKeywordSequence(String input, int startIndex) { + String remaining = input.substring(startIndex); + if (Strings.toRootLowerCase(remaining).startsWith("and")) return remaining.substring(0, 3); + if (Strings.toRootLowerCase(remaining).startsWith("or")) return remaining.substring(0, 2); + if (Strings.toRootLowerCase(remaining).startsWith("not")) return remaining.substring(0, 3); + return ""; + } + + private static int handleUnicodeSequemce(TerminalNode ctx, StringBuilder sb, String text, int startIdx) { + int endIdx = startIdx + 4; + String hex = text.substring(startIdx, endIdx); + + try { + int code = Integer.parseInt(hex, 16); + + if (code >= 0xD800 && code <= 0xDFFF) { + // U+D800—U+DFFF can only be used as surrogate pairs and are not valid character codes. + throw new KqlParsingException( + "Invalid unicode character code, [{}] is a surrogate code", + ctx.getSymbol().getLine(), + ctx.getSymbol().getCharPositionInLine() + startIdx, + hex + ); + } + sb.append(String.valueOf(Character.toChars(code))); + } catch (IllegalArgumentException e) { + throw new KqlParsingException( + "Invalid unicode character code [{}]", + ctx.getSymbol().getLine(), + ctx.getSymbol().getCharPositionInLine() + startIdx, + hex + ); + } + + return endIdx; + } +} diff --git a/x-pack/plugin/kql/src/test/java/org/elasticsearch/xpack/kql/parser/ParserUtilsTests.java b/x-pack/plugin/kql/src/test/java/org/elasticsearch/xpack/kql/parser/ParserUtilsTests.java new file mode 100644 index 0000000000000..9cd148f63262a --- /dev/null +++ b/x-pack/plugin/kql/src/test/java/org/elasticsearch/xpack/kql/parser/ParserUtilsTests.java @@ -0,0 +1,277 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.kql.parser; + +import org.antlr.v4.runtime.ParserRuleContext; +import org.antlr.v4.runtime.Token; +import org.antlr.v4.runtime.tree.ParseTree; +import org.antlr.v4.runtime.tree.TerminalNode; +import org.antlr.v4.runtime.tree.TerminalNodeImpl; +import org.elasticsearch.test.ESTestCase; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.stream.Stream; + +import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.xpack.kql.parser.KqlBaseParser.QUOTED_STRING; +import static org.elasticsearch.xpack.kql.parser.KqlBaseParser.UNQUOTED_LITERAL; +import static org.elasticsearch.xpack.kql.parser.KqlBaseParser.WILDCARD; +import static org.elasticsearch.xpack.kql.parser.ParserUtils.escapeQueryString; +import static org.elasticsearch.xpack.kql.parser.ParserUtils.extractText; +import static org.elasticsearch.xpack.kql.parser.ParserUtils.hasWildcard; +import static org.hamcrest.Matchers.equalTo; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class ParserUtilsTests extends ESTestCase { + + public void testExtractTestWithQuotedString() { + // General case + assertThat(extractText(parserRuleContext(quotedStringNode("foo"))), equalTo("foo")); + + // Empty string + assertThat(extractText(parserRuleContext(quotedStringNode(""))), equalTo("")); + + // Whitespaces are preserved + assertThat(extractText(parserRuleContext(quotedStringNode(" foo bar "))), equalTo(" foo bar ")); + + // Quoted string does not need escaping for KQL keywords (and, or, ...) + assertThat(extractText(parserRuleContext(quotedStringNode("not foo and bar or baz"))), equalTo("not foo and bar or baz")); + + // Quoted string does not need escaping for KQL special chars (e.g: '{', ':', ...) + assertThat(extractText(parserRuleContext(quotedStringNode("foo*:'\u3000{(})"))), equalTo("foo*:'\u3000{(})")); + + // Escaped characters handling + assertThat(extractText(parserRuleContext(quotedStringNode("\\\\"))), equalTo("\\")); + assertThat(extractText(parserRuleContext(quotedStringNode("foo\\\\bar"))), equalTo("foo\\bar")); + assertThat(extractText(parserRuleContext(quotedStringNode("foo\\\\"))), equalTo("foo\\")); + assertThat(extractText(parserRuleContext(quotedStringNode("\\\\foo"))), equalTo("\\foo")); + + assertThat(extractText(parserRuleContext(quotedStringNode("\\\""))), equalTo("\"")); + assertThat(extractText(parserRuleContext(quotedStringNode("foo\\\"bar"))), equalTo("foo\"bar")); + assertThat(extractText(parserRuleContext(quotedStringNode("foo\\\""))), equalTo("foo\"")); + assertThat(extractText(parserRuleContext(quotedStringNode("\\\"foo"))), equalTo("\"foo")); + + assertThat(extractText(parserRuleContext(quotedStringNode("\\t"))), equalTo("\t")); + assertThat(extractText(parserRuleContext(quotedStringNode("foo\\tbar"))), equalTo("foo\tbar")); + assertThat(extractText(parserRuleContext(quotedStringNode("foo\\t"))), equalTo("foo\t")); + assertThat(extractText(parserRuleContext(quotedStringNode("\\tfoo"))), equalTo("\tfoo")); + + assertThat(extractText(parserRuleContext(quotedStringNode("\\n"))), equalTo("\n")); + assertThat(extractText(parserRuleContext(quotedStringNode("foo\\nbar"))), equalTo("foo\nbar")); + assertThat(extractText(parserRuleContext(quotedStringNode("foo\\n"))), equalTo("foo\n")); + assertThat(extractText(parserRuleContext(quotedStringNode("\\nfoo"))), equalTo("\nfoo")); + + assertThat(extractText(parserRuleContext(quotedStringNode("\\r"))), equalTo("\r")); + assertThat(extractText(parserRuleContext(quotedStringNode("foo\\rbar"))), equalTo("foo\rbar")); + assertThat(extractText(parserRuleContext(quotedStringNode("foo\\r"))), equalTo("foo\r")); + assertThat(extractText(parserRuleContext(quotedStringNode("\\rfoo"))), equalTo("\rfoo")); + + // Unicode characters handling (\u0041 is 'A') + assertThat(extractText(parserRuleContext(quotedStringNode(format("\\u0041")))), equalTo("A")); + assertThat(extractText(parserRuleContext(quotedStringNode(format("foo\\u0041bar")))), equalTo("fooAbar")); + assertThat(extractText(parserRuleContext(quotedStringNode(format("foo\\u0041")))), equalTo("fooA")); + assertThat(extractText(parserRuleContext(quotedStringNode(format("\\u0041foo")))), equalTo("Afoo")); + } + + public void testExtractTestWithUnquotedLiteral() { + // General case + assertThat(extractText(parserRuleContext(literalNode("foo"))), equalTo("foo")); + + // KQL keywords unescaping + assertThat(extractText(parserRuleContext(literalNode("\\not foo \\and bar \\or baz"))), equalTo("not foo and bar or baz")); + assertThat( + extractText(parserRuleContext(literalNode("\\\\not foo \\\\and bar \\\\or baz"))), + equalTo("\\not foo \\and bar \\or baz") + ); + + // Escaped characters handling + assertThat(extractText(parserRuleContext(literalNode("\\\\"))), equalTo("\\")); + assertThat(extractText(parserRuleContext(literalNode("foo\\\\bar"))), equalTo("foo\\bar")); + assertThat(extractText(parserRuleContext(literalNode("foo\\\\"))), equalTo("foo\\")); + assertThat(extractText(parserRuleContext(literalNode("\\\\foo"))), equalTo("\\foo")); + + assertThat(extractText(parserRuleContext(literalNode("\\\""))), equalTo("\"")); + assertThat(extractText(parserRuleContext(literalNode("foo\\\"bar"))), equalTo("foo\"bar")); + assertThat(extractText(parserRuleContext(literalNode("foo\\\""))), equalTo("foo\"")); + assertThat(extractText(parserRuleContext(literalNode("\\\"foo"))), equalTo("\"foo")); + + assertThat(extractText(parserRuleContext(literalNode("\\t"))), equalTo("\t")); + assertThat(extractText(parserRuleContext(literalNode("foo\\tbar"))), equalTo("foo\tbar")); + assertThat(extractText(parserRuleContext(literalNode("foo\\t"))), equalTo("foo\t")); + assertThat(extractText(parserRuleContext(literalNode("\\tfoo"))), equalTo("\tfoo")); + + assertThat(extractText(parserRuleContext(literalNode("\\n"))), equalTo("\n")); + assertThat(extractText(parserRuleContext(literalNode("foo\\nbar"))), equalTo("foo\nbar")); + assertThat(extractText(parserRuleContext(literalNode("foo\\n"))), equalTo("foo\n")); + assertThat(extractText(parserRuleContext(literalNode("\\nfoo"))), equalTo("\nfoo")); + + assertThat(extractText(parserRuleContext(literalNode("\\r"))), equalTo("\r")); + assertThat(extractText(parserRuleContext(literalNode("foo\\rbar"))), equalTo("foo\rbar")); + assertThat(extractText(parserRuleContext(literalNode("foo\\r"))), equalTo("foo\r")); + assertThat(extractText(parserRuleContext(literalNode("\\rfoo"))), equalTo("\rfoo")); + + for (String escapedChar : List.of("(", ")", ":", "<", ">", "*", "{", "}")) { + assertThat(extractText(parserRuleContext(literalNode(format("\\%s", escapedChar)))), equalTo(escapedChar)); + assertThat( + extractText(parserRuleContext(literalNode(format("foo\\%sbar", escapedChar)))), + equalTo(format("foo%sbar", escapedChar)) + ); + assertThat(extractText(parserRuleContext(literalNode(format("foo\\%s", escapedChar)))), equalTo(format("foo%s", escapedChar))); + assertThat(extractText(parserRuleContext(literalNode(format("\\%sfoo", escapedChar)))), equalTo(format("%sfoo", escapedChar))); + } + + // Unicode characters handling (\u0041 is 'A') + assertThat(extractText(parserRuleContext(literalNode(format("\\u0041")))), equalTo("A")); + assertThat(extractText(parserRuleContext(literalNode(format("foo\\u0041bar")))), equalTo("fooAbar")); + assertThat(extractText(parserRuleContext(literalNode(format("foo\\u0041")))), equalTo("fooA")); + assertThat(extractText(parserRuleContext(literalNode(format("\\u0041foo")))), equalTo("Afoo")); + } + + public void testHasWildcard() { + // No children + assertFalse(hasWildcard(parserRuleContext(List.of()))); + + // Lone wildcard + assertTrue(hasWildcard(parserRuleContext(wildcardNode()))); + assertTrue(hasWildcard(parserRuleContext(randomTextNodeListWithNode(wildcardNode())))); + + // All children are literals + assertFalse(hasWildcard(parserRuleContext(randomList(1, randomInt(100), ParserUtilsTests::randomLiteralNode)))); + + // Quoted string + assertFalse(hasWildcard(parserRuleContext(randomQuotedStringNode()))); + + // Literal node containing the wildcard character + assertTrue(hasWildcard(parserRuleContext(literalNode("f*oo")))); + assertTrue(hasWildcard(parserRuleContext(literalNode("*foo")))); + assertTrue(hasWildcard(parserRuleContext(literalNode("foo*")))); + + // Literal node containing the wildcard characters (escaped) + assertFalse(hasWildcard(parserRuleContext(literalNode("f\\*oo")))); + assertFalse(hasWildcard(parserRuleContext(literalNode("\\*foo")))); + assertFalse(hasWildcard(parserRuleContext(literalNode("foo\\*")))); + } + + public void testUnquotedLiteralInvalidUnicodeCodeParsing() { + { + // Invalid unicode digit (G) + ParserRuleContext ctx = parserRuleContext(literalNode("\\u0G41")); + KqlParsingException e = assertThrows(KqlParsingException.class, () -> extractText(ctx)); + assertThat(e.getMessage(), equalTo("line 0:4: Invalid unicode character code [0G41]")); + } + + { + // U+D800—U+DFFF can only be used as surrogate pairs and are not valid character codes. + ParserRuleContext ctx = parserRuleContext(literalNode("\\uD900")); + KqlParsingException e = assertThrows(KqlParsingException.class, () -> extractText(ctx)); + assertThat(e.getMessage(), equalTo("line 0:4: Invalid unicode character code, [D900] is a surrogate code")); + } + } + + public void testQuotedStringInvalidUnicodeCodeParsing() { + { + // Invalid unicode digit (G) + ParserRuleContext ctx = parserRuleContext(quotedStringNode("\\u0G41")); + KqlParsingException e = assertThrows(KqlParsingException.class, () -> extractText(ctx)); + assertThat(e.getMessage(), equalTo("line 0:4: Invalid unicode character code [0G41]")); + } + + { + // U+D800—U+DFFF can only be used as surrogate pairs and are not valid character codes. + ParserRuleContext ctx = parserRuleContext(quotedStringNode("\\uD900")); + KqlParsingException e = assertThrows(KqlParsingException.class, () -> extractText(ctx)); + assertThat(e.getMessage(), equalTo("line 0:4: Invalid unicode character code, [D900] is a surrogate code")); + } + } + + public void testEscapeQueryString() { + // Quotes + assertThat(escapeQueryString("\"The Pink Panther\"", true), equalTo("\\\"The Pink Panther\\\"")); + + // Escape chars + assertThat(escapeQueryString("The Pink \\ Panther", true), equalTo("The Pink \\\\ Panther")); + + // Field operations + assertThat(escapeQueryString("title:Do it right", true), equalTo("title\\:Do it right")); + assertThat(escapeQueryString("title:(pink panther)", true), equalTo("title\\:\\(pink panther\\)")); + assertThat(escapeQueryString("title:-pink", true), equalTo("title\\:\\-pink")); + assertThat(escapeQueryString("title:+pink", true), equalTo("title\\:\\+pink")); + assertThat(escapeQueryString("title:pink~", true), equalTo("title\\:pink\\~")); + assertThat(escapeQueryString("title:pink~3.5", true), equalTo("title\\:pink\\~3.5")); + assertThat(escapeQueryString("title:pink panther^4", true), equalTo("title\\:pink panther\\^4")); + assertThat(escapeQueryString("rating:[0 TO 5]", true), equalTo("rating\\:\\[0 TO 5\\]")); + assertThat(escapeQueryString("rating:{0 TO 5}", true), equalTo("rating\\:\\{0 TO 5\\}")); + + // Boolean operators + assertThat(escapeQueryString("foo || bar", true), equalTo("foo \\|\\| bar")); + assertThat(escapeQueryString("foo && bar", true), equalTo("foo \\&\\& bar")); + assertThat(escapeQueryString("!foo", true), equalTo("\\!foo")); + + // Wildcards: + assertThat(escapeQueryString("te?t", true), equalTo("te\\?t")); + assertThat(escapeQueryString("foo*", false), equalTo("foo\\*")); + } + + private static ParserRuleContext parserRuleContext(ParseTree child) { + return parserRuleContext(List.of(child)); + } + + private static ParserRuleContext parserRuleContext(List children) { + ParserRuleContext ctx = new ParserRuleContext(null, randomInt()); + ctx.children = children; + return ctx; + } + + private static TerminalNode terminalNode(int type, String text) { + Token symbol = mock(Token.class); + when(symbol.getType()).thenReturn(type); + when(symbol.getText()).thenReturn(text); + when(symbol.getLine()).thenReturn(0); + when(symbol.getCharPositionInLine()).thenReturn(0); + return new TerminalNodeImpl(symbol); + } + + private static List randomTextNodeListWithNode(TerminalNode node) { + List nodes = new ArrayList<>( + Stream.concat(Stream.generate(ParserUtilsTests::randomTextNode).limit(100), Stream.of(node)).toList() + ); + Collections.shuffle(nodes, random()); + return nodes; + } + + private static TerminalNode randomTextNode() { + return switch (randomInt() % 3) { + case 0 -> wildcardNode(); + case 1 -> randomQuotedStringNode(); + default -> randomLiteralNode(); + }; + } + + private static TerminalNode quotedStringNode(String quotedStringText) { + return terminalNode(QUOTED_STRING, "\"" + quotedStringText + "\""); + } + + private static TerminalNode randomQuotedStringNode() { + return quotedStringNode(randomIdentifier()); + } + + private static TerminalNode literalNode(String literalText) { + return terminalNode(UNQUOTED_LITERAL, literalText); + } + + private static TerminalNode randomLiteralNode() { + return terminalNode(UNQUOTED_LITERAL, randomIdentifier()); + } + + private static TerminalNode wildcardNode() { + return terminalNode(WILDCARD, "*"); + } +}