Skip to content

Commit

Permalink
parser, ddl: support UPPER(SUBSTRING_INDEX(user(),'@',1))
Browse files Browse the repository at this point in the history
  • Loading branch information
zimulala committed Feb 19, 2024
1 parent cea3605 commit 4749504
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 8 deletions.
59 changes: 59 additions & 0 deletions pkg/ddl/db_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1606,6 +1606,65 @@ func TestDefaultColumnWithRand(t *testing.T) {
tk.MustGetErrCode("CREATE TABLE t3 (c int, c1 int default a_function_not_supported_yet());", errno.ErrDefValGeneratedNamedFunctionIsNotAllowed)
}

func TestDefaultColumnWithUpper(t *testing.T) {
store := testkit.CreateMockStoreWithSchemaLease(t, testLease)
tk := testkit.NewTestKit(t, store)
tk.MustExec("use test")
tk.MustExec("drop table if exists t, t1, t2")

// create table
tk.MustExec("create table t (c int(10), c1 varchar(256) default (upper(substring_index(user(),'@',1))))")
tk.MustExec("create table t1 (c int(10), c1 int default (upper(substring_index(user(),_utf8mb4'@',1))))")
tk.MustGetErrCode("create table t2 (c int(10), c1 varchar(256) default (substring_index(user(),'@',1)))", errno.ErrDefValGeneratedNamedFunctionIsNotAllowed)
tk.MustGetErrCode("create table t2 (c int(10), c1 varchar(256) default (upper(substring_index('fjks@jkkl','@',1))))", errno.ErrDefValGeneratedNamedFunctionIsNotAllowed)
tk.MustGetErrCode("create table t2 (c int(10), c1 varchar(256) default (upper(substring_index(user(),'x',1))))", errno.ErrDefValGeneratedNamedFunctionIsNotAllowed)

// add column with default expression for table t is forbidden in MySQL 8.0
tk.MustGetErrCode("alter table t add column c2 varchar(32) default (upper(substring_index(user(),'@',1)))", errno.ErrBinlogUnsafeSystemFunction)
tk.MustGetErrCode("alter table t add column c3 int default (upper(substring_index('fjks@jkkl','@',1)))", errno.ErrBinlogUnsafeSystemFunction)

// insert records
tk.Session().GetSessionVars().User = &auth.UserIdentity{Username: "root", Hostname: "localhost"}
tk.MustExec("insert into t(c) values (1),(2),(3)")
tk.MustGetErrCode("insert into t1(c) values (1)", errno.ErrTruncatedWrongValue)
tk.Session().GetSessionVars().User = &auth.UserIdentity{Username: "xyz", Hostname: "localhost"}
tk.MustExec("insert into t(c) values (4),(5),(6)")

rows := tk.MustQuery("SELECT c1 from t order by c").Rows()
for i, row := range rows {
d, ok := row[0].(string)
require.True(t, ok)
if i < 3 {
require.Equal(t, "ROOT", d)
} else {
require.Equal(t, "XYZ", d)
}
}

tk.MustQuery("show create table t").Check(testkit.Rows(
"t CREATE TABLE `t` (\n" +
" `c` int(10) DEFAULT NULL,\n" +
" `c1` varchar(256) DEFAULT upper(substring_index(user(), _utf8mb4''@'', 1))\n" +
") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin"))
tk.MustQuery("show create table t1").Check(testkit.Rows(
"t1 CREATE TABLE `t1` (\n" +
" `c` int(10) DEFAULT NULL,\n" +
" `c1` int(11) DEFAULT upper(substring_index(user(), _utf8mb4''@'', 1))\n" +
") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin"))
tk.MustExec("alter table t1 modify column c1 varchar(30) default 'xx';")
tk.MustQuery("show create table t1").Check(testkit.Rows(
"t1 CREATE TABLE `t1` (\n" +
" `c` int(10) DEFAULT NULL,\n" +
" `c1` varchar(30) DEFAULT 'xx'\n" +
") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin"))
tk.MustExec("alter table t1 modify column c1 varchar(32) default (upper(substring_index(user(),'@',1)));")
tk.MustQuery("show create table t1").Check(testkit.Rows(
"t1 CREATE TABLE `t1` (\n" +
" `c` int(10) DEFAULT NULL,\n" +
" `c1` varchar(32) DEFAULT upper(substring_index(user(), _utf8mb4''@'', 1))\n" +
") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin"))
}

