Skip to content
This repository has been archived by the owner on Aug 2, 2022. It is now read-only.

Support extra number operators: POWER, ATAN2, COT, SIGN #251

Merged
merged 13 commits into from
Oct 25, 2019
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,14 @@ public enum ScalarFunction implements TypeExpression {
ABS(func(T(NUMBER)).to(T)), // translate to Java: <T extends Number> T ABS(T)
ASIN(func(T(NUMBER)).to(T)),
ATAN(func(T(NUMBER)).to(T)),
ATAN2(func(T(NUMBER)).to(T)),
ATAN2(func(T(NUMBER), NUMBER).to(T)),
CBRT(func(T(NUMBER)).to(T)),
CEIL(func(T(NUMBER)).to(T)),
CONCAT(), // TODO: varargs support required
CONCAT_WS(),
COS(func(T(NUMBER)).to(T)),
COSH(func(T(NUMBER)).to(T)),
COT(func(T(NUMBER)).to(T)),
DATE_FORMAT(
func(DATE, STRING).to(STRING),
func(DATE, STRING, STRING).to(STRING)
Expand All @@ -60,14 +61,16 @@ public enum ScalarFunction implements TypeExpression {
func(T(STRING), STRING).to(T)
),
PI(func().to(DOUBLE)),
POW(
POW, POWER(
func(T(NUMBER)).to(T),
func(T(NUMBER), NUMBER).to(T)
),
RADIANS(func(T(NUMBER)).to(T)),
RANDOM(func(T(NUMBER)).to(T)),
RINT(func(T(NUMBER)).to(T)),
ROUND(func(T(NUMBER)).to(T)),
SIGN(func(T(NUMBER)).to(T)),
SIGNUM(func(T(NUMBER)).to(T)),
SIN(func(T(NUMBER)).to(T)),
SINH(func(T(NUMBER)).to(T)),
SQRT(func(T(NUMBER)).to(T)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,14 @@
public class SQLFunctions {

private static final Set<String> numberOperators = Sets.newHashSet(
"exp", "expm1", "log", "log2", "log10", "sqrt", "cbrt", "ceil", "floor", "rint", "pow",
"round", "random", "abs"
"exp", "expm1", "log", "log2", "log10", "sqrt", "cbrt", "ceil", "floor", "rint", "pow", "power",
"round", "random", "abs", "sign", "signum"
);

private static final Set<String> mathConstants = Sets.newHashSet("e", "pi");

private static final Set<String> trigFunctions = Sets.newHashSet(
"degrees", "radians", "sin", "cos", "tan", "asin", "acos", "atan", "sinh", "cosh"
"degrees", "radians", "sin", "cos", "tan", "asin", "acos", "atan", "atan2", "sinh", "cosh", "cot"
);

private static final Set<String> stringOperators = Sets.newHashSet(
Expand Down Expand Up @@ -205,12 +205,32 @@ public Tuple<String, String> function(String methodName, List<KVValue> paramers,
(SQLExpr) paramers.get(0).value, name);
break;

case "cot":
// ES does not support the function name cot
functionStr = mathSingleValueTemplate("1 / Math.tan", methodName,
(SQLExpr) paramers.get(0).value, name);
break;

case "sign":
case "signum":
methodName = "signum";
functionStr = mathSingleValueTemplate("Math." + methodName, methodName,
(SQLExpr) paramers.get(0).value, name);
break;

case "pow":
case "power":
methodName = "pow";
functionStr = mathDoubleValueTemplate("Math." + methodName, methodName,
(SQLExpr) paramers.get(0).value, Util.expr2Object((SQLExpr) paramers.get(1).value).toString(),
name);
break;

case "atan2":
functionStr = mathDoubleValueTemplate("Math." + methodName, methodName,
(SQLExpr) paramers.get(0).value, (SQLExpr) paramers.get(1).value);
break;

case "substring":
functionStr = substring((SQLExpr) paramers.get(0).value,
Integer.parseInt(Util.expr2Object((SQLExpr) paramers.get(1).value).toString()),
Expand Down Expand Up @@ -544,6 +564,13 @@ private Tuple<String, String> mathDoubleValueTemplate(String methodName, String
}
}

private Tuple<String, String> mathDoubleValueTemplate(String methodName, String fieldName, SQLExpr val1,
SQLExpr val2) {
String name = nextId(fieldName);
return new Tuple<>(name, def(name, func(methodName, false,
getPropertyOrValue(val1), getPropertyOrValue(val2))));
}

private Tuple<String, String> mathSingleValueTemplate(String methodName, String fieldName, SQLExpr field,
String valueName) {
String name = nextId(fieldName);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ public void allSupportedMathFunctionCallInSelectClauseShouldPass() {
" ABS(age), " +
" ASIN(age), " +
" ATAN(age), " +
" ATAN2(age), " +
" ATAN2(age, age), " +
" CBRT(age), " +
" CEIL(age), " +
" COS(age), " +
Expand Down Expand Up @@ -170,7 +170,7 @@ public void allSupportedMathFunctionCallInWhereClauseShouldPass() {
" ABS(age) = 1 AND " +
" ASIN(age) = 1 AND " +
" ATAN(age) = 1 AND " +
" ATAN2(age) = 1 AND " +
" ATAN2(age, age) = 1 AND " +
" CBRT(age) = 1 AND " +
" CEIL(age) = 1 AND " +
" COS(age) = 1 AND " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

import java.io.IOException;

import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;

Expand Down Expand Up @@ -138,6 +139,44 @@ public void sinh() throws IOException {
assertThat(sinh, equalTo(Math.sinh(Math.PI)));
}

@Test
public void power() throws IOException {
SearchHit[] hits = query(
"SELECT POWER(age, 2) AS power",
"WHERE (age IS NOT NULL) AND (balance IS NOT NULL) and (POWER(balance, 3) > 0)"
);
double power = (double) getField(hits[0], "power");
assertTrue(power >= 0);
}

@Test
public void atan2() throws IOException {
SearchHit[] hits = query(
"SELECT ATAN2(age, age) AS atan2",
"WHERE (age IS NOT NULL) AND (ATAN2(age, age) > 0)"
);
double atan2 = (double) getField(hits[0], "atan2");
assertThat(atan2, equalTo(Math.atan2(1, 1)));
}

@Test
public void cot() throws IOException {
SearchHit[] hits = query(
"SELECT COT(PI()) AS cot"
);
double cot = (double) getField(hits[0], "cot");
assertThat(cot, closeTo(1 / Math.tan(Math.PI), 0.001));
}

@Test
public void sign() throws IOException {
SearchHit[] hits = query(
"SELECT SIGN(E()) AS sign"
);
double sign = (double) getField(hits[0], "sign");
assertThat(sign, equalTo(Math.signum(Math.E)));
}

private SearchHit[] query(String select, String... statements) throws IOException {
final String response = executeQueryWithStringOutput(select + " " + FROM + " " + String.join(" ", statements));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@

import static org.elasticsearch.search.builder.SearchSourceBuilder.ScriptField;
import static org.junit.Assert.assertTrue;
import static com.amazon.opendistroforelasticsearch.sql.util.CheckScriptContents.scriptContainsString;
import static com.amazon.opendistroforelasticsearch.sql.util.CheckScriptContents.scriptHasPattern;

public class MathFunctionsTest {

Expand Down Expand Up @@ -315,4 +313,66 @@ public void coshWithValueArgument() {
"Math.cosh(0)"));
}

@Test
public void powerWithPropertyArgument() {
String query = "SELECT POWER(age, 2) FROM bank WHERE POWER(balance, 3) > 0";
ScriptField scriptField = CheckScriptContents.getScriptFieldFromQuery(query);
assertTrue(
CheckScriptContents.scriptContainsString(
scriptField,
"Math.pow(doc['age'].value, 2)"));

ScriptFilter scriptFilter = CheckScriptContents.getScriptFilterFromQuery(query, parser);
assertTrue(
CheckScriptContents.scriptContainsString(
scriptFilter,
"Math.pow(doc['balance'].value, 3)"));
}

@Test
public void atan2WithPropertyArgument() {
String query = "SELECT ATAN2(age, 2) FROM bank WHERE ATAN2(balance, 3) > 0";
ScriptField scriptField = CheckScriptContents.getScriptFieldFromQuery(query);
assertTrue(
CheckScriptContents.scriptContainsString(
scriptField,
"Math.atan2(doc['age'].value, 2)"));

ScriptFilter scriptFilter = CheckScriptContents.getScriptFilterFromQuery(query, parser);
assertTrue(
CheckScriptContents.scriptContainsString(
scriptFilter,
"Math.atan2(doc['balance'].value, 3)"));
}

@Test
public void cotWithPropertyArgument() {
String query = "SELECT COT(age) FROM bank WHERE COT(balance) > 0";
ScriptField scriptField = CheckScriptContents.getScriptFieldFromQuery(query);
assertTrue(
CheckScriptContents.scriptContainsString(
scriptField,
"1 / Math.tan(doc['age'].value)"));

ScriptFilter scriptFilter = CheckScriptContents.getScriptFilterFromQuery(query, parser);
assertTrue(
CheckScriptContents.scriptContainsString(
scriptFilter,
"1 / Math.tan(doc['balance'].value)"));
}

@Test
public void signWithFunctionPropertyArgument() {
String query = "SELECT SIGN(age) FROM bank WHERE SIGNUM(balance) = 1";
ScriptField scriptField = CheckScriptContents.getScriptFieldFromQuery(query);
assertTrue(CheckScriptContents.scriptContainsString(
scriptField,
"Math.signum(doc['age'].value)"));

ScriptFilter scriptFilter = CheckScriptContents.getScriptFilterFromQuery(query, parser);
assertTrue(
CheckScriptContents.scriptContainsString(
scriptFilter,
"Math.signum(doc['balance'].value)"));
}
}