Skip to content

Commit

Permalink
Modified returning NaN to NULL (opensearch-project#225) (opensearch-p…
Browse files Browse the repository at this point in the history
…roject#1341)

Signed-off-by: Guian Gumpac <[email protected]>
  • Loading branch information
GumpacG authored Feb 15, 2023
1 parent 72547f4 commit de40f42
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,8 @@ private static DefaultFunctionResolver floor() {
*/
private static DefaultFunctionResolver ln() {
return baseMathFunction(BuiltinFunctionName.LN.getName(),
v -> new ExprDoubleValue(Math.log(v.doubleValue())), DOUBLE);
v -> v.doubleValue() <= 0 ? ExprNullValue.of() :
new ExprDoubleValue(Math.log(v.doubleValue())), DOUBLE);
}

/**
Expand All @@ -255,15 +256,17 @@ private static DefaultFunctionResolver log() {
// build unary log(x), SHORT/INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE
for (ExprType type : ExprCoreType.numberTypes()) {
builder.add(FunctionDSL.impl(FunctionDSL
.nullMissingHandling(v -> new ExprDoubleValue(Math.log(v.doubleValue()))),
.nullMissingHandling(v -> v.doubleValue() <= 0 ? ExprNullValue.of() :
new ExprDoubleValue(Math.log(v.doubleValue()))),
DOUBLE, type));
}

// build binary function log(b, x)
for (ExprType baseType : ExprCoreType.numberTypes()) {
for (ExprType numberType : ExprCoreType.numberTypes()) {
builder.add(FunctionDSL.impl(FunctionDSL
.nullMissingHandling((b, x) -> new ExprDoubleValue(
.nullMissingHandling((b, x) -> b.doubleValue() <= 0 || x.doubleValue() <= 0
? ExprNullValue.of() : new ExprDoubleValue(
Math.log(x.doubleValue()) / Math.log(b.doubleValue()))),
DOUBLE, baseType, numberType));
}
Expand All @@ -278,7 +281,8 @@ private static DefaultFunctionResolver log() {
*/
private static DefaultFunctionResolver log10() {
return baseMathFunction(BuiltinFunctionName.LOG10.getName(),
v -> new ExprDoubleValue(Math.log10(v.doubleValue())), DOUBLE);
v -> v.doubleValue() <= 0 ? ExprNullValue.of() :
new ExprDoubleValue(Math.log10(v.doubleValue())), DOUBLE);
}

/**
Expand All @@ -287,7 +291,8 @@ private static DefaultFunctionResolver log10() {
*/
private static DefaultFunctionResolver log2() {
return baseMathFunction(BuiltinFunctionName.LOG2.getName(),
v -> new ExprDoubleValue(Math.log(v.doubleValue()) / Math.log(2)), DOUBLE);
v -> v.doubleValue() <= 0 ? ExprNullValue.of() :
new ExprDoubleValue(Math.log(v.doubleValue()) / Math.log(2)), DOUBLE);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ private static Stream<Arguments> testLogDoubleArguments() {
return builder.add(Arguments.of(2D, 2D)).build();
}

private static Stream<Arguments> testLogInvalidDoubleArguments() {
return Stream.of(Arguments.of(0D, -2D),
Arguments.of(0D, 2D),
Arguments.of(2D, 0D));
}

