Skip to content

Commit

Permalink
Merge pull request #4 from imtbkcat/supportLoginFail
Browse files Browse the repository at this point in the history
privilege: support login fail count and account blacklist
  • Loading branch information
Lingyu Song authored May 9, 2020
2 parents 63dc364 + 559b974 commit 0db9f6e
Show file tree
Hide file tree
Showing 7 changed files with 226 additions and 0 deletions.
6 changes: 6 additions & 0 deletions privilege/privilege.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@ type Manager interface {

// GetAllRoles return all roles of user.
GetAllRoles(user, host string) []*auth.RoleIdentity

// CheckAccountLock return if the account has been locked.
CheckAccountLocked(ctx sessionctx.Context, user, host string) bool

// IncFailTimer is used to increase lock timer.
IncFailTimer(ctx sessionctx.Context, user, host string)
}

const key keyType = 0
Expand Down
100 changes: 100 additions & 0 deletions privilege/privileges/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,12 @@ type roleGraphEdgesTable struct {
roleList map[string]*auth.RoleIdentity
}

type blackListItem struct {
baseRecord

startTime time.Time
}

// Find method is used to find role from table
func (g roleGraphEdgesTable) Find(user, host string) bool {
if host == "" {
Expand Down Expand Up @@ -229,6 +235,8 @@ type MySQLPrivilege struct {
ColumnsPriv []columnsPrivRecord
DefaultRoles []defaultRoleRecord
RoleGraph map[string]roleGraphEdgesTable
PwdErrorCnt map[string]int
BlackList map[string][]blackListItem
}

// FindAllRole is used to find all roles grant to this user.
Expand Down Expand Up @@ -256,6 +264,66 @@ func (p *MySQLPrivilege) FindAllRole(activeRoles []*auth.RoleIdentity) []*auth.R
return ret
}

func (p *MySQLPrivilege) IncLoginFail(user, host string) int {
key := user + "@" + host
cnt, exist := p.PwdErrorCnt[key]
if exist {
p.PwdErrorCnt[key] = cnt + 1
return cnt + 1
}
p.PwdErrorCnt[key] = 1
return 1
}

func (p *MySQLPrivilege) ClearLoginFail(user, host string) {
key := user + "@" + host
if p.PwdErrorCnt == nil {
p.PwdErrorCnt = make(map[string]int)
}
p.PwdErrorCnt[key] = 0
}

func (p *MySQLPrivilege) LockAccount(user, host string, sctx sessionctx.Context) error {
lock := types.CurrentTime(0)
ctx := context.Background()
sql := fmt.Sprintf("insert into mysql.login_blacklist(USER, HOST, lock_time) values('%s', '%s', '%s') on duplicate key update lock_time = '%s'", user, host, lock.String(), lock.String())
_, err := sctx.(sqlexec.SQLExecutor).Execute(ctx, sql)
return err
}

func (p *MySQLPrivilege) CheckAccountLock(user, host string, sctx sessionctx.Context, limit time.Duration) bool {
recs, exist := p.BlackList[user]
if exist {
for _, r := range recs {
if r.Host == host {
t := r.startTime
if time.Now().Sub(t) > limit*time.Second {
ctx := context.Background()
sql := fmt.Sprintf("delete from mysql.login_blacklist where user = '%s' and host = '%s'", user, host)
_, err := sctx.(sqlexec.SQLExecutor).Execute(ctx, sql)
if err != nil {
// ignore
}
return false
} else {
return true
}
}
}
}
return false
}

func (p *MySQLPrivilege) LoadBlackList(ctx sessionctx.Context) error {
p.BlackList = make(map[string][]blackListItem)
err := p.loadTable(ctx, "select HOST, USER, lock_time from mysql.login_blacklist;", p.decodeBlackListRow)
if err != nil {
logutil.BgLogger().Warn("load mysql.user fail", zap.Error(err))
return errLoadPrivilege.FastGen("mysql.login_blacklist")
}
return nil
}

// FindRole is used to detect whether there is edges between users and roles.
func (p *MySQLPrivilege) FindRole(user string, host string, role *auth.RoleIdentity) bool {
rec := p.matchUser(user, host)
Expand All @@ -269,6 +337,9 @@ func (p *MySQLPrivilege) FindRole(user string, host string, role *auth.RoleIdent

// LoadAll loads the tables from database to memory.
func (p *MySQLPrivilege) LoadAll(ctx sessionctx.Context) error {
if p.PwdErrorCnt == nil {
p.PwdErrorCnt = make(map[string]int)
}
err := p.LoadUserTable(ctx)
if err != nil {
logutil.BgLogger().Warn("load mysql.user fail", zap.Error(err))
Expand Down Expand Up @@ -324,6 +395,11 @@ func (p *MySQLPrivilege) LoadAll(ctx sessionctx.Context) error {
}
logutil.BgLogger().Warn("mysql.role_edges missing")
}

err = p.LoadBlackList(ctx)
if err != nil {
return errLoadPrivilege.FastGen("mysql.login_blacklist")
}
return nil
}

Expand Down Expand Up @@ -584,6 +660,30 @@ func (record *baseRecord) assignUserOrHost(row chunk.Row, i int, f *ast.ResultFi
}
}

func (p *MySQLPrivilege) decodeBlackListRow(row chunk.Row, fs []*ast.ResultField) error {
var value blackListItem
for i, f := range fs {
switch {
case f.ColumnAsName.L == "host":
value.Host = row.GetString(i)
case f.ColumnAsName.L == "user":
value.User = row.GetString(i)
case f.ColumnAsName.L == "lock_time":
ti := row.GetTime(i)
var err error
value.startTime, err = ti.GoTime(time.Local)
if err != nil {
return err
}
}
}
if p.BlackList == nil {
p.BlackList = make(map[string][]blackListItem)
}
p.BlackList[value.User] = append(p.BlackList[value.User], value)
return nil
}

func (p *MySQLPrivilege) decodeUserTableRow(row chunk.Row, fs []*ast.ResultField) error {
var value UserRecord
for i, f := range fs {
Expand Down
79 changes: 79 additions & 0 deletions privilege/privileges/privileges.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,17 @@ package privileges
import (
"crypto/tls"
"fmt"
"strconv"
"strings"
"time"

"github.com/pingcap/parser/auth"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/tidb/infoschema"
"github.com/pingcap/tidb/infoschema/perfschema"
"github.com/pingcap/tidb/privilege"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util"
"github.com/pingcap/tidb/util/logutil"
Expand Down Expand Up @@ -145,6 +148,81 @@ func (p *UserPrivileges) GetAuthWithoutVerification(user, host string) (u string
return
}

// CheckAccountLock return if the account has been locked.
func (p *UserPrivileges) CheckAccountLocked(ctx sessionctx.Context, user, host string) bool {
if SkipWithGrant {
return false
}

mysqlPriv := p.Handle.Get()
record := mysqlPriv.connectionVerification(user, host)
if record == nil {
logutil.BgLogger().Error("get user privilege record fail",
zap.String("user", user), zap.String("host", host))
return true
}

u := record.User
h := record.Host

sessionVars := ctx.GetSessionVars()
lockTimeVar, err := variable.GetSessionSystemVar(sessionVars, variable.LoginBlockInterval)
if err != nil {
logutil.BgLogger().Error("Get locktime fail.", zap.Error(err))
return true
}
lockTime, err := strconv.ParseInt(lockTimeVar, 10, 0)
if err != nil {
logutil.BgLogger().Error("Parse password limit fail.", zap.Error(err))
return true
}
if mysqlPriv.CheckAccountLock(u, h, ctx, time.Duration(lockTime)) {
return true
}

return false
}

func (p *UserPrivileges) IncFailTimer(ctx sessionctx.Context, user, host string) {
if SkipWithGrant {
return
}

mysqlPriv := p.Handle.Get()
record := mysqlPriv.connectionVerification(user, host)
if record == nil {
logutil.BgLogger().Error("get user privilege record fail",
zap.String("user", user), zap.String("host", host))
return
}

u := record.User
h := record.Host

sessionVars := ctx.GetSessionVars()
cnt := mysqlPriv.IncLoginFail(u, h)
limitVar, err := variable.GetSessionSystemVar(sessionVars, variable.MaxLoginAttempts)
if err != nil {
logutil.BgLogger().Error("Get password limit fail.", zap.Error(err))
return
}
limit, err := strconv.ParseInt(limitVar, 10, 0)
if err != nil {
logutil.BgLogger().Error("Parse password limit fail.", zap.Error(err))
return
}
if cnt > int(limit) {
err := mysqlPriv.LockAccount(u, h, ctx)
if err != nil {
logutil.BgLogger().Error("error occuer while locking account", zap.Error(err))
}
err = p.Update(ctx)
if err != nil {
logutil.BgLogger().Error("error occuer while updating account lock info", zap.Error(err))
}
}
}

// ConnectionVerification implements the Manager interface.
func (p *UserPrivileges) ConnectionVerification(user, host string, authentication, salt []byte, tlsState *tls.ConnectionState) (u string, h string, success bool) {
if SkipWithGrant {
Expand Down Expand Up @@ -214,6 +292,7 @@ func (p *UserPrivileges) ConnectionVerification(user, host string, authenticatio

p.user = user
p.host = h
mysqlPriv.ClearLoginFail(u, h)
success = true
return
}
Expand Down
24 changes: 24 additions & 0 deletions session/bootstrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,13 @@ const (
CreateOptRuleBlacklist = `CREATE TABLE IF NOT EXISTS mysql.opt_rule_blacklist (
name char(100) NOT NULL
);`

CreateLoginBlackList = `CREATE TABLE IF NOT EXISTS mysql.login_blacklist (
HOST char(60) COLLATE utf8_bin NOT NULL DEFAULT '',
USER char(32) COLLATE utf8_bin NOT NULL DEFAULT '',
lock_time timestamp default null,
primary key(USER, HOST)
)`
)

// bootstrap initiates system DB for a store.
Expand Down Expand Up @@ -384,6 +391,8 @@ const (
version44 = 44
// version45 introduces CONFIG_PRIV for SET CONFIG statements.
version45 = 45
// version46 introduces mysql.login_blacklist to detect login fail.
version46 = 46
)

var (
Expand Down Expand Up @@ -432,6 +441,7 @@ var (
upgradeToVer43,
upgradeToVer44,
upgradeToVer45,
upgradeToVer46,
}
)

Expand Down Expand Up @@ -1045,6 +1055,19 @@ func upgradeToVer45(s Session, ver int64) {
mustExecute(s, "UPDATE HIGH_PRIORITY mysql.user SET Config_priv='Y' where Super_priv='Y'")
}

func upgradeToVer46(s Session, ver int64) {
if ver >= version46 {
return
}
mustExecute(s, CreateLoginBlackList)
sql := fmt.Sprintf("INSERT IGNORE INTO %s.%s (`VARIABLE_NAME`, `VARIABLE_VALUE`) VALUES ('%s', '%d')",
mysql.SystemDB, mysql.GlobalVariablesTable, variable.MaxLoginAttempts, 10)
mustExecute(s, sql)
sql = fmt.Sprintf("INSERT IGNORE INTO %s.%s (`VARIABLE_NAME`, `VARIABLE_VALUE`) VALUES ('%s', '%d')",
mysql.SystemDB, mysql.GlobalVariablesTable, variable.LoginBlockInterval, 3600)
mustExecute(s, sql)
}

// updateBootstrapVer updates bootstrap version variable in mysql.TiDB table.
func updateBootstrapVer(s Session) {
// Update bootstrap version.
Expand Down Expand Up @@ -1108,6 +1131,7 @@ func doDDLWorks(s Session) {
mustExecute(s, CreateExprPushdownBlacklist)
// Create opt_rule_blacklist table.
mustExecute(s, CreateOptRuleBlacklist)
mustExecute(s, CreateLoginBlackList)
}

// doDMLWorks executes DML statements in bootstrap stage.
Expand Down
8 changes: 8 additions & 0 deletions session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -1505,6 +1505,12 @@ func (s *session) GetSessionVars() *variable.SessionVars {
func (s *session) Auth(user *auth.UserIdentity, authentication []byte, salt []byte) bool {
pm := privilege.GetPrivilegeManager(s)

// check account lock.
locked := pm.CheckAccountLocked(s, user.Username, user.Hostname)
if locked {
return false
}

// Check IP or localhost.
var success bool
user.AuthUsername, user.AuthHostname, success = pm.ConnectionVerification(user.Username, user.Hostname, authentication, salt, s.sessionVars.TLSConnectionState)
Expand All @@ -1513,6 +1519,7 @@ func (s *session) Auth(user *auth.UserIdentity, authentication []byte, salt []by
s.sessionVars.ActiveRoles = pm.GetDefaultRoles(user.AuthUsername, user.AuthHostname)
return true
} else if user.Hostname == variable.DefHostname {
pm.IncFailTimer(s, user.AuthUsername, user.AuthHostname)
return false
}

Expand All @@ -1530,6 +1537,7 @@ func (s *session) Auth(user *auth.UserIdentity, authentication []byte, salt []by
return true
}
}
pm.IncFailTimer(s, user.AuthUsername, user.AuthHostname)
return false
}

Expand Down
7 changes: 7 additions & 0 deletions sessionctx/variable/sysvar.go
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,9 @@ var defaultSysVars = []*SysVar{
{ScopeSession, TiDBEnableSlowLog, BoolToIntStr(logutil.DefaultTiDBEnableSlowLog)},
{ScopeSession, TiDBQueryLogMaxLen, strconv.Itoa(logutil.DefaultQueryLogMaxLen)},
{ScopeSession, TiDBCheckMb4ValueInUTF8, BoolToIntStr(config.GetGlobalConfig().CheckMb4ValueInUTF8)},
// Used to block DoS.
{ScopeGlobal, MaxLoginAttempts, strconv.Itoa(DefMaxLoginAttempts)},
{ScopeGlobal, LoginBlockInterval, strconv.Itoa(DefLoginBlockInterval)},
}

// SynonymsSysVariables is synonyms of system variables.
Expand Down Expand Up @@ -979,6 +982,10 @@ const (
ThreadPoolSize = "thread_pool_size"
// WindowingUseHighPrecision is the name of 'windowing_use_high_precision' system variable.
WindowingUseHighPrecision = "windowing_use_high_precision"
// MaxLoginAttempts is the name for 'max_login_attempts' system variable.
MaxLoginAttempts = "max_login_attempts"
// LoginBlockInterval is the name for 'login_block_interval' system variable.
LoginBlockInterval = "login_block_interval"
)

// GlobalVarAccessor is the interface for accessing global scope system and status variables.
Expand Down
2 changes: 2 additions & 0 deletions sessionctx/variable/tidb_vars.go
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,8 @@ const (
DefTiDBStoreLimit = 0
DefTiDBMetricSchemaStep = 60 // 60s
DefTiDBMetricSchemaRangeDuration = 60 // 60s
DefMaxLoginAttempts = 10
DefLoginBlockInterval = 3600
)

// Process global variables.
Expand Down

0 comments on commit 0db9f6e

Please sign in to comment.