diff --git a/pkg/ddl/db_integration_test.go b/pkg/ddl/db_integration_test.go index 124e602c3e062..f2d697badcd12 100644 --- a/pkg/ddl/db_integration_test.go +++ b/pkg/ddl/db_integration_test.go @@ -1606,6 +1606,65 @@ func TestDefaultColumnWithRand(t *testing.T) { tk.MustGetErrCode("CREATE TABLE t3 (c int, c1 int default a_function_not_supported_yet());", errno.ErrDefValGeneratedNamedFunctionIsNotAllowed) } +func TestDefaultColumnWithUpper(t *testing.T) { + store := testkit.CreateMockStoreWithSchemaLease(t, testLease) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t, t1, t2") + + // create table + tk.MustExec("create table t (c int(10), c1 varchar(256) default (upper(substring_index(user(),'@',1))))") + tk.MustExec("create table t1 (c int(10), c1 int default (upper(substring_index(user(),_utf8mb4'@',1))))") + tk.MustGetErrCode("create table t2 (c int(10), c1 varchar(256) default (substring_index(user(),'@',1)))", errno.ErrDefValGeneratedNamedFunctionIsNotAllowed) + tk.MustGetErrCode("create table t2 (c int(10), c1 varchar(256) default (upper(substring_index('fjks@jkkl','@',1))))", errno.ErrDefValGeneratedNamedFunctionIsNotAllowed) + tk.MustGetErrCode("create table t2 (c int(10), c1 varchar(256) default (upper(substring_index(user(),'x',1))))", errno.ErrDefValGeneratedNamedFunctionIsNotAllowed) + + // add column with default expression for table t is forbidden in MySQL 8.0 + tk.MustGetErrCode("alter table t add column c2 varchar(32) default (upper(substring_index(user(),'@',1)))", errno.ErrBinlogUnsafeSystemFunction) + tk.MustGetErrCode("alter table t add column c3 int default (upper(substring_index('fjks@jkkl','@',1)))", errno.ErrBinlogUnsafeSystemFunction) + + // insert records + tk.Session().GetSessionVars().User = &auth.UserIdentity{Username: "root", Hostname: "localhost"} + tk.MustExec("insert into t(c) values (1),(2),(3)") + tk.MustGetErrCode("insert into t1(c) values (1)", errno.ErrTruncatedWrongValue) + tk.Session().GetSessionVars().User = &auth.UserIdentity{Username: "xyz", Hostname: "localhost"} + tk.MustExec("insert into t(c) values (4),(5),(6)") + + rows := tk.MustQuery("SELECT c1 from t order by c").Rows() + for i, row := range rows { + d, ok := row[0].(string) + require.True(t, ok) + if i < 3 { + require.Equal(t, "ROOT", d) + } else { + require.Equal(t, "XYZ", d) + } + } + + tk.MustQuery("show create table t").Check(testkit.Rows( + "t CREATE TABLE `t` (\n" + + " `c` int(10) DEFAULT NULL,\n" + + " `c1` varchar(256) DEFAULT upper(substring_index(user(), _utf8mb4''@'', 1))\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin")) + tk.MustQuery("show create table t1").Check(testkit.Rows( + "t1 CREATE TABLE `t1` (\n" + + " `c` int(10) DEFAULT NULL,\n" + + " `c1` int(11) DEFAULT upper(substring_index(user(), _utf8mb4''@'', 1))\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin")) + tk.MustExec("alter table t1 modify column c1 varchar(30) default 'xx';") + tk.MustQuery("show create table t1").Check(testkit.Rows( + "t1 CREATE TABLE `t1` (\n" + + " `c` int(10) DEFAULT NULL,\n" + + " `c1` varchar(30) DEFAULT 'xx'\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin")) + tk.MustExec("alter table t1 modify column c1 varchar(32) default (upper(substring_index(user(),'@',1)));") + tk.MustQuery("show create table t1").Check(testkit.Rows( + "t1 CREATE TABLE `t1` (\n" + + " `c` int(10) DEFAULT NULL,\n" + + " `c1` varchar(32) DEFAULT upper(substring_index(user(), _utf8mb4''@'', 1))\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin")) +} + func TestChangingDBCharset(t *testing.T) { store := testkit.CreateMockStore(t, mockstore.WithDDLChecker()) diff --git a/pkg/ddl/ddl_api.go b/pkg/ddl/ddl_api.go index 31c4c8f5d1bfb..be30c3a52238a 100644 --- a/pkg/ddl/ddl_api.go +++ b/pkg/ddl/ddl_api.go @@ -1264,6 +1264,17 @@ func columnDefToCol(ctx sessionctx.Context, offset int, colDef *ast.ColumnDef, o return col, constraints, nil } +func restoreFuncCall(expr *ast.FuncCallExpr) (string, error) { + var sb strings.Builder + restoreFlags := format.RestoreStringSingleQuotes | format.RestoreKeyWordLowercase | format.RestoreNameBackQuotes | + format.RestoreSpacesAroundBinaryOperation + restoreCtx := format.NewRestoreCtx(restoreFlags, &sb) + if err := expr.Restore(restoreCtx); err != nil { + return "", err + } + return sb.String(), nil +} + // getFuncCallDefaultValue gets the default column value of function-call expression. func getFuncCallDefaultValue(col *table.Column, option *ast.ColumnOption, expr *ast.FuncCallExpr) (any, bool, error) { switch expr.FnName.L { @@ -1292,15 +1303,38 @@ func getFuncCallDefaultValue(col *table.Column, option *ast.ColumnOption, expr * if err := expression.VerifyArgsWrapper(expr.FnName.L, len(expr.Args)); err != nil { return nil, false, errors.Trace(err) } + str, err := restoreFuncCall(expr) + if err != nil { + return nil, false, errors.Trace(err) + } col.DefaultIsExpr = true - var sb strings.Builder - restoreFlags := format.RestoreStringSingleQuotes | format.RestoreKeyWordLowercase | format.RestoreNameBackQuotes | - format.RestoreSpacesAroundBinaryOperation - restoreCtx := format.NewRestoreCtx(restoreFlags, &sb) - if err := expr.Restore(restoreCtx); err != nil { - return "", false, err + return str, false, nil + case ast.Upper: + // Support UPPER(SUBSTRING_INDEX(USER(), '@', 1)). + if err := expression.VerifyArgsWrapper(expr.FnName.L, len(expr.Args)); err != nil { + return nil, false, errors.Trace(err) + } + if substringIndexFunc, ok := expr.Args[0].(*ast.FuncCallExpr); ok && substringIndexFunc.FnName.L == ast.SubstringIndex { + if err := expression.VerifyArgsWrapper(substringIndexFunc.FnName.L, len(substringIndexFunc.Args)); err != nil { + return nil, false, errors.Trace(err) + } + if userFunc, ok := substringIndexFunc.Args[0].(*ast.FuncCallExpr); ok && userFunc.FnName.L == ast.User { + if err := expression.VerifyArgsWrapper(userFunc.FnName.L, len(userFunc.Args)); err != nil { + return nil, false, errors.Trace(err) + } + valExpr, isValue := substringIndexFunc.Args[1].(ast.ValueExpr) + if !isValue || valExpr.GetString() != "@" { + return nil, false, dbterror.ErrDefValGeneratedNamedFunctionIsNotAllowed.GenWithStackByArgs(col.Name.String(), valExpr) + } + str, err := restoreFuncCall(expr) + if err != nil { + return nil, false, errors.Trace(err) + } + col.DefaultIsExpr = true + return str, false, nil + } } - return sb.String(), false, nil + return nil, false, dbterror.ErrDefValGeneratedNamedFunctionIsNotAllowed.GenWithStackByArgs(col.Name.String(), expr.FnName.String()) default: return nil, false, dbterror.ErrDefValGeneratedNamedFunctionIsNotAllowed.GenWithStackByArgs(col.Name.String(), expr.FnName.String()) } @@ -4182,7 +4216,7 @@ func CreateNewColumn(ctx sessionctx.Context, schema *model.DBInfo, spec *ast.Alt return nil, errors.Trace(err) } return nil, errors.Trace(dbterror.ErrAddColumnWithSequenceAsDefault.GenWithStackByArgs(specNewColumn.Name.Name.O)) - case ast.Rand, ast.UUID, ast.UUIDToBin: + case ast.Rand, ast.UUID, ast.UUIDToBin, ast.Upper: return nil, errors.Trace(dbterror.ErrBinlogUnsafeSystemFunction.GenWithStackByArgs()) } } diff --git a/pkg/parser/parser_test.go b/pkg/parser/parser_test.go index 4117de8b8d820..c350807a0dadf 100644 --- a/pkg/parser/parser_test.go +++ b/pkg/parser/parser_test.go @@ -2949,6 +2949,8 @@ func TestDDL(t *testing.T) { {"create table t (d date default (current_date()))", true, "CREATE TABLE `t` (`d` DATE DEFAULT CURRENT_DATE())"}, {"create table t (d date default (curdate()))", true, "CREATE TABLE `t` (`d` DATE DEFAULT CURRENT_DATE())"}, {"create table t (d date default curdate())", true, "CREATE TABLE `t` (`d` DATE DEFAULT CURRENT_DATE())"}, + {"create table t (a int default upper(substring_index(user(),'@',1)))", true, "CREATE TABLE `t` (`a` INT DEFAULT UPPER(SUBSTRING_INDEX(USER(), _UTF8MB4'@', 1)))"}, + {"create table t (a int default (upper(substring_index(user(),'@',1))))", true, "CREATE TABLE `t` (`a` INT DEFAULT UPPER(SUBSTRING_INDEX(USER(), _UTF8MB4'@', 1)))"}, // For table option `ENCRYPTION` {"create table t (a int) encryption = 'n';", true, "CREATE TABLE `t` (`a` INT) ENCRYPTION = 'n'"},