Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

domain: Optimize GetDomain api #58550

Merged
merged 7 commits into from
Dec 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 4 additions & 19 deletions pkg/domain/domainctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,11 @@ import (
contextutil "github.com/pingcap/tidb/pkg/util/context"
)

// domainKeyType is a dummy type to avoid naming collision in context.
type domainKeyType int

// String defines a Stringer function for debugging and pretty printing.
func (domainKeyType) String() string {
return "domain"
}

const domainKey domainKeyType = 0

// BindDomain binds domain to context.
func BindDomain(ctx contextutil.ValueStoreContext, domain *Domain) {
ctx.SetValue(domainKey, domain)
}

// GetDomain gets domain from context.
func GetDomain(ctx contextutil.ValueStoreContext) *Domain {
v, ok := ctx.Value(domainKey).(*Domain)
if !ok {
return nil
v, ok := ctx.GetDomain().(*Domain)
if ok {
return v
}
return v
return nil
}
8 changes: 3 additions & 5 deletions pkg/domain/domainctx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,11 @@ import (

func TestDomainCtx(t *testing.T) {
ctx := mock.NewContext()
require.NotEqual(t, "", domainKey.String())

BindDomain(ctx, nil)
ctx.BindDomain(nil)
v := GetDomain(ctx)
require.Nil(t, v)

ctx.ClearValue(domainKey)
ctx.BindDomain(&Domain{})
v = GetDomain(ctx)
require.Nil(t, v)
require.NotNil(t, v)
}
2 changes: 1 addition & 1 deletion pkg/executor/executor_pkg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ func TestFilterTemporaryTableKeys(t *testing.T) {

func TestErrLevelsForResetStmtContext(t *testing.T) {
ctx := mock.NewContext()
domain.BindDomain(ctx, &domain.Domain{})
ctx.BindDomain(&domain.Domain{})

cases := []struct {
name string
Expand Down
2 changes: 1 addition & 1 deletion pkg/executor/executor_required_rows_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ func defaultCtx() sessionctx.Context {
ctx.GetSessionVars().StmtCtx.MemTracker = memory.NewTracker(-1, ctx.GetSessionVars().MemQuotaQuery)
ctx.GetSessionVars().StmtCtx.DiskTracker = disk.NewTracker(-1, -1)
ctx.GetSessionVars().SnapshotTS = uint64(1)
domain.BindDomain(ctx, domain.NewMockDomain())
ctx.BindDomain(domain.NewMockDomain())
return ctx
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/executor/join/joiner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func defaultCtx() sessionctx.Context {
ctx.GetSessionVars().StmtCtx.MemTracker = memory.NewTracker(-1, ctx.GetSessionVars().MemQuotaQuery)
ctx.GetSessionVars().StmtCtx.DiskTracker = disk.NewTracker(-1, -1)
ctx.GetSessionVars().SnapshotTS = uint64(1)
domain.BindDomain(ctx, domain.NewMockDomain())
ctx.BindDomain(domain.NewMockDomain())
return ctx
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/executor/select_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import (
func BenchmarkResetContextOfStmt(b *testing.B) {
stmt := &ast.SelectStmt{}
ctx := mock.NewContext()
domain.BindDomain(ctx, &domain.Domain{})
ctx.BindDomain(&domain.Domain{})
for i := 0; i < b.N; i++ {
executor.ResetContextOfStmt(ctx, stmt)
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/planner/core/logical_plans_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ func createPlannerSuite() (s *plannerSuite) {
if err := do.CreateStatsHandle(ctx, initStatsCtx); err != nil {
panic(fmt.Sprintf("create mock context panic: %+v", err))
}
domain.BindDomain(ctx, do)
ctx.BindDomain(do)
ctx.SetInfoSchema(s.is)
s.ctx = ctx
s.sctx = ctx
Expand Down
2 changes: 1 addition & 1 deletion pkg/planner/core/mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ func MockContext() *mock.Context {
if err := do.CreateStatsHandle(ctx, initStatsCtx); err != nil {
panic(fmt.Sprintf("create mock context panic: %+v", err))
}
domain.BindDomain(ctx, do)
ctx.BindDomain(do)
return ctx
}

Expand Down
4 changes: 2 additions & 2 deletions pkg/session/bootstrap_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ func TestBootstrapWithError(t *testing.T) {
dom, err := domap.Get(store)
require.NoError(t, err)
require.NoError(t, dom.Start(ddl.Bootstrap))
domain.BindDomain(se, dom)
se.dom = dom
b, err := checkBootstrapped(se)
require.False(t, b)
require.NoError(t, err)
Expand Down Expand Up @@ -2409,7 +2409,7 @@ func TestTiDBUpgradeToVer211(t *testing.T) {
require.NoError(t, err)
require.Less(t, int64(ver210), ver)

domain.BindDomain(seV210, dom)
seV210.(*session).dom = dom
r := MustExecToRecodeSet(t, seV210, "select count(summary) from mysql.tidb_background_subtask_history;")
req := r.NewChunk(nil)
err = r.Next(ctx, req)
Expand Down
11 changes: 9 additions & 2 deletions pkg/session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@ type session struct {
currentCtx context.Context // only use for runtime.trace, Please NEVER use it.
currentPlan base.Plan

// dom is *domain.Domain, use `any` to avoid import cycle.
dom any
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe leave a TODO to replace the usage of any in the future, it will reduce code maintainability.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I add a comment here.

store kv.Storage

sessionPlanCache sessionctx.SessionPlanCache
Expand Down Expand Up @@ -3789,6 +3791,7 @@ func createSessionWithOpt(store kv.Storage, opt *Opt) (*session, error) {
return nil, err
}
s := &session{
dom: dom,
store: store,
ddlOwnerManager: dom.DDL().OwnerManager(),
client: store.GetClient(),
Expand All @@ -3808,7 +3811,6 @@ func createSessionWithOpt(store kv.Storage, opt *Opt) (*session, error) {
s.lockedTables = make(map[int64]model.TableLockTpInfo)
s.advisoryLocks = make(map[string]*advisoryLock)

domain.BindDomain(s, dom)
// session implements variable.GlobalVarAccessor. Bind it to ctx.
s.sessionVars.GlobalVarsAccessor = s
s.txn.init()
Expand Down Expand Up @@ -3852,6 +3854,7 @@ func detachStatsCollector(s *session) *session {
// a lock context, which cause we can't call createSession directly.
func CreateSessionWithDomain(store kv.Storage, dom *domain.Domain) (*session, error) {
s := &session{
dom: dom,
store: store,
sessionVars: variable.NewSessionVars(nil),
client: store.GetClient(),
Expand All @@ -3864,7 +3867,6 @@ func CreateSessionWithDomain(store kv.Storage, dom *domain.Domain) (*session, er
s.tblctx = tblsession.NewMutateContext(s)
s.mu.values = make(map[fmt.Stringer]any)
s.lockedTables = make(map[int64]model.TableLockTpInfo)
domain.BindDomain(s, dom)
// session implements variable.GlobalVarAccessor. Bind it to ctx.
s.sessionVars.GlobalVarsAccessor = s
s.txn.init()
Expand Down Expand Up @@ -4674,3 +4676,8 @@ func (s *session) GetCursorTracker() cursor.Tracker {
func (s *session) GetCommitWaitGroup() *sync.WaitGroup {
return &s.commitWaitGroup
}

// GetDomain get domain from session.
func (s *session) GetDomain() any {
return s.dom
}
3 changes: 3 additions & 0 deletions pkg/util/context/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ type ValueStoreContext interface {

// ClearValue clears the value associated with this context for key.
ClearValue(key fmt.Stringer)

// GetDomain returns the domain.
GetDomain() any
}

var contextIDGenerator atomic.Uint64
Expand Down
13 changes: 12 additions & 1 deletion pkg/util/mock/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ var (
type Context struct {
planctx.EmptyPlanContextExtended
*sessionexpr.ExprContext
txn wrapTxn // mock global variable
txn wrapTxn // mock global variable
dom any
Store kv.Storage // mock global variable
ctx context.Context
sm util.SessionManager
Expand Down Expand Up @@ -639,6 +640,16 @@ func (*Context) GetCommitWaitGroup() *sync.WaitGroup {
return nil
}

// BindDomain bind domain into ctx.
func (c *Context) BindDomain(dom any) {
c.dom = dom
}

// GetDomain get domain from ctx.
func (c *Context) GetDomain() any {
return c.dom
}

// NewContextDeprecated creates a new mocked sessionctx.Context.
// Deprecated: This method is only used for some legacy code.
// DO NOT use mock.Context in new production code, and use the real Context instead.
Expand Down