diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java index cac7ec392a6e78..9abd91acad7c26 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java @@ -1191,6 +1191,21 @@ private void analyzeArrayFunction(Analyzer analyzer) throws AnalysisException { } fn.setReturnType(getChild(0).getType()); } + + // make nested type with function param can be Compatible otherwise be will not deal with type + if (fnName.getFunction().equalsIgnoreCase("array_position") + || fnName.getFunction().equalsIgnoreCase("array_contains") + || fnName.getFunction().equalsIgnoreCase("countequal")) { + Type[] childTypes = collectChildReturnTypes(); + Type compatibleType = ((ArrayType) childTypes[0]).getItemType(); + for (int i = 1; i < childTypes.length; ++i) { + compatibleType = Type.getAssignmentCompatibleType(compatibleType, childTypes[i], true); + if (compatibleType == Type.INVALID) { + throw new AnalysisException(getFunctionNotFoundError(collectChildReturnTypes())); + } + uncheckedCastChild(compatibleType, i); + } + } } // Provide better error message for some aggregate builtins. These can be @@ -1698,8 +1713,6 @@ && collectChildReturnTypes()[0].isDecimalV3()) { || fnName.getFunction().equalsIgnoreCase("array_shuffle") || fnName.getFunction().equalsIgnoreCase("shuffle") || fnName.getFunction().equalsIgnoreCase("array_except") - || fnName.getFunction().equalsIgnoreCase("array_contains") - || fnName.getFunction().equalsIgnoreCase("array_position") || fnName.getFunction().equalsIgnoreCase("width_bucket")) && (args[ix].isDecimalV3() || (children.get(0).getType().isArrayType() && (((ArrayType) children.get(0).getType()).getItemType().isDecimalV3()) diff --git a/regression-test/data/query_p0/sql_functions/array_functions/test_array_functions.out b/regression-test/data/query_p0/sql_functions/array_functions/test_array_functions.out index 79a8da76a26f72..086a172e3e3d09 100644 --- a/regression-test/data/query_p0/sql_functions/array_functions/test_array_functions.out +++ b/regression-test/data/query_p0/sql_functions/array_functions/test_array_functions.out @@ -1055,6 +1055,17 @@ 8 \N 9 \N +-- !select -- +1 1 +2 0 +3 \N +4 \N +5 \N +6 \N +7 \N +8 \N +9 \N + -- !select_array -- 1 [1] true 2 [2] true @@ -1655,5 +1666,19 @@ [2023-01-19 18:22:22.222, 2023-01-19 18:33:33.333, 2023-01-19 18:44:44.444] -- !select_array_datetimev2_4 -- -[2023-01-19 18:11:11.111111, 2023-01-19 18:22:22.222222, 2023-01-19 18:33:33.333333] +[2023-01-19 18:11:11.111, 2023-01-19 18:22:22.222, 2023-01-19 18:33:33.333] + +-- !sql -- +1 [0.100000000, 0.100000000] 0.100 +1 [0.200000000, 0.200000000] 0.200 +1 [0.300000000, 0.300000000] 0.300 +1 [0.400000000, 0.400000000] 0.400 +1 [0.500000000, 0.500000000] 0.500 +1 [0.600000000, 0.600000000] 0.600 +1 [0.700000000, 0.700000000] 0.700 +1 [0.800000000, 0.800000000] 0.800 +1 [0.900000000, 0.900000000] 0.900 +1 [1.000000000, 1.000000000] 1.000 +1 [1.100000000, 1.100000000] 1.100 +1 [1.200000000, 1.200000000] 1.200 diff --git a/regression-test/suites/query_p0/sql_functions/array_functions/test_array_functions.groovy b/regression-test/suites/query_p0/sql_functions/array_functions/test_array_functions.groovy index 910a4d4f5e1181..6a770daa799a3c 100644 --- a/regression-test/suites/query_p0/sql_functions/array_functions/test_array_functions.groovy +++ b/regression-test/suites/query_p0/sql_functions/array_functions/test_array_functions.groovy @@ -16,6 +16,7 @@ // under the License. suite("test_array_functions") { + sql "SET enable_nereids_planner=false" def tableName = "tbl_test_array_functions" // array functions only supported in vectorized engine sql """DROP TABLE IF EXISTS ${tableName}""" @@ -159,6 +160,9 @@ suite("test_array_functions") { qt_select "SELECT k1, array_position(k8, cast('2023-02-05' as datev2)) FROM ${tableName} ORDER BY k1" qt_select "SELECT k1, array_position(k10, cast('2022-10-15 10:30:00.999' as datetimev2(3))) FROM ${tableName} ORDER BY k1" qt_select "SELECT k1, array_position(k12, cast(111.111 as decimalv3(6,3))) FROM ${tableName} ORDER BY k1" + // array_position without cast function + qt_select "SELECT k1, array_position(k12, 111.111) FROM ${tableName} ORDER BY k1" + qt_select_array "SELECT k1, array(k1), array_contains(array(k1), k1) from ${tableName} ORDER BY k1" qt_select "SELECT k1, array_concat(k2, k4) FROM ${tableName} ORDER BY k1" qt_select "SELECT k1, array_concat(k2, [1, null, 2], k4, [null]) FROM ${tableName} ORDER BY k1" @@ -295,4 +299,37 @@ suite("test_array_functions") { qt_select_array_datetimev2_2 "SELECT if(1,k2,k3) FROM ${tableName4}" qt_select_array_datetimev2_3 "SELECT if(0,k2,k3) FROM ${tableName4}" qt_select_array_datetimev2_4 "SELECT if(0,k2,k4) FROM ${tableName4}" + + // array with decimal + sql "drop table if exists fn_test" + + sql """ + CREATE TABLE IF NOT EXISTS `fn_test` ( + `id` int null, + `kdcmls1` decimal(9, 3) null, + `kdcmls3` decimal(27, 9) null, + `kdcmlv3s1` decimalv3(9, 3) null, + `kdcmlv3s2` decimalv3(15, 5) null, + `kdcmlv3s3` decimalv3(27, 9) null, + `kadcml` array null, + ) engine=olap + DISTRIBUTED BY HASH(`id`) BUCKETS 1 + properties("replication_num" = "1") + """ + sql """ insert into `fn_test` values + (0, 0.100, 0.100000000 , 0.100, 0.10000, 0.100000000, [0.100000000, 0.100000000]), + (1, 0.200, 0.200000000 , 0.200, 0.20000, 0.200000000, [0.200000000, 0.200000000]), + (3, 0.300, 0.300000000 , 0.300, 0.30000, 0.300000000, [0.300000000, 0.300000000]), + (4, 0.400, 0.400000000 , 0.400, 0.40000, 0.400000000, [0.400000000, 0.400000000]), + (5, 0.500, 0.500000000 , 0.500, 0.50000, 0.500000000, [0.500000000, 0.500000000]), + (6, 0.600, 0.600000000 , 0.600, 0.60000, 0.600000000, [0.600000000, 0.600000000]), + (7, 0.700, 0.700000000 , 0.700, 0.70000, 0.700000000, [0.700000000, 0.700000000]), + (8, 0.800, 0.800000000 , 0.800, 0.80000, 0.800000000, [0.800000000, 0.800000000]), + (9, 0.900, 0.900000000 , 0.900, 0.90000, 0.900000000, [0.900000000, 0.900000000]), + (10, 1.000, 1.000000000 , 1.000, 1.00000, 1.000000000, [1.000000000, 1.000000000]), + (11, 1.100, 1.100000000 , 1.100, 1.10000, 1.100000000, [1.100000000, 1.100000000]), + (12, 1.200, 1.200000000 , 1.200, 1.20000, 1.200000000, [1.200000000, 1.200000000]); """ + + qt_sql """ select array_position(kadcml, kdcmls1), kadcml, kdcmls1 from fn_test;""" + }