Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

executor, util: wrap cast upon the args for AggFunction #7180

Merged
merged 5 commits into from
Jul 31, 2018
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 21 additions & 43 deletions executor/aggfuncs/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,19 +111,17 @@ func buildSum(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
case aggregation.DedupMode:
return nil
default:
switch aggFuncDesc.Args[0].GetType().Tp {
case mysql.TypeFloat, mysql.TypeDouble:
if aggFuncDesc.HasDistinct {
return &sum4DistinctFloat64{base}
}
return &sum4Float64{base}
case mysql.TypeNewDecimal:
switch aggFuncDesc.RetTp.EvalType() {
case types.ETDecimal:
if aggFuncDesc.HasDistinct {
return &sum4DistinctDecimal{base}
}
return &sum4Decimal{base}
default:
return nil
if aggFuncDesc.HasDistinct {
return &sum4DistinctFloat64{base}
}
return &sum4Float64{base}
}
}
}
Expand All @@ -143,13 +141,13 @@ func buildAvg(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
// Build avg functions which consume the original data and update their
// partial results.
case aggregation.CompleteMode, aggregation.Partial1Mode:
switch aggFuncDesc.Args[0].GetType().Tp {
case mysql.TypeNewDecimal:
switch aggFuncDesc.RetTp.EvalType() {
case types.ETDecimal:
if aggFuncDesc.HasDistinct {
return &avgOriginal4DistinctDecimal{base}
}
return &avgOriginal4Decimal{baseAvgDecimal{base}}
case mysql.TypeFloat, mysql.TypeDouble:
default:
if aggFuncDesc.HasDistinct {
return &avgOriginal4DistinctFloat64{base}
}
Expand Down Expand Up @@ -249,14 +247,6 @@ func buildMaxMin(aggFuncDesc *aggregation.AggFuncDesc, ordinal int, isMax bool)

// buildGroupConcat builds the AggFunc implementation for function "GROUP_CONCAT".
func buildGroupConcat(ctx sessionctx.Context, aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
// TODO: There might be different kind of types of the args,
// we should add CastAsString upon every arg after cast can be pushed down to coprocessor.
// And this check can be removed at that time.
for _, arg := range aggFuncDesc.Args {
if arg.GetType().EvalType() != types.ETString {
return nil
}
}
switch aggFuncDesc.Mode {
case aggregation.DedupMode:
return nil
Expand Down Expand Up @@ -291,39 +281,27 @@ func buildGroupConcat(ctx sessionctx.Context, aggFuncDesc *aggregation.AggFuncDe

// buildBitOr builds the AggFunc implementation for function "BIT_OR".
func buildBitOr(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
switch aggFuncDesc.Args[0].GetType().EvalType() {
case types.ETInt:
base := baseAggFunc{
args: aggFuncDesc.Args,
ordinal: ordinal,
}
return &bitOrUint64{baseBitAggFunc{base}}
base := baseAggFunc{
args: aggFuncDesc.Args,
ordinal: ordinal,
}
return nil
return &bitOrUint64{baseBitAggFunc{base}}
}

// buildBitXor builds the AggFunc implementation for function "BIT_XOR".
func buildBitXor(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
switch aggFuncDesc.Args[0].GetType().EvalType() {
case types.ETInt:
base := baseAggFunc{
args: aggFuncDesc.Args,
ordinal: ordinal,
}
return &bitXorUint64{baseBitAggFunc{base}}
base := baseAggFunc{
args: aggFuncDesc.Args,
ordinal: ordinal,
}
return nil
return &bitXorUint64{baseBitAggFunc{base}}
}

// buildBitAnd builds the AggFunc implementation for function "BIT_AND".
func buildBitAnd(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
switch aggFuncDesc.Args[0].GetType().EvalType() {
case types.ETInt:
base := baseAggFunc{
args: aggFuncDesc.Args,
ordinal: ordinal,
}
return &bitAndUint64{baseBitAggFunc{base}}
base := baseAggFunc{
args: aggFuncDesc.Args,
ordinal: ordinal,
}
return nil
return &bitAndUint64{baseBitAggFunc{base}}
}
32 changes: 32 additions & 0 deletions executor/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -825,11 +825,43 @@ func (b *executorBuilder) buildHashJoin(v *plan.PhysicalHashJoin) Executor {
return e
}

// wrapCastForAggArgs wraps the args of an aggregate function with a cast function.
func (b *executorBuilder) wrapCastForAggArgs(funcs []*aggregation.AggFuncDesc) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'd better make implement a function like InferArgumentType for AggFuncDesc, and call that function inside this routine:

for _, aggFuncDesc := range funcs {
    aggFuncDesc.InferArgumentType()
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think current function name may be clearer,
since this is not a real type infer.
Type infer for agg function has been done in AggFuncDesc.typeInfer.
This function is more like a supplementary behavior for type infer.

for _, f := range funcs {
// We do not need to wrap cast upon these functions,
// since the EvalXXX method called by the arg is determined by the corresponding arg type.
if f.Name == ast.AggFuncCount || f.Name == ast.AggFuncMin || f.Name == ast.AggFuncMax || f.Name == ast.AggFuncFirstRow {
continue
}
var castFunc func(ctx sessionctx.Context, expr expression.Expression) expression.Expression
switch retTp := f.RetTp; retTp.EvalType() {
case types.ETInt:
castFunc = expression.WrapWithCastAsInt
case types.ETReal:
castFunc = expression.WrapWithCastAsReal
case types.ETString:
castFunc = expression.WrapWithCastAsString
case types.ETDecimal:
castFunc = expression.WrapWithCastAsDecimal
default:
panic("should never happen in executorBuilder.wrapCastForAggArgs")
}
for i := range f.Args {
f.Args[i] = castFunc(b.ctx, f.Args[i])
}
}
}

// buildProjBelowAgg builds a ProjectionExec below AggregationExec.
// If all the args of `aggFuncs`, and all the item of `groupByItems`
// are columns or constants, we do not need to build the `proj`.
func (b *executorBuilder) buildProjBelowAgg(aggFuncs []*aggregation.AggFuncDesc, groupByItems []expression.Expression, src Executor) Executor {
hasScalarFunc := false
// If the mode is FinalMode, we do not need to wrap cast upon the args,
// since the types of the args are already the expected.
if len(aggFuncs) > 0 && aggFuncs[0].Mode != aggregation.FinalMode {
b.wrapCastForAggArgs(aggFuncs)
}
for i := 0; !hasScalarFunc && i < len(aggFuncs); i++ {
f := aggFuncs[i]
for _, arg := range f.Args {
Expand Down
10 changes: 9 additions & 1 deletion util/chunk/row.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,15 @@ func (r Row) GetDatum(colIdx int, tp *types.FieldType) types.Datum {
if !r.IsNull(colIdx) {
d.SetMysqlDecimal(r.GetMyDecimal(colIdx))
d.SetLength(tp.Flen)
d.SetFrac(tp.Decimal)
// If tp.Decimal is unspecified(-1), we should set it to the real
// fraction length of the decimal value, if not, the d.Frac will
// be set to MAX_UINT16 which will cause unexpected BadNumber error
// when encoding.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if tp.Decimal == types.UnspecifiedLength {
d.SetFrac(d.Frac())
} else {
d.SetFrac(tp.Decimal)
}
}
case mysql.TypeEnum:
if !r.IsNull(colIdx) {
Expand Down