Skip to content

Commit

Permalink
*: support aggregate function stddev_samp() and var_samp() (#19810)
Browse files Browse the repository at this point in the history
Co-authored-by: Zhuomin(Charming) Liu <[email protected]>
Co-authored-by: ti-srebot <[email protected]>
  • Loading branch information
3 people authored Sep 16, 2020
1 parent 1bfeff9 commit 205c401
Show file tree
Hide file tree
Showing 9 changed files with 209 additions and 15 deletions.
42 changes: 42 additions & 0 deletions executor/aggfuncs/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ func Build(ctx sessionctx.Context, aggFuncDesc *aggregation.AggFuncDesc, ordinal
return buildApproxCountDistinct(aggFuncDesc, ordinal)
case ast.AggFuncStddevPop:
return buildStdDevPop(aggFuncDesc, ordinal)
case ast.AggFuncVarSamp:
return buildVarSamp(aggFuncDesc, ordinal)
case ast.AggFuncStddevSamp:
return buildStddevSamp(aggFuncDesc, ordinal)
}
return nil
}
Expand Down Expand Up @@ -503,6 +507,44 @@ func buildStdDevPop(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
}
}

// buildVarSamp builds the AggFunc implementation for function "VAR_SAMP()"
func buildVarSamp(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
base := baseVarPopAggFunc{
baseAggFunc{
args: aggFuncDesc.Args,
ordinal: ordinal,
},
}
switch aggFuncDesc.Mode {
case aggregation.DedupMode:
return nil
default:
if aggFuncDesc.HasDistinct {
return &varSamp4DistinctFloat64{varPop4DistinctFloat64{base}}
}
return &varSamp4Float64{varPop4Float64{base}}
}
}

// buildStddevSamp builds the AggFunc implementation for function "STDDEV_SAMP()"
func buildStddevSamp(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
base := baseVarPopAggFunc{
baseAggFunc{
args: aggFuncDesc.Args,
ordinal: ordinal,
},
}
switch aggFuncDesc.Mode {
case aggregation.DedupMode:
return nil
default:
if aggFuncDesc.HasDistinct {
return &stddevSamp4DistinctFloat64{varPop4DistinctFloat64{base}}
}
return &stddevSamp4Float64{varPop4Float64{base}}
}
}

