diff --git a/ddl/internal/session/BUILD.bazel b/ddl/internal/session/BUILD.bazel index e5f238f703c19..64f73641fb73c 100644 --- a/ddl/internal/session/BUILD.bazel +++ b/ddl/internal/session/BUILD.bazel @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") go_library( name = "session", @@ -9,6 +9,7 @@ go_library( importpath = "github.com/pingcap/tidb/ddl/internal/session", visibility = ["//ddl:__subpackages__"], deps = [ + "//domain/infosync", "//kv", "//metrics", "//parser/mysql", @@ -24,3 +25,16 @@ go_library( "@com_github_pingcap_failpoint//:failpoint", ], ) + +go_test( + name = "session_test", + timeout = "short", + srcs = ["session_pool_test.go"], + flaky = True, + deps = [ + ":session", + "//testkit", + "@com_github_ngaut_pools//:pools", + "@com_github_stretchr_testify//require", + ], +) diff --git a/ddl/internal/session/session_pool.go b/ddl/internal/session/session_pool.go index 2e97867034b5d..d86164afe7a76 100644 --- a/ddl/internal/session/session_pool.go +++ b/ddl/internal/session/session_pool.go @@ -20,6 +20,7 @@ import ( "github.com/ngaut/pools" "github.com/pingcap/errors" + "github.com/pingcap/tidb/domain/infosync" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/sessionctx" @@ -54,7 +55,7 @@ func (sg *Pool) Get() (sessionctx.Context, error) { sg.mu.Lock() if sg.mu.closed { sg.mu.Unlock() - return nil, errors.Errorf("Session pool is closed") + return nil, errors.Errorf("session pool is closed") } sg.mu.Unlock() @@ -66,10 +67,11 @@ func (sg *Pool) Get() (sessionctx.Context, error) { ctx, ok := resource.(sessionctx.Context) if !ok { - return nil, fmt.Errorf("Session pool resource get %v", ctx) + return nil, errors.Trace(fmt.Errorf("need sessionctx.Context, but got %T", ctx)) } ctx.GetSessionVars().SetStatusFlag(mysql.ServerStatusAutocommit, true) ctx.GetSessionVars().InRestrictedSQL = true + infosync.StoreInternalSession(ctx) return ctx, nil } @@ -82,6 +84,7 @@ func (sg *Pool) Put(ctx sessionctx.Context) { // no need to protect sg.resPool, even the sg.resPool is closed, the ctx still need to // Put into resPool, because when resPool is closing, it will wait all the ctx returns, then resPool finish closing. sg.resPool.Put(ctx.(pools.Resource)) + infosync.DeleteInternalSession(ctx) } // Close clean up the Pool. @@ -92,7 +95,7 @@ func (sg *Pool) Close() { if sg.mu.closed || sg.resPool == nil { return } - logutil.BgLogger().Info("[ddl] closing Session pool") + logutil.BgLogger().Info("[ddl] closing session pool") sg.resPool.Close() sg.mu.closed = true } diff --git a/ddl/internal/session/session_pool_test.go b/ddl/internal/session/session_pool_test.go new file mode 100644 index 0000000000000..96de32dad51bc --- /dev/null +++ b/ddl/internal/session/session_pool_test.go @@ -0,0 +1,68 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session_test + +import ( + "context" + "testing" + + "github.com/ngaut/pools" + "github.com/pingcap/tidb/ddl/internal/session" + "github.com/pingcap/tidb/testkit" + "github.com/stretchr/testify/require" +) + +func TestSessionPool(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + resourcePool := pools.NewResourcePool(func() (pools.Resource, error) { + newTk := testkit.NewTestKit(t, store) + return newTk.Session(), nil + }, 4, 4, 0) + pool := session.NewSessionPool(resourcePool, store) + sessCtx, err := pool.Get() + require.NoError(t, err) + se := session.NewSession(sessCtx) + err = se.Begin() + startTS := se.GetSessionVars().TxnCtx.StartTS + require.NoError(t, err) + rows, err := se.Execute(context.Background(), "select 2;", "test") + require.NoError(t, err) + require.Equal(t, 1, len(rows)) + require.Equal(t, int64(2), rows[0].GetInt64(0)) + mgr := tk.Session().GetSessionManager() + tsList := mgr.GetInternalSessionStartTSList() + var targetTS uint64 + for _, ts := range tsList { + if ts == startTS { + targetTS = ts + break + } + } + require.NotEqual(t, uint64(0), targetTS) + err = se.Commit() + pool.Put(sessCtx) + require.NoError(t, err) + tsList = mgr.GetInternalSessionStartTSList() + targetTS = 0 + for _, ts := range tsList { + if ts == startTS { + targetTS = ts + break + } + } + require.Equal(t, uint64(0), targetTS) +} diff --git a/testkit/BUILD.bazel b/testkit/BUILD.bazel index 3e33b8bd3f555..16e766350c6ef 100644 --- a/testkit/BUILD.bazel +++ b/testkit/BUILD.bazel @@ -24,6 +24,7 @@ go_library( "//resourcemanager", "//session", "//session/txninfo", + "//sessionctx", "//sessionctx/variable", "//store/driver", "//store/mockstore", diff --git a/testkit/mocksessionmanager.go b/testkit/mocksessionmanager.go index f858a439f6e92..be09fa66c643a 100644 --- a/testkit/mocksessionmanager.go +++ b/testkit/mocksessionmanager.go @@ -23,6 +23,7 @@ import ( "github.com/pingcap/tidb/planner/core" "github.com/pingcap/tidb/session" "github.com/pingcap/tidb/session/txninfo" + "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/util" ) @@ -35,6 +36,8 @@ type MockSessionManager struct { Dom *domain.Domain Conn map[uint64]session.Session mu sync.Mutex + + internalSessions map[interface{}]struct{} } // ShowTxnList is to show txn list. @@ -104,7 +107,7 @@ func (msm *MockSessionManager) GetProcessInfo(id uint64) (*util.ProcessInfo, boo func (*MockSessionManager) Kill(uint64, bool) { } -// KillAllConnections implements the SessionManager.KillAllConections interface. +// KillAllConnections implements the SessionManager.KillAllConnections interface. func (*MockSessionManager) KillAllConnections() { } @@ -118,14 +121,33 @@ func (msm *MockSessionManager) ServerID() uint64 { } // StoreInternalSession is to store internal session. -func (*MockSessionManager) StoreInternalSession(interface{}) {} +func (msm *MockSessionManager) StoreInternalSession(s interface{}) { + msm.mu.Lock() + if msm.internalSessions == nil { + msm.internalSessions = make(map[interface{}]struct{}) + } + msm.internalSessions[s] = struct{}{} + msm.mu.Unlock() +} // DeleteInternalSession is to delete the internal session pointer from the map in the SessionManager -func (*MockSessionManager) DeleteInternalSession(interface{}) {} +func (msm *MockSessionManager) DeleteInternalSession(s interface{}) { + msm.mu.Lock() + delete(msm.internalSessions, s) + msm.mu.Unlock() +} -// GetInternalSessionStartTSList is to get all startTS of every transactions running in the current internal sessions -func (*MockSessionManager) GetInternalSessionStartTSList() []uint64 { - return nil +// GetInternalSessionStartTSList is to get all startTS of every transaction running in the current internal sessions +func (msm *MockSessionManager) GetInternalSessionStartTSList() []uint64 { + msm.mu.Lock() + defer msm.mu.Unlock() + ret := make([]uint64, 0, len(msm.internalSessions)) + for internalSess := range msm.internalSessions { + se := internalSess.(sessionctx.Context) + startTS := se.GetSessionVars().TxnCtx.StartTS + ret = append(ret, startTS) + } + return ret } // KillNonFlashbackClusterConn implement SessionManager interface.