From b29d52ba9cdb1f5e4a2abfb632c9066ec982b465 Mon Sep 17 00:00:00 2001 From: HuaiyuXu <391585975@qq.com> Date: Tue, 17 Jul 2018 20:22:13 +0800 Subject: [PATCH] executor: support group_concat under new aggregation evaluation framework (#7032) --- executor/aggfuncs/aggfuncs.go | 3 + executor/aggfuncs/builder.go | 32 +++++- executor/aggfuncs/func_group_concat.go | 144 +++++++++++++++++++++++++ executor/aggregate_test.go | 3 + 4 files changed, 181 insertions(+), 1 deletion(-) create mode 100644 executor/aggfuncs/func_group_concat.go diff --git a/executor/aggfuncs/aggfuncs.go b/executor/aggfuncs/aggfuncs.go index ac9b2ff238ca9..7dd7a3960e584 100644 --- a/executor/aggfuncs/aggfuncs.go +++ b/executor/aggfuncs/aggfuncs.go @@ -55,6 +55,9 @@ var ( _ AggFunc = (*avgOriginal4DistinctFloat64)(nil) // All the AggFunc implementations for "GROUP_CONCAT" are listed here. + _ AggFunc = (*groupConcatDistinct)(nil) + _ AggFunc = (*groupConcat)(nil) + // All the AggFunc implementations for "BIT_OR" are listed here. _ AggFunc = (*bitOrUint64)(nil) diff --git a/executor/aggfuncs/builder.go b/executor/aggfuncs/builder.go index 24579a96fcf7d..e8ded00d0d993 100644 --- a/executor/aggfuncs/builder.go +++ b/executor/aggfuncs/builder.go @@ -14,7 +14,10 @@ package aggfuncs import ( + "fmt" + "github.com/juju/errors" "github.com/pingcap/tidb/ast" + "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/expression/aggregation" "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/types" @@ -187,7 +190,34 @@ func buildMaxMin(aggFuncDesc *aggregation.AggFuncDesc, ordinal int, isMax bool) // buildGroupConcat builds the AggFunc implementation for function "GROUP_CONCAT". func buildGroupConcat(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc { - return nil + // TODO: There might be different kind of types of the args, + // we should add CastAsString upon every arg after cast can be pushed down to coprocessor. + // And this check can be removed at that time. + for _, arg := range aggFuncDesc.Args { + if arg.GetType().EvalType() != types.ETString { + return nil + } + } + switch aggFuncDesc.Mode { + case aggregation.DedupMode: + return nil + default: + base := baseAggFunc{ + args: aggFuncDesc.Args[:len(aggFuncDesc.Args)-1], + ordinal: ordinal, + } + // The last arg is promised to be a not-null string constant, so the error can be ignored. + c, _ := aggFuncDesc.Args[len(aggFuncDesc.Args)-1].(*expression.Constant) + sep, _, err := c.EvalString(nil, nil) + // This err should never happen. + if err != nil { + panic(fmt.Sprintf("Error happened when buildGroupConcat: %s", errors.Trace(err).Error())) + } + if aggFuncDesc.HasDistinct { + return &groupConcatDistinct{baseGroupConcat4String{baseAggFunc: base, sep: sep}} + } + return &groupConcat{baseGroupConcat4String{baseAggFunc: base, sep: sep}} + } } // buildBitOr builds the AggFunc implementation for function "BIT_OR". diff --git a/executor/aggfuncs/func_group_concat.go b/executor/aggfuncs/func_group_concat.go new file mode 100644 index 0000000000000..79c361ddb34ea --- /dev/null +++ b/executor/aggfuncs/func_group_concat.go @@ -0,0 +1,144 @@ +// Copyright 2018 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package aggfuncs + +import ( + "bytes" + + "github.com/juju/errors" + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/util/chunk" +) + +type baseGroupConcat4String struct { + baseAggFunc + + sep string +} + +func (e *baseGroupConcat4String) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error { + p := (*partialResult4GroupConcat)(pr) + if p.buffer == nil { + chk.AppendNull(e.ordinal) + return nil + } + chk.AppendString(e.ordinal, p.buffer.String()) + return nil +} + +type basePartialResult4GroupConcat struct { + buffer *bytes.Buffer +} + +type partialResult4GroupConcat struct { + basePartialResult4GroupConcat +} + +type groupConcat struct { + baseGroupConcat4String +} + +func (e *groupConcat) AllocPartialResult() PartialResult { + return PartialResult(new(partialResult4GroupConcat)) +} + +func (e *groupConcat) ResetPartialResult(pr PartialResult) { + p := (*partialResult4GroupConcat)(pr) + p.buffer = nil +} + +func (e *groupConcat) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) (err error) { + p := (*partialResult4GroupConcat)(pr) + v, isNull, isWriteSep := "", false, false + for _, row := range rowsInGroup { + isWriteSep = false + for _, arg := range e.args { + v, isNull, err = arg.EvalString(sctx, row) + if err != nil { + return errors.Trace(err) + } + if isNull { + continue + } + isWriteSep = true + if p.buffer == nil { + p.buffer = &bytes.Buffer{} + } + p.buffer.WriteString(v) + } + if isWriteSep { + p.buffer.WriteString(e.sep) + } + } + p.buffer.Truncate(p.buffer.Len() - len(e.sep)) + // TODO: if total length is greater than global var group_concat_max_len, truncate it. + // issue: #7034 + return nil +} + +type partialResult4GroupConcatDistinct struct { + basePartialResult4GroupConcat + valsBuf *bytes.Buffer + valSet stringSet +} + +type groupConcatDistinct struct { + baseGroupConcat4String +} + +func (e *groupConcatDistinct) AllocPartialResult() PartialResult { + p := new(partialResult4GroupConcatDistinct) + p.valsBuf = &bytes.Buffer{} + p.valSet = newStringSet() + return PartialResult(p) +} + +func (e *groupConcatDistinct) ResetPartialResult(pr PartialResult) { + p := (*partialResult4GroupConcatDistinct)(pr) + p.buffer, p.valSet = nil, newStringSet() +} + +func (e *groupConcatDistinct) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) (err error) { + p := (*partialResult4GroupConcatDistinct)(pr) + v, isNull := "", false + for _, row := range rowsInGroup { + p.valsBuf.Reset() + for _, arg := range e.args { + v, isNull, err = arg.EvalString(sctx, row) + if err != nil { + return errors.Trace(err) + } + if isNull { + continue + } + p.valsBuf.WriteString(v) + } + joinedVals := p.valsBuf.String() + if p.valSet.exist(joinedVals) { + continue + } + p.valSet.insert(joinedVals) + // write separator + if p.buffer == nil { + p.buffer = &bytes.Buffer{} + } else { + p.buffer.WriteString(e.sep) + } + // write values + p.buffer.WriteString(joinedVals) + } + // TODO: if total length is greater than global var group_concat_max_len, truncate it. + // issue: #7034 + return nil +} diff --git a/executor/aggregate_test.go b/executor/aggregate_test.go index 107b626c9824f..d5973745c55d3 100644 --- a/executor/aggregate_test.go +++ b/executor/aggregate_test.go @@ -375,6 +375,9 @@ func (s *testSuite) TestGroupConcatAggr(c *C) { result = tk.MustQuery("select id, group_concat(name SEPARATOR '') from test group by id order by id") result.Check(testkit.Rows("1 102030", "2 20", "3 200500")) + + result = tk.MustQuery("select id, group_concat(name SEPARATOR '123') from test group by id order by id") + result.Check(testkit.Rows("1 101232012330", "2 20", "3 200123500")) } func (s *testSuite) TestSelectDistinct(c *C) {