diff --git a/executor/aggfuncs/builder.go b/executor/aggfuncs/builder.go index fdfc446381362..1a47a50fa1822 100644 --- a/executor/aggfuncs/builder.go +++ b/executor/aggfuncs/builder.go @@ -112,19 +112,17 @@ func buildSum(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc { case aggregation.DedupMode: return nil default: - switch aggFuncDesc.Args[0].GetType().Tp { - case mysql.TypeFloat, mysql.TypeDouble: - if aggFuncDesc.HasDistinct { - return &sum4DistinctFloat64{base} - } - return &sum4Float64{base} - case mysql.TypeNewDecimal: + switch aggFuncDesc.RetTp.EvalType() { + case types.ETDecimal: if aggFuncDesc.HasDistinct { return &sum4DistinctDecimal{base} } return &sum4Decimal{base} default: - return nil + if aggFuncDesc.HasDistinct { + return &sum4DistinctFloat64{base} + } + return &sum4Float64{base} } } } @@ -144,13 +142,13 @@ func buildAvg(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc { // Build avg functions which consume the original data and update their // partial results. case aggregation.CompleteMode, aggregation.Partial1Mode: - switch aggFuncDesc.Args[0].GetType().Tp { - case mysql.TypeNewDecimal: + switch aggFuncDesc.RetTp.EvalType() { + case types.ETDecimal: if aggFuncDesc.HasDistinct { return &avgOriginal4DistinctDecimal{base} } return &avgOriginal4Decimal{baseAvgDecimal{base}} - case mysql.TypeFloat, mysql.TypeDouble: + default: if aggFuncDesc.HasDistinct { return &avgOriginal4DistinctFloat64{base} } @@ -250,14 +248,6 @@ func buildMaxMin(aggFuncDesc *aggregation.AggFuncDesc, ordinal int, isMax bool) // buildGroupConcat builds the AggFunc implementation for function "GROUP_CONCAT". func buildGroupConcat(ctx sessionctx.Context, aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc { - // TODO: There might be different kind of types of the args, - // we should add CastAsString upon every arg after cast can be pushed down to coprocessor. - // And this check can be removed at that time. - for _, arg := range aggFuncDesc.Args { - if arg.GetType().EvalType() != types.ETString { - return nil - } - } switch aggFuncDesc.Mode { case aggregation.DedupMode: return nil @@ -292,39 +282,27 @@ func buildGroupConcat(ctx sessionctx.Context, aggFuncDesc *aggregation.AggFuncDe // buildBitOr builds the AggFunc implementation for function "BIT_OR". func buildBitOr(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc { - switch aggFuncDesc.Args[0].GetType().EvalType() { - case types.ETInt: - base := baseAggFunc{ - args: aggFuncDesc.Args, - ordinal: ordinal, - } - return &bitOrUint64{baseBitAggFunc{base}} + base := baseAggFunc{ + args: aggFuncDesc.Args, + ordinal: ordinal, } - return nil + return &bitOrUint64{baseBitAggFunc{base}} } // buildBitXor builds the AggFunc implementation for function "BIT_XOR". func buildBitXor(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc { - switch aggFuncDesc.Args[0].GetType().EvalType() { - case types.ETInt: - base := baseAggFunc{ - args: aggFuncDesc.Args, - ordinal: ordinal, - } - return &bitXorUint64{baseBitAggFunc{base}} + base := baseAggFunc{ + args: aggFuncDesc.Args, + ordinal: ordinal, } - return nil + return &bitXorUint64{baseBitAggFunc{base}} } // buildBitAnd builds the AggFunc implementation for function "BIT_AND". func buildBitAnd(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc { - switch aggFuncDesc.Args[0].GetType().EvalType() { - case types.ETInt: - base := baseAggFunc{ - args: aggFuncDesc.Args, - ordinal: ordinal, - } - return &bitAndUint64{baseBitAggFunc{base}} + base := baseAggFunc{ + args: aggFuncDesc.Args, + ordinal: ordinal, } - return nil + return &bitAndUint64{baseBitAggFunc{base}} } diff --git a/executor/builder.go b/executor/builder.go index eefd10fc5cf8f..3bd91329e8a26 100644 --- a/executor/builder.go +++ b/executor/builder.go @@ -827,11 +827,43 @@ func (b *executorBuilder) buildHashJoin(v *plan.PhysicalHashJoin) Executor { return e } +// wrapCastForAggArgs wraps the args of an aggregate function with a cast function. +func (b *executorBuilder) wrapCastForAggArgs(funcs []*aggregation.AggFuncDesc) { + for _, f := range funcs { + // We do not need to wrap cast upon these functions, + // since the EvalXXX method called by the arg is determined by the corresponding arg type. + if f.Name == ast.AggFuncCount || f.Name == ast.AggFuncMin || f.Name == ast.AggFuncMax || f.Name == ast.AggFuncFirstRow { + continue + } + var castFunc func(ctx sessionctx.Context, expr expression.Expression) expression.Expression + switch retTp := f.RetTp; retTp.EvalType() { + case types.ETInt: + castFunc = expression.WrapWithCastAsInt + case types.ETReal: + castFunc = expression.WrapWithCastAsReal + case types.ETString: + castFunc = expression.WrapWithCastAsString + case types.ETDecimal: + castFunc = expression.WrapWithCastAsDecimal + default: + panic("should never happen in executorBuilder.wrapCastForAggArgs") + } + for i := range f.Args { + f.Args[i] = castFunc(b.ctx, f.Args[i]) + } + } +} + // buildProjBelowAgg builds a ProjectionExec below AggregationExec. // If all the args of `aggFuncs`, and all the item of `groupByItems` // are columns or constants, we do not need to build the `proj`. func (b *executorBuilder) buildProjBelowAgg(aggFuncs []*aggregation.AggFuncDesc, groupByItems []expression.Expression, src Executor) Executor { hasScalarFunc := false + // If the mode is FinalMode, we do not need to wrap cast upon the args, + // since the types of the args are already the expected. + if len(aggFuncs) > 0 && aggFuncs[0].Mode != aggregation.FinalMode { + b.wrapCastForAggArgs(aggFuncs) + } for i := 0; !hasScalarFunc && i < len(aggFuncs); i++ { f := aggFuncs[i] for _, arg := range f.Args { diff --git a/util/chunk/row.go b/util/chunk/row.go index 55dcf102c2740..dcf573eea66b5 100644 --- a/util/chunk/row.go +++ b/util/chunk/row.go @@ -187,7 +187,15 @@ func (r Row) GetDatum(colIdx int, tp *types.FieldType) types.Datum { if !r.IsNull(colIdx) { d.SetMysqlDecimal(r.GetMyDecimal(colIdx)) d.SetLength(tp.Flen) - d.SetFrac(tp.Decimal) + // If tp.Decimal is unspecified(-1), we should set it to the real + // fraction length of the decimal value, if not, the d.Frac will + // be set to MAX_UINT16 which will cause unexpected BadNumber error + // when encoding. + if tp.Decimal == types.UnspecifiedLength { + d.SetFrac(d.Frac()) + } else { + d.SetFrac(tp.Decimal) + } } case mysql.TypeEnum: if !r.IsNull(colIdx) {