diff --git a/sessionctx/variable/sysvar.go b/sessionctx/variable/sysvar.go index 0f62114d1d984..f5845afffbcc4 100644 --- a/sessionctx/variable/sysvar.go +++ b/sessionctx/variable/sysvar.go @@ -2115,6 +2115,45 @@ var defaultSysVars = []*SysVar{ s.EnableReuseCheck = TiDBOptOn(val) return nil }}, + {Scope: ScopeGlobal, Name: TiDBTTLJobEnable, Value: BoolToOnOff(DefTiDBTTLJobEnable), Type: TypeBool, SetGlobal: func(ctx context.Context, vars *SessionVars, s string) error { + EnableTTLJob.Store(TiDBOptOn(s)) + return nil + }, GetGlobal: func(ctx context.Context, vars *SessionVars) (string, error) { + return BoolToOnOff(EnableTTLJob.Load()), nil + }}, + {Scope: ScopeGlobal, Name: TiDBTTLScanBatchSize, Value: strconv.Itoa(DefTiDBTTLScanBatchSize), Type: TypeInt, MinValue: DefTiDBTTLScanBatchMinSize, MaxValue: DefTiDBTTLScanBatchMaxSize, SetGlobal: func(ctx context.Context, vars *SessionVars, s string) error { + val, err := strconv.ParseInt(s, 10, 64) + if err != nil { + return err + } + TTLScanBatchSize.Store(val) + return nil + }, GetGlobal: func(ctx context.Context, vars *SessionVars) (string, error) { + val := TTLScanBatchSize.Load() + return strconv.FormatInt(val, 10), nil + }}, + {Scope: ScopeGlobal, Name: TiDBTTLDeleteBatchSize, Value: strconv.Itoa(DefTiDBTTLDeleteBatchSize), Type: TypeInt, MinValue: DefTiDBTTLDeleteBatchMinSize, MaxValue: DefTiDBTTLDeleteBatchMaxSize, SetGlobal: func(ctx context.Context, vars *SessionVars, s string) error { + val, err := strconv.ParseInt(s, 10, 64) + if err != nil { + return err + } + TTLDeleteBatchSize.Store(val) + return nil + }, GetGlobal: func(ctx context.Context, vars *SessionVars) (string, error) { + val := TTLDeleteBatchSize.Load() + return strconv.FormatInt(val, 10), nil + }}, + {Scope: ScopeGlobal, Name: TiDBTTLDeleteRateLimit, Value: strconv.Itoa(DefTiDBTTLDeleteRateLimit), Type: TypeInt, MinValue: 0, MaxValue: math.MaxInt64, SetGlobal: func(ctx context.Context, vars *SessionVars, s string) error { + val, err := strconv.ParseInt(s, 10, 64) + if err != nil { + return err + } + TTLDeleteRateLimit.Store(val) + return nil + }, GetGlobal: func(ctx context.Context, vars *SessionVars) (string, error) { + val := TTLDeleteRateLimit.Load() + return strconv.FormatInt(val, 10), nil + }}, { Scope: ScopeGlobal | ScopeSession, Name: TiDBStoreBatchSize, Value: strconv.FormatInt(DefTiDBStoreBatchSize, 10), Type: TypeInt, MinValue: 0, MaxValue: 25000, SetSession: func(s *SessionVars, val string) error { diff --git a/sessionctx/variable/tidb_vars.go b/sessionctx/variable/tidb_vars.go index 5ca85c251aaf3..0b8f7a12e3d8c 100644 --- a/sessionctx/variable/tidb_vars.go +++ b/sessionctx/variable/tidb_vars.go @@ -868,6 +868,14 @@ const ( TiDBGOGCTunerThreshold = "tidb_gogc_tuner_threshold" // TiDBExternalTS is the ts to read through when the `TiDBEnableExternalTsRead` is on TiDBExternalTS = "tidb_external_ts" + // TiDBTTLJobEnable is used to enable/disable scheduling ttl job + TiDBTTLJobEnable = "tidb_ttl_job_enable" + // TiDBTTLScanBatchSize is used to control the batch size in the SELECT statement for TTL jobs + TiDBTTLScanBatchSize = "tidb_ttl_scan_batch_size" + // TiDBTTLDeleteBatchSize is used to control the batch size in the DELETE statement for TTL jobs + TiDBTTLDeleteBatchSize = "tidb_ttl_delete_batch_size" + // TiDBTTLDeleteRateLimit is used to control the delete rate limit for TTL jobs in each node + TiDBTTLDeleteRateLimit = "tidb_ttl_delete_rate_limit" // PasswordReuseHistory limit a few passwords to reuse. PasswordReuseHistory = "password_history" // PasswordReuseTime limit how long passwords can be reused. @@ -1115,6 +1123,14 @@ const ( DefTiDBUseAlloc = false DefTiDBEnablePlanReplayerCapture = false DefTiDBIndexMergeIntersectionConcurrency = ConcurrencyUnset + DefTiDBTTLJobEnable = true + DefTiDBTTLScanBatchSize = 500 + DefTiDBTTLScanBatchMaxSize = 10240 + DefTiDBTTLScanBatchMinSize = 1 + DefTiDBTTLDeleteBatchSize = 500 + DefTiDBTTLDeleteBatchMaxSize = 10240 + DefTiDBTTLDeleteBatchMinSize = 1 + DefTiDBTTLDeleteRateLimit = 0 DefPasswordReuseHistory = 0 DefPasswordReuseTime = 0 DefTiDBStoreBatchSize = 0 @@ -1180,6 +1196,10 @@ var ( PasswordValidationMixedCaseCount = atomic.NewInt32(1) PasswordValidtaionNumberCount = atomic.NewInt32(1) PasswordValidationSpecialCharCount = atomic.NewInt32(1) + EnableTTLJob = atomic.NewBool(DefTiDBTTLJobEnable) + TTLScanBatchSize = atomic.NewInt64(DefTiDBTTLScanBatchSize) + TTLDeleteBatchSize = atomic.NewInt64(DefTiDBTTLDeleteBatchSize) + TTLDeleteRateLimit = atomic.NewInt64(DefTiDBTTLDeleteRateLimit) PasswordHistory = atomic.NewInt64(DefPasswordReuseHistory) PasswordReuseInterval = atomic.NewInt64(DefPasswordReuseTime) ) diff --git a/ttl/cache/table.go b/ttl/cache/table.go index 9b49c5186a2b6..ab8c6c4b0fbd5 100644 --- a/ttl/cache/table.go +++ b/ttl/cache/table.go @@ -56,6 +56,8 @@ func getTableKeyColumns(tbl *model.TableInfo) ([]*model.ColumnInfo, []*types.Fie // PhysicalTable is used to provide some information for a physical table in TTL job type PhysicalTable struct { + // ID is the physical ID of the table + ID int64 // Schema is the database name of the table Schema model.CIStr *model.TableInfo @@ -92,11 +94,13 @@ func NewPhysicalTable(schema model.CIStr, tbl *model.TableInfo, partition model. return nil, err } + var physicalID int64 var partitionDef *model.PartitionDefinition if tbl.Partition == nil { if partition.L != "" { return nil, errors.Errorf("table '%s.%s' is not a partitioned table", schema, tbl.Name) } + physicalID = tbl.ID } else { if partition.L == "" { return nil, errors.Errorf("partition name is required, table '%s.%s' is a partitioned table", schema, tbl.Name) @@ -112,9 +116,12 @@ func NewPhysicalTable(schema model.CIStr, tbl *model.TableInfo, partition model. if partitionDef == nil { return nil, errors.Errorf("partition '%s' is not found in ttl table '%s.%s'", partition.O, schema, tbl.Name) } + + physicalID = partitionDef.ID } return &PhysicalTable{ + ID: physicalID, Schema: schema, TableInfo: tbl, Partition: partition, diff --git a/ttl/cache/table_test.go b/ttl/cache/table_test.go index 5c2ecdef89d84..f79d4c3bf4256 100644 --- a/ttl/cache/table_test.go +++ b/ttl/cache/table_test.go @@ -110,7 +110,7 @@ func TestNewTTLTable(t *testing.T) { physicalTbls = append(physicalTbls, ttlTbl) } else { for _, partition := range tblInfo.Partition.Definitions { - ttlTbl, err := cache.NewPhysicalTable(model.NewCIStr(c.db), tblInfo, model.NewCIStr(partition.Name.O)) + ttlTbl, err := cache.NewPhysicalTable(model.NewCIStr(c.db), tblInfo, partition.Name) if c.timeCol == "" { require.Error(t, err) continue @@ -131,10 +131,12 @@ func TestNewTTLTable(t *testing.T) { require.Same(t, timeColumn, ttlTbl.TimeColumn) if tblInfo.Partition == nil { + require.Equal(t, ttlTbl.TableInfo.ID, ttlTbl.ID) require.Equal(t, "", ttlTbl.Partition.L) require.Nil(t, ttlTbl.PartitionDef) } else { def := tblInfo.Partition.Definitions[i] + require.Equal(t, def.ID, ttlTbl.ID) require.Equal(t, def.Name.L, ttlTbl.Partition.L) require.Equal(t, def, *(ttlTbl.PartitionDef)) } diff --git a/ttl/session/BUILD.bazel b/ttl/session/BUILD.bazel index 36bdb259e6107..6d28bff0730dc 100644 --- a/ttl/session/BUILD.bazel +++ b/ttl/session/BUILD.bazel @@ -10,6 +10,7 @@ go_library( "//kv", "//parser/terror", "//sessionctx", + "//sessionctx/variable", "//sessiontxn", "//util/chunk", "//util/sqlexec", @@ -22,10 +23,12 @@ go_test( srcs = [ "main_test.go", "session_test.go", + "sysvar_test.go", ], - embed = [":session"], flaky = True, deps = [ + ":session", + "//sessionctx/variable", "//testkit", "//testkit/testsetup", "@com_github_pingcap_errors//:errors", diff --git a/ttl/session/session.go b/ttl/session/session.go index de03388e4cc2f..ea48bcd1f34dd 100644 --- a/ttl/session/session.go +++ b/ttl/session/session.go @@ -22,6 +22,7 @@ import ( "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/parser/terror" "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/sessiontxn" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/sqlexec" @@ -36,6 +37,8 @@ type Session interface { ExecuteSQL(ctx context.Context, sql string, args ...interface{}) ([]chunk.Row, error) // RunInTxn executes the specified function in a txn RunInTxn(ctx context.Context, fn func() error) (err error) + // ResetWithGlobalTimeZone resets the session time zone to global time zone + ResetWithGlobalTimeZone(ctx context.Context) error // Close closes the session Close() } @@ -112,6 +115,27 @@ func (s *session) RunInTxn(ctx context.Context, fn func() error) (err error) { return err } +// ResetWithGlobalTimeZone resets the session time zone to global time zone +func (s *session) ResetWithGlobalTimeZone(ctx context.Context) error { + sessVar := s.GetSessionVars() + globalTZ, err := sessVar.GetGlobalSystemVar(ctx, variable.TimeZone) + if err != nil { + return err + } + + tz, err := sessVar.GetSessionOrGlobalSystemVar(ctx, variable.TimeZone) + if err != nil { + return err + } + + if globalTZ == tz { + return nil + } + + _, err = s.ExecuteSQL(ctx, "SET @@time_zone=@@global.time_zone") + return err +} + // Close closes the session func (s *session) Close() { if s.closeFn != nil { diff --git a/ttl/session/session_test.go b/ttl/session/session_test.go index f23a717776b10..a30949206223a 100644 --- a/ttl/session/session_test.go +++ b/ttl/session/session_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package session +package session_test import ( "context" @@ -20,6 +20,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/tidb/testkit" + "github.com/pingcap/tidb/ttl/session" "github.com/stretchr/testify/require" ) @@ -28,7 +29,7 @@ func TestSessionRunInTxn(t *testing.T) { tk := testkit.NewTestKit(t, store) tk.MustExec("use test") tk.MustExec("create table t(id int primary key, v int)") - se := NewSession(tk.Session(), tk.Session(), nil) + se := session.NewSession(tk.Session(), tk.Session(), nil) tk2 := testkit.NewTestKit(t, store) tk2.MustExec("use test") @@ -50,3 +51,15 @@ func TestSessionRunInTxn(t *testing.T) { })) tk2.MustQuery("select * from t order by id asc").Check(testkit.Rows("1 10", "3 30")) } + +func TestSessionResetTimeZone(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("set @@global.time_zone='UTC'") + tk.MustExec("set @@time_zone='Asia/Shanghai'") + + se := session.NewSession(tk.Session(), tk.Session(), nil) + tk.MustQuery("select @@time_zone").Check(testkit.Rows("Asia/Shanghai")) + require.NoError(t, se.ResetWithGlobalTimeZone(context.TODO())) + tk.MustQuery("select @@time_zone").Check(testkit.Rows("UTC")) +} diff --git a/ttl/session/sysvar_test.go b/ttl/session/sysvar_test.go new file mode 100644 index 0000000000000..58f61c3cc88bb --- /dev/null +++ b/ttl/session/sysvar_test.go @@ -0,0 +1,125 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session_test + +import ( + "fmt" + "strconv" + "testing" + + "github.com/pingcap/tidb/sessionctx/variable" + "github.com/pingcap/tidb/testkit" + "github.com/stretchr/testify/require" +) + +func TestSysVarTTLJobEnable(t *testing.T) { + origEnableDDL := variable.EnableTTLJob.Load() + defer func() { + variable.EnableTTLJob.Store(origEnableDDL) + }() + + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("set @@global.tidb_ttl_job_enable=0") + require.False(t, variable.EnableTTLJob.Load()) + tk.MustQuery("select @@global.tidb_ttl_job_enable").Check(testkit.Rows("0")) + tk.MustQuery("select @@tidb_ttl_job_enable").Check(testkit.Rows("0")) + + tk.MustExec("set @@global.tidb_ttl_job_enable=1") + require.True(t, variable.EnableTTLJob.Load()) + tk.MustQuery("select @@global.tidb_ttl_job_enable").Check(testkit.Rows("1")) + tk.MustQuery("select @@tidb_ttl_job_enable").Check(testkit.Rows("1")) + + tk.MustExec("set @@global.tidb_ttl_job_enable=0") + require.False(t, variable.EnableTTLJob.Load()) + tk.MustQuery("select @@global.tidb_ttl_job_enable").Check(testkit.Rows("0")) + tk.MustQuery("select @@tidb_ttl_job_enable").Check(testkit.Rows("0")) +} + +func TestSysVarTTLScanBatchSize(t *testing.T) { + origScanBatchSize := variable.TTLScanBatchSize.Load() + defer func() { + variable.TTLScanBatchSize.Store(origScanBatchSize) + }() + + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("set @@global.tidb_ttl_scan_batch_size=789") + require.Equal(t, int64(789), variable.TTLScanBatchSize.Load()) + tk.MustQuery("select @@global.tidb_ttl_scan_batch_size").Check(testkit.Rows("789")) + tk.MustQuery("select @@tidb_ttl_scan_batch_size").Check(testkit.Rows("789")) + + tk.MustExec("set @@global.tidb_ttl_scan_batch_size=0") + require.Equal(t, int64(1), variable.TTLScanBatchSize.Load()) + tk.MustQuery("select @@global.tidb_ttl_scan_batch_size").Check(testkit.Rows("1")) + tk.MustQuery("select @@tidb_ttl_scan_batch_size").Check(testkit.Rows("1")) + + maxVal := int64(variable.DefTiDBTTLScanBatchMaxSize) + tk.MustExec(fmt.Sprintf("set @@global.tidb_ttl_scan_batch_size=%d", maxVal+1)) + require.Equal(t, maxVal, variable.TTLScanBatchSize.Load()) + tk.MustQuery("select @@global.tidb_ttl_scan_batch_size").Check(testkit.Rows(strconv.FormatInt(maxVal, 10))) + tk.MustQuery("select @@tidb_ttl_scan_batch_size").Check(testkit.Rows(strconv.FormatInt(maxVal, 10))) +} + +func TestSysVarTTLScanDeleteBatchSize(t *testing.T) { + origScanBatchSize := variable.TTLScanBatchSize.Load() + defer func() { + variable.TTLScanBatchSize.Store(origScanBatchSize) + }() + + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("set @@global.tidb_ttl_delete_batch_size=789") + require.Equal(t, int64(789), variable.TTLDeleteBatchSize.Load()) + tk.MustQuery("select @@global.tidb_ttl_delete_batch_size").Check(testkit.Rows("789")) + tk.MustQuery("select @@tidb_ttl_delete_batch_size").Check(testkit.Rows("789")) + + tk.MustExec("set @@global.tidb_ttl_delete_batch_size=0") + require.Equal(t, int64(1), variable.TTLDeleteBatchSize.Load()) + tk.MustQuery("select @@global.tidb_ttl_delete_batch_size").Check(testkit.Rows("1")) + tk.MustQuery("select @@tidb_ttl_delete_batch_size").Check(testkit.Rows("1")) + + maxVal := int64(variable.DefTiDBTTLDeleteBatchMaxSize) + tk.MustExec(fmt.Sprintf("set @@global.tidb_ttl_delete_batch_size=%d", maxVal+1)) + require.Equal(t, maxVal, variable.TTLDeleteBatchSize.Load()) + tk.MustQuery("select @@global.tidb_ttl_delete_batch_size").Check(testkit.Rows(strconv.FormatInt(maxVal, 10))) + tk.MustQuery("select @@tidb_ttl_delete_batch_size").Check(testkit.Rows(strconv.FormatInt(maxVal, 10))) +} + +func TestSysVarTTLScanDeleteLimit(t *testing.T) { + origDeleteLimit := variable.TTLDeleteRateLimit.Load() + defer func() { + variable.TTLDeleteRateLimit.Store(origDeleteLimit) + }() + + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustQuery("select @@global.tidb_ttl_delete_rate_limit").Check(testkit.Rows("0")) + + tk.MustExec("set @@global.tidb_ttl_delete_rate_limit=100000") + require.Equal(t, int64(100000), variable.TTLDeleteRateLimit.Load()) + tk.MustQuery("select @@global.tidb_ttl_delete_rate_limit").Check(testkit.Rows("100000")) + tk.MustQuery("select @@tidb_ttl_delete_rate_limit").Check(testkit.Rows("100000")) + + tk.MustExec("set @@global.tidb_ttl_delete_rate_limit=0") + require.Equal(t, int64(0), variable.TTLDeleteRateLimit.Load()) + tk.MustQuery("select @@global.tidb_ttl_delete_rate_limit").Check(testkit.Rows("0")) + tk.MustQuery("select @@tidb_ttl_delete_rate_limit").Check(testkit.Rows("0")) + + tk.MustExec("set @@global.tidb_ttl_delete_rate_limit=-1") + require.Equal(t, int64(0), variable.TTLDeleteRateLimit.Load()) + tk.MustQuery("select @@global.tidb_ttl_delete_rate_limit").Check(testkit.Rows("0")) + tk.MustQuery("select @@tidb_ttl_delete_rate_limit").Check(testkit.Rows("0")) +} diff --git a/ttl/ttlworker/BUILD.bazel b/ttl/ttlworker/BUILD.bazel new file mode 100644 index 0000000000000..a278b9cac754f --- /dev/null +++ b/ttl/ttlworker/BUILD.bazel @@ -0,0 +1,55 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "ttlworker", + srcs = [ + "del.go", + "scan.go", + "session.go", + "worker.go", + ], + importpath = "github.com/pingcap/tidb/ttl/ttlworker", + visibility = ["//visibility:public"], + deps = [ + "//parser/terror", + "//sessionctx", + "//sessionctx/variable", + "//ttl/cache", + "//ttl/session", + "//ttl/sqlbuilder", + "//types", + "//util", + "//util/chunk", + "//util/logutil", + "//util/sqlexec", + "@com_github_ngaut_pools//:pools", + "@com_github_pingcap_errors//:errors", + "@org_golang_x_time//rate", + "@org_uber_go_zap//:zap", + ], +) + +go_test( + name = "ttlworker_test", + srcs = [ + "del_test.go", + "scan_test.go", + "session_test.go", + ], + embed = [":ttlworker"], + deps = [ + "//infoschema", + "//parser/ast", + "//parser/model", + "//parser/mysql", + "//sessionctx", + "//sessionctx/variable", + "//ttl/cache", + "//types", + "//util/chunk", + "@com_github_ngaut_pools//:pools", + "@com_github_pingcap_errors//:errors", + "@com_github_stretchr_testify//require", + "@org_golang_x_time//rate", + ], +) diff --git a/ttl/ttlworker/del.go b/ttl/ttlworker/del.go new file mode 100644 index 0000000000000..d0bb651da8f41 --- /dev/null +++ b/ttl/ttlworker/del.go @@ -0,0 +1,276 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ttlworker + +import ( + "container/list" + "context" + "fmt" + "sync" + "sync/atomic" + "time" + + "github.com/pingcap/tidb/sessionctx/variable" + "github.com/pingcap/tidb/ttl/cache" + "github.com/pingcap/tidb/ttl/session" + "github.com/pingcap/tidb/ttl/sqlbuilder" + "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/util/logutil" + "go.uber.org/zap" + "golang.org/x/time/rate" +) + +const ( + delMaxRetry = 3 + delRetryBufferSize = 128 + delRetryInterval = time.Second * 5 +) + +var globalDelRateLimiter = newDelRateLimiter() + +type delRateLimiter struct { + sync.Mutex + limiter *rate.Limiter + limit atomic.Int64 +} + +func newDelRateLimiter() *delRateLimiter { + limiter := &delRateLimiter{} + limiter.limiter = rate.NewLimiter(0, 1) + limiter.limit.Store(0) + return limiter +} + +func (l *delRateLimiter) Wait(ctx context.Context) error { + limit := l.limit.Load() + if variable.TTLDeleteRateLimit.Load() != limit { + limit = l.reset() + } + + if limit == 0 { + return ctx.Err() + } + + return l.limiter.Wait(ctx) +} + +func (l *delRateLimiter) reset() (newLimit int64) { + l.Lock() + defer l.Unlock() + newLimit = variable.TTLDeleteRateLimit.Load() + if newLimit != l.limit.Load() { + l.limit.Store(newLimit) + l.limiter.SetLimit(rate.Limit(newLimit)) + } + return +} + +type ttlDeleteTask struct { + tbl *cache.PhysicalTable + expire time.Time + rows [][]types.Datum + statistics *ttlStatistics +} + +func (t *ttlDeleteTask) doDelete(ctx context.Context, rawSe session.Session) (retryRows [][]types.Datum) { + leftRows := t.rows + se := newTableSession(rawSe, t.tbl, t.expire) + for len(leftRows) > 0 { + maxBatch := variable.TTLDeleteBatchSize.Load() + var delBatch [][]types.Datum + if int64(len(leftRows)) < maxBatch { + delBatch = leftRows + leftRows = nil + } else { + delBatch = leftRows[0:maxBatch] + leftRows = leftRows[maxBatch:] + } + + sql, err := sqlbuilder.BuildDeleteSQL(t.tbl, delBatch, t.expire) + if err != nil { + t.statistics.IncErrorRows(len(delBatch)) + logutil.BgLogger().Warn( + "build delete SQL in TTL failed", + zap.Error(err), + zap.String("table", t.tbl.Schema.O+"."+t.tbl.Name.O), + ) + } + + if err = globalDelRateLimiter.Wait(ctx); err != nil { + t.statistics.IncErrorRows(len(delBatch)) + return + } + + _, needRetry, err := se.ExecuteSQLWithCheck(ctx, sql) + if err != nil { + needRetry = needRetry && ctx.Err() == nil + logutil.BgLogger().Warn( + "delete SQL in TTL failed", + zap.Error(err), + zap.String("SQL", sql), + zap.Bool("needRetry", needRetry), + ) + + if needRetry { + if retryRows == nil { + retryRows = make([][]types.Datum, 0, len(leftRows)+len(delBatch)) + } + retryRows = append(retryRows, delBatch...) + } else { + t.statistics.IncErrorRows(len(delBatch)) + } + continue + } + + t.statistics.IncSuccessRows(len(delBatch)) + } + return retryRows +} + +type ttlDelRetryItem struct { + task *ttlDeleteTask + retryCnt int + inTime time.Time +} + +type ttlDelRetryBuffer struct { + list *list.List + maxSize int + maxRetry int + retryInterval time.Duration + getTime func() time.Time +} + +func newTTLDelRetryBuffer() *ttlDelRetryBuffer { + return &ttlDelRetryBuffer{ + list: list.New(), + maxSize: delRetryBufferSize, + maxRetry: delMaxRetry, + retryInterval: delRetryInterval, + getTime: time.Now, + } +} + +func (b *ttlDelRetryBuffer) Len() int { + return b.list.Len() +} + +func (b *ttlDelRetryBuffer) RecordTaskResult(task *ttlDeleteTask, retryRows [][]types.Datum) { + b.recordRetryItem(task, retryRows, 0) +} + +func (b *ttlDelRetryBuffer) DoRetry(do func(*ttlDeleteTask) [][]types.Datum) time.Duration { + for b.list.Len() > 0 { + ele := b.list.Front() + item, ok := ele.Value.(*ttlDelRetryItem) + if !ok { + logutil.BgLogger().Error(fmt.Sprintf("invalid retry buffer item type: %T", ele)) + b.list.Remove(ele) + continue + } + + now := b.getTime() + interval := now.Sub(item.inTime) + if interval < b.retryInterval { + return b.retryInterval - interval + } + + b.list.Remove(ele) + if retryRows := do(item.task); len(retryRows) > 0 { + b.recordRetryItem(item.task, retryRows, item.retryCnt+1) + } + } + return b.retryInterval +} + +func (b *ttlDelRetryBuffer) recordRetryItem(task *ttlDeleteTask, retryRows [][]types.Datum, retryCnt int) bool { + if len(retryRows) == 0 { + return false + } + + if retryCnt >= b.maxRetry { + task.statistics.IncErrorRows(len(retryRows)) + return false + } + + for b.list.Len() > 0 && b.list.Len() >= b.maxSize { + ele := b.list.Front() + if item, ok := ele.Value.(*ttlDelRetryItem); ok { + item.task.statistics.IncErrorRows(len(item.task.rows)) + } else { + logutil.BgLogger().Error(fmt.Sprintf("invalid retry buffer item type: %T", ele)) + } + b.list.Remove(b.list.Front()) + } + + newTask := *task + newTask.rows = retryRows + b.list.PushBack(&ttlDelRetryItem{ + task: &newTask, + inTime: b.getTime(), + retryCnt: retryCnt, + }) + return true +} + +type ttlDeleteWorker struct { + baseWorker + delCh <-chan *ttlDeleteTask + sessionPool sessionPool + retryBuffer *ttlDelRetryBuffer +} + +func newDeleteWorker(delCh <-chan *ttlDeleteTask, sessPool sessionPool) *ttlDeleteWorker { + w := &ttlDeleteWorker{ + delCh: delCh, + sessionPool: sessPool, + retryBuffer: newTTLDelRetryBuffer(), + } + w.init(w.loop) + return w +} + +func (w *ttlDeleteWorker) loop() error { + se, err := getSession(w.sessionPool) + if err != nil { + return err + } + + ctx := w.baseWorker.ctx + + doRetry := func(task *ttlDeleteTask) [][]types.Datum { + return task.doDelete(ctx, se) + } + + timer := time.NewTimer(w.retryBuffer.retryInterval) + defer timer.Stop() + + for w.status == workerStatusRunning { + select { + case <-ctx.Done(): + return nil + case <-timer.C: + nextInterval := w.retryBuffer.DoRetry(doRetry) + timer.Reset(nextInterval) + case task, ok := <-w.delCh: + if !ok { + return nil + } + retryRows := task.doDelete(ctx, se) + w.retryBuffer.RecordTaskResult(task, retryRows) + } + } + return nil +} diff --git a/ttl/ttlworker/del_test.go b/ttl/ttlworker/del_test.go new file mode 100644 index 0000000000000..62e05dbeb4c14 --- /dev/null +++ b/ttl/ttlworker/del_test.go @@ -0,0 +1,389 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ttlworker + +import ( + "context" + "errors" + "strings" + "testing" + "time" + + "github.com/pingcap/tidb/sessionctx/variable" + "github.com/pingcap/tidb/ttl/cache" + "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/util/chunk" + "github.com/stretchr/testify/require" + "golang.org/x/time/rate" +) + +func TestTTLDelRetryBuffer(t *testing.T) { + createTask := func(name string) (*ttlDeleteTask, [][]types.Datum, *ttlStatistics) { + rows := make([][]types.Datum, 10) + statistics := &ttlStatistics{} + statistics.IncTotalRows(10) + task := &ttlDeleteTask{ + tbl: newMockTTLTbl(t, name), + expire: time.UnixMilli(0), + rows: rows, + statistics: statistics, + } + return task, rows, statistics + } + + shouldNotDoRetry := func(task *ttlDeleteTask) [][]types.Datum { + require.FailNow(t, "should not do retry") + return nil + } + + start := time.UnixMilli(0) + tm := start + buffer := newTTLDelRetryBuffer() + require.Equal(t, delRetryBufferSize, buffer.maxSize) + require.Equal(t, delMaxRetry, buffer.maxRetry) + require.Equal(t, delRetryInterval, buffer.retryInterval) + + buffer.maxSize = 3 + buffer.maxRetry = 2 + buffer.retryInterval = 10 * time.Second + buffer.getTime = func() time.Time { + return tm + } + + // add success task + task1, rows1, statics1 := createTask("t1") + buffer.RecordTaskResult(task1, nil) + require.Equal(t, 0, buffer.Len()) + buffer.DoRetry(shouldNotDoRetry) + require.Equal(t, uint64(0), statics1.ErrorRows.Load()) + + // add a task with 1 failed rows + buffer.RecordTaskResult(task1, rows1[:1]) + require.Equal(t, 1, buffer.Len()) + buffer.DoRetry(shouldNotDoRetry) + require.Equal(t, uint64(0), statics1.ErrorRows.Load()) + + // add another task with 2 failed rows + tm = tm.Add(time.Second) + task2, rows2, statics2 := createTask("t2") + buffer.RecordTaskResult(task2, rows2[:2]) + require.Equal(t, 2, buffer.Len()) + buffer.DoRetry(shouldNotDoRetry) + require.Equal(t, uint64(0), statics2.ErrorRows.Load()) + + // add another task with 3 failed rows + tm = tm.Add(time.Second) + task3, rows3, statics3 := createTask("t3") + buffer.RecordTaskResult(task3, rows3[:3]) + require.Equal(t, 3, buffer.Len()) + buffer.DoRetry(shouldNotDoRetry) + require.Equal(t, uint64(0), statics3.ErrorRows.Load()) + + // add new task will eliminate old tasks + tm = tm.Add(time.Second) + task4, rows4, statics4 := createTask("t4") + buffer.RecordTaskResult(task4, rows4[:4]) + require.Equal(t, 3, buffer.Len()) + buffer.DoRetry(shouldNotDoRetry) + require.Equal(t, uint64(0), statics4.ErrorRows.Load()) + require.Equal(t, uint64(1), statics1.ErrorRows.Load()) + + // poll up-to-date tasks + tm = tm.Add(10*time.Second - time.Millisecond) + tasks := make([]*ttlDeleteTask, 0) + doRetrySuccess := func(task *ttlDeleteTask) [][]types.Datum { + task.statistics.IncSuccessRows(len(task.rows)) + tasks = append(tasks, task) + return nil + } + nextInterval := buffer.DoRetry(doRetrySuccess) + require.Equal(t, time.Millisecond, nextInterval) + require.Equal(t, 2, len(tasks)) + require.Equal(t, "t2", tasks[0].tbl.Name.L) + require.Equal(t, time.UnixMilli(0), tasks[0].expire) + require.Equal(t, 2, len(tasks[0].rows)) + require.Equal(t, uint64(2), statics2.SuccessRows.Load()) + require.Equal(t, uint64(0), statics2.ErrorRows.Load()) + require.Equal(t, "t3", tasks[1].tbl.Name.L) + require.Equal(t, time.UnixMilli(0), tasks[0].expire) + require.Equal(t, 3, len(tasks[1].rows)) + require.Equal(t, 1, buffer.Len()) + require.Equal(t, uint64(3), statics3.SuccessRows.Load()) + require.Equal(t, uint64(0), statics3.ErrorRows.Load()) + require.Equal(t, uint64(0), statics4.SuccessRows.Load()) + require.Equal(t, uint64(0), statics4.ErrorRows.Load()) + + // poll next + tm = tm.Add(time.Millisecond) + tasks = make([]*ttlDeleteTask, 0) + nextInterval = buffer.DoRetry(doRetrySuccess) + require.Equal(t, 10*time.Second, nextInterval) + require.Equal(t, 1, len(tasks)) + require.Equal(t, "t4", tasks[0].tbl.Name.L) + require.Equal(t, time.UnixMilli(0), tasks[0].expire) + require.Equal(t, 4, len(tasks[0].rows)) + require.Equal(t, 0, buffer.Len()) + require.Equal(t, uint64(4), statics4.SuccessRows.Load()) + require.Equal(t, uint64(0), statics4.ErrorRows.Load()) + + // test retry max count + retryCnt := 0 + doRetryFail := func(task *ttlDeleteTask) [][]types.Datum { + retryCnt++ + task.statistics.SuccessRows.Add(1) + return task.rows[1:] + } + task5, rows5, statics5 := createTask("t5") + buffer.RecordTaskResult(task5, rows5[:5]) + require.Equal(t, 1, buffer.Len()) + tm = tm.Add(10 * time.Second) + nextInterval = buffer.DoRetry(doRetryFail) + require.Equal(t, 10*time.Second, nextInterval) + require.Equal(t, uint64(1), statics5.SuccessRows.Load()) + require.Equal(t, uint64(0), statics5.ErrorRows.Load()) + require.Equal(t, 1, retryCnt) + tm = tm.Add(10 * time.Second) + buffer.DoRetry(doRetryFail) + require.Equal(t, uint64(2), statics5.SuccessRows.Load()) + require.Equal(t, uint64(3), statics5.ErrorRows.Load()) + require.Equal(t, 2, retryCnt) + require.Equal(t, 0, buffer.Len()) + + // test task should be immutable + require.Equal(t, 10, len(task5.rows)) +} + +func TestTTLDeleteTaskDoDelete(t *testing.T) { + origBatchSize := variable.TTLDeleteBatchSize.Load() + variable.TTLDeleteBatchSize.Store(3) + defer variable.TTLDeleteBatchSize.Store(origBatchSize) + + t1 := newMockTTLTbl(t, "t1") + t2 := newMockTTLTbl(t, "t2") + t3 := newMockTTLTbl(t, "t3") + t4 := newMockTTLTbl(t, "t4") + s := newMockSession(t) + invokes := 0 + s.executeSQL = func(ctx context.Context, sql string, args ...interface{}) ([]chunk.Row, error) { + invokes++ + s.sessionInfoSchema = newMockInfoSchema(t1.TableInfo, t2.TableInfo, t3.TableInfo, t4.TableInfo) + if strings.Contains(sql, "`t1`") { + return nil, nil + } + + if strings.Contains(sql, "`t2`") { + return nil, errors.New("mockErr") + } + + if strings.Contains(sql, "`t3`") { + s.sessionInfoSchema = newMockInfoSchema() + return nil, nil + } + + if strings.Contains(sql, "`t4`") { + switch invokes { + case 1: + return nil, nil + case 2, 4: + return nil, errors.New("mockErr") + case 3: + s.sessionInfoSchema = newMockInfoSchema() + return nil, nil + } + } + + require.FailNow(t, "") + return nil, nil + } + + nRows := func(n int) [][]types.Datum { + rows := make([][]types.Datum, n) + for i := 0; i < n; i++ { + rows[i] = []types.Datum{ + types.NewIntDatum(int64(i)), + } + } + return rows + } + + delTask := func(t *cache.PhysicalTable) *ttlDeleteTask { + task := &ttlDeleteTask{ + tbl: t, + expire: time.UnixMilli(0), + rows: nRows(10), + statistics: &ttlStatistics{}, + } + task.statistics.TotalRows.Add(10) + return task + } + + cases := []struct { + task *ttlDeleteTask + retryRows []int + successRows int + errorRows int + }{ + { + task: delTask(t1), + retryRows: nil, + successRows: 10, + errorRows: 0, + }, + { + task: delTask(t2), + retryRows: []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, + successRows: 0, + errorRows: 0, + }, + { + task: delTask(t3), + retryRows: nil, + successRows: 0, + errorRows: 10, + }, + { + task: delTask(t4), + retryRows: []int{3, 4, 5, 9}, + successRows: 3, + errorRows: 3, + }, + } + + for _, c := range cases { + invokes = 0 + retryRows := c.task.doDelete(context.TODO(), s) + require.Equal(t, 4, invokes) + if c.retryRows == nil { + require.Nil(t, retryRows) + } + require.Equal(t, len(c.retryRows), len(retryRows)) + for i, row := range retryRows { + require.Equal(t, int64(c.retryRows[i]), row[0].GetInt64()) + } + require.Equal(t, uint64(10), c.task.statistics.TotalRows.Load()) + require.Equal(t, uint64(c.successRows), c.task.statistics.SuccessRows.Load()) + require.Equal(t, uint64(c.errorRows), c.task.statistics.ErrorRows.Load()) + } +} + +func TestTTLDeleteRateLimiter(t *testing.T) { + origDeleteLimit := variable.TTLDeleteRateLimit.Load() + defer func() { + variable.TTLDeleteRateLimit.Store(origDeleteLimit) + }() + + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer func() { + if cancel != nil { + cancel() + } + }() + + variable.TTLDeleteRateLimit.Store(100000) + require.NoError(t, globalDelRateLimiter.Wait(ctx)) + require.Equal(t, rate.Limit(100000), globalDelRateLimiter.limiter.Limit()) + require.Equal(t, int64(100000), globalDelRateLimiter.limit.Load()) + + variable.TTLDeleteRateLimit.Store(0) + require.NoError(t, globalDelRateLimiter.Wait(ctx)) + require.Equal(t, rate.Limit(0), globalDelRateLimiter.limiter.Limit()) + require.Equal(t, int64(0), globalDelRateLimiter.limit.Load()) + + // 0 stands for no limit + require.NoError(t, globalDelRateLimiter.Wait(ctx)) + // cancel ctx returns an error + cancel() + cancel = nil + require.EqualError(t, globalDelRateLimiter.Wait(ctx), "context canceled") +} + +func TestTTLDeleteTaskWorker(t *testing.T) { + origBatchSize := variable.TTLDeleteBatchSize.Load() + variable.TTLDeleteBatchSize.Store(3) + defer variable.TTLDeleteBatchSize.Store(origBatchSize) + + t1 := newMockTTLTbl(t, "t1") + t2 := newMockTTLTbl(t, "t2") + t3 := newMockTTLTbl(t, "t3") + s := newMockSession(t) + pool := newMockSessionPool(t) + pool.se = s + + sqlMap := make(map[string]struct{}) + s.executeSQL = func(ctx context.Context, sql string, args ...interface{}) ([]chunk.Row, error) { + pool.lastSession.sessionInfoSchema = newMockInfoSchema(t1.TableInfo, t2.TableInfo, t3.TableInfo) + if strings.Contains(sql, "`t1`") { + return nil, nil + } + + if strings.Contains(sql, "`t2`") { + if _, ok := sqlMap[sql]; ok { + return nil, nil + } + sqlMap[sql] = struct{}{} + return nil, errors.New("mockErr") + } + + if strings.Contains(sql, "`t3`") { + pool.lastSession.sessionInfoSchema = newMockInfoSchema() + return nil, nil + } + + require.FailNow(t, "") + return nil, nil + } + + delCh := make(chan *ttlDeleteTask) + w := newDeleteWorker(delCh, pool) + w.retryBuffer.retryInterval = time.Millisecond + require.Equal(t, workerStatusCreated, w.Status()) + w.Start() + require.Equal(t, workerStatusRunning, w.Status()) + defer func() { + w.Stop() + require.NoError(t, w.WaitStopped(context.TODO(), 10*time.Second)) + }() + + tasks := make([]*ttlDeleteTask, 0) + for _, tbl := range []*cache.PhysicalTable{t1, t2, t3} { + task := &ttlDeleteTask{ + tbl: tbl, + expire: time.UnixMilli(0), + rows: [][]types.Datum{ + {types.NewIntDatum(1)}, + {types.NewIntDatum(2)}, + {types.NewIntDatum(3)}, + }, + statistics: &ttlStatistics{}, + } + task.statistics.TotalRows.Add(3) + tasks = append(tasks, task) + select { + case delCh <- task: + case <-time.After(time.Second): + require.FailNow(t, "") + } + } + + time.Sleep(time.Millisecond * 100) + require.Equal(t, uint64(3), tasks[0].statistics.SuccessRows.Load()) + require.Equal(t, uint64(0), tasks[0].statistics.ErrorRows.Load()) + + require.Equal(t, uint64(3), tasks[1].statistics.SuccessRows.Load()) + require.Equal(t, uint64(0), tasks[1].statistics.ErrorRows.Load()) + + require.Equal(t, uint64(0), tasks[2].statistics.SuccessRows.Load()) + require.Equal(t, uint64(3), tasks[2].statistics.ErrorRows.Load()) +} diff --git a/ttl/ttlworker/scan.go b/ttl/ttlworker/scan.go new file mode 100644 index 0000000000000..4d9f7924cd9a6 --- /dev/null +++ b/ttl/ttlworker/scan.go @@ -0,0 +1,288 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ttlworker + +import ( + "context" + "strconv" + "sync/atomic" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/tidb/parser/terror" + "github.com/pingcap/tidb/sessionctx/variable" + "github.com/pingcap/tidb/ttl/cache" + "github.com/pingcap/tidb/ttl/sqlbuilder" + "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/logutil" + "go.uber.org/zap" +) + +var ( + scanTaskExecuteSQLMaxRetry = 5 + scanTaskExecuteSQLRetryInterval = 2 * time.Second + taskStartCheckErrorRateCnt = 10000 + taskMaxErrorRate = 0.4 +) + +type ttlStatistics struct { + TotalRows atomic.Uint64 + SuccessRows atomic.Uint64 + ErrorRows atomic.Uint64 +} + +func (s *ttlStatistics) IncTotalRows(cnt int) { + s.TotalRows.Add(uint64(cnt)) +} + +func (s *ttlStatistics) IncSuccessRows(cnt int) { + s.SuccessRows.Add(uint64(cnt)) +} + +func (s *ttlStatistics) IncErrorRows(cnt int) { + s.ErrorRows.Add(uint64(cnt)) +} + +func (s *ttlStatistics) Reset() { + s.SuccessRows.Store(0) + s.ErrorRows.Store(0) + s.TotalRows.Store(0) +} + +type ttlScanTask struct { + tbl *cache.PhysicalTable + expire time.Time + rangeStart []types.Datum + rangeEnd []types.Datum + statistics *ttlStatistics +} + +type ttlScanTaskExecResult struct { + task *ttlScanTask + err error +} + +func (t *ttlScanTask) result(err error) *ttlScanTaskExecResult { + return &ttlScanTaskExecResult{task: t, err: err} +} + +func (t *ttlScanTask) getDatumRows(rows []chunk.Row) [][]types.Datum { + datums := make([][]types.Datum, len(rows)) + for i, row := range rows { + datums[i] = row.GetDatumRow(t.tbl.KeyColumnTypes) + } + return datums +} + +func (t *ttlScanTask) doScan(ctx context.Context, delCh chan<- *ttlDeleteTask, sessPool sessionPool) *ttlScanTaskExecResult { + rawSess, err := getSession(sessPool) + if err != nil { + return t.result(err) + } + defer rawSess.Close() + + origConcurrency := rawSess.GetSessionVars().DistSQLScanConcurrency() + if _, err = rawSess.ExecuteSQL(ctx, "set @@tidb_distsql_scan_concurrency=1"); err != nil { + return t.result(err) + } + + defer func() { + _, err = rawSess.ExecuteSQL(ctx, "set @@tidb_distsql_scan_concurrency="+strconv.Itoa(origConcurrency)) + terror.Log(err) + }() + + sess := newTableSession(rawSess, t.tbl, t.expire) + generator, err := sqlbuilder.NewScanQueryGenerator(t.tbl, t.expire, t.rangeStart, t.rangeEnd) + if err != nil { + return t.result(err) + } + + retrySQL := "" + retryTimes := 0 + var lastResult [][]types.Datum + for { + if err = ctx.Err(); err != nil { + return t.result(err) + } + + if total := t.statistics.TotalRows.Load(); total > uint64(taskStartCheckErrorRateCnt) { + if t.statistics.ErrorRows.Load() > uint64(float64(total)*taskMaxErrorRate) { + return t.result(errors.Errorf("error exceeds the limit")) + } + } + + sql := retrySQL + if sql == "" { + limit := int(variable.TTLScanBatchSize.Load()) + if sql, err = generator.NextSQL(lastResult, limit); err != nil { + return t.result(err) + } + } + + if sql == "" { + return t.result(nil) + } + + rows, retryable, sqlErr := sess.ExecuteSQLWithCheck(ctx, sql) + if sqlErr != nil { + needRetry := retryable && retryTimes < scanTaskExecuteSQLMaxRetry && ctx.Err() == nil + logutil.BgLogger().Error("execute query for ttl scan task failed", + zap.String("SQL", sql), + zap.Int("retryTimes", retryTimes), + zap.Bool("needRetry", needRetry), + zap.Error(err), + ) + + if !needRetry { + return t.result(sqlErr) + } + retrySQL = sql + retryTimes++ + select { + case <-ctx.Done(): + return t.result(ctx.Err()) + case <-time.After(scanTaskExecuteSQLRetryInterval): + } + continue + } + + retrySQL = "" + retryTimes = 0 + lastResult = t.getDatumRows(rows) + if len(rows) == 0 { + continue + } + + delTask := &ttlDeleteTask{ + tbl: t.tbl, + expire: t.expire, + rows: lastResult, + statistics: t.statistics, + } + select { + case <-ctx.Done(): + return t.result(ctx.Err()) + case delCh <- delTask: + t.statistics.IncTotalRows(len(lastResult)) + } + } +} + +type scanTaskExecEndMsg struct { + result *ttlScanTaskExecResult +} + +type ttlScanWorker struct { + baseWorker + curTask *ttlScanTask + curTaskResult *ttlScanTaskExecResult + delCh chan<- *ttlDeleteTask + notifyStateCh chan<- interface{} + sessionPool sessionPool +} + +func newScanWorker(delCh chan<- *ttlDeleteTask, notifyStateCh chan<- interface{}, sessPool sessionPool) *ttlScanWorker { + w := &ttlScanWorker{ + delCh: delCh, + notifyStateCh: notifyStateCh, + sessionPool: sessPool, + } + w.init(w.loop) + return w +} + +func (w *ttlScanWorker) Idle() bool { + w.Lock() + defer w.Unlock() + return w.status == workerStatusRunning && w.curTask == nil +} + +func (w *ttlScanWorker) Schedule(task *ttlScanTask) error { + w.Lock() + defer w.Unlock() + if w.status != workerStatusRunning { + return errors.New("worker is not running") + } + + if w.curTaskResult != nil { + return errors.New("the result of previous task has not been polled") + } + + if w.curTask != nil { + return errors.New("a task is running") + } + + w.curTask = task + w.curTaskResult = nil + w.baseWorker.ch <- task + return nil +} + +func (w *ttlScanWorker) CurrentTask() *ttlScanTask { + w.Lock() + defer w.Unlock() + return w.curTask +} + +func (w *ttlScanWorker) PollTaskResult() (*ttlScanTaskExecResult, bool) { + w.Lock() + defer w.Unlock() + if r := w.curTaskResult; r != nil { + w.curTask = nil + w.curTaskResult = nil + return r, true + } + return nil, false +} + +func (w *ttlScanWorker) loop() error { + ctx := w.baseWorker.ctx + for w.status == workerStatusRunning { + select { + case <-ctx.Done(): + return nil + case msg, ok := <-w.baseWorker.ch: + if !ok { + return nil + } + switch task := msg.(type) { + case *ttlScanTask: + w.handleScanTask(ctx, task) + default: + logutil.BgLogger().Warn("unrecognized message for ttlScanWorker", zap.Any("msg", msg)) + } + } + } + return nil +} + +func (w *ttlScanWorker) handleScanTask(ctx context.Context, task *ttlScanTask) { + result := task.doScan(ctx, w.delCh, w.sessionPool) + if result == nil { + result = task.result(nil) + } + + w.baseWorker.Lock() + w.curTaskResult = result + w.baseWorker.Unlock() + + if w.notifyStateCh != nil { + select { + case w.notifyStateCh <- &scanTaskExecEndMsg{result: result}: + default: + } + } +} diff --git a/ttl/ttlworker/scan_test.go b/ttl/ttlworker/scan_test.go new file mode 100644 index 0000000000000..1eeb7a306afca --- /dev/null +++ b/ttl/ttlworker/scan_test.go @@ -0,0 +1,397 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ttlworker + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/tidb/sessionctx/variable" + "github.com/pingcap/tidb/ttl/cache" + "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/util/chunk" + "github.com/stretchr/testify/require" +) + +type mockScanWorker struct { + *ttlScanWorker + t *testing.T + delCh chan *ttlDeleteTask + notifyCh chan interface{} + sessPoll *mockSessionPool +} + +func newMockScanWorker(t *testing.T) *mockScanWorker { + w := &mockScanWorker{ + t: t, + delCh: make(chan *ttlDeleteTask), + notifyCh: make(chan interface{}, 10), + sessPoll: newMockSessionPool(t), + } + + w.ttlScanWorker = newScanWorker(w.delCh, w.notifyCh, w.sessPoll) + require.Equal(t, workerStatusCreated, w.Status()) + require.False(t, w.Idle()) + result, ok := w.PollTaskResult() + require.False(t, ok) + require.Nil(t, result) + return w +} + +func (w *mockScanWorker) checkWorkerStatus(status workerStatus, idle bool, curTask *ttlScanTask) { + require.Equal(w.t, status, w.status) + require.Equal(w.t, idle, w.Idle()) + require.Same(w.t, curTask, w.CurrentTask()) +} + +func (w *mockScanWorker) checkPollResult(exist bool, err string) { + curTask := w.CurrentTask() + r, ok := w.PollTaskResult() + require.Equal(w.t, exist, ok) + if !exist { + require.Nil(w.t, r) + } else { + require.NotNil(w.t, r) + require.NotNil(w.t, r.task) + require.Same(w.t, curTask, r.task) + if err == "" { + require.NoError(w.t, r.err) + } else { + require.EqualError(w.t, r.err, err) + } + } +} + +func (w *mockScanWorker) waitNotifyScanTaskEnd() *scanTaskExecEndMsg { + select { + case msg := <-w.notifyCh: + endMsg, ok := msg.(*scanTaskExecEndMsg) + require.True(w.t, ok) + require.NotNil(w.t, endMsg.result) + require.Same(w.t, w.CurrentTask(), endMsg.result.task) + return endMsg + case <-time.After(10 * time.Second): + require.FailNow(w.t, "timeout") + } + + require.FailNow(w.t, "") + return nil +} + +func (w *mockScanWorker) pollDelTask() *ttlDeleteTask { + select { + case del := <-w.delCh: + require.NotNil(w.t, del) + require.NotNil(w.t, del.statistics) + require.Same(w.t, w.curTask.tbl, del.tbl) + require.Equal(w.t, w.curTask.expire, del.expire) + require.NotEqual(w.t, 0, len(del.rows)) + return del + case <-time.After(10 * time.Second): + require.FailNow(w.t, "timeout") + } + + require.FailNow(w.t, "") + return nil +} + +func (w *mockScanWorker) setOneRowResult(tbl *cache.PhysicalTable, val ...interface{}) { + w.sessPoll.se.sessionInfoSchema = newMockInfoSchema(tbl.TableInfo) + w.sessPoll.se.rows = newMockRows(w.t, tbl.KeyColumnTypes...).Append(val...).Rows() +} + +func (w *mockScanWorker) clearInfoSchema() { + w.sessPoll.se.sessionInfoSchema = newMockInfoSchema() +} + +func (w *mockScanWorker) stopWithWait() { + w.Stop() + require.NoError(w.t, w.WaitStopped(context.TODO(), 10*time.Second)) +} + +func TestScanWorkerSchedule(t *testing.T) { + origLimit := variable.TTLScanBatchSize.Load() + variable.TTLScanBatchSize.Store(5) + defer variable.TTLScanBatchSize.Store(origLimit) + + tbl := newMockTTLTbl(t, "t1") + w := newMockScanWorker(t) + w.setOneRowResult(tbl, 7) + defer w.stopWithWait() + + task := &ttlScanTask{ + tbl: tbl, + expire: time.UnixMilli(0), + statistics: &ttlStatistics{}, + } + + require.EqualError(t, w.Schedule(task), "worker is not running") + w.checkWorkerStatus(workerStatusCreated, false, nil) + w.checkPollResult(false, "") + + w.Start() + w.checkWorkerStatus(workerStatusRunning, true, nil) + w.checkPollResult(false, "") + + require.NoError(t, w.Schedule(task)) + w.checkWorkerStatus(workerStatusRunning, false, task) + w.checkPollResult(false, "") + + require.EqualError(t, w.Schedule(task), "a task is running") + w.checkWorkerStatus(workerStatusRunning, false, task) + w.checkPollResult(false, "") + + del := w.pollDelTask() + require.Equal(t, 1, len(del.rows)) + require.Equal(t, 1, len(del.rows[0])) + require.Equal(t, int64(7), del.rows[0][0].GetInt64()) + + msg := w.waitNotifyScanTaskEnd() + require.Same(t, task, msg.result.task) + require.NoError(t, msg.result.err) + w.checkWorkerStatus(workerStatusRunning, false, task) + w.checkPollResult(true, "") + w.checkWorkerStatus(workerStatusRunning, true, nil) + w.checkPollResult(false, "") +} + +func TestScanWorkerScheduleWithFailedTask(t *testing.T) { + origLimit := variable.TTLScanBatchSize.Load() + variable.TTLScanBatchSize.Store(5) + defer variable.TTLScanBatchSize.Store(origLimit) + + tbl := newMockTTLTbl(t, "t1") + w := newMockScanWorker(t) + w.clearInfoSchema() + defer w.stopWithWait() + + task := &ttlScanTask{ + tbl: tbl, + expire: time.UnixMilli(0), + statistics: &ttlStatistics{}, + } + + w.Start() + w.checkWorkerStatus(workerStatusRunning, true, nil) + w.checkPollResult(false, "") + + require.NoError(t, w.Schedule(task)) + w.checkWorkerStatus(workerStatusRunning, false, task) + msg := w.waitNotifyScanTaskEnd() + require.Same(t, task, msg.result.task) + require.EqualError(t, msg.result.err, "table 'test.t1' meta changed, should abort current job: [schema:1146]Table 'test.t1' doesn't exist") + w.checkWorkerStatus(workerStatusRunning, false, task) + w.checkPollResult(true, msg.result.err.Error()) + w.checkWorkerStatus(workerStatusRunning, true, nil) +} + +type mockScanTask struct { + *ttlScanTask + t *testing.T + tbl *cache.PhysicalTable + sessPool *mockSessionPool + sqlRetry []int + + delCh chan *ttlDeleteTask + prevSQL string + prevSQLRetry int + delTasks []*ttlDeleteTask + schemaChangeIdx int + schemaChangeInRetry int +} + +func newMockScanTask(t *testing.T, sqlCnt int) *mockScanTask { + tbl := newMockTTLTbl(t, "t1") + task := &mockScanTask{ + t: t, + ttlScanTask: &ttlScanTask{ + tbl: tbl, + expire: time.UnixMilli(0), + rangeStart: []types.Datum{types.NewIntDatum(0)}, + statistics: &ttlStatistics{}, + }, + tbl: tbl, + delCh: make(chan *ttlDeleteTask, sqlCnt*(scanTaskExecuteSQLMaxRetry+1)), + sessPool: newMockSessionPool(t), + sqlRetry: make([]int, sqlCnt), + schemaChangeIdx: -1, + } + task.sessPool.se.executeSQL = task.execSQL + return task +} + +func (t *mockScanTask) selectSQL(i int) string { + return fmt.Sprintf("SELECT LOW_PRIORITY `_tidb_rowid` FROM `test`.`t1` WHERE `_tidb_rowid` > %d AND `time` < '1970-01-01 08:00:00' ORDER BY `_tidb_rowid` ASC LIMIT 3", i*100) +} + +func (t *mockScanTask) runDoScanForTest(delTaskCnt int, errString string) *ttlScanTaskExecResult { + t.ttlScanTask.statistics.Reset() + origLimit := variable.TTLScanBatchSize.Load() + variable.TTLScanBatchSize.Store(3) + origRetryInterval := scanTaskExecuteSQLRetryInterval + scanTaskExecuteSQLRetryInterval = time.Millisecond + defer func() { + variable.TTLScanBatchSize.Store(origLimit) + scanTaskExecuteSQLRetryInterval = origRetryInterval + }() + + t.sessPool.se.sessionInfoSchema = newMockInfoSchema(t.tbl.TableInfo) + t.prevSQL = "" + t.prevSQLRetry = 0 + t.sessPool.lastSession = nil + r := t.doScan(context.TODO(), t.delCh, t.sessPool) + require.NotNil(t.t, t.sessPool.lastSession) + require.True(t.t, t.sessPool.lastSession.closed) + require.Greater(t.t, t.sessPool.lastSession.resetTimeZoneCalls, 0) + require.NotNil(t.t, r) + require.Same(t.t, t.ttlScanTask, r.task) + if errString == "" { + require.NoError(t.t, r.err) + } else { + require.EqualError(t.t, r.err, errString) + } + + previousIdx := delTaskCnt + if errString == "" { + previousIdx = len(t.sqlRetry) - 1 + } + require.Equal(t.t, t.selectSQL(previousIdx), t.prevSQL) + if errString == "" { + require.Equal(t.t, t.sqlRetry[previousIdx], t.prevSQLRetry) + } else if previousIdx == t.schemaChangeIdx && t.schemaChangeInRetry <= scanTaskExecuteSQLMaxRetry { + require.Equal(t.t, t.schemaChangeInRetry, t.prevSQLRetry) + } else { + require.Equal(t.t, scanTaskExecuteSQLMaxRetry, t.prevSQLRetry) + } + t.delTasks = make([]*ttlDeleteTask, 0, len(t.sqlRetry)) +loop: + for { + select { + case del, ok := <-t.delCh: + if !ok { + break loop + } + t.delTasks = append(t.delTasks, del) + default: + break loop + } + } + + require.Equal(t.t, delTaskCnt, len(t.delTasks)) + expectTotalRows := 0 + for i, del := range t.delTasks { + require.NotNil(t.t, del) + require.NotNil(t.t, del.statistics) + require.Same(t.t, t.statistics, del.statistics) + require.Same(t.t, t.tbl, del.tbl) + require.Equal(t.t, t.expire, del.expire) + if i < len(t.sqlRetry)-1 { + require.Equal(t.t, 3, len(del.rows)) + require.Equal(t.t, 1, len(del.rows[2])) + require.Equal(t.t, int64((i+1)*100), del.rows[2][0].GetInt64()) + } else { + require.Equal(t.t, 2, len(del.rows)) + } + require.Equal(t.t, 1, len(del.rows[0])) + require.Equal(t.t, int64(i*100+1), del.rows[0][0].GetInt64()) + require.Equal(t.t, 1, len(del.rows[0])) + require.Equal(t.t, int64(i*100+2), del.rows[1][0].GetInt64()) + expectTotalRows += len(del.rows) + } + require.Equal(t.t, expectTotalRows, int(t.statistics.TotalRows.Load())) + return r +} + +func (t *mockScanTask) checkDelTasks(cnt int) { + require.Equal(t.t, cnt, len(t.delTasks)) + for i := 0; i < cnt; i++ { + del := t.delTasks[i] + require.Nil(t.t, del) + require.NotNil(t.t, del.statistics) + require.Same(t.t, t.statistics, del.statistics) + if i < 2 { + require.Equal(t.t, 3, len(del.rows)) + require.Equal(t.t, 1, len(del.rows[2])) + require.Equal(t.t, int64((i+1)*100), del.rows[2][0].GetInt64()) + } else { + require.Equal(t.t, 2, len(del.rows)) + } + require.Equal(t.t, 1, len(del.rows[0])) + require.Equal(t.t, int64(i*100+1), del.rows[0][0].GetInt64()) + require.Equal(t.t, 1, len(del.rows[0])) + require.Equal(t.t, int64(i*100+2), del.rows[1][0].GetInt64()) + } +} + +func (t *mockScanTask) execSQL(_ context.Context, sql string, _ ...interface{}) ([]chunk.Row, error) { + var i int + found := false + for i = 0; i < len(t.sqlRetry); i++ { + if sql == t.selectSQL(i) { + found = true + break + } + } + require.True(t.t, found, sql) + + curRetry := 0 + if sql == t.prevSQL { + curRetry = t.prevSQLRetry + 1 + } + + if curRetry == 0 && i > 0 { + require.Equal(t.t, t.selectSQL(i-1), t.prevSQL) + require.Equal(t.t, t.sqlRetry[i-1], t.prevSQLRetry) + } + t.prevSQL = sql + t.prevSQLRetry = curRetry + require.LessOrEqual(t.t, curRetry, t.sqlRetry[i]) + + if t.schemaChangeIdx == i && t.schemaChangeInRetry == curRetry { + t.sessPool.lastSession.sessionInfoSchema = newMockInfoSchema() + } + + if curRetry < t.sqlRetry[i] { + return nil, errors.New("mockErr") + } + + rows := newMockRows(t.t, t.tbl.KeyColumnTypes...).Append(i*100 + 1).Append(i*100 + 2) + if i < len(t.sqlRetry)-1 { + rows.Append((i + 1) * 100) + } + return rows.Rows(), nil +} + +func TestScanTaskDoScan(t *testing.T) { + task := newMockScanTask(t, 3) + task.sqlRetry[1] = scanTaskExecuteSQLMaxRetry + task.runDoScanForTest(3, "") + + task.sqlRetry[1] = scanTaskExecuteSQLMaxRetry + 1 + task.runDoScanForTest(1, "mockErr") + + task.sqlRetry[1] = scanTaskExecuteSQLMaxRetry + task.schemaChangeIdx = 1 + task.schemaChangeInRetry = 0 + task.runDoScanForTest(1, "table 'test.t1' meta changed, should abort current job: [schema:1146]Table 'test.t1' doesn't exist") + + task.sqlRetry[1] = scanTaskExecuteSQLMaxRetry + task.schemaChangeIdx = 1 + task.schemaChangeInRetry = 2 + task.runDoScanForTest(1, "table 'test.t1' meta changed, should abort current job: [schema:1146]Table 'test.t1' doesn't exist") +} diff --git a/ttl/ttlworker/session.go b/ttl/ttlworker/session.go new file mode 100644 index 0000000000000..1514c1746f348 --- /dev/null +++ b/ttl/ttlworker/session.go @@ -0,0 +1,168 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ttlworker + +import ( + "context" + "time" + + "github.com/ngaut/pools" + "github.com/pingcap/errors" + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/sessionctx/variable" + "github.com/pingcap/tidb/ttl/cache" + "github.com/pingcap/tidb/ttl/session" + "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/sqlexec" +) + +type sessionPool interface { + Get() (pools.Resource, error) + Put(pools.Resource) +} + +func getSession(pool sessionPool) (session.Session, error) { + resource, err := pool.Get() + if err != nil { + return nil, err + } + + if se, ok := resource.(session.Session); ok { + // Only for test, in this case, the return session is mockSession + return se, nil + } + + sctx, ok := resource.(sessionctx.Context) + if !ok { + pool.Put(resource) + return nil, errors.Errorf("%T cannot be casted to sessionctx.Context", sctx) + } + + exec, ok := resource.(sqlexec.SQLExecutor) + if !ok { + pool.Put(resource) + return nil, errors.Errorf("%T cannot be casted to sqlexec.SQLExecutor", sctx) + } + + se := session.NewSession(sctx, exec, func() { + pool.Put(resource) + }) + + if _, err = se.ExecuteSQL(context.Background(), "commit"); err != nil { + se.Close() + return nil, err + } + + return se, nil +} + +func newTableSession(se session.Session, tbl *cache.PhysicalTable, expire time.Time) *ttlTableSession { + return &ttlTableSession{ + Session: se, + tbl: tbl, + expire: expire, + } +} + +type ttlTableSession struct { + session.Session + tbl *cache.PhysicalTable + expire time.Time +} + +func (s *ttlTableSession) ExecuteSQLWithCheck(ctx context.Context, sql string) (rows []chunk.Row, shouldRetry bool, err error) { + if !variable.EnableTTLJob.Load() { + return nil, false, errors.New("global TTL job is disabled") + } + + if err = s.ResetWithGlobalTimeZone(ctx); err != nil { + return nil, false, err + } + + err = s.RunInTxn(ctx, func() error { + rows, err = s.ExecuteSQL(ctx, sql) + // We must check the configuration after ExecuteSQL because of MDL and the meta the current transaction used + // can only be determined after executed one query. + if validateErr := validateTTLWork(ctx, s.Session, s.tbl, s.expire); validateErr != nil { + shouldRetry = false + return errors.Annotatef(validateErr, "table '%s.%s' meta changed, should abort current job", s.tbl.Schema, s.tbl.Name) + } + + if err != nil { + shouldRetry = true + return err + } + return nil + }) + + if err != nil { + return nil, shouldRetry, err + } + + return rows, false, nil +} + +func validateTTLWork(ctx context.Context, s session.Session, tbl *cache.PhysicalTable, expire time.Time) error { + curTbl, err := s.SessionInfoSchema().TableByName(tbl.Schema, tbl.Name) + if err != nil { + return err + } + + newTblInfo := curTbl.Meta() + if tbl.TableInfo == newTblInfo { + return nil + } + + if tbl.TableInfo.ID != newTblInfo.ID { + return errors.New("table id changed") + } + + newTTLTbl, err := cache.NewPhysicalTable(tbl.Schema, newTblInfo, tbl.Partition) + if err != nil { + return err + } + + if newTTLTbl.ID != tbl.ID { + return errors.New("physical id changed") + } + + if tbl.Partition.L != "" { + if newTTLTbl.PartitionDef.Name.L != tbl.PartitionDef.Name.L { + return errors.New("partition name changed") + } + } + + if !newTTLTbl.TTLInfo.Enable { + return errors.New("table TTL disabled") + } + + if newTTLTbl.TimeColumn.Name.L != tbl.TimeColumn.Name.L { + return errors.New("time column name changed") + } + + if newTblInfo.TTLInfo.IntervalExprStr != tbl.TTLInfo.IntervalExprStr || + newTblInfo.TTLInfo.IntervalTimeUnit != tbl.TTLInfo.IntervalTimeUnit { + newExpireTime, err := newTTLTbl.EvalExpireTime(ctx, s, time.Now()) + if err != nil { + return err + } + + if newExpireTime.Before(expire) { + return errors.New("expire interval changed") + } + } + + return nil +} diff --git a/ttl/ttlworker/session_test.go b/ttl/ttlworker/session_test.go new file mode 100644 index 0000000000000..b12d47dd9f2bc --- /dev/null +++ b/ttl/ttlworker/session_test.go @@ -0,0 +1,329 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ttlworker + +import ( + "context" + "errors" + "strings" + "testing" + "time" + + "github.com/ngaut/pools" + "github.com/pingcap/tidb/infoschema" + "github.com/pingcap/tidb/parser/ast" + "github.com/pingcap/tidb/parser/model" + "github.com/pingcap/tidb/parser/mysql" + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/sessionctx/variable" + "github.com/pingcap/tidb/ttl/cache" + "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/util/chunk" + "github.com/stretchr/testify/require" +) + +func newMockTTLTbl(t *testing.T, name string) *cache.PhysicalTable { + tblInfo := &model.TableInfo{ + Name: model.NewCIStr(name), + Columns: []*model.ColumnInfo{ + { + ID: 1, + Name: model.NewCIStr("time"), + Offset: 0, + FieldType: *types.NewFieldType(mysql.TypeDatetime), + State: model.StatePublic, + }, + }, + TTLInfo: &model.TTLInfo{ + ColumnName: model.NewCIStr("time"), + IntervalExprStr: "1", + IntervalTimeUnit: int(ast.TimeUnitSecond), + Enable: true, + }, + State: model.StatePublic, + } + + tbl, err := cache.NewPhysicalTable(model.NewCIStr("test"), tblInfo, model.NewCIStr("")) + require.NoError(t, err) + return tbl +} + +func newMockInfoSchema(tbl ...*model.TableInfo) infoschema.InfoSchema { + return infoschema.MockInfoSchema(tbl) +} + +type mockRows struct { + t *testing.T + fieldTypes []*types.FieldType + *chunk.Chunk +} + +func newMockRows(t *testing.T, fieldTypes ...*types.FieldType) *mockRows { + return &mockRows{ + t: t, + fieldTypes: fieldTypes, + Chunk: chunk.NewChunkWithCapacity(fieldTypes, 8), + } +} + +func (r *mockRows) Append(row ...interface{}) *mockRows { + require.Equal(r.t, len(r.fieldTypes), len(row)) + for i, ft := range r.fieldTypes { + tp := ft.GetType() + switch tp { + case mysql.TypeTimestamp, mysql.TypeDate, mysql.TypeDatetime: + tm, ok := row[i].(time.Time) + require.True(r.t, ok) + r.AppendTime(i, types.NewTime(types.FromGoTime(tm), tp, types.DefaultFsp)) + case mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong: + val, ok := row[i].(int) + require.True(r.t, ok) + r.AppendInt64(i, int64(val)) + default: + require.FailNow(r.t, "unsupported tp %v", tp) + } + } + return r +} + +func (r *mockRows) Rows() []chunk.Row { + rows := make([]chunk.Row, r.NumRows()) + for i := 0; i < r.NumRows(); i++ { + rows[i] = r.GetRow(i) + } + return rows +} + +type mockSessionPool struct { + t *testing.T + se *mockSession + lastSession *mockSession +} + +func (p *mockSessionPool) Get() (pools.Resource, error) { + se := *(p.se) + p.lastSession = &se + return p.lastSession, nil +} + +func (p *mockSessionPool) Put(pools.Resource) {} + +func newMockSessionPool(t *testing.T, tbl ...*cache.PhysicalTable) *mockSessionPool { + return &mockSessionPool{ + se: newMockSession(t, tbl...), + } +} + +type mockSession struct { + t *testing.T + sessionctx.Context + sessionVars *variable.SessionVars + sessionInfoSchema infoschema.InfoSchema + executeSQL func(ctx context.Context, sql string, args ...interface{}) ([]chunk.Row, error) + rows []chunk.Row + execErr error + evalExpire time.Time + resetTimeZoneCalls int + closed bool +} + +func newMockSession(t *testing.T, tbl ...*cache.PhysicalTable) *mockSession { + tbls := make([]*model.TableInfo, len(tbl)) + for i, ttlTbl := range tbl { + tbls[i] = ttlTbl.TableInfo + } + sessVars := variable.NewSessionVars(nil) + sessVars.TimeZone = time.UTC + return &mockSession{ + t: t, + sessionInfoSchema: newMockInfoSchema(tbls...), + evalExpire: time.Now(), + sessionVars: sessVars, + } +} + +func (s *mockSession) SessionInfoSchema() infoschema.InfoSchema { + require.False(s.t, s.closed) + return s.sessionInfoSchema +} + +func (s *mockSession) GetSessionVars() *variable.SessionVars { + require.False(s.t, s.closed) + return s.sessionVars +} + +func (s *mockSession) ExecuteSQL(ctx context.Context, sql string, args ...interface{}) ([]chunk.Row, error) { + require.False(s.t, s.closed) + if strings.HasPrefix(strings.ToUpper(sql), "SELECT FROM_UNIXTIME") { + return newMockRows(s.t, types.NewFieldType(mysql.TypeTimestamp)).Append(s.evalExpire.In(s.GetSessionVars().TimeZone)).Rows(), nil + } + + if strings.HasPrefix(strings.ToUpper(sql), "SET ") { + return nil, nil + } + + if s.executeSQL != nil { + return s.executeSQL(ctx, sql, args) + } + return s.rows, s.execErr +} + +func (s *mockSession) RunInTxn(_ context.Context, fn func() error) (err error) { + require.False(s.t, s.closed) + return fn() +} + +func (s *mockSession) ResetWithGlobalTimeZone(_ context.Context) (err error) { + require.False(s.t, s.closed) + s.resetTimeZoneCalls++ + return nil +} + +func (s *mockSession) Close() { + s.closed = true +} + +func TestExecuteSQLWithCheck(t *testing.T) { + ctx := context.TODO() + tbl := newMockTTLTbl(t, "t1") + s := newMockSession(t, tbl) + s.execErr = errors.New("mockErr") + s.rows = newMockRows(t, types.NewFieldType(mysql.TypeInt24)).Append(12).Rows() + tblSe := newTableSession(s, tbl, time.UnixMilli(0).In(time.UTC)) + + rows, shouldRetry, err := tblSe.ExecuteSQLWithCheck(ctx, "select 1") + require.EqualError(t, err, "mockErr") + require.True(t, shouldRetry) + require.Nil(t, rows) + require.Equal(t, 1, s.resetTimeZoneCalls) + + s.sessionInfoSchema = newMockInfoSchema() + rows, shouldRetry, err = tblSe.ExecuteSQLWithCheck(ctx, "select 1") + require.EqualError(t, err, "table 'test.t1' meta changed, should abort current job: [schema:1146]Table 'test.t1' doesn't exist") + require.False(t, shouldRetry) + require.Nil(t, rows) + require.Equal(t, 2, s.resetTimeZoneCalls) + + s.sessionInfoSchema = newMockInfoSchema(tbl.TableInfo) + s.execErr = nil + rows, shouldRetry, err = tblSe.ExecuteSQLWithCheck(ctx, "select 1") + require.NoError(t, err) + require.False(t, shouldRetry) + require.Equal(t, 1, len(rows)) + require.Equal(t, int64(12), rows[0].GetInt64(0)) + require.Equal(t, 3, s.resetTimeZoneCalls) +} + +func TestValidateTTLWork(t *testing.T) { + ctx := context.TODO() + tbl := newMockTTLTbl(t, "t1") + expire := time.UnixMilli(0).In(time.UTC) + + s := newMockSession(t, tbl) + s.execErr = errors.New("mockErr") + s.evalExpire = time.UnixMilli(0).In(time.UTC) + + // test table dropped + s.sessionInfoSchema = newMockInfoSchema() + err := validateTTLWork(ctx, s, tbl, expire) + require.EqualError(t, err, "[schema:1146]Table 'test.t1' doesn't exist") + + // test TTL option removed + tbl2 := tbl.TableInfo.Clone() + tbl2.TTLInfo = nil + s.sessionInfoSchema = newMockInfoSchema(tbl2) + err = validateTTLWork(ctx, s, tbl, expire) + require.EqualError(t, err, "table 'test.t1' is not a ttl table") + + // test table state not public + tbl2 = tbl.TableInfo.Clone() + tbl2.State = model.StateDeleteOnly + s.sessionInfoSchema = newMockInfoSchema(tbl2) + err = validateTTLWork(ctx, s, tbl, expire) + require.EqualError(t, err, "table 'test.t1' is not a public table") + + // test table name changed + tbl2 = tbl.TableInfo.Clone() + tbl2.Name = model.NewCIStr("testcc") + s.sessionInfoSchema = newMockInfoSchema(tbl2) + err = validateTTLWork(ctx, s, tbl, expire) + require.EqualError(t, err, "[schema:1146]Table 'test.t1' doesn't exist") + + // test table id changed + tbl2 = tbl.TableInfo.Clone() + tbl2.ID = 123 + s.sessionInfoSchema = newMockInfoSchema(tbl2) + err = validateTTLWork(ctx, s, tbl, expire) + require.EqualError(t, err, "table id changed") + + // test time column name changed + tbl2 = tbl.TableInfo.Clone() + tbl2.Columns[0] = tbl2.Columns[0].Clone() + tbl2.Columns[0].Name = model.NewCIStr("time2") + tbl2.TTLInfo.ColumnName = model.NewCIStr("time2") + s.sessionInfoSchema = newMockInfoSchema(tbl2) + err = validateTTLWork(ctx, s, tbl, expire) + require.EqualError(t, err, "time column name changed") + + // test interval changed and expire time before previous + tbl2 = tbl.TableInfo.Clone() + tbl2.TTLInfo.IntervalExprStr = "10" + s.sessionInfoSchema = newMockInfoSchema(tbl2) + s.evalExpire = time.UnixMilli(-1) + err = validateTTLWork(ctx, s, tbl, expire) + require.EqualError(t, err, "expire interval changed") + + tbl2 = tbl.TableInfo.Clone() + tbl2.TTLInfo.IntervalTimeUnit = int(ast.TimeUnitDay) + s.evalExpire = time.UnixMilli(-1) + s.sessionInfoSchema = newMockInfoSchema(tbl2) + err = validateTTLWork(ctx, s, tbl, expire) + require.EqualError(t, err, "expire interval changed") + + // test for safe meta change + tbl2 = tbl.TableInfo.Clone() + tbl2.Columns[0] = tbl2.Columns[0].Clone() + tbl2.Columns[0].ID += 10 + tbl2.Columns[0].FieldType = *types.NewFieldType(mysql.TypeDate) + tbl2.TTLInfo.IntervalExprStr = "100" + s.evalExpire = time.UnixMilli(1000) + s.sessionInfoSchema = newMockInfoSchema(tbl2) + err = validateTTLWork(ctx, s, tbl, expire) + require.NoError(t, err) + + // test table partition name changed + tp := tbl.TableInfo.Clone() + tp.Partition = &model.PartitionInfo{ + Definitions: []model.PartitionDefinition{ + {ID: 1023, Name: model.NewCIStr("p0")}, + }, + } + tbl, err = cache.NewPhysicalTable(model.NewCIStr("test"), tp, model.NewCIStr("p0")) + require.NoError(t, err) + tbl2 = tp.Clone() + tbl2.Partition = tp.Partition.Clone() + tbl2.Partition.Definitions[0].Name = model.NewCIStr("p1") + s.sessionInfoSchema = newMockInfoSchema(tbl2) + err = validateTTLWork(ctx, s, tbl, expire) + require.EqualError(t, err, "partition 'p0' is not found in ttl table 'test.t1'") + + // test table partition id changed + tbl2 = tp.Clone() + tbl2.Partition = tp.Partition.Clone() + tbl2.Partition.Definitions[0].ID += 100 + s.sessionInfoSchema = newMockInfoSchema(tbl2) + err = validateTTLWork(ctx, s, tbl, expire) + require.EqualError(t, err, "physical id changed") +} diff --git a/ttl/ttlworker/worker.go b/ttl/ttlworker/worker.go new file mode 100644 index 0000000000000..d03a747d2855e --- /dev/null +++ b/ttl/ttlworker/worker.go @@ -0,0 +1,126 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ttlworker + +import ( + "context" + "sync" + "time" + + "github.com/pingcap/tidb/util" +) + +type workerStatus int + +const ( + workerStatusCreated workerStatus = iota + workerStatusRunning + workerStatusStopping + workerStatusStopped +) + +type worker interface { + Start() + Stop() + Status() workerStatus + Error() error + Send() chan<- interface{} + WaitStopped(ctx context.Context, timeout time.Duration) error +} + +type baseWorker struct { + sync.Mutex + ctx context.Context + cancel func() + ch chan interface{} + loopFunc func() error + + err error + status workerStatus + wg util.WaitGroupWrapper +} + +func (w *baseWorker) init(loop func() error) { + w.ctx, w.cancel = context.WithCancel(context.Background()) + w.status = workerStatusCreated + w.loopFunc = loop + w.ch = make(chan interface{}) +} + +func (w *baseWorker) Start() { + w.Lock() + defer w.Unlock() + if w.status != workerStatusCreated { + return + } + + w.wg.Run(w.loop) + w.status = workerStatusRunning +} + +func (w *baseWorker) Stop() { + w.Lock() + defer w.Unlock() + switch w.status { + case workerStatusCreated: + w.cancel() + w.toStopped(nil) + case workerStatusRunning: + w.cancel() + w.status = workerStatusStopping + } +} + +func (w *baseWorker) Status() workerStatus { + w.Lock() + defer w.Unlock() + return w.status +} + +func (w *baseWorker) Error() error { + w.Lock() + defer w.Unlock() + return w.err +} + +func (w *baseWorker) WaitStopped(ctx context.Context, timeout time.Duration) error { + ctx, cancel := context.WithTimeout(ctx, timeout) + go func() { + w.wg.Wait() + cancel() + }() + + <-ctx.Done() + if w.Status() != workerStatusStopped { + return ctx.Err() + } + return nil +} + +func (w *baseWorker) loop() { + var err error + defer func() { + w.Lock() + w.toStopped(err) + w.Unlock() + }() + err = w.loopFunc() +} + +func (w *baseWorker) toStopped(err error) { + w.status = workerStatusStopped + w.err = err + close(w.ch) +}