Skip to content
This repository has been archived by the owner on Aug 2, 2022. It is now read-only.

Enable sql function ifnull, nullif and isnull #962

Merged
merged 13 commits into from
Jan 14, 2021
Original file line number Diff line number Diff line change
Expand Up @@ -507,13 +507,25 @@ private Aggregator aggregate(BuiltinFunctionName functionName, Expression... exp
}

public FunctionExpression isnull(Expression... expressions) {
return function(BuiltinFunctionName.ISNULL, expressions);
}

public FunctionExpression is_null(Expression... expressions) {
return function(BuiltinFunctionName.IS_NULL, expressions);
}

public FunctionExpression isnotnull(Expression... expressions) {
return function(BuiltinFunctionName.IS_NOT_NULL, expressions);
}

public FunctionExpression ifnull(Expression... expressions) {
return function(BuiltinFunctionName.IFNULL, expressions);
}

public FunctionExpression nullif(Expression... expressions) {
return function(BuiltinFunctionName.NULLIF, expressions);
}

public static Expression cases(Expression defaultResult,
WhenClause... whenClauses) {
return new CaseClause(Arrays.asList(whenClauses), defaultResult);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,9 @@ public enum BuiltinFunctionName {
*/
IS_NULL(FunctionName.of("is null")),
IS_NOT_NULL(FunctionName.of("is not null")),
IFNULL(FunctionName.of("ifnull")),
NULLIF(FunctionName.of("nullif")),
ISNULL(FunctionName.of("isnull")),
harold-wang marked this conversation as resolved.
Show resolved Hide resolved

ROW_NUMBER(FunctionName.of("row_number")),
RANK(FunctionName.of("rank")),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,31 @@

package com.amazon.opendistroforelasticsearch.sql.expression.operator.predicate;

import static com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils.LITERAL_FALSE;
import static com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils.LITERAL_MISSING;
import static com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils.LITERAL_NULL;
import static com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils.LITERAL_TRUE;

import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.BOOLEAN;
import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.UNKNOWN;

import static com.amazon.opendistroforelasticsearch.sql.expression.function.FunctionDSL.impl;

import com.amazon.opendistroforelasticsearch.sql.data.model.ExprBooleanValue;
import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValue;
import com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType;
import com.amazon.opendistroforelasticsearch.sql.data.type.ExprType;
import com.amazon.opendistroforelasticsearch.sql.expression.function.BuiltinFunctionName;
import com.amazon.opendistroforelasticsearch.sql.expression.function.BuiltinFunctionRepository;
import com.amazon.opendistroforelasticsearch.sql.expression.function.FunctionBuilder;
import com.amazon.opendistroforelasticsearch.sql.expression.function.FunctionDSL;
import com.amazon.opendistroforelasticsearch.sql.expression.function.FunctionName;
import com.amazon.opendistroforelasticsearch.sql.expression.function.FunctionResolver;
import com.amazon.opendistroforelasticsearch.sql.expression.function.FunctionSignature;
import com.amazon.opendistroforelasticsearch.sql.expression.function.SerializableFunction;

import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import lombok.experimental.UtilityClass;

Expand All @@ -39,8 +54,11 @@ public class UnaryPredicateOperator {
*/
public static void register(BuiltinFunctionRepository repository) {
repository.register(not());
repository.register(isNull());
repository.register(isNotNull());
repository.register(ifNull());
repository.register(nullIf());
repository.register(isNull(BuiltinFunctionName.IS_NULL));
repository.register(isNull(BuiltinFunctionName.ISNULL));
}

private static FunctionResolver not() {
Expand All @@ -64,10 +82,9 @@ public ExprValue not(ExprValue v) {
}
}

private static FunctionResolver isNull() {

private static FunctionResolver isNull(BuiltinFunctionName funcName) {
return FunctionDSL
.define(BuiltinFunctionName.IS_NULL.getName(), Arrays.stream(ExprCoreType.values())
.define(funcName.getName(), Arrays.stream(ExprCoreType.values())
.map(type -> FunctionDSL
.impl((v) -> ExprBooleanValue.of(v.isNull()), BOOLEAN, type))
.collect(
Expand All @@ -82,4 +99,54 @@ private static FunctionResolver isNotNull() {
.collect(
Collectors.toList()));
}

private static FunctionResolver ifNull() {
FunctionName functionName = BuiltinFunctionName.IFNULL.getName();
List<ExprType> typeList = ExprCoreType.coreTypes();

List<SerializableFunction<FunctionName, org.apache.commons.lang3.tuple.Pair<FunctionSignature,
FunctionBuilder>>> functionsOne = typeList.stream().map(v ->
impl((UnaryPredicateOperator::exprIfNull), v, v, v))
harold-wang marked this conversation as resolved.
Show resolved Hide resolved
.collect(Collectors.toList());

List<SerializableFunction<FunctionName, org.apache.commons.lang3.tuple.Pair<FunctionSignature,
FunctionBuilder>>> functionsTwo = typeList.stream().map(v ->
impl((UnaryPredicateOperator::exprIfNull), v, UNKNOWN, v))
.collect(Collectors.toList());

functionsOne.addAll(functionsTwo);
FunctionResolver functionResolver = FunctionDSL.define(functionName, functionsOne);
return functionResolver;
}

private static FunctionResolver nullIf() {
FunctionName functionName = BuiltinFunctionName.NULLIF.getName();
List<ExprType> typeList = ExprCoreType.coreTypes();

FunctionResolver functionResolver =
FunctionDSL.define(functionName,
typeList.stream().map(v ->
impl((UnaryPredicateOperator::exprNullIf), v, v, v))
.collect(Collectors.toList()));
return functionResolver;
}

/** v2 if v1 is null.
* @param v1 varable 1
* @param v2 varable 2
* @return v2 if v1 is null
*/
public static ExprValue exprIfNull(ExprValue v1, ExprValue v2) {
return (v1.isNull() || v1.isMissing()) ? v2 : v1;
}

/** return null if v1 equls to v2.
* @param v1 varable 1
* @param v2 varable 2
* @return null if v1 equls to v2
*/
public static ExprValue exprNullIf(ExprValue v1, ExprValue v2) {
return v1.equals(v2) ? LITERAL_NULL : v1;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ public class ExpressionTestBase {
protected Environment<Expression, ExprType> typeEnv;

@Bean
protected Environment<Expression, ExprValue> valueEnv() {
protected static Environment<Expression, ExprValue> valueEnv() {
return var -> {
if (var instanceof ReferenceExpression) {
switch (((ReferenceExpression) var).getAttr()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,29 @@
import static com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils.LITERAL_TRUE;
import static com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils.booleanValue;
import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.BOOLEAN;
import static java.lang.Enum.valueOf;
import static org.junit.jupiter.api.Assertions.assertEquals;

import com.amazon.opendistroforelasticsearch.sql.data.model.ExprNullValue;
import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValue;
import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils;
import com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType;
import com.amazon.opendistroforelasticsearch.sql.expression.DSL;
import com.amazon.opendistroforelasticsearch.sql.expression.Expression;
import com.amazon.opendistroforelasticsearch.sql.expression.ExpressionTestBase;
import com.amazon.opendistroforelasticsearch.sql.expression.FunctionExpression;
import com.google.common.collect.Lists;
import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;

class UnaryPredicateOperatorTest extends ExpressionTestBase {
Expand All @@ -44,6 +58,100 @@ public void test_not(Boolean v) {
assertEquals(String.format("not(%s)", v.toString()), not.toString());
}

private static Stream<Arguments> isNullArguments() {
ArrayList<Expression> expressions = new ArrayList<>();
expressions.add(DSL.literal("test"));
expressions.add(DSL.literal(100));
expressions.add(DSL.literal(""));
expressions.add(DSL.literal(LITERAL_NULL));

return Lists.cartesianProduct(expressions, expressions).stream()
.map(list -> {
Expression e1 = list.get(0);
if (e1.valueOf(valueEnv()).isNull()
|| e1.valueOf(valueEnv()).isMissing()) {
return Arguments.of(e1, DSL.literal(LITERAL_TRUE));
} else {
return Arguments.of(e1, DSL.literal(LITERAL_FALSE));
}
});
}

private static Stream<Arguments> ifNullArguments() {
ArrayList<Expression> exprValueArrayList = new ArrayList<>();
exprValueArrayList.add(DSL.literal(123));
exprValueArrayList.add(DSL.literal("test"));
exprValueArrayList.add(DSL.literal(321));
exprValueArrayList.add(DSL.literal(""));

return Lists.cartesianProduct(exprValueArrayList, exprValueArrayList).stream()
.map(list -> {
Expression e1 = list.get(0);
Expression e2 = list.get(1);
if (e1.valueOf(valueEnv()).value() == LITERAL_NULL.value()
|| e1.valueOf(valueEnv()).value() == LITERAL_MISSING) {
return Arguments.of(e1, e2, e2);
} else {
return Arguments.of(e1, e2, e1);
}
});
}

private static Stream<Arguments> nullIfArguments() {
ArrayList<Expression> exprValueArrayList = new ArrayList<>();
exprValueArrayList.add(DSL.literal(123));
exprValueArrayList.add(DSL.literal(321));

return Lists.cartesianProduct(exprValueArrayList, exprValueArrayList).stream()
.map(list -> {
Expression e1 = list.get(0);
Expression e2 = list.get(1);

if (e1.equals(e2)) {
return Arguments.of(e1, e2, DSL.literal(LITERAL_NULL));
} else {
return Arguments.of(e1, e2, e1);
}
});
}

private static Stream<Arguments> exprIfNullArguments() {
ArrayList<ExprValue> exprValues = new ArrayList<>();
exprValues.add(LITERAL_NULL);
exprValues.add(LITERAL_MISSING);
exprValues.add(ExprValueUtils.integerValue(123));
exprValues.add(ExprValueUtils.stringValue("test"));

return Lists.cartesianProduct(exprValues, exprValues).stream()
.map(list -> {
ExprValue e1 = list.get(0);
ExprValue e2 = list.get(1);
if (e1.isNull() || e1.isMissing()) {
return Arguments.of(e1, e2, e2);
} else {
return Arguments.of(e1, e2, e1);
}
});
}

private static Stream<Arguments> exprNullIfArguments() {
ArrayList<ExprValue> exprValues = new ArrayList<>();
exprValues.add(LITERAL_NULL);
exprValues.add(LITERAL_MISSING);
exprValues.add(ExprValueUtils.integerValue(123));

return Lists.cartesianProduct(exprValues, exprValues).stream()
.map(list -> {
ExprValue e1 = list.get(0);
ExprValue e2 = list.get(1);
if (e1.equals(e2)) {
return Arguments.of(e1, e2, LITERAL_NULL);
} else {
return Arguments.of(e1, e2, e1);
}
});
}

@Test
public void test_not_null() {
FunctionExpression expression = dsl.not(DSL.ref(BOOL_TYPE_NULL_VALUE_FIELD, BOOLEAN));
Expand All @@ -59,18 +167,18 @@ public void test_not_missing() {
}

@Test
public void is_null_predicate() {
FunctionExpression expression = dsl.isnull(DSL.literal(1));
public void test_is_null_predicate() {
FunctionExpression expression = dsl.is_null(DSL.literal(1));
assertEquals(BOOLEAN, expression.type());
assertEquals(LITERAL_FALSE, expression.valueOf(valueEnv()));

expression = dsl.isnull(DSL.literal(ExprNullValue.of()));
expression = dsl.is_null(DSL.literal(ExprNullValue.of()));
assertEquals(BOOLEAN, expression.type());
assertEquals(LITERAL_TRUE, expression.valueOf(valueEnv()));
}

@Test
public void is_not_null_predicate() {
public void test_is_not_null_predicate() {
FunctionExpression expression = dsl.isnotnull(DSL.literal(1));
assertEquals(BOOLEAN, expression.type());
assertEquals(LITERAL_TRUE, expression.valueOf(valueEnv()));
Expand All @@ -79,4 +187,35 @@ public void is_not_null_predicate() {
assertEquals(BOOLEAN, expression.type());
assertEquals(LITERAL_FALSE, expression.valueOf(valueEnv()));
}

@ParameterizedTest
@MethodSource("isNullArguments")
public void test_isnull_predicate(Expression v1, Expression expected) {
assertEquals(expected.valueOf(valueEnv()), dsl.isnull(v1).valueOf(valueEnv()));
}

@ParameterizedTest
@MethodSource("ifNullArguments")
public void test_ifnull_predicate(Expression v1, Expression v2, Expression expected) {
assertEquals(expected.valueOf(valueEnv()), dsl.ifnull(v1, v2).valueOf(valueEnv()));
}

@ParameterizedTest
@MethodSource("exprIfNullArguments")
public void test_exprIfNull_predicate(ExprValue v1, ExprValue v2, ExprValue expected) {
assertEquals(expected.value(), UnaryPredicateOperator.exprIfNull(v1, v2).value());
}

@ParameterizedTest
@MethodSource("nullIfArguments")
public void test_nullif_predicate(Expression v1, Expression v2, Expression expected) {
assertEquals(expected.valueOf(valueEnv()), dsl.nullif(v1, v2).valueOf(valueEnv()));
}

@ParameterizedTest
@MethodSource("exprNullIfArguments")
public void test_exprNullIf_predicate(ExprValue v1, ExprValue v2, ExprValue expected) {
assertEquals(expected.value(), UnaryPredicateOperator.exprNullIf(v1, v2).value());
}

}
Loading