diff --git a/distsql/request_builder.go b/distsql/request_builder.go index 44378b2a262a2..80bfae29e172f 100644 --- a/distsql/request_builder.go +++ b/distsql/request_builder.go @@ -235,7 +235,7 @@ func (builder *RequestBuilder) getKVPriority(sv *variable.SessionVars) int { } // SetFromSessionVars sets the following fields for "kv.Request" from session variables: -// "Concurrency", "IsolationLevel", "NotFillCache", "ReplicaRead", "SchemaVar". +// "Concurrency", "IsolationLevel", "NotFillCache", "TaskID", "Priority", "ReplicaRead", "ResourceGroupTagger". func (builder *RequestBuilder) SetFromSessionVars(sv *variable.SessionVars) *RequestBuilder { if builder.Request.Concurrency == 0 { // Concurrency may be set to 1 by SetDAGRequest @@ -246,7 +246,7 @@ func (builder *RequestBuilder) SetFromSessionVars(sv *variable.SessionVars) *Req builder.Request.TaskID = sv.StmtCtx.TaskID builder.Request.Priority = builder.getKVPriority(sv) builder.Request.ReplicaRead = sv.GetReplicaRead() - builder.SetResourceGroupTag(sv.StmtCtx) + builder.SetResourceGroupTagger(sv.StmtCtx) return builder } @@ -282,10 +282,10 @@ func (builder *RequestBuilder) SetFromInfoSchema(pis interface{}) *RequestBuilde return builder } -// SetResourceGroupTag sets the request resource group tag. -func (builder *RequestBuilder) SetResourceGroupTag(sc *stmtctx.StatementContext) *RequestBuilder { +// SetResourceGroupTagger sets the request resource group tagger. +func (builder *RequestBuilder) SetResourceGroupTagger(sc *stmtctx.StatementContext) *RequestBuilder { if variable.TopSQLEnabled() { - builder.Request.ResourceGroupTag = sc.GetResourceGroupTag() + builder.Request.ResourceGroupTagger = sc.GetResourceGroupTagger() } return builder } diff --git a/executor/analyze.go b/executor/analyze.go index ad471763db032..9414b2fbade1e 100644 --- a/executor/analyze.go +++ b/executor/analyze.go @@ -389,7 +389,7 @@ func (e *AnalyzeIndexExec) fetchAnalyzeResult(ranges []*ranger.Range, isNullRang } else { kvReqBuilder = builder.SetIndexRangesForTables(e.ctx.GetSessionVars().StmtCtx, []int64{e.tableID.GetStatisticsID()}, e.idxInfo.ID, ranges) } - kvReqBuilder.SetResourceGroupTag(e.ctx.GetSessionVars().StmtCtx) + kvReqBuilder.SetResourceGroupTagger(e.ctx.GetSessionVars().StmtCtx) kvReq, err := kvReqBuilder. SetAnalyzeRequest(e.analyzePB). SetStartTS(e.snapshot). @@ -750,7 +750,7 @@ func (e *AnalyzeColumnsExec) open(ranges []*ranger.Range) error { func (e *AnalyzeColumnsExec) buildResp(ranges []*ranger.Range) (distsql.SelectResult, error) { var builder distsql.RequestBuilder reqBuilder := builder.SetHandleRangesForTables(e.ctx.GetSessionVars().StmtCtx, []int64{e.TableID.GetStatisticsID()}, e.handleCols != nil && !e.handleCols.IsInt(), ranges, nil) - builder.SetResourceGroupTag(e.ctx.GetSessionVars().StmtCtx) + builder.SetResourceGroupTagger(e.ctx.GetSessionVars().StmtCtx) // Always set KeepOrder of the request to be true, in order to compute // correct `correlation` of columns. kvReq, err := reqBuilder. @@ -1853,7 +1853,7 @@ func (e *AnalyzeFastExec) handleScanTasks(bo *tikv.Backoffer) (keysSize int, err if e.ctx.GetSessionVars().GetReplicaRead().IsFollowerRead() { snapshot.SetOption(kv.ReplicaRead, kv.ReplicaReadFollower) } - setResourceGroupTagForTxn(e.ctx.GetSessionVars().StmtCtx, snapshot) + setResourceGroupTaggerForTxn(e.ctx.GetSessionVars().StmtCtx, snapshot) for _, t := range e.scanTasks { iter, err := snapshot.Iter(kv.Key(t.StartKey), kv.Key(t.EndKey)) if err != nil { @@ -1874,7 +1874,7 @@ func (e *AnalyzeFastExec) handleSampTasks(workID int, step uint32, err *error) { snapshot.SetOption(kv.NotFillCache, true) snapshot.SetOption(kv.IsolationLevel, kv.SI) snapshot.SetOption(kv.Priority, kv.PriorityLow) - setResourceGroupTagForTxn(e.ctx.GetSessionVars().StmtCtx, snapshot) + setResourceGroupTaggerForTxn(e.ctx.GetSessionVars().StmtCtx, snapshot) readReplicaType := e.ctx.GetSessionVars().GetReplicaRead() if readReplicaType.IsFollowerRead() { snapshot.SetOption(kv.ReplicaRead, readReplicaType) diff --git a/executor/batch_point_get.go b/executor/batch_point_get.go index 93c037b6775b6..c30bf507d6d9d 100644 --- a/executor/batch_point_get.go +++ b/executor/batch_point_get.go @@ -149,7 +149,7 @@ func (e *BatchPointGetExec) Open(context.Context) error { }, }) } - setResourceGroupTagForTxn(stmtCtx, snapshot) + setResourceGroupTaggerForTxn(stmtCtx, snapshot) var batchGetter kv.BatchGetter = snapshot if txn.Valid() { lock := e.tblInfo.Lock diff --git a/executor/checksum.go b/executor/checksum.go index 611510cec864d..69fd6ed319e75 100644 --- a/executor/checksum.go +++ b/executor/checksum.go @@ -241,7 +241,7 @@ func (c *checksumContext) buildTableRequest(ctx sessionctx.Context, tableID int6 } var builder distsql.RequestBuilder - builder.SetResourceGroupTag(ctx.GetSessionVars().StmtCtx) + builder.SetResourceGroupTagger(ctx.GetSessionVars().StmtCtx) return builder.SetHandleRanges(ctx.GetSessionVars().StmtCtx, tableID, c.TableInfo.IsCommonHandle, ranges, nil). SetChecksumRequest(checksum). SetStartTS(c.StartTs). @@ -258,7 +258,7 @@ func (c *checksumContext) buildIndexRequest(ctx sessionctx.Context, tableID int6 ranges := ranger.FullRange() var builder distsql.RequestBuilder - builder.SetResourceGroupTag(ctx.GetSessionVars().StmtCtx) + builder.SetResourceGroupTagger(ctx.GetSessionVars().StmtCtx) return builder.SetIndexRanges(ctx.GetSessionVars().StmtCtx, tableID, indexInfo.ID, ranges). SetChecksumRequest(checksum). SetStartTS(c.StartTs). diff --git a/executor/executor.go b/executor/executor.go index 7383941acd046..9ade86612b96f 100644 --- a/executor/executor.go +++ b/executor/executor.go @@ -31,6 +31,7 @@ import ( "github.com/opentracing/opentracing-go" "github.com/pingcap/errors" "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/kvrpcpb" "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/domain" "github.com/pingcap/tidb/domain/infosync" @@ -39,7 +40,6 @@ import ( "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/meta" "github.com/pingcap/tidb/meta/autoid" - "github.com/pingcap/tidb/parser" "github.com/pingcap/tidb/parser/ast" "github.com/pingcap/tidb/parser/auth" "github.com/pingcap/tidb/parser/model" @@ -968,18 +968,25 @@ func (e *SelectLockExec) Next(ctx context.Context, req *chunk.Chunk) error { } func newLockCtx(seVars *variable.SessionVars, lockWaitTime int64) *tikvstore.LockCtx { - var planDigest *parser.Digest - _, sqlDigest := seVars.StmtCtx.SQLDigest() - if variable.TopSQLEnabled() { - _, planDigest = seVars.StmtCtx.GetPlanDigest() - } lockCtx := tikvstore.NewLockCtx(seVars.TxnCtx.GetForUpdateTS(), lockWaitTime, seVars.StmtCtx.GetLockWaitStartTime()) lockCtx.Killed = &seVars.Killed lockCtx.PessimisticLockWaited = &seVars.StmtCtx.PessimisticLockWaited lockCtx.LockKeysDuration = &seVars.StmtCtx.LockKeysDuration lockCtx.LockKeysCount = &seVars.StmtCtx.LockKeysCount lockCtx.LockExpired = &seVars.TxnCtx.LockExpire - lockCtx.ResourceGroupTag = resourcegrouptag.EncodeResourceGroupTag(sqlDigest, planDigest) + lockCtx.ResourceGroupTagger = func(req *kvrpcpb.PessimisticLockRequest) []byte { + if req == nil { + return nil + } + if len(req.Mutations) == 0 { + return nil + } + if mutation := req.Mutations[0]; mutation != nil { + label := resourcegrouptag.GetResourceGroupLabelByKey(mutation.Key) + return seVars.StmtCtx.GetResourceGroupTagByLabel(label) + } + return nil + } lockCtx.OnDeadlock = func(deadlock *tikverr.ErrDeadlock) { cfg := config.GetGlobalConfig() if deadlock.IsRetryable && !cfg.PessimisticTxn.DeadlockHistoryCollectRetryable { @@ -1896,8 +1903,8 @@ func FillVirtualColumnValue(virtualRetTypes []*types.FieldType, virtualColumnInd return nil } -func setResourceGroupTagForTxn(sc *stmtctx.StatementContext, snapshot kv.Snapshot) { +func setResourceGroupTaggerForTxn(sc *stmtctx.StatementContext, snapshot kv.Snapshot) { if snapshot != nil && variable.TopSQLEnabled() { - snapshot.SetOption(kv.ResourceGroupTag, sc.GetResourceGroupTag()) + snapshot.SetOption(kv.ResourceGroupTagger, sc.GetResourceGroupTagger()) } } diff --git a/executor/executor_test.go b/executor/executor_test.go index b3e9d613c9c53..ae2fb1916778b 100644 --- a/executor/executor_test.go +++ b/executor/executor_test.go @@ -8692,6 +8692,7 @@ func (s *testResourceTagSuite) TestResourceGroupTag(c *C) { defer failpoint.Disable("github.com/pingcap/tidb/store/mockstore/unistore/unistoreRPCClientSendHook") var sqlDigest, planDigest *parser.Digest + var tagLabel tipb.ResourceGroupTagLabel checkFn := func() {} unistore.UnistoreRPCClientSendHook = func(req *tikvrpc.Request) { var startKey []byte @@ -8734,6 +8735,7 @@ func (s *testResourceTagSuite) TestResourceGroupTag(c *C) { c.Assert(err, IsNil) sqlDigest = parser.NewDigest(tag.SqlDigest) planDigest = parser.NewDigest(tag.PlanDigest) + tagLabel = *tag.Label checkFn() } @@ -8743,19 +8745,78 @@ func (s *testResourceTagSuite) TestResourceGroupTag(c *C) { } cases := []struct { - sql string - ignore bool + sql string + tagLabels []tipb.ResourceGroupTagLabel + ignore bool }{ - {sql: "insert into t values(1,1),(2,2),(3,3)"}, - {sql: "select * from t use index (idx) where a=1"}, - {sql: "select * from t use index (idx) where a in (1,2,3)"}, - {sql: "select * from t use index (idx) where a>1"}, - {sql: "select * from t where b>1"}, - {sql: "begin pessimistic", ignore: true}, - {sql: "insert into t values(4,4)"}, - {sql: "commit", ignore: true}, - {sql: "update t set a=5,b=5 where a=5"}, - {sql: "replace into t values(6,6)"}, + { + sql: "insert into t values(1,1),(2,2),(3,3)", + tagLabels: []tipb.ResourceGroupTagLabel{ + tipb.ResourceGroupTagLabel_ResourceGroupTagLabelIndex, + tipb.ResourceGroupTagLabel_ResourceGroupTagLabelIndex, + }, + }, + { + sql: "select * from t use index (idx) where a=1", + tagLabels: []tipb.ResourceGroupTagLabel{ + tipb.ResourceGroupTagLabel_ResourceGroupTagLabelIndex, + tipb.ResourceGroupTagLabel_ResourceGroupTagLabelRow, + }, + }, + { + sql: "select * from t use index (idx) where a in (1,2,3)", + tagLabels: []tipb.ResourceGroupTagLabel{ + tipb.ResourceGroupTagLabel_ResourceGroupTagLabelIndex, + tipb.ResourceGroupTagLabel_ResourceGroupTagLabelRow, + }, + }, + { + sql: "select * from t use index (idx) where a>1", + tagLabels: []tipb.ResourceGroupTagLabel{ + tipb.ResourceGroupTagLabel_ResourceGroupTagLabelIndex, + tipb.ResourceGroupTagLabel_ResourceGroupTagLabelRow, + }, + }, + { + sql: "select * from t where b>1", + tagLabels: []tipb.ResourceGroupTagLabel{ + tipb.ResourceGroupTagLabel_ResourceGroupTagLabelRow, + }, + }, + { + sql: "select a from t use index (idx) where a>1", + tagLabels: []tipb.ResourceGroupTagLabel{ + tipb.ResourceGroupTagLabel_ResourceGroupTagLabelIndex, + }, + }, + { + sql: "begin pessimistic", + ignore: true, + }, + { + sql: "insert into t values(4,4)", + tagLabels: []tipb.ResourceGroupTagLabel{ + tipb.ResourceGroupTagLabel_ResourceGroupTagLabelIndex, + tipb.ResourceGroupTagLabel_ResourceGroupTagLabelRow, + }, + }, + { + sql: "commit", + ignore: true, + }, + { + sql: "update t set a=5,b=5 where a=5", + tagLabels: []tipb.ResourceGroupTagLabel{ + tipb.ResourceGroupTagLabel_ResourceGroupTagLabelIndex, + }, + }, + { + sql: "replace into t values(6,6)", + tagLabels: []tipb.ResourceGroupTagLabel{ + tipb.ResourceGroupTagLabel_ResourceGroupTagLabelIndex, + tipb.ResourceGroupTagLabel_ResourceGroupTagLabelIndex, + }, + }, } for _, ca := range cases { resetVars() @@ -8777,6 +8838,10 @@ func (s *testResourceTagSuite) TestResourceGroupTag(c *C) { } c.Assert(sqlDigest.String(), Equals, expectSQLDigest.String(), commentf) c.Assert(planDigest.String(), Equals, expectPlanDigest.String()) + if len(ca.tagLabels) > 0 { + c.Assert(tagLabel, Equals, ca.tagLabels[0]) + ca.tagLabels = ca.tagLabels[1:] // next label + } checkCnt++ } diff --git a/executor/insert.go b/executor/insert.go index 714964c4715c8..2862416ddf04a 100644 --- a/executor/insert.go +++ b/executor/insert.go @@ -66,7 +66,7 @@ func (e *InsertExec) exec(ctx context.Context, rows [][]types.Datum) error { if err != nil { return err } - setResourceGroupTagForTxn(sessVars.StmtCtx, txn) + setResourceGroupTaggerForTxn(sessVars.StmtCtx, txn) txnSize := txn.Size() sessVars.StmtCtx.AddRecordRows(uint64(len(rows))) // If you use the IGNORE keyword, duplicate-key error that occurs while executing the INSERT statement are ignored. diff --git a/executor/point_get.go b/executor/point_get.go index 489bbf9bb8085..45f3fa76e263f 100644 --- a/executor/point_get.go +++ b/executor/point_get.go @@ -190,7 +190,7 @@ func (e *PointGetExecutor) Open(context.Context) error { panic("point get replica option fail") } }) - setResourceGroupTagForTxn(e.ctx.GetSessionVars().StmtCtx, e.snapshot) + setResourceGroupTaggerForTxn(e.ctx.GetSessionVars().StmtCtx, e.snapshot) return nil } diff --git a/executor/replace.go b/executor/replace.go index a9b29707f9ade..78e0085aa520e 100644 --- a/executor/replace.go +++ b/executor/replace.go @@ -222,7 +222,7 @@ func (e *ReplaceExec) exec(ctx context.Context, newRows [][]types.Datum) error { defer snapshot.SetOption(kv.CollectRuntimeStats, nil) } } - setResourceGroupTagForTxn(e.ctx.GetSessionVars().StmtCtx, txn) + setResourceGroupTaggerForTxn(e.ctx.GetSessionVars().StmtCtx, txn) prefetchStart := time.Now() // Use BatchGet to fill cache. // It's an optimization and could be removed without affecting correctness. diff --git a/executor/update.go b/executor/update.go index 03a9979878ab1..7df144b28196c 100644 --- a/executor/update.go +++ b/executor/update.go @@ -274,7 +274,7 @@ func (e *UpdateExec) updateRows(ctx context.Context) (int, error) { if variable.TopSQLEnabled() { txn, err := e.ctx.Txn(true) if err == nil { - txn.SetOption(kv.ResourceGroupTag, e.ctx.GetSessionVars().StmtCtx.GetResourceGroupTag()) + txn.SetOption(kv.ResourceGroupTagger, e.ctx.GetSessionVars().StmtCtx.GetResourceGroupTagger()) } } for rowIdx := 0; rowIdx < chk.NumRows(); rowIdx++ { diff --git a/go.mod b/go.mod index fe719777fb6cf..4f83f251d4d54 100644 --- a/go.mod +++ b/go.mod @@ -65,7 +65,7 @@ require ( github.com/spf13/pflag v1.0.5 github.com/stretchr/testify v1.7.0 github.com/tiancaiamao/appdash v0.0.0-20181126055449-889f96f722a2 - github.com/tikv/client-go/v2 v2.0.0-alpha.0.20211115071040-a3f1c41ac1a0 + github.com/tikv/client-go/v2 v2.0.0-alpha.0.20211118154139-b11da6307c6f github.com/tikv/pd v1.1.0-beta.0.20211104095303-69c86d05d379 github.com/twmb/murmur3 v1.1.3 github.com/uber/jaeger-client-go v2.22.1+incompatible diff --git a/go.sum b/go.sum index 36ec754040cd5..4e8fbc01390bd 100644 --- a/go.sum +++ b/go.sum @@ -710,8 +710,8 @@ github.com/tiancaiamao/appdash v0.0.0-20181126055449-889f96f722a2/go.mod h1:2PfK github.com/tidwall/gjson v1.3.5/go.mod h1:P256ACg0Mn+j1RXIDXoss50DeIABTYK1PULOJHhxOls= github.com/tidwall/match v1.0.1/go.mod h1:LujAq0jyVjBy028G1WhWfIzbpQfMO8bBZ6Tyb0+pL9E= github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk= -github.com/tikv/client-go/v2 v2.0.0-alpha.0.20211115071040-a3f1c41ac1a0 h1:c12Pv8Xks4oubDr/uHHxrlBkwGJFqKZUEIUemHV794g= -github.com/tikv/client-go/v2 v2.0.0-alpha.0.20211115071040-a3f1c41ac1a0/go.mod h1:iiwtsCxcbNLK5i9VRYGvdcihgHXTKy2ukWjoaJsrphg= +github.com/tikv/client-go/v2 v2.0.0-alpha.0.20211118154139-b11da6307c6f h1:UyJjp3wGIjf1edGiQiIdAtL5QFqaqR4+s3LDwUZU7NY= +github.com/tikv/client-go/v2 v2.0.0-alpha.0.20211118154139-b11da6307c6f/go.mod h1:BEAS0vXm5BorlF/HTndqGwcGDvaiwe7B7BkfgwwZMJ4= github.com/tikv/pd v1.1.0-beta.0.20211029083450-e65f0c55b6ae/go.mod h1:varH0IE0jJ9E9WN2Ei/N6pajMlPkcXdDEf7f5mmsUVQ= github.com/tikv/pd v1.1.0-beta.0.20211104095303-69c86d05d379 h1:nFm1jQDz1iRktoyV2SyM5zVk6+PJHQNunJZ7ZJcqzAo= github.com/tikv/pd v1.1.0-beta.0.20211104095303-69c86d05d379/go.mod h1:y+09hAUXJbrd4c0nktL74zXDDuD7atGtfOKxL90PCOE= diff --git a/infoschema/cluster_tables_serial_test.go b/infoschema/cluster_tables_serial_test.go index b39aa73f2efd6..a9cfd496c4b8d 100644 --- a/infoschema/cluster_tables_serial_test.go +++ b/infoschema/cluster_tables_serial_test.go @@ -43,6 +43,7 @@ import ( "github.com/pingcap/tidb/util/resourcegrouptag" "github.com/pingcap/tidb/util/set" "github.com/pingcap/tidb/util/testutil" + "github.com/pingcap/tipb/go-tipb" "github.com/stretchr/testify/require" "google.golang.org/grpc" ) @@ -160,10 +161,10 @@ func SubTestTestDataLockWaits(s *clusterTablesSuite) func(*testing.T) { _, digest1 := parser.NormalizeDigest("select * from test_data_lock_waits for update") _, digest2 := parser.NormalizeDigest("update test_data_lock_waits set f1=1 where id=2") s.store.(mockstorage.MockLockWaitSetter).SetMockLockWaits([]*deadlock.WaitForEntry{ - {Txn: 1, WaitForTxn: 2, Key: []byte("key1"), ResourceGroupTag: resourcegrouptag.EncodeResourceGroupTag(digest1, nil)}, - {Txn: 3, WaitForTxn: 4, Key: []byte("key2"), ResourceGroupTag: resourcegrouptag.EncodeResourceGroupTag(digest2, nil)}, + {Txn: 1, WaitForTxn: 2, Key: []byte("key1"), ResourceGroupTag: resourcegrouptag.EncodeResourceGroupTag(digest1, nil, tipb.ResourceGroupTagLabel_ResourceGroupTagLabelUnknown)}, + {Txn: 3, WaitForTxn: 4, Key: []byte("key2"), ResourceGroupTag: resourcegrouptag.EncodeResourceGroupTag(digest2, nil, tipb.ResourceGroupTagLabel_ResourceGroupTagLabelUnknown)}, // Invalid digests - {Txn: 5, WaitForTxn: 6, Key: []byte("key3"), ResourceGroupTag: resourcegrouptag.EncodeResourceGroupTag(nil, nil)}, + {Txn: 5, WaitForTxn: 6, Key: []byte("key3"), ResourceGroupTag: resourcegrouptag.EncodeResourceGroupTag(nil, nil, tipb.ResourceGroupTagLabel_ResourceGroupTagLabelUnknown)}, {Txn: 7, WaitForTxn: 8, Key: []byte("key4"), ResourceGroupTag: []byte("asdfghjkl")}, }) diff --git a/kv/kv.go b/kv/kv.go index a6ce83fd8caf0..22610d17b8d9a 100644 --- a/kv/kv.go +++ b/kv/kv.go @@ -30,6 +30,7 @@ import ( tikvstore "github.com/tikv/client-go/v2/kv" "github.com/tikv/client-go/v2/oracle" "github.com/tikv/client-go/v2/tikv" + "github.com/tikv/client-go/v2/tikvrpc" ) // UnCommitIndexKVFlag uses to indicate the index key/value is no need to commit. @@ -335,8 +336,8 @@ type Request struct { IsStaleness bool // MatchStoreLabels indicates the labels the store should be matched MatchStoreLabels []*metapb.StoreLabel - // ResourceGroupTag indicates the kv request task group. - ResourceGroupTag []byte + // ResourceGroupTagger indicates the kv request task group tagger. + ResourceGroupTagger tikvrpc.ResourceGroupTagger } const ( diff --git a/kv/option.go b/kv/option.go index 682c2be4f2d60..683f5fcb8e389 100644 --- a/kv/option.go +++ b/kv/option.go @@ -62,10 +62,13 @@ const ( IsStalenessReadOnly // MatchStoreLabels indicates the labels the store should be matched MatchStoreLabels - // ResourceGroupTag indicates the resource group of the kv request. + // ResourceGroupTag indicates the resource group tag of the kv request. ResourceGroupTag + // ResourceGroupTagger can be used to set the ResourceGroupTag dynamically according to the request content. It will be used only when ResourceGroupTag is nil. + ResourceGroupTagger // KVFilter indicates the filter to ignore key-values in the transaction's memory buffer. KVFilter + // SnapInterceptor is used for setting the interceptor for snapshot SnapInterceptor ) diff --git a/session/session.go b/session/session.go index e53c3fe5b4e38..2228431007048 100644 --- a/session/session.go +++ b/session/session.go @@ -544,7 +544,7 @@ func (s *session) doCommit(ctx context.Context) error { } s.txn.SetOption(kv.EnableAsyncCommit, sessVars.EnableAsyncCommit) s.txn.SetOption(kv.Enable1PC, sessVars.Enable1PC) - s.txn.SetOption(kv.ResourceGroupTag, sessVars.StmtCtx.GetResourceGroupTag()) + s.txn.SetOption(kv.ResourceGroupTagger, sessVars.StmtCtx.GetResourceGroupTagger()) // priority of the sysvar is lower than `start transaction with causal consistency only` if val := s.txn.GetOption(kv.GuaranteeLinearizability); val == nil || val.(bool) { // We needn't ask the TiKV client to guarantee linearizability for auto-commit transactions diff --git a/sessionctx/stmtctx/stmtctx.go b/sessionctx/stmtctx/stmtctx.go index 03c488c883a0c..07cbd89d18e52 100644 --- a/sessionctx/stmtctx/stmtctx.go +++ b/sessionctx/stmtctx/stmtctx.go @@ -30,6 +30,8 @@ import ( "github.com/pingcap/tidb/util/memory" "github.com/pingcap/tidb/util/resourcegrouptag" "github.com/pingcap/tidb/util/tracing" + "github.com/pingcap/tipb/go-tipb" + "github.com/tikv/client-go/v2/tikvrpc" "github.com/tikv/client-go/v2/util" atomic2 "go.uber.org/atomic" "go.uber.org/zap" @@ -171,8 +173,12 @@ type StatementContext struct { // stmtCache is used to store some statement-related values. stmtCache map[StmtCacheKey]interface{} - // resourceGroupTag cache for the current statement resource group tag. - resourceGroupTag atomic.Value + // resourceGroupTagWithRow cache for the current statement resource group tag (with `Row` label). + resourceGroupTagWithRow atomic.Value + // resourceGroupTagWithIndex cache for the current statement resource group tag (with `Index` label). + resourceGroupTagWithIndex atomic.Value + // resourceGroupTagWithUnknown cache for the current statement resource group tag (with `Unknown` label). + resourceGroupTagWithUnknown atomic.Value // Map to store all CTE storages of current SQL. // Will clean up at the end of the execution. CTEStorageMap interface{} @@ -280,19 +286,47 @@ func (sc *StatementContext) GetPlanDigest() (normalized string, planDigest *pars return sc.planNormalized, sc.planDigest } -// GetResourceGroupTag gets the resource group of the statement. -func (sc *StatementContext) GetResourceGroupTag() []byte { - tag, _ := sc.resourceGroupTag.Load().([]byte) - if len(tag) > 0 { - return tag +// GetResourceGroupTagger returns the implementation of tikvrpc.ResourceGroupTagger related to self. +func (sc *StatementContext) GetResourceGroupTagger() tikvrpc.ResourceGroupTagger { + return func(req *tikvrpc.Request) { + req.ResourceGroupTag = sc.GetResourceGroupTagByLabel( + resourcegrouptag.GetResourceGroupLabelByKey(resourcegrouptag.GetFirstKeyFromRequest(req))) + } +} + +// GetResourceGroupTagByLabel gets the resource group of the statement based on the label. +func (sc *StatementContext) GetResourceGroupTagByLabel(label tipb.ResourceGroupTagLabel) []byte { + switch label { + case tipb.ResourceGroupTagLabel_ResourceGroupTagLabelRow: + v := sc.resourceGroupTagWithRow.Load() + if v != nil { + return v.([]byte) + } + case tipb.ResourceGroupTagLabel_ResourceGroupTagLabelIndex: + v := sc.resourceGroupTagWithIndex.Load() + if v != nil { + return v.([]byte) + } + case tipb.ResourceGroupTagLabel_ResourceGroupTagLabelUnknown: + v := sc.resourceGroupTagWithUnknown.Load() + if v != nil { + return v.([]byte) + } } normalized, sqlDigest := sc.SQLDigest() if len(normalized) == 0 { return nil } - tag = resourcegrouptag.EncodeResourceGroupTag(sqlDigest, sc.planDigest) - sc.resourceGroupTag.Store(tag) - return tag + newTag := resourcegrouptag.EncodeResourceGroupTag(sqlDigest, sc.planDigest, label) + switch label { + case tipb.ResourceGroupTagLabel_ResourceGroupTagLabelRow: + sc.resourceGroupTagWithRow.Store(newTag) + case tipb.ResourceGroupTagLabel_ResourceGroupTagLabelIndex: + sc.resourceGroupTagWithIndex.Store(newTag) + case tipb.ResourceGroupTagLabel_ResourceGroupTagLabelUnknown: + sc.resourceGroupTagWithUnknown.Store(newTag) + } + return newTag } // SetPlanDigest sets the normalized plan and plan digest. diff --git a/sessionctx/stmtctx/stmtctx_test.go b/sessionctx/stmtctx/stmtctx_test.go index f5bf2cca866be..acfc9c00866d1 100644 --- a/sessionctx/stmtctx/stmtctx_test.go +++ b/sessionctx/stmtctx/stmtctx_test.go @@ -21,6 +21,7 @@ import ( "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/util/execdetails" + "github.com/pingcap/tipb/go-tipb" "github.com/stretchr/testify/require" "github.com/tikv/client-go/v2/util" ) @@ -90,3 +91,25 @@ func TestStatementContextPushDownFLags(t *testing.T) { require.Equal(t, tt.out, got) } } + +func TestGetResourceGroupTagByLabel(t *testing.T) { + ctx := stmtctx.StatementContext{OriginalSQL: "SELECT * FROM t"} + tagRow := ctx.GetResourceGroupTagByLabel(tipb.ResourceGroupTagLabel_ResourceGroupTagLabelRow) + tagIndex := ctx.GetResourceGroupTagByLabel(tipb.ResourceGroupTagLabel_ResourceGroupTagLabelIndex) + tagUnknown := ctx.GetResourceGroupTagByLabel(tipb.ResourceGroupTagLabel_ResourceGroupTagLabelUnknown) + tagRow2 := ctx.GetResourceGroupTagByLabel(tipb.ResourceGroupTagLabel_ResourceGroupTagLabelRow) + tagIndex2 := ctx.GetResourceGroupTagByLabel(tipb.ResourceGroupTagLabel_ResourceGroupTagLabelIndex) + tagUnknown2 := ctx.GetResourceGroupTagByLabel(tipb.ResourceGroupTagLabel_ResourceGroupTagLabelUnknown) + require.NotEmpty(t, tagRow) + require.NotEmpty(t, tagIndex) + require.NotEmpty(t, tagUnknown) + require.NotEmpty(t, tagRow2) + require.NotEmpty(t, tagIndex2) + require.NotEmpty(t, tagUnknown2) + require.Equal(t, &tagRow, &tagRow2) // mem addr + require.Equal(t, &tagIndex, &tagIndex2) + require.Equal(t, &tagUnknown, &tagUnknown2) + require.NotEqual(t, &tagRow, &tagIndex) + require.NotEqual(t, &tagRow, &tagUnknown) + require.NotEqual(t, &tagIndex, &tagUnknown) +} diff --git a/store/copr/batch_coprocessor.go b/store/copr/batch_coprocessor.go index 77ed84366c0a4..e2ab8de16c30b 100644 --- a/store/copr/batch_coprocessor.go +++ b/store/copr/batch_coprocessor.go @@ -793,14 +793,16 @@ func (b *batchCopIterator) handleTaskOnce(ctx context.Context, bo *backoff.Backo } req := tikvrpc.NewRequest(task.cmdType, &copReq, kvrpcpb.Context{ - IsolationLevel: isolationLevelToPB(b.req.IsolationLevel), - Priority: priorityToPB(b.req.Priority), - NotFillCache: b.req.NotFillCache, - RecordTimeStat: true, - RecordScanStat: true, - TaskId: b.req.TaskID, - ResourceGroupTag: b.req.ResourceGroupTag, + IsolationLevel: isolationLevelToPB(b.req.IsolationLevel), + Priority: priorityToPB(b.req.Priority), + NotFillCache: b.req.NotFillCache, + RecordTimeStat: true, + RecordScanStat: true, + TaskId: b.req.TaskID, }) + if b.req.ResourceGroupTagger != nil { + b.req.ResourceGroupTagger(req) + } req.StoreTp = tikvrpc.TiFlash logutil.BgLogger().Debug("send batch request to ", zap.String("req info", req.String()), zap.Int("cop task len", len(task.regionInfos))) diff --git a/store/copr/coprocessor.go b/store/copr/coprocessor.go index 77fbd2a90a6e9..5606cb246e863 100644 --- a/store/copr/coprocessor.go +++ b/store/copr/coprocessor.go @@ -706,14 +706,16 @@ func (worker *copIteratorWorker) handleTaskOnce(bo *Backoffer, task *copTask, ch } req := tikvrpc.NewReplicaReadRequest(task.cmdType, &copReq, options.GetTiKVReplicaReadType(worker.req.ReplicaRead), &worker.replicaReadSeed, kvrpcpb.Context{ - IsolationLevel: isolationLevelToPB(worker.req.IsolationLevel), - Priority: priorityToPB(worker.req.Priority), - NotFillCache: worker.req.NotFillCache, - RecordTimeStat: true, - RecordScanStat: true, - TaskId: worker.req.TaskID, - ResourceGroupTag: worker.req.ResourceGroupTag, + IsolationLevel: isolationLevelToPB(worker.req.IsolationLevel), + Priority: priorityToPB(worker.req.Priority), + NotFillCache: worker.req.NotFillCache, + RecordTimeStat: true, + RecordScanStat: true, + TaskId: worker.req.TaskID, }) + if worker.req.ResourceGroupTagger != nil { + worker.req.ResourceGroupTagger(req) + } req.StoreTp = getEndPointType(task.storeType) startTime := time.Now() if worker.kvclient.Stats == nil { diff --git a/store/driver/txn/snapshot.go b/store/driver/txn/snapshot.go index 98179d4d62d02..3c372bae83725 100644 --- a/store/driver/txn/snapshot.go +++ b/store/driver/txn/snapshot.go @@ -22,6 +22,7 @@ import ( "github.com/pingcap/tidb/kv" derr "github.com/pingcap/tidb/store/driver/error" "github.com/pingcap/tidb/store/driver/options" + "github.com/tikv/client-go/v2/tikvrpc" "github.com/tikv/client-go/v2/txnkv/txnsnapshot" "github.com/tikv/client-go/v2/txnkv/txnutil" ) @@ -113,6 +114,8 @@ func (s *tikvSnapshot) SetOption(opt int, val interface{}) { s.KVSnapshot.SetMatchStoreLabels(val.([]*metapb.StoreLabel)) case kv.ResourceGroupTag: s.KVSnapshot.SetResourceGroupTag(val.([]byte)) + case kv.ResourceGroupTagger: + s.KVSnapshot.SetResourceGroupTagger(val.(tikvrpc.ResourceGroupTagger)) case kv.ReadReplicaScope: s.KVSnapshot.SetReadReplicaScope(val.(string)) case kv.SnapInterceptor: diff --git a/store/driver/txn/txn_driver.go b/store/driver/txn/txn_driver.go index 0c08e9b3f65db..823d33ac88f59 100644 --- a/store/driver/txn/txn_driver.go +++ b/store/driver/txn/txn_driver.go @@ -31,6 +31,7 @@ import ( tikverr "github.com/tikv/client-go/v2/error" tikvstore "github.com/tikv/client-go/v2/kv" "github.com/tikv/client-go/v2/tikv" + "github.com/tikv/client-go/v2/tikvrpc" "github.com/tikv/client-go/v2/txnkv/txnsnapshot" ) @@ -223,6 +224,8 @@ func (txn *tikvTxn) SetOption(opt int, val interface{}) { txn.KVTxn.GetSnapshot().SetMatchStoreLabels(val.([]*metapb.StoreLabel)) case kv.ResourceGroupTag: txn.KVTxn.SetResourceGroupTag(val.([]byte)) + case kv.ResourceGroupTagger: + txn.KVTxn.SetResourceGroupTagger(val.(tikvrpc.ResourceGroupTagger)) case kv.KVFilter: txn.KVTxn.SetKVFilter(val.(tikv.KVFilter)) case kv.SnapInterceptor: diff --git a/tablecodec/rowindexcodec/main_test.go b/tablecodec/rowindexcodec/main_test.go new file mode 100644 index 0000000000000..55b15ba96e15d --- /dev/null +++ b/tablecodec/rowindexcodec/main_test.go @@ -0,0 +1,27 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rowindexcodec + +import ( + "testing" + + "github.com/pingcap/tidb/util/testbridge" + "go.uber.org/goleak" +) + +func TestMain(m *testing.M) { + testbridge.WorkaroundGoCheckFlags() + goleak.VerifyTestMain(m) +} diff --git a/tablecodec/rowindexcodec/rowindexcodec.go b/tablecodec/rowindexcodec/rowindexcodec.go new file mode 100644 index 0000000000000..924d3fff2e5cf --- /dev/null +++ b/tablecodec/rowindexcodec/rowindexcodec.go @@ -0,0 +1,57 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rowindexcodec + +import ( + "bytes" +) + +// KeyKind is a specific type of key, mainly used to distinguish row/index. +type KeyKind int + +const ( + // KeyKindUnknown indicates that this key is unknown type. + KeyKindUnknown KeyKind = iota + // KeyKindRow means that this key belongs to row. + KeyKindRow + // KeyKindIndex means that this key belongs to index. + KeyKindIndex +) + +var ( + tablePrefix = []byte{'t'} + rowPrefix = []byte("_r") + indexPrefix = []byte("_i") +) + +// GetKeyKind returns the KeyKind that matches the key in the minimalist way. +func GetKeyKind(key []byte) KeyKind { + // [ TABLE_PREFIX | TABLE_ID | ROW_PREFIX (INDEX_PREFIX) | ROW_ID (INDEX_ID) | ... ] (name) + // [ 1 | 8 | 2 | 8 | ... ] (byte) + if len(key) < 11 { + return KeyKindUnknown + } + if !bytes.HasPrefix(key, tablePrefix) { + return KeyKindUnknown + } + key = key[9:] + if bytes.HasPrefix(key, rowPrefix) { + return KeyKindRow + } + if bytes.HasPrefix(key, indexPrefix) { + return KeyKindIndex + } + return KeyKindUnknown +} diff --git a/tablecodec/rowindexcodec/rowindexcodec_test.go b/tablecodec/rowindexcodec/rowindexcodec_test.go new file mode 100644 index 0000000000000..146fc80d9bee0 --- /dev/null +++ b/tablecodec/rowindexcodec/rowindexcodec_test.go @@ -0,0 +1,28 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rowindexcodec + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestGetKeyKind(t *testing.T) { + require.Equal(t, KeyKindRow, GetKeyKind([]byte{116, 128, 0, 0, 0, 0, 0, 0, 0, 95, 114})) + require.Equal(t, KeyKindIndex, GetKeyKind([]byte{116, 128, 0, 0, 0, 0, 0, 0, 0, 95, 105, 128, 0, 0, 0, 0, 0, 0, 0})) + require.Equal(t, KeyKindUnknown, GetKeyKind([]byte(""))) + require.Equal(t, KeyKindUnknown, GetKeyKind(nil)) +} diff --git a/util/resourcegrouptag/resource_group_tag.go b/util/resourcegrouptag/resource_group_tag.go index bf63a08cd6d07..1b70c6e207ea4 100644 --- a/util/resourcegrouptag/resource_group_tag.go +++ b/util/resourcegrouptag/resource_group_tag.go @@ -16,17 +16,21 @@ package resourcegrouptag import ( "github.com/pingcap/errors" + "github.com/pingcap/kvproto/pkg/coprocessor" + "github.com/pingcap/kvproto/pkg/kvrpcpb" "github.com/pingcap/tidb/parser" + "github.com/pingcap/tidb/tablecodec/rowindexcodec" "github.com/pingcap/tipb/go-tipb" + "github.com/tikv/client-go/v2/tikvrpc" ) // EncodeResourceGroupTag encodes sql digest and plan digest into resource group tag. -func EncodeResourceGroupTag(sqlDigest, planDigest *parser.Digest) []byte { +func EncodeResourceGroupTag(sqlDigest, planDigest *parser.Digest, label tipb.ResourceGroupTagLabel) []byte { if sqlDigest == nil && planDigest == nil { return nil } - tag := &tipb.ResourceGroupTag{} + tag := &tipb.ResourceGroupTag{Label: &label} if sqlDigest != nil { tag.SqlDigest = sqlDigest.Bytes() } @@ -52,3 +56,75 @@ func DecodeResourceGroupTag(data []byte) (sqlDigest []byte, err error) { } return tag.SqlDigest, nil } + +// GetResourceGroupLabelByKey determines the tipb.ResourceGroupTagLabel of key. +func GetResourceGroupLabelByKey(key []byte) tipb.ResourceGroupTagLabel { + switch rowindexcodec.GetKeyKind(key) { + case rowindexcodec.KeyKindRow: + return tipb.ResourceGroupTagLabel_ResourceGroupTagLabelRow + case rowindexcodec.KeyKindIndex: + return tipb.ResourceGroupTagLabel_ResourceGroupTagLabelIndex + default: + return tipb.ResourceGroupTagLabel_ResourceGroupTagLabelUnknown + } +} + +// GetFirstKeyFromRequest gets the first Key of the request from tikvrpc.Request. +func GetFirstKeyFromRequest(req *tikvrpc.Request) (firstKey []byte) { + if req == nil { + return + } + switch req.Req.(type) { + case *kvrpcpb.GetRequest: + r := req.Req.(*kvrpcpb.GetRequest) + if r != nil { + firstKey = r.Key + } + case *kvrpcpb.BatchGetRequest: + r := req.Req.(*kvrpcpb.BatchGetRequest) + if r != nil && len(r.Keys) > 0 { + firstKey = r.Keys[0] + } + case *kvrpcpb.ScanRequest: + r := req.Req.(*kvrpcpb.ScanRequest) + if r != nil { + firstKey = r.StartKey + } + case *kvrpcpb.PrewriteRequest: + r := req.Req.(*kvrpcpb.PrewriteRequest) + if r != nil && len(r.Mutations) > 0 { + if mutation := r.Mutations[0]; mutation != nil { + firstKey = mutation.Key + } + } + case *kvrpcpb.CommitRequest: + r := req.Req.(*kvrpcpb.CommitRequest) + if r != nil && len(r.Keys) > 0 { + firstKey = r.Keys[0] + } + case *kvrpcpb.BatchRollbackRequest: + r := req.Req.(*kvrpcpb.BatchRollbackRequest) + if r != nil && len(r.Keys) > 0 { + firstKey = r.Keys[0] + } + case *coprocessor.Request: + r := req.Req.(*coprocessor.Request) + if r != nil && len(r.Ranges) > 0 { + if keyRange := r.Ranges[0]; keyRange != nil { + firstKey = keyRange.Start + } + } + case *coprocessor.BatchRequest: + r := req.Req.(*coprocessor.BatchRequest) + if r != nil && len(r.Regions) > 0 { + if region := r.Regions[0]; region != nil { + if len(region.Ranges) > 0 { + if keyRange := region.Ranges[0]; keyRange != nil { + firstKey = keyRange.Start + } + } + } + } + } + return +} diff --git a/util/resourcegrouptag/resource_group_tag_test.go b/util/resourcegrouptag/resource_group_tag_test.go index 355134251a020..5cf65d52ad411 100644 --- a/util/resourcegrouptag/resource_group_tag_test.go +++ b/util/resourcegrouptag/resource_group_tag_test.go @@ -19,40 +19,43 @@ import ( "math/rand" "testing" + "github.com/pingcap/kvproto/pkg/coprocessor" + "github.com/pingcap/kvproto/pkg/kvrpcpb" "github.com/pingcap/tidb/parser" "github.com/pingcap/tidb/util/hack" "github.com/pingcap/tipb/go-tipb" "github.com/stretchr/testify/require" + "github.com/tikv/client-go/v2/tikvrpc" ) func TestResourceGroupTagEncoding(t *testing.T) { t.Parallel() sqlDigest := parser.NewDigest(nil) - tag := EncodeResourceGroupTag(sqlDigest, nil) - require.Len(t, tag, 0) + tag := EncodeResourceGroupTag(sqlDigest, nil, tipb.ResourceGroupTagLabel_ResourceGroupTagLabelUnknown) + require.Len(t, tag, 2) decodedSQLDigest, err := DecodeResourceGroupTag(tag) require.NoError(t, err) require.Len(t, decodedSQLDigest, 0) sqlDigest = parser.NewDigest([]byte{'a', 'a'}) - tag = EncodeResourceGroupTag(sqlDigest, nil) - // version(1) + prefix(1) + length(1) + content(2hex -> 1byte) - require.Len(t, tag, 4) + tag = EncodeResourceGroupTag(sqlDigest, nil, tipb.ResourceGroupTagLabel_ResourceGroupTagLabelUnknown) + // version(1) + prefix(1) + length(1) + content(2hex -> 1byte) + label(2) + require.Len(t, tag, 6) decodedSQLDigest, err = DecodeResourceGroupTag(tag) require.NoError(t, err) require.Equal(t, sqlDigest.Bytes(), decodedSQLDigest) sqlDigest = parser.NewDigest(genRandHex(64)) - tag = EncodeResourceGroupTag(sqlDigest, nil) + tag = EncodeResourceGroupTag(sqlDigest, nil, tipb.ResourceGroupTagLabel_ResourceGroupTagLabelUnknown) decodedSQLDigest, err = DecodeResourceGroupTag(tag) require.NoError(t, err) require.Equal(t, sqlDigest.Bytes(), decodedSQLDigest) sqlDigest = parser.NewDigest(genRandHex(510)) - tag = EncodeResourceGroupTag(sqlDigest, nil) + tag = EncodeResourceGroupTag(sqlDigest, nil, tipb.ResourceGroupTagLabel_ResourceGroupTagLabelUnknown) decodedSQLDigest, err = DecodeResourceGroupTag(tag) require.NoError(t, err) require.Equal(t, sqlDigest.Bytes(), decodedSQLDigest) @@ -93,6 +96,98 @@ func TestResourceGroupTagEncodingPB(t *testing.T) { require.Nil(t, tag.PlanDigest) } +func TestGetResourceGroupLabelByKey(t *testing.T) { + var label tipb.ResourceGroupTagLabel + // tablecodec.EncodeRowKey(0, []byte{}) + label = GetResourceGroupLabelByKey([]byte{116, 128, 0, 0, 0, 0, 0, 0, 0, 95, 114}) + require.Equal(t, tipb.ResourceGroupTagLabel_ResourceGroupTagLabelRow, label) + // tablecodec.EncodeIndexSeekKey(0, 0, []byte{})) + label = GetResourceGroupLabelByKey([]byte{116, 128, 0, 0, 0, 0, 0, 0, 0, 95, 105, 128, 0, 0, 0, 0, 0, 0, 0}) + require.Equal(t, tipb.ResourceGroupTagLabel_ResourceGroupTagLabelIndex, label) + label = GetResourceGroupLabelByKey([]byte("")) + require.Equal(t, tipb.ResourceGroupTagLabel_ResourceGroupTagLabelUnknown, label) +} + +func TestGetFirstKeyFromRequest(t *testing.T) { + var testK1 = []byte("TEST-1") + var testK2 = []byte("TEST-2") + var req *tikvrpc.Request + + require.Nil(t, GetFirstKeyFromRequest(nil)) + + req = &tikvrpc.Request{Req: (*kvrpcpb.GetRequest)(nil)} + require.Nil(t, GetFirstKeyFromRequest(req)) + req = &tikvrpc.Request{Req: &kvrpcpb.GetRequest{Key: nil}} + require.Nil(t, GetFirstKeyFromRequest(req)) + req = &tikvrpc.Request{Req: &kvrpcpb.GetRequest{Key: testK1}} + require.Equal(t, testK1, GetFirstKeyFromRequest(req)) + + req = &tikvrpc.Request{Req: (*kvrpcpb.BatchGetRequest)(nil)} + require.Nil(t, GetFirstKeyFromRequest(req)) + req = &tikvrpc.Request{Req: &kvrpcpb.BatchGetRequest{Keys: nil}} + require.Nil(t, GetFirstKeyFromRequest(req)) + req = &tikvrpc.Request{Req: &kvrpcpb.BatchGetRequest{Keys: [][]byte{}}} + require.Nil(t, GetFirstKeyFromRequest(req)) + req = &tikvrpc.Request{Req: &kvrpcpb.BatchGetRequest{Keys: [][]byte{testK2, testK1}}} + require.Equal(t, testK2, GetFirstKeyFromRequest(req)) + + req = &tikvrpc.Request{Req: (*kvrpcpb.ScanRequest)(nil)} + require.Nil(t, GetFirstKeyFromRequest(req)) + req = &tikvrpc.Request{Req: &kvrpcpb.ScanRequest{StartKey: nil}} + require.Nil(t, GetFirstKeyFromRequest(req)) + req = &tikvrpc.Request{Req: &kvrpcpb.ScanRequest{StartKey: testK1}} + require.Equal(t, testK1, GetFirstKeyFromRequest(req)) + + req = &tikvrpc.Request{Req: (*kvrpcpb.PrewriteRequest)(nil)} + require.Nil(t, GetFirstKeyFromRequest(req)) + req = &tikvrpc.Request{Req: &kvrpcpb.PrewriteRequest{Mutations: nil}} + require.Nil(t, GetFirstKeyFromRequest(req)) + req = &tikvrpc.Request{Req: &kvrpcpb.PrewriteRequest{Mutations: []*kvrpcpb.Mutation{}}} + require.Nil(t, GetFirstKeyFromRequest(req)) + req = &tikvrpc.Request{Req: &kvrpcpb.PrewriteRequest{Mutations: []*kvrpcpb.Mutation{{Key: testK2}, {Key: testK1}}}} + require.Equal(t, testK2, GetFirstKeyFromRequest(req)) + + req = &tikvrpc.Request{Req: (*kvrpcpb.CommitRequest)(nil)} + require.Nil(t, GetFirstKeyFromRequest(req)) + req = &tikvrpc.Request{Req: &kvrpcpb.CommitRequest{Keys: nil}} + require.Nil(t, GetFirstKeyFromRequest(req)) + req = &tikvrpc.Request{Req: &kvrpcpb.CommitRequest{Keys: [][]byte{}}} + require.Nil(t, GetFirstKeyFromRequest(req)) + req = &tikvrpc.Request{Req: &kvrpcpb.CommitRequest{Keys: [][]byte{testK1, testK1}}} + require.Equal(t, testK1, GetFirstKeyFromRequest(req)) + + req = &tikvrpc.Request{Req: (*kvrpcpb.BatchRollbackRequest)(nil)} + require.Nil(t, GetFirstKeyFromRequest(req)) + req = &tikvrpc.Request{Req: &kvrpcpb.BatchRollbackRequest{Keys: nil}} + require.Nil(t, GetFirstKeyFromRequest(req)) + req = &tikvrpc.Request{Req: &kvrpcpb.BatchRollbackRequest{Keys: [][]byte{}}} + require.Nil(t, GetFirstKeyFromRequest(req)) + req = &tikvrpc.Request{Req: &kvrpcpb.BatchRollbackRequest{Keys: [][]byte{testK2, testK1}}} + require.Equal(t, testK2, GetFirstKeyFromRequest(req)) + + req = &tikvrpc.Request{Req: (*coprocessor.Request)(nil)} + require.Nil(t, GetFirstKeyFromRequest(req)) + req = &tikvrpc.Request{Req: &coprocessor.Request{Ranges: nil}} + require.Nil(t, GetFirstKeyFromRequest(req)) + req = &tikvrpc.Request{Req: &coprocessor.Request{Ranges: []*coprocessor.KeyRange{}}} + require.Nil(t, GetFirstKeyFromRequest(req)) + req = &tikvrpc.Request{Req: &coprocessor.Request{Ranges: []*coprocessor.KeyRange{{Start: testK1}}}} + require.Equal(t, testK1, GetFirstKeyFromRequest(req)) + + req = &tikvrpc.Request{Req: (*coprocessor.BatchRequest)(nil)} + require.Nil(t, GetFirstKeyFromRequest(req)) + req = &tikvrpc.Request{Req: &coprocessor.BatchRequest{Regions: nil}} + require.Nil(t, GetFirstKeyFromRequest(req)) + req = &tikvrpc.Request{Req: &coprocessor.BatchRequest{Regions: []*coprocessor.RegionInfo{}}} + require.Nil(t, GetFirstKeyFromRequest(req)) + req = &tikvrpc.Request{Req: &coprocessor.BatchRequest{Regions: []*coprocessor.RegionInfo{{Ranges: nil}}}} + require.Nil(t, GetFirstKeyFromRequest(req)) + req = &tikvrpc.Request{Req: &coprocessor.BatchRequest{Regions: []*coprocessor.RegionInfo{{Ranges: []*coprocessor.KeyRange{}}}}} + require.Nil(t, GetFirstKeyFromRequest(req)) + req = &tikvrpc.Request{Req: &coprocessor.BatchRequest{Regions: []*coprocessor.RegionInfo{{Ranges: []*coprocessor.KeyRange{{Start: testK2}}}}}} + require.Equal(t, testK2, GetFirstKeyFromRequest(req)) +} + func genRandHex(length int) []byte { const chars = "0123456789abcdef" res := make([]byte, length)