From bf1b47d12c853e0f5e62527c075eb0fc1ea8cc63 Mon Sep 17 00:00:00 2001 From: TengJianPing <18241664+jacktengg@users.noreply.github.com> Date: Wed, 23 Oct 2024 18:14:57 +0800 Subject: [PATCH] [fix](decimal256) support decimal256 for many functions (#42136) ## Proposed changes Issue Number: close #xxx Support decimal256 for the following functions: ``` multi_distinct_sum multi_distinct_count array_sum array_avg array_product array_cum_sum ``` --- .../aggregate_functions/aggregate_function.h | 4 + ...gregate_function_approx_count_distinct.cpp | 3 +- .../aggregate_function_avg.cpp | 15 +- .../aggregate_function_bitmap.cpp | 9 +- .../aggregate_function_bitmap_agg.cpp | 3 +- .../aggregate_function_collect.cpp | 3 +- .../aggregate_function_corr.cpp | 3 +- .../aggregate_function_count.cpp | 9 +- .../aggregate_function_count_by_enum.cpp | 3 +- .../aggregate_function_covar.cpp | 6 +- .../aggregate_function_distinct.cpp | 5 +- .../aggregate_function_foreach.cpp | 7 +- ...gregate_function_group_array_intersect.cpp | 3 +- .../aggregate_function_group_concat.cpp | 3 +- .../aggregate_function_histogram.cpp | 3 +- .../aggregate_function_kurtosis.cpp | 3 +- .../aggregate_function_linear_histogram.cpp | 3 +- .../aggregate_function_map.cpp | 3 +- .../aggregate_function_min_max.cpp | 3 +- .../aggregate_function_min_max.h | 3 +- .../aggregate_function_min_max_by.h | 3 +- .../aggregate_function_orthogonal_bitmap.cpp | 3 +- .../aggregate_function_percentile.cpp | 9 +- .../aggregate_function_product.h | 16 +- .../aggregate_function_quantile_state.cpp | 6 +- .../aggregate_function_quantile_state.h | 6 +- .../aggregate_function_reader_first_last.h | 39 +-- .../aggregate_function_regr_union.cpp | 3 +- .../aggregate_function_sequence_match.cpp | 3 +- .../aggregate_function_simple_factory.h | 15 +- .../aggregate_function_skew.cpp | 3 +- .../aggregate_function_stddev.cpp | 12 +- .../aggregate_function_sum.cpp | 15 +- .../aggregate_function_sum.h | 1 - .../aggregate_function_topn.cpp | 9 +- .../aggregate_function_uniq.cpp | 8 +- ...aggregate_function_uniq_distribute_key.cpp | 3 +- .../aggregate_function_window.cpp | 6 +- .../aggregate_function_window_funnel.cpp | 3 +- be/src/vec/aggregate_functions/helpers.h | 15 +- be/src/vec/core/wide_integer.h | 5 + be/src/vec/core/wide_integer_impl.h | 34 +- be/src/vec/exec/scan/vfile_scanner.cpp | 5 +- be/src/vec/exprs/vcase_expr.cpp | 6 +- be/src/vec/exprs/vcast_expr.cpp | 6 +- be/src/vec/exprs/vectorized_agg_fn.cpp | 4 +- be/src/vec/exprs/vectorized_fn_call.cpp | 3 +- be/src/vec/exprs/vin_predicate.cpp | 5 +- be/src/vec/exprs/vmatch_predicate.cpp | 6 +- be/src/vec/exprs/vtopn_pred.h | 2 +- .../array/function_array_aggregation.cpp | 100 ++++-- .../array/function_array_cum_sum.cpp | 31 +- .../functions/comparison_equal_for_null.cpp | 11 +- be/src/vec/functions/function.h | 4 + be/src/vec/functions/function_coalesce.cpp | 21 +- be/src/vec/functions/function_ifnull.h | 4 +- be/src/vec/functions/nullif.cpp | 11 +- .../vec/functions/simple_function_factory.h | 11 +- .../agg_linear_histogram_test.cpp | 3 +- .../decimalv3}/aggregate_decimal256.out | 8 + .../decimalv3/test_decimal256_array.out | 63 ++++ .../test_decimal256_multi_distinct.out | 33 ++ .../scalar_function/Array.out | 306 ++++++++++++++++++ .../decimalv3}/aggregate_decimal256.groovy | 4 +- .../decimalv3/test_decimal256_array.groovy | 118 +++++++ .../test_decimal256_multi_distinct.groovy | 73 +++++ .../scalar_function/Array.groovy | 24 +- 67 files changed, 956 insertions(+), 217 deletions(-) rename regression-test/data/{query_p0/aggregate => datatype_p0/decimalv3}/aggregate_decimal256.out (95%) create mode 100644 regression-test/data/datatype_p0/decimalv3/test_decimal256_array.out create mode 100644 regression-test/data/datatype_p0/decimalv3/test_decimal256_multi_distinct.out rename regression-test/suites/{query_p0/aggregate => datatype_p0/decimalv3}/aggregate_decimal256.groovy (95%) create mode 100644 regression-test/suites/datatype_p0/decimalv3/test_decimal256_array.groovy create mode 100644 regression-test/suites/datatype_p0/decimalv3/test_decimal256_multi_distinct.groovy diff --git a/be/src/vec/aggregate_functions/aggregate_function.h b/be/src/vec/aggregate_functions/aggregate_function.h index 05f1bd2a602c68..cd1f8922e1b459 100644 --- a/be/src/vec/aggregate_functions/aggregate_function.h +++ b/be/src/vec/aggregate_functions/aggregate_function.h @@ -38,6 +38,10 @@ class Arena; class IColumn; class IDataType; +struct AggregateFunctionAttr { + bool enable_decimal256 {false}; +}; + template class AggregateFunctionBitmapCount; template diff --git a/be/src/vec/aggregate_functions/aggregate_function_approx_count_distinct.cpp b/be/src/vec/aggregate_functions/aggregate_function_approx_count_distinct.cpp index 10616be4258477..18662bf66cf38c 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_approx_count_distinct.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_approx_count_distinct.cpp @@ -31,7 +31,8 @@ namespace doris::vectorized { AggregateFunctionPtr create_aggregate_function_approx_count_distinct( - const std::string& name, const DataTypes& argument_types, const bool result_is_nullable) { + const std::string& name, const DataTypes& argument_types, const bool result_is_nullable, + const AggregateFunctionAttr& attr) { WhichDataType which(remove_nullable(argument_types[0])); #define DISPATCH(TYPE, COLUMN_TYPE) \ diff --git a/be/src/vec/aggregate_functions/aggregate_function_avg.cpp b/be/src/vec/aggregate_functions/aggregate_function_avg.cpp index 0f3d0fd3bdad6b..6a6711f90f983e 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_avg.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_avg.cpp @@ -45,8 +45,17 @@ template using AggregateFuncAvgDecimal256 = typename AvgDecimal256::Function; void register_aggregate_function_avg(AggregateFunctionSimpleFactory& factory) { - factory.register_function_both("avg", creator_with_type::creator); - factory.register_function_both("avg_decimal256", - creator_with_type::creator); + AggregateFunctionCreator creator = [&](const std::string& name, const DataTypes& types, + const bool result_is_nullable, + const AggregateFunctionAttr& attr) { + if (attr.enable_decimal256) { + return creator_with_type::creator(name, types, + result_is_nullable, attr); + } else { + return creator_with_type::creator(name, types, result_is_nullable, + attr); + } + }; + factory.register_function_both("avg", creator); } } // namespace doris::vectorized diff --git a/be/src/vec/aggregate_functions/aggregate_function_bitmap.cpp b/be/src/vec/aggregate_functions/aggregate_function_bitmap.cpp index 0676fd5bc27090..e9c86d4b9556da 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_bitmap.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_bitmap.cpp @@ -40,9 +40,9 @@ AggregateFunctionPtr create_with_int_data_type(const DataTypes& argument_type) { return nullptr; } -AggregateFunctionPtr create_aggregate_function_bitmap_union_count(const std::string& name, - const DataTypes& argument_types, - const bool result_is_nullable) { +AggregateFunctionPtr create_aggregate_function_bitmap_union_count( + const std::string& name, const DataTypes& argument_types, const bool result_is_nullable, + const AggregateFunctionAttr& attr) { const bool arg_is_nullable = argument_types[0]->is_nullable(); if (arg_is_nullable) { return std::make_shared>(argument_types); @@ -53,7 +53,8 @@ AggregateFunctionPtr create_aggregate_function_bitmap_union_count(const std::str AggregateFunctionPtr create_aggregate_function_bitmap_union_int(const std::string& name, const DataTypes& argument_types, - const bool result_is_nullable) { + const bool result_is_nullable, + const AggregateFunctionAttr& attr) { const bool arg_is_nullable = argument_types[0]->is_nullable(); if (arg_is_nullable) { return AggregateFunctionPtr( diff --git a/be/src/vec/aggregate_functions/aggregate_function_bitmap_agg.cpp b/be/src/vec/aggregate_functions/aggregate_function_bitmap_agg.cpp index b8ae4c6530d575..0b95ddfd46f0d5 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_bitmap_agg.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_bitmap_agg.cpp @@ -41,7 +41,8 @@ AggregateFunctionPtr create_with_int_data_type(const DataTypes& argument_types) AggregateFunctionPtr create_aggregate_function_bitmap_agg(const std::string& name, const DataTypes& argument_types, - const bool result_is_nullable) { + const bool result_is_nullable, + const AggregateFunctionAttr& attr) { const bool arg_is_nullable = argument_types[0]->is_nullable(); if (arg_is_nullable) { return AggregateFunctionPtr(create_with_int_data_type(argument_types)); diff --git a/be/src/vec/aggregate_functions/aggregate_function_collect.cpp b/be/src/vec/aggregate_functions/aggregate_function_collect.cpp index 4fcf09b59b33c6..d726b7c6355318 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_collect.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_collect.cpp @@ -96,7 +96,8 @@ AggregateFunctionPtr create_aggregate_function_collect_impl(const std::string& n AggregateFunctionPtr create_aggregate_function_collect(const std::string& name, const DataTypes& argument_types, - const bool result_is_nullable) { + const bool result_is_nullable, + const AggregateFunctionAttr& attr) { if (argument_types.size() == 1) { if (name == "array_agg") { return create_aggregate_function_collect_impl( diff --git a/be/src/vec/aggregate_functions/aggregate_function_corr.cpp b/be/src/vec/aggregate_functions/aggregate_function_corr.cpp index a454afb45f22e0..cdaab6e086f4a5 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_corr.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_corr.cpp @@ -89,7 +89,8 @@ struct CorrMoment { AggregateFunctionPtr create_aggregate_corr_function(const std::string& name, const DataTypes& argument_types, - const bool result_is_nullable) { + const bool result_is_nullable, + const AggregateFunctionAttr& attr) { assert_binary(name, argument_types); return create_with_two_basic_numeric_types(argument_types[0], argument_types[1], argument_types, result_is_nullable); diff --git a/be/src/vec/aggregate_functions/aggregate_function_count.cpp b/be/src/vec/aggregate_functions/aggregate_function_count.cpp index 8c54714b046da1..5cfe5af41982f6 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_count.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_count.cpp @@ -29,15 +29,16 @@ namespace doris::vectorized { AggregateFunctionPtr create_aggregate_function_count(const std::string& name, const DataTypes& argument_types, - const bool result_is_nullable) { + const bool result_is_nullable, + const AggregateFunctionAttr& attr) { assert_arity_at_most<1>(name, argument_types); return std::make_shared(argument_types); } -AggregateFunctionPtr create_aggregate_function_count_not_null_unary(const std::string& name, - const DataTypes& argument_types, - const bool result_is_nullable) { +AggregateFunctionPtr create_aggregate_function_count_not_null_unary( + const std::string& name, const DataTypes& argument_types, const bool result_is_nullable, + const AggregateFunctionAttr& attr) { assert_arity_at_most<1>(name, argument_types); return std::make_shared(argument_types); diff --git a/be/src/vec/aggregate_functions/aggregate_function_count_by_enum.cpp b/be/src/vec/aggregate_functions/aggregate_function_count_by_enum.cpp index 1a0bf2518202f3..093b31d57db554 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_count_by_enum.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_count_by_enum.cpp @@ -29,7 +29,8 @@ namespace doris::vectorized { AggregateFunctionPtr create_aggregate_function_count_by_enum(const std::string& name, const DataTypes& argument_types, - const bool result_is_nullable) { + const bool result_is_nullable, + const AggregateFunctionAttr& attr) { if (argument_types.size() < 1) { LOG(WARNING) << fmt::format("Illegal number {} of argument for aggregate function {}", argument_types.size(), name); diff --git a/be/src/vec/aggregate_functions/aggregate_function_covar.cpp b/be/src/vec/aggregate_functions/aggregate_function_covar.cpp index b02d6ae0e12572..71d09f61de4302 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_covar.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_covar.cpp @@ -53,14 +53,16 @@ AggregateFunctionPtr create_function_single_value(const String& name, AggregateFunctionPtr create_aggregate_function_covariance_samp(const std::string& name, const DataTypes& argument_types, - const bool result_is_nullable) { + const bool result_is_nullable, + const AggregateFunctionAttr& attr) { return create_function_single_value( name, argument_types, result_is_nullable, NOTNULLABLE); } AggregateFunctionPtr create_aggregate_function_covariance_pop(const std::string& name, const DataTypes& argument_types, - const bool result_is_nullable) { + const bool result_is_nullable, + const AggregateFunctionAttr& attr) { return create_function_single_value( name, argument_types, result_is_nullable, NOTNULLABLE); } diff --git a/be/src/vec/aggregate_functions/aggregate_function_distinct.cpp b/be/src/vec/aggregate_functions/aggregate_function_distinct.cpp index 9bb2954207babb..fce58b38688b28 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_distinct.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_distinct.cpp @@ -83,7 +83,8 @@ const std::string DISTINCT_FUNCTION_PREFIX = "multi_distinct_"; void register_aggregate_function_combinator_distinct(AggregateFunctionSimpleFactory& factory) { AggregateFunctionCreator creator = [&](const std::string& name, const DataTypes& types, - const bool result_is_nullable) { + const bool result_is_nullable, + const AggregateFunctionAttr& attr) { // 1. we should get not nullable types; DataTypes nested_types(types.size()); std::transform(types.begin(), types.end(), nested_types.begin(), @@ -92,7 +93,7 @@ void register_aggregate_function_combinator_distinct(AggregateFunctionSimpleFact auto transform_arguments = function_combinator->transform_arguments(nested_types); auto nested_function_name = name.substr(DISTINCT_FUNCTION_PREFIX.size()); auto nested_function = factory.get(nested_function_name, transform_arguments, false, - BeExecVersionManager::get_newest_version()); + BeExecVersionManager::get_newest_version(), attr); return function_combinator->transform_aggregate_function(nested_function, types, result_is_nullable); }; diff --git a/be/src/vec/aggregate_functions/aggregate_function_foreach.cpp b/be/src/vec/aggregate_functions/aggregate_function_foreach.cpp index ab6d0142f6a8c0..c1cbcc89996caf 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_foreach.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_foreach.cpp @@ -34,8 +34,9 @@ namespace doris::vectorized { void register_aggregate_function_combinator_foreach(AggregateFunctionSimpleFactory& factory) { - AggregateFunctionCreator creator = [&](const std::string& name, const DataTypes& types, - const bool result_is_nullable) -> AggregateFunctionPtr { + AggregateFunctionCreator creator = + [&](const std::string& name, const DataTypes& types, const bool result_is_nullable, + const AggregateFunctionAttr& attr) -> AggregateFunctionPtr { const std::string& suffix = AggregateFunctionForEach::AGG_FOREACH_SUFFIX; DataTypes transform_arguments; for (const auto& t : types) { @@ -46,7 +47,7 @@ void register_aggregate_function_combinator_foreach(AggregateFunctionSimpleFacto auto nested_function_name = name.substr(0, name.size() - suffix.size()); auto nested_function = factory.get(nested_function_name, transform_arguments, result_is_nullable, - BeExecVersionManager::get_newest_version(), false); + BeExecVersionManager::get_newest_version(), attr); if (!nested_function) { throw Exception( ErrorCode::INTERNAL_ERROR, diff --git a/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.cpp b/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.cpp index b3b9a8b9af47c6..24faf58b2e1ff9 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.cpp @@ -70,7 +70,8 @@ inline AggregateFunctionPtr create_aggregate_function_group_array_intersect_impl } AggregateFunctionPtr create_aggregate_function_group_array_intersect( - const std::string& name, const DataTypes& argument_types, const bool result_is_nullable) { + const std::string& name, const DataTypes& argument_types, const bool result_is_nullable, + const AggregateFunctionAttr& attr) { assert_unary(name, argument_types); const DataTypePtr& argument_type = remove_nullable(argument_types[0]); diff --git a/be/src/vec/aggregate_functions/aggregate_function_group_concat.cpp b/be/src/vec/aggregate_functions/aggregate_function_group_concat.cpp index 9661b9c89d5700..286795ea2ba70c 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_group_concat.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_group_concat.cpp @@ -28,7 +28,8 @@ const std::string AggregateFunctionGroupConcatImplStr::separator = ","; AggregateFunctionPtr create_aggregate_function_group_concat(const std::string& name, const DataTypes& argument_types, - const bool result_is_nullable) { + const bool result_is_nullable, + const AggregateFunctionAttr& attr) { if (argument_types.size() == 1) { return creator_without_type::create< AggregateFunctionGroupConcat>( diff --git a/be/src/vec/aggregate_functions/aggregate_function_histogram.cpp b/be/src/vec/aggregate_functions/aggregate_function_histogram.cpp index 5b06af28399d71..fb2fa9c2513ec0 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_histogram.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_histogram.cpp @@ -47,7 +47,8 @@ AggregateFunctionPtr create_agg_function_histogram(const DataTypes& argument_typ AggregateFunctionPtr create_aggregate_function_histogram(const std::string& name, const DataTypes& argument_types, - const bool result_is_nullable) { + const bool result_is_nullable, + const AggregateFunctionAttr& attr) { WhichDataType type(remove_nullable(argument_types[0])); #define DISPATCH(TYPE) \ diff --git a/be/src/vec/aggregate_functions/aggregate_function_kurtosis.cpp b/be/src/vec/aggregate_functions/aggregate_function_kurtosis.cpp index 00ad1893eafcf6..a763721f3f4061 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_kurtosis.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_kurtosis.cpp @@ -45,7 +45,8 @@ AggregateFunctionPtr type_dispatch_for_aggregate_function_kurt(const DataTypes& AggregateFunctionPtr create_aggregate_function_kurt(const std::string& name, const DataTypes& argument_types, - const bool result_is_nullable) { + const bool result_is_nullable, + const AggregateFunctionAttr& attr) { if (argument_types.size() != 1) { LOG(WARNING) << "aggregate function " << name << " requires exactly 1 argument"; return nullptr; diff --git a/be/src/vec/aggregate_functions/aggregate_function_linear_histogram.cpp b/be/src/vec/aggregate_functions/aggregate_function_linear_histogram.cpp index 62ce1657526767..683cf1a18f78ba 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_linear_histogram.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_linear_histogram.cpp @@ -41,7 +41,8 @@ AggregateFunctionPtr create_agg_function_linear_histogram(const DataTypes& argum AggregateFunctionPtr create_aggregate_function_linear_histogram(const std::string& name, const DataTypes& argument_types, - const bool result_is_nullable) { + const bool result_is_nullable, + const AggregateFunctionAttr& attr) { WhichDataType type(remove_nullable(argument_types[0])); #define DISPATCH(TYPE) \ diff --git a/be/src/vec/aggregate_functions/aggregate_function_map.cpp b/be/src/vec/aggregate_functions/aggregate_function_map.cpp index bcf3f2d66dfeaf..f289d885f48f52 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_map.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_map.cpp @@ -32,7 +32,8 @@ AggregateFunctionPtr create_agg_function_map_agg(const DataTypes& argument_types AggregateFunctionPtr create_aggregate_function_map_agg(const std::string& name, const DataTypes& argument_types, - const bool result_is_nullable) { + const bool result_is_nullable, + const AggregateFunctionAttr& attr) { WhichDataType type(remove_nullable(argument_types[0])); #define DISPATCH(TYPE) \ diff --git a/be/src/vec/aggregate_functions/aggregate_function_min_max.cpp b/be/src/vec/aggregate_functions/aggregate_function_min_max.cpp index 8aa8850a314d84..c1a72fd52bdd76 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_min_max.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_min_max.cpp @@ -30,7 +30,8 @@ namespace doris::vectorized { template