diff --git a/statistics/BUILD.bazel b/statistics/BUILD.bazel index e6992020197c3..2a26bfb76f34d 100644 --- a/statistics/BUILD.bazel +++ b/statistics/BUILD.bazel @@ -66,6 +66,7 @@ go_test( name = "statistics_test", timeout = "short", srcs = [ + "cmsketch_bench_test.go", "cmsketch_test.go", "feedback_test.go", "fmsketch_test.go", diff --git a/statistics/cmsketch_bench_test.go b/statistics/cmsketch_bench_test.go new file mode 100644 index 0000000000000..08666c4c2c3db --- /dev/null +++ b/statistics/cmsketch_bench_test.go @@ -0,0 +1,161 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package statistics_test + +import ( + "fmt" + "testing" + "time" + + "github.com/pingcap/tidb/parser/mysql" + "github.com/pingcap/tidb/sessionctx/stmtctx" + "github.com/pingcap/tidb/statistics" + "github.com/pingcap/tidb/statistics/handle" + "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/codec" + "github.com/stretchr/testify/require" +) + +// cmd: go test -run=^$ -bench=BenchmarkMergePartTopN2GlobalTopNWithHists -benchmem github.com/pingcap/tidb/statistics +func benchmarkMergePartTopN2GlobalTopNWithHists(partitions int, b *testing.B) { + loc := time.UTC + sc := &stmtctx.StatementContext{TimeZone: loc} + version := 1 + isKilled := uint32(0) + + // Prepare TopNs. + topNs := make([]*statistics.TopN, 0, partitions) + for i := 0; i < partitions; i++ { + // Construct TopN, should be key1 -> 2, key2 -> 2, key3 -> 3. + topN := statistics.NewTopN(3) + { + key1, err := codec.EncodeKey(sc, nil, types.NewIntDatum(1)) + require.NoError(b, err) + topN.AppendTopN(key1, 2) + key2, err := codec.EncodeKey(sc, nil, types.NewIntDatum(2)) + require.NoError(b, err) + topN.AppendTopN(key2, 2) + if i%2 == 0 { + key3, err := codec.EncodeKey(sc, nil, types.NewIntDatum(3)) + require.NoError(b, err) + topN.AppendTopN(key3, 3) + } + } + topNs = append(topNs, topN) + } + + // Prepare Hists. + hists := make([]*statistics.Histogram, 0, partitions) + for i := 0; i < partitions; i++ { + // Construct Hist + h := statistics.NewHistogram(1, 10, 0, 0, types.NewFieldType(mysql.TypeTiny), chunk.InitialCapacity, 0) + h.Bounds.AppendInt64(0, 1) + h.Buckets = append(h.Buckets, statistics.Bucket{Repeat: 10, Count: 20}) + h.Bounds.AppendInt64(0, 2) + h.Buckets = append(h.Buckets, statistics.Bucket{Repeat: 10, Count: 30}) + h.Bounds.AppendInt64(0, 3) + h.Buckets = append(h.Buckets, statistics.Bucket{Repeat: 10, Count: 30}) + h.Bounds.AppendInt64(0, 4) + h.Buckets = append(h.Buckets, statistics.Bucket{Repeat: 10, Count: 40}) + hists = append(hists, h) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + // Benchmark merge 10 topN. + _, _, _, _ = statistics.MergePartTopN2GlobalTopN(loc, version, topNs, 10, hists, false, &isKilled) + } +} + +// cmd: go test -run=^$ -bench=BenchmarkMergeGlobalStatsTopNByConcurrencyWithHists -benchmem github.com/pingcap/tidb/statistics +func benchmarkMergeGlobalStatsTopNByConcurrencyWithHists(partitions int, b *testing.B) { + loc := time.UTC + sc := &stmtctx.StatementContext{TimeZone: loc} + version := 1 + isKilled := uint32(0) + + // Prepare TopNs. + topNs := make([]*statistics.TopN, 0, partitions) + for i := 0; i < partitions; i++ { + // Construct TopN, should be key1 -> 2, key2 -> 2, key3 -> 3. + topN := statistics.NewTopN(3) + { + key1, err := codec.EncodeKey(sc, nil, types.NewIntDatum(1)) + require.NoError(b, err) + topN.AppendTopN(key1, 2) + key2, err := codec.EncodeKey(sc, nil, types.NewIntDatum(2)) + require.NoError(b, err) + topN.AppendTopN(key2, 2) + if i%2 == 0 { + key3, err := codec.EncodeKey(sc, nil, types.NewIntDatum(3)) + require.NoError(b, err) + topN.AppendTopN(key3, 3) + } + } + topNs = append(topNs, topN) + } + + // Prepare Hists. + hists := make([]*statistics.Histogram, 0, partitions) + for i := 0; i < partitions; i++ { + // Construct Hist + h := statistics.NewHistogram(1, 10, 0, 0, types.NewFieldType(mysql.TypeTiny), chunk.InitialCapacity, 0) + h.Bounds.AppendInt64(0, 1) + h.Buckets = append(h.Buckets, statistics.Bucket{Repeat: 10, Count: 20}) + h.Bounds.AppendInt64(0, 2) + h.Buckets = append(h.Buckets, statistics.Bucket{Repeat: 10, Count: 30}) + h.Bounds.AppendInt64(0, 3) + h.Buckets = append(h.Buckets, statistics.Bucket{Repeat: 10, Count: 30}) + h.Bounds.AppendInt64(0, 4) + h.Buckets = append(h.Buckets, statistics.Bucket{Repeat: 10, Count: 40}) + hists = append(hists, h) + } + wrapper := &statistics.StatsWrapper{ + AllTopN: topNs, + AllHg: hists, + } + const mergeConcurrency = 4 + batchSize := len(wrapper.AllTopN) / mergeConcurrency + if batchSize < 1 { + batchSize = 1 + } else if batchSize > handle.MaxPartitionMergeBatchSize { + batchSize = handle.MaxPartitionMergeBatchSize + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + // Benchmark merge 10 topN. + _, _, _, _ = handle.MergeGlobalStatsTopNByConcurrency(mergeConcurrency, batchSize, wrapper, loc, version, 10, false, &isKilled) + } +} + +var benchmarkSizes = []int{100, 1000, 10000, 100000, 1000000, 10000000} +var benchmarkConcurrencySizes = []int{100, 1000, 10000, 100000, 1000000, 10000000, 100000000} + +func BenchmarkMergePartTopN2GlobalTopNWithHists(b *testing.B) { + for _, size := range benchmarkSizes { + b.Run(fmt.Sprintf("Size%d", size), func(b *testing.B) { + benchmarkMergePartTopN2GlobalTopNWithHists(size, b) + }) + } +} + +func BenchmarkMergeGlobalStatsTopNByConcurrencyWithHists(b *testing.B) { + for _, size := range benchmarkConcurrencySizes { + b.Run(fmt.Sprintf("Size%d", size), func(b *testing.B) { + benchmarkMergeGlobalStatsTopNByConcurrencyWithHists(size, b) + }) + } +} diff --git a/statistics/cmsketch_test.go b/statistics/cmsketch_test.go index 9675b0d0dcd50..1585342d8826b 100644 --- a/statistics/cmsketch_test.go +++ b/statistics/cmsketch_test.go @@ -306,3 +306,87 @@ func TestCMSketchCodingTopN(t *testing.T) { _, _, err = DecodeCMSketchAndTopN([]byte{}, rows) require.NoError(t, err) } + +func TestMergePartTopN2GlobalTopNWithoutHists(t *testing.T) { + loc := time.UTC + sc := &stmtctx.StatementContext{TimeZone: loc} + version := 1 + isKilled := uint32(0) + + // Prepare TopNs. + topNs := make([]*TopN, 0, 10) + for i := 0; i < 10; i++ { + // Construct TopN, should be key(1, 1) -> 2, key(1, 2) -> 2, key(1, 3) -> 3. + topN := NewTopN(3) + { + key1, err := codec.EncodeKey(sc, nil, types.NewIntDatum(1), types.NewIntDatum(1)) + require.NoError(t, err) + topN.AppendTopN(key1, 2) + key2, err := codec.EncodeKey(sc, nil, types.NewIntDatum(1), types.NewIntDatum(2)) + require.NoError(t, err) + topN.AppendTopN(key2, 2) + key3, err := codec.EncodeKey(sc, nil, types.NewIntDatum(1), types.NewIntDatum(3)) + require.NoError(t, err) + topN.AppendTopN(key3, 3) + } + topNs = append(topNs, topN) + } + + // Test merge 2 topN with nil hists. + globalTopN, leftTopN, _, err := MergePartTopN2GlobalTopN(loc, version, topNs, 2, nil, false, &isKilled) + require.NoError(t, err) + require.Len(t, globalTopN.TopN, 2, "should only have 2 topN") + require.Equal(t, uint64(50), globalTopN.TotalCount(), "should have 50 rows") + require.Len(t, leftTopN, 1, "should have 1 left topN") +} + +func TestMergePartTopN2GlobalTopNWithHists(t *testing.T) { + loc := time.UTC + sc := &stmtctx.StatementContext{TimeZone: loc} + version := 1 + isKilled := uint32(0) + + // Prepare TopNs. + topNs := make([]*TopN, 0, 10) + for i := 0; i < 10; i++ { + // Construct TopN, should be key1 -> 2, key2 -> 2, key3 -> 3. + topN := NewTopN(3) + { + key1, err := codec.EncodeKey(sc, nil, types.NewIntDatum(1)) + require.NoError(t, err) + topN.AppendTopN(key1, 2) + key2, err := codec.EncodeKey(sc, nil, types.NewIntDatum(2)) + require.NoError(t, err) + topN.AppendTopN(key2, 2) + if i%2 == 0 { + key3, err := codec.EncodeKey(sc, nil, types.NewIntDatum(3)) + require.NoError(t, err) + topN.AppendTopN(key3, 3) + } + } + topNs = append(topNs, topN) + } + + // Prepare Hists. + hists := make([]*Histogram, 0, 10) + for i := 0; i < 10; i++ { + // Construct Hist + h := NewHistogram(1, 10, 0, 0, types.NewFieldType(mysql.TypeTiny), chunk.InitialCapacity, 0) + h.Bounds.AppendInt64(0, 1) + h.Buckets = append(h.Buckets, Bucket{Repeat: 10, Count: 20}) + h.Bounds.AppendInt64(0, 2) + h.Buckets = append(h.Buckets, Bucket{Repeat: 10, Count: 30}) + h.Bounds.AppendInt64(0, 3) + h.Buckets = append(h.Buckets, Bucket{Repeat: 10, Count: 30}) + h.Bounds.AppendInt64(0, 4) + h.Buckets = append(h.Buckets, Bucket{Repeat: 10, Count: 40}) + hists = append(hists, h) + } + + // Test merge 2 topN. + globalTopN, leftTopN, _, err := MergePartTopN2GlobalTopN(loc, version, topNs, 2, hists, false, &isKilled) + require.NoError(t, err) + require.Len(t, globalTopN.TopN, 2, "should only have 2 topN") + require.Equal(t, uint64(55), globalTopN.TotalCount(), "should have 55") + require.Len(t, leftTopN, 1, "should have 1 left topN") +} diff --git a/statistics/handle/handle.go b/statistics/handle/handle.go index a99bcad173dca..26db0a5d4d3de 100644 --- a/statistics/handle/handle.go +++ b/statistics/handle/handle.go @@ -59,8 +59,8 @@ const ( // TiDBGlobalStats represents the global-stats for a partitioned table. TiDBGlobalStats = "global" - // maxPartitionMergeBatchSize indicates the max batch size for a worker to merge partition stats - maxPartitionMergeBatchSize = 256 + // MaxPartitionMergeBatchSize indicates the max batch size for a worker to merge partition stats + MaxPartitionMergeBatchSize = 256 ) // Handle can update stats info periodically. @@ -831,7 +831,7 @@ func (h *Handle) mergePartitionStats2GlobalStats(sc sessionctx.Context, // These remaining topN numbers will be used as a separate bucket for later histogram merging. var popedTopN []statistics.TopNMeta wrapper := statistics.NewStatsWrapper(allHg[i], allTopN[i]) - globalStats.TopN[i], popedTopN, allHg[i], err = h.mergeGlobalStatsTopN(sc, wrapper, sc.GetSessionVars().StmtCtx.TimeZone, sc.GetSessionVars().AnalyzeVersion, uint32(opts[ast.AnalyzeOptNumTopN]), isIndex == 1) + globalStats.TopN[i], popedTopN, allHg[i], err = mergeGlobalStatsTopN(sc, wrapper, sc.GetSessionVars().StmtCtx.TimeZone, sc.GetSessionVars().AnalyzeVersion, uint32(opts[ast.AnalyzeOptNumTopN]), isIndex == 1) if err != nil { return } @@ -863,7 +863,7 @@ func (h *Handle) mergePartitionStats2GlobalStats(sc sessionctx.Context, return } -func (h *Handle) mergeGlobalStatsTopN(sc sessionctx.Context, wrapper *statistics.StatsWrapper, +func mergeGlobalStatsTopN(sc sessionctx.Context, wrapper *statistics.StatsWrapper, timeZone *time.Location, version int, n uint32, isIndex bool) (*statistics.TopN, []statistics.TopNMeta, []*statistics.Histogram, error) { mergeConcurrency := sc.GetSessionVars().AnalyzePartitionMergeConcurrency @@ -875,17 +875,17 @@ func (h *Handle) mergeGlobalStatsTopN(sc sessionctx.Context, wrapper *statistics batchSize := len(wrapper.AllTopN) / mergeConcurrency if batchSize < 1 { batchSize = 1 - } else if batchSize > maxPartitionMergeBatchSize { - batchSize = maxPartitionMergeBatchSize + } else if batchSize > MaxPartitionMergeBatchSize { + batchSize = MaxPartitionMergeBatchSize } - return h.mergeGlobalStatsTopNByConcurrency(mergeConcurrency, batchSize, wrapper, timeZone, version, n, isIndex, killed) + return MergeGlobalStatsTopNByConcurrency(mergeConcurrency, batchSize, wrapper, timeZone, version, n, isIndex, killed) } -// mergeGlobalStatsTopNByConcurrency merge partition topN by concurrency +// MergeGlobalStatsTopNByConcurrency merge partition topN by concurrency // To merge global stats topn by concurrency, we will separate the partition topn in concurrency part and deal it with different worker. // mergeConcurrency is used to control the total concurrency of the running worker, and mergeBatchSize is sued to control // the partition size for each worker to solve it -func (h *Handle) mergeGlobalStatsTopNByConcurrency(mergeConcurrency, mergeBatchSize int, wrapper *statistics.StatsWrapper, +func MergeGlobalStatsTopNByConcurrency(mergeConcurrency, mergeBatchSize int, wrapper *statistics.StatsWrapper, timeZone *time.Location, version int, n uint32, isIndex bool, killed *uint32) (*statistics.TopN, []statistics.TopNMeta, []*statistics.Histogram, error) { if len(wrapper.AllTopN) < mergeConcurrency {