Skip to content

Commit

Permalink
planner: simplify the binding interface (#49987)
Browse files Browse the repository at this point in the history
ref #48875
  • Loading branch information
qw4990 authored Jan 3, 2024
1 parent 2065be4 commit 965ad8a
Show file tree
Hide file tree
Showing 9 changed files with 46 additions and 71 deletions.
6 changes: 3 additions & 3 deletions pkg/bindinfo/binding_match.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,20 +50,20 @@ func getBindRecord(ctx sessionctx.Context, stmt ast.StmtNode) (*BindRecord, stri
if ctx.Value(SessionBindInfoKeyType) == nil {
return nil, "", nil
}
stmtNode, normalizedSQL, sqlDigest, err := normalizeStmt(stmt, ctx.GetSessionVars().CurrentDB)
stmtNode, _, sqlDigest, err := normalizeStmt(stmt, ctx.GetSessionVars().CurrentDB)
if err != nil || stmtNode == nil {
return nil, "", err
}
// the priority: session normal > session universal > global normal > global universal
sessionHandle := ctx.Value(SessionBindInfoKeyType).(SessionBindingHandle)
if bindRecord := sessionHandle.GetSessionBinding(sqlDigest, normalizedSQL, ""); bindRecord != nil && bindRecord.HasEnabledBinding() {
if bindRecord := sessionHandle.GetSessionBinding(sqlDigest); bindRecord != nil && bindRecord.HasEnabledBinding() {
return bindRecord, metrics.ScopeSession, nil
}
globalHandle := GetGlobalBindingHandle(ctx)
if globalHandle == nil {
return nil, "", nil
}
if bindRecord := globalHandle.GetGlobalBinding(sqlDigest, normalizedSQL, ""); bindRecord != nil && bindRecord.HasEnabledBinding() {
if bindRecord := globalHandle.GetGlobalBinding(sqlDigest); bindRecord != nil && bindRecord.HasEnabledBinding() {
return bindRecord, metrics.ScopeGlobal, nil
}
return nil, "", nil
Expand Down
2 changes: 1 addition & 1 deletion pkg/bindinfo/capture.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ func (h *globalBindingHandle) CaptureBaselines() {
}
dbName := utilparser.GetDefaultDB(stmt, bindableStmt.Schema)
normalizedSQL, digest := parser.NormalizeDigest(utilparser.RestoreWithDefaultDB(stmt, dbName, bindableStmt.Query))
if r := h.GetGlobalBinding(digest.String(), normalizedSQL, dbName); r != nil && r.HasAvailableBinding() {
if r := h.GetGlobalBinding(digest.String()); r != nil && r.HasAvailableBinding() {
continue
}
bindSQL := GenerateBindSQL(context.TODO(), stmt, bindableStmt.PlanHint, true, dbName)
Expand Down
8 changes: 4 additions & 4 deletions pkg/bindinfo/capture_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -328,8 +328,8 @@ func TestBindingSource(t *testing.T) {
// Test Source for SQL created sql
tk.MustExec("create global binding for select * from t where a > 10 using select * from t ignore index(idx_a) where a > 10")
bindHandle := dom.BindHandle()
sql, sqlDigest := internal.UtilNormalizeWithDefaultDB(t, "select * from t where a > ?")
bindData := bindHandle.GetGlobalBinding(sqlDigest, sql, "test")
_, sqlDigest := internal.UtilNormalizeWithDefaultDB(t, "select * from t where a > ?")
bindData := bindHandle.GetGlobalBinding(sqlDigest)
require.NotNil(t, bindData)
require.Equal(t, "select * from `test` . `t` where `a` > ?", bindData.OriginalSQL)
require.Len(t, bindData.Bindings, 1)
Expand All @@ -348,8 +348,8 @@ func TestBindingSource(t *testing.T) {
tk.MustExec("select * from t ignore index(idx_a) where a < 10")
tk.MustExec("admin capture bindings")
bindHandle.CaptureBaselines()
sql, sqlDigest = internal.UtilNormalizeWithDefaultDB(t, "select * from t where a < ?")
bindData = bindHandle.GetGlobalBinding(sqlDigest, sql, "test")
_, sqlDigest = internal.UtilNormalizeWithDefaultDB(t, "select * from t where a < ?")
bindData = bindHandle.GetGlobalBinding(sqlDigest)
require.NotNil(t, bindData)
require.Equal(t, "select * from `test` . `t` where `a` < ?", bindData.OriginalSQL)
require.Len(t, bindData.Bindings, 1)
Expand Down
24 changes: 5 additions & 19 deletions pkg/bindinfo/global_handle.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,7 @@ type GlobalBindingHandle interface {
// Methods for create, get, drop global sql bindings.

// GetGlobalBinding returns the BindRecord of the (normalizedSQL,db) if BindRecord exist.
GetGlobalBinding(sqlDigest, normalizedSQL, db string) *BindRecord

// GetGlobalBindingBySQLDigest returns the BindRecord of the sql digest.
GetGlobalBindingBySQLDigest(sqlDigest string) (*BindRecord, error)
GetGlobalBinding(sqlDigest string) *BindRecord

// GetAllGlobalBindings returns all bind records in cache.
GetAllGlobalBindings() (bindRecords []*BindRecord)
Expand Down Expand Up @@ -367,10 +364,7 @@ func (h *globalBindingHandle) DropGlobalBindingByDigest(sqlDigest string) (delet
if sqlDigest == "" {
return 0, errors.New("sql digest is empty")
}
oldRecord, err := h.GetGlobalBindingBySQLDigest(sqlDigest)
if err != nil {
return 0, err
}
oldRecord := h.GetGlobalBinding(sqlDigest)
if oldRecord == nil {
return 0, errors.Errorf("can't find any binding for '%s'", sqlDigest)
}
Expand Down Expand Up @@ -405,7 +399,7 @@ func (h *globalBindingHandle) SetGlobalBindingStatus(originalSQL string, binding
ok = true
record := &BindRecord{OriginalSQL: originalSQL}
sqlDigest := parser.DigestNormalized(record.OriginalSQL)
oldRecord := h.GetGlobalBinding(sqlDigest.String(), originalSQL, "")
oldRecord := h.GetGlobalBinding(sqlDigest.String())
setBindingStatusInCacheSucc := false
if oldRecord != nil && len(oldRecord.Bindings) > 0 {
record.Bindings = make([]Binding, len(oldRecord.Bindings))
Expand Down Expand Up @@ -448,10 +442,7 @@ func (h *globalBindingHandle) SetGlobalBindingStatus(originalSQL string, binding

// SetGlobalBindingStatusByDigest set a BindRecord's status to the storage and bind cache.
func (h *globalBindingHandle) SetGlobalBindingStatusByDigest(newStatus, sqlDigest string) (ok bool, err error) {
oldRecord, err := h.GetGlobalBindingBySQLDigest(sqlDigest)
if err != nil {
return false, err
}
oldRecord := h.GetGlobalBinding(sqlDigest)
if oldRecord == nil {
return false, errors.Errorf("can't find any binding for '%s'", sqlDigest)
}
Expand Down Expand Up @@ -563,15 +554,10 @@ func (h *globalBindingHandle) Size() int {
}

// GetGlobalBinding returns the BindRecord of the (normalizedSQL,db) if BindRecord exist.
func (h *globalBindingHandle) GetGlobalBinding(sqlDigest, _, _ string) *BindRecord {
func (h *globalBindingHandle) GetGlobalBinding(sqlDigest string) *BindRecord {
return h.getCache().GetBinding(sqlDigest)
}

// GetGlobalBindingBySQLDigest returns the BindRecord of the sql digest.
func (h *globalBindingHandle) GetGlobalBindingBySQLDigest(sqlDigest string) (*BindRecord, error) {
return h.getCache().GetBinding(sqlDigest), nil
}

// GetAllGlobalBindings returns all bind records in cache.
func (h *globalBindingHandle) GetAllGlobalBindings() (bindRecords []*BindRecord) {
return h.getCache().GetAllBindings()
Expand Down
18 changes: 9 additions & 9 deletions pkg/bindinfo/global_handle_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ func TestBindingLastUpdateTime(t *testing.T) {
bindHandle := bindinfo.NewGlobalBindingHandle(&mockSessionPool{tk.Session()})
err := bindHandle.Update(true)
require.NoError(t, err)
sql, sqlDigest := parser.NormalizeDigest("select * from test . t0")
bindData := bindHandle.GetGlobalBinding(sqlDigest.String(), sql, "test")
_, sqlDigest := parser.NormalizeDigest("select * from test . t0")
bindData := bindHandle.GetGlobalBinding(sqlDigest.String())
require.Equal(t, 1, len(bindData.Bindings))
bind := bindData.Bindings[0]
updateTime := bind.UpdateTime.String()
Expand Down Expand Up @@ -133,8 +133,8 @@ func TestBindParse(t *testing.T) {
require.NoError(t, err)
require.Equal(t, 1, bindHandle.Size())

sql, sqlDigest := parser.NormalizeDigest("select * from test . t")
bindData := bindHandle.GetGlobalBinding(sqlDigest.String(), sql, "test")
_, sqlDigest := parser.NormalizeDigest("select * from test . t")
bindData := bindHandle.GetGlobalBinding(sqlDigest.String())
require.NotNil(t, bindData)
require.Equal(t, "select * from `test` . `t`", bindData.OriginalSQL)
bind := bindData.Bindings[0]
Expand Down Expand Up @@ -447,9 +447,9 @@ func TestGlobalBinding(t *testing.T) {
require.NoError(t, err)
require.Equal(t, testSQL.memoryUsage, pb.GetGauge().GetValue())

sql, sqlDigest := internal.UtilNormalizeWithDefaultDB(t, testSQL.querySQL)
_, sqlDigest := internal.UtilNormalizeWithDefaultDB(t, testSQL.querySQL)

bindData := dom.BindHandle().GetGlobalBinding(sqlDigest, sql, "test")
bindData := dom.BindHandle().GetGlobalBinding(sqlDigest)
require.NotNil(t, bindData)
require.Equal(t, testSQL.originSQL, bindData.OriginalSQL)
bind := bindData.Bindings[0]
Expand Down Expand Up @@ -482,7 +482,7 @@ func TestGlobalBinding(t *testing.T) {
require.NoError(t, err)
require.Equal(t, 1, bindHandle.Size())

bindData = bindHandle.GetGlobalBinding(sqlDigest, sql, "test")
bindData = bindHandle.GetGlobalBinding(sqlDigest)
require.NotNil(t, bindData)
require.Equal(t, testSQL.originSQL, bindData.OriginalSQL)
bind = bindData.Bindings[0]
Expand All @@ -497,7 +497,7 @@ func TestGlobalBinding(t *testing.T) {
_, err = tk.Exec("drop global " + testSQL.dropSQL)
require.Equal(t, uint64(1), tk.Session().AffectedRows())
require.NoError(t, err)
bindData = dom.BindHandle().GetGlobalBinding(sqlDigest, sql, "test")
bindData = dom.BindHandle().GetGlobalBinding(sqlDigest)
require.Nil(t, bindData)

err = metrics.BindTotalGauge.WithLabelValues(metrics.ScopeGlobal, bindinfo.Enabled).Write(pb)
Expand All @@ -513,7 +513,7 @@ func TestGlobalBinding(t *testing.T) {
require.NoError(t, err)
require.Equal(t, 0, bindHandle.Size())

bindData = bindHandle.GetGlobalBinding(sqlDigest, sql, "test")
bindData = bindHandle.GetGlobalBinding(sqlDigest)
require.Nil(t, bindData)

rs, err = tk.Exec("show global bindings")
Expand Down
21 changes: 5 additions & 16 deletions pkg/bindinfo/session_handle.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,7 @@ type SessionBindingHandle interface {
DropSessionBindingByDigest(sqlDigest string) error

// GetSessionBinding return the binding which can match the digest.
GetSessionBinding(sqlDigest, normdOrigSQL, db string) *BindRecord

// GetSessionBindingBySQLDigest return all bindings which can match the digest.
GetSessionBindingBySQLDigest(sqlDigest string) (*BindRecord, error)
GetSessionBinding(sqlDigest string) *BindRecord

// GetAllSessionBindings return all bindings.
GetAllSessionBindings() (bindRecords []*BindRecord)
Expand Down Expand Up @@ -104,7 +101,7 @@ func (h *sessionBindingHandle) CreateSessionBinding(sctx sessionctx.Context, rec
func (h *sessionBindingHandle) DropSessionBinding(originalSQL, db string, binding *Binding) error {
db = strings.ToLower(db)
sqlDigest := parser.DigestNormalized(originalSQL).String()
oldRecord := h.GetSessionBinding(sqlDigest, originalSQL, db)
oldRecord := h.GetSessionBinding(sqlDigest)
var newRecord *BindRecord
record := &BindRecord{OriginalSQL: originalSQL, Db: db}
if binding != nil {
Expand All @@ -129,26 +126,18 @@ func (h *sessionBindingHandle) DropSessionBindingByDigest(sqlDigest string) erro
if sqlDigest == "" {
return errors.New("sql digest is empty")
}
oldRecord, err := h.GetSessionBindingBySQLDigest(sqlDigest)
if err != nil {
return err
}
oldRecord := h.GetSessionBinding(sqlDigest)
if oldRecord == nil {
return errors.Errorf("can't find any binding for '%s'", sqlDigest)
}
return h.DropSessionBinding(oldRecord.OriginalSQL, strings.ToLower(oldRecord.Db), nil)
}

// GetSessionBinding return the BindMeta of the (normdOrigSQL,db) if BindMeta exist.
func (h *sessionBindingHandle) GetSessionBinding(sqlDigest, _, _ string) *BindRecord {
// GetSessionBinding return all BindMeta corresponding to sqlDigest.
func (h *sessionBindingHandle) GetSessionBinding(sqlDigest string) *BindRecord {
return h.ch.GetBinding(sqlDigest)
}

// GetSessionBindingBySQLDigest return all BindMeta corresponding to sqlDigest.
func (h *sessionBindingHandle) GetSessionBindingBySQLDigest(sqlDigest string) (*BindRecord, error) {
return h.ch.GetBinding(sqlDigest), nil
}

// GetAllSessionBindings return all session bind info.
func (h *sessionBindingHandle) GetAllSessionBindings() (bindRecords []*BindRecord) {
return h.ch.GetAllBindings()
Expand Down
4 changes: 2 additions & 2 deletions pkg/bindinfo/session_handle_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ func TestSessionBinding(t *testing.T) {

handle := tk.Session().Value(bindinfo.SessionBindInfoKeyType).(bindinfo.SessionBindingHandle)
sqlDigest := parser.DigestNormalized(testSQL.originSQL).String()
bindData := handle.GetSessionBinding(sqlDigest, testSQL.originSQL, "test")
bindData := handle.GetSessionBinding(sqlDigest)
require.NotNil(t, bindData)
require.Equal(t, testSQL.originSQL, bindData.OriginalSQL)
bind := bindData.Bindings[0]
Expand Down Expand Up @@ -148,7 +148,7 @@ func TestSessionBinding(t *testing.T) {

_, err = tk.Exec("drop session " + testSQL.dropSQL)
require.NoError(t, err)
bindData = handle.GetSessionBinding(sqlDigest, testSQL.originSQL, "test")
bindData = handle.GetSessionBinding(sqlDigest)
require.NotNil(t, bindData)
require.Equal(t, testSQL.originSQL, bindData.OriginalSQL)
require.Len(t, bindData.Bindings, 0)
Expand Down
30 changes: 15 additions & 15 deletions pkg/bindinfo/tests/bind_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -317,9 +317,9 @@ func TestBindingSymbolList(t *testing.T) {
require.True(t, tk.MustUseIndex("select a, b from t where a = 3 limit 1, 100", "ib(b)"))

// Normalize
sql, hash := parser.NormalizeDigestForBinding("select a, b from test . t where a = 1 limit 0, 1")
_, hash := parser.NormalizeDigestForBinding("select a, b from test . t where a = 1 limit 0, 1")

bindData := dom.BindHandle().GetGlobalBinding(hash.String(), sql, "test")
bindData := dom.BindHandle().GetGlobalBinding(hash.String())
require.NotNil(t, bindData)
require.Equal(t, "select `a` , `b` from `test` . `t` where `a` = ? limit ...", bindData.OriginalSQL)
bind := bindData.Bindings[0]
Expand Down Expand Up @@ -361,9 +361,9 @@ func TestBindingInListWithSingleLiteral(t *testing.T) {
tk.MustQuery("select @@last_plan_from_binding").Check(testkit.Rows("1"))

// Normalize
sql, hash := parser.NormalizeDigestForBinding("select a, b from test . t where a in (1)")
_, hash := parser.NormalizeDigestForBinding("select a, b from test . t where a in (1)")

bindData := dom.BindHandle().GetGlobalBinding(hash.String(), sql, "test")
bindData := dom.BindHandle().GetGlobalBinding(hash.String())
require.NotNil(t, bindData)
require.Equal(t, "select `a` , `b` from `test` . `t` where `a` in ( ... )", bindData.OriginalSQL)
bind := bindData.Bindings[0]
Expand Down Expand Up @@ -397,8 +397,8 @@ func TestBestPlanInBaselines(t *testing.T) {
tk.MustExec(`create global binding for select a, b from t where a = 1 limit 0, 1 using select /*+ use_index(@sel_1 test.t ia) */ a, b from t where a = 1 limit 0, 1`)
tk.MustExec(`create global binding for select a, b from t where b = 1 limit 0, 1 using select /*+ use_index(@sel_1 test.t ib) */ a, b from t where b = 1 limit 0, 1`)

sql, hash := internal.UtilNormalizeWithDefaultDB(t, "select a, b from t where a = 1 limit 0, 1")
bindData := dom.BindHandle().GetGlobalBinding(hash, sql, "test")
_, hash := internal.UtilNormalizeWithDefaultDB(t, "select a, b from t where a = 1 limit 0, 1")
bindData := dom.BindHandle().GetGlobalBinding(hash)
require.NotNil(t, bindData)
require.Equal(t, "select `a` , `b` from `test` . `t` where `a` = ? limit ...", bindData.OriginalSQL)
bind := bindData.Bindings[0]
Expand Down Expand Up @@ -430,8 +430,8 @@ func TestErrorBind(t *testing.T) {
_, err := tk.Exec("create global binding for select * from t where i>100 using select * from t use index(index_t) where i>100")
require.NoError(t, err, "err %v", err)

sql, hash := parser.NormalizeDigestForBinding("select * from test . t where i > ?")
bindData := dom.BindHandle().GetGlobalBinding(hash.String(), sql, "test")
_, hash := parser.NormalizeDigestForBinding("select * from test . t where i > ?")
bindData := dom.BindHandle().GetGlobalBinding(hash.String())
require.NotNil(t, bindData)
require.Equal(t, "select * from `test` . `t` where `i` > ?", bindData.OriginalSQL)
bind := bindData.Bindings[0]
Expand Down Expand Up @@ -487,8 +487,8 @@ func TestHintsSetID(t *testing.T) {
tk.MustExec("create global binding for select * from t where a > 10 using select /*+ use_index(test.t, idx_a) */ * from t where a > 10")
bindHandle := dom.BindHandle()
// Verify the added Binding contains ID with restored query block.
sql, hash := internal.UtilNormalizeWithDefaultDB(t, "select * from t where a > ?")
bindData := bindHandle.GetGlobalBinding(hash, sql, "test")
_, hash := internal.UtilNormalizeWithDefaultDB(t, "select * from t where a > ?")
bindData := bindHandle.GetGlobalBinding(hash)
require.NotNil(t, bindData)
require.Equal(t, "select * from `test` . `t` where `a` > ?", bindData.OriginalSQL)
require.Len(t, bindData.Bindings, 1)
Expand All @@ -497,7 +497,7 @@ func TestHintsSetID(t *testing.T) {

internal.UtilCleanBindingEnv(tk, dom)
tk.MustExec("create global binding for select * from t where a > 10 using select /*+ use_index(t, idx_a) */ * from t where a > 10")
bindData = bindHandle.GetGlobalBinding(hash, sql, "test")
bindData = bindHandle.GetGlobalBinding(hash)
require.NotNil(t, bindData)
require.Equal(t, "select * from `test` . `t` where `a` > ?", bindData.OriginalSQL)
require.Len(t, bindData.Bindings, 1)
Expand All @@ -506,7 +506,7 @@ func TestHintsSetID(t *testing.T) {

internal.UtilCleanBindingEnv(tk, dom)
tk.MustExec("create global binding for select * from t where a > 10 using select /*+ use_index(@sel_1 t, idx_a) */ * from t where a > 10")
bindData = bindHandle.GetGlobalBinding(hash, sql, "test")
bindData = bindHandle.GetGlobalBinding(hash)
require.NotNil(t, bindData)
require.Equal(t, "select * from `test` . `t` where `a` > ?", bindData.OriginalSQL)
require.Len(t, bindData.Bindings, 1)
Expand All @@ -515,7 +515,7 @@ func TestHintsSetID(t *testing.T) {

internal.UtilCleanBindingEnv(tk, dom)
tk.MustExec("create global binding for select * from t where a > 10 using select /*+ use_index(@qb1 t, idx_a) qb_name(qb1) */ * from t where a > 10")
bindData = bindHandle.GetGlobalBinding(hash, sql, "test")
bindData = bindHandle.GetGlobalBinding(hash)
require.NotNil(t, bindData)
require.Equal(t, "select * from `test` . `t` where `a` > ?", bindData.OriginalSQL)
require.Len(t, bindData.Bindings, 1)
Expand All @@ -524,7 +524,7 @@ func TestHintsSetID(t *testing.T) {

internal.UtilCleanBindingEnv(tk, dom)
tk.MustExec("create global binding for select * from t where a > 10 using select /*+ use_index(T, IDX_A) */ * from t where a > 10")
bindData = bindHandle.GetGlobalBinding(hash, sql, "test")
bindData = bindHandle.GetGlobalBinding(hash)
require.NotNil(t, bindData)
require.Equal(t, "select * from `test` . `t` where `a` > ?", bindData.OriginalSQL)
require.Len(t, bindData.Bindings, 1)
Expand All @@ -535,7 +535,7 @@ func TestHintsSetID(t *testing.T) {
err := tk.ExecToErr("create global binding for select * from t using select /*+ non_exist_hint() */ * from t")
require.True(t, terror.ErrorEqual(err, parser.ErrParse))
tk.MustExec("create global binding for select * from t where a > 10 using select * from t where a > 10")
bindData = bindHandle.GetGlobalBinding(hash, sql, "test")
bindData = bindHandle.GetGlobalBinding(hash)
require.NotNil(t, bindData)
require.Equal(t, "select * from `test` . `t` where `a` > ?", bindData.OriginalSQL)
require.Len(t, bindData.Bindings, 1)
Expand Down
4 changes: 2 additions & 2 deletions pkg/planner/core/plan_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -801,7 +801,7 @@ func GetBindSQL4PlanCache(sctx sessionctx.Context, stmt *PlanCacheStmt) (string,
return "", ignore
}
sessionHandle := sctx.Value(bindinfo.SessionBindInfoKeyType).(bindinfo.SessionBindingHandle)
bindRecord := sessionHandle.GetSessionBinding(stmt.SQLDigest4PC, stmt.NormalizedSQL4PC, "")
bindRecord := sessionHandle.GetSessionBinding(stmt.SQLDigest4PC)
if bindRecord != nil {
enabledBinding := bindRecord.FindEnabledBinding()
if enabledBinding != nil {
Expand All @@ -813,7 +813,7 @@ func GetBindSQL4PlanCache(sctx sessionctx.Context, stmt *PlanCacheStmt) (string,
if globalHandle == nil {
return "", ignore
}
bindRecord = globalHandle.GetGlobalBinding(stmt.SQLDigest4PC, stmt.NormalizedSQL4PC, "")
bindRecord = globalHandle.GetGlobalBinding(stmt.SQLDigest4PC)
if bindRecord != nil {
enabledBinding := bindRecord.FindEnabledBinding()
if enabledBinding != nil {
Expand Down

0 comments on commit 965ad8a

Please sign in to comment.