Skip to content

Commit

Permalink
server: make clientConn() thread-safe
Browse files Browse the repository at this point in the history
  • Loading branch information
bb7133 committed Dec 1, 2023
1 parent b21bb3e commit 8dbd4ea
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 3 deletions.
11 changes: 8 additions & 3 deletions pkg/server/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,9 @@ type clientConn struct {
lastActive time.Time // last active time
authPlugin string // default authentication plugin
isUnixSocket bool // connection is Unix Socket file
rsEncoder *column.ResultEncoder // rsEncoder is used to encode the string result to different charsets.
inputDecoder *util2.InputDecoder // inputDecoder is used to decode the different charsets of incoming strings to utf-8.
isClosed int32 // atomic variable to track whether the connection is closed
rsEncoder *column.ResultEncoder // rsEncoder is used to encode the string result to different charsets
inputDecoder *util2.InputDecoder // inputDecoder is used to decode the different charsets of incoming strings to utf-8
socketCredUID uint32 // UID from the other end of the Unix Socket
// mu is used for cancelling the execution of current transaction.
mu struct {
Expand Down Expand Up @@ -349,9 +350,13 @@ func (cc *clientConn) Close() error {
return closeConn(cc, connections)
}

// closeConn should be idempotent.
// closeConn is idempotent and thread-safe.
// It will be called on the same `clientConn` more than once to avoid connection leak.
func closeConn(cc *clientConn, connections int) error {
if !atomic.CompareAndSwapInt32(&cc.isClosed, 0, 1) {
return nil
}

metrics.ConnGauge.Set(float64(connections))
if cc.connectionID > 0 {
cc.server.dom.ReleaseConnID(cc.connectionID)
Expand Down
41 changes: 41 additions & 0 deletions pkg/server/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"io"
"path/filepath"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
Expand Down Expand Up @@ -2052,3 +2053,43 @@ func TestStats(t *testing.T) {
require.Equal(t, "zstd", m["Compression_algorithm"])
require.Equal(t, 1, m["Compression_level"])
}

func TestCloseConn(t *testing.T) {
var outBuffer bytes.Buffer

store, _ := testkit.CreateMockStoreAndDomain(t)
cfg := serverutil.NewTestConfig()
cfg.Port = 0
cfg.Status.StatusPort = 0
drv := NewTiDBDriver(store)
server, err := NewServer(cfg, drv)
require.NoError(t, err)

cc := &clientConn{
connectionID: 0,
salt: []byte{
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A,
0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, 0x11, 0x12, 0x13, 0x14,
},
server: server,
pkt: internal.NewPacketIOForTest(bufio.NewWriter(&outBuffer)),
collation: mysql.DefaultCollationID,
peerHost: "localhost",
alloc: arena.NewAllocator(512),
chunkAlloc: chunk.NewAllocator(),
capability: mysql.ClientProtocol41,
}

var wg sync.WaitGroup
const numGoroutines = 10
wg.Add(numGoroutines)
for i := 0; i < numGoroutines; i++ {
go func() {
defer wg.Done()
err := closeConn(cc, 1)
require.NoError(t, err)
}()
}
wg.Wait()
require.Equalf(t, int32(1), cc.isClosed, "Expected isClosed to be 1, got 0")
}

0 comments on commit 8dbd4ea

Please sign in to comment.