Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP]expression: support ValidatePasswordStrength() #9812

Closed
wants to merge 12 commits into from
2 changes: 2 additions & 0 deletions executor/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ var (
ErrBadDB = terror.ClassExecutor.New(mysql.ErrBadDB, mysql.MySQLErrName[mysql.ErrBadDB])
ErrWrongObject = terror.ClassExecutor.New(mysql.ErrWrongObject, mysql.MySQLErrName[mysql.ErrWrongObject])
ErrRoleNotGranted = terror.ClassPrivilege.New(mysql.ErrRoleNotGranted, mysql.MySQLErrName[mysql.ErrRoleNotGranted])
ErrNotValidPassword = terror.ClassExecutor.New(mysql.ErrNotValidPassword, mysql.MySQLErrName[mysql.ErrNotValidPassword])
ErrDeadlock = terror.ClassExecutor.New(mysql.ErrLockDeadlock, mysql.MySQLErrName[mysql.ErrLockDeadlock])
ErrQueryInterrupted = terror.ClassExecutor.New(mysql.ErrQueryInterrupted, mysql.MySQLErrName[mysql.ErrQueryInterrupted])
)
Expand All @@ -69,6 +70,7 @@ func init() {
mysql.ErrTableaccessDenied: mysql.ErrTableaccessDenied,
mysql.ErrBadDB: mysql.ErrBadDB,
mysql.ErrWrongObject: mysql.ErrWrongObject,
mysql.ErrNotValidPassword: mysql.ErrNotValidPassword,
mysql.ErrLockDeadlock: mysql.ErrLockDeadlock,
mysql.ErrQueryInterrupted: mysql.ErrQueryInterrupted,
}
Expand Down
154 changes: 153 additions & 1 deletion expression/builtin_encryption.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ import (
"fmt"
"hash"
"io"
"strconv"
"strings"
"unicode"

"github.com/pingcap/errors"
"github.com/pingcap/parser/auth"
Expand Down Expand Up @@ -895,10 +897,160 @@ func (b *builtinUncompressedLengthSig) evalInt(row chunk.Row) (int64, bool, erro
return int64(len), false, nil
}

func reverse(s string) string {
SunRunAway marked this conversation as resolved.
Show resolved Hide resolved
runes := []rune(s)
for i, j := 0, len(runes)-1; i < j; i, j = i+1, j-1 {
runes[i], runes[j] = runes[j], runes[i]
}
return string(runes)
}

func validateByUser(s *variable.SessionVars, psw string) (bool, error) {
v, err := variable.GetGlobalSystemVar(s, variable.ValidatePasswordCheckUserName)
if err != nil || strings.EqualFold(v, "OFF") {
return err == nil, err
}
if s.User == nil {
return true, nil
}
if n := s.User.Username; n != "" {
if psw == n || psw == reverse(n) {
return false, nil
}
}
if n := s.User.AuthUsername; n != "" {
if psw == n || psw == reverse(n) {
return false, nil
}
}
return true, nil
}

func validateByMixedDigitSpecial(s *variable.SessionVars, pwd string) (bool, error) {
SunRunAway marked this conversation as resolved.
Show resolved Hide resolved
numLower := int64(0)
numUpper := int64(0)
numDigit := int64(0)
numSpecial := int64(0)
for _, c := range pwd {
if unicode.IsLower(c) {
numLower++
} else if unicode.IsUpper(c) {
numUpper++
} else if unicode.IsDigit(c) {
numDigit++
} else {
numSpecial++
}
}
v, err := variable.GetGlobalSystemVar(s, variable.ValidatePasswordMixedCaseCount)
if err != nil {
return false, err
}
minMixed, err := strconv.ParseInt(v, 10, 64)
if err != nil || numLower < minMixed || numUpper < minMixed {
return false, err
}
v, err = variable.GetGlobalSystemVar(s, variable.ValidatePasswordNumberCount)
if err != nil {
return false, err
}
minDigit, err := strconv.ParseInt(v, 10, 64)
if err != nil || numDigit < minDigit {
return false, err
}
v, err = variable.GetGlobalSystemVar(s, variable.ValidatePasswordSpecialCharCount)
if err != nil {
return false, err
}
minSpecial, err := strconv.ParseInt(v, 10, 64)
if err != nil || numSpecial < minSpecial {
return false, err
}
wjhuang2016 marked this conversation as resolved.
Show resolved Hide resolved

return true, nil
}

// TODO: Support validating password by dictionary file
func validateByDictionary(s *variable.SessionVars, pwd string) (bool, error) {
return true, nil
}

type validatePasswordStrengthFunctionClass struct {
baseFunctionClass
}

func (c *validatePasswordStrengthFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) {
return nil, errFunctionNotExists.GenWithStackByArgs("FUNCTION", "VALIDATE_PASSWORD_STRENGTH")
if err := c.verifyArgs(args); err != nil {
return nil, err
}
bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETInt, types.ETString)
bf.tp.Flag = mysql.MaxIntWidth
sig := &builtinValidatePasswordStrengthSig{bf}
return sig, nil
}

