diff --git a/executor/aggregate_test.go b/executor/aggregate_test.go index a99d7e71a69f4..976775ce5dc4c 100644 --- a/executor/aggregate_test.go +++ b/executor/aggregate_test.go @@ -1483,6 +1483,7 @@ func TestAvgDecimal(t *testing.T) { tk.MustExec("insert into td values (0,29815);") tk.MustExec("insert into td values (10017,-32661);") tk.MustQuery(" SELECT AVG( col_bigint / col_smallint) AS field1 FROM td;").Sort().Check(testkit.Rows("25769363061037.62077260")) + tk.MustQuery(" SELECT AVG(col_bigint) OVER (PARTITION BY col_smallint) as field2 FROM td where col_smallint = -23828;").Sort().Check(testkit.Rows("4.0000")) tk.MustExec("drop table td;") } diff --git a/executor/builder.go b/executor/builder.go index 73d20b93eb9bb..78c94fd02f1aa 100644 --- a/executor/builder.go +++ b/executor/builder.go @@ -4192,7 +4192,7 @@ func (b *executorBuilder) buildWindow(v *plannercore.PhysicalWindow) Executor { partialResults := make([]aggfuncs.PartialResult, 0, len(v.WindowFuncDescs)) resultColIdx := v.Schema().Len() - len(v.WindowFuncDescs) for _, desc := range v.WindowFuncDescs { - aggDesc, err := aggregation.NewAggFuncDesc(b.ctx, desc.Name, desc.Args, false) + aggDesc, err := aggregation.NewAggFuncDescForWindowFunc(b.ctx, desc, false) if err != nil { b.err = err return nil diff --git a/executor/tiflash_test.go b/executor/tiflash_test.go index 10d51bb1a27d7..14e7c13e9c17e 100644 --- a/executor/tiflash_test.go +++ b/executor/tiflash_test.go @@ -829,6 +829,7 @@ func (s *tiflashTestSuite) TestAvgOverflow(c *C) { tk.MustExec("set @@session.tidb_isolation_read_engines=\"tiflash\"") tk.MustExec("set @@session.tidb_enforce_mpp=ON") tk.MustQuery(" SELECT AVG( col_bigint / col_smallint) AS field1 FROM td;").Sort().Check(testkit.Rows("25769363061037.62077260")) + tk.MustQuery(" SELECT AVG(col_bigint) OVER (PARTITION BY col_smallint) as field2 FROM td where col_smallint = -23828;").Sort().Check(testkit.Rows("4.0000")) tk.MustExec("drop table if exists td;") } diff --git a/expression/aggregation/descriptor.go b/expression/aggregation/descriptor.go index 30f020e7dfdf2..8882b93ec05d7 100644 --- a/expression/aggregation/descriptor.go +++ b/expression/aggregation/descriptor.go @@ -43,6 +43,7 @@ type AggFuncDesc struct { } // NewAggFuncDesc creates an aggregation function signature descriptor. +// this func cannot be called twice as the TypeInfer has changed the type of args in the first time. func NewAggFuncDesc(ctx sessionctx.Context, name string, args []expression.Expression, hasDistinct bool) (*AggFuncDesc, error) { b, err := newBaseFuncDesc(ctx, name, args) if err != nil { @@ -51,6 +52,14 @@ func NewAggFuncDesc(ctx sessionctx.Context, name string, args []expression.Expre return &AggFuncDesc{baseFuncDesc: b, HasDistinct: hasDistinct}, nil } +// NewAggFuncDescForWindowFunc creates an aggregation function from window functions, where baseFuncDesc may be ready. +func NewAggFuncDescForWindowFunc(ctx sessionctx.Context, Desc *WindowFuncDesc, hasDistinct bool) (*AggFuncDesc, error) { + if Desc.RetTp == nil { // safety check + return NewAggFuncDesc(ctx, Desc.Name, Desc.Args, hasDistinct) + } + return &AggFuncDesc{baseFuncDesc: baseFuncDesc{Desc.Name, Desc.Args, Desc.RetTp}, HasDistinct: hasDistinct}, nil +} + // String implements the fmt.Stringer interface. func (a *AggFuncDesc) String() string { buffer := bytes.NewBufferString(a.Name)