Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP]expression: support ValidatePasswordStrength() #9812

Closed
wants to merge 12 commits into from
2 changes: 2 additions & 0 deletions executor/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ var (
ErrBadDB = terror.ClassExecutor.New(mysql.ErrBadDB, mysql.MySQLErrName[mysql.ErrBadDB])
ErrWrongObject = terror.ClassExecutor.New(mysql.ErrWrongObject, mysql.MySQLErrName[mysql.ErrWrongObject])
ErrRoleNotGranted = terror.ClassPrivilege.New(mysql.ErrRoleNotGranted, mysql.MySQLErrName[mysql.ErrRoleNotGranted])
ErrNotValidPassword = terror.ClassExecutor.New(mysql.ErrNotValidPassword, mysql.MySQLErrName[mysql.ErrNotValidPassword])
ErrDeadlock = terror.ClassExecutor.New(mysql.ErrLockDeadlock, mysql.MySQLErrName[mysql.ErrLockDeadlock])
ErrQueryInterrupted = terror.ClassExecutor.New(mysql.ErrQueryInterrupted, mysql.MySQLErrName[mysql.ErrQueryInterrupted])
)
Expand All @@ -69,6 +70,7 @@ func init() {
mysql.ErrTableaccessDenied: mysql.ErrTableaccessDenied,
mysql.ErrBadDB: mysql.ErrBadDB,
mysql.ErrWrongObject: mysql.ErrWrongObject,
mysql.ErrNotValidPassword: mysql.ErrNotValidPassword,
mysql.ErrLockDeadlock: mysql.ErrLockDeadlock,
mysql.ErrQueryInterrupted: mysql.ErrQueryInterrupted,
}
Expand Down
164 changes: 163 additions & 1 deletion expression/builtin_encryption.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ import (
"fmt"
"hash"
"io"
"strconv"
"strings"
"unicode"
"unicode/utf8"

"github.com/pingcap/errors"
"github.com/pingcap/parser/auth"
Expand Down Expand Up @@ -76,6 +79,15 @@ var (
// ivSize indicates the initialization vector supplied to aes_decrypt
const ivSize = aes.BlockSize

// the max length of a password is 100 in mysql
const maxPwdLength int64 = 100

// VALIDATE_PASSWORD_STRENGTH() will return 0 when the length of a password is less than minPwsLength
const minPwdLength int64 = 4

// Differential score between levels in password test
const differentialScore int64 = 25

// aesModeAttr indicates that the key length and iv attribute for specific block_encryption_mode.
// keySize is the key length in bits and mode is the encryption mode.
// ivRequired indicates that initialization vector is required or not.
Expand Down Expand Up @@ -895,10 +907,160 @@ func (b *builtinUncompressedLengthSig) evalInt(row chunk.Row) (int64, bool, erro
return int64(len), false, nil
}

func reverse(s string) string {
SunRunAway marked this conversation as resolved.
Show resolved Hide resolved
runes := []rune(s)
for i, j := 0, len(runes)-1; i < j; i, j = i+1, j-1 {
runes[i], runes[j] = runes[j], runes[i]
}
return string(runes)
}

func validateByUser(s *variable.SessionVars, psw string) (bool, error) {
v, err := variable.GetGlobalSystemVar(s, variable.ValidatePasswordCheckUserName)
if err != nil || strings.EqualFold(v, "OFF") {
return err == nil, err
}
if s.User == nil {
return true, nil
}
if n := s.User.Username; n != "" {
if psw == n || psw == reverse(n) {
return false, nil
}
}
if n := s.User.AuthUsername; n != "" {
if psw == n || psw == reverse(n) {
return false, nil
}
}
return true, nil
}

func validateByMixedDigitSpecial(s *variable.SessionVars, pwd string) (bool, error) {
SunRunAway marked this conversation as resolved.
Show resolved Hide resolved
// get system variables
v, err := variable.GetGlobalSystemVar(s, variable.ValidatePasswordMixedCaseCount)
if err != nil {
return false, err
}
minMixed, err := strconv.ParseInt(v, 10, 64)
if err != nil {
return false, err
}
v, err = variable.GetGlobalSystemVar(s, variable.ValidatePasswordNumberCount)
if err != nil {
return false, err
}
minDigit, err := strconv.ParseInt(v, 10, 64)
if err != nil {
return false, err
}
v, err = variable.GetGlobalSystemVar(s, variable.ValidatePasswordSpecialCharCount)
if err != nil {
return false, err
}
minSpecial, err := strconv.ParseInt(v, 10, 64)
if err != nil {
return false, err
}
wjhuang2016 marked this conversation as resolved.
Show resolved Hide resolved

numLower := int64(0)
numUpper := int64(0)
numDigit := int64(0)
numSpecial := int64(0)
for _, c := range pwd {
if unicode.IsLower(c) {
numLower++
} else if unicode.IsUpper(c) {
numUpper++
} else if unicode.IsDigit(c) {
numDigit++
} else {
numSpecial++
}
}

if numLower < minMixed || numUpper < minMixed || numDigit < minDigit || numSpecial < minSpecial {
return false, nil
}
return true, nil
}

// TODO: Support validating password by dictionary file
func validateByDictionary(s *variable.SessionVars, pwd string) (bool, error) {
return true, nil
}

type validatePasswordStrengthFunctionClass struct {
baseFunctionClass
}

func (c *validatePasswordStrengthFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) {
return nil, errFunctionNotExists.GenWithStackByArgs("FUNCTION", "VALIDATE_PASSWORD_STRENGTH")
if err := c.verifyArgs(args); err != nil {
return nil, err
}
bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETInt, types.ETString)
bf.tp.Flag = mysql.MaxIntWidth
sig := &builtinValidatePasswordStrengthSig{bf}
return sig, nil
}

type builtinValidatePasswordStrengthSig struct {
baseBuiltinFunc
}

func (c *builtinValidatePasswordStrengthSig) Clone() builtinFunc {
newSig := &builtinValidatePasswordStrengthSig{}
newSig.cloneFrom(&c.baseBuiltinFunc)
return newSig
}

// evalInt evals VALIDATE_PASSWORD_STRENGTH(str).
// See https://dev.mysql.com/doc/refman/5.7/en/encryption-functions.html#function_validate-password-strength
func (c *builtinValidatePasswordStrengthSig) evalInt(row chunk.Row) (int64, bool, error) {
sv := c.ctx.GetSessionVars()
pwd, isNull, err := c.args[0].EvalString(c.ctx, row)
score := int64(0)

if isNull || err != nil {
return score, true, err
}

l := int64(utf8.RuneCountInString(pwd))
if l > maxPwdLength {
pwd = string([]rune(pwd)[0:maxPwdLength])
}

valid, err := validateByUser(sv, pwd)
if err != nil || !valid {
return score, err != nil, err
}

if l < minPwdLength {
return score, false, nil
}
score += differentialScore

v, err := variable.GetGlobalSystemVar(sv, variable.ValidatePasswordLength)
if err != nil {
return score, false, nil
}
valPwdLen, err := strconv.ParseInt(v, 10, 64)
if err != nil || l < valPwdLen {
return score, err != nil, err
}
score += differentialScore

valid, err = validateByMixedDigitSpecial(sv, pwd)
if err != nil || !valid {
return score, err != nil, err
}
score += differentialScore

valid, err = validateByDictionary(sv, pwd)
if err != nil || !valid {
return score, err != nil, nil
}
score += differentialScore

return score, false, nil
}
32 changes: 32 additions & 0 deletions expression/builtin_encryption_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -493,3 +493,35 @@ func (s *testEvaluatorSuite) TestPassword(c *C) {
_, err := funcs[ast.PasswordFunc].getFunction(s.ctx, []Expression{Zero})
c.Assert(err, IsNil)
}

func (s *testEvaluatorSuite) TestValidatePasswordStrength(c *C) {
defer testleak.AfterTest(c)()
tests := []struct {
in interface{}
expect int64
}{
{"", 0},
{"1", 0},
{"你好啊", 0},
{"pass", 25},
{"user123", 25},
{"你好世界", 25},
{"password", 50},
{"password0000", 50},
{"auth_user", 50},
{"你好世界你好世界", 50},
{"Pingcap123", 50},
{"Pingcap123_", 100},
{"password1A#", 100},
{"PA12wrd!#", 100},
{"PA00wrd!#", 100},
}

for _, t := range tests {
f, err := newFunctionForTest(s.ctx, ast.ValidatePasswordStrength, s.primitiveValsToConstants([]interface{}{t.in})...)
c.Assert(err, IsNil)
d, err := f.Eval(chunk.Row{})
c.Assert(err, IsNil)
c.Assert(d.GetInt64(), Equals, t.expect)
}
}
59 changes: 59 additions & 0 deletions expression/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4368,6 +4368,65 @@ func (s *testIntegrationSuite) TestIssue9325(c *C) {
result.Check(testkit.Rows("2019-02-16 14:19:59", "2019-02-16 14:20:01"))
}

func (s *testIntegrationSuite) TestFuncValidatePasswordStrength(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("set global validate_password_check_user_name = 'ON';")
tk.MustQuery("SELECT @@global.validate_password_check_user_name;").Check(testkit.Rows("1"))
err := tk.ExecToErr("SET @@session.validate_password_check_user_name= ON;")
c.Assert(err, ErrorMatches, "*Variable 'validate_password_check_user_name' is a GLOBAL variable and should be set with SET GLOBAL")
err = tk.ExecToErr("SET validate_password_check_user_name= ON;")
c.Assert(err, ErrorMatches, "*Variable 'validate_password_check_user_name' is a GLOBAL variable and should be set with SET GLOBAL")
tk.MustExec("SET @@global.validate_password_policy=LOW;")
tk.MustExec("SET @@global.validate_password_mixed_case_count=0;")
tk.MustExec("SET @@global.validate_password_number_count=0;")
tk.MustExec("SET @@global.validate_password_special_char_count=0;")
tk.MustExec("SET @@global.validate_password_length=0;")
tk.MustExec("SET @@global.validate_password_check_user_name= ON;")
tk.Se.GetSessionVars().User = &auth.UserIdentity{Username: "root", AuthUsername: "root"}
result := tk.MustQuery("SELECT VALIDATE_PASSWORD_STRENGTH('root') = 0;")
result.Check(testkit.Rows("1"))
result = tk.MustQuery("SELECT VALIDATE_PASSWORD_STRENGTH('toor') = 0;")
result.Check(testkit.Rows("1"))
result = tk.MustQuery("SELECT VALIDATE_PASSWORD_STRENGTH('Root') <> 0;")
result.Check(testkit.Rows("1"))
result = tk.MustQuery("SELECT VALIDATE_PASSWORD_STRENGTH('Toor') <> 0;")
result.Check(testkit.Rows("1"))
result = tk.MustQuery("SELECT VALIDATE_PASSWORD_STRENGTH('fooHoHo%1') <> 0;")
result.Check(testkit.Rows("1"))

err = tk.ExecToErr("SELECT VALIDATE_PASSWORD_STRENGTH('password', 0);")
c.Check(err, ErrorMatches, "*Incorrect parameter count in the call to native function 'validate_password_strength'")
err = tk.ExecToErr("SELECT VALIDATE_PASSWORD_STRENGTH();")
c.Check(err, ErrorMatches, "*Incorrect parameter count in the call to native function 'validate_password_strength'")

tk.MustExec("set global validate_password_length = 10;")
tk.MustExec("set global validate_password_mixed_case_count = 3;")
tk.MustExec("set global validate_password_number_count = 3;")
tk.MustExec("set global validate_password_special_char_count = 3;")
tk.Se.GetSessionVars().User = &auth.UserIdentity{Username: "user123", AuthUsername: "auth_user"}
result = tk.MustQuery("select VALIDATE_PASSWORD_STRENGTH('user123')")
result.Check(testkit.Rows("0"))
result = tk.MustQuery("select VALIDATE_PASSWORD_STRENGTH('321resu')")
result.Check(testkit.Rows("0"))
result = tk.MustQuery("select VALIDATE_PASSWORD_STRENGTH('auth_user')")
result.Check(testkit.Rows("0"))
result = tk.MustQuery("select VALIDATE_PASSWORD_STRENGTH('resu_htua')")
result.Check(testkit.Rows("0"))
result = tk.MustQuery("select VALIDATE_PASSWORD_STRENGTH('password')")
result.Check(testkit.Rows("25"))
result = tk.MustQuery("select VALIDATE_PASSWORD_STRENGTH('password12')")
result.Check(testkit.Rows("50"))
result = tk.MustQuery("select VALIDATE_PASSWORD_STRENGTH('Pingcap123_')")
result.Check(testkit.Rows("50"))
result = tk.MustQuery("select VALIDATE_PASSWORD_STRENGTH('PingcapPP123_')")
result.Check(testkit.Rows("50"))
result = tk.MustQuery("select VALIDATE_PASSWORD_STRENGTH('PINGCAp123_')")
result.Check(testkit.Rows("50"))
result = tk.MustQuery("select VALIDATE_PASSWORD_STRENGTH('Pingcap_PP_123!')")
result.Check(testkit.Rows("100"))

}

func (s *testIntegrationSuite) TestIssue9710(c *C) {
tk := testkit.NewTestKit(c, s.store)
getSAndMS := func(str string) (int, int) {
Expand Down
23 changes: 18 additions & 5 deletions sessionctx/variable/sysvar.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,14 @@ var defaultSysVars = []*SysVar{
{ScopeGlobal, "slave_pending_jobs_size_max", "16777216"},
{ScopeNone, "innodb_sync_array_size", "1"},
{ScopeSession, "rand_seed2", ""},
{ScopeGlobal, "validate_password_check_user_name", "OFF"},
{ScopeGlobal, "validate_password_dictionary_file", ""},
{ScopeGlobal, "validate_password_length", "8"},
{ScopeGlobal, "validate_password_mixed_case_count", "1"},
{ScopeGlobal, ValidatePasswordCheckUserName, "0"},
{ScopeGlobal, "validate_password_number_count", "1"},
{ScopeGlobal, "validate_password_policy", "MEDIUM"},
{ScopeGlobal, "validate_password_special_char_count", "1"},
{ScopeSession, "gtid_next", ""},
{ScopeGlobal | ScopeSession, SQLSelectLimit, "18446744073709551615"},
{ScopeGlobal, "ndb_show_foreign_key_mock_tables", ""},
Expand Down Expand Up @@ -606,7 +612,6 @@ var defaultSysVars = []*SysVar{
{ScopeGlobal, "sync_relay_log_info", "10000"},
{ScopeGlobal | ScopeSession, "optimizer_trace_limit", "1"},
{ScopeNone, "innodb_ft_max_token_size", "84"},
{ScopeGlobal, "validate_password_length", "8"},
{ScopeGlobal, "ndb_log_binlog_index", ""},
{ScopeGlobal, "innodb_api_bk_commit_interval", "5"},
{ScopeNone, "innodb_undo_directory", "."},
Expand Down Expand Up @@ -815,10 +820,20 @@ const (
BlockEncryptionMode = "block_encryption_mode"
// WaitTimeout is the name for 'wait_timeout' system variable.
WaitTimeout = "wait_timeout"
// ValidatePasswordNumberCount is the name of 'validate_password_number_count' system variable.
ValidatePasswordNumberCount = "validate_password_number_count"
// ValidatePasswordCheckUserName is the name of 'validate_password_check_user_name' system variable
ValidatePasswordCheckUserName = "validate_password_check_user_name"
// ValidatePasswordDictionaryFile is the name of 'validate_password_dictionary_file' system variable
ValidatePasswordDictionaryFile = "validate_password_dictionary_file"
// ValidatePasswordLength is the name of 'validate_password_length' system variable.
ValidatePasswordLength = "validate_password_length"
// ValidatePasswordMixedCaseCount is the name of 'validate_password_mixed_case_count' system variable
ValidatePasswordMixedCaseCount = "validate_password_mixed_case_count"
// ValidatePasswordNumberCount is the name of 'validate_password_number_count' system variable.
ValidatePasswordNumberCount = "validate_password_number_count"
// ValidatePasswordPolicy is the name of 'validate_password_policy' system variable
ValidatePasswordPolicy = "validate_password_policy"
// ValidatePasswordSpecialCharCount is the name of 'validate_password_special_char_count' system variable
ValidatePasswordSpecialCharCount = "validate_password_special_char_count"
zz-jason marked this conversation as resolved.
Show resolved Hide resolved
// PluginDir is the name of 'plugin_dir' system variable.
PluginDir = "plugin_dir"
// PluginLoad is the name of 'plugin_load' system variable.
Expand All @@ -835,8 +850,6 @@ const (
BinlogOrderCommits = "binlog_order_commits"
// MasterVerifyChecksum is the name for 'master_verify_checksum' system variable.
MasterVerifyChecksum = "master_verify_checksum"
// ValidatePasswordCheckUserName is the name for 'validate_password_check_user_name' system variable.
ValidatePasswordCheckUserName = "validate_password_check_user_name"
// SuperReadOnly is the name for 'super_read_only' system variable.
SuperReadOnly = "super_read_only"
// SQLNotes is the name for 'sql_notes' system variable.
Expand Down
13 changes: 11 additions & 2 deletions sessionctx/variable/varsutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -347,8 +347,17 @@ func ValidateSetSystemVar(vars *SessionVars, name string, value string) (string,
}
_, err := parseTimeZone(value)
return value, err
case ValidatePasswordLength, ValidatePasswordNumberCount:
return checkUInt64SystemVar(name, value, 0, math.MaxUint64, vars)
case ValidatePasswordLength, ValidatePasswordMixedCaseCount, ValidatePasswordNumberCount, ValidatePasswordSpecialCharCount:
return checkUInt64SystemVar(name, value, 0, 100, vars)
case ValidatePasswordPolicy:
if strings.EqualFold(value, "LOW") || value == "0" {
return "0", nil
} else if strings.EqualFold(value, "MEDIUM") || value == "1" {
return "1", nil
} else if strings.EqualFold(value, "STRONG") || value == "2" {
return "2", nil
}
return value, ErrWrongValueForVar.GenWithStackByArgs(name, value)
case WarningCount, ErrorCount:
return value, ErrReadOnly.GenWithStackByArgs(name)
case EnforceGtidConsistency:
Expand Down