diff --git a/executor/aggfuncs/aggfuncs.go b/executor/aggfuncs/aggfuncs.go index 0e34e8aa65dbd..535c678178c09 100644 --- a/executor/aggfuncs/aggfuncs.go +++ b/executor/aggfuncs/aggfuncs.go @@ -24,6 +24,16 @@ import ( // All the AggFunc implementations are listed here for navigation. var ( // All the AggFunc implementations for "COUNT" are listed here. + _ AggFunc = (*countPartial)(nil) + _ AggFunc = (*countOriginal4Int)(nil) + _ AggFunc = (*countOriginal4Real)(nil) + _ AggFunc = (*countOriginal4Decimal)(nil) + _ AggFunc = (*countOriginal4Time)(nil) + _ AggFunc = (*countOriginal4Duration)(nil) + _ AggFunc = (*countOriginal4JSON)(nil) + _ AggFunc = (*countOriginal4String)(nil) + _ AggFunc = (*countOriginalWithDistinct)(nil) + // All the AggFunc implementations for "SUM" are listed here. // All the AggFunc implementations for "FIRSTROW" are listed here. // All the AggFunc implementations for "MAX"/"MIN" are listed here. diff --git a/executor/aggfuncs/builder.go b/executor/aggfuncs/builder.go index 0019a623655ed..427e8589a0671 100644 --- a/executor/aggfuncs/builder.go +++ b/executor/aggfuncs/builder.go @@ -50,6 +50,45 @@ func Build(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc { // buildCount builds the AggFunc implementation for function "COUNT". func buildCount(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc { + // If mode is DedupMode, we return nil for not implemented. + if aggFuncDesc.Mode == aggregation.DedupMode { + return nil // not implemented yet. + } + + base := baseAggFunc{ + args: aggFuncDesc.Args, + ordinal: ordinal, + } + + // If HasDistinct and mode is CompleteMode or Partial1Mode, we should + // use countOriginalWithDistinct. + if aggFuncDesc.HasDistinct && + (aggFuncDesc.Mode == aggregation.CompleteMode || aggFuncDesc.Mode == aggregation.Partial1Mode) { + return &countOriginalWithDistinct{baseCount{base}} + } + + switch aggFuncDesc.Mode { + case aggregation.CompleteMode, aggregation.Partial1Mode: + switch aggFuncDesc.Args[0].GetType().EvalType() { + case types.ETInt: + return &countOriginal4Int{baseCount{base}} + case types.ETReal: + return &countOriginal4Real{baseCount{base}} + case types.ETDecimal: + return &countOriginal4Decimal{baseCount{base}} + case types.ETTimestamp, types.ETDatetime: + return &countOriginal4Time{baseCount{base}} + case types.ETDuration: + return &countOriginal4Duration{baseCount{base}} + case types.ETJson: + return &countOriginal4JSON{baseCount{base}} + case types.ETString: + return &countOriginal4String{baseCount{base}} + } + case aggregation.Partial2Mode, aggregation.FinalMode: + return &countPartial{baseCount{base}} + } + return nil } diff --git a/executor/aggfuncs/func_count.go b/executor/aggfuncs/func_count.go new file mode 100644 index 0000000000000..0948eeae1bfde --- /dev/null +++ b/executor/aggfuncs/func_count.go @@ -0,0 +1,382 @@ +package aggfuncs + +import ( + "encoding/binary" + "unsafe" + + "github.com/juju/errors" + "github.com/pingcap/tidb/expression" + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/types/json" + "github.com/pingcap/tidb/util/chunk" +) + +type baseCount struct { + baseAggFunc +} + +type partialResult4Count = int64 + +func (e *baseCount) AllocPartialResult() PartialResult { + return PartialResult(new(partialResult4Count)) +} + +func (e *baseCount) ResetPartialResult(pr PartialResult) { + p := (*partialResult4Count)(pr) + *p = 0 +} + +func (e *baseCount) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error { + p := (*partialResult4Count)(pr) + chk.AppendInt64(e.ordinal, *p) + return nil +} + +type countOriginal4Int struct { + baseCount +} + +func (e *countOriginal4Int) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error { + p := (*partialResult4Count)(pr) + + for _, row := range rowsInGroup { + _, isNull, err := e.args[0].EvalInt(sctx, row) + if err != nil { + return errors.Trace(err) + } + if isNull { + continue + } + + *p++ + } + + return nil +} + +type countOriginal4Real struct { + baseCount +} + +func (e *countOriginal4Real) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error { + p := (*partialResult4Count)(pr) + + for _, row := range rowsInGroup { + _, isNull, err := e.args[0].EvalReal(sctx, row) + if err != nil { + return errors.Trace(err) + } + if isNull { + continue + } + + *p++ + } + + return nil +} + +type countOriginal4Decimal struct { + baseCount +} + +func (e *countOriginal4Decimal) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error { + p := (*partialResult4Count)(pr) + + for _, row := range rowsInGroup { + _, isNull, err := e.args[0].EvalDecimal(sctx, row) + if err != nil { + return errors.Trace(err) + } + if isNull { + continue + } + + *p++ + } + + return nil +} + +type countOriginal4Time struct { + baseCount +} + +func (e *countOriginal4Time) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error { + p := (*partialResult4Count)(pr) + + for _, row := range rowsInGroup { + _, isNull, err := e.args[0].EvalTime(sctx, row) + if err != nil { + return errors.Trace(err) + } + if isNull { + continue + } + + *p++ + } + + return nil +} + +type countOriginal4Duration struct { + baseCount +} + +func (e *countOriginal4Duration) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error { + p := (*partialResult4Count)(pr) + + for _, row := range rowsInGroup { + _, isNull, err := e.args[0].EvalDuration(sctx, row) + if err != nil { + return errors.Trace(err) + } + if isNull { + continue + } + + *p++ + } + + return nil +} + +type countOriginal4JSON struct { + baseCount +} + +func (e *countOriginal4JSON) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error { + p := (*partialResult4Count)(pr) + + for _, row := range rowsInGroup { + _, isNull, err := e.args[0].EvalJSON(sctx, row) + if err != nil { + return errors.Trace(err) + } + if isNull { + continue + } + + *p++ + } + + return nil +} + +type countOriginal4String struct { + baseCount +} + +func (e *countOriginal4String) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error { + p := (*partialResult4Count)(pr) + + for _, row := range rowsInGroup { + _, isNull, err := e.args[0].EvalString(sctx, row) + if err != nil { + return errors.Trace(err) + } + if isNull { + continue + } + + *p++ + } + + return nil +} + +type countPartial struct { + baseCount +} + +func (e *countPartial) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error { + p := (*partialResult4Count)(pr) + for _, row := range rowsInGroup { + input, isNull, err := e.args[0].EvalInt(sctx, row) + if err != nil { + return errors.Trace(err) + } + if isNull { + continue + } + + *p += input + } + return nil +} + +type countOriginalWithDistinct struct { + baseCount +} + +type partialResult4CountWithDistinct struct { + count int64 + + valSet stringSet +} + +func (e *countOriginalWithDistinct) AllocPartialResult() PartialResult { + return PartialResult(&partialResult4CountWithDistinct{ + count: 0, + valSet: newStringSet(), + }) +} + +func (e *countOriginalWithDistinct) ResetPartialResult(pr PartialResult) { + p := (*partialResult4CountWithDistinct)(pr) + p.count = 0 + p.valSet = newStringSet() +} + +func (e *countOriginalWithDistinct) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error { + p := (*partialResult4CountWithDistinct)(pr) + chk.AppendInt64(e.ordinal, p.count) + return nil +} + +func (e *countOriginalWithDistinct) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) (err error) { + p := (*partialResult4CountWithDistinct)(pr) + + hasNull, isNull := false, false + encodedBytes := make([]byte, 0) + // Decimal struct is the biggest type we will use. + buf := make([]byte, types.MyDecimalStructSize) + + for _, row := range rowsInGroup { + hasNull, isNull = false, false + encodedBytes = encodedBytes[:0] + + for i := 0; i < len(e.args) && !hasNull; i++ { + encodedBytes, isNull, err = e.evalAndEncode(sctx, e.args[i], row, buf, encodedBytes) + if err != nil { + return + } + if isNull { + hasNull = true + break + } + } + encodedString := string(encodedBytes) + if hasNull || p.valSet.exist(encodedString) { + continue + } + p.valSet.insert(encodedString) + p.count++ + } + + return nil +} + +// evalAndEncode eval one row with an expression and encode value to bytes. +func (e *countOriginalWithDistinct) evalAndEncode( + sctx sessionctx.Context, arg expression.Expression, + row chunk.Row, buf, encodedBytes []byte, +) ([]byte, bool, error) { + switch tp := arg.GetType().EvalType(); tp { + case types.ETInt: + val, isNull, err := arg.EvalInt(sctx, row) + if err != nil || isNull { + return encodedBytes, isNull, errors.Trace(err) + } + encodedBytes = appendInt64(encodedBytes, buf, val) + case types.ETReal: + val, isNull, err := arg.EvalReal(sctx, row) + if err != nil || isNull { + return encodedBytes, isNull, errors.Trace(err) + } + encodedBytes = appendFloat64(encodedBytes, buf, val) + case types.ETDecimal: + val, isNull, err := arg.EvalDecimal(sctx, row) + if err != nil || isNull { + return encodedBytes, isNull, errors.Trace(err) + } + encodedBytes = appendDecimal(encodedBytes, buf, val) + case types.ETTimestamp, types.ETDatetime: + val, isNull, err := arg.EvalTime(sctx, row) + if err != nil || isNull { + return encodedBytes, isNull, errors.Trace(err) + } + encodedBytes = appendTime(encodedBytes, buf, val) + case types.ETDuration: + val, isNull, err := arg.EvalDuration(sctx, row) + if err != nil || isNull { + return encodedBytes, isNull, errors.Trace(err) + } + encodedBytes = appendDuration(encodedBytes, buf, val) + case types.ETJson: + val, isNull, err := arg.EvalJSON(sctx, row) + if err != nil || isNull { + return encodedBytes, isNull, errors.Trace(err) + } + encodedBytes = appendJSON(encodedBytes, buf, val) + case types.ETString: + val, isNull, err := arg.EvalString(sctx, row) + if err != nil || isNull { + return encodedBytes, isNull, errors.Trace(err) + } + encodedBytes = appendString(encodedBytes, buf, val) + default: + return nil, false, errors.Errorf("unsupported column type for encode %d", tp) + } + return encodedBytes, false, nil +} + +func appendInt64(encodedBytes, buf []byte, val int64) []byte { + *(*int64)(unsafe.Pointer(&buf[0])) = val + buf = buf[:8] + encodedBytes = append(encodedBytes, buf...) + return encodedBytes +} + +func appendFloat64(encodedBytes, buf []byte, val float64) []byte { + *(*float64)(unsafe.Pointer(&buf[0])) = val + buf = buf[:8] + encodedBytes = append(encodedBytes, buf...) + return encodedBytes +} + +func appendDecimal(encodedBytes, buf []byte, val *types.MyDecimal) []byte { + *(*types.MyDecimal)(unsafe.Pointer(&buf[0])) = *val + buf = buf[:types.MyDecimalStructSize] + encodedBytes = append(encodedBytes, buf...) + return encodedBytes +} + +func writeTime(buf []byte, t types.Time) { + binary.BigEndian.PutUint16(buf, uint16(t.Time.Year())) + buf[2] = uint8(t.Time.Month()) + buf[3] = uint8(t.Time.Day()) + buf[4] = uint8(t.Time.Hour()) + buf[5] = uint8(t.Time.Minute()) + buf[6] = uint8(t.Time.Second()) + binary.BigEndian.PutUint32(buf[8:], uint32(t.Time.Microsecond())) + buf[12] = t.Type + buf[13] = uint8(t.Fsp) +} + +func appendTime(encodedBytes, buf []byte, val types.Time) []byte { + writeTime(buf, val) + buf = buf[:16] + encodedBytes = append(encodedBytes, buf...) + return encodedBytes +} + +func appendDuration(encodedBytes, buf []byte, val types.Duration) []byte { + *(*types.Duration)(unsafe.Pointer(&buf[0])) = val + buf = buf[:16] + encodedBytes = append(encodedBytes, buf...) + return encodedBytes +} + +func appendJSON(encodedBytes, _ []byte, val json.BinaryJSON) []byte { + encodedBytes = append(encodedBytes, val.TypeCode) + encodedBytes = append(encodedBytes, val.Value...) + return encodedBytes +} + +func appendString(encodedBytes, _ []byte, val string) []byte { + encodedBytes = append(encodedBytes, val...) + return encodedBytes +} diff --git a/executor/aggfuncs/sets.go b/executor/aggfuncs/sets.go index 4c37db9922454..642b1e5dc1289 100644 --- a/executor/aggfuncs/sets.go +++ b/executor/aggfuncs/sets.go @@ -19,6 +19,7 @@ import ( type decimalSet map[types.MyDecimal]struct{} type float64Set map[float64]struct{} +type stringSet map[string]struct{} func newDecimalSet() decimalSet { return make(map[types.MyDecimal]struct{}) @@ -45,3 +46,16 @@ func (s float64Set) exist(val float64) bool { func (s float64Set) insert(val float64) { s[val] = struct{}{} } + +func newStringSet() stringSet { + return make(map[string]struct{}) +} + +func (s stringSet) exist(val string) bool { + _, ok := s[val] + return ok +} + +func (s stringSet) insert(val string) { + s[val] = struct{}{} +} diff --git a/executor/aggregate_test.go b/executor/aggregate_test.go index 73d13dcce44cd..18ecc92c0da19 100644 --- a/executor/aggregate_test.go +++ b/executor/aggregate_test.go @@ -262,6 +262,11 @@ func (s *testSuite) TestAggregation(c *C) { tk.MustExec("insert into t values(1, 2, 3), (1, 2, 4)") result = tk.MustQuery("select count(distinct c), count(distinct a,b) from t") result.Check(testkit.Rows("2 1")) + tk.MustExec("drop table if exists t") + tk.MustExec("create table t (a float)") + tk.MustExec("insert into t values(966.36), (363.97), (569.99), (453.33), (376.45), (321.93), (12.12), (45.77), (9.66), (612.17)") + result = tk.MustQuery("select distinct count(distinct a) from t") + result.Check(testkit.Rows("10")) tk.MustExec("create table idx_agg (a int, b int, index (b))") tk.MustExec("insert idx_agg values (1, 1), (1, 2), (2, 2)")