Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix](decimal256) support decimal256 for many functions (#42136) #42356

Merged
merged 1 commit into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion be/src/runtime/runtime_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ class RuntimeState {
_query_options.check_overflow_for_decimal;
}

bool enable_decima256() const {
bool enable_decimal256() const {
return _query_options.__isset.enable_decimal256 && _query_options.enable_decimal256;
}

Expand Down
4 changes: 4 additions & 0 deletions be/src/vec/aggregate_functions/aggregate_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ class Arena;
class IColumn;
class IDataType;

struct AggregateFunctionAttr {
bool enable_decimal256 {false};
};

template <bool nullable, typename ColVecType>
class AggregateFunctionBitmapCount;
template <typename Op>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
namespace doris::vectorized {

AggregateFunctionPtr create_aggregate_function_approx_count_distinct(
const std::string& name, const DataTypes& argument_types, const bool result_is_nullable) {
const std::string& name, const DataTypes& argument_types, const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
WhichDataType which(remove_nullable(argument_types[0]));

#define DISPATCH(TYPE, COLUMN_TYPE) \
Expand Down
15 changes: 12 additions & 3 deletions be/src/vec/aggregate_functions/aggregate_function_avg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,17 @@ template <typename T>
using AggregateFuncAvgDecimal256 = typename AvgDecimal256<T>::Function;

void register_aggregate_function_avg(AggregateFunctionSimpleFactory& factory) {
factory.register_function_both("avg", creator_with_type::creator<AggregateFuncAvg>);
factory.register_function_both("avg_decimal256",
creator_with_type::creator<AggregateFuncAvgDecimal256>);
AggregateFunctionCreator creator = [&](const std::string& name, const DataTypes& types,
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
if (attr.enable_decimal256) {
return creator_with_type::creator<AggregateFuncAvgDecimal256>(name, types,
result_is_nullable, attr);
} else {
return creator_with_type::creator<AggregateFuncAvg>(name, types, result_is_nullable,
attr);
}
};
factory.register_function_both("avg", creator);
}
} // namespace doris::vectorized
9 changes: 5 additions & 4 deletions be/src/vec/aggregate_functions/aggregate_function_bitmap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ AggregateFunctionPtr create_with_int_data_type(const DataTypes& argument_type) {
return nullptr;
}

AggregateFunctionPtr create_aggregate_function_bitmap_union_count(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
AggregateFunctionPtr create_aggregate_function_bitmap_union_count(
const std::string& name, const DataTypes& argument_types, const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
const bool arg_is_nullable = argument_types[0]->is_nullable();
if (arg_is_nullable) {
return std::make_shared<AggregateFunctionBitmapCount<true, ColumnBitmap>>(argument_types);
Expand All @@ -53,7 +53,8 @@ AggregateFunctionPtr create_aggregate_function_bitmap_union_count(const std::str

AggregateFunctionPtr create_aggregate_function_bitmap_union_int(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
const bool arg_is_nullable = argument_types[0]->is_nullable();
if (arg_is_nullable) {
return AggregateFunctionPtr(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ AggregateFunctionPtr create_with_int_data_type(const DataTypes& argument_types)

AggregateFunctionPtr create_aggregate_function_bitmap_agg(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
const bool arg_is_nullable = argument_types[0]->is_nullable();
if (arg_is_nullable) {
return AggregateFunctionPtr(create_with_int_data_type<true>(argument_types));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ AggregateFunctionPtr create_aggregate_function_collect_impl(const std::string& n

AggregateFunctionPtr create_aggregate_function_collect(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
if (argument_types.size() == 1) {
if (name == "array_agg") {
return create_aggregate_function_collect_impl<std::false_type, std::true_type>(
Expand Down
3 changes: 2 additions & 1 deletion be/src/vec/aggregate_functions/aggregate_function_corr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ struct CorrMoment {

AggregateFunctionPtr create_aggregate_corr_function(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
assert_binary(name, argument_types);
return create_with_two_basic_numeric_types<CorrMoment>(argument_types[0], argument_types[1],
argument_types, result_is_nullable);
Expand Down
9 changes: 5 additions & 4 deletions be/src/vec/aggregate_functions/aggregate_function_count.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,16 @@ namespace doris::vectorized {

AggregateFunctionPtr create_aggregate_function_count(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
assert_arity_at_most<1>(name, argument_types);

return std::make_shared<AggregateFunctionCount>(argument_types);
}

AggregateFunctionPtr create_aggregate_function_count_not_null_unary(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
AggregateFunctionPtr create_aggregate_function_count_not_null_unary(
const std::string& name, const DataTypes& argument_types, const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
assert_arity_at_most<1>(name, argument_types);

return std::make_shared<AggregateFunctionCountNotNullUnary>(argument_types);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ namespace doris::vectorized {

AggregateFunctionPtr create_aggregate_function_count_by_enum(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
if (argument_types.size() < 1) {
LOG(WARNING) << fmt::format("Illegal number {} of argument for aggregate function {}",
argument_types.size(), name);
Expand Down
6 changes: 4 additions & 2 deletions be/src/vec/aggregate_functions/aggregate_function_covar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,16 @@ AggregateFunctionPtr create_function_single_value(const String& name,

AggregateFunctionPtr create_aggregate_function_covariance_samp(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
return create_function_single_value<AggregateFunctionSamp, CovarSampName, SampData>(
name, argument_types, result_is_nullable, NOTNULLABLE);
}

AggregateFunctionPtr create_aggregate_function_covariance_pop(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
return create_function_single_value<AggregateFunctionPop, CovarName, PopData>(
name, argument_types, result_is_nullable, NOTNULLABLE);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ const std::string DISTINCT_FUNCTION_PREFIX = "multi_distinct_";

void register_aggregate_function_combinator_distinct(AggregateFunctionSimpleFactory& factory) {
AggregateFunctionCreator creator = [&](const std::string& name, const DataTypes& types,
const bool result_is_nullable) {
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
// 1. we should get not nullable types;
DataTypes nested_types(types.size());
std::transform(types.begin(), types.end(), nested_types.begin(),
Expand All @@ -92,7 +93,7 @@ void register_aggregate_function_combinator_distinct(AggregateFunctionSimpleFact
auto transform_arguments = function_combinator->transform_arguments(nested_types);
auto nested_function_name = name.substr(DISTINCT_FUNCTION_PREFIX.size());
auto nested_function = factory.get(nested_function_name, transform_arguments, false,
BeExecVersionManager::get_newest_version());
BeExecVersionManager::get_newest_version(), attr);
return function_combinator->transform_aggregate_function(nested_function, types,
result_is_nullable);
};
Expand Down
7 changes: 4 additions & 3 deletions be/src/vec/aggregate_functions/aggregate_function_foreach.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@
namespace doris::vectorized {

void register_aggregate_function_combinator_foreach(AggregateFunctionSimpleFactory& factory) {
AggregateFunctionCreator creator = [&](const std::string& name, const DataTypes& types,
const bool result_is_nullable) -> AggregateFunctionPtr {
AggregateFunctionCreator creator =
[&](const std::string& name, const DataTypes& types, const bool result_is_nullable,
const AggregateFunctionAttr& attr) -> AggregateFunctionPtr {
const std::string& suffix = AggregateFunctionForEach::AGG_FOREACH_SUFFIX;
DataTypes transform_arguments;
for (const auto& t : types) {
Expand All @@ -46,7 +47,7 @@ void register_aggregate_function_combinator_foreach(AggregateFunctionSimpleFacto
auto nested_function_name = name.substr(0, name.size() - suffix.size());
auto nested_function =
factory.get(nested_function_name, transform_arguments, result_is_nullable,
BeExecVersionManager::get_newest_version(), false);
BeExecVersionManager::get_newest_version(), attr);
if (!nested_function) {
throw Exception(
ErrorCode::INTERNAL_ERROR,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ inline AggregateFunctionPtr create_aggregate_function_group_array_intersect_impl
}

AggregateFunctionPtr create_aggregate_function_group_array_intersect(
const std::string& name, const DataTypes& argument_types, const bool result_is_nullable) {
const std::string& name, const DataTypes& argument_types, const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
assert_unary(name, argument_types);
const DataTypePtr& argument_type = remove_nullable(argument_types[0]);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ const std::string AggregateFunctionGroupConcatImplStr::separator = ",";

AggregateFunctionPtr create_aggregate_function_group_concat(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
if (argument_types.size() == 1) {
return creator_without_type::create<
AggregateFunctionGroupConcat<AggregateFunctionGroupConcatImplStr>>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ AggregateFunctionPtr create_agg_function_histogram(const DataTypes& argument_typ

AggregateFunctionPtr create_aggregate_function_histogram(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
WhichDataType type(remove_nullable(argument_types[0]));

#define DISPATCH(TYPE) \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ AggregateFunctionPtr type_dispatch_for_aggregate_function_kurt(const DataTypes&

AggregateFunctionPtr create_aggregate_function_kurt(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
if (argument_types.size() != 1) {
LOG(WARNING) << "aggregate function " << name << " requires exactly 1 argument";
return nullptr;
Expand Down
3 changes: 2 additions & 1 deletion be/src/vec/aggregate_functions/aggregate_function_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ AggregateFunctionPtr create_agg_function_map_agg(const DataTypes& argument_types

AggregateFunctionPtr create_aggregate_function_map_agg(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
WhichDataType type(remove_nullable(argument_types[0]));

#define DISPATCH(TYPE) \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ namespace doris::vectorized {
template <template <typename> class Data>
AggregateFunctionPtr create_aggregate_function_single_value(const String& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
assert_unary(name, argument_types);

AggregateFunctionPtr res(creator_with_numeric_type::create<AggregateFunctionsSingleValue, Data,
Expand Down
3 changes: 2 additions & 1 deletion be/src/vec/aggregate_functions/aggregate_function_min_max.h
Original file line number Diff line number Diff line change
Expand Up @@ -714,5 +714,6 @@ class AggregateFunctionsSingleValue final
template <template <typename> class Data>
AggregateFunctionPtr create_aggregate_function_single_value(const String& name,
const DataTypes& argument_types,
const bool result_is_nullable);
const bool result_is_nullable,
const AggregateFunctionAttr& attr = {});
} // namespace doris::vectorized
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,8 @@ template <template <typename> class AggregateFunctionTemplate,
template <typename, typename> class Data>
AggregateFunctionPtr create_aggregate_function_min_max_by(const String& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
if (argument_types.size() != 2) {
return nullptr;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ template <template <typename> class Impl>
AggregateFunctionPtr create_aggregate_function_orthogonal(const std::string& name,
const DataTypes& argument_types,

const bool result_is_nullable) {
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
if (argument_types.empty()) {
LOG(WARNING) << "Incorrect number of arguments for aggregate function " << name;
return nullptr;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@

namespace doris::vectorized {

AggregateFunctionPtr create_aggregate_function_percentile_approx(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
AggregateFunctionPtr create_aggregate_function_percentile_approx(
const std::string& name, const DataTypes& argument_types, const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
const DataTypePtr& argument_type = remove_nullable(argument_types[0]);
WhichDataType which(argument_type);
if (which.idx != TypeIndex::Float64) {
Expand All @@ -43,7 +43,8 @@ AggregateFunctionPtr create_aggregate_function_percentile_approx(const std::stri
}

AggregateFunctionPtr create_aggregate_function_percentile_approx_weighted(
const std::string& name, const DataTypes& argument_types, const bool result_is_nullable) {
const std::string& name, const DataTypes& argument_types, const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
const DataTypePtr& argument_type = remove_nullable(argument_types[0]);
WhichDataType which(argument_type);
if (which.idx != TypeIndex::Float64) {
Expand Down
Loading
Loading