From 2d33c6eb6c23d7f27b10863fd8bfe01c25d348f8 Mon Sep 17 00:00:00 2001 From: bb7133 Date: Fri, 1 Dec 2023 18:07:20 -0800 Subject: [PATCH] This is an automated cherry-pick of #49073 Signed-off-by: ti-chi-bot --- server/conn.go | 51 +++++++++++++-- server/conn_test.go | 155 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 199 insertions(+), 7 deletions(-) diff --git a/server/conn.go b/server/conn.go index 44f43920a8b37..622bfe6adadec 100644 --- a/server/conn.go +++ b/server/conn.go @@ -198,6 +198,7 @@ type clientConn struct { sync.RWMutex *TiDBContext // an interface to execute sql statements. } +<<<<<<< HEAD:server/conn.go attrs map[string]string // attributes parsed from client handshake response, not used for now. serverHost string // server host peerHost string // peer host @@ -211,6 +212,22 @@ type clientConn struct { rsEncoder *resultEncoder // rsEncoder is used to encode the string result to different charsets. inputDecoder *inputDecoder // inputDecoder is used to decode the different charsets of incoming strings to utf-8. socketCredUID uint32 // UID from the other end of the Unix Socket +======= + attrs map[string]string // attributes parsed from client handshake response. + serverHost string // server host + peerHost string // peer host + peerPort string // peer port + status int32 // dispatching/reading/shutdown/waitshutdown + lastCode uint16 // last error code + collation uint8 // collation used by client, may be different from the collation used by database. + lastActive time.Time // last active time + authPlugin string // default authentication plugin + isUnixSocket bool // connection is Unix Socket file + closeOnce sync.Once // closeOnce is used to make sure clientConn closes only once + rsEncoder *column.ResultEncoder // rsEncoder is used to encode the string result to different charsets + inputDecoder *util2.InputDecoder // inputDecoder is used to decode the different charsets of incoming strings to utf-8 + socketCredUID uint32 // UID from the other end of the Unix Socket +>>>>>>> 43825796a66 (server: make `clientConn()` thread-safe (#49073)):pkg/server/conn.go // mu is used for cancelling the execution of current transaction. mu struct { sync.RWMutex @@ -348,6 +365,7 @@ func (cc *clientConn) Close() error { return closeConn(cc, connections) } +<<<<<<< HEAD:server/conn.go func closeConn(cc *clientConn, connections int) error { metrics.ConnGauge.Set(float64(connections)) if cc.bufReadConn != nil { @@ -356,14 +374,33 @@ func closeConn(cc *clientConn, connections int) error { // We need to expect connection might have already disconnected. // This is because closeConn() might be called after a connection read-timeout. logutil.Logger(context.Background()).Debug("could not close connection", zap.Error(err)) +======= +// closeConn is idempotent and thread-safe. +// It will be called on the same `clientConn` more than once to avoid connection leak. +func closeConn(cc *clientConn, connections int) error { + var err error + cc.closeOnce.Do(func() { + metrics.ConnGauge.Set(float64(connections)) + if cc.connectionID > 0 { + cc.server.dom.ReleaseConnID(cc.connectionID) + cc.connectionID = 0 +>>>>>>> 43825796a66 (server: make `clientConn()` thread-safe (#49073)):pkg/server/conn.go + } + if cc.bufReadConn != nil { + err := cc.bufReadConn.Close() + if err != nil { + // We need to expect connection might have already disconnected. + // This is because closeConn() might be called after a connection read-timeout. + logutil.Logger(context.Background()).Debug("could not close connection", zap.Error(err)) + } } - } - // Close statements and session - // This will release advisory locks, row locks, etc. - if ctx := cc.getCtx(); ctx != nil { - return ctx.Close() - } - return nil + // Close statements and session + // This will release advisory locks, row locks, etc. + if ctx := cc.getCtx(); ctx != nil { + err = ctx.Close() + } + }) + return err } func (cc *clientConn) closeWithoutLock() error { diff --git a/server/conn_test.go b/server/conn_test.go index c540b1784793d..1062c78e2e08e 100644 --- a/server/conn_test.go +++ b/server/conn_test.go @@ -25,6 +25,7 @@ import ( "io" "path/filepath" "strings" + "sync" "sync/atomic" "testing" "time" @@ -1805,3 +1806,157 @@ func TestProcessInfoForExecuteCommand(t *testing.T) { 0x0A, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0})) require.Equal(t, cc.ctx.Session.ShowProcess().Info, "select sum(col1) from t where col1 < ? and col1 > 100") } +<<<<<<< HEAD:server/conn_test.go +======= + +func TestLDAPAuthSwitch(t *testing.T) { + store := testkit.CreateMockStore(t) + cfg := serverutil.NewTestConfig() + cfg.Port = 0 + cfg.Status.StatusPort = 0 + drv := NewTiDBDriver(store) + srv, err := NewServer(cfg, drv) + require.NoError(t, err) + tk := testkit.NewTestKit(t, store) + tk.MustExec("CREATE USER test_simple_ldap IDENTIFIED WITH authentication_ldap_simple AS 'uid=test_simple_ldap,dc=example,dc=com'") + + cc := &clientConn{ + connectionID: 1, + alloc: arena.NewAllocator(1024), + chunkAlloc: chunk.NewAllocator(), + pkt: internal.NewPacketIOForTest(bufio.NewWriter(bytes.NewBuffer(nil))), + server: srv, + user: "test_simple_ldap", + } + se, _ := session.CreateSession4Test(store) + tc := &TiDBContext{ + Session: se, + stmts: make(map[int]*TiDBStatement), + } + cc.SetCtx(tc) + cc.isUnixSocket = true + + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/server/FakeAuthSwitch", "return(1)")) + respAuthSwitch, err := cc.checkAuthPlugin(context.Background(), &handshake.Response41{ + Capability: mysql.ClientProtocol41 | mysql.ClientPluginAuth, + User: "test_simple_ldap", + }) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/server/FakeAuthSwitch")) + require.NoError(t, err) + require.Equal(t, []byte(mysql.AuthMySQLClearPassword), respAuthSwitch) +} + +func TestEmptyOrgName(t *testing.T) { + inputs := []dispatchInput{ + { + com: mysql.ComQuery, + in: append([]byte("SELECT DATE_FORMAT(CONCAT('2023-07-0', a), '%Y') AS 'YEAR' FROM test.t"), 0x0), + err: nil, + out: []byte{0x1, 0x0, 0x0, 0x0, 0x1, // 1 column + 0x1a, 0x0, 0x0, + 0x1, 0x3, 0x64, 0x65, 0x66, // catalog + 0x0, // schema + 0x0, // table name + 0x0, // org table + 0x4, 0x59, 0x45, 0x41, 0x52, // name 'YEAR' + 0x0, // org name + 0xc, 0x2e, 0x0, 0x2c, 0x0, 0x0, 0x0, 0xfd, 0x0, 0x0, 0x1f, 0x0, 0x0, 0x1, 0x0, 0x0, 0x2, 0xfe, 0x5, 0x0, + 0x0, 0x3, 0x4, 0x32, 0x30, 0x32, 0x33, 0x1, 0x0, 0x0, 0x4, 0xfe}, + }, + } + + testDispatch(t, inputs, 0) +} + +func TestStats(t *testing.T) { + var outBuffer bytes.Buffer + + store := testkit.CreateMockStore(t) + cfg := serverutil.NewTestConfig() + cfg.Port = 0 + cfg.Status.StatusPort = 0 + drv := NewTiDBDriver(store) + server, err := NewServer(cfg, drv) + require.NoError(t, err) + tk := testkit.NewTestKit(t, store) + + cc := &clientConn{ + connectionID: 1, + salt: []byte{ + 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, + 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, 0x11, 0x12, 0x13, 0x14, + }, + server: server, + pkt: internal.NewPacketIOForTest(bufio.NewWriter(&outBuffer)), + collation: mysql.DefaultCollationID, + peerHost: "localhost", + alloc: arena.NewAllocator(512), + chunkAlloc: chunk.NewAllocator(), + capability: mysql.ClientProtocol41, + } + + // No compression + vars := tk.Session().GetSessionVars() + m, err := cc.Stats(vars) + require.NoError(t, err) + require.Equal(t, "OFF", m["Compression"]) + require.Equal(t, "", m["Compression_algorithm"]) + require.Equal(t, 0, m["Compression_level"]) + + // zlib compression + vars.CompressionAlgorithm = mysql.CompressionZlib + m, err = cc.Stats(vars) + require.NoError(t, err) + require.Equal(t, "ON", m["Compression"]) + require.Equal(t, "zlib", m["Compression_algorithm"]) + require.Equal(t, mysql.ZlibCompressDefaultLevel, m["Compression_level"]) + + // zstd compression, with level 1 + vars.CompressionAlgorithm = mysql.CompressionZstd + vars.CompressionLevel = 1 + m, err = cc.Stats(vars) + require.NoError(t, err) + require.Equal(t, "ON", m["Compression"]) + require.Equal(t, "zstd", m["Compression_algorithm"]) + require.Equal(t, 1, m["Compression_level"]) +} + +func TestCloseConn(t *testing.T) { + var outBuffer bytes.Buffer + + store, _ := testkit.CreateMockStoreAndDomain(t) + cfg := serverutil.NewTestConfig() + cfg.Port = 0 + cfg.Status.StatusPort = 0 + drv := NewTiDBDriver(store) + server, err := NewServer(cfg, drv) + require.NoError(t, err) + + cc := &clientConn{ + connectionID: 0, + salt: []byte{ + 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, + 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, 0x11, 0x12, 0x13, 0x14, + }, + server: server, + pkt: internal.NewPacketIOForTest(bufio.NewWriter(&outBuffer)), + collation: mysql.DefaultCollationID, + peerHost: "localhost", + alloc: arena.NewAllocator(512), + chunkAlloc: chunk.NewAllocator(), + capability: mysql.ClientProtocol41, + } + + var wg sync.WaitGroup + const numGoroutines = 10 + wg.Add(numGoroutines) + for i := 0; i < numGoroutines; i++ { + go func() { + defer wg.Done() + err := closeConn(cc, 1) + require.NoError(t, err) + }() + } + wg.Wait() +} +>>>>>>> 43825796a66 (server: make `clientConn()` thread-safe (#49073)):pkg/server/conn_test.go