diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSqlAction.java b/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSqlAction.java index 88ed42010b..5249d2d5d0 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSqlAction.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSqlAction.java @@ -141,8 +141,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli } } - LOG.info("[{}] Incoming request {}: {}", QueryContext.getRequestId(), request.uri(), - QueryDataAnonymizer.anonymizeData(sqlRequest.getSql())); + LOG.info("[{}] Incoming request {}", QueryContext.getRequestId(), request.uri()); Format format = SqlRequestParam.getFormat(request.params()); @@ -157,6 +156,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli } LOG.debug("[{}] Request {} is not supported and falling back to old SQL engine", QueryContext.getRequestId(), newSqlRequest); + LOG.info("Request Query: {}", QueryDataAnonymizer.anonymizeData(sqlRequest.getSql())); QueryAction queryAction = explainRequest(client, sqlRequest, format); executeSqlRequest(request, queryAction, client, restChannel); } catch (Exception e) { diff --git a/sql/src/main/java/org/opensearch/sql/sql/antlr/AnonymizerListener.java b/sql/src/main/java/org/opensearch/sql/sql/antlr/AnonymizerListener.java new file mode 100644 index 0000000000..7efd71a325 --- /dev/null +++ b/sql/src/main/java/org/opensearch/sql/sql/antlr/AnonymizerListener.java @@ -0,0 +1,107 @@ +package org.opensearch.sql.sql.antlr; + +import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLLexer.BACKTICK_QUOTE_ID; +import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLLexer.BOOLEAN; +import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLLexer.COMMA; +import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLLexer.DECIMAL_LITERAL; +import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLLexer.DOT; +import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLLexer.EQUAL_SYMBOL; +import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLLexer.EXCLAMATION_SYMBOL; +import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLLexer.FALSE; +import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLLexer.FROM; +import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLLexer.GREATER_SYMBOL; +import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLLexer.ID; +import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLLexer.LESS_SYMBOL; +import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLLexer.ONE_DECIMAL; +import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLLexer.REAL_LITERAL; +import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLLexer.STRING_LITERAL; +import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLLexer.TIMESTAMP; +import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLLexer.TRUE; +import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLLexer.TWO_DECIMAL; +import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLLexer.ZERO_DECIMAL; + +import org.antlr.v4.runtime.ParserRuleContext; +import org.antlr.v4.runtime.tree.ErrorNode; +import org.antlr.v4.runtime.tree.ParseTreeListener; +import org.antlr.v4.runtime.tree.TerminalNode; + +/** + * Parse tree listener for anonymizing SQL requests. + */ +public class AnonymizerListener implements ParseTreeListener { + private String anonymizedQueryString = ""; + private static final int NO_TYPE = -1; + private int previousType = NO_TYPE; + + @Override + public void enterEveryRule(ParserRuleContext ctx) { + } + + @Override + public void exitEveryRule(ParserRuleContext ctx) { + } + + @Override + public void visitTerminal(TerminalNode node) { + // In these situations don't add a space prior: + // 1. a DOT between two identifiers + // 2. before a comma + // 3. between equal comparison tokens: e.g <= + // 4. between alt not equals: <> + int token = node.getSymbol().getType(); + boolean isDotIdentifiers = token == DOT || previousType == DOT; + boolean isComma = token == COMMA; + boolean isEqualComparison = ((token == EQUAL_SYMBOL) + && (previousType == LESS_SYMBOL + || previousType == GREATER_SYMBOL + || previousType == EXCLAMATION_SYMBOL)); + boolean isNotEqualComparisonAlternative = + previousType == LESS_SYMBOL && token == GREATER_SYMBOL; + if (!isDotIdentifiers && !isComma && !isEqualComparison && !isNotEqualComparisonAlternative) { + anonymizedQueryString += " "; + } + + // anonymize the following tokens + switch (node.getSymbol().getType()) { + case ID: + case TIMESTAMP: + case BACKTICK_QUOTE_ID: + if (previousType == FROM) { + anonymizedQueryString += "table"; + } else { + anonymizedQueryString += "identifier"; + } + break; + case ZERO_DECIMAL: + case ONE_DECIMAL: + case TWO_DECIMAL: + case DECIMAL_LITERAL: + case REAL_LITERAL: + anonymizedQueryString += "number"; + break; + case STRING_LITERAL: + anonymizedQueryString += "'string_literal'"; + break; + case BOOLEAN: + case TRUE: + case FALSE: + anonymizedQueryString += "boolean_literal"; + break; + case NO_TYPE: + // end of file + break; + default: + anonymizedQueryString += node.getText().toUpperCase(); + } + previousType = node.getSymbol().getType(); + } + + @Override + public void visitErrorNode(ErrorNode node) { + + } + + public String getAnonymizedQueryString() { + return "(" + anonymizedQueryString + ")"; + } +} diff --git a/sql/src/main/java/org/opensearch/sql/sql/antlr/SQLSyntaxParser.java b/sql/src/main/java/org/opensearch/sql/sql/antlr/SQLSyntaxParser.java index ee1e991bd4..b5eea0495c 100644 --- a/sql/src/main/java/org/opensearch/sql/sql/antlr/SQLSyntaxParser.java +++ b/sql/src/main/java/org/opensearch/sql/sql/antlr/SQLSyntaxParser.java @@ -8,6 +8,8 @@ import org.antlr.v4.runtime.CommonTokenStream; import org.antlr.v4.runtime.tree.ParseTree; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.opensearch.sql.common.antlr.CaseInsensitiveCharStream; import org.opensearch.sql.common.antlr.Parser; import org.opensearch.sql.common.antlr.SyntaxAnalysisErrorListener; @@ -18,6 +20,12 @@ * SQL syntax parser which encapsulates an ANTLR parser. */ public class SQLSyntaxParser implements Parser { + private SyntaxAnalysisErrorListener syntaxAnalysisErrorListener; + private static final Logger LOG = LogManager.getLogger(SQLSyntaxParser.class); + + public SQLSyntaxParser() { + this.syntaxAnalysisErrorListener = new SyntaxAnalysisErrorListener(); + } /** * Parse a SQL query by ANTLR parser. @@ -26,10 +34,16 @@ public class SQLSyntaxParser implements Parser { */ @Override public ParseTree parse(String query) { + AnonymizerListener anonymizer = new AnonymizerListener(); + OpenSearchSQLLexer lexer = new OpenSearchSQLLexer(new CaseInsensitiveCharStream(query)); OpenSearchSQLParser parser = new OpenSearchSQLParser(new CommonTokenStream(lexer)); - parser.addErrorListener(new SyntaxAnalysisErrorListener()); - return parser.root(); - } + parser.addErrorListener(syntaxAnalysisErrorListener); + parser.addParseListener(anonymizer); + + ParseTree parseTree = parser.root(); + LOG.info("New Engine Request Query: {}", anonymizer.getAnonymizedQueryString()); + return parseTree; + } } diff --git a/sql/src/test/java/org/opensearch/sql/sql/parser/AnonymizerListenerTest.java b/sql/src/test/java/org/opensearch/sql/sql/parser/AnonymizerListenerTest.java new file mode 100644 index 0000000000..d7913d8dcb --- /dev/null +++ b/sql/src/test/java/org/opensearch/sql/sql/parser/AnonymizerListenerTest.java @@ -0,0 +1,211 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + + +package org.opensearch.sql.sql.parser; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.Mockito.mock; + +import org.antlr.v4.runtime.CommonTokenStream; +import org.antlr.v4.runtime.tree.ErrorNode; +import org.junit.jupiter.api.Test; +import org.opensearch.sql.common.antlr.CaseInsensitiveCharStream; +import org.opensearch.sql.common.antlr.SyntaxAnalysisErrorListener; +import org.opensearch.sql.sql.antlr.AnonymizerListener; +import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLLexer; +import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser; + +public class AnonymizerListenerTest { + + public AnonymizerListener anonymizerListener = new AnonymizerListener(); + + public SyntaxAnalysisErrorListener errorListener = new SyntaxAnalysisErrorListener(); + + /** + * Helper function to parse SQl queries for testing purposes. + * @param query SQL query to be anonymized. + */ + public void parse(String query) { + OpenSearchSQLLexer lexer = new OpenSearchSQLLexer(new CaseInsensitiveCharStream(query)); + OpenSearchSQLParser parser = new OpenSearchSQLParser(new CommonTokenStream(lexer)); + parser.addErrorListener(errorListener); + parser.addParseListener(anonymizerListener); + + parser.root(); + } + + @Test + public void queriesShouldHaveAnonymousFieldAndIndex() { + String query = "SELECT ABS(balance) FROM accounts WHERE age > 30 GROUP BY ABS(balance)"; + String expectedQuery = "( SELECT ABS ( identifier ) FROM table " + + "WHERE identifier > number GROUP BY ABS ( identifier ) )"; + parse(query); + assertEquals(expectedQuery, anonymizerListener.getAnonymizedQueryString()); + } + + @Test + public void queriesShouldAnonymousNumbers() { + String query = "SELECT ABS(20), LOG(20.20) FROM accounts"; + String expectedQuery = "( SELECT ABS ( number ), LOG ( number ) FROM table )"; + parse(query); + assertEquals(expectedQuery, anonymizerListener.getAnonymizedQueryString()); + } + + @Test + public void queriesShouldHaveAnonymousBooleanLiterals() { + String query = "SELECT TRUE FROM accounts"; + String expectedQuery = "( SELECT boolean_literal FROM table )"; + parse(query); + assertEquals(expectedQuery, anonymizerListener.getAnonymizedQueryString()); + } + + @Test + public void queriesShouldHaveAnonymousInputStrings() { + String query = "SELECT * FROM accounts WHERE name = 'Oliver'"; + String expectedQuery = "( SELECT * FROM table WHERE identifier = 'string_literal' )"; + parse(query); + assertEquals(expectedQuery, anonymizerListener.getAnonymizedQueryString()); + } + + @Test + public void queriesWithAliasesShouldAnonymizeSensitiveData() { + String query = "SELECT balance AS b FROM accounts AS a"; + String expectedQuery = "( SELECT identifier AS identifier FROM table AS identifier )"; + parse(query); + assertEquals(expectedQuery, anonymizerListener.getAnonymizedQueryString()); + } + + @Test + public void queriesWithFunctionsShouldAnonymizeSensitiveData() { + String query = "SELECT LTRIM(firstname) FROM accounts"; + String expectedQuery = "( SELECT LTRIM ( identifier ) FROM table )"; + parse(query); + assertEquals(expectedQuery, anonymizerListener.getAnonymizedQueryString()); + } + + @Test + public void queriesWithAggregatesShouldAnonymizeSensitiveData() { + String query = "SELECT MAX(price) - MIN(price) from tickets"; + String expectedQuery = "( SELECT MAX ( identifier ) - MIN ( identifier ) FROM table )"; + parse(query); + assertEquals(expectedQuery, anonymizerListener.getAnonymizedQueryString()); + } + + @Test + public void queriesWithSubqueriesShouldAnonymizeSensitiveData() { + String query = "SELECT a.f, a.l, a.a FROM " + + "(SELECT firstname AS f, lastname AS l, age AS a FROM accounts WHERE age > 30) a"; + String expectedQuery = + "( SELECT identifier.identifier, identifier.identifier, identifier.identifier FROM " + + "( SELECT identifier AS identifier, identifier AS identifier, identifier AS identifier " + + "FROM table WHERE identifier > number ) identifier )"; + parse(query); + assertEquals(expectedQuery, anonymizerListener.getAnonymizedQueryString()); + } + + @Test + public void queriesWithLimitShouldAnonymizeSensitiveData() { + String query = "SELECT balance FROM accounts LIMIT 5"; + String expectedQuery = "( SELECT identifier FROM table LIMIT number )"; + parse(query); + assertEquals(expectedQuery, anonymizerListener.getAnonymizedQueryString()); + } + + @Test + public void queriesWithOrderByShouldAnonymizeSensitiveData() { + String query = "SELECT firstname FROM accounts ORDER BY lastname"; + String expectedQuery = "( SELECT identifier FROM table ORDER BY identifier )"; + parse(query); + assertEquals(expectedQuery, anonymizerListener.getAnonymizedQueryString()); + } + + @Test + public void queriesWithHavingShouldAnonymizeSensitiveData() { + String query = "SELECT SUM(balance) FROM accounts GROUP BY lastname HAVING COUNT(balance) > 2"; + String expectedQuery = "( SELECT SUM ( identifier ) FROM table " + + "GROUP BY identifier HAVING COUNT ( identifier ) > number )"; + parse(query); + assertEquals(expectedQuery, anonymizerListener.getAnonymizedQueryString()); + } + + @Test + public void queriesWithHighlightShouldAnonymizeSensitiveData() { + String query = "SELECT HIGHLIGHT(str0) FROM CALCS WHERE QUERY_STRING(['str0'], 'FURNITURE')"; + String expectedQuery = "( SELECT HIGHLIGHT ( identifier ) FROM table WHERE " + + "QUERY_STRING ( [ 'string_literal' ], 'string_literal' ) )"; + parse(query); + assertEquals(expectedQuery, anonymizerListener.getAnonymizedQueryString()); + } + + @Test + public void queriesWithMatchShouldAnonymizeSensitiveData() { + String query = "SELECT str0 FROM CALCS WHERE MATCH(str0, 'FURNITURE')"; + String expectedQuery = "( SELECT identifier FROM table " + + "WHERE MATCH ( identifier, 'string_literal' ) )"; + parse(query); + assertEquals(expectedQuery, anonymizerListener.getAnonymizedQueryString()); + } + + @Test + public void queriesWithPositionShouldAnonymizeSensitiveData() { + String query = "SELECT POSITION('world' IN 'helloworld')"; + String expectedQuery = "( SELECT POSITION ( 'string_literal' IN 'string_literal' ) )"; + parse(query); + assertEquals(expectedQuery, anonymizerListener.getAnonymizedQueryString()); + } + + @Test + public void queriesWithMatch_Bool_Prefix_ShouldAnonymizeSensitiveData() { + String query = "SELECT firstname, address FROM accounts WHERE " + + "match_bool_prefix(address, 'Bristol Street', minimum_should_match=2)"; + String expectedQuery = "( SELECT identifier, identifier FROM table WHERE MATCH_BOOL_PREFIX " + + "( identifier, 'string_literal', MINIMUM_SHOULD_MATCH = number ) )"; + parse(query); + assertEquals(expectedQuery, anonymizerListener.getAnonymizedQueryString()); + } + + @Test + public void queriesWithGreaterOrEqualShouldAnonymizeSensitiveData() { + String query = "SELECT int0 FROM accounts WHERE int0 >= 0"; + String expectedQuery = "( SELECT identifier FROM table WHERE identifier >= number )"; + parse(query); + assertEquals(expectedQuery, anonymizerListener.getAnonymizedQueryString()); + } + + @Test + public void queriesWithLessOrEqualShouldAnonymizeSensitiveData() { + String query = "SELECT int0 FROM accounts WHERE int0 <= 0"; + String expectedQuery = "( SELECT identifier FROM table WHERE identifier <= number )"; + parse(query); + assertEquals(expectedQuery, anonymizerListener.getAnonymizedQueryString()); + } + + @Test + public void queriesWithNotEqualShouldAnonymizeSensitiveData() { + String query = "SELECT int0 FROM accounts WHERE int0 != 0"; + String expectedQuery = "( SELECT identifier FROM table WHERE identifier != number )"; + parse(query); + assertEquals(expectedQuery, anonymizerListener.getAnonymizedQueryString()); + } + + @Test + public void queriesWithNotEqualAlternateShouldAnonymizeSensitiveData() { + String query = "SELECT int0 FROM calcs WHERE int0 <> 0"; + String expectedQuery = "( SELECT identifier FROM table WHERE identifier <> number )"; + parse(query); + assertEquals(expectedQuery, anonymizerListener.getAnonymizedQueryString()); + } + + + /** + * Test added for coverage, but the errorNode will not be hit normally. + */ + @Test + public void enterErrorNote() { + ErrorNode node = mock(ErrorNode.class); + anonymizerListener.visitErrorNode(node); + } +}