diff --git a/executor/executor.go b/executor/executor.go index 209262ac0d4db..eee07f8774ed0 100644 --- a/executor/executor.go +++ b/executor/executor.go @@ -983,7 +983,12 @@ func newLockCtx(seVars *variable.SessionVars, lockWaitTime int64) *tikvstore.Loc } if mutation := req.Mutations[0]; mutation != nil { label := resourcegrouptag.GetResourceGroupLabelByKey(mutation.Key) - return seVars.StmtCtx.GetResourceGroupTagByLabel(label) + normalized, digest := seVars.StmtCtx.SQLDigest() + if len(normalized) == 0 { + return nil + } + _, planDigest := seVars.StmtCtx.GetPlanDigest() + return resourcegrouptag.EncodeResourceGroupTag(digest, planDigest, label) } return nil } diff --git a/sessionctx/stmtctx/stmtctx.go b/sessionctx/stmtctx/stmtctx.go index 5d896000d5def..37fbe907f2b10 100644 --- a/sessionctx/stmtctx/stmtctx.go +++ b/sessionctx/stmtctx/stmtctx.go @@ -30,7 +30,6 @@ import ( "github.com/pingcap/tidb/util/memory" "github.com/pingcap/tidb/util/resourcegrouptag" "github.com/pingcap/tidb/util/tracing" - "github.com/pingcap/tipb/go-tipb" "github.com/tikv/client-go/v2/tikvrpc" "github.com/tikv/client-go/v2/util" atomic2 "go.uber.org/atomic" @@ -289,27 +288,20 @@ func (sc *StatementContext) GetPlanDigest() (normalized string, planDigest *pars // GetResourceGroupTagger returns the implementation of tikvrpc.ResourceGroupTagger related to self. func (sc *StatementContext) GetResourceGroupTagger() tikvrpc.ResourceGroupTagger { + normalized, digest := sc.SQLDigest() + planDigest := sc.planDigest return func(req *tikvrpc.Request) { if req == nil { return } - req.ResourceGroupTag = sc.GetResourceGroupTagByLabel( + if len(normalized) == 0 { + return + } + req.ResourceGroupTag = resourcegrouptag.EncodeResourceGroupTag(digest, planDigest, resourcegrouptag.GetResourceGroupLabelByKey(resourcegrouptag.GetFirstKeyFromRequest(req))) } } -// GetResourceGroupTagByLabel gets the resource group of the statement based on the label. -func (sc *StatementContext) GetResourceGroupTagByLabel(label tipb.ResourceGroupTagLabel) []byte { - if sc == nil { - return nil - } - normalized, sqlDigest := sc.SQLDigest() - if len(normalized) == 0 { - return nil - } - return resourcegrouptag.EncodeResourceGroupTag(sqlDigest, sc.planDigest, label) -} - // SetPlanDigest sets the normalized plan and plan digest. func (sc *StatementContext) SetPlanDigest(normalized string, planDigest *parser.Digest) { if planDigest != nil { diff --git a/sessionctx/stmtctx/stmtctx_test.go b/sessionctx/stmtctx/stmtctx_test.go index acfc9c00866d1..f5bf2cca866be 100644 --- a/sessionctx/stmtctx/stmtctx_test.go +++ b/sessionctx/stmtctx/stmtctx_test.go @@ -21,7 +21,6 @@ import ( "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/util/execdetails" - "github.com/pingcap/tipb/go-tipb" "github.com/stretchr/testify/require" "github.com/tikv/client-go/v2/util" ) @@ -91,25 +90,3 @@ func TestStatementContextPushDownFLags(t *testing.T) { require.Equal(t, tt.out, got) } } - -func TestGetResourceGroupTagByLabel(t *testing.T) { - ctx := stmtctx.StatementContext{OriginalSQL: "SELECT * FROM t"} - tagRow := ctx.GetResourceGroupTagByLabel(tipb.ResourceGroupTagLabel_ResourceGroupTagLabelRow) - tagIndex := ctx.GetResourceGroupTagByLabel(tipb.ResourceGroupTagLabel_ResourceGroupTagLabelIndex) - tagUnknown := ctx.GetResourceGroupTagByLabel(tipb.ResourceGroupTagLabel_ResourceGroupTagLabelUnknown) - tagRow2 := ctx.GetResourceGroupTagByLabel(tipb.ResourceGroupTagLabel_ResourceGroupTagLabelRow) - tagIndex2 := ctx.GetResourceGroupTagByLabel(tipb.ResourceGroupTagLabel_ResourceGroupTagLabelIndex) - tagUnknown2 := ctx.GetResourceGroupTagByLabel(tipb.ResourceGroupTagLabel_ResourceGroupTagLabelUnknown) - require.NotEmpty(t, tagRow) - require.NotEmpty(t, tagIndex) - require.NotEmpty(t, tagUnknown) - require.NotEmpty(t, tagRow2) - require.NotEmpty(t, tagIndex2) - require.NotEmpty(t, tagUnknown2) - require.Equal(t, &tagRow, &tagRow2) // mem addr - require.Equal(t, &tagIndex, &tagIndex2) - require.Equal(t, &tagUnknown, &tagUnknown2) - require.NotEqual(t, &tagRow, &tagIndex) - require.NotEqual(t, &tagRow, &tagUnknown) - require.NotEqual(t, &tagIndex, &tagUnknown) -}