// buildJSONObjectAgg builds the AggFunc implementation for function "json_objectagg".
func buildJSONObjectAgg(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
base := baseAggFunc{
Expand Down
51 changes: 51 additions & 0 deletions executor/aggfuncs/func_stddevsamp.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright 2019 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.

package aggfuncs

import (
"math"

"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/util/chunk"
)

type stddevSamp4Float64 struct {
varPop4Float64
}

func (e *stddevSamp4Float64) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error {
p := (*partialResult4VarPopFloat64)(pr)
if p.count <= 1 {
chk.AppendNull(e.ordinal)
return nil
}
variance := p.variance / float64(p.count-1)
chk.AppendFloat64(e.ordinal, math.Sqrt(variance))
return nil
}

type stddevSamp4DistinctFloat64 struct {
varPop4DistinctFloat64
}

func (e *stddevSamp4DistinctFloat64) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error {
p := (*partialResult4VarPopDistinctFloat64)(pr)
if p.count <= 1 {
chk.AppendNull(e.ordinal)
return nil
}
variance := p.variance / float64(p.count-1)
chk.AppendFloat64(e.ordinal, math.Sqrt(variance))
return nil
}
25 changes: 25 additions & 0 deletions executor/aggfuncs/func_stddevsamp_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package aggfuncs_test

import (
. "github.com/pingcap/check"
"github.com/pingcap/parser/ast"
"github.com/pingcap/parser/mysql"
)

func (s *testSuite) TestMergePartialResult4Stddevsamp(c *C) {
tests := []aggTest{
buildAggTester(ast.AggFuncStddevSamp, mysql.TypeDouble, 5, 1.5811388300841898, 1, 1.407885953173359),
}
for _, test := range tests {
s.testMergePartialResult(c, test)
}
}

func (s *testSuite) TestStddevsamp(c *C) {
tests := []aggTest{
buildAggTester(ast.AggFuncStddevSamp, mysql.TypeDouble, 5, nil, 1.5811388300841898),
}
for _, test := range tests {
s.testAggFunc(c, test)
}
}
49 changes: 49 additions & 0 deletions executor/aggfuncs/func_varsamp.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// Copyright 2019 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.

package aggfuncs

import (
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/util/chunk"
)

type varSamp4Float64 struct {
varPop4Float64
}

func (e *varSamp4Float64) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error {
p := (*partialResult4VarPopFloat64)(pr)
if p.count <= 1 {
chk.AppendNull(e.ordinal)
return nil
}
variance := p.variance / float64(p.count-1)
chk.AppendFloat64(e.ordinal, variance)
return nil
}

type varSamp4DistinctFloat64 struct {
varPop4DistinctFloat64
}

func (e *varSamp4DistinctFloat64) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error {
p := (*partialResult4VarPopDistinctFloat64)(pr)
if p.count <= 1 {
chk.AppendNull(e.ordinal)
return nil
}
variance := p.variance / float64(p.count-1)
chk.AppendFloat64(e.ordinal, variance)
return nil
}
25 changes: 25 additions & 0 deletions executor/aggfuncs/func_varsamp_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package aggfuncs_test

import (
. "github.com/pingcap/check"
"github.com/pingcap/parser/ast"
"github.com/pingcap/parser/mysql"
)

func (s *testSuite) TestMergePartialResult4Varsamp(c *C) {
tests := []aggTest{
buildAggTester(ast.AggFuncVarSamp, mysql.TypeDouble, 5, 2.5, 1, 1.9821428571428572),
}
for _, test := range tests {
s.testMergePartialResult(c, test)
}
}

func (s *testSuite) TestVarsamp(c *C) {
tests := []aggTest{
buildAggTester(ast.AggFuncVarSamp, mysql.TypeDouble, 5, nil, 2.5),
}
for _, test := range tests {
s.testAggFunc(c, test)
}
}
10 changes: 8 additions & 2 deletions executor/aggregate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -409,8 +409,6 @@ func (s *testSuiteAgg) TestAggregation(c *C) {
_, err = tk.Exec("select std_samp(a) from t")
// TODO: Fix this error message.
c.Assert(errors.Cause(err).Error(), Equals, "[expression:1305]FUNCTION test.std_samp does not exist")
_, err = tk.Exec("select var_samp(a) from t")
c.Assert(errors.Cause(err).Error(), Equals, "unsupported agg function: var_samp")

// For issue #14072: wrong result when using generated column with aggregate statement
tk.MustExec("drop table if exists t1;")
Expand Down Expand Up @@ -464,6 +462,14 @@ func (s *testSuiteAgg) TestAggregation(c *C) {
tk.MustQuery("select stddev_pop(b) from t1 group by a order by a;").Check(testkit.Rows("<nil>", "0", "0"))
tk.MustQuery("select std(b) from t1 group by a order by a;").Check(testkit.Rows("<nil>", "0", "0"))
tk.MustQuery("select stddev(b) from t1 group by a order by a;").Check(testkit.Rows("<nil>", "0", "0"))

//For var_samp()/stddev_samp()
tk.MustExec("drop table if exists t1;")
tk.MustExec("CREATE TABLE t1 (id int(11),value1 float(10,2));")
tk.MustExec("INSERT INTO t1 VALUES (1,0.00),(1,1.00), (1,2.00), (2,10.00), (2,11.00), (2,12.00), (2,13.00);")
result = tk.MustQuery("select id, stddev_pop(value1), var_pop(value1), stddev_samp(value1), var_samp(value1) from t1 group by id order by id;")
result.Check(testkit.Rows("1 0.816496580927726 0.6666666666666666 1 1", "2 1.118033988749895 1.25 1.2909944487358056 1.6666666666666667"))

// For issue #19676 The result of stddev_pop(distinct xxx) is wrong
tk.MustExec("drop table if exists t1;")
tk.MustExec("CREATE TABLE t1 (id int);")
Expand Down
4 changes: 4 additions & 0 deletions expression/aggregation/agg_to_pb.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ func AggFuncToPBExpr(sc *stmtctx.StatementContext, client kv.Client, aggFunc *Ag
tp = tipb.ExprType_JsonObjectAgg
case ast.AggFuncStddevPop:
tp = tipb.ExprType_StddevPop
case ast.AggFuncVarSamp:
tp = tipb.ExprType_VarSamp
case ast.AggFuncStddevSamp:
tp = tipb.ExprType_StddevSamp
}
if !client.IsRequestTypeSupported(kv.ReqTypeSelect, int64(tp)) {
return nil
Expand Down
16 changes: 4 additions & 12 deletions expression/aggregation/base_func.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,8 @@ func (a *baseFuncDesc) typeInfer(ctx sessionctx.Context) error {
a.typeInfer4PercentRank()
case ast.WindowFuncLead, ast.WindowFuncLag:
a.typeInfer4LeadLag(ctx)
case ast.AggFuncVarPop:
a.typeInfer4VarPop(ctx)
case ast.AggFuncStddevPop:
a.typeInfer4Std(ctx)
case ast.AggFuncVarPop, ast.AggFuncStddevPop, ast.AggFuncVarSamp, ast.AggFuncStddevSamp:
a.typeInfer4PopOrSamp(ctx)
case ast.AggFuncJsonObjectAgg:
a.typeInfer4JsonFuncs(ctx)
default:
Expand Down Expand Up @@ -255,14 +253,8 @@ func (a *baseFuncDesc) typeInfer4LeadLag(ctx sessionctx.Context) {
}
}

func (a *baseFuncDesc) typeInfer4VarPop(ctx sessionctx.Context) {
//var_pop's return value type is double
a.RetTp = types.NewFieldType(mysql.TypeDouble)
a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxRealWidth, types.UnspecifiedLength
}

func (a *baseFuncDesc) typeInfer4Std(ctx sessionctx.Context) {
//std's return value type is double
func (a *baseFuncDesc) typeInfer4PopOrSamp(ctx sessionctx.Context) {
//var_pop/std/var_samp/stddev_samp's return value type is double
a.RetTp = types.NewFieldType(mysql.TypeDouble)
a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxRealWidth, types.UnspecifiedLength
}
Expand Down
2 changes: 1 addition & 1 deletion planner/core/rule_aggregation_push_down.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func (a *aggregationPushDownSolver) isDecomposableWithJoin(fun *aggregation.AggF
return false
}
switch fun.Name {
case ast.AggFuncAvg, ast.AggFuncGroupConcat, ast.AggFuncVarPop, ast.AggFuncJsonObjectAgg, ast.AggFuncStddevPop:
case ast.AggFuncAvg, ast.AggFuncGroupConcat, ast.AggFuncVarPop, ast.AggFuncJsonObjectAgg, ast.AggFuncStddevPop, ast.AggFuncVarSamp, ast.AggFuncStddevSamp:
// TODO: Support avg push down.
return false
case ast.AggFuncMax, ast.AggFuncMin, ast.AggFuncFirstRow:
Expand Down

0 comments on commit 205c401

Please sign in to comment.