Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
lcwangchao committed Nov 29, 2022
1 parent da74f02 commit 6a27b83
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 34 deletions.
64 changes: 43 additions & 21 deletions ttl/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,37 +18,59 @@ import (
"context"

"github.com/pingcap/errors"
"github.com/pingcap/tidb/infoschema"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/parser/terror"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/sessiontxn"
"github.com/pingcap/tidb/util/chunk"
"github.com/pingcap/tidb/util/sqlexec"
)

// Session is used to execute queries for TTL case
type Session struct {
Sctx sessionctx.Context
SQLExec sqlexec.SQLExecutor
CloseFn func()
type Session interface {
sessionctx.Context
// SessionInfoSchema returns information schema of current session
SessionInfoSchema() infoschema.InfoSchema
// ExecuteSQL executes the sql
ExecuteSQL(ctx context.Context, sql string, args ...interface{}) ([]chunk.Row, error)
// RunInTxn executes the specified function in a txn
RunInTxn(ctx context.Context, fn func() error) (err error)
// Close closes the session
Close()
}

// GetSessionVars returns the sessionVars
func (s *Session) GetSessionVars() *variable.SessionVars {
if s.Sctx != nil {
return s.Sctx.GetSessionVars()
type session struct {
sessionctx.Context
sqlExec sqlexec.SQLExecutor
closeFn func()
}

// NewSession creates a new Session
func NewSession(sctx sessionctx.Context, sqlExec sqlexec.SQLExecutor, closeFn func()) Session {
return &session{
Context: sctx,
sqlExec: sqlExec,
closeFn: closeFn,
}
}

// SessionInfoSchema returns information schema of current session
func (s *session) SessionInfoSchema() infoschema.InfoSchema {
if s.Context == nil {
return nil
}
return nil
return sessiontxn.GetTxnManager(s.Context).GetTxnInfoSchema()
}

// ExecuteSQL executes the sql
func (s *Session) ExecuteSQL(ctx context.Context, sql string, args ...interface{}) ([]chunk.Row, error) {
if s.SQLExec == nil {
func (s *session) ExecuteSQL(ctx context.Context, sql string, args ...interface{}) ([]chunk.Row, error) {
if s.sqlExec == nil {
return nil, errors.New("session is closed")
}

ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnTTL)
rs, err := s.SQLExec.ExecuteInternal(ctx, sql, args...)
rs, err := s.sqlExec.ExecuteInternal(ctx, sql, args...)
if err != nil {
return nil, err
}
Expand All @@ -65,7 +87,7 @@ func (s *Session) ExecuteSQL(ctx context.Context, sql string, args ...interface{
}

// RunInTxn executes the specified function in a txn
func (s *Session) RunInTxn(ctx context.Context, fn func() error) (err error) {
func (s *session) RunInTxn(ctx context.Context, fn func() error) (err error) {
if _, err = s.ExecuteSQL(ctx, "BEGIN"); err != nil {
return err
}
Expand All @@ -90,12 +112,12 @@ func (s *Session) RunInTxn(ctx context.Context, fn func() error) (err error) {
return err
}

// Close closed the session
func (s *Session) Close() {
if s.CloseFn != nil {
s.CloseFn()
s.Sctx = nil
s.SQLExec = nil
s.CloseFn = nil
// Close closes the session
func (s *session) Close() {
if s.closeFn != nil {
s.closeFn()
s.Context = nil
s.sqlExec = nil
s.closeFn = nil
}
}
5 changes: 1 addition & 4 deletions ttl/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,7 @@ func TestSessionRunInTxn(t *testing.T) {
tk := testkit.NewTestKit(t, store)
tk.MustExec("use test")
tk.MustExec("create table t(id int primary key, v int)")
se := &Session{
Sctx: tk.Session(),
SQLExec: tk.Session(),
}
se := NewSession(tk.Session(), tk.Session(), nil)
tk2 := testkit.NewTestKit(t, store)
tk2.MustExec("use test")

Expand Down
2 changes: 1 addition & 1 deletion ttl/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ func (t *PhysicalTable) ValidateKey(key []types.Datum) error {
}

// EvalExpireTime returns the expired time
func (t *PhysicalTable) EvalExpireTime(ctx context.Context, se *Session, now time.Time) (expire time.Time, err error) {
func (t *PhysicalTable) EvalExpireTime(ctx context.Context, se Session, now time.Time) (expire time.Time, err error) {
tz := se.GetSessionVars().TimeZone

expireExpr := t.TTLInfo.IntervalExprStr
Expand Down
13 changes: 5 additions & 8 deletions ttl/table_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,38 +177,35 @@ func TestEvalTTLExpireTime(t *testing.T) {
ttlTbl2, err := ttl.NewPhysicalTable(model.NewCIStr("test"), tblInfo2, model.NewCIStr(""))
require.NoError(t, err)

se := &ttl.Session{
Sctx: tk.Session(),
SQLExec: tk.Session(),
}
se := ttl.NewSession(tk.Session(), tk.Session(), nil)

now := time.UnixMilli(0)
tz1, err := time.LoadLocation("Asia/Shanghai")
require.NoError(t, err)
tz2, err := time.LoadLocation("Europe/Berlin")
require.NoError(t, err)

se.Sctx.GetSessionVars().TimeZone = tz1
se.GetSessionVars().TimeZone = tz1
tm, err := ttlTbl.EvalExpireTime(context.TODO(), se, now)
require.NoError(t, err)
require.Equal(t, now.Add(-time.Hour*24).Unix(), tm.Unix())
require.Equal(t, "1969-12-31 08:00:00", tm.Format("2006-01-02 15:04:05"))
require.Equal(t, tz1.String(), tm.Location().String())

se.Sctx.GetSessionVars().TimeZone = tz2
se.GetSessionVars().TimeZone = tz2
tm, err = ttlTbl.EvalExpireTime(context.TODO(), se, now)
require.NoError(t, err)
require.Equal(t, now.Add(-time.Hour*24).Unix(), tm.Unix())
require.Equal(t, "1969-12-31 01:00:00", tm.Format("2006-01-02 15:04:05"))
require.Equal(t, tz2.String(), tm.Location().String())

se.Sctx.GetSessionVars().TimeZone = tz1
se.GetSessionVars().TimeZone = tz1
tm, err = ttlTbl2.EvalExpireTime(context.TODO(), se, now)
require.NoError(t, err)
require.Equal(t, "1969-10-01 08:00:00", tm.Format("2006-01-02 15:04:05"))
require.Equal(t, tz1.String(), tm.Location().String())

se.Sctx.GetSessionVars().TimeZone = tz2
se.GetSessionVars().TimeZone = tz2
tm, err = ttlTbl2.EvalExpireTime(context.TODO(), se, now)
require.NoError(t, err)
require.Equal(t, "1969-10-01 01:00:00", tm.Format("2006-01-02 15:04:05"))
Expand Down

0 comments on commit 6a27b83

Please sign in to comment.