diff --git a/plugin/audit.go b/plugin/audit.go index f1471562fc657..c9ba27abaa404 100644 --- a/plugin/audit.go +++ b/plugin/audit.go @@ -45,6 +45,8 @@ const ( ChangeUser // PreAuth presents event before start auth. PreAuth + // Reject presents event reject connection event. + Reject ) func (c ConnectionEvent) String() string { @@ -57,6 +59,8 @@ func (c ConnectionEvent) String() string { return "ChangeUser" case PreAuth: return "PreAuth" + case Reject: + return "Reject" } return "" } @@ -85,6 +89,11 @@ type AuditManifest struct { OnParseEvent func(ctx context.Context, sctx *variable.SessionVars, event ParseEvent) error } +type ( + // RejectReasonCtxValue will be used in OnConnectionEvent to pass RejectReason to plugin. + RejectReasonCtxValue struct{} +) + const ( // ExecStartTimeCtxKey indicates stmt start execution time. ExecStartTimeCtxKey = "ExecStartTime" diff --git a/plugin/conn_ip_example/conn_ip_example.go b/plugin/conn_ip_example/conn_ip_example.go index 08c14f9250445..2bac3f4b47b22 100644 --- a/plugin/conn_ip_example/conn_ip_example.go +++ b/plugin/conn_ip_example/conn_ip_example.go @@ -47,3 +47,14 @@ func OnGeneralEvent(ctx context.Context, sctx *variable.SessionVars, event plugi fmt.Printf("new connection by %s\n", ctx.Value("ip")) return } + +// OnConnectionEvent implements TiDB Audit plugin's OnConnectionEvent SPI. +func OnConnectionEvent(ctx context.Context, event plugin.ConnectionEvent, info *variable.ConnectionInfo) error { + var reason string + if r := ctx.Value(plugin.RejectReasonCtxValue{}); r != nil { + reason = r.(string) + } + fmt.Println("conn_ip_example onConnect called") + fmt.Printf("conenct event: %s, reason: %s\n", event, reason) + return nil +} diff --git a/plugin/conn_ip_example/manifest.toml b/plugin/conn_ip_example/manifest.toml index b57badaf689f0..2cbebb6a47f98 100644 --- a/plugin/conn_ip_example/manifest.toml +++ b/plugin/conn_ip_example/manifest.toml @@ -11,5 +11,6 @@ validate = "Validate" onInit = "OnInit" onShutdown = "OnShutdown" export = [ - {extPoint="OnGeneralEvent", impl="OnGeneralEvent"} + {extPoint="OnGeneralEvent", impl="OnGeneralEvent"}, + {extPoint="OnConnectionEvent", impl="OnConnectionEvent"} ] diff --git a/server/server.go b/server/server.go index cbc2f47aca68d..2339517125ef1 100644 --- a/server/server.go +++ b/server/server.go @@ -404,6 +404,18 @@ func (s *Server) Close() { func (s *Server) onConn(conn *clientConn) { ctx := logutil.WithConnID(context.Background(), conn.connectionID) if err := conn.handshake(ctx); err != nil { + if plugin.IsEnable(plugin.Audit) { + conn.ctx.GetSessionVars().ConnectionInfo = conn.connectInfo() + } + err = plugin.ForeachPlugin(plugin.Audit, func(p *plugin.Plugin) error { + authPlugin := plugin.DeclareAuditManifest(p.Manifest) + if authPlugin.OnConnectionEvent != nil { + pluginCtx := context.WithValue(context.Background(), plugin.RejectReasonCtxValue{}, err.Error()) + return authPlugin.OnConnectionEvent(pluginCtx, plugin.Reject, conn.ctx.GetSessionVars().ConnectionInfo) + } + return nil + }) + terror.Log(err) // Some keep alive services will send request to TiDB and disconnect immediately. // So we only record metrics. metrics.HandShakeErrorCounter.Inc()