diff --git a/kv/option.go b/kv/option.go index ee5354141cd7b..a0e658f45aade 100644 --- a/kv/option.go +++ b/kv/option.go @@ -167,4 +167,6 @@ const ( InternalTxnBR = InternalTxnTools // InternalTxnTrace handles the trace statement. InternalTxnTrace = "Trace" + // InternalTxnTTL is the type of TTL usage + InternalTxnTTL = "TTL" ) diff --git a/ttl/BUILD.bazel b/ttl/BUILD.bazel index e6a76c69d8df5..e5ec05168b29b 100644 --- a/ttl/BUILD.bazel +++ b/ttl/BUILD.bazel @@ -3,17 +3,25 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") go_library( name = "ttl", srcs = [ + "session.go", "sql.go", "table.go", ], importpath = "github.com/pingcap/tidb/ttl", visibility = ["//visibility:public"], deps = [ + "//infoschema", + "//kv", "//parser/ast", "//parser/format", "//parser/model", "//parser/mysql", + "//parser/terror", + "//sessionctx", + "//sessiontxn", + "//table/tables", "//types", + "//util/chunk", "//util/sqlexec", "@com_github_pingcap_errors//:errors", "@com_github_pkg_errors//:errors", @@ -24,11 +32,13 @@ go_test( name = "ttl_test", srcs = [ "main_test.go", + "session_test.go", "sql_test.go", + "table_test.go", ], + embed = [":ttl"], flaky = True, deps = [ - ":ttl", "//kv", "//parser", "//parser/ast", @@ -38,6 +48,7 @@ go_test( "//testkit/testsetup", "//types", "//util/sqlexec", + "@com_github_pingcap_errors//:errors", "@com_github_stretchr_testify//require", "@org_uber_go_goleak//:goleak", ], diff --git a/ttl/session.go b/ttl/session.go new file mode 100644 index 0000000000000..b3321e0d53c06 --- /dev/null +++ b/ttl/session.go @@ -0,0 +1,123 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ttl + +import ( + "context" + + "github.com/pingcap/errors" + "github.com/pingcap/tidb/infoschema" + "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/parser/terror" + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/sessiontxn" + "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/sqlexec" +) + +// Session is used to execute queries for TTL case +type Session interface { + sessionctx.Context + // SessionInfoSchema returns information schema of current session + SessionInfoSchema() infoschema.InfoSchema + // ExecuteSQL executes the sql + ExecuteSQL(ctx context.Context, sql string, args ...interface{}) ([]chunk.Row, error) + // RunInTxn executes the specified function in a txn + RunInTxn(ctx context.Context, fn func() error) (err error) + // Close closes the session + Close() +} + +type session struct { + sessionctx.Context + sqlExec sqlexec.SQLExecutor + closeFn func() +} + +// NewSession creates a new Session +func NewSession(sctx sessionctx.Context, sqlExec sqlexec.SQLExecutor, closeFn func()) Session { + return &session{ + Context: sctx, + sqlExec: sqlExec, + closeFn: closeFn, + } +} + +// SessionInfoSchema returns information schema of current session +func (s *session) SessionInfoSchema() infoschema.InfoSchema { + if s.Context == nil { + return nil + } + return sessiontxn.GetTxnManager(s.Context).GetTxnInfoSchema() +} + +// ExecuteSQL executes the sql +func (s *session) ExecuteSQL(ctx context.Context, sql string, args ...interface{}) ([]chunk.Row, error) { + if s.sqlExec == nil { + return nil, errors.New("session is closed") + } + + ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnTTL) + rs, err := s.sqlExec.ExecuteInternal(ctx, sql, args...) + if err != nil { + return nil, err + } + + if rs == nil { + return nil, nil + } + + defer func() { + terror.Log(rs.Close()) + }() + + return sqlexec.DrainRecordSet(ctx, rs, 8) +} + +// RunInTxn executes the specified function in a txn +func (s *session) RunInTxn(ctx context.Context, fn func() error) (err error) { + if _, err = s.ExecuteSQL(ctx, "BEGIN"); err != nil { + return err + } + + success := false + defer func() { + if !success { + _, err = s.ExecuteSQL(ctx, "ROLLBACK") + terror.Log(err) + } + }() + + if err = fn(); err != nil { + return err + } + + if _, err = s.ExecuteSQL(ctx, "COMMIT"); err != nil { + return err + } + + success = true + return err +} + +// Close closes the session +func (s *session) Close() { + if s.closeFn != nil { + s.closeFn() + s.Context = nil + s.sqlExec = nil + s.closeFn = nil + } +} diff --git a/ttl/session_test.go b/ttl/session_test.go new file mode 100644 index 0000000000000..90d47ed313e73 --- /dev/null +++ b/ttl/session_test.go @@ -0,0 +1,52 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ttl + +import ( + "context" + "testing" + + "github.com/pingcap/errors" + "github.com/pingcap/tidb/testkit" + "github.com/stretchr/testify/require" +) + +func TestSessionRunInTxn(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("create table t(id int primary key, v int)") + se := NewSession(tk.Session(), tk.Session(), nil) + tk2 := testkit.NewTestKit(t, store) + tk2.MustExec("use test") + + require.NoError(t, se.RunInTxn(context.TODO(), func() error { + tk.MustExec("insert into t values (1, 10)") + return nil + })) + tk2.MustQuery("select * from t order by id asc").Check(testkit.Rows("1 10")) + + require.NoError(t, se.RunInTxn(context.TODO(), func() error { + tk.MustExec("insert into t values (2, 20)") + return errors.New("err") + })) + tk2.MustQuery("select * from t order by id asc").Check(testkit.Rows("1 10")) + + require.NoError(t, se.RunInTxn(context.TODO(), func() error { + tk.MustExec("insert into t values (3, 30)") + return nil + })) + tk2.MustQuery("select * from t order by id asc").Check(testkit.Rows("1 10", "3 30")) +} diff --git a/ttl/sql.go b/ttl/sql.go index 9cdf762846d4d..3d100fd62eee7 100644 --- a/ttl/sql.go +++ b/ttl/sql.go @@ -31,6 +31,8 @@ import ( "github.com/pkg/errors" ) +const dateTimeFormat = "2006-01-02 15:04:05.999999" + func writeHex(in io.Writer, d types.Datum) error { _, err := fmt.Fprintf(in, "x'%s'", hex.EncodeToString(d.GetBytes())) return err @@ -179,7 +181,7 @@ func (b *SQLBuilder) WriteExpireCondition(expire time.Time) error { b.writeColNames([]*model.ColumnInfo{b.tbl.TimeColumn}, false) b.restoreCtx.WritePlain(" < ") b.restoreCtx.WritePlain("'") - b.restoreCtx.WritePlain(expire.Format("2006-01-02 15:04:05.999999")) + b.restoreCtx.WritePlain(expire.Format(dateTimeFormat)) b.restoreCtx.WritePlain("'") b.hasWriteExpireCond = true return nil diff --git a/ttl/table.go b/ttl/table.go index b9c59a34e5c17..4885da0e137b4 100644 --- a/ttl/table.go +++ b/ttl/table.go @@ -15,23 +15,115 @@ package ttl import ( + "context" + "fmt" + "time" + "github.com/pingcap/errors" + "github.com/pingcap/tidb/parser/ast" "github.com/pingcap/tidb/parser/model" + "github.com/pingcap/tidb/parser/mysql" + "github.com/pingcap/tidb/table/tables" "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/util/chunk" ) +func getTableKeyColumns(tbl *model.TableInfo) ([]*model.ColumnInfo, []*types.FieldType, error) { + if tbl.PKIsHandle { + for i, col := range tbl.Columns { + if mysql.HasPriKeyFlag(col.GetFlag()) { + return []*model.ColumnInfo{tbl.Columns[i]}, []*types.FieldType{&tbl.Columns[i].FieldType}, nil + } + } + return nil, nil, errors.Errorf("Cannot find primary key for table: %s", tbl.Name) + } + + if tbl.IsCommonHandle { + idxInfo := tables.FindPrimaryIndex(tbl) + columns := make([]*model.ColumnInfo, len(idxInfo.Columns)) + fieldTypes := make([]*types.FieldType, len(idxInfo.Columns)) + for i, idxCol := range idxInfo.Columns { + columns[i] = tbl.Columns[idxCol.Offset] + fieldTypes[i] = &tbl.Columns[idxCol.Offset].FieldType + } + return columns, fieldTypes, nil + } + + extraHandleColInfo := model.NewExtraHandleColInfo() + return []*model.ColumnInfo{extraHandleColInfo}, []*types.FieldType{&extraHandleColInfo.FieldType}, nil +} + // PhysicalTable is used to provide some information for a physical table in TTL job type PhysicalTable struct { + // Schema is the database name of the table Schema model.CIStr *model.TableInfo + // Partition is the partition name + Partition model.CIStr // PartitionDef is the partition definition PartitionDef *model.PartitionDefinition // KeyColumns is the cluster index key columns for the table KeyColumns []*model.ColumnInfo + // KeyColumnTypes is the types of the key columns + KeyColumnTypes []*types.FieldType // TimeColum is the time column used for TTL TimeColumn *model.ColumnInfo } +// NewPhysicalTable create a new PhysicalTable +func NewPhysicalTable(schema model.CIStr, tbl *model.TableInfo, partition model.CIStr) (*PhysicalTable, error) { + if tbl.State != model.StatePublic { + return nil, errors.Errorf("table '%s.%s' is not a public table", schema, tbl.Name) + } + + ttlInfo := tbl.TTLInfo + if ttlInfo == nil { + return nil, errors.Errorf("table '%s.%s' is not a ttl table", schema, tbl.Name) + } + + timeColumn := tbl.FindPublicColumnByName(ttlInfo.ColumnName.L) + if timeColumn == nil { + return nil, errors.Errorf("time column '%s' is not public in ttl table '%s.%s'", ttlInfo.ColumnName, schema, tbl.Name) + } + + keyColumns, keyColumTypes, err := getTableKeyColumns(tbl) + if err != nil { + return nil, err + } + + var partitionDef *model.PartitionDefinition + if tbl.Partition == nil { + if partition.L != "" { + return nil, errors.Errorf("table '%s.%s' is not a partitioned table", schema, tbl.Name) + } + } else { + if partition.L == "" { + return nil, errors.Errorf("partition name is required, table '%s.%s' is a partitioned table", schema, tbl.Name) + } + + for i := range tbl.Partition.Definitions { + def := &tbl.Partition.Definitions[i] + if def.Name.L == partition.L { + partitionDef = def + } + } + + if partitionDef == nil { + return nil, errors.Errorf("partition '%s' is not found in ttl table '%s.%s'", partition.O, schema, tbl.Name) + } + } + + return &PhysicalTable{ + Schema: schema, + TableInfo: tbl, + Partition: partition, + PartitionDef: partitionDef, + KeyColumns: keyColumns, + KeyColumnTypes: keyColumTypes, + TimeColumn: timeColumn, + }, nil +} + // ValidateKey validates a key func (t *PhysicalTable) ValidateKey(key []types.Datum) error { if len(t.KeyColumns) != len(key) { @@ -39,3 +131,25 @@ func (t *PhysicalTable) ValidateKey(key []types.Datum) error { } return nil } + +// EvalExpireTime returns the expired time +func (t *PhysicalTable) EvalExpireTime(ctx context.Context, se Session, now time.Time) (expire time.Time, err error) { + tz := se.GetSessionVars().TimeZone + + expireExpr := t.TTLInfo.IntervalExprStr + unit := ast.TimeUnitType(t.TTLInfo.IntervalTimeUnit) + + var rows []chunk.Row + rows, err = se.ExecuteSQL( + ctx, + // FROM_UNIXTIME does not support negative value, so we use `FROM_UNIXTIME(0) + INTERVAL ` to present current time + fmt.Sprintf("SELECT FROM_UNIXTIME(0) + INTERVAL %d SECOND - INTERVAL %s %s", now.Unix(), expireExpr, unit.String()), + ) + + if err != nil { + return + } + + tm := rows[0].GetTime(0) + return tm.CoreTime().GoTime(tz) +} diff --git a/ttl/table_test.go b/ttl/table_test.go new file mode 100644 index 0000000000000..f77556c98dc09 --- /dev/null +++ b/ttl/table_test.go @@ -0,0 +1,213 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ttl_test + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/pingcap/tidb/parser/model" + "github.com/pingcap/tidb/testkit" + "github.com/pingcap/tidb/ttl" + "github.com/stretchr/testify/require" +) + +func TestNewTTLTable(t *testing.T) { + cases := []struct { + db string + tbl string + def string + timeCol string + keyCols []string + }{ + { + db: "test", + tbl: "t1", + def: "(a int)", + }, + { + db: "test", + tbl: "ttl1", + def: "(a int, t datetime) ttl = `t` + interval 2 hour", + timeCol: "t", + keyCols: []string{"_tidb_rowid"}, + }, + { + db: "test", + tbl: "ttl2", + def: "(id int primary key, t datetime) ttl = `t` + interval 3 hour", + timeCol: "t", + keyCols: []string{"id"}, + }, + { + db: "test", + tbl: "ttl3", + def: "(a int, b varchar(32), c binary(32), t datetime, primary key (a, b, c)) ttl = `t` + interval 1 month", + timeCol: "t", + keyCols: []string{"a", "b", "c"}, + }, + { + db: "test", + tbl: "ttl4", + def: "(id int primary key, t datetime) " + + "ttl = `t` + interval 1 day " + + "PARTITION BY RANGE (id) (" + + " PARTITION p0 VALUES LESS THAN (10)," + + " PARTITION p1 VALUES LESS THAN (100)," + + " PARTITION p2 VALUES LESS THAN (1000)," + + " PARTITION p3 VALUES LESS THAN MAXVALUE)", + timeCol: "t", + keyCols: []string{"id"}, + }, + { + db: "test", + tbl: "ttl5", + def: "(id int primary key nonclustered, t datetime) ttl = `t` + interval 3 hour", + timeCol: "t", + keyCols: []string{"_tidb_rowid"}, + }, + } + + store, do := testkit.CreateMockStoreAndDomain(t) + tk := testkit.NewTestKit(t, store) + + for _, c := range cases { + tk.MustExec("use " + c.db) + tk.MustExec("create table " + c.tbl + c.def) + } + + for _, c := range cases { + is := do.InfoSchema() + tbl, err := is.TableByName(model.NewCIStr(c.db), model.NewCIStr(c.tbl)) + require.NoError(t, err) + tblInfo := tbl.Meta() + var physicalTbls []*ttl.PhysicalTable + if tblInfo.Partition == nil { + ttlTbl, err := ttl.NewPhysicalTable(model.NewCIStr(c.db), tblInfo, model.NewCIStr("")) + if c.timeCol == "" { + require.Error(t, err) + continue + } + require.NoError(t, err) + physicalTbls = append(physicalTbls, ttlTbl) + } else { + for _, partition := range tblInfo.Partition.Definitions { + ttlTbl, err := ttl.NewPhysicalTable(model.NewCIStr(c.db), tblInfo, model.NewCIStr(partition.Name.O)) + if c.timeCol == "" { + require.Error(t, err) + continue + } + require.NoError(t, err) + physicalTbls = append(physicalTbls, ttlTbl) + } + if c.timeCol == "" { + continue + } + } + + for i, ttlTbl := range physicalTbls { + require.Equal(t, c.db, ttlTbl.Schema.O) + require.Same(t, tblInfo, ttlTbl.TableInfo) + timeColumn := tblInfo.FindPublicColumnByName(c.timeCol) + require.NotNil(t, timeColumn) + require.Same(t, timeColumn, ttlTbl.TimeColumn) + + if tblInfo.Partition == nil { + require.Equal(t, "", ttlTbl.Partition.L) + require.Nil(t, ttlTbl.PartitionDef) + } else { + def := tblInfo.Partition.Definitions[i] + require.Equal(t, def.Name.L, ttlTbl.Partition.L) + require.Equal(t, def, *(ttlTbl.PartitionDef)) + } + + require.Equal(t, len(c.keyCols), len(ttlTbl.KeyColumns)) + require.Equal(t, len(c.keyCols), len(ttlTbl.KeyColumnTypes)) + + for j, keyCol := range c.keyCols { + msg := fmt.Sprintf("%s, col: %s", c.tbl, keyCol) + var col *model.ColumnInfo + if keyCol == model.ExtraHandleName.L { + col = model.NewExtraHandleColInfo() + } else { + col = tblInfo.FindPublicColumnByName(keyCol) + } + colJ := ttlTbl.KeyColumns[j] + colFieldJ := ttlTbl.KeyColumnTypes[j] + + require.NotNil(t, col, msg) + require.Equal(t, col.ID, colJ.ID, msg) + require.Equal(t, col.Name.L, colJ.Name.L, msg) + require.Equal(t, col.FieldType, colJ.FieldType, msg) + require.Equal(t, col.FieldType, *colFieldJ, msg) + } + } + } +} + +func TestEvalTTLExpireTime(t *testing.T) { + store, do := testkit.CreateMockStoreAndDomain(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("create table test.t(a int, t datetime) ttl = `t` + interval 1 day") + tk.MustExec("create table test.t2(a int, t datetime) ttl = `t` + interval 3 month") + + tb, err := do.InfoSchema().TableByName(model.NewCIStr("test"), model.NewCIStr("t")) + require.NoError(t, err) + tblInfo := tb.Meta() + ttlTbl, err := ttl.NewPhysicalTable(model.NewCIStr("test"), tblInfo, model.NewCIStr("")) + require.NoError(t, err) + + tb2, err := do.InfoSchema().TableByName(model.NewCIStr("test"), model.NewCIStr("t2")) + require.NoError(t, err) + tblInfo2 := tb2.Meta() + ttlTbl2, err := ttl.NewPhysicalTable(model.NewCIStr("test"), tblInfo2, model.NewCIStr("")) + require.NoError(t, err) + + se := ttl.NewSession(tk.Session(), tk.Session(), nil) + + now := time.UnixMilli(0) + tz1, err := time.LoadLocation("Asia/Shanghai") + require.NoError(t, err) + tz2, err := time.LoadLocation("Europe/Berlin") + require.NoError(t, err) + + se.GetSessionVars().TimeZone = tz1 + tm, err := ttlTbl.EvalExpireTime(context.TODO(), se, now) + require.NoError(t, err) + require.Equal(t, now.Add(-time.Hour*24).Unix(), tm.Unix()) + require.Equal(t, "1969-12-31 08:00:00", tm.Format("2006-01-02 15:04:05")) + require.Equal(t, tz1.String(), tm.Location().String()) + + se.GetSessionVars().TimeZone = tz2 + tm, err = ttlTbl.EvalExpireTime(context.TODO(), se, now) + require.NoError(t, err) + require.Equal(t, now.Add(-time.Hour*24).Unix(), tm.Unix()) + require.Equal(t, "1969-12-31 01:00:00", tm.Format("2006-01-02 15:04:05")) + require.Equal(t, tz2.String(), tm.Location().String()) + + se.GetSessionVars().TimeZone = tz1 + tm, err = ttlTbl2.EvalExpireTime(context.TODO(), se, now) + require.NoError(t, err) + require.Equal(t, "1969-10-01 08:00:00", tm.Format("2006-01-02 15:04:05")) + require.Equal(t, tz1.String(), tm.Location().String()) + + se.GetSessionVars().TimeZone = tz2 + tm, err = ttlTbl2.EvalExpireTime(context.TODO(), se, now) + require.NoError(t, err) + require.Equal(t, "1969-10-01 01:00:00", tm.Format("2006-01-02 15:04:05")) + require.Equal(t, tz2.String(), tm.Location().String()) +}