type builtinValidatePasswordStrengthSig struct {
baseBuiltinFunc
}

func (c *builtinValidatePasswordStrengthSig) Clone() builtinFunc {
newSig := &builtinValidatePasswordStrengthSig{}
newSig.cloneFrom(&c.baseBuiltinFunc)
return newSig
}

// evalInt evals VALIDATE_PASSWORD_STRENGTH(str).
// See https://dev.mysql.com/doc/refman/8.0/en/encryption-functions.html#function_validate-password-strength
wjhuang2016 marked this conversation as resolved.
Show resolved Hide resolved
func (c *builtinValidatePasswordStrengthSig) evalInt(row chunk.Row) (int64, bool, error) {
sv := c.ctx.GetSessionVars()
pwd, isNull, err := c.args[0].EvalString(c.ctx, row)
score := int64(0)

if isNull || err != nil {
return score, true, err
}

// In MySQL, the max length of a password is 100
l := int64(0)
for index := range pwd {
wjhuang2016 marked this conversation as resolved.
Show resolved Hide resolved
l++
if l > 100 {
pwd = pwd[:index-1]
wjhuang2016 marked this conversation as resolved.
Show resolved Hide resolved
break
}
}

valid, err := validateByUser(sv, pwd)
if err != nil || !valid {
return score, err != nil, err
}

if l < 4 {
wjhuang2016 marked this conversation as resolved.
Show resolved Hide resolved
return score, false, nil
}
score += 25
wjhuang2016 marked this conversation as resolved.
Show resolved Hide resolved

v, err := variable.GetGlobalSystemVar(sv, variable.ValidatePasswordLength)
if err != nil {
return score, false, nil
}
valPwdLen, err := strconv.ParseInt(v, 10, 64)
if err != nil || l < valPwdLen {
return score, err != nil, err
}
score += 25

valid, err = validateByMixedDigitSpecial(sv, pwd)
if err != nil || !valid {
return score, err != nil, err
}
score += 25

valid, err = validateByDictionary(sv, pwd)
if err != nil || !valid {
return score, err != nil, nil
}
score += 25

return score, false, nil
}
28 changes: 28 additions & 0 deletions expression/builtin_encryption_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -493,3 +493,31 @@ func (s *testEvaluatorSuite) TestPassword(c *C) {
_, err := funcs[ast.PasswordFunc].getFunction(s.ctx, []Expression{Zero})
c.Assert(err, IsNil)
}

func (s *testEvaluatorSuite) TestValidatePasswordStrength(c *C) {
defer testleak.AfterTest(c)()
tests := []struct {
in interface{}
expect int64
}{
{"", 0},
{"1", 0},
{"你好啊", 0},
{"pass", 25},
{"user123", 25},
{"你好世界", 25},
{"password", 50},
{"auth_user", 50},
{"你好世界你好世界", 50},
{"Pingcap123", 50},
{"Pingcap123_", 100},
}

for _, t := range tests {
f, err := newFunctionForTest(s.ctx, ast.ValidatePasswordStrength, s.primitiveValsToConstants([]interface{}{t.in})...)
c.Assert(err, IsNil)
d, err := f.Eval(chunk.Row{})
c.Assert(err, IsNil)
c.Assert(d.GetInt64(), Equals, t.expect)
}
}
30 changes: 30 additions & 0 deletions expression/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4227,6 +4227,36 @@ func (s *testIntegrationSuite) TestIssue9325(c *C) {
result.Check(testkit.Rows("2019-02-16 14:19:59", "2019-02-16 14:20:01"))
}

