Skip to content

Commit

Permalink
This is an automated cherry-pick of pingcap#49073
Browse files Browse the repository at this point in the history
Signed-off-by: ti-chi-bot <[email protected]>
  • Loading branch information
bb7133 authored and ti-chi-bot committed Dec 2, 2023
1 parent e30242b commit 2d33c6e
Show file tree
Hide file tree
Showing 2 changed files with 199 additions and 7 deletions.
51 changes: 44 additions & 7 deletions server/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ type clientConn struct {
sync.RWMutex
*TiDBContext // an interface to execute sql statements.
}
<<<<<<< HEAD:server/conn.go
attrs map[string]string // attributes parsed from client handshake response, not used for now.
serverHost string // server host
peerHost string // peer host
Expand All @@ -211,6 +212,22 @@ type clientConn struct {
rsEncoder *resultEncoder // rsEncoder is used to encode the string result to different charsets.
inputDecoder *inputDecoder // inputDecoder is used to decode the different charsets of incoming strings to utf-8.
socketCredUID uint32 // UID from the other end of the Unix Socket
=======
attrs map[string]string // attributes parsed from client handshake response.
serverHost string // server host
peerHost string // peer host
peerPort string // peer port
status int32 // dispatching/reading/shutdown/waitshutdown
lastCode uint16 // last error code
collation uint8 // collation used by client, may be different from the collation used by database.
lastActive time.Time // last active time
authPlugin string // default authentication plugin
isUnixSocket bool // connection is Unix Socket file
closeOnce sync.Once // closeOnce is used to make sure clientConn closes only once
rsEncoder *column.ResultEncoder // rsEncoder is used to encode the string result to different charsets
inputDecoder *util2.InputDecoder // inputDecoder is used to decode the different charsets of incoming strings to utf-8
socketCredUID uint32 // UID from the other end of the Unix Socket
>>>>>>> 43825796a66 (server: make `clientConn()` thread-safe (#49073)):pkg/server/conn.go
// mu is used for cancelling the execution of current transaction.
mu struct {
sync.RWMutex
Expand Down Expand Up @@ -348,6 +365,7 @@ func (cc *clientConn) Close() error {
return closeConn(cc, connections)
}

<<<<<<< HEAD:server/conn.go
func closeConn(cc *clientConn, connections int) error {
metrics.ConnGauge.Set(float64(connections))
if cc.bufReadConn != nil {
Expand All @@ -356,14 +374,33 @@ func closeConn(cc *clientConn, connections int) error {
// We need to expect connection might have already disconnected.
// This is because closeConn() might be called after a connection read-timeout.
logutil.Logger(context.Background()).Debug("could not close connection", zap.Error(err))
=======
// closeConn is idempotent and thread-safe.
// It will be called on the same `clientConn` more than once to avoid connection leak.
func closeConn(cc *clientConn, connections int) error {
var err error
cc.closeOnce.Do(func() {
metrics.ConnGauge.Set(float64(connections))
if cc.connectionID > 0 {
cc.server.dom.ReleaseConnID(cc.connectionID)
cc.connectionID = 0
>>>>>>> 43825796a66 (server: make `clientConn()` thread-safe (#49073)):pkg/server/conn.go
}
if cc.bufReadConn != nil {
err := cc.bufReadConn.Close()
if err != nil {
// We need to expect connection might have already disconnected.
// This is because closeConn() might be called after a connection read-timeout.
logutil.Logger(context.Background()).Debug("could not close connection", zap.Error(err))
}
}
}
// Close statements and session
// This will release advisory locks, row locks, etc.
if ctx := cc.getCtx(); ctx != nil {
return ctx.Close()
}
return nil
// Close statements and session
// This will release advisory locks, row locks, etc.
if ctx := cc.getCtx(); ctx != nil {
err = ctx.Close()
}
})
return err
}

func (cc *clientConn) closeWithoutLock() error {
Expand Down
155 changes: 155 additions & 0 deletions server/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"io"
"path/filepath"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
Expand Down Expand Up @@ -1805,3 +1806,157 @@ func TestProcessInfoForExecuteCommand(t *testing.T) {
0x0A, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}))
require.Equal(t, cc.ctx.Session.ShowProcess().Info, "select sum(col1) from t where col1 < ? and col1 > 100")
}
<<<<<<< HEAD:server/conn_test.go
=======

func TestLDAPAuthSwitch(t *testing.T) {
store := testkit.CreateMockStore(t)
cfg := serverutil.NewTestConfig()
cfg.Port = 0
cfg.Status.StatusPort = 0
drv := NewTiDBDriver(store)
srv, err := NewServer(cfg, drv)
require.NoError(t, err)
tk := testkit.NewTestKit(t, store)
tk.MustExec("CREATE USER test_simple_ldap IDENTIFIED WITH authentication_ldap_simple AS 'uid=test_simple_ldap,dc=example,dc=com'")

cc := &clientConn{
connectionID: 1,
alloc: arena.NewAllocator(1024),
chunkAlloc: chunk.NewAllocator(),
pkt: internal.NewPacketIOForTest(bufio.NewWriter(bytes.NewBuffer(nil))),
server: srv,
user: "test_simple_ldap",
}
se, _ := session.CreateSession4Test(store)
tc := &TiDBContext{
Session: se,
stmts: make(map[int]*TiDBStatement),
}
cc.SetCtx(tc)
cc.isUnixSocket = true

require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/server/FakeAuthSwitch", "return(1)"))
respAuthSwitch, err := cc.checkAuthPlugin(context.Background(), &handshake.Response41{
Capability: mysql.ClientProtocol41 | mysql.ClientPluginAuth,
User: "test_simple_ldap",
})
require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/server/FakeAuthSwitch"))
require.NoError(t, err)
require.Equal(t, []byte(mysql.AuthMySQLClearPassword), respAuthSwitch)
}

func TestEmptyOrgName(t *testing.T) {
inputs := []dispatchInput{
{
com: mysql.ComQuery,
in: append([]byte("SELECT DATE_FORMAT(CONCAT('2023-07-0', a), '%Y') AS 'YEAR' FROM test.t"), 0x0),
err: nil,
out: []byte{0x1, 0x0, 0x0, 0x0, 0x1, // 1 column
0x1a, 0x0, 0x0,
0x1, 0x3, 0x64, 0x65, 0x66, // catalog
0x0, // schema
0x0, // table name
0x0, // org table
0x4, 0x59, 0x45, 0x41, 0x52, // name 'YEAR'
0x0, // org name
0xc, 0x2e, 0x0, 0x2c, 0x0, 0x0, 0x0, 0xfd, 0x0, 0x0, 0x1f, 0x0, 0x0, 0x1, 0x0, 0x0, 0x2, 0xfe, 0x5, 0x0,
0x0, 0x3, 0x4, 0x32, 0x30, 0x32, 0x33, 0x1, 0x0, 0x0, 0x4, 0xfe},
},
}

testDispatch(t, inputs, 0)
}

func TestStats(t *testing.T) {
var outBuffer bytes.Buffer

store := testkit.CreateMockStore(t)
cfg := serverutil.NewTestConfig()
cfg.Port = 0
cfg.Status.StatusPort = 0
drv := NewTiDBDriver(store)
server, err := NewServer(cfg, drv)
require.NoError(t, err)
tk := testkit.NewTestKit(t, store)

cc := &clientConn{
connectionID: 1,
salt: []byte{
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A,
0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, 0x11, 0x12, 0x13, 0x14,
},
server: server,
pkt: internal.NewPacketIOForTest(bufio.NewWriter(&outBuffer)),
collation: mysql.DefaultCollationID,
peerHost: "localhost",
alloc: arena.NewAllocator(512),
chunkAlloc: chunk.NewAllocator(),
capability: mysql.ClientProtocol41,
}

// No compression
vars := tk.Session().GetSessionVars()
m, err := cc.Stats(vars)
require.NoError(t, err)
require.Equal(t, "OFF", m["Compression"])
require.Equal(t, "", m["Compression_algorithm"])
require.Equal(t, 0, m["Compression_level"])

// zlib compression
vars.CompressionAlgorithm = mysql.CompressionZlib
m, err = cc.Stats(vars)
require.NoError(t, err)
require.Equal(t, "ON", m["Compression"])
require.Equal(t, "zlib", m["Compression_algorithm"])
require.Equal(t, mysql.ZlibCompressDefaultLevel, m["Compression_level"])

// zstd compression, with level 1
vars.CompressionAlgorithm = mysql.CompressionZstd
vars.CompressionLevel = 1
m, err = cc.Stats(vars)
require.NoError(t, err)
require.Equal(t, "ON", m["Compression"])
require.Equal(t, "zstd", m["Compression_algorithm"])
require.Equal(t, 1, m["Compression_level"])
}

func TestCloseConn(t *testing.T) {
var outBuffer bytes.Buffer

store, _ := testkit.CreateMockStoreAndDomain(t)
cfg := serverutil.NewTestConfig()
cfg.Port = 0
cfg.Status.StatusPort = 0
drv := NewTiDBDriver(store)
server, err := NewServer(cfg, drv)
require.NoError(t, err)

cc := &clientConn{
connectionID: 0,
salt: []byte{
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A,
0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, 0x11, 0x12, 0x13, 0x14,
},
server: server,
pkt: internal.NewPacketIOForTest(bufio.NewWriter(&outBuffer)),
collation: mysql.DefaultCollationID,
peerHost: "localhost",
alloc: arena.NewAllocator(512),
chunkAlloc: chunk.NewAllocator(),
capability: mysql.ClientProtocol41,
}

var wg sync.WaitGroup
const numGoroutines = 10
wg.Add(numGoroutines)
for i := 0; i < numGoroutines; i++ {
go func() {
defer wg.Done()
err := closeConn(cc, 1)
require.NoError(t, err)
}()
}
wg.Wait()
}
>>>>>>> 43825796a66 (server: make `clientConn()` thread-safe (#49073)):pkg/server/conn_test.go

0 comments on commit 2d33c6e

Please sign in to comment.