diff --git a/cmd/explaintest/main.go b/cmd/explaintest/main.go index fa5265f7af871..a85c8ce82dd3c 100644 --- a/cmd/explaintest/main.go +++ b/cmd/explaintest/main.go @@ -30,7 +30,6 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/log" "github.com/pingcap/parser/ast" - "github.com/pingcap/tidb/domain" "github.com/pingcap/tidb/session" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/util/logutil" @@ -663,8 +662,6 @@ func main() { log.Fatal(fmt.Sprintf("%s failed", sql), zap.Error(err)) } } - // Wait global variables to reload. - time.Sleep(domain.GlobalVariableCacheExpiry) if _, err = mdb.Exec("set sql_mode='STRICT_TRANS_TABLES'"); err != nil { log.Fatal("set sql_mode='STRICT_TRANS_TABLES' failed", zap.Error(err)) diff --git a/ddl/column_change_test.go b/ddl/column_change_test.go index 94e8787a2bdc4..6bd5a94f7235e 100644 --- a/ddl/column_change_test.go +++ b/ddl/column_change_test.go @@ -47,15 +47,18 @@ type testColumnChangeSuite struct { func (s *testColumnChangeSuite) SetUpSuite(c *C) { SetWaitTimeWhenErrorOccurred(1 * time.Microsecond) s.store = testCreateStore(c, "test_column_change") - s.dbInfo = &model.DBInfo{ - Name: model.NewCIStr("test_column_change"), - ID: 1, - } - err := kv.RunInNewTxn(context.Background(), s.store, true, func(ctx context.Context, txn kv.Transaction) error { - t := meta.NewMeta(txn) - return errors.Trace(t.CreateDatabase(s.dbInfo)) - }) - c.Check(err, IsNil) + d := testNewDDLAndStart( + context.Background(), + c, + WithStore(s.store), + WithLease(testLease), + ) + defer func() { + err := d.Stop() + c.Assert(err, IsNil) + }() + s.dbInfo = testSchemaInfo(c, d, "test_index_change") + testCreateSchema(c, testNewContext(d), d, s.dbInfo) } func (s *testColumnChangeSuite) TearDownSuite(c *C) { diff --git a/ddl/column_test.go b/ddl/column_test.go index 862fb4aa04c59..f3eaa26d22385 100644 --- a/ddl/column_test.go +++ b/ddl/column_test.go @@ -54,8 +54,7 @@ func (s *testColumnSuite) SetUpSuite(c *C) { s.dbInfo = testSchemaInfo(c, d, "test_column") testCreateSchema(c, testNewContext(d), d, s.dbInfo) - err := d.Stop() - c.Assert(err, IsNil) + c.Assert(d.Stop(), IsNil) } func (s *testColumnSuite) TearDownSuite(c *C) { diff --git a/ddl/db_change_test.go b/ddl/db_change_test.go index 7c3a0f9ad970f..041f35c7734a8 100644 --- a/ddl/db_change_test.go +++ b/ddl/db_change_test.go @@ -1063,7 +1063,6 @@ func (s *testStateChangeSuite) TestParallelAddGeneratedColumnAndAlterModifyColum _, err = s.se.Execute(context.Background(), "set global tidb_enable_change_column_type = 0") c.Assert(err, IsNil) }() - domain.GetDomain(s.se).GetGlobalVarsCache().Disable() sql1 := "ALTER TABLE t ADD COLUMN f INT GENERATED ALWAYS AS(a+1);" sql2 := "ALTER TABLE t MODIFY COLUMN a tinyint;" @@ -1083,7 +1082,6 @@ func (s *testStateChangeSuite) TestParallelAlterModifyColumnAndAddPK(c *C) { _, err = s.se.Execute(context.Background(), "set global tidb_enable_change_column_type = 0") c.Assert(err, IsNil) }() - domain.GetDomain(s.se).GetGlobalVarsCache().Disable() sql1 := "ALTER TABLE t ADD PRIMARY KEY (b) NONCLUSTERED;" sql2 := "ALTER TABLE t MODIFY COLUMN b tinyint;" diff --git a/ddl/ddl.go b/ddl/ddl.go index 6f20fe25ccc07..9eb05b86741ed 100644 --- a/ddl/ddl.go +++ b/ddl/ddl.go @@ -202,7 +202,7 @@ type ddlCtx struct { ddlEventCh chan<- *util.Event lease time.Duration // lease is schema lease. binlogCli *pumpcli.PumpsClient // binlogCli is used for Binlog. - infoHandle *infoschema.Handle + infoCache *infoschema.InfoCache statsHandle *handle.Handle tableLockCkr util.DeadTableLockChecker etcdCli *clientv3.Client @@ -282,6 +282,15 @@ func newDDL(ctx context.Context, options ...Option) *ddl { deadLockCkr = util.NewDeadTableLockChecker(etcdCli) } + // TODO: make store and infoCache explicit arguments + // these two should be ensured to exist + if opt.Store == nil { + panic("store should not be nil") + } + if opt.InfoCache == nil { + panic("infoCache should not be nil") + } + ddlCtx := &ddlCtx{ uuid: id, store: opt.Store, @@ -290,7 +299,7 @@ func newDDL(ctx context.Context, options ...Option) *ddl { ownerManager: manager, schemaSyncer: syncer, binlogCli: binloginfo.GetPumpsClient(), - infoHandle: opt.InfoHandle, + infoCache: opt.InfoCache, tableLockCkr: deadLockCkr, etcdCli: opt.EtcdCli, } @@ -411,7 +420,7 @@ func (d *ddl) GetLease() time.Duration { // Please don't use this function, it is used by TestParallelDDLBeforeRunDDLJob to intercept the calling of d.infoHandle.Get(), use d.infoHandle.Get() instead. // Otherwise, the TestParallelDDLBeforeRunDDLJob will hang up forever. func (d *ddl) GetInfoSchemaWithInterceptor(ctx sessionctx.Context) infoschema.InfoSchema { - is := d.infoHandle.Get() + is := d.infoCache.GetLatest() d.mu.RLock() defer d.mu.RUnlock() @@ -649,10 +658,7 @@ func (d *ddl) startCleanDeadTableLock() { if !d.ownerManager.IsOwner() { continue } - if d.infoHandle == nil || !d.infoHandle.IsValid() { - continue - } - deadLockTables, err := d.tableLockCkr.GetDeadLockedTables(d.ctx, d.infoHandle.Get().AllSchemas()) + deadLockTables, err := d.tableLockCkr.GetDeadLockedTables(d.ctx, d.infoCache.GetLatest().AllSchemas()) if err != nil { logutil.BgLogger().Info("[ddl] get dead table lock failed.", zap.Error(err)) continue diff --git a/ddl/ddl_api.go b/ddl/ddl_api.go index a3f8bb7f9c622..d0289dc19e39f 100644 --- a/ddl/ddl_api.go +++ b/ddl/ddl_api.go @@ -2367,7 +2367,7 @@ func (d *ddl) AlterTable(ctx sessionctx.Context, ident ast.Ident, specs []*ast.A return errors.Trace(err) } - is := d.infoHandle.Get() + is := d.infoCache.GetLatest() if is.TableIsView(ident.Schema, ident.Name) || is.TableIsSequence(ident.Schema, ident.Name) { return ErrWrongObject.GenWithStackByArgs(ident.Schema, ident.Name, "BASE TABLE") } @@ -2898,7 +2898,7 @@ func (d *ddl) AddColumns(ctx sessionctx.Context, ti ast.Ident, specs []*ast.Alte // AddTablePartitions will add a new partition to the table. func (d *ddl) AddTablePartitions(ctx sessionctx.Context, ident ast.Ident, spec *ast.AlterTableSpec) error { - is := d.infoHandle.Get() + is := d.infoCache.GetLatest() schema, ok := is.SchemaByName(ident.Schema) if !ok { return errors.Trace(infoschema.ErrDatabaseNotExists.GenWithStackByArgs(schema)) @@ -2959,7 +2959,7 @@ func (d *ddl) AddTablePartitions(ctx sessionctx.Context, ident ast.Ident, spec * // CoalescePartitions coalesce partitions can be used with a table that is partitioned by hash or key to reduce the number of partitions by number. func (d *ddl) CoalescePartitions(ctx sessionctx.Context, ident ast.Ident, spec *ast.AlterTableSpec) error { - is := d.infoHandle.Get() + is := d.infoCache.GetLatest() schema, ok := is.SchemaByName(ident.Schema) if !ok { return errors.Trace(infoschema.ErrDatabaseNotExists.GenWithStackByArgs(schema)) @@ -2991,7 +2991,7 @@ func (d *ddl) CoalescePartitions(ctx sessionctx.Context, ident ast.Ident, spec * } func (d *ddl) TruncateTablePartition(ctx sessionctx.Context, ident ast.Ident, spec *ast.AlterTableSpec) error { - is := d.infoHandle.Get() + is := d.infoCache.GetLatest() schema, ok := is.SchemaByName(ident.Schema) if !ok { return errors.Trace(infoschema.ErrDatabaseNotExists.GenWithStackByArgs(schema)) @@ -3039,7 +3039,7 @@ func (d *ddl) TruncateTablePartition(ctx sessionctx.Context, ident ast.Ident, sp } func (d *ddl) DropTablePartition(ctx sessionctx.Context, ident ast.Ident, spec *ast.AlterTableSpec) error { - is := d.infoHandle.Get() + is := d.infoCache.GetLatest() schema, ok := is.SchemaByName(ident.Schema) if !ok { return errors.Trace(infoschema.ErrDatabaseNotExists.GenWithStackByArgs(schema)) @@ -3752,7 +3752,7 @@ func processAndCheckDefaultValueAndColumn(ctx sessionctx.Context, col *table.Col func (d *ddl) getModifiableColumnJob(ctx sessionctx.Context, ident ast.Ident, originalColName model.CIStr, spec *ast.AlterTableSpec) (*model.Job, error) { specNewColumn := spec.NewColumns[0] - is := d.infoHandle.Get() + is := d.infoCache.GetLatest() schema, ok := is.SchemaByName(ident.Schema) if !ok { return nil, errors.Trace(infoschema.ErrDatabaseNotExists) @@ -4203,7 +4203,7 @@ func (d *ddl) ModifyColumn(ctx sessionctx.Context, ident ast.Ident, spec *ast.Al func (d *ddl) AlterColumn(ctx sessionctx.Context, ident ast.Ident, spec *ast.AlterTableSpec) error { specNewColumn := spec.NewColumns[0] - is := d.infoHandle.Get() + is := d.infoCache.GetLatest() schema, ok := is.SchemaByName(ident.Schema) if !ok { return infoschema.ErrTableNotExists.GenWithStackByArgs(ident.Schema, ident.Name) @@ -4257,7 +4257,7 @@ func (d *ddl) AlterColumn(ctx sessionctx.Context, ident ast.Ident, spec *ast.Alt // AlterTableComment updates the table comment information. func (d *ddl) AlterTableComment(ctx sessionctx.Context, ident ast.Ident, spec *ast.AlterTableSpec) error { - is := d.infoHandle.Get() + is := d.infoCache.GetLatest() schema, ok := is.SchemaByName(ident.Schema) if !ok { return infoschema.ErrDatabaseNotExists.GenWithStackByArgs(ident.Schema) @@ -4310,7 +4310,7 @@ func (d *ddl) AlterTableCharsetAndCollate(ctx sessionctx.Context, ident ast.Iden return ErrUnknownCharacterSet.GenWithStackByArgs(toCharset) } - is := d.infoHandle.Get() + is := d.infoCache.GetLatest() schema, ok := is.SchemaByName(ident.Schema) if !ok { return infoschema.ErrDatabaseNotExists.GenWithStackByArgs(ident.Schema) @@ -4471,7 +4471,7 @@ func (d *ddl) AlterTableDropStatistics(ctx sessionctx.Context, ident ast.Ident, // UpdateTableReplicaInfo updates the table flash replica infos. func (d *ddl) UpdateTableReplicaInfo(ctx sessionctx.Context, physicalID int64, available bool) error { - is := d.infoHandle.Get() + is := d.infoCache.GetLatest() tb, ok := is.TableByID(physicalID) if !ok { tb, _, _ = is.FindTableByPartitionID(physicalID) @@ -4574,7 +4574,7 @@ func checkAlterTableCharset(tblInfo *model.TableInfo, dbInfo *model.DBInfo, toCh // In TiDB, indexes are case-insensitive (so index 'a' and 'A" are considered the same index), // but index names are case-sensitive (we can rename index 'a' to 'A') func (d *ddl) RenameIndex(ctx sessionctx.Context, ident ast.Ident, spec *ast.AlterTableSpec) error { - is := d.infoHandle.Get() + is := d.infoCache.GetLatest() schema, ok := is.SchemaByName(ident.Schema) if !ok { return infoschema.ErrDatabaseNotExists.GenWithStackByArgs(ident.Schema) @@ -5232,7 +5232,7 @@ func buildFKInfo(fkName model.CIStr, keys []*ast.IndexPartSpecification, refer * } func (d *ddl) CreateForeignKey(ctx sessionctx.Context, ti ast.Ident, fkName model.CIStr, keys []*ast.IndexPartSpecification, refer *ast.ReferenceDef) error { - is := d.infoHandle.Get() + is := d.infoCache.GetLatest() schema, ok := is.SchemaByName(ti.Schema) if !ok { return infoschema.ErrDatabaseNotExists.GenWithStackByArgs(ti.Schema) @@ -5264,7 +5264,7 @@ func (d *ddl) CreateForeignKey(ctx sessionctx.Context, ti ast.Ident, fkName mode } func (d *ddl) DropForeignKey(ctx sessionctx.Context, ti ast.Ident, fkName model.CIStr) error { - is := d.infoHandle.Get() + is := d.infoCache.GetLatest() schema, ok := is.SchemaByName(ti.Schema) if !ok { return infoschema.ErrDatabaseNotExists.GenWithStackByArgs(ti.Schema) @@ -5290,7 +5290,7 @@ func (d *ddl) DropForeignKey(ctx sessionctx.Context, ti ast.Ident, fkName model. } func (d *ddl) DropIndex(ctx sessionctx.Context, ti ast.Ident, indexName model.CIStr, ifExists bool) error { - is := d.infoHandle.Get() + is := d.infoCache.GetLatest() schema, ok := is.SchemaByName(ti.Schema) if !ok { return errors.Trace(infoschema.ErrDatabaseNotExists) @@ -6036,7 +6036,7 @@ func (d *ddl) AlterTableAlterPartition(ctx sessionctx.Context, ident ast.Ident, return errors.Trace(err) } - oldBundle := infoschema.GetBundle(d.infoHandle.Get(), []int64{partitionID, meta.ID, schema.ID}) + oldBundle := infoschema.GetBundle(d.infoCache.GetLatest(), []int64{partitionID, meta.ID, schema.ID}) oldBundle.ID = placement.GroupID(partitionID) diff --git a/ddl/ddl_test.go b/ddl/ddl_test.go index b77c3300c2700..79635bfc0933b 100644 --- a/ddl/ddl_test.go +++ b/ddl/ddl_test.go @@ -24,6 +24,7 @@ import ( "github.com/pingcap/parser/model" "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/domain/infosync" + "github.com/pingcap/tidb/infoschema" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/meta" "github.com/pingcap/tidb/meta/autoid" @@ -86,6 +87,10 @@ func TestT(t *testing.T) { } func testNewDDLAndStart(ctx context.Context, c *C, options ...Option) *ddl { + // init infoCache and a stub infoSchema + ic := infoschema.NewCache(2) + ic.Insert(infoschema.MockInfoSchemaWithSchemaVer(nil, 0)) + options = append(options, WithInfoCache(ic)) d := newDDL(ctx, options...) err := d.Start(nil) c.Assert(err, IsNil) diff --git a/ddl/ddl_worker_test.go b/ddl/ddl_worker_test.go index 6e745820b04b9..72fef7c96c19e 100644 --- a/ddl/ddl_worker_test.go +++ b/ddl/ddl_worker_test.go @@ -247,7 +247,7 @@ func (s *testDDLSuite) TestTableError(c *C) { // Schema ID is wrong, so dropping table is failed. doDDLJobErr(c, -1, 1, model.ActionDropTable, nil, ctx, d) // Table ID is wrong, so dropping table is failed. - dbInfo := testSchemaInfo(c, d, "test") + dbInfo := testSchemaInfo(c, d, "test_ddl") testCreateSchema(c, testNewContext(d), d, dbInfo) job := doDDLJobErr(c, dbInfo.ID, -1, model.ActionDropTable, nil, ctx, d) @@ -295,7 +295,7 @@ func (s *testDDLSuite) TestViewError(c *C) { c.Assert(err, IsNil) }() ctx := testNewContext(d) - dbInfo := testSchemaInfo(c, d, "test") + dbInfo := testSchemaInfo(c, d, "test_ddl") testCreateSchema(c, testNewContext(d), d, dbInfo) // Table ID or schema ID is wrong, so getting table is failed. @@ -363,7 +363,7 @@ func (s *testDDLSuite) TestForeignKeyError(c *C) { doDDLJobErr(c, -1, 1, model.ActionAddForeignKey, nil, ctx, d) doDDLJobErr(c, -1, 1, model.ActionDropForeignKey, nil, ctx, d) - dbInfo := testSchemaInfo(c, d, "test") + dbInfo := testSchemaInfo(c, d, "test_ddl") tblInfo := testTableInfo(c, d, "t", 3) testCreateSchema(c, ctx, d, dbInfo) testCreateTable(c, ctx, d, dbInfo, tblInfo) @@ -393,7 +393,7 @@ func (s *testDDLSuite) TestIndexError(c *C) { doDDLJobErr(c, -1, 1, model.ActionAddIndex, nil, ctx, d) doDDLJobErr(c, -1, 1, model.ActionDropIndex, nil, ctx, d) - dbInfo := testSchemaInfo(c, d, "test") + dbInfo := testSchemaInfo(c, d, "test_ddl") tblInfo := testTableInfo(c, d, "t", 3) testCreateSchema(c, ctx, d, dbInfo) testCreateTable(c, ctx, d, dbInfo, tblInfo) @@ -435,7 +435,7 @@ func (s *testDDLSuite) TestColumnError(c *C) { }() ctx := testNewContext(d) - dbInfo := testSchemaInfo(c, d, "test") + dbInfo := testSchemaInfo(c, d, "test_ddl") tblInfo := testTableInfo(c, d, "t", 3) testCreateSchema(c, ctx, d, dbInfo) testCreateTable(c, ctx, d, dbInfo, tblInfo) diff --git a/ddl/index_change_test.go b/ddl/index_change_test.go index dfdfc7111c372..6a34599137c10 100644 --- a/ddl/index_change_test.go +++ b/ddl/index_change_test.go @@ -22,7 +22,6 @@ import ( "github.com/pingcap/parser/model" "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/kv" - "github.com/pingcap/tidb/meta" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/table/tables" @@ -38,15 +37,18 @@ type testIndexChangeSuite struct { func (s *testIndexChangeSuite) SetUpSuite(c *C) { s.store = testCreateStore(c, "test_index_change") - s.dbInfo = &model.DBInfo{ - Name: model.NewCIStr("test_index_change"), - ID: 1, - } - err := kv.RunInNewTxn(context.Background(), s.store, true, func(ctx context.Context, txn kv.Transaction) error { - t := meta.NewMeta(txn) - return errors.Trace(t.CreateDatabase(s.dbInfo)) - }) - c.Check(err, IsNil, Commentf("err %v", errors.ErrorStack(err))) + d := testNewDDLAndStart( + context.Background(), + c, + WithStore(s.store), + WithLease(testLease), + ) + defer func() { + err := d.Stop() + c.Assert(err, IsNil) + }() + s.dbInfo = testSchemaInfo(c, d, "test_index_change") + testCreateSchema(c, testNewContext(d), d, s.dbInfo) } func (s *testIndexChangeSuite) TearDownSuite(c *C) { diff --git a/ddl/options.go b/ddl/options.go index 8613a8e9affa9..9238a7c8542ff 100644 --- a/ddl/options.go +++ b/ddl/options.go @@ -26,11 +26,11 @@ type Option func(*Options) // Options represents all the options of the DDL module needs type Options struct { - EtcdCli *clientv3.Client - Store kv.Storage - InfoHandle *infoschema.Handle - Hook Callback - Lease time.Duration + EtcdCli *clientv3.Client + Store kv.Storage + InfoCache *infoschema.InfoCache + Hook Callback + Lease time.Duration } // WithEtcdClient specifies the `clientv3.Client` of DDL used to request the etcd service @@ -47,10 +47,10 @@ func WithStore(store kv.Storage) Option { } } -// WithInfoHandle specifies the `infoschema.Handle` -func WithInfoHandle(ih *infoschema.Handle) Option { +// WithInfoCache specifies the `infoschema.InfoCache` +func WithInfoCache(ic *infoschema.InfoCache) Option { return func(options *Options) { - options.InfoHandle = ih + options.InfoCache = ic } } diff --git a/ddl/options_test.go b/ddl/options_test.go index 294d68731e4c3..22a451d622c71 100644 --- a/ddl/options_test.go +++ b/ddl/options_test.go @@ -33,14 +33,14 @@ func (s *ddlOptionsSuite) TestOptions(c *C) { callback := &ddl.BaseCallback{} lease := time.Second * 3 store := &mock.Store{} - infoHandle := infoschema.NewHandle(store) + infoHandle := infoschema.NewCache(16) options := []ddl.Option{ ddl.WithEtcdClient(client), ddl.WithHook(callback), ddl.WithLease(lease), ddl.WithStore(store), - ddl.WithInfoHandle(infoHandle), + ddl.WithInfoCache(infoHandle), } opt := &ddl.Options{} @@ -52,5 +52,5 @@ func (s *ddlOptionsSuite) TestOptions(c *C) { c.Assert(opt.Hook, Equals, callback) c.Assert(opt.Lease, Equals, lease) c.Assert(opt.Store, Equals, store) - c.Assert(opt.InfoHandle, Equals, infoHandle) + c.Assert(opt.InfoCache, Equals, infoHandle) } diff --git a/ddl/partition.go b/ddl/partition.go index 0cafa9d2ff525..4e55ec1779e21 100644 --- a/ddl/partition.go +++ b/ddl/partition.go @@ -911,18 +911,15 @@ func getTableInfoWithDroppingPartitions(t *model.TableInfo) *model.TableInfo { } func dropRuleBundles(d *ddlCtx, physicalTableIDs []int64) error { - if d.infoHandle != nil && d.infoHandle.IsValid() { - bundles := make([]*placement.Bundle, 0, len(physicalTableIDs)) - for _, ID := range physicalTableIDs { - oldBundle, ok := d.infoHandle.Get().BundleByName(placement.GroupID(ID)) - if ok && !oldBundle.IsEmpty() { - bundles = append(bundles, placement.BuildPlacementDropBundle(ID)) - } + bundles := make([]*placement.Bundle, 0, len(physicalTableIDs)) + for _, ID := range physicalTableIDs { + oldBundle, ok := d.infoCache.GetLatest().BundleByName(placement.GroupID(ID)) + if ok && !oldBundle.IsEmpty() { + bundles = append(bundles, placement.BuildPlacementDropBundle(ID)) } - err := infosync.PutRuleBundles(context.TODO(), bundles) - return err } - return nil + err := infosync.PutRuleBundles(context.TODO(), bundles) + return err } // onDropTablePartition deletes old partition meta. @@ -1095,22 +1092,20 @@ func onTruncateTablePartition(d *ddlCtx, t *meta.Meta, job *model.Job) (int64, e } } - if d.infoHandle != nil && d.infoHandle.IsValid() { - bundles := make([]*placement.Bundle, 0, len(oldIDs)) + bundles := make([]*placement.Bundle, 0, len(oldIDs)) - for i, oldID := range oldIDs { - oldBundle, ok := d.infoHandle.Get().BundleByName(placement.GroupID(oldID)) - if ok && !oldBundle.IsEmpty() { - bundles = append(bundles, placement.BuildPlacementDropBundle(oldID)) - bundles = append(bundles, placement.BuildPlacementCopyBundle(oldBundle, newPartitions[i].ID)) - } + for i, oldID := range oldIDs { + oldBundle, ok := d.infoCache.GetLatest().BundleByName(placement.GroupID(oldID)) + if ok && !oldBundle.IsEmpty() { + bundles = append(bundles, placement.BuildPlacementDropBundle(oldID)) + bundles = append(bundles, placement.BuildPlacementCopyBundle(oldBundle, newPartitions[i].ID)) } + } - err = infosync.PutRuleBundles(context.TODO(), bundles) - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Wrapf(err, "failed to notify PD the placement rules") - } + err = infosync.PutRuleBundles(context.TODO(), bundles) + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Wrapf(err, "failed to notify PD the placement rules") } newIDs := make([]int64, len(oldIDs)) @@ -1299,27 +1294,25 @@ func (w *worker) onExchangeTablePartition(d *ddlCtx, t *meta.Meta, job *model.Jo // the follow code is a swap function for rules of two partitions // though partitions has exchanged their ID, swap still take effect - if d.infoHandle != nil && d.infoHandle.IsValid() { - bundles := make([]*placement.Bundle, 0, 2) - ptBundle, ptOK := d.infoHandle.Get().BundleByName(placement.GroupID(partDef.ID)) - ptOK = ptOK && !ptBundle.IsEmpty() - ntBundle, ntOK := d.infoHandle.Get().BundleByName(placement.GroupID(nt.ID)) - ntOK = ntOK && !ntBundle.IsEmpty() - if ptOK && ntOK { - bundles = append(bundles, placement.BuildPlacementCopyBundle(ptBundle, nt.ID)) - bundles = append(bundles, placement.BuildPlacementCopyBundle(ntBundle, partDef.ID)) - } else if ptOK { - bundles = append(bundles, placement.BuildPlacementDropBundle(partDef.ID)) - bundles = append(bundles, placement.BuildPlacementCopyBundle(ptBundle, nt.ID)) - } else if ntOK { - bundles = append(bundles, placement.BuildPlacementDropBundle(nt.ID)) - bundles = append(bundles, placement.BuildPlacementCopyBundle(ntBundle, partDef.ID)) - } - err = infosync.PutRuleBundles(context.TODO(), bundles) - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Wrapf(err, "failed to notify PD the placement rules") - } + bundles := make([]*placement.Bundle, 0, 2) + ptBundle, ptOK := d.infoCache.GetLatest().BundleByName(placement.GroupID(partDef.ID)) + ptOK = ptOK && !ptBundle.IsEmpty() + ntBundle, ntOK := d.infoCache.GetLatest().BundleByName(placement.GroupID(nt.ID)) + ntOK = ntOK && !ntBundle.IsEmpty() + if ptOK && ntOK { + bundles = append(bundles, placement.BuildPlacementCopyBundle(ptBundle, nt.ID)) + bundles = append(bundles, placement.BuildPlacementCopyBundle(ntBundle, partDef.ID)) + } else if ptOK { + bundles = append(bundles, placement.BuildPlacementDropBundle(partDef.ID)) + bundles = append(bundles, placement.BuildPlacementCopyBundle(ptBundle, nt.ID)) + } else if ntOK { + bundles = append(bundles, placement.BuildPlacementDropBundle(nt.ID)) + bundles = append(bundles, placement.BuildPlacementCopyBundle(ntBundle, partDef.ID)) + } + err = infosync.PutRuleBundles(context.TODO(), bundles) + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Wrapf(err, "failed to notify PD the placement rules") } ver, err = updateSchemaVersion(t, job) diff --git a/ddl/reorg_test.go b/ddl/reorg_test.go index 18dd9a975fceb..4c28540e7ad3b 100644 --- a/ddl/reorg_test.go +++ b/ddl/reorg_test.go @@ -217,7 +217,7 @@ func (s *testDDLSuite) TestReorgOwner(c *C) { c.Assert(err, IsNil) }() - dbInfo := testSchemaInfo(c, d1, "test") + dbInfo := testSchemaInfo(c, d1, "test_reorg") testCreateSchema(c, ctx, d1, dbInfo) tblInfo := testTableInfo(c, d1, "t", 3) diff --git a/ddl/restart_test.go b/ddl/restart_test.go index b587d54b80cc8..b7791ef7679bd 100644 --- a/ddl/restart_test.go +++ b/ddl/restart_test.go @@ -120,7 +120,7 @@ func (s *testSchemaSuite) TestSchemaResume(c *C) { testCheckOwner(c, d1, true) - dbInfo := testSchemaInfo(c, d1, "test") + dbInfo := testSchemaInfo(c, d1, "test_restart") job := &model.Job{ SchemaID: dbInfo.ID, Type: model.ActionCreateSchema, @@ -157,7 +157,7 @@ func (s *testStatSuite) TestStat(c *C) { c.Assert(err, IsNil) }() - dbInfo := testSchemaInfo(c, d, "test") + dbInfo := testSchemaInfo(c, d, "test_restart") testCreateSchema(c, testNewContext(d), d, dbInfo) // TODO: Get this information from etcd. diff --git a/ddl/schema.go b/ddl/schema.go index 823e12a551900..a4b14a49bdbc3 100644 --- a/ddl/schema.go +++ b/ddl/schema.go @@ -68,16 +68,12 @@ func onCreateSchema(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ error } func checkSchemaNotExists(d *ddlCtx, t *meta.Meta, schemaID int64, dbInfo *model.DBInfo) error { - // d.infoHandle maybe nil in some test. - if d.infoHandle == nil { - return checkSchemaNotExistsFromStore(t, schemaID, dbInfo) - } // Try to use memory schema info to check first. currVer, err := t.GetSchemaVersion() if err != nil { return err } - is := d.infoHandle.Get() + is := d.infoCache.GetLatest() if is.SchemaMetaVersion() == currVer { return checkSchemaNotExistsFromInfoSchema(is, schemaID, dbInfo) } @@ -169,7 +165,7 @@ func onDropSchema(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ error) oldIDs := getIDs(tables) bundles := make([]*placement.Bundle, 0, len(oldIDs)+1) for _, ID := range append(oldIDs, dbInfo.ID) { - oldBundle, ok := d.infoHandle.Get().BundleByName(placement.GroupID(ID)) + oldBundle, ok := d.infoCache.GetLatest().BundleByName(placement.GroupID(ID)) if ok && !oldBundle.IsEmpty() { bundles = append(bundles, placement.BuildPlacementDropBundle(ID)) } diff --git a/ddl/schema_test.go b/ddl/schema_test.go index c70a0b793bb35..b4c8efee7b089 100644 --- a/ddl/schema_test.go +++ b/ddl/schema_test.go @@ -139,7 +139,7 @@ func (s *testSchemaSuite) TestSchema(c *C) { c.Assert(err, IsNil) }() ctx := testNewContext(d) - dbInfo := testSchemaInfo(c, d, "test") + dbInfo := testSchemaInfo(c, d, "test_schema") // create a database. job := testCreateSchema(c, ctx, d, dbInfo) @@ -228,7 +228,7 @@ func (s *testSchemaSuite) TestSchemaWaitJob(c *C) { // d2 must not be owner. d2.ownerManager.RetireOwner() - dbInfo := testSchemaInfo(c, d2, "test") + dbInfo := testSchemaInfo(c, d2, "test_schema") testCreateSchema(c, ctx, d2, dbInfo) testCheckSchemaState(c, d2, dbInfo, model.StatePublic) diff --git a/ddl/stat_test.go b/ddl/stat_test.go index fe562a0ae0fb8..1ed3cbfe4c7fc 100644 --- a/ddl/stat_test.go +++ b/ddl/stat_test.go @@ -61,7 +61,7 @@ func (s *testSerialStatSuite) TestDDLStatsInfo(c *C) { c.Assert(err, IsNil) }() - dbInfo := testSchemaInfo(c, d, "test") + dbInfo := testSchemaInfo(c, d, "test_stat") testCreateSchema(c, testNewContext(d), d, dbInfo) tblInfo := testTableInfo(c, d, "t", 2) ctx := testNewContext(d) diff --git a/ddl/table.go b/ddl/table.go index 668de3ac41c05..424dd040a0de9 100644 --- a/ddl/table.go +++ b/ddl/table.go @@ -487,34 +487,32 @@ func onTruncateTable(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ erro } } - if d.infoHandle != nil && d.infoHandle.IsValid() { - is := d.infoHandle.Get() - - bundles := make([]*placement.Bundle, 0, len(oldPartitionIDs)+1) - if oldBundle, ok := is.BundleByName(placement.GroupID(tableID)); ok { - bundles = append(bundles, placement.BuildPlacementCopyBundle(oldBundle, newTableID)) - } - - if pi := tblInfo.GetPartitionInfo(); pi != nil { - oldIDs := make([]int64, 0, len(oldPartitionIDs)) - newIDs := make([]int64, 0, len(oldPartitionIDs)) - newDefs := pi.Definitions - for i := range oldPartitionIDs { - newID := newDefs[i].ID - if oldBundle, ok := is.BundleByName(placement.GroupID(oldPartitionIDs[i])); ok && !oldBundle.IsEmpty() { - oldIDs = append(oldIDs, oldPartitionIDs[i]) - newIDs = append(newIDs, newID) - bundles = append(bundles, placement.BuildPlacementCopyBundle(oldBundle, newID)) - } + is := d.infoCache.GetLatest() + + bundles := make([]*placement.Bundle, 0, len(oldPartitionIDs)+1) + if oldBundle, ok := is.BundleByName(placement.GroupID(tableID)); ok { + bundles = append(bundles, placement.BuildPlacementCopyBundle(oldBundle, newTableID)) + } + + if pi := tblInfo.GetPartitionInfo(); pi != nil { + oldIDs := make([]int64, 0, len(oldPartitionIDs)) + newIDs := make([]int64, 0, len(oldPartitionIDs)) + newDefs := pi.Definitions + for i := range oldPartitionIDs { + newID := newDefs[i].ID + if oldBundle, ok := is.BundleByName(placement.GroupID(oldPartitionIDs[i])); ok && !oldBundle.IsEmpty() { + oldIDs = append(oldIDs, oldPartitionIDs[i]) + newIDs = append(newIDs, newID) + bundles = append(bundles, placement.BuildPlacementCopyBundle(oldBundle, newID)) } - job.CtxVars = []interface{}{oldIDs, newIDs} } + job.CtxVars = []interface{}{oldIDs, newIDs} + } - err = infosync.PutRuleBundles(context.TODO(), bundles) - if err != nil { - job.State = model.JobStateCancelled - return 0, errors.Wrapf(err, "failed to notify PD the placement rules") - } + err = infosync.PutRuleBundles(context.TODO(), bundles) + if err != nil { + job.State = model.JobStateCancelled + return 0, errors.Wrapf(err, "failed to notify PD the placement rules") } // Clear the tiflash replica available status. @@ -967,16 +965,12 @@ func onUpdateFlashReplicaStatus(t *meta.Meta, job *model.Job) (ver int64, _ erro } func checkTableNotExists(d *ddlCtx, t *meta.Meta, schemaID int64, tableName string) error { - // d.infoHandle maybe nil in some test. - if d.infoHandle == nil || !d.infoHandle.IsValid() { - return checkTableNotExistsFromStore(t, schemaID, tableName) - } // Try to use memory schema info to check first. currVer, err := t.GetSchemaVersion() if err != nil { return err } - is := d.infoHandle.Get() + is := d.infoCache.GetLatest() if is.SchemaMetaVersion() == currVer { return checkTableNotExistsFromInfoSchema(is, schemaID, tableName) } diff --git a/ddl/table_test.go b/ddl/table_test.go index 5760fc2b152b5..10927908f5289 100644 --- a/ddl/table_test.go +++ b/ddl/table_test.go @@ -355,7 +355,7 @@ func (s *testTableSuite) SetUpSuite(c *C) { WithLease(testLease), ) - s.dbInfo = testSchemaInfo(c, s.d, "test") + s.dbInfo = testSchemaInfo(c, s.d, "test_table") testCreateSchema(c, testNewContext(s.d), s.d, s.dbInfo) } diff --git a/ddl/util/syncer_test.go b/ddl/util/syncer_test.go index b552488ad49de..5a9d41d47e3b8 100644 --- a/ddl/util/syncer_test.go +++ b/ddl/util/syncer_test.go @@ -26,6 +26,7 @@ import ( "github.com/pingcap/parser/terror" . "github.com/pingcap/tidb/ddl" . "github.com/pingcap/tidb/ddl/util" + "github.com/pingcap/tidb/infoschema" "github.com/pingcap/tidb/owner" "github.com/pingcap/tidb/store/mockstore" "go.etcd.io/etcd/clientv3" @@ -69,11 +70,14 @@ func TestSyncerSimple(t *testing.T) { defer clus.Terminate(t) cli := clus.RandClient() ctx := goctx.Background() + ic := infoschema.NewCache(2) + ic.Insert(infoschema.MockInfoSchemaWithSchemaVer(nil, 0)) d := NewDDL( ctx, WithEtcdClient(cli), WithStore(store), WithLease(testLease), + WithInfoCache(ic), ) err = d.Start(nil) if err != nil { @@ -110,11 +114,14 @@ func TestSyncerSimple(t *testing.T) { t.Fatalf("client get global version result not match, err %v", err) } + ic2 := infoschema.NewCache(2) + ic2.Insert(infoschema.MockInfoSchemaWithSchemaVer(nil, 0)) d1 := NewDDL( ctx, WithEtcdClient(cli), WithStore(store), WithLease(testLease), + WithInfoCache(ic2), ) err = d1.Start(nil) if err != nil { diff --git a/domain/domain.go b/domain/domain.go index f4b0ac8900f24..44f6df1aa9086 100644 --- a/domain/domain.go +++ b/domain/domain.go @@ -67,7 +67,7 @@ import ( // Multiple domains can be used in parallel without synchronization. type Domain struct { store kv.Storage - infoHandle *infoschema.Handle + infoCache *infoschema.InfoCache privHandle *privileges.Handle bindHandle *bindinfo.BindHandle statsHandle unsafe.Pointer @@ -79,7 +79,7 @@ type Domain struct { sysSessionPool *sessionPool exit chan struct{} etcdClient *clientv3.Client - gvc GlobalVariableCache + sysVarCache SysVarCache // replaces GlobalVariableCache slowQuery *topNSlowQueries expensiveQueryHandle *expensivequery.Handle wg sync.WaitGroup @@ -92,78 +92,75 @@ type Domain struct { isLostConnectionToPD sync2.AtomicInt32 // !0: true, 0: false. } -// loadInfoSchema loads infoschema at startTS into handle, usedSchemaVersion is the currently used -// infoschema version, if it is the same as the schema version at startTS, we don't need to reload again. -// It returns the latest schema version, the changed table IDs, whether it's a full load and an error. -func (do *Domain) loadInfoSchema(handle *infoschema.Handle, usedSchemaVersion int64, - startTS uint64) (neededSchemaVersion int64, change *tikv.RelatedSchemaChange, fullLoad bool, err error) { +// loadInfoSchema loads infoschema at startTS. +// It returns: +// 1. the needed infoschema +// 2. cache hit indicator +// 3. currentSchemaVersion(before loading) +// 4. the changed table IDs if it is not full load +// 5. an error if any +func (do *Domain) loadInfoSchema(startTS uint64) (infoschema.InfoSchema, bool, int64, *tikv.RelatedSchemaChange, error) { snapshot := do.store.GetSnapshot(kv.NewVersion(startTS)) m := meta.NewSnapshotMeta(snapshot) - neededSchemaVersion, err = m.GetSchemaVersion() + neededSchemaVersion, err := m.GetSchemaVersion() if err != nil { - return 0, nil, fullLoad, err - } - if usedSchemaVersion != 0 && usedSchemaVersion == neededSchemaVersion { - return neededSchemaVersion, nil, fullLoad, nil + return nil, false, 0, nil, err } - // Update self schema version to etcd. - defer func() { - // There are two possibilities for not updating the self schema version to etcd. - // 1. Failed to loading schema information. - // 2. When users use history read feature, the neededSchemaVersion isn't the latest schema version. - if err != nil || neededSchemaVersion < do.InfoSchema().SchemaMetaVersion() { - logutil.BgLogger().Info("do not update self schema version to etcd", - zap.Int64("usedSchemaVersion", usedSchemaVersion), - zap.Int64("neededSchemaVersion", neededSchemaVersion), zap.Error(err)) - return - } + if is := do.infoCache.GetByVersion(neededSchemaVersion); is != nil { + return is, true, 0, nil, nil + } - err = do.ddl.SchemaSyncer().UpdateSelfVersion(context.Background(), neededSchemaVersion) - if err != nil { - logutil.BgLogger().Info("update self version failed", - zap.Int64("usedSchemaVersion", usedSchemaVersion), - zap.Int64("neededSchemaVersion", neededSchemaVersion), zap.Error(err)) - } - }() + currentSchemaVersion := int64(0) + if oldInfoSchema := do.infoCache.GetLatest(); oldInfoSchema != nil { + currentSchemaVersion = oldInfoSchema.SchemaMetaVersion() + } + // TODO: tryLoadSchemaDiffs has potential risks of failure. And it becomes worse in history reading cases. + // It is only kept because there is no alternative diff/partial loading solution. + // And it is only used to diff upgrading the current latest infoschema, if: + // 1. Not first time bootstrap loading, which needs a full load. + // 2. It is newer than the current one, so it will be "the current one" after this function call. + // 3. There are less 100 diffs. startTime := time.Now() - ok, relatedChanges, err := do.tryLoadSchemaDiffs(m, usedSchemaVersion, neededSchemaVersion) - if err != nil { + if currentSchemaVersion != 0 && neededSchemaVersion > currentSchemaVersion && neededSchemaVersion-currentSchemaVersion < 100 { + is, relatedChanges, err := do.tryLoadSchemaDiffs(m, currentSchemaVersion, neededSchemaVersion) + if err == nil { + do.infoCache.Insert(is) + logutil.BgLogger().Info("diff load InfoSchema success", + zap.Int64("currentSchemaVersion", currentSchemaVersion), + zap.Int64("neededSchemaVersion", neededSchemaVersion), + zap.Duration("start time", time.Since(startTime)), + zap.Int64s("phyTblIDs", relatedChanges.PhyTblIDS), + zap.Uint64s("actionTypes", relatedChanges.ActionTypes)) + return is, false, currentSchemaVersion, relatedChanges, nil + } // We can fall back to full load, don't need to return the error. logutil.BgLogger().Error("failed to load schema diff", zap.Error(err)) } - if ok { - logutil.BgLogger().Info("diff load InfoSchema success", - zap.Int64("usedSchemaVersion", usedSchemaVersion), - zap.Int64("neededSchemaVersion", neededSchemaVersion), - zap.Duration("start time", time.Since(startTime)), - zap.Int64s("phyTblIDs", relatedChanges.PhyTblIDS), - zap.Uint64s("actionTypes", relatedChanges.ActionTypes)) - return neededSchemaVersion, relatedChanges, fullLoad, nil - } - fullLoad = true schemas, err := do.fetchAllSchemasWithTables(m) if err != nil { - return 0, nil, fullLoad, err + return nil, false, currentSchemaVersion, nil, err } bundles, err := infosync.GetAllRuleBundles(context.TODO()) if err != nil { - return 0, nil, fullLoad, err + return nil, false, currentSchemaVersion, nil, err } - newISBuilder, err := infoschema.NewBuilder(handle).InitWithDBInfos(schemas, bundles, neededSchemaVersion) + newISBuilder, err := infoschema.NewBuilder(do.Store()).InitWithDBInfos(schemas, bundles, neededSchemaVersion) if err != nil { - return 0, nil, fullLoad, err + return nil, false, currentSchemaVersion, nil, err } logutil.BgLogger().Info("full load InfoSchema success", - zap.Int64("usedSchemaVersion", usedSchemaVersion), + zap.Int64("currentSchemaVersion", currentSchemaVersion), zap.Int64("neededSchemaVersion", neededSchemaVersion), zap.Duration("start time", time.Since(startTime))) - newISBuilder.Build() - return neededSchemaVersion, nil, fullLoad, nil + + is := newISBuilder.Build() + do.infoCache.Insert(is) + return is, false, currentSchemaVersion, nil, nil } func (do *Domain) fetchAllSchemasWithTables(m *meta.Meta) ([]*model.DBInfo, error) { @@ -238,48 +235,31 @@ func (do *Domain) fetchSchemasWithTables(schemas []*model.DBInfo, m *meta.Meta, done <- nil } -const ( - initialVersion = 0 - maxNumberOfDiffsToLoad = 100 -) - -func isTooOldSchema(usedVersion, newVersion int64) bool { - if usedVersion == initialVersion || newVersion-usedVersion > maxNumberOfDiffsToLoad { - return true - } - return false -} - // tryLoadSchemaDiffs tries to only load latest schema changes. // Return true if the schema is loaded successfully. // Return false if the schema can not be loaded by schema diff, then we need to do full load. // The second returned value is the delta updated table and partition IDs. -func (do *Domain) tryLoadSchemaDiffs(m *meta.Meta, usedVersion, newVersion int64) (bool, *tikv.RelatedSchemaChange, error) { - // If there isn't any used version, or used version is too old, we do full load. - // And when users use history read feature, we will set usedVersion to initialVersion, then full load is needed. - if isTooOldSchema(usedVersion, newVersion) { - return false, nil, nil - } +func (do *Domain) tryLoadSchemaDiffs(m *meta.Meta, usedVersion, newVersion int64) (infoschema.InfoSchema, *tikv.RelatedSchemaChange, error) { var diffs []*model.SchemaDiff for usedVersion < newVersion { usedVersion++ diff, err := m.GetSchemaDiff(usedVersion) if err != nil { - return false, nil, err + return nil, nil, err } if diff == nil { // If diff is missing for any version between used and new version, we fall back to full reload. - return false, nil, nil + return nil, nil, fmt.Errorf("failed to get schemadiff") } diffs = append(diffs, diff) } - builder := infoschema.NewBuilder(do.infoHandle).InitWithOldInfoSchema() + builder := infoschema.NewBuilder(do.Store()).InitWithOldInfoSchema(do.infoCache.GetLatest()) phyTblIDs := make([]int64, 0, len(diffs)) actions := make([]uint64, 0, len(diffs)) for _, diff := range diffs { IDs, err := builder.ApplyDiff(m, diff) if err != nil { - return false, nil, err + return nil, nil, err } if canSkipSchemaCheckerDDL(diff.Type) { continue @@ -289,11 +269,11 @@ func (do *Domain) tryLoadSchemaDiffs(m *meta.Meta, usedVersion, newVersion int64 actions = append(actions, uint64(1< 10 { + time.Sleep(time.Duration(count) * time.Second) + } + continue + } + count = 0 + logutil.BgLogger().Debug("Rebuilding sysvar cache from etcd watch event.") + err := do.sysVarCache.RebuildSysVarCache(ctx) + metrics.LoadSysVarCacheCounter.WithLabelValues(metrics.RetLabel(err)).Inc() + if err != nil { + logutil.BgLogger().Error("LoadSysVarCacheLoop failed", zap.Error(err)) + } + } + }() + return nil +} + // PrivilegeHandle returns the MySQLPrivilege. func (do *Domain) PrivilegeHandle() *privileges.Handle { return do.privHandle @@ -1300,7 +1327,10 @@ func (do *Domain) ExpensiveQueryHandle() *expensivequery.Handle { return do.expensiveQueryHandle } -const privilegeKey = "/tidb/privilege" +const ( + privilegeKey = "/tidb/privilege" + sysVarCacheKey = "/tidb/sysvars" +) // NotifyUpdatePrivilege updates privilege key in etcd, TiDB client that watches // the key will get notification. @@ -1322,6 +1352,23 @@ func (do *Domain) NotifyUpdatePrivilege(ctx sessionctx.Context) { } } +// NotifyUpdateSysVarCache updates the sysvar cache key in etcd, which other TiDB +// clients are subscribed to for updates. For the caller, the cache is also built +// synchronously so that the effect is immediate. +func (do *Domain) NotifyUpdateSysVarCache(ctx sessionctx.Context) { + if do.etcdClient != nil { + row := do.etcdClient.KV + _, err := row.Put(context.Background(), sysVarCacheKey, "") + if err != nil { + logutil.BgLogger().Warn("notify update sysvar cache failed", zap.Error(err)) + } + } + // update locally + if err := do.sysVarCache.RebuildSysVarCache(ctx); err != nil { + logutil.BgLogger().Error("rebuilding sysvar cache failed", zap.Error(err)) + } +} + // ServerID gets serverID. func (do *Domain) ServerID() uint64 { return atomic.LoadUint64(&do.serverID) diff --git a/domain/domain_test.go b/domain/domain_test.go index a4432b0fb1fe6..82a583866aad3 100644 --- a/domain/domain_test.go +++ b/domain/domain_test.go @@ -128,7 +128,7 @@ func TestInfo(t *testing.T) { goCtx, ddl.WithEtcdClient(dom.GetEtcdClient()), ddl.WithStore(s), - ddl.WithInfoHandle(dom.infoHandle), + ddl.WithInfoCache(dom.infoCache), ddl.WithLease(ddlLease), ) err = dom.ddl.Start(nil) diff --git a/domain/global_vars_cache.go b/domain/global_vars_cache.go deleted file mode 100644 index 52aa12a5ac955..0000000000000 --- a/domain/global_vars_cache.go +++ /dev/null @@ -1,135 +0,0 @@ -// Copyright 2018 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// See the License for the specific language governing permissions and -// limitations under the License. - -package domain - -import ( - "fmt" - "sync" - "time" - - "github.com/pingcap/parser/ast" - "github.com/pingcap/tidb/sessionctx/variable" - "github.com/pingcap/tidb/util/chunk" - "github.com/pingcap/tidb/util/logutil" - "github.com/pingcap/tidb/util/stmtsummary" - "go.uber.org/zap" - "golang.org/x/sync/singleflight" -) - -// GlobalVariableCache caches global variables. -type GlobalVariableCache struct { - sync.RWMutex - lastModify time.Time - rows []chunk.Row - fields []*ast.ResultField - - // Unit test may like to disable it. - disable bool - SingleFight singleflight.Group -} - -// GlobalVariableCacheExpiry is the global variable cache TTL. -const GlobalVariableCacheExpiry = 2 * time.Second - -// Update updates the global variable cache. -func (gvc *GlobalVariableCache) Update(rows []chunk.Row, fields []*ast.ResultField) { - gvc.Lock() - gvc.lastModify = time.Now() - gvc.rows = rows - gvc.fields = fields - gvc.Unlock() - - checkEnableServerGlobalVar(rows) -} - -// Get gets the global variables from cache. -func (gvc *GlobalVariableCache) Get() (succ bool, rows []chunk.Row, fields []*ast.ResultField) { - gvc.RLock() - defer gvc.RUnlock() - if time.Since(gvc.lastModify) < GlobalVariableCacheExpiry { - succ, rows, fields = !gvc.disable, gvc.rows, gvc.fields - return - } - succ = false - return -} - -type loadResult struct { - rows []chunk.Row - fields []*ast.ResultField -} - -// LoadGlobalVariables will load from global cache first, loadFn will be executed if cache is not valid -func (gvc *GlobalVariableCache) LoadGlobalVariables(loadFn func() ([]chunk.Row, []*ast.ResultField, error)) ([]chunk.Row, []*ast.ResultField, error) { - succ, rows, fields := gvc.Get() - if succ { - return rows, fields, nil - } - fn := func() (interface{}, error) { - resRows, resFields, loadErr := loadFn() - if loadErr != nil { - return nil, loadErr - } - gvc.Update(resRows, resFields) - return &loadResult{resRows, resFields}, nil - } - res, err, _ := gvc.SingleFight.Do("loadGlobalVariable", fn) - if err != nil { - return nil, nil, err - } - loadRes := res.(*loadResult) - return loadRes.rows, loadRes.fields, nil -} - -// Disable disables the global variable cache, used in test only. -func (gvc *GlobalVariableCache) Disable() { - gvc.Lock() - defer gvc.Unlock() - gvc.disable = true -} - -// checkEnableServerGlobalVar processes variables that acts in server and global level. -func checkEnableServerGlobalVar(rows []chunk.Row) { - for _, row := range rows { - sVal := "" - if !row.IsNull(1) { - sVal = row.GetString(1) - } - var err error - switch row.GetString(0) { - case variable.TiDBEnableStmtSummary: - err = stmtsummary.StmtSummaryByDigestMap.SetEnabled(sVal, false) - case variable.TiDBStmtSummaryInternalQuery: - err = stmtsummary.StmtSummaryByDigestMap.SetEnabledInternalQuery(sVal, false) - case variable.TiDBStmtSummaryRefreshInterval: - err = stmtsummary.StmtSummaryByDigestMap.SetRefreshInterval(sVal, false) - case variable.TiDBStmtSummaryHistorySize: - err = stmtsummary.StmtSummaryByDigestMap.SetHistorySize(sVal, false) - case variable.TiDBStmtSummaryMaxStmtCount: - err = stmtsummary.StmtSummaryByDigestMap.SetMaxStmtCount(sVal, false) - case variable.TiDBStmtSummaryMaxSQLLength: - err = stmtsummary.StmtSummaryByDigestMap.SetMaxSQLLength(sVal, false) - case variable.TiDBCapturePlanBaseline: - variable.CapturePlanBaseline.Set(sVal, false) - } - if err != nil { - logutil.BgLogger().Error(fmt.Sprintf("load global variable %s error", row.GetString(0)), zap.Error(err)) - } - } -} - -// GetGlobalVarsCache gets the global variable cache. -func (do *Domain) GetGlobalVarsCache() *GlobalVariableCache { - return &do.gvc -} diff --git a/domain/global_vars_cache_test.go b/domain/global_vars_cache_test.go deleted file mode 100644 index 7358d709986af..0000000000000 --- a/domain/global_vars_cache_test.go +++ /dev/null @@ -1,221 +0,0 @@ -// Copyright 2019 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// See the License for the specific language governing permissions and -// limitations under the License. - -package domain - -import ( - "sync" - "sync/atomic" - "time" - - . "github.com/pingcap/check" - "github.com/pingcap/parser/ast" - "github.com/pingcap/parser/charset" - "github.com/pingcap/parser/model" - "github.com/pingcap/parser/mysql" - "github.com/pingcap/tidb/sessionctx/variable" - "github.com/pingcap/tidb/store/mockstore" - "github.com/pingcap/tidb/types" - "github.com/pingcap/tidb/util/chunk" - "github.com/pingcap/tidb/util/stmtsummary" - "github.com/pingcap/tidb/util/testleak" -) - -var _ = SerialSuites(&testGVCSuite{}) - -type testGVCSuite struct{} - -func (gvcSuite *testGVCSuite) TestSimple(c *C) { - defer testleak.AfterTest(c)() - testleak.BeforeTest() - - store, err := mockstore.NewMockStore() - c.Assert(err, IsNil) - defer func() { - err := store.Close() - c.Assert(err, IsNil) - }() - ddlLease := 50 * time.Millisecond - dom := NewDomain(store, ddlLease, 0, 0, mockFactory) - err = dom.Init(ddlLease, sysMockFactory) - c.Assert(err, IsNil) - defer dom.Close() - - // Get empty global vars cache. - gvc := dom.GetGlobalVarsCache() - succ, rows, fields := gvc.Get() - c.Assert(succ, IsFalse) - c.Assert(rows, IsNil) - c.Assert(fields, IsNil) - // Get a variable from global vars cache. - rf := getResultField("c", 1, 0) - rf1 := getResultField("c1", 2, 1) - ft := &types.FieldType{ - Tp: mysql.TypeString, - Charset: charset.CharsetBin, - Collate: charset.CollationBin, - } - ft1 := &types.FieldType{ - Tp: mysql.TypeString, - Charset: charset.CharsetBin, - Collate: charset.CollationBin, - } - ck := chunk.NewChunkWithCapacity([]*types.FieldType{ft, ft1}, 1024) - ck.AppendString(0, "variable1") - ck.AppendString(1, "value1") - row := ck.GetRow(0) - gvc.Update([]chunk.Row{row}, []*ast.ResultField{rf, rf1}) - succ, rows, fields = gvc.Get() - c.Assert(succ, IsTrue) - c.Assert(rows[0], Equals, row) - c.Assert(fields, DeepEquals, []*ast.ResultField{rf, rf1}) - // Disable the cache. - gvc.Disable() - succ, rows, fields = gvc.Get() - c.Assert(succ, IsFalse) - c.Assert(rows[0], Equals, row) - c.Assert(fields, DeepEquals, []*ast.ResultField{rf, rf1}) -} - -func getResultField(colName string, id, offset int) *ast.ResultField { - return &ast.ResultField{ - Column: &model.ColumnInfo{ - Name: model.NewCIStr(colName), - ID: int64(id), - Offset: offset, - FieldType: types.FieldType{ - Tp: mysql.TypeString, - Charset: charset.CharsetUTF8, - Collate: charset.CollationUTF8, - }, - }, - TableAsName: model.NewCIStr("tbl"), - DBName: model.NewCIStr("test"), - } -} - -func (gvcSuite *testGVCSuite) TestConcurrentOneFlight(c *C) { - defer testleak.AfterTest(c)() - testleak.BeforeTest() - gvc := &GlobalVariableCache{} - succ, rows, fields := gvc.Get() - c.Assert(succ, IsFalse) - c.Assert(rows, IsNil) - c.Assert(fields, IsNil) - - // Get a variable from global vars cache. - rf := getResultField("c", 1, 0) - rf1 := getResultField("c1", 2, 1) - ft := &types.FieldType{ - Tp: mysql.TypeString, - Charset: charset.CharsetBin, - Collate: charset.CollationBin, - } - ft1 := &types.FieldType{ - Tp: mysql.TypeString, - Charset: charset.CharsetBin, - Collate: charset.CollationBin, - } - ckLow := chunk.NewChunkWithCapacity([]*types.FieldType{ft, ft1}, 1024) - val := "fromStorage" - val1 := "fromStorage1" - ckLow.AppendString(0, val) - ckLow.AppendString(1, val1) - - // Let cache become invalid, and try concurrent load - counter := int32(0) - waitToStart := new(sync.WaitGroup) - waitToStart.Add(1) - gvc.lastModify = time.Now().Add(time.Duration(-10) * time.Second) - loadFunc := func() ([]chunk.Row, []*ast.ResultField, error) { - time.Sleep(100 * time.Millisecond) - atomic.AddInt32(&counter, 1) - return []chunk.Row{ckLow.GetRow(0)}, []*ast.ResultField{rf, rf1}, nil - } - wg := new(sync.WaitGroup) - worker := 100 - resArray := make([]loadResult, worker) - for i := 0; i < worker; i++ { - wg.Add(1) - go func(idx int) { - defer wg.Done() - waitToStart.Wait() - resRow, resField, _ := gvc.LoadGlobalVariables(loadFunc) - resArray[idx].rows = resRow - resArray[idx].fields = resField - }(i) - } - waitToStart.Done() - wg.Wait() - succ, rows, fields = gvc.Get() - c.Assert(counter, Equals, int32(1)) - c.Assert(resArray[0].rows[0].GetString(0), Equals, val) - c.Assert(resArray[0].rows[0].GetString(1), Equals, val1) - for i := 0; i < worker; i++ { - c.Assert(resArray[0].rows[0], Equals, resArray[i].rows[0]) - c.Assert(resArray[i].rows[0].GetString(0), Equals, val) - c.Assert(resArray[i].rows[0].GetString(1), Equals, val1) - } - // Validate cache - c.Assert(succ, IsTrue) - c.Assert(rows[0], Equals, resArray[0].rows[0]) - c.Assert(fields, DeepEquals, []*ast.ResultField{rf, rf1}) -} - -func (gvcSuite *testGVCSuite) TestCheckEnableStmtSummary(c *C) { - defer testleak.AfterTest(c)() - testleak.BeforeTest() - - store, err := mockstore.NewMockStore() - c.Assert(err, IsNil) - defer func() { - err := store.Close() - c.Assert(err, IsNil) - }() - ddlLease := 50 * time.Millisecond - dom := NewDomain(store, ddlLease, 0, 0, mockFactory) - err = dom.Init(ddlLease, sysMockFactory) - c.Assert(err, IsNil) - defer dom.Close() - - gvc := dom.GetGlobalVarsCache() - - rf := getResultField("c", 1, 0) - rf1 := getResultField("c1", 2, 1) - ft := &types.FieldType{ - Tp: mysql.TypeString, - Charset: charset.CharsetBin, - Collate: charset.CollationBin, - } - ft1 := &types.FieldType{ - Tp: mysql.TypeString, - Charset: charset.CharsetBin, - Collate: charset.CollationBin, - } - - err = stmtsummary.StmtSummaryByDigestMap.SetEnabled("0", false) - c.Assert(err, IsNil) - ck := chunk.NewChunkWithCapacity([]*types.FieldType{ft, ft1}, 1024) - ck.AppendString(0, variable.TiDBEnableStmtSummary) - ck.AppendString(1, "1") - row := ck.GetRow(0) - gvc.Update([]chunk.Row{row}, []*ast.ResultField{rf, rf1}) - c.Assert(stmtsummary.StmtSummaryByDigestMap.Enabled(), Equals, true) - - ck = chunk.NewChunkWithCapacity([]*types.FieldType{ft, ft1}, 1024) - ck.AppendString(0, variable.TiDBEnableStmtSummary) - ck.AppendString(1, "0") - row = ck.GetRow(0) - gvc.Update([]chunk.Row{row}, []*ast.ResultField{rf, rf1}) - c.Assert(stmtsummary.StmtSummaryByDigestMap.Enabled(), Equals, false) -} diff --git a/domain/schema_validator.go b/domain/schema_validator.go index b983eff1d6203..a8baa49db93b9 100644 --- a/domain/schema_validator.go +++ b/domain/schema_validator.go @@ -234,8 +234,9 @@ func (s *schemaValidator) Check(txnTS uint64, schemaVer int64, relatedPhysicalTa // Schema changed, result decided by whether related tables change. if schemaVer < s.latestSchemaVer { - // The DDL relatedPhysicalTableIDs is empty. - if len(relatedPhysicalTableIDs) == 0 { + // When a transaction executes a DDL, relatedPhysicalTableIDs is nil. + // When a transaction only contains DML on temporary tables, relatedPhysicalTableIDs is []. + if relatedPhysicalTableIDs == nil { logutil.BgLogger().Info("the related physical table ID is empty", zap.Int64("schemaVer", schemaVer), zap.Int64("latestSchemaVer", s.latestSchemaVer)) return nil, ResultFail diff --git a/domain/sysvar_cache.go b/domain/sysvar_cache.go new file mode 100644 index 0000000000000..23c9688ea2f81 --- /dev/null +++ b/domain/sysvar_cache.go @@ -0,0 +1,167 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package domain + +import ( + "context" + "fmt" + "sync" + + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/sessionctx/variable" + "github.com/pingcap/tidb/util/logutil" + "github.com/pingcap/tidb/util/sqlexec" + "github.com/pingcap/tidb/util/stmtsummary" + "go.uber.org/zap" +) + +// The sysvar cache replaces the GlobalVariableCache. +// It is an improvement because it operates similar to privilege cache, +// where it caches for 5 minutes instead of 2 seconds, plus it listens on etcd +// for updates from other servers. + +// SysVarCache represents the cache of system variables broken up into session and global scope. +type SysVarCache struct { + sync.RWMutex + global map[string]string + session map[string]string +} + +// GetSysVarCache gets the global variable cache. +func (do *Domain) GetSysVarCache() *SysVarCache { + return &do.sysVarCache +} + +func (svc *SysVarCache) rebuildCacheIfNeeded(ctx sessionctx.Context) (err error) { + svc.RLock() + cacheNeedsRebuild := len(svc.session) == 0 || len(svc.global) == 0 + svc.RUnlock() + if cacheNeedsRebuild { + logutil.BgLogger().Warn("sysvar cache is empty, triggering rebuild") + if err = svc.RebuildSysVarCache(ctx); err != nil { + logutil.BgLogger().Error("rebuilding sysvar cache failed", zap.Error(err)) + } + } + return err +} + +// GetSessionCache gets a copy of the session sysvar cache. +// The intention is to copy it directly to the systems[] map +// on creating a new session. +func (svc *SysVarCache) GetSessionCache(ctx sessionctx.Context) (map[string]string, error) { + if err := svc.rebuildCacheIfNeeded(ctx); err != nil { + return nil, err + } + svc.RLock() + defer svc.RUnlock() + // Perform a deep copy since this will be assigned directly to the session + newMap := make(map[string]string, len(svc.session)) + for k, v := range svc.session { + newMap[k] = v + } + return newMap, nil +} + +// GetGlobalVar gets an individual global var from the sysvar cache. +func (svc *SysVarCache) GetGlobalVar(ctx sessionctx.Context, name string) (string, error) { + if err := svc.rebuildCacheIfNeeded(ctx); err != nil { + return "", err + } + svc.RLock() + defer svc.RUnlock() + + if val, ok := svc.global[name]; ok { + return val, nil + } + logutil.BgLogger().Warn("could not find key in global cache", zap.String("name", name)) + return "", variable.ErrUnknownSystemVar.GenWithStackByArgs(name) +} + +func (svc *SysVarCache) fetchTableValues(ctx sessionctx.Context) (map[string]string, error) { + tableContents := make(map[string]string) + // Copy all variables from the table to tableContents + exec := ctx.(sqlexec.RestrictedSQLExecutor) + stmt, err := exec.ParseWithParams(context.Background(), `SELECT variable_name, variable_value FROM mysql.global_variables`) + if err != nil { + return tableContents, err + } + rows, _, err := exec.ExecRestrictedStmt(context.TODO(), stmt) + if err != nil { + return nil, err + } + for _, row := range rows { + name := row.GetString(0) + val := row.GetString(1) + tableContents[name] = val + } + return tableContents, nil +} + +// RebuildSysVarCache rebuilds the sysvar cache both globally and for session vars. +// It needs to be called when sysvars are added or removed. +func (svc *SysVarCache) RebuildSysVarCache(ctx sessionctx.Context) error { + newSessionCache := make(map[string]string) + newGlobalCache := make(map[string]string) + tableContents, err := svc.fetchTableValues(ctx) + if err != nil { + return err + } + + for _, sv := range variable.GetSysVars() { + sVal := sv.Value + if _, ok := tableContents[sv.Name]; ok { + sVal = tableContents[sv.Name] + } + if sv.HasSessionScope() { + newSessionCache[sv.Name] = sVal + } + if sv.HasGlobalScope() { + newGlobalCache[sv.Name] = sVal + } + // Propagate any changes to the server scoped variables + checkEnableServerGlobalVar(sv.Name, sVal) + } + + logutil.BgLogger().Debug("rebuilding sysvar cache") + + svc.Lock() + defer svc.Unlock() + svc.session = newSessionCache + svc.global = newGlobalCache + return nil +} + +// checkEnableServerGlobalVar processes variables that acts in server and global level. +func checkEnableServerGlobalVar(name, sVal string) { + var err error + switch name { + case variable.TiDBEnableStmtSummary: + err = stmtsummary.StmtSummaryByDigestMap.SetEnabled(sVal, false) + case variable.TiDBStmtSummaryInternalQuery: + err = stmtsummary.StmtSummaryByDigestMap.SetEnabledInternalQuery(sVal, false) + case variable.TiDBStmtSummaryRefreshInterval: + err = stmtsummary.StmtSummaryByDigestMap.SetRefreshInterval(sVal, false) + case variable.TiDBStmtSummaryHistorySize: + err = stmtsummary.StmtSummaryByDigestMap.SetHistorySize(sVal, false) + case variable.TiDBStmtSummaryMaxStmtCount: + err = stmtsummary.StmtSummaryByDigestMap.SetMaxStmtCount(sVal, false) + case variable.TiDBStmtSummaryMaxSQLLength: + err = stmtsummary.StmtSummaryByDigestMap.SetMaxSQLLength(sVal, false) + case variable.TiDBCapturePlanBaseline: + variable.CapturePlanBaseline.Set(sVal, false) + } + if err != nil { + logutil.BgLogger().Error(fmt.Sprintf("load global variable %s error", name), zap.Error(err)) + } +} diff --git a/executor/executor_test.go b/executor/executor_test.go index 5e6b4490f5eb6..7fa0d7b0d10bd 100644 --- a/executor/executor_test.go +++ b/executor/executor_test.go @@ -2267,8 +2267,6 @@ func (s *testSuiteP2) TestSQLMode(c *C) { tk.MustExec("set sql_mode = 'STRICT_TRANS_TABLES'") tk.MustExec("set @@global.sql_mode = ''") - // Disable global variable cache, so load global session variable take effect immediate. - s.domain.GetGlobalVarsCache().Disable() tk2 := testkit.NewTestKit(c, s.store) tk2.MustExec("use test") tk2.MustExec("drop table if exists t2") diff --git a/executor/partition_table_test.go b/executor/partition_table_test.go index 038dbc6d6d47e..f7840d290e5e1 100644 --- a/executor/partition_table_test.go +++ b/executor/partition_table_test.go @@ -24,6 +24,7 @@ import ( "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/util/israce" "github.com/pingcap/tidb/util/testkit" + "github.com/pingcap/tidb/util/testutil" ) func (s *partitionTableSuite) TestFourReader(c *C) { @@ -572,6 +573,240 @@ func (s *partitionTableSuite) TestGlobalStatsAndSQLBinding(c *C) { tk.MustIndexLookup("select * from tlist where a<1") } +func (s *partitionTableSuite) TestPartitionTableWithDifferentJoin(c *C) { + if israce.RaceEnabled { + c.Skip("exhaustive types test, skip race test") + } + + tk := testkit.NewTestKitWithInit(c, s.store) + tk.MustExec("create database test_partition_joins") + tk.MustExec("use test_partition_joins") + tk.MustExec("set @@tidb_partition_prune_mode = 'dynamic'") + + // hash and range partition + tk.MustExec("create table thash(a int, b int, key(a)) partition by hash(a) partitions 4") + tk.MustExec("create table tregular1(a int, b int, key(a))") + + tk.MustExec(`create table trange(a int, b int, key(a)) partition by range(a) ( + partition p0 values less than (200), + partition p1 values less than (400), + partition p2 values less than (600), + partition p3 values less than (800), + partition p4 values less than (1001))`) + tk.MustExec("create table tregular2(a int, b int, key(a))") + + vals := make([]string, 0, 2000) + for i := 0; i < 2000; i++ { + vals = append(vals, fmt.Sprintf("(%v, %v)", rand.Intn(1000), rand.Intn(1000))) + } + tk.MustExec("insert into thash values " + strings.Join(vals, ",")) + tk.MustExec("insert into tregular1 values " + strings.Join(vals, ",")) + + vals = make([]string, 0, 2000) + for i := 0; i < 2000; i++ { + vals = append(vals, fmt.Sprintf("(%v, %v)", rand.Intn(1000), rand.Intn(1000))) + } + tk.MustExec("insert into trange values " + strings.Join(vals, ",")) + tk.MustExec("insert into tregular2 values " + strings.Join(vals, ",")) + + // random params + x1 := rand.Intn(1000) + x2 := rand.Intn(1000) + x3 := rand.Intn(1000) + x4 := rand.Intn(1000) + + // group 1 + // hash_join range partition and hash partition + queryHash := fmt.Sprintf("select /*+ hash_join(trange, thash) */ * from trange, thash where trange.b=thash.b and thash.a = %v and trange.a > %v;", x1, x2) + queryRegular := fmt.Sprintf("select /*+ hash_join(tregular2, tregular1) */ * from tregular2, tregular1 where tregular2.b=tregular1.b and tregular1.a = %v and tregular2.a > %v;", x1, x2) + c.Assert(tk.HasPlan(queryHash, "HashJoin"), IsTrue) + tk.MustQuery(queryHash).Sort().Check(tk.MustQuery(queryRegular).Sort().Rows()) + + queryHash = fmt.Sprintf("select /*+ hash_join(trange, thash) */ * from trange, thash where trange.a=thash.a and thash.a > %v;", x1) + queryRegular = fmt.Sprintf("select /*+ hash_join(tregular2, tregular1) */ * from tregular2, tregular1 where tregular2.a=tregular1.a and tregular1.a > %v;", x1) + c.Assert(tk.HasPlan(queryHash, "HashJoin"), IsTrue) + tk.MustQuery(queryHash).Sort().Check(tk.MustQuery(queryRegular).Sort().Rows()) + + queryHash = fmt.Sprintf("select /*+ hash_join(trange, thash) */ * from trange, thash where trange.a=thash.a and trange.b = thash.b and thash.a > %v;", x1) + queryRegular = fmt.Sprintf("select /*+ hash_join(tregular2, tregular1) */ * from tregular2, tregular1 where tregular2.a=tregular1.a and tregular1.b = tregular2.b and tregular1.a > %v;", x1) + c.Assert(tk.HasPlan(queryHash, "HashJoin"), IsTrue) + tk.MustQuery(queryHash).Sort().Check(tk.MustQuery(queryRegular).Sort().Rows()) + + queryHash = fmt.Sprintf("select /*+ hash_join(trange, thash) */ * from trange, thash where trange.a=thash.a and thash.a = %v;", x1) + queryRegular = fmt.Sprintf("select /*+ hash_join(tregular2, tregular1) */ * from tregular2, tregular1 where tregular2.a=tregular1.a and tregular1.a = %v;", x1) + c.Assert(tk.HasPlan(queryHash, "HashJoin"), IsTrue) + tk.MustQuery(queryHash).Sort().Check(tk.MustQuery(queryRegular).Sort().Rows()) + + // group 2 + // hash_join range partition and regular table + queryHash = fmt.Sprintf("select /*+ hash_join(trange, tregular1) */ * from trange, tregular1 where trange.a = tregular1.a and trange.a >= %v and tregular1.a > %v;", x1, x2) + queryRegular = fmt.Sprintf("select /*+ hash_join(tregular2, tregular1) */ * from tregular2, tregular1 where tregular2.a = tregular1.a and tregular2.a >= %v and tregular1.a > %v;", x1, x2) + c.Assert(tk.HasPlan(queryHash, "HashJoin"), IsTrue) + tk.MustQuery(queryHash).Sort().Check(tk.MustQuery(queryRegular).Sort().Rows()) + + queryHash = fmt.Sprintf("select /*+ hash_join(trange, tregular1) */ * from trange, tregular1 where trange.a = tregular1.a and trange.a in (%v, %v, %v);", x1, x2, x3) + queryRegular = fmt.Sprintf("select /*+ hash_join(tregular2, tregular1) */ * from tregular2, tregular1 where tregular2.a = tregular1.a and tregular2.a in (%v, %v, %v);", x1, x2, x3) + c.Assert(tk.HasPlan(queryHash, "HashJoin"), IsTrue) + tk.MustQuery(queryHash).Sort().Check(tk.MustQuery(queryRegular).Sort().Rows()) + + queryHash = fmt.Sprintf("select /*+ hash_join(trange, tregular1) */ * from trange, tregular1 where trange.a = tregular1.a and tregular1.a >= %v;", x1) + queryRegular = fmt.Sprintf("select /*+ hash_join(tregular2, tregular1) */ * from tregular2, tregular1 where tregular2.a = tregular1.a and tregular1.a >= %v;", x1) + c.Assert(tk.HasPlan(queryHash, "HashJoin"), IsTrue) + tk.MustQuery(queryHash).Sort().Check(tk.MustQuery(queryRegular).Sort().Rows()) + + // group 3 + // merge_join range partition and hash partition + queryHash = fmt.Sprintf("select /*+ merge_join(trange, thash) */ * from trange, thash where trange.b=thash.b and thash.a = %v and trange.a > %v;", x1, x2) + queryRegular = fmt.Sprintf("select /*+ merge_join(tregular2, tregular1) */ * from tregular2, tregular1 where tregular2.b=tregular1.b and tregular1.a = %v and tregular2.a > %v;", x1, x2) + c.Assert(tk.HasPlan(queryHash, "MergeJoin"), IsTrue) + tk.MustQuery(queryHash).Sort().Check(tk.MustQuery(queryRegular).Sort().Rows()) + + queryHash = fmt.Sprintf("select /*+ merge_join(trange, thash) */ * from trange, thash where trange.a=thash.a and thash.a > %v;", x1) + queryRegular = fmt.Sprintf("select /*+ merge_join(tregular2, tregular1) */ * from tregular2, tregular1 where tregular2.a=tregular1.a and tregular1.a > %v;", x1) + c.Assert(tk.HasPlan(queryHash, "MergeJoin"), IsTrue) + tk.MustQuery(queryHash).Sort().Check(tk.MustQuery(queryRegular).Sort().Rows()) + + queryHash = fmt.Sprintf("select /*+ merge_join(trange, thash) */ * from trange, thash where trange.a=thash.a and trange.b = thash.b and thash.a > %v;", x1) + queryRegular = fmt.Sprintf("select /*+ merge_join(tregular2, tregular1) */ * from tregular2, tregular1 where tregular2.a=tregular1.a and tregular1.b = tregular2.b and tregular1.a > %v;", x1) + c.Assert(tk.HasPlan(queryHash, "MergeJoin"), IsTrue) + tk.MustQuery(queryHash).Sort().Check(tk.MustQuery(queryRegular).Sort().Rows()) + + queryHash = fmt.Sprintf("select /*+ merge_join(trange, thash) */ * from trange, thash where trange.a=thash.a and thash.a = %v;", x1) + queryRegular = fmt.Sprintf("select /*+ merge_join(tregular2, tregular1) */ * from tregular2, tregular1 where tregular2.a=tregular1.a and tregular1.a = %v;", x1) + c.Assert(tk.HasPlan(queryHash, "MergeJoin"), IsTrue) + tk.MustQuery(queryHash).Sort().Check(tk.MustQuery(queryRegular).Sort().Rows()) + + // group 4 + // merge_join range partition and regular table + queryHash = fmt.Sprintf("select /*+ merge_join(trange, tregular1) */ * from trange, tregular1 where trange.a = tregular1.a and trange.a >= %v and tregular1.a > %v;", x1, x2) + queryRegular = fmt.Sprintf("select /*+ merge_join(tregular2, tregular1) */ * from tregular2, tregular1 where tregular2.a = tregular1.a and tregular2.a >= %v and tregular1.a > %v;", x1, x2) + c.Assert(tk.HasPlan(queryHash, "MergeJoin"), IsTrue) + tk.MustQuery(queryHash).Sort().Check(tk.MustQuery(queryRegular).Sort().Rows()) + + queryHash = fmt.Sprintf("select /*+ merge_join(trange, tregular1) */ * from trange, tregular1 where trange.a = tregular1.a and trange.a in (%v, %v, %v);", x1, x2, x3) + queryRegular = fmt.Sprintf("select /*+ merge_join(tregular2, tregular1) */ * from tregular2, tregular1 where tregular2.a = tregular1.a and tregular2.a in (%v, %v, %v);", x1, x2, x3) + c.Assert(tk.HasPlan(queryHash, "MergeJoin"), IsTrue) + tk.MustQuery(queryHash).Sort().Check(tk.MustQuery(queryRegular).Sort().Rows()) + + queryHash = fmt.Sprintf("select /*+ merge_join(trange, tregular1) */ * from trange, tregular1 where trange.a = tregular1.a and tregular1.a >= %v;", x1) + queryRegular = fmt.Sprintf("select /*+ merge_join(tregular2, tregular1) */ * from tregular2, tregular1 where tregular2.a = tregular1.a and tregular1.a >= %v;", x1) + c.Assert(tk.HasPlan(queryHash, "MergeJoin"), IsTrue) + tk.MustQuery(queryHash).Sort().Check(tk.MustQuery(queryRegular).Sort().Rows()) + + // new table instances + tk.MustExec("create table thash2(a int, b int, index idx(a)) partition by hash(a) partitions 4") + tk.MustExec("create table tregular3(a int, b int, index idx(a))") + + tk.MustExec(`create table trange2(a int, b int, index idx(a)) partition by range(a) ( + partition p0 values less than (200), + partition p1 values less than (400), + partition p2 values less than (600), + partition p3 values less than (800), + partition p4 values less than (1001))`) + tk.MustExec("create table tregular4(a int, b int, index idx(a))") + + vals = make([]string, 0, 2000) + for i := 0; i < 2000; i++ { + vals = append(vals, fmt.Sprintf("(%v, %v)", rand.Intn(1000), rand.Intn(1000))) + } + tk.MustExec("insert into thash2 values " + strings.Join(vals, ",")) + tk.MustExec("insert into tregular3 values " + strings.Join(vals, ",")) + + vals = make([]string, 0, 2000) + for i := 0; i < 2000; i++ { + vals = append(vals, fmt.Sprintf("(%v, %v)", rand.Intn(1000), rand.Intn(1000))) + } + tk.MustExec("insert into trange2 values " + strings.Join(vals, ",")) + tk.MustExec("insert into tregular4 values " + strings.Join(vals, ",")) + + // group 5 + // index_merge_join range partition and range partition + // Currently don't support index merge join on two partition tables. Only test warning. + queryHash = fmt.Sprintf("select /*+ inl_merge_join(trange, trange2) */ * from trange, trange2 where trange.a=trange2.a and trange.a > %v;", x1) + // queryRegular = fmt.Sprintf("select /*+ inl_merge_join(tregular2, tregular4) */ * from tregular2, tregular4 where tregular2.a=tregular4.a and tregular2.a > %v;", x1) + // c.Assert(tk.HasPlan(queryHash, "IndexMergeJoin"), IsTrue) + // tk.MustQuery(queryHash).Sort().Check(tk.MustQuery(queryRegular).Sort().Rows()) + tk.MustQuery(queryHash) + tk.MustQuery("show warnings").Check(testutil.RowsWithSep("|", "Warning|1815|Optimizer Hint /*+ INL_MERGE_JOIN(trange, trange2) */ is inapplicable")) + + queryHash = fmt.Sprintf("select /*+ inl_merge_join(trange, trange2) */ * from trange, trange2 where trange.a=trange2.a and trange.a > %v and trange2.a > %v;", x1, x2) + // queryRegular = fmt.Sprintf("select /*+ inl_merge_join(tregular2, tregular4) */ * from tregular2, tregular4 where tregular2.a=tregular4.a and tregular2.a > %v and tregular4.a > %v;", x1, x2) + // c.Assert(tk.HasPlan(queryHash, "IndexMergeJoin"), IsTrue) + // tk.MustQuery(queryHash).Sort().Check(tk.MustQuery(queryRegular).Sort().Rows()) + tk.MustQuery(queryHash) + tk.MustQuery("show warnings").Check(testutil.RowsWithSep("|", "Warning|1815|Optimizer Hint /*+ INL_MERGE_JOIN(trange, trange2) */ is inapplicable")) + + queryHash = fmt.Sprintf("select /*+ inl_merge_join(trange, trange2) */ * from trange, trange2 where trange.a=trange2.a and trange.a > %v and trange.b > %v;", x1, x2) + // queryRegular = fmt.Sprintf("select /*+ inl_merge_join(tregular2, tregular4) */ * from tregular2, tregular4 where tregular2.a=tregular4.a and tregular2.a > %v and tregular2.b > %v;", x1, x2) + // c.Assert(tk.HasPlan(queryHash, "IndexMergeJoin"), IsTrue) + // tk.MustQuery(queryHash).Sort().Check(tk.MustQuery(queryRegular).Sort().Rows()) + tk.MustQuery(queryHash) + tk.MustQuery("show warnings").Check(testutil.RowsWithSep("|", "Warning|1815|Optimizer Hint /*+ INL_MERGE_JOIN(trange, trange2) */ is inapplicable")) + + queryHash = fmt.Sprintf("select /*+ inl_merge_join(trange, trange2) */ * from trange, trange2 where trange.a=trange2.a and trange.a > %v and trange2.b > %v;", x1, x2) + // queryRegular = fmt.Sprintf("select /*+ inl_merge_join(tregular2, tregular4) */ * from tregular2, tregular4 where tregular2.a=tregular4.a and tregular2.a > %v and tregular4.b > %v;", x1, x2) + // c.Assert(tk.HasPlan(queryHash, "IndexMergeJoin"), IsTrue) + // tk.MustQuery(queryHash).Sort().Check(tk.MustQuery(queryRegular).Sort().Rows()) + tk.MustQuery(queryHash) + tk.MustQuery("show warnings").Check(testutil.RowsWithSep("|", "Warning|1815|Optimizer Hint /*+ INL_MERGE_JOIN(trange, trange2) */ is inapplicable")) + + // group 6 + // index_merge_join range partition and regualr table + queryHash = fmt.Sprintf("select /*+ inl_merge_join(trange, tregular4) */ * from trange, tregular4 where trange.a=tregular4.a and trange.a > %v;", x1) + queryRegular = fmt.Sprintf("select /*+ inl_merge_join(tregular2, tregular4) */ * from tregular2, tregular4 where tregular2.a=tregular4.a and tregular2.a > %v;", x1) + c.Assert(tk.HasPlan(queryHash, "IndexMergeJoin"), IsTrue) + tk.MustQuery(queryHash).Sort().Check(tk.MustQuery(queryRegular).Sort().Rows()) + + queryHash = fmt.Sprintf("select /*+ inl_merge_join(trange, tregular4) */ * from trange, tregular4 where trange.a=tregular4.a and trange.a > %v and tregular4.a > %v;", x1, x2) + queryRegular = fmt.Sprintf("select /*+ inl_merge_join(tregular2, tregular4) */ * from tregular2, tregular4 where tregular2.a=tregular4.a and tregular2.a > %v and tregular4.a > %v;", x1, x2) + c.Assert(tk.HasPlan(queryHash, "IndexMergeJoin"), IsTrue) + tk.MustQuery(queryHash).Sort().Check(tk.MustQuery(queryRegular).Sort().Rows()) + + queryHash = fmt.Sprintf("select /*+ inl_merge_join(trange, tregular4) */ * from trange, tregular4 where trange.a=tregular4.a and trange.a > %v and trange.b > %v;", x1, x2) + queryRegular = fmt.Sprintf("select /*+ inl_merge_join(tregular2, tregular4) */ * from tregular2, tregular4 where tregular2.a=tregular4.a and tregular2.a > %v and tregular2.b > %v;", x1, x2) + c.Assert(tk.HasPlan(queryHash, "IndexMergeJoin"), IsTrue) + tk.MustQuery(queryHash).Sort().Check(tk.MustQuery(queryRegular).Sort().Rows()) + + queryHash = fmt.Sprintf("select /*+ inl_merge_join(trange, tregular4) */ * from trange, tregular4 where trange.a=tregular4.a and trange.a > %v and tregular4.b > %v;", x1, x2) + queryRegular = fmt.Sprintf("select /*+ inl_merge_join(tregular2, tregular4) */ * from tregular2, tregular4 where tregular2.a=tregular4.a and tregular2.a > %v and tregular4.b > %v;", x1, x2) + c.Assert(tk.HasPlan(queryHash, "IndexMergeJoin"), IsTrue) + tk.MustQuery(queryHash).Sort().Check(tk.MustQuery(queryRegular).Sort().Rows()) + + // group 7 + // index_hash_join hash partition and hash partition + queryHash = fmt.Sprintf("select /*+ inl_hash_join(thash, thash2) */ * from thash, thash2 where thash.a = thash2.a and thash.a in (%v, %v);", x1, x2) + queryRegular = fmt.Sprintf("select /*+ inl_hash_join(tregular1, tregular3) */ * from tregular1, tregular3 where tregular1.a = tregular3.a and tregular1.a in (%v, %v);", x1, x2) + c.Assert(tk.HasPlan(queryHash, "IndexHashJoin"), IsTrue) + tk.MustQuery(queryHash).Sort().Check(tk.MustQuery(queryRegular).Sort().Rows()) + + queryHash = fmt.Sprintf("select /*+ inl_hash_join(thash, thash2) */ * from thash, thash2 where thash.a = thash2.a and thash.a in (%v, %v) and thash2.a in (%v, %v);", x1, x2, x3, x4) + queryRegular = fmt.Sprintf("select /*+ inl_hash_join(tregular1, tregular3) */ * from tregular1, tregular3 where tregular1.a = tregular3.a and tregular1.a in (%v, %v) and tregular3.a in (%v, %v);", x1, x2, x3, x4) + c.Assert(tk.HasPlan(queryHash, "IndexHashJoin"), IsTrue) + tk.MustQuery(queryHash).Sort().Check(tk.MustQuery(queryRegular).Sort().Rows()) + + queryHash = fmt.Sprintf("select /*+ inl_hash_join(thash, thash2) */ * from thash, thash2 where thash.a = thash2.a and thash.a > %v and thash2.b > %v;", x1, x2) + queryRegular = fmt.Sprintf("select /*+ inl_hash_join(tregular1, tregular3) */ * from tregular1, tregular3 where tregular1.a = tregular3.a and tregular1.a > %v and tregular3.b > %v;", x1, x2) + c.Assert(tk.HasPlan(queryHash, "IndexHashJoin"), IsTrue) + tk.MustQuery(queryHash).Sort().Check(tk.MustQuery(queryRegular).Sort().Rows()) + + // group 8 + // index_hash_join hash partition and hash partition + queryHash = fmt.Sprintf("select /*+ inl_hash_join(thash, tregular3) */ * from thash, tregular3 where thash.a = tregular3.a and thash.a in (%v, %v);", x1, x2) + queryRegular = fmt.Sprintf("select /*+ inl_hash_join(tregular1, tregular3) */ * from tregular1, tregular3 where tregular1.a = tregular3.a and tregular1.a in (%v, %v);", x1, x2) + c.Assert(tk.HasPlan(queryHash, "IndexHashJoin"), IsTrue) + tk.MustQuery(queryHash).Sort().Check(tk.MustQuery(queryRegular).Sort().Rows()) + + queryHash = fmt.Sprintf("select /*+ inl_hash_join(thash, tregular3) */ * from thash, tregular3 where thash.a = tregular3.a and thash.a in (%v, %v) and tregular3.a in (%v, %v);", x1, x2, x3, x4) + queryRegular = fmt.Sprintf("select /*+ inl_hash_join(tregular1, tregular3) */ * from tregular1, tregular3 where tregular1.a = tregular3.a and tregular1.a in (%v, %v) and tregular3.a in (%v, %v);", x1, x2, x3, x4) + c.Assert(tk.HasPlan(queryHash, "IndexHashJoin"), IsTrue) + tk.MustQuery(queryHash).Sort().Check(tk.MustQuery(queryRegular).Sort().Rows()) + + queryHash = fmt.Sprintf("select /*+ inl_hash_join(thash, tregular3) */ * from thash, tregular3 where thash.a = tregular3.a and thash.a > %v and tregular3.b > %v;", x1, x2) + queryRegular = fmt.Sprintf("select /*+ inl_hash_join(tregular1, tregular3) */ * from tregular1, tregular3 where tregular1.a = tregular3.a and tregular1.a > %v and tregular3.b > %v;", x1, x2) + c.Assert(tk.HasPlan(queryHash, "IndexHashJoin"), IsTrue) + tk.MustQuery(queryHash).Sort().Check(tk.MustQuery(queryRegular).Sort().Rows()) +} + func createTable4DynamicPruneModeTestWithExpression(tk *testkit.TestKit) { tk.MustExec("create table trange(a int, b int) partition by range(a) (partition p0 values less than(3), partition p1 values less than (5), partition p2 values less than(11));") tk.MustExec("create table thash(a int, b int) partition by hash(a) partitions 4;") @@ -1233,6 +1468,101 @@ func (s *partitionTableSuite) TestDirectReadingWithAgg(c *C) { } } +func (s *partitionTableSuite) TestIdexMerge(c *C) { + if israce.RaceEnabled { + c.Skip("exhaustive types test, skip race test") + } + + tk := testkit.NewTestKitWithInit(c, s.store) + tk.MustExec("create database test_idx_merge") + tk.MustExec("use test_idx_merge") + tk.MustExec("set @@tidb_partition_prune_mode = 'dynamic'") + + // list partition table + tk.MustExec(`create table tlist(a int, b int, primary key(a) clustered, index idx_b(b)) partition by list(a)( + partition p0 values in (1, 2, 3, 4), + partition p1 values in (5, 6, 7, 8), + partition p2 values in (9, 10, 11, 12));`) + + // range partition table + tk.MustExec(`create table trange(a int, b int, primary key(a) clustered, index idx_b(b)) partition by range(a) ( + partition p0 values less than(300), + partition p1 values less than (500), + partition p2 values less than(1100));`) + + // hash partition table + tk.MustExec(`create table thash(a int, b int, primary key(a) clustered, index idx_b(b)) partition by hash(a) partitions 4;`) + + // regular table + tk.MustExec("create table tregular1(a int, b int, primary key(a) clustered)") + tk.MustExec("create table tregular2(a int, b int, primary key(a) clustered)") + + // generate some random data to be inserted + vals := make([]string, 0, 2000) + for i := 0; i < 2000; i++ { + vals = append(vals, fmt.Sprintf("(%v, %v)", rand.Intn(1100), rand.Intn(2000))) + } + + tk.MustExec("insert ignore into trange values " + strings.Join(vals, ",")) + tk.MustExec("insert ignore into thash values " + strings.Join(vals, ",")) + tk.MustExec("insert ignore into tregular1 values " + strings.Join(vals, ",")) + + vals = make([]string, 0, 2000) + for i := 0; i < 2000; i++ { + vals = append(vals, fmt.Sprintf("(%v, %v)", rand.Intn(12)+1, rand.Intn(20))) + } + + tk.MustExec("insert ignore into tlist values " + strings.Join(vals, ",")) + tk.MustExec("insert ignore into tregular2 values " + strings.Join(vals, ",")) + + // test range partition + for i := 0; i < 100; i++ { + x1 := rand.Intn(1099) + x2 := rand.Intn(1099) + + queryPartition1 := fmt.Sprintf("select /*+ use_index_merge(trange) */ * from trange where a > %v or b < %v;", x1, x2) + queryRegular1 := fmt.Sprintf("select /*+ use_index_merge(tregular1) */ * from tregular1 where a > %v or b < %v;", x1, x2) + c.Assert(tk.HasPlan(queryPartition1, "IndexMerge"), IsTrue) // check if IndexLookUp is used + tk.MustQuery(queryPartition1).Sort().Check(tk.MustQuery(queryRegular1).Sort().Rows()) + + queryPartition2 := fmt.Sprintf("select /*+ use_index_merge(trange) */ * from trange where a > %v or b > %v;", x1, x2) + queryRegular2 := fmt.Sprintf("select /*+ use_index_merge(tregular1) */ * from tregular1 where a > %v or b > %v;", x1, x2) + c.Assert(tk.HasPlan(queryPartition2, "IndexMerge"), IsTrue) // check if IndexLookUp is used + tk.MustQuery(queryPartition2).Sort().Check(tk.MustQuery(queryRegular2).Sort().Rows()) + } + + // test hash partition + for i := 0; i < 100; i++ { + x1 := rand.Intn(1099) + x2 := rand.Intn(1099) + + queryPartition1 := fmt.Sprintf("select /*+ use_index_merge(thash) */ * from thash where a > %v or b < %v;", x1, x2) + queryRegular1 := fmt.Sprintf("select /*+ use_index_merge(tregualr1) */ * from tregular1 where a > %v or b < %v;", x1, x2) + c.Assert(tk.HasPlan(queryPartition1, "IndexMerge"), IsTrue) // check if IndexLookUp is used + tk.MustQuery(queryPartition1).Sort().Check(tk.MustQuery(queryRegular1).Sort().Rows()) + + queryPartition2 := fmt.Sprintf("select /*+ use_index_merge(thash) */ * from thash where a > %v or b > %v;", x1, x2) + queryRegular2 := fmt.Sprintf("select /*+ use_index_merge(tregular1) */ * from tregular1 where a > %v or b > %v;", x1, x2) + c.Assert(tk.HasPlan(queryPartition2, "IndexMerge"), IsTrue) // check if IndexLookUp is used + tk.MustQuery(queryPartition2).Sort().Check(tk.MustQuery(queryRegular2).Sort().Rows()) + } + + // test list partition + for i := 0; i < 100; i++ { + x1 := rand.Intn(12) + 1 + x2 := rand.Intn(12) + 1 + queryPartition1 := fmt.Sprintf("select /*+ use_index_merge(tlist) */ * from tlist where a > %v or b < %v;", x1, x2) + queryRegular1 := fmt.Sprintf("select /*+ use_index_merge(tregular2) */ * from tregular2 where a > %v or b < %v;", x1, x2) + c.Assert(tk.HasPlan(queryPartition1, "IndexMerge"), IsTrue) // check if IndexLookUp is used + tk.MustQuery(queryPartition1).Sort().Check(tk.MustQuery(queryRegular1).Sort().Rows()) + + queryPartition2 := fmt.Sprintf("select /*+ use_index_merge(tlist) */ * from tlist where a > %v or b > %v;", x1, x2) + queryRegular2 := fmt.Sprintf("select /*+ use_index_merge(tregular2) */ * from tregular2 where a > %v or b > %v;", x1, x2) + c.Assert(tk.HasPlan(queryPartition2, "IndexMerge"), IsTrue) // check if IndexLookUp is used + tk.MustQuery(queryPartition2).Sort().Check(tk.MustQuery(queryRegular2).Sort().Rows()) + } +} + func (s *globalIndexSuite) TestGlobalIndexScan(c *C) { tk := testkit.NewTestKitWithInit(c, s.store) tk.MustExec("drop table if exists p") @@ -1262,3 +1592,80 @@ func (s *globalIndexSuite) TestIssue21731(c *C) { tk.MustExec("drop table if exists p, t") tk.MustExec("create table t (a int, b int, unique index idx(a)) partition by list columns(b) (partition p0 values in (1), partition p1 values in (2));") } + +func (s *testSuiteWithData) TestRangePartitionBoundariesEq(c *C) { + tk := testkit.NewTestKit(c, s.store) + + tk.MustExec("SET @@tidb_partition_prune_mode = 'dynamic'") + tk.MustExec("CREATE DATABASE TestRangePartitionBoundaries") + defer tk.MustExec("DROP DATABASE TestRangePartitionBoundaries") + tk.MustExec("USE TestRangePartitionBoundaries") + tk.MustExec("DROP TABLE IF EXISTS t") + tk.MustExec(`CREATE TABLE t +(a INT, b varchar(255)) +PARTITION BY RANGE (a) ( + PARTITION p0 VALUES LESS THAN (1000000), + PARTITION p1 VALUES LESS THAN (2000000), + PARTITION p2 VALUES LESS THAN (3000000)); +`) + + var input []string + var output []testOutput + s.testData.GetTestCases(c, &input, &output) + s.verifyPartitionResult(tk, input, output) +} + +type testOutput struct { + SQL string + Plan []string + Res []string +} + +func (s *testSuiteWithData) TestRangePartitionBoundariesNe(c *C) { + tk := testkit.NewTestKit(c, s.store) + + tk.MustExec("SET @@tidb_partition_prune_mode = 'dynamic'") + tk.MustExec("CREATE DATABASE TestRangePartitionBoundariesNe") + defer tk.MustExec("DROP DATABASE TestRangePartitionBoundariesNe") + tk.MustExec("USE TestRangePartitionBoundariesNe") + tk.MustExec("DROP TABLE IF EXISTS t") + tk.MustExec(`CREATE TABLE t +(a INT, b varchar(255)) +PARTITION BY RANGE (a) ( + PARTITION p0 VALUES LESS THAN (1), + PARTITION p1 VALUES LESS THAN (2), + PARTITION p2 VALUES LESS THAN (3), + PARTITION p3 VALUES LESS THAN (4), + PARTITION p4 VALUES LESS THAN (5), + PARTITION p5 VALUES LESS THAN (6), + PARTITION p6 VALUES LESS THAN (7))`) + + var input []string + var output []testOutput + s.testData.GetTestCases(c, &input, &output) + s.verifyPartitionResult(tk, input, output) +} + +func (s *testSuiteWithData) verifyPartitionResult(tk *testkit.TestKit, input []string, output []testOutput) { + for i, tt := range input { + var isSelect bool = false + if strings.HasPrefix(strings.ToLower(tt), "select ") { + isSelect = true + } + s.testData.OnRecord(func() { + output[i].SQL = tt + if isSelect { + output[i].Plan = s.testData.ConvertRowsToStrings(tk.UsedPartitions(tt).Rows()) + output[i].Res = s.testData.ConvertRowsToStrings(tk.MustQuery(tt).Sort().Rows()) + } else { + // to avoid double execution of INSERT (and INSERT does not return anything) + output[i].Res = nil + output[i].Plan = nil + } + }) + if isSelect { + tk.UsedPartitions(tt).Check(testkit.Rows(output[i].Plan...)) + } + tk.MayQuery(tt).Sort().Check(testkit.Rows(output[i].Res...)) + } +} diff --git a/executor/seqtest/seq_executor_test.go b/executor/seqtest/seq_executor_test.go index 061e09dcc1315..bcecfc8d52ad4 100644 --- a/executor/seqtest/seq_executor_test.go +++ b/executor/seqtest/seq_executor_test.go @@ -1473,8 +1473,6 @@ func (s *seqTestSuite) TestMaxDeltaSchemaCount(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") c.Assert(variable.GetMaxDeltaSchemaCount(), Equals, int64(variable.DefTiDBMaxDeltaSchemaCount)) - gvc := domain.GetDomain(tk.Se).GetGlobalVarsCache() - gvc.Disable() tk.MustExec("set @@global.tidb_max_delta_schema_count= -1") tk.MustQuery("show warnings;").Check(testkit.Rows("Warning 1292 Truncated incorrect tidb_max_delta_schema_count value: '-1'")) diff --git a/executor/show_test.go b/executor/show_test.go index a343779245a3f..ea6d6734159b6 100644 --- a/executor/show_test.go +++ b/executor/show_test.go @@ -1102,9 +1102,10 @@ func (s *testSuite5) TestShowBuiltin(c *C) { res := tk.MustQuery("show builtins;") c.Assert(res, NotNil) rows := res.Rows() - c.Assert(268, Equals, len(rows)) + const builtinFuncNum = 269 + c.Assert(builtinFuncNum, Equals, len(rows)) c.Assert("abs", Equals, rows[0][0].(string)) - c.Assert("yearweek", Equals, rows[267][0].(string)) + c.Assert("yearweek", Equals, rows[builtinFuncNum-1][0].(string)) } func (s *testSuite5) TestShowClusterConfig(c *C) { diff --git a/executor/simple.go b/executor/simple.go index 74063b2429c06..7270f12aecdd0 100644 --- a/executor/simple.go +++ b/executor/simple.go @@ -22,7 +22,6 @@ import ( "github.com/ngaut/pools" "github.com/pingcap/errors" - "github.com/pingcap/failpoint" "github.com/pingcap/parser/ast" "github.com/pingcap/parser/auth" "github.com/pingcap/parser/model" @@ -674,26 +673,6 @@ func (e *SimpleExec) executeStartTransactionReadOnlyWithTimestampBound(ctx conte if err != nil { return err } - dom := domain.GetDomain(e.ctx) - m, err := dom.GetSnapshotMeta(e.ctx.GetSessionVars().TxnCtx.StartTS) - if err != nil { - return err - } - staleVer, err := m.GetSchemaVersion() - if err != nil { - return err - } - failpoint.Inject("mockStalenessTxnSchemaVer", func(val failpoint.Value) { - if val.(bool) { - staleVer = e.ctx.GetSessionVars().GetInfoSchema().SchemaMetaVersion() - 1 - } else { - staleVer = e.ctx.GetSessionVars().GetInfoSchema().SchemaMetaVersion() - } - }) - // TODO: currently we directly check the schema version. In future, we can cache the stale infoschema instead. - if e.ctx.GetSessionVars().GetInfoSchema().SchemaMetaVersion() > staleVer { - return errors.New("schema version changed after the staleness startTS") - } // With START TRANSACTION, autocommit remains disabled until you end // the transaction with COMMIT or ROLLBACK. The autocommit mode then diff --git a/executor/stale_txn_test.go b/executor/stale_txn_test.go index 493bda06c5de2..7cf235bd3c0f7 100644 --- a/executor/stale_txn_test.go +++ b/executor/stale_txn_test.go @@ -18,7 +18,6 @@ import ( "time" . "github.com/pingcap/check" - "github.com/pingcap/errors" "github.com/pingcap/failpoint" "github.com/pingcap/tidb/ddl/placement" "github.com/pingcap/tidb/store/tikv/oracle" @@ -26,12 +25,6 @@ import ( ) func (s *testStaleTxnSerialSuite) TestExactStalenessTransaction(c *C) { - c.Assert(failpoint.Enable("github.com/pingcap/tidb/executor/mockStalenessTxnSchemaVer", "return(false)"), IsNil) - defer func() { - err := failpoint.Disable("github.com/pingcap/tidb/executor/mockStalenessTxnSchemaVer") - c.Assert(err, IsNil) - }() - testcases := []struct { name string preSQL string @@ -117,8 +110,6 @@ func (s *testStaleTxnSerialSuite) TestExactStalenessTransaction(c *C) { } func (s *testStaleTxnSerialSuite) TestStaleReadKVRequest(c *C) { - c.Assert(failpoint.Enable("github.com/pingcap/tidb/executor/mockStalenessTxnSchemaVer", "return(false)"), IsNil) - defer failpoint.Disable("github.com/pingcap/tidb/executor/mockStalenessTxnSchemaVer") tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") tk.MustExec("drop table if exists t") @@ -155,7 +146,7 @@ func (s *testStaleTxnSerialSuite) TestStaleReadKVRequest(c *C) { failpoint.Enable("github.com/pingcap/tidb/config/injectTxnScope", fmt.Sprintf(`return("%v")`, testcase.zone)) failpoint.Enable("github.com/pingcap/tidb/store/tikv/assertStoreLabels", fmt.Sprintf(`return("%v_%v")`, placement.DCLabelKey, testcase.txnScope)) failpoint.Enable("github.com/pingcap/tidb/store/tikv/assertStaleReadFlag", `return(true)`) - tk.MustExec(`START TRANSACTION READ ONLY WITH TIMESTAMP BOUND EXACT STALENESS '00:00:20';`) + tk.MustExec(`START TRANSACTION READ ONLY WITH TIMESTAMP BOUND EXACT STALENESS '00:00:00';`) tk.MustQuery(testcase.sql) tk.MustExec(`commit`) failpoint.Disable("github.com/pingcap/tidb/config/injectTxnScope") @@ -165,12 +156,6 @@ func (s *testStaleTxnSerialSuite) TestStaleReadKVRequest(c *C) { } func (s *testStaleTxnSerialSuite) TestStalenessAndHistoryRead(c *C) { - c.Assert(failpoint.Enable("github.com/pingcap/tidb/executor/mockStalenessTxnSchemaVer", "return(false)"), IsNil) - defer func() { - err := failpoint.Disable("github.com/pingcap/tidb/executor/mockStalenessTxnSchemaVer") - c.Assert(err, IsNil) - }() - tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") // For mocktikv, safe point is not initialized, we manually insert it for snapshot to use. @@ -193,62 +178,7 @@ func (s *testStaleTxnSerialSuite) TestStalenessAndHistoryRead(c *C) { tk.MustExec("commit") } -func (s *testStaleTxnSerialSuite) TestStalenessTransactionSchemaVer(c *C) { - testcases := []struct { - name string - sql string - expectErr error - }{ - { - name: "ddl change before stale txn", - sql: `START TRANSACTION READ ONLY WITH TIMESTAMP BOUND EXACT STALENESS '00:00:03'`, - expectErr: errors.New("schema version changed after the staleness startTS"), - }, - { - name: "ddl change before stale txn", - sql: fmt.Sprintf("START TRANSACTION READ ONLY WITH TIMESTAMP BOUND READ TIMESTAMP '%v'", - time.Now().Truncate(3*time.Second).Format("2006-01-02 15:04:05")), - expectErr: errors.New(".*schema version changed after the staleness startTS.*"), - }, - { - name: "ddl change before stale txn", - sql: `START TRANSACTION READ ONLY WITH TIMESTAMP BOUND EXACT STALENESS '00:00:03'`, - expectErr: nil, - }, - } - tk := testkit.NewTestKitWithInit(c, s.store) - for _, testcase := range testcases { - check := func() { - if testcase.expectErr != nil { - c.Assert(failpoint.Enable("github.com/pingcap/tidb/executor/mockStalenessTxnSchemaVer", "return(true)"), IsNil) - defer func() { - err := failpoint.Disable("github.com/pingcap/tidb/executor/mockStalenessTxnSchemaVer") - c.Assert(err, IsNil) - }() - - } else { - c.Assert(failpoint.Enable("github.com/pingcap/tidb/executor/mockStalenessTxnSchemaVer", "return(false)"), IsNil) - defer func() { - err := failpoint.Disable("github.com/pingcap/tidb/executor/mockStalenessTxnSchemaVer") - c.Assert(err, IsNil) - }() - - } - _, err := tk.Exec(testcase.sql) - if testcase.expectErr != nil { - c.Assert(err, NotNil) - c.Assert(err.Error(), Matches, testcase.expectErr.Error()) - } else { - c.Assert(err, IsNil) - } - } - check() - } -} - func (s *testStaleTxnSerialSuite) TestTimeBoundedStalenessTxn(c *C) { - c.Assert(failpoint.Enable("github.com/pingcap/tidb/executor/mockStalenessTxnSchemaVer", "return(false)"), IsNil) - defer failpoint.Disable("github.com/pingcap/tidb/executor/mockStalenessTxnSchemaVer") tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") tk.MustExec("drop table if exists t") @@ -318,3 +248,22 @@ func (s *testStaleTxnSerialSuite) TestTimeBoundedStalenessTxn(c *C) { failpoint.Disable("github.com/pingcap/tidb/store/tikv/injectSafeTS") } } + +func (s *testStaleTxnSerialSuite) TestStalenessTransactionSchemaVer(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t (id int primary key);") + + schemaVer1 := tk.Se.GetSessionVars().GetInfoSchema().SchemaMetaVersion() + time.Sleep(time.Second) + tk.MustExec("drop table if exists t") + schemaVer2 := tk.Se.GetSessionVars().GetInfoSchema().SchemaMetaVersion() + // confirm schema changed + c.Assert(schemaVer1, Less, schemaVer2) + + tk.MustExec(`START TRANSACTION READ ONLY WITH TIMESTAMP BOUND EXACT STALENESS '00:00:01'`) + schemaVer3 := tk.Se.GetSessionVars().GetInfoSchema().SchemaMetaVersion() + // got an old infoSchema + c.Assert(schemaVer3, Equals, schemaVer1) +} diff --git a/executor/testdata/executor_suite_in.json b/executor/testdata/executor_suite_in.json index 6abd20c740a80..fff3187717f0a 100644 --- a/executor/testdata/executor_suite_in.json +++ b/executor/testdata/executor_suite_in.json @@ -51,5 +51,96 @@ "select count(*) from t as t1 left join t as t2 on t1.c1 = t2.c1 where t1.c1 != NULL", "select * from t as t1 left join t as t2 on t1.c1 = t2.c1 where t1.c1 is not NULL" ] + }, + { + "name": "TestRangePartitionBoundariesEq", + "cases": [ + "INSERT INTO t VALUES (999998, '999998 Filler ...'), (999999, '999999 Filler ...'), (1000000, '1000000 Filler ...'), (1000001, '1000001 Filler ...'), (1000002, '1000002 Filler ...')", + "INSERT INTO t VALUES (1999998, '1999998 Filler ...'), (1999999, '1999999 Filler ...'), (2000000, '2000000 Filler ...'), (2000001, '2000001 Filler ...'), (2000002, '2000002 Filler ...')", + "INSERT INTO t VALUES (2999998, '2999998 Filler ...'), (2999999, '2999999 Filler ...')", + "INSERT INTO t VALUES (-2147483648, 'MIN_INT filler...'), (0, '0 Filler...')", + "ANALYZE TABLE t", + "SELECT * FROM t WHERE a = -2147483648", + "SELECT * FROM t WHERE a IN (-2147483648)", + "SELECT * FROM t WHERE a = 0", + "SELECT * FROM t WHERE a IN (0)", + "SELECT * FROM t WHERE a = 999998", + "SELECT * FROM t WHERE a IN (999998)", + "SELECT * FROM t WHERE a = 999999", + "SELECT * FROM t WHERE a IN (999999)", + "SELECT * FROM t WHERE a = 1000000", + "SELECT * FROM t WHERE a IN (1000000)", + "SELECT * FROM t WHERE a = 1000001", + "SELECT * FROM t WHERE a IN (1000001)", + "SELECT * FROM t WHERE a = 1000002", + "SELECT * FROM t WHERE a IN (1000002)", + "SELECT * FROM t WHERE a = 3000000", + "SELECT * FROM t WHERE a IN (3000000)", + "SELECT * FROM t WHERE a = 3000001", + "SELECT * FROM t WHERE a IN (3000001)", + "SELECT * FROM t WHERE a IN (-2147483648, -2147483647)", + "SELECT * FROM t WHERE a IN (-2147483647, -2147483646)", + "SELECT * FROM t WHERE a IN (999997, 999998, 999999)", + "SELECT * FROM t WHERE a IN (999998, 999999, 1000000)", + "SELECT * FROM t WHERE a IN (999999, 1000000, 1000001)", + "SELECT * FROM t WHERE a IN (1000000, 1000001, 1000002)", + "SELECT * FROM t WHERE a IN (1999997, 1999998, 1999999)", + "SELECT * FROM t WHERE a IN (1999998, 1999999, 2000000)", + "SELECT * FROM t WHERE a IN (1999999, 2000000, 2000001)", + "SELECT * FROM t WHERE a IN (2000000, 2000001, 2000002)", + "SELECT * FROM t WHERE a IN (2999997, 2999998, 2999999)", + "SELECT * FROM t WHERE a IN (2999998, 2999999, 3000000)", + "SELECT * FROM t WHERE a IN (2999999, 3000000, 3000001)", + "SELECT * FROM t WHERE a IN (3000000, 3000001, 3000002)" + ] + }, + { + "name": "TestRangePartitionBoundariesNe", + "cases": [ + "INSERT INTO t VALUES (0, '0 Filler...')", + "INSERT INTO t VALUES (1, '1 Filler...')", + "INSERT INTO t VALUES (2, '2 Filler...')", + "INSERT INTO t VALUES (3, '3 Filler...')", + "INSERT INTO t VALUES (4, '4 Filler...')", + "INSERT INTO t VALUES (5, '5 Filler...')", + "INSERT INTO t VALUES (6, '6 Filler...')", + "ANALYZE TABLE t", + "SELECT * FROM t WHERE a != -1", + "SELECT * FROM t WHERE 1 = 1 AND a != -1", + "SELECT * FROM t WHERE a NOT IN (-2, -1)", + "SELECT * FROM t WHERE 1 = 0 OR a = -1", + "SELECT * FROM t WHERE a != 0", + "SELECT * FROM t WHERE 1 = 1 AND a != -1 AND a != 0", + "SELECT * FROM t WHERE a NOT IN (-2, -1, 0)", + "SELECT * FROM t WHERE 1 = 0 OR a = -1 OR a = 0", + "SELECT * FROM t WHERE a != 1", + "SELECT * FROM t WHERE 1 = 1 AND a != -1 AND a != 0 AND a != 1", + "SELECT * FROM t WHERE a NOT IN (-2, -1, 0, 1)", + "SELECT * FROM t WHERE 1 = 0 OR a = -1 OR a = 0 OR a = 1", + "SELECT * FROM t WHERE a != 2", + "SELECT * FROM t WHERE 1 = 1 AND a != -1 AND a != 0 AND a != 1 AND a != 2", + "SELECT * FROM t WHERE a NOT IN (-2, -1, 0, 1, 2)", + "SELECT * FROM t WHERE 1 = 0 OR a = -1 OR a = 0 OR a = 1 OR a = 2", + "SELECT * FROM t WHERE a != 3", + "SELECT * FROM t WHERE 1 = 1 AND a != -1 AND a != 0 AND a != 1 AND a != 2 AND a != 3", + "SELECT * FROM t WHERE a NOT IN (-2, -1, 0, 1, 2, 3)", + "SELECT * FROM t WHERE 1 = 0 OR a = -1 OR a = 0 OR a = 1 OR a = 2 OR a = 3", + "SELECT * FROM t WHERE a != 4", + "SELECT * FROM t WHERE 1 = 1 AND a != -1 AND a != 0 AND a != 1 AND a != 2 AND a != 3 AND a != 4", + "SELECT * FROM t WHERE a NOT IN (-2, -1, 0, 1, 2, 3, 4)", + "SELECT * FROM t WHERE 1 = 0 OR a = -1 OR a = 0 OR a = 1 OR a = 2 OR a = 3 OR a = 4", + "SELECT * FROM t WHERE a != 5", + "SELECT * FROM t WHERE 1 = 1 AND a != -1 AND a != 0 AND a != 1 AND a != 2 AND a != 3 AND a != 4 AND a != 5", + "SELECT * FROM t WHERE a NOT IN (-2, -1, 0, 1, 2, 3, 4, 5)", + "SELECT * FROM t WHERE 1 = 0 OR a = -1 OR a = 0 OR a = 1 OR a = 2 OR a = 3 OR a = 4 OR a = 5", + "SELECT * FROM t WHERE a != 6", + "SELECT * FROM t WHERE 1 = 1 AND a != -1 AND a != 0 AND a != 1 AND a != 2 AND a != 3 AND a != 4 AND a != 5 AND a != 6", + "SELECT * FROM t WHERE a NOT IN (-2, -1, 0, 1, 2, 3, 4, 5, 6)", + "SELECT * FROM t WHERE 1 = 0 OR a = -1 OR a = 0 OR a = 1 OR a = 2 OR a = 3 OR a = 4 OR a = 5 OR a = 6", + "SELECT * FROM t WHERE a != 7", + "SELECT * FROM t WHERE 1 = 1 AND a != -1 AND a != 0 AND a != 1 AND a != 2 AND a != 3 AND a != 4 AND a != 5 AND a != 6 AND a != 7", + "SELECT * FROM t WHERE a NOT IN (-2, -1, 0, 1, 2, 3, 4, 5, 6, 7)", + "SELECT * FROM t WHERE 1 = 0 OR a = -1 OR a = 0 OR a = 1 OR a = 2 OR a = 3 OR a = 4 OR a = 5 OR a = 6 OR a = 7" + ] } ] diff --git a/executor/testdata/executor_suite_out.json b/executor/testdata/executor_suite_out.json index 2be3c8ea4894f..caa5c4f948966 100644 --- a/executor/testdata/executor_suite_out.json +++ b/executor/testdata/executor_suite_out.json @@ -598,5 +598,802 @@ ] } ] + }, + { + "Name": "TestRangePartitionBoundariesEq", + "Cases": [ + { + "SQL": "INSERT INTO t VALUES (999998, '999998 Filler ...'), (999999, '999999 Filler ...'), (1000000, '1000000 Filler ...'), (1000001, '1000001 Filler ...'), (1000002, '1000002 Filler ...')", + "Plan": null, + "Res": null + }, + { + "SQL": "INSERT INTO t VALUES (1999998, '1999998 Filler ...'), (1999999, '1999999 Filler ...'), (2000000, '2000000 Filler ...'), (2000001, '2000001 Filler ...'), (2000002, '2000002 Filler ...')", + "Plan": null, + "Res": null + }, + { + "SQL": "INSERT INTO t VALUES (2999998, '2999998 Filler ...'), (2999999, '2999999 Filler ...')", + "Plan": null, + "Res": null + }, + { + "SQL": "INSERT INTO t VALUES (-2147483648, 'MIN_INT filler...'), (0, '0 Filler...')", + "Plan": null, + "Res": null + }, + { + "SQL": "ANALYZE TABLE t", + "Plan": null, + "Res": null + }, + { + "SQL": "SELECT * FROM t WHERE a = -2147483648", + "Plan": [ + "p0" + ], + "Res": [ + "-2147483648 MIN_INT filler..." + ] + }, + { + "SQL": "SELECT * FROM t WHERE a IN (-2147483648)", + "Plan": [ + "p0" + ], + "Res": [ + "-2147483648 MIN_INT filler..." + ] + }, + { + "SQL": "SELECT * FROM t WHERE a = 0", + "Plan": [ + "p0" + ], + "Res": [ + "0 0 Filler..." + ] + }, + { + "SQL": "SELECT * FROM t WHERE a IN (0)", + "Plan": [ + "p0" + ], + "Res": [ + "0 0 Filler..." + ] + }, + { + "SQL": "SELECT * FROM t WHERE a = 999998", + "Plan": [ + "p0" + ], + "Res": [ + "999998 999998 Filler ..." + ] + }, + { + "SQL": "SELECT * FROM t WHERE a IN (999998)", + "Plan": [ + "p0" + ], + "Res": [ + "999998 999998 Filler ..." + ] + }, + { + "SQL": "SELECT * FROM t WHERE a = 999999", + "Plan": [ + "p0" + ], + "Res": [ + "999999 999999 Filler ..." + ] + }, + { + "SQL": "SELECT * FROM t WHERE a IN (999999)", + "Plan": [ + "p0" + ], + "Res": [ + "999999 999999 Filler ..." + ] + }, + { + "SQL": "SELECT * FROM t WHERE a = 1000000", + "Plan": [ + "p1" + ], + "Res": [ + "1000000 1000000 Filler ..." + ] + }, + { + "SQL": "SELECT * FROM t WHERE a IN (1000000)", + "Plan": [ + "p1" + ], + "Res": [ + "1000000 1000000 Filler ..." + ] + }, + { + "SQL": "SELECT * FROM t WHERE a = 1000001", + "Plan": [ + "p1" + ], + "Res": [ + "1000001 1000001 Filler ..." + ] + }, + { + "SQL": "SELECT * FROM t WHERE a IN (1000001)", + "Plan": [ + "p1" + ], + "Res": [ + "1000001 1000001 Filler ..." + ] + }, + { + "SQL": "SELECT * FROM t WHERE a = 1000002", + "Plan": [ + "p1" + ], + "Res": [ + "1000002 1000002 Filler ..." + ] + }, + { + "SQL": "SELECT * FROM t WHERE a IN (1000002)", + "Plan": [ + "p1" + ], + "Res": [ + "1000002 1000002 Filler ..." + ] + }, + { + "SQL": "SELECT * FROM t WHERE a = 3000000", + "Plan": [ + "dual" + ], + "Res": null + }, + { + "SQL": "SELECT * FROM t WHERE a IN (3000000)", + "Plan": [ + "dual" + ], + "Res": null + }, + { + "SQL": "SELECT * FROM t WHERE a = 3000001", + "Plan": [ + "dual" + ], + "Res": null + }, + { + "SQL": "SELECT * FROM t WHERE a IN (3000001)", + "Plan": [ + "dual" + ], + "Res": null + }, + { + "SQL": "SELECT * FROM t WHERE a IN (-2147483648, -2147483647)", + "Plan": [ + "p0" + ], + "Res": [ + "-2147483648 MIN_INT filler..." + ] + }, + { + "SQL": "SELECT * FROM t WHERE a IN (-2147483647, -2147483646)", + "Plan": [ + "p0" + ], + "Res": null + }, + { + "SQL": "SELECT * FROM t WHERE a IN (999997, 999998, 999999)", + "Plan": [ + "p0" + ], + "Res": [ + "999998 999998 Filler ...", + "999999 999999 Filler ..." + ] + }, + { + "SQL": "SELECT * FROM t WHERE a IN (999998, 999999, 1000000)", + "Plan": [ + "p0 p1" + ], + "Res": [ + "1000000 1000000 Filler ...", + "999998 999998 Filler ...", + "999999 999999 Filler ..." + ] + }, + { + "SQL": "SELECT * FROM t WHERE a IN (999999, 1000000, 1000001)", + "Plan": [ + "p0 p1" + ], + "Res": [ + "1000000 1000000 Filler ...", + "1000001 1000001 Filler ...", + "999999 999999 Filler ..." + ] + }, + { + "SQL": "SELECT * FROM t WHERE a IN (1000000, 1000001, 1000002)", + "Plan": [ + "p1" + ], + "Res": [ + "1000000 1000000 Filler ...", + "1000001 1000001 Filler ...", + "1000002 1000002 Filler ..." + ] + }, + { + "SQL": "SELECT * FROM t WHERE a IN (1999997, 1999998, 1999999)", + "Plan": [ + "p1" + ], + "Res": [ + "1999998 1999998 Filler ...", + "1999999 1999999 Filler ..." + ] + }, + { + "SQL": "SELECT * FROM t WHERE a IN (1999998, 1999999, 2000000)", + "Plan": [ + "p1 p2" + ], + "Res": [ + "1999998 1999998 Filler ...", + "1999999 1999999 Filler ...", + "2000000 2000000 Filler ..." + ] + }, + { + "SQL": "SELECT * FROM t WHERE a IN (1999999, 2000000, 2000001)", + "Plan": [ + "p1 p2" + ], + "Res": [ + "1999999 1999999 Filler ...", + "2000000 2000000 Filler ...", + "2000001 2000001 Filler ..." + ] + }, + { + "SQL": "SELECT * FROM t WHERE a IN (2000000, 2000001, 2000002)", + "Plan": [ + "p2" + ], + "Res": [ + "2000000 2000000 Filler ...", + "2000001 2000001 Filler ...", + "2000002 2000002 Filler ..." + ] + }, + { + "SQL": "SELECT * FROM t WHERE a IN (2999997, 2999998, 2999999)", + "Plan": [ + "p2" + ], + "Res": [ + "2999998 2999998 Filler ...", + "2999999 2999999 Filler ..." + ] + }, + { + "SQL": "SELECT * FROM t WHERE a IN (2999998, 2999999, 3000000)", + "Plan": [ + "p2" + ], + "Res": [ + "2999998 2999998 Filler ...", + "2999999 2999999 Filler ..." + ] + }, + { + "SQL": "SELECT * FROM t WHERE a IN (2999999, 3000000, 3000001)", + "Plan": [ + "p2" + ], + "Res": [ + "2999999 2999999 Filler ..." + ] + }, + { + "SQL": "SELECT * FROM t WHERE a IN (3000000, 3000001, 3000002)", + "Plan": [ + "dual" + ], + "Res": null + } + ] + }, + { + "Name": "TestRangePartitionBoundariesNe", + "Cases": [ + { + "SQL": "INSERT INTO t VALUES (0, '0 Filler...')", + "Plan": null, + "Res": null + }, + { + "SQL": "INSERT INTO t VALUES (1, '1 Filler...')", + "Plan": null, + "Res": null + }, + { + "SQL": "INSERT INTO t VALUES (2, '2 Filler...')", + "Plan": null, + "Res": null + }, + { + "SQL": "INSERT INTO t VALUES (3, '3 Filler...')", + "Plan": null, + "Res": null + }, + { + "SQL": "INSERT INTO t VALUES (4, '4 Filler...')", + "Plan": null, + "Res": null + }, + { + "SQL": "INSERT INTO t VALUES (5, '5 Filler...')", + "Plan": null, + "Res": null + }, + { + "SQL": "INSERT INTO t VALUES (6, '6 Filler...')", + "Plan": null, + "Res": null + }, + { + "SQL": "ANALYZE TABLE t", + "Plan": null, + "Res": null + }, + { + "SQL": "SELECT * FROM t WHERE a != -1", + "Plan": [ + "all" + ], + "Res": [ + "0 0 Filler...", + "1 1 Filler...", + "2 2 Filler...", + "3 3 Filler...", + "4 4 Filler...", + "5 5 Filler...", + "6 6 Filler..." + ] + }, + { + "SQL": "SELECT * FROM t WHERE 1 = 1 AND a != -1", + "Plan": [ + "all" + ], + "Res": [ + "0 0 Filler...", + "1 1 Filler...", + "2 2 Filler...", + "3 3 Filler...", + "4 4 Filler...", + "5 5 Filler...", + "6 6 Filler..." + ] + }, + { + "SQL": "SELECT * FROM t WHERE a NOT IN (-2, -1)", + "Plan": [ + "all" + ], + "Res": [ + "0 0 Filler...", + "1 1 Filler...", + "2 2 Filler...", + "3 3 Filler...", + "4 4 Filler...", + "5 5 Filler...", + "6 6 Filler..." + ] + }, + { + "SQL": "SELECT * FROM t WHERE 1 = 0 OR a = -1", + "Plan": [ + "p0" + ], + "Res": null + }, + { + "SQL": "SELECT * FROM t WHERE a != 0", + "Plan": [ + "all" + ], + "Res": [ + "1 1 Filler...", + "2 2 Filler...", + "3 3 Filler...", + "4 4 Filler...", + "5 5 Filler...", + "6 6 Filler..." + ] + }, + { + "SQL": "SELECT * FROM t WHERE 1 = 1 AND a != -1 AND a != 0", + "Plan": [ + "all" + ], + "Res": [ + "1 1 Filler...", + "2 2 Filler...", + "3 3 Filler...", + "4 4 Filler...", + "5 5 Filler...", + "6 6 Filler..." + ] + }, + { + "SQL": "SELECT * FROM t WHERE a NOT IN (-2, -1, 0)", + "Plan": [ + "all" + ], + "Res": [ + "1 1 Filler...", + "2 2 Filler...", + "3 3 Filler...", + "4 4 Filler...", + "5 5 Filler...", + "6 6 Filler..." + ] + }, + { + "SQL": "SELECT * FROM t WHERE 1 = 0 OR a = -1 OR a = 0", + "Plan": [ + "p0" + ], + "Res": [ + "0 0 Filler..." + ] + }, + { + "SQL": "SELECT * FROM t WHERE a != 1", + "Plan": [ + "all" + ], + "Res": [ + "0 0 Filler...", + "2 2 Filler...", + "3 3 Filler...", + "4 4 Filler...", + "5 5 Filler...", + "6 6 Filler..." + ] + }, + { + "SQL": "SELECT * FROM t WHERE 1 = 1 AND a != -1 AND a != 0 AND a != 1", + "Plan": [ + "all" + ], + "Res": [ + "2 2 Filler...", + "3 3 Filler...", + "4 4 Filler...", + "5 5 Filler...", + "6 6 Filler..." + ] + }, + { + "SQL": "SELECT * FROM t WHERE a NOT IN (-2, -1, 0, 1)", + "Plan": [ + "all" + ], + "Res": [ + "2 2 Filler...", + "3 3 Filler...", + "4 4 Filler...", + "5 5 Filler...", + "6 6 Filler..." + ] + }, + { + "SQL": "SELECT * FROM t WHERE 1 = 0 OR a = -1 OR a = 0 OR a = 1", + "Plan": [ + "p0 p1" + ], + "Res": [ + "0 0 Filler...", + "1 1 Filler..." + ] + }, + { + "SQL": "SELECT * FROM t WHERE a != 2", + "Plan": [ + "all" + ], + "Res": [ + "0 0 Filler...", + "1 1 Filler...", + "3 3 Filler...", + "4 4 Filler...", + "5 5 Filler...", + "6 6 Filler..." + ] + }, + { + "SQL": "SELECT * FROM t WHERE 1 = 1 AND a != -1 AND a != 0 AND a != 1 AND a != 2", + "Plan": [ + "all" + ], + "Res": [ + "3 3 Filler...", + "4 4 Filler...", + "5 5 Filler...", + "6 6 Filler..." + ] + }, + { + "SQL": "SELECT * FROM t WHERE a NOT IN (-2, -1, 0, 1, 2)", + "Plan": [ + "all" + ], + "Res": [ + "3 3 Filler...", + "4 4 Filler...", + "5 5 Filler...", + "6 6 Filler..." + ] + }, + { + "SQL": "SELECT * FROM t WHERE 1 = 0 OR a = -1 OR a = 0 OR a = 1 OR a = 2", + "Plan": [ + "p0 p1 p2" + ], + "Res": [ + "0 0 Filler...", + "1 1 Filler...", + "2 2 Filler..." + ] + }, + { + "SQL": "SELECT * FROM t WHERE a != 3", + "Plan": [ + "all" + ], + "Res": [ + "0 0 Filler...", + "1 1 Filler...", + "2 2 Filler...", + "4 4 Filler...", + "5 5 Filler...", + "6 6 Filler..." + ] + }, + { + "SQL": "SELECT * FROM t WHERE 1 = 1 AND a != -1 AND a != 0 AND a != 1 AND a != 2 AND a != 3", + "Plan": [ + "all" + ], + "Res": [ + "4 4 Filler...", + "5 5 Filler...", + "6 6 Filler..." + ] + }, + { + "SQL": "SELECT * FROM t WHERE a NOT IN (-2, -1, 0, 1, 2, 3)", + "Plan": [ + "all" + ], + "Res": [ + "4 4 Filler...", + "5 5 Filler...", + "6 6 Filler..." + ] + }, + { + "SQL": "SELECT * FROM t WHERE 1 = 0 OR a = -1 OR a = 0 OR a = 1 OR a = 2 OR a = 3", + "Plan": [ + "p0 p1 p2 p3" + ], + "Res": [ + "0 0 Filler...", + "1 1 Filler...", + "2 2 Filler...", + "3 3 Filler..." + ] + }, + { + "SQL": "SELECT * FROM t WHERE a != 4", + "Plan": [ + "all" + ], + "Res": [ + "0 0 Filler...", + "1 1 Filler...", + "2 2 Filler...", + "3 3 Filler...", + "5 5 Filler...", + "6 6 Filler..." + ] + }, + { + "SQL": "SELECT * FROM t WHERE 1 = 1 AND a != -1 AND a != 0 AND a != 1 AND a != 2 AND a != 3 AND a != 4", + "Plan": [ + "all" + ], + "Res": [ + "5 5 Filler...", + "6 6 Filler..." + ] + }, + { + "SQL": "SELECT * FROM t WHERE a NOT IN (-2, -1, 0, 1, 2, 3, 4)", + "Plan": [ + "all" + ], + "Res": [ + "5 5 Filler...", + "6 6 Filler..." + ] + }, + { + "SQL": "SELECT * FROM t WHERE 1 = 0 OR a = -1 OR a = 0 OR a = 1 OR a = 2 OR a = 3 OR a = 4", + "Plan": [ + "p0 p1 p2 p3 p4" + ], + "Res": [ + "0 0 Filler...", + "1 1 Filler...", + "2 2 Filler...", + "3 3 Filler...", + "4 4 Filler..." + ] + }, + { + "SQL": "SELECT * FROM t WHERE a != 5", + "Plan": [ + "all" + ], + "Res": [ + "0 0 Filler...", + "1 1 Filler...", + "2 2 Filler...", + "3 3 Filler...", + "4 4 Filler...", + "6 6 Filler..." + ] + }, + { + "SQL": "SELECT * FROM t WHERE 1 = 1 AND a != -1 AND a != 0 AND a != 1 AND a != 2 AND a != 3 AND a != 4 AND a != 5", + "Plan": [ + "all" + ], + "Res": [ + "6 6 Filler..." + ] + }, + { + "SQL": "SELECT * FROM t WHERE a NOT IN (-2, -1, 0, 1, 2, 3, 4, 5)", + "Plan": [ + "all" + ], + "Res": [ + "6 6 Filler..." + ] + }, + { + "SQL": "SELECT * FROM t WHERE 1 = 0 OR a = -1 OR a = 0 OR a = 1 OR a = 2 OR a = 3 OR a = 4 OR a = 5", + "Plan": [ + "p0 p1 p2 p3 p4 p5" + ], + "Res": [ + "0 0 Filler...", + "1 1 Filler...", + "2 2 Filler...", + "3 3 Filler...", + "4 4 Filler...", + "5 5 Filler..." + ] + }, + { + "SQL": "SELECT * FROM t WHERE a != 6", + "Plan": [ + "all" + ], + "Res": [ + "0 0 Filler...", + "1 1 Filler...", + "2 2 Filler...", + "3 3 Filler...", + "4 4 Filler...", + "5 5 Filler..." + ] + }, + { + "SQL": "SELECT * FROM t WHERE 1 = 1 AND a != -1 AND a != 0 AND a != 1 AND a != 2 AND a != 3 AND a != 4 AND a != 5 AND a != 6", + "Plan": [ + "all" + ], + "Res": null + }, + { + "SQL": "SELECT * FROM t WHERE a NOT IN (-2, -1, 0, 1, 2, 3, 4, 5, 6)", + "Plan": [ + "all" + ], + "Res": null + }, + { + "SQL": "SELECT * FROM t WHERE 1 = 0 OR a = -1 OR a = 0 OR a = 1 OR a = 2 OR a = 3 OR a = 4 OR a = 5 OR a = 6", + "Plan": [ + "all" + ], + "Res": [ + "0 0 Filler...", + "1 1 Filler...", + "2 2 Filler...", + "3 3 Filler...", + "4 4 Filler...", + "5 5 Filler...", + "6 6 Filler..." + ] + }, + { + "SQL": "SELECT * FROM t WHERE a != 7", + "Plan": [ + "all" + ], + "Res": [ + "0 0 Filler...", + "1 1 Filler...", + "2 2 Filler...", + "3 3 Filler...", + "4 4 Filler...", + "5 5 Filler...", + "6 6 Filler..." + ] + }, + { + "SQL": "SELECT * FROM t WHERE 1 = 1 AND a != -1 AND a != 0 AND a != 1 AND a != 2 AND a != 3 AND a != 4 AND a != 5 AND a != 6 AND a != 7", + "Plan": [ + "all" + ], + "Res": null + }, + { + "SQL": "SELECT * FROM t WHERE a NOT IN (-2, -1, 0, 1, 2, 3, 4, 5, 6, 7)", + "Plan": [ + "all" + ], + "Res": null + }, + { + "SQL": "SELECT * FROM t WHERE 1 = 0 OR a = -1 OR a = 0 OR a = 1 OR a = 2 OR a = 3 OR a = 4 OR a = 5 OR a = 6 OR a = 7", + "Plan": [ + "all" + ], + "Res": [ + "0 0 Filler...", + "1 1 Filler...", + "2 2 Filler...", + "3 3 Filler...", + "4 4 Filler...", + "5 5 Filler...", + "6 6 Filler..." + ] + } + ] } ] diff --git a/executor/write_test.go b/executor/write_test.go index b832e52a9935c..cf7a51985a450 100644 --- a/executor/write_test.go +++ b/executor/write_test.go @@ -3930,6 +3930,25 @@ func (s *testSerialSuite) TestIssue20840(c *C) { tk.MustExec("drop table t1") } +func (s *testSerialSuite) TestIssueInsertPrefixIndexForNonUTF8Collation(c *C) { + collate.SetNewCollationEnabledForTest(true) + defer collate.SetNewCollationEnabledForTest(false) + + tk := testkit.NewTestKitWithInit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t1, t2, t3") + tk.MustExec("create table t1 ( c_int int, c_str varchar(40) character set ascii collate ascii_bin, primary key(c_int, c_str(8)) clustered , unique key(c_str))") + tk.MustExec("create table t2 ( c_int int, c_str varchar(40) character set latin1 collate latin1_bin, primary key(c_int, c_str(8)) clustered , unique key(c_str))") + tk.MustExec("insert into t1 values (3, 'fervent brattain')") + tk.MustExec("insert into t2 values (3, 'fervent brattain')") + tk.MustExec("admin check table t1") + tk.MustExec("admin check table t2") + + tk.MustExec("create table t3 (x varchar(40) CHARACTER SET ascii COLLATE ascii_bin, UNIQUE KEY uk(x(4)))") + tk.MustExec("insert into t3 select 'abc '") + tk.MustGetErrCode("insert into t3 select 'abc d'", 1062) +} + func (s *testSerialSuite) TestIssue22496(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") diff --git a/expression/builtin.go b/expression/builtin.go index 9c530f92949e1..a33650eef7b1f 100644 --- a/expression/builtin.go +++ b/expression/builtin.go @@ -687,6 +687,9 @@ var funcs = map[string]functionClass{ ast.Year: &yearFunctionClass{baseFunctionClass{ast.Year, 1, 1}}, ast.YearWeek: &yearWeekFunctionClass{baseFunctionClass{ast.YearWeek, 1, 2}}, ast.LastDay: &lastDayFunctionClass{baseFunctionClass{ast.LastDay, 1, 1}}, + // TSO functions + ast.TiDBBoundedStaleness: &tidbBoundedStalenessFunctionClass{baseFunctionClass{ast.TiDBBoundedStaleness, 2, 2}}, + ast.TiDBParseTso: &tidbParseTsoFunctionClass{baseFunctionClass{ast.TiDBParseTso, 1, 1}}, // string functions ast.ASCII: &asciiFunctionClass{baseFunctionClass{ast.ASCII, 1, 1}}, @@ -881,7 +884,6 @@ var funcs = map[string]functionClass{ // This function is used to show tidb-server version info. ast.TiDBVersion: &tidbVersionFunctionClass{baseFunctionClass{ast.TiDBVersion, 0, 0}}, ast.TiDBIsDDLOwner: &tidbIsDDLOwnerFunctionClass{baseFunctionClass{ast.TiDBIsDDLOwner, 0, 0}}, - ast.TiDBParseTso: &tidbParseTsoFunctionClass{baseFunctionClass{ast.TiDBParseTso, 1, 1}}, ast.TiDBDecodePlan: &tidbDecodePlanFunctionClass{baseFunctionClass{ast.TiDBDecodePlan, 1, 1}}, // TiDB Sequence function. diff --git a/expression/builtin_cast.go b/expression/builtin_cast.go index 0f2d5827c91d9..9d6becab44cbe 100644 --- a/expression/builtin_cast.go +++ b/expression/builtin_cast.go @@ -1804,6 +1804,7 @@ func BuildCastFunction4Union(ctx sessionctx.Context, expr Expression, tp *types. // BuildCastFunction builds a CAST ScalarFunction from the Expression. func BuildCastFunction(ctx sessionctx.Context, expr Expression, tp *types.FieldType) (res Expression) { + expr = TryPushCastIntoControlFunctionForHybridType(ctx, expr, tp) var fc functionClass switch tp.EvalType() { case types.ETInt: @@ -1983,3 +1984,92 @@ func WrapWithCastAsJSON(ctx sessionctx.Context, expr Expression) Expression { } return BuildCastFunction(ctx, expr, tp) } + +// TryPushCastIntoControlFunctionForHybridType try to push cast into control function for Hybrid Type. +// If necessary, it will rebuild control function using changed args. +// When a hybrid type is the output of a control function, the result may be as a numeric type to subsequent calculation +// We should perform the `Cast` operation early to avoid using the wrong type for calculation +// For example, the condition `if(1, e, 'a') = 1`, `if` function will output `e` and compare with `1`. +// If the evaltype is ETString, it will get wrong result. So we can rewrite the condition to +// `IfInt(1, cast(e as int), cast('a' as int)) = 1` to get the correct result. +func TryPushCastIntoControlFunctionForHybridType(ctx sessionctx.Context, expr Expression, tp *types.FieldType) (res Expression) { + sf, ok := expr.(*ScalarFunction) + if !ok { + return expr + } + + var wrapCastFunc func(ctx sessionctx.Context, expr Expression) Expression + switch tp.EvalType() { + case types.ETInt: + wrapCastFunc = WrapWithCastAsInt + case types.ETReal: + wrapCastFunc = WrapWithCastAsReal + default: + return expr + } + + isHybrid := func(ft *types.FieldType) bool { + // todo: compatible with mysql control function using bit type. issue 24725 + return ft.Hybrid() && ft.Tp != mysql.TypeBit + } + + args := sf.GetArgs() + switch sf.FuncName.L { + case ast.If: + if isHybrid(args[1].GetType()) || isHybrid(args[2].GetType()) { + args[1] = wrapCastFunc(ctx, args[1]) + args[2] = wrapCastFunc(ctx, args[2]) + f, err := funcs[ast.If].getFunction(ctx, args) + if err != nil { + return expr + } + sf.RetType, sf.Function = f.getRetTp(), f + return sf + } + case ast.Case: + hasHybrid := false + for i := 0; i < len(args)-1; i += 2 { + hasHybrid = hasHybrid || isHybrid(args[i+1].GetType()) + } + if len(args)%2 == 1 { + hasHybrid = hasHybrid || isHybrid(args[len(args)-1].GetType()) + } + if !hasHybrid { + return expr + } + + for i := 0; i < len(args)-1; i += 2 { + args[i+1] = wrapCastFunc(ctx, args[i+1]) + } + if len(args)%2 == 1 { + args[len(args)-1] = wrapCastFunc(ctx, args[len(args)-1]) + } + f, err := funcs[ast.Case].getFunction(ctx, args) + if err != nil { + return expr + } + sf.RetType, sf.Function = f.getRetTp(), f + return sf + case ast.Elt: + hasHybrid := false + for i := 1; i < len(args); i++ { + hasHybrid = hasHybrid || isHybrid(args[i].GetType()) + } + if !hasHybrid { + return expr + } + + for i := 1; i < len(args); i++ { + args[i] = wrapCastFunc(ctx, args[i]) + } + f, err := funcs[ast.Elt].getFunction(ctx, args) + if err != nil { + return expr + } + sf.RetType, sf.Function = f.getRetTp(), f + return sf + default: + return expr + } + return expr +} diff --git a/expression/builtin_time.go b/expression/builtin_time.go index 1d52cf6adc2c3..67413dab11374 100644 --- a/expression/builtin_time.go +++ b/expression/builtin_time.go @@ -27,6 +27,7 @@ import ( "github.com/cznic/mathutil" "github.com/pingcap/errors" + "github.com/pingcap/failpoint" "github.com/pingcap/parser/ast" "github.com/pingcap/parser/mysql" "github.com/pingcap/parser/terror" @@ -2768,20 +2769,20 @@ func (b *builtinExtractDurationSig) evalInt(row chunk.Row) (int64, bool, error) return res, err != nil, err } -// baseDateArithmitical is the base class for all "builtinAddDateXXXSig" and "builtinSubDateXXXSig", +// baseDateArithmetical is the base class for all "builtinAddDateXXXSig" and "builtinSubDateXXXSig", // which provides parameter getter and date arithmetical calculate functions. -type baseDateArithmitical struct { +type baseDateArithmetical struct { // intervalRegexp is "*Regexp" used to extract string interval for "DAY" unit. intervalRegexp *regexp.Regexp } -func newDateArighmeticalUtil() baseDateArithmitical { - return baseDateArithmitical{ +func newDateArighmeticalUtil() baseDateArithmetical { + return baseDateArithmetical{ intervalRegexp: regexp.MustCompile(`-?[\d]+`), } } -func (du *baseDateArithmitical) getDateFromString(ctx sessionctx.Context, args []Expression, row chunk.Row, unit string) (types.Time, bool, error) { +func (du *baseDateArithmetical) getDateFromString(ctx sessionctx.Context, args []Expression, row chunk.Row, unit string) (types.Time, bool, error) { dateStr, isNull, err := args[0].EvalString(ctx, row) if isNull || err != nil { return types.ZeroTime, true, err @@ -2806,7 +2807,7 @@ func (du *baseDateArithmitical) getDateFromString(ctx sessionctx.Context, args [ return date, false, handleInvalidTimeError(ctx, err) } -func (du *baseDateArithmitical) getDateFromInt(ctx sessionctx.Context, args []Expression, row chunk.Row, unit string) (types.Time, bool, error) { +func (du *baseDateArithmetical) getDateFromInt(ctx sessionctx.Context, args []Expression, row chunk.Row, unit string) (types.Time, bool, error) { dateInt, isNull, err := args[0].EvalInt(ctx, row) if isNull || err != nil { return types.ZeroTime, true, err @@ -2826,7 +2827,7 @@ func (du *baseDateArithmitical) getDateFromInt(ctx sessionctx.Context, args []Ex return date, false, nil } -func (du *baseDateArithmitical) getDateFromDatetime(ctx sessionctx.Context, args []Expression, row chunk.Row, unit string) (types.Time, bool, error) { +func (du *baseDateArithmetical) getDateFromDatetime(ctx sessionctx.Context, args []Expression, row chunk.Row, unit string) (types.Time, bool, error) { date, isNull, err := args[0].EvalTime(ctx, row) if isNull || err != nil { return types.ZeroTime, true, err @@ -2838,7 +2839,7 @@ func (du *baseDateArithmitical) getDateFromDatetime(ctx sessionctx.Context, args return date, false, nil } -func (du *baseDateArithmitical) getIntervalFromString(ctx sessionctx.Context, args []Expression, row chunk.Row, unit string) (string, bool, error) { +func (du *baseDateArithmetical) getIntervalFromString(ctx sessionctx.Context, args []Expression, row chunk.Row, unit string) (string, bool, error) { interval, isNull, err := args[1].EvalString(ctx, row) if isNull || err != nil { return "", true, err @@ -2856,7 +2857,7 @@ func (du *baseDateArithmitical) getIntervalFromString(ctx sessionctx.Context, ar return interval, false, nil } -func (du *baseDateArithmitical) getIntervalFromDecimal(ctx sessionctx.Context, args []Expression, row chunk.Row, unit string) (string, bool, error) { +func (du *baseDateArithmetical) getIntervalFromDecimal(ctx sessionctx.Context, args []Expression, row chunk.Row, unit string) (string, bool, error) { decimal, isNull, err := args[1].EvalDecimal(ctx, row) if isNull || err != nil { return "", true, err @@ -2910,7 +2911,7 @@ func (du *baseDateArithmitical) getIntervalFromDecimal(ctx sessionctx.Context, a return interval, false, nil } -func (du *baseDateArithmitical) getIntervalFromInt(ctx sessionctx.Context, args []Expression, row chunk.Row, unit string) (string, bool, error) { +func (du *baseDateArithmetical) getIntervalFromInt(ctx sessionctx.Context, args []Expression, row chunk.Row, unit string) (string, bool, error) { interval, isNull, err := args[1].EvalInt(ctx, row) if isNull || err != nil { return "", true, err @@ -2918,7 +2919,7 @@ func (du *baseDateArithmitical) getIntervalFromInt(ctx sessionctx.Context, args return strconv.FormatInt(interval, 10), false, nil } -func (du *baseDateArithmitical) getIntervalFromReal(ctx sessionctx.Context, args []Expression, row chunk.Row, unit string) (string, bool, error) { +func (du *baseDateArithmetical) getIntervalFromReal(ctx sessionctx.Context, args []Expression, row chunk.Row, unit string) (string, bool, error) { interval, isNull, err := args[1].EvalReal(ctx, row) if isNull || err != nil { return "", true, err @@ -2926,7 +2927,7 @@ func (du *baseDateArithmitical) getIntervalFromReal(ctx sessionctx.Context, args return strconv.FormatFloat(interval, 'f', args[1].GetType().Decimal, 64), false, nil } -func (du *baseDateArithmitical) add(ctx sessionctx.Context, date types.Time, interval string, unit string) (types.Time, bool, error) { +func (du *baseDateArithmetical) add(ctx sessionctx.Context, date types.Time, interval string, unit string) (types.Time, bool, error) { year, month, day, nano, err := types.ParseDurationValue(unit, interval) if err := handleInvalidTimeError(ctx, err); err != nil { return types.ZeroTime, true, err @@ -2934,7 +2935,7 @@ func (du *baseDateArithmitical) add(ctx sessionctx.Context, date types.Time, int return du.addDate(ctx, date, year, month, day, nano) } -func (du *baseDateArithmitical) addDate(ctx sessionctx.Context, date types.Time, year, month, day, nano int64) (types.Time, bool, error) { +func (du *baseDateArithmetical) addDate(ctx sessionctx.Context, date types.Time, year, month, day, nano int64) (types.Time, bool, error) { goTime, err := date.GoTime(time.UTC) if err := handleInvalidTimeError(ctx, err); err != nil { return types.ZeroTime, true, err @@ -2971,7 +2972,7 @@ func (du *baseDateArithmitical) addDate(ctx sessionctx.Context, date types.Time, return date, false, nil } -func (du *baseDateArithmitical) addDuration(ctx sessionctx.Context, d types.Duration, interval string, unit string) (types.Duration, bool, error) { +func (du *baseDateArithmetical) addDuration(ctx sessionctx.Context, d types.Duration, interval string, unit string) (types.Duration, bool, error) { dur, err := types.ExtractDurationValue(unit, interval) if err != nil { return types.ZeroDuration, true, handleInvalidTimeError(ctx, err) @@ -2983,7 +2984,7 @@ func (du *baseDateArithmitical) addDuration(ctx sessionctx.Context, d types.Dura return retDur, false, nil } -func (du *baseDateArithmitical) subDuration(ctx sessionctx.Context, d types.Duration, interval string, unit string) (types.Duration, bool, error) { +func (du *baseDateArithmetical) subDuration(ctx sessionctx.Context, d types.Duration, interval string, unit string) (types.Duration, bool, error) { dur, err := types.ExtractDurationValue(unit, interval) if err != nil { return types.ZeroDuration, true, handleInvalidTimeError(ctx, err) @@ -2995,7 +2996,7 @@ func (du *baseDateArithmitical) subDuration(ctx sessionctx.Context, d types.Dura return retDur, false, nil } -func (du *baseDateArithmitical) sub(ctx sessionctx.Context, date types.Time, interval string, unit string) (types.Time, bool, error) { +func (du *baseDateArithmetical) sub(ctx sessionctx.Context, date types.Time, interval string, unit string) (types.Time, bool, error) { year, month, day, nano, err := types.ParseDurationValue(unit, interval) if err := handleInvalidTimeError(ctx, err); err != nil { return types.ZeroTime, true, err @@ -3003,7 +3004,7 @@ func (du *baseDateArithmitical) sub(ctx sessionctx.Context, date types.Time, int return du.addDate(ctx, date, -year, -month, -day, -nano) } -func (du *baseDateArithmitical) vecGetDateFromInt(b *baseBuiltinFunc, input *chunk.Chunk, unit string, result *chunk.Column) error { +func (du *baseDateArithmetical) vecGetDateFromInt(b *baseBuiltinFunc, input *chunk.Chunk, unit string, result *chunk.Column) error { n := input.NumRows() buf, err := b.bufAllocator.get(types.ETInt, n) if err != nil { @@ -3045,7 +3046,7 @@ func (du *baseDateArithmitical) vecGetDateFromInt(b *baseBuiltinFunc, input *chu return nil } -func (du *baseDateArithmitical) vecGetDateFromString(b *baseBuiltinFunc, input *chunk.Chunk, unit string, result *chunk.Column) error { +func (du *baseDateArithmetical) vecGetDateFromString(b *baseBuiltinFunc, input *chunk.Chunk, unit string, result *chunk.Column) error { n := input.NumRows() buf, err := b.bufAllocator.get(types.ETString, n) if err != nil { @@ -3089,7 +3090,7 @@ func (du *baseDateArithmitical) vecGetDateFromString(b *baseBuiltinFunc, input * return nil } -func (du *baseDateArithmitical) vecGetDateFromDatetime(b *baseBuiltinFunc, input *chunk.Chunk, unit string, result *chunk.Column) error { +func (du *baseDateArithmetical) vecGetDateFromDatetime(b *baseBuiltinFunc, input *chunk.Chunk, unit string, result *chunk.Column) error { n := input.NumRows() result.ResizeTime(n, false) if err := b.args[0].VecEvalTime(b.ctx, input, result); err != nil { @@ -3110,7 +3111,7 @@ func (du *baseDateArithmitical) vecGetDateFromDatetime(b *baseBuiltinFunc, input return nil } -func (du *baseDateArithmitical) vecGetIntervalFromString(b *baseBuiltinFunc, input *chunk.Chunk, unit string, result *chunk.Column) error { +func (du *baseDateArithmetical) vecGetIntervalFromString(b *baseBuiltinFunc, input *chunk.Chunk, unit string, result *chunk.Column) error { n := input.NumRows() buf, err := b.bufAllocator.get(types.ETString, n) if err != nil { @@ -3147,7 +3148,7 @@ func (du *baseDateArithmitical) vecGetIntervalFromString(b *baseBuiltinFunc, inp return nil } -func (du *baseDateArithmitical) vecGetIntervalFromDecimal(b *baseBuiltinFunc, input *chunk.Chunk, unit string, result *chunk.Column) error { +func (du *baseDateArithmetical) vecGetIntervalFromDecimal(b *baseBuiltinFunc, input *chunk.Chunk, unit string, result *chunk.Column) error { n := input.NumRows() buf, err := b.bufAllocator.get(types.ETDecimal, n) if err != nil { @@ -3248,7 +3249,7 @@ func (du *baseDateArithmitical) vecGetIntervalFromDecimal(b *baseBuiltinFunc, in return nil } -func (du *baseDateArithmitical) vecGetIntervalFromInt(b *baseBuiltinFunc, input *chunk.Chunk, unit string, result *chunk.Column) error { +func (du *baseDateArithmetical) vecGetIntervalFromInt(b *baseBuiltinFunc, input *chunk.Chunk, unit string, result *chunk.Column) error { n := input.NumRows() buf, err := b.bufAllocator.get(types.ETInt, n) if err != nil { @@ -3271,7 +3272,7 @@ func (du *baseDateArithmitical) vecGetIntervalFromInt(b *baseBuiltinFunc, input return nil } -func (du *baseDateArithmitical) vecGetIntervalFromReal(b *baseBuiltinFunc, input *chunk.Chunk, unit string, result *chunk.Column) error { +func (du *baseDateArithmetical) vecGetIntervalFromReal(b *baseBuiltinFunc, input *chunk.Chunk, unit string, result *chunk.Column) error { n := input.NumRows() buf, err := b.bufAllocator.get(types.ETReal, n) if err != nil { @@ -3355,97 +3356,97 @@ func (c *addDateFunctionClass) getFunction(ctx sessionctx.Context, args []Expres case dateEvalTp == types.ETString && intervalEvalTp == types.ETString: sig = &builtinAddDateStringStringSig{ baseBuiltinFunc: bf, - baseDateArithmitical: newDateArighmeticalUtil(), + baseDateArithmetical: newDateArighmeticalUtil(), } sig.setPbCode(tipb.ScalarFuncSig_AddDateStringString) case dateEvalTp == types.ETString && intervalEvalTp == types.ETInt: sig = &builtinAddDateStringIntSig{ baseBuiltinFunc: bf, - baseDateArithmitical: newDateArighmeticalUtil(), + baseDateArithmetical: newDateArighmeticalUtil(), } sig.setPbCode(tipb.ScalarFuncSig_AddDateStringInt) case dateEvalTp == types.ETString && intervalEvalTp == types.ETReal: sig = &builtinAddDateStringRealSig{ baseBuiltinFunc: bf, - baseDateArithmitical: newDateArighmeticalUtil(), + baseDateArithmetical: newDateArighmeticalUtil(), } sig.setPbCode(tipb.ScalarFuncSig_AddDateStringReal) case dateEvalTp == types.ETString && intervalEvalTp == types.ETDecimal: sig = &builtinAddDateStringDecimalSig{ baseBuiltinFunc: bf, - baseDateArithmitical: newDateArighmeticalUtil(), + baseDateArithmetical: newDateArighmeticalUtil(), } sig.setPbCode(tipb.ScalarFuncSig_AddDateStringDecimal) case dateEvalTp == types.ETInt && intervalEvalTp == types.ETString: sig = &builtinAddDateIntStringSig{ baseBuiltinFunc: bf, - baseDateArithmitical: newDateArighmeticalUtil(), + baseDateArithmetical: newDateArighmeticalUtil(), } sig.setPbCode(tipb.ScalarFuncSig_AddDateIntString) case dateEvalTp == types.ETInt && intervalEvalTp == types.ETInt: sig = &builtinAddDateIntIntSig{ baseBuiltinFunc: bf, - baseDateArithmitical: newDateArighmeticalUtil(), + baseDateArithmetical: newDateArighmeticalUtil(), } sig.setPbCode(tipb.ScalarFuncSig_AddDateIntInt) case dateEvalTp == types.ETInt && intervalEvalTp == types.ETReal: sig = &builtinAddDateIntRealSig{ baseBuiltinFunc: bf, - baseDateArithmitical: newDateArighmeticalUtil(), + baseDateArithmetical: newDateArighmeticalUtil(), } sig.setPbCode(tipb.ScalarFuncSig_AddDateIntReal) case dateEvalTp == types.ETInt && intervalEvalTp == types.ETDecimal: sig = &builtinAddDateIntDecimalSig{ baseBuiltinFunc: bf, - baseDateArithmitical: newDateArighmeticalUtil(), + baseDateArithmetical: newDateArighmeticalUtil(), } sig.setPbCode(tipb.ScalarFuncSig_AddDateIntDecimal) case dateEvalTp == types.ETDatetime && intervalEvalTp == types.ETString: sig = &builtinAddDateDatetimeStringSig{ baseBuiltinFunc: bf, - baseDateArithmitical: newDateArighmeticalUtil(), + baseDateArithmetical: newDateArighmeticalUtil(), } sig.setPbCode(tipb.ScalarFuncSig_AddDateDatetimeString) case dateEvalTp == types.ETDatetime && intervalEvalTp == types.ETInt: sig = &builtinAddDateDatetimeIntSig{ baseBuiltinFunc: bf, - baseDateArithmitical: newDateArighmeticalUtil(), + baseDateArithmetical: newDateArighmeticalUtil(), } sig.setPbCode(tipb.ScalarFuncSig_AddDateDatetimeInt) case dateEvalTp == types.ETDatetime && intervalEvalTp == types.ETReal: sig = &builtinAddDateDatetimeRealSig{ baseBuiltinFunc: bf, - baseDateArithmitical: newDateArighmeticalUtil(), + baseDateArithmetical: newDateArighmeticalUtil(), } sig.setPbCode(tipb.ScalarFuncSig_AddDateDatetimeReal) case dateEvalTp == types.ETDatetime && intervalEvalTp == types.ETDecimal: sig = &builtinAddDateDatetimeDecimalSig{ baseBuiltinFunc: bf, - baseDateArithmitical: newDateArighmeticalUtil(), + baseDateArithmetical: newDateArighmeticalUtil(), } sig.setPbCode(tipb.ScalarFuncSig_AddDateDatetimeDecimal) case dateEvalTp == types.ETDuration && intervalEvalTp == types.ETString: sig = &builtinAddDateDurationStringSig{ baseBuiltinFunc: bf, - baseDateArithmitical: newDateArighmeticalUtil(), + baseDateArithmetical: newDateArighmeticalUtil(), } sig.setPbCode(tipb.ScalarFuncSig_AddDateDurationString) case dateEvalTp == types.ETDuration && intervalEvalTp == types.ETInt: sig = &builtinAddDateDurationIntSig{ baseBuiltinFunc: bf, - baseDateArithmitical: newDateArighmeticalUtil(), + baseDateArithmetical: newDateArighmeticalUtil(), } sig.setPbCode(tipb.ScalarFuncSig_AddDateDurationInt) case dateEvalTp == types.ETDuration && intervalEvalTp == types.ETReal: sig = &builtinAddDateDurationRealSig{ baseBuiltinFunc: bf, - baseDateArithmitical: newDateArighmeticalUtil(), + baseDateArithmetical: newDateArighmeticalUtil(), } sig.setPbCode(tipb.ScalarFuncSig_AddDateDurationReal) case dateEvalTp == types.ETDuration && intervalEvalTp == types.ETDecimal: sig = &builtinAddDateDurationDecimalSig{ baseBuiltinFunc: bf, - baseDateArithmitical: newDateArighmeticalUtil(), + baseDateArithmetical: newDateArighmeticalUtil(), } sig.setPbCode(tipb.ScalarFuncSig_AddDateDurationDecimal) } @@ -3454,11 +3455,11 @@ func (c *addDateFunctionClass) getFunction(ctx sessionctx.Context, args []Expres type builtinAddDateStringStringSig struct { baseBuiltinFunc - baseDateArithmitical + baseDateArithmetical } func (b *builtinAddDateStringStringSig) Clone() builtinFunc { - newSig := &builtinAddDateStringStringSig{baseDateArithmitical: b.baseDateArithmitical} + newSig := &builtinAddDateStringStringSig{baseDateArithmetical: b.baseDateArithmetical} newSig.cloneFrom(&b.baseBuiltinFunc) return newSig } @@ -3487,11 +3488,11 @@ func (b *builtinAddDateStringStringSig) evalTime(row chunk.Row) (types.Time, boo type builtinAddDateStringIntSig struct { baseBuiltinFunc - baseDateArithmitical + baseDateArithmetical } func (b *builtinAddDateStringIntSig) Clone() builtinFunc { - newSig := &builtinAddDateStringIntSig{baseDateArithmitical: b.baseDateArithmitical} + newSig := &builtinAddDateStringIntSig{baseDateArithmetical: b.baseDateArithmetical} newSig.cloneFrom(&b.baseBuiltinFunc) return newSig } @@ -3520,11 +3521,11 @@ func (b *builtinAddDateStringIntSig) evalTime(row chunk.Row) (types.Time, bool, type builtinAddDateStringRealSig struct { baseBuiltinFunc - baseDateArithmitical + baseDateArithmetical } func (b *builtinAddDateStringRealSig) Clone() builtinFunc { - newSig := &builtinAddDateStringRealSig{baseDateArithmitical: b.baseDateArithmitical} + newSig := &builtinAddDateStringRealSig{baseDateArithmetical: b.baseDateArithmetical} newSig.cloneFrom(&b.baseBuiltinFunc) return newSig } @@ -3553,11 +3554,11 @@ func (b *builtinAddDateStringRealSig) evalTime(row chunk.Row) (types.Time, bool, type builtinAddDateStringDecimalSig struct { baseBuiltinFunc - baseDateArithmitical + baseDateArithmetical } func (b *builtinAddDateStringDecimalSig) Clone() builtinFunc { - newSig := &builtinAddDateStringDecimalSig{baseDateArithmitical: b.baseDateArithmitical} + newSig := &builtinAddDateStringDecimalSig{baseDateArithmetical: b.baseDateArithmetical} newSig.cloneFrom(&b.baseBuiltinFunc) return newSig } @@ -3586,11 +3587,11 @@ func (b *builtinAddDateStringDecimalSig) evalTime(row chunk.Row) (types.Time, bo type builtinAddDateIntStringSig struct { baseBuiltinFunc - baseDateArithmitical + baseDateArithmetical } func (b *builtinAddDateIntStringSig) Clone() builtinFunc { - newSig := &builtinAddDateIntStringSig{baseDateArithmitical: b.baseDateArithmitical} + newSig := &builtinAddDateIntStringSig{baseDateArithmetical: b.baseDateArithmetical} newSig.cloneFrom(&b.baseBuiltinFunc) return newSig } @@ -3619,11 +3620,11 @@ func (b *builtinAddDateIntStringSig) evalTime(row chunk.Row) (types.Time, bool, type builtinAddDateIntIntSig struct { baseBuiltinFunc - baseDateArithmitical + baseDateArithmetical } func (b *builtinAddDateIntIntSig) Clone() builtinFunc { - newSig := &builtinAddDateIntIntSig{baseDateArithmitical: b.baseDateArithmitical} + newSig := &builtinAddDateIntIntSig{baseDateArithmetical: b.baseDateArithmetical} newSig.cloneFrom(&b.baseBuiltinFunc) return newSig } @@ -3652,11 +3653,11 @@ func (b *builtinAddDateIntIntSig) evalTime(row chunk.Row) (types.Time, bool, err type builtinAddDateIntRealSig struct { baseBuiltinFunc - baseDateArithmitical + baseDateArithmetical } func (b *builtinAddDateIntRealSig) Clone() builtinFunc { - newSig := &builtinAddDateIntRealSig{baseDateArithmitical: b.baseDateArithmitical} + newSig := &builtinAddDateIntRealSig{baseDateArithmetical: b.baseDateArithmetical} newSig.cloneFrom(&b.baseBuiltinFunc) return newSig } @@ -3685,11 +3686,11 @@ func (b *builtinAddDateIntRealSig) evalTime(row chunk.Row) (types.Time, bool, er type builtinAddDateIntDecimalSig struct { baseBuiltinFunc - baseDateArithmitical + baseDateArithmetical } func (b *builtinAddDateIntDecimalSig) Clone() builtinFunc { - newSig := &builtinAddDateIntDecimalSig{baseDateArithmitical: b.baseDateArithmitical} + newSig := &builtinAddDateIntDecimalSig{baseDateArithmetical: b.baseDateArithmetical} newSig.cloneFrom(&b.baseBuiltinFunc) return newSig } @@ -3718,11 +3719,11 @@ func (b *builtinAddDateIntDecimalSig) evalTime(row chunk.Row) (types.Time, bool, type builtinAddDateDatetimeStringSig struct { baseBuiltinFunc - baseDateArithmitical + baseDateArithmetical } func (b *builtinAddDateDatetimeStringSig) Clone() builtinFunc { - newSig := &builtinAddDateDatetimeStringSig{baseDateArithmitical: b.baseDateArithmitical} + newSig := &builtinAddDateDatetimeStringSig{baseDateArithmetical: b.baseDateArithmetical} newSig.cloneFrom(&b.baseBuiltinFunc) return newSig } @@ -3751,11 +3752,11 @@ func (b *builtinAddDateDatetimeStringSig) evalTime(row chunk.Row) (types.Time, b type builtinAddDateDatetimeIntSig struct { baseBuiltinFunc - baseDateArithmitical + baseDateArithmetical } func (b *builtinAddDateDatetimeIntSig) Clone() builtinFunc { - newSig := &builtinAddDateDatetimeIntSig{baseDateArithmitical: b.baseDateArithmitical} + newSig := &builtinAddDateDatetimeIntSig{baseDateArithmetical: b.baseDateArithmetical} newSig.cloneFrom(&b.baseBuiltinFunc) return newSig } @@ -3784,11 +3785,11 @@ func (b *builtinAddDateDatetimeIntSig) evalTime(row chunk.Row) (types.Time, bool type builtinAddDateDatetimeRealSig struct { baseBuiltinFunc - baseDateArithmitical + baseDateArithmetical } func (b *builtinAddDateDatetimeRealSig) Clone() builtinFunc { - newSig := &builtinAddDateDatetimeRealSig{baseDateArithmitical: b.baseDateArithmitical} + newSig := &builtinAddDateDatetimeRealSig{baseDateArithmetical: b.baseDateArithmetical} newSig.cloneFrom(&b.baseBuiltinFunc) return newSig } @@ -3817,11 +3818,11 @@ func (b *builtinAddDateDatetimeRealSig) evalTime(row chunk.Row) (types.Time, boo type builtinAddDateDatetimeDecimalSig struct { baseBuiltinFunc - baseDateArithmitical + baseDateArithmetical } func (b *builtinAddDateDatetimeDecimalSig) Clone() builtinFunc { - newSig := &builtinAddDateDatetimeDecimalSig{baseDateArithmitical: b.baseDateArithmitical} + newSig := &builtinAddDateDatetimeDecimalSig{baseDateArithmetical: b.baseDateArithmetical} newSig.cloneFrom(&b.baseBuiltinFunc) return newSig } @@ -3850,11 +3851,11 @@ func (b *builtinAddDateDatetimeDecimalSig) evalTime(row chunk.Row) (types.Time, type builtinAddDateDurationStringSig struct { baseBuiltinFunc - baseDateArithmitical + baseDateArithmetical } func (b *builtinAddDateDurationStringSig) Clone() builtinFunc { - newSig := &builtinAddDateDurationStringSig{baseDateArithmitical: b.baseDateArithmitical} + newSig := &builtinAddDateDurationStringSig{baseDateArithmetical: b.baseDateArithmetical} newSig.cloneFrom(&b.baseBuiltinFunc) return newSig } @@ -3881,11 +3882,11 @@ func (b *builtinAddDateDurationStringSig) evalDuration(row chunk.Row) (types.Dur type builtinAddDateDurationIntSig struct { baseBuiltinFunc - baseDateArithmitical + baseDateArithmetical } func (b *builtinAddDateDurationIntSig) Clone() builtinFunc { - newSig := &builtinAddDateDurationIntSig{baseDateArithmitical: b.baseDateArithmitical} + newSig := &builtinAddDateDurationIntSig{baseDateArithmetical: b.baseDateArithmetical} newSig.cloneFrom(&b.baseBuiltinFunc) return newSig } @@ -3911,11 +3912,11 @@ func (b *builtinAddDateDurationIntSig) evalDuration(row chunk.Row) (types.Durati type builtinAddDateDurationDecimalSig struct { baseBuiltinFunc - baseDateArithmitical + baseDateArithmetical } func (b *builtinAddDateDurationDecimalSig) Clone() builtinFunc { - newSig := &builtinAddDateDurationDecimalSig{baseDateArithmitical: b.baseDateArithmitical} + newSig := &builtinAddDateDurationDecimalSig{baseDateArithmetical: b.baseDateArithmetical} newSig.cloneFrom(&b.baseBuiltinFunc) return newSig } @@ -3941,11 +3942,11 @@ func (b *builtinAddDateDurationDecimalSig) evalDuration(row chunk.Row) (types.Du type builtinAddDateDurationRealSig struct { baseBuiltinFunc - baseDateArithmitical + baseDateArithmetical } func (b *builtinAddDateDurationRealSig) Clone() builtinFunc { - newSig := &builtinAddDateDurationRealSig{baseDateArithmitical: b.baseDateArithmitical} + newSig := &builtinAddDateDurationRealSig{baseDateArithmetical: b.baseDateArithmetical} newSig.cloneFrom(&b.baseBuiltinFunc) return newSig } @@ -4029,97 +4030,97 @@ func (c *subDateFunctionClass) getFunction(ctx sessionctx.Context, args []Expres case dateEvalTp == types.ETString && intervalEvalTp == types.ETString: sig = &builtinSubDateStringStringSig{ baseBuiltinFunc: bf, - baseDateArithmitical: newDateArighmeticalUtil(), + baseDateArithmetical: newDateArighmeticalUtil(), } sig.setPbCode(tipb.ScalarFuncSig_SubDateStringString) case dateEvalTp == types.ETString && intervalEvalTp == types.ETInt: sig = &builtinSubDateStringIntSig{ baseBuiltinFunc: bf, - baseDateArithmitical: newDateArighmeticalUtil(), + baseDateArithmetical: newDateArighmeticalUtil(), } sig.setPbCode(tipb.ScalarFuncSig_SubDateStringInt) case dateEvalTp == types.ETString && intervalEvalTp == types.ETReal: sig = &builtinSubDateStringRealSig{ baseBuiltinFunc: bf, - baseDateArithmitical: newDateArighmeticalUtil(), + baseDateArithmetical: newDateArighmeticalUtil(), } sig.setPbCode(tipb.ScalarFuncSig_SubDateStringReal) case dateEvalTp == types.ETString && intervalEvalTp == types.ETDecimal: sig = &builtinSubDateStringDecimalSig{ baseBuiltinFunc: bf, - baseDateArithmitical: newDateArighmeticalUtil(), + baseDateArithmetical: newDateArighmeticalUtil(), } sig.setPbCode(tipb.ScalarFuncSig_SubDateStringDecimal) case dateEvalTp == types.ETInt && intervalEvalTp == types.ETString: sig = &builtinSubDateIntStringSig{ baseBuiltinFunc: bf, - baseDateArithmitical: newDateArighmeticalUtil(), + baseDateArithmetical: newDateArighmeticalUtil(), } sig.setPbCode(tipb.ScalarFuncSig_SubDateIntString) case dateEvalTp == types.ETInt && intervalEvalTp == types.ETInt: sig = &builtinSubDateIntIntSig{ baseBuiltinFunc: bf, - baseDateArithmitical: newDateArighmeticalUtil(), + baseDateArithmetical: newDateArighmeticalUtil(), } sig.setPbCode(tipb.ScalarFuncSig_SubDateIntInt) case dateEvalTp == types.ETInt && intervalEvalTp == types.ETReal: sig = &builtinSubDateIntRealSig{ baseBuiltinFunc: bf, - baseDateArithmitical: newDateArighmeticalUtil(), + baseDateArithmetical: newDateArighmeticalUtil(), } sig.setPbCode(tipb.ScalarFuncSig_SubDateIntReal) case dateEvalTp == types.ETInt && intervalEvalTp == types.ETDecimal: sig = &builtinSubDateIntDecimalSig{ baseBuiltinFunc: bf, - baseDateArithmitical: newDateArighmeticalUtil(), + baseDateArithmetical: newDateArighmeticalUtil(), } sig.setPbCode(tipb.ScalarFuncSig_SubDateIntDecimal) case dateEvalTp == types.ETDatetime && intervalEvalTp == types.ETString: sig = &builtinSubDateDatetimeStringSig{ baseBuiltinFunc: bf, - baseDateArithmitical: newDateArighmeticalUtil(), + baseDateArithmetical: newDateArighmeticalUtil(), } sig.setPbCode(tipb.ScalarFuncSig_SubDateDatetimeString) case dateEvalTp == types.ETDatetime && intervalEvalTp == types.ETInt: sig = &builtinSubDateDatetimeIntSig{ baseBuiltinFunc: bf, - baseDateArithmitical: newDateArighmeticalUtil(), + baseDateArithmetical: newDateArighmeticalUtil(), } sig.setPbCode(tipb.ScalarFuncSig_SubDateDatetimeInt) case dateEvalTp == types.ETDatetime && intervalEvalTp == types.ETReal: sig = &builtinSubDateDatetimeRealSig{ baseBuiltinFunc: bf, - baseDateArithmitical: newDateArighmeticalUtil(), + baseDateArithmetical: newDateArighmeticalUtil(), } sig.setPbCode(tipb.ScalarFuncSig_SubDateDatetimeReal) case dateEvalTp == types.ETDatetime && intervalEvalTp == types.ETDecimal: sig = &builtinSubDateDatetimeDecimalSig{ baseBuiltinFunc: bf, - baseDateArithmitical: newDateArighmeticalUtil(), + baseDateArithmetical: newDateArighmeticalUtil(), } sig.setPbCode(tipb.ScalarFuncSig_SubDateDatetimeDecimal) case dateEvalTp == types.ETDuration && intervalEvalTp == types.ETString: sig = &builtinSubDateDurationStringSig{ baseBuiltinFunc: bf, - baseDateArithmitical: newDateArighmeticalUtil(), + baseDateArithmetical: newDateArighmeticalUtil(), } sig.setPbCode(tipb.ScalarFuncSig_SubDateDurationString) case dateEvalTp == types.ETDuration && intervalEvalTp == types.ETInt: sig = &builtinSubDateDurationIntSig{ baseBuiltinFunc: bf, - baseDateArithmitical: newDateArighmeticalUtil(), + baseDateArithmetical: newDateArighmeticalUtil(), } sig.setPbCode(tipb.ScalarFuncSig_SubDateDurationInt) case dateEvalTp == types.ETDuration && intervalEvalTp == types.ETReal: sig = &builtinSubDateDurationRealSig{ baseBuiltinFunc: bf, - baseDateArithmitical: newDateArighmeticalUtil(), + baseDateArithmetical: newDateArighmeticalUtil(), } sig.setPbCode(tipb.ScalarFuncSig_SubDateDurationReal) case dateEvalTp == types.ETDuration && intervalEvalTp == types.ETDecimal: sig = &builtinSubDateDurationDecimalSig{ baseBuiltinFunc: bf, - baseDateArithmitical: newDateArighmeticalUtil(), + baseDateArithmetical: newDateArighmeticalUtil(), } sig.setPbCode(tipb.ScalarFuncSig_SubDateDurationDecimal) } @@ -4128,11 +4129,11 @@ func (c *subDateFunctionClass) getFunction(ctx sessionctx.Context, args []Expres type builtinSubDateStringStringSig struct { baseBuiltinFunc - baseDateArithmitical + baseDateArithmetical } func (b *builtinSubDateStringStringSig) Clone() builtinFunc { - newSig := &builtinSubDateStringStringSig{baseDateArithmitical: b.baseDateArithmitical} + newSig := &builtinSubDateStringStringSig{baseDateArithmetical: b.baseDateArithmetical} newSig.cloneFrom(&b.baseBuiltinFunc) return newSig } @@ -4161,11 +4162,11 @@ func (b *builtinSubDateStringStringSig) evalTime(row chunk.Row) (types.Time, boo type builtinSubDateStringIntSig struct { baseBuiltinFunc - baseDateArithmitical + baseDateArithmetical } func (b *builtinSubDateStringIntSig) Clone() builtinFunc { - newSig := &builtinSubDateStringIntSig{baseDateArithmitical: b.baseDateArithmitical} + newSig := &builtinSubDateStringIntSig{baseDateArithmetical: b.baseDateArithmetical} newSig.cloneFrom(&b.baseBuiltinFunc) return newSig } @@ -4194,11 +4195,11 @@ func (b *builtinSubDateStringIntSig) evalTime(row chunk.Row) (types.Time, bool, type builtinSubDateStringRealSig struct { baseBuiltinFunc - baseDateArithmitical + baseDateArithmetical } func (b *builtinSubDateStringRealSig) Clone() builtinFunc { - newSig := &builtinSubDateStringRealSig{baseDateArithmitical: b.baseDateArithmitical} + newSig := &builtinSubDateStringRealSig{baseDateArithmetical: b.baseDateArithmetical} newSig.cloneFrom(&b.baseBuiltinFunc) return newSig } @@ -4227,11 +4228,11 @@ func (b *builtinSubDateStringRealSig) evalTime(row chunk.Row) (types.Time, bool, type builtinSubDateStringDecimalSig struct { baseBuiltinFunc - baseDateArithmitical + baseDateArithmetical } func (b *builtinSubDateStringDecimalSig) Clone() builtinFunc { - newSig := &builtinSubDateStringDecimalSig{baseDateArithmitical: b.baseDateArithmitical} + newSig := &builtinSubDateStringDecimalSig{baseDateArithmetical: b.baseDateArithmetical} newSig.cloneFrom(&b.baseBuiltinFunc) return newSig } @@ -4258,11 +4259,11 @@ func (b *builtinSubDateStringDecimalSig) evalTime(row chunk.Row) (types.Time, bo type builtinSubDateIntStringSig struct { baseBuiltinFunc - baseDateArithmitical + baseDateArithmetical } func (b *builtinSubDateIntStringSig) Clone() builtinFunc { - newSig := &builtinSubDateIntStringSig{baseDateArithmitical: b.baseDateArithmitical} + newSig := &builtinSubDateIntStringSig{baseDateArithmetical: b.baseDateArithmetical} newSig.cloneFrom(&b.baseBuiltinFunc) return newSig } @@ -4291,11 +4292,11 @@ func (b *builtinSubDateIntStringSig) evalTime(row chunk.Row) (types.Time, bool, type builtinSubDateIntIntSig struct { baseBuiltinFunc - baseDateArithmitical + baseDateArithmetical } func (b *builtinSubDateIntIntSig) Clone() builtinFunc { - newSig := &builtinSubDateIntIntSig{baseDateArithmitical: b.baseDateArithmitical} + newSig := &builtinSubDateIntIntSig{baseDateArithmetical: b.baseDateArithmetical} newSig.cloneFrom(&b.baseBuiltinFunc) return newSig } @@ -4324,11 +4325,11 @@ func (b *builtinSubDateIntIntSig) evalTime(row chunk.Row) (types.Time, bool, err type builtinSubDateIntRealSig struct { baseBuiltinFunc - baseDateArithmitical + baseDateArithmetical } func (b *builtinSubDateIntRealSig) Clone() builtinFunc { - newSig := &builtinSubDateIntRealSig{baseDateArithmitical: b.baseDateArithmitical} + newSig := &builtinSubDateIntRealSig{baseDateArithmetical: b.baseDateArithmetical} newSig.cloneFrom(&b.baseBuiltinFunc) return newSig } @@ -4357,16 +4358,16 @@ func (b *builtinSubDateIntRealSig) evalTime(row chunk.Row) (types.Time, bool, er type builtinSubDateDatetimeStringSig struct { baseBuiltinFunc - baseDateArithmitical + baseDateArithmetical } type builtinSubDateIntDecimalSig struct { baseBuiltinFunc - baseDateArithmitical + baseDateArithmetical } func (b *builtinSubDateIntDecimalSig) Clone() builtinFunc { - newSig := &builtinSubDateIntDecimalSig{baseDateArithmitical: b.baseDateArithmitical} + newSig := &builtinSubDateIntDecimalSig{baseDateArithmetical: b.baseDateArithmetical} newSig.cloneFrom(&b.baseBuiltinFunc) return newSig } @@ -4394,7 +4395,7 @@ func (b *builtinSubDateIntDecimalSig) evalTime(row chunk.Row) (types.Time, bool, } func (b *builtinSubDateDatetimeStringSig) Clone() builtinFunc { - newSig := &builtinSubDateDatetimeStringSig{baseDateArithmitical: b.baseDateArithmitical} + newSig := &builtinSubDateDatetimeStringSig{baseDateArithmetical: b.baseDateArithmetical} newSig.cloneFrom(&b.baseBuiltinFunc) return newSig } @@ -4423,11 +4424,11 @@ func (b *builtinSubDateDatetimeStringSig) evalTime(row chunk.Row) (types.Time, b type builtinSubDateDatetimeIntSig struct { baseBuiltinFunc - baseDateArithmitical + baseDateArithmetical } func (b *builtinSubDateDatetimeIntSig) Clone() builtinFunc { - newSig := &builtinSubDateDatetimeIntSig{baseDateArithmitical: b.baseDateArithmitical} + newSig := &builtinSubDateDatetimeIntSig{baseDateArithmetical: b.baseDateArithmetical} newSig.cloneFrom(&b.baseBuiltinFunc) return newSig } @@ -4456,11 +4457,11 @@ func (b *builtinSubDateDatetimeIntSig) evalTime(row chunk.Row) (types.Time, bool type builtinSubDateDatetimeRealSig struct { baseBuiltinFunc - baseDateArithmitical + baseDateArithmetical } func (b *builtinSubDateDatetimeRealSig) Clone() builtinFunc { - newSig := &builtinSubDateDatetimeRealSig{baseDateArithmitical: b.baseDateArithmitical} + newSig := &builtinSubDateDatetimeRealSig{baseDateArithmetical: b.baseDateArithmetical} newSig.cloneFrom(&b.baseBuiltinFunc) return newSig } @@ -4489,11 +4490,11 @@ func (b *builtinSubDateDatetimeRealSig) evalTime(row chunk.Row) (types.Time, boo type builtinSubDateDatetimeDecimalSig struct { baseBuiltinFunc - baseDateArithmitical + baseDateArithmetical } func (b *builtinSubDateDatetimeDecimalSig) Clone() builtinFunc { - newSig := &builtinSubDateDatetimeDecimalSig{baseDateArithmitical: b.baseDateArithmitical} + newSig := &builtinSubDateDatetimeDecimalSig{baseDateArithmetical: b.baseDateArithmetical} newSig.cloneFrom(&b.baseBuiltinFunc) return newSig } @@ -4522,11 +4523,11 @@ func (b *builtinSubDateDatetimeDecimalSig) evalTime(row chunk.Row) (types.Time, type builtinSubDateDurationStringSig struct { baseBuiltinFunc - baseDateArithmitical + baseDateArithmetical } func (b *builtinSubDateDurationStringSig) Clone() builtinFunc { - newSig := &builtinSubDateDurationStringSig{baseDateArithmitical: b.baseDateArithmitical} + newSig := &builtinSubDateDurationStringSig{baseDateArithmetical: b.baseDateArithmetical} newSig.cloneFrom(&b.baseBuiltinFunc) return newSig } @@ -4553,11 +4554,11 @@ func (b *builtinSubDateDurationStringSig) evalDuration(row chunk.Row) (types.Dur type builtinSubDateDurationIntSig struct { baseBuiltinFunc - baseDateArithmitical + baseDateArithmetical } func (b *builtinSubDateDurationIntSig) Clone() builtinFunc { - newSig := &builtinSubDateDurationIntSig{baseDateArithmitical: b.baseDateArithmitical} + newSig := &builtinSubDateDurationIntSig{baseDateArithmetical: b.baseDateArithmetical} newSig.cloneFrom(&b.baseBuiltinFunc) return newSig } @@ -4584,11 +4585,11 @@ func (b *builtinSubDateDurationIntSig) evalDuration(row chunk.Row) (types.Durati type builtinSubDateDurationDecimalSig struct { baseBuiltinFunc - baseDateArithmitical + baseDateArithmetical } func (b *builtinSubDateDurationDecimalSig) Clone() builtinFunc { - newSig := &builtinSubDateDurationDecimalSig{baseDateArithmitical: b.baseDateArithmitical} + newSig := &builtinSubDateDurationDecimalSig{baseDateArithmetical: b.baseDateArithmetical} newSig.cloneFrom(&b.baseBuiltinFunc) return newSig } @@ -4615,11 +4616,11 @@ func (b *builtinSubDateDurationDecimalSig) evalDuration(row chunk.Row) (types.Du type builtinSubDateDurationRealSig struct { baseBuiltinFunc - baseDateArithmitical + baseDateArithmetical } func (b *builtinSubDateDurationRealSig) Clone() builtinFunc { - newSig := &builtinSubDateDurationRealSig{baseDateArithmitical: b.baseDateArithmitical} + newSig := &builtinSubDateDurationRealSig{baseDateArithmetical: b.baseDateArithmetical} newSig.cloneFrom(&b.baseBuiltinFunc) return newSig } @@ -7113,3 +7114,97 @@ func handleInvalidZeroTime(ctx sessionctx.Context, t types.Time) (bool, error) { } return true, handleInvalidTimeError(ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, t.String())) } + +// tidbBoundedStalenessFunctionClass reads a time window [a, b] and compares it with the latest SafeTS +// to determine which TS to use in a read only transaction. +type tidbBoundedStalenessFunctionClass struct { + baseFunctionClass +} + +func (c *tidbBoundedStalenessFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETDatetime, types.ETDatetime, types.ETDatetime) + if err != nil { + return nil, err + } + sig := &builtinTiDBBoundedStalenessSig{bf} + return sig, nil +} + +type builtinTiDBBoundedStalenessSig struct { + baseBuiltinFunc +} + +func (b *builtinTiDBBoundedStalenessSig) Clone() builtinFunc { + newSig := &builtinTidbParseTsoSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (b *builtinTiDBBoundedStalenessSig) evalTime(row chunk.Row) (types.Time, bool, error) { + leftTime, isNull, err := b.args[0].EvalTime(b.ctx, row) + if isNull || err != nil { + return types.ZeroTime, true, handleInvalidTimeError(b.ctx, err) + } + rightTime, isNull, err := b.args[1].EvalTime(b.ctx, row) + if isNull || err != nil { + return types.ZeroTime, true, handleInvalidTimeError(b.ctx, err) + } + if invalidLeftTime, invalidRightTime := leftTime.InvalidZero(), rightTime.InvalidZero(); invalidLeftTime || invalidRightTime { + if invalidLeftTime { + err = handleInvalidTimeError(b.ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, leftTime.String())) + } + if invalidRightTime { + err = handleInvalidTimeError(b.ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, rightTime.String())) + } + return types.ZeroTime, true, err + } + timeZone := getTimeZone(b.ctx) + minTime, err := leftTime.GoTime(timeZone) + if err != nil { + return types.ZeroTime, true, err + } + maxTime, err := rightTime.GoTime(timeZone) + if err != nil { + return types.ZeroTime, true, err + } + if minTime.After(maxTime) { + return types.ZeroTime, true, nil + } + // Because the minimum unit of a TSO is millisecond, so we only need fsp to be 3. + return types.NewTime(types.FromGoTime(calAppropriateTime(minTime, maxTime, getMinSafeTime(b.ctx, timeZone))), mysql.TypeDatetime, 3), false, nil +} + +func getMinSafeTime(sessionCtx sessionctx.Context, timeZone *time.Location) time.Time { + var minSafeTS uint64 + if store := sessionCtx.GetStore(); store != nil { + minSafeTS = store.GetMinSafeTS(sessionCtx.GetSessionVars().CheckAndGetTxnScope()) + } + // Inject mocked SafeTS for test. + failpoint.Inject("injectSafeTS", func(val failpoint.Value) { + injectTS := val.(int) + minSafeTS = uint64(injectTS) + }) + // Try to get from the stmt cache to make sure this function is deterministic. + stmtCtx := sessionCtx.GetSessionVars().StmtCtx + minSafeTS = stmtCtx.GetOrStoreStmtCache(stmtctx.StmtSafeTSCacheKey, minSafeTS).(uint64) + return oracle.GetTimeFromTS(minSafeTS).In(timeZone) +} + +// For a SafeTS t and a time range [t1, t2]: +// 1. If t < t1, we will use t1 as the result, +// and with it, a read request may fail because it's an unreached SafeTS. +// 2. If t1 <= t <= t2, we will use t as the result, and with it, +// a read request won't fail. +// 2. If t2 < t, we will use t2 as the result, +// and with it, a read request won't fail because it's bigger than the latest SafeTS. +func calAppropriateTime(minTime, maxTime, minSafeTime time.Time) time.Time { + if minSafeTime.Before(minTime) { + return minTime + } else if minSafeTime.After(maxTime) { + return maxTime + } + return minSafeTime +} diff --git a/expression/builtin_time_test.go b/expression/builtin_time_test.go index f82f6fb8f76ea..e247e8756ae9a 100644 --- a/expression/builtin_time_test.go +++ b/expression/builtin_time_test.go @@ -14,12 +14,14 @@ package expression import ( + "fmt" "math" "strings" "time" . "github.com/pingcap/check" "github.com/pingcap/errors" + "github.com/pingcap/failpoint" "github.com/pingcap/parser/ast" "github.com/pingcap/parser/charset" "github.com/pingcap/parser/mysql" @@ -27,6 +29,7 @@ import ( "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/sessionctx/variable" + "github.com/pingcap/tidb/store/tikv/oracle" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/mock" @@ -804,7 +807,7 @@ func (s *testEvaluatorSuite) TestTime(c *C) { } func resetStmtContext(ctx sessionctx.Context) { - ctx.GetSessionVars().StmtCtx.ResetNowTs() + ctx.GetSessionVars().StmtCtx.ResetStmtCache() } func (s *testEvaluatorSuite) TestNowAndUTCTimestamp(c *C) { @@ -2854,8 +2857,105 @@ func (s *testEvaluatorSuite) TestTidbParseTso(c *C) { } } +func (s *testEvaluatorSuite) TestTiDBBoundedStaleness(c *C) { + t1, err := time.Parse(types.TimeFormat, "2015-09-21 09:53:04") + c.Assert(err, IsNil) + // time.Parse uses UTC time zone by default, we need to change it to Local manually. + t1 = t1.Local() + t1Str := t1.Format(types.TimeFormat) + t2 := time.Now() + t2Str := t2.Format(types.TimeFormat) + timeZone := time.Local + s.ctx.GetSessionVars().TimeZone = timeZone + tests := []struct { + leftTime interface{} + rightTime interface{} + injectSafeTS uint64 + isNull bool + expect time.Time + }{ + // SafeTS is in the range. + { + leftTime: t1Str, + rightTime: t2Str, + injectSafeTS: oracle.GoTimeToTS(t2.Add(-1 * time.Second)), + isNull: false, + expect: t2.Add(-1 * time.Second), + }, + // SafeTS is less than the left time. + { + leftTime: t1Str, + rightTime: t2Str, + injectSafeTS: oracle.GoTimeToTS(t1.Add(-1 * time.Second)), + isNull: false, + expect: t1, + }, + // SafeTS is bigger than the right time. + { + leftTime: t1Str, + rightTime: t2Str, + injectSafeTS: oracle.GoTimeToTS(t2.Add(time.Second)), + isNull: false, + expect: t2, + }, + // Wrong time order. + { + leftTime: t2Str, + rightTime: t1Str, + injectSafeTS: 0, + isNull: true, + expect: time.Time{}, + }, + } + + fc := funcs[ast.TiDBBoundedStaleness] + for _, test := range tests { + c.Assert(failpoint.Enable("github.com/pingcap/tidb/expression/injectSafeTS", + fmt.Sprintf("return(%v)", test.injectSafeTS)), IsNil) + f, err := fc.getFunction(s.ctx, s.datumsToConstants([]types.Datum{types.NewDatum(test.leftTime), types.NewDatum(test.rightTime)})) + c.Assert(err, IsNil) + d, err := evalBuiltinFunc(f, chunk.Row{}) + c.Assert(err, IsNil) + if test.isNull { + c.Assert(d.IsNull(), IsTrue) + } else { + goTime, err := d.GetMysqlTime().GoTime(timeZone) + c.Assert(err, IsNil) + c.Assert(goTime.Format(types.TimeFormat), Equals, test.expect.Format(types.TimeFormat)) + } + resetStmtContext(s.ctx) + } + + // Test whether it's deterministic. + safeTime1 := t2.Add(-1 * time.Second) + safeTS1 := oracle.ComposeTS(safeTime1.Unix()*1000, 0) + c.Assert(failpoint.Enable("github.com/pingcap/tidb/expression/injectSafeTS", + fmt.Sprintf("return(%v)", safeTS1)), IsNil) + f, err := fc.getFunction(s.ctx, s.datumsToConstants([]types.Datum{types.NewDatum(t1Str), types.NewDatum(t2Str)})) + c.Assert(err, IsNil) + d, err := evalBuiltinFunc(f, chunk.Row{}) + c.Assert(err, IsNil) + goTime, err := d.GetMysqlTime().GoTime(timeZone) + c.Assert(err, IsNil) + resultTime := goTime.Format(types.TimeFormat) + c.Assert(resultTime, Equals, safeTime1.Format(types.TimeFormat)) + // SafeTS updated. + safeTime2 := t2.Add(1 * time.Second) + safeTS2 := oracle.ComposeTS(safeTime2.Unix()*1000, 0) + c.Assert(failpoint.Enable("github.com/pingcap/tidb/expression/injectSafeTS", + fmt.Sprintf("return(%v)", safeTS2)), IsNil) + f, err = fc.getFunction(s.ctx, s.datumsToConstants([]types.Datum{types.NewDatum(t1Str), types.NewDatum(t2Str)})) + c.Assert(err, IsNil) + d, err = evalBuiltinFunc(f, chunk.Row{}) + c.Assert(err, IsNil) + // Still safeTime1 + c.Assert(resultTime, Equals, safeTime1.Format(types.TimeFormat)) + resetStmtContext(s.ctx) + failpoint.Disable("github.com/pingcap/tidb/expression/injectSafeTS") +} + func (s *testEvaluatorSuite) TestGetIntervalFromDecimal(c *C) { - du := baseDateArithmitical{} + du := baseDateArithmetical{} tests := []struct { param string diff --git a/expression/builtin_time_vec.go b/expression/builtin_time_vec.go index 94c1cd8b6f0c4..6f74a8f587e50 100644 --- a/expression/builtin_time_vec.go +++ b/expression/builtin_time_vec.go @@ -854,6 +854,70 @@ func (b *builtinTidbParseTsoSig) vecEvalTime(input *chunk.Chunk, result *chunk.C return nil } +func (b *builtinTiDBBoundedStalenessSig) vectorized() bool { + return true +} + +func (b *builtinTiDBBoundedStalenessSig) vecEvalTime(input *chunk.Chunk, result *chunk.Column) error { + n := input.NumRows() + buf0, err := b.bufAllocator.get(types.ETDatetime, n) + if err != nil { + return err + } + defer b.bufAllocator.put(buf0) + if err = b.args[0].VecEvalTime(b.ctx, input, buf0); err != nil { + return err + } + buf1, err := b.bufAllocator.get(types.ETDatetime, n) + if err != nil { + return err + } + defer b.bufAllocator.put(buf1) + if err = b.args[1].VecEvalTime(b.ctx, input, buf1); err != nil { + return err + } + args0 := buf0.Times() + args1 := buf1.Times() + timeZone := getTimeZone(b.ctx) + minSafeTime := getMinSafeTime(b.ctx, timeZone) + result.ResizeTime(n, false) + result.MergeNulls(buf0, buf1) + times := result.Times() + for i := 0; i < n; i++ { + if result.IsNull(i) { + continue + } + if invalidArg0, invalidArg1 := args0[i].InvalidZero(), args1[i].InvalidZero(); invalidArg0 || invalidArg1 { + if invalidArg0 { + err = handleInvalidTimeError(b.ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, args0[i].String())) + } + if invalidArg1 { + err = handleInvalidTimeError(b.ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, args1[i].String())) + } + if err != nil { + return err + } + result.SetNull(i, true) + continue + } + minTime, err := args0[i].GoTime(timeZone) + if err != nil { + return err + } + maxTime, err := args1[i].GoTime(timeZone) + if err != nil { + return err + } + if minTime.After(maxTime) { + result.SetNull(i, true) + continue + } + // Because the minimum unit of a TSO is millisecond, so we only need fsp to be 3. + times[i] = types.NewTime(types.FromGoTime(calAppropriateTime(minTime, maxTime, minSafeTime)), mysql.TypeDatetime, 3) + } + return nil +} + func (b *builtinFromDaysSig) vectorized() bool { return true } diff --git a/expression/builtin_time_vec_test.go b/expression/builtin_time_vec_test.go index 593cce162d7ff..a757b867b783c 100644 --- a/expression/builtin_time_vec_test.go +++ b/expression/builtin_time_vec_test.go @@ -519,6 +519,13 @@ var vecBuiltinTimeCases = map[string][]vecExprBenchCase{ geners: []dataGenerator{newRangeInt64Gener(0, math.MaxInt64)}, }, }, + // Todo: how to inject the safeTS for better testing. + ast.TiDBBoundedStaleness: { + { + retEvalType: types.ETDatetime, + childrenTypes: []types.EvalType{types.ETDatetime, types.ETDatetime}, + }, + }, ast.LastDay: { {retEvalType: types.ETDatetime, childrenTypes: []types.EvalType{types.ETDatetime}}, }, diff --git a/expression/helper.go b/expression/helper.go index c5f91dbd090b5..d9f1e22610b62 100644 --- a/expression/helper.go +++ b/expression/helper.go @@ -22,6 +22,7 @@ import ( "github.com/pingcap/parser/mysql" "github.com/pingcap/parser/terror" "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/types" driver "github.com/pingcap/tidb/types/parser_driver" @@ -155,5 +156,5 @@ func getStmtTimestamp(ctx sessionctx.Context) (time.Time, error) { return time.Unix(timestamp, 0), nil } stmtCtx := ctx.GetSessionVars().StmtCtx - return stmtCtx.GetNowTsCached(), nil + return stmtCtx.GetOrStoreStmtCache(stmtctx.StmtNowTsCacheKey, time.Now()).(time.Time), nil } diff --git a/expression/integration_test.go b/expression/integration_test.go index 80e39b76ce746..a0fb6fd8b5499 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -41,6 +41,7 @@ import ( "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/store/mockstore" + "github.com/pingcap/tidb/store/tikv/oracle" "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/table/tables" "github.com/pingcap/tidb/tablecodec" @@ -2263,6 +2264,79 @@ func (s *testIntegrationSuite2) TestTimeBuiltin(c *C) { result = tk.MustQuery(`select tidb_parse_tso(-1)`) result.Check(testkit.Rows("")) + // for tidb_bounded_staleness + tk.MustExec("SET time_zone = '+00:00';") + t := time.Now().UTC() + ts := oracle.GoTimeToTS(t) + tidbBoundedStalenessTests := []struct { + sql string + injectSafeTS uint64 + expect string + }{ + { + sql: `select tidb_bounded_staleness(DATE_SUB(NOW(), INTERVAL 600 SECOND), DATE_ADD(NOW(), INTERVAL 600 SECOND))`, + injectSafeTS: ts, + expect: t.Format(types.TimeFSPFormat[:len(types.TimeFSPFormat)-3]), + }, + { + sql: `select tidb_bounded_staleness("2021-04-27 12:00:00.000", "2021-04-27 13:00:00.000")`, + injectSafeTS: func() uint64 { + t, err := time.Parse("2006-01-02 15:04:05.000", "2021-04-27 13:30:04.877") + c.Assert(err, IsNil) + return oracle.GoTimeToTS(t) + }(), + expect: "2021-04-27 13:00:00.000", + }, + { + sql: `select tidb_bounded_staleness("2021-04-27 12:00:00.000", "2021-04-27 13:00:00.000")`, + injectSafeTS: func() uint64 { + t, err := time.Parse("2006-01-02 15:04:05.000", "2021-04-27 11:30:04.877") + c.Assert(err, IsNil) + return oracle.GoTimeToTS(t) + }(), + expect: "2021-04-27 12:00:00.000", + }, + { + sql: `select tidb_bounded_staleness("2021-04-27 12:00:00.000", "2021-04-27 11:00:00.000")`, + injectSafeTS: 0, + expect: "", + }, + // Time is too small. + { + sql: `select tidb_bounded_staleness("0020-04-27 12:00:00.000", "2021-04-27 11:00:00.000")`, + injectSafeTS: 0, + expect: "1970-01-01 00:00:00.000", + }, + // Wrong value. + { + sql: `select tidb_bounded_staleness(1, 2)`, + injectSafeTS: 0, + expect: "", + }, + { + sql: `select tidb_bounded_staleness("invalid_time_1", "invalid_time_2")`, + injectSafeTS: 0, + expect: "", + }, + } + for _, test := range tidbBoundedStalenessTests { + c.Assert(failpoint.Enable("github.com/pingcap/tidb/expression/injectSafeTS", + fmt.Sprintf("return(%v)", test.injectSafeTS)), IsNil) + result = tk.MustQuery(test.sql) + result.Check(testkit.Rows(test.expect)) + } + failpoint.Disable("github.com/pingcap/tidb/expression/injectSafeTS") + // test whether tidb_bounded_staleness is deterministic + result = tk.MustQuery(`select tidb_bounded_staleness(NOW(), DATE_ADD(NOW(), INTERVAL 600 SECOND)), tidb_bounded_staleness(NOW(), DATE_ADD(NOW(), INTERVAL 600 SECOND))`) + c.Assert(result.Rows()[0], HasLen, 2) + c.Assert(result.Rows()[0][0], Equals, result.Rows()[0][1]) + preResult := result.Rows()[0][0] + time.Sleep(time.Second) + result = tk.MustQuery(`select tidb_bounded_staleness(NOW(), DATE_ADD(NOW(), INTERVAL 600 SECOND)), tidb_bounded_staleness(NOW(), DATE_ADD(NOW(), INTERVAL 600 SECOND))`) + c.Assert(result.Rows()[0], HasLen, 2) + c.Assert(result.Rows()[0][0], Equals, result.Rows()[0][1]) + c.Assert(result.Rows()[0][0], Not(Equals), preResult) + // fix issue 10308 result = tk.MustQuery("select time(\"- -\");") result.Check(testkit.Rows("00:00:00")) @@ -9361,3 +9435,76 @@ func (s *testIntegrationSuite) TestEnumIndex(c *C) { tk.MustQuery("select /*+ use_index(t,idx) */ col3 from t where col2 = 'b' and col1 is not null;").Check( testkit.Rows("2")) } + +func (s *testIntegrationSuite) TestControlFunctionWithEnumOrSet(c *C) { + defer s.cleanEnv(c) + + // issue 23114 + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists e;") + tk.MustExec("create table e(e enum('c', 'b', 'a'));") + tk.MustExec("insert into e values ('a'),('b'),('a'),('b');") + tk.MustQuery("select e from e where if(e>1, e, e);").Sort().Check( + testkit.Rows("a", "a", "b", "b")) + tk.MustQuery("select e from e where case e when 1 then e else e end;").Sort().Check( + testkit.Rows("a", "a", "b", "b")) + tk.MustQuery("select e from e where case 1 when e then e end;").Check(testkit.Rows()) + + tk.MustQuery("select if(e>1,e,e)='a' from e").Sort().Check( + testkit.Rows("0", "0", "1", "1")) + tk.MustQuery("select if(e>1,e,e)=1 from e").Sort().Check( + testkit.Rows("0", "0", "0", "0")) + // if and if + tk.MustQuery("select if(e>2,e,e) and if(e<=2,e,e) from e;").Sort().Check( + testkit.Rows("1", "1", "1", "1")) + tk.MustQuery("select if(e>2,e,e) and (if(e<3,0,e) or if(e>=2,0,e)) from e;").Sort().Check( + testkit.Rows("0", "0", "1", "1")) + tk.MustQuery("select * from e where if(e>2,e,e) and if(e<=2,e,e);").Sort().Check( + testkit.Rows("a", "a", "b", "b")) + tk.MustQuery("select * from e where if(e>2,e,e) and (if(e<3,0,e) or if(e>=2,0,e));").Sort().Check( + testkit.Rows("a", "a")) + + // issue 24494 + tk.MustExec("drop table if exists t;") + tk.MustExec("create table t(a int,b enum(\"b\",\"y\",\"1\"));") + tk.MustExec("insert into t values(0,\"y\"),(1,\"b\"),(null,null),(2,\"1\");") + tk.MustQuery("SELECT count(*) FROM t where if(a,b ,null);").Check(testkit.Rows("2")) + + tk.MustExec("drop table if exists t;") + tk.MustExec("create table t(a int,b enum(\"b\"),c enum(\"c\"));") + tk.MustExec("insert into t values(1,1,1),(2,1,1),(1,1,1),(2,1,1);") + tk.MustQuery("select a from t where if(a=1,b,c)=\"b\";").Check(testkit.Rows("1", "1")) + tk.MustQuery("select a from t where if(a=1,b,c)=\"c\";").Check(testkit.Rows("2", "2")) + tk.MustQuery("select a from t where if(a=1,b,c)=1;").Sort().Check(testkit.Rows("1", "1", "2", "2")) + tk.MustQuery("select a from t where if(a=1,b,c);").Sort().Check(testkit.Rows("1", "1", "2", "2")) + + tk.MustExec("drop table if exists e;") + tk.MustExec("create table e(e enum('c', 'b', 'a'));") + tk.MustExec("insert into e values(3)") + tk.MustQuery("select elt(1,e) = 'a' from e").Check(testkit.Rows("1")) + tk.MustQuery("select elt(1,e) = 3 from e").Check(testkit.Rows("1")) + tk.MustQuery("select e from e where elt(1,e)").Check(testkit.Rows("a")) + + // test set type + tk.MustExec("drop table if exists s;") + tk.MustExec("create table s(s set('c', 'b', 'a'));") + tk.MustExec("insert into s values ('a'),('b'),('a'),('b');") + tk.MustQuery("select s from s where if(s>1, s, s);").Sort().Check( + testkit.Rows("a", "a", "b", "b")) + tk.MustQuery("select s from s where case s when 1 then s else s end;").Sort().Check( + testkit.Rows("a", "a", "b", "b")) + tk.MustQuery("select s from s where case 1 when s then s end;").Check(testkit.Rows()) + + tk.MustQuery("select if(s>1,s,s)='a' from s").Sort().Check( + testkit.Rows("0", "0", "1", "1")) + tk.MustQuery("select if(s>1,s,s)=4 from s").Sort().Check( + testkit.Rows("0", "0", "1", "1")) + + tk.MustExec("drop table if exists s;") + tk.MustExec("create table s(s set('c', 'b', 'a'));") + tk.MustExec("insert into s values('a')") + tk.MustQuery("select elt(1,s) = 'a' from s").Check(testkit.Rows("1")) + tk.MustQuery("select elt(1,s) = 4 from s").Check(testkit.Rows("1")) + tk.MustQuery("select s from s where elt(1,s)").Check(testkit.Rows("a")) +} diff --git a/go.mod b/go.mod index fe8e08ae42e47..f82a8a187775f 100644 --- a/go.mod +++ b/go.mod @@ -47,7 +47,7 @@ require ( github.com/pingcap/goleveldb v0.0.0-20191226122134-f82aafb29989 github.com/pingcap/kvproto v0.0.0-20210507054410-a8152f8a876c github.com/pingcap/log v0.0.0-20210317133921-96f4fcab92a4 - github.com/pingcap/parser v0.0.0-20210513020953-ae2c4497c07b + github.com/pingcap/parser v0.0.0-20210518053259-92fa6fe07eb6 github.com/pingcap/sysutil v0.0.0-20210315073920-cc0985d983a3 github.com/pingcap/tidb-tools v4.0.9-0.20201127090955-2707c97b3853+incompatible github.com/pingcap/tipb v0.0.0-20210422074242-57dd881b81b1 diff --git a/go.sum b/go.sum index 3ee71da011a54..14986c3d1f025 100644 --- a/go.sum +++ b/go.sum @@ -443,8 +443,8 @@ github.com/pingcap/log v0.0.0-20200511115504-543df19646ad/go.mod h1:4rbK1p9ILyIf github.com/pingcap/log v0.0.0-20201112100606-8f1e84a3abc8/go.mod h1:4rbK1p9ILyIfb6hU7OG2CiWSqMXnp3JMbiaVJ6mvoY8= github.com/pingcap/log v0.0.0-20210317133921-96f4fcab92a4 h1:ERrF0fTuIOnwfGbt71Ji3DKbOEaP189tjym50u8gpC8= github.com/pingcap/log v0.0.0-20210317133921-96f4fcab92a4/go.mod h1:4rbK1p9ILyIfb6hU7OG2CiWSqMXnp3JMbiaVJ6mvoY8= -github.com/pingcap/parser v0.0.0-20210513020953-ae2c4497c07b h1:eLuDQ6eJCEKCbGwhGrkjzagwev1GJGU2Y2kFkAsBzV0= -github.com/pingcap/parser v0.0.0-20210513020953-ae2c4497c07b/go.mod h1:xZC8I7bug4GJ5KtHhgAikjTfU4kBv1Sbo3Pf1MZ6lVw= +github.com/pingcap/parser v0.0.0-20210518053259-92fa6fe07eb6 h1:wsH3psMH5ksDowsN9VUE9ZqSrX6oF4AYQQfOunkvSfU= +github.com/pingcap/parser v0.0.0-20210518053259-92fa6fe07eb6/go.mod h1:xZC8I7bug4GJ5KtHhgAikjTfU4kBv1Sbo3Pf1MZ6lVw= github.com/pingcap/sysutil v0.0.0-20200206130906-2bfa6dc40bcd/go.mod h1:EB/852NMQ+aRKioCpToQ94Wl7fktV+FNnxf3CX/TTXI= github.com/pingcap/sysutil v0.0.0-20210315073920-cc0985d983a3 h1:A9KL9R+lWSVPH8IqUuH1QSTRJ5FGoY1bT2IcfPKsWD8= github.com/pingcap/sysutil v0.0.0-20210315073920-cc0985d983a3/go.mod h1:tckvA041UWP+NqYzrJ3fMgC/Hw9wnmQ/tUkp/JaHly8= diff --git a/infoschema/builder.go b/infoschema/builder.go index 28591d8679baf..88e8b71add319 100644 --- a/infoschema/builder.go +++ b/infoschema/builder.go @@ -26,6 +26,7 @@ import ( "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/ddl/placement" "github.com/pingcap/tidb/domain/infosync" + "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/meta" "github.com/pingcap/tidb/meta/autoid" "github.com/pingcap/tidb/table" @@ -35,8 +36,10 @@ import ( // Builder builds a new InfoSchema. type Builder struct { - is *infoSchema - handle *Handle + is *infoSchema + // TODO: store is only used by autoid allocators + // detach allocators from storage, use passed transaction in the feature + store kv.Storage } // ApplyDiff applies SchemaDiff to the new InfoSchema. @@ -352,14 +355,14 @@ func (b *Builder) applyCreateTable(m *meta.Meta, dbInfo *model.DBInfo, tableID i ConvertOldVersionUTF8ToUTF8MB4IfNeed(tblInfo) if len(allocs) == 0 { - allocs = autoid.NewAllocatorsFromTblInfo(b.handle.store, dbInfo.ID, tblInfo) + allocs = autoid.NewAllocatorsFromTblInfo(b.store, dbInfo.ID, tblInfo) } else { switch tp { case model.ActionRebaseAutoID, model.ActionModifyTableAutoIdCache: - newAlloc := autoid.NewAllocator(b.handle.store, dbInfo.ID, tblInfo.IsAutoIncColUnsigned(), autoid.RowIDAllocType) + newAlloc := autoid.NewAllocator(b.store, dbInfo.ID, tblInfo.IsAutoIncColUnsigned(), autoid.RowIDAllocType) allocs = append(allocs, newAlloc) case model.ActionRebaseAutoRandomBase: - newAlloc := autoid.NewAllocator(b.handle.store, dbInfo.ID, tblInfo.IsAutoRandomBitColUnsigned(), autoid.AutoRandomType) + newAlloc := autoid.NewAllocator(b.store, dbInfo.ID, tblInfo.IsAutoRandomBitColUnsigned(), autoid.AutoRandomType) allocs = append(allocs, newAlloc) case model.ActionModifyColumn: // Change column attribute from auto_increment to auto_random. @@ -368,7 +371,7 @@ func (b *Builder) applyCreateTable(m *meta.Meta, dbInfo *model.DBInfo, tableID i allocs = allocs.Filter(func(a autoid.Allocator) bool { return a.GetType() != autoid.AutoIncrementType && a.GetType() != autoid.RowIDAllocType }) - newAlloc := autoid.NewAllocator(b.handle.store, dbInfo.ID, tblInfo.IsAutoRandomBitColUnsigned(), autoid.AutoRandomType) + newAlloc := autoid.NewAllocator(b.store, dbInfo.ID, tblInfo.IsAutoRandomBitColUnsigned(), autoid.AutoRandomType) allocs = append(allocs, newAlloc) } } @@ -470,9 +473,14 @@ func (b *Builder) applyPlacementUpdate(id string) error { return nil } +// Build builds and returns the built infoschema. +func (b *Builder) Build() InfoSchema { + return b.is +} + // InitWithOldInfoSchema initializes an empty new InfoSchema by copies all the data from old InfoSchema. -func (b *Builder) InitWithOldInfoSchema() *Builder { - oldIS := b.handle.Get().(*infoSchema) +func (b *Builder) InitWithOldInfoSchema(oldSchema InfoSchema) *Builder { + oldIS := oldSchema.(*infoSchema) b.is.schemaMetaVersion = oldIS.schemaMetaVersion b.copySchemasMap(oldIS) b.copyBundlesMap(oldIS) @@ -549,7 +557,7 @@ func (b *Builder) createSchemaTablesForDB(di *model.DBInfo, tableFromMeta tableF b.is.schemaMap[di.Name.L] = schTbls for _, t := range di.Tables { - allocs := autoid.NewAllocatorsFromTblInfo(b.handle.store, di.ID, t) + allocs := autoid.NewAllocatorsFromTblInfo(b.store, di.ID, t) var tbl table.Table tbl, err := tableFromMeta(allocs, t) if err != nil { @@ -574,21 +582,16 @@ func RegisterVirtualTable(dbInfo *model.DBInfo, tableFromMeta tableFromMetaFunc) drivers = append(drivers, &virtualTableDriver{dbInfo, tableFromMeta}) } -// Build sets new InfoSchema to the handle in the Builder. -func (b *Builder) Build() { - b.handle.value.Store(b.is) -} - // NewBuilder creates a new Builder with a Handle. -func NewBuilder(handle *Handle) *Builder { - b := new(Builder) - b.handle = handle - b.is = &infoSchema{ - schemaMap: map[string]*schemaTables{}, - ruleBundleMap: map[string]*placement.Bundle{}, - sortedTablesBuckets: make([]sortedTables, bucketCount), +func NewBuilder(store kv.Storage) *Builder { + return &Builder{ + store: store, + is: &infoSchema{ + schemaMap: map[string]*schemaTables{}, + ruleBundleMap: map[string]*placement.Bundle{}, + sortedTablesBuckets: make([]sortedTables, bucketCount), + }, } - return b } func tableBucketIdx(tableID int64) int { diff --git a/infoschema/cache.go b/infoschema/cache.go new file mode 100644 index 0000000000000..4c3371b1bc354 --- /dev/null +++ b/infoschema/cache.go @@ -0,0 +1,95 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package infoschema + +import ( + "sort" + "sync" + + "github.com/pingcap/tidb/metrics" +) + +// InfoCache handles information schema, including getting and setting. +// The cache behavior, however, is transparent and under automatic management. +// It only promised to cache the infoschema, if it is newer than all the cached. +type InfoCache struct { + mu sync.RWMutex + // cache is sorted by SchemaVersion in descending order + cache []InfoSchema +} + +// NewCache creates a new InfoCache. +func NewCache(capcity int) *InfoCache { + return &InfoCache{cache: make([]InfoSchema, 0, capcity)} +} + +// GetLatest gets the newest information schema. +func (h *InfoCache) GetLatest() InfoSchema { + h.mu.RLock() + defer h.mu.RUnlock() + metrics.InfoCacheCounters.WithLabelValues("get").Inc() + if len(h.cache) > 0 { + metrics.InfoCacheCounters.WithLabelValues("hit").Inc() + return h.cache[0] + } + return nil +} + +// GetByVersion gets the information schema based on schemaVersion. Returns nil if it is not loaded. +func (h *InfoCache) GetByVersion(version int64) InfoSchema { + h.mu.RLock() + defer h.mu.RUnlock() + metrics.InfoCacheCounters.WithLabelValues("get").Inc() + i := sort.Search(len(h.cache), func(i int) bool { + return h.cache[i].SchemaMetaVersion() <= version + }) + if i < len(h.cache) && h.cache[i].SchemaMetaVersion() == version { + metrics.InfoCacheCounters.WithLabelValues("hit").Inc() + return h.cache[i] + } + return nil +} + +// Insert will **TRY** to insert the infoschema into the cache. +// It only promised to cache the newest infoschema. +// It returns 'true' if it is cached, 'false' otherwise. +func (h *InfoCache) Insert(is InfoSchema) bool { + h.mu.Lock() + defer h.mu.Unlock() + + version := is.SchemaMetaVersion() + i := sort.Search(len(h.cache), func(i int) bool { + return h.cache[i].SchemaMetaVersion() <= version + }) + + // cached entry + if i < len(h.cache) && h.cache[i].SchemaMetaVersion() == version { + return true + } + + if len(h.cache) < cap(h.cache) { + // has free space, grown the slice + h.cache = h.cache[:len(h.cache)+1] + copy(h.cache[i+1:], h.cache[i:]) + h.cache[i] = is + return true + } else if i < len(h.cache) { + // drop older schema + copy(h.cache[i+1:], h.cache[i:]) + h.cache[i] = is + return true + } + // older than all cached schemas, refuse to cache it + return false +} diff --git a/infoschema/cache_test.go b/infoschema/cache_test.go new file mode 100644 index 0000000000000..a8e9ddcc0df5a --- /dev/null +++ b/infoschema/cache_test.go @@ -0,0 +1,119 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package infoschema_test + +import ( + . "github.com/pingcap/check" + "github.com/pingcap/tidb/infoschema" +) + +var _ = Suite(&testInfoCacheSuite{}) + +type testInfoCacheSuite struct { +} + +func (s *testInfoCacheSuite) TestNewCache(c *C) { + ic := infoschema.NewCache(16) + c.Assert(ic, NotNil) +} + +func (s *testInfoCacheSuite) TestInsert(c *C) { + ic := infoschema.NewCache(3) + c.Assert(ic, NotNil) + + is2 := infoschema.MockInfoSchemaWithSchemaVer(nil, 2) + ic.Insert(is2) + c.Assert(ic.GetByVersion(2), NotNil) + + // newer + is5 := infoschema.MockInfoSchemaWithSchemaVer(nil, 5) + ic.Insert(is5) + c.Assert(ic.GetByVersion(5), NotNil) + c.Assert(ic.GetByVersion(2), NotNil) + + // older + is0 := infoschema.MockInfoSchemaWithSchemaVer(nil, 0) + ic.Insert(is0) + c.Assert(ic.GetByVersion(5), NotNil) + c.Assert(ic.GetByVersion(2), NotNil) + c.Assert(ic.GetByVersion(0), NotNil) + + // replace 5, drop 0 + is6 := infoschema.MockInfoSchemaWithSchemaVer(nil, 6) + ic.Insert(is6) + c.Assert(ic.GetByVersion(6), NotNil) + c.Assert(ic.GetByVersion(5), NotNil) + c.Assert(ic.GetByVersion(2), NotNil) + c.Assert(ic.GetByVersion(0), IsNil) + + // replace 2, drop 2 + is3 := infoschema.MockInfoSchemaWithSchemaVer(nil, 3) + ic.Insert(is3) + c.Assert(ic.GetByVersion(6), NotNil) + c.Assert(ic.GetByVersion(5), NotNil) + c.Assert(ic.GetByVersion(3), NotNil) + c.Assert(ic.GetByVersion(2), IsNil) + c.Assert(ic.GetByVersion(0), IsNil) + + // insert 2, but failed silently + ic.Insert(is2) + c.Assert(ic.GetByVersion(6), NotNil) + c.Assert(ic.GetByVersion(5), NotNil) + c.Assert(ic.GetByVersion(3), NotNil) + c.Assert(ic.GetByVersion(2), IsNil) + c.Assert(ic.GetByVersion(0), IsNil) + + // insert 5, but it is already in + ic.Insert(is5) + c.Assert(ic.GetByVersion(6), NotNil) + c.Assert(ic.GetByVersion(5), NotNil) + c.Assert(ic.GetByVersion(3), NotNil) + c.Assert(ic.GetByVersion(2), IsNil) + c.Assert(ic.GetByVersion(0), IsNil) +} + +func (s *testInfoCacheSuite) TestGetByVersion(c *C) { + ic := infoschema.NewCache(2) + c.Assert(ic, NotNil) + is1 := infoschema.MockInfoSchemaWithSchemaVer(nil, 1) + ic.Insert(is1) + is3 := infoschema.MockInfoSchemaWithSchemaVer(nil, 3) + ic.Insert(is3) + + c.Assert(ic.GetByVersion(1), Equals, is1) + c.Assert(ic.GetByVersion(3), Equals, is3) + c.Assert(ic.GetByVersion(0), IsNil, Commentf("index == 0, but not found")) + c.Assert(ic.GetByVersion(2), IsNil, Commentf("index in the middle, but not found")) + c.Assert(ic.GetByVersion(4), IsNil, Commentf("index == length, but not found")) +} + +func (s *testInfoCacheSuite) TestGetLatest(c *C) { + ic := infoschema.NewCache(16) + c.Assert(ic, NotNil) + c.Assert(ic.GetLatest(), IsNil) + + is1 := infoschema.MockInfoSchemaWithSchemaVer(nil, 1) + ic.Insert(is1) + c.Assert(ic.GetLatest(), Equals, is1) + + // newer change the newest + is2 := infoschema.MockInfoSchemaWithSchemaVer(nil, 2) + ic.Insert(is2) + c.Assert(ic.GetLatest(), Equals, is2) + + // older schema doesn't change the newest + is0 := infoschema.MockInfoSchemaWithSchemaVer(nil, 0) + ic.Insert(is0) + c.Assert(ic.GetLatest(), Equals, is2) +} diff --git a/infoschema/infoschema.go b/infoschema/infoschema.go index ac8afd14605f1..2494e89b4d57f 100644 --- a/infoschema/infoschema.go +++ b/infoschema/infoschema.go @@ -17,12 +17,10 @@ import ( "fmt" "sort" "sync" - "sync/atomic" "github.com/pingcap/parser/model" "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/ddl/placement" - "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/meta/autoid" "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/util" @@ -312,40 +310,6 @@ func (is *infoSchema) SequenceByName(schema, sequence model.CIStr) (util.Sequenc return tbl.(util.SequenceTable), nil } -// Handle handles information schema, including getting and setting. -type Handle struct { - value atomic.Value - store kv.Storage -} - -// NewHandle creates a new Handle. -func NewHandle(store kv.Storage) *Handle { - h := &Handle{ - store: store, - } - return h -} - -// Get gets information schema from Handle. -func (h *Handle) Get() InfoSchema { - v := h.value.Load() - schema, _ := v.(InfoSchema) - return schema -} - -// IsValid uses to check whether handle value is valid. -func (h *Handle) IsValid() bool { - return h.value.Load() != nil -} - -// EmptyClone creates a new Handle with the same store and memSchema, but the value is not set. -func (h *Handle) EmptyClone() *Handle { - newHandle := &Handle{ - store: h.store, - } - return newHandle -} - func init() { // Initialize the information shema database and register the driver to `drivers` dbID := autoid.InformationSchemaDBID diff --git a/infoschema/infoschema_test.go b/infoschema/infoschema_test.go index 6aa0c5526f467..87276ef1452b9 100644 --- a/infoschema/infoschema_test.go +++ b/infoschema/infoschema_test.go @@ -15,7 +15,6 @@ package infoschema_test import ( "context" - "sync" "testing" . "github.com/pingcap/check" @@ -57,7 +56,6 @@ func (*testSuite) TestT(c *C) { c.Assert(err, IsNil) defer dom.Close() - handle := infoschema.NewHandle(store) dbName := model.NewCIStr("Test") tbName := model.NewCIStr("T") colName := model.NewCIStr("A") @@ -116,7 +114,7 @@ func (*testSuite) TestT(c *C) { }) c.Assert(err, IsNil) - builder, err := infoschema.NewBuilder(handle).InitWithDBInfos(dbInfos, nil, 1) + builder, err := infoschema.NewBuilder(dom.Store()).InitWithDBInfos(dbInfos, nil, 1) c.Assert(err, IsNil) txn, err := store.Begin() @@ -126,8 +124,7 @@ func (*testSuite) TestT(c *C) { err = txn.Rollback() c.Assert(err, IsNil) - builder.Build() - is := handle.Get() + is := builder.Build() schemaNames := is.AllSchemaNames() c.Assert(schemaNames, HasLen, 4) @@ -213,14 +210,10 @@ func (*testSuite) TestT(c *C) { c.Assert(err, IsNil) err = txn.Rollback() c.Assert(err, IsNil) - builder.Build() - is = handle.Get() + is = builder.Build() schema, ok = is.SchemaByID(dbID) c.Assert(ok, IsTrue) c.Assert(len(schema.Tables), Equals, 1) - - emptyHandle := handle.EmptyClone() - c.Assert(emptyHandle.Get(), IsNil) } func (testSuite) TestMockInfoSchema(c *C) { @@ -258,32 +251,6 @@ func checkApplyCreateNonExistsTableDoesNotPanic(c *C, txn kv.Transaction, builde c.Assert(infoschema.ErrTableNotExists.Equal(err), IsTrue) } -// TestConcurrent makes sure it is safe to concurrently create handle on multiple stores. -func (testSuite) TestConcurrent(c *C) { - defer testleak.AfterTest(c)() - storeCount := 5 - stores := make([]kv.Storage, storeCount) - for i := 0; i < storeCount; i++ { - store, err := mockstore.NewMockStore() - c.Assert(err, IsNil) - stores[i] = store - } - defer func() { - for _, store := range stores { - store.Close() - } - }() - var wg sync.WaitGroup - wg.Add(storeCount) - for _, store := range stores { - go func(s kv.Storage) { - defer wg.Done() - _ = infoschema.NewHandle(s) - }(store) - } - wg.Wait() -} - // TestInfoTables makes sure that all tables of information_schema could be found in infoschema handle. func (*testSuite) TestInfoTables(c *C) { defer testleak.AfterTest(c)() @@ -293,12 +260,10 @@ func (*testSuite) TestInfoTables(c *C) { err := store.Close() c.Assert(err, IsNil) }() - handle := infoschema.NewHandle(store) - builder, err := infoschema.NewBuilder(handle).InitWithDBInfos(nil, nil, 0) + + builder, err := infoschema.NewBuilder(store).InitWithDBInfos(nil, nil, 0) c.Assert(err, IsNil) - builder.Build() - is := handle.Get() - c.Assert(is, NotNil) + is := builder.Build() infoTables := []string{ "SCHEMATA", @@ -360,12 +325,9 @@ func (*testSuite) TestGetBundle(c *C) { c.Assert(err, IsNil) }() - handle := infoschema.NewHandle(store) - builder, err := infoschema.NewBuilder(handle).InitWithDBInfos(nil, nil, 0) + builder, err := infoschema.NewBuilder(store).InitWithDBInfos(nil, nil, 0) c.Assert(err, IsNil) - builder.Build() - - is := handle.Get() + is := builder.Build() bundle := &placement.Bundle{ ID: placement.PDBundleID, diff --git a/infoschema/tables_test.go b/infoschema/tables_test.go index ebe4a0620256f..1e5687928f3ad 100644 --- a/infoschema/tables_test.go +++ b/infoschema/tables_test.go @@ -965,8 +965,6 @@ func (s *testTableSuite) TestStmtSummaryTable(c *C) { tk.MustExec("set global tidb_enable_stmt_summary = 1") tk.MustQuery("select @@global.tidb_enable_stmt_summary").Check(testkit.Rows("1")) - // Invalidate the cache manually so that tidb_enable_stmt_summary works immediately. - s.dom.GetGlobalVarsCache().Disable() // Disable refreshing summary. tk.MustExec("set global tidb_stmt_summary_refresh_interval = 999999999") tk.MustQuery("select @@global.tidb_stmt_summary_refresh_interval").Check(testkit.Rows("999999999")) @@ -1209,8 +1207,6 @@ func (s *testClusterTableSuite) TestStmtSummaryHistoryTable(c *C) { tk.MustExec("set global tidb_enable_stmt_summary = 1") tk.MustQuery("select @@global.tidb_enable_stmt_summary").Check(testkit.Rows("1")) - // Invalidate the cache manually so that tidb_enable_stmt_summary works immediately. - s.dom.GetGlobalVarsCache().Disable() // Disable refreshing summary. tk.MustExec("set global tidb_stmt_summary_refresh_interval = 999999999") tk.MustQuery("select @@global.tidb_stmt_summary_refresh_interval").Check(testkit.Rows("999999999")) @@ -1266,8 +1262,6 @@ func (s *testTableSuite) TestStmtSummaryInternalQuery(c *C) { tk.MustExec("create global binding for select * from t where t.a = 1 using select * from t ignore index(k) where t.a = 1") tk.MustExec("set global tidb_enable_stmt_summary = 1") tk.MustQuery("select @@global.tidb_enable_stmt_summary").Check(testkit.Rows("1")) - // Invalidate the cache manually so that tidb_enable_stmt_summary works immediately. - s.dom.GetGlobalVarsCache().Disable() // Disable refreshing summary. tk.MustExec("set global tidb_stmt_summary_refresh_interval = 999999999") tk.MustQuery("select @@global.tidb_stmt_summary_refresh_interval").Check(testkit.Rows("999999999")) diff --git a/kv/interface_mock_test.go b/kv/interface_mock_test.go index e1d41f1693088..5d85261bc2111 100644 --- a/kv/interface_mock_test.go +++ b/kv/interface_mock_test.go @@ -213,6 +213,10 @@ func (s *mockStorage) GetMemCache() MemManager { return nil } +func (s *mockStorage) GetMinSafeTS(txnScope string) uint64 { + return 0 +} + // newMockStorage creates a new mockStorage. func newMockStorage() Storage { return &mockStorage{} diff --git a/kv/kv.go b/kv/kv.go index 572fe104024bc..20b0fc84b7144 100644 --- a/kv/kv.go +++ b/kv/kv.go @@ -417,6 +417,8 @@ type Storage interface { ShowStatus(ctx context.Context, key string) (interface{}, error) // GetMemCache return memory manager of the storage. GetMemCache() MemManager + // GetMinSafeTS return the minimal SafeTS of the storage with given txnScope. + GetMinSafeTS(txnScope string) uint64 } // EtcdBackend is used for judging a storage is a real TiKV. diff --git a/metrics/domain.go b/metrics/domain.go index dd3912555d59c..a05b25dd6a46a 100644 --- a/metrics/domain.go +++ b/metrics/domain.go @@ -38,6 +38,19 @@ var ( Buckets: prometheus.ExponentialBuckets(0.001, 2, 20), // 1ms ~ 524s }) + // InfoCacheCounters are the counters of get/hit. + InfoCacheCounters = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Namespace: "tidb", + Subsystem: "domain", + Name: "infocache_counters", + Help: "Counters of infoCache: get/hit.", + }, []string{LblType}) + // InfoCacheCounterGet is the total number of getting entry. + InfoCacheCounterGet = "get" + // InfoCacheCounterHit is the cache hit numbers for get. + InfoCacheCounterHit = "hit" + // LoadPrivilegeCounter records the counter of load privilege. LoadPrivilegeCounter = prometheus.NewCounterVec( prometheus.CounterOpts{ @@ -47,6 +60,15 @@ var ( Help: "Counter of load privilege", }, []string{LblType}) + // LoadSysVarCacheCounter records the counter of loading sysvars + LoadSysVarCacheCounter = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Namespace: "tidb", + Subsystem: "domain", + Name: "load_sysvarcache_total", + Help: "Counter of load sysvar cache", + }, []string{LblType}) + SchemaValidatorStop = "stop" SchemaValidatorRestart = "restart" SchemaValidatorReset = "reset" diff --git a/metrics/metrics.go b/metrics/metrics.go index ff2ac3b1aa08d..542398e7bbdee 100644 --- a/metrics/metrics.go +++ b/metrics/metrics.go @@ -98,6 +98,7 @@ func RegisterMetrics() { prometheus.MustRegister(JobsGauge) prometheus.MustRegister(KeepAliveCounter) prometheus.MustRegister(LoadPrivilegeCounter) + prometheus.MustRegister(InfoCacheCounters) prometheus.MustRegister(LoadSchemaCounter) prometheus.MustRegister(LoadSchemaDuration) prometheus.MustRegister(MetaHistogram) @@ -150,6 +151,7 @@ func RegisterMetrics() { prometheus.MustRegister(TiFlashQueryTotalCounter) prometheus.MustRegister(SmallTxnWriteDuration) prometheus.MustRegister(TxnWriteThroughput) + prometheus.MustRegister(LoadSysVarCacheCounter) tikvmetrics.InitMetrics(TiDB, TiKVClient) tikvmetrics.RegisterMetrics() diff --git a/owner/manager_test.go b/owner/manager_test.go index e25b204e6bbb4..e239419057291 100644 --- a/owner/manager_test.go +++ b/owner/manager_test.go @@ -24,6 +24,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/parser/terror" . "github.com/pingcap/tidb/ddl" + "github.com/pingcap/tidb/infoschema" "github.com/pingcap/tidb/owner" "github.com/pingcap/tidb/store/mockstore" "github.com/pingcap/tidb/util/logutil" @@ -72,11 +73,14 @@ func TestSingle(t *testing.T) { defer clus.Terminate(t) cli := clus.RandClient() ctx := goctx.Background() + ic := infoschema.NewCache(2) + ic.Insert(infoschema.MockInfoSchemaWithSchemaVer(nil, 0)) d := NewDDL( ctx, WithEtcdClient(cli), WithStore(store), WithLease(testLease), + WithInfoCache(ic), ) err = d.Start(nil) if err != nil { @@ -142,11 +146,14 @@ func TestCluster(t *testing.T) { defer clus.Terminate(t) cli := clus.Client(0) + ic := infoschema.NewCache(2) + ic.Insert(infoschema.MockInfoSchemaWithSchemaVer(nil, 0)) d := NewDDL( goctx.Background(), WithEtcdClient(cli), WithStore(store), WithLease(testLease), + WithInfoCache(ic), ) err = d.Start(nil) if err != nil { @@ -157,11 +164,14 @@ func TestCluster(t *testing.T) { t.Fatalf("expect true, got isOwner:%v", isOwner) } cli1 := clus.Client(1) + ic2 := infoschema.NewCache(2) + ic2.Insert(infoschema.MockInfoSchemaWithSchemaVer(nil, 0)) d1 := NewDDL( goctx.Background(), WithEtcdClient(cli1), WithStore(store), WithLease(testLease), + WithInfoCache(ic2), ) err = d1.Start(nil) if err != nil { @@ -189,11 +199,14 @@ func TestCluster(t *testing.T) { // d3 (not owner) stop cli3 := clus.Client(3) + ic3 := infoschema.NewCache(2) + ic3.Insert(infoschema.MockInfoSchemaWithSchemaVer(nil, 0)) d3 := NewDDL( goctx.Background(), WithEtcdClient(cli3), WithStore(store), WithLease(testLease), + WithInfoCache(ic3), ) err = d3.Start(nil) if err != nil { diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index 64bc0c41407e1..fc6fb53dcff44 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -975,6 +975,19 @@ func (b *PlanBuilder) buildSelection(ctx context.Context, p LogicalPlan, where a if len(cnfExpres) == 0 { return p, nil } + // check expr field types. + for i, expr := range cnfExpres { + if expr.GetType().EvalType() == types.ETString { + tp := &types.FieldType{ + Tp: mysql.TypeDouble, + Flag: expr.GetType().Flag, + Flen: mysql.MaxRealWidth, + Decimal: types.UnspecifiedLength, + } + types.SetBinChsClnFlag(tp) + cnfExpres[i] = expression.TryPushCastIntoControlFunctionForHybridType(b.ctx, expr, tp) + } + } selection.Conditions = cnfExpres selection.SetChildren(p) return selection, nil diff --git a/planner/core/prepare_test.go b/planner/core/prepare_test.go index cd43b3964d59b..d6bfe69f82b39 100644 --- a/planner/core/prepare_test.go +++ b/planner/core/prepare_test.go @@ -278,9 +278,6 @@ func (s *testPrepareSerialSuite) TestPrepareOverMaxPreparedStmtCount(c *C) { tk.MustExec("set @@global.max_prepared_stmt_count = 2") tk.MustQuery("select @@global.max_prepared_stmt_count").Check(testkit.Rows("2")) - // Disable global variable cache, so load global session variable take effect immediate. - dom.GetGlobalVarsCache().Disable() - // test close session to give up all prepared stmt tk.MustExec(`prepare stmt2 from "select 1"`) prePrepared = readGaugeInt(metrics.PreparedStmtGauge) diff --git a/session/session.go b/session/session.go index af3f41c863dc0..efd6706c4ffb3 100644 --- a/session/session.go +++ b/session/session.go @@ -473,6 +473,9 @@ func (s *session) doCommit(ctx context.Context) error { if err != nil { return err } + if err = s.removeTempTableFromBuffer(); err != nil { + return err + } // mockCommitError and mockGetTSErrorInRetry use to test PR #8743. failpoint.Inject("mockCommitError", func(val failpoint.Value) { @@ -502,8 +505,14 @@ func (s *session) doCommit(ctx context.Context) error { // Get the related table or partition IDs. relatedPhysicalTables := s.GetSessionVars().TxnCtx.TableDeltaMap + // Get accessed global temporary tables in the transaction. + temporaryTables := s.GetSessionVars().TxnCtx.GlobalTemporaryTables physicalTableIDs := make([]int64, 0, len(relatedPhysicalTables)) for id := range relatedPhysicalTables { + // Schema change on global temporary tables doesn't affect transactions. + if _, ok := temporaryTables[id]; ok { + continue + } physicalTableIDs = append(physicalTableIDs, id) } // Set this option for 2 phase commit to validate schema lease. @@ -526,29 +535,40 @@ func (s *session) doCommit(ctx context.Context) error { s.GetSessionVars().TxnCtx.IsExplicit && s.GetSessionVars().GuaranteeLinearizability) } - // Filter out the temporary table key-values. - if tables := s.sessionVars.TxnCtx.GlobalTemporaryTables; tables != nil { - memBuffer := s.txn.GetMemBuffer() - for tid := range tables { - seekKey := tablecodec.EncodeTablePrefix(tid) - endKey := tablecodec.EncodeTablePrefix(tid + 1) - iter, err := memBuffer.Iter(seekKey, endKey) - if err != nil { + return s.txn.Commit(tikvutil.SetSessionID(ctx, s.GetSessionVars().ConnectionID)) +} + +// removeTempTableFromBuffer filters out the temporary table key-values. +func (s *session) removeTempTableFromBuffer() error { + tables := s.GetSessionVars().TxnCtx.GlobalTemporaryTables + if len(tables) == 0 { + return nil + } + memBuffer := s.txn.GetMemBuffer() + // Reset and new an empty stage buffer. + defer func() { + s.txn.cleanup() + }() + for tid := range tables { + seekKey := tablecodec.EncodeTablePrefix(tid) + endKey := tablecodec.EncodeTablePrefix(tid + 1) + iter, err := memBuffer.Iter(seekKey, endKey) + if err != nil { + return err + } + for iter.Valid() && iter.Key().HasPrefix(seekKey) { + if err = memBuffer.Delete(iter.Key()); err != nil { return err } - for iter.Valid() && iter.Key().HasPrefix(seekKey) { - if err = memBuffer.Delete(iter.Key()); err != nil { - return errors.Trace(err) - } - s.txn.UpdateEntriesCountAndSize() - if err = iter.Next(); err != nil { - return errors.Trace(err) - } + s.txn.UpdateEntriesCountAndSize() + if err = iter.Next(); err != nil { + return err } } } - - return s.txn.Commit(tikvutil.SetSessionID(ctx, s.GetSessionVars().ConnectionID)) + // Flush to the root membuffer. + s.txn.flushStmtBuf() + return nil } // errIsNoisy is used to filter DUPLCATE KEY errors. @@ -977,6 +997,7 @@ func (s *session) replaceTableValue(ctx context.Context, tblName string, varName return err } _, _, err = s.ExecRestrictedStmt(ctx, stmt) + domain.GetDomain(s).NotifyUpdateSysVarCache(s) return err } @@ -997,16 +1018,27 @@ func (s *session) GetGlobalSysVar(name string) (string, error) { // When running bootstrap or upgrade, we should not access global storage. return "", nil } - sysVar, err := s.getTableValue(context.TODO(), mysql.GlobalVariablesTable, name) + + sv := variable.GetSysVar(name) + if sv == nil { + // It might be a recently unregistered sysvar. We should return unknown + // since GetSysVar is the canonical version, but we can update the cache + // so the next request doesn't attempt to load this. + logutil.BgLogger().Info("sysvar does not exist. sysvar cache may be stale", zap.String("name", name)) + return "", variable.ErrUnknownSystemVar.GenWithStackByArgs(name) + } + + sysVar, err := domain.GetDomain(s).GetSysVarCache().GetGlobalVar(s, name) if err != nil { - if errResultIsEmpty.Equal(err) { - sv := variable.GetSysVar(name) - if sv != nil { - return sv.Value, nil - } - return "", variable.ErrUnknownSystemVar.GenWithStackByArgs(name) + // The sysvar exists, but there is no cache entry yet. + // This might be because the sysvar was only recently registered. + // In which case it is safe to return the default, but we can also + // update the cache for the future. + logutil.BgLogger().Info("sysvar not in cache yet. sysvar cache may be stale", zap.String("name", name)) + sysVar, err = s.getTableValue(context.TODO(), mysql.GlobalVariablesTable, name) + if err != nil { + return sv.Value, nil } - return "", err } // Fetch mysql.tidb values if required if s.varFromTiDBTable(name) { @@ -1051,12 +1083,7 @@ func (s *session) updateGlobalSysVar(sv *variable.SysVar, value string) error { return err } } - stmt, err := s.ParseWithParams(context.TODO(), "REPLACE %n.%n VALUES (%?, %?)", mysql.SystemDB, mysql.GlobalVariablesTable, sv.Name, value) - if err != nil { - return err - } - _, _, err = s.ExecRestrictedStmt(context.TODO(), stmt) - return err + return s.replaceTableValue(context.TODO(), mysql.GlobalVariablesTable, sv.Name, value) } // setTiDBTableValue handles tikv_* sysvars which need to update mysql.tidb @@ -1935,7 +1962,7 @@ func (s *session) isTxnRetryable() bool { func (s *session) NewTxn(ctx context.Context) error { if s.txn.Valid() { - txnID := s.txn.StartTS() + txnStartTS := s.txn.StartTS() txnScope := s.GetSessionVars().TxnCtx.TxnScope err := s.CommitTxn(ctx) if err != nil { @@ -1944,7 +1971,7 @@ func (s *session) NewTxn(ctx context.Context) error { vars := s.GetSessionVars() logutil.Logger(ctx).Info("NewTxn() inside a transaction auto commit", zap.Int64("schemaVersion", vars.GetInfoSchema().SchemaMetaVersion()), - zap.Uint64("txnStartTS", txnID), + zap.Uint64("txnStartTS", txnStartTS), zap.String("txnScope", txnScope)) } @@ -2316,13 +2343,18 @@ func BootstrapSession(store kv.Storage) (*domain.Domain, error) { } } + // Rebuild sysvar cache in a loop + err = dom.LoadSysVarCacheLoop(se) + if err != nil { + return nil, err + } + if len(cfg.Plugin.Load) > 0 { err := plugin.Init(context.Background(), plugin.Config{EtcdClient: dom.GetEtcdClient()}) if err != nil { return nil, err } } - se4, err := createSession(store) if err != nil { return nil, err @@ -2425,7 +2457,7 @@ func createSessionWithOpt(store kv.Storage, opt *Opt) (*session, error) { // CreateSessionWithDomain creates a new Session and binds it with a Domain. // We need this because when we start DDL in Domain, the DDL need a session // to change some system tables. But at that time, we have been already in -// a lock context, which cause we can't call createSesion directly. +// a lock context, which cause we can't call createSession directly. func CreateSessionWithDomain(store kv.Storage, dom *domain.Domain) (*session, error) { s := &session{ store: store, @@ -2620,6 +2652,7 @@ var builtinGlobalVariable = []string{ variable.TiDBAllowFallbackToTiKV, variable.TiDBEnableDynamicPrivileges, variable.CTEMaxRecursionDepth, + variable.TiDBDMLBatchSize, } // loadCommonGlobalVariablesIfNeeded loads and applies commonly used global variables for the session. @@ -2633,38 +2666,30 @@ func (s *session) loadCommonGlobalVariablesIfNeeded() error { return nil } - var err error - // Use GlobalVariableCache if TiDB just loaded global variables within 2 second ago. - // When a lot of connections connect to TiDB simultaneously, it can protect TiKV meta region from overload. - gvc := domain.GetDomain(s).GetGlobalVarsCache() - loadFunc := func() ([]chunk.Row, []*ast.ResultField, error) { - vars := append(make([]string, 0, len(builtinGlobalVariable)+len(variable.PluginVarNames)), builtinGlobalVariable...) - if len(variable.PluginVarNames) > 0 { - vars = append(vars, variable.PluginVarNames...) - } - - stmt, err := s.ParseWithParams(context.TODO(), "select HIGH_PRIORITY * from mysql.global_variables where variable_name in (%?) order by VARIABLE_NAME", vars) - if err != nil { - return nil, nil, errors.Trace(err) - } + vars.CommonGlobalLoaded = true - return s.ExecRestrictedStmt(context.TODO(), stmt) - } - rows, _, err := gvc.LoadGlobalVariables(loadFunc) + // Deep copy sessionvar cache + // Eventually this whole map will be applied to systems[], which is a MySQL behavior. + sessionCache, err := domain.GetDomain(s).GetSysVarCache().GetSessionCache(s) if err != nil { - logutil.BgLogger().Warn("failed to load global variables", - zap.Uint64("conn", s.sessionVars.ConnectionID), zap.Error(err)) return err } - vars.CommonGlobalLoaded = true - - for _, row := range rows { - varName := row.GetString(0) - varVal := row.GetString(1) + for _, varName := range builtinGlobalVariable { + // The item should be in the sessionCache, but due to a strange current behavior there are some Global-only + // vars that are in builtinGlobalVariable. For compatibility we need to fall back to the Global cache on these items. + // TODO: don't load these globals into the session! + var varVal string + var ok bool + if varVal, ok = sessionCache[varName]; !ok { + varVal, err = s.GetGlobalSysVar(varName) + if err != nil { + continue // skip variables that are not loaded. + } + } // `collation_server` is related to `character_set_server`, set `character_set_server` will also set `collation_server`. // We have to make sure we set the `collation_server` with right value. if _, ok := vars.GetSystemVar(varName); !ok || varName == variable.CollationServer { - err = vars.SetSystemVar(varName, varVal) + err = vars.SetSystemVarWithRelaxedValidation(varName, varVal) if err != nil { return err } @@ -2679,8 +2704,6 @@ func (s *session) loadCommonGlobalVariablesIfNeeded() error { } } } - - vars.CommonGlobalLoaded = true return nil } @@ -2805,7 +2828,10 @@ func (s *session) NewTxnWithStalenessOption(ctx context.Context, option sessionc txn.SetOption(kv.IsStalenessReadOnly, true) txn.SetOption(kv.TxnScope, txnScope) s.txn.changeInvalidToValid(txn) - is := domain.GetDomain(s).InfoSchema() + is, err := domain.GetDomain(s).GetSnapshotInfoSchema(txn.StartTS()) + if err != nil { + return errors.Trace(err) + } s.sessionVars.TxnCtx = &variable.TransactionContext{ InfoSchema: is, CreateTime: time.Now(), diff --git a/session/session_test.go b/session/session_test.go index 4870215f33c9e..df2a167921e56 100644 --- a/session/session_test.go +++ b/session/session_test.go @@ -213,7 +213,6 @@ func (s *testSessionSuiteBase) SetUpSuite(c *C) { var err error s.dom, err = session.BootstrapSession(s.store) c.Assert(err, IsNil) - s.dom.GetGlobalVarsCache().Disable() } func (s *testSessionSuiteBase) TearDownSuite(c *C) { @@ -639,7 +638,6 @@ func (s *testSessionSuite) TestGlobalVarAccessor(c *C) { c.Assert(v, Equals, varValue2) // For issue 10955, make sure the new session load `max_execution_time` into sessionVars. - s.dom.GetGlobalVarsCache().Disable() tk1.MustExec("set @@global.max_execution_time = 100") tk2 := testkit.NewTestKitWithInit(c, s.store) c.Assert(tk2.Se.GetSessionVars().MaxExecutionTime, Equals, uint64(100)) @@ -789,6 +787,49 @@ func (s *testSessionSuite) TestRetryUnion(c *C) { c.Assert(err, ErrorMatches, ".*can not retry select for update statement") } +func (s *testSessionSuite) TestRetryGlobalTempTable(c *C) { + tk := testkit.NewTestKitWithInit(c, s.store) + tk.MustExec("drop table if exists normal_table") + tk.MustExec("create table normal_table(a int primary key, b int)") + defer tk.MustExec("drop table if exists normal_table") + tk.MustExec("drop table if exists temp_table") + tk.MustExec("create global temporary table temp_table(a int primary key, b int) on commit delete rows") + defer tk.MustExec("drop table if exists temp_table") + + // insert select + tk.MustExec("set tidb_disable_txn_auto_retry = 0") + tk.MustExec("insert normal_table value(100, 100)") + tk.MustExec("set @@autocommit = 0") + // used to make conflicts + tk.MustExec("update normal_table set b=b+1 where a=100") + tk.MustExec("insert temp_table value(1, 1)") + tk.MustExec("insert normal_table select * from temp_table") + c.Assert(session.GetHistory(tk.Se).Count(), Equals, 3) + + // try to conflict with tk + tk1 := testkit.NewTestKitWithInit(c, s.store) + tk1.MustExec("update normal_table set b=b+1 where a=100") + + // It will retry internally. + tk.MustExec("commit") + tk.MustQuery("select a, b from normal_table order by a").Check(testkit.Rows("1 1", "100 102")) + tk.MustQuery("select a, b from temp_table order by a").Check(testkit.Rows()) + + // update multi-tables + tk.MustExec("update normal_table set b=b+1 where a=100") + tk.MustExec("insert temp_table value(1, 2)") + // before update: normal_table=(1 1) (100 102), temp_table=(1 2) + tk.MustExec("update normal_table, temp_table set normal_table.b=temp_table.b where normal_table.a=temp_table.a") + c.Assert(session.GetHistory(tk.Se).Count(), Equals, 3) + + // try to conflict with tk + tk1.MustExec("update normal_table set b=b+1 where a=100") + + // It will retry internally. + tk.MustExec("commit") + tk.MustQuery("select a, b from normal_table order by a").Check(testkit.Rows("1 2", "100 104")) +} + func (s *testSessionSuite) TestRetryShow(c *C) { tk := testkit.NewTestKitWithInit(c, s.store) tk.MustExec("set @@autocommit = 0") @@ -2090,6 +2131,45 @@ func (s *testSchemaSerialSuite) TestSchemaCheckerSQL(c *C) { c.Assert(err, NotNil) } +func (s *testSchemaSerialSuite) TestSchemaCheckerTempTable(c *C) { + tk := testkit.NewTestKitWithInit(c, s.store) + tk1 := testkit.NewTestKitWithInit(c, s.store) + + // create table + tk.MustExec(`drop table if exists normal_table`) + tk.MustExec(`create table normal_table (id int, c int);`) + defer tk.MustExec(`drop table if exists normal_table`) + tk.MustExec(`drop table if exists temp_table`) + tk.MustExec(`create global temporary table temp_table (id int, c int) on commit delete rows;`) + defer tk.MustExec(`drop table if exists temp_table`) + + // The schema version is out of date in the first transaction, and the SQL can't be retried. + atomic.StoreUint32(&session.SchemaChangedWithoutRetry, 1) + defer func() { + atomic.StoreUint32(&session.SchemaChangedWithoutRetry, 0) + }() + + // It's fine to change the schema of temporary tables. + tk.MustExec(`begin;`) + tk1.MustExec(`alter table temp_table modify column c bigint;`) + tk.MustExec(`insert into temp_table values(3, 3);`) + tk.MustExec(`commit;`) + + // Truncate will modify table ID. + tk.MustExec(`begin;`) + tk1.MustExec(`truncate table temp_table;`) + tk.MustExec(`insert into temp_table values(3, 3);`) + tk.MustExec(`commit;`) + + // It reports error when also changing the schema of a normal table. + tk.MustExec(`begin;`) + tk1.MustExec(`alter table normal_table modify column c bigint;`) + tk.MustExec(`insert into temp_table values(3, 3);`) + tk.MustExec(`insert into normal_table values(3, 3);`) + _, err := tk.Exec(`commit;`) + c.Assert(terror.ErrorEqual(err, domain.ErrInfoSchemaChanged), IsTrue, Commentf("err %v", err)) +} + func (s *testSchemaSuite) TestPrepareStmtCommitWhenSchemaChanged(c *C) { tk := testkit.NewTestKitWithInit(c, s.store) tk1 := testkit.NewTestKitWithInit(c, s.store) @@ -2575,8 +2655,6 @@ func (s *testSessionSuite) TestSetGlobalTZ(c *C) { tk.MustQuery("show variables like 'time_zone'").Check(testkit.Rows("time_zone +08:00")) - // Disable global variable cache, so load global session variable take effect immediate. - s.dom.GetGlobalVarsCache().Disable() tk1 := testkit.NewTestKitWithInit(c, s.store) tk1.MustQuery("show variables like 'time_zone'").Check(testkit.Rows("time_zone +00:00")) } @@ -2718,8 +2796,6 @@ func (s *testSessionSuite3) TestEnablePartition(c *C) { tk.MustExec("set tidb_enable_list_partition=on") tk.MustQuery("show variables like 'tidb_enable_list_partition'").Check(testkit.Rows("tidb_enable_list_partition ON")) - // Disable global variable cache, so load global session variable take effect immediate. - s.dom.GetGlobalVarsCache().Disable() tk1 := testkit.NewTestKitWithInit(c, s.store) tk1.MustQuery("show variables like 'tidb_enable_table_partition'").Check(testkit.Rows("tidb_enable_table_partition ON")) } @@ -3903,9 +3979,7 @@ func (s *testSessionSerialSuite) TestIssue21943(c *C) { c.Assert(err.Error(), Equals, "[variable:1238]Variable 'last_plan_from_cache' is a read only variable") } -func (s *testSessionSuite) TestValidateReadOnlyInStalenessTransaction(c *C) { - c.Assert(failpoint.Enable("github.com/pingcap/tidb/executor/mockStalenessTxnSchemaVer", "return(false)"), IsNil) - defer failpoint.Disable("github.com/pingcap/tidb/executor/mockStalenessTxnSchemaVer") +func (s *testSessionSerialSuite) TestValidateReadOnlyInStalenessTransaction(c *C) { testcases := []struct { name string sql string @@ -4036,7 +4110,7 @@ func (s *testSessionSuite) TestValidateReadOnlyInStalenessTransaction(c *C) { tk.MustExec(`set @@tidb_enable_noop_functions=1;`) for _, testcase := range testcases { c.Log(testcase.name) - tk.MustExec(`START TRANSACTION READ ONLY WITH TIMESTAMP BOUND READ TIMESTAMP '2020-09-06 00:00:00';`) + tk.MustExec(`START TRANSACTION READ ONLY WITH TIMESTAMP BOUND EXACT STALENESS '00:00:00';`) if testcase.isValidate { _, err := tk.Exec(testcase.sql) c.Assert(err, IsNil) @@ -4050,8 +4124,6 @@ func (s *testSessionSuite) TestValidateReadOnlyInStalenessTransaction(c *C) { } func (s *testSessionSerialSuite) TestSpecialSQLInStalenessTxn(c *C) { - c.Assert(failpoint.Enable("github.com/pingcap/tidb/executor/mockStalenessTxnSchemaVer", "return(false)"), IsNil) - defer failpoint.Disable("github.com/pingcap/tidb/executor/mockStalenessTxnSchemaVer") tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") testcases := []struct { @@ -4098,7 +4170,7 @@ func (s *testSessionSerialSuite) TestSpecialSQLInStalenessTxn(c *C) { tk.MustExec("CREATE USER 'newuser' IDENTIFIED BY 'mypassword';") for _, testcase := range testcases { comment := Commentf(testcase.name) - tk.MustExec(`START TRANSACTION READ ONLY WITH TIMESTAMP BOUND READ TIMESTAMP '2020-09-06 00:00:00';`) + tk.MustExec(`START TRANSACTION READ ONLY WITH TIMESTAMP BOUND EXACT STALENESS '00:00:00';`) c.Assert(tk.Se.GetSessionVars().TxnCtx.IsStaleness, Equals, true, comment) tk.MustExec(testcase.sql) c.Assert(tk.Se.GetSessionVars().TxnCtx.IsStaleness, Equals, testcase.sameSession, comment) @@ -4404,3 +4476,13 @@ func (s *testTxnStateSuite) TestRollbacking(c *C) { c.Assert(tk.Se.TxnInfo().State, Equals, txninfo.TxnRollingBack) <-ch } + +func (s *testSessionSuite) TestReadDMLBatchSize(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("set global tidb_dml_batch_size=1000") + se, err := session.CreateSession(s.store) + c.Assert(err, IsNil) + // `select 1` to load the global variables. + _, _ = se.Execute(context.TODO(), "select 1") + c.Assert(se.GetSessionVars().DMLBatchSize, Equals, 1000) +} diff --git a/sessionctx/stmtctx/stmtctx.go b/sessionctx/stmtctx/stmtctx.go index 8df0001427173..d8a75aec48610 100644 --- a/sessionctx/stmtctx/stmtctx.go +++ b/sessionctx/stmtctx/stmtctx.go @@ -140,8 +140,6 @@ type StatementContext struct { RuntimeStatsColl *execdetails.RuntimeStatsColl TableIDs []int64 IndexNames []string - nowTs time.Time // use this variable for now/current_timestamp calculation/cache for one stmt - stmtTimeCached bool StmtType string OriginalSQL string digestMemo struct { @@ -164,6 +162,9 @@ type StatementContext struct { TblInfo2UnionScan map[*model.TableInfo]bool TaskID uint64 // unique ID for an execution of a statement TaskMapBakTS uint64 // counter for + + // stmtCache is used to store some statement-related values. + stmtCache map[StmtCacheKey]interface{} } // StmtHints are SessionVars related sql hints. @@ -195,19 +196,35 @@ func (sh *StmtHints) TaskMapNeedBackUp() bool { return sh.ForceNthPlan != -1 } -// GetNowTsCached getter for nowTs, if not set get now time and cache it -func (sc *StatementContext) GetNowTsCached() time.Time { - if !sc.stmtTimeCached { - now := time.Now() - sc.nowTs = now - sc.stmtTimeCached = true +// StmtCacheKey represents the key type in the StmtCache. +type StmtCacheKey int + +const ( + // StmtNowTsCacheKey is a variable for now/current_timestamp calculation/cache of one stmt. + StmtNowTsCacheKey StmtCacheKey = iota + // StmtSafeTSCacheKey is a variable for safeTS calculation/cache of one stmt. + StmtSafeTSCacheKey +) + +// GetOrStoreStmtCache gets the cached value of the given key if it exists, otherwise stores the value. +func (sc *StatementContext) GetOrStoreStmtCache(key StmtCacheKey, value interface{}) interface{} { + if sc.stmtCache == nil { + sc.stmtCache = make(map[StmtCacheKey]interface{}) + } + if _, ok := sc.stmtCache[key]; !ok { + sc.stmtCache[key] = value } - return sc.nowTs + return sc.stmtCache[key] +} + +// ResetInStmtCache resets the cache of given key. +func (sc *StatementContext) ResetInStmtCache(key StmtCacheKey) { + delete(sc.stmtCache, key) } -// ResetNowTs resetter for nowTs, clear cached time flag -func (sc *StatementContext) ResetNowTs() { - sc.stmtTimeCached = false +// ResetStmtCache resets all cached values. +func (sc *StatementContext) ResetStmtCache() { + sc.stmtCache = make(map[StmtCacheKey]interface{}) } // SQLDigest gets normalized and digest for provided sql. diff --git a/sessionctx/variable/session.go b/sessionctx/variable/session.go index 0c6c74d90a26d..c474e7905fa7b 100644 --- a/sessionctx/variable/session.go +++ b/sessionctx/variable/session.go @@ -1421,6 +1421,15 @@ func (s *SessionVars) SetSystemVar(name string, val string) error { return sv.SetSessionFromHook(s, val) } +// SetSystemVarWithRelaxedValidation sets the value of a system variable for session scope. +// Validation functions are called, but scope validation is skipped. +// Errors are not expected to be returned because this could cause upgrade issues. +func (s *SessionVars) SetSystemVarWithRelaxedValidation(name string, val string) error { + sv := GetSysVar(name) + val = sv.ValidateWithRelaxedValidation(s, val, ScopeSession) + return sv.SetSessionFromHook(s, val) +} + // GetReadableTxnMode returns the session variable TxnMode but rewrites it to "OPTIMISTIC" when it's empty. func (s *SessionVars) GetReadableTxnMode() string { txnMode := s.TxnMode diff --git a/sessionctx/variable/sysvar.go b/sessionctx/variable/sysvar.go index 98518fe4af0f0..e6f632db6b02d 100644 --- a/sessionctx/variable/sysvar.go +++ b/sessionctx/variable/sysvar.go @@ -188,6 +188,10 @@ func (sv *SysVar) HasGlobalScope() bool { // Validate checks if system variable satisfies specific restriction. func (sv *SysVar) Validate(vars *SessionVars, value string, scope ScopeFlag) (string, error) { + // Check that the scope is correct first. + if err := sv.validateScope(scope); err != nil { + return value, err + } // Normalize the value and apply validation based on type. // i.e. TypeBool converts 1/on/ON to ON. normalizedValue, err := sv.validateFromType(vars, value, scope) @@ -203,17 +207,6 @@ func (sv *SysVar) Validate(vars *SessionVars, value string, scope ScopeFlag) (st // validateFromType provides automatic validation based on the SysVar's type func (sv *SysVar) validateFromType(vars *SessionVars, value string, scope ScopeFlag) (string, error) { - // Check that the scope is correct and return the appropriate error message. - if sv.ReadOnly || sv.Scope == ScopeNone { - return value, ErrIncorrectScope.FastGenByArgs(sv.Name, "read only") - } - if scope == ScopeGlobal && !sv.HasGlobalScope() { - return value, errLocalVariable.FastGenByArgs(sv.Name) - } - if scope == ScopeSession && !sv.HasSessionScope() { - return value, errGlobalVariable.FastGenByArgs(sv.Name) - } - // The string "DEFAULT" is a special keyword in MySQL, which restores // the compiled sysvar value. In which case we can skip further validation. if strings.EqualFold(value, "DEFAULT") { @@ -245,6 +238,37 @@ func (sv *SysVar) validateFromType(vars *SessionVars, value string, scope ScopeF return value, nil // typeString } +func (sv *SysVar) validateScope(scope ScopeFlag) error { + if sv.ReadOnly || sv.Scope == ScopeNone { + return ErrIncorrectScope.FastGenByArgs(sv.Name, "read only") + } + if scope == ScopeGlobal && !sv.HasGlobalScope() { + return errLocalVariable.FastGenByArgs(sv.Name) + } + if scope == ScopeSession && !sv.HasSessionScope() { + return errGlobalVariable.FastGenByArgs(sv.Name) + } + return nil +} + +// ValidateWithRelaxedValidation normalizes values but can not return errors. +// Normalization+validation needs to be applied when reading values because older versions of TiDB +// may be less sophisticated in normalizing values. But errors should be caught and handled, +// because otherwise there will be upgrade issues. +func (sv *SysVar) ValidateWithRelaxedValidation(vars *SessionVars, value string, scope ScopeFlag) string { + normalizedValue, err := sv.validateFromType(vars, value, scope) + if err != nil { + return normalizedValue + } + if sv.Validation != nil { + normalizedValue, err = sv.Validation(vars, normalizedValue, value, scope) + if err != nil { + return normalizedValue + } + } + return normalizedValue +} + const ( localDayTimeFormat = "15:04" // FullDayTimeFormat is the full format of analyze start time and end time. @@ -485,11 +509,15 @@ func SetSysVar(name string, value string) { sysVars[name].Value = value } -// GetSysVars returns the sysVars list under a RWLock +// GetSysVars deep copies the sysVars list under a RWLock func GetSysVars() map[string]*SysVar { sysVarsLock.RLock() defer sysVarsLock.RUnlock() - return sysVars + copy := make(map[string]*SysVar, len(sysVars)) + for name, sv := range sysVars { + copy[name] = sv + } + return copy } // PluginVarNames is global plugin var names set. @@ -876,7 +904,7 @@ var defaultSysVars = []*SysVar{ return nil }}, {Scope: ScopeGlobal | ScopeSession, Name: TiDBDMLBatchSize, Value: strconv.Itoa(DefDMLBatchSize), Type: TypeUnsigned, MinValue: 0, MaxValue: math.MaxUint64, SetSession: func(s *SessionVars, val string) error { - s.DMLBatchSize = int(tidbOptInt64(val, DefOptCorrelationExpFactor)) + s.DMLBatchSize = int(tidbOptInt64(val, DefDMLBatchSize)) return nil }}, {Scope: ScopeSession, Name: TiDBCurrentTS, Value: strconv.Itoa(DefCurretTS), ReadOnly: true}, diff --git a/store/copr/batch_coprocessor.go b/store/copr/batch_coprocessor.go index b0c0ad5c9ea7b..8c73f58fbf892 100644 --- a/store/copr/batch_coprocessor.go +++ b/store/copr/batch_coprocessor.go @@ -16,6 +16,8 @@ package copr import ( "context" "io" + "math" + "strconv" "sync" "sync/atomic" "time" @@ -25,6 +27,7 @@ import ( "github.com/pingcap/kvproto/pkg/coprocessor" "github.com/pingcap/kvproto/pkg/kvrpcpb" "github.com/pingcap/kvproto/pkg/metapb" + "github.com/pingcap/log" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/store/driver/backoff" derr "github.com/pingcap/tidb/store/driver/error" @@ -40,8 +43,9 @@ import ( type batchCopTask struct { storeAddr string cmdType tikvrpc.CmdType + ctx *tikv.RPCContext - copTasks []copTaskAndRPCContext + regionInfos []tikv.RegionInfo } type batchCopResponse struct { @@ -93,9 +97,152 @@ func (rs *batchCopResponse) RespTime() time.Duration { return rs.respTime } -type copTaskAndRPCContext struct { - task *copTask - ctx *tikv.RPCContext +// balanceBatchCopTask balance the regions between available stores, the basic rule is +// 1. the first region of each original batch cop task belongs to its original store because some +// meta data(like the rpc context) in batchCopTask is related to it +// 2. for the remaining regions: +// if there is only 1 available store, then put the region to the related store +// otherwise, use a greedy algorithm to put it into the store with highest weight +func balanceBatchCopTask(originalTasks []*batchCopTask) []*batchCopTask { + if len(originalTasks) <= 1 { + return originalTasks + } + storeTaskMap := make(map[uint64]*batchCopTask) + storeCandidateRegionMap := make(map[uint64]map[string]tikv.RegionInfo) + totalRegionCandidateNum := 0 + totalRemainingRegionNum := 0 + + for _, task := range originalTasks { + taskStoreID := task.regionInfos[0].AllStores[0] + batchTask := &batchCopTask{ + storeAddr: task.storeAddr, + cmdType: task.cmdType, + ctx: task.ctx, + regionInfos: []tikv.RegionInfo{task.regionInfos[0]}, + } + storeTaskMap[taskStoreID] = batchTask + } + + for _, task := range originalTasks { + taskStoreID := task.regionInfos[0].AllStores[0] + for index, ri := range task.regionInfos { + // for each region, figure out the valid store num + validStoreNum := 0 + if index == 0 { + continue + } + if len(ri.AllStores) <= 1 { + validStoreNum = 1 + } else { + for _, storeID := range ri.AllStores { + if _, ok := storeTaskMap[storeID]; ok { + validStoreNum++ + } + } + } + if validStoreNum == 1 { + // if only one store is valid, just put it to storeTaskMap + storeTaskMap[taskStoreID].regionInfos = append(storeTaskMap[taskStoreID].regionInfos, ri) + } else { + // if more than one store is valid, put the region + // to store candidate map + totalRegionCandidateNum += validStoreNum + totalRemainingRegionNum += 1 + taskKey := ri.Region.String() + for _, storeID := range ri.AllStores { + if _, validStore := storeTaskMap[storeID]; !validStore { + continue + } + if _, ok := storeCandidateRegionMap[storeID]; !ok { + candidateMap := make(map[string]tikv.RegionInfo) + storeCandidateRegionMap[storeID] = candidateMap + } + if _, duplicateRegion := storeCandidateRegionMap[storeID][taskKey]; duplicateRegion { + // duplicated region, should not happen, just give up balance + logutil.BgLogger().Warn("Meet duplicated region info during when trying to balance batch cop task, give up balancing") + return originalTasks + } + storeCandidateRegionMap[storeID][taskKey] = ri + } + } + } + } + if totalRemainingRegionNum == 0 { + return originalTasks + } + + avgStorePerRegion := float64(totalRegionCandidateNum) / float64(totalRemainingRegionNum) + findNextStore := func(candidateStores []uint64) uint64 { + store := uint64(math.MaxUint64) + weightedRegionNum := math.MaxFloat64 + if candidateStores != nil { + for _, storeID := range candidateStores { + if _, validStore := storeCandidateRegionMap[storeID]; !validStore { + continue + } + num := float64(len(storeCandidateRegionMap[storeID]))/avgStorePerRegion + float64(len(storeTaskMap[storeID].regionInfos)) + if num < weightedRegionNum { + store = storeID + weightedRegionNum = num + } + } + if store != uint64(math.MaxUint64) { + return store + } + } + for storeID := range storeTaskMap { + if _, validStore := storeCandidateRegionMap[storeID]; !validStore { + continue + } + num := float64(len(storeCandidateRegionMap[storeID]))/avgStorePerRegion + float64(len(storeTaskMap[storeID].regionInfos)) + if num < weightedRegionNum { + store = storeID + weightedRegionNum = num + } + } + return store + } + + store := findNextStore(nil) + for totalRemainingRegionNum > 0 { + if store == uint64(math.MaxUint64) { + break + } + var key string + var ri tikv.RegionInfo + for key, ri = range storeCandidateRegionMap[store] { + // get the first region + break + } + storeTaskMap[store].regionInfos = append(storeTaskMap[store].regionInfos, ri) + totalRemainingRegionNum-- + for _, id := range ri.AllStores { + if _, ok := storeCandidateRegionMap[id]; ok { + delete(storeCandidateRegionMap[id], key) + totalRegionCandidateNum-- + if len(storeCandidateRegionMap[id]) == 0 { + delete(storeCandidateRegionMap, id) + } + } + } + if totalRemainingRegionNum > 0 { + avgStorePerRegion = float64(totalRegionCandidateNum) / float64(totalRemainingRegionNum) + // it is not optimal because we only check the stores that affected by this region, in fact in order + // to find out the store with the lowest weightedRegionNum, all stores should be checked, but I think + // check only the affected stores is more simple and will get a good enough result + store = findNextStore(ri.AllStores) + } + } + if totalRemainingRegionNum > 0 { + logutil.BgLogger().Warn("Some regions are not used when trying to balance batch cop task, give up balancing") + return originalTasks + } + + var ret []*batchCopTask + for _, task := range storeTaskMap { + ret = append(ret, task) + } + return ret } func buildBatchCopTasks(bo *Backoffer, cache *tikv.RegionCache, ranges *tikv.KeyRanges, storeType kv.StoreType) ([]*batchCopTask, error) { @@ -138,13 +285,15 @@ func buildBatchCopTasks(bo *Backoffer, cache *tikv.RegionCache, ranges *tikv.Key // Then `splitRegion` will reloads these regions. continue } + allStores := cache.GetAllValidTiFlashStores(task.region, rpcCtx.Store) if batchCop, ok := storeTaskMap[rpcCtx.Addr]; ok { - batchCop.copTasks = append(batchCop.copTasks, copTaskAndRPCContext{task: task, ctx: rpcCtx}) + batchCop.regionInfos = append(batchCop.regionInfos, tikv.RegionInfo{Region: task.region, Meta: rpcCtx.Meta, Ranges: task.ranges, AllStores: allStores}) } else { batchTask := &batchCopTask{ - storeAddr: rpcCtx.Addr, - cmdType: cmdType, - copTasks: []copTaskAndRPCContext{{task, rpcCtx}}, + storeAddr: rpcCtx.Addr, + cmdType: cmdType, + ctx: rpcCtx, + regionInfos: []tikv.RegionInfo{{Region: task.region, Meta: rpcCtx.Meta, Ranges: task.ranges, AllStores: allStores}}, } storeTaskMap[rpcCtx.Addr] = batchTask } @@ -159,9 +308,25 @@ func buildBatchCopTasks(bo *Backoffer, cache *tikv.RegionCache, ranges *tikv.Key } continue } + for _, task := range storeTaskMap { batchTasks = append(batchTasks, task) } + if log.GetLevel() <= zap.DebugLevel { + msg := "Before region balance:" + for _, task := range batchTasks { + msg += " store " + task.storeAddr + ": " + strconv.Itoa(len(task.regionInfos)) + " regions," + } + logutil.BgLogger().Debug(msg) + } + batchTasks = balanceBatchCopTask(batchTasks) + if log.GetLevel() <= zap.DebugLevel { + msg := "After region balance:" + for _, task := range batchTasks { + msg += " store " + task.storeAddr + ": " + strconv.Itoa(len(task.regionInfos)) + " regions," + } + logutil.BgLogger().Debug(msg) + } if elapsed := time.Since(start); elapsed > time.Millisecond*500 { logutil.BgLogger().Warn("buildBatchCopTasks takes too much time", @@ -311,8 +476,8 @@ func (b *batchCopIterator) handleTask(ctx context.Context, bo *Backoffer, task * // Merge all ranges and request again. func (b *batchCopIterator) retryBatchCopTask(ctx context.Context, bo *Backoffer, batchTask *batchCopTask) ([]*batchCopTask, error) { var ranges []tikvstore.KeyRange - for _, taskCtx := range batchTask.copTasks { - taskCtx.task.ranges.Do(func(ran *tikvstore.KeyRange) { + for _, ri := range batchTask.regionInfos { + ri.Ranges.Do(func(ran *tikvstore.KeyRange) { ranges = append(ranges, *ran) }) } @@ -320,16 +485,16 @@ func (b *batchCopIterator) retryBatchCopTask(ctx context.Context, bo *Backoffer, } func (b *batchCopIterator) handleTaskOnce(ctx context.Context, bo *Backoffer, task *batchCopTask) ([]*batchCopTask, error) { - sender := NewRegionBatchRequestSender(b.store.GetRegionCache(), b.store.GetTiKVClient()) - var regionInfos = make([]*coprocessor.RegionInfo, 0, len(task.copTasks)) - for _, task := range task.copTasks { + sender := tikv.NewRegionBatchRequestSender(b.store.GetRegionCache(), b.store.GetTiKVClient()) + var regionInfos = make([]*coprocessor.RegionInfo, 0, len(task.regionInfos)) + for _, ri := range task.regionInfos { regionInfos = append(regionInfos, &coprocessor.RegionInfo{ - RegionId: task.task.region.GetID(), + RegionId: ri.Region.GetID(), RegionEpoch: &metapb.RegionEpoch{ - ConfVer: task.task.region.GetConfVer(), - Version: task.task.region.GetVer(), + ConfVer: ri.Region.GetConfVer(), + Version: ri.Region.GetVer(), }, - Ranges: task.task.ranges.ToPBRanges(), + Ranges: ri.Ranges.ToPBRanges(), }) } @@ -351,13 +516,14 @@ func (b *batchCopIterator) handleTaskOnce(ctx context.Context, bo *Backoffer, ta }) req.StoreTp = tikvrpc.TiFlash - logutil.BgLogger().Debug("send batch request to ", zap.String("req info", req.String()), zap.Int("cop task len", len(task.copTasks))) - resp, retry, cancel, err := sender.sendStreamReqToAddr(bo, task.copTasks, req, tikv.ReadTimeoutUltraLong) + logutil.BgLogger().Debug("send batch request to ", zap.String("req info", req.String()), zap.Int("cop task len", len(task.regionInfos))) + resp, retry, cancel, err := sender.SendReqToAddr(bo.TiKVBackoffer(), task.ctx, task.regionInfos, req, tikv.ReadTimeoutUltraLong) // If there are store errors, we should retry for all regions. if retry { return b.retryBatchCopTask(ctx, bo, task) } if err != nil { + err = derr.ToTiDBErr(err) return nil, errors.Trace(err) } defer cancel() diff --git a/store/copr/mpp.go b/store/copr/mpp.go index 2aaf4223ed8e5..1941f2b3fbfa4 100644 --- a/store/copr/mpp.go +++ b/store/copr/mpp.go @@ -180,14 +180,14 @@ func (m *mppIterator) handleDispatchReq(ctx context.Context, bo *Backoffer, req var regionInfos []*coprocessor.RegionInfo originalTask, ok := req.Meta.(*batchCopTask) if ok { - for _, task := range originalTask.copTasks { + for _, ri := range originalTask.regionInfos { regionInfos = append(regionInfos, &coprocessor.RegionInfo{ - RegionId: task.task.region.GetID(), + RegionId: ri.Region.GetID(), RegionEpoch: &metapb.RegionEpoch{ - ConfVer: task.task.region.GetConfVer(), - Version: task.task.region.GetVer(), + ConfVer: ri.Region.GetConfVer(), + Version: ri.Region.GetVer(), }, - Ranges: task.task.ranges.ToPBRanges(), + Ranges: ri.Ranges.ToPBRanges(), }) } } @@ -214,8 +214,8 @@ func (m *mppIterator) handleDispatchReq(ctx context.Context, bo *Backoffer, req // Or else it's the task without region, which always happens in high layer task without table. // In that case if originalTask != nil { - sender := NewRegionBatchRequestSender(m.store.GetRegionCache(), m.store.GetTiKVClient()) - rpcResp, _, _, err = sender.sendStreamReqToAddr(bo, originalTask.copTasks, wrappedReq, tikv.ReadTimeoutMedium) + sender := tikv.NewRegionBatchRequestSender(m.store.GetRegionCache(), m.store.GetTiKVClient()) + rpcResp, _, _, err = sender.SendReqToAddr(bo.TiKVBackoffer(), originalTask.ctx, originalTask.regionInfos, wrappedReq, tikv.ReadTimeoutMedium) // No matter what the rpc error is, we won't retry the mpp dispatch tasks. // TODO: If we want to retry, we must redo the plan fragment cutting and task scheduling. // That's a hard job but we can try it in the future. diff --git a/store/helper/helper.go b/store/helper/helper.go index e96ad4ae21851..49aa7cf2107e0 100644 --- a/store/helper/helper.go +++ b/store/helper/helper.go @@ -71,6 +71,7 @@ type Storage interface { SetTiKVClient(client tikv.Client) GetTiKVClient() tikv.Client Closed() <-chan struct{} + GetMinSafeTS(txnScope string) uint64 } // Helper is a middleware to get some information from tikv/pd. It can be used for TiDB's http api or mem table. diff --git a/store/mockstore/mockstorage/storage.go b/store/mockstore/mockstorage/storage.go index 36ded5e434817..6221ef855707d 100644 --- a/store/mockstore/mockstorage/storage.go +++ b/store/mockstore/mockstorage/storage.go @@ -99,6 +99,11 @@ func (s *mockStorage) CurrentVersion(txnScope string) (kv.Version, error) { return kv.NewVersion(ver), err } +// GetMinSafeTS return the minimal SafeTS of the storage with given txnScope. +func (s *mockStorage) GetMinSafeTS(txnScope string) uint64 { + return 0 +} + func newTiKVTxn(txn *tikv.KVTxn, err error) (kv.Transaction, error) { if err != nil { return nil, err diff --git a/store/tikv/2pc.go b/store/tikv/2pc.go index 19f3e4faf40e3..14609f5f77400 100644 --- a/store/tikv/2pc.go +++ b/store/tikv/2pc.go @@ -739,15 +739,11 @@ func (tm *ttlManager) keepAlive(c *twoPhaseCommitter) { return } bo := retry.NewBackofferWithVars(context.Background(), pessimisticLockMaxBackoff, c.txn.vars) - now, err := c.store.GetOracle().GetTimestamp(bo.GetCtx(), &oracle.Option{TxnScope: oracle.GlobalTxnScope}) + now, err := c.store.getTimestampWithRetry(bo, c.txn.GetScope()) if err != nil { - err1 := bo.Backoff(retry.BoPDRPC, err) - if err1 != nil { - logutil.Logger(bo.GetCtx()).Warn("keepAlive get tso fail", - zap.Error(err)) - return - } - continue + logutil.Logger(bo.GetCtx()).Warn("keepAlive get tso fail", + zap.Error(err)) + return } uptime := uint64(oracle.ExtractPhysical(now) - oracle.ExtractPhysical(c.startTS)) @@ -999,7 +995,7 @@ func (c *twoPhaseCommitter) execute(ctx context.Context) (err error) { // from PD and plus one as our MinCommitTS. if commitTSMayBeCalculated && c.needLinearizability() { failpoint.Inject("getMinCommitTSFromTSO", nil) - latestTS, err := c.store.oracle.GetTimestamp(ctx, &oracle.Option{TxnScope: oracle.GlobalTxnScope}) + latestTS, err := c.store.getTimestampWithRetry(retry.NewBackofferWithVars(ctx, tsoMaxBackoff, c.txn.vars), c.txn.GetScope()) // If we fail to get a timestamp from PD, we just propagate the failure // instead of falling back to the normal 2PC because a normal 2PC will // also be likely to fail due to the same timestamp issue. diff --git a/store/copr/batch_request_sender.go b/store/tikv/batch_request_sender.go similarity index 54% rename from store/copr/batch_request_sender.go rename to store/tikv/batch_request_sender.go index 422306382337d..9aad070b70306 100644 --- a/store/copr/batch_request_sender.go +++ b/store/tikv/batch_request_sender.go @@ -11,7 +11,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package copr +package tikv import ( "context" @@ -19,45 +19,52 @@ import ( "time" "github.com/pingcap/errors" - "github.com/pingcap/tidb/store/tikv" + "github.com/pingcap/kvproto/pkg/metapb" tikverr "github.com/pingcap/tidb/store/tikv/error" "github.com/pingcap/tidb/store/tikv/tikvrpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) +// RegionInfo contains region related information for batchCopTask +type RegionInfo struct { + Region RegionVerID + Meta *metapb.Region + Ranges *KeyRanges + AllStores []uint64 +} + // RegionBatchRequestSender sends BatchCop requests to TiFlash server by stream way. type RegionBatchRequestSender struct { - *tikv.RegionRequestSender + *RegionRequestSender } // NewRegionBatchRequestSender creates a RegionBatchRequestSender object. -func NewRegionBatchRequestSender(cache *tikv.RegionCache, client tikv.Client) *RegionBatchRequestSender { +func NewRegionBatchRequestSender(cache *RegionCache, client Client) *RegionBatchRequestSender { return &RegionBatchRequestSender{ - RegionRequestSender: tikv.NewRegionRequestSender(cache, client), + RegionRequestSender: NewRegionRequestSender(cache, client), } } -func (ss *RegionBatchRequestSender) sendStreamReqToAddr(bo *Backoffer, ctxs []copTaskAndRPCContext, req *tikvrpc.Request, timout time.Duration) (resp *tikvrpc.Response, retry bool, cancel func(), err error) { - // use the first ctx to send request, because every ctx has same address. +// SendReqToAddr send batch cop request +func (ss *RegionBatchRequestSender) SendReqToAddr(bo *Backoffer, rpcCtx *RPCContext, regionInfos []RegionInfo, req *tikvrpc.Request, timout time.Duration) (resp *tikvrpc.Response, retry bool, cancel func(), err error) { cancel = func() {} - rpcCtx := ctxs[0].ctx if e := tikvrpc.SetContext(req, rpcCtx.Meta, rpcCtx.Peer); e != nil { return nil, false, cancel, errors.Trace(e) } ctx := bo.GetCtx() - if rawHook := ctx.Value(tikv.RPCCancellerCtxKey{}); rawHook != nil { - ctx, cancel = rawHook.(*tikv.RPCCanceller).WithCancel(ctx) + if rawHook := ctx.Value(RPCCancellerCtxKey{}); rawHook != nil { + ctx, cancel = rawHook.(*RPCCanceller).WithCancel(ctx) } start := time.Now() resp, err = ss.GetClient().SendRequest(ctx, rpcCtx.Addr, req, timout) if ss.Stats != nil { - tikv.RecordRegionRequestRuntimeStats(ss.Stats, req.Type, time.Since(start)) + RecordRegionRequestRuntimeStats(ss.Stats, req.Type, time.Since(start)) } if err != nil { cancel() ss.SetRPCError(err) - e := ss.onSendFail(bo, ctxs, err) + e := ss.onSendFailForBatchRegions(bo, rpcCtx, regionInfos, err) if e != nil { return nil, false, func() {}, errors.Trace(e) } @@ -67,30 +74,25 @@ func (ss *RegionBatchRequestSender) sendStreamReqToAddr(bo *Backoffer, ctxs []co return } -func (ss *RegionBatchRequestSender) onSendFail(bo *Backoffer, ctxs []copTaskAndRPCContext, err error) error { +func (ss *RegionBatchRequestSender) onSendFailForBatchRegions(bo *Backoffer, ctx *RPCContext, regionInfos []RegionInfo, err error) error { // If it failed because the context is cancelled by ourself, don't retry. if errors.Cause(err) == context.Canceled || status.Code(errors.Cause(err)) == codes.Canceled { return errors.Trace(err) - } else if atomic.LoadUint32(&tikv.ShuttingDown) > 0 { + } else if atomic.LoadUint32(&ShuttingDown) > 0 { return tikverr.ErrTiDBShuttingDown } - for _, failedCtx := range ctxs { - ctx := failedCtx.ctx - if ctx.Meta != nil { - // The reload region param is always true. Because that every time we try, we must - // re-build the range then re-create the batch sender. As a result, the len of "failStores" - // will change. If tiflash's replica is more than two, the "reload region" will always be false. - // Now that the batch cop and mpp has a relative low qps, it's reasonable to reload every time - // when meeting io error. - ss.GetRegionCache().OnSendFail(bo.TiKVBackoffer(), ctx, true, err) - } - } + // The reload region param is always true. Because that every time we try, we must + // re-build the range then re-create the batch sender. As a result, the len of "failStores" + // will change. If tiflash's replica is more than two, the "reload region" will always be false. + // Now that the batch cop and mpp has a relative low qps, it's reasonable to reload every time + // when meeting io error. + ss.GetRegionCache().OnSendFailForBatchRegions(bo, ctx.Store, regionInfos, true, err) // Retry on send request failure when it's not canceled. // When a store is not available, the leader of related region should be elected quickly. // TODO: the number of retry time should be limited:since region may be unavailable // when some unrecoverable disaster happened. - err = bo.Backoff(tikv.BoTiFlashRPC, errors.Errorf("send tikv request error: %v, ctxs: %v, try next peer later", err, ctxs)) + err = bo.Backoff(BoTiFlashRPC, errors.Errorf("send request error: %v, ctx: %v, regionInfos: %v", err, ctx, regionInfos)) return errors.Trace(err) } diff --git a/store/tikv/kv.go b/store/tikv/kv.go index bbf8517a42a8c..edaef3b4744d7 100644 --- a/store/tikv/kv.go +++ b/store/tikv/kv.go @@ -18,6 +18,7 @@ import ( "crypto/tls" "math" "math/rand" + "strconv" "sync" "sync/atomic" "time" @@ -26,6 +27,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/failpoint" "github.com/pingcap/kvproto/pkg/kvrpcpb" + "github.com/pingcap/kvproto/pkg/metapb" tidbkv "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/store/tikv/config" tikverr "github.com/pingcap/tidb/store/tikv/error" @@ -84,7 +86,7 @@ type KVStore struct { safePoint uint64 spTime time.Time spMutex sync.RWMutex // this is used to update safePoint and spTime - closed chan struct{} // this is used to nofity when the store is closed + closed chan struct{} // this is used to notify when the store is closed // storeID -> safeTS, stored as map[uint64]uint64 // safeTS here will be used during the Stale Read process, @@ -358,6 +360,27 @@ func (s *KVStore) GetTiKVClient() (client Client) { return s.clientMu.client } +// GetMinSafeTS return the minimal safeTS of the storage with given txnScope. +func (s *KVStore) GetMinSafeTS(txnScope string) uint64 { + stores := make([]*Store, 0) + allStores := s.regionCache.getStoresByType(tikvrpc.TiKV) + if txnScope != oracle.GlobalTxnScope { + for _, store := range allStores { + if store.IsLabelsMatch([]*metapb.StoreLabel{ + { + Key: DCLabelKey, + Value: txnScope, + }, + }) { + stores = append(stores, store) + } + } + } else { + stores = allStores + } + return s.getMinSafeTSByStores(stores) +} + func (s *KVStore) getSafeTS(storeID uint64) uint64 { safeTS, ok := s.safeTSMap.Load(storeID) if !ok { @@ -414,17 +437,20 @@ func (s *KVStore) updateSafeTS(ctx context.Context) { storeAddr := store.addr go func(ctx context.Context, wg *sync.WaitGroup, storeID uint64, storeAddr string) { defer wg.Done() - // TODO: add metrics for updateSafeTS resp, err := tikvClient.SendRequest(ctx, storeAddr, tikvrpc.NewRequest(tikvrpc.CmdStoreSafeTS, &kvrpcpb.StoreSafeTSRequest{KeyRange: &kvrpcpb.KeyRange{ StartKey: []byte(""), EndKey: []byte(""), }}), ReadTimeoutShort) + storeIDStr := strconv.Itoa(int(storeID)) if err != nil { + metrics.TiKVSafeTSUpdateCounter.WithLabelValues("fail", storeIDStr).Inc() logutil.BgLogger().Debug("update safeTS failed", zap.Error(err), zap.Uint64("store-id", storeID)) return } safeTSResp := resp.Resp.(*kvrpcpb.StoreSafeTSResponse) s.setSafeTS(storeID, safeTSResp.GetSafeTs()) + metrics.TiKVSafeTSUpdateCounter.WithLabelValues("success", storeIDStr).Inc() + metrics.TiKVSafeTSUpdateStats.WithLabelValues(storeIDStr).Set(float64(safeTSResp.GetSafeTs())) }(ctx, wg, storeID, storeAddr) } wg.Wait() diff --git a/store/tikv/lock_resolver.go b/store/tikv/lock_resolver.go index fe50910a896e6..0ed9ecb3fa471 100644 --- a/store/tikv/lock_resolver.go +++ b/store/tikv/lock_resolver.go @@ -229,11 +229,6 @@ func (lr *LockResolver) BatchResolveLocks(bo *Backoffer, locks []*Lock, loc Regi // locks have been cleaned before GC. expiredLocks := locks - callerStartTS, err := lr.store.GetOracle().GetTimestamp(bo.GetCtx(), &oracle.Option{TxnScope: oracle.GlobalTxnScope}) - if err != nil { - return false, errors.Trace(err) - } - txnInfos := make(map[uint64]uint64) startTime := time.Now() for _, l := range expiredLocks { @@ -243,7 +238,7 @@ func (lr *LockResolver) BatchResolveLocks(bo *Backoffer, locks []*Lock, loc Regi metrics.LockResolverCountWithExpired.Inc() // Use currentTS = math.MaxUint64 means rollback the txn, no matter the lock is expired or not! - status, err := lr.getTxnStatus(bo, l.TxnID, l.Primary, callerStartTS, math.MaxUint64, true, false, l) + status, err := lr.getTxnStatus(bo, l.TxnID, l.Primary, 0, math.MaxUint64, true, false, l) if err != nil { return false, err } @@ -257,7 +252,7 @@ func (lr *LockResolver) BatchResolveLocks(bo *Backoffer, locks []*Lock, loc Regi continue } if _, ok := errors.Cause(err).(*nonAsyncCommitLock); ok { - status, err = lr.getTxnStatus(bo, l.TxnID, l.Primary, callerStartTS, math.MaxUint64, true, true, l) + status, err = lr.getTxnStatus(bo, l.TxnID, l.Primary, 0, math.MaxUint64, true, true, l) if err != nil { return false, err } diff --git a/store/tikv/metrics/metrics.go b/store/tikv/metrics/metrics.go index 8d71582fa2522..6b8ea32d456f7 100644 --- a/store/tikv/metrics/metrics.go +++ b/store/tikv/metrics/metrics.go @@ -59,6 +59,8 @@ var ( TiKVPanicCounter *prometheus.CounterVec TiKVForwardRequestCounter *prometheus.CounterVec TiKVTSFutureWaitDuration prometheus.Histogram + TiKVSafeTSUpdateCounter *prometheus.CounterVec + TiKVSafeTSUpdateStats *prometheus.GaugeVec ) // Label constants. @@ -414,6 +416,22 @@ func initMetrics(namespace, subsystem string) { Buckets: prometheus.ExponentialBuckets(0.000005, 2, 30), // 5us ~ 2560s }) + TiKVSafeTSUpdateCounter = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Namespace: namespace, + Subsystem: subsystem, + Name: "safets_update_counter", + Help: "Counter of tikv safe_ts being updated.", + }, []string{LblResult, LblStore}) + + TiKVSafeTSUpdateStats = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Namespace: namespace, + Subsystem: subsystem, + Name: "safets_update_stats", + Help: "stat of tikv updating safe_ts stats", + }, []string{LblStore}) + initShortcuts() } @@ -468,6 +486,8 @@ func RegisterMetrics() { prometheus.MustRegister(TiKVPanicCounter) prometheus.MustRegister(TiKVForwardRequestCounter) prometheus.MustRegister(TiKVTSFutureWaitDuration) + prometheus.MustRegister(TiKVSafeTSUpdateCounter) + prometheus.MustRegister(TiKVSafeTSUpdateStats) } // readCounter reads the value of a prometheus.Counter. diff --git a/store/tikv/region_cache.go b/store/tikv/region_cache.go index f6225a2724f8e..0d9423a9f5a7e 100644 --- a/store/tikv/region_cache.go +++ b/store/tikv/region_cache.go @@ -112,6 +112,15 @@ func (r *RegionStore) accessStore(mode AccessMode, idx AccessIndex) (int, *Store return sidx, r.stores[sidx] } +func (r *RegionStore) getAccessIndex(mode AccessMode, store *Store) AccessIndex { + for index, sidx := range r.accessIndex[mode] { + if r.stores[sidx].storeID == store.storeID { + return AccessIndex(index) + } + } + return -1 +} + func (r *RegionStore) accessStoreNum(mode AccessMode) int { return len(r.accessIndex[mode]) } @@ -526,6 +535,40 @@ func (c *RegionCache) GetTiKVRPCContext(bo *Backoffer, id RegionVerID, replicaRe }, nil } +// GetAllValidTiFlashStores returns the store ids of all valid TiFlash stores, the store id of currentStore is always the first one +func (c *RegionCache) GetAllValidTiFlashStores(id RegionVerID, currentStore *Store) []uint64 { + // set the cap to 2 because usually, TiFlash table will have 2 replicas + allStores := make([]uint64, 0, 2) + // make sure currentStore id is always the first in allStores + allStores = append(allStores, currentStore.storeID) + ts := time.Now().Unix() + cachedRegion := c.getCachedRegionWithRLock(id) + if cachedRegion == nil { + return allStores + } + if !cachedRegion.checkRegionCacheTTL(ts) { + return allStores + } + regionStore := cachedRegion.getStore() + currentIndex := regionStore.getAccessIndex(TiFlashOnly, currentStore) + if currentIndex == -1 { + return allStores + } + for startOffset := 1; startOffset < regionStore.accessStoreNum(TiFlashOnly); startOffset++ { + accessIdx := AccessIndex((int(currentIndex) + startOffset) % regionStore.accessStoreNum(TiFlashOnly)) + storeIdx, store := regionStore.accessStore(TiFlashOnly, accessIdx) + if store.getResolveState() == needCheck { + continue + } + storeFailEpoch := atomic.LoadUint32(&store.epoch) + if storeFailEpoch != regionStore.storeEpochs[storeIdx] { + continue + } + allStores = append(allStores, store.storeID) + } + return allStores +} + // GetTiFlashRPCContext returns RPCContext for a region must access flash store. If it returns nil, the region // must be out of date and already dropped from cache or not flash store found. // `loadBalance` is an option. For MPP and batch cop, it is pointless and might cause try the failed store repeatly. @@ -668,6 +711,64 @@ func (c *RegionCache) findRegionByKey(bo *Backoffer, key []byte, isEndKey bool) return r, nil } +// OnSendFailForBatchRegions handles send request fail logic. +func (c *RegionCache) OnSendFailForBatchRegions(bo *Backoffer, store *Store, regionInfos []RegionInfo, scheduleReload bool, err error) { + metrics.RegionCacheCounterWithSendFail.Add(float64(len(regionInfos))) + if store.storeType != tikvrpc.TiFlash { + logutil.Logger(bo.GetCtx()).Info("Should not reach here, OnSendFailForBatchRegions only support TiFlash") + return + } + for _, ri := range regionInfos { + if ri.Meta == nil { + continue + } + r := c.getCachedRegionWithRLock(ri.Region) + if r != nil { + peersNum := len(r.meta.Peers) + if len(ri.Meta.Peers) != peersNum { + logutil.Logger(bo.GetCtx()).Info("retry and refresh current region after send request fail and up/down stores length changed", + zap.Stringer("region", &ri.Region), + zap.Bool("needReload", scheduleReload), + zap.Reflect("oldPeers", ri.Meta.Peers), + zap.Reflect("newPeers", r.meta.Peers), + zap.Error(err)) + continue + } + + rs := r.getStore() + + accessMode := TiFlashOnly + accessIdx := rs.getAccessIndex(accessMode, store) + if accessIdx == -1 { + logutil.Logger(bo.GetCtx()).Warn("can not get access index for region " + ri.Region.String()) + continue + } + if err != nil { + storeIdx, s := rs.accessStore(accessMode, accessIdx) + epoch := rs.storeEpochs[storeIdx] + if atomic.CompareAndSwapUint32(&s.epoch, epoch, epoch+1) { + logutil.BgLogger().Info("mark store's regions need be refill", zap.String("store", s.addr)) + metrics.RegionCacheCounterWithInvalidateStoreRegionsOK.Inc() + } + // schedule a store addr resolve. + s.markNeedCheck(c.notifyCheckCh) + } + + // try next peer + rs.switchNextFlashPeer(r, accessIdx) + logutil.Logger(bo.GetCtx()).Info("switch region tiflash peer to next due to send request fail", + zap.Stringer("region", &ri.Region), + zap.Bool("needReload", scheduleReload), + zap.Error(err)) + + // force reload region when retry all known peers in region. + if scheduleReload { + r.scheduleReload() + } + } + } +} + // OnSendFail handles send request fail logic. func (c *RegionCache) OnSendFail(bo *Backoffer, ctx *RPCContext, scheduleReload bool, err error) { metrics.RegionCacheCounterWithSendFail.Inc() diff --git a/tablecodec/tablecodec.go b/tablecodec/tablecodec.go index 215f8d05c27fa..de766831bc245 100644 --- a/tablecodec/tablecodec.go +++ b/tablecodec/tablecodec.go @@ -1292,6 +1292,7 @@ func TruncateIndexValue(v *types.Datum, idxCol *model.IndexColumn, tblCol *model if notStringType { return } + originalKind := v.Kind() isUTF8Charset := tblCol.Charset == charset.CharsetUTF8 || tblCol.Charset == charset.CharsetUTF8MB4 if isUTF8Charset && utf8.RuneCount(v.GetBytes()) > idxCol.Length { rs := bytes.Runes(v.GetBytes()) @@ -1303,7 +1304,7 @@ func TruncateIndexValue(v *types.Datum, idxCol *model.IndexColumn, tblCol *model } } else if !isUTF8Charset && len(v.GetBytes()) > idxCol.Length { v.SetBytes(v.GetBytes()[:idxCol.Length]) - if v.Kind() == types.KindString { + if originalKind == types.KindString { v.SetString(v.GetString(), tblCol.Collate) } } diff --git a/util/mock/store.go b/util/mock/store.go index 804f3d6a3f2d3..7c86de4b3cb72 100644 --- a/util/mock/store.go +++ b/util/mock/store.go @@ -72,3 +72,8 @@ func (s *Store) GetMemCache() kv.MemManager { // ShowStatus implements kv.Storage interface. func (s *Store) ShowStatus(ctx context.Context, key string) (interface{}, error) { return nil, nil } + +// GetMinSafeTS implements kv.Storage interface. +func (s *Store) GetMinSafeTS(txnScope string) uint64 { + return 0 +} diff --git a/util/sem/sem.go b/util/sem/sem.go index d29d29b601559..1aac6d0a9a999 100644 --- a/util/sem/sem.go +++ b/util/sem/sem.go @@ -138,6 +138,7 @@ func IsInvisibleSysVar(varNameInLower string) bool { variable.TiDBCheckMb4ValueInUTF8, variable.TiDBConfig, variable.TiDBEnableSlowLog, + variable.TiDBEnableTelemetry, variable.TiDBExpensiveQueryTimeThreshold, variable.TiDBForcePriority, variable.TiDBGeneralLog, @@ -146,12 +147,13 @@ func IsInvisibleSysVar(varNameInLower string) bool { variable.TiDBOptWriteRowID, variable.TiDBPProfSQLCPU, variable.TiDBRecordPlanInSlowLog, + variable.TiDBRowFormatVersion, variable.TiDBSlowQueryFile, variable.TiDBSlowLogThreshold, variable.TiDBEnableCollectExecutionInfo, variable.TiDBMemoryUsageAlarmRatio, - variable.TiDBEnableTelemetry, - variable.TiDBRowFormatVersion: + variable.TiDBRedactLog, + variable.TiDBSlowLogMasking: return true } return false diff --git a/util/sem/sem_test.go b/util/sem/sem_test.go index 073a195139c37..c2a54170dcf99 100644 --- a/util/sem/sem_test.go +++ b/util/sem/sem_test.go @@ -98,4 +98,6 @@ func (s *testSecurity) TestIsInvisibleSysVar(c *C) { c.Assert(IsInvisibleSysVar(variable.TiDBMemoryUsageAlarmRatio), IsTrue) c.Assert(IsInvisibleSysVar(variable.TiDBEnableTelemetry), IsTrue) c.Assert(IsInvisibleSysVar(variable.TiDBRowFormatVersion), IsTrue) + c.Assert(IsInvisibleSysVar(variable.TiDBRedactLog), IsTrue) + c.Assert(IsInvisibleSysVar(variable.TiDBSlowLogMasking), IsTrue) } diff --git a/util/testkit/testkit.go b/util/testkit/testkit.go index 7cacaf211375e..ef0a0858f76cf 100644 --- a/util/testkit/testkit.go +++ b/util/testkit/testkit.go @@ -271,6 +271,21 @@ func (tk *TestKit) MustPartition(sql string, partitions string, args ...interfac return tk.MustQuery(sql, args...) } +// UsedPartitions returns the partition names that will be used or all/dual. +func (tk *TestKit) UsedPartitions(sql string, args ...interface{}) *Result { + rs := tk.MustQuery("explain "+sql, args...) + var usedPartitions [][]string + for i := range rs.rows { + index := strings.Index(rs.rows[i][3], "partition:") + if index != -1 { + p := rs.rows[i][3][index+len("partition:"):] + partitions := strings.Split(strings.SplitN(p, " ", 2)[0], ",") + usedPartitions = append(usedPartitions, partitions) + } + } + return &Result{rows: usedPartitions, c: tk.c, comment: check.Commentf("sql:%s, args:%v", sql, args)} +} + // MustUseIndex checks if the result execution plan contains specific index(es). func (tk *TestKit) MustUseIndex(sql string, index string, args ...interface{}) bool { rs := tk.MustQuery("explain "+sql, args...) @@ -312,6 +327,19 @@ func (tk *TestKit) MustQuery(sql string, args ...interface{}) *Result { return tk.ResultSetToResult(rs, comment) } +// MayQuery query the statements and returns result rows if result set is returned. +// If expected result is set it asserts the query result equals expected result. +func (tk *TestKit) MayQuery(sql string, args ...interface{}) *Result { + comment := check.Commentf("sql:%s, args:%v", sql, args) + rs, err := tk.Exec(sql, args...) + tk.c.Assert(errors.ErrorStack(err), check.Equals, "", comment) + if rs == nil { + var emptyStringAoA [][]string + return &Result{rows: emptyStringAoA, c: tk.c, comment: comment} + } + return tk.ResultSetToResult(rs, comment) +} + // QueryToErr executes a sql statement and discard results. func (tk *TestKit) QueryToErr(sql string, args ...interface{}) error { comment := check.Commentf("sql:%s, args:%v", sql, args)