func (s *testIntegrationSuite) TestFuncValidatePasswordStrength(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("set global validate_password_check_user_name = 'ON';")
tk.MustExec("set global validate_password_length = 10;")
tk.MustExec("set global validate_password_mixed_case_count = 3;")
tk.MustExec("set global validate_password_number_count = 3;")
tk.MustExec("set global validate_password_special_char_count = 3;")
tk.Se.GetSessionVars().User = &auth.UserIdentity{Username: "user123", AuthUsername: "auth_user"}
result := tk.MustQuery("select VALIDATE_PASSWORD_STRENGTH('user123')")
result.Check(testkit.Rows("0"))
result = tk.MustQuery("select VALIDATE_PASSWORD_STRENGTH('321resu')")
result.Check(testkit.Rows("0"))
result = tk.MustQuery("select VALIDATE_PASSWORD_STRENGTH('auth_user')")
result.Check(testkit.Rows("0"))
result = tk.MustQuery("select VALIDATE_PASSWORD_STRENGTH('resu_htua')")
result.Check(testkit.Rows("0"))
result = tk.MustQuery("select VALIDATE_PASSWORD_STRENGTH('password')")
result.Check(testkit.Rows("25"))
result = tk.MustQuery("select VALIDATE_PASSWORD_STRENGTH('password12')")
result.Check(testkit.Rows("50"))
result = tk.MustQuery("select VALIDATE_PASSWORD_STRENGTH('Pingcap123_')")
result.Check(testkit.Rows("50"))
result = tk.MustQuery("select VALIDATE_PASSWORD_STRENGTH('PingcapPP123_')")
result.Check(testkit.Rows("50"))
result = tk.MustQuery("select VALIDATE_PASSWORD_STRENGTH('PINGCAp123_')")
result.Check(testkit.Rows("50"))
result = tk.MustQuery("select VALIDATE_PASSWORD_STRENGTH('Pingcap_PP_123!')")
result.Check(testkit.Rows("100"))
}

