Skip to content

Commit

Permalink
executor: fix wrong behavior when ENUM column meet Aggregati… (#14364)
Browse files Browse the repository at this point in the history
  • Loading branch information
sre-bot authored and zz-jason committed Jan 7, 2020
1 parent 00f0015 commit b89108b
Show file tree
Hide file tree
Showing 7 changed files with 142 additions and 1 deletion.
12 changes: 12 additions & 0 deletions executor/aggfuncs/aggfunc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,18 @@ func getDataGenFunc(ft *types.FieldType) func(i int) types.Datum {
return func(i int) types.Datum { return types.NewDurationDatum(types.Duration{Duration: time.Duration(i)}) }
case mysql.TypeJSON:
return func(i int) types.Datum { return types.NewDatum(json.CreateBinary(int64(i))) }
case mysql.TypeEnum:
elems := []string{"a", "b", "c", "d", "e"}
return func(i int) types.Datum {
e, _ := types.ParseEnumValue(elems, uint64(i+1))
return types.NewMysqlEnumDatum(e)
}
case mysql.TypeSet:
elems := []string{"a", "b", "c", "d", "e"}
return func(i int) types.Datum {
e, _ := types.ParseSetValue(elems, uint64(i+1))
return types.NewMysqlSetDatum(e)
}
}
return nil
}
Expand Down
2 changes: 2 additions & 0 deletions executor/aggfuncs/aggfuncs.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ var (
_ AggFunc = (*firstRow4Float32)(nil)
_ AggFunc = (*firstRow4Float64)(nil)
_ AggFunc = (*firstRow4JSON)(nil)
_ AggFunc = (*firstRow4Enum)(nil)
_ AggFunc = (*firstRow4Set)(nil)

// All the AggFunc implementations for "MAX"/"MIN" are listed here.
_ AggFunc = (*maxMin4Int)(nil)
Expand Down
7 changes: 7 additions & 0 deletions executor/aggfuncs/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,13 @@ func buildFirstRow(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
switch aggFuncDesc.Mode {
case aggregation.DedupMode:
default:
switch fieldType.Tp {
case mysql.TypeEnum:
return &firstRow4Enum{base}
case mysql.TypeSet:
return &firstRow4Set{base}
}

switch evalType {
case types.ETInt:
return &firstRow4Int{base}
Expand Down
104 changes: 104 additions & 0 deletions executor/aggfuncs/func_first_row.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,18 @@ type partialResult4FirstRowJSON struct {
val json.BinaryJSON
}

type partialResult4FirstRowEnum struct {
basePartialResult4FirstRow

val types.Enum
}

type partialResult4FirstRowSet struct {
basePartialResult4FirstRow

val types.Set
}

type firstRow4Int struct {
baseAggFunc
}
Expand Down Expand Up @@ -451,3 +463,95 @@ func (*firstRow4Decimal) MergePartialResult(sctx sessionctx.Context, src Partial
}
return nil
}

type firstRow4Enum struct {
baseAggFunc
}

func (e *firstRow4Enum) AllocPartialResult() PartialResult {
return PartialResult(new(partialResult4FirstRowEnum))
}

func (e *firstRow4Enum) ResetPartialResult(pr PartialResult) {
p := (*partialResult4FirstRowEnum)(pr)
p.isNull, p.gotFirstRow = false, false
}

func (e *firstRow4Enum) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error {
p := (*partialResult4FirstRowEnum)(pr)
if p.gotFirstRow {
return nil
}
for _, row := range rowsInGroup {
d, err := e.args[0].Eval(row)
if err != nil {
return err
}
p.gotFirstRow, p.isNull, p.val = true, d.IsNull(), d.GetMysqlEnum()
break
}
return nil
}

func (*firstRow4Enum) MergePartialResult(sctx sessionctx.Context, src PartialResult, dst PartialResult) error {
p1, p2 := (*partialResult4FirstRowEnum)(src), (*partialResult4FirstRowEnum)(dst)
if !p2.gotFirstRow {
*p2 = *p1
}
return nil
}
func (e *firstRow4Enum) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error {
p := (*partialResult4FirstRowEnum)(pr)
if p.isNull || !p.gotFirstRow {
chk.AppendNull(e.ordinal)
return nil
}
chk.AppendEnum(e.ordinal, p.val)
return nil
}

type firstRow4Set struct {
baseAggFunc
}

func (e *firstRow4Set) AllocPartialResult() PartialResult {
return PartialResult(new(partialResult4FirstRowSet))
}

func (e *firstRow4Set) ResetPartialResult(pr PartialResult) {
p := (*partialResult4FirstRowSet)(pr)
p.isNull, p.gotFirstRow = false, false
}

func (e *firstRow4Set) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error {
p := (*partialResult4FirstRowSet)(pr)
if p.gotFirstRow {
return nil
}
for _, row := range rowsInGroup {
d, err := e.args[0].Eval(row)
if err != nil {
return err
}
p.gotFirstRow, p.isNull, p.val = true, d.IsNull(), d.GetMysqlSet()
break
}
return nil
}

func (*firstRow4Set) MergePartialResult(sctx sessionctx.Context, src PartialResult, dst PartialResult) error {
p1, p2 := (*partialResult4FirstRowSet)(src), (*partialResult4FirstRowSet)(dst)
if !p2.gotFirstRow {
*p2 = *p1
}
return nil
}
func (e *firstRow4Set) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error {
p := (*partialResult4FirstRowSet)(pr)
if p.isNull || !p.gotFirstRow {
chk.AppendNull(e.ordinal)
return nil
}
chk.AppendSet(e.ordinal, p.val)
return nil
}
9 changes: 9 additions & 0 deletions executor/aggfuncs/func_first_row_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@ import (
)

func (s *testSuite) TestMergePartialResult4FirstRow(c *C) {
elems := []string{"a", "b", "c", "d", "e"}
enumA, _ := types.ParseEnumName(elems, "a")
enumC, _ := types.ParseEnumName(elems, "c")

setA, _ := types.ParseSetName(elems, "a")
setAB, _ := types.ParseSetName(elems, "a,b")

tests := []aggTest{
buildAggTester(ast.AggFuncFirstRow, mysql.TypeLonglong, 5, 0, 2, 0),
buildAggTester(ast.AggFuncFirstRow, mysql.TypeFloat, 5, 0.0, 2.0, 0.0),
Expand All @@ -33,6 +40,8 @@ func (s *testSuite) TestMergePartialResult4FirstRow(c *C) {
buildAggTester(ast.AggFuncFirstRow, mysql.TypeDate, 5, types.TimeFromDays(365), types.TimeFromDays(367), types.TimeFromDays(365)),
buildAggTester(ast.AggFuncFirstRow, mysql.TypeDuration, 5, types.Duration{Duration: time.Duration(0)}, types.Duration{Duration: time.Duration(2)}, types.Duration{Duration: time.Duration(0)}),
buildAggTester(ast.AggFuncFirstRow, mysql.TypeJSON, 5, json.CreateBinary(int64(0)), json.CreateBinary(int64(2)), json.CreateBinary(int64(0))),
buildAggTester(ast.AggFuncFirstRow, mysql.TypeEnum, 5, enumA, enumC, enumA),
buildAggTester(ast.AggFuncFirstRow, mysql.TypeSet, 5, setA, setAB, setA),
}
for _, test := range tests {
s.testMergePartialResult(c, test)
Expand Down
3 changes: 2 additions & 1 deletion expression/aggregation/base_func.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,8 @@ func (a *baseFuncDesc) typeInfer4MaxMin(ctx sessionctx.Context) {
a.RetTp = a.Args[0].GetType().Clone()
a.RetTp.Flag &^= mysql.NotNullFlag
}
if a.RetTp.Tp == mysql.TypeEnum || a.RetTp.Tp == mysql.TypeSet {
// TODO: fix other aggFuncs for TypeEnum & TypeSet
if (a.RetTp.Tp == mysql.TypeEnum || a.RetTp.Tp == mysql.TypeSet) && a.Name != ast.AggFuncFirstRow {
a.RetTp = &types.FieldType{Tp: mysql.TypeString, Flen: mysql.MaxFieldCharLength}
}
}
Expand Down
6 changes: 6 additions & 0 deletions types/datum.go
Original file line number Diff line number Diff line change
Expand Up @@ -1716,6 +1716,12 @@ func NewMysqlEnumDatum(e Enum) (d Datum) {
return d
}

// NewMysqlSetDatum creates a new MysqlSet Datum for a Enum value.
func NewMysqlSetDatum(e Set) (d Datum) {
d.SetMysqlSet(e)
return d
}

// MakeDatums creates datum slice from interfaces.
func MakeDatums(args ...interface{}) []Datum {
datums := make([]Datum, len(args))
Expand Down

0 comments on commit b89108b

Please sign in to comment.