diff --git a/executor/aggfuncs/aggfunc_test.go b/executor/aggfuncs/aggfunc_test.go index 7f9d582d5bf81..c47578975654c 100644 --- a/executor/aggfuncs/aggfunc_test.go +++ b/executor/aggfuncs/aggfunc_test.go @@ -168,6 +168,18 @@ func getDataGenFunc(ft *types.FieldType) func(i int) types.Datum { return func(i int) types.Datum { return types.NewDurationDatum(types.Duration{Duration: time.Duration(i)}) } case mysql.TypeJSON: return func(i int) types.Datum { return types.NewDatum(json.CreateBinary(int64(i))) } + case mysql.TypeEnum: + elems := []string{"a", "b", "c", "d", "e"} + return func(i int) types.Datum { + e, _ := types.ParseEnumValue(elems, uint64(i+1)) + return types.NewMysqlEnumDatum(e) + } + case mysql.TypeSet: + elems := []string{"a", "b", "c", "d", "e"} + return func(i int) types.Datum { + e, _ := types.ParseSetValue(elems, uint64(i+1)) + return types.NewMysqlSetDatum(e) + } } return nil } diff --git a/executor/aggfuncs/aggfuncs.go b/executor/aggfuncs/aggfuncs.go index 9bc105d2c1d24..e5db2b389d345 100644 --- a/executor/aggfuncs/aggfuncs.go +++ b/executor/aggfuncs/aggfuncs.go @@ -43,6 +43,8 @@ var ( _ AggFunc = (*firstRow4Float32)(nil) _ AggFunc = (*firstRow4Float64)(nil) _ AggFunc = (*firstRow4JSON)(nil) + _ AggFunc = (*firstRow4Enum)(nil) + _ AggFunc = (*firstRow4Set)(nil) // All the AggFunc implementations for "MAX"/"MIN" are listed here. _ AggFunc = (*maxMin4Int)(nil) diff --git a/executor/aggfuncs/builder.go b/executor/aggfuncs/builder.go index 02e1570b4263d..67502c9ea9c9e 100644 --- a/executor/aggfuncs/builder.go +++ b/executor/aggfuncs/builder.go @@ -211,6 +211,13 @@ func buildFirstRow(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc { switch aggFuncDesc.Mode { case aggregation.DedupMode: default: + switch fieldType.Tp { + case mysql.TypeEnum: + return &firstRow4Enum{base} + case mysql.TypeSet: + return &firstRow4Set{base} + } + switch evalType { case types.ETInt: return &firstRow4Int{base} diff --git a/executor/aggfuncs/func_first_row.go b/executor/aggfuncs/func_first_row.go index b87121fb51ae5..b37c41ec128a9 100644 --- a/executor/aggfuncs/func_first_row.go +++ b/executor/aggfuncs/func_first_row.go @@ -77,6 +77,18 @@ type partialResult4FirstRowJSON struct { val json.BinaryJSON } +type partialResult4FirstRowEnum struct { + basePartialResult4FirstRow + + val types.Enum +} + +type partialResult4FirstRowSet struct { + basePartialResult4FirstRow + + val types.Set +} + type firstRow4Int struct { baseAggFunc } @@ -451,3 +463,95 @@ func (*firstRow4Decimal) MergePartialResult(sctx sessionctx.Context, src Partial } return nil } + +type firstRow4Enum struct { + baseAggFunc +} + +func (e *firstRow4Enum) AllocPartialResult() PartialResult { + return PartialResult(new(partialResult4FirstRowEnum)) +} + +func (e *firstRow4Enum) ResetPartialResult(pr PartialResult) { + p := (*partialResult4FirstRowEnum)(pr) + p.isNull, p.gotFirstRow = false, false +} + +func (e *firstRow4Enum) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error { + p := (*partialResult4FirstRowEnum)(pr) + if p.gotFirstRow { + return nil + } + for _, row := range rowsInGroup { + d, err := e.args[0].Eval(row) + if err != nil { + return err + } + p.gotFirstRow, p.isNull, p.val = true, d.IsNull(), d.GetMysqlEnum() + break + } + return nil +} + +func (*firstRow4Enum) MergePartialResult(sctx sessionctx.Context, src PartialResult, dst PartialResult) error { + p1, p2 := (*partialResult4FirstRowEnum)(src), (*partialResult4FirstRowEnum)(dst) + if !p2.gotFirstRow { + *p2 = *p1 + } + return nil +} +func (e *firstRow4Enum) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error { + p := (*partialResult4FirstRowEnum)(pr) + if p.isNull || !p.gotFirstRow { + chk.AppendNull(e.ordinal) + return nil + } + chk.AppendEnum(e.ordinal, p.val) + return nil +} + +type firstRow4Set struct { + baseAggFunc +} + +func (e *firstRow4Set) AllocPartialResult() PartialResult { + return PartialResult(new(partialResult4FirstRowSet)) +} + +func (e *firstRow4Set) ResetPartialResult(pr PartialResult) { + p := (*partialResult4FirstRowSet)(pr) + p.isNull, p.gotFirstRow = false, false +} + +func (e *firstRow4Set) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error { + p := (*partialResult4FirstRowSet)(pr) + if p.gotFirstRow { + return nil + } + for _, row := range rowsInGroup { + d, err := e.args[0].Eval(row) + if err != nil { + return err + } + p.gotFirstRow, p.isNull, p.val = true, d.IsNull(), d.GetMysqlSet() + break + } + return nil +} + +func (*firstRow4Set) MergePartialResult(sctx sessionctx.Context, src PartialResult, dst PartialResult) error { + p1, p2 := (*partialResult4FirstRowSet)(src), (*partialResult4FirstRowSet)(dst) + if !p2.gotFirstRow { + *p2 = *p1 + } + return nil +} +func (e *firstRow4Set) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error { + p := (*partialResult4FirstRowSet)(pr) + if p.isNull || !p.gotFirstRow { + chk.AppendNull(e.ordinal) + return nil + } + chk.AppendSet(e.ordinal, p.val) + return nil +} diff --git a/executor/aggfuncs/func_first_row_test.go b/executor/aggfuncs/func_first_row_test.go index daf7f63d8fb75..357f651beb26c 100644 --- a/executor/aggfuncs/func_first_row_test.go +++ b/executor/aggfuncs/func_first_row_test.go @@ -24,6 +24,13 @@ import ( ) func (s *testSuite) TestMergePartialResult4FirstRow(c *C) { + elems := []string{"a", "b", "c", "d", "e"} + enumA, _ := types.ParseEnumName(elems, "a") + enumC, _ := types.ParseEnumName(elems, "c") + + setA, _ := types.ParseSetName(elems, "a") + setAB, _ := types.ParseSetName(elems, "a,b") + tests := []aggTest{ buildAggTester(ast.AggFuncFirstRow, mysql.TypeLonglong, 5, 0, 2, 0), buildAggTester(ast.AggFuncFirstRow, mysql.TypeFloat, 5, 0.0, 2.0, 0.0), @@ -33,6 +40,8 @@ func (s *testSuite) TestMergePartialResult4FirstRow(c *C) { buildAggTester(ast.AggFuncFirstRow, mysql.TypeDate, 5, types.TimeFromDays(365), types.TimeFromDays(367), types.TimeFromDays(365)), buildAggTester(ast.AggFuncFirstRow, mysql.TypeDuration, 5, types.Duration{Duration: time.Duration(0)}, types.Duration{Duration: time.Duration(2)}, types.Duration{Duration: time.Duration(0)}), buildAggTester(ast.AggFuncFirstRow, mysql.TypeJSON, 5, json.CreateBinary(int64(0)), json.CreateBinary(int64(2)), json.CreateBinary(int64(0))), + buildAggTester(ast.AggFuncFirstRow, mysql.TypeEnum, 5, enumA, enumC, enumA), + buildAggTester(ast.AggFuncFirstRow, mysql.TypeSet, 5, setA, setAB, setA), } for _, test := range tests { s.testMergePartialResult(c, test) diff --git a/expression/aggregation/base_func.go b/expression/aggregation/base_func.go index 6706eeea99d9d..362ae62b31e9e 100644 --- a/expression/aggregation/base_func.go +++ b/expression/aggregation/base_func.go @@ -188,7 +188,8 @@ func (a *baseFuncDesc) typeInfer4MaxMin(ctx sessionctx.Context) { a.RetTp = a.Args[0].GetType().Clone() a.RetTp.Flag &^= mysql.NotNullFlag } - if a.RetTp.Tp == mysql.TypeEnum || a.RetTp.Tp == mysql.TypeSet { + // TODO: fix other aggFuncs for TypeEnum & TypeSet + if (a.RetTp.Tp == mysql.TypeEnum || a.RetTp.Tp == mysql.TypeSet) && a.Name != ast.AggFuncFirstRow { a.RetTp = &types.FieldType{Tp: mysql.TypeString, Flen: mysql.MaxFieldCharLength} } } diff --git a/types/datum.go b/types/datum.go index d6421cce35730..9df170ddc3e48 100644 --- a/types/datum.go +++ b/types/datum.go @@ -1716,6 +1716,12 @@ func NewMysqlEnumDatum(e Enum) (d Datum) { return d } +// NewMysqlSetDatum creates a new MysqlSet Datum for a Enum value. +func NewMysqlSetDatum(e Set) (d Datum) { + d.SetMysqlSet(e) + return d +} + // MakeDatums creates datum slice from interfaces. func MakeDatums(args ...interface{}) []Datum { datums := make([]Datum, len(args))