diff --git a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/analysis/ExpressionAnalyzer.java b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/analysis/ExpressionAnalyzer.java index 54cd263596..06d90608f6 100644 --- a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/analysis/ExpressionAnalyzer.java +++ b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/analysis/ExpressionAnalyzer.java @@ -22,6 +22,7 @@ import com.amazon.opendistroforelasticsearch.sql.ast.expression.AllFields; import com.amazon.opendistroforelasticsearch.sql.ast.expression.And; import com.amazon.opendistroforelasticsearch.sql.ast.expression.Case; +import com.amazon.opendistroforelasticsearch.sql.ast.expression.Cast; import com.amazon.opendistroforelasticsearch.sql.ast.expression.Compare; import com.amazon.opendistroforelasticsearch.sql.ast.expression.EqualTo; import com.amazon.opendistroforelasticsearch.sql.ast.expression.Field; @@ -67,6 +68,13 @@ public class ExpressionAnalyzer extends AbstractNodeVisitor CONVERTED_TYPE_FUNCTION_NAME_MAP = + new ImmutableMap.Builder() + .put("string", CAST_TO_STRING.getName()) + .put("int", CAST_TO_INT.getName()) + .put("long", CAST_TO_LONG.getName()) + .put("float", CAST_TO_FLOAT.getName()) + .put("double", CAST_TO_DOUBLE.getName()) + .put("boolean", CAST_TO_BOOLEAN.getName()) + .put("date", CAST_TO_DATE.getName()) + .put("time", CAST_TO_TIME.getName()) + .put("timestamp", CAST_TO_TIMESTAMP.getName()) + .build(); + + /** + * The source expression cast from. + */ + private final UnresolvedExpression expression; + + /** + * Expression that represents ELSE statement result. + */ + private final UnresolvedExpression convertedType; + + /** + * Get the converted type. + * + * @return converted type + */ + public FunctionName convertFunctionName() { + String type = convertedType.toString().toLowerCase(Locale.ROOT); + if (CONVERTED_TYPE_FUNCTION_NAME_MAP.containsKey(type)) { + return CONVERTED_TYPE_FUNCTION_NAME_MAP.get(type); + } else { + throw new IllegalStateException("unsupported cast type: " + type); + } + } + + @Override + public List getChild() { + return Collections.singletonList(expression); + } + + @Override + public T accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitCast(this, context); + } +} diff --git a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/DSL.java b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/DSL.java index ddb7f622a5..dbeb5453b5 100644 --- a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/DSL.java +++ b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/DSL.java @@ -523,4 +523,49 @@ public FunctionExpression interval(Expression value, Expression unit) { return (FunctionExpression) repository.compile( BuiltinFunctionName.INTERVAL.getName(), Arrays.asList(value, unit)); } + + public FunctionExpression castString(Expression value) { + return (FunctionExpression) repository + .compile(BuiltinFunctionName.CAST_TO_STRING.getName(), Arrays.asList(value)); + } + + public FunctionExpression castInt(Expression value) { + return (FunctionExpression) repository + .compile(BuiltinFunctionName.CAST_TO_INT.getName(), Arrays.asList(value)); + } + + public FunctionExpression castLong(Expression value) { + return (FunctionExpression) repository + .compile(BuiltinFunctionName.CAST_TO_LONG.getName(), Arrays.asList(value)); + } + + public FunctionExpression castFloat(Expression value) { + return (FunctionExpression) repository + .compile(BuiltinFunctionName.CAST_TO_FLOAT.getName(), Arrays.asList(value)); + } + + public FunctionExpression castDouble(Expression value) { + return (FunctionExpression) repository + .compile(BuiltinFunctionName.CAST_TO_DOUBLE.getName(), Arrays.asList(value)); + } + + public FunctionExpression castBoolean(Expression value) { + return (FunctionExpression) repository + .compile(BuiltinFunctionName.CAST_TO_BOOLEAN.getName(), Arrays.asList(value)); + } + + public FunctionExpression castDate(Expression value) { + return (FunctionExpression) repository + .compile(BuiltinFunctionName.CAST_TO_DATE.getName(), Arrays.asList(value)); + } + + public FunctionExpression castTime(Expression value) { + return (FunctionExpression) repository + .compile(BuiltinFunctionName.CAST_TO_TIME.getName(), Arrays.asList(value)); + } + + public FunctionExpression castTimestamp(Expression value) { + return (FunctionExpression) repository + .compile(BuiltinFunctionName.CAST_TO_TIMESTAMP.getName(), Arrays.asList(value)); + } } diff --git a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/config/ExpressionConfig.java b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/config/ExpressionConfig.java index b52514b447..1053d556cf 100644 --- a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/config/ExpressionConfig.java +++ b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/config/ExpressionConfig.java @@ -22,6 +22,7 @@ import com.amazon.opendistroforelasticsearch.sql.expression.function.BuiltinFunctionRepository; import com.amazon.opendistroforelasticsearch.sql.expression.operator.arthmetic.ArithmeticFunction; import com.amazon.opendistroforelasticsearch.sql.expression.operator.arthmetic.MathematicalFunction; +import com.amazon.opendistroforelasticsearch.sql.expression.operator.convert.TypeCastOperator; import com.amazon.opendistroforelasticsearch.sql.expression.operator.predicate.BinaryPredicateOperator; import com.amazon.opendistroforelasticsearch.sql.expression.operator.predicate.UnaryPredicateOperator; import com.amazon.opendistroforelasticsearch.sql.expression.text.TextFunction; @@ -51,6 +52,7 @@ public BuiltinFunctionRepository functionRepository() { IntervalClause.register(builtinFunctionRepository); WindowFunctions.register(builtinFunctionRepository); TextFunction.register(builtinFunctionRepository); + TypeCastOperator.register(builtinFunctionRepository); return builtinFunctionRepository; } diff --git a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/function/BuiltinFunctionName.java b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/function/BuiltinFunctionName.java index f346620686..8c00c120d0 100644 --- a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/function/BuiltinFunctionName.java +++ b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/function/BuiltinFunctionName.java @@ -142,7 +142,20 @@ public enum BuiltinFunctionName { RANK(FunctionName.of("rank")), DENSE_RANK(FunctionName.of("dense_rank")), - INTERVAL(FunctionName.of("interval")); + INTERVAL(FunctionName.of("interval")), + + /** + * Data Type Convert Function. + */ + CAST_TO_STRING(FunctionName.of("cast_to_string")), + CAST_TO_INT(FunctionName.of("cast_to_int")), + CAST_TO_LONG(FunctionName.of("cast_to_long")), + CAST_TO_FLOAT(FunctionName.of("cast_to_float")), + CAST_TO_DOUBLE(FunctionName.of("cast_to_double")), + CAST_TO_BOOLEAN(FunctionName.of("cast_to_boolean")), + CAST_TO_DATE(FunctionName.of("cast_to_date")), + CAST_TO_TIME(FunctionName.of("cast_to_time")), + CAST_TO_TIMESTAMP(FunctionName.of("cast_to_timestamp")); private final FunctionName name; diff --git a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/operator/convert/TypeCastOperator.java b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/operator/convert/TypeCastOperator.java new file mode 100644 index 0000000000..f49e9cb818 --- /dev/null +++ b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/operator/convert/TypeCastOperator.java @@ -0,0 +1,171 @@ +/* + * + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + * + */ + +package com.amazon.opendistroforelasticsearch.sql.expression.operator.convert; + +import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.BOOLEAN; +import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.BYTE; +import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.DATE; +import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.DATETIME; +import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.DOUBLE; +import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.FLOAT; +import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.INTEGER; +import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.LONG; +import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.SHORT; +import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.STRING; +import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.TIME; +import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.TIMESTAMP; +import static com.amazon.opendistroforelasticsearch.sql.expression.function.FunctionDSL.impl; +import static com.amazon.opendistroforelasticsearch.sql.expression.function.FunctionDSL.nullMissingHandling; + +import com.amazon.opendistroforelasticsearch.sql.data.model.ExprBooleanValue; +import com.amazon.opendistroforelasticsearch.sql.data.model.ExprDateValue; +import com.amazon.opendistroforelasticsearch.sql.data.model.ExprDoubleValue; +import com.amazon.opendistroforelasticsearch.sql.data.model.ExprFloatValue; +import com.amazon.opendistroforelasticsearch.sql.data.model.ExprIntegerValue; +import com.amazon.opendistroforelasticsearch.sql.data.model.ExprLongValue; +import com.amazon.opendistroforelasticsearch.sql.data.model.ExprStringValue; +import com.amazon.opendistroforelasticsearch.sql.data.model.ExprTimeValue; +import com.amazon.opendistroforelasticsearch.sql.data.model.ExprTimestampValue; +import com.amazon.opendistroforelasticsearch.sql.expression.function.BuiltinFunctionName; +import com.amazon.opendistroforelasticsearch.sql.expression.function.BuiltinFunctionRepository; +import com.amazon.opendistroforelasticsearch.sql.expression.function.FunctionDSL; +import com.amazon.opendistroforelasticsearch.sql.expression.function.FunctionResolver; +import java.util.Arrays; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import lombok.experimental.UtilityClass; + +@UtilityClass +public class TypeCastOperator { + /** + * Register Type Cast Operator. + */ + public static void register(BuiltinFunctionRepository repository) { + repository.register(castToString()); + repository.register(castToInt()); + repository.register(castToLong()); + repository.register(castToFloat()); + repository.register(castToDouble()); + repository.register(castToBoolean()); + repository.register(castToDate()); + repository.register(castToTime()); + repository.register(castToTimestamp()); + } + + + private static FunctionResolver castToString() { + return FunctionDSL.define(BuiltinFunctionName.CAST_TO_STRING.getName(), + Stream.concat( + Arrays.asList(BYTE, SHORT, INTEGER, LONG, FLOAT, DOUBLE, BOOLEAN, TIME, DATE, + TIMESTAMP, DATETIME).stream() + .map(type -> impl( + nullMissingHandling((v) -> new ExprStringValue(v.value().toString())), + STRING, type)), + Stream.of(impl(nullMissingHandling((v) -> v), STRING, STRING))) + .collect(Collectors.toList()) + ); + } + + private static FunctionResolver castToInt() { + return FunctionDSL.define(BuiltinFunctionName.CAST_TO_INT.getName(), + impl(nullMissingHandling( + (v) -> new ExprIntegerValue(Integer.valueOf(v.stringValue()))), INTEGER, STRING), + impl(nullMissingHandling( + (v) -> new ExprIntegerValue(v.integerValue())), INTEGER, DOUBLE), + impl(nullMissingHandling( + (v) -> new ExprIntegerValue(v.booleanValue() ? 1 : 0)), INTEGER, BOOLEAN) + ); + } + + private static FunctionResolver castToLong() { + return FunctionDSL.define(BuiltinFunctionName.CAST_TO_LONG.getName(), + impl(nullMissingHandling( + (v) -> new ExprLongValue(Long.valueOf(v.stringValue()))), LONG, STRING), + impl(nullMissingHandling( + (v) -> new ExprLongValue(v.longValue())), LONG, DOUBLE), + impl(nullMissingHandling( + (v) -> new ExprLongValue(v.booleanValue() ? 1L : 0L)), LONG, BOOLEAN) + ); + } + + private static FunctionResolver castToFloat() { + return FunctionDSL.define(BuiltinFunctionName.CAST_TO_FLOAT.getName(), + impl(nullMissingHandling( + (v) -> new ExprFloatValue(Float.valueOf(v.stringValue()))), FLOAT, STRING), + impl(nullMissingHandling( + (v) -> new ExprFloatValue(v.longValue())), FLOAT, DOUBLE), + impl(nullMissingHandling( + (v) -> new ExprFloatValue(v.booleanValue() ? 1f : 0f)), FLOAT, BOOLEAN) + ); + } + + private static FunctionResolver castToDouble() { + return FunctionDSL.define(BuiltinFunctionName.CAST_TO_DOUBLE.getName(), + impl(nullMissingHandling( + (v) -> new ExprDoubleValue(Double.valueOf(v.stringValue()))), DOUBLE, STRING), + impl(nullMissingHandling( + (v) -> new ExprDoubleValue(v.doubleValue())), DOUBLE, DOUBLE), + impl(nullMissingHandling( + (v) -> new ExprDoubleValue(v.booleanValue() ? 1D : 0D)), DOUBLE, BOOLEAN) + ); + } + + private static FunctionResolver castToBoolean() { + return FunctionDSL.define(BuiltinFunctionName.CAST_TO_BOOLEAN.getName(), + impl(nullMissingHandling( + (v) -> ExprBooleanValue.of(Boolean.valueOf(v.stringValue()))), BOOLEAN, STRING), + impl(nullMissingHandling( + (v) -> ExprBooleanValue.of(v.doubleValue() != 0)), BOOLEAN, DOUBLE), + impl(nullMissingHandling((v) -> v), BOOLEAN, BOOLEAN) + ); + } + + private static FunctionResolver castToDate() { + return FunctionDSL.define(BuiltinFunctionName.CAST_TO_DATE.getName(), + impl(nullMissingHandling( + (v) -> new ExprDateValue(v.stringValue())), DATE, STRING), + impl(nullMissingHandling( + (v) -> new ExprDateValue(v.dateValue())), DATE, DATETIME), + impl(nullMissingHandling( + (v) -> new ExprDateValue(v.dateValue())), DATE, TIMESTAMP), + impl(nullMissingHandling((v) -> v), DATE, DATE) + ); + } + + private static FunctionResolver castToTime() { + return FunctionDSL.define(BuiltinFunctionName.CAST_TO_TIME.getName(), + impl(nullMissingHandling( + (v) -> new ExprTimeValue(v.stringValue())), TIME, STRING), + impl(nullMissingHandling( + (v) -> new ExprTimeValue(v.timeValue())), TIME, DATETIME), + impl(nullMissingHandling( + (v) -> new ExprTimeValue(v.timeValue())), TIME, TIMESTAMP), + impl(nullMissingHandling((v) -> v), TIME, TIME) + ); + } + + private static FunctionResolver castToTimestamp() { + return FunctionDSL.define(BuiltinFunctionName.CAST_TO_TIMESTAMP.getName(), + impl(nullMissingHandling( + (v) -> new ExprTimestampValue(v.stringValue())), TIMESTAMP, STRING), + impl(nullMissingHandling( + (v) -> new ExprTimestampValue(v.timestampValue())), TIMESTAMP, DATETIME), + impl(nullMissingHandling((v) -> v), TIMESTAMP, TIMESTAMP) + ); + } +} diff --git a/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/analysis/ExpressionAnalyzerTest.java b/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/analysis/ExpressionAnalyzerTest.java index 76965cd749..ce40c26721 100644 --- a/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/analysis/ExpressionAnalyzerTest.java +++ b/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/analysis/ExpressionAnalyzerTest.java @@ -37,6 +37,9 @@ import com.amazon.opendistroforelasticsearch.sql.expression.Expression; import com.amazon.opendistroforelasticsearch.sql.expression.LiteralExpression; import com.amazon.opendistroforelasticsearch.sql.expression.config.ExpressionConfig; +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.springframework.context.annotation.Configuration; @@ -137,6 +140,17 @@ public void case_conditions() { AstDSL.intLiteral(30)), AstDSL.stringLiteral("Thirty")))); } + @Test + public void castAnalyzer() { + assertAnalyzeEqual( + dsl.castInt(DSL.ref("boolean_value", BOOLEAN)), + AstDSL.cast(AstDSL.unresolvedAttr("boolean_value"), AstDSL.stringLiteral("INT")) + ); + + assertThrows(IllegalStateException.class, () -> analyze(AstDSL.cast(AstDSL.unresolvedAttr( + "boolean_value"), AstDSL.stringLiteral("DATETIME")))); + } + @Test public void case_with_default_result_type_different() { UnresolvedExpression caseWhen = AstDSL.caseWhen( diff --git a/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/expression/operator/convert/TypeCastOperatorTest.java b/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/expression/operator/convert/TypeCastOperatorTest.java new file mode 100644 index 0000000000..269d22e2c0 --- /dev/null +++ b/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/expression/operator/convert/TypeCastOperatorTest.java @@ -0,0 +1,304 @@ +/* + * + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + * + */ + +package com.amazon.opendistroforelasticsearch.sql.expression.operator.convert; + +import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.BOOLEAN; +import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.DATE; +import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.DOUBLE; +import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.FLOAT; +import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.INTEGER; +import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.LONG; +import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.STRING; +import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.TIME; +import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.TIMESTAMP; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import com.amazon.opendistroforelasticsearch.sql.data.model.ExprBooleanValue; +import com.amazon.opendistroforelasticsearch.sql.data.model.ExprByteValue; +import com.amazon.opendistroforelasticsearch.sql.data.model.ExprDateValue; +import com.amazon.opendistroforelasticsearch.sql.data.model.ExprDatetimeValue; +import com.amazon.opendistroforelasticsearch.sql.data.model.ExprDoubleValue; +import com.amazon.opendistroforelasticsearch.sql.data.model.ExprFloatValue; +import com.amazon.opendistroforelasticsearch.sql.data.model.ExprIntegerValue; +import com.amazon.opendistroforelasticsearch.sql.data.model.ExprLongValue; +import com.amazon.opendistroforelasticsearch.sql.data.model.ExprShortValue; +import com.amazon.opendistroforelasticsearch.sql.data.model.ExprStringValue; +import com.amazon.opendistroforelasticsearch.sql.data.model.ExprTimeValue; +import com.amazon.opendistroforelasticsearch.sql.data.model.ExprTimestampValue; +import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValue; +import com.amazon.opendistroforelasticsearch.sql.expression.DSL; +import com.amazon.opendistroforelasticsearch.sql.expression.FunctionExpression; +import com.amazon.opendistroforelasticsearch.sql.expression.config.ExpressionConfig; +import java.util.stream.Stream; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +class TypeCastOperatorTest { + + private final DSL dsl = new ExpressionConfig().dsl(new ExpressionConfig().functionRepository()); + + private static Stream numberData() { + return Stream.of(new ExprByteValue(3), new ExprShortValue(3), + new ExprIntegerValue(3), new ExprLongValue(3L), new ExprFloatValue(3f), + new ExprDoubleValue(3D)); + } + + private static Stream stringData() { + return Stream.of(new ExprStringValue("strV")); + } + + private static Stream boolData() { + return Stream.of(ExprBooleanValue.of(true)); + } + + private static Stream date() { + return Stream.of(new ExprDateValue("2020-12-24")); + } + + private static Stream time() { + return Stream.of(new ExprTimeValue("01:01:01")); + } + + private static Stream timestamp() { + return Stream.of(new ExprTimestampValue("2020-12-24 01:01:01")); + } + + private static Stream datetime() { + return Stream.of(new ExprDatetimeValue("2020-12-24 01:01:01")); + } + + @ParameterizedTest(name = "castString({0})") + @MethodSource({"numberData", "stringData", "boolData", "date", "time", "timestamp", "datetime"}) + void castToString(ExprValue value) { + FunctionExpression expression = dsl.castString(DSL.literal(value)); + assertEquals(STRING, expression.type()); + assertEquals(new ExprStringValue(value.value().toString()), expression.valueOf(null)); + } + + @ParameterizedTest(name = "castToInt({0})") + @MethodSource({"numberData"}) + void castToInt(ExprValue value) { + FunctionExpression expression = dsl.castInt(DSL.literal(value)); + assertEquals(INTEGER, expression.type()); + assertEquals(new ExprIntegerValue(value.integerValue()), expression.valueOf(null)); + } + + @Test + void castStringToInt() { + FunctionExpression expression = dsl.castInt(DSL.literal("100")); + assertEquals(INTEGER, expression.type()); + assertEquals(new ExprIntegerValue(100), expression.valueOf(null)); + } + + @Test + void castStringToIntException() { + FunctionExpression expression = dsl.castInt(DSL.literal("invalid")); + assertThrows(RuntimeException.class, () -> expression.valueOf(null)); + } + + @Test + void castBooleanToInt() { + FunctionExpression expression = dsl.castInt(DSL.literal(true)); + assertEquals(INTEGER, expression.type()); + assertEquals(new ExprIntegerValue(1), expression.valueOf(null)); + + expression = dsl.castInt(DSL.literal(false)); + assertEquals(INTEGER, expression.type()); + assertEquals(new ExprIntegerValue(0), expression.valueOf(null)); + } + + @ParameterizedTest(name = "castToLong({0})") + @MethodSource({"numberData"}) + void castToLong(ExprValue value) { + FunctionExpression expression = dsl.castLong(DSL.literal(value)); + assertEquals(LONG, expression.type()); + assertEquals(new ExprLongValue(value.longValue()), expression.valueOf(null)); + } + + @Test + void castStringToLong() { + FunctionExpression expression = dsl.castLong(DSL.literal("100")); + assertEquals(LONG, expression.type()); + assertEquals(new ExprLongValue(100), expression.valueOf(null)); + } + + @Test + void castStringToLongException() { + FunctionExpression expression = dsl.castLong(DSL.literal("invalid")); + assertThrows(RuntimeException.class, () -> expression.valueOf(null)); + } + + @Test + void castBooleanToLong() { + FunctionExpression expression = dsl.castLong(DSL.literal(true)); + assertEquals(LONG, expression.type()); + assertEquals(new ExprLongValue(1), expression.valueOf(null)); + + expression = dsl.castLong(DSL.literal(false)); + assertEquals(LONG, expression.type()); + assertEquals(new ExprLongValue(0), expression.valueOf(null)); + } + + @ParameterizedTest(name = "castToFloat({0})") + @MethodSource({"numberData"}) + void castToFloat(ExprValue value) { + FunctionExpression expression = dsl.castFloat(DSL.literal(value)); + assertEquals(FLOAT, expression.type()); + assertEquals(new ExprFloatValue(value.floatValue()), expression.valueOf(null)); + } + + @Test + void castStringToFloat() { + FunctionExpression expression = dsl.castFloat(DSL.literal("100.0")); + assertEquals(FLOAT, expression.type()); + assertEquals(new ExprFloatValue(100.0), expression.valueOf(null)); + } + + @Test + void castStringToFloatException() { + FunctionExpression expression = dsl.castFloat(DSL.literal("invalid")); + assertThrows(RuntimeException.class, () -> expression.valueOf(null)); + } + + @Test + void castBooleanToFloat() { + FunctionExpression expression = dsl.castFloat(DSL.literal(true)); + assertEquals(FLOAT, expression.type()); + assertEquals(new ExprFloatValue(1), expression.valueOf(null)); + + expression = dsl.castFloat(DSL.literal(false)); + assertEquals(FLOAT, expression.type()); + assertEquals(new ExprFloatValue(0), expression.valueOf(null)); + } + + @ParameterizedTest(name = "castToDouble({0})") + @MethodSource({"numberData"}) + void castToDouble(ExprValue value) { + FunctionExpression expression = dsl.castDouble(DSL.literal(value)); + assertEquals(DOUBLE, expression.type()); + assertEquals(new ExprDoubleValue(value.doubleValue()), expression.valueOf(null)); + } + + @Test + void castStringToDouble() { + FunctionExpression expression = dsl.castDouble(DSL.literal("100.0")); + assertEquals(DOUBLE, expression.type()); + assertEquals(new ExprDoubleValue(100), expression.valueOf(null)); + } + + @Test + void castStringToDoubleException() { + FunctionExpression expression = dsl.castDouble(DSL.literal("invalid")); + assertThrows(RuntimeException.class, () -> expression.valueOf(null)); + } + + @Test + void castBooleanToDouble() { + FunctionExpression expression = dsl.castDouble(DSL.literal(true)); + assertEquals(DOUBLE, expression.type()); + assertEquals(new ExprDoubleValue(1), expression.valueOf(null)); + + expression = dsl.castDouble(DSL.literal(false)); + assertEquals(DOUBLE, expression.type()); + assertEquals(new ExprDoubleValue(0), expression.valueOf(null)); + } + + @ParameterizedTest(name = "castToBoolean({0})") + @MethodSource({"numberData"}) + void castToBoolean(ExprValue value) { + FunctionExpression expression = dsl.castBoolean(DSL.literal(value)); + assertEquals(BOOLEAN, expression.type()); + assertEquals(ExprBooleanValue.of(true), expression.valueOf(null)); + } + + @Test + void castZeroToBoolean() { + FunctionExpression expression = dsl.castBoolean(DSL.literal(0)); + assertEquals(BOOLEAN, expression.type()); + assertEquals(ExprBooleanValue.of(false), expression.valueOf(null)); + } + + @Test + void castStringToBoolean() { + FunctionExpression expression = dsl.castBoolean(DSL.literal("True")); + assertEquals(BOOLEAN, expression.type()); + assertEquals(ExprBooleanValue.of(true), expression.valueOf(null)); + } + + @Test + void castBooleanToBoolean() { + FunctionExpression expression = dsl.castBoolean(DSL.literal(true)); + assertEquals(BOOLEAN, expression.type()); + assertEquals(ExprBooleanValue.of(true), expression.valueOf(null)); + } + + @Test + void castToDate() { + FunctionExpression expression = dsl.castDate(DSL.literal("2012-08-07")); + assertEquals(DATE, expression.type()); + assertEquals(new ExprDateValue("2012-08-07"), expression.valueOf(null)); + + expression = dsl.castDate(DSL.literal(new ExprDatetimeValue("2012-08-07 01:01:01"))); + assertEquals(DATE, expression.type()); + assertEquals(new ExprDateValue("2012-08-07"), expression.valueOf(null)); + + expression = dsl.castDate(DSL.literal(new ExprTimestampValue("2012-08-07 01:01:01"))); + assertEquals(DATE, expression.type()); + assertEquals(new ExprDateValue("2012-08-07"), expression.valueOf(null)); + + expression = dsl.castDate(DSL.literal(new ExprDateValue("2012-08-07"))); + assertEquals(DATE, expression.type()); + assertEquals(new ExprDateValue("2012-08-07"), expression.valueOf(null)); + } + + @Test + void castToTime() { + FunctionExpression expression = dsl.castTime(DSL.literal("01:01:01")); + assertEquals(TIME, expression.type()); + assertEquals(new ExprTimeValue("01:01:01"), expression.valueOf(null)); + + expression = dsl.castTime(DSL.literal(new ExprDatetimeValue("2012-08-07 01:01:01"))); + assertEquals(TIME, expression.type()); + assertEquals(new ExprTimeValue("01:01:01"), expression.valueOf(null)); + + expression = dsl.castTime(DSL.literal(new ExprTimestampValue("2012-08-07 01:01:01"))); + assertEquals(TIME, expression.type()); + assertEquals(new ExprTimeValue("01:01:01"), expression.valueOf(null)); + + expression = dsl.castTime(DSL.literal(new ExprTimeValue("01:01:01"))); + assertEquals(TIME, expression.type()); + assertEquals(new ExprTimeValue("01:01:01"), expression.valueOf(null)); + } + + @Test + void castToTimestamp() { + FunctionExpression expression = dsl.castTimestamp(DSL.literal("2012-08-07 01:01:01")); + assertEquals(TIMESTAMP, expression.type()); + assertEquals(new ExprTimestampValue("2012-08-07 01:01:01"), expression.valueOf(null)); + + expression = dsl.castTimestamp(DSL.literal(new ExprDatetimeValue("2012-08-07 01:01:01"))); + assertEquals(TIMESTAMP, expression.type()); + assertEquals(new ExprTimestampValue("2012-08-07 01:01:01"), expression.valueOf(null)); + + expression = dsl.castTimestamp(DSL.literal(new ExprTimestampValue("2012-08-07 01:01:01"))); + assertEquals(TIMESTAMP, expression.type()); + assertEquals(new ExprTimestampValue("2012-08-07 01:01:01"), expression.valueOf(null)); + } +} \ No newline at end of file diff --git a/docs/user/dql/functions.rst b/docs/user/dql/functions.rst index 5ed6c38a87..0e3831673f 100644 --- a/docs/user/dql/functions.rst +++ b/docs/user/dql/functions.rst @@ -25,7 +25,65 @@ CAST Description >>>>>>>>>>> -Specification is undefined and type check is skipped for now +Usage: cast(expr as dateType) cast the expr to dataType. return the value of dataType. The following conversion rules are used: + ++------------+--------+--------+---------+-------------+--------+--------+ +| Src/Target | STRING | NUMBER | BOOLEAN | TIMESTAMP | DATE | TIME | ++------------+--------+--------+---------+-------------+--------+--------+ +| STRING | | Note1 | Note1 | TIMESTAMP() | DATE() | TIME() | ++------------+--------+--------+---------+-------------+--------+--------+ +| NUMBER | Note1 | | v!=0 | N/A | N/A | N/A | ++------------+--------+--------+---------+-------------+--------+--------+ +| BOOLEAN | Note1 | v?1:0 | | N/A | N/A | N/A | ++------------+--------+--------+---------+-------------+--------+--------+ +| TIMESTAMP | Note1 | N/A | N/A | | DATE() | TIME() | ++------------+--------+--------+---------+-------------+--------+--------+ +| DATE | Note1 | N/A | N/A | N/A | | N/A | ++------------+--------+--------+---------+-------------+--------+--------+ +| TIME | Note1 | N/A | N/A | N/A | N/A | | ++------------+--------+--------+---------+-------------+--------+--------+ + +Note1: the conversion follow the JDK specification. + +Cast to string example:: + + od> SELECT cast(true as string) as cbool, cast(1 as string) as cint, cast(DATE '2012-08-07' as string) as cdate + fetched rows / total rows = 1/1 + +---------+--------+------------+ + | cbool | cint | cdate | + |---------+--------+------------| + | true | 1 | 2012-08-07 | + +---------+--------+------------+ + +Cast to number example:: + + od> SELECT cast(true as int) as cbool, cast('1' as int) as cstring + fetched rows / total rows = 1/1 + +---------+-----------+ + | cbool | cstring | + |---------+-----------| + | 1 | 1 | + +---------+-----------+ + +Cast to date example:: + + od> SELECT cast('2012-08-07' as date) as cdate, cast('01:01:01' as time) as ctime, cast('2012-08-07 01:01:01' as timestamp) as ctimestamp + fetched rows / total rows = 1/1 + +------------+----------+---------------------+ + | cdate | ctime | ctimestamp | + |------------+----------+---------------------| + | 2012-08-07 | 01:01:01 | 2012-08-07 01:01:01 | + +------------+----------+---------------------+ + +Cast function can be chained:: + + od> SELECT cast(cast(true as string) as boolean) as cbool + fetched rows / total rows = 1/1 + +---------+ + | cbool | + |---------| + | True | + +---------+ Mathematical Functions diff --git a/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/legacy/AggregationExpressionIT.java b/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/legacy/AggregationExpressionIT.java index d2473c3c6b..028b51b8c7 100644 --- a/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/legacy/AggregationExpressionIT.java +++ b/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/legacy/AggregationExpressionIT.java @@ -22,6 +22,7 @@ import static com.amazon.opendistroforelasticsearch.sql.util.MatcherUtils.verifySchema; import org.json.JSONObject; +import org.junit.Assume; import org.junit.Ignore; import org.junit.Test; @@ -257,11 +258,13 @@ public void groupByDateWithAliasShouldPass() { @Test public void aggregateCastStatementShouldNotReturnZero() { + Assume.assumeTrue(isNewQueryEngineEabled()); + JSONObject response = executeJdbcRequest(String.format( "SELECT SUM(CAST(male AS INT)) AS male_sum FROM %s", Index.BANK.getName())); - verifySchema(response, schema("male_sum", "male_sum", "double")); + verifySchema(response, schema("male_sum", "male_sum", "integer")); verifyDataRows(response, rows(4)); } diff --git a/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/legacy/SQLFunctionsIT.java b/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/legacy/SQLFunctionsIT.java index 08292df717..b95d07d11a 100644 --- a/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/legacy/SQLFunctionsIT.java +++ b/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/legacy/SQLFunctionsIT.java @@ -52,6 +52,7 @@ import org.hamcrest.collection.IsMapContaining; import org.json.JSONObject; import org.junit.Assert; +import org.junit.Assume; import org.junit.Ignore; import org.junit.Test; @@ -219,19 +220,23 @@ public void castIntFieldToStringWithAliasTest() throws IOException { @Test public void castIntFieldToFloatWithoutAliasJdbcFormatTest() { + Assume.assumeTrue(isNewQueryEngineEabled()); + JSONObject response = executeJdbcRequest( - "SELECT CAST(balance AS FLOAT) FROM " + TestsConstants.TEST_INDEX_ACCOUNT + + "SELECT CAST(balance AS FLOAT) AS cast_balance FROM " + TestsConstants.TEST_INDEX_ACCOUNT + " ORDER BY balance DESC LIMIT 1"); verifySchema(response, schema("cast_balance", null, "float")); verifyDataRows(response, - rows(49989)); + rows(49989.0)); } @Test public void castIntFieldToFloatWithAliasJdbcFormatTest() { + Assume.assumeTrue(isNewQueryEngineEabled()); + JSONObject response = executeJdbcRequest( "SELECT CAST(balance AS FLOAT) AS jdbc_float_alias " + "FROM " + TestsConstants.TEST_INDEX_ACCOUNT + " ORDER BY jdbc_float_alias LIMIT 1"); @@ -240,7 +245,7 @@ public void castIntFieldToFloatWithAliasJdbcFormatTest() { schema("jdbc_float_alias", null, "float")); verifyDataRows(response, - rows(1011)); + rows(1011.0)); } @Test @@ -370,8 +375,11 @@ public void castFieldToDatetimeWithGroupByJdbcFormatTest() { rows("2019-09-25T02:04:13.469Z")); } + @Test public void castBoolFieldToNumericValueInSelectClause() { + Assume.assumeTrue(isNewQueryEngineEabled()); + JSONObject response = executeJdbcRequest( "SELECT " @@ -392,13 +400,15 @@ public void castBoolFieldToNumericValueInSelectClause() { schema("cast_double", "double") ); verifyDataRows(response, - rows(true, 1, 1, 1, 1), - rows(false, 0, 0, 0, 0) + rows(true, 1, 1, 1.0, 1.0), + rows(false, 0, 0, 0.0, 0.0) ); } @Test public void castBoolFieldToNumericValueWithGroupByAlias() { + Assume.assumeTrue(isNewQueryEngineEabled()); + JSONObject response = executeJdbcRequest( "SELECT " @@ -409,12 +419,12 @@ public void castBoolFieldToNumericValueWithGroupByAlias() { ); verifySchema(response, - schema("cast_int", "cast_int", "double"), //Type is double due to query plan fail to infer + schema("cast_int", "cast_int", "integer"), schema("COUNT(*)", "integer") ); verifyDataRows(response, - rows("0", 3), - rows("1", 4) + rows(0, 3), + rows(1, 4) ); } diff --git a/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/legacy/SQLIntegTestCase.java b/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/legacy/SQLIntegTestCase.java index 98598b1ad0..9f59d93da2 100644 --- a/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/legacy/SQLIntegTestCase.java +++ b/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/legacy/SQLIntegTestCase.java @@ -148,7 +148,7 @@ public static void cleanUpIndices() throws IOException { } private void enableNewQueryEngine() throws IOException { - boolean isEnabled = Boolean.parseBoolean(System.getProperty("enableNewEngine", "false")); + boolean isEnabled = isNewQueryEngineEabled(); if (isEnabled) { com.amazon.opendistroforelasticsearch.sql.util.TestUtils.enableNewQueryEngine(client()); } diff --git a/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/util/MatcherUtils.java b/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/util/MatcherUtils.java index 6d1bbb75f8..39cc74552b 100644 --- a/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/util/MatcherUtils.java +++ b/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/util/MatcherUtils.java @@ -28,6 +28,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThat; +import com.amazon.opendistroforelasticsearch.sql.common.utils.StringUtils; import com.google.common.base.Strings; import com.google.gson.JsonParser; import java.util.ArrayList; @@ -225,8 +226,7 @@ protected boolean matchesSafely(JSONObject jsonObject) { String actualAlias = (String) jsonObject.query("/alias"); String actualType = (String) jsonObject.query("/type"); return expectedName.equals(actualName) && - (Strings.isNullOrEmpty(actualAlias) && Strings.isNullOrEmpty(expectedAlias) || - expectedAlias.equals(actualAlias)) && + (Strings.isNullOrEmpty(expectedAlias) || expectedAlias.equals(actualAlias)) && expectedType.equals(actualType); } }; diff --git a/integ-test/src/test/resources/correctness/expressions/cast.txt b/integ-test/src/test/resources/correctness/expressions/cast.txt new file mode 100644 index 0000000000..8870b503b4 --- /dev/null +++ b/integ-test/src/test/resources/correctness/expressions/cast.txt @@ -0,0 +1,14 @@ +cast('1' as int) as castInt +cast(1 as int) as castInt +cast(true as int) as castInt +cast(false as int) as castInt +cast('1' as double) as castDouble +cast(1 as double) as castDouble +cast(true as double) as castDouble +cast(false as double) as castDouble +cast('2012-08-07 01:01:01' as timestamp) as castTimestamp +cast('2012-08-07' as date) as castDate +cast('01:01:01' as time) as castTime +cast('true' as boolean) as castBool +cast(1 as boolean) as castBool +cast(cast(1 as string) as int) castCombine diff --git a/sql/src/main/antlr/OpenDistroSQLLexer.g4 b/sql/src/main/antlr/OpenDistroSQLLexer.g4 index a14f0ad45a..8f094179d9 100644 --- a/sql/src/main/antlr/OpenDistroSQLLexer.g4 +++ b/sql/src/main/antlr/OpenDistroSQLLexer.g4 @@ -46,6 +46,7 @@ ALL: 'ALL'; AND: 'AND'; AS: 'AS'; ASC: 'ASC'; +BOOLEAN: 'BOOLEAN'; BETWEEN: 'BETWEEN'; BY: 'BY'; CASE: 'CASE'; diff --git a/sql/src/main/antlr/OpenDistroSQLParser.g4 b/sql/src/main/antlr/OpenDistroSQLParser.g4 index b9ff08f46e..5dbc92c86d 100644 --- a/sql/src/main/antlr/OpenDistroSQLParser.g4 +++ b/sql/src/main/antlr/OpenDistroSQLParser.g4 @@ -298,6 +298,19 @@ specificFunction (ELSE elseArg=functionArg)? END #caseFunctionCall | CASE caseFuncAlternative+ (ELSE elseArg=functionArg)? END #caseFunctionCall + | CAST '(' expression AS convertedDataType ')' #dataTypeFunctionCall + ; + +convertedDataType + : typeName=DATE + | typeName=TIME + | typeName=TIMESTAMP + | typeName=INT + | typeName=DOUBLE + | typeName=LONG + | typeName=FLOAT + | typeName=STRING + | typeName=BOOLEAN ; caseFuncAlternative diff --git a/sql/src/main/java/com/amazon/opendistroforelasticsearch/sql/sql/parser/AstExpressionBuilder.java b/sql/src/main/java/com/amazon/opendistroforelasticsearch/sql/sql/parser/AstExpressionBuilder.java index 4f45da4fc1..4c5947a223 100644 --- a/sql/src/main/java/com/amazon/opendistroforelasticsearch/sql/sql/parser/AstExpressionBuilder.java +++ b/sql/src/main/java/com/amazon/opendistroforelasticsearch/sql/sql/parser/AstExpressionBuilder.java @@ -53,6 +53,7 @@ import com.amazon.opendistroforelasticsearch.sql.ast.expression.AllFields; import com.amazon.opendistroforelasticsearch.sql.ast.expression.And; import com.amazon.opendistroforelasticsearch.sql.ast.expression.Case; +import com.amazon.opendistroforelasticsearch.sql.ast.expression.Cast; import com.amazon.opendistroforelasticsearch.sql.ast.expression.Function; import com.amazon.opendistroforelasticsearch.sql.ast.expression.Interval; import com.amazon.opendistroforelasticsearch.sql.ast.expression.IntervalUnit; @@ -315,6 +316,18 @@ public UnresolvedExpression visitCaseFuncAlternative(CaseFuncAlternativeContext return new When(visit(ctx.condition), visit(ctx.consequent)); } + @Override + public UnresolvedExpression visitDataTypeFunctionCall( + OpenDistroSQLParser.DataTypeFunctionCallContext ctx) { + return new Cast(visit(ctx.expression()), visit(ctx.convertedDataType())); + } + + @Override + public UnresolvedExpression visitConvertedDataType( + OpenDistroSQLParser.ConvertedDataTypeContext ctx) { + return AstDSL.stringLiteral(ctx.getText()); + } + private QualifiedName visitIdentifiers(List identifiers) { return new QualifiedName( identifiers.stream() diff --git a/sql/src/test/java/com/amazon/opendistroforelasticsearch/sql/sql/parser/AstExpressionBuilderTest.java b/sql/src/test/java/com/amazon/opendistroforelasticsearch/sql/sql/parser/AstExpressionBuilderTest.java index cab9d0221e..622a5be712 100644 --- a/sql/src/test/java/com/amazon/opendistroforelasticsearch/sql/sql/parser/AstExpressionBuilderTest.java +++ b/sql/src/test/java/com/amazon/opendistroforelasticsearch/sql/sql/parser/AstExpressionBuilderTest.java @@ -36,6 +36,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import com.amazon.opendistroforelasticsearch.sql.ast.Node; +import com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL; import com.amazon.opendistroforelasticsearch.sql.ast.expression.DataType; import com.amazon.opendistroforelasticsearch.sql.common.antlr.CaseInsensitiveCharStream; import com.amazon.opendistroforelasticsearch.sql.common.antlr.SyntaxAnalysisErrorListener; @@ -317,6 +318,22 @@ public void canBuildKeywordsAsIdentInQualifiedName() { ); } + @Test + public void canCastFieldAsString() { + assertEquals( + AstDSL.cast(qualifiedName("state"), stringLiteral("string")), + buildExprAst("cast(state as string)") + ); + } + + @Test + public void canCastValueAsString() { + assertEquals( + AstDSL.cast(intLiteral(1), stringLiteral("string")), + buildExprAst("cast(1 as string)") + ); + } + private Node buildExprAst(String expr) { OpenDistroSQLLexer lexer = new OpenDistroSQLLexer(new CaseInsensitiveCharStream(expr)); OpenDistroSQLParser parser = new OpenDistroSQLParser(new CommonTokenStream(lexer));