Skip to content

Commit

Permalink
fix array func with decimal type
Browse files Browse the repository at this point in the history
  • Loading branch information
amorynan committed Sep 14, 2024
1 parent 83e12a8 commit 01b188d
Show file tree
Hide file tree
Showing 7 changed files with 997 additions and 3 deletions.
13 changes: 13 additions & 0 deletions be/src/vec/functions/array/function_array_aggregation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,18 @@ struct ArrayAggregateImpl {
using Function = AggregateFunction<AggregateFunctionImpl<operation>>;
const DataTypeArray* data_type_array =
static_cast<const DataTypeArray*>(remove_nullable(arguments[0]).get());
if constexpr (operation == AggregateOperation::AVERAGE ||
operation == AggregateOperation::SUM ||
operation == AggregateOperation::PRODUCT) {
if (is_decimal(remove_nullable(data_type_array->get_nested_type()))) {
const auto decimal_type = remove_nullable(data_type_array->get_nested_type());
if (check_decimal<Decimal256>(*decimal_type)) {
throw doris::Exception(
ErrorCode::INVALID_ARGUMENT, "Unexpected type {} for aggregation {}",
data_type_array->get_nested_type()->get_name(), operation);
}
}
}
auto function = Function::create(data_type_array->get_nested_type());
if (function) {
return function->get_return_type();
Expand Down Expand Up @@ -175,6 +187,7 @@ struct ArrayAggregateImpl {
execute_type<Decimal64>(res, type, data, offsets) ||
execute_type<Decimal128V2>(res, type, data, offsets) ||
execute_type<Decimal128V3>(res, type, data, offsets) ||
execute_type<Decimal256>(res, type, data, offsets) ||
execute_type<Date>(res, type, data, offsets) ||
execute_type<DateTime>(res, type, data, offsets) ||
execute_type<DateV2>(res, type, data, offsets) ||
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.apache.doris.nereids.types.ArrayType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.DecimalV3Type;
import org.apache.doris.qe.ConnectContext;

/** ComputePrecisionForSum */
public interface ComputePrecisionForArrayItemAgg extends ComputePrecision {
Expand All @@ -29,8 +30,15 @@ default FunctionSignature computePrecision(FunctionSignature signature) {
if (getArgumentType(0) instanceof ArrayType) {
DataType itemType = ((ArrayType) getArgument(0).getDataType()).getItemType();
if (itemType instanceof DecimalV3Type) {
boolean enableDecimal256 = false;
ConnectContext connectContext = ConnectContext.get();
if (connectContext != null) {
enableDecimal256 = connectContext.getSessionVariable().isEnableDecimal256();
}
DecimalV3Type returnType = DecimalV3Type.createDecimalV3Type(
DecimalV3Type.MAX_DECIMAL128_PRECISION, ((DecimalV3Type) itemType).getScale());
enableDecimal256 ? DecimalV3Type.MAX_DECIMAL256_PRECISION
: DecimalV3Type.MAX_DECIMAL128_PRECISION,
((DecimalV3Type) itemType).getScale());
if (signature.returnType instanceof ArrayType) {
signature = signature.withReturnType(ArrayType.of(returnType));
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import org.apache.doris.nereids.types.ArrayType;
import org.apache.doris.nereids.types.BooleanType;
import org.apache.doris.nereids.types.coercion.AnyDataType;
import org.apache.doris.nereids.types.coercion.FollowToAnyDataType;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
Expand All @@ -41,7 +40,7 @@ public class ArraysOverlap extends ScalarFunction implements ExplicitlyCastableS

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(BooleanType.INSTANCE)
.args(ArrayType.of(new AnyDataType(0)), ArrayType.of(new FollowToAnyDataType(0)))
.args(ArrayType.of(new AnyDataType(0)), ArrayType.of(new AnyDataType(0)))
);

/**
Expand Down
Loading

0 comments on commit 01b188d

Please sign in to comment.