From 95a2cd8a410d279f049f813b316f73d8d20a80ef Mon Sep 17 00:00:00 2001 From: Matthew Wells Date: Mon, 29 May 2023 15:34:30 -0700 Subject: [PATCH 1/5] Create new anonymizer for new engine (#266) * Created anonymizer listener for anonymizing SQL queries through the new engine Signed-off-by: Matthew Wells --- .../sql/legacy/plugin/RestSqlAction.java | 4 +- .../sql/sql/antlr/AnonymizerListener.java | 107 +++++++++ .../sql/sql/antlr/SQLSyntaxParser.java | 20 +- .../sql/parser/AnonymizerListenerTest.java | 211 ++++++++++++++++++ 4 files changed, 337 insertions(+), 5 deletions(-) create mode 100644 sql/src/main/java/org/opensearch/sql/sql/antlr/AnonymizerListener.java create mode 100644 sql/src/test/java/org/opensearch/sql/sql/parser/AnonymizerListenerTest.java diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSqlAction.java b/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSqlAction.java index 88ed42010b..5249d2d5d0 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSqlAction.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSqlAction.java @@ -141,8 +141,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli } } - LOG.info("[{}] Incoming request {}: {}", QueryContext.getRequestId(), request.uri(), - QueryDataAnonymizer.anonymizeData(sqlRequest.getSql())); + LOG.info("[{}] Incoming request {}", QueryContext.getRequestId(), request.uri()); Format format = SqlRequestParam.getFormat(request.params()); @@ -157,6 +156,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli } LOG.debug("[{}] Request {} is not supported and falling back to old SQL engine", QueryContext.getRequestId(), newSqlRequest); + LOG.info("Request Query: {}", QueryDataAnonymizer.anonymizeData(sqlRequest.getSql())); QueryAction queryAction = explainRequest(client, sqlRequest, format); executeSqlRequest(request, queryAction, client, restChannel); } catch (Exception e) { diff --git a/sql/src/main/java/org/opensearch/sql/sql/antlr/AnonymizerListener.java b/sql/src/main/java/org/opensearch/sql/sql/antlr/AnonymizerListener.java new file mode 100644 index 0000000000..7efd71a325 --- /dev/null +++ b/sql/src/main/java/org/opensearch/sql/sql/antlr/AnonymizerListener.java @@ -0,0 +1,107 @@ +package org.opensearch.sql.sql.antlr; + +import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLLexer.BACKTICK_QUOTE_ID; +import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLLexer.BOOLEAN; +import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLLexer.COMMA; +import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLLexer.DECIMAL_LITERAL; +import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLLexer.DOT; +import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLLexer.EQUAL_SYMBOL; +import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLLexer.EXCLAMATION_SYMBOL; +import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLLexer.FALSE; +import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLLexer.FROM; +import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLLexer.GREATER_SYMBOL; +import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLLexer.ID; +import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLLexer.LESS_SYMBOL; +import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLLexer.ONE_DECIMAL; +import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLLexer.REAL_LITERAL; +import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLLexer.STRING_LITERAL; +import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLLexer.TIMESTAMP; +import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLLexer.TRUE; +import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLLexer.TWO_DECIMAL; +import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLLexer.ZERO_DECIMAL; + +import org.antlr.v4.runtime.ParserRuleContext; +import org.antlr.v4.runtime.tree.ErrorNode; +import org.antlr.v4.runtime.tree.ParseTreeListener; +import org.antlr.v4.runtime.tree.TerminalNode; + +/** + * Parse tree listener for anonymizing SQL requests. + */ +public class AnonymizerListener implements ParseTreeListener { + private String anonymizedQueryString = ""; + private static final int NO_TYPE = -1; + private int previousType = NO_TYPE; + + @Override + public void enterEveryRule(ParserRuleContext ctx) { + } + + @Override + public void exitEveryRule(ParserRuleContext ctx) { + } + + @Override + public void visitTerminal(TerminalNode node) { + // In these situations don't add a space prior: + // 1. a DOT between two identifiers + // 2. before a comma + // 3. between equal comparison tokens: e.g <= + // 4. between alt not equals: <> + int token = node.getSymbol().getType(); + boolean isDotIdentifiers = token == DOT || previousType == DOT; + boolean isComma = token == COMMA; + boolean isEqualComparison = ((token == EQUAL_SYMBOL) + && (previousType == LESS_SYMBOL + || previousType == GREATER_SYMBOL + || previousType == EXCLAMATION_SYMBOL)); + boolean isNotEqualComparisonAlternative = + previousType == LESS_SYMBOL && token == GREATER_SYMBOL; + if (!isDotIdentifiers && !isComma && !isEqualComparison && !isNotEqualComparisonAlternative) { + anonymizedQueryString += " "; + } + + // anonymize the following tokens + switch (node.getSymbol().getType()) { + case ID: + case TIMESTAMP: + case BACKTICK_QUOTE_ID: + if (previousType == FROM) { + anonymizedQueryString += "table"; + } else { + anonymizedQueryString += "identifier"; + } + break; + case ZERO_DECIMAL: + case ONE_DECIMAL: + case TWO_DECIMAL: + case DECIMAL_LITERAL: + case REAL_LITERAL: + anonymizedQueryString += "number"; + break; + case STRING_LITERAL: + anonymizedQueryString += "'string_literal'"; + break; + case BOOLEAN: + case TRUE: + case FALSE: + anonymizedQueryString += "boolean_literal"; + break; + case NO_TYPE: + // end of file + break; + default: + anonymizedQueryString += node.getText().toUpperCase(); + } + previousType = node.getSymbol().getType(); + } + + @Override + public void visitErrorNode(ErrorNode node) { + + } + + public String getAnonymizedQueryString() { + return "(" + anonymizedQueryString + ")"; + } +} diff --git a/sql/src/main/java/org/opensearch/sql/sql/antlr/SQLSyntaxParser.java b/sql/src/main/java/org/opensearch/sql/sql/antlr/SQLSyntaxParser.java index ee1e991bd4..b5eea0495c 100644 --- a/sql/src/main/java/org/opensearch/sql/sql/antlr/SQLSyntaxParser.java +++ b/sql/src/main/java/org/opensearch/sql/sql/antlr/SQLSyntaxParser.java @@ -8,6 +8,8 @@ import org.antlr.v4.runtime.CommonTokenStream; import org.antlr.v4.runtime.tree.ParseTree; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.opensearch.sql.common.antlr.CaseInsensitiveCharStream; import org.opensearch.sql.common.antlr.Parser; import org.opensearch.sql.common.antlr.SyntaxAnalysisErrorListener; @@ -18,6 +20,12 @@ * SQL syntax parser which encapsulates an ANTLR parser. */ public class SQLSyntaxParser implements Parser { + private SyntaxAnalysisErrorListener syntaxAnalysisErrorListener; + private static final Logger LOG = LogManager.getLogger(SQLSyntaxParser.class); + + public SQLSyntaxParser() { + this.syntaxAnalysisErrorListener = new SyntaxAnalysisErrorListener(); + } /** * Parse a SQL query by ANTLR parser. @@ -26,10 +34,16 @@ public class SQLSyntaxParser implements Parser { */ @Override public ParseTree parse(String query) { + AnonymizerListener anonymizer = new AnonymizerListener(); + OpenSearchSQLLexer lexer = new OpenSearchSQLLexer(new CaseInsensitiveCharStream(query)); OpenSearchSQLParser parser = new OpenSearchSQLParser(new CommonTokenStream(lexer)); - parser.addErrorListener(new SyntaxAnalysisErrorListener()); - return parser.root(); - } + parser.addErrorListener(syntaxAnalysisErrorListener); + parser.addParseListener(anonymizer); + + ParseTree parseTree = parser.root(); + LOG.info("New Engine Request Query: {}", anonymizer.getAnonymizedQueryString()); + return parseTree; + } } diff --git a/sql/src/test/java/org/opensearch/sql/sql/parser/AnonymizerListenerTest.java b/sql/src/test/java/org/opensearch/sql/sql/parser/AnonymizerListenerTest.java new file mode 100644 index 0000000000..d7913d8dcb --- /dev/null +++ b/sql/src/test/java/org/opensearch/sql/sql/parser/AnonymizerListenerTest.java @@ -0,0 +1,211 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + + +package org.opensearch.sql.sql.parser; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.Mockito.mock; + +import org.antlr.v4.runtime.CommonTokenStream; +import org.antlr.v4.runtime.tree.ErrorNode; +import org.junit.jupiter.api.Test; +import org.opensearch.sql.common.antlr.CaseInsensitiveCharStream; +import org.opensearch.sql.common.antlr.SyntaxAnalysisErrorListener; +import org.opensearch.sql.sql.antlr.AnonymizerListener; +import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLLexer; +import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser; + +public class AnonymizerListenerTest { + + public AnonymizerListener anonymizerListener = new AnonymizerListener(); + + public SyntaxAnalysisErrorListener errorListener = new SyntaxAnalysisErrorListener(); + + /** + * Helper function to parse SQl queries for testing purposes. + * @param query SQL query to be anonymized. + */ + public void parse(String query) { + OpenSearchSQLLexer lexer = new OpenSearchSQLLexer(new CaseInsensitiveCharStream(query)); + OpenSearchSQLParser parser = new OpenSearchSQLParser(new CommonTokenStream(lexer)); + parser.addErrorListener(errorListener); + parser.addParseListener(anonymizerListener); + + parser.root(); + } + + @Test + public void queriesShouldHaveAnonymousFieldAndIndex() { + String query = "SELECT ABS(balance) FROM accounts WHERE age > 30 GROUP BY ABS(balance)"; + String expectedQuery = "( SELECT ABS ( identifier ) FROM table " + + "WHERE identifier > number GROUP BY ABS ( identifier ) )"; + parse(query); + assertEquals(expectedQuery, anonymizerListener.getAnonymizedQueryString()); + } + + @Test + public void queriesShouldAnonymousNumbers() { + String query = "SELECT ABS(20), LOG(20.20) FROM accounts"; + String expectedQuery = "( SELECT ABS ( number ), LOG ( number ) FROM table )"; + parse(query); + assertEquals(expectedQuery, anonymizerListener.getAnonymizedQueryString()); + } + + @Test + public void queriesShouldHaveAnonymousBooleanLiterals() { + String query = "SELECT TRUE FROM accounts"; + String expectedQuery = "( SELECT boolean_literal FROM table )"; + parse(query); + assertEquals(expectedQuery, anonymizerListener.getAnonymizedQueryString()); + } + + @Test + public void queriesShouldHaveAnonymousInputStrings() { + String query = "SELECT * FROM accounts WHERE name = 'Oliver'"; + String expectedQuery = "( SELECT * FROM table WHERE identifier = 'string_literal' )"; + parse(query); + assertEquals(expectedQuery, anonymizerListener.getAnonymizedQueryString()); + } + + @Test + public void queriesWithAliasesShouldAnonymizeSensitiveData() { + String query = "SELECT balance AS b FROM accounts AS a"; + String expectedQuery = "( SELECT identifier AS identifier FROM table AS identifier )"; + parse(query); + assertEquals(expectedQuery, anonymizerListener.getAnonymizedQueryString()); + } + + @Test + public void queriesWithFunctionsShouldAnonymizeSensitiveData() { + String query = "SELECT LTRIM(firstname) FROM accounts"; + String expectedQuery = "( SELECT LTRIM ( identifier ) FROM table )"; + parse(query); + assertEquals(expectedQuery, anonymizerListener.getAnonymizedQueryString()); + } + + @Test + public void queriesWithAggregatesShouldAnonymizeSensitiveData() { + String query = "SELECT MAX(price) - MIN(price) from tickets"; + String expectedQuery = "( SELECT MAX ( identifier ) - MIN ( identifier ) FROM table )"; + parse(query); + assertEquals(expectedQuery, anonymizerListener.getAnonymizedQueryString()); + } + + @Test + public void queriesWithSubqueriesShouldAnonymizeSensitiveData() { + String query = "SELECT a.f, a.l, a.a FROM " + + "(SELECT firstname AS f, lastname AS l, age AS a FROM accounts WHERE age > 30) a"; + String expectedQuery = + "( SELECT identifier.identifier, identifier.identifier, identifier.identifier FROM " + + "( SELECT identifier AS identifier, identifier AS identifier, identifier AS identifier " + + "FROM table WHERE identifier > number ) identifier )"; + parse(query); + assertEquals(expectedQuery, anonymizerListener.getAnonymizedQueryString()); + } + + @Test + public void queriesWithLimitShouldAnonymizeSensitiveData() { + String query = "SELECT balance FROM accounts LIMIT 5"; + String expectedQuery = "( SELECT identifier FROM table LIMIT number )"; + parse(query); + assertEquals(expectedQuery, anonymizerListener.getAnonymizedQueryString()); + } + + @Test + public void queriesWithOrderByShouldAnonymizeSensitiveData() { + String query = "SELECT firstname FROM accounts ORDER BY lastname"; + String expectedQuery = "( SELECT identifier FROM table ORDER BY identifier )"; + parse(query); + assertEquals(expectedQuery, anonymizerListener.getAnonymizedQueryString()); + } + + @Test + public void queriesWithHavingShouldAnonymizeSensitiveData() { + String query = "SELECT SUM(balance) FROM accounts GROUP BY lastname HAVING COUNT(balance) > 2"; + String expectedQuery = "( SELECT SUM ( identifier ) FROM table " + + "GROUP BY identifier HAVING COUNT ( identifier ) > number )"; + parse(query); + assertEquals(expectedQuery, anonymizerListener.getAnonymizedQueryString()); + } + + @Test + public void queriesWithHighlightShouldAnonymizeSensitiveData() { + String query = "SELECT HIGHLIGHT(str0) FROM CALCS WHERE QUERY_STRING(['str0'], 'FURNITURE')"; + String expectedQuery = "( SELECT HIGHLIGHT ( identifier ) FROM table WHERE " + + "QUERY_STRING ( [ 'string_literal' ], 'string_literal' ) )"; + parse(query); + assertEquals(expectedQuery, anonymizerListener.getAnonymizedQueryString()); + } + + @Test + public void queriesWithMatchShouldAnonymizeSensitiveData() { + String query = "SELECT str0 FROM CALCS WHERE MATCH(str0, 'FURNITURE')"; + String expectedQuery = "( SELECT identifier FROM table " + + "WHERE MATCH ( identifier, 'string_literal' ) )"; + parse(query); + assertEquals(expectedQuery, anonymizerListener.getAnonymizedQueryString()); + } + + @Test + public void queriesWithPositionShouldAnonymizeSensitiveData() { + String query = "SELECT POSITION('world' IN 'helloworld')"; + String expectedQuery = "( SELECT POSITION ( 'string_literal' IN 'string_literal' ) )"; + parse(query); + assertEquals(expectedQuery, anonymizerListener.getAnonymizedQueryString()); + } + + @Test + public void queriesWithMatch_Bool_Prefix_ShouldAnonymizeSensitiveData() { + String query = "SELECT firstname, address FROM accounts WHERE " + + "match_bool_prefix(address, 'Bristol Street', minimum_should_match=2)"; + String expectedQuery = "( SELECT identifier, identifier FROM table WHERE MATCH_BOOL_PREFIX " + + "( identifier, 'string_literal', MINIMUM_SHOULD_MATCH = number ) )"; + parse(query); + assertEquals(expectedQuery, anonymizerListener.getAnonymizedQueryString()); + } + + @Test + public void queriesWithGreaterOrEqualShouldAnonymizeSensitiveData() { + String query = "SELECT int0 FROM accounts WHERE int0 >= 0"; + String expectedQuery = "( SELECT identifier FROM table WHERE identifier >= number )"; + parse(query); + assertEquals(expectedQuery, anonymizerListener.getAnonymizedQueryString()); + } + + @Test + public void queriesWithLessOrEqualShouldAnonymizeSensitiveData() { + String query = "SELECT int0 FROM accounts WHERE int0 <= 0"; + String expectedQuery = "( SELECT identifier FROM table WHERE identifier <= number )"; + parse(query); + assertEquals(expectedQuery, anonymizerListener.getAnonymizedQueryString()); + } + + @Test + public void queriesWithNotEqualShouldAnonymizeSensitiveData() { + String query = "SELECT int0 FROM accounts WHERE int0 != 0"; + String expectedQuery = "( SELECT identifier FROM table WHERE identifier != number )"; + parse(query); + assertEquals(expectedQuery, anonymizerListener.getAnonymizedQueryString()); + } + + @Test + public void queriesWithNotEqualAlternateShouldAnonymizeSensitiveData() { + String query = "SELECT int0 FROM calcs WHERE int0 <> 0"; + String expectedQuery = "( SELECT identifier FROM table WHERE identifier <> number )"; + parse(query); + assertEquals(expectedQuery, anonymizerListener.getAnonymizedQueryString()); + } + + + /** + * Test added for coverage, but the errorNode will not be hit normally. + */ + @Test + public void enterErrorNote() { + ErrorNode node = mock(ErrorNode.class); + anonymizerListener.visitErrorNode(node); + } +} From da0359fd2da379af4aececde7cd4bcfe06a5384f Mon Sep 17 00:00:00 2001 From: Andrew Carbonetto Date: Mon, 29 May 2023 16:41:30 -0700 Subject: [PATCH 2/5] Update for review comments Signed-off-by: Andrew Carbonetto --- .../java/org/opensearch/sql/sql/antlr/SQLSyntaxParser.java | 7 +------ .../opensearch/sql/sql/parser/AnonymizerListenerTest.java | 7 ++----- 2 files changed, 3 insertions(+), 11 deletions(-) diff --git a/sql/src/main/java/org/opensearch/sql/sql/antlr/SQLSyntaxParser.java b/sql/src/main/java/org/opensearch/sql/sql/antlr/SQLSyntaxParser.java index b5eea0495c..4f7b925718 100644 --- a/sql/src/main/java/org/opensearch/sql/sql/antlr/SQLSyntaxParser.java +++ b/sql/src/main/java/org/opensearch/sql/sql/antlr/SQLSyntaxParser.java @@ -20,13 +20,8 @@ * SQL syntax parser which encapsulates an ANTLR parser. */ public class SQLSyntaxParser implements Parser { - private SyntaxAnalysisErrorListener syntaxAnalysisErrorListener; private static final Logger LOG = LogManager.getLogger(SQLSyntaxParser.class); - public SQLSyntaxParser() { - this.syntaxAnalysisErrorListener = new SyntaxAnalysisErrorListener(); - } - /** * Parse a SQL query by ANTLR parser. * @param query a SQL query @@ -38,7 +33,7 @@ public ParseTree parse(String query) { OpenSearchSQLLexer lexer = new OpenSearchSQLLexer(new CaseInsensitiveCharStream(query)); OpenSearchSQLParser parser = new OpenSearchSQLParser(new CommonTokenStream(lexer)); - parser.addErrorListener(syntaxAnalysisErrorListener); + parser.addErrorListener(new SyntaxAnalysisErrorListener()); parser.addParseListener(anonymizer); ParseTree parseTree = parser.root(); diff --git a/sql/src/test/java/org/opensearch/sql/sql/parser/AnonymizerListenerTest.java b/sql/src/test/java/org/opensearch/sql/sql/parser/AnonymizerListenerTest.java index d7913d8dcb..9ef5fc70d1 100644 --- a/sql/src/test/java/org/opensearch/sql/sql/parser/AnonymizerListenerTest.java +++ b/sql/src/test/java/org/opensearch/sql/sql/parser/AnonymizerListenerTest.java @@ -20,18 +20,15 @@ public class AnonymizerListenerTest { - public AnonymizerListener anonymizerListener = new AnonymizerListener(); - - public SyntaxAnalysisErrorListener errorListener = new SyntaxAnalysisErrorListener(); + private AnonymizerListener anonymizerListener = new AnonymizerListener(); /** * Helper function to parse SQl queries for testing purposes. * @param query SQL query to be anonymized. */ - public void parse(String query) { + private void parse(String query) { OpenSearchSQLLexer lexer = new OpenSearchSQLLexer(new CaseInsensitiveCharStream(query)); OpenSearchSQLParser parser = new OpenSearchSQLParser(new CommonTokenStream(lexer)); - parser.addErrorListener(errorListener); parser.addParseListener(anonymizerListener); parser.root(); From 872b26232ece0c0d6e551cea92a42d50a81f2596 Mon Sep 17 00:00:00 2001 From: Matthew Wells Date: Tue, 30 May 2023 08:54:28 -0700 Subject: [PATCH 3/5] added missing file header, change public variable to private Signed-off-by: Matthew Wells --- .../org/opensearch/sql/sql/antlr/AnonymizerListener.java | 6 ++++++ .../opensearch/sql/sql/parser/AnonymizerListenerTest.java | 3 +-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/sql/src/main/java/org/opensearch/sql/sql/antlr/AnonymizerListener.java b/sql/src/main/java/org/opensearch/sql/sql/antlr/AnonymizerListener.java index 7efd71a325..bd7b5cbedf 100644 --- a/sql/src/main/java/org/opensearch/sql/sql/antlr/AnonymizerListener.java +++ b/sql/src/main/java/org/opensearch/sql/sql/antlr/AnonymizerListener.java @@ -1,3 +1,9 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + + package org.opensearch.sql.sql.antlr; import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLLexer.BACKTICK_QUOTE_ID; diff --git a/sql/src/test/java/org/opensearch/sql/sql/parser/AnonymizerListenerTest.java b/sql/src/test/java/org/opensearch/sql/sql/parser/AnonymizerListenerTest.java index 9ef5fc70d1..59d723e3a2 100644 --- a/sql/src/test/java/org/opensearch/sql/sql/parser/AnonymizerListenerTest.java +++ b/sql/src/test/java/org/opensearch/sql/sql/parser/AnonymizerListenerTest.java @@ -13,14 +13,13 @@ import org.antlr.v4.runtime.tree.ErrorNode; import org.junit.jupiter.api.Test; import org.opensearch.sql.common.antlr.CaseInsensitiveCharStream; -import org.opensearch.sql.common.antlr.SyntaxAnalysisErrorListener; import org.opensearch.sql.sql.antlr.AnonymizerListener; import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLLexer; import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser; public class AnonymizerListenerTest { - private AnonymizerListener anonymizerListener = new AnonymizerListener(); + private final AnonymizerListener anonymizerListener = new AnonymizerListener(); /** * Helper function to parse SQl queries for testing purposes. From 6185cfb22b7dfb50a54c5c36fb9d07b84fb78ecd Mon Sep 17 00:00:00 2001 From: Matthew Wells Date: Thu, 1 Jun 2023 14:11:01 -0700 Subject: [PATCH 4/5] fixed bug, updated tests to cover bytes divided by zero Signed-off-by: Matthew Wells --- .../operator/arthmetic/ArithmeticFunction.java | 6 ++++-- .../operator/arthmetic/ArithmeticFunctionTest.java | 10 +++++----- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/ArithmeticFunction.java b/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/ArithmeticFunction.java index 7130ac105a..20cb6dd51c 100644 --- a/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/ArithmeticFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/ArithmeticFunction.java @@ -106,7 +106,8 @@ private static DefaultFunctionResolver addFunction() { private static DefaultFunctionResolver divideBase(FunctionName functionName) { return define(functionName, impl(nullMissingHandling( - (v1, v2) -> new ExprByteValue(v1.byteValue() / v2.byteValue())), + (v1, v2) -> v2.byteValue() == 0 ? ExprNullValue.of() : + new ExprByteValue(v1.byteValue() / v2.byteValue())), BYTE, BYTE, BYTE), impl(nullMissingHandling( (v1, v2) -> v2.shortValue() == 0 ? ExprNullValue.of() : @@ -149,7 +150,8 @@ private static DefaultFunctionResolver divideFunction() { private static DefaultFunctionResolver modulusBase(FunctionName functionName) { return define(functionName, impl(nullMissingHandling( - (v1, v2) -> new ExprByteValue(v1.byteValue() % v2.byteValue())), + (v1, v2) -> v2.byteValue() == 0 ? ExprNullValue.of() : + new ExprByteValue(v1.byteValue() % v2.byteValue())), BYTE, BYTE, BYTE), impl(nullMissingHandling( (v1, v2) -> v2.shortValue() == 0 ? ExprNullValue.of() : diff --git a/core/src/test/java/org/opensearch/sql/expression/operator/arthmetic/ArithmeticFunctionTest.java b/core/src/test/java/org/opensearch/sql/expression/operator/arthmetic/ArithmeticFunctionTest.java index b28bea8b89..028ace6231 100644 --- a/core/src/test/java/org/opensearch/sql/expression/operator/arthmetic/ArithmeticFunctionTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/operator/arthmetic/ArithmeticFunctionTest.java @@ -113,7 +113,7 @@ public void mod(ExprValue op1, ExprValue op2) { assertEquals(String.format("mod(%s, %s)", op1.toString(), op2.toString()), expression.toString()); - expression = DSL.mod(literal(op1), literal(new ExprShortValue(0))); + expression = DSL.mod(literal(op1), literal(new ExprByteValue(0))); assertTrue(expression.valueOf(valueEnv()).isNull()); assertEquals(String.format("mod(%s, 0)", op1.toString()), expression.toString()); } @@ -128,7 +128,7 @@ public void modulus(ExprValue op1, ExprValue op2) { assertEquals(String.format("%%(%s, %s)", op1.toString(), op2.toString()), expression.toString()); - expression = DSL.modulus(literal(op1), literal(new ExprShortValue(0))); + expression = DSL.modulus(literal(op1), literal(new ExprByteValue(0))); assertTrue(expression.valueOf(valueEnv()).isNull()); assertEquals(String.format("%%(%s, 0)", op1.toString()), expression.toString()); } @@ -144,7 +144,7 @@ public void modulusFunction(ExprValue op1, ExprValue op2) { assertEquals(String.format("modulus(%s, %s)", op1.toString(), op2.toString()), expression.toString()); - expression = DSL.modulusFunction(literal(op1), literal(new ExprShortValue(0))); + expression = DSL.modulusFunction(literal(op1), literal(new ExprByteValue(0))); assertTrue(expression.valueOf(valueEnv()).isNull()); assertEquals(String.format("modulus(%s, 0)", op1.toString()), expression.toString()); } @@ -183,7 +183,7 @@ public void divide(ExprValue op1, ExprValue op2) { assertEquals(String.format("/(%s, %s)", op1.toString(), op2.toString()), expression.toString()); - expression = DSL.divide(literal(op1), literal(new ExprShortValue(0))); + expression = DSL.divide(literal(op1), literal(new ExprByteValue(0))); assertTrue(expression.valueOf(valueEnv()).isNull()); assertEquals(String.format("/(%s, 0)", op1.toString()), expression.toString()); } @@ -199,7 +199,7 @@ public void divideFunction(ExprValue op1, ExprValue op2) { assertEquals(String.format("divide(%s, %s)", op1.toString(), op2.toString()), expression.toString()); - expression = DSL.divideFunction(literal(op1), literal(new ExprShortValue(0))); + expression = DSL.divideFunction(literal(op1), literal(new ExprByteValue(0))); assertTrue(expression.valueOf(valueEnv()).isNull()); assertEquals(String.format("divide(%s, 0)", op1.toString()), expression.toString()); } From 1d5ea1e62c17855747456bfee3deb8d0e438e19d Mon Sep 17 00:00:00 2001 From: Matthew Wells Date: Thu, 1 Jun 2023 14:30:19 -0700 Subject: [PATCH 5/5] corrected function name from modoulo to modulus Signed-off-by: Matthew Wells --- .../sql/expression/operator/arthmetic/ArithmeticFunction.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/ArithmeticFunction.java b/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/ArithmeticFunction.java index 20cb6dd51c..1f4ac3943c 100644 --- a/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/ArithmeticFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/ArithmeticFunction.java @@ -141,7 +141,7 @@ private static DefaultFunctionResolver divideFunction() { } /** - * Definition of modulo(x, y) function. + * Definition of modulus(x, y) function. * Returns the number x modulo by number y * The supported signature of modulo function is * (x: BYTE/SHORT/INTEGER/LONG/FLOAT/DOUBLE, y: BYTE/SHORT/INTEGER/LONG/FLOAT/DOUBLE)