diff --git a/executor/aggfuncs/builder.go b/executor/aggfuncs/builder.go index 02e1570b4263d..f89b549e1b214 100644 --- a/executor/aggfuncs/builder.go +++ b/executor/aggfuncs/builder.go @@ -424,7 +424,7 @@ func buildLeadLag(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) baseLeadLag args: aggFuncDesc.Args, ordinal: ordinal, } - return baseLeadLag{baseAggFunc: base, offset: offset, defaultExpr: defaultExpr, valueEvaluator: buildValueEvaluator(aggFuncDesc.RetTp)} + return baseLeadLag{baseAggFunc: base, offset: offset, defaultExpr: defaultExpr, retTp: aggFuncDesc.RetTp} } func buildLead(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc { diff --git a/executor/aggfuncs/func_cume_dist.go b/executor/aggfuncs/func_cume_dist.go index 37e1ffb1636a5..9637f11cdc76e 100644 --- a/executor/aggfuncs/func_cume_dist.go +++ b/executor/aggfuncs/func_cume_dist.go @@ -24,35 +24,33 @@ type cumeDist struct { } type partialResult4CumeDist struct { - curIdx int - lastRank int - rows []chunk.Row + partialResult4Rank + cum int64 } func (r *cumeDist) AllocPartialResult() PartialResult { - return PartialResult(&partialResult4Rank{}) + return PartialResult(&partialResult4CumeDist{}) } func (r *cumeDist) ResetPartialResult(pr PartialResult) { - p := (*partialResult4Rank)(pr) - p.curIdx = 0 - p.lastRank = 0 - p.rows = p.rows[:0] + p := (*partialResult4CumeDist)(pr) + p.partialResult4Rank.reset() + p.cum = 0 } func (r *cumeDist) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error { p := (*partialResult4CumeDist)(pr) - p.rows = append(p.rows, rowsInGroup...) + p.partialResult4Rank.updatePartialResult(rowsInGroup, false, r.compareRows) return nil } func (r *cumeDist) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error { p := (*partialResult4CumeDist)(pr) - numRows := len(p.rows) - for p.lastRank < numRows && r.compareRows(p.rows[p.curIdx], p.rows[p.lastRank]) == 0 { - p.lastRank++ + numRows := int64(len(p.results)) + for p.cum < numRows && p.results[p.cum] == p.results[p.curIdx] { + p.cum++ } p.curIdx++ - chk.AppendFloat64(r.ordinal, float64(p.lastRank)/float64(numRows)) + chk.AppendFloat64(r.ordinal, float64(p.cum)/float64(numRows)) return nil } diff --git a/executor/aggfuncs/func_lead_lag.go b/executor/aggfuncs/func_lead_lag.go index ba53e9eb47809..35150de8871fe 100644 --- a/executor/aggfuncs/func_lead_lag.go +++ b/executor/aggfuncs/func_lead_lag.go @@ -14,76 +14,200 @@ package aggfuncs import ( + "github.com/cznic/mathutil" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" ) type baseLeadLag struct { baseAggFunc - valueEvaluator // TODO: move it to partial result when parallel execution is supported. defaultExpr expression.Expression offset uint64 + retTp *types.FieldType } -type partialResult4LeadLag struct { - rows []chunk.Row - curIdx uint64 +type circleBuf struct { + buf []valueExtractor + head, tail int + size int } -func (v *baseLeadLag) AllocPartialResult() PartialResult { - return PartialResult(&partialResult4LeadLag{}) +func (cb *circleBuf) reset() { + cb.buf = cb.buf[:0] + cb.head, cb.tail = 0, 0 } -func (v *baseLeadLag) ResetPartialResult(pr PartialResult) { - p := (*partialResult4LeadLag)(pr) - p.rows = p.rows[:0] - p.curIdx = 0 +func (cb *circleBuf) append(e valueExtractor) { + if len(cb.buf) < cb.size { + cb.buf = append(cb.buf, e) + cb.tail++ + } else { + if cb.tail >= cb.size { + cb.tail = 0 + } + cb.buf[cb.tail] = e + cb.tail++ + } } -func (v *baseLeadLag) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error { - p := (*partialResult4LeadLag)(pr) - p.rows = append(p.rows, rowsInGroup...) - return nil +func (cb *circleBuf) get() (e valueExtractor) { + if len(cb.buf) < cb.size { + e = cb.buf[cb.head] + cb.head++ + } else { + if cb.tail >= cb.size { + cb.tail = 0 + } + e = cb.buf[cb.tail] + cb.tail++ + } + return e +} + +type partialResult4Lead struct { + seenRows uint64 + curIdx int + extractors []valueExtractor + defaultExtractors circleBuf + defaultConstExtractor valueExtractor } +const maxDefaultExtractorBufferSize = 1000 + type lead struct { baseLeadLag } +func (v *lead) AllocPartialResult() PartialResult { + return PartialResult(&partialResult4Lead{ + defaultExtractors: circleBuf{ + // Do not use v.offset directly since v.offset is defined by user + // and may larger than a table size. + buf: make([]valueExtractor, 0, mathutil.MinUint64(v.offset, maxDefaultExtractorBufferSize)), + size: int(v.offset), + }, + }) +} + +func (v *lead) ResetPartialResult(pr PartialResult) { + p := (*partialResult4Lead)(pr) + p.seenRows = 0 + p.curIdx = 0 + p.extractors = p.extractors[:0] + p.defaultExtractors.reset() + p.defaultConstExtractor = nil +} + +func (v *lead) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) (err error) { + p := (*partialResult4Lead)(pr) + for _, row := range rowsInGroup { + p.seenRows++ + if p.seenRows > v.offset { + e := buildValueExtractor(v.retTp) + err = e.extractRow(sctx, v.args[0], row) + if err != nil { + return err + } + p.extractors = append(p.extractors, e) + } + if v.offset > 0 { + if !v.defaultExpr.ConstItem() { + // We must cache the results of last v.offset lines. + e := buildValueExtractor(v.retTp) + err = e.extractRow(sctx, v.defaultExpr, row) + if err != nil { + return err + } + p.defaultExtractors.append(e) + } else if p.defaultConstExtractor == nil { + e := buildValueExtractor(v.retTp) + err = e.extractRow(sctx, v.defaultExpr, chunk.Row{}) + if err != nil { + return err + } + p.defaultConstExtractor = e + } + } + } + return nil +} + func (v *lead) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error { - p := (*partialResult4LeadLag)(pr) - var err error - if p.curIdx+v.offset < uint64(len(p.rows)) { - err = v.evaluateRow(sctx, v.args[0], p.rows[p.curIdx+v.offset]) + p := (*partialResult4Lead)(pr) + var e valueExtractor + if p.curIdx < len(p.extractors) { + e = p.extractors[p.curIdx] } else { - err = v.evaluateRow(sctx, v.defaultExpr, p.rows[p.curIdx]) + if !v.defaultExpr.ConstItem() { + e = p.defaultExtractors.get() + } else { + e = p.defaultConstExtractor + } } - if err != nil { - return err - } - v.appendResult(chk, v.ordinal) + e.appendResult(chk, v.ordinal) p.curIdx++ return nil } +type partialResult4Lag struct { + seenRows uint64 + curIdx uint64 + extractors []valueExtractor + defaultExtractors []valueExtractor +} + type lag struct { baseLeadLag } +func (v *lag) AllocPartialResult() PartialResult { + return PartialResult(&partialResult4Lag{ + defaultExtractors: make([]valueExtractor, 0, mathutil.MinUint64(v.offset, maxDefaultExtractorBufferSize)), + }) +} + +func (v *lag) ResetPartialResult(pr PartialResult) { + p := (*partialResult4Lag)(pr) + p.seenRows = 0 + p.curIdx = 0 + p.extractors = p.extractors[:0] + p.defaultExtractors = p.defaultExtractors[:0] +} + +func (v *lag) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) (err error) { + p := (*partialResult4Lag)(pr) + for _, row := range rowsInGroup { + p.seenRows++ + if p.seenRows <= v.offset { + e := buildValueExtractor(v.retTp) + err = e.extractRow(sctx, v.defaultExpr, row) + if err != nil { + return err + } + p.defaultExtractors = append(p.defaultExtractors, e) + } + e := buildValueExtractor(v.retTp) + err = e.extractRow(sctx, v.args[0], row) + if err != nil { + return err + } + p.extractors = append(p.extractors, e) + } + return nil +} + func (v *lag) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error { - p := (*partialResult4LeadLag)(pr) - var err error - if p.curIdx >= v.offset { - err = v.evaluateRow(sctx, v.args[0], p.rows[p.curIdx-v.offset]) + p := (*partialResult4Lag)(pr) + var e valueExtractor + if p.curIdx < v.offset { + e = p.defaultExtractors[p.curIdx] } else { - err = v.evaluateRow(sctx, v.defaultExpr, p.rows[p.curIdx]) - } - if err != nil { - return err + e = p.extractors[p.curIdx-v.offset] } - v.appendResult(chk, v.ordinal) + e.appendResult(chk, v.ordinal) p.curIdx++ return nil } diff --git a/executor/aggfuncs/func_lead_lag_test.go b/executor/aggfuncs/func_lead_lag_test.go new file mode 100644 index 0000000000000..fd4e5aa23dfcb --- /dev/null +++ b/executor/aggfuncs/func_lead_lag_test.go @@ -0,0 +1,114 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package aggfuncs_test + +import ( + . "github.com/pingcap/check" + "github.com/pingcap/parser/ast" + "github.com/pingcap/parser/mysql" + "github.com/pingcap/tidb/expression" + "github.com/pingcap/tidb/types" +) + +func (s *testSuite) TestLeadLag(c *C) { + zero := expression.Zero + one := expression.One + two := &expression.Constant{ + Value: types.NewDatum(2), + RetType: types.NewFieldType(mysql.TypeTiny), + } + three := &expression.Constant{ + Value: types.NewDatum(3), + RetType: types.NewFieldType(mysql.TypeTiny), + } + million := &expression.Constant{ + Value: types.NewDatum(1000000), + RetType: types.NewFieldType(mysql.TypeLong), + } + defaultArg := &expression.Column{RetType: types.NewFieldType(mysql.TypeLonglong), Index: 0} + + numRows := 3 + tests := []windowTest{ + // lag(field0, N) + buildWindowTesterWithArgs(ast.WindowFuncLag, mysql.TypeLonglong, + []expression.Expression{zero}, 0, numRows, 0, 1, 2), + buildWindowTesterWithArgs(ast.WindowFuncLag, mysql.TypeLonglong, + []expression.Expression{one}, 0, numRows, nil, 0, 1), + buildWindowTesterWithArgs(ast.WindowFuncLag, mysql.TypeLonglong, + []expression.Expression{two}, 0, numRows, nil, nil, 0), + buildWindowTesterWithArgs(ast.WindowFuncLag, mysql.TypeLonglong, + []expression.Expression{three}, 0, numRows, nil, nil, nil), + buildWindowTesterWithArgs(ast.WindowFuncLag, mysql.TypeLonglong, + []expression.Expression{million}, 0, numRows, nil, nil, nil), + // lag(field0, N, 1000000) + buildWindowTesterWithArgs(ast.WindowFuncLag, mysql.TypeLonglong, + []expression.Expression{zero, million}, 0, numRows, 0, 1, 2), + buildWindowTesterWithArgs(ast.WindowFuncLag, mysql.TypeLonglong, + []expression.Expression{one, million}, 0, numRows, 1000000, 0, 1), + buildWindowTesterWithArgs(ast.WindowFuncLag, mysql.TypeLonglong, + []expression.Expression{two, million}, 0, numRows, 1000000, 1000000, 0), + buildWindowTesterWithArgs(ast.WindowFuncLag, mysql.TypeLonglong, + []expression.Expression{three, million}, 0, numRows, 1000000, 1000000, 1000000), + buildWindowTesterWithArgs(ast.WindowFuncLag, mysql.TypeLonglong, + []expression.Expression{million, million}, 0, numRows, 1000000, 1000000, 1000000), + // lag(field0, N, field0) + buildWindowTesterWithArgs(ast.WindowFuncLag, mysql.TypeLonglong, + []expression.Expression{zero, defaultArg}, 0, numRows, 0, 1, 2), + buildWindowTesterWithArgs(ast.WindowFuncLag, mysql.TypeLonglong, + []expression.Expression{one, defaultArg}, 0, numRows, 0, 0, 1), + buildWindowTesterWithArgs(ast.WindowFuncLag, mysql.TypeLonglong, + []expression.Expression{two, defaultArg}, 0, numRows, 0, 1, 0), + buildWindowTesterWithArgs(ast.WindowFuncLag, mysql.TypeLonglong, + []expression.Expression{three, defaultArg}, 0, numRows, 0, 1, 2), + buildWindowTesterWithArgs(ast.WindowFuncLag, mysql.TypeLonglong, + []expression.Expression{million, defaultArg}, 0, numRows, 0, 1, 2), + + // lead(field0, N) + buildWindowTesterWithArgs(ast.WindowFuncLead, mysql.TypeLonglong, + []expression.Expression{zero}, 0, numRows, 0, 1, 2), + buildWindowTesterWithArgs(ast.WindowFuncLead, mysql.TypeLonglong, + []expression.Expression{one}, 0, numRows, 1, 2, nil), + buildWindowTesterWithArgs(ast.WindowFuncLead, mysql.TypeLonglong, + []expression.Expression{two}, 0, numRows, 2, nil, nil), + buildWindowTesterWithArgs(ast.WindowFuncLead, mysql.TypeLonglong, + []expression.Expression{three}, 0, numRows, nil, nil, nil), + buildWindowTesterWithArgs(ast.WindowFuncLead, mysql.TypeLonglong, + []expression.Expression{million}, 0, numRows, nil, nil, nil), + // lead(field0, N, 1000000) + buildWindowTesterWithArgs(ast.WindowFuncLead, mysql.TypeLonglong, + []expression.Expression{zero, million}, 0, numRows, 0, 1, 2), + buildWindowTesterWithArgs(ast.WindowFuncLead, mysql.TypeLonglong, + []expression.Expression{one, million}, 0, numRows, 1, 2, 1000000), + buildWindowTesterWithArgs(ast.WindowFuncLead, mysql.TypeLonglong, + []expression.Expression{two, million}, 0, numRows, 2, 1000000, 1000000), + buildWindowTesterWithArgs(ast.WindowFuncLead, mysql.TypeLonglong, + []expression.Expression{three, million}, 0, numRows, 1000000, 1000000, 1000000), + buildWindowTesterWithArgs(ast.WindowFuncLead, mysql.TypeLonglong, + []expression.Expression{million, million}, 0, numRows, 1000000, 1000000, 1000000), + // lead(field0, N, field0) + buildWindowTesterWithArgs(ast.WindowFuncLead, mysql.TypeLonglong, + []expression.Expression{zero, defaultArg}, 0, numRows, 0, 1, 2), + buildWindowTesterWithArgs(ast.WindowFuncLead, mysql.TypeLonglong, + []expression.Expression{one, defaultArg}, 0, numRows, 1, 2, 2), + buildWindowTesterWithArgs(ast.WindowFuncLead, mysql.TypeLonglong, + []expression.Expression{two, defaultArg}, 0, numRows, 2, 1, 2), + buildWindowTesterWithArgs(ast.WindowFuncLead, mysql.TypeLonglong, + []expression.Expression{three, defaultArg}, 0, numRows, 0, 1, 2), + buildWindowTesterWithArgs(ast.WindowFuncLead, mysql.TypeLonglong, + []expression.Expression{million, defaultArg}, 0, numRows, 0, 1, 2), + } + for _, test := range tests { + s.testWindowFunc(c, test) + } +} diff --git a/executor/aggfuncs/func_ntile.go b/executor/aggfuncs/func_ntile.go index 1adbb326d7609..eefa0a9cd6b83 100644 --- a/executor/aggfuncs/func_ntile.go +++ b/executor/aggfuncs/func_ntile.go @@ -31,7 +31,7 @@ type partialResult4Ntile struct { curGroupIdx uint64 remainder uint64 quotient uint64 - rows []chunk.Row + numRows uint64 } func (n *ntile) AllocPartialResult() PartialResult { @@ -42,16 +42,16 @@ func (n *ntile) ResetPartialResult(pr PartialResult) { p := (*partialResult4Ntile)(pr) p.curIdx = 0 p.curGroupIdx = 1 - p.rows = p.rows[:0] + p.numRows = 0 } func (n *ntile) UpdatePartialResult(_ sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error { p := (*partialResult4Ntile)(pr) - p.rows = append(p.rows, rowsInGroup...) + p.numRows += uint64(len(rowsInGroup)) // Update the quotient and remainder. if n.n != 0 { - p.quotient = uint64(len(p.rows)) / n.n - p.remainder = uint64(len(p.rows)) % n.n + p.quotient = p.numRows / n.n + p.remainder = p.numRows % n.n } return nil } diff --git a/executor/aggfuncs/func_percent_rank.go b/executor/aggfuncs/func_percent_rank.go index 0ea863a048bb6..86be2a4a0579a 100644 --- a/executor/aggfuncs/func_percent_rank.go +++ b/executor/aggfuncs/func_percent_rank.go @@ -31,31 +31,23 @@ func (pr *percentRank) AllocPartialResult() PartialResult { func (pr *percentRank) ResetPartialResult(partial PartialResult) { p := (*partialResult4Rank)(partial) - p.curIdx = 0 - p.lastRank = 0 - p.rows = p.rows[:0] + p.reset() } func (pr *percentRank) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, partial PartialResult) error { p := (*partialResult4Rank)(partial) - p.rows = append(p.rows, rowsInGroup...) + p.updatePartialResult(rowsInGroup, false, pr.compareRows) return nil } func (pr *percentRank) AppendFinalResult2Chunk(sctx sessionctx.Context, partial PartialResult, chk *chunk.Chunk) error { p := (*partialResult4Rank)(partial) - numRows := int64(len(p.rows)) - p.curIdx++ - if p.curIdx == 1 { - p.lastRank = 1 + numRows := len(p.results) + if numRows == 1 { chk.AppendFloat64(pr.ordinal, 0) - return nil - } - if pr.compareRows(p.rows[p.curIdx-2], p.rows[p.curIdx-1]) == 0 { - chk.AppendFloat64(pr.ordinal, float64(p.lastRank-1)/float64(numRows-1)) - return nil + } else { + chk.AppendFloat64(pr.ordinal, float64(p.results[p.curIdx]-1)/float64(numRows-1)) } - p.lastRank = p.curIdx - chk.AppendFloat64(pr.ordinal, float64(p.lastRank-1)/float64(numRows-1)) + p.curIdx++ return nil } diff --git a/executor/aggfuncs/func_rank.go b/executor/aggfuncs/func_rank.go index 02267cd2146a6..a70bdaf14a0eb 100644 --- a/executor/aggfuncs/func_rank.go +++ b/executor/aggfuncs/func_rank.go @@ -27,8 +27,45 @@ type rank struct { type partialResult4Rank struct { curIdx int64 - lastRank int64 - rows []chunk.Row + seenRows int64 + results []int64 + lastRow chunk.Row +} + +func (p *partialResult4Rank) reset() { + p.curIdx = 0 + p.seenRows = 0 + p.results = p.results[:0] +} + +func (p *partialResult4Rank) updatePartialResult( + rowsInGroup []chunk.Row, + isDense bool, + compareRows func(prev, curr chunk.Row) int, +) { + if len(rowsInGroup) == 0 { + return + } + lastRow := p.lastRow + for _, row := range rowsInGroup { + p.seenRows++ + if p.seenRows == 1 { + p.results = append(p.results, 1) + lastRow = row + continue + } + var rank int64 + if compareRows(lastRow, row) == 0 { + rank = p.results[len(p.results)-1] + } else if isDense { + rank = p.results[len(p.results)-1] + 1 + } else { + rank = p.seenRows + } + p.results = append(p.results, rank) + lastRow = row + } + p.lastRow = rowsInGroup[len(rowsInGroup)-1].CopyConstruct() } func (r *rank) AllocPartialResult() PartialResult { @@ -37,35 +74,19 @@ func (r *rank) AllocPartialResult() PartialResult { func (r *rank) ResetPartialResult(pr PartialResult) { p := (*partialResult4Rank)(pr) - p.curIdx = 0 - p.lastRank = 0 - p.rows = p.rows[:0] + p.reset() } func (r *rank) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error { p := (*partialResult4Rank)(pr) - p.rows = append(p.rows, rowsInGroup...) + p.updatePartialResult(rowsInGroup, r.isDense, r.compareRows) return nil } func (r *rank) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error { p := (*partialResult4Rank)(pr) + chk.AppendInt64(r.ordinal, p.results[p.curIdx]) p.curIdx++ - if p.curIdx == 1 { - p.lastRank = 1 - chk.AppendInt64(r.ordinal, p.lastRank) - return nil - } - if r.compareRows(p.rows[p.curIdx-2], p.rows[p.curIdx-1]) == 0 { - chk.AppendInt64(r.ordinal, p.lastRank) - return nil - } - if r.isDense { - p.lastRank++ - } else { - p.lastRank = p.curIdx - } - chk.AppendInt64(r.ordinal, p.lastRank) return nil } diff --git a/executor/aggfuncs/func_value.go b/executor/aggfuncs/func_value.go index 88c8f76e8615a..9e8f63a1c2ab8 100644 --- a/executor/aggfuncs/func_value.go +++ b/executor/aggfuncs/func_value.go @@ -20,13 +20,14 @@ import ( "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/types/json" "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/stringutil" ) -// valueEvaluator is used to evaluate values for `first_value`, `last_value`, `nth_value`, +// valueExtractor is used to extract values for `first_value`, `last_value`, `nth_value`, // `lead` and `lag`. -type valueEvaluator interface { - // evaluateRow evaluates the expression using row and stores the result inside. - evaluateRow(ctx sessionctx.Context, expr expression.Expression, row chunk.Row) error +type valueExtractor interface { + // extractRow extracts the expression using row and stores the result inside. + extractRow(ctx sessionctx.Context, expr expression.Expression, row chunk.Row) error // appendResult appends the result to chunk. appendResult(chk *chunk.Chunk, colIdx int) } @@ -36,7 +37,7 @@ type value4Int struct { isNull bool } -func (v *value4Int) evaluateRow(ctx sessionctx.Context, expr expression.Expression, row chunk.Row) error { +func (v *value4Int) extractRow(ctx sessionctx.Context, expr expression.Expression, row chunk.Row) error { var err error v.val, v.isNull, err = expr.EvalInt(ctx, row) return err @@ -55,7 +56,7 @@ type value4Float32 struct { isNull bool } -func (v *value4Float32) evaluateRow(ctx sessionctx.Context, expr expression.Expression, row chunk.Row) error { +func (v *value4Float32) extractRow(ctx sessionctx.Context, expr expression.Expression, row chunk.Row) error { var err error var val float64 val, v.isNull, err = expr.EvalReal(ctx, row) @@ -76,9 +77,10 @@ type value4Decimal struct { isNull bool } -func (v *value4Decimal) evaluateRow(ctx sessionctx.Context, expr expression.Expression, row chunk.Row) error { +func (v *value4Decimal) extractRow(ctx sessionctx.Context, expr expression.Expression, row chunk.Row) error { var err error v.val, v.isNull, err = expr.EvalDecimal(ctx, row) + v.val = v.val.Copy() return err } @@ -95,7 +97,7 @@ type value4Float64 struct { isNull bool } -func (v *value4Float64) evaluateRow(ctx sessionctx.Context, expr expression.Expression, row chunk.Row) error { +func (v *value4Float64) extractRow(ctx sessionctx.Context, expr expression.Expression, row chunk.Row) error { var err error v.val, v.isNull, err = expr.EvalReal(ctx, row) return err @@ -114,12 +116,12 @@ type value4String struct { isNull bool } -func (v *value4String) evaluateRow(ctx sessionctx.Context, expr expression.Expression, row chunk.Row) error { +func (v *value4String) extractRow(ctx sessionctx.Context, expr expression.Expression, row chunk.Row) error { var err error v.val, v.isNull, err = expr.EvalString(ctx, row) + v.val = stringutil.Copy(v.val) return err } - func (v *value4String) appendResult(chk *chunk.Chunk, colIdx int) { if v.isNull { chk.AppendNull(colIdx) @@ -133,7 +135,7 @@ type value4Time struct { isNull bool } -func (v *value4Time) evaluateRow(ctx sessionctx.Context, expr expression.Expression, row chunk.Row) error { +func (v *value4Time) extractRow(ctx sessionctx.Context, expr expression.Expression, row chunk.Row) error { var err error v.val, v.isNull, err = expr.EvalTime(ctx, row) return err @@ -152,7 +154,7 @@ type value4Duration struct { isNull bool } -func (v *value4Duration) evaluateRow(ctx sessionctx.Context, expr expression.Expression, row chunk.Row) error { +func (v *value4Duration) extractRow(ctx sessionctx.Context, expr expression.Expression, row chunk.Row) error { var err error v.val, v.isNull, err = expr.EvalDuration(ctx, row) return err @@ -171,7 +173,7 @@ type value4JSON struct { isNull bool } -func (v *value4JSON) evaluateRow(ctx sessionctx.Context, expr expression.Expression, row chunk.Row) error { +func (v *value4JSON) extractRow(ctx sessionctx.Context, expr expression.Expression, row chunk.Row) error { var err error v.val, v.isNull, err = expr.EvalJSON(ctx, row) v.val = v.val.Copy() // deep copy to avoid content change. @@ -186,7 +188,7 @@ func (v *value4JSON) appendResult(chk *chunk.Chunk, colIdx int) { } } -func buildValueEvaluator(tp *types.FieldType) valueEvaluator { +func buildValueExtractor(tp *types.FieldType) valueExtractor { evalType := tp.EvalType() if tp.Tp == mysql.TypeBit { evalType = types.ETString @@ -223,11 +225,11 @@ type firstValue struct { type partialResult4FirstValue struct { gotFirstValue bool - evaluator valueEvaluator + extractor valueExtractor } func (v *firstValue) AllocPartialResult() PartialResult { - return PartialResult(&partialResult4FirstValue{evaluator: buildValueEvaluator(v.tp)}) + return PartialResult(&partialResult4FirstValue{extractor: buildValueExtractor(v.tp)}) } func (v *firstValue) ResetPartialResult(pr PartialResult) { @@ -242,7 +244,7 @@ func (v *firstValue) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup [] } if len(rowsInGroup) > 0 { p.gotFirstValue = true - err := p.evaluator.evaluateRow(sctx, v.args[0], rowsInGroup[0]) + err := p.extractor.extractRow(sctx, v.args[0], rowsInGroup[0]) if err != nil { return err } @@ -255,7 +257,7 @@ func (v *firstValue) AppendFinalResult2Chunk(sctx sessionctx.Context, pr Partial if !p.gotFirstValue { chk.AppendNull(v.ordinal) } else { - p.evaluator.appendResult(chk, v.ordinal) + p.extractor.appendResult(chk, v.ordinal) } return nil } @@ -268,11 +270,11 @@ type lastValue struct { type partialResult4LastValue struct { gotLastValue bool - evaluator valueEvaluator + extractor valueExtractor } func (v *lastValue) AllocPartialResult() PartialResult { - return PartialResult(&partialResult4LastValue{evaluator: buildValueEvaluator(v.tp)}) + return PartialResult(&partialResult4LastValue{extractor: buildValueExtractor(v.tp)}) } func (v *lastValue) ResetPartialResult(pr PartialResult) { @@ -284,7 +286,7 @@ func (v *lastValue) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []c p := (*partialResult4LastValue)(pr) if len(rowsInGroup) > 0 { p.gotLastValue = true - err := p.evaluator.evaluateRow(sctx, v.args[0], rowsInGroup[len(rowsInGroup)-1]) + err := p.extractor.extractRow(sctx, v.args[0], rowsInGroup[len(rowsInGroup)-1]) if err != nil { return err } @@ -297,7 +299,7 @@ func (v *lastValue) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialR if !p.gotLastValue { chk.AppendNull(v.ordinal) } else { - p.evaluator.appendResult(chk, v.ordinal) + p.extractor.appendResult(chk, v.ordinal) } return nil } @@ -311,11 +313,11 @@ type nthValue struct { type partialResult4NthValue struct { seenRows uint64 - evaluator valueEvaluator + extractor valueExtractor } func (v *nthValue) AllocPartialResult() PartialResult { - return PartialResult(&partialResult4NthValue{evaluator: buildValueEvaluator(v.tp)}) + return PartialResult(&partialResult4NthValue{extractor: buildValueExtractor(v.tp)}) } func (v *nthValue) ResetPartialResult(pr PartialResult) { @@ -330,7 +332,7 @@ func (v *nthValue) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []ch p := (*partialResult4NthValue)(pr) numRows := uint64(len(rowsInGroup)) if v.nth > p.seenRows && v.nth-p.seenRows <= numRows { - err := p.evaluator.evaluateRow(sctx, v.args[0], rowsInGroup[v.nth-p.seenRows-1]) + err := p.extractor.extractRow(sctx, v.args[0], rowsInGroup[v.nth-p.seenRows-1]) if err != nil { return err } @@ -344,7 +346,7 @@ func (v *nthValue) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialRe if v.nth == 0 || p.seenRows < v.nth { chk.AppendNull(v.ordinal) } else { - p.evaluator.appendResult(chk, v.ordinal) + p.extractor.appendResult(chk, v.ordinal) } return nil } diff --git a/executor/aggfuncs/window_func_test.go b/executor/aggfuncs/window_func_test.go index 1108fdc92e019..ae150668c3983 100644 --- a/executor/aggfuncs/window_func_test.go +++ b/executor/aggfuncs/window_func_test.go @@ -14,6 +14,7 @@ package aggfuncs_test import ( + "math/rand" "time" . "github.com/pingcap/check" @@ -37,6 +38,12 @@ type windowTest struct { } func (s *testSuite) testWindowFunc(c *C, p windowTest) { + s._testWindowFunc(c, p, false) + s._testWindowFunc(c, p, true) +} + +func (s *testSuite) _testWindowFunc(c *C, p windowTest, pollute bool) { + srcChk := chunk.NewChunkWithCapacity([]*types.FieldType{p.dataType}, p.numRows) dataGen := getDataGenFunc(p.dataType) for i := 0; i < p.numRows; i++ { @@ -52,11 +59,25 @@ func (s *testSuite) testWindowFunc(c *C, p windowTest) { iter := chunk.NewIterator4Chunk(srcChk) for row := iter.Begin(); row != iter.End(); row = iter.Next() { - finalFunc.UpdatePartialResult(s.ctx, []chunk.Row{row}, finalPr) + err = finalFunc.UpdatePartialResult(s.ctx, []chunk.Row{row}, finalPr) + c.Assert(err, IsNil) + } + // Reset or pollute srcChk on purpose to make sure any AggFunc does + // not hold any memory referenced by srcChk after UpdatePartialResult. + // See issue #11614 and #11626 + srcChk.Reset() + if pollute { + rand.Seed(time.Now().Unix()) + for i := 0; i < p.numRows; i++ { + dt := dataGen(rand.Int()) + srcChk.AppendDatum(0, &dt) + } } + c.Assert(p.numRows, Equals, len(p.results)) for i := 0; i < p.numRows; i++ { - finalFunc.AppendFinalResult2Chunk(s.ctx, finalPr, resultChk) + err = finalFunc.AppendFinalResult2Chunk(s.ctx, finalPr, resultChk) + c.Assert(err, IsNil) dt := resultChk.GetRow(0).GetDatum(0, desc.RetTp) result, err := dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[i]) c.Assert(err, IsNil) @@ -66,6 +87,26 @@ func (s *testSuite) testWindowFunc(c *C, p windowTest) { finalFunc.ResetPartialResult(finalPr) } +func buildWindowTesterWithArgs(funcName string, tp byte, args []expression.Expression, orderByCols int, numRows int, results ...interface{}) windowTest { + pt := windowTest{ + dataType: types.NewFieldType(tp), + numRows: numRows, + funcName: funcName, + } + if funcName != ast.WindowFuncNtile { + pt.args = append(pt.args, &expression.Column{RetType: pt.dataType, Index: 0}) + } + pt.args = append(pt.args, args...) + if orderByCols > 0 { + pt.orderByCols = append(pt.orderByCols, &expression.Column{RetType: pt.dataType, Index: 0}) + } + + for _, result := range results { + pt.results = append(pt.results, types.NewDatum(result)) + } + return pt +} + func buildWindowTester(funcName string, tp byte, constantArg uint64, orderByCols int, numRows int, results ...interface{}) windowTest { pt := windowTest{ dataType: types.NewFieldType(tp), @@ -90,6 +131,7 @@ func buildWindowTester(funcName string, tp byte, constantArg uint64, orderByCols func (s *testSuite) TestWindowFunctions(c *C) { tests := []windowTest{ + buildWindowTester(ast.WindowFuncCumeDist, mysql.TypeLonglong, 0, 1, 1, 1), buildWindowTester(ast.WindowFuncCumeDist, mysql.TypeLonglong, 0, 0, 2, 1, 1), buildWindowTester(ast.WindowFuncCumeDist, mysql.TypeLonglong, 0, 1, 4, 0.25, 0.5, 0.75, 1), @@ -105,23 +147,19 @@ func (s *testSuite) TestWindowFunctions(c *C) { buildWindowTester(ast.WindowFuncFirstValue, mysql.TypeDuration, 0, 1, 2, types.Duration{Duration: time.Duration(0)}, types.Duration{Duration: time.Duration(0)}), buildWindowTester(ast.WindowFuncFirstValue, mysql.TypeJSON, 0, 1, 2, json.CreateBinary(int64(0)), json.CreateBinary(int64(0))), - buildWindowTester(ast.WindowFuncLag, mysql.TypeLonglong, 1, 0, 3, nil, 0, 1), - buildWindowTester(ast.WindowFuncLag, mysql.TypeLonglong, 2, 1, 4, nil, nil, 0, 1), - buildWindowTester(ast.WindowFuncLastValue, mysql.TypeLonglong, 1, 0, 2, 1, 1), - buildWindowTester(ast.WindowFuncLead, mysql.TypeLonglong, 1, 0, 3, 1, 2, nil), - buildWindowTester(ast.WindowFuncLead, mysql.TypeLonglong, 2, 0, 4, 2, 3, nil, nil), - buildWindowTester(ast.WindowFuncNthValue, mysql.TypeLonglong, 2, 0, 3, 1, 1, 1), buildWindowTester(ast.WindowFuncNthValue, mysql.TypeLonglong, 5, 0, 3, nil, nil, nil), buildWindowTester(ast.WindowFuncNtile, mysql.TypeLonglong, 3, 0, 4, 1, 1, 2, 3), buildWindowTester(ast.WindowFuncNtile, mysql.TypeLonglong, 5, 0, 3, 1, 2, 3), + buildWindowTester(ast.WindowFuncPercentRank, mysql.TypeLonglong, 0, 1, 1, 0), buildWindowTester(ast.WindowFuncPercentRank, mysql.TypeLonglong, 0, 0, 3, 0, 0, 0), buildWindowTester(ast.WindowFuncPercentRank, mysql.TypeLonglong, 0, 1, 4, 0, 0.3333333333333333, 0.6666666666666666, 1), + buildWindowTester(ast.WindowFuncRank, mysql.TypeLonglong, 0, 1, 1, 1), buildWindowTester(ast.WindowFuncRank, mysql.TypeLonglong, 0, 0, 3, 1, 1, 1), buildWindowTester(ast.WindowFuncRank, mysql.TypeLonglong, 0, 1, 4, 1, 2, 3, 4), diff --git a/executor/benchmark_test.go b/executor/benchmark_test.go index 2f9d945bedd1c..aff3cd4e375a3 100644 --- a/executor/benchmark_test.go +++ b/executor/benchmark_test.go @@ -18,6 +18,7 @@ import ( "fmt" "math/rand" "sort" + "strings" "testing" "github.com/pingcap/parser/ast" @@ -25,6 +26,7 @@ import ( "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/expression/aggregation" "github.com/pingcap/tidb/planner/core" + "github.com/pingcap/tidb/planner/property" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/types" @@ -87,6 +89,8 @@ func (mds *mockDataSource) genColDatums(col int) (results []interface{}) { return results[i].(int64) < results[j].(int64) case mysql.TypeDouble: return results[i].(float64) < results[j].(float64) + case mysql.TypeVarString: + return results[i].(string) < results[j].(string) default: panic("not implement") } @@ -102,6 +106,8 @@ func (mds *mockDataSource) randDatum(typ *types.FieldType) interface{} { return int64(rand.Int()) case mysql.TypeDouble: return rand.Float64() + case mysql.TypeVarString: + return rawData default: panic("not implement") } @@ -149,6 +155,8 @@ func buildMockDataSource(opt mockDataSourceParameters) *mockDataSource { m.genData[idx].AppendInt64(colIdx, colData[colIdx][i].(int64)) case mysql.TypeDouble: m.genData[idx].AppendFloat64(colIdx, colData[colIdx][i].(float64)) + case mysql.TypeVarString: + m.genData[idx].AppendString(colIdx, colData[colIdx][i].(string)) default: panic("not implement") } @@ -171,11 +179,12 @@ type aggTestCase struct { func (a aggTestCase) columns() []*expression.Column { return []*expression.Column{ {Index: 0, RetType: types.NewFieldType(mysql.TypeDouble)}, - {Index: 1, RetType: types.NewFieldType(mysql.TypeLonglong)}} + {Index: 1, RetType: types.NewFieldType(mysql.TypeLonglong)}, + } } func (a aggTestCase) String() string { - return fmt.Sprintf("(execType:%v, aggFunc:%v, groupByNDV:%v, hasDistinct:%v, rows:%v, concruuency:%v)", + return fmt.Sprintf("(execType:%v, aggFunc:%v, ndv:%v, hasDistinct:%v, rows:%v, concruuency:%v)", a.execType, a.aggFunc, a.groupByNDV, a.hasDistinct, a.rows, a.concurrency) } @@ -350,3 +359,147 @@ func BenchmarkAggDistinct(b *testing.B) { } } } + +func buildWindowExecutor(ctx sessionctx.Context, windowFunc string, src Executor, schema *expression.Schema, partitionBy []*expression.Column) Executor { + plan := new(core.PhysicalWindow) + + var args []expression.Expression + switch windowFunc { + case ast.WindowFuncNtile: + args = append(args, &expression.Constant{Value: types.NewUintDatum(2)}) + case ast.WindowFuncNthValue: + args = append(args, partitionBy[0], &expression.Constant{Value: types.NewUintDatum(2)}) + default: + args = append(args, partitionBy[0]) + } + desc, _ := aggregation.NewWindowFuncDesc(ctx, windowFunc, args) + plan.WindowFuncDescs = []*aggregation.WindowFuncDesc{desc} + for _, col := range partitionBy { + plan.PartitionBy = append(plan.PartitionBy, property.Item{Col: col}) + } + plan.OrderBy = nil + plan.SetSchema(schema) + plan.Init(ctx, nil) + plan.SetChildren(nil) + b := newExecutorBuilder(ctx, nil) + exec := b.build(plan) + window := exec.(*WindowExec) + window.children[0] = src + return exec +} + +type windowTestCase struct { + // The test table's schema is fixed (col Double, partitionBy LongLong, rawData VarString(5128), col LongLong). + windowFunc string + ndv int // the number of distinct group-by keys + rows int + ctx sessionctx.Context +} + +var rawData = strings.Repeat("x", 5*1024) + +func (a windowTestCase) columns() []*expression.Column { + rawDataTp := new(types.FieldType) + types.DefaultTypeForValue(rawData, rawDataTp) + return []*expression.Column{ + {Index: 0, RetType: types.NewFieldType(mysql.TypeDouble)}, + {Index: 1, RetType: types.NewFieldType(mysql.TypeLonglong)}, + {Index: 2, RetType: rawDataTp}, + {Index: 3, RetType: types.NewFieldType(mysql.TypeLonglong)}, + } +} + +func (a windowTestCase) String() string { + return fmt.Sprintf("(func:%v, ndv:%v, rows:%v)", + a.windowFunc, a.ndv, a.rows) +} + +func defaultWindowTestCase() *windowTestCase { + ctx := mock.NewContext() + ctx.GetSessionVars().InitChunkSize = variable.DefInitChunkSize + ctx.GetSessionVars().MaxChunkSize = variable.DefMaxChunkSize + return &windowTestCase{ast.WindowFuncRowNumber, 1000, 10000000, ctx} +} + +func benchmarkWindowExecWithCase(b *testing.B, casTest *windowTestCase) { + cols := casTest.columns() + dataSource := buildMockDataSource(mockDataSourceParameters{ + schema: expression.NewSchema(cols...), + ndvs: []int{0, casTest.ndv, 0, 0}, + orders: []bool{false, true, false, false}, + rows: casTest.rows, + ctx: casTest.ctx, + }) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + b.StopTimer() // prepare a new window-executor + childCols := casTest.columns() + schema := expression.NewSchema(childCols...) + windowExec := buildWindowExecutor(casTest.ctx, casTest.windowFunc, dataSource, schema, childCols[1:2]) + tmpCtx := context.Background() + chk := newFirstChunk(windowExec) + dataSource.prepareChunks() + + b.StartTimer() + if err := windowExec.Open(tmpCtx); err != nil { + b.Fatal(err) + } + for { + if err := windowExec.Next(tmpCtx, chk); err != nil { + b.Fatal(b) + } + if chk.NumRows() == 0 { + break + } + } + + if err := windowExec.Close(); err != nil { + b.Fatal(err) + } + b.StopTimer() + } +} + +func BenchmarkWindowRows(b *testing.B) { + b.ReportAllocs() + rows := []int{1000, 100000} + ndvs := []int{10, 1000} + for _, row := range rows { + for _, ndv := range ndvs { + cas := defaultWindowTestCase() + cas.rows = row + cas.ndv = ndv + cas.windowFunc = ast.WindowFuncRowNumber // cheapest + b.Run(fmt.Sprintf("%v", cas), func(b *testing.B) { + benchmarkWindowExecWithCase(b, cas) + }) + } + } +} + +func BenchmarkWindowFunctions(b *testing.B) { + b.ReportAllocs() + windowFuncs := []string{ + ast.WindowFuncRowNumber, + ast.WindowFuncRank, + ast.WindowFuncDenseRank, + ast.WindowFuncCumeDist, + ast.WindowFuncPercentRank, + ast.WindowFuncNtile, + ast.WindowFuncLead, + ast.WindowFuncLag, + ast.WindowFuncFirstValue, + ast.WindowFuncLastValue, + ast.WindowFuncNthValue, + } + for _, windowFunc := range windowFuncs { + cas := defaultWindowTestCase() + cas.rows = 100000 + cas.ndv = 1000 + cas.windowFunc = windowFunc + b.Run(fmt.Sprintf("%v", cas), func(b *testing.B) { + benchmarkWindowExecWithCase(b, cas) + }) + } +} diff --git a/executor/window_test.go b/executor/window_test.go index a5b2880550583..736be720cbbad 100644 --- a/executor/window_test.go +++ b/executor/window_test.go @@ -160,6 +160,12 @@ func (s *testSuite4) TestWindowFunctions(c *C) { ), ) + tk.MustExec("CREATE TABLE td_dec (id DECIMAL(10,2), sex CHAR(1));") + tk.MustExec("insert into td_dec value (2.0, 'F'), (NULL, 'F'), (1.0, 'F')") + tk.MustQuery("SELECT id, FIRST_VALUE(id) OVER w FROM td_dec WINDOW w AS (ORDER BY id);").Check( + testkit.Rows(" ", "1.00 ", "2.00 "), + ) + result = tk.MustQuery("select sum(a) over w, sum(b) over w from t window w as (order by a)") result.Check(testkit.Rows("2 3", "2 3", "6 6", "6 6")) result = tk.MustQuery("select row_number() over w, sum(b) over w from t window w as (order by a)") @@ -171,3 +177,19 @@ func (s *testSuite4) TestWindowFunctions(c *C) { result = tk.MustQuery("select a, row_number() over (partition by a) from t") result.Check(testkit.Rows("1 1", "1 2", "2 1", "2 2")) } + +func (s *testSuite4) TestWindowFunctionsIssue11614(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a int, b int)") + tk.MustExec("insert into t values (2,1),(2,2),(2,3)") + + tk.Se.GetSessionVars().MaxChunkSize = 2 + result := tk.MustQuery("select a, b, rank() over (partition by a order by b) from t") + result.Check(testkit.Rows("2 1 1", "2 2 2", "2 3 3")) + result = tk.MustQuery("select a, b, PERCENT_RANK() over (partition by a order by b) from t") + result.Check(testkit.Rows("2 1 0", "2 2 0.5", "2 3 1")) + result = tk.MustQuery("select a, b, CUME_DIST() over (partition by a order by b) from t") + result.Check(testkit.Rows("2 1 0.3333333333333333", "2 2 0.6666666666666666", "2 3 1")) +} diff --git a/session/tidb.go b/session/tidb.go index 1bcfed18ee25f..eaa5aedfbe39d 100644 --- a/session/tidb.go +++ b/session/tidb.go @@ -275,10 +275,8 @@ func GetRows4Test(ctx context.Context, sctx sessionctx.Context, rs sqlexec.Recor } var rows []chunk.Row req := rs.NewChunk() + // Must reuse `req` for imitating server.(*clientConn).writeChunks for { - // Since we collect all the rows, we can not reuse the chunk. - iter := chunk.NewIterator4Chunk(req) - err := rs.Next(ctx, req) if err != nil { return nil, err @@ -287,10 +285,10 @@ func GetRows4Test(ctx context.Context, sctx sessionctx.Context, rs sqlexec.Recor break } + iter := chunk.NewIterator4Chunk(req.CopyConstruct()) for row := iter.Begin(); row != iter.End(); row = iter.Next() { rows = append(rows, row) } - req = chunk.Renew(req, sctx.GetSessionVars().MaxChunkSize) } return rows, nil } diff --git a/types/mydecimal.go b/types/mydecimal.go index 996323fa8706f..1823bacdb2b25 100644 --- a/types/mydecimal.go +++ b/types/mydecimal.go @@ -250,6 +250,21 @@ func (d *MyDecimal) GetDigitsFrac() int8 { return d.digitsFrac } +// Copy copies a new *MyDecimal from itself. +func (d *MyDecimal) Copy() *MyDecimal { + if d == nil { + return nil + } + dst := &MyDecimal{ + digitsInt: d.digitsInt, + digitsFrac: d.digitsFrac, + resultFrac: d.resultFrac, + negative: d.negative, + } + copy(dst.wordBuf[:], d.wordBuf[:]) + return dst +} + // String returns the decimal string representation rounded to resultFrac. func (d *MyDecimal) String() string { tmp := *d diff --git a/types/mydecimal_test.go b/types/mydecimal_test.go index 551105987b3d0..b2cb0acfd6b3c 100644 --- a/types/mydecimal_test.go +++ b/types/mydecimal_test.go @@ -14,6 +14,7 @@ package types import ( + "reflect" "strings" "testing" @@ -540,6 +541,33 @@ func (s *testMyDecimalSuite) TestToString(c *C) { } } +func (s *testMyDecimalSuite) TestCopy(c *C) { + type tcase struct { + input string + } + tests := []tcase{ + {".0"}, + {".123"}, + {"123.123"}, + {"123."}, + {"123"}, + {"123.1230"}, + {"-123.1230"}, + {"00123.123"}, + } + for _, ca := range tests { + var dec MyDecimal + err := dec.FromString([]byte(ca.input)) + c.Assert(err, IsNil) + + dec2 := dec.Copy() + c.Assert(reflect.DeepEqual(dec, *dec2), IsTrue) + } + + var dec *MyDecimal + c.Assert(dec.Copy(), IsNil) +} + func (s *testMyDecimalSuite) TestToBinFromBin(c *C) { type tcase struct { input string