func TestChangingDBCharset(t *testing.T) {
store := testkit.CreateMockStore(t, mockstore.WithDDLChecker())

Expand Down
50 changes: 42 additions & 8 deletions pkg/ddl/ddl_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -1264,6 +1264,17 @@ func columnDefToCol(ctx sessionctx.Context, offset int, colDef *ast.ColumnDef, o
return col, constraints, nil
}

func restoreFuncCall(expr *ast.FuncCallExpr) (string, error) {
var sb strings.Builder
restoreFlags := format.RestoreStringSingleQuotes | format.RestoreKeyWordLowercase | format.RestoreNameBackQuotes |
format.RestoreSpacesAroundBinaryOperation
restoreCtx := format.NewRestoreCtx(restoreFlags, &sb)
if err := expr.Restore(restoreCtx); err != nil {
return "", err
}
return sb.String(), nil
}

// getFuncCallDefaultValue gets the default column value of function-call expression.
func getFuncCallDefaultValue(col *table.Column, option *ast.ColumnOption, expr *ast.FuncCallExpr) (any, bool, error) {
switch expr.FnName.L {
Expand Down Expand Up @@ -1292,15 +1303,38 @@ func getFuncCallDefaultValue(col *table.Column, option *ast.ColumnOption, expr *
if err := expression.VerifyArgsWrapper(expr.FnName.L, len(expr.Args)); err != nil {
return nil, false, errors.Trace(err)
}
str, err := restoreFuncCall(expr)
if err != nil {
return nil, false, errors.Trace(err)
}
col.DefaultIsExpr = true
var sb strings.Builder
restoreFlags := format.RestoreStringSingleQuotes | format.RestoreKeyWordLowercase | format.RestoreNameBackQuotes |
format.RestoreSpacesAroundBinaryOperation
restoreCtx := format.NewRestoreCtx(restoreFlags, &sb)
if err := expr.Restore(restoreCtx); err != nil {
return "", false, err
return str, false, nil
case ast.Upper:
// Support UPPER(SUBSTRING_INDEX(USER(), '@', 1)).
if err := expression.VerifyArgsWrapper(expr.FnName.L, len(expr.Args)); err != nil {
return nil, false, errors.Trace(err)
}
if substringIndexFunc, ok := expr.Args[0].(*ast.FuncCallExpr); ok && substringIndexFunc.FnName.L == ast.SubstringIndex {
if err := expression.VerifyArgsWrapper(substringIndexFunc.FnName.L, len(substringIndexFunc.Args)); err != nil {
return nil, false, errors.Trace(err)
}
if userFunc, ok := substringIndexFunc.Args[0].(*ast.FuncCallExpr); ok && userFunc.FnName.L == ast.User {
if err := expression.VerifyArgsWrapper(userFunc.FnName.L, len(userFunc.Args)); err != nil {
return nil, false, errors.Trace(err)
}
valExpr, isValue := substringIndexFunc.Args[1].(ast.ValueExpr)
if !isValue || valExpr.GetString() != "@" {
return nil, false, dbterror.ErrDefValGeneratedNamedFunctionIsNotAllowed.GenWithStackByArgs(col.Name.String(), valExpr)
}
str, err := restoreFuncCall(expr)
if err != nil {
return nil, false, errors.Trace(err)
}
col.DefaultIsExpr = true
return str, false, nil
}
}
return sb.String(), false, nil
return nil, false, dbterror.ErrDefValGeneratedNamedFunctionIsNotAllowed.GenWithStackByArgs(col.Name.String(), expr.FnName.String())
default:
return nil, false, dbterror.ErrDefValGeneratedNamedFunctionIsNotAllowed.GenWithStackByArgs(col.Name.String(), expr.FnName.String())
}
Expand Down Expand Up @@ -4182,7 +4216,7 @@ func CreateNewColumn(ctx sessionctx.Context, schema *model.DBInfo, spec *ast.Alt
return nil, errors.Trace(err)
}
return nil, errors.Trace(dbterror.ErrAddColumnWithSequenceAsDefault.GenWithStackByArgs(specNewColumn.Name.Name.O))
case ast.Rand, ast.UUID, ast.UUIDToBin:
case ast.Rand, ast.UUID, ast.UUIDToBin, ast.Upper:
return nil, errors.Trace(dbterror.ErrBinlogUnsafeSystemFunction.GenWithStackByArgs())
}
}
Expand Down
2 changes: 2 additions & 0 deletions pkg/parser/parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2949,6 +2949,8 @@ func TestDDL(t *testing.T) {
{"create table t (d date default (current_date()))", true, "CREATE TABLE `t` (`d` DATE DEFAULT CURRENT_DATE())"},
{"create table t (d date default (curdate()))", true, "CREATE TABLE `t` (`d` DATE DEFAULT CURRENT_DATE())"},
{"create table t (d date default curdate())", true, "CREATE TABLE `t` (`d` DATE DEFAULT CURRENT_DATE())"},
{"create table t (a int default upper(substring_index(user(),'@',1)))", true, "CREATE TABLE `t` (`a` INT DEFAULT UPPER(SUBSTRING_INDEX(USER(), _UTF8MB4'@', 1)))"},
{"create table t (a int default (upper(substring_index(user(),'@',1))))", true, "CREATE TABLE `t` (`a` INT DEFAULT UPPER(SUBSTRING_INDEX(USER(), _UTF8MB4'@', 1)))"},

// For table option `ENCRYPTION`
{"create table t (a int) encryption = 'n';", true, "CREATE TABLE `t` (`a` INT) ENCRYPTION = 'n'"},
Expand Down

0 comments on commit 4749504

Please sign in to comment.