From aef4fc124b5a0f83e3076fbd8f9d38c61ff7f4f7 Mon Sep 17 00:00:00 2001 From: clark1013 Date: Wed, 21 Feb 2024 10:36:57 +0800 Subject: [PATCH] extension/serverless: introduce serverless audit log (#870) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * init * audit: init first version of audit log * audit: support log global rotate * audit: update log keys * audit: more common notify implement * audit: add record id * audit: rename some names * audit: default use normal log path * audit: add server ip info * audit: do some refine * audit: Add some log keys * audit: update some item format * audit: add stmt demo * audit: support redact * audit: fix bug * audit: update filter * audit: update filter * audit: check user * audit: update * update * audit: update * audit: fix bug * aduit: update * audit: update * update * update * audit: add filter unit test * fmt * Add tests for `tidb_audit_enabled` and `tidb_audit_log` * Add tests for `tidb_audit_log_max_size` and `tidb_audit_log_max_lifetime` * Add tests for `tidb_audit_log_reserved_*` * TODO: TestAuditLogRedact * Fininsh `TestAuditLogRedact` * Update (#1) * fix typo (#2) * Add `TRANSACTION` * rename sysvar * audit_log_create_filter, audit_log_remove_filter * finish function call and table test * test privilege * finish test for sysvar * TODO: TestConnectionEvenClass * update * audit: fix lint for audit log (#1) * audit: fix UT failure caused by the change of redact log (#3) * audit: use `t.TempDir()` to make test stable (#4) * audit: fix test failed for 7.1 * Add `OWNERS` file (#35) Co-authored-by: Chao Wang * audit: fix panic when logging sometime (#26) (#31) * test: fix unstable test TestAuditLogReservedDays (#8) * Format sysvar_test.go * audit: use `StatementContext` to generate redacted SQL (#9) * audit: fix panic when logging sometime (#26) * update * add owner * Update OWNERS --------- Co-authored-by: CbcWestwolf <1004626265@qq.com> Co-authored-by: wuhuizuo * audit: fix panic sometimes when `create user` without password (#37) (#39) * This is an automated cherry-pick of #37 Signed-off-by: ti-chi-bot * fix conflict --------- Signed-off-by: ti-chi-bot Co-authored-by: 王超 * extension/audit: introduce serverless audit log Signed-off-by: Wen Jiazhi * update bazel config Signed-off-by: Wen Jiazhi * support enable audit log when activate Signed-off-by: Wen Jiazhi * add gwconnid to audit log Signed-off-by: Wen Jiazhi * comment unstable test * update bazel * add log about activate request * diff audit log enabled * audit: fix memory leak for executeSQL (#44) (#45) * This is an automated cherry-pick of #44 Signed-off-by: ti-chi-bot * Update util.go --------- Signed-off-by: ti-chi-bot Co-authored-by: 王超 * Update extension/serverless/OWNERS --------- Signed-off-by: ti-chi-bot Signed-off-by: Wen Jiazhi Co-authored-by: Chao Wang Co-authored-by: cbcwestwolf <1004626265@qq.com> Co-authored-by: Ti Chi Robot Co-authored-by: wuhuizuo Co-authored-by: zzm Co-authored-by: Yuqing Bai --- config/config.go | 3 + config/serverless.go | 37 + extension/serverless/OWNERS | 4 + extension/serverless/README.md | 1 + extension/serverless/audit/BUILD.bazel | 72 ++ extension/serverless/audit/entry.go | 602 ++++++++++++++ extension/serverless/audit/entry_test.go | 117 +++ extension/serverless/audit/filter.go | 453 ++++++++++ extension/serverless/audit/filter_test.go | 867 ++++++++++++++++++++ extension/serverless/audit/function_test.go | 340 ++++++++ extension/serverless/audit/logger.go | 254 ++++++ extension/serverless/audit/manager.go | 713 ++++++++++++++++ extension/serverless/audit/register.go | 362 ++++++++ extension/serverless/audit/sysvar_test.go | 447 ++++++++++ extension/serverless/audit/util.go | 185 +++++ extension/serverless/generate/main.go | 71 ++ extension/session.go | 2 + server/extension.go | 4 + server/server.go | 2 + sessionctx/variable/session.go | 1 + standby/standby.go | 31 +- tidb-server/BUILD.bazel | 1 + tidb-server/main.go | 12 + 23 files changed, 4579 insertions(+), 2 deletions(-) create mode 100644 extension/serverless/OWNERS create mode 100644 extension/serverless/README.md create mode 100644 extension/serverless/audit/BUILD.bazel create mode 100644 extension/serverless/audit/entry.go create mode 100644 extension/serverless/audit/entry_test.go create mode 100644 extension/serverless/audit/filter.go create mode 100644 extension/serverless/audit/filter_test.go create mode 100644 extension/serverless/audit/function_test.go create mode 100644 extension/serverless/audit/logger.go create mode 100644 extension/serverless/audit/manager.go create mode 100644 extension/serverless/audit/register.go create mode 100644 extension/serverless/audit/sysvar_test.go create mode 100644 extension/serverless/audit/util.go create mode 100644 extension/serverless/generate/main.go diff --git a/config/config.go b/config/config.go index 949baf36fd7b7..92a3eb75c513f 100644 --- a/config/config.go +++ b/config/config.go @@ -322,6 +322,8 @@ type Config struct { TiDBWorker TiDBWorker `toml:"tidb-worker" json:"tidb-worker"` + AuditLog AuditLog `toml:"audit-log" json:"audit-log"` + // The following items are deprecated. We need to keep them here temporarily // to support the upgrade process. They can be removed in future. @@ -1188,6 +1190,7 @@ var defaultConf = Config{ }, RewriteCollations: make(map[string]map[string]string), TiDBWorker: defaultTiDBWorker(), + AuditLog: defaultAuditLog(), ExportID: "", } diff --git a/config/serverless.go b/config/serverless.go index e584b29285294..2679e82fca94b 100644 --- a/config/serverless.go +++ b/config/serverless.go @@ -147,3 +147,40 @@ func (w *TiDBWorker) Valid(c *Config) error { } return nil } + +const ( + // LogFormatText is stardard log format surrounded with [] + LogFormatText = "TEXT" + // LogFormatJSON is json format + LogFormatJSON = "JSON" +) + +// AuditLog is the config for serverless audit log +type AuditLog struct { + // Enable indicates whether audit log will be recorded + Enable bool `toml:"enable" json:"enable"` + // Path is the audit log output path + Path string `toml:"path" json:"path"` + // Format is the output format of audit log, both text and json are supported + Format string `toml:"format" json:"format"` + // MaxFilesize is the maximum file size before the log file be rotated, unit is MB + MaxFilesize int64 `toml:"max-filesize" json:"max-filesize"` + // MaxLifetime is the maximum time before the log file be rotated, unit is second + MaxLifetime int64 `toml:"max-lifetime" json:"max-lifetime"` + // Redacted indicates whether audit log redaction is enabled. If it set to true, user data will be replaced with `?` + Redacted bool `toml:"redacted" json:"redacted"` + // EncryptKey is used to encrypt sensitive informations in audit log. This key should be 32 bytes, we use AES-256 for encryption + EncryptKey string `toml:"encrypt-key" json:"encrypt-key"` +} + +func defaultAuditLog() AuditLog { + return AuditLog{ + Enable: false, + Path: "tidb-audit.log", + Format: LogFormatText, + MaxFilesize: 10, + MaxLifetime: 24 * 60 * 60, + Redacted: true, + EncryptKey: "", + } +} diff --git a/extension/serverless/OWNERS b/extension/serverless/OWNERS new file mode 100644 index 0000000000000..1d52fb19a86e0 --- /dev/null +++ b/extension/serverless/OWNERS @@ -0,0 +1,4 @@ +# See the OWNERS docs at https://go.k8s.io/owners +approvers: +reviewers: [] + diff --git a/extension/serverless/README.md b/extension/serverless/README.md new file mode 100644 index 0000000000000..42ecb835ae63f --- /dev/null +++ b/extension/serverless/README.md @@ -0,0 +1 @@ +This repo maintains the tidb enterprise extensions code. diff --git a/extension/serverless/audit/BUILD.bazel b/extension/serverless/audit/BUILD.bazel new file mode 100644 index 0000000000000..5cfae0fff2d5e --- /dev/null +++ b/extension/serverless/audit/BUILD.bazel @@ -0,0 +1,72 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "audit", + srcs = [ + "entry.go", + "filter.go", + "logger.go", + "manager.go", + "register.go", + "util.go", + ], + importpath = "github.com/pingcap/tidb/extension/serverless/audit", + visibility = ["//visibility:public"], + deps = [ + "//config", + "//extension", + "//keyspace", + "//kv", + "//metrics", + "//parser/ast", + "//parser/auth", + "//parser/mysql", + "//parser/terror", + "//sessionctx/stmtctx", + "//sessionctx/variable", + "//types", + "//util/chunk", + "//util/encrypt", + "//util/logutil", + "//util/sqlexec", + "//util/stringutil", + "//util/table-filter", + "@com_github_google_uuid//:uuid", + "@com_github_ngaut_pools//:pools", + "@com_github_pingcap_errors//:errors", + "@com_github_pingcap_log//:log", + "@in_gopkg_natefinch_lumberjack_v2//:lumberjack_v2", + "@io_etcd_go_etcd_api_v3//mvccpb", + "@io_etcd_go_etcd_client_v3//:client", + "@org_uber_go_atomic//:atomic", + "@org_uber_go_zap//:zap", + "@org_uber_go_zap//zapcore", + ], +) + +go_test( + name = "audit_test", + timeout = "short", + srcs = [ + "entry_test.go", + "filter_test.go", + "function_test.go", + "sysvar_test.go", + ], + embed = [":audit"], + flaky = True, + shard_count = 20, + deps = [ + "//config", + "//errno", + "//parser", + "//parser/auth", + "//server", + "//sessionctx/stmtctx", + "//testkit", + "//util/sem", + "@com_github_pingcap_errors//:errors", + "@com_github_stretchr_testify//require", + "@org_uber_go_zap//:zap", + ], +) diff --git a/extension/serverless/audit/entry.go b/extension/serverless/audit/entry.go new file mode 100644 index 0000000000000..08435415395f1 --- /dev/null +++ b/extension/serverless/audit/entry.go @@ -0,0 +1,602 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package audit + +import ( + "encoding/base64" + "fmt" + "strconv" + "strings" + "time" + + "github.com/ngaut/pools" + "github.com/pingcap/tidb/extension" + "github.com/pingcap/tidb/keyspace" + "github.com/pingcap/tidb/metrics" + "github.com/pingcap/tidb/parser/ast" + "github.com/pingcap/tidb/parser/auth" + "github.com/pingcap/tidb/parser/terror" + "github.com/pingcap/tidb/sessionctx/stmtctx" + "github.com/pingcap/tidb/sessionctx/variable" + "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/util/encrypt" + "go.uber.org/zap" +) + +// EventClass is the class for event +type EventClass int + +// The below defines all event classes +const ( + ClassUnknown EventClass = iota + ClassConnection + ClassConnect + ClassDisconnect + ClassChangeUser + ClassQuery + ClassTransaction + ClassExecute + ClassDML + ClassInsert + ClassReplace + ClassUpdate + ClassDelete + ClassLoadData + ClassSelect + ClassDDL + ClassAudit + ClassAuditSetSysVar + ClassAuditFuncCall + ClassAuditEnable + ClassAuditDisable + + // ClassCount is the count of all classes + // New class should be added before it + ClassCount +) + +var class2String = map[EventClass]string{ + ClassConnection: "CONNECTION", + ClassConnect: "CONNECT", + ClassDisconnect: "DISCONNECT", + ClassChangeUser: "CHANGE_USER", + ClassQuery: "QUERY", + ClassTransaction: "TRANSACTION", + ClassExecute: "EXECUTE", + ClassDML: "QUERY_DML", + ClassInsert: "INSERT", + ClassReplace: "REPLACE", + ClassUpdate: "UPDATE", + ClassDelete: "DELETE", + ClassLoadData: "LOAD_DATA", + ClassSelect: "SELECT", + ClassDDL: "QUERY_DDL", + ClassAudit: "AUDIT", + ClassAuditSetSysVar: "AUDIT_SET_SYS_VAR", + ClassAuditFuncCall: "AUDIT_FUNC_CALL", + ClassAuditEnable: "AUDIT_ENABLE", + ClassAuditDisable: "AUDIT_DISABLE", +} + +var string2class map[string]EventClass + +func init() { + string2class = make(map[string]EventClass, ClassCount) + for class := ClassUnknown + 1; class < ClassCount; class++ { + if str, ok := class2String[class]; ok { + string2class[strings.ToUpper(str)] = class + } else { + panic(fmt.Sprintf("string for audit class '%d' not set", class)) + } + } +} + +func getEventClass(s string) (EventClass, bool) { + if c, ok := string2class[strings.ToUpper(s)]; ok { + return c, true + } + return ClassUnknown, false +} + +func (EventClass) displayInLog() bool { + return true +} + +func (c EventClass) String() string { + return class2String[c] +} + +func getStmtClasses(stmt ast.StmtNode, execute bool) []EventClass { + classes := append(make([]EventClass, 0, 4), ClassQuery) + if execute { + classes = append(classes, ClassExecute) + } + switch stmt.(type) { + case *ast.SelectStmt: + classes = append(classes, ClassSelect) + case ast.DMLNode: + classes = appendDMLClass(classes, stmt) + case ast.DDLNode: + classes = append(classes, ClassDDL) + case *ast.BeginStmt, *ast.CommitStmt, *ast.RollbackStmt, *ast.SavepointStmt, *ast.ReleaseSavepointStmt: + classes = append(classes, ClassTransaction) + } + return classes +} + +func appendDMLClass(classes []EventClass, stmt ast.StmtNode) []EventClass { + switch n := stmt.(type) { + case *ast.InsertStmt: + if n.IsReplace { + classes = append(classes, ClassDML, ClassReplace) + } else { + classes = append(classes, ClassDML, ClassInsert) + } + case *ast.UpdateStmt: + classes = append(classes, ClassDML, ClassUpdate) + case *ast.DeleteStmt: + classes = append(classes, ClassDML, ClassDelete) + case *ast.LoadDataStmt: + classes = append(classes, ClassDML, ClassLoadData) + case *ast.NonTransactionalDMLStmt: + classes = appendDMLClass(classes, n.DMLStmt) + } + return classes +} + +// log keys +const ( + LogKeyID = "ID" + LogKeyTime = "TIME" + LogKeyEvent = "EVENT" + LogKeyUser = "USER" + LogKeyRoles = "ROLES" + LogKeyConnectionID = "CONNECTION_ID" + LogKeyGwConnID = "GATEWAY_CONNECTION_ID" + LogKeyCurrentDB = "CURRENT_DB" + LogKeyTables = "TABLES" + LogKeyReason = "REASON" + LogKeyStatusCode = "STATUS_CODE" + + LogKeyConnectionTYPE = "CONNECTION_TYPE" + LogKeyPID = "PID" + LogKeyServerVersion = "SERVER_VERSION" + LogKeySSLVersion = "SSL_VERSION" + LogKeyHostIP = "HOST_IP" + LogKeyHostPort = "HOST_PORT" + LogKeyClientIP = "CLIENT_IP" + LogKeyClientPort = "CLIENT_PORT" + LogKeySQLText = "SQL_TEXT" + LogKeyExecuteParams = "EXECUTE_PARAMS" + LogKeyAffectedRows = "AFFECTED_ROWS" + + LogKeyAuditOpTarget = "AUDIT_OP_TARGET" + LogKeyAuditOpArgs = "AUDIT_OP_ARGS" + + // Below keys are serverless related fileds + LogKeyKeyspaceName = "KEYSPACE_NAME" + LogKeyServerlessTenantID = "SERVERLESS_TENANT_ID" + LogKeyServerlessProjectID = "SERVERLESS_PROJECT_ID" + LogKeyServerlessClusterID = "SERVERLESS_CLUSTER_ID" + + // Below keys are copied form slow log, which are intended for performance tuning. + // Infomations exposed by executor such as kv_total/retry_count can't be collected in audit log. + LogKeyTimeTotal = "QUERY_TIME" + LogKeyTimeParse = "PARSE_TIME" + LogKeyTimeComile = "COMPILE_TIME" + LogKeyTimeOptimize = "OPTIMIZE_TIME" + LogKeyTimeWaitTS = "WAIT_TS" + LogKeyIndexNames = "INDEX_NAMES" + LogKeyCopTasks = "COP_TASKS" + LogKeyExecDetail = "EXEC_DETAIL" + LogKeyMemMax = "MEM_MAX" + LogKeyDiskMax = "DISK_MAX" + LogKeyPrevStmt = "PREV_STMT" + LogKeyPlanDigest = "PLAN_DIGEST" + + // Below keys are collected for debugging transaction + LogKeyStartTime = "START_TIME" + LogKeyTxnCreateTime = "TXN_CREATE_TIME" + LogKeyTxnStartTS = "TXN_START_TS" + LogKeyTxnForUpdateTS = "TXN_FOR_UPDATE_TS" +) + +var logFieldsPool = createLogFieldsPool(1024) + +func createLogFieldsPool(capacity int) *pools.ResourcePool { + return pools.NewResourcePool( + func() (pools.Resource, error) { + return newLogFields(), nil + }, + capacity, capacity, 0, + ) +} + +type logFields []zap.Field + +func newLogFields() logFields { + return make([]zap.Field, 0, 32) +} + +func fieldsFromPool(pool *pools.ResourcePool) (logFields, func()) { + resource, err := pool.TryGet() + if err != nil { + terror.Log(err) + return newLogFields(), func() {} + } + + if fields, ok := resource.(logFields); ok { + return fields, func() { + pool.Put(fields) + } + } + + return newLogFields(), func() {} +} + +func (logFields) Close() {} + +// LogEntry is the entry for the audit log +type LogEntry struct { + user string + host string + roles []*auth.RoleIdentity + connID uint64 + gwConnID string + classes []EventClass + tables []stmtctx.TableEntry + err error + keyspaceName string + fieldsFunc func(fields []zap.Field, cfg *LoggerConfig) []zap.Field +} + +func serverlessFields() []zap.Field { + return []zap.Field{ + zap.String(LogKeyServerlessTenantID, metrics.ServerlessTenantID), + zap.String(LogKeyServerlessProjectID, metrics.ServerlessProjectID), + zap.String(LogKeyServerlessClusterID, metrics.ServerlessClusterID), + } +} + +func newAuditEntryWithSetGlobalVar(name, value string, sessVars *variable.SessionVars, err error) *LogEntry { + connInfo := sessVars.ConnectionInfo + if connInfo == nil { + // if connInfo is nil, it means it is a background setting, only log this event when user manually set it. + return nil + } + + var classes []EventClass + switch name { + case TiDBAuditEnabled: + var auditEnableClass EventClass + if variable.TiDBOptOn(value) { + auditEnableClass = ClassAuditEnable + } else { + auditEnableClass = ClassAuditDisable + } + classes = []EventClass{ClassAudit, ClassAuditSetSysVar, auditEnableClass} + default: + classes = []EventClass{ClassAudit, ClassAuditSetSysVar} + } + + return &LogEntry{ + user: connInfo.User, + host: connInfo.Host, + roles: sessVars.ActiveRoles, + connID: connInfo.ConnectionID, + gwConnID: connInfo.GwConnID, + classes: classes, + err: err, + keyspaceName: keyspace.GetKeyspaceNameBySettings(), + fieldsFunc: func(fields []zap.Field, cfg *LoggerConfig) []zap.Field { + fields = append(fields, serverlessFields()...) + return append( + fields, + zap.String(LogKeyAuditOpTarget, name), + zap.Strings(LogKeyAuditOpArgs, []string{value}), + ) + }, + } +} + +func newAuditFuncCallEntry(name string, args []types.Datum, ctx extension.FunctionContext, err error) *LogEntry { + var connID uint64 + var gwConnID, user, host string + if conn := ctx.ConnectionInfo(); conn != nil { + connID = conn.ConnectionID + gwConnID = conn.GwConnID + user = conn.User + host = conn.Host + } + + return &LogEntry{ + user: user, + host: host, + roles: ctx.ActiveRoles(), + connID: connID, + gwConnID: gwConnID, + err: err, + keyspaceName: keyspace.GetKeyspaceNameBySettings(), + classes: []EventClass{ClassAudit, ClassAuditFuncCall}, + fieldsFunc: func(fields []zap.Field, cfg *LoggerConfig) []zap.Field { + fields = append(fields, serverlessFields()...) + return append( + fields, + zap.String(LogKeyAuditOpTarget, name), + zap.Strings(LogKeyAuditOpArgs, toStringArgs(args)), + ) + }, + } +} + +func newConnEventEntry(tp extension.ConnEventTp, info *extension.ConnEventInfo) *LogEntry { + e := &LogEntry{ + user: info.User, + host: info.Host, + roles: info.ActiveRoles, + connID: info.ConnectionID, + gwConnID: info.GwConnID, + err: info.Error, + keyspaceName: keyspace.GetKeyspaceNameBySettings(), + } + + switch tp { + case extension.ConnHandshakeAccepted, extension.ConnHandshakeRejected: + e.classes = []EventClass{ClassConnection, ClassConnect} + case extension.ConnReset: + e.classes = []EventClass{ClassConnection, ClassChangeUser} + case extension.ConnDisconnected: + e.classes = []EventClass{ClassConnection, ClassDisconnect} + default: + return nil + } + + e.user = info.User + e.host = info.Host + e.roles = info.ActiveRoles + e.connID = info.ConnectionID + e.gwConnID = info.GwConnID + e.err = info.Error + e.fieldsFunc = func(fields []zap.Field, cfg *LoggerConfig) []zap.Field { + fields = append(fields, serverlessFields()...) + if tp != extension.ConnDisconnected { + // only log current db when connect or change user + fields = append(fields, zap.String(LogKeyCurrentDB, info.DB)) + } + return append( + fields, + zap.String(LogKeyConnectionTYPE, info.ConnectionType), + zap.String(LogKeyPID, strconv.Itoa(info.PID)), + zap.String(LogKeyServerVersion, info.ServerVersion), + zap.String(LogKeySSLVersion, info.SSLVersion), + zap.String(LogKeyHostIP, normalizeIP(info.ServerIP)), + zap.String(LogKeyHostPort, strconv.Itoa(info.ServerPort)), + zap.String(LogKeyClientIP, normalizeIP(info.ClientIP)), + zap.String(LogKeyClientPort, info.ClientPort), + ) + } + + return e +} + +// Statements containing password should always be masked. +func isPasswordStmt(node ast.StmtNode) bool { + if node == nil { + return false + } + + findInSpecs := func(specs []*ast.UserSpec) bool { + for _, spec := range specs { + if authOpt := spec.AuthOpt; authOpt != nil && (authOpt.ByAuthString || authOpt.ByHashString) { + return true + } + } + return false + } + + switch stmt := node.(type) { + case *ast.SetPwdStmt: + return true + case *ast.CreateUserStmt: + return findInSpecs(stmt.Specs) + case *ast.AlterUserStmt: + return findInSpecs(stmt.Specs) + } + return false +} + +func newStmtEventEntry(tp extension.StmtEventTp, info extension.StmtEventInfo) *LogEntry { + if tp != extension.StmtSuccess && tp != extension.StmtError { + return nil + } + + isExecute := info.ExecuteStmtNode() != nil + innerStmtNode := info.ExecutePreparedStmt() + if innerStmtNode == nil { + innerStmtNode = info.StmtNode() + } + + entry := &LogEntry{ + classes: getStmtClasses(innerStmtNode, isExecute), + tables: info.RelatedTables(), + err: info.GetError(), + keyspaceName: keyspace.GetKeyspaceNameBySettings(), + } + + if connInfo := info.ConnectionInfo(); connInfo != nil { + entry.connID = connInfo.ConnectionID + entry.gwConnID = connInfo.GwConnID + entry.user = connInfo.User + entry.host = connInfo.Host + entry.roles = info.ActiveRoles() + } + entry.fieldsFunc = func(fields []zap.Field, cfg *LoggerConfig) []zap.Field { + var sqlText string + if cfg.Redact || isPasswordStmt(info.StmtNode()) { + sqlText, _ = info.SQLDigest() + } else { + sqlText = info.OriginalText() + } + sqlText, _ = encryptIfKeySet(sqlText, cfg.EncryptKey, cfg.EncryptIv) + + currentDB, _ := encryptIfKeySet(info.CurrentDB(), cfg.EncryptKey, cfg.EncryptIv) + + fields = append(fields, serverlessFields()...) + fields = append( + fields, + zap.String(LogKeyCurrentDB, currentDB), + zap.String(LogKeySQLText, sqlText), + ) + + if sessVars := info.GetSessionVars(); sessVars != nil { + _, planDigest := sessVars.StmtCtx.GetPlanDigest() + var digest string + if planDigest != nil { + digest, _ = encryptIfKeySet(planDigest.String(), cfg.EncryptKey, cfg.EncryptIv) + } + + indexNames := make([]string, len(sessVars.StmtCtx.IndexNames)) + for i, name := range sessVars.StmtCtx.IndexNames { + indexNames[i], _ = encryptIfKeySet(name, cfg.EncryptKey, cfg.EncryptIv) + } + + fields = append( + fields, + zap.String(LogKeyTimeTotal, strconv.FormatFloat((time.Since(sessVars.StartTime)+sessVars.DurationParse).Seconds(), 'f', -1, 64)), + zap.String(LogKeyTimeParse, strconv.FormatFloat(sessVars.DurationParse.Seconds(), 'f', -1, 64)), + zap.String(LogKeyTimeComile, strconv.FormatFloat(sessVars.DurationCompile.Seconds(), 'f', -1, 64)), + zap.String(LogKeyTimeOptimize, strconv.FormatFloat(sessVars.DurationOptimization.Seconds(), 'f', -1, 64)), + zap.String(LogKeyTimeWaitTS, strconv.FormatFloat(sessVars.DurationWaitTS.Seconds(), 'f', -1, 64)), + zap.Strings(LogKeyIndexNames, indexNames), + zap.Any(LogKeyCopTasks, sessVars.StmtCtx.CopTasksDetails()), + zap.Any(LogKeyExecDetail, sessVars.StmtCtx.GetExecDetails()), + zap.Int64(LogKeyMemMax, sessVars.MemTracker.MaxConsumed()), + zap.Int64(LogKeyDiskMax, sessVars.DiskTracker.MaxConsumed()), + zap.String(LogKeyPlanDigest, digest), + zap.Time(LogKeyStartTime, sessVars.StartTime), + ) + + if sessVars.PrevStmt != nil { + prevStmt, _ := encryptIfKeySet(sessVars.PrevStmt.String(), cfg.EncryptKey, cfg.EncryptIv) + fields = append(fields, zap.String(LogKeyPrevStmt, prevStmt)) + } + + if sessVars.TxnCtx != nil { + fields = append( + fields, + zap.Time(LogKeyTxnCreateTime, sessVars.TxnCtx.CreateTime), + zap.Uint64(LogKeyTxnStartTS, sessVars.TxnCtx.StartTS), + zap.Uint64(LogKeyTxnForUpdateTS, sessVars.TxnCtx.GetForUpdateTS()), + ) + } + } + + if isExecute && !cfg.Redact { + fields = append(fields, zap.Strings(LogKeyExecuteParams, toStringArgs(info.PreparedParams()))) + } + + for _, class := range entry.classes { + if class == ClassDML { + fields = append(fields, zap.Uint64(LogKeyAffectedRows, info.AffectedRows())) + break + } + } + + return fields + } + return entry +} + +// Filter filters this log entry by filter +func (e *LogEntry) Filter(filter *LogFilterRuleBundle) *LogEntry { + return filter.Filter(e) +} + +// Log logs this entry +func (e *LogEntry) Log(logger *Logger, id *entryIDGenerator) { + if e == nil { + return + } + + events := make([]string, 0, len(e.classes)) + for _, class := range e.classes { + if class.displayInLog() { + events = append(events, class.String()) + } + } + + var roles []string + if len(e.roles) > 0 { + roles = make([]string, len(e.roles)) + for i, role := range e.roles { + roles[i] = role.String() + } + } + + var tables []string + if len(e.tables) > 0 { + tables = make([]string, len(e.tables)) + for i, tb := range e.tables { + tables[i], _ = encryptIfKeySet(fmt.Sprintf("`%s`.`%s`", tb.DB, tb.Table), logger.cfg.EncryptKey, logger.cfg.EncryptIv) + } + } + + user, _ := encryptIfKeySet(e.user, logger.cfg.EncryptKey, logger.cfg.EncryptIv) + + fields, closeFields := fieldsFromPool(logFieldsPool) + defer closeFields() + + fields = append( + fields, + zap.String(LogKeyID, id.Next()), + zap.Strings(LogKeyEvent, events), + zap.String(LogKeyUser, user), + zap.Strings(LogKeyRoles, roles), + zap.String(LogKeyConnectionID, strconv.FormatUint(e.connID, 10)), + zap.String(LogKeyGwConnID, e.gwConnID), + zap.Strings(LogKeyTables, tables), + zap.Int(LogKeyStatusCode, getStatusCode(e.err)), + zap.String(LogKeyKeyspaceName, e.keyspaceName), + ) + + if err := e.err; err != nil { + fields = append(fields, zap.String(LogKeyReason, err.Error())) + } + + if e.fieldsFunc != nil { + fields = e.fieldsFunc(fields, &logger.cfg) + } + + logger.Event(fields) +} + +func getStatusCode(err error) int { + if err != nil { + return 0 + } + return 1 +} + +// Encrypt text to base64 string with AES-256, if the key if empty, return the original text +func encryptIfKeySet(text, key, iv string) (string, error) { + if key == "" { + return text, nil + } + encrypted, err := encrypt.AESEncryptWithCBC([]byte(text), []byte(key), []byte(iv)) + if err != nil { + return "", err + } + return base64.StdEncoding.EncodeToString(encrypted), nil +} diff --git a/extension/serverless/audit/entry_test.go b/extension/serverless/audit/entry_test.go new file mode 100644 index 0000000000000..3396800b03d44 --- /dev/null +++ b/extension/serverless/audit/entry_test.go @@ -0,0 +1,117 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package audit + +import ( + "testing" + "time" + "unsafe" + + "github.com/pingcap/tidb/parser" + "github.com/stretchr/testify/require" + "go.uber.org/zap" +) + +func TestEntryPool(t *testing.T) { + poolCap := 8 + pool := createLogFieldsPool(poolCap) + require.Equal(t, time.Duration(0), pool.IdleTimeout()) + + // cnt = poolCap * 2 to make sure the pool will be exhausted + cnt := poolCap * 2 + resources := make([]logFields, 0, cnt) + closeFuncs := make([]func(), 0, cnt) + dedup := make(map[uintptr]logFields) + for i := 0; i < cnt; i++ { + fields, fn := fieldsFromPool(pool) + require.Len(t, fields, 0) + fields = append(fields, zap.String("foo", "bar")) + resources = append(resources, fields) + closeFuncs = append(closeFuncs, fn) + + ptr := uintptr(unsafe.Pointer(unsafe.SliceData(fields))) + dedup[ptr] = fields + + if i < poolCap { + require.Equal(t, int64(poolCap-i-1), pool.Available()) + } else { + require.Equal(t, int64(0), pool.Available()) + } + } + + // All new fields should be different. + require.Equal(t, cnt, len(dedup)) + for _, fn := range closeFuncs { + fn() + } + + // fields should be reused + require.Equal(t, int64(poolCap), pool.Available()) + fields, fn := fieldsFromPool(pool) + require.Len(t, fields, 0) + require.Contains(t, dedup, uintptr(unsafe.Pointer(unsafe.SliceData(fields)))) + fn() +} + +func TestIsPasswordStmt(t *testing.T) { + cases := []struct { + sql string + hasPassword bool + }{ + { + sql: "create user u1", + hasPassword: false, + }, + { + sql: "create user u1 identified by '123'", + hasPassword: true, + }, + { + sql: "create user u1 identified with authentication_ldap_sasl as 'uid=u1_ldap,ou=People,dc=example,dc=com'", + hasPassword: true, + }, + { + sql: "alter user u1", + hasPassword: false, + }, + { + sql: "alter user u1 identified by '123'", + hasPassword: true, + }, + { + sql: "alter user u1 identified with authentication_ldap_sasl as 'uid=u1_ldap,ou=People,dc=example,dc=com'", + hasPassword: true, + }, + { + sql: "set password for u1 = '123'", + hasPassword: true, + }, + { + sql: "set @a=1", + hasPassword: false, + }, + { + sql: "select * from t", + hasPassword: false, + }, + } + + p := parser.New() + for _, c := range cases { + stmt, err := p.ParseOneStmt(c.sql, "", "") + require.NoError(t, err) + require.Equal(t, c.hasPassword, isPasswordStmt(stmt)) + } +} diff --git a/extension/serverless/audit/filter.go b/extension/serverless/audit/filter.go new file mode 100644 index 0000000000000..e215c178fe3bd --- /dev/null +++ b/extension/serverless/audit/filter.go @@ -0,0 +1,453 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package audit + +import ( + "context" + "encoding/json" + "fmt" + "strings" + + "github.com/pingcap/errors" + "github.com/pingcap/tidb/parser/mysql" + "github.com/pingcap/tidb/sessionctx/stmtctx" + "github.com/pingcap/tidb/util/logutil" + "github.com/pingcap/tidb/util/sqlexec" + "github.com/pingcap/tidb/util/stringutil" + tablefilter "github.com/pingcap/tidb/util/table-filter" + "go.uber.org/zap" +) + +var ( + // SchemaName is the database of audit log tables + SchemaName = "mysql" + // TableNameAuditLogFilters is the table of audit log filters + TableNameAuditLogFilters = fmt.Sprintf("%s.%s", SchemaName, "audit_log_filters") + // TableNameAuditLogFilterRules is the table of audit log filter rules + TableNameAuditLogFilterRules = fmt.Sprintf("%s.%s", SchemaName, "audit_log_filter_rules") +) + +func checkTableAccess(db, tbl, _ string, _ mysql.PrivilegeType, sem bool) []string { + if db != SchemaName { + return nil + } + + fullTblName := strings.ToLower(fmt.Sprintf("%s.%s", db, tbl)) + if fullTblName != TableNameAuditLogFilters && fullTblName != TableNameAuditLogFilterRules { + return nil + } + + return requirePrivileges(sem) +} + +var ( + // CreateFilterTableSQL creates audit log filter table + CreateFilterTableSQL = fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s ( + FILTER_NAME VARCHAR(128), + CONTENT TEXT, + PRIMARY KEY(FILTER_NAME) + )`, TableNameAuditLogFilters) + // InsertFilterSQL inserts a filter + InsertFilterSQL = fmt.Sprintf( + "INSERT INTO %s (FILTER_NAME, CONTENT) VALUES (%%?, %%?)", TableNameAuditLogFilters) + // ReplaceFilterSQL replaces a filter + ReplaceFilterSQL = fmt.Sprintf( + "REPLACE INTO %s (FILTER_NAME, CONTENT) VALUES (%%?, %%?)", TableNameAuditLogFilters) + // SelectFilterByNameSQL queries a filter by name + SelectFilterByNameSQL = fmt.Sprintf( + "SELECT FILTER_NAME, CONTENT FROM %s WHERE FILTER_NAME = %%?", TableNameAuditLogFilters) + // DeleteFilterByNameSQL deletes a filter by name + DeleteFilterByNameSQL = fmt.Sprintf( + "DELETE FROM %s WHERE FILTER_NAME = %%?", TableNameAuditLogFilters) + // SelectFilterListSQL queries all filters + SelectFilterListSQL = fmt.Sprintf( + "SELECT FILTER_NAME, CONTENT FROM %s", TableNameAuditLogFilters) + + // CreateFilterRuleTableSQL creates audit log filter rule table + CreateFilterRuleTableSQL = fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s ( + USER VARCHAR(64), + FILTER_NAME VARCHAR(128), + ENABLED TINYINT(4), + PRIMARY KEY(FILTER_NAME, USER) + )`, TableNameAuditLogFilterRules) + // InsertFilterRuleSQL inerts a new filter rule + InsertFilterRuleSQL = fmt.Sprintf( + "INSERT INTO %s (USER, FILTER_NAME, ENABLED) VALUES (%%?, %%?, 1)", TableNameAuditLogFilterRules) + // ReplaceFilterRuleSQL replaces a filter rule + ReplaceFilterRuleSQL = fmt.Sprintf( + "REPLACE INTO %s (USER, FILTER_NAME, ENABLED) VALUES (%%?, %%?, 1)", TableNameAuditLogFilterRules) + // DeleteFilterRuleSQL deletes a filter rule + DeleteFilterRuleSQL = fmt.Sprintf( + "DELETE FROM %s WHERE (USER, FILTER_NAME) = (%%?, %%?)", TableNameAuditLogFilterRules) + // SelectFilterRuleListSQL queries all filter rules + SelectFilterRuleListSQL = fmt.Sprintf( + "SELECT USER, FILTER_NAME, ENABLED FROM %s", TableNameAuditLogFilterRules) + // SelectFilterRuleByUserAndFilterSQL queries a filter by user and filter name + SelectFilterRuleByUserAndFilterSQL = fmt.Sprintf( + "SELECT USER, FILTER_NAME FROM %s WHERE (USER, FILTER_NAME) = (%%?, %%?)", TableNameAuditLogFilterRules) + // SelectFilterRuleByFilterSQL queries filters by filter name + SelectFilterRuleByFilterSQL = fmt.Sprintf( + "SELECT USER, FILTER_NAME FROM %s WHERE FILTER_NAME = %%?", TableNameAuditLogFilterRules) + // UpdateFilterRuleEnabledSQL update a filter by user and filter name + UpdateFilterRuleEnabledSQL = fmt.Sprintf( + "UPDATE %s SET ENABLED = %%? WHERE (USER, FILTER_NAME) = (%%?, %%?)", TableNameAuditLogFilterRules) +) + +// filter function names +const ( + FuncAuditLogCreateFilter = "audit_log_create_filter" + FuncAuditLogRemoveFilter = "audit_log_remove_filter" + FuncAuditLogCreateRule = "audit_log_create_rule" + FuncAuditLogRemoveRule = "audit_log_remove_rule" + FuncAuditLogEnableRule = "audit_log_enable_rule" + FuncAuditLogDisableRule = "audit_log_disable_rule" +) + +type filterFunc func(entry *LogEntry) bool + +var trueFilter = filterFunc(func(_ *LogEntry) bool { return true }) +var falseFilter = filterFunc(func(_ *LogEntry) bool { return false }) + +func valuesMatchSet[T any, E any](values []T, set []E, include bool, match func(T, E) bool) bool { + if set == nil { + return true + } + if include && len(values) == 0 { + return false + } + for _, v := range values { + for _, item := range set { + if match(v, item) { + return include + } + } + } + return !include +} + +func eventClassMatch(v, m EventClass) bool { + return v == m +} + +func tableFilterMatch(v stmtctx.TableEntry, f tablefilter.Filter) bool { + return f.MatchTable(v.DB, v.Table) +} + +func statusCodeMatch(v int, codes []int) bool { + if codes == nil { + return true + } + for _, c := range codes { + if v == c { + return true + } + } + return false +} + +type filterSpec struct { + Classes []string `json:"class,omitempty"` + ClassesExclude []string `json:"class_excl,omitempty"` + Tables []string `json:"table,omitempty"` + TablesExclude []string `json:"table_excl,omitempty"` + StatusCodes []int `json:"status_code,omitempty"` +} + +func (f *filterSpec) Validate() error { + for _, c := range f.Classes { + if _, ok := getEventClass(c); !ok { + return errors.Errorf("invalid event class name '%s'", c) + } + } + for _, c := range f.ClassesExclude { + if _, ok := getEventClass(c); !ok { + return errors.Errorf("invalid event class name '%s'", c) + } + } + for _, t := range f.Tables { + if _, err := parseTableFilter(t); err != nil { + return err + } + } + for _, t := range f.TablesExclude { + if _, err := parseTableFilter(t); err != nil { + return err + } + } + return nil +} + +type logFilter struct { + Name string `json:"-"` + Filter []filterSpec `json:"filter,omitempty"` +} + +func newLogFilterFormJSON(name string, data []byte) (*logFilter, error) { + var f logFilter + if err := json.Unmarshal(data, &f); err != nil { + return nil, err + } + f.Name = name + return &f, nil +} + +func (f *logFilter) Normalize() (*logFilter, error) { + newFilter := *f + for i := range newFilter.Filter { + spec := &newFilter.Filter[i] + for j, c := range spec.Classes { + if class, ok := getEventClass(c); ok { + spec.Classes[j] = class.String() + } else { + return nil, errors.Errorf("invalid event class: '%s'", c) + } + } + for j, c := range spec.ClassesExclude { + if class, ok := getEventClass(c); ok { + spec.ClassesExclude[j] = class.String() + } else { + return nil, errors.Errorf("invalid event class: '%s'", c) + } + } + } + return &newFilter, nil +} + +func (f *logFilter) ToJSON() (string, error) { + bs, err := json.Marshal(f) + if err != nil { + return "", err + } + return string(bs), nil +} + +func (f *logFilter) Validate() error { + if f.Name == "" { + return errors.New("filter name should not be empty") + } + for _, spec := range f.Filter { + if err := spec.Validate(); err != nil { + return err + } + } + return nil +} + +func (f *logFilter) createFilterFunc() filterFunc { + if len(f.Filter) == 0 { + return trueFilter + } + + classes := make([][]EventClass, len(f.Filter)) + exclClasses := make([][]EventClass, len(f.Filter)) + tableFilters := make([][]tablefilter.Filter, len(f.Filter)) + exclTableFilters := make([][]tablefilter.Filter, len(f.Filter)) + for i, spec := range f.Filter { + classes[i] = f.getEventClassFilterList(spec.Classes) + exclClasses[i] = f.getEventClassFilterList(spec.ClassesExclude) + tableFilters[i] = f.getTableFilterList(spec.Tables) + exclTableFilters[i] = f.getTableFilterList(spec.TablesExclude) + } + + return func(entry *LogEntry) bool { + statusCode := getStatusCode(entry.err) + for i, spec := range f.Filter { + if valuesMatchSet(entry.classes, classes[i], true, eventClassMatch) && + valuesMatchSet(entry.classes, exclClasses[i], false, eventClassMatch) && + valuesMatchSet(entry.tables, tableFilters[i], true, tableFilterMatch) && + valuesMatchSet(entry.tables, exclTableFilters[i], false, tableFilterMatch) && + statusCodeMatch(statusCode, spec.StatusCodes) { + return true + } + } + return false + } +} + +func (f *logFilter) getEventClassFilterList(classes []string) []EventClass { + if classes == nil { + return nil + } + + class := make([]EventClass, 0, len(classes)) + for _, s := range classes { + if c, ok := string2class[strings.ToUpper(s)]; ok { + class = append(class, c) + } else { + logutil.BgLogger().Error( + "error occurs when loading audit log filter, parse event class failed", + zap.String("filter", f.Name), + zap.String("eventClass", s), + ) + } + } + return class +} + +func (f *logFilter) getTableFilterList(tables []string) []tablefilter.Filter { + if len(tables) == 0 { + return nil + } + + tblFilters := make([]tablefilter.Filter, 0, len(tables)) + for _, tbl := range tables { + tblFilter, err := parseTableFilter(tbl) + if err != nil { + logutil.BgLogger().Error( + "error occurs when loading audit log filter, parse table filter failed", + zap.String("filter", f.Name), + zap.String("tableFilter", tbl), + zap.Error(err), + ) + continue + } + tblFilters = append(tblFilters, tblFilter) + } + return tblFilters +} + +type logFilterRule struct { + user string + filterName string + enabled bool + filter *logFilter +} + +func (r *logFilterRule) createFilterFunc() filterFunc { + if r.filter == nil || !r.enabled { + return falseFilter + } + + fn := r.filter.createFilterFunc() + _, userPattern, hostPattern, err := normalizeUserAndHost(r.user) + if err != nil { + logutil.BgLogger().Error( + "error occurs when loading audit log filter, parse user filter failed", + zap.String("filter", r.filterName), + zap.String("userFilter", r.user), + zap.Error(err), + ) + return falseFilter + } + userPatWeights, userPatTypes := stringutil.CompilePattern(userPattern, '\\') + hostPatWeights, hostPatTypes := stringutil.CompilePattern(hostPattern, '\\') + + return func(entry *LogEntry) bool { + if !stringutil.DoMatch(entry.user, userPatWeights, userPatTypes) { + return false + } + + switch entry.host { + case "localhost", "127.0.0.1": + if !stringutil.DoMatch("localhost", hostPatWeights, hostPatTypes) && + !stringutil.DoMatch("127.0.0.1", hostPatWeights, hostPatTypes) { + return false + } + default: + if !stringutil.DoMatch(entry.host, hostPatWeights, hostPatTypes) { + return false + } + } + return fn(entry) + } +} + +// LogFilterRuleBundle is a bundle that contains some rules +type LogFilterRuleBundle struct { + rules []*logFilterRule + ruleFuncs []filterFunc +} + +func newLogFilterRuleBundle(rules []*logFilterRule) *LogFilterRuleBundle { + ruleFuncs := make([]filterFunc, len(rules)) + for i, rule := range rules { + ruleFuncs[i] = rule.createFilterFunc() + } + + return &LogFilterRuleBundle{ + rules: rules, + ruleFuncs: ruleFuncs, + } +} + +// Filter filters a log entry +func (b *LogFilterRuleBundle) Filter(entry *LogEntry) *LogEntry { + if entry == nil || b == nil { + return entry + } + + for _, fn := range b.ruleFuncs { + if fn(entry) { + return entry + } + } + return nil +} + +func listFilters(ctx context.Context, exec sqlexec.SQLExecutor) ([]*logFilter, error) { + rows, err := executeSQL(ctx, exec, SelectFilterListSQL) + if err != nil { + return nil, err + } + + filters := make([]*logFilter, 0, len(rows)) + for _, row := range rows { + f, err := newLogFilterFormJSON(row.GetString(0), row.GetBytes(1)) + if err != nil { + return nil, err + } + filters = append(filters, f) + } + return filters, nil +} + +func listFilterRules(ctx context.Context, exec sqlexec.SQLExecutor) ([]*logFilterRule, error) { + filters, err := listFilters(ctx, exec) + if err != nil { + return nil, err + } + + filterMap := make(map[string]*logFilter, len(filters)) + for _, f := range filters { + filterMap[f.Name] = f + } + + rows, err := executeSQL(ctx, exec, SelectFilterRuleListSQL) + if err != nil { + return nil, err + } + + rules := make([]*logFilterRule, 0, len(rows)) + for _, row := range rows { + user := row.GetString(0) + filterName := row.GetString(1) + filter, ok := filterMap[filterName] + if !ok { + logutil.BgLogger().Warn( + "cannot find filter when listFilterRules", + zap.String("filter", filterName), + zap.String("user", user), + ) + continue + } + rules = append(rules, &logFilterRule{ + user: user, + filterName: filterName, + enabled: row.GetInt64(2) == 1, + filter: filter, + }) + } + return rules, nil +} diff --git a/extension/serverless/audit/filter_test.go b/extension/serverless/audit/filter_test.go new file mode 100644 index 0000000000000..ddfee57f92f99 --- /dev/null +++ b/extension/serverless/audit/filter_test.go @@ -0,0 +1,867 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package audit + +import ( + "context" + "encoding/json" + "fmt" + "testing" + + "github.com/pingcap/errors" + "github.com/pingcap/tidb/sessionctx/stmtctx" + "github.com/pingcap/tidb/testkit" + "github.com/stretchr/testify/require" +) + +type entryFilterTest struct { + entry LogEntry + match bool +} + +func TestFilterSpec(t *testing.T) { + register4Test() + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + cases := []struct { + spec filterSpec + invalid bool + entries []entryFilterTest + jsonContent string + }{ + { + spec: filterSpec{ + Classes: []string{"abc"}, + }, + invalid: true, + }, + { + spec: filterSpec{ + ClassesExclude: []string{"abc"}, + }, + invalid: true, + }, + { + spec: filterSpec{ + Tables: []string{"t1"}, + }, + invalid: true, + }, + { + spec: filterSpec{ + TablesExclude: []string{"t1"}, + }, + invalid: true, + }, + { + // empty filterSpec can pass all types of log + spec: filterSpec{}, + jsonContent: "{}", + entries: []entryFilterTest{ + { + entry: LogEntry{}, + match: true, + }, + { + entry: LogEntry{user: "root", classes: []EventClass{ClassQuery}}, + match: true, + }, + { + entry: LogEntry{ + user: "u1", + classes: []EventClass{ClassQuery}, + tables: []stmtctx.TableEntry{ + {DB: "db1", Table: "t1"}, + {DB: "db2", Table: "t2"}, + }, + }, + match: true, + }, + { + entry: LogEntry{ + err: errors.New(""), + }, + match: true, + }, + }, + }, + { + spec: filterSpec{ + Classes: []string{"QUERY_DML", "QUERY_DDL"}, + }, + jsonContent: `{"class": ["QUERY_DML", "QUERY_DDL"]}`, + entries: []entryFilterTest{ + { + entry: LogEntry{user: "root", classes: []EventClass{ClassQuery}}, + match: false, + }, + { + entry: LogEntry{user: "root", classes: []EventClass{ClassQuery, ClassDML}}, + match: true, + }, + { + entry: LogEntry{user: "root", classes: []EventClass{ClassQuery, ClassDDL}}, + match: true, + }, + { + entry: LogEntry{user: "root", classes: []EventClass{ClassQuery, ClassSelect}}, + match: false, + }, + { + entry: LogEntry{user: "root", err: errors.New(""), classes: []EventClass{ClassQuery, ClassDML}, tables: []stmtctx.TableEntry{ + {DB: "db1", Table: "t1"}, + {DB: "db2", Table: "t2"}, + }}, + match: true, + }, + }, + }, + { + spec: filterSpec{ + ClassesExclude: []string{"SELECT", "INSERT"}, + }, + jsonContent: `{"class_excl": ["SELECT", "INSERT"]}`, + entries: []entryFilterTest{ + { + entry: LogEntry{user: "root"}, + match: true, + }, + { + entry: LogEntry{user: "root", classes: []EventClass{ClassQuery}}, + match: true, + }, + { + entry: LogEntry{user: "root", classes: []EventClass{ClassQuery, ClassSelect}}, + match: false, + }, + { + entry: LogEntry{user: "root", classes: []EventClass{ClassQuery, ClassInsert}}, + match: false, + }, + { + entry: LogEntry{user: "root", classes: []EventClass{ClassQuery, ClassUpdate}}, + match: true, + }, + }, + }, + { + spec: filterSpec{ + Classes: []string{"QUERY_DDL", "QUERY_DML"}, + ClassesExclude: []string{"SELECT", "INSERT"}, + }, + jsonContent: `{"class": ["QUERY_DDL", "QUERY_DML"], "class_excl": ["SELECT", "INSERT"]}`, + entries: []entryFilterTest{ + { + entry: LogEntry{user: "root"}, + match: false, + }, + { + entry: LogEntry{user: "root", classes: []EventClass{ClassDML}}, + match: true, + }, + { + entry: LogEntry{user: "root", classes: []EventClass{ClassDDL}}, + match: true, + }, + { + entry: LogEntry{user: "root", classes: []EventClass{ClassDML, ClassDDL}}, + match: true, + }, + { + entry: LogEntry{user: "root", classes: []EventClass{ClassUpdate}}, + match: false, + }, + { + entry: LogEntry{user: "root", classes: []EventClass{ClassDML, ClassUpdate}}, + match: true, + }, + { + entry: LogEntry{user: "root", classes: []EventClass{ClassDML, ClassInsert}}, + match: false, + }, + { + entry: LogEntry{user: "root", classes: []EventClass{ClassDML, ClassSelect}}, + match: false, + }, + }, + }, + { + spec: filterSpec{ + Classes: []string{"SELECT", "QUERY_DML"}, + ClassesExclude: []string{"SELECT", "INSERT"}, + }, + jsonContent: `{"class": ["SELECT", "QUERY_DML"], "class_excl": ["SELECT", "INSERT"]}`, + entries: []entryFilterTest{ + { + entry: LogEntry{user: "root"}, + match: false, + }, + { + entry: LogEntry{user: "root", classes: []EventClass{ClassDML}}, + match: true, + }, + { + entry: LogEntry{user: "root", classes: []EventClass{ClassDDL}}, + match: false, + }, + { + entry: LogEntry{user: "root", classes: []EventClass{ClassDML, ClassDDL}}, + match: true, + }, + { + entry: LogEntry{user: "root", classes: []EventClass{ClassUpdate}}, + match: false, + }, + { + entry: LogEntry{user: "root", classes: []EventClass{ClassDML, ClassUpdate}}, + match: true, + }, + { + entry: LogEntry{user: "root", classes: []EventClass{ClassDML, ClassInsert}}, + match: false, + }, + { + entry: LogEntry{user: "root", classes: []EventClass{ClassDML, ClassSelect}}, + match: false, + }, + }, + }, + { + spec: filterSpec{ + Tables: []string{"*.*"}, + }, + jsonContent: `{"table": ["*.*"]}`, + entries: []entryFilterTest{ + { + entry: LogEntry{user: "root"}, + match: false, + }, + { + entry: LogEntry{user: "root", tables: []stmtctx.TableEntry{ + {DB: "test", Table: "t1"}, + }}, + match: true, + }, + }, + }, + { + spec: filterSpec{ + Tables: []string{"db*.test*", "test.t1"}, + }, + jsonContent: `{"table": ["db*.test*", "test.t1"]}`, + entries: []entryFilterTest{ + { + entry: LogEntry{user: "root"}, + match: false, + }, + { + entry: LogEntry{user: "root", tables: []stmtctx.TableEntry{ + {DB: "db1", Table: "test1"}, + }}, + match: true, + }, + { + entry: LogEntry{user: "root", err: errors.New(""), tables: []stmtctx.TableEntry{ + {DB: "db2", Table: "test2"}, + }}, + match: true, + }, + { + entry: LogEntry{user: "root", tables: []stmtctx.TableEntry{ + {DB: "test", Table: "t1"}, + }}, + match: true, + }, + { + entry: LogEntry{user: "root", tables: []stmtctx.TableEntry{ + {DB: "db1", Table: "t1"}, + }}, + match: false, + }, + { + entry: LogEntry{user: "root", tables: []stmtctx.TableEntry{ + {DB: "d1", Table: "test1"}, + }}, + match: false, + }, + { + entry: LogEntry{user: "root", tables: []stmtctx.TableEntry{ + {DB: "test", Table: "test1"}, + }}, + match: false, + }, + { + entry: LogEntry{user: "root", tables: []stmtctx.TableEntry{ + {DB: "d2", Table: "t2"}, + {DB: "db2", Table: "test2"}, + }}, + match: true, + }, + }, + }, + { + spec: filterSpec{ + TablesExclude: []string{"db*.test*", "test.t1"}, + }, + jsonContent: `{"table_excl": ["db*.test*", "test.t1"]}`, + entries: []entryFilterTest{ + { + entry: LogEntry{user: "root"}, + match: true, + }, + { + entry: LogEntry{user: "root", tables: []stmtctx.TableEntry{ + {DB: "d2", Table: "t2"}, + {DB: "db2", Table: "test2"}, + }}, + match: false, + }, + { + entry: LogEntry{user: "root", tables: []stmtctx.TableEntry{ + {DB: "d2", Table: "t2"}, + {DB: "test", Table: "t1"}, + }}, + match: false, + }, + { + entry: LogEntry{user: "root", tables: []stmtctx.TableEntry{ + {DB: "d2", Table: "t2"}, + {DB: "test", Table: "t2"}, + }}, + match: true, + }, + { + entry: LogEntry{user: "root", tables: []stmtctx.TableEntry{ + {DB: "d2", Table: "t2"}, + {DB: "db2", Table: "t2"}, + }}, + match: true, + }, + }, + }, + { + spec: filterSpec{ + Tables: []string{"db*.test*"}, + TablesExclude: []string{"db2.test*"}, + }, + jsonContent: `{"table": ["db*.test*"], "table_excl": ["db2.test*"]}`, + entries: []entryFilterTest{ + { + entry: LogEntry{user: "root"}, + match: false, + }, + { + entry: LogEntry{user: "root", tables: []stmtctx.TableEntry{ + {DB: "db2", Table: "test2"}, + }}, + match: false, + }, + { + entry: LogEntry{user: "root", tables: []stmtctx.TableEntry{ + {DB: "db3", Table: "test1"}, + }}, + match: true, + }, + }, + }, + { + spec: filterSpec{ + Tables: []string{"db*.test*"}, + TablesExclude: []string{"db*.test*"}, + }, + jsonContent: `{"table": ["db*.test*"], "table_excl": ["db*.test*"]}`, + entries: []entryFilterTest{ + { + entry: LogEntry{user: "root"}, + match: false, + }, + { + entry: LogEntry{user: "root", tables: []stmtctx.TableEntry{ + {DB: "db2", Table: "test2"}, + }}, + match: false, + }, + { + entry: LogEntry{user: "root", tables: []stmtctx.TableEntry{ + {DB: "db3", Table: "test1"}, + }}, + match: false, + }, + }, + }, + { + spec: filterSpec{ + StatusCodes: []int{0}, + }, + jsonContent: `{"status_code": [0]}`, + entries: []entryFilterTest{ + { + entry: LogEntry{user: "root"}, + match: false, + }, + { + entry: LogEntry{user: "root", err: errors.New("err")}, + match: true, + }, + }, + }, + { + spec: filterSpec{ + StatusCodes: []int{1}, + }, + jsonContent: `{"status_code": [1]}`, + entries: []entryFilterTest{ + { + entry: LogEntry{user: "root"}, + match: true, + }, + { + entry: LogEntry{user: "root", err: errors.New("err")}, + match: false, + }, + }, + }, + { + spec: filterSpec{ + StatusCodes: []int{0, 1}, + }, + jsonContent: `{"status_code": [0, 1]}`, + entries: []entryFilterTest{ + { + entry: LogEntry{user: "root"}, + match: true, + }, + { + entry: LogEntry{user: "root", err: errors.New("err")}, + match: true, + }, + }, + }, + } + + for _, c := range cases { + // test filter + bs, err := json.Marshal(c.spec) + require.NoError(t, err) + specJSON := string(bs) + if c.invalid { + require.Error(t, c.spec.Validate(), specJSON) + continue + } + require.NoError(t, c.spec.Validate(), specJSON) + filter := logFilter{Name: "test", Filter: []filterSpec{c.spec}} + fn := filter.createFilterFunc() + for _, ent := range c.entries { + require.Equal(t, ent.match, fn(&ent.entry), "%v %v", specJSON, ent.entry) + } + + // test creating filter by sql + tk.MustExec(fmt.Sprintf(`SELECT audit_log_create_filter('test', '{"filter":[%s]}')`, c.jsonContent)) + tk.MustExec(`SELECT audit_log_create_rule('%@%','test')`) + tk.MustQuery("SELECT count(*) FROM mysql.audit_log_filters").Check(testkit.Rows("1")) + tk.MustQuery("SELECT COALESCE(content->>'$.filter[0]', '{}') FROM mysql.audit_log_filters WHERE filter_name = 'test'").Check(testkit.Rows(c.jsonContent)) + rules, err := listFilterRules(context.Background(), tk.Session()) + require.NoError(t, err, specJSON) + require.Equal(t, 1, len(rules), specJSON) + require.Equal(t, c.spec, rules[0].filter.Filter[0], specJSON) + require.Equal(t, "test", rules[0].filterName, specJSON) + tk.MustExec("SELECT audit_log_remove_rule('%@%','test')") + tk.MustExec("SELECT audit_log_remove_filter('test')") + } + _, err := deleteAllAuditLogs(workDir, "tidb-audit", ".log") + require.NoError(t, err) +} + +func TestFilterRule(t *testing.T) { + register4Test() + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + cases := []struct { + rule logFilterRule + entries []entryFilterTest + jsonContent string + }{ + { + rule: logFilterRule{ + user: "root", + filterName: "f1", + enabled: true, + filter: &logFilter{ + Name: "f1", + }, + }, + jsonContent: `{}`, + entries: []entryFilterTest{ + { + entry: LogEntry{user: "root"}, + match: true, + }, + { + entry: LogEntry{user: "u1"}, + match: false, + }, + }, + }, + { + rule: logFilterRule{ + user: "root", + filterName: "f1", + enabled: false, + filter: &logFilter{ + Name: "f1", + }, + }, + jsonContent: `{}`, + entries: []entryFilterTest{ + { + entry: LogEntry{user: "root"}, + match: false, + }, + { + entry: LogEntry{user: "u1"}, + match: false, + }, + }, + }, + { + rule: logFilterRule{ + user: "root@127.0.0._", + filterName: "f1", + enabled: true, + filter: &logFilter{ + Name: "f1", + }, + }, + jsonContent: `{}`, + entries: []entryFilterTest{ + { + entry: LogEntry{user: "root"}, + match: false, + }, + { + entry: LogEntry{user: "u1"}, + match: false, + }, + { + entry: LogEntry{user: "root", host: "localhost"}, + match: true, + }, + { + entry: LogEntry{user: "root", host: "127.0.0.1"}, + match: true, + }, + { + entry: LogEntry{user: "root", host: "127.0.0.11"}, + match: false, + }, + { + entry: LogEntry{user: "root", host: "%"}, + match: false, + }, + }, + }, + { + rule: logFilterRule{ + user: "%@%", + filterName: "f1", + enabled: true, + filter: &logFilter{ + Name: "f1", + }, + }, + jsonContent: `{}`, + entries: []entryFilterTest{ + { + entry: LogEntry{user: "root"}, + match: true, + }, + { + entry: LogEntry{user: "u1"}, + match: true, + }, + { + entry: LogEntry{user: "root", host: "localhost"}, + match: true, + }, + { + entry: LogEntry{user: "root", host: "127.0.0.1"}, + match: true, + }, + { + entry: LogEntry{user: "root", host: "127.0.0.11"}, + match: true, + }, + { + entry: LogEntry{user: "root", host: "%"}, + match: true, + }, + }, + }, + { + rule: logFilterRule{ + user: "root", + filterName: "f1", + enabled: true, + filter: &logFilter{ + Name: "f1", + Filter: []filterSpec{}, + }, + }, + jsonContent: `{}`, + entries: []entryFilterTest{ + { + entry: LogEntry{user: "root"}, + match: true, + }, + { + entry: LogEntry{user: "u1"}, + match: false, + }, + }, + }, + { + rule: logFilterRule{ + user: "root", + filterName: "f1", + enabled: true, + filter: &logFilter{ + Name: "f1", + Filter: []filterSpec{ + // An LogEntry matches the Filter if only any filterSpec matches + { + Classes: []string{"QUERY_DML"}, + }, + { + Classes: []string{"QUERY_DDL"}, + }, + }, + }, + }, + jsonContent: `{"filter":[{"class":["QUERY_DML"]},{"class":["QUERY_DDL"]}]}`, + entries: []entryFilterTest{ + { + entry: LogEntry{user: "root"}, + match: false, + }, + { + entry: LogEntry{user: "root", classes: []EventClass{ClassDML}}, + match: true, + }, + { + entry: LogEntry{user: "root", classes: []EventClass{ClassDDL}}, + match: true, + }, + { + entry: LogEntry{user: "root", classes: []EventClass{ClassConnection}}, + match: false, + }, + { + entry: LogEntry{user: "user1", classes: []EventClass{ClassDML}}, + match: false, + }, + }, + }, + { + rule: logFilterRule{ + user: "root", + filterName: "f1", + enabled: true, + filter: &logFilter{ + Name: "f1", + Filter: []filterSpec{ + { + Classes: []string{"QUERY_DML"}, + }, + { + ClassesExclude: []string{"QUERY_DML"}, + }, + }, + }, + }, + jsonContent: `{"filter":[{"class":["QUERY_DML"]},{"class_excl":["QUERY_DML"]}]}`, + entries: []entryFilterTest{ + { + entry: LogEntry{user: "root"}, + match: true, + }, + { + entry: LogEntry{user: "root", classes: []EventClass{ClassDML}}, + match: true, + }, + { + entry: LogEntry{user: "root", classes: []EventClass{ClassDDL}}, + match: true, + }, + { + entry: LogEntry{user: "root", classes: []EventClass{ClassDML, ClassDDL}}, + match: true, + }, + { + entry: LogEntry{user: "user1", classes: []EventClass{ClassDML}}, + match: false, + }, + }, + }, + { + rule: logFilterRule{ + user: "root", + filterName: "f1", + enabled: true, + filter: &logFilter{ + Name: "f1", + Filter: []filterSpec{ + { + Tables: []string{"test*.t*"}, + }, + { + TablesExclude: []string{"test*.t*"}, + }, + }, + }, + }, + jsonContent: `{"filter":[{"table":["test*.t*"]},{"table_excl":["test*.t*"]}]}`, + entries: []entryFilterTest{ + { + entry: LogEntry{user: "root"}, + match: true, + }, + { + entry: LogEntry{user: "root", tables: []stmtctx.TableEntry{}}, + match: true, + }, + { + entry: LogEntry{user: "root", tables: []stmtctx.TableEntry{{DB: "test", Table: "t"}}}, + match: true, + }, + { + entry: LogEntry{user: "root", tables: []stmtctx.TableEntry{{DB: "test", Table: "c"}}}, + match: true, + }, + { + entry: LogEntry{user: "user1"}, + match: false, + }, + }, + }, + } + + for i, c := range cases { + // test rule + fn := c.rule.createFilterFunc() + for j, ent := range c.entries { + require.Equal(t, ent.match, fn(&ent.entry), "case index: %d entry index: %d", i, j) + } + + // test creating rule by sql + fmt.Println(c.jsonContent) + tk.MustExec(fmt.Sprintf(`SELECT audit_log_create_filter('%s', '%s')`, c.rule.filterName, c.jsonContent)) + tk.MustExec(fmt.Sprintf(`SELECT audit_log_create_rule('%s','%s')`, c.rule.user, c.rule.filterName)) + tk.MustQuery("SELECT count(*) FROM mysql.audit_log_filters").Check(testkit.Rows("1")) + tk.MustQuery("SELECT count(*) FROM mysql.audit_log_filter_rules").Check(testkit.Rows("1")) + tk.MustQuery(fmt.Sprintf("SELECT content FROM mysql.audit_log_filters WHERE filter_name = '%s'", c.rule.filterName)).Check(testkit.Rows(c.jsonContent)) + if !c.rule.enabled { + tk.MustExec(fmt.Sprintf(`SELECT audit_log_disable_rule('%s','%s')`, c.rule.user, c.rule.filterName)) + } + rules, err := listFilterRules(context.Background(), tk.Session()) + require.NoError(t, err, c.jsonContent) + require.Equal(t, 1, len(rules), c.jsonContent) + if rules[0].filter != nil && len(rules[0].filter.Filter) > 0 { + require.Equal(t, c.rule, *rules[0], c.jsonContent) + } + tk.MustExec(fmt.Sprintf(`SELECT audit_log_remove_rule('%s','%s')`, c.rule.user, c.rule.filterName)) + tk.MustExec(fmt.Sprintf(`SELECT audit_log_remove_filter('%s')`, c.rule.filterName)) + } +} + +func TestFilterRuleBundle(t *testing.T) { + cases := []struct { + bundle *LogFilterRuleBundle + entries []entryFilterTest + }{ + { + bundle: newLogFilterRuleBundle([]*logFilterRule{ + { + user: "root", + filterName: "f1", + enabled: true, + filter: &logFilter{ + Name: "f1", + Filter: []filterSpec{ + {Classes: []string{"query_ddl"}}, + }, + }, + }, + { + user: "root", + filterName: "f2", + enabled: true, + filter: &logFilter{ + Name: "f2", + Filter: []filterSpec{ + {Classes: []string{"query_dml"}, ClassesExclude: []string{"query_ddl"}}, + }, + }, + }, + { + user: "u1", + filterName: "f2", + enabled: true, + filter: &logFilter{ + Name: "f2", + Filter: []filterSpec{ + {Classes: []string{"query_dml"}}, + }, + }, + }, + }), + entries: []entryFilterTest{ + { + entry: LogEntry{user: "root", classes: []EventClass{ClassDML}}, + match: true, + }, + { + entry: LogEntry{user: "root", classes: []EventClass{ClassDDL}}, + match: true, + }, + { + entry: LogEntry{user: "root", classes: []EventClass{ClassDML, ClassDDL}}, + match: true, + }, + { + entry: LogEntry{user: "root", classes: []EventClass{ClassConnection}}, + match: false, + }, + { + entry: LogEntry{user: "u1", classes: []EventClass{ClassDML}}, + match: true, + }, + { + entry: LogEntry{user: "u1", classes: []EventClass{ClassDDL}}, + match: false, + }, + }, + }, + } + + for i, c := range cases { + for j, ent := range c.entries { + filtered := c.bundle.Filter(&ent.entry) + if filtered != nil { + require.Same(t, &ent.entry, filtered, "case index: %d entry index: %d", i, j) + } + require.Equal(t, ent.match, filtered != nil, "case index: %d entry index: %d", i, j) + } + } +} diff --git a/extension/serverless/audit/function_test.go b/extension/serverless/audit/function_test.go new file mode 100644 index 0000000000000..4e8c187bf5ccf --- /dev/null +++ b/extension/serverless/audit/function_test.go @@ -0,0 +1,340 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package audit + +import ( + "context" + "fmt" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/pingcap/tidb/config" + "github.com/pingcap/tidb/errno" + "github.com/pingcap/tidb/parser/auth" + "github.com/pingcap/tidb/server" + "github.com/pingcap/tidb/testkit" + "github.com/pingcap/tidb/util/sem" + "github.com/stretchr/testify/require" +) + +func eventClasses2String(eventClasses []EventClass) string { + eventClassNames := make([]string, len(eventClasses)) + for i, eventClass := range eventClasses { + eventClassNames[i] = class2String[eventClass] + } + return fmt.Sprintf(`[EVENT="[%s]"]`, strings.Join(eventClassNames, ",")) +} + +func TestAuditLogRotate(t *testing.T) { + register4Test() + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + + _, err := deleteAllAuditLogs(workDir, "tidb-audit", ".log") + require.NoError(t, err) + + tk.MustExec("SET global tidb_audit_log_reserved_backups = 2") + tk.MustExec("SET global tidb_audit_enabled = 1") + for i := 0; i < 5; i++ { + tk.MustExec("SELECT audit_log_rotate()") + time.Sleep(time.Second) + } + + files, err := deleteAllAuditLogs(workDir, "tidb-audit-", ".log") + require.NoError(t, err) + require.Equal(t, 2, len(files), files) +} + +func TestAuditLogEnableRule(t *testing.T) { + register4Test() + tempDir := t.TempDir() + store := testkit.CreateMockStore(t) + srv := server.CreateMockServer(t, store) + defer srv.Close() + conn := server.CreateMockConn(t, srv) + defer conn.Close() + tk := testkit.NewTestKit(t, store) + err := tk.Session().Auth(&auth.UserIdentity{Username: "root", Hostname: "%"}, nil, nil, nil) + require.NoError(t, err) + tk.MustExec("SET global tidb_audit_enabled = 1") + tk.MustExec(fmt.Sprintf("SET global tidb_audit_log = '%s'", filepath.Join(tempDir, DefAuditLogName))) + + _, err = deleteAllAuditLogs(tempDir, "tidb-audit", ".log") + require.NoError(t, err) + + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t (c int)") + tk.MustQuery("select audit_log_create_filter('all', '{}')").Check(testkit.Rows("OK")) + tk.MustQuery("select content from mysql.audit_log_filters where filter_name = 'all'").Check(testkit.Rows(`{}`)) + tk.MustQuery("select audit_log_create_rule('%@%', 'all')").Check(testkit.Rows("OK")) + tk.MustQuery("select enabled from mysql.audit_log_filter_rules where user = '%@%' and filter_name = 'all'").Check(testkit.Rows("1")) + tk.MustExec("SET global tidb_audit_log_redacted = 0") + + require.NoError(t, conn.HandleQuery(context.Background(), "use test")) + require.NoError(t, conn.HandleQuery(context.Background(), "insert into t values (1)")) + ok, _, err := containsMessage(filepath.Join(tempDir, DefAuditLogName), "insert into t values (1)") + require.NoError(t, err) + require.True(t, ok) + + tk.MustQuery("select audit_log_disable_rule('%@%', 'all')").Check(testkit.Rows("OK")) + require.NoError(t, conn.HandleQuery(context.Background(), "insert into t values (2)")) + ok, _, err = containsMessage(filepath.Join(tempDir, DefAuditLogName), "insert into t values (2)") + require.NoError(t, err) + require.False(t, ok) + + tk.MustQuery("select audit_log_enable_rule('%@%', 'all')").Check(testkit.Rows("OK")) + require.NoError(t, conn.HandleQuery(context.Background(), "insert into t values (3)")) + ok, _, err = containsMessage(filepath.Join(tempDir, DefAuditLogName), "insert into t values (3)") + require.NoError(t, err) + require.True(t, ok) + + _, err = deleteAllAuditLogs(tempDir, "tidb-audit", ".log") + require.NoError(t, err) +} + +func TestAuditAdmin(t *testing.T) { + register4Test() + store := testkit.CreateMockStore(t) + srv := server.CreateMockServer(t, store) + defer srv.Close() + conn := server.CreateMockConn(t, srv) + defer conn.Close() + tk := testkit.NewTestKit(t, store) + err := tk.Session().Auth(&auth.UserIdentity{Username: "root", Hostname: "%"}, nil, nil, nil) + require.NoError(t, err) + + tk.MustExec("CREATE USER testuser") + tk2 := testkit.NewTestKit(t, store) + err = tk2.Session().Auth(&auth.UserIdentity{Username: "testuser", Hostname: "%"}, nil, nil, nil) + require.NoError(t, err) + tk.MustExec("GRANT SELECT, SYSTEM_VARIABLES_ADMIN ON *.* to testuser") + + tk2.MustGetErrCode(`select audit_log_create_filter("empty", '{"filter":[]}')`, errno.ErrSpecificAccessDenied) + tk2.MustGetErrCode(`set global tidb_audit_enabled = 0`, errno.ErrSpecificAccessDenied) + tk2.MustGetErrCode(`select * from mysql.audit_log_filters`, errno.ErrTableaccessDenied) + + tk.MustExec("GRANT AUDIT_ADMIN ON *.* to testuser") + + tk2.MustExec(`select audit_log_create_filter("all_query", '{"filter":[{"class":["QUERY"]}]}')`) + tk2.MustExec(`set global tidb_audit_enabled = 0`) + tk2.MustQuery(`select * from mysql.audit_log_filters`).Check(testkit.Rows(`all_query {"filter":[{"class":["QUERY"]}]}`)) + + _, err = deleteAllAuditLogs(workDir, "tidb-audit", ".log") + require.NoError(t, err) +} + +func TestRestrictedAuditAdmin(t *testing.T) { + register4Test() + store := testkit.CreateMockStore(t) + sem.Enable(config.SEMLevelBasic) + defer sem.Disable() + srv := server.CreateMockServer(t, store) + defer srv.Close() + conn := server.CreateMockConn(t, srv) + defer conn.Close() + tk := testkit.NewTestKit(t, store) + + tk.MustExec("CREATE USER testuser") + tk2 := testkit.NewTestKit(t, store) + err := tk2.Session().Auth(&auth.UserIdentity{Username: "testuser", Hostname: "%"}, nil, nil, nil) + require.NoError(t, err) + tk.MustExec("GRANT SELECT, SYSTEM_VARIABLES_ADMIN, AUDIT_ADMIN ON *.* to testuser") + + tk2.MustGetErrCode(`select audit_log_create_filter("empty", '{"filter":[]}')`, errno.ErrSpecificAccessDenied) + tk2.MustGetErrCode(`set global tidb_audit_enabled = 0`, errno.ErrSpecificAccessDenied) + tk2.MustGetErrCode(`select * from mysql.audit_log_filters`, errno.ErrTableaccessDenied) + + tk.MustExec("GRANT RESTRICTED_AUDIT_ADMIN ON *.* to testuser") + + tk2.MustExec(`select audit_log_create_filter("all_query", '{"filter":[{"class":["QUERY"]}]}')`) + tk2.MustExec(`set global tidb_audit_enabled = 0`) + tk2.MustQuery(`select * from mysql.audit_log_filters`).Check(testkit.Rows(`all_query {"filter":[{"class":["QUERY"]}]}`)) + + _, err = deleteAllAuditLogs(workDir, "tidb-audit", ".log") + require.NoError(t, err) +} + +func TestEventClass(t *testing.T) { + register4Test() + store := testkit.CreateMockStore(t) + srv := server.CreateMockServer(t, store) + defer srv.Close() + conn := server.CreateMockConn(t, srv) + defer conn.Close() + + defer func() { + _, err := deleteAllAuditLogs(workDir, "tidb-audit", ".log") + require.NoError(t, err) + }() + logPath := filepath.Join(workDir, DefAuditLogName) + _, err := deleteAllAuditLogs(workDir, "tidb-audit", ".log") + require.NoError(t, err) + require.NoError(t, conn.HandleQuery(context.Background(), "SET global tidb_audit_enabled = 1")) + require.NoError(t, conn.HandleQuery(context.Background(), "SELECT audit_log_create_filter('all', '{}')")) + require.NoError(t, conn.HandleQuery(context.Background(), "SELECT audit_log_create_rule('%@%', 'all')")) + + // test QUERY and AUDIT + testcases := []struct { + sql string + requireEvents []EventClass + success bool + }{ + {"use test", []EventClass{ClassQuery}, true}, + {"begin", []EventClass{ClassQuery, ClassTransaction}, true}, + {"create table if not exists t (c int)", []EventClass{ClassQuery, ClassDDL}, true}, + {"select * from t", []EventClass{ClassQuery, ClassSelect}, true}, + {"commit", []EventClass{ClassQuery, ClassTransaction}, true}, + {"prepare abs_func from 'select abs(?)'", []EventClass{ClassQuery}, true}, + {"set @num = -1", []EventClass{ClassQuery}, true}, + {"execute abs_func using @num", []EventClass{ClassQuery, ClassExecute, ClassSelect}, true}, + {"insert into t values (1)", []EventClass{ClassQuery, ClassDML, ClassInsert}, true}, + {"replace into t values (2)", []EventClass{ClassQuery, ClassDML, ClassReplace}, true}, + {"update t set c = 3 where c = 1", []EventClass{ClassQuery, ClassDML, ClassUpdate}, true}, + {"delete from t", []EventClass{ClassQuery, ClassDML, ClassDelete}, true}, + {`LOAD DATA LOCAL INFILE 'load_data.csv' INTO TABLE t FIELDS TERMINATED BY ',' ENCLOSED BY '\"' LINES TERMINATED BY '\r\n'`, []EventClass{ClassQuery, ClassDML, ClassLoadData}, false}, + {"select audit_log_create_filter('empty', '{}')", []EventClass{ClassAudit, ClassAuditFuncCall}, true}, + {"set global tidb_audit_enabled = 1", []EventClass{ClassAudit, ClassAuditSetSysVar, ClassAuditEnable}, true}, + {"set global tidb_audit_enabled = 0", []EventClass{ClassAudit, ClassAuditSetSysVar, ClassAuditDisable}, true}, + } + for _, tc := range testcases { + if tc.success { + require.NoError(t, conn.HandleQuery(context.Background(), tc.sql), tc.sql) + } else { + require.Error(t, conn.HandleQuery(context.Background(), tc.sql), tc.sql) + } + eventClassesStr := eventClasses2String(tc.requireEvents) + ok, log, err := containsMessage(logPath, eventClassesStr) + require.NoError(t, err, "%s\n%s\n%s", tc.sql, eventClassesStr, log) + require.True(t, ok, "%s\n%s\n%s", tc.sql, eventClassesStr, log) + require.NoError(t, conn.HandleQuery(context.Background(), "SELECT audit_log_rotate()"), "%s\n%s\n%s", tc.sql, eventClassesStr, log) + time.Sleep(time.Second) + } +} + +func lastLineContainsMessage(logpath, msg string) (ok bool, str string, err error) { + bytes, err := os.ReadFile(logpath) + strs := strings.Split(string(bytes), "\n") + str = strs[len(strs)-2] + if err != nil { + return false, str, err + } + return strings.Contains(str, msg), str, nil +} + +func TestServerlessDisabled(t *testing.T) { + conf := config.AuditLog{Enable: false, Path: "tidb-audit.log"} + + registerForServerlessTest(&conf) + store := testkit.CreateMockStore(t) + srv := server.CreateMockServer(t, store) + defer srv.Close() + conn := server.CreateMockConn(t, srv) + defer conn.Close() + + defer func() { + _, err := deleteAllAuditLogs(workDir, "tidb-audit", ".log") + require.NoError(t, err) + }() + logPath := filepath.Join(workDir, conf.Path) + _, err := deleteAllAuditLogs(workDir, "tidb-audit", ".log") + require.NoError(t, err) + + sqls := []string{ + "use test", + "begin", + "create table if not exists t (c int)", + "select * from t", + "commit", + } + for _, sql := range sqls { + require.NoError(t, conn.HandleQuery(context.Background(), sql), sql) + } + + _, err = os.ReadFile(logPath) + require.Error(t, err, "audit log file should not exists if audit log disabled") +} + +func TestServerlessEnabled(t *testing.T) { + conf := config.AuditLog{ + Enable: true, + Path: "tidb-audit.log", + Format: LogFormatText, + MaxFilesize: 10, + MaxLifetime: 60 * 60, + Redacted: true, + } + registerForServerlessTest(&conf) + store := testkit.CreateMockStore(t) + srv := server.CreateMockServer(t, store) + defer srv.Close() + conn := server.CreateMockConn(t, srv) + defer conn.Close() + + defer func() { + _, err := deleteAllAuditLogs(workDir, "tidb-audit", ".log") + require.NoError(t, err) + }() + logPath := filepath.Join(workDir, conf.Path) + _, err := deleteAllAuditLogs(workDir, "tidb-audit", ".log") + require.NoError(t, err) + + testcases := []struct { + sql string + requireEvents []EventClass + success bool + }{ + {"use test", []EventClass{ClassQuery}, true}, + {"begin", []EventClass{ClassQuery, ClassTransaction}, true}, + {"create table if not exists t (c int)", []EventClass{ClassQuery, ClassDDL}, true}, + {"select * from t", []EventClass{ClassQuery, ClassSelect}, true}, + {"commit", []EventClass{ClassQuery, ClassTransaction}, true}, + {"insert into t values (1)", []EventClass{ClassQuery, ClassDML, ClassInsert}, true}, + {"replace into t values (2)", []EventClass{ClassQuery, ClassDML, ClassReplace}, true}, + {"update t set c = 3 where c = 1", []EventClass{ClassQuery, ClassDML, ClassUpdate}, true}, + {"delete from t", []EventClass{ClassQuery, ClassDML, ClassDelete}, true}, + } + for _, tc := range testcases { + if tc.success { + require.NoError(t, conn.HandleQuery(context.Background(), tc.sql), tc.sql) + } else { + require.Error(t, conn.HandleQuery(context.Background(), tc.sql), tc.sql) + } + eventClassesStr := eventClasses2String(tc.requireEvents) + ok, log, err := lastLineContainsMessage(logPath, eventClassesStr) + require.NoError(t, err, "%s\n%s\n%s", tc.sql, eventClassesStr, log) + require.True(t, ok, "%s\n%s\n%s", tc.sql, eventClassesStr, log) + ok, log, err = lastLineContainsMessage(logPath, tc.sql) + require.NoError(t, err, "%s\n%s", log, tc.sql) + require.True(t, ok, "%s\n%s", log, tc.sql) + } + + // Enable encryption to validate it works + conf.EncryptKey = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" + registerForServerlessTest(&conf) + for _, tc := range testcases { + if tc.success { + require.NoError(t, conn.HandleQuery(context.Background(), tc.sql), tc.sql) + } else { + require.Error(t, conn.HandleQuery(context.Background(), tc.sql), tc.sql) + } + ok, log, err := lastLineContainsMessage(logPath, tc.sql) + require.NoError(t, err, "%s\n%s", log, tc.sql) + require.False(t, ok, "%s\n%s", log, tc.sql) + } +} diff --git a/extension/serverless/audit/logger.go b/extension/serverless/audit/logger.go new file mode 100644 index 0000000000000..2ae77e9ef2d30 --- /dev/null +++ b/extension/serverless/audit/logger.go @@ -0,0 +1,254 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package audit + +import ( + "context" + "fmt" + "os" + "path/filepath" + "strings" + "time" + "unicode" + + "github.com/pingcap/errors" + "github.com/pingcap/log" + "github.com/pingcap/tidb/config" + "github.com/pingcap/tidb/parser/terror" + "github.com/pingcap/tidb/util/logutil" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + "gopkg.in/natefinch/lumberjack.v2" +) + +// Log formats +const ( + LogFormatText = "TEXT" + LogFormatJSON = "JSON" +) + +// Values of system variables +const ( + MaxAuditLogFileMaxSize uint64 = 100 * 1024 // 100GB + MaxAuditLogFileMaxLifetime uint64 = 30 * 24 * 60 * 60 // 30 days + MaxAuditLogFileReservedBackups uint64 = 1024 + MaxAuditLogFileReservedDays uint64 = 1024 + + DefAuditLogName = "tidb-audit.log" + DefAuditLogFormat = LogFormatText + DefAuditLogFileMaxSize int64 = 100 // 100MB + DefAuditLogFileMaxLifetime int64 = 1 * 24 * 60 * 60 // 1 day + DefAuditLogFileReservedBackups int = 10 + DefAuditLogFileReservedDays int = 0 + DefAuditLogRedact bool = true +) + +type zapLoggerConfig struct { + Filepath string + FileMaxSize int64 + FileMaxLifetime int64 + FileReservedBackups int + FileReservedDays int + LogFormat string +} + +func (cfg zapLoggerConfig) zapCfgChanged(cmp LoggerConfig) bool { + return cfg != cmp.zapLoggerConfig +} + +// LoggerConfig is the config for logger +type LoggerConfig struct { + zapLoggerConfig + Enabled bool + ConfigPath string + Redact bool + Filter *LogFilterRuleBundle + EncryptKey string + EncryptIv string +} + +// Logger is used to log audit logs +type Logger struct { + zapLogger *zap.Logger + rotate func() error + cfg LoggerConfig + + ctx context.Context + cancel context.CancelFunc +} + +// LoggerWithConfig creates a new Logger, it will reuse some inner states if possible +func LoggerWithConfig(ctx context.Context, cfg LoggerConfig, old *Logger) (_ *Logger, err error) { + if old != nil && old.cfg == cfg { + return old, nil + } + + var zapLogger *zap.Logger + var rotate func() error + + if old == nil || old.cfg.zapCfgChanged(cfg) { + if zapLogger, rotate, err = newZapLogger(cfg); err != nil { + return nil, err + } + } else { + zapLogger = old.zapLogger + rotate = old.rotate + } + + subCtx, cancel := context.WithCancel(ctx) + return &Logger{ + zapLogger: zapLogger, + rotate: rotate, + cfg: cfg, + ctx: subCtx, + cancel: cancel, + }, nil +} + +// Rotate rotates the log +func (l *Logger) Rotate() error { + if l == nil { + return errors.New("audit logger is not initialized") + } + return l.rotate() +} + +// Event logs the event +func (l *Logger) Event(fields []zap.Field) { + if l == nil { + terror.Log(errors.New("audit logger is not initialized")) + return + } + l.zapLogger.Info("", fields...) +} + +func (l *Logger) loopCheckLogMaxLifetime(interval time.Duration) { + go func() { + ticker := time.NewTicker(interval) + for { + select { + case <-ticker.C: + if err := l.Rotate(); err != nil { + logutil.BgLogger().Error("Fail to rotate audit log by tidb_audit_log_max_lifetime", zap.Error(err)) + } + case <-l.ctx.Done(): + ticker.Stop() + return + } + } + }() +} + +func getServerEndpoint() string { + globalConf := config.GetGlobalConfig() + var serverEndpoint strings.Builder + for _, c := range fmt.Sprintf("%s-%d", globalConf.AdvertiseAddress, globalConf.Port) { + if unicode.IsSpace(c) { + serverEndpoint.WriteRune('_') + continue + } + + if c == '.' { + serverEndpoint.WriteRune('-') + continue + } + + serverEndpoint.WriteRune(c) + } + return serverEndpoint.String() +} + +func getFilePath(configPath string, format string) (string, error) { + if configPath == "" { + configPath = DefAuditLogName + } + if !filepath.IsAbs(configPath) { + baseFolder := "" + if logFile := config.GetGlobalConfig().Log.File.Filename; logFile != "" { + baseFolder = filepath.Dir(logFile) + } else { + baseFolder = filepath.Dir(configPath) + } + if !filepath.IsAbs(baseFolder) { + wd, err := os.Getwd() + if err != nil { + return "", err + } + baseFolder = filepath.Join(wd, baseFolder) + } + configPath = filepath.Join(baseFolder, configPath) + } + + if format == LogFormatJSON { + configPath = configPath + ".json" + } + + return strings.ReplaceAll(configPath, "%e", getServerEndpoint()), nil +} + +func newZapLogger(cfg LoggerConfig) (*zap.Logger, func() error, error) { + if st, err := os.Stat(cfg.Filepath); err == nil { + if st.IsDir() { + return nil, nil, errors.New("can't use directory as log file name") + } + } + + zapCfg := &log.Config{ + Level: "info", + DisableCaller: true, + File: log.FileLogConfig{ + Filename: cfg.Filepath, + MaxSize: int(cfg.FileMaxSize), + MaxBackups: cfg.FileReservedBackups, + MaxDays: cfg.FileReservedDays, + }, + } + + // use lumberjack to logrotate + writer := &lumberjack.Logger{ + Filename: zapCfg.File.Filename, + MaxSize: zapCfg.File.MaxSize, + MaxBackups: zapCfg.File.MaxBackups, + MaxAge: zapCfg.File.MaxDays, + LocalTime: true, + } + + logger, props, err := log.InitLoggerWithWriteSyncer( + zapCfg, + zapcore.AddSync(writer), + zapcore.AddSync(&errLogWriter{}), + ) + + if err != nil { + return nil, nil, err + } + + var newEncoder zapcore.Encoder + if cfg.LogFormat == LogFormatJSON { + newEncoder = zapcore.NewJSONEncoder(zapcore.EncoderConfig{ + TimeKey: LogKeyTime, + EncodeTime: log.DefaultTimeEncoder, + }) + } + + if newEncoder != nil { + newCore := log.NewTextCore(newEncoder, props.Syncer, props.Level) + logger = logger.WithOptions(zap.WrapCore(func(core zapcore.Core) zapcore.Core { + return newCore + })) + } + + return logger, writer.Rotate, nil +} diff --git a/extension/serverless/audit/manager.go b/extension/serverless/audit/manager.go new file mode 100644 index 0000000000000..2723b2e3ff3af --- /dev/null +++ b/extension/serverless/audit/manager.go @@ -0,0 +1,713 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package audit + +import ( + "context" + "crypto/aes" + "fmt" + "strings" + "sync" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/tidb/config" + "github.com/pingcap/tidb/extension" + "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/parser/terror" + "github.com/pingcap/tidb/util/logutil" + "github.com/pingcap/tidb/util/sqlexec" + "go.etcd.io/etcd/api/v3/mvccpb" + clientv3 "go.etcd.io/etcd/client/v3" + "go.uber.org/atomic" + "go.uber.org/zap" +) + +const ( + // EtcdKeyNotifyPrefix is the key prefix to notify audit log admin events to all instances + EtcdKeyNotifyPrefix = "/tidb/audit/log/notify/" + // EtcdKeyNotifyRotate is used to notify to rotate logs + EtcdKeyNotifyRotate = EtcdKeyNotifyPrefix + "rotate" + // EtcdKeyNotifyFilter is used to notify filter changes + EtcdKeyNotifyFilter = EtcdKeyNotifyPrefix + "filter" + // GlobalNotifyTimeout is the timeout to notify global event + GlobalNotifyTimeout = time.Minute + // MinEtcdReWatchInterval is the min interval for retry watch + MinEtcdReWatchInterval = 5 * time.Second + // Expect to use AES-256 for log encryption + expectedEncryptKeyLength = 32 +) + +// LogManager is used to manage audit logs +type LogManager struct { + sync.Mutex + + ctx context.Context + cancel func() + + loggerRef atomic.Value + loggerCfg LoggerConfig + etcdCli *clientv3.Client + sessPool extension.SessionPool +} + +// InitWithConfig initialize a manager with configuration +func (m *LogManager) InitWithConfig(conf *config.AuditLog) (err error) { + m.Lock() + defer m.Unlock() + + m.ctx = kv.WithInternalSourceType(context.Background(), kv.InternalTxnOthers) + m.ctx, m.cancel = context.WithCancel(m.ctx) + defer func() { + if err != nil { + m.Close() + } + }() + + logFilepath, err := getFilePath(conf.Path, string(conf.Format)) + if err != nil { + return err + } + + var encryptKey, encryptIv string + if conf.EncryptKey != "" { + if len(conf.EncryptKey) != expectedEncryptKeyLength { + return fmt.Errorf("bad encrypt key, %d != %d", len(conf.EncryptKey), expectedEncryptKeyLength) + } + encryptKey = conf.EncryptKey + encryptIv = conf.EncryptKey[:aes.BlockSize] + } + + m.loggerCfg = LoggerConfig{ + Enabled: conf.Enable, + ConfigPath: conf.Path, + zapLoggerConfig: zapLoggerConfig{ + Filepath: logFilepath, + FileMaxSize: conf.MaxFilesize, + FileMaxLifetime: conf.MaxLifetime, + FileReservedBackups: 0, + FileReservedDays: 0, + LogFormat: string(conf.Format), + }, + Redact: conf.Redacted, + EncryptKey: encryptKey, + EncryptIv: encryptIv, + } + + return m.refreshLoggerRef() +} + +// Init inits a manager +func (m *LogManager) Init() (err error) { + m.Lock() + defer m.Unlock() + + // m.ctx != nil means already inited + if m.ctx != nil { + return nil + } + + m.ctx = kv.WithInternalSourceType(context.Background(), kv.InternalTxnOthers) + m.ctx, m.cancel = context.WithCancel(m.ctx) + defer func() { + if err != nil { + m.Close() + } + }() + + logConfigPath := DefAuditLogName + logFilepath, err := getFilePath(logConfigPath, DefAuditLogFormat) + if err != nil { + return err + } + + m.loggerCfg = LoggerConfig{ + Enabled: false, + ConfigPath: logConfigPath, + zapLoggerConfig: zapLoggerConfig{ + Filepath: logFilepath, + FileMaxSize: DefAuditLogFileMaxSize, + FileMaxLifetime: DefAuditLogFileMaxLifetime, + FileReservedBackups: DefAuditLogFileReservedBackups, + FileReservedDays: DefAuditLogFileReservedDays, + LogFormat: DefAuditLogFormat, + }, + Redact: DefAuditLogRedact, + } + + return m.refreshLoggerRef() +} + +// Bootstrap bootstraps audit log manager +func (m *LogManager) Bootstrap(ctx extension.BootstrapContext) error { + etcdCli := ctx.EtcdClient() + sessPool := ctx.SessionPool() + + m.Lock() + if m.ctx == nil { + m.Unlock() + return errors.New("not initialized") + } + m.etcdCli = etcdCli + m.sessPool = sessPool + m.Unlock() + + if err := m.bootstrapSQL(); err != nil { + return err + } + + if err := m.ReloadFilter(); err != nil { + return err + } + + go m.handleNotifyLoop(ctx, etcdCli) + return nil +} + +func (m *LogManager) bootstrapSQL() error { + sqlList := []string{ + CreateFilterTableSQL, + CreateFilterRuleTableSQL, + } + + return m.runSQLInExecutor(func(exec sqlexec.SQLExecutor) error { + for _, sql := range sqlList { + if _, err := executeSQL(m.ctx, exec, sql); err != nil { + return err + } + } + return nil + }) +} + +func (m *LogManager) runSQLInExecutor(fn func(sqlexec.SQLExecutor) error) error { + m.Lock() + sessPool := m.sessPool + m.Unlock() + + if sessPool == nil { + return errors.New("sessPool is not initialized") + } + + sess, err := sessPool.Get() + if err != nil { + return err + } + + defer func() { + if sess != nil { + sessPool.Put(sess) + } + }() + + exec, ok := sess.(sqlexec.SQLExecutor) + if !ok { + return errors.Errorf("type %T cannot be casted to SQLExecutor", sess) + } + + return fn(exec) +} + +func (m *LogManager) runInTxn(ctx context.Context, fn func(sqlexec.SQLExecutor) error) error { + return m.runSQLInExecutor(func(exec sqlexec.SQLExecutor) (err error) { + done := false + defer func() { + if done { + _, err = executeSQL(ctx, exec, "COMMIT") + } else { + _, rollbackErr := exec.ExecuteInternal(ctx, "ROLLBACK") + terror.Log(rollbackErr) + } + }() + + if _, err = executeSQL(ctx, exec, "BEGIN"); err != nil { + return err + } + + if err = fn(exec); err != nil { + return err + } + + done = true + return nil + }) +} + +func (m *LogManager) globalNotify(key string) (bool, error) { + var cli *clientv3.Client + m.Lock() + cli = m.etcdCli + m.Unlock() + + if cli == nil { + // `cli == nil` it is not a TiKV cluster, return false, nil to indicate global notify is invalid + return false, nil + } + + ctx, cancel := context.WithTimeout(m.ctx, GlobalNotifyTimeout) + defer cancel() + + if _, err := cli.Put(ctx, key, fmt.Sprintf(`{"ts": %d}`, time.Now().UnixMilli())); err != nil { + return false, err + } + + return true, nil +} + +func (m *LogManager) handleNotifyLoop(ctx context.Context, etcd *clientv3.Client) { + if etcd == nil { + return + } + + watcher := newKeyWatcher(ctx, etcd, EtcdKeyNotifyPrefix) + reloadFilterTicker := time.NewTicker(5 * time.Minute) +Loop: + for { + select { + case <-ctx.Done(): + break Loop + case resp, ok := <-watcher.Chan: + if !ok { + watcher.RenewClosedChan() + continue + } + + processedKeys := make(map[string]struct{}, len(resp.Events)) + for _, event := range resp.Events { + if event.Type != mvccpb.PUT { + continue + } + + key := string(event.Kv.Key) + logutil.BgLogger().Info("audit notify message received", zap.String("key", key), zap.ByteString("value", event.Kv.Value)) + if _, ok = processedKeys[key]; ok { + continue + } + + switch key { + case EtcdKeyNotifyRotate: + terror.Log(m.RotateLog()) + case EtcdKeyNotifyFilter: + terror.Log(m.ReloadFilter()) + } + } + case <-reloadFilterTicker.C: + terror.Log(m.ReloadFilter()) + } + } + logutil.BgLogger().Info("audit globalRotateNotifyLoop exit") +} + +// Close closes the log manager +func (m *LogManager) Close() { + m.Lock() + defer m.Unlock() + m.ctx = nil + m.loggerCfg = LoggerConfig{} + m.loggerRef.Store(&Logger{}) + m.etcdCli = nil + if m.cancel != nil { + m.cancel() + m.cancel = nil + } +} + +// Enabled returns whether the log is enabled +func (m *LogManager) Enabled() bool { + m.Lock() + defer m.Unlock() + return m.loggerCfg.Enabled +} + +// SetEnabled sets whether the log should be enabled +func (m *LogManager) SetEnabled(enabled bool) error { + m.Lock() + defer m.Unlock() + m.loggerCfg.Enabled = enabled + return m.refreshLoggerRef() +} + +// SetLogConfigPath sets the log config path +func (m *LogManager) SetLogConfigPath(path string) error { + m.Lock() + defer m.Unlock() + + logFilepath, err := getFilePath(path, m.loggerCfg.LogFormat) + if err != nil { + return err + } + + m.loggerCfg.ConfigPath = path + m.loggerCfg.Filepath = logFilepath + return m.refreshLoggerRef() +} + +// GetLogConfigPath gets the log config path +func (m *LogManager) GetLogConfigPath() string { + m.Lock() + defer m.Unlock() + return m.loggerCfg.ConfigPath +} + +// used for tests +func (m *LogManager) getLogPath() string { + m.Lock() + defer m.Unlock() + return m.loggerCfg.zapLoggerConfig.Filepath +} + +// SetLogFormat sets the log format +func (m *LogManager) SetLogFormat(format string) error { + m.Lock() + defer m.Unlock() + format = strings.ToUpper(format) + + logFilePath, err := getFilePath(m.loggerCfg.ConfigPath, format) + if err != nil { + return err + } + + m.loggerCfg.LogFormat = format + m.loggerCfg.Filepath = logFilePath + return m.refreshLoggerRef() +} + +// GetLogFormat gets the format of log +func (m *LogManager) GetLogFormat() string { + m.Lock() + defer m.Unlock() + return m.loggerCfg.LogFormat +} + +// SetFileMaxSize sets the max size of a log file +func (m *LogManager) SetFileMaxSize(val int64) error { + m.Lock() + defer m.Unlock() + m.loggerCfg.FileMaxSize = val + return m.refreshLoggerRef() +} + +// SetFileMaxLifetime sets the max rotation time of a log file +func (m *LogManager) SetFileMaxLifetime(val int64) error { + m.Lock() + defer m.Unlock() + m.loggerCfg.FileMaxLifetime = val + return m.refreshLoggerRef() +} + +// GetFileMaxSize returns the max size of a log file +func (m *LogManager) GetFileMaxSize() int64 { + m.Lock() + defer m.Unlock() + return m.loggerCfg.FileMaxSize +} + +// GetFileMaxLifetime returns the max rotation time of a log file +func (m *LogManager) GetFileMaxLifetime() int64 { + m.Lock() + defer m.Unlock() + return m.loggerCfg.FileMaxLifetime +} + +// SetFileReservedBackups sets the max backup count +func (m *LogManager) SetFileReservedBackups(val int) error { + m.Lock() + defer m.Unlock() + m.loggerCfg.FileReservedBackups = val + return m.refreshLoggerRef() +} + +// GetFileReservedBackups returns the max backup count +func (m *LogManager) GetFileReservedBackups() int { + m.Lock() + defer m.Unlock() + return m.loggerCfg.FileReservedBackups +} + +// SetFileReservedDays sets the max reserved days for logs +func (m *LogManager) SetFileReservedDays(val int) error { + m.Lock() + defer m.Unlock() + m.loggerCfg.FileReservedDays = val + return m.refreshLoggerRef() +} + +// GetFileReservedDays returns the max reserved days for logs +func (m *LogManager) GetFileReservedDays() int { + m.Lock() + defer m.Unlock() + return m.loggerCfg.FileReservedDays +} + +// RedactLog indicates whether to redact the log +func (m *LogManager) RedactLog() bool { + m.Lock() + defer m.Unlock() + return m.loggerCfg.Redact +} + +// SetRedactLog sets whether redact the log +func (m *LogManager) SetRedactLog(val bool) error { + m.Lock() + defer m.Unlock() + m.loggerCfg.Redact = val + return m.refreshLoggerRef() +} + +// GetFilter returns the filter +func (m *LogManager) GetFilter() *LogFilterRuleBundle { + m.Lock() + defer m.Unlock() + return m.loggerCfg.Filter +} + +// SetFilter sets the filter +func (m *LogManager) SetFilter(f *LogFilterRuleBundle) error { + m.Lock() + defer m.Unlock() + m.loggerCfg.Filter = f + return m.refreshLoggerRef() +} + +// RotateLog rotates the log in current instance +func (m *LogManager) RotateLog() error { + logutil.BgLogger().Info("rotate audit log in local") + return m.logger().Rotate() +} + +// GlobalRotateLog rotates all logs in cluster +func (m *LogManager) GlobalRotateLog() error { + ok, err := m.globalNotify(EtcdKeyNotifyRotate) + if err != nil { + return err + } + + if !ok { + // !ok means the not a TiKV cluster, only rotate in local + return m.RotateLog() + } + return nil +} + +func (m *LogManager) refreshLoggerRef() error { + oldLogger := m.logger() + newLogger, err := LoggerWithConfig(m.ctx, m.loggerCfg, oldLogger) + if err != nil { + return err + } + + if newLogger != oldLogger { + m.loggerRef.Store(newLogger) + + if cfg := m.loggerCfg; cfg.Enabled && cfg.FileMaxLifetime > 0 { + newLogger.loopCheckLogMaxLifetime(time.Duration(m.loggerCfg.FileMaxLifetime) * time.Second) + } + if oldLogger != nil && oldLogger.cfg.Enabled && oldLogger.cancel != nil { + oldLogger.cancel() + } + } + + return nil +} + +func (m *LogManager) notifyFilterConfigUpdated() error { + ok, err := m.globalNotify(EtcdKeyNotifyFilter) + if err != nil { + return err + } + + if !ok { + return m.ReloadFilter() + } + + return nil +} + +// ReloadFilter reloads filters from table +func (m *LogManager) ReloadFilter() error { + return m.runInTxn(m.ctx, func(exec sqlexec.SQLExecutor) error { + rules, err := listFilterRules(m.ctx, exec) + if err != nil { + return err + } + filter := newLogFilterRuleBundle(rules) + return m.SetFilter(filter) + }) +} + +// CreateLogFilter creates log filter +func (m *LogManager) CreateLogFilter(filter *logFilter, replace bool) error { + if err := filter.Validate(); err != nil { + return err + } + + normalizedFilter, err := filter.Normalize() + if err != nil { + return err + } + + content, err := normalizedFilter.ToJSON() + if err != nil { + return err + } + + err = m.runSQLInExecutor(func(exec sqlexec.SQLExecutor) error { + insertSQL := InsertFilterSQL + if replace { + insertSQL = ReplaceFilterSQL + } + + if _, err = executeSQL(m.ctx, exec, insertSQL, filter.Name, content); err != nil { + return err + } + return nil + }) + + if err != nil { + return err + } + + return m.notifyFilterConfigUpdated() +} + +// RemoveLogFilter removes a log filter +func (m *LogManager) RemoveLogFilter(name string) error { + err := m.runInTxn(m.ctx, func(exec sqlexec.SQLExecutor) error { + rows, err := executeSQL(m.ctx, exec, SelectFilterRuleByFilterSQL, name) + if err != nil { + return err + } + + if len(rows) > 0 { + return errors.Errorf("filter '%s' is in use", name) + } + + if _, err = executeSQL(m.ctx, exec, DeleteFilterByNameSQL, name); err != nil { + return err + } + return nil + }) + + if err != nil { + return err + } + + return m.notifyFilterConfigUpdated() +} + +// CreateLogFilterRule creates a log filter rule +func (m *LogManager) CreateLogFilterRule(user, filter string, replace bool) error { + user, _, _, err := normalizeUserAndHost(user) + if err != nil { + return err + } + + err = m.runInTxn(m.ctx, func(exec sqlexec.SQLExecutor) error { + rows, err := executeSQL(m.ctx, exec, SelectFilterByNameSQL, filter) + if err != nil { + return err + } + + if len(rows) == 0 { + return errors.Errorf("cannot find filter with name '%s'", filter) + } + + insertSQL := InsertFilterRuleSQL + if replace { + insertSQL = ReplaceFilterRuleSQL + } + + if _, err = executeSQL(m.ctx, exec, insertSQL, user, filter); err != nil { + return err + } + return nil + }) + + if err != nil { + return err + } + + return m.notifyFilterConfigUpdated() +} + +// RemoveLogFilterRule removes a log filter rule +func (m *LogManager) RemoveLogFilterRule(user, filter string) error { + err := m.runSQLInExecutor(func(exec sqlexec.SQLExecutor) error { + if _, err := executeSQL(m.ctx, exec, DeleteFilterRuleSQL, user, filter); err != nil { + return err + } + return nil + }) + + if err != nil { + return err + } + + return m.notifyFilterConfigUpdated() +} + +// SwitchLogFilterRuleEnabled switches a log filter rule to enabled/disabled +func (m *LogManager) SwitchLogFilterRuleEnabled(user, filter string, enabled bool) error { + err := m.runInTxn(m.ctx, func(exec sqlexec.SQLExecutor) error { + rows, err := executeSQL(m.ctx, exec, SelectFilterRuleByUserAndFilterSQL, user, filter) + if err != nil { + return err + } + + if len(rows) == 0 { + return errors.New("rule does not exist") + } + + enabledVal := 0 + if enabled { + enabledVal = 1 + } + + if _, err = executeSQL(m.ctx, exec, UpdateFilterRuleEnabledSQL, enabledVal, user, filter); err != nil { + return err + } + return nil + }) + + if err != nil { + return err + } + + return m.notifyFilterConfigUpdated() +} + +func (m *LogManager) logger() *Logger { + if ref, ok := m.loggerRef.Load().(*Logger); ok { + return ref + } + return nil +} + +// GetSessionHandler returns `*extension.SessionHandler` to register +func (m *LogManager) GetSessionHandler() *extension.SessionHandler { + id := newEntryIDGenerator() + return &extension.SessionHandler{ + OnConnectionEvent: func(tp extension.ConnEventTp, info *extension.ConnEventInfo) { + if logger := m.logger(); logger.cfg.Enabled { + newConnEventEntry(tp, info).Filter(logger.cfg.Filter).Log(logger, id) + } + }, + OnStmtEvent: func(tp extension.StmtEventTp, info extension.StmtEventInfo) { + if logger := m.logger(); logger.cfg.Enabled { + newStmtEventEntry(tp, info).Filter(logger.cfg.Filter).Log(logger, id) + } + }, + } +} diff --git a/extension/serverless/audit/register.go b/extension/serverless/audit/register.go new file mode 100644 index 0000000000000..2cce8245fe515 --- /dev/null +++ b/extension/serverless/audit/register.go @@ -0,0 +1,362 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package audit + +import ( + "context" + "encoding/json" + "flag" + "strconv" + + "github.com/pingcap/tidb/config" + "github.com/pingcap/tidb/extension" + "github.com/pingcap/tidb/parser/terror" + "github.com/pingcap/tidb/sessionctx/variable" + "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/util/chunk" +) + +// ExtensionName is the extension name for audit log +const ExtensionName = "Audit" + +const ( + // PrivAuditAdmin is the dynamic privilege that allows the admin of audit log + PrivAuditAdmin = "AUDIT_ADMIN" + // PrivRestrictedAuditAdmin is the same with `PRIV_AUDIT_ADMIN` but used in sem mode. + PrivRestrictedAuditAdmin = "RESTRICTED_AUDIT_ADMIN" +) + +const ( + // TiDBAuditEnabled is a system variable that used to enable/disable audit feature + TiDBAuditEnabled = "tidb_audit_enabled" + // TiDBAuditLog is a system variable that used to set audit log file path + TiDBAuditLog = "tidb_audit_log" + // TiDBAuditLogFormat is a system variable that used to set the output format of the audit log + TiDBAuditLogFormat = "tidb_audit_log_format" + // TiDBAuditLogMaxFileSize is a system variable that used to control the max size of a log file + TiDBAuditLogMaxFileSize = "tidb_audit_log_max_filesize" + // TiDBAuditLogMaxLifetime is a system variable that used to control the max lifetime of a log file + TiDBAuditLogMaxLifetime = "tidb_audit_log_max_lifetime" + // TiDBAuditLogReservedBackups is a system variable that used to control the max backup count of the logs + TiDBAuditLogReservedBackups = "tidb_audit_log_reserved_backups" + // TiDBAuditLogMaxReservedDays is a system variable that used to control the max days of the reserved log + TiDBAuditLogMaxReservedDays = "tidb_audit_log_reserved_days" + // TiDBAuditRedactLog is a system variable that used to switch on/off of redact log + TiDBAuditRedactLog = "tidb_audit_log_redacted" +) + +const ( + // FuncAuditLogRotate is used to rotate audit log + FuncAuditLogRotate = "audit_log_rotate" +) + +// globalLogManager manages the global state of the audit log +var globalLogManager LogManager + +var enableAuditExtension bool + +// Register registers enterprise extension +func Register() { + flag.BoolVar(&enableAuditExtension, "enable-audit-extension", true, "enable audit extension or not") + err := extension.RegisterFactory(ExtensionName, extensionOptionsFactory) + terror.MustNil(err) +} + +func register4Test() { + extension.Reset() + enableAuditExtension = true + err := extension.RegisterFactory(ExtensionName, extensionOptionsFactory) + terror.MustNil(err) +} + +// RegisterForServerless register serverless version audit log with configuration +func RegisterForServerless(conf *config.AuditLog) error { + if err := globalLogManager.InitWithConfig(conf); err != nil { + return err + } + return extension.RegisterFactory(ExtensionName, func() ([]extension.Option, error) { + return []extension.Option{ + extension.WithSessionHandlerFactory(globalLogManager.GetSessionHandler), + extension.WithClose(globalLogManager.Close), + }, nil + }) +} + +func registerForServerlessTest(conf *config.AuditLog) { + extension.Reset() + err := RegisterForServerless(conf) + terror.MustNil(err) +} + +func requirePrivileges(sem bool) []string { + if sem { + return []string{PrivRestrictedAuditAdmin} + } + return []string{PrivAuditAdmin} +} + +func extensionOptionsFactory() ([]extension.Option, error) { + if !enableAuditExtension { + return nil, nil + } + + if err := globalLogManager.Init(); err != nil { + return nil, err + } + + setGlobalFunc := func(sysVar string, fn func(context.Context, *variable.SessionVars, string) error) func(context.Context, *variable.SessionVars, string) error { + return func(ctx context.Context, vars *variable.SessionVars, val string) (err error) { + defer func() { + newAuditEntryWithSetGlobalVar(sysVar, val, vars, err).Log(globalLogManager.logger(), newEntryIDGenerator()) + }() + return fn(ctx, vars, val) + } + } + + statusFunc := func(name string, fn func(extension.FunctionContext, []types.Datum) error) func(ctx extension.FunctionContext, row chunk.Row) (string, bool, error) { + return func(ctx extension.FunctionContext, row chunk.Row) (_ string, _ bool, err error) { + var args []types.Datum + defer func() { + newAuditFuncCallEntry(name, args, ctx, err).Log(globalLogManager.logger(), newEntryIDGenerator()) + }() + + args, err = ctx.EvalArgs(row) + if err != nil { + return "", false, err + } + + if err = fn(ctx, args); err != nil { + return "", false, err + } + return "OK", false, nil + } + } + + sysVarPrivileges := func(_ bool, sem bool) []string { + return requirePrivileges(sem) + } + + sysVars := []*variable.SysVar{ + { + Name: TiDBAuditEnabled, + Scope: variable.ScopeGlobal, + Type: variable.TypeBool, + Value: variable.BoolToOnOff(globalLogManager.Enabled()), + SetGlobal: setGlobalFunc(TiDBAuditEnabled, func(_ context.Context, sessVars *variable.SessionVars, s string) (err error) { + enable := variable.TiDBOptOn(s) + return globalLogManager.SetEnabled(enable) + }), + RequireDynamicPrivileges: sysVarPrivileges, + }, + { + Name: TiDBAuditLog, + Scope: variable.ScopeGlobal, + Type: variable.TypeStr, + Value: globalLogManager.GetLogConfigPath(), + SetGlobal: setGlobalFunc(TiDBAuditLog, func(_ context.Context, _ *variable.SessionVars, s string) error { + return globalLogManager.SetLogConfigPath(s) + }), + RequireDynamicPrivileges: sysVarPrivileges, + }, + { + Name: TiDBAuditLogFormat, + Scope: variable.ScopeGlobal, + Type: variable.TypeEnum, + Value: globalLogManager.GetLogFormat(), + PossibleValues: []string{LogFormatText, LogFormatJSON}, + SetGlobal: setGlobalFunc(TiDBAuditLogFormat, func(_ context.Context, _ *variable.SessionVars, s string) error { + return globalLogManager.SetLogFormat(s) + }), + RequireDynamicPrivileges: sysVarPrivileges, + }, + { + Name: TiDBAuditLogMaxFileSize, + Scope: variable.ScopeGlobal, + Type: variable.TypeInt, + Value: strconv.FormatInt(globalLogManager.GetFileMaxSize(), 10), + MinValue: 0, + MaxValue: MaxAuditLogFileMaxSize, + Validation: func(vars *variable.SessionVars, normalizedValue string, originalValue string, scope variable.ScopeFlag) (string, error) { + val, err := strconv.ParseInt(normalizedValue, 10, 64) + if err != nil { + return normalizedValue, err + } + if val == 0 { + return strconv.FormatInt(DefAuditLogFileMaxSize, 10), err + } + return normalizedValue, err + }, + SetGlobal: setGlobalFunc(TiDBAuditLogMaxFileSize, func(_ context.Context, _ *variable.SessionVars, s string) error { + val, err := strconv.ParseInt(s, 10, 64) + if err != nil { + return err + } + return globalLogManager.SetFileMaxSize(val) + }), + RequireDynamicPrivileges: sysVarPrivileges, + }, + { + Name: TiDBAuditLogMaxLifetime, + Scope: variable.ScopeGlobal, + Type: variable.TypeInt, + Value: strconv.FormatInt(globalLogManager.GetFileMaxLifetime(), 10), + MinValue: 0, + MaxValue: MaxAuditLogFileMaxLifetime, + SetGlobal: setGlobalFunc(TiDBAuditLogMaxLifetime, func(_ context.Context, _ *variable.SessionVars, s string) error { + val, err := strconv.ParseInt(s, 10, 64) + if err != nil { + return err + } + return globalLogManager.SetFileMaxLifetime(val) + }), + RequireDynamicPrivileges: sysVarPrivileges, + }, + { + Name: TiDBAuditLogReservedBackups, + Scope: variable.ScopeGlobal, + Type: variable.TypeInt, + Value: strconv.Itoa(globalLogManager.GetFileReservedBackups()), + MinValue: 0, + MaxValue: MaxAuditLogFileReservedBackups, + SetGlobal: setGlobalFunc(TiDBAuditLogReservedBackups, func(ctx context.Context, vars *variable.SessionVars, s string) error { + val, err := strconv.Atoi(s) + if err != nil { + return err + } + return globalLogManager.SetFileReservedBackups(val) + }), + RequireDynamicPrivileges: sysVarPrivileges, + }, + { + Name: TiDBAuditLogMaxReservedDays, + Scope: variable.ScopeGlobal, + Type: variable.TypeInt, + Value: strconv.Itoa(globalLogManager.GetFileReservedDays()), + MinValue: 0, + MaxValue: MaxAuditLogFileReservedDays, + SetGlobal: setGlobalFunc(TiDBAuditLogMaxReservedDays, func(ctx context.Context, vars *variable.SessionVars, s string) error { + val, err := strconv.Atoi(s) + if err != nil { + return err + } + return globalLogManager.SetFileReservedDays(val) + }), + RequireDynamicPrivileges: sysVarPrivileges, + }, + { + Name: TiDBAuditRedactLog, + Scope: variable.ScopeGlobal, + Type: variable.TypeBool, + Value: variable.BoolToOnOff(globalLogManager.RedactLog()), + SetGlobal: setGlobalFunc(TiDBAuditRedactLog, func(_ context.Context, _ *variable.SessionVars, s string) error { + return globalLogManager.SetRedactLog(variable.TiDBOptOn(s)) + }), + RequireDynamicPrivileges: sysVarPrivileges, + }, + } + + functions := []*extension.FunctionDef{ + { + Name: FuncAuditLogRotate, + EvalTp: types.ETString, + EvalStringFunc: statusFunc(FuncAuditLogRotate, func(ctx extension.FunctionContext, _ []types.Datum) error { + return globalLogManager.GlobalRotateLog() + }), + RequireDynamicPrivileges: requirePrivileges, + }, + { + Name: FuncAuditLogCreateFilter, + EvalTp: types.ETString, + ArgTps: []types.EvalType{types.ETString, types.ETString, types.ETInt}, + OptionalArgsLen: 1, + EvalStringFunc: statusFunc(FuncAuditLogCreateFilter, func(ctx extension.FunctionContext, args []types.Datum) error { + filterName := args[0].GetString() + var filter logFilter + if err := json.Unmarshal(args[1].GetBytes(), &filter); err != nil { + return err + } + + replace := false + if len(args) > 2 { + replace = args[2].GetInt64() == 1 + } + + filter.Name = filterName + return globalLogManager.CreateLogFilter(&filter, replace) + }), + RequireDynamicPrivileges: requirePrivileges, + }, + { + Name: FuncAuditLogRemoveFilter, + EvalTp: types.ETString, + ArgTps: []types.EvalType{types.ETString}, + EvalStringFunc: statusFunc(FuncAuditLogRemoveFilter, func(ctx extension.FunctionContext, args []types.Datum) error { + return globalLogManager.RemoveLogFilter(args[0].GetString()) + }), + RequireDynamicPrivileges: requirePrivileges, + }, + { + Name: FuncAuditLogCreateRule, + EvalTp: types.ETString, + ArgTps: []types.EvalType{types.ETString, types.ETString, types.ETInt}, + OptionalArgsLen: 1, + EvalStringFunc: statusFunc(FuncAuditLogCreateRule, func(ctx extension.FunctionContext, args []types.Datum) error { + replace := false + if len(args) > 2 { + replace = args[2].GetInt64() == 1 + } + return globalLogManager.CreateLogFilterRule(args[0].GetString(), args[1].GetString(), replace) + }), + RequireDynamicPrivileges: requirePrivileges, + }, + { + Name: FuncAuditLogRemoveRule, + EvalTp: types.ETString, + ArgTps: []types.EvalType{types.ETString, types.ETString}, + EvalStringFunc: statusFunc(FuncAuditLogRemoveRule, func(ctx extension.FunctionContext, args []types.Datum) error { + return globalLogManager.RemoveLogFilterRule(args[0].GetString(), args[1].GetString()) + }), + RequireDynamicPrivileges: requirePrivileges, + }, + { + Name: FuncAuditLogEnableRule, + EvalTp: types.ETString, + ArgTps: []types.EvalType{types.ETString, types.ETString}, + EvalStringFunc: statusFunc(FuncAuditLogEnableRule, func(ctx extension.FunctionContext, args []types.Datum) error { + return globalLogManager.SwitchLogFilterRuleEnabled(args[0].GetString(), args[1].GetString(), true) + }), + RequireDynamicPrivileges: requirePrivileges, + }, + { + Name: FuncAuditLogDisableRule, + EvalTp: types.ETString, + ArgTps: []types.EvalType{types.ETString, types.ETString}, + EvalStringFunc: statusFunc(FuncAuditLogDisableRule, func(ctx extension.FunctionContext, args []types.Datum) error { + return globalLogManager.SwitchLogFilterRuleEnabled(args[0].GetString(), args[1].GetString(), false) + }), + RequireDynamicPrivileges: requirePrivileges, + }, + } + + return []extension.Option{ + extension.WithCustomDynPrivs([]string{PrivAuditAdmin, PrivRestrictedAuditAdmin}), + extension.WithCustomSysVariables(sysVars), + extension.WithSessionHandlerFactory(globalLogManager.GetSessionHandler), + extension.WithCustomFunctions(functions), + extension.WithBootstrap(globalLogManager.Bootstrap), + extension.WithCustomAccessCheck(checkTableAccess), + extension.WithClose(globalLogManager.Close), + }, nil +} diff --git a/extension/serverless/audit/sysvar_test.go b/extension/serverless/audit/sysvar_test.go new file mode 100644 index 0000000000000..da872c7522528 --- /dev/null +++ b/extension/serverless/audit/sysvar_test.go @@ -0,0 +1,447 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package audit + +import ( + "context" + "fmt" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/pingcap/tidb/config" + "github.com/pingcap/tidb/parser/auth" + "github.com/pingcap/tidb/server" + "github.com/pingcap/tidb/sessionctx/stmtctx" + "github.com/pingcap/tidb/testkit" + "github.com/stretchr/testify/require" +) + +/* + 1. Any statement executed by MockConn.HandleQuery() would be trailed. + 2. Due to the different implementation of `testkit`, statements executed by testkit.MustExec() could only trail SET GLOBAL and FUNCTION CALL. +*/ + +var ( + workDir string +) + +func init() { + workDir, _ = os.Getwd() +} + +func getAllAuditLogs(dir, prefix, suffix string) ([]string, error) { + f, err := os.Open(dir) + if err != nil { + return nil, err + } + files, err := f.Readdir(0) + if err != nil { + return nil, err + } + + var res []string + for _, file := range files { + if name := file.Name(); strings.HasPrefix(name, prefix) && strings.HasSuffix(name, suffix) { + res = append(res, filepath.Join(dir, name)) + } + } + + return res, nil +} + +// return all audit logs +func deleteAllAuditLogs(dir, prefix, suffix string) ([]string, error) { + files, err := getAllAuditLogs(dir, prefix, suffix) + if err != nil { + return nil, err + } + for _, file := range files { + err = os.Remove(file) + if err != nil { + return nil, err + } + } + return files, nil +} + +func containsMessage(logpath, msg string) (ok bool, str string, err error) { + bytes, err := os.ReadFile(logpath) + str = string(bytes) + if err != nil { + return false, str, err + } + return strings.Contains(str, msg), str, nil +} + +func TestAuditEnabled(t *testing.T) { + _, err := deleteAllAuditLogs(workDir, "tidb-audit", ".log") + require.NoError(t, err) + + register4Test() + store := testkit.CreateMockStore(t) + srv := server.CreateMockServer(t, store) + defer srv.Close() + conn := server.CreateMockConn(t, srv) + defer conn.Close() + tk := testkit.NewTestKit(t, store) + err = tk.Session().Auth(&auth.UserIdentity{Username: "root", Hostname: "%"}, nil, nil, nil) + require.NoError(t, err) + + tk.MustQuery("SELECT @@global.tidb_audit_enabled").Check(testkit.Rows("0")) + require.False(t, globalLogManager.Enabled()) + require.NoError(t, conn.HandleQuery(context.Background(), "SET global tidb_audit_enabled = 1")) + tk.MustQuery("SELECT @@global.tidb_audit_enabled").Check(testkit.Rows("1")) + require.True(t, globalLogManager.Enabled()) + require.Equal(t, filepath.Join(workDir, DefAuditLogName), globalLogManager.getLogPath()) + time.Sleep(time.Second) + + files, err := deleteAllAuditLogs(workDir, "tidb-audit", ".log") + require.NoError(t, err) + require.True(t, len(files) > 0, files) +} + +func TestAuditLogDefault(t *testing.T) { + register4Test() + tempDir := t.TempDir() + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + + tk.MustQuery("SELECT @@global.tidb_audit_log").Check(testkit.Rows(DefAuditLogName)) + require.Equal(t, DefAuditLogName, globalLogManager.GetLogConfigPath()) + require.Equal(t, filepath.Join(workDir, DefAuditLogName), globalLogManager.getLogPath()) + + // Set empty log name + tk.MustExec("SET global tidb_audit_log = ''") + tk.MustQuery("SELECT @@global.tidb_audit_log").Check(testkit.Rows("")) + require.Equal(t, "", globalLogManager.GetLogConfigPath()) + require.Equal(t, filepath.Join(workDir, DefAuditLogName), globalLogManager.getLogPath()) + + // Set custom log name + tk.MustExec(fmt.Sprintf("SET global tidb_audit_log = '%s'", "custom-tidb-audit.log")) + tk.MustQuery("SELECT @@global.tidb_audit_log").Check(testkit.Rows("custom-tidb-audit.log")) + require.Equal(t, "custom-tidb-audit.log", globalLogManager.GetLogConfigPath()) + require.Equal(t, filepath.Join(workDir, "custom-tidb-audit.log"), globalLogManager.getLogPath()) + rawAddr, rawPort := config.GetGlobalConfig().AdvertiseAddress, config.GetGlobalConfig().Port + config.GetGlobalConfig().AdvertiseAddress = "127.0.0.1" + config.GetGlobalConfig().Port = 4000 + defer func() { + config.GetGlobalConfig().AdvertiseAddress, config.GetGlobalConfig().Port = rawAddr, rawPort + }() + tk.MustExec(fmt.Sprintf("SET global tidb_audit_log = '%s'", "custom-tidb-audit-%e.log")) + tk.MustQuery("SELECT @@global.tidb_audit_log").Check(testkit.Rows("custom-tidb-audit-%e.log")) + require.Equal(t, "custom-tidb-audit-%e.log", globalLogManager.GetLogConfigPath()) + require.Equal(t, filepath.Join(workDir, "custom-tidb-audit-127-0-0-1-4000.log"), globalLogManager.getLogPath()) + + // Write to absolute path + tempDirLog := filepath.Join(tempDir, "temp-audit.log") + tk.MustExec(fmt.Sprintf("SET global tidb_audit_log = '%s'", tempDirLog)) + tk.MustQuery("SELECT @@global.tidb_audit_log").Check(testkit.Rows(tempDirLog)) + require.Equal(t, tempDirLog, globalLogManager.GetLogConfigPath()) + require.Equal(t, tempDirLog, globalLogManager.getLogPath()) +} + +func TestAuditLogDefaultInstanceLog(t *testing.T) { + tempDir := t.TempDir() + // Default to the same directory as tidb instance logs + config.GetGlobalConfig().Log.File.Filename = filepath.Join(tempDir, "tidb.log") + rawAddr, rawPort := config.GetGlobalConfig().AdvertiseAddress, config.GetGlobalConfig().Port + config.GetGlobalConfig().AdvertiseAddress = "127.0.0.1" + config.GetGlobalConfig().Port = 4000 + defer func() { + config.GetGlobalConfig().Log.File.Filename = "" + config.GetGlobalConfig().AdvertiseAddress, config.GetGlobalConfig().Port = rawAddr, rawPort + }() + register4Test() + + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec(fmt.Sprintf("SET global tidb_audit_log = '%s'", DefAuditLogName)) + tk.MustQuery("SELECT @@global.tidb_audit_log").Check(testkit.Rows(DefAuditLogName)) + require.Equal(t, DefAuditLogName, globalLogManager.GetLogConfigPath()) + require.Equal(t, filepath.Join(tempDir, DefAuditLogName), globalLogManager.getLogPath()) + + tk.MustExec("SET global tidb_audit_log = 'tidb-audit-%e.log'") + tk.MustQuery("SELECT @@global.tidb_audit_log").Check(testkit.Rows("tidb-audit-%e.log")) + require.Equal(t, "tidb-audit-%e.log", globalLogManager.GetLogConfigPath()) + require.Equal(t, filepath.Join(tempDir, "tidb-audit-127-0-0-1-4000.log"), globalLogManager.getLogPath()) +} + +func TestAuditLogFormat(t *testing.T) { + register4Test() + store := testkit.CreateMockStore(t) + srv := server.CreateMockServer(t, store) + defer srv.Close() + conn := server.CreateMockConn(t, srv) + defer conn.Close() + tk := testkit.NewTestKit(t, store) + err := tk.Session().Auth(&auth.UserIdentity{Username: "root", Hostname: "%"}, nil, nil, nil) + require.NoError(t, err) + + _, err = deleteAllAuditLogs(workDir, "tidb-audit", ".log") + require.NoError(t, err) + _, err = deleteAllAuditLogs(workDir, "tidb-audit", ".json") + require.NoError(t, err) + + tk.MustQuery("SELECT @@global.tidb_audit_log_format").Check(testkit.Rows("TEXT")) + require.NoError(t, conn.HandleQuery(context.Background(), "SET global tidb_audit_enabled = 1")) + require.Equal(t, DefAuditLogName, globalLogManager.GetLogConfigPath()) + require.Equal(t, filepath.Join(workDir, DefAuditLogName), globalLogManager.getLogPath()) + require.Equal(t, "TEXT", globalLogManager.GetLogFormat()) + + require.NoError(t, conn.HandleQuery(context.Background(), "SET global tidb_audit_log_format = 'JSON'")) + tk.MustQuery("SELECT @@global.tidb_audit_log_format").Check(testkit.Rows("JSON")) + require.Equal(t, DefAuditLogName, globalLogManager.GetLogConfigPath()) + require.Equal(t, filepath.Join(workDir, DefAuditLogName+".json"), globalLogManager.getLogPath()) + require.Equal(t, "JSON", globalLogManager.GetLogFormat()) + + files, err := deleteAllAuditLogs(workDir, "tidb-audit", ".log") + require.NoError(t, err) + require.Equal(t, 1, len(files), files) + files, err = deleteAllAuditLogs(workDir, "tidb-audit", ".json") + require.NoError(t, err) + require.Equal(t, 1, len(files), files) +} + +func TestAuditLogMaxSize(t *testing.T) { + register4Test() + store := testkit.CreateMockStore(t) + srv := server.CreateMockServer(t, store) + defer srv.Close() + conn := server.CreateMockConn(t, srv) + defer conn.Close() + tk := testkit.NewTestKit(t, store) + + // test setting + tk.MustQuery("SELECT @@global.tidb_audit_log_max_filesize").Check(testkit.Rows("100")) + require.Equal(t, DefAuditLogFileMaxSize, globalLogManager.GetFileMaxSize()) + tk.MustExec("SET global tidb_audit_log_max_filesize = 102401") + tk.MustQuery("SELECT @@global.tidb_audit_log_max_filesize").Check(testkit.Rows("102400")) + require.Equal(t, int64(102400), globalLogManager.GetFileMaxSize()) + tk.MustExec("SET global tidb_audit_log_max_filesize = 0") + tk.MustQuery("SELECT @@global.tidb_audit_log_max_filesize").Check(testkit.Rows("100")) + require.Equal(t, DefAuditLogFileMaxSize, globalLogManager.GetFileMaxSize()) + + // test rotation + tk.MustExec("SET global tidb_audit_log_max_filesize = 1") + _, err := deleteAllAuditLogs(workDir, "tidb-audit", ".log") + require.NoError(t, err) + for i := 0; i <= 10_240; i++ { + require.NoError(t, conn.HandleQuery(context.Background(), "set global tidb_audit_enabled = 1")) + } + files, err := deleteAllAuditLogs(workDir, "tidb-audit", ".log") + require.NoError(t, err) + require.Greater(t, len(files), 1, files) +} + +func TestAuditLogMaxLifetime(t *testing.T) { + register4Test() + tempDir := t.TempDir() + + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + + _, err := deleteAllAuditLogs(tempDir, "tidb-audit", ".log") + require.NoError(t, err) + + tk.MustQuery("SELECT @@global.tidb_audit_enabled").Check(testkit.Rows("0")) + tk.MustQuery("SELECT @@global.tidb_audit_log_max_lifetime").Check(testkit.Rows("86400")) + + tk.MustExec(fmt.Sprintf("SET global tidb_audit_log = '%s'", filepath.Join(tempDir, DefAuditLogName))) + tk.MustQuery("SELECT @@global.tidb_audit_log").Check(testkit.Rows(filepath.Join(tempDir, DefAuditLogName))) + + tk.MustExec("SET global tidb_audit_enabled = 1") + tk.MustQuery("SELECT @@global.tidb_audit_enabled").Check(testkit.Rows("1")) + tk.MustExec("SET global tidb_audit_log_max_lifetime = 3") + tk.MustQuery("SELECT @@global.tidb_audit_log_max_lifetime").Check(testkit.Rows("3")) + require.Equal(t, int64(3), globalLogManager.GetFileMaxLifetime()) + // Generate 3 log files + time.Sleep(10 * time.Second) + files, err := getAllAuditLogs(tempDir, "tidb-audit", ".log") + require.NoError(t, err) + require.Equal(t, 3, len(files), files) + + tk.MustExec("SET global tidb_audit_log_max_lifetime = 0") + tk.MustQuery("SELECT @@global.tidb_audit_log_max_lifetime").Check(testkit.Rows("0")) + require.Equal(t, int64(0), globalLogManager.GetFileMaxLifetime()) + // Not generate log file anymore + time.Sleep(7 * time.Second) + + files, err = deleteAllAuditLogs(tempDir, "tidb-audit", ".log") + require.NoError(t, err) + require.Equal(t, 3, len(files), files) +} + +func TestAuditLogReservedBackups(t *testing.T) { + register4Test() + tempDir := t.TempDir() + + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + + tk.MustQuery("SELECT @@global.tidb_audit_log_reserved_backups").Check(testkit.Rows("10")) + + tk.MustExec("SET global tidb_audit_enabled = 1") + tk.MustQuery("SELECT @@global.tidb_audit_enabled").Check(testkit.Rows("1")) + tk.MustExec(fmt.Sprintf("SET global tidb_audit_log = '%s'", filepath.Join(tempDir, DefAuditLogName))) + tk.MustQuery("SELECT @@global.tidb_audit_log").Check(testkit.Rows(filepath.Join(tempDir, DefAuditLogName))) + + _, err := deleteAllAuditLogs(tempDir, "tidb-audit", ".log") + require.NoError(t, err) + + // Generate 5 log files, only two files are reserved + tk.MustExec("SET global tidb_audit_log_reserved_backups = 2") + tk.MustQuery("SELECT @@global.tidb_audit_log_reserved_backups").Check(testkit.Rows("2")) + require.Equal(t, 2, globalLogManager.GetFileReservedBackups()) + tk.MustExec("SET global tidb_audit_log_max_lifetime = 3") + tk.MustQuery("SELECT @@global.tidb_audit_log_max_lifetime").Check(testkit.Rows("3")) + time.Sleep(16 * time.Second) + tk.MustExec("SET global tidb_audit_log_max_lifetime = 0") + tk.MustQuery("SELECT @@global.tidb_audit_log_max_lifetime").Check(testkit.Rows("0")) + files, err := deleteAllAuditLogs(tempDir, "tidb-audit-2", ".log") + require.NoError(t, err) + require.Equal(t, 2, len(files), files) + + // Generate 5 log files, all are reserved + tk.MustExec("SET global tidb_audit_log_reserved_backups = 0") + tk.MustQuery("SELECT @@global.tidb_audit_log_reserved_backups").Check(testkit.Rows("0")) + require.Equal(t, 0, globalLogManager.GetFileReservedBackups()) + tk.MustExec("SET global tidb_audit_log_max_lifetime = 3") + tk.MustQuery("SELECT @@global.tidb_audit_log_max_lifetime").Check(testkit.Rows("3")) + time.Sleep(16 * time.Second) + tk.MustExec("SET global tidb_audit_log_max_lifetime = 0") + tk.MustQuery("SELECT @@global.tidb_audit_log_max_lifetime").Check(testkit.Rows("0")) + files, err = deleteAllAuditLogs(tempDir, "tidb-audit-", ".log") + require.NoError(t, err) + require.Equal(t, 5, len(files), files) +} + +// func TestAuditLogReservedDays(t *testing.T) { +// register4Test() +// store := testkit.CreateMockStore(t) +// srv := server.CreateMockServer(t, store) +// defer srv.Close() +// conn := server.CreateMockConn(t, srv) +// defer conn.Close() +// tk := testkit.NewTestKit(t, store) +// _, err := deleteAllAuditLogs(workDir, "tidb-audit", ".log") +// require.NoError(t, err) +// +// require.NoError(t, conn.HandleQuery(context.Background(), "SET global tidb_audit_enabled = 1")) +// tk.MustQuery("SELECT @@global.tidb_audit_log_reserved_days").Check(testkit.Rows("0")) +// require.NoError(t, conn.HandleQuery(context.Background(), "SELECT audit_log_rotate()")) +// files, err := getAllAuditLogs(workDir, "tidb-audit-", ".log") +// require.NoError(t, err) +// require.Equal(t, 1, len(files), files) +// deprecatedLog := files[0] +// oldLog := filepath.Join(filepath.Dir(deprecatedLog), "tidb-audit-2020-01-01T14-17-15.536.log") +// logutil.BgLogger().Info("TestAuditLogReservedDays", zap.String("deprecatedLog", deprecatedLog), zap.String("oldLog", oldLog)) +// +// require.NoError(t, os.Rename(deprecatedLog, oldLog)) +// require.NoError(t, conn.HandleQuery(context.Background(), "SELECT audit_log_rotate()")) +// time.Sleep(3 * time.Second) +// // Currently 3 logs: tidb-audit.log, oldLog, and another new one +// files, err = getAllAuditLogs(workDir, "tidb-audit", ".log") +// require.NoError(t, err) +// require.Equal(t, 3, len(files), files) +// require.NotContains(t, files, deprecatedLog, "%v", files) +// require.Contains(t, files, oldLog, "%v", files) +// require.NoError(t, conn.HandleQuery(context.Background(), "SET global tidb_audit_log_reserved_days = 1")) +// tk.MustQuery("SELECT @@global.tidb_audit_log_reserved_days").Check(testkit.Rows("1")) +// require.Equal(t, 1, globalLogManager.GetFileReservedDays()) +// files, err = deleteAllAuditLogs(workDir, "tidb-audit", ".log") +// // Currently 2 logs: tidb-audit.log, and another new one +// require.NoError(t, err) +// require.Equal(t, 2, len(files), files) +// require.NotContains(t, files, oldLog, "%v", files) +// } + +func TestAuditLogRedact(t *testing.T) { + register4Test() + tempDir := t.TempDir() + + store := testkit.CreateMockStore(t) + srv := server.CreateMockServer(t, store) + defer srv.Close() + conn := server.CreateMockConn(t, srv) + defer conn.Close() + tk := testkit.NewTestKit(t, store) + err := tk.Session().Auth(&auth.UserIdentity{Username: "root", Hostname: "%"}, nil, nil, nil) + require.NoError(t, err) + tk.MustExec("SET global tidb_audit_enabled = 1") + tk.MustExec(fmt.Sprintf("SET global tidb_audit_log = '%s'", filepath.Join(tempDir, DefAuditLogName))) + + _, err = deleteAllAuditLogs(tempDir, "tidb-audit", ".log") + require.NoError(t, err) + + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t (c int)") + tk.MustQuery("select audit_log_create_filter('all', '{}')").Check(testkit.Rows("OK")) + tk.MustQuery("select content from mysql.audit_log_filters where filter_name = 'all'").Check(testkit.Rows(`{}`)) + tk.MustQuery("select audit_log_create_rule('%@%', 'all')").Check(testkit.Rows("OK")) + tk.MustQuery("select enabled from mysql.audit_log_filter_rules where user = '%@%' and filter_name = 'all'").Check(testkit.Rows("1")) + + require.NoError(t, conn.HandleQuery(context.Background(), "use test")) + testcases := []struct { + sql string + redactedSQL string + }{ + {"insert into t values (1), (2), (3)", ""}, + } + + sctx := &stmtctx.StatementContext{} + for i := range testcases { + sctx.OriginalSQL = testcases[i].sql + testcases[i].redactedSQL, _ = sctx.SQLDigest() + } + + for _, testcase := range testcases { + // default on + tk.MustExec("SET global tidb_audit_log_redacted = 1") + tk.MustQuery("select @@global.tidb_audit_log_redacted").Check(testkit.Rows("1")) + require.NoError(t, conn.HandleQuery(context.Background(), testcase.sql)) + ok, log, err := containsMessage(filepath.Join(tempDir, DefAuditLogName), testcase.redactedSQL) + require.NoError(t, err) + require.True(t, ok, log) + + // turn off + tk.MustExec("SET global tidb_audit_log_redacted = 0") + tk.MustQuery("select @@global.tidb_audit_log_redacted").Check(testkit.Rows("0")) + require.NoError(t, conn.HandleQuery(context.Background(), testcase.sql)) + ok, log, err = containsMessage(filepath.Join(tempDir, DefAuditLogName), testcase.sql) + require.NoError(t, err) + require.True(t, ok, log) + } + + // Statements containing password would always be redacted + tk.MustExec("SET global tidb_audit_log_redacted = 1") + require.True(t, globalLogManager.RedactLog()) + require.NoError(t, conn.HandleQuery(context.Background(), "create user if not exists testuser identified by '1234'")) + ok, log, err := containsMessage(filepath.Join(tempDir, DefAuditLogName), "create user if not exists `testuser` identified by ?") + require.NoError(t, err) + require.True(t, ok, log) + tk.MustExec("SET global tidb_audit_log_redacted = 0") + require.False(t, globalLogManager.RedactLog()) + require.NoError(t, conn.HandleQuery(context.Background(), "create user if not exists testuser identified by '1234'")) + ok, log, err = containsMessage(filepath.Join(tempDir, DefAuditLogName), "create user if not exists `testuser` identified by ?") + require.NoError(t, err) + require.True(t, ok, log) + + _, err = deleteAllAuditLogs(tempDir, "tidb-audit", ".log") + require.NoError(t, err) +} diff --git a/extension/serverless/audit/util.go b/extension/serverless/audit/util.go new file mode 100644 index 0000000000000..3b89763a10e98 --- /dev/null +++ b/extension/serverless/audit/util.go @@ -0,0 +1,185 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package audit + +import ( + "context" + "fmt" + "math" + "strings" + "time" + + "github.com/google/uuid" + "github.com/pingcap/errors" + "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/parser/terror" + "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/logutil" + "github.com/pingcap/tidb/util/sqlexec" + tablefilter "github.com/pingcap/tidb/util/table-filter" + clientv3 "go.etcd.io/etcd/client/v3" + "go.uber.org/zap" +) + +// UserHostSep is a separator to separate user and hot +const UserHostSep = "@" + +type keyWatcher struct { + ctx context.Context + etcd *clientv3.Client + prefix string + Chan clientv3.WatchChan + chCreateTime time.Time +} + +func newKeyWatcher(ctx context.Context, etcd *clientv3.Client, prefix string) *keyWatcher { + return &keyWatcher{ + ctx: ctx, + etcd: etcd, + prefix: prefix, + Chan: etcd.Watch(ctx, prefix, clientv3.WithPrefix()), + chCreateTime: time.Now(), + } +} + +func (w *keyWatcher) RenewClosedChan() { + var sleep time.Duration + if interval := time.Since(w.chCreateTime); interval < MinEtcdReWatchInterval { + sleep = interval + } + + logutil.BgLogger().Info("audit watch chan closed, renew it", zap.String("watch", w.prefix), zap.Duration("sleep", sleep)) + i := 0 + for time.Since(w.chCreateTime) < MinEtcdReWatchInterval { + if err := w.ctx.Err(); err != nil { + return + } + + i++ + if i <= 5 { + time.Sleep(10 * time.Millisecond) + } else { + time.Sleep(100 * time.Millisecond) + } + } + + w.Chan = w.etcd.Watch(w.ctx, w.prefix, clientv3.WithPrefix()) + w.chCreateTime = time.Now() +} + +type entryIDGenerator struct { + base string + seq uint +} + +func newEntryIDGenerator() *entryIDGenerator { + return &entryIDGenerator{ + base: uuid.NewString(), + seq: 1, + } +} + +func (g *entryIDGenerator) Next() (id string) { + if g.seq < math.MaxUint16 { + g.seq++ + } else { + g.base = uuid.NewString() + g.seq = 1 + } + return fmt.Sprintf("%s-%04x", g.base, g.seq) +} + +type errLogWriter struct{} + +func (*errLogWriter) Write(err []byte) (int, error) { + logutil.BgLogger().Error("audit log error", zap.ByteString("err", err)) + return len(err), nil +} + +func normalizeIP(ip string) string { + switch ip { + case "localhost": + return "127.0.0.1" + default: + return ip + } +} + +func parseTableFilter(table string) (tablefilter.Filter, error) { + // forbid wildcards '!' and '@' for safety, + // please see https://github.com/pingcap/tidb-tools/tree/master/pkg/table-filter for more details. + arg := strings.TrimLeft(table, " \t") + if arg == "" || arg[0] == '!' || arg[0] == '@' || arg[0] == '/' { + return nil, errors.Errorf("invalid table format: '%s'", table) + } + + f, err := tablefilter.Parse([]string{arg}) + if err != nil { + return nil, errors.Errorf("invalid table format: '%s'", table) + } + + return f, nil +} + +func normalizeUserAndHost(userAndHost string) (normalized string, user string, host string, err error) { + parts := strings.Split(userAndHost, UserHostSep) + switch len(parts) { + case 1: + user = strings.TrimSpace(userAndHost) + host = "%" + normalized = user + case 2: + user = strings.TrimSpace(parts[0]) + host = strings.TrimSpace(parts[1]) + normalized = user + UserHostSep + host + default: + return "", "", "", errors.Errorf("illegal user format: '%s'", userAndHost) + } + return +} + +func toStringArgs(args []types.Datum) []string { + if len(args) == 0 { + return nil + } + + stringArgs := make([]string, len(args)) + for i, arg := range args { + if arg.IsNull() { + stringArgs[i] = "NULL" + } else { + stringArgs[i] = fmt.Sprintf("%v", arg.GetValue()) + } + } + return stringArgs +} + +func executeSQL(ctx context.Context, exec sqlexec.SQLExecutor, sql string, args ...interface{}) ([]chunk.Row, error) { + ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnOthers) + rs, err := exec.ExecuteInternal(ctx, sql, args...) + if err != nil { + return nil, err + } + + if rs != nil { + defer func() { + terror.Log(rs.Close()) + }() + return sqlexec.DrainRecordSet(ctx, rs, 8) + } + + return nil, nil +} diff --git a/extension/serverless/generate/main.go b/extension/serverless/generate/main.go new file mode 100644 index 0000000000000..1f5284184ed53 --- /dev/null +++ b/extension/serverless/generate/main.go @@ -0,0 +1,71 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build ignore +// +build ignore + +//go:generate go run main.go genfile +//go:generate go run main.go clear + +package main + +import ( + "os" + + "github.com/pingcap/errors" + "github.com/pingcap/tidb/parser/terror" +) + +const ( + importGoFile = "../../_import/generated-enterprise.go" + importGoFileContent = `// This file is auto generated by enterprise extension + +//go:build enterprise +// +build enterprise + +package extensionimport + +import ( + "github.com/pingcap/tidb/extension/enterprise/audit" +) + +func init() { + audit.Register() +} +` +) + +func main() { + doClear := false + if len(os.Args) > 1 { + action := os.Args[1] + switch action { + case "genfile": + break + case "clear": + doClear = true + default: + terror.MustNil(errors.New("Invalid action: " + action)) + } + } + + if doClear { + err := os.Remove(importGoFile) + if !os.IsNotExist(err) { + terror.MustNil(err) + } + } else { + terror.MustNil(os.WriteFile(importGoFile, []byte(importGoFileContent), 0644)) + } +} diff --git a/extension/session.go b/extension/session.go index e35f31ec68920..30dc223bc229d 100644 --- a/extension/session.go +++ b/extension/session.go @@ -88,6 +88,8 @@ type StmtEventInfo interface { RelatedTables() []stmtctx.TableEntry // GetError will return the error when the current statement is failed GetError() error + // GetSessionVars will return variables of current session + GetSessionVars() *variable.SessionVars } // SessionHandler is used to listen session events diff --git a/server/extension.go b/server/extension.go index 3851f58c09826..74ccbea12c74b 100644 --- a/server/extension.go +++ b/server/extension.go @@ -208,6 +208,10 @@ func (e *stmtEventInfo) GetError() error { return e.err } +func (e *stmtEventInfo) GetSessionVars() *variable.SessionVars { + return e.sessVars +} + func (e *stmtEventInfo) ensureExecutePreparedCache() *core.PlanCacheStmt { if e.executeStmt == nil { return nil diff --git a/server/server.go b/server/server.go index 7dee9aecd5ace..64f00b4b7f35d 100644 --- a/server/server.go +++ b/server/server.go @@ -37,6 +37,7 @@ import ( "math/rand" "net" "net/http" //nolint:goimports + // For pprof _ "net/http/pprof" // #nosec G108 "os" @@ -782,6 +783,7 @@ func (cc *clientConn) connectInfo() *variable.ConnectionInfo { } connInfo := &variable.ConnectionInfo{ ConnectionID: cc.connectionID, + GwConnID: cc.gwConnID, ConnectionType: connType, Host: cc.peerHost, ClientIP: cc.peerHost, diff --git a/sessionctx/variable/session.go b/sessionctx/variable/session.go index cf753e66fa64e..914184234e600 100644 --- a/sessionctx/variable/session.go +++ b/sessionctx/variable/session.go @@ -1815,6 +1815,7 @@ func (p *PlanCacheParamList) AllParamValues() []types.Datum { // ConnectionInfo presents the connection information, which is mainly used by audit logs. type ConnectionInfo struct { ConnectionID uint64 + GwConnID string ConnectionType string Host string ClientIP string diff --git a/standby/standby.go b/standby/standby.go index d3949588286f8..9c499176eeec3 100644 --- a/standby/standby.go +++ b/standby/standby.go @@ -45,8 +45,22 @@ const ( // ActivateRequest is the request body for activating the tidb server. type ActivateRequest struct { - KeyspaceName string `json:"keyspace_name"` - ExportID string `json:"export_id"` + KeyspaceName string `json:"keyspace_name"` + AuditLog *AuditLogConfig `json:"audit_log,omitempty"` + ExportID string `json:"export_id"` +} + +func (r *ActivateRequest) auditLogEnabled() bool { + if r.AuditLog == nil { + return false + } + return r.AuditLog.Enable +} + +// AuditLogConfig is the configuration about audit log when activating the tidb server. +type AuditLogConfig struct { + Enable bool `json:"enable"` + EncryptKey string `json:"encrypt_key"` } type sessionManager interface { @@ -160,6 +174,14 @@ func Handler(sm sessionManager) *http.ServeMux { w.WriteHeader(http.StatusBadRequest) return } + if req.AuditLog != nil && req.AuditLog.Enable { + if len(req.AuditLog.EncryptKey) != 32 { + logutil.BgLogger().Error("bad audit log encrypt key", zap.String("encrypt_key", req.AuditLog.EncryptKey)) + w.WriteHeader(http.StatusBadRequest) + return + } + logutil.BgLogger().Info("activate with audit log enabled", zap.String("keyspaceName", req.KeyspaceName)) + } mu.Lock() if state == standbyState { @@ -171,6 +193,11 @@ func Handler(sm sessionManager) *http.ServeMux { w.WriteHeader(http.StatusPreconditionFailed) w.Write([]byte("server is not in standby mode")) return + } else if activateRequest.auditLogEnabled() != req.auditLogEnabled() { + mu.Unlock() + w.WriteHeader(http.StatusPreconditionFailed) + w.Write([]byte("server audit log status has changed")) + return } // if client tries to activate with same keyspace name, wait for ready signal and return 200. mu.Unlock() diff --git a/tidb-server/BUILD.bazel b/tidb-server/BUILD.bazel index 097ca2aacb082..fadf472a4f266 100644 --- a/tidb-server/BUILD.bazel +++ b/tidb-server/BUILD.bazel @@ -14,6 +14,7 @@ go_library( "//executor", "//extension", "//extension/_import", + "//extension/serverless/audit", "//keyspace", "//kv", "//metrics", diff --git a/tidb-server/main.go b/tidb-server/main.go index 1f0157027c922..47165bbe10b76 100644 --- a/tidb-server/main.go +++ b/tidb-server/main.go @@ -42,6 +42,7 @@ import ( "github.com/pingcap/tidb/executor" "github.com/pingcap/tidb/extension" _ "github.com/pingcap/tidb/extension/_import" + "github.com/pingcap/tidb/extension/serverless/audit" "github.com/pingcap/tidb/keyspace" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/metrics" @@ -239,6 +240,10 @@ func main() { config.GetGlobalConfig().ActivationTimeout) config.UpdateGlobal(func(c *config.Config) { c.KeyspaceName = activateRequest.KeyspaceName + if activateRequest.AuditLog != nil { + c.AuditLog.Enable = activateRequest.AuditLog.Enable + c.AuditLog.EncryptKey = activateRequest.AuditLog.EncryptKey + } c.ExportID = activateRequest.ExportID }) // replace mainErrHandler to make sure standby handler can exit gracefully. @@ -300,6 +305,8 @@ func main() { } err = setupLog(keyspaceID) mainErrHandler(err) + err = setupAuditLog() + mainErrHandler(err) _, err = setupExtensions() mainErrHandler(err) err = setupStmtSummary() @@ -1103,6 +1110,11 @@ func setupLog(keyspaceID uint32) error { return nil } +func setupAuditLog() error { + auditLogCfg := config.GetGlobalConfig().AuditLog + return audit.RegisterForServerless(&auditLogCfg) +} + func setupExtensions() (*extension.Extensions, error) { err := extension.Setup() if err != nil {