Skip to content

Commit

Permalink
Create new anonymizer for new engine (opensearch-project#1665)
Browse files Browse the repository at this point in the history
* Create new anonymizer for new engine (#266)

* Created anonymizer listener for anonymizing SQL queries through the new engine
Signed-off-by: Matthew Wells <[email protected]>

* Update for review comments

Signed-off-by: Andrew Carbonetto <[email protected]>

* added missing file header, change public variable to private

Signed-off-by: Matthew Wells <[email protected]>

---------

Signed-off-by: Andrew Carbonetto <[email protected]>
Signed-off-by: Matthew Wells <[email protected]>
Co-authored-by: Andrew Carbonetto <[email protected]>
  • Loading branch information
matthewryanwells and acarbonetto authored May 30, 2023
1 parent 6d796ee commit 62120fd
Show file tree
Hide file tree
Showing 4 changed files with 333 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
}
}

LOG.info("[{}] Incoming request {}: {}", QueryContext.getRequestId(), request.uri(),
QueryDataAnonymizer.anonymizeData(sqlRequest.getSql()));
LOG.info("[{}] Incoming request {}", QueryContext.getRequestId(), request.uri());

Format format = SqlRequestParam.getFormat(request.params());

Expand All @@ -157,6 +156,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
}
LOG.debug("[{}] Request {} is not supported and falling back to old SQL engine",
QueryContext.getRequestId(), newSqlRequest);
LOG.info("Request Query: {}", QueryDataAnonymizer.anonymizeData(sqlRequest.getSql()));
QueryAction queryAction = explainRequest(client, sqlRequest, format);
executeSqlRequest(request, queryAction, client, restChannel);
} catch (Exception e) {
Expand Down
113 changes: 113 additions & 0 deletions sql/src/main/java/org/opensearch/sql/sql/antlr/AnonymizerListener.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/


package org.opensearch.sql.sql.antlr;

import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLLexer.BACKTICK_QUOTE_ID;
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLLexer.BOOLEAN;
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLLexer.COMMA;
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLLexer.DECIMAL_LITERAL;
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLLexer.DOT;
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLLexer.EQUAL_SYMBOL;
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLLexer.EXCLAMATION_SYMBOL;
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLLexer.FALSE;
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLLexer.FROM;
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLLexer.GREATER_SYMBOL;
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLLexer.ID;
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLLexer.LESS_SYMBOL;
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLLexer.ONE_DECIMAL;
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLLexer.REAL_LITERAL;
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLLexer.STRING_LITERAL;
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLLexer.TIMESTAMP;
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLLexer.TRUE;
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLLexer.TWO_DECIMAL;
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLLexer.ZERO_DECIMAL;

import org.antlr.v4.runtime.ParserRuleContext;
import org.antlr.v4.runtime.tree.ErrorNode;
import org.antlr.v4.runtime.tree.ParseTreeListener;
import org.antlr.v4.runtime.tree.TerminalNode;

/**
* Parse tree listener for anonymizing SQL requests.
*/
public class AnonymizerListener implements ParseTreeListener {
private String anonymizedQueryString = "";
private static final int NO_TYPE = -1;
private int previousType = NO_TYPE;

@Override
public void enterEveryRule(ParserRuleContext ctx) {
}

@Override
public void exitEveryRule(ParserRuleContext ctx) {
}

@Override
public void visitTerminal(TerminalNode node) {
// In these situations don't add a space prior:
// 1. a DOT between two identifiers
// 2. before a comma
// 3. between equal comparison tokens: e.g <=
// 4. between alt not equals: <>
int token = node.getSymbol().getType();
boolean isDotIdentifiers = token == DOT || previousType == DOT;
boolean isComma = token == COMMA;
boolean isEqualComparison = ((token == EQUAL_SYMBOL)
&& (previousType == LESS_SYMBOL
|| previousType == GREATER_SYMBOL
|| previousType == EXCLAMATION_SYMBOL));
boolean isNotEqualComparisonAlternative =
previousType == LESS_SYMBOL && token == GREATER_SYMBOL;
if (!isDotIdentifiers && !isComma && !isEqualComparison && !isNotEqualComparisonAlternative) {
anonymizedQueryString += " ";
}

// anonymize the following tokens
switch (node.getSymbol().getType()) {
case ID:
case TIMESTAMP:
case BACKTICK_QUOTE_ID:
if (previousType == FROM) {
anonymizedQueryString += "table";
} else {
anonymizedQueryString += "identifier";
}
break;
case ZERO_DECIMAL:
case ONE_DECIMAL:
case TWO_DECIMAL:
case DECIMAL_LITERAL:
case REAL_LITERAL:
anonymizedQueryString += "number";
break;
case STRING_LITERAL:
anonymizedQueryString += "'string_literal'";
break;
case BOOLEAN:
case TRUE:
case FALSE:
anonymizedQueryString += "boolean_literal";
break;
case NO_TYPE:
// end of file
break;
default:
anonymizedQueryString += node.getText().toUpperCase();
}
previousType = node.getSymbol().getType();
}

@Override
public void visitErrorNode(ErrorNode node) {

}

public String getAnonymizedQueryString() {
return "(" + anonymizedQueryString + ")";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

import org.antlr.v4.runtime.CommonTokenStream;
import org.antlr.v4.runtime.tree.ParseTree;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.sql.common.antlr.CaseInsensitiveCharStream;
import org.opensearch.sql.common.antlr.Parser;
import org.opensearch.sql.common.antlr.SyntaxAnalysisErrorListener;
Expand All @@ -18,6 +20,7 @@
* SQL syntax parser which encapsulates an ANTLR parser.
*/
public class SQLSyntaxParser implements Parser {
private static final Logger LOG = LogManager.getLogger(SQLSyntaxParser.class);

/**
* Parse a SQL query by ANTLR parser.
Expand All @@ -26,10 +29,16 @@ public class SQLSyntaxParser implements Parser {
*/
@Override
public ParseTree parse(String query) {
AnonymizerListener anonymizer = new AnonymizerListener();

OpenSearchSQLLexer lexer = new OpenSearchSQLLexer(new CaseInsensitiveCharStream(query));
OpenSearchSQLParser parser = new OpenSearchSQLParser(new CommonTokenStream(lexer));
parser.addErrorListener(new SyntaxAnalysisErrorListener());
return parser.root();
}
parser.addParseListener(anonymizer);

ParseTree parseTree = parser.root();
LOG.info("New Engine Request Query: {}", anonymizer.getAnonymizedQueryString());

return parseTree;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/


package org.opensearch.sql.sql.parser;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Mockito.mock;

import org.antlr.v4.runtime.CommonTokenStream;
import org.antlr.v4.runtime.tree.ErrorNode;
import org.junit.jupiter.api.Test;
import org.opensearch.sql.common.antlr.CaseInsensitiveCharStream;
import org.opensearch.sql.sql.antlr.AnonymizerListener;
import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLLexer;
import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser;

public class AnonymizerListenerTest {

private final AnonymizerListener anonymizerListener = new AnonymizerListener();

/**
* Helper function to parse SQl queries for testing purposes.
* @param query SQL query to be anonymized.
*/
private void parse(String query) {
OpenSearchSQLLexer lexer = new OpenSearchSQLLexer(new CaseInsensitiveCharStream(query));
OpenSearchSQLParser parser = new OpenSearchSQLParser(new CommonTokenStream(lexer));
parser.addParseListener(anonymizerListener);

parser.root();
}

@Test
public void queriesShouldHaveAnonymousFieldAndIndex() {
String query = "SELECT ABS(balance) FROM accounts WHERE age > 30 GROUP BY ABS(balance)";
String expectedQuery = "( SELECT ABS ( identifier ) FROM table "
+ "WHERE identifier > number GROUP BY ABS ( identifier ) )";
parse(query);
assertEquals(expectedQuery, anonymizerListener.getAnonymizedQueryString());
}

@Test
public void queriesShouldAnonymousNumbers() {
String query = "SELECT ABS(20), LOG(20.20) FROM accounts";
String expectedQuery = "( SELECT ABS ( number ), LOG ( number ) FROM table )";
parse(query);
assertEquals(expectedQuery, anonymizerListener.getAnonymizedQueryString());
}

@Test
public void queriesShouldHaveAnonymousBooleanLiterals() {
String query = "SELECT TRUE FROM accounts";
String expectedQuery = "( SELECT boolean_literal FROM table )";
parse(query);
assertEquals(expectedQuery, anonymizerListener.getAnonymizedQueryString());
}

@Test
public void queriesShouldHaveAnonymousInputStrings() {
String query = "SELECT * FROM accounts WHERE name = 'Oliver'";
String expectedQuery = "( SELECT * FROM table WHERE identifier = 'string_literal' )";
parse(query);
assertEquals(expectedQuery, anonymizerListener.getAnonymizedQueryString());
}

@Test
public void queriesWithAliasesShouldAnonymizeSensitiveData() {
String query = "SELECT balance AS b FROM accounts AS a";
String expectedQuery = "( SELECT identifier AS identifier FROM table AS identifier )";
parse(query);
assertEquals(expectedQuery, anonymizerListener.getAnonymizedQueryString());
}

@Test
public void queriesWithFunctionsShouldAnonymizeSensitiveData() {
String query = "SELECT LTRIM(firstname) FROM accounts";
String expectedQuery = "( SELECT LTRIM ( identifier ) FROM table )";
parse(query);
assertEquals(expectedQuery, anonymizerListener.getAnonymizedQueryString());
}

@Test
public void queriesWithAggregatesShouldAnonymizeSensitiveData() {
String query = "SELECT MAX(price) - MIN(price) from tickets";
String expectedQuery = "( SELECT MAX ( identifier ) - MIN ( identifier ) FROM table )";
parse(query);
assertEquals(expectedQuery, anonymizerListener.getAnonymizedQueryString());
}

@Test
public void queriesWithSubqueriesShouldAnonymizeSensitiveData() {
String query = "SELECT a.f, a.l, a.a FROM "
+ "(SELECT firstname AS f, lastname AS l, age AS a FROM accounts WHERE age > 30) a";
String expectedQuery =
"( SELECT identifier.identifier, identifier.identifier, identifier.identifier FROM "
+ "( SELECT identifier AS identifier, identifier AS identifier, identifier AS identifier "
+ "FROM table WHERE identifier > number ) identifier )";
parse(query);
assertEquals(expectedQuery, anonymizerListener.getAnonymizedQueryString());
}

@Test
public void queriesWithLimitShouldAnonymizeSensitiveData() {
String query = "SELECT balance FROM accounts LIMIT 5";
String expectedQuery = "( SELECT identifier FROM table LIMIT number )";
parse(query);
assertEquals(expectedQuery, anonymizerListener.getAnonymizedQueryString());
}

@Test
public void queriesWithOrderByShouldAnonymizeSensitiveData() {
String query = "SELECT firstname FROM accounts ORDER BY lastname";
String expectedQuery = "( SELECT identifier FROM table ORDER BY identifier )";
parse(query);
assertEquals(expectedQuery, anonymizerListener.getAnonymizedQueryString());
}

@Test
public void queriesWithHavingShouldAnonymizeSensitiveData() {
String query = "SELECT SUM(balance) FROM accounts GROUP BY lastname HAVING COUNT(balance) > 2";
String expectedQuery = "( SELECT SUM ( identifier ) FROM table "
+ "GROUP BY identifier HAVING COUNT ( identifier ) > number )";
parse(query);
assertEquals(expectedQuery, anonymizerListener.getAnonymizedQueryString());
}

@Test
public void queriesWithHighlightShouldAnonymizeSensitiveData() {
String query = "SELECT HIGHLIGHT(str0) FROM CALCS WHERE QUERY_STRING(['str0'], 'FURNITURE')";
String expectedQuery = "( SELECT HIGHLIGHT ( identifier ) FROM table WHERE "
+ "QUERY_STRING ( [ 'string_literal' ], 'string_literal' ) )";
parse(query);
assertEquals(expectedQuery, anonymizerListener.getAnonymizedQueryString());
}

@Test
public void queriesWithMatchShouldAnonymizeSensitiveData() {
String query = "SELECT str0 FROM CALCS WHERE MATCH(str0, 'FURNITURE')";
String expectedQuery = "( SELECT identifier FROM table "
+ "WHERE MATCH ( identifier, 'string_literal' ) )";
parse(query);
assertEquals(expectedQuery, anonymizerListener.getAnonymizedQueryString());
}

@Test
public void queriesWithPositionShouldAnonymizeSensitiveData() {
String query = "SELECT POSITION('world' IN 'helloworld')";
String expectedQuery = "( SELECT POSITION ( 'string_literal' IN 'string_literal' ) )";
parse(query);
assertEquals(expectedQuery, anonymizerListener.getAnonymizedQueryString());
}

@Test
public void queriesWithMatch_Bool_Prefix_ShouldAnonymizeSensitiveData() {
String query = "SELECT firstname, address FROM accounts WHERE "
+ "match_bool_prefix(address, 'Bristol Street', minimum_should_match=2)";
String expectedQuery = "( SELECT identifier, identifier FROM table WHERE MATCH_BOOL_PREFIX "
+ "( identifier, 'string_literal', MINIMUM_SHOULD_MATCH = number ) )";
parse(query);
assertEquals(expectedQuery, anonymizerListener.getAnonymizedQueryString());
}

@Test
public void queriesWithGreaterOrEqualShouldAnonymizeSensitiveData() {
String query = "SELECT int0 FROM accounts WHERE int0 >= 0";
String expectedQuery = "( SELECT identifier FROM table WHERE identifier >= number )";
parse(query);
assertEquals(expectedQuery, anonymizerListener.getAnonymizedQueryString());
}

@Test
public void queriesWithLessOrEqualShouldAnonymizeSensitiveData() {
String query = "SELECT int0 FROM accounts WHERE int0 <= 0";
String expectedQuery = "( SELECT identifier FROM table WHERE identifier <= number )";
parse(query);
assertEquals(expectedQuery, anonymizerListener.getAnonymizedQueryString());
}

@Test
public void queriesWithNotEqualShouldAnonymizeSensitiveData() {
String query = "SELECT int0 FROM accounts WHERE int0 != 0";
String expectedQuery = "( SELECT identifier FROM table WHERE identifier != number )";
parse(query);
assertEquals(expectedQuery, anonymizerListener.getAnonymizedQueryString());
}

@Test
public void queriesWithNotEqualAlternateShouldAnonymizeSensitiveData() {
String query = "SELECT int0 FROM calcs WHERE int0 <> 0";
String expectedQuery = "( SELECT identifier FROM table WHERE identifier <> number )";
parse(query);
assertEquals(expectedQuery, anonymizerListener.getAnonymizedQueryString());
}


/**
* Test added for coverage, but the errorNode will not be hit normally.
*/
@Test
public void enterErrorNote() {
ErrorNode node = mock(ErrorNode.class);
anonymizerListener.visitErrorNode(node);
}
}

0 comments on commit 62120fd

Please sign in to comment.