diff --git a/server/conn.go b/server/conn.go index 3af992d798cd2..b3972097b8239 100644 --- a/server/conn.go +++ b/server/conn.go @@ -152,6 +152,7 @@ type clientConn struct { lastActive time.Time // last active time authPlugin string // default authentication plugin isUnixSocket bool // connection is Unix Socket file + closeOnce sync.Once // closeOnce is used to make sure clientConn closes only once rsEncoder *resultEncoder // rsEncoder is used to encode the string result to different charsets. inputDecoder *inputDecoder // inputDecoder is used to decode the different charsets of incoming strings to utf-8. socketCredUID uint32 // UID from the other end of the Unix Socket @@ -320,21 +321,33 @@ func (cc *clientConn) Close() error { } func closeConn(cc *clientConn, connections int) error { - metrics.ConnGauge.Set(float64(connections)) - if cc.bufReadConn != nil { - err := cc.bufReadConn.Close() - if err != nil { - // We need to expect connection might have already disconnected. - // This is because closeConn() might be called after a connection read-timeout. - logutil.Logger(context.Background()).Debug("could not close connection", zap.Error(err)) + var err error + cc.closeOnce.Do(func() { + metrics.ConnGauge.Set(float64(connections)) + + if cc.bufReadConn != nil { + err = cc.bufReadConn.Close() + if err != nil { + // We need to expect connection might have already disconnected. + // This is because closeConn() might be called after a connection read-timeout. + logutil.Logger(context.Background()).Debug("could not close connection", zap.Error(err)) + } + if cc.bufReadConn != nil { + err = cc.bufReadConn.Close() + if err != nil { + // We need to expect connection might have already disconnected. + // This is because closeConn() might be called after a connection read-timeout. + logutil.Logger(context.Background()).Debug("could not close connection", zap.Error(err)) + } + } + // Close statements and session + // This will release advisory locks, row locks, etc. + if ctx := cc.getCtx(); ctx != nil { + err = ctx.Close() + } } - } - // Close statements and session - // This will release advisory locks, row locks, etc. - if ctx := cc.getCtx(); ctx != nil { - return ctx.Close() - } - return nil + }) + return err } func (cc *clientConn) closeWithoutLock() error { diff --git a/server/conn_test.go b/server/conn_test.go index f97f5d0a91a06..0a2a8c428dcb6 100644 --- a/server/conn_test.go +++ b/server/conn_test.go @@ -25,6 +25,7 @@ import ( "io" "path/filepath" "strings" + "sync" "sync/atomic" "testing" "time" @@ -2025,3 +2026,44 @@ func TestLDAPAuthSwitch(t *testing.T) { require.NoError(t, err) require.Equal(t, []byte(mysql.AuthMySQLClearPassword), respAuthSwitch) } + +func TestCloseConn(t *testing.T) { + var outBuffer bytes.Buffer + + store, _ := testkit.CreateMockStoreAndDomain(t) + cfg := newTestConfig() + cfg.Port = 0 + cfg.Status.StatusPort = 0 + drv := NewTiDBDriver(store) + server, err := NewServer(cfg, drv) + require.NoError(t, err) + + cc := &clientConn{ + connectionID: 0, + salt: []byte{ + 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, + 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, 0x11, 0x12, 0x13, 0x14, + }, + server: server, + pkt: &packetIO{ + bufWriter: bufio.NewWriter(&outBuffer), + }, + collation: mysql.DefaultCollationID, + peerHost: "localhost", + alloc: arena.NewAllocator(512), + chunkAlloc: chunk.NewAllocator(), + capability: mysql.ClientProtocol41, + } + + var wg sync.WaitGroup + const numGoroutines = 10 + wg.Add(numGoroutines) + for i := 0; i < numGoroutines; i++ { + go func() { + defer wg.Done() + err := closeConn(cc, 1) + require.NoError(t, err) + }() + } + wg.Wait() +}