Skip to content

Commit

Permalink
extension: add error and active roles info to `extension.ConnEventInf…
Browse files Browse the repository at this point in the history
…o` (#38752)

close #38493
  • Loading branch information
lcwangchao authored Nov 1, 2022
1 parent 38e9aa0 commit ecdc0f7
Show file tree
Hide file tree
Showing 8 changed files with 106 additions and 37 deletions.
1 change: 1 addition & 0 deletions extension/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ go_library(
importpath = "github.com/pingcap/tidb/extension",
visibility = ["//visibility:public"],
deps = [
"//parser/auth",
"//sessionctx/variable",
"//types",
"//util/chunk",
Expand Down
2 changes: 1 addition & 1 deletion extension/extensionimpl/bootstrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func (c *bootstrapContext) SessionPool() extension.SessionPool {
return c.sessionPool
}

// Bootstrap bootstrap all extensions
// Bootstrap bootstraps all extensions
func Bootstrap(ctx context.Context, do *domain.Domain) error {
extensions, err := extension.GetExtensions()
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion extension/extensions.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func (es *Extensions) Manifests() []*Manifest {
return manifests
}

// Bootstrap bootstrap all extensions
// Bootstrap bootstraps all extensions
func (es *Extensions) Bootstrap(ctx BootstrapContext) error {
if es == nil {
return nil
Expand Down
16 changes: 11 additions & 5 deletions extension/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,17 @@

package extension

import "github.com/pingcap/tidb/sessionctx/variable"
import (
"github.com/pingcap/tidb/parser/auth"
"github.com/pingcap/tidb/sessionctx/variable"
)

// ConnEventInfo is the connection info for the event
type ConnEventInfo variable.ConnectionInfo
type ConnEventInfo struct {
*variable.ConnectionInfo
ActiveRoles []*auth.RoleIdentity
Error error
}

// ConnEventTp is the type of the connection event
type ConnEventTp uint8
Expand Down Expand Up @@ -60,13 +67,12 @@ type SessionExtensions struct {
}

// OnConnectionEvent will be called when a connection event happens
func (es *SessionExtensions) OnConnectionEvent(tp ConnEventTp, info *variable.ConnectionInfo) {
func (es *SessionExtensions) OnConnectionEvent(tp ConnEventTp, event *ConnEventInfo) {
if es == nil {
return
}

eventInfo := ConnEventInfo(*info)
for _, fn := range es.connectionEventFuncs {
fn(tp, &eventInfo)
fn(tp, event)
}
}
3 changes: 1 addition & 2 deletions server/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -2510,8 +2510,7 @@ func (cc *clientConn) handleCommonConnectionReset(ctx context.Context) error {
connectionInfo := cc.connectInfo()
cc.ctx.GetSessionVars().ConnectionInfo = connectionInfo

cc.extensions.OnConnectionEvent(extension.ConnReset, connectionInfo)

cc.onExtensionConnEvent(extension.ConnReset, nil)
err := plugin.ForeachPlugin(plugin.Audit, func(p *plugin.Plugin) error {
authPlugin := plugin.DeclareAuditManifest(p.Manifest)
if authPlugin.OnConnectionEvent != nil {
Expand Down
17 changes: 13 additions & 4 deletions server/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1663,7 +1663,10 @@ func TestExtensionChangeUser(t *testing.T) {
outBuffer.Reset()
}

expectedConnInfo := extension.ConnEventInfo(*cc.connectInfo())
expectedConnInfo := extension.ConnEventInfo{
ConnectionInfo: cc.connectInfo(),
ActiveRoles: []*auth.RoleIdentity{},
}
expectedConnInfo.User = "user1"
expectedConnInfo.DB = "db1"

Expand All @@ -1679,7 +1682,9 @@ func TestExtensionChangeUser(t *testing.T) {
})
require.True(t, logged)
require.Equal(t, extension.ConnReset, logTp)
require.Equal(t, expectedConnInfo, *logInfo)
require.Equal(t, expectedConnInfo.ActiveRoles, logInfo.ActiveRoles)
require.Equal(t, expectedConnInfo.Error, logInfo.Error)
require.Equal(t, *(expectedConnInfo.ConnectionInfo), *(logInfo.ConnectionInfo))

logged = false
logTp = 0
Expand All @@ -1697,7 +1702,9 @@ func TestExtensionChangeUser(t *testing.T) {
})
require.True(t, logged)
require.Equal(t, extension.ConnReset, logTp)
require.Equal(t, expectedConnInfo, *logInfo)
require.Equal(t, expectedConnInfo.ActiveRoles, logInfo.ActiveRoles)
require.Equal(t, expectedConnInfo.Error, logInfo.Error)
require.Equal(t, *(expectedConnInfo.ConnectionInfo), *(logInfo.ConnectionInfo))

logged = false
logTp = 0
Expand All @@ -1710,5 +1717,7 @@ func TestExtensionChangeUser(t *testing.T) {
})
require.True(t, logged)
require.Equal(t, extension.ConnReset, logTp)
require.Equal(t, expectedConnInfo, *logInfo)
require.Equal(t, expectedConnInfo.ActiveRoles, logInfo.ActiveRoles)
require.Equal(t, expectedConnInfo.Error, logInfo.Error)
require.Equal(t, *(expectedConnInfo.ConnectionInfo), *(logInfo.ConnectionInfo))
}
42 changes: 33 additions & 9 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ import (
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/metrics"
"github.com/pingcap/tidb/parser/ast"
"github.com/pingcap/tidb/parser/auth"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/parser/terror"
"github.com/pingcap/tidb/planner/core"
Expand Down Expand Up @@ -510,7 +511,6 @@ func (s *Server) onConn(conn *clientConn) {
terror.Log(conn.Close())
return
}
connectionInfo := conn.connectInfo()

extensions, err := extension.GetExtensions()
if err != nil {
Expand All @@ -522,18 +522,17 @@ func (s *Server) onConn(conn *clientConn) {

if sessExtensions := extensions.NewSessionExtensions(); sessExtensions != nil {
conn.extensions = sessExtensions
sessExtensions.OnConnectionEvent(extension.ConnConnected, connectionInfo)
conn.onExtensionConnEvent(extension.ConnConnected, nil)
defer func() {
sessExtensions.OnConnectionEvent(extension.ConnDisconnected, connectionInfo)
conn.onExtensionConnEvent(extension.ConnDisconnected, nil)
}()
}

ctx := logutil.WithConnID(context.Background(), conn.connectionID)
if err := conn.handshake(ctx); err != nil {
connectionInfo = conn.connectInfo()
conn.extensions.OnConnectionEvent(extension.ConnHandshakeRejected, connectionInfo)
conn.onExtensionConnEvent(extension.ConnHandshakeRejected, err)
if plugin.IsEnable(plugin.Audit) && conn.getCtx() != nil {
conn.getCtx().GetSessionVars().ConnectionInfo = connectionInfo
conn.getCtx().GetSessionVars().ConnectionInfo = conn.connectInfo()
err = plugin.ForeachPlugin(plugin.Audit, func(p *plugin.Plugin) error {
authPlugin := plugin.DeclareAuditManifest(p.Manifest)
if authPlugin.OnConnectionEvent != nil {
Expand Down Expand Up @@ -578,9 +577,8 @@ func (s *Server) onConn(conn *clientConn) {
metrics.ConnGauge.Set(float64(connections))

sessionVars := conn.ctx.GetSessionVars()
connectionInfo = conn.connectInfo()
sessionVars.ConnectionInfo = connectionInfo
conn.extensions.OnConnectionEvent(extension.ConnHandshakeAccepted, connectionInfo)
sessionVars.ConnectionInfo = conn.connectInfo()
conn.onExtensionConnEvent(extension.ConnHandshakeAccepted, nil)
err = plugin.ForeachPlugin(plugin.Audit, func(p *plugin.Plugin) error {
authPlugin := plugin.DeclareAuditManifest(p.Manifest)
if authPlugin.OnConnectionEvent != nil {
Expand Down Expand Up @@ -638,6 +636,32 @@ func (cc *clientConn) connectInfo() *variable.ConnectionInfo {
return connInfo
}

func (cc *clientConn) onExtensionConnEvent(tp extension.ConnEventTp, err error) {
if cc.extensions == nil {
return
}

var connInfo *variable.ConnectionInfo
var activeRoles []*auth.RoleIdentity
if ctx := cc.getCtx(); ctx != nil {
sessVars := ctx.GetSessionVars()
connInfo = sessVars.ConnectionInfo
activeRoles = sessVars.ActiveRoles
}

if connInfo == nil {
connInfo = cc.connectInfo()
}

info := &extension.ConnEventInfo{
ConnectionInfo: connInfo,
ActiveRoles: activeRoles,
Error: err,
}

cc.extensions.OnConnectionEvent(tp, info)
}

func (s *Server) checkConnectionCount() error {
// When the value of Instance.MaxConnections is 0, the number of connections is unlimited.
if int(s.cfg.Instance.MaxConnections) == 0 {
Expand Down
60 changes: 45 additions & 15 deletions server/tidb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,10 @@ import (
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/parser"
"github.com/pingcap/tidb/parser/ast"
"github.com/pingcap/tidb/parser/auth"
tmysql "github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/session"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/sessiontxn"
"github.com/pingcap/tidb/store/mockstore"
"github.com/pingcap/tidb/store/mockstore/unistore"
Expand Down Expand Up @@ -2825,28 +2827,50 @@ func TestExtensionConnEvent(t *testing.T) {
_ = conn.Close()
}()

var conn1, conn2 extension.ConnEventInfo
var expectedConn2 variable.ConnectionInfo
logs.check(func() {
require.Equal(t, []extension.ConnEventTp{
extension.ConnConnected,
extension.ConnHandshakeAccepted,
}, logs.types)
conn1 = logs.infos[0]
conn2 = conn1
conn2.User = "root"
conn2.DB = "test"

conn1 := logs.infos[0]
require.Equal(t, "127.0.0.1", conn1.ClientIP)
require.Equal(t, "127.0.0.1", conn1.ServerIP)
require.Empty(t, conn1.User)
require.Empty(t, conn1.DB)
require.Equal(t, conn2, logs.infos[1])
require.Equal(t, int(ts.port), conn1.ServerPort)
require.NotEqual(t, conn1.ServerPort, conn1.ClientPort)
require.NotEmpty(t, conn1.ConnectionID)
require.Nil(t, conn1.ActiveRoles)
require.NoError(t, conn1.Error)

expectedConn2 = *(conn1.ConnectionInfo)
expectedConn2.User = "root"
expectedConn2.DB = "test"
require.Equal(t, []*auth.RoleIdentity{}, logs.infos[1].ActiveRoles)
require.Nil(t, logs.infos[1].Error)
require.Equal(t, expectedConn2, *(logs.infos[1].ConnectionInfo))
})

_, err = conn.ExecContext(context.TODO(), "create role r1@'%'")
require.NoError(t, err)
_, err = conn.ExecContext(context.TODO(), "grant r1 TO root")
require.NoError(t, err)
_, err = conn.ExecContext(context.TODO(), "set role all")
require.NoError(t, err)

require.NoError(t, conn.Close())
require.NoError(t, db.Close())
require.NoError(t, logs.waitConnDisconnected())
logs.check(func() {
require.Equal(t, conn2, logs.infos[2])
require.Equal(t, 3, len(logs.infos))
require.Equal(t, 1, len(logs.infos[2].ActiveRoles))
require.Equal(t, auth.RoleIdentity{
Username: "r1",
Hostname: "%",
}, *logs.infos[2].ActiveRoles[0])
require.Nil(t, logs.infos[2].Error)
require.Equal(t, expectedConn2, *(logs.infos[2].ConnectionInfo))
})

// test for login failed
Expand All @@ -2871,16 +2895,22 @@ func TestExtensionConnEvent(t *testing.T) {
extension.ConnHandshakeRejected,
extension.ConnDisconnected,
}, logs.types)
conn1 = logs.infos[0]
conn2 = conn1
conn2.User = "noexist"
conn2.DB = "test"

conn1 := logs.infos[0]
require.Equal(t, "127.0.0.1", conn1.ClientIP)
require.Equal(t, "127.0.0.1", conn1.ServerIP)
require.Empty(t, conn1.User)
require.Empty(t, conn1.DB)
require.Equal(t, conn2, logs.infos[1])
require.Equal(t, conn2, logs.infos[2])
require.Equal(t, int(ts.port), conn1.ServerPort)
require.NotEqual(t, conn1.ServerPort, conn1.ClientPort)
require.NotEmpty(t, conn1.ConnectionID)
require.Nil(t, conn1.ActiveRoles)
require.NoError(t, conn1.Error)

expectedConn2 = *(conn1.ConnectionInfo)
expectedConn2.User = "noexist"
expectedConn2.DB = "test"
require.Equal(t, []*auth.RoleIdentity{}, logs.infos[1].ActiveRoles)
require.EqualError(t, logs.infos[1].Error, "[server:1045]Access denied for user 'noexist'@'127.0.0.1' (using password: NO)")
require.Equal(t, expectedConn2, *(logs.infos[1].ConnectionInfo))
})
}

0 comments on commit ecdc0f7

Please sign in to comment.