private static Stream<Arguments> trigonometricArguments() {
Stream.Builder<Arguments> builder = Stream.builder();
return builder
Expand Down Expand Up @@ -725,7 +731,7 @@ public void floor_missing_value() {
* Test ln with integer value.
*/
@ParameterizedTest(name = "ln({0})")
@ValueSource(ints = {2, -2})
@ValueSource(ints = {2, 3})
public void ln_int_value(Integer value) {
FunctionExpression ln = DSL.ln(DSL.literal(value));
assertThat(
Expand All @@ -738,7 +744,7 @@ public void ln_int_value(Integer value) {
* Test ln with long value.
*/
@ParameterizedTest(name = "ln({0})")
@ValueSource(longs = {2L, -2L})
@ValueSource(longs = {2L, 3L})
public void ln_long_value(Long value) {
FunctionExpression ln = DSL.ln(DSL.literal(value));
assertThat(
Expand All @@ -751,7 +757,7 @@ public void ln_long_value(Long value) {
* Test ln with float value.
*/
@ParameterizedTest(name = "ln({0})")
@ValueSource(floats = {2F, -2F})
@ValueSource(floats = {2F, 3F})
public void ln_float_value(Float value) {
FunctionExpression ln = DSL.ln(DSL.literal(value));
assertThat(
Expand All @@ -764,7 +770,7 @@ public void ln_float_value(Float value) {
* Test ln with double value.
*/
@ParameterizedTest(name = "ln({0})")
@ValueSource(doubles = {2D, -2D})
@ValueSource(doubles = {2D, 3D})
public void ln_double_value(Double value) {
FunctionExpression ln = DSL.ln(DSL.literal(value));
assertThat(
Expand All @@ -773,6 +779,17 @@ public void ln_double_value(Double value) {
assertEquals(String.format("ln(%s)", value.toString()), ln.toString());
}

/**
* Test ln with invalid value.
*/
@ParameterizedTest(name = "ln({0})")
@ValueSource(doubles = {0D, -3D})
public void ln_invalid_value(Double value) {
FunctionExpression ln = DSL.ln(DSL.literal(value));
assertEquals(DOUBLE, ln.type());
assertTrue(ln.valueOf(valueEnv()).isNull());
}

/**
* Test ln with null value.
*/
Expand Down Expand Up @@ -853,6 +870,17 @@ public void log_double_value(Double v) {
assertEquals(String.format("log(%s)", v.toString()), log.toString());
}

/**
* Test log with 1 invalid value.
*/
@ParameterizedTest(name = "log({0})")
@ValueSource(doubles = {0D, -3D})
public void log_invalid_value(Double value) {
FunctionExpression log = DSL.log(DSL.literal(value));
assertEquals(DOUBLE, log.type());
assertTrue(log.valueOf(valueEnv()).isNull());
}

/**
* Test log with 1 null value argument.
*/
Expand Down Expand Up @@ -931,6 +959,17 @@ public void log_two_double_value(Double v1, Double v2) {
assertEquals(String.format("log(%s, %s)", v1.toString(), v2.toString()), log.toString());
}

/**
* Test log with 2 invalid double arguments.
*/
@ParameterizedTest(name = "log({0}, {2})")
@MethodSource("testLogInvalidDoubleArguments")
public void log_two_invalid_double_value(Double v1, Double v2) {
FunctionExpression log = DSL.log(DSL.literal(v1), DSL.literal(v2));
assertEquals(log.type(), DOUBLE);
assertTrue(log.valueOf(valueEnv()).isNull());
}

/**
* Test log with 2 null value arguments.
*/
Expand Down Expand Up @@ -1051,6 +1090,17 @@ public void log10_double_value(Double v) {
assertEquals(String.format("log10(%s)", v.toString()), log.toString());
}

/**
* Test log10 with 1 invalid double argument.
*/
@ParameterizedTest(name = "log10({0})")
@ValueSource(doubles = {0D, -3D})
public void log10_two_invalid_value(Double v) {
FunctionExpression log = DSL.log10(DSL.literal(v));
assertEquals(log.type(), DOUBLE);
assertTrue(log.valueOf(valueEnv()).isNull());
}

/**
* Test log10 with null value.
*/
Expand Down Expand Up @@ -1133,6 +1183,17 @@ public void log2_double_value(Double v) {
assertEquals(String.format("log2(%s)", v.toString()), log.toString());
}

/**
* Test log2 with an invalid double value.
*/
@ParameterizedTest(name = "log2({0})")
@ValueSource(doubles = {0D, -2D})
public void log2_invalid_double_value(Double v) {
FunctionExpression log = DSL.log2(DSL.literal(v));
assertEquals(log.type(), DOUBLE);
assertTrue(log.valueOf(valueEnv()).isNull());
}

/**
* Test log2 with null value.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,19 +197,6 @@ public void testAtan() throws IOException {
verifyDataRows(result, rows(Math.atan2(2, 3)));
}

protected JSONObject executeQuery(String query) throws IOException {
Request request = new Request("POST", QUERY_API_ENDPOINT);
request.setJsonEntity(String.format(Locale.ROOT, "{\n" + " \"query\": \"%s\"\n" + "}", query));

RequestOptions.Builder restOptionsBuilder = RequestOptions.DEFAULT.toBuilder();
restOptionsBuilder.addHeader("Content-Type", "application/json");
request.setOptions(restOptionsBuilder);

Response response = client().performRequest(request);
return new JSONObject(getResponseBody(response));
}


@Test
public void testCbrt() throws IOException {
JSONObject result = executeQuery("select cbrt(8)");
Expand All @@ -224,4 +211,52 @@ public void testCbrt() throws IOException {
verifySchema(result, schema("cbrt(-27)", "double"));
verifyDataRows(result, rows(-3.0));
}

@Test
public void testLnReturnsNull() throws IOException {
JSONObject result = executeQuery("select ln(0), ln(-2)");
verifySchema(result,
schema("ln(0)", "double"),
schema("ln(-2)", "double"));
verifyDataRows(result, rows(null, null));
}

@Test
public void testLogReturnsNull() throws IOException {
JSONObject result = executeQuery("select log(0), log(-2)");
verifySchema(result,
schema("log(0)", "double"),
schema("log(-2)", "double"));
verifyDataRows(result, rows(null, null));
}

@Test
public void testLog10ReturnsNull() throws IOException {
JSONObject result = executeQuery("select log10(0), log10(-2)");
verifySchema(result,
schema("log10(0)", "double"),
schema("log10(-2)", "double"));
verifyDataRows(result, rows(null, null));
}

@Test
public void testLog2ReturnsNull() throws IOException {
JSONObject result = executeQuery("select log2(0), log2(-2)");
verifySchema(result,
schema("log2(0)", "double"),
schema("log2(-2)", "double"));
verifyDataRows(result, rows(null, null));
}

protected JSONObject executeQuery(String query) throws IOException {
Request request = new Request("POST", QUERY_API_ENDPOINT);
request.setJsonEntity(String.format(Locale.ROOT, "{\n" + " \"query\": \"%s\"\n" + "}", query));

RequestOptions.Builder restOptionsBuilder = RequestOptions.DEFAULT.toBuilder();
restOptionsBuilder.addHeader("Content-Type", "application/json");
request.setOptions(restOptionsBuilder);

Response response = client().performRequest(request);
return new JSONObject(getResponseBody(response));
}
}

0 comments on commit de40f42

Please sign in to comment.