Skip to content

Commit

Permalink
ddl: include ddl sessions in session manager's internal list
Browse files Browse the repository at this point in the history
  • Loading branch information
tangenta committed Apr 10, 2023
1 parent a8024e2 commit 6b7734c
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 9 deletions.
9 changes: 6 additions & 3 deletions ddl/internal/session/session_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (

"github.com/ngaut/pools"
"github.com/pingcap/errors"
"github.com/pingcap/tidb/domain/infosync"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/sessionctx"
Expand Down Expand Up @@ -54,7 +55,7 @@ func (sg *Pool) Get() (sessionctx.Context, error) {
sg.mu.Lock()
if sg.mu.closed {
sg.mu.Unlock()
return nil, errors.Errorf("Session pool is closed")
return nil, errors.Errorf("session pool is closed")
}
sg.mu.Unlock()

Expand All @@ -66,10 +67,11 @@ func (sg *Pool) Get() (sessionctx.Context, error) {

ctx, ok := resource.(sessionctx.Context)
if !ok {
return nil, fmt.Errorf("Session pool resource get %v", ctx)
return nil, errors.Trace(fmt.Errorf("need sessionctx.Context, but got %T", ctx))
}
ctx.GetSessionVars().SetStatusFlag(mysql.ServerStatusAutocommit, true)
ctx.GetSessionVars().InRestrictedSQL = true
infosync.StoreInternalSession(ctx)
return ctx, nil
}

Expand All @@ -82,6 +84,7 @@ func (sg *Pool) Put(ctx sessionctx.Context) {
// no need to protect sg.resPool, even the sg.resPool is closed, the ctx still need to
// Put into resPool, because when resPool is closing, it will wait all the ctx returns, then resPool finish closing.
sg.resPool.Put(ctx.(pools.Resource))
infosync.DeleteInternalSession(ctx)
}

// Close clean up the Pool.
Expand All @@ -92,7 +95,7 @@ func (sg *Pool) Close() {
if sg.mu.closed || sg.resPool == nil {
return
}
logutil.BgLogger().Info("[ddl] closing Session pool")
logutil.BgLogger().Info("[ddl] closing session pool")
sg.resPool.Close()
sg.mu.closed = true
}
68 changes: 68 additions & 0 deletions ddl/internal/session/session_pool_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// Copyright 2023 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package session_test

import (
"context"
"testing"

"github.com/ngaut/pools"
"github.com/pingcap/tidb/ddl/internal/session"
"github.com/pingcap/tidb/testkit"
"github.com/stretchr/testify/require"
)

func TestSessionPool(t *testing.T) {
store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)
tk.MustExec("use test")
resourcePool := pools.NewResourcePool(func() (pools.Resource, error) {
newTk := testkit.NewTestKit(t, store)
return newTk.Session(), nil
}, 4, 4, 0)
pool := session.NewSessionPool(resourcePool, store)
sessCtx, err := pool.Get()
require.NoError(t, err)
se := session.NewSession(sessCtx)
err = se.Begin()
startTS := se.GetSessionVars().TxnCtx.StartTS
require.NoError(t, err)
rows, err := se.Execute(context.Background(), "select 2;", "test")
require.NoError(t, err)
require.Equal(t, 1, len(rows))
require.Equal(t, int64(2), rows[0].GetInt64(0))
mgr := tk.Session().GetSessionManager()
tsList := mgr.GetInternalSessionStartTSList()
var targetTS uint64
for _, ts := range tsList {
if ts == startTS {
targetTS = ts
break
}
}
require.NotEqual(t, uint64(0), targetTS)
err = se.Commit()
pool.Put(sessCtx)
require.NoError(t, err)
tsList = mgr.GetInternalSessionStartTSList()
targetTS = 0
for _, ts := range tsList {
if ts == startTS {
targetTS = ts
break
}
}
require.Equal(t, uint64(0), targetTS)
}
34 changes: 28 additions & 6 deletions testkit/mocksessionmanager.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"github.com/pingcap/tidb/planner/core"
"github.com/pingcap/tidb/session"
"github.com/pingcap/tidb/session/txninfo"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/util"
)

Expand All @@ -35,6 +36,8 @@ type MockSessionManager struct {
Dom *domain.Domain
Conn map[uint64]session.Session
mu sync.Mutex

internalSessions map[interface{}]struct{}
}

// ShowTxnList is to show txn list.
Expand Down Expand Up @@ -104,7 +107,7 @@ func (msm *MockSessionManager) GetProcessInfo(id uint64) (*util.ProcessInfo, boo
func (*MockSessionManager) Kill(uint64, bool) {
}

// KillAllConnections implements the SessionManager.KillAllConections interface.
// KillAllConnections implements the SessionManager.KillAllConnections interface.
func (*MockSessionManager) KillAllConnections() {
}

Expand All @@ -118,14 +121,33 @@ func (msm *MockSessionManager) ServerID() uint64 {
}

// StoreInternalSession is to store internal session.
func (*MockSessionManager) StoreInternalSession(interface{}) {}
func (msm *MockSessionManager) StoreInternalSession(s interface{}) {
msm.mu.Lock()
if msm.internalSessions == nil {
msm.internalSessions = make(map[interface{}]struct{})
}
msm.internalSessions[s] = struct{}{}
msm.mu.Unlock()
}

// DeleteInternalSession is to delete the internal session pointer from the map in the SessionManager
func (*MockSessionManager) DeleteInternalSession(interface{}) {}
func (msm *MockSessionManager) DeleteInternalSession(s interface{}) {
msm.mu.Lock()
delete(msm.internalSessions, s)
msm.mu.Unlock()
}

// GetInternalSessionStartTSList is to get all startTS of every transactions running in the current internal sessions
func (*MockSessionManager) GetInternalSessionStartTSList() []uint64 {
return nil
// GetInternalSessionStartTSList is to get all startTS of every transaction running in the current internal sessions
func (msm *MockSessionManager) GetInternalSessionStartTSList() []uint64 {
msm.mu.Lock()
defer msm.mu.Unlock()
ret := make([]uint64, 0, len(msm.internalSessions))
for internalSess, _ := range msm.internalSessions {
se := internalSess.(sessionctx.Context)
startTS := se.GetSessionVars().TxnCtx.StartTS
ret = append(ret, startTS)
}
return ret
}

// KillNonFlashbackClusterConn implement SessionManager interface.
Expand Down

0 comments on commit 6b7734c

Please sign in to comment.