func (s *testIntegrationSuite) TestIssue9710(c *C) {
tk := testkit.NewTestKit(c, s.store)
getSAndMS := func(str string) (int, int) {
Expand Down
26 changes: 21 additions & 5 deletions sessionctx/variable/sysvar.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,14 @@ var defaultSysVars = []*SysVar{
{ScopeGlobal, "slave_pending_jobs_size_max", "16777216"},
{ScopeNone, "innodb_sync_array_size", "1"},
{ScopeSession, "rand_seed2", ""},
{ScopeGlobal, "validate_password_check_user_name", "OFF"},
{ScopeGlobal, "validate_password_dictionary_file", ""},
{ScopeGlobal, "validate_password_length", "8"},
{ScopeGlobal, "validate_password_mixed_case_count", "1"},
{ScopeGlobal, ValidatePasswordCheckUserName, "0"},
{ScopeGlobal, "validate_password_number_count", "1"},
{ScopeGlobal, "validate_password_policy", "MEDIUM"},
{ScopeGlobal, "validate_password_special_char_count", "1"},
{ScopeSession, "gtid_next", ""},
{ScopeGlobal | ScopeSession, SQLSelectLimit, "18446744073709551615"},
{ScopeGlobal, "ndb_show_foreign_key_mock_tables", ""},
Expand Down Expand Up @@ -205,6 +211,8 @@ var defaultSysVars = []*SysVar{
{ScopeNone, "innodb_log_group_home_dir", "./"},
{ScopeNone, "performance_schema_events_statements_history_size", "10"},
{ScopeGlobal, GeneralLog, "0"},
{ScopeGlobal, "binlog_order_commits", "ON"},
{ScopeGlobal, "master_verify_checksum", "OFF"},
{ScopeGlobal, "validate_password_dictionary_file", ""},
{ScopeGlobal, BinlogOrderCommits, "1"},
{ScopeGlobal, MasterVerifyChecksum, "0"},
Expand Down Expand Up @@ -594,6 +602,7 @@ var defaultSysVars = []*SysVar{
{ScopeGlobal | ScopeSession, "character_set_connection", mysql.DefaultCharset},
{ScopeGlobal, MyISAMUseMmap, "0"},
{ScopeGlobal | ScopeSession, "ndb_join_pushdown", ""},
{ScopeGlobal | ScopeSession, "character_set_server", mysql.DefaultCharset},
wjhuang2016 marked this conversation as resolved.
Show resolved Hide resolved
{ScopeGlobal | ScopeSession, CharacterSetServer, mysql.DefaultCharset},
{ScopeGlobal, "validate_password_special_char_count", "1"},
{ScopeNone, "performance_schema_max_thread_instances", "402"},
Expand All @@ -606,7 +615,6 @@ var defaultSysVars = []*SysVar{
{ScopeGlobal, "sync_relay_log_info", "10000"},
{ScopeGlobal | ScopeSession, "optimizer_trace_limit", "1"},
{ScopeNone, "innodb_ft_max_token_size", "84"},
{ScopeGlobal, "validate_password_length", "8"},
{ScopeGlobal, "ndb_log_binlog_index", ""},
{ScopeGlobal, "innodb_api_bk_commit_interval", "5"},
{ScopeNone, "innodb_undo_directory", "."},
Expand Down Expand Up @@ -814,10 +822,20 @@ const (
BlockEncryptionMode = "block_encryption_mode"
// WaitTimeout is the name for 'wait_timeout' system variable.
WaitTimeout = "wait_timeout"
// ValidatePasswordNumberCount is the name of 'validate_password_number_count' system variable.
ValidatePasswordNumberCount = "validate_password_number_count"
// ValidatePasswordCheckUserName is the name of 'validate_password_check_user_name' system variable
ValidatePasswordCheckUserName = "validate_password_check_user_name"
// ValidatePasswordDictionaryFile is the name of 'validate_password_dictionary_file' system variable
ValidatePasswordDictionaryFile = "validate_password_dictionary_file"
// ValidatePasswordLength is the name of 'validate_password_length' system variable.
ValidatePasswordLength = "validate_password_length"
// ValidatePasswordMixedCaseCount is the name of 'validate_password_mixed_case_count' system variable
ValidatePasswordMixedCaseCount = "validate_password_mixed_case_count"
// ValidatePasswordNumberCount is the name of 'validate_password_number_count' system variable.
ValidatePasswordNumberCount = "validate_password_number_count"
// ValidatePasswordPolicy is the name of 'validate_password_policy' system variable
ValidatePasswordPolicy = "validate_password_policy"
// ValidatePasswordSpecialCharCount is the name of 'validate_password_special_char_count' system variable
ValidatePasswordSpecialCharCount = "validate_password_special_char_count"
zz-jason marked this conversation as resolved.
Show resolved Hide resolved
// PluginDir is the name of 'plugin_dir' system variable.
PluginDir = "plugin_dir"
// PluginLoad is the name of 'plugin_load' system variable.
Expand All @@ -834,8 +852,6 @@ const (
BinlogOrderCommits = "binlog_order_commits"
// MasterVerifyChecksum is the name for 'master_verify_checksum' system variable.
MasterVerifyChecksum = "master_verify_checksum"
// ValidatePasswordCheckUserName is the name for 'validate_password_check_user_name' system variable.
ValidatePasswordCheckUserName = "validate_password_check_user_name"
// SuperReadOnly is the name for 'super_read_only' system variable.
SuperReadOnly = "super_read_only"
// SQLNotes is the name for 'sql_notes' system variable.
Expand Down
13 changes: 11 additions & 2 deletions sessionctx/variable/varsutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -347,8 +347,17 @@ func ValidateSetSystemVar(vars *SessionVars, name string, value string) (string,
}
_, err := parseTimeZone(value)
return value, err
case ValidatePasswordLength, ValidatePasswordNumberCount:
return checkUInt64SystemVar(name, value, 0, math.MaxUint64, vars)
case ValidatePasswordLength, ValidatePasswordMixedCaseCount, ValidatePasswordNumberCount, ValidatePasswordSpecialCharCount:
return checkUInt64SystemVar(name, value, 0, 100, vars)
case ValidatePasswordPolicy:
if strings.EqualFold(value, "LOW") || value == "0" {
return "0", nil
} else if strings.EqualFold(value, "MEDIUM") || value == "1" {
return "1", nil
} else if strings.EqualFold(value, "STRONG") || value == "2" {
return "2", nil
}
return value, ErrWrongValueForVar.GenWithStackByArgs(name, value)
case WarningCount, ErrorCount:
return value, ErrReadOnly.GenWithStackByArgs(name)
case EnforceGtidConsistency:
Expand Down