Skip to content

Commit

Permalink
executor, util: wrap cast upon the args for AggFunction (#7180)
Browse files Browse the repository at this point in the history
  • Loading branch information
XuHuaiyu authored Jul 31, 2018
1 parent 42bba99 commit 0ca4cc6
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 44 deletions.
64 changes: 21 additions & 43 deletions executor/aggfuncs/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,19 +112,17 @@ func buildSum(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
case aggregation.DedupMode:
return nil
default:
switch aggFuncDesc.Args[0].GetType().Tp {
case mysql.TypeFloat, mysql.TypeDouble:
if aggFuncDesc.HasDistinct {
return &sum4DistinctFloat64{base}
}
return &sum4Float64{base}
case mysql.TypeNewDecimal:
switch aggFuncDesc.RetTp.EvalType() {
case types.ETDecimal:
if aggFuncDesc.HasDistinct {
return &sum4DistinctDecimal{base}
}
return &sum4Decimal{base}
default:
return nil
if aggFuncDesc.HasDistinct {
return &sum4DistinctFloat64{base}
}
return &sum4Float64{base}
}
}
}
Expand All @@ -144,13 +142,13 @@ func buildAvg(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
// Build avg functions which consume the original data and update their
// partial results.
case aggregation.CompleteMode, aggregation.Partial1Mode:
switch aggFuncDesc.Args[0].GetType().Tp {
case mysql.TypeNewDecimal:
switch aggFuncDesc.RetTp.EvalType() {
case types.ETDecimal:
if aggFuncDesc.HasDistinct {
return &avgOriginal4DistinctDecimal{base}
}
return &avgOriginal4Decimal{baseAvgDecimal{base}}
case mysql.TypeFloat, mysql.TypeDouble:
default:
if aggFuncDesc.HasDistinct {
return &avgOriginal4DistinctFloat64{base}
}
Expand Down Expand Up @@ -250,14 +248,6 @@ func buildMaxMin(aggFuncDesc *aggregation.AggFuncDesc, ordinal int, isMax bool)

// buildGroupConcat builds the AggFunc implementation for function "GROUP_CONCAT".
func buildGroupConcat(ctx sessionctx.Context, aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
// TODO: There might be different kind of types of the args,
// we should add CastAsString upon every arg after cast can be pushed down to coprocessor.
// And this check can be removed at that time.
for _, arg := range aggFuncDesc.Args {
if arg.GetType().EvalType() != types.ETString {
return nil
}
}
switch aggFuncDesc.Mode {
case aggregation.DedupMode:
return nil
Expand Down Expand Up @@ -292,39 +282,27 @@ func buildGroupConcat(ctx sessionctx.Context, aggFuncDesc *aggregation.AggFuncDe

// buildBitOr builds the AggFunc implementation for function "BIT_OR".
func buildBitOr(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
switch aggFuncDesc.Args[0].GetType().EvalType() {
case types.ETInt:
base := baseAggFunc{
args: aggFuncDesc.Args,
ordinal: ordinal,
}
return &bitOrUint64{baseBitAggFunc{base}}
base := baseAggFunc{
args: aggFuncDesc.Args,
ordinal: ordinal,
}
return nil
return &bitOrUint64{baseBitAggFunc{base}}
}

// buildBitXor builds the AggFunc implementation for function "BIT_XOR".
func buildBitXor(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
switch aggFuncDesc.Args[0].GetType().EvalType() {
case types.ETInt:
base := baseAggFunc{
args: aggFuncDesc.Args,
ordinal: ordinal,
}
return &bitXorUint64{baseBitAggFunc{base}}
base := baseAggFunc{
args: aggFuncDesc.Args,
ordinal: ordinal,
}
return nil
return &bitXorUint64{baseBitAggFunc{base}}
}

// buildBitAnd builds the AggFunc implementation for function "BIT_AND".
func buildBitAnd(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
switch aggFuncDesc.Args[0].GetType().EvalType() {
case types.ETInt:
base := baseAggFunc{
args: aggFuncDesc.Args,
ordinal: ordinal,
}
return &bitAndUint64{baseBitAggFunc{base}}
base := baseAggFunc{
args: aggFuncDesc.Args,
ordinal: ordinal,
}
return nil
return &bitAndUint64{baseBitAggFunc{base}}
}
32 changes: 32 additions & 0 deletions executor/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -827,11 +827,43 @@ func (b *executorBuilder) buildHashJoin(v *plan.PhysicalHashJoin) Executor {
return e
}

// wrapCastForAggArgs wraps the args of an aggregate function with a cast function.
func (b *executorBuilder) wrapCastForAggArgs(funcs []*aggregation.AggFuncDesc) {
for _, f := range funcs {
// We do not need to wrap cast upon these functions,
// since the EvalXXX method called by the arg is determined by the corresponding arg type.
if f.Name == ast.AggFuncCount || f.Name == ast.AggFuncMin || f.Name == ast.AggFuncMax || f.Name == ast.AggFuncFirstRow {
continue
}
var castFunc func(ctx sessionctx.Context, expr expression.Expression) expression.Expression
switch retTp := f.RetTp; retTp.EvalType() {
case types.ETInt:
castFunc = expression.WrapWithCastAsInt
case types.ETReal:
castFunc = expression.WrapWithCastAsReal
case types.ETString:
castFunc = expression.WrapWithCastAsString
case types.ETDecimal:
castFunc = expression.WrapWithCastAsDecimal
default:
panic("should never happen in executorBuilder.wrapCastForAggArgs")
}
for i := range f.Args {
f.Args[i] = castFunc(b.ctx, f.Args[i])
}
}
}

// buildProjBelowAgg builds a ProjectionExec below AggregationExec.
// If all the args of `aggFuncs`, and all the item of `groupByItems`
// are columns or constants, we do not need to build the `proj`.
func (b *executorBuilder) buildProjBelowAgg(aggFuncs []*aggregation.AggFuncDesc, groupByItems []expression.Expression, src Executor) Executor {
hasScalarFunc := false
// If the mode is FinalMode, we do not need to wrap cast upon the args,
// since the types of the args are already the expected.
if len(aggFuncs) > 0 && aggFuncs[0].Mode != aggregation.FinalMode {
b.wrapCastForAggArgs(aggFuncs)
}
for i := 0; !hasScalarFunc && i < len(aggFuncs); i++ {
f := aggFuncs[i]
for _, arg := range f.Args {
Expand Down
10 changes: 9 additions & 1 deletion util/chunk/row.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,15 @@ func (r Row) GetDatum(colIdx int, tp *types.FieldType) types.Datum {
if !r.IsNull(colIdx) {
d.SetMysqlDecimal(r.GetMyDecimal(colIdx))
d.SetLength(tp.Flen)
d.SetFrac(tp.Decimal)
// If tp.Decimal is unspecified(-1), we should set it to the real
// fraction length of the decimal value, if not, the d.Frac will
// be set to MAX_UINT16 which will cause unexpected BadNumber error
// when encoding.
if tp.Decimal == types.UnspecifiedLength {
d.SetFrac(d.Frac())
} else {
d.SetFrac(tp.Decimal)
}
}
case mysql.TypeEnum:
if !r.IsNull(colIdx) {
Expand Down

0 comments on commit 0ca4cc6

Please sign in to comment.