From f7d66f49139b7eb2d5ce2baf50012c85fa70e787 Mon Sep 17 00:00:00 2001 From: YangKeao Date: Mon, 1 Apr 2024 16:27:16 +0800 Subject: [PATCH] distsql: use a more accurate type of the context of distsql (#51618) close pingcap/tidb#51617 --- pkg/ddl/index_cop.go | 6 +- pkg/ddl/reorg.go | 6 +- pkg/distsql/BUILD.bazel | 6 +- pkg/distsql/context/BUILD.bazel | 12 ++- pkg/distsql/context/context.go | 76 ++++++++++++-- pkg/distsql/context_test.go | 46 +++++++++ pkg/distsql/distsql.go | 66 ++++++------- pkg/distsql/distsql_test.go | 17 ++-- pkg/distsql/request_builder.go | 99 +++++++++---------- pkg/distsql/request_builder_test.go | 41 ++++---- pkg/distsql/select_result.go | 35 +++---- pkg/distsql/select_result_test.go | 4 +- pkg/executor/admin.go | 24 +++-- pkg/executor/analyze_col.go | 4 +- pkg/executor/analyze_idx.go | 6 +- pkg/executor/builder.go | 37 +++---- pkg/executor/checksum.go | 6 +- pkg/executor/distsql.go | 37 ++++--- pkg/executor/executor_pkg_test.go | 6 +- pkg/executor/index_merge_reader.go | 14 +-- .../internal/builder/builder_utils.go | 2 +- .../internal/mpp/local_mpp_coordinator.go | 6 +- pkg/executor/mpp_gather.go | 2 +- pkg/executor/simple.go | 2 +- pkg/executor/table_reader.go | 28 +++--- .../table_readers_required_rows_test.go | 2 +- pkg/planner/core/fragment.go | 4 +- pkg/session/session.go | 60 ++++++++++- pkg/sessionctx/context.go | 2 +- pkg/sessionctx/stmtctx/BUILD.bazel | 1 + pkg/sessionctx/stmtctx/stmtctx.go | 21 ++++ pkg/util/mock/context.go | 29 +++++- 32 files changed, 463 insertions(+), 244 deletions(-) create mode 100644 pkg/distsql/context_test.go diff --git a/pkg/ddl/index_cop.go b/pkg/ddl/index_cop.go index 4a6e058f3d779..ab8802686b69a 100644 --- a/pkg/ddl/index_cop.go +++ b/pkg/ddl/index_cop.go @@ -274,7 +274,7 @@ func buildTableScan(ctx context.Context, c *copr.CopContextBase, startTS uint64, SetStartTS(startTS). SetKeyRanges([]kv.KeyRange{{StartKey: start, EndKey: end}}). SetKeepOrder(true). - SetFromSessionVars(c.SessionContext.GetSessionVars()). + SetFromSessionVars(c.SessionContext.GetDistSQLCtx()). SetFromInfoSchema(c.SessionContext.GetDomainInfoSchema()). SetConcurrency(1). Build() @@ -284,7 +284,7 @@ func buildTableScan(ctx context.Context, c *copr.CopContextBase, startTS uint64, if err != nil { return nil, err } - return distsql.Select(ctx, c.SessionContext, kvReq, c.FieldTypes) + return distsql.Select(ctx, c.SessionContext.GetDistSQLCtx(), kvReq, c.FieldTypes) } func fetchTableScanResult( @@ -353,7 +353,7 @@ func buildDAGPB(sCtx sessionctx.Context, tblInfo *model.TableInfo, colInfos []*m return nil, err } dagReq.Executors = append(dagReq.Executors, execPB) - distsql.SetEncodeType(sCtx, dagReq) + distsql.SetEncodeType(sCtx.GetDistSQLCtx(), dagReq) return dagReq, nil } diff --git a/pkg/ddl/reorg.go b/pkg/ddl/reorg.go index 8095e63e8d2f2..c1061773dad8a 100644 --- a/pkg/ddl/reorg.go +++ b/pkg/ddl/reorg.go @@ -500,7 +500,7 @@ func buildDescTableScanDAG(ctx sessionctx.Context, tbl table.PhysicalTable, hand tblScanExec := constructDescTableScanPB(tbl.GetPhysicalID(), tbl.Meta(), handleCols) dagReq.Executors = append(dagReq.Executors, tblScanExec) dagReq.Executors = append(dagReq.Executors, constructLimitPB(limit)) - distsql.SetEncodeType(ctx, dagReq) + distsql.SetEncodeType(ctx.GetDistSQLCtx(), dagReq) return dagReq, nil } @@ -528,7 +528,7 @@ func (dc *ddlCtx) buildDescTableScan(ctx *JobContext, startTS uint64, tbl table. } else { ranges = ranger.FullIntRange(false) } - builder = b.SetHandleRanges(sctx.GetSessionVars().StmtCtx, tbl.GetPhysicalID(), tbl.Meta().IsCommonHandle, ranges) + builder = b.SetHandleRanges(sctx.GetDistSQLCtx(), tbl.GetPhysicalID(), tbl.Meta().IsCommonHandle, ranges) builder.SetDAGRequest(dagPB). SetStartTS(startTS). SetKeepOrder(true). @@ -547,7 +547,7 @@ func (dc *ddlCtx) buildDescTableScan(ctx *JobContext, startTS uint64, tbl table. return nil, errors.Trace(err) } - result, err := distsql.Select(ctx.ddlJobCtx, sctx, kvReq, getColumnsTypes(handleCols)) + result, err := distsql.Select(ctx.ddlJobCtx, sctx.GetDistSQLCtx(), kvReq, getColumnsTypes(handleCols)) if err != nil { return nil, errors.Trace(err) } diff --git a/pkg/distsql/BUILD.bazel b/pkg/distsql/BUILD.bazel index 4a2cf4a0e76d9..590727e0cb7b4 100644 --- a/pkg/distsql/BUILD.bazel +++ b/pkg/distsql/BUILD.bazel @@ -23,7 +23,6 @@ go_library( "//pkg/parser/mysql", "//pkg/parser/terror", "//pkg/planner/util", - "//pkg/sessionctx/stmtctx", "//pkg/sessionctx/variable", "//pkg/store/copr", "//pkg/tablecodec", @@ -36,6 +35,7 @@ go_library( "//pkg/util/logutil", "//pkg/util/memory", "//pkg/util/ranger", + "//pkg/util/topsql/stmtstats", "//pkg/util/tracing", "//pkg/util/trxevents", "@com_github_pingcap_errors//:errors", @@ -58,6 +58,7 @@ go_test( timeout = "short", srcs = [ "bench_test.go", + "context_test.go", "distsql_test.go", "main_test.go", "request_builder_test.go", @@ -68,7 +69,9 @@ go_test( race = "on", shard_count = 26, deps = [ + "//pkg/distsql/context", "//pkg/domain/resourcegroup", + "//pkg/errctx", "//pkg/kv", "//pkg/parser/charset", "//pkg/parser/model", @@ -84,6 +87,7 @@ go_test( "//pkg/util/chunk", "//pkg/util/codec", "//pkg/util/collate", + "//pkg/util/context", "//pkg/util/disk", "//pkg/util/execdetails", "//pkg/util/memory", diff --git a/pkg/distsql/context/BUILD.bazel b/pkg/distsql/context/BUILD.bazel index f0e4f4d8ae156..fe2811f78585f 100644 --- a/pkg/distsql/context/BUILD.bazel +++ b/pkg/distsql/context/BUILD.bazel @@ -6,7 +6,17 @@ go_library( importpath = "github.com/pingcap/tidb/pkg/distsql/context", visibility = ["//visibility:public"], deps = [ + "//pkg/domain/resourcegroup", + "//pkg/errctx", "//pkg/kv", - "//pkg/sessionctx/variable", + "//pkg/parser/mysql", + "//pkg/util/execdetails", + "//pkg/util/memory", + "//pkg/util/nocopy", + "//pkg/util/sqlkiller", + "//pkg/util/tiflash", + "//pkg/util/topsql/stmtstats", + "@com_github_tikv_client_go_v2//kv", + "@com_github_tikv_client_go_v2//tikvrpc", ], ) diff --git a/pkg/distsql/context/context.go b/pkg/distsql/context/context.go index 4fec70980f809..6625cdeac7376 100644 --- a/pkg/distsql/context/context.go +++ b/pkg/distsql/context/context.go @@ -15,14 +15,76 @@ package context import ( + "time" + + "github.com/pingcap/tidb/pkg/domain/resourcegroup" + "github.com/pingcap/tidb/pkg/errctx" "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/util/execdetails" + "github.com/pingcap/tidb/pkg/util/memory" + "github.com/pingcap/tidb/pkg/util/nocopy" + "github.com/pingcap/tidb/pkg/util/sqlkiller" + "github.com/pingcap/tidb/pkg/util/tiflash" + "github.com/pingcap/tidb/pkg/util/topsql/stmtstats" + tikvstore "github.com/tikv/client-go/v2/kv" + "github.com/tikv/client-go/v2/tikvrpc" ) -// DistSQLContext gives the interface -type DistSQLContext interface { - // GetSessionVars gets the session variables. - GetSessionVars() *variable.SessionVars - // GetClient gets a kv.Client. - GetClient() kv.Client +// DistSQLContext provides all information needed by using functions in `distsql` +type DistSQLContext struct { + // TODO: provide a `Clone` to copy this struct. + // The life cycle of some fields in this struct cannot be extended. For example, some fields will be recycled before + // the next execution. They'll need to be handled specially. + _ nocopy.NoCopy + + AppendWarning func(error) + InRestrictedSQL bool + Client kv.Client + + EnabledRateLimitAction bool + EnableChunkRPC bool + OriginalSQL string + KVVars *tikvstore.Variables + KvExecCounter *stmtstats.KvExecCounter + SessionMemTracker *memory.Tracker + + Location *time.Location + RuntimeStatsColl *execdetails.RuntimeStatsColl + SQLKiller *sqlkiller.SQLKiller + ErrCtx errctx.Context + + // TiFlash related configurations + TiFlashReplicaRead tiflash.ReplicaRead + TiFlashMaxThreads int64 + TiFlashMaxBytesBeforeExternalJoin int64 + TiFlashMaxBytesBeforeExternalGroupBy int64 + TiFlashMaxBytesBeforeExternalSort int64 + TiFlashMaxQueryMemoryPerNode int64 + TiFlashQuerySpillRatio float64 + + DistSQLConcurrency int + ReplicaReadType kv.ReplicaReadType + WeakConsistency bool + RCCheckTS bool + NotFillCache bool + TaskID uint64 + Priority mysql.PriorityEnum + ResourceGroupTagger tikvrpc.ResourceGroupTagger + EnablePaging bool + MinPagingSize int + MaxPagingSize int + RequestSourceType string + ExplicitRequestSourceType string + StoreBatchSize int + ResourceGroupName string + LoadBasedReplicaReadThreshold time.Duration + RunawayChecker *resourcegroup.RunawayChecker + TiKVClientReadTimeout uint64 + + ReplicaClosestReadThreshold int64 + ConnectionID uint64 + SessionAlias string + + ExecDetails *execdetails.SyncExecDetails } diff --git a/pkg/distsql/context_test.go b/pkg/distsql/context_test.go new file mode 100644 index 0000000000000..bca80c67a1e7c --- /dev/null +++ b/pkg/distsql/context_test.go @@ -0,0 +1,46 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package distsql + +import ( + distsqlctx "github.com/pingcap/tidb/pkg/distsql/context" + "github.com/pingcap/tidb/pkg/errctx" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + contextutil "github.com/pingcap/tidb/pkg/util/context" +) + +// NewDistSQLContextForTest creates a new dist sql context for test +func NewDistSQLContextForTest() *distsqlctx.DistSQLContext { + return &distsqlctx.DistSQLContext{ + AppendWarning: func(error) {}, + TiFlashMaxThreads: variable.DefTiFlashMaxThreads, + TiFlashMaxBytesBeforeExternalJoin: variable.DefTiFlashMaxBytesBeforeExternalJoin, + TiFlashMaxBytesBeforeExternalGroupBy: variable.DefTiFlashMaxBytesBeforeExternalGroupBy, + TiFlashMaxBytesBeforeExternalSort: variable.DefTiFlashMaxBytesBeforeExternalSort, + TiFlashMaxQueryMemoryPerNode: variable.DefTiFlashMemQuotaQueryPerNode, + TiFlashQuerySpillRatio: variable.DefTiFlashQuerySpillRatio, + + DistSQLConcurrency: variable.DefDistSQLScanConcurrency, + MinPagingSize: variable.DefMinPagingSize, + MaxPagingSize: variable.DefMaxPagingSize, + ResourceGroupName: "default", + + ErrCtx: errctx.NewContext(contextutil.IgnoreWarn), + } +} + +// DefaultDistSQLContext is an empty distsql context used for testing, which doesn't have a client and cannot be used to +// send requests. +var DefaultDistSQLContext = NewDistSQLContextForTest() diff --git a/pkg/distsql/distsql.go b/pkg/distsql/distsql.go index af8120f3f96ff..3949b30ea8dc9 100644 --- a/pkg/distsql/distsql.go +++ b/pkg/distsql/distsql.go @@ -24,10 +24,10 @@ import ( distsqlctx "github.com/pingcap/tidb/pkg/distsql/context" "github.com/pingcap/tidb/pkg/kv" "github.com/pingcap/tidb/pkg/metrics" - "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" "github.com/pingcap/tidb/pkg/sessionctx/variable" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/topsql/stmtstats" "github.com/pingcap/tidb/pkg/util/tracing" "github.com/pingcap/tidb/pkg/util/trxevents" "github.com/pingcap/tipb/go-tipb" @@ -37,7 +37,7 @@ import ( ) // GenSelectResultFromMPPResponse generates an iterator from response. -func GenSelectResultFromMPPResponse(dctx distsqlctx.DistSQLContext, fieldTypes []*types.FieldType, planIDs []int, rootID int, resp kv.Response) SelectResult { +func GenSelectResultFromMPPResponse(dctx *distsqlctx.DistSQLContext, fieldTypes []*types.FieldType, planIDs []int, rootID int, resp kv.Response) SelectResult { // TODO: Add metric label and set open tracing. return &selectResult{ label: "mpp", @@ -53,7 +53,7 @@ func GenSelectResultFromMPPResponse(dctx distsqlctx.DistSQLContext, fieldTypes [ // Select sends a DAG request, returns SelectResult. // In kvReq, KeyRanges is required, Concurrency/KeepOrder/Desc/IsolationLevel/Priority are optional. -func Select(ctx context.Context, dctx distsqlctx.DistSQLContext, kvReq *kv.Request, fieldTypes []*types.FieldType) (SelectResult, error) { +func Select(ctx context.Context, dctx *distsqlctx.DistSQLContext, kvReq *kv.Request, fieldTypes []*types.FieldType) (SelectResult, error) { r, ctx := tracing.StartRegionEx(ctx, "distsql.Select") defer r.End() @@ -62,8 +62,8 @@ func Select(ctx context.Context, dctx distsqlctx.DistSQLContext, kvReq *kv.Reque hook.(func(*kv.Request))(kvReq) } - enabledRateLimitAction := dctx.GetSessionVars().EnabledRateLimitAction - originalSQL := dctx.GetSessionVars().StmtCtx.OriginalSQL + enabledRateLimitAction := dctx.EnabledRateLimitAction + originalSQL := dctx.OriginalSQL eventCb := func(event trxevents.TransactionEvent) { // Note: Do not assume this callback will be invoked within the same goroutine. if copMeetLock := event.GetCopMeetLock(); copMeetLock != nil { @@ -74,27 +74,27 @@ func Select(ctx context.Context, dctx distsqlctx.DistSQLContext, kvReq *kv.Reque } } - ctx = WithSQLKvExecCounterInterceptor(ctx, dctx.GetSessionVars().StmtCtx) + ctx = WithSQLKvExecCounterInterceptor(ctx, dctx.KvExecCounter) option := &kv.ClientSendOption{ - SessionMemTracker: dctx.GetSessionVars().MemTracker, + SessionMemTracker: dctx.SessionMemTracker, EnabledRateLimitAction: enabledRateLimitAction, EventCb: eventCb, EnableCollectExecutionInfo: config.GetGlobalConfig().Instance.EnableCollectExecutionInfo.Load(), } if kvReq.StoreType == kv.TiFlash { - ctx = SetTiFlashConfVarsInContext(ctx, dctx.GetSessionVars()) - option.TiFlashReplicaRead = dctx.GetSessionVars().TiFlashReplicaRead - option.AppendWarning = dctx.GetSessionVars().StmtCtx.AppendWarning + ctx = SetTiFlashConfVarsInContext(ctx, dctx) + option.TiFlashReplicaRead = dctx.TiFlashReplicaRead + option.AppendWarning = dctx.AppendWarning } - resp := dctx.GetClient().Send(ctx, kvReq, dctx.GetSessionVars().KVVars, option) + resp := dctx.Client.Send(ctx, kvReq, dctx.KVVars, option) if resp == nil { return nil, errors.New("client returns nil response") } label := metrics.LblGeneral - if dctx.GetSessionVars().InRestrictedSQL { + if dctx.InRestrictedSQL { label = metrics.LblInternal } @@ -116,32 +116,32 @@ func Select(ctx context.Context, dctx distsqlctx.DistSQLContext, kvReq *kv.Reque } // SetTiFlashConfVarsInContext set some TiFlash config variables in context. -func SetTiFlashConfVarsInContext(ctx context.Context, vars *variable.SessionVars) context.Context { - if vars.TiFlashMaxThreads != -1 { - ctx = metadata.AppendToOutgoingContext(ctx, variable.TiDBMaxTiFlashThreads, strconv.FormatInt(vars.TiFlashMaxThreads, 10)) +func SetTiFlashConfVarsInContext(ctx context.Context, dctx *distsqlctx.DistSQLContext) context.Context { + if dctx.TiFlashMaxThreads != -1 { + ctx = metadata.AppendToOutgoingContext(ctx, variable.TiDBMaxTiFlashThreads, strconv.FormatInt(dctx.TiFlashMaxThreads, 10)) } - if vars.TiFlashMaxBytesBeforeExternalJoin != -1 { - ctx = metadata.AppendToOutgoingContext(ctx, variable.TiDBMaxBytesBeforeTiFlashExternalJoin, strconv.FormatInt(vars.TiFlashMaxBytesBeforeExternalJoin, 10)) + if dctx.TiFlashMaxBytesBeforeExternalJoin != -1 { + ctx = metadata.AppendToOutgoingContext(ctx, variable.TiDBMaxBytesBeforeTiFlashExternalJoin, strconv.FormatInt(dctx.TiFlashMaxBytesBeforeExternalJoin, 10)) } - if vars.TiFlashMaxBytesBeforeExternalGroupBy != -1 { - ctx = metadata.AppendToOutgoingContext(ctx, variable.TiDBMaxBytesBeforeTiFlashExternalGroupBy, strconv.FormatInt(vars.TiFlashMaxBytesBeforeExternalGroupBy, 10)) + if dctx.TiFlashMaxBytesBeforeExternalGroupBy != -1 { + ctx = metadata.AppendToOutgoingContext(ctx, variable.TiDBMaxBytesBeforeTiFlashExternalGroupBy, strconv.FormatInt(dctx.TiFlashMaxBytesBeforeExternalGroupBy, 10)) } - if vars.TiFlashMaxBytesBeforeExternalSort != -1 { - ctx = metadata.AppendToOutgoingContext(ctx, variable.TiDBMaxBytesBeforeTiFlashExternalSort, strconv.FormatInt(vars.TiFlashMaxBytesBeforeExternalSort, 10)) + if dctx.TiFlashMaxBytesBeforeExternalSort != -1 { + ctx = metadata.AppendToOutgoingContext(ctx, variable.TiDBMaxBytesBeforeTiFlashExternalSort, strconv.FormatInt(dctx.TiFlashMaxBytesBeforeExternalSort, 10)) } - if vars.TiFlashMaxQueryMemoryPerNode <= 0 { + if dctx.TiFlashMaxQueryMemoryPerNode <= 0 { ctx = metadata.AppendToOutgoingContext(ctx, variable.TiFlashMemQuotaQueryPerNode, "0") } else { - ctx = metadata.AppendToOutgoingContext(ctx, variable.TiFlashMemQuotaQueryPerNode, strconv.FormatInt(vars.TiFlashMaxQueryMemoryPerNode, 10)) + ctx = metadata.AppendToOutgoingContext(ctx, variable.TiFlashMemQuotaQueryPerNode, strconv.FormatInt(dctx.TiFlashMaxQueryMemoryPerNode, 10)) } - ctx = metadata.AppendToOutgoingContext(ctx, variable.TiFlashQuerySpillRatio, strconv.FormatFloat(vars.TiFlashQuerySpillRatio, 'f', -1, 64)) + ctx = metadata.AppendToOutgoingContext(ctx, variable.TiFlashQuerySpillRatio, strconv.FormatFloat(dctx.TiFlashQuerySpillRatio, 'f', -1, 64)) return ctx } // SelectWithRuntimeStats sends a DAG request, returns SelectResult. // The difference from Select is that SelectWithRuntimeStats will set copPlanIDs into selectResult, // which can help selectResult to collect runtime stats. -func SelectWithRuntimeStats(ctx context.Context, dctx distsqlctx.DistSQLContext, kvReq *kv.Request, +func SelectWithRuntimeStats(ctx context.Context, dctx *distsqlctx.DistSQLContext, kvReq *kv.Request, fieldTypes []*types.FieldType, copPlanIDs []int, rootPlanID int) (SelectResult, error) { sr, err := Select(ctx, dctx, kvReq, fieldTypes) if err != nil { @@ -156,8 +156,8 @@ func SelectWithRuntimeStats(ctx context.Context, dctx distsqlctx.DistSQLContext, // Analyze do a analyze request. func Analyze(ctx context.Context, client kv.Client, kvReq *kv.Request, vars any, - isRestrict bool, stmtCtx *stmtctx.StatementContext) (SelectResult, error) { - ctx = WithSQLKvExecCounterInterceptor(ctx, stmtCtx) + isRestrict bool, dctx *distsqlctx.DistSQLContext) (SelectResult, error) { + ctx = WithSQLKvExecCounterInterceptor(ctx, dctx.KvExecCounter) kvReq.RequestSource.RequestSourceInternal = true kvReq.RequestSource.RequestSourceType = kv.InternalTxnStats resp := client.Send(ctx, kvReq, vars, &kv.ClientSendOption{}) @@ -198,7 +198,7 @@ func Checksum(ctx context.Context, client kv.Client, kvReq *kv.Request, vars any // methods are: // 1. TypeChunk: the result is encoded using the Chunk format, refer util/chunk/chunk.go // 2. TypeDefault: the result is encoded row by row -func SetEncodeType(ctx distsqlctx.DistSQLContext, dagReq *tipb.DAGRequest) { +func SetEncodeType(ctx *distsqlctx.DistSQLContext, dagReq *tipb.DAGRequest) { if canUseChunkRPC(ctx) { dagReq.EncodeType = tipb.EncodeType_TypeChunk setChunkMemoryLayout(dagReq) @@ -207,8 +207,8 @@ func SetEncodeType(ctx distsqlctx.DistSQLContext, dagReq *tipb.DAGRequest) { } } -func canUseChunkRPC(ctx distsqlctx.DistSQLContext) bool { - if !ctx.GetSessionVars().EnableChunkRPC { +func canUseChunkRPC(ctx *distsqlctx.DistSQLContext) bool { + if !ctx.EnableChunkRPC { return false } if !checkAlignment() { @@ -250,12 +250,12 @@ func init() { // WithSQLKvExecCounterInterceptor binds an interceptor for client-go to count the // number of SQL executions of each TiKV (if any). -func WithSQLKvExecCounterInterceptor(ctx context.Context, stmtCtx *stmtctx.StatementContext) context.Context { - if stmtCtx.KvExecCounter != nil { +func WithSQLKvExecCounterInterceptor(ctx context.Context, counter *stmtstats.KvExecCounter) context.Context { + if counter != nil { // Unlike calling Transaction or Snapshot interface, in distsql package we directly // face tikv Request. So we need to manually bind RPCInterceptor to ctx. Instead of // calling SetRPCInterceptor on Transaction or Snapshot. - return interceptor.WithRPCInterceptor(ctx, stmtCtx.KvExecCounter.RPCInterceptor()) + return interceptor.WithRPCInterceptor(ctx, counter.RPCInterceptor()) } return ctx } diff --git a/pkg/distsql/distsql_test.go b/pkg/distsql/distsql_test.go index 00da879d3f59b..7031e494f636e 100644 --- a/pkg/distsql/distsql_test.go +++ b/pkg/distsql/distsql_test.go @@ -25,7 +25,6 @@ import ( "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" - "github.com/pingcap/tidb/pkg/sessionctx/variable" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/chunk" "github.com/pingcap/tidb/pkg/util/codec" @@ -161,7 +160,7 @@ func TestAnalyze(t *testing.T) { Build() require.NoError(t, err) - response, err := Analyze(context.TODO(), sctx.GetClient(), request, tikvstore.DefaultVars, true, sctx.GetSessionVars().StmtCtx) + response, err := Analyze(context.TODO(), sctx.GetClient(), request, tikvstore.DefaultVars, true, sctx.GetDistSQLCtx()) require.NoError(t, err) result, ok := response.(*selectResult) @@ -231,7 +230,7 @@ func (resp *mockResponse) Next(context.Context) (kv.ResultSubset, error) { resp.count += numRows var chunks []tipb.Chunk - if !canUseChunkRPC(resp.ctx) { + if !canUseChunkRPC(resp.ctx.GetDistSQLCtx()) { datum := types.NewIntDatum(1) bytes := make([]byte, 0, 100) bytes, _ = codec.EncodeValue(time.UTC, bytes, datum, datum, datum, datum) @@ -269,7 +268,7 @@ func (resp *mockResponse) Next(context.Context) (kv.ResultSubset, error) { Chunks: chunks, OutputCounts: []int64{1}, } - if canUseChunkRPC(resp.ctx) { + if canUseChunkRPC(resp.ctx.GetDistSQLCtx()) { respPB.EncodeType = tipb.EncodeType_TypeChunk } else { respPB.EncodeType = tipb.EncodeType_TypeDefault @@ -320,7 +319,7 @@ func createSelectNormalByBenchmarkTest(batch, totalRows int, ctx sessionctx.Cont SetDAGRequest(&tipb.DAGRequest{}). SetDesc(false). SetKeepOrder(false). - SetFromSessionVars(variable.NewSessionVars(nil)). + SetFromSessionVars(DefaultDistSQLContext). SetMemTracker(memory.NewTracker(-1, -1)). Build() @@ -336,7 +335,7 @@ func createSelectNormalByBenchmarkTest(batch, totalRows int, ctx sessionctx.Cont // Test Next. var response SelectResult - response, _ = Select(context.TODO(), ctx, request, colTypes) + response, _ = Select(context.TODO(), ctx.GetDistSQLCtx(), request, colTypes) result, _ := response.(*selectResult) resp, _ := result.resp.(*mockResponse) @@ -389,7 +388,7 @@ func createSelectNormal(t *testing.T, batch, totalRows int, planIDs []int, sctx SetDAGRequest(&tipb.DAGRequest{}). SetDesc(false). SetKeepOrder(false). - SetFromSessionVars(variable.NewSessionVars(nil)). + SetFromSessionVars(DefaultDistSQLContext). SetMemTracker(memory.NewTracker(-1, -1)). Build() require.NoError(t, err) @@ -411,9 +410,9 @@ func createSelectNormal(t *testing.T, batch, totalRows int, planIDs []int, sctx // Test Next. var response SelectResult if planIDs == nil { - response, err = Select(context.TODO(), sctx, request, colTypes) + response, err = Select(context.TODO(), sctx.GetDistSQLCtx(), request, colTypes) } else { - response, err = SelectWithRuntimeStats(context.TODO(), sctx, request, colTypes, planIDs, 1) + response, err = SelectWithRuntimeStats(context.TODO(), sctx.GetDistSQLCtx(), request, colTypes, planIDs, 1) } require.NoError(t, err) diff --git a/pkg/distsql/request_builder.go b/pkg/distsql/request_builder.go index af4da47bc7a3f..302579734d5c9 100644 --- a/pkg/distsql/request_builder.go +++ b/pkg/distsql/request_builder.go @@ -25,13 +25,12 @@ import ( "github.com/pingcap/failpoint" "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/tidb/pkg/ddl/placement" + distsqlctx "github.com/pingcap/tidb/pkg/distsql/context" "github.com/pingcap/tidb/pkg/errctx" "github.com/pingcap/tidb/pkg/infoschema" "github.com/pingcap/tidb/pkg/kv" "github.com/pingcap/tidb/pkg/parser/model" "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" - "github.com/pingcap/tidb/pkg/sessionctx/variable" "github.com/pingcap/tidb/pkg/tablecodec" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/codec" @@ -98,35 +97,35 @@ func (builder *RequestBuilder) SetTableRanges(tid int64, tableRanges []*ranger.R // SetIndexRanges sets "KeyRanges" for "kv.Request" by converting index range // "ranges" to "KeyRanges" firstly. -func (builder *RequestBuilder) SetIndexRanges(sc *stmtctx.StatementContext, tid, idxID int64, ranges []*ranger.Range) *RequestBuilder { +func (builder *RequestBuilder) SetIndexRanges(dctx *distsqlctx.DistSQLContext, tid, idxID int64, ranges []*ranger.Range) *RequestBuilder { if builder.err == nil { - builder.Request.KeyRanges, builder.err = IndexRangesToKVRanges(sc, tid, idxID, ranges) + builder.Request.KeyRanges, builder.err = IndexRangesToKVRanges(dctx, tid, idxID, ranges) } return builder } // SetIndexRangesForTables sets "KeyRanges" for "kv.Request" by converting multiple indexes range // "ranges" to "KeyRanges" firstly. -func (builder *RequestBuilder) SetIndexRangesForTables(sc *stmtctx.StatementContext, tids []int64, idxID int64, ranges []*ranger.Range) *RequestBuilder { +func (builder *RequestBuilder) SetIndexRangesForTables(dctx *distsqlctx.DistSQLContext, tids []int64, idxID int64, ranges []*ranger.Range) *RequestBuilder { if builder.err == nil { - builder.Request.KeyRanges, builder.err = IndexRangesToKVRangesForTables(sc, tids, idxID, ranges) + builder.Request.KeyRanges, builder.err = IndexRangesToKVRangesForTables(dctx, tids, idxID, ranges) } return builder } // SetHandleRanges sets "KeyRanges" for "kv.Request" by converting table handle range // "ranges" to "KeyRanges" firstly. -func (builder *RequestBuilder) SetHandleRanges(sc *stmtctx.StatementContext, tid int64, isCommonHandle bool, ranges []*ranger.Range) *RequestBuilder { - builder = builder.SetHandleRangesForTables(sc, []int64{tid}, isCommonHandle, ranges) +func (builder *RequestBuilder) SetHandleRanges(dctx *distsqlctx.DistSQLContext, tid int64, isCommonHandle bool, ranges []*ranger.Range) *RequestBuilder { + builder = builder.SetHandleRangesForTables(dctx, []int64{tid}, isCommonHandle, ranges) builder.err = builder.Request.KeyRanges.SetToNonPartitioned() return builder } // SetHandleRangesForTables sets "KeyRanges" for "kv.Request" by converting table handle range // "ranges" to "KeyRanges" firstly for multiple tables. -func (builder *RequestBuilder) SetHandleRangesForTables(sc *stmtctx.StatementContext, tid []int64, isCommonHandle bool, ranges []*ranger.Range) *RequestBuilder { +func (builder *RequestBuilder) SetHandleRangesForTables(dctx *distsqlctx.DistSQLContext, tid []int64, isCommonHandle bool, ranges []*ranger.Range) *RequestBuilder { if builder.err == nil { - builder.Request.KeyRanges, builder.err = TableHandleRangesToKVRanges(sc, tid, isCommonHandle, ranges) + builder.Request.KeyRanges, builder.err = TableHandleRangesToKVRanges(dctx, tid, isCommonHandle, ranges) } return builder } @@ -264,8 +263,8 @@ func (builder *RequestBuilder) getIsolationLevel() kv.IsoLevel { return kv.SI } -func (*RequestBuilder) getKVPriority(sv *variable.SessionVars) int { - switch sv.StmtCtx.Priority { +func (*RequestBuilder) getKVPriority(dctx *distsqlctx.DistSQLContext) int { + switch dctx.Priority { case mysql.NoPriority, mysql.DelayedPriority: return kv.PriorityNormal case mysql.LowPriority: @@ -279,8 +278,8 @@ func (*RequestBuilder) getKVPriority(sv *variable.SessionVars) int { // SetFromSessionVars sets the following fields for "kv.Request" from session variables: // "Concurrency", "IsolationLevel", "NotFillCache", "TaskID", "Priority", "ReplicaRead", // "ResourceGroupTagger", "ResourceGroupName" -func (builder *RequestBuilder) SetFromSessionVars(sv *variable.SessionVars) *RequestBuilder { - distsqlConcurrency := sv.DistSQLScanConcurrency() +func (builder *RequestBuilder) SetFromSessionVars(dctx *distsqlctx.DistSQLContext) *RequestBuilder { + distsqlConcurrency := dctx.DistSQLConcurrency if builder.Request.Concurrency == 0 { // Concurrency unset. builder.Request.Concurrency = distsqlConcurrency @@ -288,33 +287,33 @@ func (builder *RequestBuilder) SetFromSessionVars(sv *variable.SessionVars) *Req // Concurrency is set in SetDAGRequest, check the upper limit. builder.Request.Concurrency = distsqlConcurrency } - replicaReadType := sv.GetReplicaRead() - if sv.StmtCtx.WeakConsistency { + replicaReadType := dctx.ReplicaReadType + if dctx.WeakConsistency { builder.Request.IsolationLevel = kv.RC - } else if sv.StmtCtx.RCCheckTS { + } else if dctx.RCCheckTS { builder.Request.IsolationLevel = kv.RCCheckTS replicaReadType = kv.ReplicaReadLeader } else { builder.Request.IsolationLevel = builder.getIsolationLevel() } - builder.Request.NotFillCache = sv.StmtCtx.NotFillCache - builder.Request.TaskID = sv.StmtCtx.TaskID - builder.Request.Priority = builder.getKVPriority(sv) + builder.Request.NotFillCache = dctx.NotFillCache + builder.Request.TaskID = dctx.TaskID + builder.Request.Priority = builder.getKVPriority(dctx) builder.Request.ReplicaRead = replicaReadType - builder.SetResourceGroupTagger(sv.StmtCtx.GetResourceGroupTagger()) + builder.SetResourceGroupTagger(dctx.ResourceGroupTagger) { - builder.SetPaging(sv.EnablePaging) - builder.Request.Paging.MinPagingSize = uint64(sv.MinPagingSize) - builder.Request.Paging.MaxPagingSize = uint64(sv.MaxPagingSize) + builder.SetPaging(dctx.EnablePaging) + builder.Request.Paging.MinPagingSize = uint64(dctx.MinPagingSize) + builder.Request.Paging.MaxPagingSize = uint64(dctx.MaxPagingSize) } - builder.RequestSource.RequestSourceInternal = sv.InRestrictedSQL - builder.RequestSource.RequestSourceType = sv.RequestSourceType - builder.RequestSource.ExplicitRequestSourceType = sv.ExplicitRequestSourceType - builder.StoreBatchSize = sv.StoreBatchSize - builder.Request.ResourceGroupName = sv.StmtCtx.ResourceGroupName - builder.Request.StoreBusyThreshold = sv.LoadBasedReplicaReadThreshold - builder.Request.RunawayChecker = sv.StmtCtx.RunawayChecker - builder.Request.TiKVClientReadTimeout = sv.GetTiKVClientReadTimeout() + builder.RequestSource.RequestSourceInternal = dctx.InRestrictedSQL + builder.RequestSource.RequestSourceType = dctx.RequestSourceType + builder.RequestSource.ExplicitRequestSourceType = dctx.ExplicitRequestSourceType + builder.StoreBatchSize = dctx.StoreBatchSize + builder.Request.ResourceGroupName = dctx.ResourceGroupName + builder.Request.StoreBusyThreshold = dctx.LoadBasedReplicaReadThreshold + builder.Request.RunawayChecker = dctx.RunawayChecker + builder.Request.TiKVClientReadTimeout = dctx.TiKVClientReadTimeout return builder } @@ -439,11 +438,11 @@ func (builder *RequestBuilder) SetConnIDAndConnAlias(connID uint64, connAlias st } // TableHandleRangesToKVRanges convert table handle ranges to "KeyRanges" for multiple tables. -func TableHandleRangesToKVRanges(sc *stmtctx.StatementContext, tid []int64, isCommonHandle bool, ranges []*ranger.Range) (*kv.KeyRanges, error) { +func TableHandleRangesToKVRanges(dctx *distsqlctx.DistSQLContext, tid []int64, isCommonHandle bool, ranges []*ranger.Range) (*kv.KeyRanges, error) { if !isCommonHandle { return tablesRangesToKVRanges(tid, ranges), nil } - return CommonHandleRangesToKVRanges(sc, tid, ranges) + return CommonHandleRangesToKVRanges(dctx, tid, ranges) } // TableRangesToKVRanges converts table ranges to "KeyRange". @@ -631,14 +630,14 @@ func PartitionHandlesToKVRanges(handles []kv.Handle) ([]kv.KeyRange, []int) { } // IndexRangesToKVRanges converts index ranges to "KeyRange". -func IndexRangesToKVRanges(sc *stmtctx.StatementContext, tid, idxID int64, ranges []*ranger.Range) (*kv.KeyRanges, error) { - return IndexRangesToKVRangesWithInterruptSignal(sc, tid, idxID, ranges, nil, nil) +func IndexRangesToKVRanges(dctx *distsqlctx.DistSQLContext, tid, idxID int64, ranges []*ranger.Range) (*kv.KeyRanges, error) { + return IndexRangesToKVRangesWithInterruptSignal(dctx, tid, idxID, ranges, nil, nil) } // IndexRangesToKVRangesWithInterruptSignal converts index ranges to "KeyRange". // The process can be interrupted by set `interruptSignal` to true. -func IndexRangesToKVRangesWithInterruptSignal(sc *stmtctx.StatementContext, tid, idxID int64, ranges []*ranger.Range, memTracker *memory.Tracker, interruptSignal *atomic.Value) (*kv.KeyRanges, error) { - keyRanges, err := indexRangesToKVRangesForTablesWithInterruptSignal(sc, []int64{tid}, idxID, ranges, memTracker, interruptSignal) +func IndexRangesToKVRangesWithInterruptSignal(dctx *distsqlctx.DistSQLContext, tid, idxID int64, ranges []*ranger.Range, memTracker *memory.Tracker, interruptSignal *atomic.Value) (*kv.KeyRanges, error) { + keyRanges, err := indexRangesToKVRangesForTablesWithInterruptSignal(dctx, []int64{tid}, idxID, ranges, memTracker, interruptSignal) if err != nil { return nil, err } @@ -647,21 +646,21 @@ func IndexRangesToKVRangesWithInterruptSignal(sc *stmtctx.StatementContext, tid, } // IndexRangesToKVRangesForTables converts indexes ranges to "KeyRange". -func IndexRangesToKVRangesForTables(sc *stmtctx.StatementContext, tids []int64, idxID int64, ranges []*ranger.Range) (*kv.KeyRanges, error) { - return indexRangesToKVRangesForTablesWithInterruptSignal(sc, tids, idxID, ranges, nil, nil) +func IndexRangesToKVRangesForTables(dctx *distsqlctx.DistSQLContext, tids []int64, idxID int64, ranges []*ranger.Range) (*kv.KeyRanges, error) { + return indexRangesToKVRangesForTablesWithInterruptSignal(dctx, tids, idxID, ranges, nil, nil) } // IndexRangesToKVRangesForTablesWithInterruptSignal converts indexes ranges to "KeyRange". // The process can be interrupted by set `interruptSignal` to true. -func indexRangesToKVRangesForTablesWithInterruptSignal(sc *stmtctx.StatementContext, tids []int64, idxID int64, ranges []*ranger.Range, memTracker *memory.Tracker, interruptSignal *atomic.Value) (*kv.KeyRanges, error) { - return indexRangesToKVWithoutSplit(sc, tids, idxID, ranges, memTracker, interruptSignal) +func indexRangesToKVRangesForTablesWithInterruptSignal(dctx *distsqlctx.DistSQLContext, tids []int64, idxID int64, ranges []*ranger.Range, memTracker *memory.Tracker, interruptSignal *atomic.Value) (*kv.KeyRanges, error) { + return indexRangesToKVWithoutSplit(dctx, tids, idxID, ranges, memTracker, interruptSignal) } // CommonHandleRangesToKVRanges converts common handle ranges to "KeyRange". -func CommonHandleRangesToKVRanges(sc *stmtctx.StatementContext, tids []int64, ranges []*ranger.Range) (*kv.KeyRanges, error) { +func CommonHandleRangesToKVRanges(dctx *distsqlctx.DistSQLContext, tids []int64, ranges []*ranger.Range) (*kv.KeyRanges, error) { rans := make([]*ranger.Range, 0, len(ranges)) for _, ran := range ranges { - low, high, err := EncodeIndexKey(sc, ran) + low, high, err := EncodeIndexKey(dctx, ran) if err != nil { return nil, err } @@ -711,7 +710,7 @@ func VerifyTxnScope(txnScope string, physicalTableID int64, is infoschema.InfoSc return true } -func indexRangesToKVWithoutSplit(sc *stmtctx.StatementContext, tids []int64, idxID int64, ranges []*ranger.Range, memTracker *memory.Tracker, interruptSignal *atomic.Value) (*kv.KeyRanges, error) { +func indexRangesToKVWithoutSplit(dctx *distsqlctx.DistSQLContext, tids []int64, idxID int64, ranges []*ranger.Range, memTracker *memory.Tracker, interruptSignal *atomic.Value) (*kv.KeyRanges, error) { krs := make([][]kv.KeyRange, len(tids)) for i := range krs { krs[i] = make([]kv.KeyRange, 0, len(ranges)) @@ -722,7 +721,7 @@ func indexRangesToKVWithoutSplit(sc *stmtctx.StatementContext, tids []int64, idx // encodeIndexKey and EncodeIndexSeekKey is time-consuming, thus we need to // check the interrupt signal periodically. for i, ran := range ranges { - low, high, err := EncodeIndexKey(sc, ran) + low, high, err := EncodeIndexKey(dctx, ran) if err != nil { return nil, err } @@ -751,12 +750,12 @@ func indexRangesToKVWithoutSplit(sc *stmtctx.StatementContext, tids []int64, idx } // EncodeIndexKey gets encoded keys containing low and high -func EncodeIndexKey(sc *stmtctx.StatementContext, ran *ranger.Range) ([]byte, []byte, error) { +func EncodeIndexKey(dctx *distsqlctx.DistSQLContext, ran *ranger.Range) ([]byte, []byte, error) { tz := time.UTC errCtx := errctx.StrictNoWarningContext - if sc != nil { - tz = sc.TimeZone() - errCtx = sc.ErrCtx() + if dctx != nil { + tz = dctx.Location + errCtx = dctx.ErrCtx } low, err := codec.EncodeKey(tz, nil, ran.LowVal...) diff --git a/pkg/distsql/request_builder_test.go b/pkg/distsql/request_builder_test.go index 459c10131e146..ffe180400f787 100644 --- a/pkg/distsql/request_builder_test.go +++ b/pkg/distsql/request_builder_test.go @@ -22,7 +22,6 @@ import ( "github.com/pingcap/tidb/pkg/domain/resourcegroup" "github.com/pingcap/tidb/pkg/kv" "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" "github.com/pingcap/tidb/pkg/sessionctx/variable" "github.com/pingcap/tidb/pkg/tablecodec" "github.com/pingcap/tidb/pkg/types" @@ -192,7 +191,7 @@ func TestIndexRangesToKVRanges(t *testing.T) { }, } - actual, err := IndexRangesToKVRanges(stmtctx.NewStmtCtx(), 12, 15, ranges) + actual, err := IndexRangesToKVRanges(DefaultDistSQLContext, 12, 15, ranges) require.NoError(t, err) for i := range actual.FirstPartitionRange() { require.Equal(t, expect[i], actual.FirstPartitionRange()[i]) @@ -237,7 +236,7 @@ func TestRequestBuilder1(t *testing.T) { SetDAGRequest(&tipb.DAGRequest{}). SetDesc(false). SetKeepOrder(false). - SetFromSessionVars(variable.NewSessionVars(nil)). + SetFromSessionVars(DefaultDistSQLContext). Build() require.NoError(t, err) expect := &kv.Request{ @@ -317,11 +316,11 @@ func TestRequestBuilder2(t *testing.T) { }, } - actual, err := (&RequestBuilder{}).SetIndexRanges(stmtctx.NewStmtCtx(), 12, 15, ranges). + actual, err := (&RequestBuilder{}).SetIndexRanges(DefaultDistSQLContext, 12, 15, ranges). SetDAGRequest(&tipb.DAGRequest{}). SetDesc(false). SetKeepOrder(false). - SetFromSessionVars(variable.NewSessionVars(nil)). + SetFromSessionVars(DefaultDistSQLContext). Build() require.NoError(t, err) expect := &kv.Request{ @@ -375,7 +374,7 @@ func TestRequestBuilder3(t *testing.T) { SetDAGRequest(&tipb.DAGRequest{}). SetDesc(false). SetKeepOrder(false). - SetFromSessionVars(variable.NewSessionVars(nil)). + SetFromSessionVars(DefaultDistSQLContext). Build() require.NoError(t, err) expect := &kv.Request{ @@ -441,7 +440,7 @@ func TestRequestBuilder4(t *testing.T) { SetDAGRequest(&tipb.DAGRequest{}). SetDesc(false). SetKeepOrder(false). - SetFromSessionVars(variable.NewSessionVars(nil)). + SetFromSessionVars(DefaultDistSQLContext). Build() require.NoError(t, err) expect := &kv.Request{ @@ -549,12 +548,12 @@ func TestRequestBuilder7(t *testing.T) { // copy iterator variable into a new variable, see issue #27779 replicaRead := replicaRead t.Run(replicaRead.src, func(t *testing.T) { - vars := variable.NewSessionVars(nil) - vars.SetReplicaRead(replicaRead.replicaReadType) + dctx := NewDistSQLContextForTest() + dctx.ReplicaReadType = replicaRead.replicaReadType concurrency := 10 actual, err := (&RequestBuilder{}). - SetFromSessionVars(vars). + SetFromSessionVars(dctx). SetConcurrency(concurrency). Build() require.NoError(t, err) @@ -581,10 +580,10 @@ func TestRequestBuilder7(t *testing.T) { } func TestRequestBuilder8(t *testing.T) { - sv := variable.NewSessionVars(nil) - sv.StmtCtx.ResourceGroupName = "test" + dctx := NewDistSQLContextForTest() + dctx.ResourceGroupName = "test" actual, err := (&RequestBuilder{}). - SetFromSessionVars(sv). + SetFromSessionVars(dctx). Build() require.NoError(t, err) expect := &kv.Request{ @@ -607,10 +606,10 @@ func TestRequestBuilder8(t *testing.T) { } func TestRequestBuilderTiKVClientReadTimeout(t *testing.T) { - sv := variable.NewSessionVars(nil) - sv.TiKVClientReadTimeout = 100 + dctx := NewDistSQLContextForTest() + dctx.TiKVClientReadTimeout = 100 actual, err := (&RequestBuilder{}). - SetFromSessionVars(sv). + SetFromSessionVars(dctx). Build() require.NoError(t, err) expect := &kv.Request{ @@ -662,7 +661,7 @@ func TestIndexRangesToKVRangesWithFbs(t *testing.T) { Collators: collate.GetBinaryCollatorSlice(1), }, } - actual, err := IndexRangesToKVRanges(stmtctx.NewStmtCtx(), 0, 0, ranges) + actual, err := IndexRangesToKVRanges(DefaultDistSQLContext, 0, 0, ranges) require.NoError(t, err) expect := []kv.KeyRange{ { @@ -676,7 +675,7 @@ func TestIndexRangesToKVRangesWithFbs(t *testing.T) { } func TestScanLimitConcurrency(t *testing.T) { - vars := variable.NewSessionVars(nil) + dctx := NewDistSQLContextForTest() for _, tt := range []struct { tp tipb.ExecType limit uint64 @@ -685,8 +684,8 @@ func TestScanLimitConcurrency(t *testing.T) { }{ {tipb.ExecType_TypeTableScan, 1, 1, "TblScan_Def"}, {tipb.ExecType_TypeIndexScan, 1, 1, "IdxScan_Def"}, - {tipb.ExecType_TypeTableScan, 1000000, vars.Concurrency.DistSQLScanConcurrency(), "TblScan_SessionVars"}, - {tipb.ExecType_TypeIndexScan, 1000000, vars.Concurrency.DistSQLScanConcurrency(), "IdxScan_SessionVars"}, + {tipb.ExecType_TypeTableScan, 1000000, dctx.DistSQLConcurrency, "TblScan_SessionVars"}, + {tipb.ExecType_TypeIndexScan, 1000000, dctx.DistSQLConcurrency, "IdxScan_SessionVars"}, } { // copy iterator variable into a new variable, see issue #27779 tt := tt @@ -703,7 +702,7 @@ func TestScanLimitConcurrency(t *testing.T) { dag := &tipb.DAGRequest{Executors: []*tipb.Executor{firstExec, limitExec}} actual, err := (&RequestBuilder{}). SetDAGRequest(dag). - SetFromSessionVars(vars). + SetFromSessionVars(dctx). Build() require.NoError(t, err) require.Equal(t, tt.concurrency, actual.Concurrency) diff --git a/pkg/distsql/select_result.go b/pkg/distsql/select_result.go index d298be1ff4b61..ac7f8713204d4 100644 --- a/pkg/distsql/select_result.go +++ b/pkg/distsql/select_result.go @@ -33,7 +33,6 @@ import ( "github.com/pingcap/tidb/pkg/metrics" "github.com/pingcap/tidb/pkg/parser/terror" "github.com/pingcap/tidb/pkg/planner/util" - "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" "github.com/pingcap/tidb/pkg/store/copr" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/chunk" @@ -286,7 +285,7 @@ type selectResult struct { rowLen int fieldTypes []*types.FieldType - ctx dcontext.DistSQLContext + ctx *dcontext.DistSQLContext selectResp *tipb.SelectResponse selectRespSize int64 // record the selectResp.Size() when it is initialized. @@ -353,13 +352,11 @@ func (r *selectResult) fetchResp(ctx context.Context) error { if err := r.selectResp.Error; err != nil { return dbterror.ClassTiKV.Synthesize(terror.ErrCode(err.Code), err.Msg) } - sessVars := r.ctx.GetSessionVars() - if err = sessVars.SQLKiller.HandleSignal(); err != nil { + if err = r.ctx.SQLKiller.HandleSignal(); err != nil { return err } - sc := sessVars.StmtCtx for _, warning := range r.selectResp.Warnings { - sc.AppendWarning(dbterror.ClassTiKV.Synthesize(terror.ErrCode(warning.Code), warning.Msg)) + r.ctx.AppendWarning(dbterror.ClassTiKV.Synthesize(terror.ErrCode(warning.Code), warning.Msg)) } r.partialCount++ @@ -372,7 +369,7 @@ func (r *selectResult) fetchResp(ctx context.Context) error { return err } copStats.CopTime = duration - sc.MergeExecDetails(&copStats.ExecDetails, nil) + r.ctx.ExecDetails.MergeExecDetails(&copStats.ExecDetails, nil) } } if len(r.selectResp.Chunks) != 0 { @@ -477,32 +474,32 @@ func (r *selectResult) readFromChunk(ctx context.Context, chk *chunk.Chunk) erro } // FillDummySummariesForTiFlashTasks fills dummy execution summaries for mpp tasks which lack summaries -func FillDummySummariesForTiFlashTasks(sctx *stmtctx.StatementContext, callee string, storeTypeName string, allPlanIDs []int, recordedPlanIDs map[int]int) { +func FillDummySummariesForTiFlashTasks(runtimeStatsColl *execdetails.RuntimeStatsColl, callee string, storeTypeName string, allPlanIDs []int, recordedPlanIDs map[int]int) { num := uint64(0) dummySummary := &tipb.ExecutorExecutionSummary{TimeProcessedNs: &num, NumProducedRows: &num, NumIterations: &num, ExecutorId: nil} for _, planID := range allPlanIDs { if _, ok := recordedPlanIDs[planID]; !ok { - sctx.RuntimeStatsColl.RecordOneCopTask(planID, storeTypeName, callee, dummySummary) + runtimeStatsColl.RecordOneCopTask(planID, storeTypeName, callee, dummySummary) } } } // recordExecutionSummariesForTiFlashTasks records mpp task execution summaries -func recordExecutionSummariesForTiFlashTasks(sctx *stmtctx.StatementContext, executionSummaries []*tipb.ExecutorExecutionSummary, callee string, storeTypeName string, allPlanIDs []int) { +func recordExecutionSummariesForTiFlashTasks(runtimeStatsColl *execdetails.RuntimeStatsColl, executionSummaries []*tipb.ExecutorExecutionSummary, callee string, storeTypeName string, allPlanIDs []int) { var recordedPlanIDs = make(map[int]int) for _, detail := range executionSummaries { if detail != nil && detail.TimeProcessedNs != nil && detail.NumProducedRows != nil && detail.NumIterations != nil { - recordedPlanIDs[sctx.RuntimeStatsColl. + recordedPlanIDs[runtimeStatsColl. RecordOneCopTask(-1, storeTypeName, callee, detail)] = 0 } } - FillDummySummariesForTiFlashTasks(sctx, callee, storeTypeName, allPlanIDs, recordedPlanIDs) + FillDummySummariesForTiFlashTasks(runtimeStatsColl, callee, storeTypeName, allPlanIDs, recordedPlanIDs) } func (r *selectResult) updateCopRuntimeStats(ctx context.Context, copStats *copr.CopRuntimeStats, respTime time.Duration) (err error) { callee := copStats.CalleeAddress - if r.rootPlanID <= 0 || r.ctx.GetSessionVars().StmtCtx.RuntimeStatsColl == nil || (callee == "" && len(copStats.Stats) == 0) { + if r.rootPlanID <= 0 || r.ctx.RuntimeStatsColl == nil || (callee == "" && len(copStats.Stats) == 0) { return } @@ -528,7 +525,7 @@ func (r *selectResult) updateCopRuntimeStats(ctx context.Context, copStats *copr r.stats.mergeCopRuntimeStats(copStats, respTime) if copStats.ScanDetail != nil && len(r.copPlanIDs) > 0 { - r.ctx.GetSessionVars().StmtCtx.RuntimeStatsColl.RecordScanDetail(r.copPlanIDs[len(r.copPlanIDs)-1], r.storeType.Name(), copStats.ScanDetail) + r.ctx.RuntimeStatsColl.RecordScanDetail(r.copPlanIDs[len(r.copPlanIDs)-1], r.storeType.Name(), copStats.ScanDetail) } // If hasExecutor is true, it means the summary is returned from TiFlash. @@ -549,7 +546,7 @@ func (r *selectResult) updateCopRuntimeStats(ctx context.Context, copStats *copr } } if hasExecutor { - recordExecutionSummariesForTiFlashTasks(r.ctx.GetSessionVars().StmtCtx, r.selectResp.GetExecutionSummaries(), callee, r.storeType.Name(), r.copPlanIDs) + recordExecutionSummariesForTiFlashTasks(r.ctx.RuntimeStatsColl, r.selectResp.GetExecutionSummaries(), callee, r.storeType.Name(), r.copPlanIDs) } else { // For cop task cases, we still need this protection. if len(r.selectResp.GetExecutionSummaries()) != len(r.copPlanIDs) { @@ -567,7 +564,7 @@ func (r *selectResult) updateCopRuntimeStats(ctx context.Context, copStats *copr if detail != nil && detail.TimeProcessedNs != nil && detail.NumProducedRows != nil && detail.NumIterations != nil { planID := r.copPlanIDs[i] - r.ctx.GetSessionVars().StmtCtx.RuntimeStatsColl. + r.ctx.RuntimeStatsColl. RecordOneCopTask(planID, r.storeType.Name(), callee, detail) } } @@ -577,7 +574,7 @@ func (r *selectResult) updateCopRuntimeStats(ctx context.Context, copStats *copr func (r *selectResult) readRowsData(chk *chunk.Chunk) (err error) { rowsData := r.selectResp.Chunks[r.respChkIdx].RowsData - decoder := codec.NewDecoder(chk, r.ctx.GetSessionVars().Location()) + decoder := codec.NewDecoder(chk, r.ctx.Location) for !chk.IsFull() && len(rowsData) > 0 { for i := 0; i < r.rowLen; i++ { rowsData, err = decoder.DecodeOne(rowsData, i, r.fieldTypes[i]) @@ -608,7 +605,7 @@ func (r *selectResult) Close() error { unconsumedCopStats := unconsumed.CollectUnconsumedCopRuntimeStats() for _, copStats := range unconsumedCopStats { _ = r.updateCopRuntimeStats(context.Background(), copStats, time.Duration(0)) - r.ctx.GetSessionVars().StmtCtx.MergeExecDetails(&copStats.ExecDetails, nil) + r.ctx.ExecDetails.MergeExecDetails(&copStats.ExecDetails, nil) } } } @@ -621,7 +618,7 @@ func (r *selectResult) Close() error { r.stats.storeBatchedNum, r.stats.storeBatchedFallbackNum = batched, fallback } } - r.ctx.GetSessionVars().StmtCtx.RuntimeStatsColl.RegisterStats(r.rootPlanID, r.stats) + r.ctx.RuntimeStatsColl.RegisterStats(r.rootPlanID, r.stats) }() } return r.resp.Close() diff --git a/pkg/distsql/select_result_test.go b/pkg/distsql/select_result_test.go index 63ae252a836a7..478f9dcd8d205 100644 --- a/pkg/distsql/select_result_test.go +++ b/pkg/distsql/select_result_test.go @@ -30,13 +30,15 @@ import ( func TestUpdateCopRuntimeStats(t *testing.T) { ctx := mock.NewContext() ctx.GetSessionVars().StmtCtx = stmtctx.NewStmtCtx() - sr := selectResult{ctx: ctx, storeType: kv.TiKV} + sr := selectResult{ctx: ctx.GetDistSQLCtx(), storeType: kv.TiKV} require.Nil(t, ctx.GetSessionVars().StmtCtx.RuntimeStatsColl) sr.rootPlanID = 1234 sr.updateCopRuntimeStats(context.Background(), &copr.CopRuntimeStats{ExecDetails: execdetails.ExecDetails{DetailsNeedP90: execdetails.DetailsNeedP90{CalleeAddress: "a"}}}, 0) ctx.GetSessionVars().StmtCtx.RuntimeStatsColl = execdetails.NewRuntimeStatsColl(nil) + // refresh the ctx after assigning `RuntimeStatsColl`. + sr.ctx = ctx.GetDistSQLCtx() i := uint64(1) sr.selectResp = &tipb.SelectResponse{ ExecutionSummaries: []*tipb.ExecutorExecutionSummary{ diff --git a/pkg/executor/admin.go b/pkg/executor/admin.go index 9f3d0257c6c34..b75fae4f4c5db 100644 --- a/pkg/executor/admin.go +++ b/pkg/executor/admin.go @@ -116,17 +116,16 @@ func (e *CheckIndexRangeExec) Open(ctx context.Context) error { if err != nil { return err } - sc := e.Ctx().GetSessionVars().StmtCtx txn, err := e.Ctx().Txn(true) if err != nil { return nil } var builder distsql.RequestBuilder - kvReq, err := builder.SetIndexRanges(sc, e.table.ID, e.index.ID, ranger.FullRange()). + kvReq, err := builder.SetIndexRanges(e.Ctx().GetDistSQLCtx(), e.table.ID, e.index.ID, ranger.FullRange()). SetDAGRequest(dagPB). SetStartTS(txn.StartTS()). SetKeepOrder(true). - SetFromSessionVars(e.Ctx().GetSessionVars()). + SetFromSessionVars(e.Ctx().GetDistSQLCtx()). SetFromInfoSchema(e.Ctx().GetInfoSchema()). SetConnIDAndConnAlias(e.Ctx().GetSessionVars().ConnectionID, e.Ctx().GetSessionVars().SessionAlias). Build() @@ -134,7 +133,7 @@ func (e *CheckIndexRangeExec) Open(ctx context.Context) error { return err } - e.result, err = distsql.Select(ctx, e.Ctx(), kvReq, e.RetFieldTypes()) + e.result, err = distsql.Select(ctx, e.Ctx().GetDistSQLCtx(), kvReq, e.RetFieldTypes()) if err != nil { return err } @@ -156,7 +155,7 @@ func (e *CheckIndexRangeExec) buildDAGPB() (*tipb.DAGRequest, error) { if err != nil { return nil, err } - distsql.SetEncodeType(e.Ctx(), dagReq) + distsql.SetEncodeType(e.Ctx().GetDistSQLCtx(), dagReq) return dagReq, nil } @@ -259,7 +258,7 @@ func (e *RecoverIndexExec) buildDAGPB(_ kv.Transaction, limitCnt uint64) (*tipb. limitExec := e.constructLimitPB(limitCnt) dagReq.Executors = append(dagReq.Executors, limitExec) - distsql.SetEncodeType(e.Ctx(), dagReq) + distsql.SetEncodeType(e.Ctx().GetDistSQLCtx(), dagReq) return dagReq, nil } @@ -278,7 +277,7 @@ func (e *RecoverIndexExec) buildTableScan(ctx context.Context, txn kv.Transactio SetDAGRequest(dagPB). SetStartTS(txn.StartTS()). SetKeepOrder(true). - SetFromSessionVars(e.Ctx().GetSessionVars()). + SetFromSessionVars(e.Ctx().GetDistSQLCtx()). SetFromInfoSchema(e.Ctx().GetInfoSchema()). SetConnIDAndConnAlias(e.Ctx().GetSessionVars().ConnectionID, e.Ctx().GetSessionVars().SessionAlias). Build() @@ -289,7 +288,7 @@ func (e *RecoverIndexExec) buildTableScan(ctx context.Context, txn kv.Transactio // Actually, with limitCnt, the match datas maybe only in one region, so let the concurrency to be 1, // avoid unnecessary region scan. kvReq.Concurrency = 1 - result, err := distsql.Select(ctx, e.Ctx(), kvReq, e.columnsTypes()) + result, err := distsql.Select(ctx, e.Ctx().GetDistSQLCtx(), kvReq, e.columnsTypes()) if err != nil { return nil, err } @@ -792,10 +791,9 @@ func (e *CleanupIndexExec) buildIndexScan(ctx context.Context, txn kv.Transactio if err != nil { return nil, err } - sc := e.Ctx().GetSessionVars().StmtCtx var builder distsql.RequestBuilder ranges := ranger.FullRange() - keyRanges, err := distsql.IndexRangesToKVRanges(sc, e.physicalID, e.index.Meta().ID, ranges) + keyRanges, err := distsql.IndexRangesToKVRanges(e.Ctx().GetDistSQLCtx(), e.physicalID, e.index.Meta().ID, ranges) if err != nil { return nil, err } @@ -808,7 +806,7 @@ func (e *CleanupIndexExec) buildIndexScan(ctx context.Context, txn kv.Transactio SetDAGRequest(dagPB). SetStartTS(txn.StartTS()). SetKeepOrder(true). - SetFromSessionVars(e.Ctx().GetSessionVars()). + SetFromSessionVars(e.Ctx().GetDistSQLCtx()). SetFromInfoSchema(e.Ctx().GetInfoSchema()). SetConnIDAndConnAlias(e.Ctx().GetSessionVars().ConnectionID, e.Ctx().GetSessionVars().SessionAlias). Build() @@ -817,7 +815,7 @@ func (e *CleanupIndexExec) buildIndexScan(ctx context.Context, txn kv.Transactio } kvReq.Concurrency = 1 - result, err := distsql.Select(ctx, e.Ctx(), kvReq, e.getIdxColTypes()) + result, err := distsql.Select(ctx, e.Ctx().GetDistSQLCtx(), kvReq, e.getIdxColTypes()) if err != nil { return nil, err } @@ -864,7 +862,7 @@ func (e *CleanupIndexExec) buildIdxDAGPB() (*tipb.DAGRequest, error) { limitExec := e.constructLimitPB() dagReq.Executors = append(dagReq.Executors, limitExec) - distsql.SetEncodeType(e.Ctx(), dagReq) + distsql.SetEncodeType(e.Ctx().GetDistSQLCtx(), dagReq) return dagReq, nil } diff --git a/pkg/executor/analyze_col.go b/pkg/executor/analyze_col.go index 99f9af2b44c76..989339daa4233 100644 --- a/pkg/executor/analyze_col.go +++ b/pkg/executor/analyze_col.go @@ -106,7 +106,7 @@ func (e *AnalyzeColumnsExec) open(ranges []*ranger.Range) error { func (e *AnalyzeColumnsExec) buildResp(ranges []*ranger.Range) (distsql.SelectResult, error) { var builder distsql.RequestBuilder - reqBuilder := builder.SetHandleRangesForTables(e.ctx.GetSessionVars().StmtCtx, []int64{e.TableID.GetStatisticsID()}, e.handleCols != nil && !e.handleCols.IsInt(), ranges) + reqBuilder := builder.SetHandleRangesForTables(e.ctx.GetDistSQLCtx(), []int64{e.TableID.GetStatisticsID()}, e.handleCols != nil && !e.handleCols.IsInt(), ranges) builder.SetResourceGroupTagger(e.ctx.GetSessionVars().StmtCtx.GetResourceGroupTagger()) startTS := uint64(math.MaxUint64) isoLevel := kv.RC @@ -129,7 +129,7 @@ func (e *AnalyzeColumnsExec) buildResp(ranges []*ranger.Range) (distsql.SelectRe return nil, err } ctx := context.TODO() - result, err := distsql.Analyze(ctx, e.ctx.GetClient(), kvReq, e.ctx.GetSessionVars().KVVars, e.ctx.GetSessionVars().InRestrictedSQL, e.ctx.GetSessionVars().StmtCtx) + result, err := distsql.Analyze(ctx, e.ctx.GetClient(), kvReq, e.ctx.GetSessionVars().KVVars, e.ctx.GetSessionVars().InRestrictedSQL, e.ctx.GetDistSQLCtx()) if err != nil { return nil, err } diff --git a/pkg/executor/analyze_idx.go b/pkg/executor/analyze_idx.go index 9bdf6a8455a68..a5b62957dfab6 100644 --- a/pkg/executor/analyze_idx.go +++ b/pkg/executor/analyze_idx.go @@ -143,9 +143,9 @@ func (e *AnalyzeIndexExec) fetchAnalyzeResult(ranges []*ranger.Range, isNullRang var builder distsql.RequestBuilder var kvReqBuilder *distsql.RequestBuilder if e.isCommonHandle && e.idxInfo.Primary { - kvReqBuilder = builder.SetHandleRangesForTables(e.ctx.GetSessionVars().StmtCtx, []int64{e.tableID.GetStatisticsID()}, true, ranges) + kvReqBuilder = builder.SetHandleRangesForTables(e.ctx.GetDistSQLCtx(), []int64{e.tableID.GetStatisticsID()}, true, ranges) } else { - kvReqBuilder = builder.SetIndexRangesForTables(e.ctx.GetSessionVars().StmtCtx, []int64{e.tableID.GetStatisticsID()}, e.idxInfo.ID, ranges) + kvReqBuilder = builder.SetIndexRangesForTables(e.ctx.GetDistSQLCtx(), []int64{e.tableID.GetStatisticsID()}, e.idxInfo.ID, ranges) } kvReqBuilder.SetResourceGroupTagger(e.ctx.GetSessionVars().StmtCtx.GetResourceGroupTagger()) startTS := uint64(math.MaxUint64) @@ -166,7 +166,7 @@ func (e *AnalyzeIndexExec) fetchAnalyzeResult(ranges []*ranger.Range, isNullRang return err } ctx := context.TODO() - result, err := distsql.Analyze(ctx, e.ctx.GetClient(), kvReq, e.ctx.GetSessionVars().KVVars, e.ctx.GetSessionVars().InRestrictedSQL, e.ctx.GetSessionVars().StmtCtx) + result, err := distsql.Analyze(ctx, e.ctx.GetClient(), kvReq, e.ctx.GetSessionVars().KVVars, e.ctx.GetSessionVars().InRestrictedSQL, e.ctx.GetDistSQLCtx()) if err != nil { return err } diff --git a/pkg/executor/builder.go b/pkg/executor/builder.go index aa43633b1cf84..e6fc4eb67da81 100644 --- a/pkg/executor/builder.go +++ b/pkg/executor/builder.go @@ -36,6 +36,7 @@ import ( "github.com/pingcap/tidb/pkg/ddl" "github.com/pingcap/tidb/pkg/ddl/placement" "github.com/pingcap/tidb/pkg/distsql" + distsqlctx "github.com/pingcap/tidb/pkg/distsql/context" "github.com/pingcap/tidb/pkg/domain" "github.com/pingcap/tidb/pkg/executor/aggfuncs" "github.com/pingcap/tidb/pkg/executor/aggregate" @@ -4094,7 +4095,7 @@ func (builder *dataReaderBuilder) buildTableReaderForIndexJoin(ctx context.Conte tbInfo := e.table.Meta() if tbInfo.GetPartitionInfo() == nil || !builder.ctx.GetSessionVars().StmtCtx.UseDynamicPartitionPrune() { if v.IsCommonHandle { - kvRanges, err := buildKvRangesForIndexJoin(e.GetSessionVars().StmtCtx, e.pctx, getPhysicalTableID(e.table), -1, lookUpContents, indexRanges, keyOff2IdxOff, cwc, memTracker, interruptSignal) + kvRanges, err := buildKvRangesForIndexJoin(e.dctx, e.pctx, getPhysicalTableID(e.table), -1, lookUpContents, indexRanges, keyOff2IdxOff, cwc, memTracker, interruptSignal) if err != nil { return nil, err } @@ -4144,7 +4145,7 @@ func (builder *dataReaderBuilder) buildTableReaderForIndexJoin(ctx context.Conte } for pid, contents := range lookUpContentsByPID { // buildKvRanges for each partition. - tmp, err := buildKvRangesForIndexJoin(e.GetSessionVars().StmtCtx, e.pctx, pid, -1, contents, indexRanges, keyOff2IdxOff, cwc, nil, interruptSignal) + tmp, err := buildKvRangesForIndexJoin(e.dctx, e.pctx, pid, -1, contents, indexRanges, keyOff2IdxOff, cwc, nil, interruptSignal) if err != nil { return nil, err } @@ -4153,7 +4154,7 @@ func (builder *dataReaderBuilder) buildTableReaderForIndexJoin(ctx context.Conte } else { kvRanges = make([]kv.KeyRange, 0, len(usedPartitions)*len(lookUpContents)) for _, p := range usedPartitionList { - tmp, err := buildKvRangesForIndexJoin(e.GetSessionVars().StmtCtx, e.pctx, p.GetPhysicalID(), -1, lookUpContents, indexRanges, keyOff2IdxOff, cwc, memTracker, interruptSignal) + tmp, err := buildKvRangesForIndexJoin(e.dctx, e.pctx, p.GetPhysicalID(), -1, lookUpContents, indexRanges, keyOff2IdxOff, cwc, memTracker, interruptSignal) if err != nil { return nil, err } @@ -4241,7 +4242,7 @@ func (h kvRangeBuilderFromRangeAndPartition) buildKeyRangeSeparately(ranges []*r if len(ranges) == 0 { continue } - kvRange, err := distsql.TableHandleRangesToKVRanges(h.sctx.GetSessionVars().StmtCtx, []int64{pid}, meta != nil && meta.IsCommonHandle, ranges) + kvRange, err := distsql.TableHandleRangesToKVRanges(h.sctx.GetDistSQLCtx(), []int64{pid}, meta != nil && meta.IsCommonHandle, ranges) if err != nil { return nil, nil, err } @@ -4258,7 +4259,7 @@ func (h kvRangeBuilderFromRangeAndPartition) buildKeyRange(ranges []*ranger.Rang for i, p := range h.partitions { pid := p.GetPhysicalID() meta := p.Meta() - kvRange, err := distsql.TableHandleRangesToKVRanges(h.sctx.GetSessionVars().StmtCtx, []int64{pid}, meta != nil && meta.IsCommonHandle, ranges) + kvRange, err := distsql.TableHandleRangesToKVRanges(h.sctx.GetDistSQLCtx(), []int64{pid}, meta != nil && meta.IsCommonHandle, ranges) if err != nil { return nil, err } @@ -4269,13 +4270,13 @@ func (h kvRangeBuilderFromRangeAndPartition) buildKeyRange(ranges []*ranger.Rang // newClosestReadAdjuster let the request be sent to closest replica(within the same zone) // if response size exceeds certain threshold. -func newClosestReadAdjuster(vars *variable.SessionVars, req *kv.Request, netDataSize float64) kv.CoprRequestAdjuster { +func newClosestReadAdjuster(dctx *distsqlctx.DistSQLContext, req *kv.Request, netDataSize float64) kv.CoprRequestAdjuster { if req.ReplicaRead != kv.ReplicaReadClosestAdaptive { return nil } return func(req *kv.Request, copTaskCount int) bool { // copTaskCount is the number of coprocessor requests - if int64(netDataSize/float64(copTaskCount)) >= vars.ReplicaClosestReadThreshold { + if int64(netDataSize/float64(copTaskCount)) >= dctx.ReplicaClosestReadThreshold { req.MatchStoreLabels = append(req.MatchStoreLabels, &metapb.StoreLabel{ Key: placement.DCLabelKey, Value: config.GetTxnScopeFromConfig(), @@ -4301,18 +4302,18 @@ func (builder *dataReaderBuilder) buildTableReaderBase(ctx context.Context, e *T SetTxnScope(e.txnScope). SetReadReplicaScope(e.readReplicaScope). SetIsStaleness(e.isStaleness). - SetFromSessionVars(e.GetSessionVars()). + SetFromSessionVars(e.dctx). SetFromInfoSchema(e.GetInfoSchema()). - SetClosestReplicaReadAdjuster(newClosestReadAdjuster(e.GetSessionVars(), &reqBuilderWithRange.Request, e.netDataSize)). + SetClosestReplicaReadAdjuster(newClosestReadAdjuster(e.dctx, &reqBuilderWithRange.Request, e.netDataSize)). SetPaging(e.paging). - SetConnIDAndConnAlias(e.GetSessionVars().ConnectionID, e.GetSessionVars().SessionAlias). + SetConnIDAndConnAlias(e.dctx.ConnectionID, e.dctx.SessionAlias). Build() if err != nil { return nil, err } e.kvRanges = kvReq.KeyRanges.AppendSelfTo(e.kvRanges) e.resultHandler = &tableResultHandler{} - result, err := builder.SelectResult(ctx, builder.ctx, kvReq, exec.RetTypes(e), getPhysicalPlanIDs(e.plans), e.ID()) + result, err := builder.SelectResult(ctx, builder.ctx.GetDistSQLCtx(), kvReq, exec.RetTypes(e), getPhysicalPlanIDs(e.plans), e.ID()) if err != nil { return nil, err } @@ -4353,7 +4354,7 @@ func (builder *dataReaderBuilder) buildIndexReaderForIndexJoin(ctx context.Conte } tbInfo := e.table.Meta() if tbInfo.GetPartitionInfo() == nil || !builder.ctx.GetSessionVars().StmtCtx.UseDynamicPartitionPrune() { - kvRanges, err := buildKvRangesForIndexJoin(e.Ctx().GetSessionVars().StmtCtx, e.Ctx().GetPlanCtx(), e.physicalTableID, e.index.ID, lookUpContents, indexRanges, keyOff2IdxOff, cwc, memoryTracker, interruptSignal) + kvRanges, err := buildKvRangesForIndexJoin(e.Ctx().GetDistSQLCtx(), e.Ctx().GetPlanCtx(), e.physicalTableID, e.index.ID, lookUpContents, indexRanges, keyOff2IdxOff, cwc, memoryTracker, interruptSignal) if err != nil { return nil, err } @@ -4424,7 +4425,7 @@ func (builder *dataReaderBuilder) buildIndexLookUpReaderForIndexJoin(ctx context tbInfo := e.table.Meta() if tbInfo.GetPartitionInfo() == nil || !builder.ctx.GetSessionVars().StmtCtx.UseDynamicPartitionPrune() { - e.kvRanges, err = buildKvRangesForIndexJoin(e.Ctx().GetSessionVars().StmtCtx, e.Ctx().GetPlanCtx(), getPhysicalTableID(e.table), e.index.ID, lookUpContents, indexRanges, keyOff2IdxOff, cwc, memTracker, interruptSignal) + e.kvRanges, err = buildKvRangesForIndexJoin(e.Ctx().GetDistSQLCtx(), e.Ctx().GetPlanCtx(), getPhysicalTableID(e.table), e.index.ID, lookUpContents, indexRanges, keyOff2IdxOff, cwc, memTracker, interruptSignal) if err != nil { return nil, err } @@ -4567,7 +4568,7 @@ func buildRangesForIndexJoin(ctx sessionctx.Context, lookUpContents []*indexJoin } // buildKvRangesForIndexJoin builds kv ranges for index join when the inner plan is index scan plan. -func buildKvRangesForIndexJoin(sc *stmtctx.StatementContext, pctx planctx.PlanContext, tableID, indexID int64, lookUpContents []*indexJoinLookUpContent, +func buildKvRangesForIndexJoin(dctx *distsqlctx.DistSQLContext, pctx planctx.PlanContext, tableID, indexID int64, lookUpContents []*indexJoinLookUpContent, ranges []*ranger.Range, keyOff2IdxOff []int, cwc *plannercore.ColWithCmpFuncManager, memTracker *memory.Tracker, interruptSignal *atomic.Value) (_ []kv.KeyRange, err error) { kvRanges := make([]kv.KeyRange, 0, len(ranges)*len(lookUpContents)) if len(ranges) == 0 { @@ -4587,9 +4588,9 @@ func buildKvRangesForIndexJoin(sc *stmtctx.StatementContext, pctx planctx.PlanCo var tmpKvRanges *kv.KeyRanges var err error if indexID == -1 { - tmpKvRanges, err = distsql.CommonHandleRangesToKVRanges(sc, []int64{tableID}, ranges) + tmpKvRanges, err = distsql.CommonHandleRangesToKVRanges(dctx, []int64{tableID}, ranges) } else { - tmpKvRanges, err = distsql.IndexRangesToKVRangesWithInterruptSignal(sc, tableID, indexID, ranges, memTracker, interruptSignal) + tmpKvRanges, err = distsql.IndexRangesToKVRangesWithInterruptSignal(dctx, tableID, indexID, ranges, memTracker, interruptSignal) } if err != nil { return nil, err @@ -4634,10 +4635,10 @@ func buildKvRangesForIndexJoin(sc *stmtctx.StatementContext, pctx planctx.PlanCo } // Index id is -1 means it's a common handle. if indexID == -1 { - tmpKeyRanges, err := distsql.CommonHandleRangesToKVRanges(sc, []int64{tableID}, tmpDatumRanges) + tmpKeyRanges, err := distsql.CommonHandleRangesToKVRanges(dctx, []int64{tableID}, tmpDatumRanges) return tmpKeyRanges.FirstPartitionRange(), err } - tmpKeyRanges, err := distsql.IndexRangesToKVRangesWithInterruptSignal(sc, tableID, indexID, tmpDatumRanges, memTracker, interruptSignal) + tmpKeyRanges, err := distsql.IndexRangesToKVRangesWithInterruptSignal(dctx, tableID, indexID, tmpDatumRanges, memTracker, interruptSignal) return tmpKeyRanges.FirstPartitionRange(), err } diff --git a/pkg/executor/checksum.go b/pkg/executor/checksum.go index bca0af9e645e5..c4d349359a74d 100644 --- a/pkg/executor/checksum.go +++ b/pkg/executor/checksum.go @@ -150,7 +150,7 @@ func (e *ChecksumTableExec) handleChecksumRequest(req *kv.Request) (resp *tipb.C if err = e.Ctx().GetSessionVars().SQLKiller.HandleSignal(); err != nil { return nil, err } - ctx := distsql.WithSQLKvExecCounterInterceptor(context.TODO(), e.Ctx().GetSessionVars().StmtCtx) + ctx := distsql.WithSQLKvExecCounterInterceptor(context.TODO(), e.Ctx().GetSessionVars().StmtCtx.KvExecCounter) res, err := distsql.Checksum(ctx, e.Ctx().GetClient(), req, e.Ctx().GetSessionVars().KVVars) if err != nil { return nil, err @@ -287,7 +287,7 @@ func (c *checksumContext) buildTableRequest(ctx sessionctx.Context, physicalTabl var builder distsql.RequestBuilder builder.SetResourceGroupTagger(ctx.GetSessionVars().StmtCtx.GetResourceGroupTagger()) - return builder.SetHandleRanges(ctx.GetSessionVars().StmtCtx, physicalTableID, c.tableInfo.IsCommonHandle, ranges). + return builder.SetHandleRanges(ctx.GetDistSQLCtx(), physicalTableID, c.tableInfo.IsCommonHandle, ranges). SetChecksumRequest(checksum). SetStartTS(c.startTs). SetConcurrency(ctx.GetSessionVars().DistSQLScanConcurrency()). @@ -306,7 +306,7 @@ func (c *checksumContext) buildIndexRequest(ctx sessionctx.Context, physicalTabl var builder distsql.RequestBuilder builder.SetResourceGroupTagger(ctx.GetSessionVars().StmtCtx.GetResourceGroupTagger()) - return builder.SetIndexRanges(ctx.GetSessionVars().StmtCtx, physicalTableID, indexInfo.ID, ranges). + return builder.SetIndexRanges(ctx.GetDistSQLCtx(), physicalTableID, indexInfo.ID, ranges). SetChecksumRequest(checksum). SetStartTS(c.startTs). SetConcurrency(ctx.GetSessionVars().DistSQLScanConcurrency()). diff --git a/pkg/executor/distsql.go b/pkg/executor/distsql.go index 078712f713d43..b29ce3d425178 100644 --- a/pkg/executor/distsql.go +++ b/pkg/executor/distsql.go @@ -29,6 +29,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/tidb/pkg/distsql" + distsqlctx "github.com/pingcap/tidb/pkg/distsql/context" "github.com/pingcap/tidb/pkg/executor/internal/builder" "github.com/pingcap/tidb/pkg/executor/internal/exec" "github.com/pingcap/tidb/pkg/expression" @@ -41,7 +42,6 @@ import ( plannercore "github.com/pingcap/tidb/pkg/planner/core" plannerutil "github.com/pingcap/tidb/pkg/planner/util" "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" "github.com/pingcap/tidb/pkg/table" "github.com/pingcap/tidb/pkg/table/tables" "github.com/pingcap/tidb/pkg/tablecodec" @@ -252,15 +252,15 @@ func (e *IndexReaderExecutor) Next(ctx context.Context, req *chunk.Chunk) error } // TODO: cleanup this method. -func (e *IndexReaderExecutor) buildKeyRanges(sc *stmtctx.StatementContext, ranges []*ranger.Range, physicalID int64) ([]kv.KeyRange, error) { +func (e *IndexReaderExecutor) buildKeyRanges(dctx *distsqlctx.DistSQLContext, ranges []*ranger.Range, physicalID int64) ([]kv.KeyRange, error) { var ( rRanges *kv.KeyRanges err error ) if e.index.ID == -1 { - rRanges, err = distsql.CommonHandleRangesToKVRanges(sc, []int64{physicalID}, ranges) + rRanges, err = distsql.CommonHandleRangesToKVRanges(dctx, []int64{physicalID}, ranges) } else { - rRanges, err = distsql.IndexRangesToKVRanges(sc, physicalID, e.index.ID, ranges) + rRanges, err = distsql.IndexRangesToKVRanges(dctx, physicalID, e.index.ID, ranges) } return rRanges.FirstPartitionRange(), err } @@ -275,7 +275,6 @@ func (e *IndexReaderExecutor) Open(ctx context.Context) error { } } - sc := e.Ctx().GetSessionVars().StmtCtx var kvRanges []kv.KeyRange if len(e.partitions) > 0 { for _, p := range e.partitions { @@ -283,14 +282,14 @@ func (e *IndexReaderExecutor) Open(ctx context.Context) error { if pRange, ok := e.partRangeMap[p.GetPhysicalID()]; ok { partRange = pRange } - kvRange, err := e.buildKeyRanges(sc, partRange, p.GetPhysicalID()) + kvRange, err := e.buildKeyRanges(e.Ctx().GetDistSQLCtx(), partRange, p.GetPhysicalID()) if err != nil { return err } kvRanges = append(kvRanges, kvRange...) } } else { - kvRanges, err = e.buildKeyRanges(sc, e.ranges, e.physicalTableID) + kvRanges, err = e.buildKeyRanges(e.Ctx().GetDistSQLCtx(), e.ranges, e.physicalTableID) } if err != nil { return err @@ -309,10 +308,10 @@ func (e *IndexReaderExecutor) buildKVReq(r []kv.KeyRange) (*kv.Request, error) { SetTxnScope(e.txnScope). SetReadReplicaScope(e.readReplicaScope). SetIsStaleness(e.isStaleness). - SetFromSessionVars(e.Ctx().GetSessionVars()). + SetFromSessionVars(e.Ctx().GetDistSQLCtx()). SetFromInfoSchema(e.Ctx().GetInfoSchema()). SetMemTracker(e.memTracker). - SetClosestReplicaReadAdjuster(newClosestReadAdjuster(e.Ctx().GetSessionVars(), &builder.Request, e.netDataSize)). + SetClosestReplicaReadAdjuster(newClosestReadAdjuster(e.Ctx().GetDistSQLCtx(), &builder.Request, e.netDataSize)). SetConnIDAndConnAlias(e.Ctx().GetSessionVars().ConnectionID, e.Ctx().GetSessionVars().SessionAlias) kvReq, err := builder.Build() return kvReq, err @@ -390,7 +389,7 @@ func (e *IndexReaderExecutor) open(ctx context.Context, kvRanges []kv.KeyRange) if err != nil { return err } - e.result, err = e.SelectResult(ctx, e.Ctx(), kvReq, exec.RetTypes(e), getPhysicalPlanIDs(e.plans), e.ID()) + e.result, err = e.SelectResult(ctx, e.Ctx().GetDistSQLCtx(), kvReq, exec.RetTypes(e), getPhysicalPlanIDs(e.plans), e.ID()) if err != nil { return err } @@ -405,7 +404,7 @@ func (e *IndexReaderExecutor) open(ctx context.Context, kvRanges []kv.KeyRange) } var results []distsql.SelectResult for _, kvReq := range kvReqs { - result, err := e.SelectResult(ctx, e.Ctx(), kvReq, exec.RetTypes(e), getPhysicalPlanIDs(e.plans), e.ID()) + result, err := e.SelectResult(ctx, e.Ctx().GetDistSQLCtx(), kvReq, exec.RetTypes(e), getPhysicalPlanIDs(e.plans), e.ID()) if err != nil { return err } @@ -533,7 +532,7 @@ func (e *IndexLookUpExecutor) Open(ctx context.Context) error { } func (e *IndexLookUpExecutor) buildTableKeyRanges() (err error) { - sc := e.Ctx().GetSessionVars().StmtCtx + dctx := e.Ctx().GetDistSQLCtx() if e.partitionTableMode { e.partitionKVRanges = make([][]kv.KeyRange, 0, len(e.prunedPartitions)) for _, p := range e.prunedPartitions { @@ -547,9 +546,9 @@ func (e *IndexLookUpExecutor) buildTableKeyRanges() (err error) { } var kvRange *kv.KeyRanges if e.index.ID == -1 { - kvRange, err = distsql.CommonHandleRangesToKVRanges(sc, []int64{physicalID}, ranges) + kvRange, err = distsql.CommonHandleRangesToKVRanges(dctx, []int64{physicalID}, ranges) } else { - kvRange, err = distsql.IndexRangesToKVRanges(sc, physicalID, e.index.ID, ranges) + kvRange, err = distsql.IndexRangesToKVRanges(dctx, physicalID, e.index.ID, ranges) } if err != nil { return err @@ -560,9 +559,9 @@ func (e *IndexLookUpExecutor) buildTableKeyRanges() (err error) { physicalID := getPhysicalTableID(e.table) var kvRanges *kv.KeyRanges if e.index.ID == -1 { - kvRanges, err = distsql.CommonHandleRangesToKVRanges(sc, []int64{physicalID}, e.ranges) + kvRanges, err = distsql.CommonHandleRangesToKVRanges(dctx, []int64{physicalID}, e.ranges) } else { - kvRanges, err = distsql.IndexRangesToKVRanges(sc, physicalID, e.index.ID, e.ranges) + kvRanges, err = distsql.IndexRangesToKVRanges(dctx, physicalID, e.index.ID, e.ranges) } e.kvRanges = kvRanges.FirstPartitionRange() } @@ -714,9 +713,9 @@ func (e *IndexLookUpExecutor) startIndexWorker(ctx context.Context, workCh chan< SetTxnScope(e.txnScope). SetReadReplicaScope(e.readReplicaScope). SetIsStaleness(e.isStaleness). - SetFromSessionVars(e.Ctx().GetSessionVars()). + SetFromSessionVars(e.Ctx().GetDistSQLCtx()). SetFromInfoSchema(e.Ctx().GetInfoSchema()). - SetClosestReplicaReadAdjuster(newClosestReadAdjuster(e.Ctx().GetSessionVars(), &builder.Request, e.idxNetDataSize/float64(len(kvRanges)))). + SetClosestReplicaReadAdjuster(newClosestReadAdjuster(e.Ctx().GetDistSQLCtx(), &builder.Request, e.idxNetDataSize/float64(len(kvRanges)))). SetMemTracker(tracker). SetConnIDAndConnAlias(e.Ctx().GetSessionVars().ConnectionID, e.Ctx().GetSessionVars().SessionAlias) @@ -743,7 +742,7 @@ func (e *IndexLookUpExecutor) startIndexWorker(ctx context.Context, workCh chan< worker.syncErr(err) break } - result, err := distsql.SelectWithRuntimeStats(ctx, e.Ctx(), kvReq, tps, getPhysicalPlanIDs(e.idxPlans), idxID) + result, err := distsql.SelectWithRuntimeStats(ctx, e.Ctx().GetDistSQLCtx(), kvReq, tps, getPhysicalPlanIDs(e.idxPlans), idxID) if err != nil { worker.syncErr(err) break diff --git a/pkg/executor/executor_pkg_test.go b/pkg/executor/executor_pkg_test.go index 6656df4042020..7680445d84955 100644 --- a/pkg/executor/executor_pkg_test.go +++ b/pkg/executor/executor_pkg_test.go @@ -65,7 +65,7 @@ func TestBuildKvRangesForIndexJoinWithoutCwc(t *testing.T) { keyOff2IdxOff := []int{1, 3} ctx := mock.NewContext() - kvRanges, err := buildKvRangesForIndexJoin(ctx.GetSessionVars().StmtCtx, ctx.GetPlanCtx(), 0, 0, joinKeyRows, indexRanges, keyOff2IdxOff, nil, nil, nil) + kvRanges, err := buildKvRangesForIndexJoin(ctx.GetDistSQLCtx(), ctx.GetPlanCtx(), 0, 0, joinKeyRows, indexRanges, keyOff2IdxOff, nil, nil, nil) require.NoError(t, err) // Check the kvRanges is in order. for i, kvRange := range kvRanges { @@ -95,7 +95,7 @@ func TestBuildKvRangesForIndexJoinWithoutCwcAndWithMemoryTracker(t *testing.T) { keyOff2IdxOff := []int{1, 3} ctx := mock.NewContext() memTracker := memory.NewTracker(memory.LabelForIndexWorker, -1) - kvRanges, err := buildKvRangesForIndexJoin(ctx.GetSessionVars().StmtCtx, ctx.GetPlanCtx(), 0, 0, joinKeyRows, indexRanges, keyOff2IdxOff, nil, memTracker, nil) + kvRanges, err := buildKvRangesForIndexJoin(ctx.GetDistSQLCtx(), ctx.GetPlanCtx(), 0, 0, joinKeyRows, indexRanges, keyOff2IdxOff, nil, memTracker, nil) require.NoError(t, err) // Check the kvRanges is in order. for i, kvRange := range kvRanges { @@ -117,7 +117,7 @@ func TestBuildKvRangesForIndexJoinWithoutCwcAndWithMemoryTracker(t *testing.T) { keyOff2IdxOff := []int{1, 3} ctx := mock.NewContext() memTracker := memory.NewTracker(memory.LabelForIndexWorker, -1) - kvRanges, err := buildKvRangesForIndexJoin(ctx.GetSessionVars().StmtCtx, ctx.GetPlanCtx(), 0, 0, joinKeyRows, indexRanges, keyOff2IdxOff, nil, memTracker, nil) + kvRanges, err := buildKvRangesForIndexJoin(ctx.GetDistSQLCtx(), ctx.GetPlanCtx(), 0, 0, joinKeyRows, indexRanges, keyOff2IdxOff, nil, memTracker, nil) require.NoError(t, err) // Check the kvRanges is in order. for i, kvRange := range kvRanges { diff --git a/pkg/executor/index_merge_reader.go b/pkg/executor/index_merge_reader.go index 288f90d25784d..61e74bf1c77da 100644 --- a/pkg/executor/index_merge_reader.go +++ b/pkg/executor/index_merge_reader.go @@ -224,16 +224,16 @@ func (e *IndexMergeReaderExecutor) rebuildRangeForCorCol() (err error) { } func (e *IndexMergeReaderExecutor) buildKeyRangesForTable(tbl table.Table) (ranges [][]kv.KeyRange, err error) { - sc := e.Ctx().GetSessionVars().StmtCtx + dctx := e.Ctx().GetDistSQLCtx() for i, plan := range e.partialPlans { _, ok := plan[0].(*plannercore.PhysicalIndexScan) if !ok { firstPartRanges, secondPartRanges := distsql.SplitRangesAcrossInt64Boundary(e.ranges[i], false, e.descs[i], tbl.Meta().IsCommonHandle) - firstKeyRanges, err := distsql.TableHandleRangesToKVRanges(sc, []int64{getPhysicalTableID(tbl)}, tbl.Meta().IsCommonHandle, firstPartRanges) + firstKeyRanges, err := distsql.TableHandleRangesToKVRanges(dctx, []int64{getPhysicalTableID(tbl)}, tbl.Meta().IsCommonHandle, firstPartRanges) if err != nil { return nil, err } - secondKeyRanges, err := distsql.TableHandleRangesToKVRanges(sc, []int64{getPhysicalTableID(tbl)}, tbl.Meta().IsCommonHandle, secondPartRanges) + secondKeyRanges, err := distsql.TableHandleRangesToKVRanges(dctx, []int64{getPhysicalTableID(tbl)}, tbl.Meta().IsCommonHandle, secondPartRanges) if err != nil { return nil, err } @@ -241,7 +241,7 @@ func (e *IndexMergeReaderExecutor) buildKeyRangesForTable(tbl table.Table) (rang ranges = append(ranges, keyRanges) continue } - keyRange, err := distsql.IndexRangesToKVRanges(sc, getPhysicalTableID(tbl), e.indexes[i].ID, e.ranges[i]) + keyRange, err := distsql.IndexRangesToKVRanges(dctx, getPhysicalTableID(tbl), e.indexes[i].ID, e.ranges[i]) if err != nil { return nil, err } @@ -384,11 +384,11 @@ func (e *IndexMergeReaderExecutor) startPartialIndexWorker(ctx context.Context, SetTxnScope(e.txnScope). SetReadReplicaScope(e.readReplicaScope). SetIsStaleness(e.isStaleness). - SetFromSessionVars(e.Ctx().GetSessionVars()). + SetFromSessionVars(e.Ctx().GetDistSQLCtx()). SetMemTracker(e.memTracker). SetPaging(e.paging). SetFromInfoSchema(e.Ctx().GetInfoSchema()). - SetClosestReplicaReadAdjuster(newClosestReadAdjuster(e.Ctx().GetSessionVars(), &builder.Request, e.partialNetDataSizes[workID])). + SetClosestReplicaReadAdjuster(newClosestReadAdjuster(e.Ctx().GetDistSQLCtx(), &builder.Request, e.partialNetDataSizes[workID])). SetConnIDAndConnAlias(e.Ctx().GetSessionVars().ConnectionID, e.Ctx().GetSessionVars().SessionAlias) tps := worker.getRetTpsForIndexScan(e.handleCols) @@ -421,7 +421,7 @@ func (e *IndexMergeReaderExecutor) startPartialIndexWorker(ctx context.Context, syncErr(ctx, e.finished, fetchCh, err) return } - result, err := distsql.SelectWithRuntimeStats(ctx, e.Ctx(), kvReq, tps, getPhysicalPlanIDs(e.partialPlans[workID]), e.getPartitalPlanID(workID)) + result, err := distsql.SelectWithRuntimeStats(ctx, e.Ctx().GetDistSQLCtx(), kvReq, tps, getPhysicalPlanIDs(e.partialPlans[workID]), e.getPartitalPlanID(workID)) if err != nil { syncErr(ctx, e.finished, fetchCh, err) return diff --git a/pkg/executor/internal/builder/builder_utils.go b/pkg/executor/internal/builder/builder_utils.go index 79311cd133a68..d161ca8a9d9e3 100644 --- a/pkg/executor/internal/builder/builder_utils.go +++ b/pkg/executor/internal/builder/builder_utils.go @@ -66,6 +66,6 @@ func ConstructDAGReq(ctx sessionctx.Context, plans []plannercore.PhysicalPlan, s dagReq.Executors, err = ConstructListBasedDistExec(ctx.GetPlanCtx(), plans) } - distsql.SetEncodeType(ctx, dagReq) + distsql.SetEncodeType(ctx.GetDistSQLCtx(), dagReq) return dagReq, err } diff --git a/pkg/executor/internal/mpp/local_mpp_coordinator.go b/pkg/executor/internal/mpp/local_mpp_coordinator.go index 16e1415732866..549d15f219350 100644 --- a/pkg/executor/internal/mpp/local_mpp_coordinator.go +++ b/pkg/executor/internal/mpp/local_mpp_coordinator.go @@ -601,7 +601,7 @@ func (c *localMppCoordinator) handleAllReports() error { } } } - distsql.FillDummySummariesForTiFlashTasks(c.sessionCtx.GetSessionVars().StmtCtx, "", kv.TiFlash.Name(), c.planIDs, recordedPlanIDs) + distsql.FillDummySummariesForTiFlashTasks(c.sessionCtx.GetSessionVars().StmtCtx.RuntimeStatsColl, "", kv.TiFlash.Name(), c.planIDs, recordedPlanIDs) case <-time.After(receiveReportTimeout): metrics.MppCoordinatorStatsReportNotReceived.Inc() logutil.BgLogger().Warn(fmt.Sprintf("Mpp coordinator not received all reports within %d seconds", int(receiveReportTimeout.Seconds())), @@ -752,9 +752,9 @@ func (c *localMppCoordinator) Execute(ctx context.Context) (kv.Response, []kv.Ke } }) - ctx = distsql.WithSQLKvExecCounterInterceptor(ctx, sctx.GetSessionVars().StmtCtx) + ctx = distsql.WithSQLKvExecCounterInterceptor(ctx, sctx.GetSessionVars().StmtCtx.KvExecCounter) _, allowTiFlashFallback := sctx.GetSessionVars().AllowFallbackToTiKV[kv.TiFlash] - ctx = distsql.SetTiFlashConfVarsInContext(ctx, sctx.GetSessionVars()) + ctx = distsql.SetTiFlashConfVarsInContext(ctx, sctx.GetDistSQLCtx()) c.needTriggerFallback = allowTiFlashFallback c.enableCollectExecutionInfo = config.GetGlobalConfig().Instance.EnableCollectExecutionInfo.Load() diff --git a/pkg/executor/mpp_gather.go b/pkg/executor/mpp_gather.go index 2bbb6fdde4810..45160e22c4490 100644 --- a/pkg/executor/mpp_gather.go +++ b/pkg/executor/mpp_gather.go @@ -104,7 +104,7 @@ func (e *MPPGather) Open(ctx context.Context) (err error) { return err } e.kvRanges = e.mppExec.KVRanges - e.respIter = distsql.GenSelectResultFromMPPResponse(e.Ctx(), e.RetFieldTypes(), planIDs, e.ID(), e.mppExec) + e.respIter = distsql.GenSelectResultFromMPPResponse(e.Ctx().GetDistSQLCtx(), e.RetFieldTypes(), planIDs, e.ID(), e.mppExec) return nil } diff --git a/pkg/executor/simple.go b/pkg/executor/simple.go index bf897b1bd6f7e..ea886786a1494 100644 --- a/pkg/executor/simple.go +++ b/pkg/executor/simple.go @@ -2612,7 +2612,7 @@ func killRemoteConn(ctx context.Context, sctx sessionctx.Context, gcid *globalco var builder distsql.RequestBuilder kvReq, err := builder. SetDAGRequest(dagReq). - SetFromSessionVars(sctx.GetSessionVars()). + SetFromSessionVars(sctx.GetDistSQLCtx()). SetFromInfoSchema(sctx.GetInfoSchema()). SetStoreType(kv.TiDB). SetTiDBServerID(gcid.ServerID). diff --git a/pkg/executor/table_reader.go b/pkg/executor/table_reader.go index 32fa2bc9c4840..19b4ef3f4e270 100644 --- a/pkg/executor/table_reader.go +++ b/pkg/executor/table_reader.go @@ -58,11 +58,11 @@ var _ exec.Executor = &TableReaderExecutor{} // selectResultHook is used to hack distsql.SelectWithRuntimeStats safely for testing. type selectResultHook struct { - selectResultFunc func(ctx context.Context, dctx distsqlctx.DistSQLContext, kvReq *kv.Request, + selectResultFunc func(ctx context.Context, dctx *distsqlctx.DistSQLContext, kvReq *kv.Request, fieldTypes []*types.FieldType, copPlanIDs []int) (distsql.SelectResult, error) } -func (sr selectResultHook) SelectResult(ctx context.Context, dctx distsqlctx.DistSQLContext, kvReq *kv.Request, +func (sr selectResultHook) SelectResult(ctx context.Context, dctx *distsqlctx.DistSQLContext, kvReq *kv.Request, fieldTypes []*types.FieldType, copPlanIDs []int, rootPlanID int) (distsql.SelectResult, error) { if sr.selectResultFunc == nil { return distsql.SelectWithRuntimeStats(ctx, dctx, kvReq, fieldTypes, copPlanIDs, rootPlanID) @@ -77,7 +77,7 @@ type kvRangeBuilder interface { // tableReaderExecutorContext is the execution context for the `TableReaderExecutor` type tableReaderExecutorContext struct { - dctx distsqlctx.DistSQLContext + dctx *distsqlctx.DistSQLContext pctx planctx.PlanContext ectx exprctx.BuildContext @@ -85,7 +85,7 @@ type tableReaderExecutorContext struct { } func (treCtx *tableReaderExecutorContext) GetSessionVars() *variable.SessionVars { - return treCtx.dctx.GetSessionVars() + return treCtx.pctx.GetSessionVars() } func (treCtx *tableReaderExecutorContext) GetInfoSchema() infoschema.InfoSchema { @@ -434,14 +434,14 @@ func (e *TableReaderExecutor) buildKVReqSeparately(ctx context.Context, ranges [ SetKeepOrder(e.keepOrder). SetTxnScope(e.txnScope). SetReadReplicaScope(e.readReplicaScope). - SetFromSessionVars(e.GetSessionVars()). + SetFromSessionVars(e.dctx). SetFromInfoSchema(e.GetInfoSchema()). SetMemTracker(e.memTracker). SetStoreType(e.storeType). SetPaging(e.paging). SetAllowBatchCop(e.batchCop). - SetClosestReplicaReadAdjuster(newClosestReadAdjuster(e.GetSessionVars(), &reqBuilder.Request, e.netDataSize)). - SetConnIDAndConnAlias(e.GetSessionVars().ConnectionID, e.GetSessionVars().SessionAlias). + SetClosestReplicaReadAdjuster(newClosestReadAdjuster(e.dctx, &reqBuilder.Request, e.netDataSize)). + SetConnIDAndConnAlias(e.dctx.ConnectionID, e.dctx.SessionAlias). Build() if err != nil { return nil, err @@ -476,14 +476,14 @@ func (e *TableReaderExecutor) buildKVReqForPartitionTableScan(ctx context.Contex SetKeepOrder(e.keepOrder). SetTxnScope(e.txnScope). SetReadReplicaScope(e.readReplicaScope). - SetFromSessionVars(e.GetSessionVars()). + SetFromSessionVars(e.dctx). SetFromInfoSchema(e.GetInfoSchema()). SetMemTracker(e.memTracker). SetStoreType(e.storeType). SetPaging(e.paging). SetAllowBatchCop(e.batchCop). - SetClosestReplicaReadAdjuster(newClosestReadAdjuster(e.GetSessionVars(), &reqBuilder.Request, e.netDataSize)). - SetConnIDAndConnAlias(e.GetSessionVars().ConnectionID, e.GetSessionVars().SessionAlias). + SetClosestReplicaReadAdjuster(newClosestReadAdjuster(e.dctx, &reqBuilder.Request, e.netDataSize)). + SetConnIDAndConnAlias(e.dctx.ConnectionID, e.dctx.SessionAlias). Build() if err != nil { return nil, err @@ -501,7 +501,7 @@ func (e *TableReaderExecutor) buildKVReq(ctx context.Context, ranges []*ranger.R } reqBuilder = builder.SetPartitionKeyRanges(kvRange) } else { - reqBuilder = builder.SetHandleRanges(e.GetSessionVars().StmtCtx, getPhysicalTableID(e.table), e.table.Meta() != nil && e.table.Meta().IsCommonHandle, ranges) + reqBuilder = builder.SetHandleRanges(e.dctx, getPhysicalTableID(e.table), e.table.Meta() != nil && e.table.Meta().IsCommonHandle, ranges) } if e.table != nil && e.table.Type().IsClusterTable() { copDestination := infoschema.GetClusterTableCopDestination(e.table.Meta().Name.L) @@ -521,14 +521,14 @@ func (e *TableReaderExecutor) buildKVReq(ctx context.Context, ranges []*ranger.R SetTxnScope(e.txnScope). SetReadReplicaScope(e.readReplicaScope). SetIsStaleness(e.isStaleness). - SetFromSessionVars(e.GetSessionVars()). + SetFromSessionVars(e.dctx). SetFromInfoSchema(e.GetInfoSchema()). SetMemTracker(e.memTracker). SetStoreType(e.storeType). SetAllowBatchCop(e.batchCop). - SetClosestReplicaReadAdjuster(newClosestReadAdjuster(e.GetSessionVars(), &reqBuilder.Request, e.netDataSize)). + SetClosestReplicaReadAdjuster(newClosestReadAdjuster(e.dctx, &reqBuilder.Request, e.netDataSize)). SetPaging(e.paging). - SetConnIDAndConnAlias(e.GetSessionVars().ConnectionID, e.GetSessionVars().SessionAlias) + SetConnIDAndConnAlias(e.dctx.ConnectionID, e.dctx.SessionAlias) return reqBuilder.Build() } diff --git a/pkg/executor/table_readers_required_rows_test.go b/pkg/executor/table_readers_required_rows_test.go index 238b5b3d31d20..18133245570b7 100644 --- a/pkg/executor/table_readers_required_rows_test.go +++ b/pkg/executor/table_readers_required_rows_test.go @@ -112,7 +112,7 @@ func mockDistsqlSelectCtxGet(ctx context.Context) (totalRows int, expectedRowsRe return } -func mockSelectResult(ctx context.Context, dctx distsqlctx.DistSQLContext, kvReq *kv.Request, +func mockSelectResult(ctx context.Context, dctx *distsqlctx.DistSQLContext, kvReq *kv.Request, fieldTypes []*types.FieldType, copPlanIDs []int) (distsql.SelectResult, error) { totalRows, expectedRowsRet := mockDistsqlSelectCtxGet(ctx) return &requiredRowsSelectResult{ diff --git a/pkg/planner/core/fragment.go b/pkg/planner/core/fragment.go index 648588ea2dbdd..4db7eeaba2023 100644 --- a/pkg/planner/core/fragment.go +++ b/pkg/planner/core/fragment.go @@ -626,7 +626,7 @@ func (e *mppTaskGenerator) constructMPPBuildTaskReqForPartitionedTable(ts *Physi for i, p := range partitions { pid := p.GetPhysicalID() meta := p.Meta() - kvRanges, err := distsql.TableHandleRangesToKVRanges(e.ctx.GetSessionVars().StmtCtx, []int64{pid}, meta != nil && ts.Table.IsCommonHandle, splitedRanges) + kvRanges, err := distsql.TableHandleRangesToKVRanges(e.ctx.GetDistSQLCtx(), []int64{pid}, meta != nil && ts.Table.IsCommonHandle, splitedRanges) if err != nil { return nil, nil, errors.Trace(err) } @@ -638,7 +638,7 @@ func (e *mppTaskGenerator) constructMPPBuildTaskReqForPartitionedTable(ts *Physi } func (e *mppTaskGenerator) constructMPPBuildTaskForNonPartitionTable(tid int64, isCommonHandle bool, splitedRanges []*ranger.Range) (*kv.MPPBuildTasksRequest, error) { - kvRanges, err := distsql.TableHandleRangesToKVRanges(e.ctx.GetSessionVars().StmtCtx, []int64{tid}, isCommonHandle, splitedRanges) + kvRanges, err := distsql.TableHandleRangesToKVRanges(e.ctx.GetDistSQLCtx(), []int64{tid}, isCommonHandle, splitedRanges) if err != nil { return nil, errors.Trace(err) } diff --git a/pkg/session/session.go b/pkg/session/session.go index f028f397e07b4..cd8496f0598e0 100644 --- a/pkg/session/session.go +++ b/pkg/session/session.go @@ -2637,8 +2637,64 @@ func (s *session) GetTableCtx() tbctx.MutateContext { } // GetDistSQLCtx returns the context used in DistSQL -func (s *session) GetDistSQLCtx() distsqlctx.DistSQLContext { - return s +func (s *session) GetDistSQLCtx() *distsqlctx.DistSQLContext { + vars := s.GetSessionVars() + sc := vars.StmtCtx + + dctx := sc.GetOrInitDistSQLFromCache(func() *distsqlctx.DistSQLContext { + return &distsqlctx.DistSQLContext{ + AppendWarning: sc.AppendWarning, + InRestrictedSQL: sc.InRestrictedSQL, + Client: s.GetClient(), + + EnabledRateLimitAction: vars.EnabledRateLimitAction, + EnableChunkRPC: vars.EnableChunkRPC, + OriginalSQL: sc.OriginalSQL, + KVVars: vars.KVVars, + KvExecCounter: sc.KvExecCounter, + SessionMemTracker: vars.MemTracker, + + Location: sc.TimeZone(), + RuntimeStatsColl: sc.RuntimeStatsColl, + SQLKiller: &vars.SQLKiller, + ErrCtx: sc.ErrCtx(), + + TiFlashReplicaRead: vars.TiFlashReplicaRead, + TiFlashMaxThreads: vars.TiFlashMaxThreads, + TiFlashMaxBytesBeforeExternalJoin: vars.TiFlashMaxBytesBeforeExternalJoin, + TiFlashMaxBytesBeforeExternalGroupBy: vars.TiFlashMaxBytesBeforeExternalGroupBy, + TiFlashMaxBytesBeforeExternalSort: vars.TiFlashMaxBytesBeforeExternalSort, + TiFlashMaxQueryMemoryPerNode: vars.TiFlashMaxQueryMemoryPerNode, + TiFlashQuerySpillRatio: vars.TiFlashQuerySpillRatio, + + DistSQLConcurrency: vars.DistSQLScanConcurrency(), + ReplicaReadType: vars.GetReplicaRead(), + WeakConsistency: sc.WeakConsistency, + RCCheckTS: sc.RCCheckTS, + NotFillCache: sc.NotFillCache, + TaskID: sc.TaskID, + Priority: sc.Priority, + ResourceGroupTagger: sc.GetResourceGroupTagger(), + EnablePaging: vars.EnablePaging, + MinPagingSize: vars.MinPagingSize, + MaxPagingSize: vars.MaxPagingSize, + RequestSourceType: vars.RequestSourceType, + ExplicitRequestSourceType: vars.ExplicitRequestSourceType, + StoreBatchSize: vars.StoreBatchSize, + ResourceGroupName: sc.ResourceGroupName, + LoadBasedReplicaReadThreshold: vars.LoadBasedReplicaReadThreshold, + RunawayChecker: sc.RunawayChecker, + TiKVClientReadTimeout: vars.GetTiKVClientReadTimeout(), + + ReplicaClosestReadThreshold: vars.ReplicaClosestReadThreshold, + ConnectionID: vars.ConnectionID, + SessionAlias: vars.SessionAlias, + + ExecDetails: &sc.SyncExecDetails, + } + }) + + return dctx.(*distsqlctx.DistSQLContext) } func (s *session) AuthPluginForUser(user *auth.UserIdentity) (string, error) { diff --git a/pkg/sessionctx/context.go b/pkg/sessionctx/context.go index 840db9d9f705f..4888177c369f7 100644 --- a/pkg/sessionctx/context.go +++ b/pkg/sessionctx/context.go @@ -117,7 +117,7 @@ type Context interface { GetPlanCtx() planctx.PlanContext // GetDistSQLCtx gets the distsql ctx of the current session - GetDistSQLCtx() distsqlctx.DistSQLContext + GetDistSQLCtx() *distsqlctx.DistSQLContext GetSessionManager() util.SessionManager diff --git a/pkg/sessionctx/stmtctx/BUILD.bazel b/pkg/sessionctx/stmtctx/BUILD.bazel index 29e00813dd8ec..894dc72b81e30 100644 --- a/pkg/sessionctx/stmtctx/BUILD.bazel +++ b/pkg/sessionctx/stmtctx/BUILD.bazel @@ -6,6 +6,7 @@ go_library( importpath = "github.com/pingcap/tidb/pkg/sessionctx/stmtctx", visibility = ["//visibility:public"], deps = [ + "//pkg/distsql/context", "//pkg/domain/resourcegroup", "//pkg/errctx", "//pkg/parser", diff --git a/pkg/sessionctx/stmtctx/stmtctx.go b/pkg/sessionctx/stmtctx/stmtctx.go index fd2d9192c18e4..da6d1d7217a68 100644 --- a/pkg/sessionctx/stmtctx/stmtctx.go +++ b/pkg/sessionctx/stmtctx/stmtctx.go @@ -28,6 +28,7 @@ import ( "time" "github.com/pingcap/errors" + distsqlctx "github.com/pingcap/tidb/pkg/distsql/context" "github.com/pingcap/tidb/pkg/domain/resourcegroup" "github.com/pingcap/tidb/pkg/errctx" "github.com/pingcap/tidb/pkg/parser" @@ -167,6 +168,13 @@ type StatementContext struct { // errCtx is used to indicate how to handle the errors errCtx errctx.Context + // distSQLCtxCache is used to persist all variables and tools needed by the `distsql` + // this cache is set on `StatementContext` because it has to be updated after each statement. + distSQLCtxCache struct { + init sync.Once + dctx *distsqlctx.DistSQLContext + } + // Set the following variables before execution hint.StmtHints @@ -1063,6 +1071,9 @@ func (sc *StatementContext) ResetForRetry() { sc.IndexNames = sc.IndexNames[:0] sc.TaskID = AllocateTaskID() sc.SyncExecDetails.Reset() + + // `TaskID` is reset, we'll need to reset distSQLCtx + sc.distSQLCtxCache.init = sync.Once{} } // GetExecDetails gets the execution details for the statement. @@ -1234,6 +1245,16 @@ func (sc *StatementContext) TypeCtxOrDefault() types.Context { return types.DefaultStmtNoWarningContext } +// GetOrInitDistSQLFromCache returns the `DistSQLContext` inside cache. If it didn't exist, return a new one created by +// the `create` function. It uses the `any` to avoid cycle dependency. +func (sc *StatementContext) GetOrInitDistSQLFromCache(create func() *distsqlctx.DistSQLContext) any { + sc.distSQLCtxCache.init.Do(func() { + sc.distSQLCtxCache.dctx = create() + }) + + return sc.distSQLCtxCache.dctx +} + func newErrCtx(tc types.Context, otherLevels errctx.LevelMap, handler contextutil.WarnHandler) errctx.Context { l := errctx.LevelError if flags := tc.Flags(); flags.IgnoreTruncateErr() { diff --git a/pkg/util/mock/context.go b/pkg/util/mock/context.go index a4f790d89b1a2..bcdbe98f66dd5 100644 --- a/pkg/util/mock/context.go +++ b/pkg/util/mock/context.go @@ -249,8 +249,33 @@ func (c *Context) GetTableCtx() tbctx.MutateContext { } // GetDistSQLCtx returns the distsql context of the session -func (c *Context) GetDistSQLCtx() distsqlctx.DistSQLContext { - return c +func (c *Context) GetDistSQLCtx() *distsqlctx.DistSQLContext { + vars := c.GetSessionVars() + sc := vars.StmtCtx + + return &distsqlctx.DistSQLContext{ + AppendWarning: sc.AppendWarning, + InRestrictedSQL: sc.InRestrictedSQL, + Client: c.GetClient(), + EnabledRateLimitAction: vars.EnabledRateLimitAction, + EnableChunkRPC: vars.EnableChunkRPC, + OriginalSQL: sc.OriginalSQL, + KVVars: vars.KVVars, + KvExecCounter: sc.KvExecCounter, + SessionMemTracker: vars.MemTracker, + Location: sc.TimeZone(), + RuntimeStatsColl: sc.RuntimeStatsColl, + SQLKiller: &vars.SQLKiller, + ErrCtx: sc.ErrCtx(), + TiFlashReplicaRead: vars.TiFlashReplicaRead, + TiFlashMaxThreads: vars.TiFlashMaxThreads, + TiFlashMaxBytesBeforeExternalJoin: vars.TiFlashMaxBytesBeforeExternalJoin, + TiFlashMaxBytesBeforeExternalGroupBy: vars.TiFlashMaxBytesBeforeExternalGroupBy, + TiFlashMaxBytesBeforeExternalSort: vars.TiFlashMaxBytesBeforeExternalSort, + TiFlashMaxQueryMemoryPerNode: vars.TiFlashMaxQueryMemoryPerNode, + TiFlashQuerySpillRatio: vars.TiFlashQuerySpillRatio, + ExecDetails: &sc.SyncExecDetails, + } } // Txn implements sessionctx.Context Txn interface.