diff --git a/bindinfo/bind_test.go b/bindinfo/bind_test.go index e3eda9fd389e1..6c47b96417208 100644 --- a/bindinfo/bind_test.go +++ b/bindinfo/bind_test.go @@ -2466,11 +2466,47 @@ func (s *testSerialSuite) TestIssue26377(c *C) { tk.MustExec("create temporary table tmp2(a int(11), key idx_a(a));") queries := []string{ + "create global binding for with cte1 as (select a from tmp1) select * from cte1 using with cte1 as (select a from tmp1) select * from cte1", "create global binding for select * from t1 inner join tmp1 on t1.a=tmp1.a using select * from t1 inner join tmp1 on t1.a=tmp1.a;", "create global binding for select * from t1 where t1.a in (select a from tmp1) using select * from t1 where t1.a in (select a from tmp1 use index (idx_a));", "create global binding for select a from t1 union select a from tmp1 using select a from t1 union select a from tmp1 use index (idx_a);", "create global binding for select t1.a, (select a from tmp1 where tmp1.a=1) as t2 from t1 using select t1.a, (select a from tmp1 where tmp1.a=1) as t2 from t1;", "create global binding for select * from (select * from tmp1) using select * from (select * from tmp1);", + "create global binding for select * from t1 where t1.a = (select a from tmp1) using select * from t1 where t1.a = (select a from tmp1)", + } + genLocalTemporarySQL := func(sql string) string { + return strings.Replace(sql, "tmp1", "tmp2", -1) + } + for _, query := range queries { + localSQL := genLocalTemporarySQL(query) + queries = append(queries, localSQL) + } + + for _, q := range queries { + tk.MustGetErrCode(q, errno.ErrOptOnTemporaryTable) + } +} + +func (s *testSerialSuite) TestIssue27422(c *C) { + tk := testkit.NewTestKit(c, s.store) + s.cleanBindingEnv(tk) + tk.MustExec("use test") + tk.MustExec("set tidb_enable_global_temporary_table = true") + tk.MustExec("set @@tidb_enable_noop_functions=1;") + tk.MustExec("drop table if exists t1,tmp1,tmp2") + tk.MustExec("create table t1(a int(11))") + tk.MustExec("create global temporary table tmp1(a int(11), key idx_a(a)) on commit delete rows;") + tk.MustExec("create temporary table tmp2(a int(11), key idx_a(a));") + + queries := []string{ + "create global binding for insert into t1 (select * from tmp1) using insert into t1 (select * from tmp1);", + "create global binding for update t1 inner join tmp1 on t1.a=tmp1.a set t1.a=1 using update t1 inner join tmp1 on t1.a=tmp1.a set t1.a=1", + "create global binding for update t1 set t1.a=(select a from tmp1) using update t1 set t1.a=(select a from tmp1)", + "create global binding for update t1 set t1.a=1 where t1.a = (select a from tmp1) using update t1 set t1.a=1 where t1.a = (select a from tmp1)", + "create global binding for with cte1 as (select a from tmp1) update t1 set t1.a=1 where t1.a in (select a from cte1) using with cte1 as (select a from tmp1) update t1 set t1.a=1 where t1.a in (select a from cte1)", + "create global binding for delete from t1 where t1.a in (select a from tmp1) using delete from t1 where t1.a in (select a from tmp1)", + "create global binding for delete from t1 where t1.a = (select a from tmp1) using delete from t1 where t1.a = (select a from tmp1)", + "create global binding for delete t1 from t1,tmp1 using delete t1 from t1,tmp1", } genLocalTemporarySQL := func(sql string) string { return strings.Replace(sql, "tmp1", "tmp2", -1) diff --git a/br/cmd/tidb-lightning-ctl/main.go b/br/cmd/tidb-lightning-ctl/main.go index 7758621ac5bf3..66b616af57e3e 100644 --- a/br/cmd/tidb-lightning-ctl/main.go +++ b/br/cmd/tidb-lightning-ctl/main.go @@ -438,7 +438,11 @@ func importEngine(ctx context.Context, cfg *config.Config, tls *common.TLS, engi return errors.Trace(err) } - return errors.Trace(ce.Import(ctx)) + regionSplitSize := int64(cfg.TikvImporter.RegionSplitSize) + if regionSplitSize == 0 { + regionSplitSize = int64(config.SplitRegionSize) + } + return errors.Trace(ce.Import(ctx, regionSplitSize)) } func cleanupEngine(ctx context.Context, cfg *config.Config, tls *common.TLS, engine string) error { diff --git a/br/pkg/lightning/backend/backend.go b/br/pkg/lightning/backend/backend.go index 0a775c4a2b015..29b8981f000cc 100644 --- a/br/pkg/lightning/backend/backend.go +++ b/br/pkg/lightning/backend/backend.go @@ -151,7 +151,7 @@ type AbstractBackend interface { // ImportEngine imports engine data to the backend. If it returns ErrDuplicateDetected, // it means there is duplicate detected. For this situation, all data in the engine must be imported. // It's safe to reset or cleanup this engine. - ImportEngine(ctx context.Context, engineUUID uuid.UUID) error + ImportEngine(ctx context.Context, engineUUID uuid.UUID, regionSplitSize int64) error CleanupEngine(ctx context.Context, engineUUID uuid.UUID) error @@ -310,7 +310,7 @@ func (be Backend) CheckDiskQuota(quota int64) ( // into the target and then reset the engine to empty. This method will not // close the engine. Make sure the engine is flushed manually before calling // this method. -func (be Backend) UnsafeImportAndReset(ctx context.Context, engineUUID uuid.UUID) error { +func (be Backend) UnsafeImportAndReset(ctx context.Context, engineUUID uuid.UUID, regionSplitSize int64) error { // DO NOT call be.abstract.CloseEngine()! The engine should still be writable after // calling UnsafeImportAndReset(). closedEngine := ClosedEngine{ @@ -320,7 +320,7 @@ func (be Backend) UnsafeImportAndReset(ctx context.Context, engineUUID uuid.UUID uuid: engineUUID, }, } - if err := closedEngine.Import(ctx); err != nil { + if err := closedEngine.Import(ctx, regionSplitSize); err != nil { return err } return be.abstract.ResetEngine(ctx, engineUUID) @@ -436,12 +436,12 @@ func (en engine) unsafeClose(ctx context.Context, cfg *EngineConfig) (*ClosedEng } // Import the data written to the engine into the target. -func (engine *ClosedEngine) Import(ctx context.Context) error { +func (engine *ClosedEngine) Import(ctx context.Context, regionSplitSize int64) error { var err error for i := 0; i < importMaxRetryTimes; i++ { task := engine.logger.With(zap.Int("retryCnt", i)).Begin(zap.InfoLevel, "import") - err = engine.backend.ImportEngine(ctx, engine.uuid) + err = engine.backend.ImportEngine(ctx, engine.uuid, regionSplitSize) if !common.IsRetryableError(err) { task.End(zap.ErrorLevel, err) return err diff --git a/br/pkg/lightning/backend/backend_test.go b/br/pkg/lightning/backend/backend_test.go index 42404f5f3a29b..db79f4b28c806 100644 --- a/br/pkg/lightning/backend/backend_test.go +++ b/br/pkg/lightning/backend/backend_test.go @@ -58,7 +58,7 @@ func (s *backendSuite) TestOpenCloseImportCleanUpEngine(c *C) { Return(nil). After(openCall) importCall := s.mockBackend.EXPECT(). - ImportEngine(ctx, engineUUID). + ImportEngine(ctx, engineUUID, gomock.Any()). Return(nil). After(closeCall) s.mockBackend.EXPECT(). @@ -70,7 +70,7 @@ func (s *backendSuite) TestOpenCloseImportCleanUpEngine(c *C) { c.Assert(err, IsNil) closedEngine, err := engine.Close(ctx, nil) c.Assert(err, IsNil) - err = closedEngine.Import(ctx) + err = closedEngine.Import(ctx, 1) c.Assert(err, IsNil) err = closedEngine.Cleanup(ctx) c.Assert(err, IsNil) @@ -252,12 +252,12 @@ func (s *backendSuite) TestImportFailedNoRetry(c *C) { s.mockBackend.EXPECT().CloseEngine(ctx, nil, gomock.Any()).Return(nil) s.mockBackend.EXPECT(). - ImportEngine(ctx, gomock.Any()). + ImportEngine(ctx, gomock.Any(), gomock.Any()). Return(errors.Annotate(context.Canceled, "fake unrecoverable import error")) closedEngine, err := s.backend.UnsafeCloseEngine(ctx, nil, "`db`.`table`", 1) c.Assert(err, IsNil) - err = closedEngine.Import(ctx) + err = closedEngine.Import(ctx, 1) c.Assert(err, ErrorMatches, "fake unrecoverable import error.*") } @@ -269,14 +269,14 @@ func (s *backendSuite) TestImportFailedWithRetry(c *C) { s.mockBackend.EXPECT().CloseEngine(ctx, nil, gomock.Any()).Return(nil) s.mockBackend.EXPECT(). - ImportEngine(ctx, gomock.Any()). + ImportEngine(ctx, gomock.Any(), gomock.Any()). Return(errors.New("fake recoverable import error")). MinTimes(2) s.mockBackend.EXPECT().RetryImportDelay().Return(time.Duration(0)).AnyTimes() closedEngine, err := s.backend.UnsafeCloseEngine(ctx, nil, "`db`.`table`", 1) c.Assert(err, IsNil) - err = closedEngine.Import(ctx) + err = closedEngine.Import(ctx, 1) c.Assert(err, ErrorMatches, ".*fake recoverable import error") } @@ -288,16 +288,16 @@ func (s *backendSuite) TestImportFailedRecovered(c *C) { s.mockBackend.EXPECT().CloseEngine(ctx, nil, gomock.Any()).Return(nil) s.mockBackend.EXPECT(). - ImportEngine(ctx, gomock.Any()). + ImportEngine(ctx, gomock.Any(), gomock.Any()). Return(errors.New("fake recoverable import error")) s.mockBackend.EXPECT(). - ImportEngine(ctx, gomock.Any()). + ImportEngine(ctx, gomock.Any(), gomock.Any()). Return(nil) s.mockBackend.EXPECT().RetryImportDelay().Return(time.Duration(0)).AnyTimes() closedEngine, err := s.backend.UnsafeCloseEngine(ctx, nil, "`db`.`table`", 1) c.Assert(err, IsNil) - err = closedEngine.Import(ctx) + err = closedEngine.Import(ctx, 1) c.Assert(err, IsNil) } diff --git a/br/pkg/lightning/backend/importer/importer.go b/br/pkg/lightning/backend/importer/importer.go index dc292bdb7e870..f4cb73930d80c 100644 --- a/br/pkg/lightning/backend/importer/importer.go +++ b/br/pkg/lightning/backend/importer/importer.go @@ -201,7 +201,7 @@ func (importer *importer) Flush(_ context.Context, _ uuid.UUID) error { return nil } -func (importer *importer) ImportEngine(ctx context.Context, engineUUID uuid.UUID) error { +func (importer *importer) ImportEngine(ctx context.Context, engineUUID uuid.UUID, _ int64) error { importer.lock.Lock() defer importer.lock.Unlock() req := &import_kvpb.ImportEngineRequest{ diff --git a/br/pkg/lightning/backend/importer/importer_test.go b/br/pkg/lightning/backend/importer/importer_test.go index 524a523d2e31d..5d75d1badc245 100644 --- a/br/pkg/lightning/backend/importer/importer_test.go +++ b/br/pkg/lightning/backend/importer/importer_test.go @@ -219,7 +219,7 @@ func (s *importerSuite) TestCloseImportCleanupEngine(c *C) { engine, err := s.engine.Close(s.ctx, nil) c.Assert(err, IsNil) - err = engine.Import(s.ctx) + err = engine.Import(s.ctx, 1) c.Assert(err, IsNil) err = engine.Cleanup(s.ctx) c.Assert(err, IsNil) diff --git a/br/pkg/lightning/backend/local/local.go b/br/pkg/lightning/backend/local/local.go index 4a78cd6b4c125..ddaad8c357717 100644 --- a/br/pkg/lightning/backend/local/local.go +++ b/br/pkg/lightning/backend/local/local.go @@ -807,9 +807,7 @@ type local struct { pdAddr string g glue.Glue - localStoreDir string - regionSplitSize int64 - regionSplitKeys int64 + localStoreDir string rangeConcurrency *worker.Pool ingestConcurrency *worker.Pool @@ -939,12 +937,6 @@ func NewLocalBackend( } } - regionSplitSize := int64(cfg.RegionSplitSize) - regionSplitKeys := int64(regionMaxKeyCount) - if regionSplitSize > defaultRegionSplitSize { - regionSplitKeys = int64(float64(regionSplitSize) / float64(defaultRegionSplitSize) * float64(regionMaxKeyCount)) - } - local := &local{ engines: sync.Map{}, pdCtl: pdCtl, @@ -953,10 +945,7 @@ func NewLocalBackend( pdAddr: pdAddr, g: g, - localStoreDir: localFile, - regionSplitSize: regionSplitSize, - regionSplitKeys: regionSplitKeys, - + localStoreDir: localFile, rangeConcurrency: worker.NewPool(ctx, rangeConcurrency, "range"), ingestConcurrency: worker.NewPool(ctx, rangeConcurrency*2, "ingest"), tcpConcurrency: rangeConcurrency, @@ -1185,11 +1174,6 @@ func (local *local) RetryImportDelay() time.Duration { return defaultRetryBackoffTime } -func (local *local) MaxChunkSize() int { - // a batch size write to leveldb - return int(local.regionSplitSize) -} - func (local *local) ShouldPostProcess() bool { return true } @@ -1365,6 +1349,8 @@ func (local *local) WriteToTiKV( engineFile *File, region *split.RegionInfo, start, end []byte, + regionSplitSize int64, + regionSplitKeys int64, ) ([]*sst.SSTMeta, Range, rangeStats, error) { for _, peer := range region.Region.GetPeers() { var e error @@ -1463,7 +1449,7 @@ func (local *local) WriteToTiKV( size := int64(0) totalCount := int64(0) firstLoop := true - regionMaxSize := local.regionSplitSize * 4 / 3 + regionMaxSize := regionSplitSize * 4 / 3 for iter.First(); iter.Valid(); iter.Next() { size += int64(len(iter.Key()) + len(iter.Value())) @@ -1492,7 +1478,7 @@ func (local *local) WriteToTiKV( bytesBuf.Reset() firstLoop = false } - if size >= regionMaxSize || totalCount >= local.regionSplitKeys { + if size >= regionMaxSize || totalCount >= regionSplitKeys { break } } @@ -1624,7 +1610,7 @@ func splitRangeBySizeProps(fullRange Range, sizeProps *sizeProperties, sizeLimit return ranges } -func (local *local) readAndSplitIntoRange(ctx context.Context, engineFile *File) ([]Range, error) { +func (local *local) readAndSplitIntoRange(ctx context.Context, engineFile *File, regionSplitSize int64, regionSplitKeys int64) ([]Range, error) { iter := newKeyIter(ctx, engineFile, &pebble.IterOptions{}) defer iter.Close() @@ -1653,7 +1639,7 @@ func (local *local) readAndSplitIntoRange(ctx context.Context, engineFile *File) engineFileLength := engineFile.Length.Load() // <= 96MB no need to split into range - if engineFileTotalSize <= local.regionSplitSize && engineFileLength <= local.regionSplitKeys { + if engineFileTotalSize <= regionSplitSize && engineFileLength <= regionSplitKeys { ranges := []Range{{start: firstKey, end: endKey}} return ranges, nil } @@ -1664,7 +1650,7 @@ func (local *local) readAndSplitIntoRange(ctx context.Context, engineFile *File) } ranges := splitRangeBySizeProps(Range{start: firstKey, end: endKey}, sizeProps, - local.regionSplitSize, local.regionSplitKeys) + regionSplitSize, regionSplitKeys) log.L().Info("split engine key ranges", zap.Stringer("engine", engineFile.UUID), zap.Int64("totalSize", engineFileTotalSize), zap.Int64("totalCount", engineFileLength), @@ -1678,6 +1664,8 @@ func (local *local) writeAndIngestByRange( ctxt context.Context, engineFile *File, start, end []byte, + regionSplitSize int64, + regionSplitKeys int64, ) error { ito := &pebble.IterOptions{ LowerBound: start, @@ -1736,7 +1724,7 @@ WriteAndIngest: zap.Binary("end", region.Region.GetEndKey()), zap.Reflect("peers", region.Region.GetPeers())) w := local.ingestConcurrency.Apply() - err = local.writeAndIngestPairs(ctx, engineFile, region, pairStart, end) + err = local.writeAndIngestPairs(ctx, engineFile, region, pairStart, end, regionSplitSize, regionSplitKeys) local.ingestConcurrency.Recycle(w) if err != nil { if common.IsContextCanceledError(err) { @@ -1774,6 +1762,8 @@ func (local *local) writeAndIngestPairs( engineFile *File, region *split.RegionInfo, start, end []byte, + regionSplitSize int64, + regionSplitKeys int64, ) error { var err error @@ -1782,7 +1772,7 @@ loopWrite: var metas []*sst.SSTMeta var finishedRange Range var rangeStats rangeStats - metas, finishedRange, rangeStats, err = local.WriteToTiKV(ctx, engineFile, region, start, end) + metas, finishedRange, rangeStats, err = local.WriteToTiKV(ctx, engineFile, region, start, end, regionSplitSize, regionSplitKeys) if err != nil { if common.IsContextCanceledError(err) { return err @@ -1889,7 +1879,7 @@ loopWrite: return errors.Trace(err) } -func (local *local) writeAndIngestByRanges(ctx context.Context, engineFile *File, ranges []Range) error { +func (local *local) writeAndIngestByRanges(ctx context.Context, engineFile *File, ranges []Range, regionSplitSize int64, regionSplitKeys int64) error { if engineFile.Length.Load() == 0 { // engine is empty, this is likes because it's a index engine but the table contains no index log.L().Info("engine contains no data", zap.Stringer("uuid", engineFile.UUID)) @@ -1921,7 +1911,7 @@ func (local *local) writeAndIngestByRanges(ctx context.Context, engineFile *File // max retry backoff time: 2+4+8+16=30s backOffTime := time.Second for i := 0; i < maxRetryTimes; i++ { - err = local.writeAndIngestByRange(ctx, engineFile, startKey, endKey) + err = local.writeAndIngestByRange(ctx, engineFile, startKey, endKey, regionSplitSize, regionSplitKeys) if err == nil || common.IsContextCanceledError(err) { return } @@ -1967,7 +1957,7 @@ func (r *syncedRanges) reset() { r.Unlock() } -func (local *local) ImportEngine(ctx context.Context, engineUUID uuid.UUID) error { +func (local *local) ImportEngine(ctx context.Context, engineUUID uuid.UUID, regionSplitSize int64) error { lf := local.lockEngine(engineUUID, importMutexStateImport) if lf == nil { // skip if engine not exist. See the comment of `CloseEngine` for more detail. @@ -1981,9 +1971,13 @@ func (local *local) ImportEngine(ctx context.Context, engineUUID uuid.UUID) erro log.L().Info("engine contains no kv, skip import", zap.Stringer("engine", engineUUID)) return nil } + regionSplitKeys := int64(regionMaxKeyCount) + if regionSplitSize > defaultRegionSplitSize { + regionSplitKeys = int64(float64(regionSplitSize) / float64(defaultRegionSplitSize) * float64(regionMaxKeyCount)) + } // split sorted file into range by 96MB size per file - ranges, err := local.readAndSplitIntoRange(ctx, lf) + ranges, err := local.readAndSplitIntoRange(ctx, lf, regionSplitSize, regionSplitKeys) if err != nil { return err } @@ -1999,10 +1993,10 @@ func (local *local) ImportEngine(ctx context.Context, engineUUID uuid.UUID) erro // if all the kv can fit in one region, skip split regions. TiDB will split one region for // the table when table is created. - needSplit := len(unfinishedRanges) > 1 || lfTotalSize > local.regionSplitSize || lfLength > local.regionSplitKeys + needSplit := len(unfinishedRanges) > 1 || lfTotalSize > regionSplitSize || lfLength > regionSplitKeys // split region by given ranges for i := 0; i < maxRetryTimes; i++ { - err = local.SplitAndScatterRegionByRanges(ctx, unfinishedRanges, lf.tableInfo, needSplit) + err = local.SplitAndScatterRegionByRanges(ctx, unfinishedRanges, lf.tableInfo, needSplit, regionSplitSize) if err == nil || common.IsContextCanceledError(err) { break } @@ -2016,7 +2010,7 @@ func (local *local) ImportEngine(ctx context.Context, engineUUID uuid.UUID) erro } // start to write to kv and ingest - err = local.writeAndIngestByRanges(ctx, lf, unfinishedRanges) + err = local.writeAndIngestByRanges(ctx, lf, unfinishedRanges, regionSplitSize, regionSplitKeys) if err != nil { log.L().Error("write and ingest engine failed", log.ShortError(err)) return err diff --git a/br/pkg/lightning/backend/local/local_unix.go b/br/pkg/lightning/backend/local/local_unix.go index d2383aca745c1..57b2fcec4149b 100644 --- a/br/pkg/lightning/backend/local/local_unix.go +++ b/br/pkg/lightning/backend/local/local_unix.go @@ -25,8 +25,8 @@ import ( ) const ( - // mininum max open files value - minRLimit = 1024 + // maximum max open files value + maxRLimit = 1000000 ) func GetSystemRLimit() (Rlim_t, error) { @@ -39,8 +39,8 @@ func GetSystemRLimit() (Rlim_t, error) { // In Local-backend, we need to read and write a lot of L0 SST files, so we need // to check system max open files limit. func VerifyRLimit(estimateMaxFiles Rlim_t) error { - if estimateMaxFiles < minRLimit { - estimateMaxFiles = minRLimit + if estimateMaxFiles > maxRLimit { + estimateMaxFiles = maxRLimit } var rLimit syscall.Rlimit err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, &rLimit) diff --git a/br/pkg/lightning/backend/local/localhelper.go b/br/pkg/lightning/backend/local/localhelper.go index 2d9dc5c48cbdd..bc7a7a65d2a4c 100644 --- a/br/pkg/lightning/backend/local/localhelper.go +++ b/br/pkg/lightning/backend/local/localhelper.go @@ -67,6 +67,7 @@ func (local *local) SplitAndScatterRegionByRanges( ranges []Range, tableInfo *checkpoints.TidbTableInfo, needSplit bool, + regionSplitSize int64, ) error { if len(ranges) == 0 { return nil @@ -270,7 +271,7 @@ func (local *local) SplitAndScatterRegionByRanges( if !ok { log.L().Warn("region stats not found", zap.Uint64("region", regionID)) } - if len(keys) == 1 && regionSize < local.regionSplitSize { + if len(keys) == 1 && regionSize < regionSplitSize { skippedKeys++ } select { diff --git a/br/pkg/lightning/backend/local/localhelper_test.go b/br/pkg/lightning/backend/local/localhelper_test.go index 11149a000fbe2..502f9a1be7d8a 100644 --- a/br/pkg/lightning/backend/local/localhelper_test.go +++ b/br/pkg/lightning/backend/local/localhelper_test.go @@ -424,7 +424,7 @@ func (s *localSuite) doTestBatchSplitRegionByRanges(ctx context.Context, c *C, h start = end } - err = local.SplitAndScatterRegionByRanges(ctx, ranges, nil, true) + err = local.SplitAndScatterRegionByRanges(ctx, ranges, nil, true, 1000) if len(errPat) == 0 { c.Assert(err, IsNil) } else { @@ -643,7 +643,7 @@ func (s *localSuite) doTestBatchSplitByRangesWithClusteredIndex(c *C, hook clien start = e } - err := local.SplitAndScatterRegionByRanges(ctx, ranges, nil, true) + err := local.SplitAndScatterRegionByRanges(ctx, ranges, nil, true, 1000) c.Assert(err, IsNil) startKey := codec.EncodeBytes([]byte{}, rangeKeys[0]) diff --git a/br/pkg/lightning/backend/noop/noop.go b/br/pkg/lightning/backend/noop/noop.go index 37ca4fd8e77a2..ca095844024d8 100644 --- a/br/pkg/lightning/backend/noop/noop.go +++ b/br/pkg/lightning/backend/noop/noop.go @@ -78,7 +78,7 @@ func (b noopBackend) CloseEngine(ctx context.Context, cfg *backend.EngineConfig, return nil } -func (b noopBackend) ImportEngine(ctx context.Context, engineUUID uuid.UUID) error { +func (b noopBackend) ImportEngine(ctx context.Context, engineUUID uuid.UUID, regionSplitSize int64) error { return nil } diff --git a/br/pkg/lightning/backend/tidb/tidb.go b/br/pkg/lightning/backend/tidb/tidb.go index b2259ffb5c8a3..092893ab9d2d9 100644 --- a/br/pkg/lightning/backend/tidb/tidb.go +++ b/br/pkg/lightning/backend/tidb/tidb.go @@ -368,7 +368,7 @@ func (be *tidbBackend) CollectRemoteDuplicateRows(ctx context.Context, tbl table panic("Unsupported Operation") } -func (be *tidbBackend) ImportEngine(context.Context, uuid.UUID) error { +func (be *tidbBackend) ImportEngine(context.Context, uuid.UUID, int64) error { return nil } diff --git a/br/pkg/lightning/config/config.go b/br/pkg/lightning/config/config.go index 4c2fbba4c1c55..a112ed4d67418 100644 --- a/br/pkg/lightning/config/config.go +++ b/br/pkg/lightning/config/config.go @@ -436,7 +436,7 @@ func NewConfig() *Config { OnDuplicate: ReplaceOnDup, MaxKVPairs: 4096, SendKVPairs: 32768, - RegionSplitSize: SplitRegionSize, + RegionSplitSize: 0, DiskQuota: ByteSize(math.MaxInt64), }, PostRestore: PostRestore{ @@ -740,9 +740,6 @@ func (cfg *Config) DefaultVarsForImporterAndLocalBackend(ctx context.Context) { if cfg.TikvImporter.RangeConcurrency == 0 { cfg.TikvImporter.RangeConcurrency = 16 } - if cfg.TikvImporter.RegionSplitSize == 0 { - cfg.TikvImporter.RegionSplitSize = SplitRegionSize - } if cfg.TiDB.BuildStatsConcurrency == 0 { cfg.TiDB.BuildStatsConcurrency = defaultBuildStatsConcurrency } diff --git a/br/pkg/lightning/config/const.go b/br/pkg/lightning/config/const.go index 78ad85c2944d7..4f262eaddbcca 100644 --- a/br/pkg/lightning/config/const.go +++ b/br/pkg/lightning/config/const.go @@ -20,10 +20,11 @@ import ( const ( // mydumper - ReadBlockSize ByteSize = 64 * units.KiB - MinRegionSize ByteSize = 256 * units.MiB - MaxRegionSize ByteSize = 256 * units.MiB - SplitRegionSize ByteSize = 96 * units.MiB + ReadBlockSize ByteSize = 64 * units.KiB + MinRegionSize ByteSize = 256 * units.MiB + MaxRegionSize ByteSize = 256 * units.MiB + SplitRegionSize ByteSize = 96 * units.MiB + MaxSplitRegionSizeRatio int = 10 BufferSizeScale = 5 diff --git a/br/pkg/lightning/lightning.go b/br/pkg/lightning/lightning.go index a25751e923f36..2596b7a7ac8d7 100644 --- a/br/pkg/lightning/lightning.go +++ b/br/pkg/lightning/lightning.go @@ -667,7 +667,10 @@ func checkSystemRequirement(cfg *config.Config, dbsMeta []*mydump.MDDatabaseMeta // region-concurrency: number of LocalWriters writing SST files. // 2*totalSize/memCacheSize: number of Pebble MemCache files. - estimateMaxFiles := local.Rlim_t(cfg.App.RegionConcurrency) + local.Rlim_t(topNTotalSize)/local.Rlim_t(cfg.TikvImporter.EngineMemCacheSize)*2 + maxDBFiles := topNTotalSize / int64(cfg.TikvImporter.LocalWriterMemCacheSize) * 2 + // the pebble db and all import routine need upto maxDBFiles fds for read and write. + maxOpenDBFiles := maxDBFiles * (1 + int64(cfg.TikvImporter.RangeConcurrency)) + estimateMaxFiles := local.Rlim_t(cfg.App.RegionConcurrency) + local.Rlim_t(maxOpenDBFiles) if err := local.VerifyRLimit(estimateMaxFiles); err != nil { return err } diff --git a/br/pkg/lightning/lightning_test.go b/br/pkg/lightning/lightning_test.go index 337adc96f9882..8bae6d89ad9e9 100644 --- a/br/pkg/lightning/lightning_test.go +++ b/br/pkg/lightning/lightning_test.go @@ -447,7 +447,8 @@ func (s *lightningServerSuite) TestCheckSystemRequirement(c *C) { cfg.App.CheckRequirements = true cfg.App.TableConcurrency = 4 cfg.TikvImporter.Backend = config.BackendLocal - cfg.TikvImporter.EngineMemCacheSize = 512 * units.MiB + cfg.TikvImporter.LocalWriterMemCacheSize = 128 * units.MiB + cfg.TikvImporter.RangeConcurrency = 16 dbMetas := []*mydump.MDDatabaseMeta{ { @@ -485,15 +486,14 @@ func (s *lightningServerSuite) TestCheckSystemRequirement(c *C) { }, } - // with max open files 1024, the max table size will be: 65536MB - err := failpoint.Enable("github.com/pingcap/tidb/br/pkg/lightning/backend/local/GetRlimitValue", "return(2049)") + err := failpoint.Enable("github.com/pingcap/tidb/br/pkg/lightning/backend/local/GetRlimitValue", "return(139439)") c.Assert(err, IsNil) err = failpoint.Enable("github.com/pingcap/tidb/br/pkg/lightning/backend/local/SetRlimitError", "return(true)") c.Assert(err, IsNil) defer func() { _ = failpoint.Disable("github.com/pingcap/tidb/br/pkg/lightning/backend/local/SetRlimitError") }() - // with this dbMetas, the estimated fds will be 2050, so should return error + // with this dbMetas, the estimated fds will be 139440, so should return error err = checkSystemRequirement(cfg, dbMetas) c.Assert(err, NotNil) @@ -501,7 +501,7 @@ func (s *lightningServerSuite) TestCheckSystemRequirement(c *C) { c.Assert(err, IsNil) // the min rlimit should be bigger than the default min value (16384) - err = failpoint.Enable("github.com/pingcap/tidb/br/pkg/lightning/backend/local/GetRlimitValue", "return(8200)") + err = failpoint.Enable("github.com/pingcap/tidb/br/pkg/lightning/backend/local/GetRlimitValue", "return(139440)") defer func() { _ = failpoint.Disable("github.com/pingcap/tidb/br/pkg/lightning/backend/local/GetRlimitValue") }() diff --git a/br/pkg/lightning/restore/restore.go b/br/pkg/lightning/restore/restore.go index 85757927134aa..04e926219a92e 100644 --- a/br/pkg/lightning/restore/restore.go +++ b/br/pkg/lightning/restore/restore.go @@ -1621,7 +1621,8 @@ func (rc *Controller) enforceDiskQuota(ctx context.Context) { task := logger.Begin(zap.WarnLevel, "importing large engines for disk quota") var importErr error for _, engine := range largeEngines { - if err := rc.backend.UnsafeImportAndReset(ctx, engine); err != nil { + // Use a larger split region size to avoid split the same region by many times. + if err := rc.backend.UnsafeImportAndReset(ctx, engine, int64(config.SplitRegionSize)*int64(config.MaxSplitRegionSizeRatio)); err != nil { importErr = multierr.Append(importErr, err) } } diff --git a/br/pkg/lightning/restore/restore_test.go b/br/pkg/lightning/restore/restore_test.go index 7758d7d81aa11..a52b61438b21b 100644 --- a/br/pkg/lightning/restore/restore_test.go +++ b/br/pkg/lightning/restore/restore_test.go @@ -852,7 +852,7 @@ func (s *tableRestoreSuite) TestImportKVSuccess(c *C) { importer := backend.MakeBackend(mockBackend) chptCh := make(chan saveCp) defer close(chptCh) - rc := &Controller{saveCpCh: chptCh} + rc := &Controller{saveCpCh: chptCh, cfg: config.NewConfig()} go func() { for range chptCh { } @@ -865,7 +865,7 @@ func (s *tableRestoreSuite) TestImportKVSuccess(c *C) { CloseEngine(ctx, nil, engineUUID). Return(nil) mockBackend.EXPECT(). - ImportEngine(ctx, engineUUID). + ImportEngine(ctx, engineUUID, gomock.Any()). Return(nil) mockBackend.EXPECT(). CleanupEngine(ctx, engineUUID). @@ -884,7 +884,7 @@ func (s *tableRestoreSuite) TestImportKVFailure(c *C) { importer := backend.MakeBackend(mockBackend) chptCh := make(chan saveCp) defer close(chptCh) - rc := &Controller{saveCpCh: chptCh} + rc := &Controller{saveCpCh: chptCh, cfg: config.NewConfig()} go func() { for range chptCh { } @@ -897,7 +897,7 @@ func (s *tableRestoreSuite) TestImportKVFailure(c *C) { CloseEngine(ctx, nil, engineUUID). Return(nil) mockBackend.EXPECT(). - ImportEngine(ctx, engineUUID). + ImportEngine(ctx, engineUUID, gomock.Any()). Return(errors.Annotate(context.Canceled, "fake import error")) closedEngine, err := importer.UnsafeCloseEngineWithUUID(ctx, nil, "tag", engineUUID) diff --git a/br/pkg/lightning/restore/table_restore.go b/br/pkg/lightning/restore/table_restore.go index ecc22ead593fe..3b4c681037317 100644 --- a/br/pkg/lightning/restore/table_restore.go +++ b/br/pkg/lightning/restore/table_restore.go @@ -850,8 +850,17 @@ func (tr *TableRestore) importKV( engineID int32, ) error { task := closedEngine.Logger().Begin(zap.InfoLevel, "import and cleanup engine") - - err := closedEngine.Import(ctx) + regionSplitSize := int64(rc.cfg.TikvImporter.RegionSplitSize) + if regionSplitSize == 0 && rc.taskMgr != nil { + regionSplitSize = int64(config.SplitRegionSize) + rc.taskMgr.CheckTasksExclusively(ctx, func(tasks []taskMeta) ([]taskMeta, error) { + if len(tasks) > 0 { + regionSplitSize = int64(config.SplitRegionSize) * int64(utils.MinInt(len(tasks), config.MaxSplitRegionSizeRatio)) + } + return nil, nil + }) + } + err := closedEngine.Import(ctx, regionSplitSize) rc.saveStatusCheckpoint(tr.tableName, engineID, err, checkpoints.CheckpointStatusImported) // Also cleanup engine when encountered ErrDuplicateDetected, since all duplicates kv pairs are recorded. if err == nil { diff --git a/br/pkg/mock/backend.go b/br/pkg/mock/backend.go index 43fc2c2af6395..6c9dc25598794 100644 --- a/br/pkg/mock/backend.go +++ b/br/pkg/mock/backend.go @@ -182,17 +182,17 @@ func (mr *MockBackendMockRecorder) FlushEngine(arg0, arg1 interface{}) *gomock.C } // ImportEngine mocks base method -func (m *MockBackend) ImportEngine(arg0 context.Context, arg1 uuid.UUID) error { +func (m *MockBackend) ImportEngine(arg0 context.Context, arg1 uuid.UUID, arg2 int64) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ImportEngine", arg0, arg1) + ret := m.ctrl.Call(m, "ImportEngine", arg0, arg1, arg2) ret0, _ := ret[0].(error) return ret0 } // ImportEngine indicates an expected call of ImportEngine -func (mr *MockBackendMockRecorder) ImportEngine(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockBackendMockRecorder) ImportEngine(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ImportEngine", reflect.TypeOf((*MockBackend)(nil).ImportEngine), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ImportEngine", reflect.TypeOf((*MockBackend)(nil).ImportEngine), arg0, arg1, arg2) } // LocalWriter mocks base method diff --git a/br/pkg/restore/split_client.go b/br/pkg/restore/split_client.go index 888da89ca179c..6a3ab5a3cdc81 100755 --- a/br/pkg/restore/split_client.go +++ b/br/pkg/restore/split_client.go @@ -282,7 +282,7 @@ func (c *pdClient) sendSplitRegionRequest( return nil, multierr.Append(splitErrors, err) } if resp.RegionError != nil { - log.Error("fail to split region", + log.Warn("fail to split region", logutil.Region(regionInfo.Region), zap.Stringer("regionErr", resp.RegionError)) splitErrors = multierr.Append(splitErrors, diff --git a/cmd/ddltest/column_test.go b/cmd/ddltest/column_test.go index 3bee4582d2148..cf49dcbb65f16 100644 --- a/cmd/ddltest/column_test.go +++ b/cmd/ddltest/column_test.go @@ -226,3 +226,12 @@ func (s *TestDDLSuite) TestCommitWhenSchemaChanged(c *C) { _, err = s1.Execute(ctx, "commit") c.Assert(terror.ErrorEqual(err, plannercore.ErrWrongValueCountOnRow), IsTrue, Commentf("err %v", err)) } + +func (s *TestDDLSuite) TestForIssue24621(c *C) { + s.mustExec(c, "use test") + s.mustExec(c, "drop table if exists t") + s.mustExec(c, "create table t(a char(250));") + s.mustExec(c, "insert into t values('0123456789abc');") + _, err := s.exec("alter table t modify a char(12) null;") + c.Assert(err.Error(), Equals, "[types:1265]Data truncated for column 'a', value is '0123456789abc'") +} diff --git a/ddl/column.go b/ddl/column.go index 94339088d8c70..8bd0eb6406be6 100644 --- a/ddl/column.go +++ b/ddl/column.go @@ -1391,7 +1391,7 @@ func (w *updateColumnWorker) getRowRecord(handle kv.Handle, recordKey []byte, ra // reformatErrors casted error because `convertTo` function couldn't package column name and datum value for some errors. func (w *updateColumnWorker) reformatErrors(err error) error { // Since row count is not precious in concurrent reorganization, here we substitute row count with datum value. - if types.ErrTruncated.Equal(err) { + if types.ErrTruncated.Equal(err) || types.ErrDataTooLong.Equal(err) { dStr := datumToStringNoErr(w.rowMap[w.oldColInfo.ID]) err = types.ErrTruncated.GenWithStack("Data truncated for column '%s', value is '%s'", w.oldColInfo.Name, dStr) } diff --git a/ddl/column_type_change_test.go b/ddl/column_type_change_test.go index a30c35507f6cb..7246c63081706 100644 --- a/ddl/column_type_change_test.go +++ b/ddl/column_type_change_test.go @@ -802,13 +802,13 @@ func (s *testColumnTypeChangeSuite) TestColumnTypeChangeFromNumericToOthers(c *C // binary reset(tk) tk.MustExec("insert into t values (-258.12345, 333.33, 2000000.20000002, 323232323.3232323232, -111.11111111, -222222222222.222222222222222, b'10101')") - tk.MustGetErrCode("alter table t modify d binary(10)", mysql.ErrDataTooLong) + tk.MustGetErrCode("alter table t modify d binary(10)", mysql.WarnDataTruncated) tk.MustExec("alter table t modify n binary(10)") - tk.MustGetErrCode("alter table t modify r binary(10)", mysql.ErrDataTooLong) - tk.MustGetErrCode("alter table t modify db binary(10)", mysql.ErrDataTooLong) + tk.MustGetErrCode("alter table t modify r binary(10)", mysql.WarnDataTruncated) + tk.MustGetErrCode("alter table t modify db binary(10)", mysql.WarnDataTruncated) // MySQL will run with no error. - tk.MustGetErrCode("alter table t modify f32 binary(10)", mysql.ErrDataTooLong) - tk.MustGetErrCode("alter table t modify f64 binary(10)", mysql.ErrDataTooLong) + tk.MustGetErrCode("alter table t modify f32 binary(10)", mysql.WarnDataTruncated) + tk.MustGetErrCode("alter table t modify f64 binary(10)", mysql.WarnDataTruncated) tk.MustExec("alter table t modify b binary(10)") tk.MustQuery("select * from t").Check(testkit.Rows("-258.1234500 333.33\x00\x00\x00\x00 2000000.20000002 323232323.32323235 -111.111115 -222222222222.22223 21\x00\x00\x00\x00\x00\x00\x00\x00")) diff --git a/ddl/db_test.go b/ddl/db_test.go index 0efe9063dad62..7dd237090b9bf 100644 --- a/ddl/db_test.go +++ b/ddl/db_test.go @@ -4458,7 +4458,7 @@ func (s *testSerialDBSuite) TestModifyColumnBetweenStringTypes(c *C) { mvc := getModifyColumn(c, s.s.(sessionctx.Context), "test", "tt", "a", false) c.Assert(mvc.FieldType.Flen, Equals, 5) tk.MustQuery("select * from tt").Check(testkit.Rows("111", "10000")) - tk.MustGetErrMsg("alter table tt change a a varchar(4);", "[types:1406]Data Too Long, field len 4, data len 5") + tk.MustGetErrMsg("alter table tt change a a varchar(4);", "[types:1265]Data truncated for column 'a', value is '10000'") tk.MustExec("alter table tt change a a varchar(100);") tk.MustQuery("select length(a) from tt").Check(testkit.Rows("3", "5")) @@ -4470,7 +4470,7 @@ func (s *testSerialDBSuite) TestModifyColumnBetweenStringTypes(c *C) { mc := getModifyColumn(c, s.s.(sessionctx.Context), "test", "tt", "a", false) c.Assert(mc.FieldType.Flen, Equals, 5) tk.MustQuery("select * from tt").Check(testkit.Rows("111", "10000")) - tk.MustGetErrMsg("alter table tt change a a char(4);", "[types:1406]Data Too Long, field len 4, data len 5") + tk.MustGetErrMsg("alter table tt change a a char(4);", "[types:1265]Data truncated for column 'a', value is '10000'") tk.MustExec("alter table tt change a a char(100);") tk.MustQuery("select length(a) from tt").Check(testkit.Rows("3", "5")) @@ -4478,11 +4478,11 @@ func (s *testSerialDBSuite) TestModifyColumnBetweenStringTypes(c *C) { tk.MustExec("drop table if exists tt;") tk.MustExec("create table tt (a binary(10));") tk.MustExec("insert into tt values ('111'),('10000');") - tk.MustGetErrMsg("alter table tt change a a binary(5);", "[types:1406]Data Too Long, field len 5, data len 10") + tk.MustGetErrMsg("alter table tt change a a binary(5);", "[types:1265]Data truncated for column 'a', value is '111\x00\x00\x00\x00\x00\x00\x00'") mb := getModifyColumn(c, s.s.(sessionctx.Context), "test", "tt", "a", false) c.Assert(mb.FieldType.Flen, Equals, 10) tk.MustQuery("select * from tt").Check(testkit.Rows("111\x00\x00\x00\x00\x00\x00\x00", "10000\x00\x00\x00\x00\x00")) - tk.MustGetErrMsg("alter table tt change a a binary(4);", "[types:1406]Data Too Long, field len 4, data len 10") + tk.MustGetErrMsg("alter table tt change a a binary(4);", "[types:1265]Data truncated for column 'a', value is '111\x00\x00\x00\x00\x00\x00\x00'") tk.MustExec("alter table tt change a a binary(12);") tk.MustQuery("select * from tt").Check(testkit.Rows("111\x00\x00\x00\x00\x00\x00\x00\x00\x00", "10000\x00\x00\x00\x00\x00\x00\x00")) tk.MustQuery("select length(a) from tt").Check(testkit.Rows("12", "12")) @@ -4495,7 +4495,7 @@ func (s *testSerialDBSuite) TestModifyColumnBetweenStringTypes(c *C) { mvb := getModifyColumn(c, s.s.(sessionctx.Context), "test", "tt", "a", false) c.Assert(mvb.FieldType.Flen, Equals, 5) tk.MustQuery("select * from tt").Check(testkit.Rows("111", "10000")) - tk.MustGetErrMsg("alter table tt change a a varbinary(4);", "[types:1406]Data Too Long, field len 4, data len 5") + tk.MustGetErrMsg("alter table tt change a a varbinary(4);", "[types:1265]Data truncated for column 'a', value is '10000'") tk.MustExec("alter table tt change a a varbinary(12);") tk.MustQuery("select * from tt").Check(testkit.Rows("111", "10000")) tk.MustQuery("select length(a) from tt").Check(testkit.Rows("3", "5")) @@ -4510,7 +4510,7 @@ func (s *testSerialDBSuite) TestModifyColumnBetweenStringTypes(c *C) { c.Assert(c2.FieldType.Tp, Equals, mysql.TypeString) c.Assert(c2.FieldType.Flen, Equals, 10) tk.MustQuery("select * from tt").Check(testkit.Rows("111", "10000")) - tk.MustGetErrMsg("alter table tt change a a char(4);", "[types:1406]Data Too Long, field len 4, data len 5") + tk.MustGetErrMsg("alter table tt change a a char(4);", "[types:1265]Data truncated for column 'a', value is '10000'") // char to text tk.MustExec("alter table tt change a a text;") diff --git a/docs/design/2020-06-24-placement-rules-in-sql.md b/docs/design/2020-06-24-placement-rules-in-sql.md index 6003a9ed68870..bd3d68400fa9c 100644 --- a/docs/design/2020-06-24-placement-rules-in-sql.md +++ b/docs/design/2020-06-24-placement-rules-in-sql.md @@ -318,7 +318,7 @@ This is what a label rule may look like: } ``` -It connects the table name `db1/tb` with the key range. +It connects the table name `db1/tb1` with the key range. Now you need to connect the label with the database / table / partition name in the placement rules. diff --git a/domain/infosync/info.go b/domain/infosync/info.go index e885bbeb175e4..dbf62e276d210 100644 --- a/domain/infosync/info.go +++ b/domain/infosync/info.go @@ -87,16 +87,17 @@ var ErrPrometheusAddrIsNotSet = dbterror.ClassDomain.NewStd(errno.ErrPrometheusA // InfoSyncer stores server info to etcd when the tidb-server starts and delete when tidb-server shuts down. type InfoSyncer struct { - etcdCli *clientv3.Client - info *ServerInfo - serverInfoPath string - minStartTS uint64 - minStartTSPath string - manager util2.SessionManager - session *concurrency.Session - topologySession *concurrency.Session - prometheusAddr string - modifyTime time.Time + etcdCli *clientv3.Client + info *ServerInfo + serverInfoPath string + minStartTS uint64 + minStartTSPath string + manager util2.SessionManager + session *concurrency.Session + topologySession *concurrency.Session + prometheusAddr string + modifyTime time.Time + labelRuleManager LabelRuleManager } // ServerInfo is server static information. @@ -175,6 +176,11 @@ func GlobalInfoSyncerInit(ctx context.Context, id string, serverIDGetter func() if err != nil { return nil, err } + if etcdCli != nil { + is.labelRuleManager = initLabelRuleManager(etcdCli.Endpoints()) + } else { + is.labelRuleManager = initLabelRuleManager([]string{}) + } setGlobalInfoSyncer(is) return is, nil } @@ -201,6 +207,13 @@ func (is *InfoSyncer) GetSessionManager() util2.SessionManager { return is.manager } +func initLabelRuleManager(addrs []string) LabelRuleManager { + if len(addrs) == 0 { + return &mockLabelManager{labelRules: map[string]*label.Rule{}} + } + return &PDLabelManager{addrs: addrs} +} + // GetServerInfo gets self server static information. func GetServerInfo() (*ServerInfo, error) { is, err := getGlobalInfoSyncer() @@ -817,24 +830,10 @@ func PutLabelRule(ctx context.Context, rule *label.Rule) error { if err != nil { return err } - - if is.etcdCli == nil { + if is.labelRuleManager == nil { return nil } - - addrs := is.etcdCli.Endpoints() - - if len(addrs) == 0 { - return errors.Errorf("pd unavailable") - } - - r, err := json.Marshal(rule) - if err != nil { - return err - } - - _, err = doRequest(ctx, addrs, path.Join(pdapi.Config, "region-label", "rule"), "POST", bytes.NewReader(r)) - return err + return is.labelRuleManager.PutLabelRule(ctx, rule) } // UpdateLabelRules synchronizes the label rule to PD. @@ -847,24 +846,10 @@ func UpdateLabelRules(ctx context.Context, patch *label.RulePatch) error { if err != nil { return err } - - if is.etcdCli == nil { + if is.labelRuleManager == nil { return nil } - - addrs := is.etcdCli.Endpoints() - - if len(addrs) == 0 { - return errors.Errorf("pd unavailable") - } - - r, err := json.Marshal(patch) - if err != nil { - return err - } - - _, err = doRequest(ctx, addrs, path.Join(pdapi.Config, "region-label", "rules"), "PATCH", bytes.NewReader(r)) - return err + return is.labelRuleManager.UpdateLabelRules(ctx, patch) } // GetAllLabelRules gets all label rules from PD. @@ -873,24 +858,10 @@ func GetAllLabelRules(ctx context.Context) ([]*label.Rule, error) { if err != nil { return nil, err } - - if is.etcdCli == nil { - return nil, err - } - - addrs := is.etcdCli.Endpoints() - - if len(addrs) == 0 { - return nil, errors.Errorf("pd unavailable") - } - - rules := []*label.Rule{} - res, err := doRequest(ctx, addrs, path.Join(pdapi.Config, "region-label", "rules"), "GET", nil) - - if err == nil && res != nil { - err = json.Unmarshal(res, &rules) + if is.labelRuleManager == nil { + return nil, nil } - return rules, err + return is.labelRuleManager.GetAllLabelRules(ctx) } // GetLabelRules gets the label rules according to the given IDs from PD. @@ -903,27 +874,8 @@ func GetLabelRules(ctx context.Context, ruleIDs []string) ([]*label.Rule, error) if err != nil { return nil, err } - - if is.etcdCli == nil { + if is.labelRuleManager == nil { return nil, nil } - - addrs := is.etcdCli.Endpoints() - - if len(addrs) == 0 { - return nil, errors.Errorf("pd unavailable") - } - - ids, err := json.Marshal(ruleIDs) - if err != nil { - return nil, err - } - - rules := []*label.Rule{} - res, err := doRequest(ctx, addrs, path.Join(pdapi.Config, "region-label", "rules", "ids"), "GET", bytes.NewReader(ids)) - - if err == nil && res != nil { - err = json.Unmarshal(res, &rules) - } - return rules, err + return is.labelRuleManager.GetLabelRules(ctx, ruleIDs) } diff --git a/domain/infosync/label_manager.go b/domain/infosync/label_manager.go new file mode 100644 index 0000000000000..bf6ba634fdf24 --- /dev/null +++ b/domain/infosync/label_manager.go @@ -0,0 +1,155 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package infosync + +import ( + "bytes" + "context" + "encoding/json" + "path" + "sync" + + "github.com/pingcap/tidb/ddl/label" + "github.com/pingcap/tidb/util/pdapi" +) + +// LabelRuleManager manages label rules +type LabelRuleManager interface { + PutLabelRule(ctx context.Context, rule *label.Rule) error + UpdateLabelRules(ctx context.Context, patch *label.RulePatch) error + GetAllLabelRules(ctx context.Context) ([]*label.Rule, error) + GetLabelRules(ctx context.Context, ruleIDs []string) ([]*label.Rule, error) +} + +// PDLabelManager manages rules with pd +type PDLabelManager struct { + addrs []string +} + +// PutLabelRule implements PutLabelRule +func (lm *PDLabelManager) PutLabelRule(ctx context.Context, rule *label.Rule) error { + r, err := json.Marshal(rule) + if err != nil { + return err + } + _, err = doRequest(ctx, lm.addrs, path.Join(pdapi.Config, "region-label", "rule"), "POST", bytes.NewReader(r)) + return err +} + +// UpdateLabelRules implements UpdateLabelRules +func (lm *PDLabelManager) UpdateLabelRules(ctx context.Context, patch *label.RulePatch) error { + r, err := json.Marshal(patch) + if err != nil { + return err + } + + _, err = doRequest(ctx, lm.addrs, path.Join(pdapi.Config, "region-label", "rules"), "PATCH", bytes.NewReader(r)) + return err +} + +// GetAllLabelRules implements GetAllLabelRules +func (lm *PDLabelManager) GetAllLabelRules(ctx context.Context) ([]*label.Rule, error) { + var rules []*label.Rule + res, err := doRequest(ctx, lm.addrs, path.Join(pdapi.Config, "region-label", "rules"), "GET", nil) + + if err == nil && res != nil { + err = json.Unmarshal(res, &rules) + } + return rules, err +} + +// GetLabelRules implements GetLabelRules +func (lm *PDLabelManager) GetLabelRules(ctx context.Context, ruleIDs []string) ([]*label.Rule, error) { + ids, err := json.Marshal(ruleIDs) + if err != nil { + return nil, err + } + + rules := []*label.Rule{} + res, err := doRequest(ctx, lm.addrs, path.Join(pdapi.Config, "region-label", "rules", "ids"), "GET", bytes.NewReader(ids)) + + if err == nil && res != nil { + err = json.Unmarshal(res, &rules) + } + return rules, err +} + +type mockLabelManager struct { + sync.RWMutex + labelRules map[string]*label.Rule +} + +// PutLabelRule implements PutLabelRule +func (mm *mockLabelManager) PutLabelRule(ctx context.Context, rule *label.Rule) error { + mm.Lock() + defer mm.Unlock() + if rule == nil { + return nil + } + mm.labelRules[rule.ID] = rule + return nil +} + +// UpdateLabelRules implements UpdateLabelRules +func (mm *mockLabelManager) UpdateLabelRules(ctx context.Context, patch *label.RulePatch) error { + mm.Lock() + defer mm.Unlock() + if patch == nil { + return nil + } + for _, p := range patch.DeleteRules { + delete(mm.labelRules, p) + } + for _, p := range patch.SetRules { + if p == nil { + continue + } + mm.labelRules[p.ID] = p + } + return nil +} + +// mockLabelManager implements GetAllLabelRules +func (mm *mockLabelManager) GetAllLabelRules(ctx context.Context) ([]*label.Rule, error) { + mm.RLock() + defer mm.RUnlock() + r := make([]*label.Rule, 0, len(mm.labelRules)) + for _, labelRule := range mm.labelRules { + if labelRule == nil { + continue + } + r = append(r, labelRule) + } + return r, nil +} + +// mockLabelManager implements GetLabelRules +func (mm *mockLabelManager) GetLabelRules(ctx context.Context, ruleIDs []string) ([]*label.Rule, error) { + mm.RLock() + defer mm.RUnlock() + r := make([]*label.Rule, 0, len(ruleIDs)) + for _, ruleID := range ruleIDs { + for _, labelRule := range mm.labelRules { + if labelRule.ID == ruleID { + if labelRule == nil { + continue + } + r = append(r, labelRule) + break + } + } + } + return r, nil +} diff --git a/executor/brie_test.go b/executor/brie_test.go index 87ae456d35b4d..05657b82117c8 100644 --- a/executor/brie_test.go +++ b/executor/brie_test.go @@ -96,7 +96,7 @@ func (s *testBRIESuite) TestFetchShowBRIE(c *C) { p.SetParserConfig(parser.ParserConfig{EnableWindowFunction: true, EnableStrictDoubleTypeCheck: true}) stmt, err := p.ParseOneStmt("show backups", "", "") c.Assert(err, IsNil) - plan, _, err := core.BuildLogicalPlan(ctx, sctx, stmt, infoschema.MockInfoSchema([]*model.TableInfo{core.MockSignedTable(), core.MockUnsignedTable(), core.MockView()})) + plan, _, err := core.BuildLogicalPlanForTest(ctx, sctx, stmt, infoschema.MockInfoSchema([]*model.TableInfo{core.MockSignedTable(), core.MockUnsignedTable(), core.MockView()})) c.Assert(err, IsNil) schema := plan.Schema() diff --git a/expression/integration_test.go b/expression/integration_test.go index d28135608ea08..7f97ee8261a55 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -4905,7 +4905,7 @@ func (s *testIntegrationSuite) TestFilterExtractFromDNF(c *C) { ret := &plannercore.PreprocessorReturn{} err = plannercore.Preprocess(sctx, stmts[0], plannercore.WithPreprocessorReturn(ret)) c.Assert(err, IsNil, Commentf("error %v, for resolve name, expr %s", err, tt.exprStr)) - p, _, err := plannercore.BuildLogicalPlan(ctx, sctx, stmts[0], ret.InfoSchema) + p, _, err := plannercore.BuildLogicalPlanForTest(ctx, sctx, stmts[0], ret.InfoSchema) c.Assert(err, IsNil, Commentf("error %v, for build plan, expr %s", err, tt.exprStr)) selection := p.(plannercore.LogicalPlan).Children()[0].(*plannercore.LogicalSelection) conds := make([]expression.Expression, len(selection.Conditions)) diff --git a/expression/typeinfer_test.go b/expression/typeinfer_test.go index 293ea722a2824..8c21a005c6ad0 100644 --- a/expression/typeinfer_test.go +++ b/expression/typeinfer_test.go @@ -142,7 +142,7 @@ func (s *testInferTypeSuite) TestInferType(c *C) { ret := &plannercore.PreprocessorReturn{} err = plannercore.Preprocess(sctx, stmt, plannercore.WithPreprocessorReturn(ret)) c.Assert(err, IsNil, comment) - p, _, err := plannercore.BuildLogicalPlan(ctx, sctx, stmt, ret.InfoSchema) + p, _, err := plannercore.BuildLogicalPlanForTest(ctx, sctx, stmt, ret.InfoSchema) c.Assert(err, IsNil, comment) tp := p.Schema().Columns[0].RetType diff --git a/planner/cascades/optimize_test.go b/planner/cascades/optimize_test.go index 50bb7d9a18ba1..beb4a6ee93ac0 100644 --- a/planner/cascades/optimize_test.go +++ b/planner/cascades/optimize_test.go @@ -39,7 +39,7 @@ func TestImplGroupZeroCost(t *testing.T) { stmt, err := p.ParseOneStmt("select t1.a, t2.a from t as t1 left join t as t2 on t1.a = t2.a where t1.a < 1.0", "", "") require.NoError(t, err) - plan, _, err := plannercore.BuildLogicalPlan(context.Background(), ctx, stmt, is) + plan, _, err := plannercore.BuildLogicalPlanForTest(context.Background(), ctx, stmt, is) require.NoError(t, err) logic, ok := plan.(plannercore.LogicalPlan) @@ -64,7 +64,7 @@ func TestInitGroupSchema(t *testing.T) { stmt, err := p.ParseOneStmt("select a from t", "", "") require.NoError(t, err) - plan, _, err := plannercore.BuildLogicalPlan(context.Background(), ctx, stmt, is) + plan, _, err := plannercore.BuildLogicalPlanForTest(context.Background(), ctx, stmt, is) require.NoError(t, err) logic, ok := plan.(plannercore.LogicalPlan) @@ -87,7 +87,7 @@ func TestFillGroupStats(t *testing.T) { stmt, err := p.ParseOneStmt("select * from t t1 join t t2 on t1.a = t2.a", "", "") require.NoError(t, err) - plan, _, err := plannercore.BuildLogicalPlan(context.Background(), ctx, stmt, is) + plan, _, err := plannercore.BuildLogicalPlanForTest(context.Background(), ctx, stmt, is) require.NoError(t, err) logic, ok := plan.(plannercore.LogicalPlan) @@ -119,7 +119,7 @@ func TestPreparePossibleProperties(t *testing.T) { stmt, err := p.ParseOneStmt("select f, sum(a) from t group by f", "", "") require.NoError(t, err) - plan, _, err := plannercore.BuildLogicalPlan(context.Background(), ctx, stmt, is) + plan, _, err := plannercore.BuildLogicalPlanForTest(context.Background(), ctx, stmt, is) require.NoError(t, err) logic, ok := plan.(plannercore.LogicalPlan) @@ -214,7 +214,7 @@ func TestAppliedRuleSet(t *testing.T) { stmt, err := p.ParseOneStmt("select 1", "", "") require.NoError(t, err) - plan, _, err := plannercore.BuildLogicalPlan(context.Background(), ctx, stmt, is) + plan, _, err := plannercore.BuildLogicalPlanForTest(context.Background(), ctx, stmt, is) require.NoError(t, err) logic, ok := plan.(plannercore.LogicalPlan) diff --git a/planner/cascades/stringer_test.go b/planner/cascades/stringer_test.go index dbde4e86ce736..8bd1dab264b15 100644 --- a/planner/cascades/stringer_test.go +++ b/planner/cascades/stringer_test.go @@ -58,7 +58,7 @@ func TestGroupStringer(t *testing.T) { stmt, err := p.ParseOneStmt(sql, "", "") require.NoError(t, err) - plan, _, err := plannercore.BuildLogicalPlan(context.Background(), ctx, stmt, is) + plan, _, err := plannercore.BuildLogicalPlanForTest(context.Background(), ctx, stmt, is) require.NoError(t, err) logic, ok := plan.(plannercore.LogicalPlan) diff --git a/planner/cascades/transformation_rules_test.go b/planner/cascades/transformation_rules_test.go index ee5d93bebcd8f..bf6e9948ad95c 100644 --- a/planner/cascades/transformation_rules_test.go +++ b/planner/cascades/transformation_rules_test.go @@ -39,7 +39,7 @@ func testGroupToString(t *testing.T, input []string, output []struct { stmt, err := p.ParseOneStmt(sql, "", "") require.NoError(t, err) - plan, _, err := plannercore.BuildLogicalPlan(context.Background(), ctx, stmt, is) + plan, _, err := plannercore.BuildLogicalPlanForTest(context.Background(), ctx, stmt, is) require.NoError(t, err) logic, ok := plan.(plannercore.LogicalPlan) @@ -89,7 +89,7 @@ func TestAggPushDownGather(t *testing.T) { stmt, err := p.ParseOneStmt(sql, "", "") require.NoError(t, err) - plan, _, err := plannercore.BuildLogicalPlan(context.Background(), ctx, stmt, is) + plan, _, err := plannercore.BuildLogicalPlanForTest(context.Background(), ctx, stmt, is) require.NoError(t, err) logic, ok := plan.(plannercore.LogicalPlan) diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index 055cbb0390056..5e2d3efe6b94e 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -5890,20 +5890,16 @@ func unfoldSelectList(list *ast.SetOprSelectList, unfoldList *ast.SetOprSelectLi // If asName is true, extract AsName prior to OrigName. // Privilege check should use OrigName, while expression may use AsName. // TODO: extracting all tables by vistor model maybe a better way -func extractTableList(node ast.ResultSetNode, input []*ast.TableName, asName bool) []*ast.TableName { +func extractTableList(node ast.Node, input []*ast.TableName, asName bool) []*ast.TableName { switch x := node.(type) { - case *ast.SubqueryExpr: - input = extractTableList(x.Query, input, asName) case *ast.SelectStmt: input = extractTableList(x.From.TableRefs, input, asName) - switch w := x.Where.(type) { - case *ast.PatternInExpr: - if s, ok := w.Sel.(*ast.SubqueryExpr); ok { - input = extractTableList(s, input, asName) - } - case *ast.ExistsSubqueryExpr: - if s, ok := w.Sel.(*ast.SubqueryExpr); ok { - input = extractTableList(s, input, asName) + if x.Where != nil { + input = extractTableList(x.Where, input, asName) + } + if x.With != nil { + for _, cte := range x.With.CTEs { + input = extractTableList(cte.Query, input, asName) } } for _, f := range x.Fields.Fields { @@ -5911,12 +5907,57 @@ func extractTableList(node ast.ResultSetNode, input []*ast.TableName, asName boo input = extractTableList(s, input, asName) } } + case *ast.DeleteStmt: + input = extractTableList(x.TableRefs.TableRefs, input, asName) + if x.IsMultiTable { + for _, t := range x.Tables.Tables { + input = extractTableList(t, input, asName) + } + } + if x.Where != nil { + input = extractTableList(x.Where, input, asName) + } + if x.With != nil { + for _, cte := range x.With.CTEs { + input = extractTableList(cte.Query, input, asName) + } + } + case *ast.UpdateStmt: + input = extractTableList(x.TableRefs.TableRefs, input, asName) + for _, e := range x.List { + input = extractTableList(e.Expr, input, asName) + } + if x.Where != nil { + input = extractTableList(x.Where, input, asName) + } + if x.With != nil { + for _, cte := range x.With.CTEs { + input = extractTableList(cte.Query, input, asName) + } + } + case *ast.InsertStmt: + input = extractTableList(x.Table.TableRefs, input, asName) + input = extractTableList(x.Select, input, asName) case *ast.SetOprStmt: l := &ast.SetOprSelectList{} unfoldSelectList(x.SelectList, l) for _, s := range l.Selects { input = extractTableList(s.(ast.ResultSetNode), input, asName) } + case *ast.PatternInExpr: + if s, ok := x.Sel.(*ast.SubqueryExpr); ok { + input = extractTableList(s, input, asName) + } + case *ast.ExistsSubqueryExpr: + if s, ok := x.Sel.(*ast.SubqueryExpr); ok { + input = extractTableList(s, input, asName) + } + case *ast.BinaryOperationExpr: + if s, ok := x.R.(*ast.SubqueryExpr); ok { + input = extractTableList(s, input, asName) + } + case *ast.SubqueryExpr: + input = extractTableList(x.Query, input, asName) case *ast.Join: input = extractTableList(x.Left, input, asName) input = extractTableList(x.Right, input, asName) diff --git a/planner/core/logical_plan_test.go b/planner/core/logical_plan_test.go index ce4e64c286c19..f99090ce5439a 100644 --- a/planner/core/logical_plan_test.go +++ b/planner/core/logical_plan_test.go @@ -82,7 +82,7 @@ func (s *testPlanSuite) TestPredicatePushDown(c *C) { comment := Commentf("for %s", ca) stmt, err := s.ParseOneStmt(ca, "", "") c.Assert(err, IsNil, comment) - p, _, err := BuildLogicalPlan(ctx, s.ctx, stmt, s.is) + p, _, err := BuildLogicalPlanForTest(ctx, s.ctx, stmt, s.is) c.Assert(err, IsNil) p, err = logicalOptimize(context.TODO(), flagPredicatePushDown|flagDecorrelate|flagPrunColumns|flagPrunColumnsAgain, p.(LogicalPlan)) c.Assert(err, IsNil) @@ -109,7 +109,7 @@ func (s *testPlanSuite) TestJoinPredicatePushDown(c *C) { comment := Commentf("for %s", ca) stmt, err := s.ParseOneStmt(ca, "", "") c.Assert(err, IsNil, comment) - p, _, err := BuildLogicalPlan(ctx, s.ctx, stmt, s.is) + p, _, err := BuildLogicalPlanForTest(ctx, s.ctx, stmt, s.is) c.Assert(err, IsNil, comment) p, err = logicalOptimize(context.TODO(), flagPredicatePushDown|flagDecorrelate|flagPrunColumns|flagPrunColumnsAgain, p.(LogicalPlan)) c.Assert(err, IsNil, comment) @@ -148,7 +148,7 @@ func (s *testPlanSuite) TestOuterWherePredicatePushDown(c *C) { comment := Commentf("for %s", ca) stmt, err := s.ParseOneStmt(ca, "", "") c.Assert(err, IsNil, comment) - p, _, err := BuildLogicalPlan(ctx, s.ctx, stmt, s.is) + p, _, err := BuildLogicalPlanForTest(ctx, s.ctx, stmt, s.is) c.Assert(err, IsNil, comment) p, err = logicalOptimize(context.TODO(), flagPredicatePushDown|flagDecorrelate|flagPrunColumns|flagPrunColumnsAgain, p.(LogicalPlan)) c.Assert(err, IsNil, comment) @@ -193,7 +193,7 @@ func (s *testPlanSuite) TestSimplifyOuterJoin(c *C) { comment := Commentf("for %s", ca) stmt, err := s.ParseOneStmt(ca, "", "") c.Assert(err, IsNil, comment) - p, _, err := BuildLogicalPlan(ctx, s.ctx, stmt, s.is) + p, _, err := BuildLogicalPlanForTest(ctx, s.ctx, stmt, s.is) c.Assert(err, IsNil, comment) p, err = logicalOptimize(context.TODO(), flagPredicatePushDown|flagPrunColumns|flagPrunColumnsAgain, p.(LogicalPlan)) c.Assert(err, IsNil, comment) @@ -233,7 +233,7 @@ func (s *testPlanSuite) TestAntiSemiJoinConstFalse(c *C) { comment := Commentf("for %s", ca.sql) stmt, err := s.ParseOneStmt(ca.sql, "", "") c.Assert(err, IsNil, comment) - p, _, err := BuildLogicalPlan(ctx, s.ctx, stmt, s.is) + p, _, err := BuildLogicalPlanForTest(ctx, s.ctx, stmt, s.is) c.Assert(err, IsNil, comment) p, err = logicalOptimize(context.TODO(), flagDecorrelate|flagPredicatePushDown|flagPrunColumns|flagPrunColumnsAgain, p.(LogicalPlan)) c.Assert(err, IsNil, comment) @@ -260,7 +260,7 @@ func (s *testPlanSuite) TestDeriveNotNullConds(c *C) { comment := Commentf("for %s", ca) stmt, err := s.ParseOneStmt(ca, "", "") c.Assert(err, IsNil, comment) - p, _, err := BuildLogicalPlan(ctx, s.ctx, stmt, s.is) + p, _, err := BuildLogicalPlanForTest(ctx, s.ctx, stmt, s.is) c.Assert(err, IsNil, comment) p, err = logicalOptimize(context.TODO(), flagPredicatePushDown|flagPrunColumns|flagPrunColumnsAgain|flagDecorrelate, p.(LogicalPlan)) c.Assert(err, IsNil, comment) @@ -288,7 +288,7 @@ func (s *testPlanSuite) TestExtraPKNotNullFlag(c *C) { comment := Commentf("for %s", sql) stmt, err := s.ParseOneStmt(sql, "", "") c.Assert(err, IsNil, comment) - p, _, err := BuildLogicalPlan(ctx, s.ctx, stmt, s.is) + p, _, err := BuildLogicalPlanForTest(ctx, s.ctx, stmt, s.is) c.Assert(err, IsNil, comment) ds := p.(*LogicalProjection).children[0].(*LogicalAggregation).children[0].(*DataSource) c.Assert(ds.Columns[2].Name.L, Equals, "_tidb_rowid") @@ -309,7 +309,7 @@ func buildLogicPlan4GroupBy(s *testPlanSuite, c *C, sql string) (Plan, error) { stmt.(*ast.SelectStmt).From.TableRefs.Left.(*ast.TableSource).Source.(*ast.TableName).TableInfo = mockedTableInfo - p, _, err := BuildLogicalPlan(context.Background(), s.ctx, stmt, s.is) + p, _, err := BuildLogicalPlanForTest(context.Background(), s.ctx, stmt, s.is) return p, err } @@ -367,7 +367,7 @@ func (s *testPlanSuite) TestDupRandJoinCondsPushDown(c *C) { comment := Commentf("for %s", sql) stmt, err := s.ParseOneStmt(sql, "", "") c.Assert(err, IsNil, comment) - p, _, err := BuildLogicalPlan(context.Background(), s.ctx, stmt, s.is) + p, _, err := BuildLogicalPlanForTest(context.Background(), s.ctx, stmt, s.is) c.Assert(err, IsNil, comment) p, err = logicalOptimize(context.TODO(), flagPredicatePushDown, p.(LogicalPlan)) c.Assert(err, IsNil, comment) @@ -435,7 +435,7 @@ func (s *testPlanSuite) TestTablePartition(c *C) { s.testData.OnRecord(func() { }) - p, _, err := BuildLogicalPlan(ctx, s.ctx, stmt, isChoices[ca.IsIdx]) + p, _, err := BuildLogicalPlanForTest(ctx, s.ctx, stmt, isChoices[ca.IsIdx]) c.Assert(err, IsNil) p, err = logicalOptimize(context.TODO(), flagDecorrelate|flagPrunColumns|flagPrunColumnsAgain|flagPredicatePushDown|flagPartitionProcessor, p.(LogicalPlan)) c.Assert(err, IsNil) @@ -460,7 +460,7 @@ func (s *testPlanSuite) TestSubquery(c *C) { err = Preprocess(s.ctx, stmt, WithPreprocessorReturn(&PreprocessorReturn{InfoSchema: s.is})) c.Assert(err, IsNil) - p, _, err := BuildLogicalPlan(ctx, s.ctx, stmt, s.is) + p, _, err := BuildLogicalPlanForTest(ctx, s.ctx, stmt, s.is) c.Assert(err, IsNil) if lp, ok := p.(LogicalPlan); ok { p, err = logicalOptimize(context.TODO(), flagBuildKeyInfo|flagDecorrelate|flagPrunColumns|flagPrunColumnsAgain, lp) @@ -486,7 +486,7 @@ func (s *testPlanSuite) TestPlanBuilder(c *C) { s.ctx.GetSessionVars().SetHashJoinConcurrency(1) err = Preprocess(s.ctx, stmt, WithPreprocessorReturn(&PreprocessorReturn{InfoSchema: s.is})) c.Assert(err, IsNil) - p, _, err := BuildLogicalPlan(ctx, s.ctx, stmt, s.is) + p, _, err := BuildLogicalPlanForTest(ctx, s.ctx, stmt, s.is) c.Assert(err, IsNil) if lp, ok := p.(LogicalPlan); ok { p, err = logicalOptimize(context.TODO(), flagPrunColumns|flagPrunColumnsAgain, lp) @@ -510,7 +510,7 @@ func (s *testPlanSuite) TestJoinReOrder(c *C) { stmt, err := s.ParseOneStmt(tt, "", "") c.Assert(err, IsNil, comment) - p, _, err := BuildLogicalPlan(ctx, s.ctx, stmt, s.is) + p, _, err := BuildLogicalPlanForTest(ctx, s.ctx, stmt, s.is) c.Assert(err, IsNil) p, err = logicalOptimize(context.TODO(), flagPredicatePushDown|flagJoinReOrder, p.(LogicalPlan)) c.Assert(err, IsNil) @@ -534,7 +534,7 @@ func (s *testPlanSuite) TestEagerAggregation(c *C) { stmt, err := s.ParseOneStmt(tt, "", "") c.Assert(err, IsNil, comment) - p, _, err := BuildLogicalPlan(ctx, s.ctx, stmt, s.is) + p, _, err := BuildLogicalPlanForTest(ctx, s.ctx, stmt, s.is) c.Assert(err, IsNil) p, err = logicalOptimize(context.TODO(), flagBuildKeyInfo|flagPredicatePushDown|flagPrunColumns|flagPrunColumnsAgain|flagPushDownAgg, p.(LogicalPlan)) c.Assert(err, IsNil) @@ -560,7 +560,7 @@ func (s *testPlanSuite) TestColumnPruning(c *C) { stmt, err := s.ParseOneStmt(tt, "", "") c.Assert(err, IsNil, comment) - p, _, err := BuildLogicalPlan(ctx, s.ctx, stmt, s.is) + p, _, err := BuildLogicalPlanForTest(ctx, s.ctx, stmt, s.is) c.Assert(err, IsNil) lp, err := logicalOptimize(ctx, flagPredicatePushDown|flagPrunColumns|flagPrunColumnsAgain, p.(LogicalPlan)) c.Assert(err, IsNil) @@ -588,7 +588,7 @@ func (s *testPlanSuite) TestSortByItemsPruning(c *C) { stmt, err := s.ParseOneStmt(tt, "", "") c.Assert(err, IsNil, comment) - p, _, err := BuildLogicalPlan(ctx, s.ctx, stmt, s.is) + p, _, err := BuildLogicalPlanForTest(ctx, s.ctx, stmt, s.is) c.Assert(err, IsNil) lp, err := logicalOptimize(ctx, flagEliminateProjection|flagPredicatePushDown|flagPrunColumns|flagPrunColumnsAgain, p.(LogicalPlan)) c.Assert(err, IsNil) @@ -614,7 +614,7 @@ func (s *testPlanSuite) TestProjectionEliminator(c *C) { stmt, err := s.ParseOneStmt(tt.sql, "", "") c.Assert(err, IsNil, comment) - p, _, err := BuildLogicalPlan(ctx, s.ctx, stmt, s.is) + p, _, err := BuildLogicalPlanForTest(ctx, s.ctx, stmt, s.is) c.Assert(err, IsNil) p, err = logicalOptimize(context.TODO(), flagBuildKeyInfo|flagPrunColumns|flagPrunColumnsAgain|flagEliminateProjection, p.(LogicalPlan)) c.Assert(err, IsNil) @@ -851,7 +851,7 @@ func (s *testPlanSuite) TestValidate(c *C) { c.Assert(err, IsNil, comment) err = Preprocess(s.ctx, stmt, WithPreprocessorReturn(&PreprocessorReturn{InfoSchema: s.is})) c.Assert(err, IsNil) - _, _, err = BuildLogicalPlan(ctx, s.ctx, stmt, s.is) + _, _, err = BuildLogicalPlanForTest(ctx, s.ctx, stmt, s.is) if tt.err == nil { c.Assert(err, IsNil, comment) } else { @@ -902,7 +902,7 @@ func (s *testPlanSuite) TestUniqueKeyInfo(c *C) { stmt, err := s.ParseOneStmt(tt, "", "") c.Assert(err, IsNil, comment) - p, _, err := BuildLogicalPlan(ctx, s.ctx, stmt, s.is) + p, _, err := BuildLogicalPlanForTest(ctx, s.ctx, stmt, s.is) c.Assert(err, IsNil) lp, err := logicalOptimize(context.TODO(), flagPredicatePushDown|flagPrunColumns|flagBuildKeyInfo, p.(LogicalPlan)) c.Assert(err, IsNil) @@ -924,7 +924,7 @@ func (s *testPlanSuite) TestAggPrune(c *C) { stmt, err := s.ParseOneStmt(tt, "", "") c.Assert(err, IsNil, comment) - p, _, err := BuildLogicalPlan(ctx, s.ctx, stmt, s.is) + p, _, err := BuildLogicalPlanForTest(ctx, s.ctx, stmt, s.is) c.Assert(err, IsNil) p, err = logicalOptimize(context.TODO(), flagPredicatePushDown|flagPrunColumns|flagPrunColumnsAgain|flagBuildKeyInfo|flagEliminateAgg|flagEliminateProjection, p.(LogicalPlan)) @@ -1491,7 +1491,7 @@ func (s *testPlanSuite) TestNameResolver(c *C) { c.Assert(err, IsNil, comment) s.ctx.GetSessionVars().SetHashJoinConcurrency(1) - _, _, err = BuildLogicalPlan(ctx, s.ctx, stmt, s.is) + _, _, err = BuildLogicalPlanForTest(ctx, s.ctx, stmt, s.is) if t.err == "" { c.Check(err, IsNil) } else { @@ -1922,7 +1922,7 @@ func (s *testPlanSuite) TestResolvingCorrelatedAggregate(c *C) { c.Assert(err, IsNil, comment) err = Preprocess(s.ctx, stmt, WithPreprocessorReturn(&PreprocessorReturn{InfoSchema: s.is})) c.Assert(err, IsNil, comment) - p, _, err := BuildLogicalPlan(ctx, s.ctx, stmt, s.is) + p, _, err := BuildLogicalPlanForTest(ctx, s.ctx, stmt, s.is) c.Assert(err, IsNil, comment) p, err = logicalOptimize(context.TODO(), flagBuildKeyInfo|flagEliminateProjection|flagPrunColumns|flagPrunColumnsAgain, p.(LogicalPlan)) c.Assert(err, IsNil, comment) @@ -1981,7 +1981,7 @@ func (s *testPlanSuite) TestWindowLogicalPlanAmbiguous(c *C) { for i := 0; i < iterations; i++ { stmt, err := s.ParseOneStmt(sql, "", "") c.Assert(err, IsNil) - p, _, err := BuildLogicalPlan(context.Background(), s.ctx, stmt, s.is) + p, _, err := BuildLogicalPlanForTest(context.Background(), s.ctx, stmt, s.is) c.Assert(err, IsNil) if planString == "" { planString = ToString(p) diff --git a/planner/core/optimizer.go b/planner/core/optimizer.go index d406248ebfac9..001add7c5021b 100644 --- a/planner/core/optimizer.go +++ b/planner/core/optimizer.go @@ -88,11 +88,11 @@ type logicalOptRule interface { name() string } -// BuildLogicalPlan used to build logical plan from ast.Node. -func BuildLogicalPlan(ctx context.Context, sctx sessionctx.Context, node ast.Node, is infoschema.InfoSchema) (Plan, types.NameSlice, error) { +// BuildLogicalPlanForTest builds a logical plan for testing purpose from ast.Node. +func BuildLogicalPlanForTest(ctx context.Context, sctx sessionctx.Context, node ast.Node, infoSchema infoschema.InfoSchema) (Plan, types.NameSlice, error) { sctx.GetSessionVars().PlanID = 0 sctx.GetSessionVars().PlanColumnID = 0 - builder, _ := NewPlanBuilder().Init(sctx, is, &utilhint.BlockHintProcessor{}) + builder, _ := NewPlanBuilder().Init(sctx, infoSchema, &utilhint.BlockHintProcessor{}) p, err := builder.Build(ctx, node) if err != nil { return nil, nil, err diff --git a/planner/core/preprocess.go b/planner/core/preprocess.go index a2fe1bd413b14..2f38b61f3ac7c 100644 --- a/planner/core/preprocess.go +++ b/planner/core/preprocess.go @@ -422,37 +422,21 @@ func (p *preprocessor) checkBindGrammar(originNode, hintedNode ast.StmtNode, def } // Check the bind operation is not on any temporary table. - var resNode ast.ResultSetNode - switch n := originNode.(type) { - case *ast.SelectStmt: - resNode = n - case *ast.SetOprStmt: - resNode = n - case *ast.DeleteStmt: - resNode = n.TableRefs.TableRefs - case *ast.UpdateStmt: - resNode = n.TableRefs.TableRefs - //TODO: What about insert into (select * from t) - case *ast.InsertStmt: - resNode = n.Table.TableRefs - } - if resNode != nil { - tblNames := extractTableList(resNode, nil, false) - for _, tn := range tblNames { - tbl, err := p.tableByName(tn) - if err != nil { - // If the operation is order is: drop table -> drop binding - // The table doesn't exist, it is not an error. - if terror.ErrorEqual(err, infoschema.ErrTableNotExists) { - continue - } - p.err = err - return - } - if tbl.Meta().TempTableType != model.TempTableNone { - p.err = ddl.ErrOptOnTemporaryTable.GenWithStackByArgs("create binding") - return + tblNames := extractTableList(originNode, nil, false) + for _, tn := range tblNames { + tbl, err := p.tableByName(tn) + if err != nil { + // If the operation is order is: drop table -> drop binding + // The table doesn't exist, it is not an error. + if terror.ErrorEqual(err, infoschema.ErrTableNotExists) { + continue } + p.err = err + return + } + if tbl.Meta().TempTableType != model.TempTableNone { + p.err = ddl.ErrOptOnTemporaryTable.GenWithStackByArgs("create binding") + return } } diff --git a/planner/memo/group_test.go b/planner/memo/group_test.go index 73ac5f4e351d1..0c21e48044fcf 100644 --- a/planner/memo/group_test.go +++ b/planner/memo/group_test.go @@ -104,7 +104,7 @@ func TestGroupFingerPrint(t *testing.T) { is := infoschema.MockInfoSchema([]*model.TableInfo{plannercore.MockSignedTable()}) ctx := plannercore.MockContext() - plan, _, err := plannercore.BuildLogicalPlan(context.Background(), ctx, stmt1, is) + plan, _, err := plannercore.BuildLogicalPlanForTest(context.Background(), ctx, stmt1, is) require.NoError(t, err) logic1, ok := plan.(plannercore.LogicalPlan) require.True(t, ok) @@ -250,7 +250,7 @@ func TestBuildKeyInfo(t *testing.T) { // case 1: primary key has constant constraint stmt1, err := p.ParseOneStmt("select a from t where a = 10", "", "") require.NoError(t, err) - p1, _, err := plannercore.BuildLogicalPlan(context.Background(), ctx, stmt1, is) + p1, _, err := plannercore.BuildLogicalPlanForTest(context.Background(), ctx, stmt1, is) require.NoError(t, err) logic1, ok := p1.(plannercore.LogicalPlan) require.True(t, ok) @@ -262,7 +262,7 @@ func TestBuildKeyInfo(t *testing.T) { // case 2: group by column is key stmt2, err := p.ParseOneStmt("select b, sum(a) from t group by b", "", "") require.NoError(t, err) - p2, _, err := plannercore.BuildLogicalPlan(context.Background(), ctx, stmt2, is) + p2, _, err := plannercore.BuildLogicalPlanForTest(context.Background(), ctx, stmt2, is) require.NoError(t, err) logic2, ok := p2.(plannercore.LogicalPlan) require.True(t, ok) diff --git a/statistics/selectivity_test.go b/statistics/selectivity_test.go index 1e4ac7508d5bd..c6ef9cf132519 100644 --- a/statistics/selectivity_test.go +++ b/statistics/selectivity_test.go @@ -307,7 +307,7 @@ func (s *testStatsSuite) TestSelectivity(c *C) { ret := &plannercore.PreprocessorReturn{} err = plannercore.Preprocess(sctx, stmts[0], plannercore.WithPreprocessorReturn(ret)) c.Assert(err, IsNil, comment) - p, _, err := plannercore.BuildLogicalPlan(ctx, sctx, stmts[0], ret.InfoSchema) + p, _, err := plannercore.BuildLogicalPlanForTest(ctx, sctx, stmts[0], ret.InfoSchema) c.Assert(err, IsNil, Commentf("error %v, for building plan, expr %s", err, tt.exprs)) sel := p.(plannercore.LogicalPlan).Children()[0].(*plannercore.LogicalSelection) @@ -584,7 +584,7 @@ func BenchmarkSelectivity(b *testing.B) { ret := &plannercore.PreprocessorReturn{} err = plannercore.Preprocess(sctx, stmts[0], plannercore.WithPreprocessorReturn(ret)) c.Assert(err, IsNil, comment) - p, _, err := plannercore.BuildLogicalPlan(context.Background(), sctx, stmts[0], ret.InfoSchema) + p, _, err := plannercore.BuildLogicalPlanForTest(context.Background(), sctx, stmts[0], ret.InfoSchema) c.Assert(err, IsNil, Commentf("error %v, for building plan, expr %s", err, exprs)) file, err := os.Create("cpu.profile") @@ -848,7 +848,7 @@ func (s *testStatsSuite) TestDNFCondSelectivity(c *C) { ret := &plannercore.PreprocessorReturn{} err = plannercore.Preprocess(sctx, stmts[0], plannercore.WithPreprocessorReturn(ret)) c.Assert(err, IsNil, Commentf("error %v, for sql %s", err, tt)) - p, _, err := plannercore.BuildLogicalPlan(ctx, sctx, stmts[0], ret.InfoSchema) + p, _, err := plannercore.BuildLogicalPlanForTest(ctx, sctx, stmts[0], ret.InfoSchema) c.Assert(err, IsNil, Commentf("error %v, for building plan, sql %s", err, tt)) sel := p.(plannercore.LogicalPlan).Children()[0].(*plannercore.LogicalSelection) diff --git a/util/ranger/ranger_test.go b/util/ranger/ranger_test.go index fd71d23028590..cd2d42572e5f5 100644 --- a/util/ranger/ranger_test.go +++ b/util/ranger/ranger_test.go @@ -286,7 +286,7 @@ func TestTableRange(t *testing.T) { ret := &plannercore.PreprocessorReturn{} err = plannercore.Preprocess(sctx, stmts[0], plannercore.WithPreprocessorReturn(ret)) require.NoError(t, err) - p, _, err := plannercore.BuildLogicalPlan(ctx, sctx, stmts[0], ret.InfoSchema) + p, _, err := plannercore.BuildLogicalPlanForTest(ctx, sctx, stmts[0], ret.InfoSchema) require.NoError(t, err) selection := p.(plannercore.LogicalPlan).Children()[0].(*plannercore.LogicalSelection) conds := make([]expression.Expression, len(selection.Conditions)) @@ -632,7 +632,7 @@ create table t( ret := &plannercore.PreprocessorReturn{} err = plannercore.Preprocess(sctx, stmts[0], plannercore.WithPreprocessorReturn(ret)) require.NoError(t, err) - p, _, err := plannercore.BuildLogicalPlan(ctx, sctx, stmts[0], ret.InfoSchema) + p, _, err := plannercore.BuildLogicalPlanForTest(ctx, sctx, stmts[0], ret.InfoSchema) require.NoError(t, err) selection := p.(plannercore.LogicalPlan).Children()[0].(*plannercore.LogicalSelection) tbl := selection.Children()[0].(*plannercore.DataSource).TableInfo() @@ -825,7 +825,7 @@ create table t( ret := &plannercore.PreprocessorReturn{} err = plannercore.Preprocess(sctx, stmts[0], plannercore.WithPreprocessorReturn(ret)) require.NoError(t, err) - p, _, err := plannercore.BuildLogicalPlan(ctx, sctx, stmts[0], ret.InfoSchema) + p, _, err := plannercore.BuildLogicalPlanForTest(ctx, sctx, stmts[0], ret.InfoSchema) require.NoError(t, err) selection := p.(plannercore.LogicalPlan).Children()[0].(*plannercore.LogicalSelection) tbl := selection.Children()[0].(*plannercore.DataSource).TableInfo() @@ -1190,7 +1190,7 @@ func TestColumnRange(t *testing.T) { ret := &plannercore.PreprocessorReturn{} err = plannercore.Preprocess(sctx, stmts[0], plannercore.WithPreprocessorReturn(ret)) require.NoError(t, err) - p, _, err := plannercore.BuildLogicalPlan(ctx, sctx, stmts[0], ret.InfoSchema) + p, _, err := plannercore.BuildLogicalPlanForTest(ctx, sctx, stmts[0], ret.InfoSchema) require.NoError(t, err) sel := p.(plannercore.LogicalPlan).Children()[0].(*plannercore.LogicalSelection) ds, ok := sel.Children()[0].(*plannercore.DataSource) @@ -1615,7 +1615,7 @@ func TestIndexRangeForYear(t *testing.T) { ret := &plannercore.PreprocessorReturn{} err = plannercore.Preprocess(sctx, stmts[0], plannercore.WithPreprocessorReturn(ret)) require.NoError(t, err) - p, _, err := plannercore.BuildLogicalPlan(ctx, sctx, stmts[0], ret.InfoSchema) + p, _, err := plannercore.BuildLogicalPlanForTest(ctx, sctx, stmts[0], ret.InfoSchema) require.NoError(t, err) selection := p.(plannercore.LogicalPlan).Children()[0].(*plannercore.LogicalSelection) tbl := selection.Children()[0].(*plannercore.DataSource).TableInfo() @@ -1688,7 +1688,7 @@ func TestPrefixIndexRangeScan(t *testing.T) { ret := &plannercore.PreprocessorReturn{} err = plannercore.Preprocess(sctx, stmts[0], plannercore.WithPreprocessorReturn(ret)) require.NoError(t, err) - p, _, err := plannercore.BuildLogicalPlan(ctx, sctx, stmts[0], ret.InfoSchema) + p, _, err := plannercore.BuildLogicalPlanForTest(ctx, sctx, stmts[0], ret.InfoSchema) require.NoError(t, err) selection := p.(plannercore.LogicalPlan).Children()[0].(*plannercore.LogicalSelection) tbl := selection.Children()[0].(*plannercore.DataSource).TableInfo() diff --git a/util/rowcodec/export_test.go b/util/rowcodec/main_test.go similarity index 87% rename from util/rowcodec/export_test.go rename to util/rowcodec/main_test.go index 47cd7b73eee6f..7a54747ceaf72 100644 --- a/util/rowcodec/export_test.go +++ b/util/rowcodec/main_test.go @@ -1,4 +1,4 @@ -// Copyright 2019 PingCAP, Inc. +// Copyright 2021 PingCAP, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -15,11 +15,20 @@ package rowcodec import ( + "testing" + "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/codec" + "github.com/pingcap/tidb/util/testbridge" + "go.uber.org/goleak" ) +func TestMain(m *testing.M) { + testbridge.WorkaroundGoCheckFlags() + goleak.VerifyTestMain(m) +} + // EncodeFromOldRow encodes a row from an old-format row. // this method will be used in test. func EncodeFromOldRow(encoder *Encoder, sc *stmtctx.StatementContext, oldRow, buf []byte) ([]byte, error) { diff --git a/util/rowcodec/rowcodec_test.go b/util/rowcodec/rowcodec_test.go index c7ff304705721..741c3ec7b3f44 100644 --- a/util/rowcodec/rowcodec_test.go +++ b/util/rowcodec/rowcodec_test.go @@ -20,7 +20,6 @@ import ( "testing" "time" - . "github.com/pingcap/check" "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/sessionctx/stmtctx" @@ -31,33 +30,28 @@ import ( "github.com/pingcap/tidb/util/codec" "github.com/pingcap/tidb/util/collate" "github.com/pingcap/tidb/util/rowcodec" + "github.com/stretchr/testify/require" ) -func TestT(t *testing.T) { - TestingT(t) -} - -var _ = Suite(&testSuite{}) - -type testSuite struct{} - type testData struct { id int64 ft *types.FieldType - dt types.Datum - bt types.Datum + input types.Datum + output types.Datum def *types.Datum handle bool } -func (s *testSuite) TestEncodeLargeSmallReuseBug(c *C) { +func TestEncodeLargeSmallReuseBug(t *testing.T) { + t.Parallel() + // reuse one rowcodec.Encoder. var encoder rowcodec.Encoder colFt := types.NewFieldType(mysql.TypeString) largeColID := int64(300) b, err := encoder.Encode(&stmtctx.StatementContext{}, []int64{largeColID}, []types.Datum{types.NewBytesDatum([]byte(""))}, nil) - c.Assert(err, IsNil) + require.NoError(t, err) bDecoder := rowcodec.NewDatumMapDecoder([]rowcodec.ColInfo{ { @@ -67,12 +61,12 @@ func (s *testSuite) TestEncodeLargeSmallReuseBug(c *C) { }, }, nil) _, err = bDecoder.DecodeToDatumMap(b, nil) - c.Assert(err, IsNil) + require.NoError(t, err) colFt = types.NewFieldType(mysql.TypeLonglong) smallColID := int64(1) b, err = encoder.Encode(&stmtctx.StatementContext{}, []int64{smallColID}, []types.Datum{types.NewIntDatum(2)}, nil) - c.Assert(err, IsNil) + require.NoError(t, err) bDecoder = rowcodec.NewDatumMapDecoder([]rowcodec.ColInfo{ { @@ -82,172 +76,183 @@ func (s *testSuite) TestEncodeLargeSmallReuseBug(c *C) { }, }, nil) m, err := bDecoder.DecodeToDatumMap(b, nil) - c.Assert(err, IsNil) + require.NoError(t, err) + v := m[smallColID] - c.Assert(v.GetInt64(), Equals, int64(2)) + require.Equal(t, int64(2), v.GetInt64()) } -func (s *testSuite) TestDecodeRowWithHandle(c *C) { +func TestDecodeRowWithHandle(t *testing.T) { + t.Parallel() + handleID := int64(-1) handleValue := int64(10000) - encodeAndDecodeHandle := func(c *C, testData []testData) { - // transform test data into input. - colIDs := make([]int64, 0, len(testData)) - dts := make([]types.Datum, 0, len(testData)) - fts := make([]*types.FieldType, 0, len(testData)) - cols := make([]rowcodec.ColInfo, 0, len(testData)) - handleColFtMap := make(map[int64]*types.FieldType) - for i := range testData { - t := testData[i] - if t.handle { - handleColFtMap[handleID] = t.ft - } else { - colIDs = append(colIDs, t.id) - dts = append(dts, t.dt) - } - fts = append(fts, t.ft) - cols = append(cols, rowcodec.ColInfo{ - ID: t.id, - IsPKHandle: t.handle, - Ft: t.ft, - }) - } - - // test encode input. - var encoder rowcodec.Encoder - sc := new(stmtctx.StatementContext) - sc.TimeZone = time.UTC - newRow, err := encoder.Encode(sc, colIDs, dts, nil) - c.Assert(err, IsNil) - - // decode to datum map. - mDecoder := rowcodec.NewDatumMapDecoder(cols, sc.TimeZone) - dm, err := mDecoder.DecodeToDatumMap(newRow, nil) - c.Assert(err, IsNil) - dm, err = tablecodec.DecodeHandleToDatumMap(kv.IntHandle(handleValue), - []int64{handleID}, handleColFtMap, sc.TimeZone, dm) - c.Assert(err, IsNil) - for _, t := range testData { - d, exists := dm[t.id] - c.Assert(exists, IsTrue) - c.Assert(d, DeepEquals, t.dt) - } - - // decode to chunk. - cDecoder := rowcodec.NewChunkDecoder(cols, []int64{-1}, nil, sc.TimeZone) - chk := chunk.New(fts, 1, 1) - err = cDecoder.DecodeToChunk(newRow, kv.IntHandle(handleValue), chk) - c.Assert(err, IsNil) - chkRow := chk.GetRow(0) - cdt := chkRow.GetDatumRow(fts) - for i, t := range testData { - d := cdt[i] - if d.Kind() == types.KindMysqlDecimal { - c.Assert(d.GetMysqlDecimal(), DeepEquals, t.bt.GetMysqlDecimal()) - } else { - c.Assert(d, DeepEquals, t.bt) - } - } - - // decode to old row bytes. - colOffset := make(map[int64]int) - for i, t := range testData { - colOffset[t.id] = i - } - bDecoder := rowcodec.NewByteDecoder(cols, []int64{-1}, nil, nil) - oldRow, err := bDecoder.DecodeToBytes(colOffset, kv.IntHandle(handleValue), newRow, nil) - c.Assert(err, IsNil) - for i, t := range testData { - remain, d, err := codec.DecodeOne(oldRow[i]) - c.Assert(err, IsNil) - c.Assert(len(remain), Equals, 0) - if d.Kind() == types.KindMysqlDecimal { - c.Assert(d.GetMysqlDecimal(), DeepEquals, t.bt.GetMysqlDecimal()) - } else { - c.Assert(d, DeepEquals, t.bt) - } - } - } - - // encode & decode signed int. - testDataSigned := []testData{ + tests := []struct { + name string + testData []testData + }{ { - handleID, - types.NewFieldType(mysql.TypeLonglong), - types.NewIntDatum(handleValue), - types.NewIntDatum(handleValue), - nil, - true, + "signed int", + []testData{ + { + handleID, + types.NewFieldType(mysql.TypeLonglong), + types.NewIntDatum(handleValue), + types.NewIntDatum(handleValue), + nil, + true, + }, + { + 10, + types.NewFieldType(mysql.TypeLonglong), + types.NewIntDatum(1), + types.NewIntDatum(1), + nil, + false, + }, + }, }, { - 10, - types.NewFieldType(mysql.TypeLonglong), - types.NewIntDatum(1), - types.NewIntDatum(1), - nil, - false, + "unsigned int", + []testData{ + { + handleID, + withUnsigned(types.NewFieldType(mysql.TypeLonglong)), + types.NewUintDatum(uint64(handleValue)), + types.NewUintDatum(uint64(handleValue)), // decode as bytes will uint if unsigned. + nil, + true, + }, + { + 10, + types.NewFieldType(mysql.TypeLonglong), + types.NewIntDatum(1), + types.NewIntDatum(1), + nil, + false, + }, + }, }, } - encodeAndDecodeHandle(c, testDataSigned) - // encode & decode unsigned int. - testDataUnsigned := []testData{ - { - handleID, - withUnsigned(types.NewFieldType(mysql.TypeLonglong)), - types.NewUintDatum(uint64(handleValue)), - types.NewUintDatum(uint64(handleValue)), // decode as bytes will uint if unsigned. - nil, - true, - }, - { - 10, - types.NewFieldType(mysql.TypeLonglong), - types.NewIntDatum(1), - types.NewIntDatum(1), - nil, - false, - }, + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + td := test.testData + + // transform test data into input. + colIDs := make([]int64, 0, len(td)) + dts := make([]types.Datum, 0, len(td)) + fts := make([]*types.FieldType, 0, len(td)) + cols := make([]rowcodec.ColInfo, 0, len(td)) + handleColFtMap := make(map[int64]*types.FieldType) + for _, d := range td { + if d.handle { + handleColFtMap[handleID] = d.ft + } else { + colIDs = append(colIDs, d.id) + dts = append(dts, d.input) + } + fts = append(fts, d.ft) + cols = append(cols, rowcodec.ColInfo{ + ID: d.id, + IsPKHandle: d.handle, + Ft: d.ft, + }) + } + + // test encode input. + var encoder rowcodec.Encoder + sc := new(stmtctx.StatementContext) + sc.TimeZone = time.UTC + newRow, err := encoder.Encode(sc, colIDs, dts, nil) + require.NoError(t, err) + + // decode to datum map. + mDecoder := rowcodec.NewDatumMapDecoder(cols, sc.TimeZone) + dm, err := mDecoder.DecodeToDatumMap(newRow, nil) + require.NoError(t, err) + + dm, err = tablecodec.DecodeHandleToDatumMap(kv.IntHandle(handleValue), []int64{handleID}, handleColFtMap, sc.TimeZone, dm) + require.NoError(t, err) + + for _, d := range td { + dat, exists := dm[d.id] + require.True(t, exists) + require.Equal(t, d.input, dat) + } + + // decode to chunk. + cDecoder := rowcodec.NewChunkDecoder(cols, []int64{-1}, nil, sc.TimeZone) + chk := chunk.New(fts, 1, 1) + err = cDecoder.DecodeToChunk(newRow, kv.IntHandle(handleValue), chk) + require.NoError(t, err) + + chkRow := chk.GetRow(0) + cdt := chkRow.GetDatumRow(fts) + for i, d := range td { + dat := cdt[i] + if dat.Kind() == types.KindMysqlDecimal { + require.Equal(t, d.output.GetMysqlDecimal(), dat.GetMysqlDecimal()) + } else { + require.Equal(t, d.output, dat) + } + } + + // decode to old row bytes. + colOffset := make(map[int64]int) + for i, t := range td { + colOffset[t.id] = i + } + bDecoder := rowcodec.NewByteDecoder(cols, []int64{-1}, nil, nil) + oldRow, err := bDecoder.DecodeToBytes(colOffset, kv.IntHandle(handleValue), newRow, nil) + require.NoError(t, err) + + for i, d := range td { + remain, dat, err := codec.DecodeOne(oldRow[i]) + require.NoError(t, err) + require.Len(t, remain, 0) + if dat.Kind() == types.KindMysqlDecimal { + require.Equal(t, d.output.GetMysqlDecimal(), dat.GetMysqlDecimal()) + } else { + require.Equal(t, d.output, dat) + } + } + }) } - encodeAndDecodeHandle(c, testDataUnsigned) } -func (s *testSuite) TestEncodeKindNullDatum(c *C) { +func TestEncodeKindNullDatum(t *testing.T) { + t.Parallel() + var encoder rowcodec.Encoder sc := new(stmtctx.StatementContext) sc.TimeZone = time.UTC - colIDs := []int64{ - 1, - 2, - } + colIDs := []int64{1, 2} + var nilDt types.Datum nilDt.SetNull() dts := []types.Datum{nilDt, types.NewIntDatum(2)} ft := types.NewFieldType(mysql.TypeLonglong) fts := []*types.FieldType{ft, ft} newRow, err := encoder.Encode(sc, colIDs, dts, nil) - c.Assert(err, IsNil) + require.NoError(t, err) - cols := []rowcodec.ColInfo{{ - ID: 1, - Ft: ft, - }, - { - ID: 2, - Ft: ft, - }} + cols := []rowcodec.ColInfo{{ID: 1, Ft: ft}, {ID: 2, Ft: ft}} cDecoder := rowcodec.NewChunkDecoder(cols, []int64{-1}, nil, sc.TimeZone) chk := chunk.New(fts, 1, 1) err = cDecoder.DecodeToChunk(newRow, kv.IntHandle(-1), chk) - c.Assert(err, IsNil) + require.NoError(t, err) + chkRow := chk.GetRow(0) cdt := chkRow.GetDatumRow(fts) - c.Assert(cdt[0].IsNull(), Equals, true) - c.Assert(cdt[1].GetInt64(), Equals, int64(2)) + require.True(t, cdt[0].IsNull()) + require.Equal(t, int64(2), cdt[1].GetInt64()) } -func (s *testSuite) TestDecodeDecimalFspNotMatch(c *C) { +func TestDecodeDecimalFspNotMatch(t *testing.T) { + t.Parallel() + var encoder rowcodec.Encoder sc := new(stmtctx.StatementContext) sc.TimeZone = time.UTC @@ -260,7 +265,7 @@ func (s *testSuite) TestDecodeDecimalFspNotMatch(c *C) { ft.Decimal = 4 fts := []*types.FieldType{ft} newRow, err := encoder.Encode(sc, colIDs, dts, nil) - c.Assert(err, IsNil) + require.NoError(t, err) // decode to chunk. ft = types.NewFieldType(mysql.TypeNewDecimal) @@ -273,17 +278,20 @@ func (s *testSuite) TestDecodeDecimalFspNotMatch(c *C) { cDecoder := rowcodec.NewChunkDecoder(cols, []int64{-1}, nil, sc.TimeZone) chk := chunk.New(fts, 1, 1) err = cDecoder.DecodeToChunk(newRow, kv.IntHandle(-1), chk) - c.Assert(err, IsNil) + require.NoError(t, err) + chkRow := chk.GetRow(0) cdt := chkRow.GetDatumRow(fts) dec = withFrac(3)(withLen(6)(types.NewDecimalDatum(types.NewDecFromStringForTest("11.990")))) - c.Assert(cdt[0].GetMysqlDecimal().String(), DeepEquals, dec.GetMysqlDecimal().String()) + require.Equal(t, dec.GetMysqlDecimal().String(), cdt[0].GetMysqlDecimal().String()) } -func (s *testSuite) TestTypesNewRowCodec(c *C) { +func TestTypesNewRowCodec(t *testing.T) { + t.Parallel() + getJSONDatum := func(value string) types.Datum { j, err := json.ParseBinaryFromString(value) - c.Assert(err, IsNil) + require.NoError(t, err) var d types.Datum d.SetMysqlJSON(j) return d @@ -294,85 +302,12 @@ func (s *testSuite) TestTypesNewRowCodec(c *C) { return d } getTime := func(value string) types.Time { - t, err := types.ParseTime(&stmtctx.StatementContext{TimeZone: time.UTC}, value, mysql.TypeTimestamp, 6) - c.Assert(err, IsNil) - return t - } - - var encoder rowcodec.Encoder - encodeAndDecode := func(c *C, testData []testData) { - // transform test data into input. - colIDs := make([]int64, 0, len(testData)) - dts := make([]types.Datum, 0, len(testData)) - fts := make([]*types.FieldType, 0, len(testData)) - cols := make([]rowcodec.ColInfo, 0, len(testData)) - for i := range testData { - t := testData[i] - colIDs = append(colIDs, t.id) - dts = append(dts, t.dt) - fts = append(fts, t.ft) - cols = append(cols, rowcodec.ColInfo{ - ID: t.id, - IsPKHandle: t.handle, - Ft: t.ft, - }) - } - - // test encode input. - sc := new(stmtctx.StatementContext) - sc.TimeZone = time.UTC - newRow, err := encoder.Encode(sc, colIDs, dts, nil) - c.Assert(err, IsNil) - - // decode to datum map. - mDecoder := rowcodec.NewDatumMapDecoder(cols, sc.TimeZone) - dm, err := mDecoder.DecodeToDatumMap(newRow, nil) - c.Assert(err, IsNil) - for _, t := range testData { - d, exists := dm[t.id] - c.Assert(exists, IsTrue) - c.Assert(d, DeepEquals, t.dt) - } - - // decode to chunk. - cDecoder := rowcodec.NewChunkDecoder(cols, []int64{-1}, nil, sc.TimeZone) - chk := chunk.New(fts, 1, 1) - err = cDecoder.DecodeToChunk(newRow, kv.IntHandle(-1), chk) - c.Assert(err, IsNil) - chkRow := chk.GetRow(0) - cdt := chkRow.GetDatumRow(fts) - for i, t := range testData { - d := cdt[i] - if d.Kind() == types.KindMysqlDecimal { - c.Assert(d.GetMysqlDecimal(), DeepEquals, t.bt.GetMysqlDecimal()) - } else { - c.Assert(d, DeepEquals, t.dt) - } - } - - // decode to old row bytes. - colOffset := make(map[int64]int) - for i, t := range testData { - colOffset[t.id] = i - } - bDecoder := rowcodec.NewByteDecoder(cols, []int64{-1}, nil, nil) - oldRow, err := bDecoder.DecodeToBytes(colOffset, kv.IntHandle(-1), newRow, nil) - c.Assert(err, IsNil) - for i, t := range testData { - remain, d, err := codec.DecodeOne(oldRow[i]) - c.Assert(err, IsNil) - c.Assert(len(remain), Equals, 0) - if d.Kind() == types.KindMysqlDecimal { - c.Assert(d.GetMysqlDecimal(), DeepEquals, t.bt.GetMysqlDecimal()) - } else if d.Kind() == types.KindBytes { - c.Assert(d.GetBytes(), DeepEquals, t.bt.GetBytes()) - } else { - c.Assert(d, DeepEquals, t.bt) - } - } + d, err := types.ParseTime(&stmtctx.StatementContext{TimeZone: time.UTC}, value, mysql.TypeTimestamp, 6) + require.NoError(t, err) + return d } - testData := []testData{ + smallTestDataList := []testData{ { 1, types.NewFieldType(mysql.TypeLonglong), @@ -519,128 +454,118 @@ func (s *testSuite) TestTypesNewRowCodec(c *C) { }, } - // test small - encodeAndDecode(c, testData) + largeColIDTestDataList := make([]testData, len(smallTestDataList)) + copy(largeColIDTestDataList, smallTestDataList) + largeColIDTestDataList[0].id = 300 - // test large colID - testData[0].id = 300 - encodeAndDecode(c, testData) - testData[0].id = 1 + largeTestDataList := make([]testData, len(smallTestDataList)) + copy(largeTestDataList, smallTestDataList) + largeTestDataList[3].input = types.NewStringDatum(strings.Repeat("a", math.MaxUint16+1)) + largeTestDataList[3].output = types.NewStringDatum(strings.Repeat("a", math.MaxUint16+1)) - // test large data - testData[3].dt = types.NewStringDatum(strings.Repeat("a", math.MaxUint16+1)) - testData[3].bt = types.NewStringDatum(strings.Repeat("a", math.MaxUint16+1)) - encodeAndDecode(c, testData) -} + var encoder rowcodec.Encoder -func (s *testSuite) TestNilAndDefault(c *C) { - encodeAndDecode := func(c *C, testData []testData) { - // transform test data into input. - colIDs := make([]int64, 0, len(testData)) - dts := make([]types.Datum, 0, len(testData)) - cols := make([]rowcodec.ColInfo, 0, len(testData)) - fts := make([]*types.FieldType, 0, len(testData)) - for i := range testData { - t := testData[i] - if t.def == nil { - colIDs = append(colIDs, t.id) - dts = append(dts, t.dt) - } - fts = append(fts, t.ft) - cols = append(cols, rowcodec.ColInfo{ - ID: t.id, - IsPKHandle: t.handle, - Ft: t.ft, - }) - } - ddf := func(i int, chk *chunk.Chunk) error { - t := testData[i] - if t.def == nil { - chk.AppendNull(i) - return nil - } - chk.AppendDatum(i, t.def) - return nil - } - bdf := func(i int) ([]byte, error) { - t := testData[i] - if t.def == nil { - return nil, nil - } - return getOldDatumByte(*t.def), nil - } - // test encode input. - var encoder rowcodec.Encoder - sc := new(stmtctx.StatementContext) - sc.TimeZone = time.UTC - newRow, err := encoder.Encode(sc, colIDs, dts, nil) - c.Assert(err, IsNil) - - // decode to datum map. - mDecoder := rowcodec.NewDatumMapDecoder(cols, sc.TimeZone) - dm, err := mDecoder.DecodeToDatumMap(newRow, nil) - c.Assert(err, IsNil) - for _, t := range testData { - d, exists := dm[t.id] - if t.def != nil { - // for datum should not fill default value. - c.Assert(exists, IsFalse) - } else { - c.Assert(exists, IsTrue) - c.Assert(d, DeepEquals, t.bt) + tests := []struct { + name string + testData []testData + }{ + { + "small", + smallTestDataList, + }, + { + "largeColID", + largeColIDTestDataList, + }, + { + "largeData", + largeTestDataList, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + td := test.testData + + // transform test data into input. + colIDs := make([]int64, 0, len(td)) + dts := make([]types.Datum, 0, len(td)) + fts := make([]*types.FieldType, 0, len(td)) + cols := make([]rowcodec.ColInfo, 0, len(td)) + for _, d := range td { + colIDs = append(colIDs, d.id) + dts = append(dts, d.input) + fts = append(fts, d.ft) + cols = append(cols, rowcodec.ColInfo{ + ID: d.id, + IsPKHandle: d.handle, + Ft: d.ft, + }) } - } - // decode to chunk. - chk := chunk.New(fts, 1, 1) - cDecoder := rowcodec.NewChunkDecoder(cols, []int64{-1}, ddf, sc.TimeZone) - err = cDecoder.DecodeToChunk(newRow, kv.IntHandle(-1), chk) - c.Assert(err, IsNil) - chkRow := chk.GetRow(0) - cdt := chkRow.GetDatumRow(fts) - for i, t := range testData { - d := cdt[i] - if d.Kind() == types.KindMysqlDecimal { - c.Assert(d.GetMysqlDecimal(), DeepEquals, t.bt.GetMysqlDecimal()) - } else { - c.Assert(d, DeepEquals, t.bt) + // test encode input. + sc := new(stmtctx.StatementContext) + sc.TimeZone = time.UTC + newRow, err := encoder.Encode(sc, colIDs, dts, nil) + require.NoError(t, err) + + // decode to datum map. + mDecoder := rowcodec.NewDatumMapDecoder(cols, sc.TimeZone) + dm, err := mDecoder.DecodeToDatumMap(newRow, nil) + require.NoError(t, err) + + for _, d := range td { + dat, exists := dm[d.id] + require.True(t, exists) + require.Equal(t, d.input, dat) } - } - chk = chunk.New(fts, 1, 1) - cDecoder = rowcodec.NewChunkDecoder(cols, []int64{-1}, nil, sc.TimeZone) - err = cDecoder.DecodeToChunk(newRow, kv.IntHandle(-1), chk) - c.Assert(err, IsNil) - chkRow = chk.GetRow(0) - cdt = chkRow.GetDatumRow(fts) - for i := range testData { - if i == 0 { - continue + // decode to chunk. + cDecoder := rowcodec.NewChunkDecoder(cols, []int64{-1}, nil, sc.TimeZone) + chk := chunk.New(fts, 1, 1) + err = cDecoder.DecodeToChunk(newRow, kv.IntHandle(-1), chk) + require.NoError(t, err) + + chkRow := chk.GetRow(0) + cdt := chkRow.GetDatumRow(fts) + for i, d := range td { + dat := cdt[i] + if dat.Kind() == types.KindMysqlDecimal { + require.Equal(t, d.output.GetMysqlDecimal(), dat.GetMysqlDecimal()) + } else { + require.Equal(t, d.input, dat) + } } - d := cdt[i] - c.Assert(d.IsNull(), Equals, true) - } - // decode to old row bytes. - colOffset := make(map[int64]int) - for i, t := range testData { - colOffset[t.id] = i - } - bDecoder := rowcodec.NewByteDecoder(cols, []int64{-1}, bdf, sc.TimeZone) - oldRow, err := bDecoder.DecodeToBytes(colOffset, kv.IntHandle(-1), newRow, nil) - c.Assert(err, IsNil) - for i, t := range testData { - remain, d, err := codec.DecodeOne(oldRow[i]) - c.Assert(err, IsNil) - c.Assert(len(remain), Equals, 0) - if d.Kind() == types.KindMysqlDecimal { - c.Assert(d.GetMysqlDecimal(), DeepEquals, t.bt.GetMysqlDecimal()) - } else { - c.Assert(d, DeepEquals, t.bt) + // decode to old row bytes. + colOffset := make(map[int64]int) + for i, t := range td { + colOffset[t.id] = i } - } + bDecoder := rowcodec.NewByteDecoder(cols, []int64{-1}, nil, nil) + oldRow, err := bDecoder.DecodeToBytes(colOffset, kv.IntHandle(-1), newRow, nil) + require.NoError(t, err) + + for i, d := range td { + remain, dat, err := codec.DecodeOne(oldRow[i]) + require.NoError(t, err) + require.Len(t, remain, 0) + if dat.Kind() == types.KindMysqlDecimal { + require.Equal(t, d.output.GetMysqlDecimal(), dat.GetMysqlDecimal()) + } else if dat.Kind() == types.KindBytes { + require.Equal(t, d.output.GetBytes(), dat.GetBytes()) + } else { + require.Equal(t, d.output, dat) + } + } + }) } - dtNilData := []testData{ +} + +func TestNilAndDefault(t *testing.T) { + t.Parallel() + + td := []testData{ { 1, types.NewFieldType(mysql.TypeLonglong), @@ -658,48 +583,122 @@ func (s *testSuite) TestNilAndDefault(c *C) { false, }, } - encodeAndDecode(c, dtNilData) -} -func (s *testSuite) TestVarintCompatibility(c *C) { - encodeAndDecodeByte := func(c *C, testData []testData) { - // transform test data into input. - colIDs := make([]int64, 0, len(testData)) - dts := make([]types.Datum, 0, len(testData)) - cols := make([]rowcodec.ColInfo, 0, len(testData)) - for i := range testData { - t := testData[i] - colIDs = append(colIDs, t.id) - dts = append(dts, t.dt) - cols = append(cols, rowcodec.ColInfo{ - ID: t.id, - IsPKHandle: t.handle, - Ft: t.ft, - }) + // transform test data into input. + colIDs := make([]int64, 0, len(td)) + dts := make([]types.Datum, 0, len(td)) + cols := make([]rowcodec.ColInfo, 0, len(td)) + fts := make([]*types.FieldType, 0, len(td)) + for i := range td { + d := td[i] + if d.def == nil { + colIDs = append(colIDs, d.id) + dts = append(dts, d.input) } + fts = append(fts, d.ft) + cols = append(cols, rowcodec.ColInfo{ + ID: d.id, + IsPKHandle: d.handle, + Ft: d.ft, + }) + } + ddf := func(i int, chk *chunk.Chunk) error { + d := td[i] + if d.def == nil { + chk.AppendNull(i) + return nil + } + chk.AppendDatum(i, d.def) + return nil + } + bdf := func(i int) ([]byte, error) { + d := td[i] + if d.def == nil { + return nil, nil + } + return getOldDatumByte(*d.def), nil + } - // test encode input. - var encoder rowcodec.Encoder - sc := new(stmtctx.StatementContext) - sc.TimeZone = time.UTC - newRow, err := encoder.Encode(sc, colIDs, dts, nil) - c.Assert(err, IsNil) - decoder := rowcodec.NewByteDecoder(cols, []int64{-1}, nil, sc.TimeZone) - // decode to old row bytes. - colOffset := make(map[int64]int) - for i, t := range testData { - colOffset[t.id] = i + // test encode input. + var encoder rowcodec.Encoder + sc := new(stmtctx.StatementContext) + sc.TimeZone = time.UTC + newRow, err := encoder.Encode(sc, colIDs, dts, nil) + require.NoError(t, err) + + // decode to datum map. + mDecoder := rowcodec.NewDatumMapDecoder(cols, sc.TimeZone) + dm, err := mDecoder.DecodeToDatumMap(newRow, nil) + require.NoError(t, err) + + for _, d := range td { + dat, exists := dm[d.id] + if d.def != nil { + // for datum should not fill default value. + require.False(t, exists) + } else { + require.True(t, exists) + require.Equal(t, d.output, dat) } - oldRow, err := decoder.DecodeToBytes(colOffset, kv.IntHandle(1), newRow, nil) - c.Assert(err, IsNil) - for i, t := range testData { - oldVarint, err := tablecodec.EncodeValue(nil, nil, t.bt) // tablecodec will encode as varint/varuint - c.Assert(err, IsNil) - c.Assert(oldVarint, DeepEquals, oldRow[i]) + } + + // decode to chunk. + chk := chunk.New(fts, 1, 1) + cDecoder := rowcodec.NewChunkDecoder(cols, []int64{-1}, ddf, sc.TimeZone) + err = cDecoder.DecodeToChunk(newRow, kv.IntHandle(-1), chk) + require.NoError(t, err) + + chkRow := chk.GetRow(0) + cdt := chkRow.GetDatumRow(fts) + for i, d := range td { + dat := cdt[i] + if dat.Kind() == types.KindMysqlDecimal { + require.Equal(t, d.output.GetMysqlDecimal(), dat.GetMysqlDecimal()) + } else { + require.Equal(t, d.output, dat) } } - testDataValue := []testData{ + chk = chunk.New(fts, 1, 1) + cDecoder = rowcodec.NewChunkDecoder(cols, []int64{-1}, nil, sc.TimeZone) + err = cDecoder.DecodeToChunk(newRow, kv.IntHandle(-1), chk) + require.NoError(t, err) + + chkRow = chk.GetRow(0) + cdt = chkRow.GetDatumRow(fts) + for i := range td { + if i == 0 { + continue + } + require.True(t, cdt[i].IsNull()) + } + + // decode to old row bytes. + colOffset := make(map[int64]int) + for i, t := range td { + colOffset[t.id] = i + } + bDecoder := rowcodec.NewByteDecoder(cols, []int64{-1}, bdf, sc.TimeZone) + oldRow, err := bDecoder.DecodeToBytes(colOffset, kv.IntHandle(-1), newRow, nil) + require.NoError(t, err) + + for i, d := range td { + remain, dat, err := codec.DecodeOne(oldRow[i]) + require.NoError(t, err) + require.Len(t, remain, 0) + + if dat.Kind() == types.KindMysqlDecimal { + require.Equal(t, d.output.GetMysqlDecimal(), dat.GetMysqlDecimal()) + } else { + require.Equal(t, d.output, dat) + } + } +} + +func TestVarintCompatibility(t *testing.T) { + t.Parallel() + + td := []testData{ { 1, types.NewFieldType(mysql.TypeLonglong), @@ -717,10 +716,47 @@ func (s *testSuite) TestVarintCompatibility(c *C) { false, }, } - encodeAndDecodeByte(c, testDataValue) + + // transform test data into input. + colIDs := make([]int64, 0, len(td)) + dts := make([]types.Datum, 0, len(td)) + cols := make([]rowcodec.ColInfo, 0, len(td)) + for _, d := range td { + colIDs = append(colIDs, d.id) + dts = append(dts, d.input) + cols = append(cols, rowcodec.ColInfo{ + ID: d.id, + IsPKHandle: d.handle, + Ft: d.ft, + }) + } + + // test encode input. + var encoder rowcodec.Encoder + sc := new(stmtctx.StatementContext) + sc.TimeZone = time.UTC + newRow, err := encoder.Encode(sc, colIDs, dts, nil) + require.NoError(t, err) + + decoder := rowcodec.NewByteDecoder(cols, []int64{-1}, nil, sc.TimeZone) + // decode to old row bytes. + colOffset := make(map[int64]int) + for i, t := range td { + colOffset[t.id] = i + } + oldRow, err := decoder.DecodeToBytes(colOffset, kv.IntHandle(1), newRow, nil) + require.NoError(t, err) + + for i, d := range td { + oldVarint, err := tablecodec.EncodeValue(nil, nil, d.output) // tablecodec will encode as varint/varuint + require.NoError(t, err) + require.Equal(t, oldRow[i], oldVarint) + } } -func (s *testSuite) TestCodecUtil(c *C) { +func TestCodecUtil(t *testing.T) { + t.Parallel() + colIDs := []int64{1, 2, 3, 4} tps := make([]*types.FieldType, 4) for i := 0; i < 3; i++ { @@ -729,15 +765,16 @@ func (s *testSuite) TestCodecUtil(c *C) { tps[3] = types.NewFieldType(mysql.TypeNull) sc := new(stmtctx.StatementContext) oldRow, err := tablecodec.EncodeOldRow(sc, types.MakeDatums(1, 2, 3, nil), colIDs, nil, nil) - c.Check(err, IsNil) + require.NoError(t, err) + var ( rb rowcodec.Encoder newRow []byte ) newRow, err = rowcodec.EncodeFromOldRow(&rb, nil, oldRow, nil) - c.Assert(err, IsNil) - c.Assert(rowcodec.IsNewFormat(newRow), IsTrue) - c.Assert(rowcodec.IsNewFormat(oldRow), IsFalse) + require.NoError(t, err) + require.True(t, rowcodec.IsNewFormat(newRow)) + require.False(t, rowcodec.IsNewFormat(oldRow)) // test stringer for decoder. var cols = make([]rowcodec.ColInfo, 0, len(tps)) @@ -751,25 +788,27 @@ func (s *testSuite) TestCodecUtil(c *C) { d := rowcodec.NewDecoder(cols, []int64{-1}, nil) // test ColumnIsNull - isNil, err := d.ColumnIsNull(newRow, 4, nil) - c.Assert(err, IsNil) - c.Assert(isNil, IsTrue) - isNil, err = d.ColumnIsNull(newRow, 1, nil) - c.Assert(err, IsNil) - c.Assert(isNil, IsFalse) - isNil, err = d.ColumnIsNull(newRow, 5, nil) - c.Assert(err, IsNil) - c.Assert(isNil, IsTrue) - isNil, err = d.ColumnIsNull(newRow, 5, []byte{1}) - c.Assert(err, IsNil) - c.Assert(isNil, IsFalse) + isNull, err := d.ColumnIsNull(newRow, 4, nil) + require.NoError(t, err) + require.True(t, isNull) + isNull, err = d.ColumnIsNull(newRow, 1, nil) + require.NoError(t, err) + require.False(t, isNull) + isNull, err = d.ColumnIsNull(newRow, 5, nil) + require.NoError(t, err) + require.True(t, isNull) + isNull, err = d.ColumnIsNull(newRow, 5, []byte{1}) + require.NoError(t, err) + require.False(t, isNull) // test isRowKey - c.Assert(rowcodec.IsRowKey([]byte{'b', 't'}), IsFalse) - c.Assert(rowcodec.IsRowKey([]byte{'t', 'r'}), IsFalse) + require.False(t, rowcodec.IsRowKey([]byte{'b', 't'})) + require.False(t, rowcodec.IsRowKey([]byte{'t', 'r'})) } -func (s *testSuite) TestOldRowCodec(c *C) { +func TestOldRowCodec(t *testing.T) { + t.Parallel() + colIDs := []int64{1, 2, 3, 4} tps := make([]*types.FieldType, 4) for i := 0; i < 3; i++ { @@ -778,14 +817,15 @@ func (s *testSuite) TestOldRowCodec(c *C) { tps[3] = types.NewFieldType(mysql.TypeNull) sc := new(stmtctx.StatementContext) oldRow, err := tablecodec.EncodeOldRow(sc, types.MakeDatums(1, 2, 3, nil), colIDs, nil, nil) - c.Check(err, IsNil) + require.NoError(t, err) var ( rb rowcodec.Encoder newRow []byte ) newRow, err = rowcodec.EncodeFromOldRow(&rb, nil, oldRow, nil) - c.Check(err, IsNil) + require.NoError(t, err) + cols := make([]rowcodec.ColInfo, len(tps)) for i, tp := range tps { cols[i] = rowcodec.ColInfo{ @@ -796,14 +836,16 @@ func (s *testSuite) TestOldRowCodec(c *C) { rd := rowcodec.NewChunkDecoder(cols, []int64{-1}, nil, time.Local) chk := chunk.NewChunkWithCapacity(tps, 1) err = rd.DecodeToChunk(newRow, kv.IntHandle(-1), chk) - c.Assert(err, IsNil) + require.NoError(t, err) row := chk.GetRow(0) for i := 0; i < 3; i++ { - c.Assert(row.GetInt64(i), Equals, int64(i)+1) + require.Equal(t, int64(i+1), row.GetInt64(i)) } } -func (s *testSuite) Test65535Bug(c *C) { +func Test65535Bug(t *testing.T) { + t.Parallel() + colIds := []int64{1} tps := make([]*types.FieldType, 1) tps[0] = types.NewFieldType(mysql.TypeString) @@ -811,7 +853,7 @@ func (s *testSuite) Test65535Bug(c *C) { text65535 := strings.Repeat("a", 65535) encode := rowcodec.Encoder{} bd, err := encode.Encode(sc, colIds, []types.Datum{types.NewStringDatum(text65535)}, nil) - c.Check(err, IsNil) + require.NoError(t, err) cols := make([]rowcodec.ColInfo, 1) cols[0] = rowcodec.ColInfo{ @@ -820,9 +862,10 @@ func (s *testSuite) Test65535Bug(c *C) { } dc := rowcodec.NewDatumMapDecoder(cols, nil) result, err := dc.DecodeToDatumMap(bd, nil) - c.Check(err, IsNil) + require.NoError(t, err) + rs := result[1] - c.Check(rs.GetString(), Equals, text65535) + require.Equal(t, text65535, rs.GetString()) } var (