From 210bc420e2131753340fd7b0ac7a4039ab0847c5 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 19 Dec 2023 06:19:55 +0000 Subject: [PATCH 1/9] build(deps): bump github.com/prometheus/prometheus from 0.48.0 to 0.48.1 (#49488) --- DEPS.bzl | 12 ++++++------ go.mod | 2 +- go.sum | 4 ++-- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/DEPS.bzl b/DEPS.bzl index 26a3771a8d075..b678f49f170de 100644 --- a/DEPS.bzl +++ b/DEPS.bzl @@ -6109,13 +6109,13 @@ def go_deps(): name = "com_github_prometheus_prometheus", build_file_proto_mode = "disable_global", importpath = "github.com/prometheus/prometheus", - sha256 = "57ac0b06c05da5d42f831e52250f3bc63d2fc6785cd9f21ca79534f1900aeb19", - strip_prefix = "github.com/prometheus/prometheus@v0.48.0", + sha256 = "942dba743bc78a6933cc9c2fbcc3d1d301254d3fd343975476ecd73573866f6e", + strip_prefix = "github.com/prometheus/prometheus@v0.48.1", urls = [ - "http://bazel-cache.pingcap.net:8080/gomod/github.com/prometheus/prometheus/com_github_prometheus_prometheus-v0.48.0.zip", - "http://ats.apps.svc/gomod/github.com/prometheus/prometheus/com_github_prometheus_prometheus-v0.48.0.zip", - "https://cache.hawkingrei.com/gomod/github.com/prometheus/prometheus/com_github_prometheus_prometheus-v0.48.0.zip", - "https://storage.googleapis.com/pingcapmirror/gomod/github.com/prometheus/prometheus/com_github_prometheus_prometheus-v0.48.0.zip", + "http://bazel-cache.pingcap.net:8080/gomod/github.com/prometheus/prometheus/com_github_prometheus_prometheus-v0.48.1.zip", + "http://ats.apps.svc/gomod/github.com/prometheus/prometheus/com_github_prometheus_prometheus-v0.48.1.zip", + "https://cache.hawkingrei.com/gomod/github.com/prometheus/prometheus/com_github_prometheus_prometheus-v0.48.1.zip", + "https://storage.googleapis.com/pingcapmirror/gomod/github.com/prometheus/prometheus/com_github_prometheus_prometheus-v0.48.1.zip", ], ) go_repository( diff --git a/go.mod b/go.mod index b3c440ebf4b75..e969f60fcaa6e 100644 --- a/go.mod +++ b/go.mod @@ -89,7 +89,7 @@ require ( github.com/prometheus/client_golang v1.17.0 github.com/prometheus/client_model v0.5.0 github.com/prometheus/common v0.45.0 - github.com/prometheus/prometheus v0.48.0 + github.com/prometheus/prometheus v0.48.1 github.com/robfig/cron/v3 v3.0.1 github.com/sasha-s/go-deadlock v0.2.0 github.com/shirou/gopsutil/v3 v3.23.10 diff --git a/go.sum b/go.sum index 0428b3dd7e46e..aeafaaa2a5b28 100644 --- a/go.sum +++ b/go.sum @@ -750,8 +750,8 @@ github.com/prometheus/common/sigv4 v0.1.0/go.mod h1:2Jkxxk9yYvCkE5G1sQT7GuEXm57J github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/prometheus/procfs v0.12.0 h1:jluTpSng7V9hY0O2R9DzzJHYb2xULk9VTR1V1R/k6Bo= github.com/prometheus/procfs v0.12.0/go.mod h1:pcuDEFsWDnvcgNzo4EEweacyhjeA9Zk3cnaOZAZEfOo= -github.com/prometheus/prometheus v0.48.0 h1:yrBloImGQ7je4h8M10ujGh4R6oxYQJQKlMuETwNskGk= -github.com/prometheus/prometheus v0.48.0/go.mod h1:SRw624aMAxTfryAcP8rOjg4S/sHHaetx2lyJJ2nM83g= +github.com/prometheus/prometheus v0.48.1 h1:CTszphSNTXkuCG6O0IfpKdHcJkvvnAAE1GbELKS+NFk= +github.com/prometheus/prometheus v0.48.1/go.mod h1:SRw624aMAxTfryAcP8rOjg4S/sHHaetx2lyJJ2nM83g= github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 h1:N/ElC8H3+5XpJzTSTfLsJV/mx9Q9g7kxmchpfZyxgzM= github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= From 971b611fbb3f2c4884fc5d912621a8b97132c701 Mon Sep 17 00:00:00 2001 From: EasonBall <592838129@qq.com> Date: Tue, 19 Dec 2023 16:01:52 +0800 Subject: [PATCH 2/9] global sort: new merge with 1 writer (#49474) ref pingcap/tidb#48779 --- br/pkg/lightning/backend/external/BUILD.bazel | 4 + .../lightning/backend/external/bench_test.go | 5 +- br/pkg/lightning/backend/external/engine.go | 2 +- br/pkg/lightning/backend/external/merge.go | 66 +----- br/pkg/lightning/backend/external/merge_v2.go | 190 ++++++++++++++++++ .../backend/external/onefile_writer.go | 3 +- .../backend/external/onefile_writer_test.go | 45 ++--- .../lightning/backend/external/reader_test.go | 3 +- .../lightning/backend/external/sort_test.go | 106 +++++++++- br/pkg/lightning/backend/external/testutil.go | 36 +++- br/pkg/lightning/backend/external/writer.go | 2 - .../lightning/backend/external/writer_test.go | 62 ++++++ 12 files changed, 419 insertions(+), 105 deletions(-) create mode 100644 br/pkg/lightning/backend/external/merge_v2.go diff --git a/br/pkg/lightning/backend/external/BUILD.bazel b/br/pkg/lightning/backend/external/BUILD.bazel index fd540a71a8f2e..89950d8348413 100644 --- a/br/pkg/lightning/backend/external/BUILD.bazel +++ b/br/pkg/lightning/backend/external/BUILD.bazel @@ -11,6 +11,7 @@ go_library( "iter.go", "kv_reader.go", "merge.go", + "merge_v2.go", "onefile_writer.go", "reader.go", "split.go", @@ -73,6 +74,7 @@ go_test( deps = [ "//br/pkg/lightning/backend/kv", "//br/pkg/lightning/common", + "//br/pkg/lightning/log", "//br/pkg/membuf", "//br/pkg/storage", "//pkg/kv", @@ -91,10 +93,12 @@ go_test( "@com_github_johannesboyne_gofakes3//:gofakes3", "@com_github_johannesboyne_gofakes3//backend/s3mem", "@com_github_pingcap_errors//:errors", + "@com_github_pingcap_failpoint//:failpoint", "@com_github_pingcap_kvproto//pkg/brpb", "@com_github_stretchr_testify//require", "@org_golang_x_exp//rand", "@org_golang_x_sync//errgroup", "@org_uber_go_atomic//:atomic", + "@org_uber_go_zap//:zap", ], ) diff --git a/br/pkg/lightning/backend/external/bench_test.go b/br/pkg/lightning/backend/external/bench_test.go index 95b00aafa50d0..fe7cfcd6327b1 100644 --- a/br/pkg/lightning/backend/external/bench_test.go +++ b/br/pkg/lightning/backend/external/bench_test.go @@ -361,7 +361,7 @@ func writeExternalOneFile(s *writeTestSuite) { } writer := builder.BuildOneFile( s.store, filePath, "writerID") - _ = writer.Init(ctx, 20*1024*1024) + intest.AssertNoError(writer.Init(ctx, 20*1024*1024)) key, val, _ := s.source.next() for key != nil { err := writer.WriteRow(ctx, key, val) @@ -371,8 +371,7 @@ func writeExternalOneFile(s *writeTestSuite) { if s.beforeWriterClose != nil { s.beforeWriterClose() } - err := writer.Close(ctx) - intest.AssertNoError(err) + intest.AssertNoError(writer.Close(ctx)) if s.afterWriterClose != nil { s.afterWriterClose() } diff --git a/br/pkg/lightning/backend/external/engine.go b/br/pkg/lightning/backend/external/engine.go index f51289b96b405..8ff4c5236bd6b 100644 --- a/br/pkg/lightning/backend/external/engine.go +++ b/br/pkg/lightning/backend/external/engine.go @@ -193,7 +193,7 @@ func getFilesReadConcurrency( for i := range statsFiles { result[i] = (endOffs[i] - startOffs[i]) / uint64(ConcurrentReaderBufferSizePerConc) result[i] = max(result[i], 1) - logutil.Logger(ctx).Info("found hotspot file in getFilesReadConcurrency", + logutil.Logger(ctx).Debug("found hotspot file in getFilesReadConcurrency", zap.String("filename", statsFiles[i]), zap.Uint64("startOffset", startOffs[i]), zap.Uint64("endOffset", endOffs[i]), diff --git a/br/pkg/lightning/backend/external/merge.go b/br/pkg/lightning/backend/external/merge.go index 3a931d55a0000..03c1a57ea0cad 100644 --- a/br/pkg/lightning/backend/external/merge.go +++ b/br/pkg/lightning/backend/external/merge.go @@ -53,7 +53,7 @@ func MergeOverlappingFiles( for _, files := range dataFilesSlice { files := files eg.Go(func() error { - return mergeOverlappingFilesV2( + return mergeOverlappingFilesInternal( egCtx, files, store, @@ -74,68 +74,9 @@ func MergeOverlappingFiles( return eg.Wait() } -// unused for now. -func mergeOverlappingFilesImpl(ctx context.Context, - paths []string, - store storage.ExternalStorage, - readBufferSize int, - newFilePrefix string, - writerID string, - memSizeLimit uint64, - blockSize int, - writeBatchCount uint64, - propSizeDist uint64, - propKeysDist uint64, - onClose OnCloseFunc, - checkHotspot bool, -) (err error) { - task := log.BeginTask(logutil.Logger(ctx).With( - zap.String("writer-id", writerID), - zap.Int("file-count", len(paths)), - ), "merge overlapping files") - defer func() { - task.End(zap.ErrorLevel, err) - }() - - zeroOffsets := make([]uint64, len(paths)) - iter, err := NewMergeKVIter(ctx, paths, zeroOffsets, store, readBufferSize, checkHotspot, 0) - if err != nil { - return err - } - defer func() { - err := iter.Close() - if err != nil { - logutil.Logger(ctx).Warn("close iterator failed", zap.Error(err)) - } - }() - - writer := NewWriterBuilder(). - SetMemorySizeLimit(memSizeLimit). - SetBlockSize(blockSize). - SetOnCloseFunc(onClose). - SetWriterBatchCount(writeBatchCount). - SetPropSizeDistance(propSizeDist). - SetPropKeysDistance(propKeysDist). - Build(store, newFilePrefix, writerID) - - // currently use same goroutine to do read and write. The main advantage is - // there's no KV copy and iter can reuse the buffer. - for iter.Next() { - err = writer.WriteRow(ctx, iter.Key(), iter.Value(), nil) - if err != nil { - return err - } - } - err = iter.Error() - if err != nil { - return err - } - return writer.Close(ctx) -} - -// mergeOverlappingFilesV2 reads from given files whose key range may overlap +// mergeOverlappingFilesInternal reads from given files whose key range may overlap // and writes to one new sorted, nonoverlapping files. -func mergeOverlappingFilesV2( +func mergeOverlappingFilesInternal( ctx context.Context, paths []string, store storage.ExternalStorage, @@ -177,7 +118,6 @@ func mergeOverlappingFilesV2( SetWriterBatchCount(writeBatchCount). SetPropKeysDistance(propKeysDist). SetPropSizeDistance(propSizeDist). - SetOnCloseFunc(onClose). BuildOneFile(store, newFilePrefix, writerID) err = writer.Init(ctx, partSize) if err != nil { diff --git a/br/pkg/lightning/backend/external/merge_v2.go b/br/pkg/lightning/backend/external/merge_v2.go new file mode 100644 index 0000000000000..24ee8e5c57afa --- /dev/null +++ b/br/pkg/lightning/backend/external/merge_v2.go @@ -0,0 +1,190 @@ +package external + +import ( + "bytes" + "context" + "math" + "time" + + "github.com/jfcg/sorty/v2" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/br/pkg/lightning/log" + "github.com/pingcap/tidb/br/pkg/membuf" + "github.com/pingcap/tidb/br/pkg/storage" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/size" + "go.uber.org/zap" +) + +// MergeOverlappingFilesV2 reads from given files whose key range may overlap +// and writes to new sorted, nonoverlapping files. +// Using 1 readAllData and 1 writer. +func MergeOverlappingFilesV2( + ctx context.Context, + dataFiles []string, + statFiles []string, + store storage.ExternalStorage, + startKey []byte, + endKey []byte, + partSize int64, + newFilePrefix string, + writerID string, + blockSize int, + writeBatchCount uint64, + propSizeDist uint64, + propKeysDist uint64, + onClose OnCloseFunc, + concurrency int, + checkHotspot bool, +) (err error) { + task := log.BeginTask(logutil.Logger(ctx).With( + zap.Int("data-file-count", len(dataFiles)), + zap.Int("stat-file-count", len(statFiles)), + zap.Binary("start-key", startKey), + zap.Binary("end-key", endKey), + zap.String("new-file-prefix", newFilePrefix), + zap.Int("concurrency", concurrency), + ), "merge overlapping files") + defer func() { + task.End(zap.ErrorLevel, err) + }() + + rangesGroupSize := 4 * size.GB + failpoint.Inject("mockRangesGroupSize", func(val failpoint.Value) { + rangesGroupSize = uint64(val.(int)) + }) + + splitter, err := NewRangeSplitter( + ctx, + dataFiles, + statFiles, + store, + int64(rangesGroupSize), + math.MaxInt64, + int64(4*size.GB), + math.MaxInt64, + checkHotspot, + ) + if err != nil { + return err + } + + writer := NewWriterBuilder(). + SetMemorySizeLimit(DefaultMemSizeLimit). + SetBlockSize(blockSize). + SetPropKeysDistance(propKeysDist). + SetPropSizeDistance(propSizeDist). + SetOnCloseFunc(onClose). + BuildOneFile( + store, + newFilePrefix, + writerID) + defer func() { + err = splitter.Close() + if err != nil { + logutil.Logger(ctx).Warn("close range splitter failed", zap.Error(err)) + } + err = writer.Close(ctx) + if err != nil { + logutil.Logger(ctx).Warn("close writer failed", zap.Error(err)) + } + }() + + err = writer.Init(ctx, partSize) + if err != nil { + logutil.Logger(ctx).Warn("init writer failed", zap.Error(err)) + return + } + + bufPool := membuf.NewPool() + loaded := &memKVsAndBuffers{} + curStart := kv.Key(startKey).Clone() + var curEnd kv.Key + var maxKey, minKey kv.Key + + for { + endKeyOfGroup, dataFilesOfGroup, statFilesOfGroup, _, err1 := splitter.SplitOneRangesGroup() + if err1 != nil { + logutil.Logger(ctx).Warn("split one ranges group failed", zap.Error(err1)) + return + } + curEnd = kv.Key(endKeyOfGroup).Clone() + if len(endKeyOfGroup) == 0 { + curEnd = kv.Key(endKey).Clone() + } + now := time.Now() + err1 = readAllData( + ctx, + store, + dataFilesOfGroup, + statFilesOfGroup, + curStart, + curEnd, + bufPool, + loaded, + ) + if err1 != nil { + logutil.Logger(ctx).Warn("read all data failed", zap.Error(err1)) + return + } + readTime := time.Since(now) + now = time.Now() + sorty.MaxGor = uint64(concurrency) + sorty.Sort(len(loaded.keys), func(i, k, r, s int) bool { + if bytes.Compare(loaded.keys[i], loaded.keys[k]) < 0 { // strict comparator like < or > + if r != s { + loaded.keys[r], loaded.keys[s] = loaded.keys[s], loaded.keys[r] + loaded.values[r], loaded.values[s] = loaded.values[s], loaded.values[r] + } + return true + } + return false + }) + sortTime := time.Since(now) + now = time.Now() + for i, key := range loaded.keys { + err1 = writer.WriteRow(ctx, key, loaded.values[i]) + if err1 != nil { + logutil.Logger(ctx).Warn("write one row to writer failed", zap.Error(err1)) + return + } + } + writeTime := time.Since(now) + logutil.Logger(ctx).Info("sort one group in MergeOverlappingFiles", + zap.Duration("read time", readTime), + zap.Duration("sort time", sortTime), + zap.Duration("write time", writeTime), + zap.Int("key len", len(loaded.keys))) + + if len(minKey) == 0 { + minKey = kv.Key(loaded.keys[0]).Clone() + } + maxKey = kv.Key(loaded.keys[len(loaded.keys)-1]).Clone() + curStart = curEnd.Clone() + loaded.keys = nil + loaded.values = nil + loaded.memKVBuffers = nil + + if len(endKeyOfGroup) == 0 { + break + } + } + + var stat MultipleFilesStat + stat.Filenames = append(stat.Filenames, + [2]string{writer.dataFile, writer.statFile}) + + stat.build([]kv.Key{minKey}, []kv.Key{curEnd}) + if onClose != nil { + onClose(&WriterSummary{ + WriterID: writer.writerID, + Seq: 0, + Min: minKey, + Max: maxKey, + TotalSize: writer.totalSize, + MultipleFilesStats: []MultipleFilesStat{stat}, + }) + } + return +} diff --git a/br/pkg/lightning/backend/external/onefile_writer.go b/br/pkg/lightning/backend/external/onefile_writer.go index 996f336909fbc..06d2f5f86df95 100644 --- a/br/pkg/lightning/backend/external/onefile_writer.go +++ b/br/pkg/lightning/backend/external/onefile_writer.go @@ -47,8 +47,7 @@ type OneFileWriter struct { dataWriter storage.ExternalFileWriter statWriter storage.ExternalFileWriter - onClose OnCloseFunc - closed bool + closed bool logger *zap.Logger } diff --git a/br/pkg/lightning/backend/external/onefile_writer_test.go b/br/pkg/lightning/backend/external/onefile_writer_test.go index 3999543a82ae7..c347dd1b3395d 100644 --- a/br/pkg/lightning/backend/external/onefile_writer_test.go +++ b/br/pkg/lightning/backend/external/onefile_writer_test.go @@ -50,8 +50,7 @@ func TestOnefileWriterBasic(t *testing.T) { SetPropKeysDistance(2). BuildOneFile(memStore, "/test", "0") - err := writer.Init(ctx, 5*1024*1024) - require.NoError(t, err) + require.NoError(t, writer.Init(ctx, 5*1024*1024)) kvCnt := 100 kvs := make([]common.KvPair, kvCnt) @@ -67,12 +66,10 @@ func TestOnefileWriterBasic(t *testing.T) { } for _, item := range kvs { - err := writer.WriteRow(ctx, item.Key, item.Val) - require.NoError(t, err) + require.NoError(t, writer.WriteRow(ctx, item.Key, item.Val)) } - err = writer.Close(ctx) - require.NoError(t, err) + require.NoError(t, writer.Close(ctx)) bufSize := rand.Intn(100) + 1 kvReader, err := newKVReader(ctx, "/test/0/one-file", memStore, 0, bufSize) @@ -124,8 +121,7 @@ func checkOneFileWriterStatWithDistance(t *testing.T, kvCnt int, keysDistance ui SetPropKeysDistance(keysDistance). BuildOneFile(memStore, "/"+prefix, "0") - err := writer.Init(ctx, 5*1024*1024) - require.NoError(t, err) + require.NoError(t, writer.Init(ctx, 5*1024*1024)) kvs := make([]common.KvPair, 0, kvCnt) for i := 0; i < kvCnt; i++ { kvs = append(kvs, common.KvPair{ @@ -134,11 +130,9 @@ func checkOneFileWriterStatWithDistance(t *testing.T, kvCnt int, keysDistance ui }) } for _, item := range kvs { - err := writer.WriteRow(ctx, item.Key, item.Val) - require.NoError(t, err) + require.NoError(t, writer.WriteRow(ctx, item.Key, item.Val)) } - err = writer.Close(ctx) - require.NoError(t, err) + require.NoError(t, writer.Close(ctx)) bufSize := rand.Intn(100) + 1 kvReader, err := newKVReader(ctx, "/"+prefix+"/0/one-file", memStore, 0, bufSize) @@ -177,7 +171,7 @@ func checkOneFileWriterStatWithDistance(t *testing.T, kvCnt int, keysDistance ui require.NoError(t, statReader.Close()) } -func TestMergeOverlappingFilesV2(t *testing.T) { +func TestMergeOverlappingFilesInternal(t *testing.T) { // 1. Write to 5 files. // 2. merge 5 files into one file. // 3. read one file and check result. @@ -197,13 +191,11 @@ func TestMergeOverlappingFilesV2(t *testing.T) { v-- // insert a duplicate key. } key, val := []byte{byte(v)}, []byte{byte(v)} - err := writer.WriteRow(ctx, key, val, dbkv.IntHandle(i)) - require.NoError(t, err) + require.NoError(t, writer.WriteRow(ctx, key, val, dbkv.IntHandle(i))) } - err := writer.Close(ctx) - require.NoError(t, err) + require.NoError(t, writer.Close(ctx)) - err = mergeOverlappingFilesV2( + require.NoError(t, mergeOverlappingFilesInternal( ctx, []string{"/test/0/0", "/test/0/1", "/test/0/2", "/test/0/3", "/test/0/4"}, memStore, @@ -218,8 +210,7 @@ func TestMergeOverlappingFilesV2(t *testing.T) { 2, nil, true, - ) - require.NoError(t, err) + )) keys := make([][]byte, 0, kvCount) values := make([][]byte, 0, kvCount) @@ -276,8 +267,7 @@ func TestOnefileWriterManyRows(t *testing.T) { SetMemorySizeLimit(1000). BuildOneFile(memStore, "/test", "0") - err := writer.Init(ctx, 5*1024*1024) - require.NoError(t, err) + require.NoError(t, writer.Init(ctx, 5*1024*1024)) kvCnt := 100000 expectedTotalSize := 0 @@ -301,17 +291,15 @@ func TestOnefileWriterManyRows(t *testing.T) { }) for _, item := range kvs { - err := writer.WriteRow(ctx, item.Key, item.Val) - require.NoError(t, err) + require.NoError(t, writer.WriteRow(ctx, item.Key, item.Val)) } - err = writer.Close(ctx) - require.NoError(t, err) + require.NoError(t, writer.Close(ctx)) var resSummary *WriterSummary onClose := func(summary *WriterSummary) { resSummary = summary } - err = mergeOverlappingFilesV2( + require.NoError(t, mergeOverlappingFilesInternal( ctx, []string{"/test/0/one-file"}, memStore, @@ -326,8 +314,7 @@ func TestOnefileWriterManyRows(t *testing.T) { 2, onClose, true, - ) - require.NoError(t, err) + )) bufSize := rand.Intn(100) + 1 kvReader, err := newKVReader(ctx, "/test2/mergeID/one-file", memStore, 0, bufSize) diff --git a/br/pkg/lightning/backend/external/reader_test.go b/br/pkg/lightning/backend/external/reader_test.go index e7fcd6c5e0586..329ac48a478b2 100644 --- a/br/pkg/lightning/backend/external/reader_test.go +++ b/br/pkg/lightning/backend/external/reader_test.go @@ -95,8 +95,7 @@ func TestReadAllOneFile(t *testing.T) { require.NoError(t, w.WriteRow(ctx, kvs[i].Key, kvs[i].Val)) } - err := w.Close(ctx) - require.NoError(t, err) + require.NoError(t, w.Close(ctx)) slices.SortFunc(kvs, func(i, j common.KvPair) int { return bytes.Compare(i.Key, j.Key) diff --git a/br/pkg/lightning/backend/external/sort_test.go b/br/pkg/lightning/backend/external/sort_test.go index 7ff2f9e20853a..dce2d75ef03d7 100644 --- a/br/pkg/lightning/backend/external/sort_test.go +++ b/br/pkg/lightning/backend/external/sort_test.go @@ -18,10 +18,12 @@ import ( "bytes" "context" "slices" + "strconv" "testing" "time" "github.com/google/uuid" + "github.com/pingcap/failpoint" "github.com/pingcap/tidb/br/pkg/lightning/backend/kv" "github.com/pingcap/tidb/br/pkg/lightning/common" "github.com/pingcap/tidb/br/pkg/storage" @@ -148,7 +150,7 @@ func TestGlobalSortLocalWithMerge(t *testing.T) { } for _, group := range dataGroup { - MergeOverlappingFiles( + require.NoError(t, MergeOverlappingFiles( ctx, group, memStore, @@ -162,7 +164,107 @@ func TestGlobalSortLocalWithMerge(t *testing.T) { closeFn, 1, true, - ) + )) + } + + // 3. read and sort step + testReadAndCompare(ctx, t, kvs, memStore, lastStepDatas, lastStepStats, startKey, memSizeLimit) +} + +func TestGlobalSortLocalWithMergeV2(t *testing.T) { + // 1. write data step + seed := time.Now().Unix() + rand.Seed(uint64(seed)) + t.Logf("seed: %d", seed) + ctx := context.Background() + memStore := storage.NewMemStorage() + memSizeLimit := (rand.Intn(10) + 1) * 400 + multiStats := make([]MultipleFilesStat, 0, 100) + randomSize := (rand.Intn(500) + 1) * 1000 + + failpoint.Enable("github.com/pingcap/tidb/br/pkg/lightning/backend/external/mockRangesGroupSize", + "return("+strconv.Itoa(randomSize)+")") + t.Cleanup(func() { + failpoint.Disable("github.com/pingcap/tidb/br/pkg/lightning/backend/external/mockRangesGroupSize") + }) + datas := make([]string, 0, 100) + stats := make([]string, 0, 100) + // prepare meta for merge step. + closeFn := func(s *WriterSummary) { + multiStats = append(multiStats, s.MultipleFilesStats...) + for _, stat := range s.MultipleFilesStats { + for i := range stat.Filenames { + datas = append(datas, stat.Filenames[i][0]) + stats = append(stats, stat.Filenames[i][1]) + } + } + } + + w := NewWriterBuilder(). + SetPropSizeDistance(100). + SetPropKeysDistance(2). + SetMemorySizeLimit(uint64(memSizeLimit)). + SetBlockSize(memSizeLimit). + SetOnCloseFunc(closeFn). + Build(memStore, "/test", "0") + + writer := NewEngineWriter(w) + kvCnt := rand.Intn(10) + 10000 + kvs := make([]common.KvPair, kvCnt) + for i := 0; i < kvCnt; i++ { + kvs[i] = common.KvPair{ + Key: []byte(uuid.New().String()), + Val: []byte("56789"), + } + } + slices.SortFunc(kvs, func(i, j common.KvPair) int { + return bytes.Compare(i.Key, j.Key) + }) + require.NoError(t, writer.AppendRows(ctx, nil, kv.MakeRowsFromKvPairs(kvs))) + _, err := writer.Close(ctx) + require.NoError(t, err) + + // 2. merge step + dataGroup, statGroup, startKeys, endKeys := splitDataStatAndKeys(datas, stats, multiStats) + lastStepDatas := make([]string, 0, 10) + lastStepStats := make([]string, 0, 10) + var startKey, endKey dbkv.Key + + // prepare meta for last step. + closeFn1 := func(s *WriterSummary) { + for _, stat := range s.MultipleFilesStats { + for i := range stat.Filenames { + lastStepDatas = append(lastStepDatas, stat.Filenames[i][0]) + lastStepStats = append(lastStepStats, stat.Filenames[i][1]) + } + + } + if len(startKey) == 0 && len(endKey) == 0 { + startKey = s.Min.Clone() + endKey = s.Max.Clone().Next() + } + startKey = BytesMin(startKey, s.Min.Clone()) + endKey = BytesMax(endKey, s.Max.Clone().Next()) + } + + for i, group := range dataGroup { + require.NoError(t, MergeOverlappingFilesV2( + ctx, + group, + statGroup[i], + memStore, + startKeys[i], + endKeys[i], + int64(5*size.MB), + "/test2", + uuid.NewString(), + 100, + 8*1024, + 100, + 2, + closeFn1, + 1, + true)) } // 3. read and sort step diff --git a/br/pkg/lightning/backend/external/testutil.go b/br/pkg/lightning/backend/external/testutil.go index 0879021a32917..c2768fa00c4a7 100644 --- a/br/pkg/lightning/backend/external/testutil.go +++ b/br/pkg/lightning/backend/external/testutil.go @@ -93,12 +93,12 @@ func testReadAndCompare( require.EqualValues(t, kvs[kvIdx].Val, loaded.values[i]) kvIdx++ } + curStart = curEnd.Clone() // release loaded.keys = nil loaded.values = nil loaded.memKVBuffers = nil - curStart = curEnd.Clone() if len(endKeyOfGroup) == 0 { break @@ -127,3 +127,37 @@ func splitDataAndStatFiles(datas []string, stats []string) ([][]string, [][]stri } return dataGroup, statGroup } + +// split data&stat files, startKeys and endKeys into groups for new merge step. +func splitDataStatAndKeys(datas []string, stats []string, multiStats []MultipleFilesStat) ([][]string, [][]string, []dbkv.Key, []dbkv.Key) { + startKeys := make([]dbkv.Key, 0, 10) + endKeys := make([]dbkv.Key, 0, 10) + i := 0 + for ; i < len(multiStats)-1; i += 2 { + startKey := BytesMin(multiStats[i].MinKey, multiStats[i+1].MinKey) + endKey := BytesMax(multiStats[i].MaxKey, multiStats[i+1].MaxKey) + endKey = dbkv.Key(endKey).Next().Clone() + startKeys = append(startKeys, startKey) + endKeys = append(endKeys, endKey) + } + if i == len(multiStats)-1 { + startKeys = append(startKeys, multiStats[i].MinKey.Clone()) + endKeys = append(endKeys, dbkv.Key(multiStats[i].MaxKey).Next().Clone()) + } + + dataGroup := make([][]string, 0, 10) + statGroup := make([][]string, 0, 10) + + start := 0 + step := 2 * multiFileStatNum + for start < len(datas) { + end := start + step + if end > len(datas) { + end = len(datas) + } + dataGroup = append(dataGroup, datas[start:end]) + statGroup = append(statGroup, stats[start:end]) + start = end + } + return dataGroup, statGroup, startKeys, endKeys +} diff --git a/br/pkg/lightning/backend/external/writer.go b/br/pkg/lightning/backend/external/writer.go index 334cacb9e135a..25b15f23699a5 100644 --- a/br/pkg/lightning/backend/external/writer.go +++ b/br/pkg/lightning/backend/external/writer.go @@ -231,7 +231,6 @@ func (b *WriterBuilder) BuildOneFile( filenamePrefix: filenamePrefix, writerID: writerID, kvStore: nil, - onClose: b.onClose, closed: false, } return ret @@ -488,7 +487,6 @@ func (w *Writer) flushKVs(ctx context.Context, fromClose bool) (err error) { w.recordMinMax(minKey, maxKey, uint64(w.kvSize)) // maintain 500-batch statistics - l := len(w.multiFileStats) w.multiFileStats[l-1].Filenames = append(w.multiFileStats[l-1].Filenames, [2]string{dataFile, statFile}, diff --git a/br/pkg/lightning/backend/external/writer_test.go b/br/pkg/lightning/backend/external/writer_test.go index 7bc853f63dc1c..feee9536ea913 100644 --- a/br/pkg/lightning/backend/external/writer_test.go +++ b/br/pkg/lightning/backend/external/writer_test.go @@ -30,14 +30,76 @@ import ( "github.com/jfcg/sorty/v2" "github.com/pingcap/tidb/br/pkg/lightning/backend/kv" "github.com/pingcap/tidb/br/pkg/lightning/common" + "github.com/pingcap/tidb/br/pkg/lightning/log" "github.com/pingcap/tidb/br/pkg/membuf" "github.com/pingcap/tidb/br/pkg/storage" dbkv "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/util/logutil" "github.com/pingcap/tidb/pkg/util/size" "github.com/stretchr/testify/require" + "go.uber.org/zap" "golang.org/x/exp/rand" ) +// only used in testing for now. +func mergeOverlappingFilesImpl(ctx context.Context, + paths []string, + store storage.ExternalStorage, + readBufferSize int, + newFilePrefix string, + writerID string, + memSizeLimit uint64, + blockSize int, + writeBatchCount uint64, + propSizeDist uint64, + propKeysDist uint64, + onClose OnCloseFunc, + checkHotspot bool, +) (err error) { + task := log.BeginTask(logutil.Logger(ctx).With( + zap.String("writer-id", writerID), + zap.Int("file-count", len(paths)), + ), "merge overlapping files") + defer func() { + task.End(zap.ErrorLevel, err) + }() + + zeroOffsets := make([]uint64, len(paths)) + iter, err := NewMergeKVIter(ctx, paths, zeroOffsets, store, readBufferSize, checkHotspot, 0) + if err != nil { + return err + } + defer func() { + err := iter.Close() + if err != nil { + logutil.Logger(ctx).Warn("close iterator failed", zap.Error(err)) + } + }() + + writer := NewWriterBuilder(). + SetMemorySizeLimit(memSizeLimit). + SetBlockSize(blockSize). + SetOnCloseFunc(onClose). + SetWriterBatchCount(writeBatchCount). + SetPropSizeDistance(propSizeDist). + SetPropKeysDistance(propKeysDist). + Build(store, newFilePrefix, writerID) + + // currently use same goroutine to do read and write. The main advantage is + // there's no KV copy and iter can reuse the buffer. + for iter.Next() { + err = writer.WriteRow(ctx, iter.Key(), iter.Value(), nil) + if err != nil { + return err + } + } + err = iter.Error() + if err != nil { + return err + } + return writer.Close(ctx) +} + func TestWriter(t *testing.T) { seed := time.Now().Unix() rand.Seed(uint64(seed)) From a9317bebab1783a88701503cb835658ac600b25e Mon Sep 17 00:00:00 2001 From: JmPotato Date: Tue, 19 Dec 2023 16:42:22 +0800 Subject: [PATCH 3/9] driver: update the PD HTTP client to support source mark (#49578) ref pingcap/tidb#35319 --- DEPS.bzl | 24 +++++++++---------- go.mod | 4 ++-- go.sum | 8 +++---- pkg/executor/infoschema_cluster_table_test.go | 2 +- pkg/executor/tikv_regions_peers_table_test.go | 2 +- pkg/store/driver/tikv_driver.go | 2 +- pkg/store/helper/helper_test.go | 2 +- 7 files changed, 22 insertions(+), 22 deletions(-) diff --git a/DEPS.bzl b/DEPS.bzl index b678f49f170de..59d24ee8beb02 100644 --- a/DEPS.bzl +++ b/DEPS.bzl @@ -7006,26 +7006,26 @@ def go_deps(): name = "com_github_tikv_client_go_v2", build_file_proto_mode = "disable_global", importpath = "github.com/tikv/client-go/v2", - sha256 = "04da7d520727a9140c0d472c0f0a054837aae1da3fa49101c0f52279c7d78094", - strip_prefix = "github.com/tikv/client-go/v2@v2.0.8-0.20231204074048-e80e9ca1fe66", + sha256 = "8ff835049b1a8bf797c8d38b7081d761e985893c9b78ac9264ccca7b428fddf7", + strip_prefix = "github.com/tikv/client-go/v2@v2.0.8-0.20231219052137-6f9ba8327b75", urls = [ - "http://bazel-cache.pingcap.net:8080/gomod/github.com/tikv/client-go/v2/com_github_tikv_client_go_v2-v2.0.8-0.20231204074048-e80e9ca1fe66.zip", - "http://ats.apps.svc/gomod/github.com/tikv/client-go/v2/com_github_tikv_client_go_v2-v2.0.8-0.20231204074048-e80e9ca1fe66.zip", - "https://cache.hawkingrei.com/gomod/github.com/tikv/client-go/v2/com_github_tikv_client_go_v2-v2.0.8-0.20231204074048-e80e9ca1fe66.zip", - "https://storage.googleapis.com/pingcapmirror/gomod/github.com/tikv/client-go/v2/com_github_tikv_client_go_v2-v2.0.8-0.20231204074048-e80e9ca1fe66.zip", + "http://bazel-cache.pingcap.net:8080/gomod/github.com/tikv/client-go/v2/com_github_tikv_client_go_v2-v2.0.8-0.20231219052137-6f9ba8327b75.zip", + "http://ats.apps.svc/gomod/github.com/tikv/client-go/v2/com_github_tikv_client_go_v2-v2.0.8-0.20231219052137-6f9ba8327b75.zip", + "https://cache.hawkingrei.com/gomod/github.com/tikv/client-go/v2/com_github_tikv_client_go_v2-v2.0.8-0.20231219052137-6f9ba8327b75.zip", + "https://storage.googleapis.com/pingcapmirror/gomod/github.com/tikv/client-go/v2/com_github_tikv_client_go_v2-v2.0.8-0.20231219052137-6f9ba8327b75.zip", ], ) go_repository( name = "com_github_tikv_pd_client", build_file_proto_mode = "disable_global", importpath = "github.com/tikv/pd/client", - sha256 = "5b31b38e151e03117ef9878c2dbac2b1f22c890e72ebb70935795ac5682c77c1", - strip_prefix = "github.com/tikv/pd/client@v0.0.0-20231213112719-f51f9134558e", + sha256 = "929a41d5e836a8ef04ed5bad9500023bd58cb17381388b9c4f1ae07ebc039287", + strip_prefix = "github.com/tikv/pd/client@v0.0.0-20231219031951-25f48f0bdd27", urls = [ - "http://bazel-cache.pingcap.net:8080/gomod/github.com/tikv/pd/client/com_github_tikv_pd_client-v0.0.0-20231213112719-f51f9134558e.zip", - "http://ats.apps.svc/gomod/github.com/tikv/pd/client/com_github_tikv_pd_client-v0.0.0-20231213112719-f51f9134558e.zip", - "https://cache.hawkingrei.com/gomod/github.com/tikv/pd/client/com_github_tikv_pd_client-v0.0.0-20231213112719-f51f9134558e.zip", - "https://storage.googleapis.com/pingcapmirror/gomod/github.com/tikv/pd/client/com_github_tikv_pd_client-v0.0.0-20231213112719-f51f9134558e.zip", + "http://bazel-cache.pingcap.net:8080/gomod/github.com/tikv/pd/client/com_github_tikv_pd_client-v0.0.0-20231219031951-25f48f0bdd27.zip", + "http://ats.apps.svc/gomod/github.com/tikv/pd/client/com_github_tikv_pd_client-v0.0.0-20231219031951-25f48f0bdd27.zip", + "https://cache.hawkingrei.com/gomod/github.com/tikv/pd/client/com_github_tikv_pd_client-v0.0.0-20231219031951-25f48f0bdd27.zip", + "https://storage.googleapis.com/pingcapmirror/gomod/github.com/tikv/pd/client/com_github_tikv_pd_client-v0.0.0-20231219031951-25f48f0bdd27.zip", ], ) go_repository( diff --git a/go.mod b/go.mod index e969f60fcaa6e..02ab1ea21f3ae 100644 --- a/go.mod +++ b/go.mod @@ -102,8 +102,8 @@ require ( github.com/stretchr/testify v1.8.4 github.com/tdakkota/asciicheck v0.2.0 github.com/tiancaiamao/appdash v0.0.0-20181126055449-889f96f722a2 - github.com/tikv/client-go/v2 v2.0.8-0.20231204074048-e80e9ca1fe66 - github.com/tikv/pd/client v0.0.0-20231213112719-f51f9134558e + github.com/tikv/client-go/v2 v2.0.8-0.20231219052137-6f9ba8327b75 + github.com/tikv/pd/client v0.0.0-20231219031951-25f48f0bdd27 github.com/timakin/bodyclose v0.0.0-20230421092635-574207250966 github.com/twmb/murmur3 v1.1.6 github.com/uber/jaeger-client-go v2.22.1+incompatible diff --git a/go.sum b/go.sum index aeafaaa2a5b28..aa001d8a15894 100644 --- a/go.sum +++ b/go.sum @@ -855,10 +855,10 @@ github.com/tiancaiamao/appdash v0.0.0-20181126055449-889f96f722a2 h1:mbAskLJ0oJf github.com/tiancaiamao/appdash v0.0.0-20181126055449-889f96f722a2/go.mod h1:2PfKggNGDuadAa0LElHrByyrz4JPZ9fFx6Gs7nx7ZZU= github.com/tiancaiamao/gp v0.0.0-20221230034425-4025bc8a4d4a h1:J/YdBZ46WKpXsxsW93SG+q0F8KI+yFrcIDT4c/RNoc4= github.com/tiancaiamao/gp v0.0.0-20221230034425-4025bc8a4d4a/go.mod h1:h4xBhSNtOeEosLJ4P7JyKXX7Cabg7AVkWCK5gV2vOrM= -github.com/tikv/client-go/v2 v2.0.8-0.20231204074048-e80e9ca1fe66 h1:+bCtNxUSYVmY/wmN8Zhf0UVl9mF+04/OjoseguM2aWY= -github.com/tikv/client-go/v2 v2.0.8-0.20231204074048-e80e9ca1fe66/go.mod h1:IE0/o4zWJW5Wpvp15CK2jpbu49DSLLSJBMNmwjv6I6o= -github.com/tikv/pd/client v0.0.0-20231213112719-f51f9134558e h1:FKyIJ13O2iIgbWmG3xNN+E2D10XAptcXYUrJY4S9AmI= -github.com/tikv/pd/client v0.0.0-20231213112719-f51f9134558e/go.mod h1:AwjTSpM7CgAynYwB6qTG5R5fVC9/eXlQXiTO6zDL1HI= +github.com/tikv/client-go/v2 v2.0.8-0.20231219052137-6f9ba8327b75 h1:Q+dA/vxdfAoydEaWMSveKXAivkB5sUQk49FUP/q7+ec= +github.com/tikv/client-go/v2 v2.0.8-0.20231219052137-6f9ba8327b75/go.mod h1:p5afyVTeuzh3u9qb4nEMhNRYl80FuCZDv2ZPa3fHw0w= +github.com/tikv/pd/client v0.0.0-20231219031951-25f48f0bdd27 h1:U8jPVwFu9Zu8tXYlmOxO/Zv3OcsgoJ/COSwMNWvED9c= +github.com/tikv/pd/client v0.0.0-20231219031951-25f48f0bdd27/go.mod h1:AwjTSpM7CgAynYwB6qTG5R5fVC9/eXlQXiTO6zDL1HI= github.com/timakin/bodyclose v0.0.0-20230421092635-574207250966 h1:quvGphlmUVU+nhpFa4gg4yJyTRJ13reZMDHrKwYw53M= github.com/timakin/bodyclose v0.0.0-20230421092635-574207250966/go.mod h1:27bSVNWSBOHm+qRp1T9qzaIpsWEP6TbUnei/43HK+PQ= github.com/tklauser/go-sysconf v0.3.9/go.mod h1:11DU/5sG7UexIrp/O6g35hrWzu0JxlwQ3LSFUzyeuhs= diff --git a/pkg/executor/infoschema_cluster_table_test.go b/pkg/executor/infoschema_cluster_table_test.go index 3e530831e5982..86894e802d66c 100644 --- a/pkg/executor/infoschema_cluster_table_test.go +++ b/pkg/executor/infoschema_cluster_table_test.go @@ -59,7 +59,7 @@ func createInfosSchemaClusterTableSuite(t *testing.T) *infosSchemaClusterTableSu s.httpServer, s.mockAddr = s.setUpMockPDHTTPServer() s.store, s.dom = testkit.CreateMockStoreAndDomain( t, - mockstore.WithTiKVOptions(tikv.WithPDHTTPClient([]string{s.mockAddr})), + mockstore.WithTiKVOptions(tikv.WithPDHTTPClient("infoschema-cluster-table-test", []string{s.mockAddr})), ) s.rpcServer, s.listenAddr = setUpRPCService(t, s.dom, "127.0.0.1:0") s.startTime = time.Now() diff --git a/pkg/executor/tikv_regions_peers_table_test.go b/pkg/executor/tikv_regions_peers_table_test.go index dd8ffb5757cca..5319960318b3f 100644 --- a/pkg/executor/tikv_regions_peers_table_test.go +++ b/pkg/executor/tikv_regions_peers_table_test.go @@ -112,7 +112,7 @@ func TestTikvRegionPeers(t *testing.T) { defer server.Close() store := testkit.CreateMockStore(t, - mockstore.WithTiKVOptions(tikv.WithPDHTTPClient([]string{mockAddr}))) + mockstore.WithTiKVOptions(tikv.WithPDHTTPClient("tikv-regions-peers-table-test", []string{mockAddr}))) store = &mockStore{ store.(helper.Storage), diff --git a/pkg/store/driver/tikv_driver.go b/pkg/store/driver/tikv_driver.go index 199661d56f95d..cd8bfa2cec98c 100644 --- a/pkg/store/driver/tikv_driver.go +++ b/pkg/store/driver/tikv_driver.go @@ -215,7 +215,7 @@ func (d TiKVDriver) OpenWithOptions(path string, options ...Option) (resStore kv ) s, err = tikv.NewKVStore(uuid, pdClient, spkv, &injectTraceClient{Client: rpcClient}, - tikv.WithPDHTTPClient(etcdAddrs, pdhttp.WithTLSConfig(tlsConfig), pdhttp.WithMetrics(metrics.PDAPIRequestCounter, metrics.PDAPIExecutionHistogram))) + tikv.WithPDHTTPClient("tikv-driver", etcdAddrs, pdhttp.WithTLSConfig(tlsConfig), pdhttp.WithMetrics(metrics.PDAPIRequestCounter, metrics.PDAPIExecutionHistogram))) if err != nil { return nil, errors.Trace(err) } diff --git a/pkg/store/helper/helper_test.go b/pkg/store/helper/helper_test.go index 21774cfc13ebe..0702928091984 100644 --- a/pkg/store/helper/helper_test.go +++ b/pkg/store/helper/helper_test.go @@ -152,7 +152,7 @@ func createMockStore(t *testing.T) (store helper.Storage) { mockstore.WithClusterInspector(func(c testutils.Cluster) { mockstore.BootstrapWithMultiRegions(c, []byte("x")) }), - mockstore.WithTiKVOptions(tikv.WithPDHTTPClient(pdAddrs)), + mockstore.WithTiKVOptions(tikv.WithPDHTTPClient("store-helper-test", pdAddrs)), ) require.NoError(t, err) From abc54d0339b9615fea45191268b7f5d9fc00034f Mon Sep 17 00:00:00 2001 From: SeaRise Date: Tue, 19 Dec 2023 17:19:52 +0800 Subject: [PATCH 4/9] expression: Add `json_valid` and `json_keys` push down support for tiflash (#49347) close pingcap/tidb#49345 --- pkg/expression/expr_to_pb_test.go | 24 ++++++++++++++++++++++++ pkg/expression/expression.go | 4 ++-- 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/pkg/expression/expr_to_pb_test.go b/pkg/expression/expr_to_pb_test.go index e93b20488ff3a..332e7a3c4bcb9 100644 --- a/pkg/expression/expr_to_pb_test.go +++ b/pkg/expression/expr_to_pb_test.go @@ -555,6 +555,30 @@ func TestExprPushDownToFlash(t *testing.T) { require.NoError(t, err) exprs = append(exprs, function) + // json_valid + /// json_valid_others + function, err = NewFunction(mock.NewContext(), ast.JSONValid, types.NewFieldType(mysql.TypeLonglong), intColumn) + require.NoError(t, err) + exprs = append(exprs, function) + /// json_valid_json + function, err = NewFunction(mock.NewContext(), ast.JSONValid, types.NewFieldType(mysql.TypeLonglong), jsonColumn) + require.NoError(t, err) + exprs = append(exprs, function) + /// json_valid_string + function, err = NewFunction(mock.NewContext(), ast.JSONValid, types.NewFieldType(mysql.TypeLonglong), stringColumn) + require.NoError(t, err) + exprs = append(exprs, function) + + // json_keys + /// 1 arg + function, err = NewFunction(mock.NewContext(), ast.JSONKeys, types.NewFieldType(mysql.TypeJSON), jsonColumn) + require.NoError(t, err) + exprs = append(exprs, function) + /// 2 args + function, err = NewFunction(mock.NewContext(), ast.JSONKeys, types.NewFieldType(mysql.TypeJSON), jsonColumn, stringColumn) + require.NoError(t, err) + exprs = append(exprs, function) + // lpad function, err = NewFunction(mock.NewContext(), ast.Lpad, types.NewFieldType(mysql.TypeString), stringColumn, int32Column, stringColumn) require.NoError(t, err) diff --git a/pkg/expression/expression.go b/pkg/expression/expression.go index fff11cb64e0fc..e32e1de6e6c70 100644 --- a/pkg/expression/expression.go +++ b/pkg/expression/expression.go @@ -1160,8 +1160,8 @@ func scalarExprSupportedByFlash(function *ScalarFunction) bool { ast.Sqrt, ast.Log, ast.Log2, ast.Log10, ast.Ln, ast.Exp, ast.Pow, ast.Sign, ast.Radians, ast.Degrees, ast.Conv, ast.CRC32, - ast.JSONLength, ast.JSONDepth, ast.JSONExtract, ast.JSONUnquote, ast.JSONArray, ast.Repeat, - ast.InetNtoa, ast.InetAton, ast.Inet6Ntoa, ast.Inet6Aton, + ast.JSONLength, ast.JSONDepth, ast.JSONExtract, ast.JSONUnquote, ast.JSONArray, ast.JSONValid, ast.JSONKeys, + ast.Repeat, ast.InetNtoa, ast.InetAton, ast.Inet6Ntoa, ast.Inet6Aton, ast.Coalesce, ast.ASCII, ast.Length, ast.Trim, ast.Position, ast.Format, ast.Elt, ast.LTrim, ast.RTrim, ast.Lpad, ast.Rpad, ast.Hour, ast.Minute, ast.Second, ast.MicroSecond, From b850d26e7fe7c380dab633d26278b5949c149ae6 Mon Sep 17 00:00:00 2001 From: YangKeao Date: Tue, 19 Dec 2023 17:54:23 +0800 Subject: [PATCH 5/9] pkg/executor: refine the Executor interface (#49494) close pingcap/tidb#49490 --- pkg/executor/adapter.go | 6 +-- pkg/executor/admin.go | 2 +- pkg/executor/batch_point_get.go | 2 +- pkg/executor/builder.go | 12 ++--- pkg/executor/coprocessor.go | 4 +- pkg/executor/cte.go | 4 +- pkg/executor/executor.go | 7 ++- pkg/executor/index_lookup_join.go | 8 +++- pkg/executor/index_lookup_merge_join.go | 4 +- pkg/executor/index_merge_reader.go | 6 +-- pkg/executor/insert.go | 2 +- pkg/executor/internal/exec/executor.go | 64 ++++++++++++++++--------- pkg/executor/join.go | 4 +- pkg/executor/merge_join.go | 2 +- pkg/executor/point_get.go | 6 +-- 15 files changed, 77 insertions(+), 56 deletions(-) diff --git a/pkg/executor/adapter.go b/pkg/executor/adapter.go index 4bb4e5c2d483c..ae239e1a81cac 100644 --- a/pkg/executor/adapter.go +++ b/pkg/executor/adapter.go @@ -177,8 +177,7 @@ func (a *recordSet) NewChunk(alloc chunk.Allocator) *chunk.Chunk { return exec.NewFirstChunk(a.executor) } - base := a.executor.Base() - return alloc.Alloc(base.RetFieldTypes(), base.InitCap(), base.MaxChunkSize()) + return alloc.Alloc(a.executor.RetFieldTypes(), a.executor.InitCap(), a.executor.MaxChunkSize()) } func (a *recordSet) Finish() error { @@ -886,8 +885,7 @@ func (c *chunkRowRecordSet) NewChunk(alloc chunk.Allocator) *chunk.Chunk { return exec.NewFirstChunk(c.e) } - base := c.e.Base() - return alloc.Alloc(base.RetFieldTypes(), base.InitCap(), base.MaxChunkSize()) + return alloc.Alloc(c.e.RetFieldTypes(), c.e.InitCap(), c.e.MaxChunkSize()) } func (c *chunkRowRecordSet) Close() error { diff --git a/pkg/executor/admin.go b/pkg/executor/admin.go index a340de24058de..2ae551c1eced8 100644 --- a/pkg/executor/admin.go +++ b/pkg/executor/admin.go @@ -216,7 +216,7 @@ func (e *RecoverIndexExec) columnsTypes() []*types.FieldType { // Open implements the Executor Open interface. func (e *RecoverIndexExec) Open(ctx context.Context) error { - if err := exec.Open(ctx, e.BaseExecutor.Base()); err != nil { + if err := exec.Open(ctx, &e.BaseExecutor); err != nil { return err } diff --git a/pkg/executor/batch_point_get.go b/pkg/executor/batch_point_get.go index 06afa7285fa8f..d1bd094e29b6d 100644 --- a/pkg/executor/batch_point_get.go +++ b/pkg/executor/batch_point_get.go @@ -186,7 +186,7 @@ func (e *BatchPointGetExec) Next(ctx context.Context, req *chunk.Chunk) error { } for !req.IsFull() && e.index < len(e.values) { handle, val := e.handles[e.index], e.values[e.index] - err := DecodeRowValToChunk(e.Base().Ctx(), e.Schema(), e.tblInfo, handle, val, req, e.rowDecoder) + err := DecodeRowValToChunk(e.BaseExecutor.Ctx(), e.Schema(), e.tblInfo, handle, val, req, e.rowDecoder) if err != nil { return err } diff --git a/pkg/executor/builder.go b/pkg/executor/builder.go index 2f65723b82bf3..d5b5b7ad35d97 100644 --- a/pkg/executor/builder.go +++ b/pkg/executor/builder.go @@ -3463,7 +3463,7 @@ func (b *executorBuilder) buildTableReader(v *plannercore.PhysicalTableReader) e } if len(partitions) == 0 { - return &TableDualExec{BaseExecutor: *ret.Base()} + return &TableDualExec{BaseExecutor: ret.BaseExecutor} } // Sort the partition is necessary to make the final multiple partition key ranges ordered. @@ -4467,7 +4467,7 @@ func (builder *dataReaderBuilder) buildIndexReaderForIndexJoin(ctx context.Conte } return e, nil } - ret := &TableDualExec{BaseExecutor: *e.Base()} + ret := &TableDualExec{BaseExecutor: e.BaseExecutor} err = exec.Open(ctx, ret) return ret, err } @@ -4544,7 +4544,7 @@ func (builder *dataReaderBuilder) buildIndexLookUpReaderForIndexJoin(ctx context } return e, err } - ret := &TableDualExec{BaseExecutor: *e.Base()} + ret := &TableDualExec{BaseExecutor: e.BaseExecutor} err = exec.Open(ctx, ret) return ret, err } @@ -5114,8 +5114,8 @@ func (b *executorBuilder) buildBatchPointGet(plan *plannercore.BatchPointGetPlan } capacity = len(e.handles) } - e.Base().SetInitCap(capacity) - e.Base().SetMaxChunkSize(capacity) + e.SetInitCap(capacity) + e.SetMaxChunkSize(capacity) e.buildVirtualColumnInfo() return e } @@ -5297,7 +5297,7 @@ func (b *executorBuilder) buildCTE(v *plannercore.PhysicalCTE) exec.Executor { } // Setup storages. - tps := seedExec.Base().RetFieldTypes() + tps := seedExec.RetFieldTypes() resTbl = cteutil.NewStorageRowContainer(tps, chkSize) if err := resTbl.OpenAndRef(); err != nil { b.err = err diff --git a/pkg/executor/coprocessor.go b/pkg/executor/coprocessor.go index fd563965eb461..c77565aa35181 100644 --- a/pkg/executor/coprocessor.go +++ b/pkg/executor/coprocessor.go @@ -82,7 +82,7 @@ func (h *CoprocessorDAGHandler) HandleRequest(ctx context.Context, req *coproces } chk := exec.TryNewCacheChunk(e) - tps := e.Base().RetFieldTypes() + tps := e.RetFieldTypes() var totalChunks, partChunks []tipb.Chunk memTracker := h.sctx.GetSessionVars().StmtCtx.MemTracker for { @@ -125,7 +125,7 @@ func (h *CoprocessorDAGHandler) HandleStreamRequest(ctx context.Context, req *co } chk := exec.TryNewCacheChunk(e) - tps := e.Base().RetFieldTypes() + tps := e.RetFieldTypes() for { chk.Reset() if err = exec.Next(ctx, e, chk); err != nil { diff --git a/pkg/executor/cte.go b/pkg/executor/cte.go index 97fe2bb2c49cd..f820ea391b509 100644 --- a/pkg/executor/cte.go +++ b/pkg/executor/cte.go @@ -204,7 +204,7 @@ func (p *cteProducer) openProducer(ctx context.Context, cteExec *CTEExec) (err e // For non-recursive CTE, the result will be put into resTbl directly. // So no need to build iterOutTbl. // Construct iterOutTbl in Open() instead of buildCTE(), because its destruct is in Close(). - recursiveTypes := p.recursiveExec.Base().RetFieldTypes() + recursiveTypes := p.recursiveExec.RetFieldTypes() p.iterOutTbl = cteutil.NewStorageRowContainer(recursiveTypes, cteExec.MaxChunkSize()) if err = p.iterOutTbl.OpenAndRef(); err != nil { return err @@ -214,7 +214,7 @@ func (p *cteProducer) openProducer(ctx context.Context, cteExec *CTEExec) (err e if p.isDistinct { p.hashTbl = newConcurrentMapHashTable() p.hCtx = &hashContext{ - allTypes: cteExec.Base().RetFieldTypes(), + allTypes: cteExec.RetFieldTypes(), } // We use all columns to compute hash. p.hCtx.keyColIdx = make([]int, len(p.hCtx.allTypes)) diff --git a/pkg/executor/executor.go b/pkg/executor/executor.go index 885193f9d2a04..16eba4e2447d3 100644 --- a/pkg/executor/executor.go +++ b/pkg/executor/executor.go @@ -212,8 +212,7 @@ func (*globalPanicOnExceed) GetPriority() int64 { // newList creates a new List to buffer current executor's result. func newList(e exec.Executor) *chunk.List { - base := e.Base() - return chunk.NewList(base.RetFieldTypes(), base.InitCap(), base.MaxChunkSize()) + return chunk.NewList(e.RetFieldTypes(), e.InitCap(), e.MaxChunkSize()) } // CommandDDLJobsExec is the general struct for Cancel/Pause/Resume commands on @@ -2377,7 +2376,7 @@ func (w *checkIndexWorker) HandleTask(task checkIndexTask, _ func(workerpool.Non w.e.err.CompareAndSwap(nil, &err) } - se, err := w.e.Base().GetSysSession() + se, err := w.e.BaseExecutor.GetSysSession() if err != nil { trySaveErr(err) return @@ -2385,7 +2384,7 @@ func (w *checkIndexWorker) HandleTask(task checkIndexTask, _ func(workerpool.Non restoreCtx := w.initSessCtx(se) defer func() { restoreCtx() - w.e.Base().ReleaseSysSession(ctx, se) + w.e.BaseExecutor.ReleaseSysSession(ctx, se) }() var pkCols []string diff --git a/pkg/executor/index_lookup_join.go b/pkg/executor/index_lookup_join.go index f178e99f88080..afd9fbbf7e7e2 100644 --- a/pkg/executor/index_lookup_join.go +++ b/pkg/executor/index_lookup_join.go @@ -437,7 +437,7 @@ func (ow *outerWorker) buildTask(ctx context.Context) (*lookUpJoinTask, error) { } maxChunkSize := ow.ctx.GetSessionVars().MaxChunkSize for requiredRows > task.outerResult.Len() { - chk := ow.ctx.GetSessionVars().GetNewChunkWithCapacity(ow.outerCtx.rowTypes, maxChunkSize, maxChunkSize, ow.executor.Base().AllocPool) + chk := ow.executor.NewChunkWithCapacity(ow.outerCtx.rowTypes, maxChunkSize, maxChunkSize) chk = chk.SetRequiredRows(requiredRows, maxChunkSize) err := exec.Next(ctx, ow.executor, chk) if err != nil { @@ -468,7 +468,11 @@ func (ow *outerWorker) buildTask(ctx context.Context) (*lookUpJoinTask, error) { } task.encodedLookUpKeys = make([]*chunk.Chunk, task.outerResult.NumChunks()) for i := range task.encodedLookUpKeys { - task.encodedLookUpKeys[i] = ow.ctx.GetSessionVars().GetNewChunkWithCapacity([]*types.FieldType{types.NewFieldType(mysql.TypeBlob)}, task.outerResult.GetChunk(i).NumRows(), task.outerResult.GetChunk(i).NumRows(), ow.executor.Base().AllocPool) + task.encodedLookUpKeys[i] = ow.executor.NewChunkWithCapacity( + []*types.FieldType{types.NewFieldType(mysql.TypeBlob)}, + task.outerResult.GetChunk(i).NumRows(), + task.outerResult.GetChunk(i).NumRows(), + ) } return task, nil } diff --git a/pkg/executor/index_lookup_merge_join.go b/pkg/executor/index_lookup_merge_join.go index c28a61eddb1a2..f526058962998 100644 --- a/pkg/executor/index_lookup_merge_join.go +++ b/pkg/executor/index_lookup_merge_join.go @@ -343,7 +343,7 @@ func (*outerMergeWorker) pushToChan(ctx context.Context, task *lookUpMergeJoinTa func (omw *outerMergeWorker) buildTask(ctx context.Context) (*lookUpMergeJoinTask, error) { task := &lookUpMergeJoinTask{ results: make(chan *indexMergeJoinResult, numResChkHold), - outerResult: chunk.NewList(omw.rowTypes, omw.executor.Base().InitCap(), omw.executor.Base().MaxChunkSize()), + outerResult: chunk.NewList(omw.rowTypes, omw.executor.InitCap(), omw.executor.MaxChunkSize()), } task.memTracker = memory.NewTracker(memory.LabelForSimpleTask, -1) task.memTracker.AttachTo(omw.parentMemTracker) @@ -712,7 +712,7 @@ func (imw *innerMergeWorker) dedupDatumLookUpKeys(lookUpContents []*indexJoinLoo // fetchNextInnerResult collects a chunk of inner results from inner child executor. func (imw *innerMergeWorker) fetchNextInnerResult(ctx context.Context, task *lookUpMergeJoinTask) (beginRow chunk.Row, err error) { - task.innerResult = imw.ctx.GetSessionVars().GetNewChunkWithCapacity(exec.RetTypes(imw.innerExec), imw.ctx.GetSessionVars().MaxChunkSize, imw.ctx.GetSessionVars().MaxChunkSize, imw.innerExec.Base().AllocPool) + task.innerResult = imw.innerExec.NewChunkWithCapacity(imw.innerExec.RetFieldTypes(), imw.innerExec.MaxChunkSize(), imw.innerExec.MaxChunkSize()) err = exec.Next(ctx, imw.innerExec, task.innerResult) task.innerIter = chunk.NewIterator4Chunk(task.innerResult) beginRow = task.innerIter.Begin() diff --git a/pkg/executor/index_merge_reader.go b/pkg/executor/index_merge_reader.go index 2c22a983d7b5e..64077e7068c45 100644 --- a/pkg/executor/index_merge_reader.go +++ b/pkg/executor/index_merge_reader.go @@ -631,7 +631,7 @@ func (w *partialTableWorker) needPartitionHandle() (bool, error) { func (w *partialTableWorker) fetchHandles(ctx context.Context, exitCh <-chan struct{}, fetchCh chan<- *indexMergeTableTask, finished <-chan struct{}, handleCols plannercore.HandleCols, parTblIdx int, partialPlanIndex int) (count int64, err error) { - chk := w.sc.GetSessionVars().GetNewChunkWithCapacity(w.getRetTpsForTableScan(), w.maxChunkSize, w.maxChunkSize, w.tableReader.Base().AllocPool) + chk := w.tableReader.NewChunkWithCapacity(w.getRetTpsForTableScan(), w.maxChunkSize, w.maxBatchSize) for { start := time.Now() handles, retChunk, err := w.extractTaskHandles(ctx, chk, handleCols) @@ -686,8 +686,8 @@ func (w *partialTableWorker) extractTaskHandles(ctx context.Context, chk *chunk. if err != nil { return nil, nil, err } - if be := w.tableReader.Base(); be != nil && be.RuntimeStats() != nil { - be.RuntimeStats().Record(time.Since(start), chk.NumRows()) + if w.tableReader != nil && w.tableReader.RuntimeStats() != nil { + w.tableReader.RuntimeStats().Record(time.Since(start), chk.NumRows()) } if chk.NumRows() == 0 { failpoint.Inject("testIndexMergeErrorPartialTableWorker", func(v failpoint.Value) { diff --git a/pkg/executor/insert.go b/pkg/executor/insert.go index f3e304b3324c7..1c4903891090b 100644 --- a/pkg/executor/insert.go +++ b/pkg/executor/insert.go @@ -367,7 +367,7 @@ func (e *InsertExec) initEvalBuffer4Dup() { evalBufferTypes = append(evalBufferTypes, &(col.FieldType)) } if extraLen > 0 { - evalBufferTypes = append(evalBufferTypes, e.SelectExec.Base().RetFieldTypes()[e.rowLen:]...) + evalBufferTypes = append(evalBufferTypes, e.SelectExec.RetFieldTypes()[e.rowLen:]...) } for _, col := range e.Table.Cols() { evalBufferTypes = append(evalBufferTypes, &(col.FieldType)) diff --git a/pkg/executor/internal/exec/executor.go b/pkg/executor/internal/exec/executor.go index 6295523263e4a..c3440bed6f670 100644 --- a/pkg/executor/internal/exec/executor.go +++ b/pkg/executor/internal/exec/executor.go @@ -45,7 +45,15 @@ import ( // return a batch of rows, other than a single row in Volcano. // NOTE: Executors must call "chk.Reset()" before appending their results to it. type Executor interface { - Base() *BaseExecutor + NewChunk() *chunk.Chunk + NewChunkWithCapacity(fields []*types.FieldType, capacity int, maxCachesize int) *chunk.Chunk + + RuntimeStats() *execdetails.BasicRuntimeStats + + HandleSQLKillerSignal() error + RegisterSQLAndPlanInExecForTopSQL() + + AllChildren() []Executor Open(context.Context) error Next(ctx context.Context, req *chunk.Chunk) error Close() error @@ -156,11 +164,6 @@ func (e *BaseExecutor) SetMaxChunkSize(size int) { e.maxChunkSize = size } -// Base returns the BaseExecutor of an executor, don't override this method! -func (e *BaseExecutor) Base() *BaseExecutor { - return e -} - // Open initializes children recursively and "childrenResults" according to children's schemas. func (e *BaseExecutor) Open(ctx context.Context) error { for _, child := range e.children { @@ -239,23 +242,43 @@ func (e *BaseExecutor) ReleaseSysSession(ctx context.Context, sctx sessionctx.Co sysSessionPool.Put(sctx.(pools.Resource)) } +// NewChunk creates a new chunk according to the executor configuration +func (e *BaseExecutor) NewChunk() *chunk.Chunk { + return e.NewChunkWithCapacity(e.RetFieldTypes(), e.InitCap(), e.MaxChunkSize()) +} + +// NewChunkWithCapacity allows the caller to allocate the chunk with any types, capacity and max size in the pool +func (e *BaseExecutor) NewChunkWithCapacity(fields []*types.FieldType, capacity int, maxCachesize int) *chunk.Chunk { + return e.ctx.GetSessionVars().GetNewChunkWithCapacity(fields, capacity, maxCachesize, e.AllocPool) +} + +// HandleSQLKillerSignal handles the signal sent by SQLKiller +func (e *BaseExecutor) HandleSQLKillerSignal() error { + return e.ctx.GetSessionVars().SQLKiller.HandleSignal() +} + +// RegisterSQLAndPlanInExecForTopSQL registers the current SQL and Plan on top sql +// TODO: consider whether it's appropriate to have this on executor +func (e *BaseExecutor) RegisterSQLAndPlanInExecForTopSQL() { + sessVars := e.ctx.GetSessionVars() + if topsqlstate.TopSQLEnabled() && sessVars.StmtCtx.IsSQLAndPlanRegistered.CompareAndSwap(false, true) { + RegisterSQLAndPlanInExecForTopSQL(sessVars) + } +} + // TryNewCacheChunk tries to get a cached chunk func TryNewCacheChunk(e Executor) *chunk.Chunk { - base := e.Base() - s := base.Ctx().GetSessionVars() - return s.GetNewChunkWithCapacity(base.RetFieldTypes(), base.InitCap(), base.MaxChunkSize(), base.AllocPool) + return e.NewChunk() } // RetTypes returns all output column types. func RetTypes(e Executor) []*types.FieldType { - base := e.Base() - return base.RetFieldTypes() + return e.RetFieldTypes() } // NewFirstChunk creates a new chunk to buffer current executor's result. func NewFirstChunk(e Executor) *chunk.Chunk { - base := e.Base() - return chunk.New(base.RetFieldTypes(), base.InitCap(), base.MaxChunkSize()) + return chunk.New(e.RetFieldTypes(), e.InitCap(), e.MaxChunkSize()) } // Open is a wrapper function on e.Open(), it handles some common codes. @@ -270,29 +293,26 @@ func Open(ctx context.Context, e Executor) (err error) { // Next is a wrapper function on e.Next(), it handles some common codes. func Next(ctx context.Context, e Executor, req *chunk.Chunk) error { - base := e.Base() - if base.RuntimeStats() != nil { + if e.RuntimeStats() != nil { start := time.Now() - defer func() { base.RuntimeStats().Record(time.Since(start), req.NumRows()) }() + defer func() { e.RuntimeStats().Record(time.Since(start), req.NumRows()) }() } - sessVars := base.Ctx().GetSessionVars() - if err := sessVars.SQLKiller.HandleSignal(); err != nil { + + if err := e.HandleSQLKillerSignal(); err != nil { return err } r, ctx := tracing.StartRegionEx(ctx, fmt.Sprintf("%T.Next", e)) defer r.End() - if topsqlstate.TopSQLEnabled() && sessVars.StmtCtx.IsSQLAndPlanRegistered.CompareAndSwap(false, true) { - RegisterSQLAndPlanInExecForTopSQL(sessVars) - } + e.RegisterSQLAndPlanInExecForTopSQL() err := e.Next(ctx, req) if err != nil { return err } // recheck whether the session/query is killed during the Next() - return sessVars.SQLKiller.HandleSignal() + return e.HandleSQLKillerSignal() } // Close is a wrapper function on e.Close(), it handles some common codes. diff --git a/pkg/executor/join.go b/pkg/executor/join.go index 0495e22d6b5f0..e06db25ac1815 100644 --- a/pkg/executor/join.go +++ b/pkg/executor/join.go @@ -332,7 +332,7 @@ func (w *buildWorker) fetchBuildSideRows(ctx context.Context, chkCh chan<- *chun if w.hashJoinCtx.finished.Load() { return } - chk := sessVars.GetNewChunkWithCapacity(w.buildSideExec.Base().RetFieldTypes(), sessVars.MaxChunkSize, sessVars.MaxChunkSize, w.hashJoinCtx.allocPool) + chk := sessVars.GetNewChunkWithCapacity(w.buildSideExec.RetFieldTypes(), sessVars.MaxChunkSize, sessVars.MaxChunkSize, w.hashJoinCtx.allocPool) err = exec.Next(ctx, w.buildSideExec, chk) if err != nil { errCh <- errors.Trace(err) @@ -1355,7 +1355,7 @@ func (e *NestedLoopApplyExec) Open(ctx context.Context) error { // aggExecutorTreeInputEmpty checks whether the executor tree returns empty if without aggregate operators. // Note that, the prerequisite is that this executor tree has been executed already and it returns one row. func aggExecutorTreeInputEmpty(e exec.Executor) bool { - children := e.Base().AllChildren() + children := e.AllChildren() if len(children) == 0 { return false } diff --git a/pkg/executor/merge_join.go b/pkg/executor/merge_join.go index ef38c2486ee2e..2f129b1babeb0 100644 --- a/pkg/executor/merge_join.go +++ b/pkg/executor/merge_join.go @@ -90,7 +90,7 @@ func (t *mergeJoinTable) init(executor *MergeJoinExec) { t.groupRowsIter = chunk.NewIterator4Chunk(t.childChunk) if t.isInner { - t.rowContainer = chunk.NewRowContainer(child.Base().RetFieldTypes(), t.childChunk.Capacity()) + t.rowContainer = chunk.NewRowContainer(child.RetFieldTypes(), t.childChunk.Capacity()) t.rowContainer.GetMemTracker().AttachTo(executor.memTracker) t.rowContainer.GetMemTracker().SetLabel(memory.LabelForInnerTable) t.rowContainer.GetDiskTracker().AttachTo(executor.diskTracker) diff --git a/pkg/executor/point_get.go b/pkg/executor/point_get.go index 42b1d8930c68b..450012b2df1f0 100644 --- a/pkg/executor/point_get.go +++ b/pkg/executor/point_get.go @@ -63,8 +63,8 @@ func (b *executorBuilder) buildPointGet(p *plannercore.PointGetPlan) exec.Execut isStaleness: b.isStaleness, } - e.Base().SetInitCap(1) - e.Base().SetMaxChunkSize(1) + e.SetInitCap(1) + e.SetMaxChunkSize(1) e.Init(p) e.snapshot, err = b.getSnapshot() @@ -333,7 +333,7 @@ func (e *PointGetExecutor) Next(ctx context.Context, req *chunk.Chunk) error { } return nil } - err = DecodeRowValToChunk(e.Base().Ctx(), e.Schema(), e.tblInfo, e.handle, val, req, e.rowDecoder) + err = DecodeRowValToChunk(e.BaseExecutor.Ctx(), e.Schema(), e.tblInfo, e.handle, val, req, e.rowDecoder) if err != nil { return err } From b8fe33aa472dfdcdb5060123dd03e9cdffc8af7d Mon Sep 17 00:00:00 2001 From: lance6716 Date: Tue, 19 Dec 2023 18:39:22 +0800 Subject: [PATCH 6/9] storage: support parallel write for gcs (#49545) close pingcap/tidb#48443 --- br/pkg/lightning/backend/external/BUILD.bazel | 1 + .../lightning/backend/external/byte_reader.go | 3 + br/pkg/lightning/backend/external/reader.go | 11 +- br/pkg/lightning/backend/external/writer.go | 165 ++++--- br/pkg/storage/BUILD.bazel | 2 + br/pkg/storage/gcs.go | 32 +- br/pkg/storage/gcs_extra.go | 419 ++++++++++++++++++ br/pkg/storage/gcs_test.go | 39 ++ br/pkg/storage/parse_test.go | 6 + go.mod | 1 + go.sum | 3 + 11 files changed, 602 insertions(+), 80 deletions(-) create mode 100644 br/pkg/storage/gcs_extra.go diff --git a/br/pkg/lightning/backend/external/BUILD.bazel b/br/pkg/lightning/backend/external/BUILD.bazel index 89950d8348413..f2f3b0eb6317c 100644 --- a/br/pkg/lightning/backend/external/BUILD.bazel +++ b/br/pkg/lightning/backend/external/BUILD.bazel @@ -33,6 +33,7 @@ go_library( "//br/pkg/storage", "//pkg/kv", "//pkg/metrics", + "//pkg/util", "//pkg/util/hack", "//pkg/util/logutil", "//pkg/util/size", diff --git a/br/pkg/lightning/backend/external/byte_reader.go b/br/pkg/lightning/backend/external/byte_reader.go index 3ba4a978390c4..a982c4349513d 100644 --- a/br/pkg/lightning/backend/external/byte_reader.go +++ b/br/pkg/lightning/backend/external/byte_reader.go @@ -209,6 +209,9 @@ func (r *byteReader) readNBytes(n int) ([]byte, error) { return bs[0], nil } // need to flatten bs + if n <= 0 { + return nil, errors.Errorf("illegal n (%d) when reading from external storage", n) + } if n > int(size.GB) { return nil, errors.Errorf("read %d bytes from external storage, exceed max limit %d", n, size.GB) } diff --git a/br/pkg/lightning/backend/external/reader.go b/br/pkg/lightning/backend/external/reader.go index 34370d005fca9..d5f3d05ff55c6 100644 --- a/br/pkg/lightning/backend/external/reader.go +++ b/br/pkg/lightning/backend/external/reader.go @@ -20,13 +20,14 @@ import ( "io" "time" + "github.com/pingcap/errors" "github.com/pingcap/tidb/br/pkg/lightning/log" "github.com/pingcap/tidb/br/pkg/membuf" "github.com/pingcap/tidb/br/pkg/storage" "github.com/pingcap/tidb/pkg/metrics" + "github.com/pingcap/tidb/pkg/util" "github.com/pingcap/tidb/pkg/util/logutil" "go.uber.org/zap" - "golang.org/x/sync/errgroup" ) func readAllData( @@ -65,12 +66,13 @@ func readAllData( if err != nil { return err } - var eg errgroup.Group + eg, egCtx := util.NewErrorGroupWithRecoverWithCtx(ctx) + // TODO(lance6716): limit the concurrency of eg to 30 does not help for i := range dataFiles { i := i eg.Go(func() error { - return readOneFile( - ctx, + err2 := readOneFile( + egCtx, storage, dataFiles[i], startKey, @@ -80,6 +82,7 @@ func readAllData( bufPool, output, ) + return errors.Annotatef(err2, "failed to read file %s", dataFiles[i]) }) } return eg.Wait() diff --git a/br/pkg/lightning/backend/external/writer.go b/br/pkg/lightning/backend/external/writer.go index 25b15f23699a5..514937393978b 100644 --- a/br/pkg/lightning/backend/external/writer.go +++ b/br/pkg/lightning/backend/external/writer.go @@ -197,7 +197,6 @@ func (b *WriterBuilder) Build( filenamePrefix: filenamePrefix, keyAdapter: keyAdapter, writerID: writerID, - kvStore: nil, onClose: b.onClose, closed: false, multiFileStats: make([]MultipleFilesStat, 1), @@ -293,8 +292,7 @@ type Writer struct { filenamePrefix string keyAdapter common.KeyAdapter - kvStore *KeyValueStore - rc *rangePropertiesCollector + rc *rangePropertiesCollector memSizeLimit uint64 @@ -400,88 +398,53 @@ func (w *Writer) recordMinMax(newMin, newMax tidbkv.Key, size uint64) { w.totalSize += size } +const flushKVsRetryTimes = 3 + func (w *Writer) flushKVs(ctx context.Context, fromClose bool) (err error) { if len(w.kvLocations) == 0 { return nil } - logger := logutil.Logger(ctx) - dataFile, statFile, dataWriter, statWriter, err := w.createStorageWriter(ctx) - if err != nil { - return err - } - - var ( - savedBytes uint64 - statSize int - sortDuration, writeDuration time.Duration - writeStartTime time.Time + logger := logutil.Logger(ctx).With( + zap.String("writer-id", w.writerID), + zap.Int("sequence-number", w.currentSeq), ) - savedBytes = w.batchSize - startTs := time.Now() - - kvCnt := len(w.kvLocations) - defer func() { - w.currentSeq++ - err1, err2 := dataWriter.Close(ctx), statWriter.Close(ctx) - if err != nil { - return - } - if err1 != nil { - logger.Error("close data writer failed", zap.Error(err1)) - err = err1 - return - } - if err2 != nil { - logger.Error("close stat writer failed", zap.Error(err2)) - err = err2 - return - } - writeDuration = time.Since(writeStartTime) - logger.Info("flush kv", - zap.Uint64("bytes", savedBytes), - zap.Int("kv-cnt", kvCnt), - zap.Int("stat-size", statSize), - zap.Duration("sort-time", sortDuration), - zap.Duration("write-time", writeDuration), - zap.String("sort-speed(kv/s)", getSpeed(uint64(kvCnt), sortDuration.Seconds(), false)), - zap.String("write-speed(bytes/s)", getSpeed(savedBytes, writeDuration.Seconds(), true)), - zap.String("writer-id", w.writerID), - ) - metrics.GlobalSortWriteToCloudStorageDuration.WithLabelValues("write").Observe(writeDuration.Seconds()) - metrics.GlobalSortWriteToCloudStorageRate.WithLabelValues("write").Observe(float64(savedBytes) / 1024.0 / 1024.0 / writeDuration.Seconds()) - metrics.GlobalSortWriteToCloudStorageDuration.WithLabelValues("sort_and_write").Observe(time.Since(startTs).Seconds()) - metrics.GlobalSortWriteToCloudStorageRate.WithLabelValues("sort_and_write").Observe(float64(savedBytes) / 1024.0 / 1024.0 / time.Since(startTs).Seconds()) - }() - sortStart := time.Now() slices.SortFunc(w.kvLocations, func(i, j membuf.SliceLocation) int { return bytes.Compare(w.getKeyByLoc(i), w.getKeyByLoc(j)) }) - sortDuration = time.Since(sortStart) - - writeStartTime = time.Now() + sortDuration := time.Since(sortStart) metrics.GlobalSortWriteToCloudStorageDuration.WithLabelValues("sort").Observe(sortDuration.Seconds()) - metrics.GlobalSortWriteToCloudStorageRate.WithLabelValues("sort").Observe(float64(savedBytes) / 1024.0 / 1024.0 / sortDuration.Seconds()) - w.kvStore, err = NewKeyValueStore(ctx, dataWriter, w.rc) - if err != nil { - return err - } - - for _, pair := range w.kvLocations { - err = w.kvStore.addEncodedData(w.kvBuffer.GetSlice(pair)) - if err != nil { - return err + metrics.GlobalSortWriteToCloudStorageRate.WithLabelValues("sort").Observe(float64(w.batchSize) / 1024.0 / 1024.0 / sortDuration.Seconds()) + + writeStartTime := time.Now() + var dataFile, statFile string + for i := 0; i < flushKVsRetryTimes; i++ { + dataFile, statFile, err = w.flushSortedKVs(ctx) + if err == nil { + break } + logger.Warn("flush sorted kv failed", + zap.Error(err), + zap.Int("retry-count", i), + ) } - - w.kvStore.Close() - encodedStat := w.rc.encode() - statSize = len(encodedStat) - _, err = statWriter.Write(ctx, encodedStat) if err != nil { return err } + writeDuration := time.Since(writeStartTime) + kvCnt := len(w.kvLocations) + logger.Info("flush kv", + zap.Uint64("bytes", w.batchSize), + zap.Int("kv-cnt", kvCnt), + zap.Duration("sort-time", sortDuration), + zap.Duration("write-time", writeDuration), + zap.String("sort-speed(kv/s)", getSpeed(uint64(kvCnt), sortDuration.Seconds(), false)), + zap.String("writer-id", w.writerID), + ) + totalDuration := time.Since(sortStart) + metrics.GlobalSortWriteToCloudStorageDuration.WithLabelValues("sort_and_write").Observe(totalDuration.Seconds()) + metrics.GlobalSortWriteToCloudStorageRate.WithLabelValues("sort_and_write").Observe(float64(w.batchSize) / 1024.0 / 1024.0 / totalDuration.Seconds()) minKey, maxKey := w.getKeyByLoc(w.kvLocations[0]), w.getKeyByLoc(w.kvLocations[len(w.kvLocations)-1]) w.recordMinMax(minKey, maxKey, uint64(w.kvSize)) @@ -507,9 +470,73 @@ func (w *Writer) flushKVs(ctx context.Context, fromClose bool) (err error) { w.kvBuffer.Reset() w.rc.reset() w.batchSize = 0 + w.currentSeq++ return nil } +func (w *Writer) flushSortedKVs(ctx context.Context) (string, string, error) { + logger := logutil.Logger(ctx).With( + zap.String("writer-id", w.writerID), + zap.Int("sequence-number", w.currentSeq), + ) + writeStartTime := time.Now() + dataFile, statFile, dataWriter, statWriter, err := w.createStorageWriter(ctx) + if err != nil { + return "", "", err + } + defer func() { + // close the writers when meet error. If no error happens, writers will + // be closed outside and assigned to nil. + if dataWriter != nil { + _ = dataWriter.Close(ctx) + } + if statWriter != nil { + _ = statWriter.Close(ctx) + } + }() + kvStore, err := NewKeyValueStore(ctx, dataWriter, w.rc) + if err != nil { + return "", "", err + } + + for _, pair := range w.kvLocations { + err = kvStore.addEncodedData(w.kvBuffer.GetSlice(pair)) + if err != nil { + return "", "", err + } + } + + kvStore.Close() + encodedStat := w.rc.encode() + statSize := len(encodedStat) + _, err = statWriter.Write(ctx, encodedStat) + if err != nil { + return "", "", err + } + err = dataWriter.Close(ctx) + dataWriter = nil + if err != nil { + return "", "", err + } + err = statWriter.Close(ctx) + statWriter = nil + if err != nil { + return "", "", err + } + + writeDuration := time.Since(writeStartTime) + logger.Info("flush sorted kv", + zap.Uint64("bytes", w.batchSize), + zap.Int("stat-size", statSize), + zap.Duration("write-time", writeDuration), + zap.String("write-speed(bytes/s)", getSpeed(w.batchSize, writeDuration.Seconds(), true)), + ) + metrics.GlobalSortWriteToCloudStorageDuration.WithLabelValues("write").Observe(writeDuration.Seconds()) + metrics.GlobalSortWriteToCloudStorageRate.WithLabelValues("write").Observe(float64(w.batchSize) / 1024.0 / 1024.0 / writeDuration.Seconds()) + + return dataFile, statFile, nil +} + func (w *Writer) getKeyByLoc(loc membuf.SliceLocation) []byte { block := w.kvBuffer.GetSlice(loc) keyLen := binary.BigEndian.Uint64(block[:lengthBytes]) diff --git a/br/pkg/storage/BUILD.bazel b/br/pkg/storage/BUILD.bazel index 3c5dd4d662705..42e96b6126158 100644 --- a/br/pkg/storage/BUILD.bazel +++ b/br/pkg/storage/BUILD.bazel @@ -7,6 +7,7 @@ go_library( "compress.go", "flags.go", "gcs.go", + "gcs_extra.go", "hdfs.go", "helper.go", "ks3.go", @@ -49,6 +50,7 @@ go_library( "@com_github_azure_azure_sdk_for_go_sdk_storage_azblob//bloberror", "@com_github_azure_azure_sdk_for_go_sdk_storage_azblob//blockblob", "@com_github_azure_azure_sdk_for_go_sdk_storage_azblob//container", + "@com_github_go_resty_resty_v2//:resty", "@com_github_google_uuid//:uuid", "@com_github_klauspost_compress//gzip", "@com_github_klauspost_compress//snappy", diff --git a/br/pkg/storage/gcs.go b/br/pkg/storage/gcs.go index 915537eec9166..5afe739260b57 100644 --- a/br/pkg/storage/gcs.go +++ b/br/pkg/storage/gcs.go @@ -99,6 +99,7 @@ func (options *GCSBackendOptions) parseFromFlags(flags *pflag.FlagSet) error { type GCSStorage struct { gcs *backuppb.GCS bucket *storage.BucketHandle + cli *storage.Client } // GetBucketHandle gets the handle to the GCS API on the bucket. @@ -272,12 +273,29 @@ func (s *GCSStorage) URI() string { } // Create implements ExternalStorage interface. -func (s *GCSStorage) Create(ctx context.Context, name string, _ *WriterOption) (ExternalFileWriter, error) { - object := s.objectName(name) - wc := s.bucket.Object(object).NewWriter(ctx) - wc.StorageClass = s.gcs.StorageClass - wc.PredefinedACL = s.gcs.PredefinedAcl - return newFlushStorageWriter(wc, &emptyFlusher{}, wc), nil +func (s *GCSStorage) Create(ctx context.Context, name string, wo *WriterOption) (ExternalFileWriter, error) { + // NewGCSWriter requires real testing environment on Google Cloud. + mockGCS := intest.InTest && strings.Contains(s.gcs.GetEndpoint(), "127.0.0.1") + if wo == nil || wo.Concurrency <= 1 || mockGCS { + object := s.objectName(name) + wc := s.bucket.Object(object).NewWriter(ctx) + wc.StorageClass = s.gcs.StorageClass + wc.PredefinedACL = s.gcs.PredefinedAcl + return newFlushStorageWriter(wc, &emptyFlusher{}, wc), nil + } + uri := s.objectName(name) + // 5MB is the minimum part size for GCS. + partSize := int64(gcsMinimumChunkSize) + if wo.PartSize > partSize { + partSize = wo.PartSize + } + w, err := NewGCSWriter(ctx, s.cli, uri, partSize, wo.Concurrency, s.gcs.Bucket) + if err != nil { + return nil, errors.Trace(err) + } + fw := newFlushStorageWriter(w, &emptyFlusher{}, w) + bw := newBufferedWriter(fw, int(partSize), NoCompression) + return bw, nil } // Rename file name from oldFileName to newFileName. @@ -371,7 +389,7 @@ skipHandleCred: // so we need find sst in slash directory gcs.Prefix += "//" } - return &GCSStorage{gcs: gcs, bucket: bucket}, nil + return &GCSStorage{gcs: gcs, bucket: bucket, cli: client}, nil } func hasSSTFiles(ctx context.Context, bucket *storage.BucketHandle, prefix string) bool { diff --git a/br/pkg/storage/gcs_extra.go b/br/pkg/storage/gcs_extra.go new file mode 100644 index 0000000000000..15b3ef5715d5a --- /dev/null +++ b/br/pkg/storage/gcs_extra.go @@ -0,0 +1,419 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Learned from https://github.com/liqiuqing/gcsmpu + +package storage + +import ( + "bytes" + "context" + "encoding/xml" + "fmt" + "net" + "net/http" + "net/url" + "runtime" + "slices" + "strconv" + "sync" + "time" + + "cloud.google.com/go/storage" + "github.com/go-resty/resty/v2" + "go.uber.org/atomic" +) + +// GCSWriter uses XML multipart upload API to upload a single file. +// https://cloud.google.com/storage/docs/multipart-uploads. +// GCSWriter will attempt to cancel uploads that fail due to an exception. +// If the upload fails in a way that precludes cancellation, such as a +// hardware failure, process termination, or power outage, then the incomplete +// upload may persist indefinitely. To mitigate this, set the +// `AbortIncompleteMultipartUpload` with a nonzero `Age` in bucket lifecycle +// rules, or refer to the XML API documentation linked above to learn more +// about how to list and delete individual downloads. +type GCSWriter struct { + uploadBase + mutex sync.Mutex + xmlMPUParts []*xmlMPUPart + wg sync.WaitGroup + err atomic.Error + chunkSize int64 + workers int + totalSize int64 + uploadID string + chunkCh chan chunk + curPart int +} + +// NewGCSWriter returns a GCSWriter which uses GCS multipart upload API behind the scene. +func NewGCSWriter( + ctx context.Context, + cli *storage.Client, + uri string, + partSize int64, + parallelCnt int, + bucketName string, +) (*GCSWriter, error) { + if partSize < gcsMinimumChunkSize || partSize > gcsMaximumChunkSize { + return nil, fmt.Errorf( + "invalid chunk size: %d. Chunk size must be between %d and %d", + partSize, gcsMinimumChunkSize, gcsMaximumChunkSize, + ) + } + + w := &GCSWriter{ + uploadBase: uploadBase{ + ctx: ctx, + cli: cli, + bucket: bucketName, + blob: uri, + retry: defaultRetry, + signedURLExpiry: defaultSignedURLExpiry, + }, + chunkSize: partSize, + workers: parallelCnt, + } + if err := w.init(); err != nil { + return nil, fmt.Errorf("failed to initiate GCSWriter: %w", err) + } + + return w, nil +} + +func (w *GCSWriter) init() error { + opts := &storage.SignedURLOptions{ + Scheme: storage.SigningSchemeV4, + Method: "POST", + Expires: time.Now().Add(w.signedURLExpiry), + QueryParameters: url.Values{mpuInitiateQuery: []string{""}}, + } + u, err := w.cli.Bucket(w.bucket).SignedURL(w.blob, opts) + if err != nil { + return fmt.Errorf("Bucket(%q).SignedURL: %s", w.bucket, err) + } + + client := resty.New() + resp, err := client.R().Post(u) + if err != nil { + return fmt.Errorf("POST request failed: %s", err) + } + + if resp.StatusCode() != http.StatusOK { + return fmt.Errorf("POST request returned non-OK status: %d", resp.StatusCode()) + } + body := resp.Body() + + result := InitiateMultipartUploadResult{} + err = xml.Unmarshal(body, &result) + if err != nil { + return fmt.Errorf("failed to unmarshal response body: %s", err) + } + + uploadID := result.UploadId + w.uploadID = uploadID + w.chunkCh = make(chan chunk) + for i := 0; i < w.workers; i++ { + w.wg.Add(1) + go w.readChunk(w.chunkCh) + } + w.curPart = 1 + return nil +} + +func (w *GCSWriter) readChunk(ch chan chunk) { + defer w.wg.Done() + for { + data, ok := <-ch + if !ok { + break + } + + select { + case <-w.ctx.Done(): + data.cleanup() + return + default: + part := &xmlMPUPart{ + uploadBase: w.uploadBase, + uploadID: w.uploadID, + buf: data.buf, + partNumber: data.num, + } + if w.err.Load() == nil { + if err := part.Upload(); err != nil { + w.err.Store(err) + } + } + part.buf = nil + w.appendMPUPart(part) + data.cleanup() + } + } +} + +// Write uploads given bytes as a part to Google Cloud Storage. Write is not +// concurrent safe. +func (w *GCSWriter) Write(p []byte) (n int, err error) { + if w.curPart > gcsMaximumParts { + err = fmt.Errorf("exceed maximum parts %d", gcsMaximumParts) + if w.err.Load() == nil { + w.err.Store(err) + } + return 0, err + } + buf := make([]byte, len(p)) + copy(buf, p) + w.chunkCh <- chunk{ + buf: buf, + num: w.curPart, + cleanup: func() {}, + } + w.curPart++ + return len(p), nil +} + +// Close finishes the upload. +func (w *GCSWriter) Close() error { + close(w.chunkCh) + w.wg.Wait() + + if err := w.err.Load(); err != nil { + return err + } + + err := w.finalizeXMLMPU() + if err == nil { + return nil + } + errC := w.cancel() + if errC != nil { + return fmt.Errorf("failed to finalize multipart upload: %s, Failed to cancel multipart upload: %s", err, errC) + } + return fmt.Errorf("failed to finalize multipart upload: %s", err) +} + +const ( + mpuInitiateQuery = "uploads" + mpuPartNumberQuery = "partNumber" + mpuUploadIDQuery = "uploadId" +) + +type uploadBase struct { + cli *storage.Client + ctx context.Context + bucket string + blob string + retry int + signedURLExpiry time.Duration +} + +const ( + defaultRetry = 3 + defaultSignedURLExpiry = 6 * time.Hour + + gcsMinimumChunkSize = 5 * 1024 * 1024 // 5 MB + gcsMaximumChunkSize = 5 * 1024 * 1024 * 1024 // 5 GB + gcsMaximumParts = 10000 +) + +type InitiateMultipartUploadResult struct { + XMLName xml.Name `xml:"InitiateMultipartUploadResult"` + Text string `xml:",chardata"` + Xmlns string `xml:"xmlns,attr"` + Bucket string `xml:"Bucket"` + Key string `xml:"Key"` + UploadId string `xml:"UploadId"` +} + +type Part struct { + Text string `xml:",chardata"` + PartNumber int `xml:"PartNumber"` + ETag string `xml:"ETag"` +} + +type CompleteMultipartUpload struct { + XMLName xml.Name `xml:"CompleteMultipartUpload"` + Text string `xml:",chardata"` + Parts []Part `xml:"Part"` +} + +func (w *GCSWriter) finalizeXMLMPU() error { + finalXMLRoot := CompleteMultipartUpload{ + Parts: make([]Part, 0, len(w.xmlMPUParts)), + } + slices.SortFunc(w.xmlMPUParts, func(a, b *xmlMPUPart) int { + return a.partNumber - b.partNumber + }) + for _, part := range w.xmlMPUParts { + part := Part{ + PartNumber: part.partNumber, + ETag: part.etag, + } + finalXMLRoot.Parts = append(finalXMLRoot.Parts, part) + } + + xmlBytes, err := xml.Marshal(finalXMLRoot) + if err != nil { + return fmt.Errorf("failed to encode XML: %v", err) + } + + opts := &storage.SignedURLOptions{ + Scheme: storage.SigningSchemeV4, + Method: "POST", + Expires: time.Now().Add(w.signedURLExpiry), + QueryParameters: url.Values{mpuUploadIDQuery: []string{w.uploadID}}, + } + u, err := w.cli.Bucket(w.bucket).SignedURL(w.blob, opts) + if err != nil { + return fmt.Errorf("Bucket(%q).SignedURL: %s", w.bucket, err) + } + + client := resty.New() + resp, err := client.R().SetBody(xmlBytes).Post(u) + if err != nil { + return fmt.Errorf("POST request failed: %s", err) + } + + if resp.StatusCode() != http.StatusOK { + return fmt.Errorf("POST request returned non-OK status: %d, body: %s", resp.StatusCode(), resp.String()) + } + return nil +} + +type chunk struct { + buf []byte + num int + cleanup func() +} + +func (w *GCSWriter) appendMPUPart(part *xmlMPUPart) { + w.mutex.Lock() + defer w.mutex.Unlock() + + w.xmlMPUParts = append(w.xmlMPUParts, part) +} + +func (w *GCSWriter) cancel() error { + opts := &storage.SignedURLOptions{ + Scheme: storage.SigningSchemeV4, + Method: "DELETE", + Expires: time.Now().Add(w.signedURLExpiry), + QueryParameters: url.Values{mpuUploadIDQuery: []string{w.uploadID}}, + } + u, err := w.cli.Bucket(w.bucket).SignedURL(w.blob, opts) + if err != nil { + return fmt.Errorf("Bucket(%q).SignedURL: %s", w.bucket, err) + } + + client := resty.New() + resp, err := client.R().Delete(u) + if err != nil { + return fmt.Errorf("DELETE request failed: %s", err) + } + + if resp.StatusCode() != http.StatusNoContent { + return fmt.Errorf("DELETE request returned non-204 status: %d", resp.StatusCode()) + } + + return nil +} + +type xmlMPUPart struct { + uploadBase + buf []byte + uploadID string + partNumber int + etag string +} + +func (p *xmlMPUPart) Clone() *xmlMPUPart { + return &xmlMPUPart{ + uploadBase: p.uploadBase, + uploadID: p.uploadID, + buf: p.buf, + partNumber: p.partNumber, + } +} + +func (p *xmlMPUPart) Upload() error { + var err error + for i := 0; i < p.retry; i++ { + err = p.upload() + if err == nil { + return nil + } + } + + return fmt.Errorf("failed to upload part %d: %w", p.partNumber, err) +} + +func (p *xmlMPUPart) upload() error { + opts := &storage.SignedURLOptions{ + Scheme: storage.SigningSchemeV4, + Method: "PUT", + Expires: time.Now().Add(p.signedURLExpiry), + QueryParameters: url.Values{ + mpuUploadIDQuery: []string{p.uploadID}, + mpuPartNumberQuery: []string{strconv.Itoa(p.partNumber)}, + }, + } + + u, err := p.cli.Bucket(p.bucket).SignedURL(p.blob, opts) + if err != nil { + return fmt.Errorf("Bucket(%q).SignedURL: %s", p.bucket, err) + } + + req, err := http.NewRequest("PUT", u, bytes.NewReader(p.buf)) + if err != nil { + return fmt.Errorf("PUT request failed: %s", err) + } + req = req.WithContext(p.ctx) + + client := &http.Client{ + Transport: createTransport(nil), + } + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("PUT request failed: %s", err) + } + defer resp.Body.Close() + + p.etag = resp.Header.Get("ETag") + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("PUT request returned non-OK status: %d", resp.StatusCode) + } + return nil +} + +func createTransport(localAddr net.Addr) *http.Transport { + dialer := &net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + } + if localAddr != nil { + dialer.LocalAddr = localAddr + } + return &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: dialer.DialContext, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + MaxIdleConnsPerHost: runtime.GOMAXPROCS(0) + 1, + } +} diff --git a/br/pkg/storage/gcs_test.go b/br/pkg/storage/gcs_test.go index f9346caffde80..6f9cbfa7ef687 100644 --- a/br/pkg/storage/gcs_test.go +++ b/br/pkg/storage/gcs_test.go @@ -3,7 +3,10 @@ package storage import ( + "bytes" "context" + "crypto/rand" + "flag" "fmt" "io" "os" @@ -460,3 +463,39 @@ func TestReadRange(t *testing.T) { require.NoError(t, err) require.Equal(t, []byte("234"), content[:n]) } + +var testingStorageURI = flag.String("testing-storage-uri", "", "the URI of the storage used for testing") + +func openTestingStorage(t *testing.T) ExternalStorage { + if *testingStorageURI == "" { + t.Skip("testingStorageURI is not set") + } + s, err := NewFromURL(context.Background(), *testingStorageURI) + require.NoError(t, err) + return s +} + +func TestMultiPartUpload(t *testing.T) { + ctx := context.Background() + + s := openTestingStorage(t) + if _, ok := s.(*GCSStorage); !ok { + t.Skipf("only test GCSStorage, got %T", s) + } + + filename := "TestMultiPartUpload" + // just get some random content, use any seed is enough + data := make([]byte, 100*1024*1024) + rand.Read(data) + w, err := s.Create(ctx, filename, &WriterOption{Concurrency: 10}) + require.NoError(t, err) + _, err = w.Write(ctx, data) + require.NoError(t, err) + err = w.Close(ctx) + require.NoError(t, err) + + got, err := s.ReadFile(ctx, filename) + require.NoError(t, err) + cmp := bytes.Compare(data, got) + require.Zero(t, cmp) +} diff --git a/br/pkg/storage/parse_test.go b/br/pkg/storage/parse_test.go index 4e8884b557961..0669564961c77 100644 --- a/br/pkg/storage/parse_test.go +++ b/br/pkg/storage/parse_test.go @@ -147,6 +147,12 @@ func TestCreateStorage(t *testing.T) { require.Equal(t, "https://gcs.example.com/", gcs.Endpoint) require.Equal(t, "fakeCredentials", gcs.CredentialsBlob) + s, err = ParseBackend("gcs://bucket?endpoint=http://127.0.0.1/", gcsOpt) + require.NoError(t, err) + gcs = s.GetGcs() + require.NotNil(t, gcs) + require.Equal(t, "http://127.0.0.1/", gcs.Endpoint) + err = os.WriteFile(fakeCredentialsFile, []byte("fakeCreds2"), credFilePerm) require.NoError(t, err) s, err = ParseBackend("gs://bucket4/backup/?credentials-file="+url.QueryEscape(fakeCredentialsFile), nil) diff --git a/go.mod b/go.mod index 02ab1ea21f3ae..fbdb378bf6563 100644 --- a/go.mod +++ b/go.mod @@ -41,6 +41,7 @@ require ( github.com/fatih/color v1.15.0 github.com/fsouza/fake-gcs-server v1.44.0 github.com/go-ldap/ldap/v3 v3.4.4 + github.com/go-resty/resty/v2 v2.7.0 github.com/go-sql-driver/mysql v1.7.1 github.com/gogo/protobuf v1.3.2 github.com/golang/protobuf v1.5.3 diff --git a/go.sum b/go.sum index aa001d8a15894..1ff7a08d4dc50 100644 --- a/go.sum +++ b/go.sum @@ -295,6 +295,8 @@ github.com/go-martini/martini v0.0.0-20170121215854-22fa46961aab/go.mod h1:/P9AE github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE= github.com/go-ole/go-ole v1.3.0/go.mod h1:5LS6F96DhAwUc7C+1HLexzMXY1xGRSryjyPPKW6zv78= +github.com/go-resty/resty/v2 v2.7.0 h1:me+K9p3uhSmXtrBZ4k9jcEAfJmuC8IivWHwaLZwPrFY= +github.com/go-resty/resty/v2 v2.7.0/go.mod h1:9PWDzw47qPphMRFfhsyk0NnSgvluHcljSMVIq3w7q0I= github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI= github.com/go-sql-driver/mysql v1.7.1/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= @@ -1095,6 +1097,7 @@ golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.0.0-20210427231257-85d9c07bbe3a/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk= golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.0.0-20211029224645-99673261e6eb/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220517181318-183a9ca12b87/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= From 57c26c0f9769275baefb9fb12e9c7eea5b5fb543 Mon Sep 17 00:00:00 2001 From: Arenatlx <314806019@qq.com> Date: Tue, 19 Dec 2023 19:22:53 +0800 Subject: [PATCH 7/9] ddl: refactor ddl pkg's warning and note generation (#49581) close pingcap/tidb#49291 --- pkg/ddl/ddl_api.go | 58 +++++++++++++++-------------- pkg/ddl/partition.go | 6 +-- pkg/ddl/schematracker/dm_tracker.go | 2 +- 3 files changed, 34 insertions(+), 32 deletions(-) diff --git a/pkg/ddl/ddl_api.go b/pkg/ddl/ddl_api.go index 0d513fd2c9d6d..66b394a3c1172 100644 --- a/pkg/ddl/ddl_api.go +++ b/pkg/ddl/ddl_api.go @@ -181,6 +181,7 @@ func (d *ddl) CreateSchemaWithInfo( is := d.GetInfoSchemaWithInterceptor(ctx) _, ok := is.SchemaByName(dbInfo.Name) if ok { + // since this error may be seen as error, keep it stack info. err := infoschema.ErrDatabaseExists.GenWithStackByArgs(dbInfo.Name) switch onExist { case OnExistIgnore: @@ -775,13 +776,13 @@ func buildColumnsAndConstraints( // No warning for BOOL-like tinyint(1) if colDef.Tp.GetFlen() != types.UnspecifiedLength && colDef.Tp.GetFlen() != 1 { ctx.GetSessionVars().StmtCtx.AppendWarning( - dbterror.ErrWarnDeprecatedIntegerDisplayWidth.GenWithStackByArgs(), + dbterror.ErrWarnDeprecatedIntegerDisplayWidth.FastGenByArgs(), ) } case mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong: if colDef.Tp.GetFlen() != types.UnspecifiedLength { ctx.GetSessionVars().StmtCtx.AppendWarning( - dbterror.ErrWarnDeprecatedIntegerDisplayWidth.GenWithStackByArgs(), + dbterror.ErrWarnDeprecatedIntegerDisplayWidth.FastGenByArgs(), ) } } @@ -793,7 +794,7 @@ func buildColumnsAndConstraints( col.State = model.StatePublic if mysql.HasZerofillFlag(col.GetFlag()) { ctx.GetSessionVars().StmtCtx.AppendWarning( - dbterror.ErrWarnDeprecatedZerofill.GenWithStackByArgs(), + dbterror.ErrWarnDeprecatedZerofill.FastGenByArgs(), ) } constraints = append(constraints, cts...) @@ -1009,7 +1010,7 @@ func checkColumnDefaultValue(ctx sessionctx.Context, col *table.Column, value in value = `null` } sc := ctx.GetSessionVars().StmtCtx - sc.AppendWarning(dbterror.ErrBlobCantHaveDefault.GenWithStackByArgs(col.Name.O)) + sc.AppendWarning(dbterror.ErrBlobCantHaveDefault.FastGenByArgs(col.Name.O)) return hasDefaultValue, value, nil } // In strict SQL mode or default value is not an empty string. @@ -1223,10 +1224,10 @@ func columnDefToCol(ctx sessionctx.Context, offset int, colDef *ast.ColumnDef, o col.FieldType.SetCollate(v.StrValue) } case ast.ColumnOptionFulltext: - ctx.GetSessionVars().StmtCtx.AppendWarning(dbterror.ErrTableCantHandleFt.GenWithStackByArgs()) + ctx.GetSessionVars().StmtCtx.AppendWarning(dbterror.ErrTableCantHandleFt.FastGenByArgs()) case ast.ColumnOptionCheck: if !variable.EnableCheckConstraint.Load() { - ctx.GetSessionVars().StmtCtx.AppendWarning(errors.New("the switch of check constraint is off")) + ctx.GetSessionVars().StmtCtx.AppendWarning(errors.NewNoStackError("the switch of check constraint is off")) } else { // Check the column CHECK constraint dependency lazily, after fill all the name. // Extract column constraint from column option. @@ -1922,7 +1923,7 @@ func setTableAutoRandomBits(ctx sessionctx.Context, tbInfo *model.TableInfo, col return dbterror.ErrInvalidAutoRandom.FastGenByArgs(autoid.AutoRandomIncrementalBitsTooSmall) } msg := fmt.Sprintf(autoid.AutoRandomAvailableAllocTimesNote, shardFmt.IncrementalBitsCapacity()) - ctx.GetSessionVars().StmtCtx.AppendNote(errors.Errorf(msg)) + ctx.GetSessionVars().StmtCtx.AppendNote(errors.NewNoStackError(msg)) } } return nil @@ -2033,7 +2034,7 @@ func BuildTableInfo( } if constr.Tp == ast.ConstraintFulltext { - ctx.GetSessionVars().StmtCtx.AppendWarning(dbterror.ErrTableCantHandleFt.GenWithStackByArgs()) + ctx.GetSessionVars().StmtCtx.AppendWarning(dbterror.ErrTableCantHandleFt.FastGenByArgs()) continue } @@ -2055,7 +2056,7 @@ func BuildTableInfo( // check constraint if constr.Tp == ast.ConstraintCheck { if !variable.EnableCheckConstraint.Load() { - ctx.GetSessionVars().StmtCtx.AppendWarning(errors.New("the switch of check constraint is off")) + ctx.GetSessionVars().StmtCtx.AppendWarning(errors.NewNoStackError("the switch of check constraint is off")) continue } // Since column check constraint dependency has been done in columnDefToCol. @@ -2938,7 +2939,7 @@ func (d *ddl) FlashbackCluster(ctx sessionctx.Context, flashbackTS uint64) error } gap := time.Until(oracle.GetTimeFromTS(nowTS)).Abs() if gap > 1*time.Second { - ctx.GetSessionVars().StmtCtx.AppendWarning(errors.Errorf("Gap between local time and PD TSO is %s, please check PD/system time", gap)) + ctx.GetSessionVars().StmtCtx.AppendWarning(errors.NewNoStackErrorf("Gap between local time and PD TSO is %s, please check PD/system time", gap)) } job := &model.Job{ Type: model.ActionFlashbackCluster, @@ -3729,7 +3730,7 @@ func (d *ddl) AlterTable(ctx context.Context, sctx sessionctx.Context, stmt *ast sctx.GetSessionVars().StmtCtx.AppendWarning(dbterror.ErrTableCantHandleFt) case ast.ConstraintCheck: if !variable.EnableCheckConstraint.Load() { - sctx.GetSessionVars().StmtCtx.AppendWarning(errors.New("the switch of check constraint is off")) + sctx.GetSessionVars().StmtCtx.AppendWarning(errors.NewNoStackError("the switch of check constraint is off")) } else { err = d.CreateCheckConstraint(sctx, ident, model.NewCIStr(constr.Name), spec.Constraint) } @@ -3830,13 +3831,13 @@ func (d *ddl) AlterTable(ctx context.Context, sctx sessionctx.Context, stmt *ast err = d.AlterIndexVisibility(sctx, ident, spec.IndexName, spec.Visibility) case ast.AlterTableAlterCheck: if !variable.EnableCheckConstraint.Load() { - sctx.GetSessionVars().StmtCtx.AppendWarning(errors.New("the switch of check constraint is off")) + sctx.GetSessionVars().StmtCtx.AppendWarning(errors.NewNoStackError("the switch of check constraint is off")) } else { err = d.AlterCheckConstraint(sctx, ident, model.NewCIStr(spec.Constraint.Name), spec.Constraint.Enforced) } case ast.AlterTableDropCheck: if !variable.EnableCheckConstraint.Load() { - sctx.GetSessionVars().StmtCtx.AppendWarning(errors.New("the switch of check constraint is off")) + sctx.GetSessionVars().StmtCtx.AppendWarning(errors.NewNoStackError("the switch of check constraint is off")) } else { err = d.DropCheckConstraint(sctx, ident, model.NewCIStr(spec.Constraint.Name)) } @@ -3917,7 +3918,7 @@ func (d *ddl) RebaseAutoID(ctx sessionctx.Context, ident ast.Ident, newBase int6 } if newBase != newBaseTemp { ctx.GetSessionVars().StmtCtx.AppendWarning( - fmt.Errorf("Can't reset AUTO_INCREMENT to %d without FORCE option, using %d instead", + errors.NewNoStackErrorf("Can't reset AUTO_INCREMENT to %d without FORCE option, using %d instead", newBase, newBaseTemp, )) } @@ -4411,7 +4412,7 @@ func (d *ddl) AlterTablePartitioning(ctx sessionctx.Context, ident ast.Ident, sp err = d.DoDDLJob(ctx, job) err = d.callHookOnChanged(job, err) if err == nil { - ctx.GetSessionVars().StmtCtx.AppendWarning(errors.New("The statistics of new partitions will be outdated after reorganizing partitions. Please use 'ANALYZE TABLE' statement if you want to update it now")) + ctx.GetSessionVars().StmtCtx.AppendWarning(errors.NewNoStackError("The statistics of new partitions will be outdated after reorganizing partitions. Please use 'ANALYZE TABLE' statement if you want to update it now")) } return errors.Trace(err) } @@ -4475,7 +4476,7 @@ func (d *ddl) ReorganizePartitions(ctx sessionctx.Context, ident ast.Ident, spec err = d.DoDDLJob(ctx, job) err = d.callHookOnChanged(job, err) if err == nil { - ctx.GetSessionVars().StmtCtx.AppendWarning(errors.New("The statistics of related partitions will be outdated after reorganizing partitions. Please use 'ANALYZE TABLE' statement if you want to update it now")) + ctx.GetSessionVars().StmtCtx.AppendWarning(errors.NewNoStackError("The statistics of related partitions will be outdated after reorganizing partitions. Please use 'ANALYZE TABLE' statement if you want to update it now")) } return errors.Trace(err) } @@ -5049,7 +5050,7 @@ func (d *ddl) ExchangeTablePartition(ctx sessionctx.Context, ident ast.Ident, sp if err != nil { return errors.Trace(err) } - ctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("after the exchange, please analyze related table of the exchange to update statistics")) + ctx.GetSessionVars().StmtCtx.AppendWarning(errors.NewNoStackError("after the exchange, please analyze related table of the exchange to update statistics")) err = d.callHookOnChanged(job, err) return errors.Trace(err) } @@ -5890,7 +5891,7 @@ func (d *ddl) ChangeColumn(ctx context.Context, sctx sessionctx.Context, ident a job, err := d.getModifiableColumnJob(ctx, sctx, ident, spec.OldColumnName.Name, spec) if err != nil { if infoschema.ErrColumnNotExists.Equal(err) && spec.IfExists { - sctx.GetSessionVars().StmtCtx.AppendNote(infoschema.ErrColumnNotExists.GenWithStackByArgs(spec.OldColumnName.Name, ident.Name)) + sctx.GetSessionVars().StmtCtx.AppendNote(infoschema.ErrColumnNotExists.FastGenByArgs(spec.OldColumnName.Name, ident.Name)) return nil } return errors.Trace(err) @@ -5981,7 +5982,7 @@ func (d *ddl) ModifyColumn(ctx context.Context, sctx sessionctx.Context, ident a job, err := d.getModifiableColumnJob(ctx, sctx, ident, originalColName, spec) if err != nil { if infoschema.ErrColumnNotExists.Equal(err) && spec.IfExists { - sctx.GetSessionVars().StmtCtx.AppendNote(infoschema.ErrColumnNotExists.GenWithStackByArgs(originalColName, ident.Name)) + sctx.GetSessionVars().StmtCtx.AppendNote(infoschema.ErrColumnNotExists.FastGenByArgs(originalColName, ident.Name)) return nil } return errors.Trace(err) @@ -6381,7 +6382,7 @@ func (d *ddl) AlterTableAddStatistics(ctx sessionctx.Context, ident ast.Ident, s return infoschema.ErrColumnNotExists.GenWithStackByArgs(colName.Name, ident.Name) } if stats.StatsType == ast.StatsTypeCorrelation && tblInfo.PKIsHandle && mysql.HasPriKeyFlag(col.GetFlag()) { - ctx.GetSessionVars().StmtCtx.AppendWarning(errors.New("No need to create correlation statistics on the integer primary key column")) + ctx.GetSessionVars().StmtCtx.AppendWarning(errors.NewNoStackError("No need to create correlation statistics on the integer primary key column")) return nil } if _, exist := colIDSet[col.ID]; exist { @@ -7735,6 +7736,7 @@ func validateCommentLength(vars *variable.SessionVars, name string, comment *str if len(*comment) > maxLen { err := errTooLongComment.GenWithStackByArgs(name, maxLen) if vars.StrictSQLMode { + // may be treated like an error. return "", err } vars.StmtCtx.AppendWarning(err) @@ -8163,7 +8165,7 @@ func (d *ddl) OrderByColumns(ctx sessionctx.Context, ident ast.Ident) error { return errors.Trace(err) } if tb.Meta().GetPkColInfo() != nil { - ctx.GetSessionVars().StmtCtx.AppendWarning(errors.Errorf("ORDER BY ignored as there is a user-defined clustered index in the table '%s'", ident.Name)) + ctx.GetSessionVars().StmtCtx.AppendWarning(errors.NewNoStackErrorf("ORDER BY ignored as there is a user-defined clustered index in the table '%s'", ident.Name)) } return nil } @@ -8463,7 +8465,7 @@ func handleDatabasePlacement(ctx sessionctx.Context, dbInfo *model.DBInfo) error if sessVars.PlacementMode == variable.PlacementModeIgnore { dbInfo.PlacementPolicyRef = nil sessVars.StmtCtx.AppendNote( - fmt.Errorf("Placement is ignored when TIDB_PLACEMENT_MODE is '%s'", variable.PlacementModeIgnore), + errors.NewNoStackErrorf("Placement is ignored when TIDB_PLACEMENT_MODE is '%s'", variable.PlacementModeIgnore), ) return nil } @@ -8477,7 +8479,7 @@ func handleTablePlacement(ctx sessionctx.Context, tbInfo *model.TableInfo) error sessVars := ctx.GetSessionVars() if sessVars.PlacementMode == variable.PlacementModeIgnore && removeTablePlacement(tbInfo) { sessVars.StmtCtx.AppendNote( - fmt.Errorf("Placement is ignored when TIDB_PLACEMENT_MODE is '%s'", variable.PlacementModeIgnore), + errors.NewNoStackErrorf("Placement is ignored when TIDB_PLACEMENT_MODE is '%s'", variable.PlacementModeIgnore), ) return nil } @@ -8504,7 +8506,7 @@ func handlePartitionPlacement(ctx sessionctx.Context, partInfo *model.PartitionI sessVars := ctx.GetSessionVars() if sessVars.PlacementMode == variable.PlacementModeIgnore && removePartitionPlacement(partInfo) { sessVars.StmtCtx.AppendNote( - fmt.Errorf("Placement is ignored when TIDB_PLACEMENT_MODE is '%s'", variable.PlacementModeIgnore), + errors.NewNoStackErrorf("Placement is ignored when TIDB_PLACEMENT_MODE is '%s'", variable.PlacementModeIgnore), ) return nil } @@ -8524,7 +8526,7 @@ func checkIgnorePlacementDDL(ctx sessionctx.Context) bool { sessVars := ctx.GetSessionVars() if sessVars.PlacementMode == variable.PlacementModeIgnore { sessVars.StmtCtx.AppendNote( - fmt.Errorf("Placement is ignored when TIDB_PLACEMENT_MODE is '%s'", variable.PlacementModeIgnore), + errors.NewNoStackErrorf("Placement is ignored when TIDB_PLACEMENT_MODE is '%s'", variable.PlacementModeIgnore), ) return true } @@ -8542,7 +8544,7 @@ func (d *ddl) AddResourceGroup(ctx sessionctx.Context, stmt *ast.CreateResourceG if _, ok := d.GetInfoSchemaWithInterceptor(ctx).ResourceGroupByName(groupName); ok { if stmt.IfNotExists { - err = infoschema.ErrResourceGroupExists.GenWithStackByArgs(groupName) + err = infoschema.ErrResourceGroupExists.FastGenByArgs(groupName) ctx.GetSessionVars().StmtCtx.AppendNote(err) return nil } @@ -8889,9 +8891,9 @@ func checkTooBigFieldLengthAndTryAutoConvert(tp *types.FieldType, colName string return err } if tp.GetCharset() == charset.CharsetBin { - sessVars.StmtCtx.AppendWarning(dbterror.ErrAutoConvert.GenWithStackByArgs(colName, "VARBINARY", "BLOB")) + sessVars.StmtCtx.AppendWarning(dbterror.ErrAutoConvert.FastGenByArgs(colName, "VARBINARY", "BLOB")) } else { - sessVars.StmtCtx.AppendWarning(dbterror.ErrAutoConvert.GenWithStackByArgs(colName, "VARCHAR", "TEXT")) + sessVars.StmtCtx.AppendWarning(dbterror.ErrAutoConvert.FastGenByArgs(colName, "VARCHAR", "TEXT")) } } } diff --git a/pkg/ddl/partition.go b/pkg/ddl/partition.go index 082c731c38c72..dfd7a27adf2c9 100644 --- a/pkg/ddl/partition.go +++ b/pkg/ddl/partition.go @@ -534,7 +534,7 @@ func buildTablePartitionInfo(ctx sessionctx.Context, s *ast.PartitionOptions, tb } // Note that linear hash is simply ignored, and creates non-linear hash/key. if s.Linear { - ctx.GetSessionVars().StmtCtx.AppendWarning(dbterror.ErrUnsupportedCreatePartition.GenWithStack(fmt.Sprintf("LINEAR %s is not supported, using non-linear %s instead", s.Tp.String(), s.Tp.String()))) + ctx.GetSessionVars().StmtCtx.AppendWarning(dbterror.ErrUnsupportedCreatePartition.FastGen(fmt.Sprintf("LINEAR %s is not supported, using non-linear %s instead", s.Tp.String(), s.Tp.String()))) } if s.Tp == model.PartitionTypeHash || len(s.ColumnNames) != 0 { enable = true @@ -542,11 +542,11 @@ func buildTablePartitionInfo(ctx sessionctx.Context, s *ast.PartitionOptions, tb } if !enable { - ctx.GetSessionVars().StmtCtx.AppendWarning(dbterror.ErrUnsupportedCreatePartition.GenWithStack(fmt.Sprintf("Unsupported partition type %v, treat as normal table", s.Tp))) + ctx.GetSessionVars().StmtCtx.AppendWarning(dbterror.ErrUnsupportedCreatePartition.FastGen(fmt.Sprintf("Unsupported partition type %v, treat as normal table", s.Tp))) return nil } if s.Sub != nil { - ctx.GetSessionVars().StmtCtx.AppendWarning(dbterror.ErrUnsupportedCreatePartition.GenWithStack(fmt.Sprintf("Unsupported subpartitioning, only using %v partitioning", s.Tp))) + ctx.GetSessionVars().StmtCtx.AppendWarning(dbterror.ErrUnsupportedCreatePartition.FastGen(fmt.Sprintf("Unsupported subpartitioning, only using %v partitioning", s.Tp))) } pi := &model.PartitionInfo{ diff --git a/pkg/ddl/schematracker/dm_tracker.go b/pkg/ddl/schematracker/dm_tracker.go index 467f35ffcaad8..eb8076c4be0dd 100644 --- a/pkg/ddl/schematracker/dm_tracker.go +++ b/pkg/ddl/schematracker/dm_tracker.go @@ -714,7 +714,7 @@ func (d SchemaTracker) handleModifyColumn( job, err := ddl.GetModifiableColumnJob(ctx, sctx, nil, ident, originalColName, schema, t, spec) if err != nil { if infoschema.ErrColumnNotExists.Equal(err) && spec.IfExists { - sctx.GetSessionVars().StmtCtx.AppendNote(infoschema.ErrColumnNotExists.GenWithStackByArgs(originalColName, ident.Name)) + sctx.GetSessionVars().StmtCtx.AppendNote(infoschema.ErrColumnNotExists.FastGenByArgs(originalColName, ident.Name)) return nil } return errors.Trace(err) From 17bdcea34722150bf3b49350fa38b338aaca5d12 Mon Sep 17 00:00:00 2001 From: Rustin Liu Date: Tue, 19 Dec 2023 19:23:00 +0800 Subject: [PATCH 8/9] statistics: add more test cases for handling DDL event (#49582) close pingcap/tidb#48126 --- pkg/statistics/handle/ddl/BUILD.bazel | 2 +- pkg/statistics/handle/ddl/ddl_test.go | 208 ++++++++++++++++++++++++++ 2 files changed, 209 insertions(+), 1 deletion(-) diff --git a/pkg/statistics/handle/ddl/BUILD.bazel b/pkg/statistics/handle/ddl/BUILD.bazel index 1cbc60d88a54b..8a7a2421aeca9 100644 --- a/pkg/statistics/handle/ddl/BUILD.bazel +++ b/pkg/statistics/handle/ddl/BUILD.bazel @@ -31,7 +31,7 @@ go_test( timeout = "short", srcs = ["ddl_test.go"], flaky = True, - shard_count = 10, + shard_count = 13, deps = [ "//pkg/parser/model", "//pkg/planner/cardinality", diff --git a/pkg/statistics/handle/ddl/ddl_test.go b/pkg/statistics/handle/ddl/ddl_test.go index 6d4ce7a9d7471..9607853687205 100644 --- a/pkg/statistics/handle/ddl/ddl_test.go +++ b/pkg/statistics/handle/ddl/ddl_test.go @@ -326,6 +326,151 @@ func TestReorgPartitions(t *testing.T) { require.NotEqual(t, versionP1, rows[1][0].(string)) } +func TestIncreasePartitionCountOfHashPartitionTable(t *testing.T) { + store, do := testkit.CreateMockStoreAndDomain(t) + testKit := testkit.NewTestKit(t, store) + h := do.StatsHandle() + + testKit.MustExec("use test") + testKit.MustExec("drop table if exists t") + testKit.MustExec("create table t (a int, b int) partition by hash(a) partitions 2") + testKit.MustExec("insert into t values (1,2),(2,2),(6,2),(11,2),(16,2)") + testKit.MustExec("analyze table t") + is := do.InfoSchema() + tbl, err := is.TableByName( + model.NewCIStr("test"), model.NewCIStr("t"), + ) + require.NoError(t, err) + tableInfo := tbl.Meta() + pi := tableInfo.GetPartitionInfo() + for _, def := range pi.Definitions { + statsTbl := h.GetPartitionStats(tableInfo, def.ID) + require.False(t, statsTbl.Pseudo) + } + err = h.Update(is) + require.NoError(t, err) + + // Get partition p0 and p1's stats update version. + partitionP0ID := pi.Definitions[0].ID + partitionP1ID := pi.Definitions[1].ID + // Get it from stats_meat first. + rows := testKit.MustQuery( + "select version from mysql.stats_meta where table_id in (?, ?) order by table_id", partitionP0ID, partitionP1ID, + ).Rows() + require.Len(t, rows, 2) + versionP0 := rows[0][0].(string) + versionP1 := rows[1][0].(string) + + // Increase the partition count to 4. + testKit.MustExec("alter table t add partition partitions 2") + // Find the reorganize partition event. + reorganizePartitionEvent := findEvent(h.DDLEventCh(), model.ActionReorganizePartition) + err = h.HandleDDLEvent(reorganizePartitionEvent) + require.NoError(t, err) + require.Nil(t, h.Update(is)) + + // Check new partitions are added. + is = do.InfoSchema() + tbl, err = is.TableByName( + model.NewCIStr("test"), model.NewCIStr("t"), + ) + require.NoError(t, err) + tableInfo = tbl.Meta() + pi = tableInfo.GetPartitionInfo() + require.Len(t, pi.Definitions, 4) + // Check the stats meta. + rows = testKit.MustQuery( + "select version from mysql.stats_meta where table_id in (?, ?, ?, ?) order by table_id", + pi.Definitions[0].ID, pi.Definitions[1].ID, pi.Definitions[2].ID, pi.Definitions[3].ID, + ).Rows() + require.Len(t, rows, 4) + + // Check the old partitions' stats version is changed. + rows = testKit.MustQuery( + "select version from mysql.stats_meta where table_id in (?, ?) order by table_id", partitionP0ID, partitionP1ID, + ).Rows() + require.Len(t, rows, 2) + require.NotEqual(t, versionP0, rows[0][0].(string)) + require.NotEqual(t, versionP1, rows[1][0].(string)) +} + +func TestDecreasePartitionCountOfHashPartitionTable(t *testing.T) { + store, do := testkit.CreateMockStoreAndDomain(t) + testKit := testkit.NewTestKit(t, store) + h := do.StatsHandle() + + testKit.MustExec("use test") + testKit.MustExec("drop table if exists t") + testKit.MustExec("create table t (a int, b int) partition by hash(a) partitions 4") + testKit.MustExec("insert into t values (1,2),(2,2),(6,2),(11,2),(16,2)") + testKit.MustExec("analyze table t") + is := do.InfoSchema() + tbl, err := is.TableByName( + model.NewCIStr("test"), model.NewCIStr("t"), + ) + require.NoError(t, err) + tableInfo := tbl.Meta() + pi := tableInfo.GetPartitionInfo() + require.Len(t, pi.Definitions, 4) + for _, def := range pi.Definitions { + statsTbl := h.GetPartitionStats(tableInfo, def.ID) + require.False(t, statsTbl.Pseudo) + } + err = h.Update(is) + require.NoError(t, err) + + // Get partition p0 and p1's stats update version. + partitionP0ID := pi.Definitions[0].ID + partitionP1ID := pi.Definitions[1].ID + partitionP2ID := pi.Definitions[2].ID + partitionP3ID := pi.Definitions[3].ID + // Get it from stats_meat first. + rows := testKit.MustQuery( + "select version from mysql.stats_meta where table_id in (?, ?, ?, ?) order by table_id", + partitionP0ID, partitionP1ID, partitionP2ID, partitionP3ID, + ).Rows() + require.Len(t, rows, 4) + versionP0 := rows[0][0].(string) + versionP1 := rows[1][0].(string) + versionP2 := rows[2][0].(string) + versionP3 := rows[3][0].(string) + + // Decrease the partition count to 2. + testKit.MustExec("alter table t coalesce partition 2") + // Find the reorganize partition event. + reorganizePartitionEvent := findEvent(h.DDLEventCh(), model.ActionReorganizePartition) + err = h.HandleDDLEvent(reorganizePartitionEvent) + require.NoError(t, err) + require.Nil(t, h.Update(is)) + + // Check new partitions are added. + is = do.InfoSchema() + tbl, err = is.TableByName( + model.NewCIStr("test"), model.NewCIStr("t"), + ) + require.NoError(t, err) + tableInfo = tbl.Meta() + pi = tableInfo.GetPartitionInfo() + require.Len(t, pi.Definitions, 2) + // Check the stats meta. + rows = testKit.MustQuery( + "select version from mysql.stats_meta where table_id in (?, ?) order by table_id", + pi.Definitions[0].ID, pi.Definitions[1].ID, + ).Rows() + require.Len(t, rows, 2) + + // Check the old partitions' stats version is changed. + rows = testKit.MustQuery( + "select version from mysql.stats_meta where table_id in (?, ?, ?, ?) order by table_id", + partitionP0ID, partitionP1ID, partitionP2ID, partitionP3ID, + ).Rows() + require.Len(t, rows, 4) + require.NotEqual(t, versionP0, rows[0][0].(string)) + require.NotEqual(t, versionP1, rows[1][0].(string)) + require.NotEqual(t, versionP2, rows[2][0].(string)) + require.NotEqual(t, versionP3, rows[3][0].(string)) +} + func TestTruncateAPartition(t *testing.T) { store, do := testkit.CreateMockStoreAndDomain(t) testKit := testkit.NewTestKit(t, store) @@ -393,6 +538,69 @@ func TestTruncateAPartition(t *testing.T) { require.NotEqual(t, version, rows[0][0].(string)) } +func TestTruncateAHashPartition(t *testing.T) { + store, do := testkit.CreateMockStoreAndDomain(t) + testKit := testkit.NewTestKit(t, store) + h := do.StatsHandle() + testKit.MustExec("use test") + testKit.MustExec("drop table if exists t") + testKit.MustExec(` + create table t ( + a bigint, + b int, + primary key(a), + index idx(b) + ) + partition by hash(a) partitions 4 + `) + testKit.MustExec("insert into t values (1,2),(2,2),(6,2),(11,2),(16,2)") + testKit.MustExec("analyze table t") + is := do.InfoSchema() + tbl, err := is.TableByName( + model.NewCIStr("test"), model.NewCIStr("t"), + ) + require.NoError(t, err) + tableInfo := tbl.Meta() + pi := tableInfo.GetPartitionInfo() + require.NotNil(t, pi) + for _, def := range pi.Definitions { + statsTbl := h.GetPartitionStats(tableInfo, def.ID) + require.False(t, statsTbl.Pseudo) + } + err = h.Update(is) + require.NoError(t, err) + + // Get partition p0's stats update version. + partitionID := pi.Definitions[0].ID + // Get it from stats_meat first. + rows := testKit.MustQuery( + "select version from mysql.stats_meta where table_id = ?", partitionID, + ).Rows() + require.Len(t, rows, 1) + version := rows[0][0].(string) + + testKit.MustExec("alter table t truncate partition p0") + // Find the truncate partition event. + truncatePartitionEvent := findEvent(h.DDLEventCh(), model.ActionTruncateTablePartition) + err = h.HandleDDLEvent(truncatePartitionEvent) + require.NoError(t, err) + // Check global stats meta. + // Because we have truncated a partition, the count should be 5 - 1 = 4 and the modify count should be 1. + testKit.MustQuery( + "select count, modify_count from mysql.stats_meta where table_id = ?", tableInfo.ID, + ).Check( + testkit.Rows("4 1"), + ) + + // Check the version again. + rows = testKit.MustQuery( + "select version from mysql.stats_meta where table_id = ?", partitionID, + ).Rows() + require.Len(t, rows, 1) + // Version gets updated after truncate the partition. + require.NotEqual(t, version, rows[0][0].(string)) +} + func TestTruncatePartitions(t *testing.T) { store, do := testkit.CreateMockStoreAndDomain(t) testKit := testkit.NewTestKit(t, store) From ad1efe451080ec60c1c592b141e780ed3cc190fb Mon Sep 17 00:00:00 2001 From: Yuanjia Zhang Date: Tue, 19 Dec 2023 20:07:52 +0800 Subject: [PATCH 9/9] planner: match hints from universal bindings correctly (#49585) ref pingcap/tidb#48875 --- pkg/bindinfo/BUILD.bazel | 2 +- pkg/bindinfo/bind_record.go | 7 +- pkg/bindinfo/universal_binding_test.go | 175 ++++++++++++++++++++++- pkg/planner/core/logical_plan_builder.go | 2 +- pkg/planner/core/planbuilder.go | 12 +- pkg/planner/core/rule_join_reorder.go | 2 +- 6 files changed, 192 insertions(+), 8 deletions(-) diff --git a/pkg/bindinfo/BUILD.bazel b/pkg/bindinfo/BUILD.bazel index 3874650f82b48..a2cff9a8afbcd 100644 --- a/pkg/bindinfo/BUILD.bazel +++ b/pkg/bindinfo/BUILD.bazel @@ -59,7 +59,7 @@ go_test( embed = [":bindinfo"], flaky = True, race = "on", - shard_count = 46, + shard_count = 47, deps = [ "//pkg/bindinfo/internal", "//pkg/config", diff --git a/pkg/bindinfo/bind_record.go b/pkg/bindinfo/bind_record.go index 241ea24cbae95..16f11399ec233 100644 --- a/pkg/bindinfo/bind_record.go +++ b/pkg/bindinfo/bind_record.go @@ -187,7 +187,12 @@ func (br *BindRecord) prepareHints(sctx sessionctx.Context) error { if (bind.Hint != nil && bind.ID != "") || bind.Status == deleted { continue } - hintsSet, stmt, warns, err := hint.ParseHintsSet(p, bind.BindSQL, bind.Charset, bind.Collation, br.Db) + dbName := br.Db + if bind.Type == TypeUniversal { + dbName = "*" // ues '*' for universal bindings + } + + hintsSet, stmt, warns, err := hint.ParseHintsSet(p, bind.BindSQL, bind.Charset, bind.Collation, dbName) if err != nil { return err } diff --git a/pkg/bindinfo/universal_binding_test.go b/pkg/bindinfo/universal_binding_test.go index 21312804a8654..d87611803db4a 100644 --- a/pkg/bindinfo/universal_binding_test.go +++ b/pkg/bindinfo/universal_binding_test.go @@ -37,6 +37,21 @@ func showBinding(tk *testkit.TestKit, showStmt string) [][]interface{} { return result } +func removeAllBindings(tk *testkit.TestKit, global bool) { + scope := "session" + if global { + scope = "global" + } + res := showBinding(tk, fmt.Sprintf("show %v bindings", scope)) + for _, r := range res { + if r[4] == "builtin" { + continue + } + tk.MustExec(fmt.Sprintf("drop %v binding for sql digest '%v'", scope, r[6])) + } + tk.MustQuery(fmt.Sprintf("show %v bindings", scope)).Check(testkit.Rows()) // empty +} + func TestUniversalBindingBasic(t *testing.T) { store := testkit.CreateMockStore(t) tk1 := testkit.NewTestKit(t, store) @@ -64,10 +79,12 @@ func TestUniversalBindingBasic(t *testing.T) { tk.MustExec("use " + useDB) for _, testDB := range []string{"", "test.", "test1.", "test2."} { tk.MustExec(`set @@tidb_opt_enable_universal_binding=1`) // enabled - tk.MustUseIndex(fmt.Sprintf("select * from %vt", testDB), idx) + require.True(t, tk.MustUseIndex(fmt.Sprintf("select * from %vt", testDB), idx)) tk.MustQuery(`select @@last_plan_from_binding`).Check(testkit.Rows("1")) + require.True(t, tk.MustUseIndex(fmt.Sprintf("select * from %vt", testDB), idx)) + tk.MustQuery(`show warnings`).Check(testkit.Rows()) // no warning tk.MustExec(`set @@tidb_opt_enable_universal_binding=0`) // disabled - tk.MustUseIndex(fmt.Sprintf("select * from %vt", testDB), idx) + tk.MustQuery(fmt.Sprintf("select * from %vt", testDB)) tk.MustQuery(`select @@last_plan_from_binding`).Check(testkit.Rows("0")) } } @@ -238,3 +255,157 @@ func TestUniversalBindingGC(t *testing.T) { require.NoError(t, bindHandle.GCGlobalBinding()) tk.MustQuery(`select bind_sql, status, type from mysql.bind_info where source != 'builtin'`).Check(testkit.Rows()) // empty after GC } + +func TestUniversalBindingHints(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec(`use test`) + + for _, db := range []string{"db1", "db2", "db3"} { + tk.MustExec(`create database ` + db) + tk.MustExec(`use ` + db) + tk.MustExec(`create table t1 (a int, b int, c int, d int, key(a), key(b), key(c), key(d))`) + tk.MustExec(`create table t2 (a int, b int, c int, d int, key(a), key(b), key(c), key(d))`) + tk.MustExec(`create table t3 (a int, b int, c int, d int, key(a), key(b), key(c), key(d))`) + } + tk.MustExec(`set @@tidb_opt_enable_universal_binding=1`) + + for _, c := range []struct { + binding string + qTemplate string + }{ + // use index + {`create universal binding using select /*+ use_index(t1, c) */ * from t1 where a=1`, + `select * from %st1 where a=1000`}, + {`create universal binding using select /*+ use_index(t1, c) */ * from t1 where d<1`, + `select * from %st1 where d<10000`}, + {`create universal binding using select /*+ use_index(t1, c) */ * from t1, t2 where t1.d<1`, + `select * from %st1, t2 where t1.d<100`}, + {`create universal binding using select /*+ use_index(t1, c) */ * from t1, t2 where t1.d<1`, + `select * from t1, %st2 where t1.d<100`}, + {`create universal binding using select /*+ use_index(t1, c), use_index(t2, a) */ * from t1, t2 where t1.d<1`, + `select * from %st1, t2 where t1.d<100`}, + {`create universal binding using select /*+ use_index(t1, c), use_index(t2, a) */ * from t1, t2 where t1.d<1`, + `select * from t1, %st2 where t1.d<100`}, + {`create universal binding using select /*+ use_index(t1, c), use_index(t2, a) */ * from t1, t2, t3 where t1.d<1`, + `select * from %st1, t2, t3 where t1.d<100`}, + {`create universal binding using select /*+ use_index(t1, c), use_index(t2, a) */ * from t1, t2, t3 where t1.d<1`, + `select * from t1, t2, %st3 where t1.d<100`}, + + // ignore index + {`create universal binding using select /*+ ignore_index(t1, b) */ * from t1 where b=1`, + `select * from %st1 where b=1000`}, + {`create universal binding using select /*+ ignore_index(t1, b) */ * from t1 where b>1`, + `select * from %st1 where b>1000`}, + {`create universal binding using select /*+ ignore_index(t1, b) */ * from t1 where b in (1,2)`, + `select * from %st1 where b in (1)`}, + {`create universal binding using select /*+ ignore_index(t1, b) */ * from t1 where b in (1,2)`, + `select * from %st1 where b in (1,2,3,4,5)`}, + + // order index hint + {`create universal binding using select /*+ order_index(t1, a) */ a from t1 where a<10 order by a limit 10`, + `select a from %st1 where a<10000 order by a limit 10`}, + {`create universal binding using select /*+ order_index(t1, b) */ b from t1 where b>10 order by b limit 1111`, + `select b from %st1 where b>2 order by b limit 10`}, + + // no order index hint + {`create universal binding using select /*+ no_order_index(t1, c) */ c from t1 where c<10 order by c limit 10`, + `select c from %st1 where c<10000 order by c limit 10`}, + {`create universal binding using select /*+ no_order_index(t1, d) */ d from t1 where d>10 order by d limit 1111`, + `select d from %st1 where d>2 order by d limit 10`}, + + // agg hint + {`create universal binding using select /*+ hash_agg() */ count(*) from t1 group by a`, + `select count(*) from %st1 group by a`}, + {`create universal binding using select /*+ stream_agg() */ count(*) from t1 group by b`, + `select count(*) from %st1 group by b`}, + + // to_cop hint + {`create universal binding using select /*+ agg_to_cop() */ sum(a) from t1`, + `select sum(a) from %st1`}, + {`create universal binding using select /*+ limit_to_cop() */ a from t1 limit 10`, + `select a from %st1 limit 101`}, + + // index merge hint + {`create universal binding using select /*+ use_index_merge(t1, c, d) */ * from t1 where c=1 or d=1`, + `select * from %st1 where c=1000 or d=1000`}, + {`create universal binding using select /*+ no_index_merge() */ * from t1 where a=1 or b=1`, + `select * from %st1 where a=1000 or b=1000`}, + + // join type hint + {`create universal binding using select /*+ hash_join(t1) */ * from t1, t2 where t1.a=t2.a`, + `select * from %st1, t2 where t1.a=t2.a`}, + {`create universal binding using select /*+ hash_join(t2) */ * from t1, t2 where t1.a=t2.a`, + `select * from t1, %st2 where t1.a=t2.a`}, + {`create universal binding using select /*+ hash_join(t2) */ * from t1, t2, t3 where t1.a=t2.a and t3.b=t2.b`, + `select * from t1, %st2, t3 where t1.a=t2.a and t3.b=t2.b`}, + {`create universal binding using select /*+ hash_join_build(t1) */ * from t1, t2 where t1.a=t2.a`, + `select * from t1, %st2 where t1.a=t2.a`}, + {`create universal binding using select /*+ hash_join_probe(t1) */ * from t1, t2 where t1.a=t2.a`, + `select * from t1, %st2 where t1.a=t2.a`}, + {`create universal binding using select /*+ merge_join(t1) */ * from t1, t2 where t1.a=t2.a`, + `select * from %st1, t2 where t1.a=t2.a`}, + {`create universal binding using select /*+ merge_join(t2) */ * from t1, t2 where t1.a=t2.a`, + `select * from t1, %st2 where t1.a=t2.a`}, + {`create universal binding using select /*+ merge_join(t2) */ * from t1, t2, t3 where t1.a=t2.a and t3.b=t2.b`, + `select * from t1, %st2, t3 where t1.a=t2.a and t3.b=t2.b`}, + {`create universal binding using select /*+ inl_join(t1) */ * from t1, t2 where t1.a=t2.a`, + `select * from %st1, t2 where t1.a=t2.a`}, + {`create universal binding using select /*+ inl_join(t2) */ * from t1, t2 where t1.a=t2.a`, + `select * from t1, %st2 where t1.a=t2.a`}, + {`create universal binding using select /*+ inl_join(t2) */ * from t1, t2, t3 where t1.a=t2.a and t3.b=t2.b`, + `select * from t1, %st2, t3 where t1.a=t2.a and t3.b=t2.b`}, + + // TODO: no_xxx_join hints are not compatible with bindings, fix later. + // no join type hint + //{`create universal binding using select /*+ no_hash_join(t1) */ * from t1, t2 where t1.b=t2.b`, + // `select * from %st1, t2 where t1.b=t2.b`}, + //{`create universal binding using select /*+ no_hash_join(t2) */ * from t1, t2 where t1.c=t2.c`, + // `select * from t1, %st2 where t1.c=t2.c`}, + //{`create universal binding using select /*+ no_hash_join(t2) */ * from t1, t2, t3 where t1.a=t2.a and t3.b=t2.b`, + // `select * from t1, %st2, t3 where t1.a=t2.a and t3.b=t2.b`}, + //{`create universal binding using select /*+ no_merge_join(t1) */ * from t1, t2 where t1.b=t2.b`, + // `select * from %st1, t2 where t1.b=t2.b`}, + //{`create universal binding using select /*+ no_merge_join(t2) */ * from t1, t2 where t1.c=t2.c`, + // `select * from t1, %st2 where t1.c=t2.c`}, + //{`create universal binding using select /*+ no_merge_join(t2) */ * from t1, t2, t3 where t1.a=t2.a and t3.b=t2.b`, + // `select * from t1, %st2, t3 where t1.a=t2.a and t3.b=t2.b`}, + //{`create universal binding using select /*+ no_index_join(t1) */ * from t1, t2 where t1.b=t2.b`, + // `select * from %st1, t2 where t1.b=t2.b`}, + //{`create universal binding using select /*+ no_index_join(t2) */ * from t1, t2 where t1.c=t2.c`, + // `select * from t1, %st2 where t1.c=t2.c`}, + //{`create universal binding using select /*+ no_index_join(t2) */ * from t1, t2, t3 where t1.a=t2.a and t3.b=t2.b`, + // `select * from t1, %st2, t3 where t1.a=t2.a and t3.b=t2.b`}, + + // join order hint + {`create universal binding using select /*+ leading(t2) */ * from t1, t2 where t1.b=t2.b`, + `select * from %st1, t2 where t1.b=t2.b`}, + {`create universal binding using select /*+ leading(t2) */ * from t1, t2 where t1.c=t2.c`, + `select * from t1, %st2 where t1.c=t2.c`}, + {`create universal binding using select /*+ leading(t2, t1) */ * from t1, t2 where t1.c=t2.c`, + `select * from t1, %st2 where t1.c=t2.c`}, + {`create universal binding using select /*+ leading(t1, t2) */ * from t1, t2 where t1.c=t2.c`, + `select * from t1, %st2 where t1.c=t2.c`}, + {`create universal binding using select /*+ leading(t1) */ * from t1, t2, t3 where t1.a=t2.a and t3.b=t2.b`, + `select * from t1, %st2, t3 where t1.a=t2.a and t3.b=t2.b`}, + {`create universal binding using select /*+ leading(t2) */ * from t1, t2, t3 where t1.a=t2.a and t3.b=t2.b`, + `select * from t1, %st2, t3 where t1.a=t2.a and t3.b=t2.b`}, + {`create universal binding using select /*+ leading(t2,t3) */ * from t1, t2, t3 where t1.a=t2.a and t3.b=t2.b`, + `select * from t1, %st2, t3 where t1.a=t2.a and t3.b=t2.b`}, + {`create universal binding using select /*+ leading(t2,t3,t1) */ * from t1, t2, t3 where t1.a=t2.a and t3.b=t2.b`, + `select * from t1, %st2, t3 where t1.a=t2.a and t3.b=t2.b`}, + } { + removeAllBindings(tk, false) + tk.MustExec(c.binding) + for _, currentDB := range []string{"db1", "db2", "db3"} { + tk.MustExec(`use ` + currentDB) + for _, db := range []string{"db1.", "db2.", "db3.", ""} { + query := fmt.Sprintf(c.qTemplate, db) + tk.MustExec(query) + tk.MustQuery(`show warnings`).Check(testkit.Rows()) // no warning + tk.MustExec(query) + tk.MustQuery(`select @@last_plan_from_binding`).Check(testkit.Rows("1")) + } + } + } +} diff --git a/pkg/planner/core/logical_plan_builder.go b/pkg/planner/core/logical_plan_builder.go index 50e6745014e05..2d241056f2415 100644 --- a/pkg/planner/core/logical_plan_builder.go +++ b/pkg/planner/core/logical_plan_builder.go @@ -5195,7 +5195,7 @@ func (b *PlanBuilder) buildDataSource(ctx context.Context, tn *ast.TableName, as var indexMergeHints []indexHintInfo if hints := b.TableHints(); hints != nil { for i, hint := range hints.indexMergeHintList { - if hint.tblName.L == tblName.L && hint.dbName.L == dbName.L { + if hint.match(dbName, tblName) { hints.indexMergeHintList[i].matched = true // check whether the index names in IndexMergeHint are valid. invalidIdxNames := make([]string, 0, len(hint.indexHint.IndexNames)) diff --git a/pkg/planner/core/planbuilder.go b/pkg/planner/core/planbuilder.go index 89c5bdcfea2c2..f0f644a4b1d5c 100644 --- a/pkg/planner/core/planbuilder.go +++ b/pkg/planner/core/planbuilder.go @@ -139,6 +139,12 @@ type indexHintInfo struct { matched bool } +func (hint *indexHintInfo) match(dbName, tblName model.CIStr) bool { + return hint.tblName.L == tblName.L && + (hint.dbName.L == dbName.L || + hint.dbName.L == "*") // for universal bindings, e.g. *.t +} + func (hint *indexHintInfo) hintTypeString() string { switch hint.indexHint.HintType { case ast.HintUse: @@ -322,7 +328,9 @@ func (*tableHintInfo) matchTableName(tables []*hintTableInfo, hintTables []hintT if table == nil { continue } - if curEntry.dbName.L == table.dbName.L && curEntry.tblName.L == table.tblName.L && table.selectOffset == curEntry.selectOffset { + if (curEntry.dbName.L == table.dbName.L || curEntry.dbName.L == "*") && + curEntry.tblName.L == table.tblName.L && + table.selectOffset == curEntry.selectOffset { hintTables[i].matched = true hintMatched = true break @@ -1454,7 +1462,7 @@ func getPossibleAccessPaths(ctx sessionctx.Context, tableHints *tableHintInfo, i indexHintsLen := len(indexHints) if tableHints != nil { for i, hint := range tableHints.indexHintList { - if hint.dbName.L == dbName.L && hint.tblName.L == tblName.L { + if hint.match(dbName, tblName) { indexHints = append(indexHints, hint.indexHint) tableHints.indexHintList[i].matched = true } diff --git a/pkg/planner/core/rule_join_reorder.go b/pkg/planner/core/rule_join_reorder.go index 98f066ed877d8..0a67b68c0e795 100644 --- a/pkg/planner/core/rule_join_reorder.go +++ b/pkg/planner/core/rule_join_reorder.go @@ -411,7 +411,7 @@ func (s *baseSingleGroupJoinOrderSolver) generateLeadingJoinGroup(curJoinGroup [ if tableAlias == nil { continue } - if hintTbl.dbName.L == tableAlias.dbName.L && hintTbl.tblName.L == tableAlias.tblName.L && hintTbl.selectOffset == tableAlias.selectOffset { + if (hintTbl.dbName.L == tableAlias.dbName.L || hintTbl.dbName.L == "*") && hintTbl.tblName.L == tableAlias.tblName.L && hintTbl.selectOffset == tableAlias.selectOffset { match = true leadingJoinGroup = append(leadingJoinGroup, joinGroup) leftJoinGroup = append(leftJoinGroup[:i], leftJoinGroup[i+1:]...)