diff --git a/pkg/ddl/BUILD.bazel b/pkg/ddl/BUILD.bazel index faa5152bfb6b6..e9e14da010ae5 100644 --- a/pkg/ddl/BUILD.bazel +++ b/pkg/ddl/BUILD.bazel @@ -46,6 +46,7 @@ go_library( "job_scheduler.go", "job_submitter.go", "job_worker.go", + "metabuild.go", "mock.go", "modify_column.go", "multi_schema_change.go", @@ -99,6 +100,7 @@ go_library( "//pkg/expression/exprctx", "//pkg/expression/exprstatic", "//pkg/infoschema", + "//pkg/infoschema/context", "//pkg/kv", "//pkg/lightning/backend", "//pkg/lightning/backend/external", @@ -107,6 +109,7 @@ go_library( "//pkg/lightning/config", "//pkg/meta", "//pkg/meta/autoid", + "//pkg/meta/metabuild", "//pkg/meta/model", "//pkg/metrics", "//pkg/owner", @@ -243,6 +246,7 @@ go_test( "job_submitter_test.go", "job_worker_test.go", "main_test.go", + "metabuild_test.go", "modify_column_test.go", "multi_schema_change_test.go", "mv_index_test.go", @@ -299,6 +303,7 @@ go_test( "//pkg/lightning/backend/external", "//pkg/meta", "//pkg/meta/autoid", + "//pkg/meta/metabuild", "//pkg/meta/model", "//pkg/parser", "//pkg/parser/ast", diff --git a/pkg/ddl/add_column.go b/pkg/ddl/add_column.go index 79e4170f0097e..10cca9134953f 100644 --- a/pkg/ddl/add_column.go +++ b/pkg/ddl/add_column.go @@ -32,6 +32,7 @@ import ( "github.com/pingcap/tidb/pkg/infoschema" "github.com/pingcap/tidb/pkg/meta" "github.com/pingcap/tidb/pkg/meta/autoid" + "github.com/pingcap/tidb/pkg/meta/metabuild" "github.com/pingcap/tidb/pkg/meta/model" "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/parser/charset" @@ -243,17 +244,17 @@ func CreateNewColumn(ctx sessionctx.Context, schema *model.DBInfo, spec *ast.Alt } } - tableCharset, tableCollate, err := ResolveCharsetCollation(ctx.GetSessionVars(), - ast.CharsetOpt{Chs: t.Meta().Charset, Col: t.Meta().Collate}, - ast.CharsetOpt{Chs: schema.Charset, Col: schema.Collate}, - ) + tableCharset, tableCollate, err := ResolveCharsetCollation([]ast.CharsetOpt{ + {Chs: t.Meta().Charset, Col: t.Meta().Collate}, + {Chs: schema.Charset, Col: schema.Collate}, + }, ctx.GetSessionVars().DefaultCollationForUTF8MB4) if err != nil { return nil, errors.Trace(err) } // Ignore table constraints now, they will be checked later. // We use length(t.Cols()) as the default offset firstly, we will change the column's offset later. col, _, err := buildColumnAndConstraint( - ctx, + NewMetaBuildContextWithSctx(ctx), len(t.Cols()), specNewColumn, nil, @@ -277,7 +278,7 @@ func CreateNewColumn(ctx sessionctx.Context, schema *model.DBInfo, spec *ast.Alt // outPriKeyConstraint is the primary key constraint out of column definition. For example: // `create table t1 (id int , age int, primary key(id));` func buildColumnAndConstraint( - ctx sessionctx.Context, + ctx *metabuild.Context, offset int, colDef *ast.ColumnDef, outPriKeyConstraint *ast.Constraint, @@ -289,20 +290,20 @@ func buildColumnAndConstraint( } // specifiedCollate refers to the last collate specified in colDef.Options. - chs, coll, err := getCharsetAndCollateInColumnDef(ctx.GetSessionVars(), colDef) + chs, coll, err := getCharsetAndCollateInColumnDef(colDef, ctx.GetDefaultCollationForUTF8MB4()) if err != nil { return nil, nil, errors.Trace(err) } - chs, coll, err = ResolveCharsetCollation(ctx.GetSessionVars(), - ast.CharsetOpt{Chs: chs, Col: coll}, - ast.CharsetOpt{Chs: tblCharset, Col: tblCollate}, - ) - chs, coll = OverwriteCollationWithBinaryFlag(ctx.GetSessionVars(), colDef, chs, coll) + chs, coll, err = ResolveCharsetCollation([]ast.CharsetOpt{ + {Chs: chs, Col: coll}, + {Chs: tblCharset, Col: tblCollate}, + }, ctx.GetDefaultCollationForUTF8MB4()) + chs, coll = OverwriteCollationWithBinaryFlag(colDef, chs, coll, ctx.GetDefaultCollationForUTF8MB4()) if err != nil { return nil, nil, errors.Trace(err) } - if err := setCharsetCollationFlenDecimal(colDef.Tp, colDef.Name.Name.O, chs, coll, ctx.GetSessionVars()); err != nil { + if err := setCharsetCollationFlenDecimal(ctx, colDef.Tp, colDef.Name.Name.O, chs, coll); err != nil { return nil, nil, errors.Trace(err) } decodeEnumSetBinaryLiteralToUTF8(colDef.Tp, chs) @@ -315,11 +316,11 @@ func buildColumnAndConstraint( // getCharsetAndCollateInColumnDef will iterate collate in the options, validate it by checking the charset // of column definition. If there's no collate in the option, the default collate of column's charset will be used. -func getCharsetAndCollateInColumnDef(sessVars *variable.SessionVars, def *ast.ColumnDef) (chs, coll string, err error) { +func getCharsetAndCollateInColumnDef(def *ast.ColumnDef, defaultUTF8MB4Coll string) (chs, coll string, err error) { chs = def.Tp.GetCharset() coll = def.Tp.GetCollate() if chs != "" && coll == "" { - if coll, err = GetDefaultCollation(sessVars, chs); err != nil { + if coll, err = GetDefaultCollation(chs, defaultUTF8MB4Coll); err != nil { return "", "", errors.Trace(err) } } @@ -345,14 +346,14 @@ func getCharsetAndCollateInColumnDef(sessVars *variable.SessionVars, def *ast.Co // CREATE TABLE t (a VARCHAR(255) BINARY) CHARSET utf8 COLLATE utf8_general_ci; // // The 'BINARY' sets the column collation to *_bin according to the table charset. -func OverwriteCollationWithBinaryFlag(sessVars *variable.SessionVars, colDef *ast.ColumnDef, chs, coll string) (newChs string, newColl string) { +func OverwriteCollationWithBinaryFlag(colDef *ast.ColumnDef, chs, coll string, defaultUTF8MB4Coll string) (newChs string, newColl string) { ignoreBinFlag := colDef.Tp.GetCharset() != "" && (colDef.Tp.GetCollate() != "" || containsColumnOption(colDef, ast.ColumnOptionCollate)) if ignoreBinFlag { return chs, coll } needOverwriteBinColl := types.IsString(colDef.Tp.GetType()) && mysql.HasBinaryFlag(colDef.Tp.GetFlag()) if needOverwriteBinColl { - newColl, err := GetDefaultCollation(sessVars, chs) + newColl, err := GetDefaultCollation(chs, defaultUTF8MB4Coll) if err != nil { return chs, coll } @@ -361,7 +362,7 @@ func OverwriteCollationWithBinaryFlag(sessVars *variable.SessionVars, colDef *as return chs, coll } -func setCharsetCollationFlenDecimal(tp *types.FieldType, colName, colCharset, colCollate string, sessVars *variable.SessionVars) error { +func setCharsetCollationFlenDecimal(ctx *metabuild.Context, tp *types.FieldType, colName, colCharset, colCollate string) error { var err error if typesNeedCharset(tp.GetType()) { tp.SetCharset(colCharset) @@ -389,7 +390,7 @@ func setCharsetCollationFlenDecimal(tp *types.FieldType, colName, colCharset, co return err } } - return checkTooBigFieldLengthAndTryAutoConvert(tp, colName, sessVars) + return checkTooBigFieldLengthAndTryAutoConvert(ctx, tp, colName) } func decodeEnumSetBinaryLiteralToUTF8(tp *types.FieldType, chs string) { @@ -422,8 +423,8 @@ func typesNeedCharset(tp byte) bool { // checkTooBigFieldLengthAndTryAutoConvert will check whether the field length is too big // in non-strict mode and varchar column. If it is, will try to adjust to blob or text, see issue #30328 -func checkTooBigFieldLengthAndTryAutoConvert(tp *types.FieldType, colName string, sessVars *variable.SessionVars) error { - if sessVars != nil && !sessVars.SQLMode.HasStrictMode() && tp.GetType() == mysql.TypeVarchar { +func checkTooBigFieldLengthAndTryAutoConvert(ctx *metabuild.Context, tp *types.FieldType, colName string) error { + if !ctx.GetSQLMode().HasStrictMode() && tp.GetType() == mysql.TypeVarchar { err := types.IsVarcharTooBigFieldLength(tp.GetFlen(), colName, tp.GetCharset()) if err != nil && terror.ErrorEqual(types.ErrTooBigFieldLength, err) { tp.SetType(mysql.TypeBlob) @@ -431,9 +432,9 @@ func checkTooBigFieldLengthAndTryAutoConvert(tp *types.FieldType, colName string return err } if tp.GetCharset() == charset.CharsetBin { - sessVars.StmtCtx.AppendWarning(dbterror.ErrAutoConvert.FastGenByArgs(colName, "VARBINARY", "BLOB")) + ctx.AppendWarning(dbterror.ErrAutoConvert.FastGenByArgs(colName, "VARBINARY", "BLOB")) } else { - sessVars.StmtCtx.AppendWarning(dbterror.ErrAutoConvert.FastGenByArgs(colName, "VARCHAR", "TEXT")) + ctx.AppendWarning(dbterror.ErrAutoConvert.FastGenByArgs(colName, "VARCHAR", "TEXT")) } } } @@ -442,7 +443,7 @@ func checkTooBigFieldLengthAndTryAutoConvert(tp *types.FieldType, colName string // columnDefToCol converts ColumnDef to Col and TableConstraints. // outPriKeyConstraint is the primary key constraint out of column definition. such as: create table t1 (id int , age int, primary key(id)); -func columnDefToCol(ctx sessionctx.Context, offset int, colDef *ast.ColumnDef, outPriKeyConstraint *ast.Constraint) (*table.Column, []*ast.Constraint, error) { +func columnDefToCol(ctx *metabuild.Context, offset int, colDef *ast.ColumnDef, outPriKeyConstraint *ast.Constraint) (*table.Column, []*ast.Constraint, error) { var constraints = make([]*ast.Constraint, 0) col := table.ToColumn(&model.ColumnInfo{ Offset: offset, @@ -511,7 +512,7 @@ func columnDefToCol(ctx sessionctx.Context, offset int, colDef *ast.ColumnDef, o col.AddFlag(mysql.UniqueKeyFlag) } case ast.ColumnOptionDefaultValue: - hasDefaultValue, err = SetDefaultValue(ctx, col, v) + hasDefaultValue, err = SetDefaultValue(ctx.GetExprCtx(), col, v) if err != nil { return nil, nil, errors.Trace(err) } @@ -527,7 +528,7 @@ func columnDefToCol(ctx sessionctx.Context, offset int, colDef *ast.ColumnDef, o col.AddFlag(mysql.OnUpdateNowFlag) setOnUpdateNow = true case ast.ColumnOptionComment: - err := setColumnComment(ctx, col, v) + err := setColumnComment(ctx.GetExprCtx(), col, v) if err != nil { return nil, nil, errors.Trace(err) } @@ -549,10 +550,10 @@ func columnDefToCol(ctx sessionctx.Context, offset int, colDef *ast.ColumnDef, o col.FieldType.SetCollate(v.StrValue) } case ast.ColumnOptionFulltext: - ctx.GetSessionVars().StmtCtx.AppendWarning(dbterror.ErrTableCantHandleFt.FastGenByArgs()) + ctx.AppendWarning(dbterror.ErrTableCantHandleFt.FastGenByArgs()) case ast.ColumnOptionCheck: if !variable.EnableCheckConstraint.Load() { - ctx.GetSessionVars().StmtCtx.AppendWarning(errCheckConstraintIsOff) + ctx.AppendWarning(errCheckConstraintIsOff) } else { // Check the column CHECK constraint dependency lazily, after fill all the name. // Extract column constraint from column option. @@ -570,7 +571,7 @@ func columnDefToCol(ctx sessionctx.Context, offset int, colDef *ast.ColumnDef, o } } - if err = processAndCheckDefaultValueAndColumn(ctx, col, outPriKeyConstraint, hasDefaultValue, setOnUpdateNow, hasNullFlag); err != nil { + if err = processAndCheckDefaultValueAndColumn(ctx.GetExprCtx(), col, outPriKeyConstraint, hasDefaultValue, setOnUpdateNow, hasNullFlag); err != nil { return nil, nil, errors.Trace(err) } return col, constraints, nil @@ -585,11 +586,11 @@ func isExplicitTimeStamp() bool { } // SetDefaultValue sets the default value of the column. -func SetDefaultValue(ctx sessionctx.Context, col *table.Column, option *ast.ColumnOption) (hasDefaultValue bool, err error) { +func SetDefaultValue(ctx expression.BuildContext, col *table.Column, option *ast.ColumnOption) (hasDefaultValue bool, err error) { var value any var isSeqExpr bool value, isSeqExpr, err = getDefaultValue( - exprctx.CtxWithHandleTruncateErrLevel(ctx.GetExprCtx(), errctx.LevelError), + exprctx.CtxWithHandleTruncateErrLevel(ctx, errctx.LevelError), col, option, ) if err != nil { @@ -604,10 +605,10 @@ func SetDefaultValue(ctx sessionctx.Context, col *table.Column, option *ast.Colu // When the default value is expression, we skip check and convert. if !col.DefaultIsExpr { - if hasDefaultValue, value, err = checkColumnDefaultValue(ctx.GetExprCtx(), col, value); err != nil { + if hasDefaultValue, value, err = checkColumnDefaultValue(ctx, col, value); err != nil { return hasDefaultValue, errors.Trace(err) } - value, err = convertTimestampDefaultValToUTC(ctx, value, col) + value, err = convertTimestampDefaultValToUTC(ctx.GetEvalCtx().TypeCtx(), value, col) if err != nil { return hasDefaultValue, errors.Trace(err) } @@ -895,8 +896,8 @@ func setDefaultValueWithBinaryPadding(col *table.Column, value any) error { return nil } -func setColumnComment(ctx sessionctx.Context, col *table.Column, option *ast.ColumnOption) error { - value, err := expression.EvalSimpleAst(ctx.GetExprCtx(), option.Expr) +func setColumnComment(ctx expression.BuildContext, col *table.Column, option *ast.ColumnOption) error { + value, err := expression.EvalSimpleAst(ctx, option.Expr) if err != nil { return errors.Trace(err) } @@ -904,12 +905,12 @@ func setColumnComment(ctx sessionctx.Context, col *table.Column, option *ast.Col return errors.Trace(err) } - sessionVars := ctx.GetSessionVars() - col.Comment, err = validateCommentLength(sessionVars.StmtCtx.ErrCtx(), sessionVars.SQLMode, col.Name.L, &col.Comment, dbterror.ErrTooLongFieldComment) + evalCtx := ctx.GetEvalCtx() + col.Comment, err = validateCommentLength(evalCtx.ErrCtx(), evalCtx.SQLMode(), col.Name.L, &col.Comment, dbterror.ErrTooLongFieldComment) return errors.Trace(err) } -func processAndCheckDefaultValueAndColumn(ctx sessionctx.Context, col *table.Column, +func processAndCheckDefaultValueAndColumn(ctx expression.BuildContext, col *table.Column, outPriKeyConstraint *ast.Constraint, hasDefaultValue, setOnUpdateNow, hasNullFlag bool) error { processDefaultValue(col, hasDefaultValue, setOnUpdateNow) processColumnFlags(col) @@ -921,7 +922,7 @@ func processAndCheckDefaultValueAndColumn(ctx sessionctx.Context, col *table.Col if err = checkColumnValueConstraint(col, col.GetCollate()); err != nil { return errors.Trace(err) } - if err = checkDefaultValue(ctx.GetExprCtx(), col, hasDefaultValue); err != nil { + if err = checkDefaultValue(ctx, col, hasDefaultValue); err != nil { return errors.Trace(err) } if err = checkColumnFieldLength(col); err != nil { @@ -1215,17 +1216,17 @@ func checkSequenceDefaultValue(col *table.Column) error { return dbterror.ErrColumnTypeUnsupportedNextValue.GenWithStackByArgs(col.ColumnInfo.Name.O) } -func convertTimestampDefaultValToUTC(ctx sessionctx.Context, defaultVal any, col *table.Column) (any, error) { +func convertTimestampDefaultValToUTC(tc types.Context, defaultVal any, col *table.Column) (any, error) { if defaultVal == nil || col.GetType() != mysql.TypeTimestamp { return defaultVal, nil } if vv, ok := defaultVal.(string); ok { if vv != types.ZeroDatetimeStr && !strings.EqualFold(vv, ast.CurrentTimestamp) { - t, err := types.ParseTime(ctx.GetSessionVars().StmtCtx.TypeCtx(), vv, col.GetType(), col.GetDecimal()) + t, err := types.ParseTime(tc, vv, col.GetType(), col.GetDecimal()) if err != nil { return defaultVal, errors.Trace(err) } - err = t.ConvertTimeZone(ctx.GetSessionVars().Location(), time.UTC) + err = t.ConvertTimeZone(tc.Location(), time.UTC) if err != nil { return defaultVal, errors.Trace(err) } diff --git a/pkg/ddl/create_table.go b/pkg/ddl/create_table.go index c566496dd23f4..42a01fb41e121 100644 --- a/pkg/ddl/create_table.go +++ b/pkg/ddl/create_table.go @@ -29,17 +29,19 @@ import ( "github.com/pingcap/tidb/pkg/ddl/notifier" "github.com/pingcap/tidb/pkg/ddl/placement" "github.com/pingcap/tidb/pkg/domain/infosync" + "github.com/pingcap/tidb/pkg/errctx" "github.com/pingcap/tidb/pkg/expression" "github.com/pingcap/tidb/pkg/infoschema" + infoschemactx "github.com/pingcap/tidb/pkg/infoschema/context" "github.com/pingcap/tidb/pkg/meta" "github.com/pingcap/tidb/pkg/meta/autoid" + "github.com/pingcap/tidb/pkg/meta/metabuild" "github.com/pingcap/tidb/pkg/meta/model" "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/parser/format" pmodel "github.com/pingcap/tidb/pkg/parser/model" "github.com/pingcap/tidb/pkg/parser/mysql" field_types "github.com/pingcap/tidb/pkg/parser/types" - "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/sessionctx/variable" "github.com/pingcap/tidb/pkg/table" "github.com/pingcap/tidb/pkg/table/tables" @@ -389,12 +391,12 @@ func findTableIDFromStore(t *meta.Meta, schemaID int64, tableName string) (int64 // BuildTableInfoFromAST builds model.TableInfo from a SQL statement. // Note: TableID and PartitionID are left as uninitialized value. func BuildTableInfoFromAST(s *ast.CreateTableStmt) (*model.TableInfo, error) { - return buildTableInfoWithCheck(mock.NewContext(), s, mysql.DefaultCharset, "", nil) + return buildTableInfoWithCheck(NewMetaBuildContextWithSctx(mock.NewContext()), s, mysql.DefaultCharset, "", nil) } // buildTableInfoWithCheck builds model.TableInfo from a SQL statement. // Note: TableID and PartitionIDs are left as uninitialized value. -func buildTableInfoWithCheck(ctx sessionctx.Context, s *ast.CreateTableStmt, dbCharset, dbCollate string, placementPolicyRef *model.PolicyRefInfo) (*model.TableInfo, error) { +func buildTableInfoWithCheck(ctx *metabuild.Context, s *ast.CreateTableStmt, dbCharset, dbCollate string, placementPolicyRef *model.PolicyRefInfo) (*model.TableInfo, error) { tbInfo, err := BuildTableInfoWithStmt(ctx, s, dbCharset, dbCollate, placementPolicyRef) if err != nil { return nil, err @@ -405,18 +407,18 @@ func buildTableInfoWithCheck(ctx sessionctx.Context, s *ast.CreateTableStmt, dbC if err = checkTableInfoValidWithStmt(ctx, tbInfo, s); err != nil { return nil, err } - if err = checkTableInfoValidExtra(ctx, tbInfo); err != nil { + if err = checkTableInfoValidExtra(ctx.GetExprCtx().GetEvalCtx().ErrCtx(), tbInfo); err != nil { return nil, err } return tbInfo, nil } // CheckTableInfoValidWithStmt exposes checkTableInfoValidWithStmt to SchemaTracker. Maybe one day we can delete it. -func CheckTableInfoValidWithStmt(ctx sessionctx.Context, tbInfo *model.TableInfo, s *ast.CreateTableStmt) (err error) { +func CheckTableInfoValidWithStmt(ctx *metabuild.Context, tbInfo *model.TableInfo, s *ast.CreateTableStmt) (err error) { return checkTableInfoValidWithStmt(ctx, tbInfo, s) } -func checkTableInfoValidWithStmt(ctx sessionctx.Context, tbInfo *model.TableInfo, s *ast.CreateTableStmt) (err error) { +func checkTableInfoValidWithStmt(ctx *metabuild.Context, tbInfo *model.TableInfo, s *ast.CreateTableStmt) (err error) { // All of these rely on the AST structure of expressions, which were // lost in the model (got serialized into strings). if err := checkGeneratedColumn(ctx, s.Table.Schema, tbInfo.Name, s.Cols); err != nil { @@ -424,15 +426,15 @@ func checkTableInfoValidWithStmt(ctx sessionctx.Context, tbInfo *model.TableInfo } // Check if table has a primary key if required. - if !ctx.GetSessionVars().InRestrictedSQL && ctx.GetSessionVars().PrimaryKeyRequired && len(tbInfo.GetPkName().String()) == 0 { + if ctx.PrimaryKeyRequired() && len(tbInfo.GetPkName().String()) == 0 { return infoschema.ErrTableWithoutPrimaryKey } if tbInfo.Partition != nil { - if err := checkPartitionDefinitionConstraints(ctx, tbInfo); err != nil { + if err := checkPartitionDefinitionConstraints(ctx.GetExprCtx(), tbInfo); err != nil { return errors.Trace(err) } if s.Partition != nil { - if err := checkPartitionFuncType(ctx, s.Partition.Expr, s.Table.Schema.O, tbInfo); err != nil { + if err := checkPartitionFuncType(ctx.GetExprCtx(), s.Partition.Expr, s.Table.Schema.O, tbInfo); err != nil { return errors.Trace(err) } if err := checkPartitioningKeysConstraints(ctx, s, tbInfo); err != nil { @@ -441,15 +443,19 @@ func checkTableInfoValidWithStmt(ctx sessionctx.Context, tbInfo *model.TableInfo } } if tbInfo.TTLInfo != nil { - if err := checkTTLInfoValid(ctx, s.Table.Schema, tbInfo); err != nil { - return errors.Trace(err) + var foreignKeyCheckIs infoschemactx.MetaOnlyInfoSchema + if is, ok := ctx.GetInfoSchema(); ok { + foreignKeyCheckIs = is + } + if err = checkTTLInfoValid(s.Table.Schema, tbInfo, foreignKeyCheckIs); err != nil { + return err } } return nil } -func checkGeneratedColumn(ctx sessionctx.Context, schemaName pmodel.CIStr, tableName pmodel.CIStr, colDefs []*ast.ColumnDef) error { +func checkGeneratedColumn(ctx *metabuild.Context, schemaName pmodel.CIStr, tableName pmodel.CIStr, colDefs []*ast.ColumnDef) error { var colName2Generation = make(map[string]columnGenerationInDDL, len(colDefs)) var exists bool var autoIncrementColumn string @@ -484,7 +490,7 @@ func checkGeneratedColumn(ctx sessionctx.Context, schemaName pmodel.CIStr, table // Check whether the generated column refers to any auto-increment columns if exists { - if !ctx.GetSessionVars().EnableAutoIncrementInGenerated { + if !ctx.EnableAutoIncrementInGenerated() { for colName, generated := range colName2Generation { if _, found := generated.dependences[autoIncrementColumn]; found { return dbterror.ErrGeneratedColumnRefAutoInc.GenWithStackByArgs(colName) @@ -507,7 +513,7 @@ func checkGeneratedColumn(ctx sessionctx.Context, schemaName pmodel.CIStr, table // name length and column count. // (checkTableInfoValid is also used in repairing objects which don't perform // these checks. Perhaps the two functions should be merged together regardless?) -func checkTableInfoValidExtra(ctx sessionctx.Context, tbInfo *model.TableInfo) error { +func checkTableInfoValidExtra(ec errctx.Context, tbInfo *model.TableInfo) error { if err := checkTooLongTable(tbInfo.Name); err != nil { return err } @@ -527,7 +533,7 @@ func checkTableInfoValidExtra(ctx sessionctx.Context, tbInfo *model.TableInfo) e if err := checkColumnsAttributes(tbInfo.Columns); err != nil { return errors.Trace(err) } - if err := checkGlobalIndexes(ctx, tbInfo); err != nil { + if err := checkGlobalIndexes(ec, tbInfo); err != nil { return errors.Trace(err) } @@ -613,7 +619,7 @@ func checkColumnAttributes(colName string, tp *types.FieldType) error { } // BuildSessionTemporaryTableInfo builds model.TableInfo from a SQL statement. -func BuildSessionTemporaryTableInfo(ctx sessionctx.Context, is infoschema.InfoSchema, s *ast.CreateTableStmt, dbCharset, dbCollate string, placementPolicyRef *model.PolicyRefInfo) (*model.TableInfo, error) { +func BuildSessionTemporaryTableInfo(ctx *metabuild.Context, is infoschema.InfoSchema, s *ast.CreateTableStmt, dbCharset, dbCollate string, placementPolicyRef *model.PolicyRefInfo) (*model.TableInfo, error) { ident := ast.Ident{Schema: s.Table.Schema, Name: s.Table.Name} //build tableInfo var tbInfo *model.TableInfo @@ -629,7 +635,7 @@ func BuildSessionTemporaryTableInfo(ctx sessionctx.Context, is infoschema.InfoSc if err != nil { return nil, infoschema.ErrTableNotExists.GenWithStackByArgs(referIdent.Schema, referIdent.Name) } - tbInfo, err = BuildTableInfoWithLike(ctx, ident, referTbl.Meta(), s) + tbInfo, err = BuildTableInfoWithLike(ident, referTbl.Meta(), s) } else { tbInfo, err = buildTableInfoWithCheck(ctx, s, dbCharset, dbCollate, placementPolicyRef) } @@ -637,16 +643,16 @@ func BuildSessionTemporaryTableInfo(ctx sessionctx.Context, is infoschema.InfoSc } // BuildTableInfoWithStmt builds model.TableInfo from a SQL statement without validity check -func BuildTableInfoWithStmt(ctx sessionctx.Context, s *ast.CreateTableStmt, dbCharset, dbCollate string, placementPolicyRef *model.PolicyRefInfo) (*model.TableInfo, error) { +func BuildTableInfoWithStmt(ctx *metabuild.Context, s *ast.CreateTableStmt, dbCharset, dbCollate string, placementPolicyRef *model.PolicyRefInfo) (*model.TableInfo, error) { colDefs := s.Cols - tableCharset, tableCollate, err := GetCharsetAndCollateInTableOption(ctx.GetSessionVars(), 0, s.Options) + tableCharset, tableCollate, err := GetCharsetAndCollateInTableOption(0, s.Options, ctx.GetDefaultCollationForUTF8MB4()) if err != nil { return nil, errors.Trace(err) } - tableCharset, tableCollate, err = ResolveCharsetCollation(ctx.GetSessionVars(), - ast.CharsetOpt{Chs: tableCharset, Col: tableCollate}, - ast.CharsetOpt{Chs: dbCharset, Col: dbCollate}, - ) + tableCharset, tableCollate, err = ResolveCharsetCollation([]ast.CharsetOpt{ + {Chs: tableCharset, Col: tableCollate}, + {Chs: dbCharset, Col: dbCollate}, + }, ctx.GetDefaultCollationForUTF8MB4()) if err != nil { return nil, errors.Trace(err) } @@ -666,7 +672,7 @@ func BuildTableInfoWithStmt(ctx sessionctx.Context, s *ast.CreateTableStmt, dbCh if err != nil { return nil, errors.Trace(err) } - if err = setTemporaryType(ctx, tbInfo, s); err != nil { + if err = setTemporaryType(tbInfo, s); err != nil { return nil, errors.Trace(err) } @@ -676,16 +682,15 @@ func BuildTableInfoWithStmt(ctx sessionctx.Context, s *ast.CreateTableStmt, dbCh // set default shard row id bits and pre-split regions for table. if !tbInfo.HasClusteredIndex() && tbInfo.TempTableType == model.TempTableNone { - tbInfo.ShardRowIDBits = ctx.GetSessionVars().ShardRowIDBits - tbInfo.PreSplitRegions = ctx.GetSessionVars().PreSplitRegions + tbInfo.ShardRowIDBits = ctx.GetShardRowIDBits() + tbInfo.PreSplitRegions = ctx.GetPreSplitRegions() } if err = handleTableOptions(s.Options, tbInfo); err != nil { return nil, errors.Trace(err) } - sessionVars := ctx.GetSessionVars() - if _, err = validateCommentLength(sessionVars.StmtCtx.ErrCtx(), sessionVars.SQLMode, tbInfo.Name.L, &tbInfo.Comment, dbterror.ErrTooLongTableComment); err != nil { + if _, err = validateCommentLength(ctx.GetExprCtx().GetEvalCtx().ErrCtx(), ctx.GetSQLMode(), tbInfo.Name.L, &tbInfo.Comment, dbterror.ErrTooLongTableComment); err != nil { return nil, errors.Trace(err) } @@ -703,7 +708,7 @@ func BuildTableInfoWithStmt(ctx sessionctx.Context, s *ast.CreateTableStmt, dbCh return tbInfo, nil } -func setTableAutoRandomBits(ctx sessionctx.Context, tbInfo *model.TableInfo, colDefs []*ast.ColumnDef) error { +func setTableAutoRandomBits(ctx *metabuild.Context, tbInfo *model.TableInfo, colDefs []*ast.ColumnDef) error { for _, col := range colDefs { if containsColumnOption(col, ast.ColumnOptionAutoRandom) { if col.Tp.GetType() != mysql.TypeLonglong { @@ -748,7 +753,7 @@ func setTableAutoRandomBits(ctx sessionctx.Context, tbInfo *model.TableInfo, col return dbterror.ErrInvalidAutoRandom.FastGenByArgs(autoid.AutoRandomIncrementalBitsTooSmall) } msg := fmt.Sprintf(autoid.AutoRandomAvailableAllocTimesNote, shardFmt.IncrementalBitsCapacity()) - ctx.GetSessionVars().StmtCtx.AppendNote(errors.NewNoStackError(msg)) + ctx.AppendNote(errors.NewNoStackError(msg)) } } return nil @@ -851,7 +856,7 @@ func handleTableOptions(options []*ast.TableOption, tbInfo *model.TableInfo) err return nil } -func setTemporaryType(_ sessionctx.Context, tbInfo *model.TableInfo, s *ast.CreateTableStmt) error { +func setTemporaryType(tbInfo *model.TableInfo, s *ast.CreateTableStmt) error { switch s.TemporaryKeyword { case ast.TemporaryGlobal: tbInfo.TempTableType = model.TempTableGlobal @@ -868,7 +873,7 @@ func setTemporaryType(_ sessionctx.Context, tbInfo *model.TableInfo, s *ast.Crea } func buildColumnsAndConstraints( - ctx sessionctx.Context, + ctx *metabuild.Context, colDefs []*ast.ColumnDef, constraints []*ast.Constraint, tblCharset string, @@ -891,13 +896,13 @@ func buildColumnsAndConstraints( case mysql.TypeTiny: // No warning for BOOL-like tinyint(1) if colDef.Tp.GetFlen() != types.UnspecifiedLength && colDef.Tp.GetFlen() != 1 { - ctx.GetSessionVars().StmtCtx.AppendWarning( + ctx.AppendWarning( dbterror.ErrWarnDeprecatedIntegerDisplayWidth.FastGenByArgs(), ) } case mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong: if colDef.Tp.GetFlen() != types.UnspecifiedLength { - ctx.GetSessionVars().StmtCtx.AppendWarning( + ctx.AppendWarning( dbterror.ErrWarnDeprecatedIntegerDisplayWidth.FastGenByArgs(), ) } @@ -909,7 +914,7 @@ func buildColumnsAndConstraints( } col.State = model.StatePublic if mysql.HasZerofillFlag(col.GetFlag()) { - ctx.GetSessionVars().StmtCtx.AppendWarning( + ctx.AppendWarning( dbterror.ErrWarnDeprecatedZerofill.FastGenByArgs(), ) } @@ -1081,13 +1086,13 @@ func setColumnFlagWithConstraint(colMap map[string]*table.Column, v *ast.Constra } // BuildTableInfoWithLike builds a new table info according to CREATE TABLE ... LIKE statement. -func BuildTableInfoWithLike(ctx sessionctx.Context, ident ast.Ident, referTblInfo *model.TableInfo, s *ast.CreateTableStmt) (*model.TableInfo, error) { +func BuildTableInfoWithLike(ident ast.Ident, referTblInfo *model.TableInfo, s *ast.CreateTableStmt) (*model.TableInfo, error) { // Check the referred table is a real table object. if referTblInfo.IsSequence() || referTblInfo.IsView() { return nil, dbterror.ErrWrongObject.GenWithStackByArgs(ident.Schema, referTblInfo.Name, "BASE TABLE") } tblInfo := *referTblInfo - if err := setTemporaryType(ctx, &tblInfo, s); err != nil { + if err := setTemporaryType(&tblInfo, s); err != nil { return nil, errors.Trace(err) } // Check non-public column and adjust column offset. @@ -1138,7 +1143,7 @@ func renameCheckConstraint(tblInfo *model.TableInfo) { // BuildTableInfo creates a TableInfo. func BuildTableInfo( - ctx sessionctx.Context, + ctx *metabuild.Context, tableName pmodel.CIStr, cols []*table.Column, constraints []*ast.Constraint, @@ -1204,7 +1209,7 @@ func BuildTableInfo( return nil, err } isSingleIntPK := isSingleIntPK(constr, lastCol) - if ShouldBuildClusteredIndex(ctx, constr.Option, isSingleIntPK) { + if ShouldBuildClusteredIndex(ctx.GetClusteredIndexDefMode(), constr.Option, isSingleIntPK) { if isSingleIntPK { tbInfo.PKIsHandle = true } else { @@ -1224,7 +1229,7 @@ func BuildTableInfo( } if constr.Tp == ast.ConstraintFulltext { - ctx.GetSessionVars().StmtCtx.AppendWarning(dbterror.ErrTableCantHandleFt.FastGenByArgs()) + ctx.AppendWarning(dbterror.ErrTableCantHandleFt.FastGenByArgs()) continue } @@ -1246,7 +1251,7 @@ func BuildTableInfo( // check constraint if constr.Tp == ast.ConstraintCheck { if !variable.EnableCheckConstraint.Load() { - ctx.GetSessionVars().StmtCtx.AppendWarning(errCheckConstraintIsOff) + ctx.AppendWarning(errCheckConstraintIsOff) continue } // Since column check constraint dependency has been done in columnDefToCol. @@ -1324,8 +1329,7 @@ func BuildTableInfo( if len(hiddenCols) > 0 { AddIndexColumnFlag(tbInfo, idxInfo) } - sessionVars := ctx.GetSessionVars() - _, err = validateCommentLength(sessionVars.StmtCtx.ErrCtx(), sessionVars.SQLMode, idxInfo.Name.String(), &idxInfo.Comment, dbterror.ErrTooLongIndexComment) + _, err = validateCommentLength(ctx.GetExprCtx().GetEvalCtx().ErrCtx(), ctx.GetSQLMode(), idxInfo.Name.String(), &idxInfo.Comment, dbterror.ErrTooLongIndexComment) if err != nil { return nil, errors.Trace(err) } @@ -1358,7 +1362,7 @@ func precheckBuildHiddenColumnInfo( return nil } -func buildHiddenColumnInfoWithCheck(ctx sessionctx.Context, indexPartSpecifications []*ast.IndexPartSpecification, indexName pmodel.CIStr, tblInfo *model.TableInfo, existCols []*table.Column) ([]*model.ColumnInfo, error) { +func buildHiddenColumnInfoWithCheck(ctx *metabuild.Context, indexPartSpecifications []*ast.IndexPartSpecification, indexName pmodel.CIStr, tblInfo *model.TableInfo, existCols []*table.Column) ([]*model.ColumnInfo, error) { if err := precheckBuildHiddenColumnInfo(indexPartSpecifications, indexName); err != nil { return nil, err } @@ -1366,7 +1370,7 @@ func buildHiddenColumnInfoWithCheck(ctx sessionctx.Context, indexPartSpecificati } // BuildHiddenColumnInfo builds hidden column info. -func BuildHiddenColumnInfo(ctx sessionctx.Context, indexPartSpecifications []*ast.IndexPartSpecification, indexName pmodel.CIStr, tblInfo *model.TableInfo, existCols []*table.Column) ([]*model.ColumnInfo, error) { +func BuildHiddenColumnInfo(ctx *metabuild.Context, indexPartSpecifications []*ast.IndexPartSpecification, indexName pmodel.CIStr, tblInfo *model.TableInfo, existCols []*table.Column) ([]*model.ColumnInfo, error) { hiddenCols := make([]*model.ColumnInfo, 0, len(indexPartSpecifications)) for i, idxPart := range indexPartSpecifications { if idxPart.Expr == nil { @@ -1391,8 +1395,10 @@ func BuildHiddenColumnInfo(ctx sessionctx.Context, indexPartSpecifications []*as if err != nil { return nil, errors.Trace(err) } - expr, err := expression.BuildSimpleExpr(ctx.GetExprCtx(), idxPart.Expr, - expression.WithTableInfo(ctx.GetSessionVars().CurrentDB, tblInfo), + + exprCtx := ctx.GetExprCtx() + expr, err := expression.BuildSimpleExpr(exprCtx, idxPart.Expr, + expression.WithTableInfo(exprCtx.GetEvalCtx().CurrentDB(), tblInfo), expression.WithAllowCastArray(true), ) if err != nil { @@ -1434,7 +1440,7 @@ func BuildHiddenColumnInfo(ctx sessionctx.Context, indexPartSpecifications []*as if err = checkDependedColExist(checkDependencies, existCols); err != nil { return nil, errors.Trace(err) } - if !ctx.GetSessionVars().EnableAutoIncrementInGenerated { + if !ctx.EnableAutoIncrementInGenerated() { if err = checkExpressionIndexAutoIncrement(indexName.O, colInfo.Dependences, tblInfo); err != nil { return nil, errors.Trace(err) } @@ -1447,7 +1453,7 @@ func BuildHiddenColumnInfo(ctx sessionctx.Context, indexPartSpecifications []*as // addIndexForForeignKey uses to auto create an index for the foreign key if the table doesn't have any index cover the // foreign key columns. -func addIndexForForeignKey(ctx sessionctx.Context, tbInfo *model.TableInfo) error { +func addIndexForForeignKey(ctx *metabuild.Context, tbInfo *model.TableInfo) error { if len(tbInfo.ForeignKeys) == 0 { return nil } @@ -1499,9 +1505,9 @@ func isSingleIntPK(constr *ast.Constraint, lastCol *model.ColumnInfo) bool { } // ShouldBuildClusteredIndex is used to determine whether the CREATE TABLE statement should build a clustered index table. -func ShouldBuildClusteredIndex(ctx sessionctx.Context, opt *ast.IndexOption, isSingleIntPK bool) bool { +func ShouldBuildClusteredIndex(mode variable.ClusteredIndexDefMode, opt *ast.IndexOption, isSingleIntPK bool) bool { if opt == nil || opt.PrimaryKeyTp == pmodel.PrimaryKeyTypeDefault { - switch ctx.GetSessionVars().EnableClusteredIndex { + switch mode { case variable.ClusteredIndexDefModeOn: return true case variable.ClusteredIndexDefModeIntOnly: diff --git a/pkg/ddl/ddl_test.go b/pkg/ddl/ddl_test.go index f11b486b58791..6b2ecec8124ff 100644 --- a/pkg/ddl/ddl_test.go +++ b/pkg/ddl/ddl_test.go @@ -25,6 +25,7 @@ import ( "github.com/pingcap/tidb/pkg/infoschema" "github.com/pingcap/tidb/pkg/kv" "github.com/pingcap/tidb/pkg/meta" + "github.com/pingcap/tidb/pkg/meta/metabuild" "github.com/pingcap/tidb/pkg/meta/model" "github.com/pingcap/tidb/pkg/parser" "github.com/pingcap/tidb/pkg/parser/ast" @@ -32,7 +33,6 @@ import ( pmodel "github.com/pingcap/tidb/pkg/parser/model" "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/parser/terror" - "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/store/mockstore" "github.com/pingcap/tidb/pkg/tablecodec" "github.com/pingcap/tidb/pkg/testkit/testfailpoint" @@ -115,7 +115,7 @@ func TestGetIntervalFromPolicy(t *testing.T) { require.False(t, changed) } -func colDefStrToFieldType(t *testing.T, str string, ctx sessionctx.Context) *types.FieldType { +func colDefStrToFieldType(t *testing.T, str string, ctx *metabuild.Context) *types.FieldType { sqlA := "alter table t modify column a " + str stmt, err := parser.New().ParseOneStmt(sqlA, "", "") require.NoError(t, err) @@ -127,7 +127,7 @@ func colDefStrToFieldType(t *testing.T, str string, ctx sessionctx.Context) *typ } func TestModifyColumn(t *testing.T) { - ctx := mock.NewContext() + ctx := NewMetaBuildContextWithSctx(mock.NewContext()) tests := []struct { origin string to string diff --git a/pkg/ddl/executor.go b/pkg/ddl/executor.go index f49df041b0d38..7704cf453932c 100644 --- a/pkg/ddl/executor.go +++ b/pkg/ddl/executor.go @@ -245,16 +245,13 @@ func (e *executor) CreateSchema(ctx sessionctx.Context, stmt *ast.CreateDatabase } } if !explicitCollation && explicitCharset { - coll, err := getDefaultCollationForUTF8MB4(ctx.GetSessionVars(), charsetOpt.Chs) - if err != nil { - return err - } + coll := getDefaultCollationForUTF8MB4(charsetOpt.Chs, ctx.GetSessionVars().DefaultCollationForUTF8MB4) if len(coll) != 0 { charsetOpt.Col = coll } } dbInfo := &model.DBInfo{Name: stmt.Name} - chs, coll, err := ResolveCharsetCollation(ctx.GetSessionVars(), charsetOpt) + chs, coll, err := ResolveCharsetCollation([]ast.CharsetOpt{charsetOpt}, ctx.GetSessionVars().DefaultCollationForUTF8MB4) if err != nil { return errors.Trace(err) } @@ -335,7 +332,7 @@ func (e *executor) CreateSchemaWithInfo( func (e *executor) ModifySchemaCharsetAndCollate(ctx sessionctx.Context, stmt *ast.AlterDatabaseStmt, toCharset, toCollate string) (err error) { if toCollate == "" { - if toCollate, err = GetDefaultCollation(ctx.GetSessionVars(), toCharset); err != nil { + if toCollate, err = GetDefaultCollation(toCharset, ctx.GetSessionVars().DefaultCollationForUTF8MB4); err != nil { return errors.Trace(err) } } @@ -862,28 +859,21 @@ func checkTooLongForeignKey(fk pmodel.CIStr) error { return nil } -func getDefaultCollationForUTF8MB4(sessVars *variable.SessionVars, cs string) (string, error) { - if sessVars == nil || cs != charset.CharsetUTF8MB4 { - return "", nil +func getDefaultCollationForUTF8MB4(cs string, defaultUTF8MB4Coll string) string { + if cs == charset.CharsetUTF8MB4 { + return defaultUTF8MB4Coll } - defaultCollation, err := sessVars.GetSessionOrGlobalSystemVar(context.Background(), variable.DefaultCollationForUTF8MB4) - if err != nil { - return "", err - } - return defaultCollation, nil + return "" } // GetDefaultCollation returns the default collation for charset and handle the default collation for UTF8MB4. -func GetDefaultCollation(sessVars *variable.SessionVars, cs string) (string, error) { - coll, err := getDefaultCollationForUTF8MB4(sessVars, cs) - if err != nil { - return "", errors.Trace(err) - } +func GetDefaultCollation(cs string, defaultUTF8MB4Collation string) (string, error) { + coll := getDefaultCollationForUTF8MB4(cs, defaultUTF8MB4Collation) if coll != "" { return coll, nil } - coll, err = charset.GetDefaultCollation(cs) + coll, err := charset.GetDefaultCollation(cs) if err != nil { return "", errors.Trace(err) } @@ -893,7 +883,7 @@ func GetDefaultCollation(sessVars *variable.SessionVars, cs string) (string, err // ResolveCharsetCollation will resolve the charset and collate by the order of parameters: // * If any given ast.CharsetOpt is not empty, the resolved charset and collate will be returned. // * If all ast.CharsetOpts are empty, the default charset and collate will be returned. -func ResolveCharsetCollation(sessVars *variable.SessionVars, charsetOpts ...ast.CharsetOpt) (chs string, coll string, err error) { +func ResolveCharsetCollation(charsetOpts []ast.CharsetOpt, utf8MB4DefaultColl string) (chs string, coll string, err error) { for _, v := range charsetOpts { if v.Col != "" { collation, err := collate.GetCollationByName(v.Col) @@ -906,7 +896,7 @@ func ResolveCharsetCollation(sessVars *variable.SessionVars, charsetOpts ...ast. return collation.CharsetName, v.Col, nil } if v.Chs != "" { - coll, err := GetDefaultCollation(sessVars, v.Chs) + coll, err := GetDefaultCollation(v.Chs, utf8MB4DefaultColl) if err != nil { return "", "", errors.Trace(err) } @@ -914,10 +904,7 @@ func ResolveCharsetCollation(sessVars *variable.SessionVars, charsetOpts ...ast. } } chs, coll = charset.GetDefaultCharsetAndCollate() - utf8mb4Coll, err := getDefaultCollationForUTF8MB4(sessVars, chs) - if err != nil { - return "", "", errors.Trace(err) - } + utf8mb4Coll := getDefaultCollationForUTF8MB4(chs, utf8MB4DefaultColl) if utf8mb4Coll != "" { return chs, utf8mb4Coll, nil } @@ -958,7 +945,7 @@ func checkInvisibleIndexOnPK(tblInfo *model.TableInfo) error { } // checkGlobalIndex check if the index is allowed to have global index -func checkGlobalIndex(ctx sessionctx.Context, tblInfo *model.TableInfo, indexInfo *model.IndexInfo) error { +func checkGlobalIndex(ec errctx.Context, tblInfo *model.TableInfo, indexInfo *model.IndexInfo) error { pi := tblInfo.GetPartitionInfo() isPartitioned := pi != nil && pi.Type != pmodel.PartitionTypeNone if indexInfo.Global { @@ -980,15 +967,15 @@ func checkGlobalIndex(ctx sessionctx.Context, tblInfo *model.TableInfo, indexInf if inAllPartitionColumns { return dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs("Global Index including all columns in the partitioning expression") } - validateGlobalIndexWithGeneratedColumns(ctx.GetSessionVars().StmtCtx.ErrCtx(), tblInfo, indexInfo.Name.O, indexInfo.Columns) + validateGlobalIndexWithGeneratedColumns(ec, tblInfo, indexInfo.Name.O, indexInfo.Columns) } return nil } // checkGlobalIndexes check if global index is supported. -func checkGlobalIndexes(ctx sessionctx.Context, tblInfo *model.TableInfo) error { +func checkGlobalIndexes(ec errctx.Context, tblInfo *model.TableInfo) error { for _, indexInfo := range tblInfo.Indices { - err := checkGlobalIndex(ctx, tblInfo, indexInfo) + err := checkGlobalIndex(ec, tblInfo, indexInfo) if err != nil { return err } @@ -1026,11 +1013,12 @@ func (e *executor) CreateTable(ctx sessionctx.Context, s *ast.CreateTableStmt) ( } // build tableInfo + metaBuildCtx := NewMetaBuildContextWithSctx(ctx) var tbInfo *model.TableInfo if s.ReferTable != nil { - tbInfo, err = BuildTableInfoWithLike(ctx, ident, referTbl.Meta(), s) + tbInfo, err = BuildTableInfoWithLike(ident, referTbl.Meta(), s) } else { - tbInfo, err = BuildTableInfoWithStmt(ctx, s, schema.Charset, schema.Collate, schema.PlacementPolicyRef) + tbInfo, err = BuildTableInfoWithStmt(metaBuildCtx, s, schema.Charset, schema.Collate, schema.PlacementPolicyRef) } if err != nil { return errors.Trace(err) @@ -1040,7 +1028,7 @@ func (e *executor) CreateTable(ctx sessionctx.Context, s *ast.CreateTableStmt) ( rewritePartitionQueryString(ctx, s.Partition, tbInfo) } - if err = checkTableInfoValidWithStmt(ctx, tbInfo, s); err != nil { + if err = checkTableInfoValidWithStmt(metaBuildCtx, tbInfo, s); err != nil { return err } if err = checkTableForeignKeysValid(ctx, is, schema.Name.L, tbInfo); err != nil { @@ -1097,7 +1085,7 @@ func (e *executor) createTableWithInfoJob( } } - if err := checkTableInfoValidExtra(ctx, tbInfo); err != nil { + if err := checkTableInfoValidExtra(ctx.GetSessionVars().StmtCtx.ErrCtx(), tbInfo); err != nil { return nil, err } @@ -1522,7 +1510,7 @@ func (e *executor) CreateView(ctx sessionctx.Context, s *ast.CreateViewStmt) (er tblCollate = v } - tbInfo, err := BuildTableInfo(ctx, s.ViewName.Name, cols, nil, tblCharset, tblCollate) + tbInfo, err := BuildTableInfo(NewMetaBuildContextWithSctx(ctx), s.ViewName.Name, cols, nil, tblCharset, tblCollate) if err != nil { return err } @@ -1585,7 +1573,7 @@ func isIgnorableSpec(tp ast.AlterTableType) bool { // GetCharsetAndCollateInTableOption will iterate the charset and collate in the options, // and returns the last charset and collate in options. If there is no charset in the options, // the returns charset will be "", the same as collate. -func GetCharsetAndCollateInTableOption(sessVars *variable.SessionVars, startIdx int, options []*ast.TableOption) (chs, coll string, err error) { +func GetCharsetAndCollateInTableOption(startIdx int, options []*ast.TableOption, defaultUTF8MB4Coll string) (chs, coll string, err error) { for i := startIdx; i < len(options); i++ { opt := options[i] // we set the charset to the last option. example: alter table t charset latin1 charset utf8 collate utf8_bin; @@ -1602,10 +1590,7 @@ func GetCharsetAndCollateInTableOption(sessVars *variable.SessionVars, startIdx return "", "", dbterror.ErrConflictingDeclarations.GenWithStackByArgs(chs, info.Name) } if len(coll) == 0 { - defaultColl, err := getDefaultCollationForUTF8MB4(sessVars, chs) - if err != nil { - return "", "", errors.Trace(err) - } + defaultColl := getDefaultCollationForUTF8MB4(chs, defaultUTF8MB4Coll) if len(defaultColl) == 0 { coll = info.DefaultCollation } else { @@ -1884,7 +1869,7 @@ func (e *executor) AlterTable(ctx context.Context, sctx sessionctx.Context, stmt continue } var toCharset, toCollate string - toCharset, toCollate, err = GetCharsetAndCollateInTableOption(sctx.GetSessionVars(), i, spec.Options) + toCharset, toCollate, err = GetCharsetAndCollateInTableOption(i, spec.Options, sctx.GetSessionVars().DefaultCollationForUTF8MB4) if err != nil { return err } @@ -2316,7 +2301,7 @@ func (e *executor) AddTablePartitions(ctx sessionctx.Context, ident ast.Ident, s tmp := *partInfo tmp.Definitions = append(pi.Definitions, tmp.Definitions...) clonedMeta.Partition = &tmp - if err := checkPartitionDefinitionConstraints(ctx, clonedMeta); err != nil { + if err := checkPartitionDefinitionConstraints(ctx.GetExprCtx(), clonedMeta); err != nil { if dbterror.ErrSameNamePartition.Equal(err) && spec.IfNotExists { ctx.GetSessionVars().StmtCtx.AppendNote(err) return nil @@ -2469,7 +2454,7 @@ func (e *executor) AlterTablePartitioning(ctx sessionctx.Context, ident ast.Iden } newMeta := meta.Clone() - err = buildTablePartitionInfo(ctx, spec.Partition, newMeta) + err = buildTablePartitionInfo(NewMetaBuildContextWithSctx(ctx), spec.Partition, newMeta) if err != nil { return err } @@ -2652,7 +2637,7 @@ func checkReorgPartitionDefs(ctx sessionctx.Context, action model.ActionType, tb default: return dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs("partition type") } - if err := checkPartitionDefinitionConstraints(ctx, clonedMeta); err != nil { + if err := checkPartitionDefinitionConstraints(ctx.GetExprCtx(), clonedMeta); err != nil { return errors.Trace(err) } if action == model.ActionReorganizePartition { @@ -2666,14 +2651,14 @@ func checkReorgPartitionDefs(ctx sessionctx.Context, action model.ActionType, tb lastAddingPartition := partInfo.Definitions[len(partInfo.Definitions)-1] lastOldPartition := pi.Definitions[lastPartIdx] if len(pi.Columns) > 0 { - newGtOld, err := checkTwoRangeColumns(ctx, &lastAddingPartition, &lastOldPartition, pi, tblInfo) + newGtOld, err := checkTwoRangeColumns(ctx.GetExprCtx(), &lastAddingPartition, &lastOldPartition, pi, tblInfo) if err != nil { return errors.Trace(err) } if newGtOld { return errors.Trace(dbterror.ErrRangeNotIncreasing) } - oldGtNew, err := checkTwoRangeColumns(ctx, &lastOldPartition, &lastAddingPartition, pi, tblInfo) + oldGtNew, err := checkTwoRangeColumns(ctx.GetExprCtx(), &lastOldPartition, &lastAddingPartition, pi, tblInfo) if err != nil { return errors.Trace(err) } @@ -3492,7 +3477,7 @@ func (e *executor) AlterColumn(ctx sessionctx.Context, ident ast.Ident, spec *as if IsAutoRandomColumnID(t.Meta(), col.ID) { return dbterror.ErrInvalidAutoRandom.GenWithStackByArgs(autoid.AutoRandomIncompatibleWithDefaultValueErrMsg) } - hasDefaultValue, err := SetDefaultValue(ctx, col, specNewColumn.Options[0]) + hasDefaultValue, err := SetDefaultValue(ctx.GetExprCtx(), col, specNewColumn.Options[0]) if err != nil { return errors.Trace(err) } @@ -3606,7 +3591,7 @@ func (e *executor) AlterTableCharsetAndCollate(ctx sessionctx.Context, ident ast if toCollate == "" { // Get the default collation of the charset. - toCollate, err = GetDefaultCollation(ctx.GetSessionVars(), toCharset) + toCollate, err = GetDefaultCollation(toCharset, ctx.GetSessionVars().DefaultCollationForUTF8MB4) if err != nil { return errors.Trace(err) } @@ -3738,7 +3723,7 @@ func (e *executor) AlterTableTTLInfoOrEnable(ctx sessionctx.Context, ident ast.I var job *model.Job if ttlInfo != nil { tblInfo.TTLInfo = ttlInfo - err = checkTTLInfoValid(ctx, ident.Schema, tblInfo) + err = checkTTLInfoValid(ident.Schema, tblInfo, is) if err != nil { return err } @@ -3968,10 +3953,10 @@ func checkAlterTableCharset(tblInfo *model.TableInfo, dbInfo *model.DBInfo, toCh } // This DDL will update the table charset to default charset. - origCharset, origCollate, err = ResolveCharsetCollation(nil, - ast.CharsetOpt{Chs: origCharset, Col: origCollate}, - ast.CharsetOpt{Chs: dbInfo.Charset, Col: dbInfo.Collate}, - ) + origCharset, origCollate, err = ResolveCharsetCollation([]ast.CharsetOpt{ + {Chs: origCharset, Col: origCollate}, + {Chs: dbInfo.Charset, Col: dbInfo.Collate}, + }, "") if err != nil { return doNothing, err } @@ -4551,7 +4536,7 @@ func (e *executor) CreatePrimaryKey(ctx sessionctx.Context, ti ast.Ident, indexN // After DDL job is put to the queue, and if the check fail, TiDB will run the DDL cancel logic. // The recover step causes DDL wait a few seconds, makes the unit test painfully slow. // For same reason, decide whether index is global here. - indexColumns, _, err := buildIndexColumns(ctx, tblInfo.Columns, indexPartSpecifications) + indexColumns, _, err := buildIndexColumns(NewMetaBuildContextWithSctx(ctx), tblInfo.Columns, indexPartSpecifications) if err != nil { return errors.Trace(err) } @@ -4683,8 +4668,9 @@ func (e *executor) createIndex(ctx sessionctx.Context, ti ast.Ident, keyType ast tblInfo := t.Meta() + metaBuildCtx := NewMetaBuildContextWithSctx(ctx) // Build hidden columns if necessary. - hiddenCols, err := buildHiddenColumnInfoWithCheck(ctx, indexPartSpecifications, indexName, t.Meta(), t.Cols()) + hiddenCols, err := buildHiddenColumnInfoWithCheck(metaBuildCtx, indexPartSpecifications, indexName, t.Meta(), t.Cols()) if err != nil { return err } @@ -4701,7 +4687,7 @@ func (e *executor) createIndex(ctx sessionctx.Context, ti ast.Ident, keyType ast // After DDL job is put to the queue, and if the check fail, TiDB will run the DDL cancel logic. // The recover step causes DDL wait a few seconds, makes the unit test painfully slow. // For same reason, decide whether index is global here. - indexColumns, _, err := buildIndexColumns(ctx, finalColumns, indexPartSpecifications) + indexColumns, _, err := buildIndexColumns(metaBuildCtx, finalColumns, indexPartSpecifications) if err != nil { return errors.Trace(err) } @@ -4744,7 +4730,7 @@ func (e *executor) createIndex(ctx sessionctx.Context, ti ast.Ident, keyType ast } if indexOption != nil && indexOption.Tp == pmodel.IndexTypeHypo { // for hypo-index - indexInfo, err := BuildIndexInfo(ctx, tblInfo.Columns, indexName, false, unique, + indexInfo, err := BuildIndexInfo(metaBuildCtx, tblInfo.Columns, indexName, false, unique, indexPartSpecifications, indexOption, model.StatePublic) if err != nil { return err @@ -5475,7 +5461,7 @@ func (e *executor) RepairTable(ctx sessionctx.Context, createStmt *ast.CreateTab } // It is necessary to specify the table.ID and partition.ID manually. - newTableInfo, err := buildTableInfoWithCheck(ctx, createStmt, oldTableInfo.Charset, oldTableInfo.Collate, oldTableInfo.PlacementPolicyRef) + newTableInfo, err := buildTableInfoWithCheck(NewMetaBuildContextWithSctx(ctx), createStmt, oldTableInfo.Charset, oldTableInfo.Collate, oldTableInfo.PlacementPolicyRef) if err != nil { return errors.Trace(err) } @@ -5561,7 +5547,7 @@ func (e *executor) CreateSequence(ctx sessionctx.Context, stmt *ast.CreateSequen return err } // TiDB describe the sequence within a tableInfo, as a same-level object of a table and view. - tbInfo, err := BuildTableInfo(ctx, ident.Name, nil, nil, "", "") + tbInfo, err := BuildTableInfo(NewMetaBuildContextWithSctx(ctx), ident.Name, nil, nil, "", "") if err != nil { return err } diff --git a/pkg/ddl/foreign_key.go b/pkg/ddl/foreign_key.go index 9b2f3f47aeb55..3721e28d679da 100644 --- a/pkg/ddl/foreign_key.go +++ b/pkg/ddl/foreign_key.go @@ -21,6 +21,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/tidb/pkg/infoschema" + infoschemactx "github.com/pingcap/tidb/pkg/infoschema/context" "github.com/pingcap/tidb/pkg/meta" "github.com/pingcap/tidb/pkg/meta/model" "github.com/pingcap/tidb/pkg/parser/ast" @@ -377,7 +378,7 @@ func isAcceptableForeignKeyColumnChange(newCol, originalCol, relatedCol *model.C return true } -func checkTableHasForeignKeyReferred(is infoschema.InfoSchema, schema, tbl string, ignoreTables []ast.Ident, fkCheck bool) *model.ReferredFKInfo { +func checkTableHasForeignKeyReferred(is infoschemactx.MetaOnlyInfoSchema, schema, tbl string, ignoreTables []ast.Ident, fkCheck bool) *model.ReferredFKInfo { if !fkCheck { return nil } diff --git a/pkg/ddl/index.go b/pkg/ddl/index.go index 44d13b4254305..e262bd507fd2d 100644 --- a/pkg/ddl/index.go +++ b/pkg/ddl/index.go @@ -46,6 +46,7 @@ import ( "github.com/pingcap/tidb/pkg/lightning/backend" litconfig "github.com/pingcap/tidb/pkg/lightning/config" "github.com/pingcap/tidb/pkg/meta" + "github.com/pingcap/tidb/pkg/meta/metabuild" "github.com/pingcap/tidb/pkg/meta/model" "github.com/pingcap/tidb/pkg/metrics" "github.com/pingcap/tidb/pkg/parser/ast" @@ -69,7 +70,6 @@ import ( decoder "github.com/pingcap/tidb/pkg/util/rowDecoder" "github.com/pingcap/tidb/pkg/util/size" "github.com/pingcap/tidb/pkg/util/sqlexec" - "github.com/pingcap/tidb/pkg/util/stringutil" "github.com/tikv/client-go/v2/oracle" "github.com/tikv/client-go/v2/tikv" kvutil "github.com/tikv/client-go/v2/util" @@ -84,22 +84,7 @@ const ( MaxCommentLength = 1024 ) -var ( - // SuppressErrorTooLongKeyKey is used by SchemaTracker to suppress err too long key error - SuppressErrorTooLongKeyKey stringutil.StringerStr = "suppressErrorTooLongKeyKey" -) - -func suppressErrorTooLongKeyForSchemaTracker(sctx sessionctx.Context) bool { - if sctx == nil { - return false - } - if suppress, ok := sctx.Value(SuppressErrorTooLongKeyKey).(bool); ok && suppress { - return true - } - return false -} - -func buildIndexColumns(ctx sessionctx.Context, columns []*model.ColumnInfo, indexPartSpecifications []*ast.IndexPartSpecification) ([]*model.IndexColumn, bool, error) { +func buildIndexColumns(ctx *metabuild.Context, columns []*model.ColumnInfo, indexPartSpecifications []*ast.IndexPartSpecification) ([]*model.IndexColumn, bool, error) { // Build offsets. idxParts := make([]*model.IndexColumn, 0, len(indexPartSpecifications)) var col *model.ColumnInfo @@ -112,8 +97,8 @@ func buildIndexColumns(ctx sessionctx.Context, columns []*model.ColumnInfo, inde if col == nil { return nil, false, dbterror.ErrKeyColumnDoesNotExits.GenWithStack("column does not exist: %s", ip.Column.Name) } - - if err := checkIndexColumn(ctx, col, ip.Length); err != nil { + // return error in strict sql mode + if err := checkIndexColumn(col, ip.Length, ctx != nil && (!ctx.GetSQLMode().HasStrictMode() || ctx.SuppressTooLongIndexErr())); err != nil { return nil, false, err } if col.FieldType.IsArray() { @@ -134,12 +119,12 @@ func buildIndexColumns(ctx sessionctx.Context, columns []*model.ColumnInfo, inde } sumLength += indexColumnLength - if !suppressErrorTooLongKeyForSchemaTracker(ctx) && sumLength > maxIndexLength { + if (ctx == nil || !ctx.SuppressTooLongIndexErr()) && sumLength > maxIndexLength { // The sum of all lengths must be shorter than the max length for prefix. // The multiple column index and the unique index in which the length sum exceeds the maximum size // will return an error instead produce a warning. - if ctx == nil || ctx.GetSessionVars().SQLMode.HasStrictMode() || mysql.HasUniKeyFlag(col.GetFlag()) || len(indexPartSpecifications) > 1 { + if ctx == nil || ctx.GetSQLMode().HasStrictMode() || mysql.HasUniKeyFlag(col.GetFlag()) || len(indexPartSpecifications) > 1 { return nil, false, dbterror.ErrTooLongKey.GenWithStackByArgs(sumLength, maxIndexLength) } // truncate index length and produce warning message in non-restrict sql mode. @@ -149,7 +134,7 @@ func buildIndexColumns(ctx sessionctx.Context, columns []*model.ColumnInfo, inde } indexColLen = maxIndexLength / colLenPerUint // produce warning message - ctx.GetSessionVars().StmtCtx.AppendWarning(dbterror.ErrTooLongKey.FastGenByArgs(sumLength, maxIndexLength)) + ctx.AppendWarning(dbterror.ErrTooLongKey.FastGenByArgs(sumLength, maxIndexLength)) } idxParts = append(idxParts, &model.IndexColumn{ @@ -210,7 +195,7 @@ func indexColumnsLen(cols []*model.ColumnInfo, idxCols []*model.IndexColumn) (co return } -func checkIndexColumn(ctx sessionctx.Context, col *model.ColumnInfo, indexColumnLen int) error { +func checkIndexColumn(col *model.ColumnInfo, indexColumnLen int, suppressTooLongKeyErr bool) error { if col.GetFlen() == 0 && (types.IsTypeChar(col.FieldType.GetType()) || types.IsTypeVarchar(col.FieldType.GetType())) { if col.Hidden { return errors.Trace(dbterror.ErrWrongKeyColumnFunctionalIndex.GenWithStackByArgs(col.GeneratedExprString)) @@ -274,8 +259,7 @@ func checkIndexColumn(ctx sessionctx.Context, col *model.ColumnInfo, indexColumn // Specified length must be shorter than the max length for prefix. maxIndexLength := config.GetGlobalConfig().MaxIndexLength if indexColumnLen > maxIndexLength { - if ctx == nil || (ctx.GetSessionVars().SQLMode.HasStrictMode() && !suppressErrorTooLongKeyForSchemaTracker(ctx)) { - // return error in strict sql mode + if !suppressTooLongKeyErr { return dbterror.ErrTooLongKey.GenWithStackByArgs(indexColumnLen, maxIndexLength) } } @@ -324,7 +308,7 @@ func calcBytesLengthForDecimal(m int) int { // BuildIndexInfo builds a new IndexInfo according to the index information. func BuildIndexInfo( - ctx sessionctx.Context, + ctx *metabuild.Context, allTableColumns []*model.ColumnInfo, indexName pmodel.CIStr, isPrimary bool, diff --git a/pkg/ddl/metabuild.go b/pkg/ddl/metabuild.go new file mode 100644 index 0000000000000..92d2cedb38b84 --- /dev/null +++ b/pkg/ddl/metabuild.go @@ -0,0 +1,43 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ddl + +import ( + "github.com/pingcap/tidb/pkg/meta/metabuild" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/util/intest" +) + +// NewMetaBuildContextWithSctx creates a new MetaBuildContext with the given session context. +func NewMetaBuildContextWithSctx(sctx sessionctx.Context, otherOpts ...metabuild.Option) *metabuild.Context { + intest.AssertNotNil(sctx) + sessVars := sctx.GetSessionVars() + intest.AssertNotNil(sessVars) + opts := []metabuild.Option{ + metabuild.WithExprCtx(sctx.GetExprCtx()), + metabuild.WithEnableAutoIncrementInGenerated(sessVars.EnableAutoIncrementInGenerated), + metabuild.WithPrimaryKeyRequired(!sessVars.InRestrictedSQL && sessVars.PrimaryKeyRequired), + metabuild.WithClusteredIndexDefMode(sessVars.EnableClusteredIndex), + metabuild.WithShardRowIDBits(sessVars.ShardRowIDBits), + metabuild.WithPreSplitRegions(sessVars.PreSplitRegions), + metabuild.WithInfoSchema(sctx.GetDomainInfoSchema()), + } + + if len(otherOpts) > 0 { + opts = append(opts, otherOpts...) + } + + return metabuild.NewContext(opts...) +} diff --git a/pkg/ddl/metabuild_test.go b/pkg/ddl/metabuild_test.go new file mode 100644 index 0000000000000..2f43896176360 --- /dev/null +++ b/pkg/ddl/metabuild_test.go @@ -0,0 +1,170 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ddl + +import ( + "testing" + + "github.com/pingcap/errors" + "github.com/pingcap/tidb/pkg/meta/metabuild" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + contextutil "github.com/pingcap/tidb/pkg/util/context" + "github.com/pingcap/tidb/pkg/util/deeptest" + "github.com/pingcap/tidb/pkg/util/mock" + "github.com/stretchr/testify/require" +) + +func TestNewMetaBuildContextWithSctx(t *testing.T) { + sqlMode := mysql.ModeStrictAllTables | mysql.ModeNoZeroDate + sctx := mock.NewContext() + sctx.GetSessionVars().SQLMode = sqlMode + sessVars := sctx.GetSessionVars() + cases := []struct { + field string + setSctx func(val any) + testVals []any + getter func(*metabuild.Context) any + check func(*metabuild.Context) + extra func() + }{ + { + field: "exprCtx", + check: func(ctx *metabuild.Context) { + require.Same(t, sctx.GetExprCtx(), ctx.GetExprCtx()) + require.Equal(t, sqlMode, ctx.GetSQLMode()) + require.Equal(t, sctx.GetSessionVars().DefaultCollationForUTF8MB4, ctx.GetDefaultCollationForUTF8MB4()) + require.Equal(t, "utf8mb4_bin", ctx.GetDefaultCollationForUTF8MB4()) + warn := errors.New("warn1") + note := errors.New("note1") + ctx.AppendWarning(warn) + ctx.AppendNote(note) + require.Equal(t, []contextutil.SQLWarn{ + {Level: contextutil.WarnLevelWarning, Err: warn}, + {Level: contextutil.WarnLevelNote, Err: note}, + }, ctx.GetExprCtx().GetEvalCtx().CopyWarnings(nil)) + }, + }, + { + field: "enableAutoIncrementInGenerated", + setSctx: func(val any) { + sessVars.EnableAutoIncrementInGenerated = val.(bool) + }, + testVals: []any{true, false}, + getter: func(ctx *metabuild.Context) any { + return ctx.EnableAutoIncrementInGenerated() + }, + }, + { + field: "primaryKeyRequired", + setSctx: func(val any) { + sessVars.PrimaryKeyRequired = val.(bool) + }, + testVals: []any{true, false}, + getter: func(ctx *metabuild.Context) any { + return ctx.PrimaryKeyRequired() + }, + extra: func() { + // `PrimaryKeyRequired` should always return false if `InRestrictedSQL` is true. + sessVars.PrimaryKeyRequired = true + sessVars.InRestrictedSQL = true + require.False(t, NewMetaBuildContextWithSctx(sctx).PrimaryKeyRequired()) + }, + }, + { + field: "clusteredIndexDefMode", + setSctx: func(val any) { + sessVars.EnableClusteredIndex = val.(variable.ClusteredIndexDefMode) + }, + testVals: []any{ + variable.ClusteredIndexDefModeIntOnly, + variable.ClusteredIndexDefModeOff, + variable.ClusteredIndexDefModeOn, + }, + getter: func(ctx *metabuild.Context) any { + return ctx.GetClusteredIndexDefMode() + }, + }, + { + field: "shardRowIDBits", + setSctx: func(val any) { + sessVars.ShardRowIDBits = val.(uint64) + }, + testVals: []any{uint64(variable.DefShardRowIDBits), uint64(6)}, + getter: func(ctx *metabuild.Context) any { + return ctx.GetShardRowIDBits() + }, + }, + { + field: "preSplitRegions", + setSctx: func(val any) { + sessVars.PreSplitRegions = val.(uint64) + }, + testVals: []any{uint64(variable.DefPreSplitRegions), uint64(123)}, + getter: func(ctx *metabuild.Context) any { + return ctx.GetPreSplitRegions() + }, + }, + { + field: "suppressTooLongIndexErr", + extra: func() { + require.True(t, + NewMetaBuildContextWithSctx(sctx, metabuild.WithSuppressTooLongIndexErr(true)). + SuppressTooLongIndexErr(), + ) + require.False(t, + NewMetaBuildContextWithSctx(sctx, metabuild.WithSuppressTooLongIndexErr(false)). + SuppressTooLongIndexErr(), + ) + }, + }, + { + field: "is", + check: func(ctx *metabuild.Context) { + sctxInfoSchema := sctx.GetDomainInfoSchema() + require.NotNil(t, sctxInfoSchema) + is, ok := ctx.GetInfoSchema() + require.True(t, ok) + require.Same(t, sctxInfoSchema, is) + }, + }, + } + + allFields := make([]string, 0, len(cases)) + for _, f := range cases { + t.Run(f.field, func(t *testing.T) { + require.NotEmpty(t, f.field) + allFields = append(allFields, "$."+f.field) + if f.check != nil { + ctx := NewMetaBuildContextWithSctx(sctx) + f.check(ctx) + } + for _, testVal := range f.testVals { + f.setSctx(testVal) + ctx := NewMetaBuildContextWithSctx(sctx) + require.Equal(t, testVal, f.getter(ctx), "field: %s, v: %v", f.field, testVal) + if f.check != nil { + f.check(ctx) + } + } + if f.extra != nil { + f.extra() + } + }) + } + + // make sure all fields are tested (WithIgnorePath contains all fields that the below asserting will pass). + deeptest.AssertRecursivelyNotEqual(t, &metabuild.Context{}, &metabuild.Context{}, deeptest.WithIgnorePath(allFields)) +} diff --git a/pkg/ddl/mock.go b/pkg/ddl/mock.go index d668618ce6df0..a2136d3789508 100644 --- a/pkg/ddl/mock.go +++ b/pkg/ddl/mock.go @@ -54,7 +54,8 @@ func (*mockDelRange) start() {} func (*mockDelRange) clear() {} // MockTableInfo mocks a table info by create table stmt ast and a specified table id. -func MockTableInfo(ctx sessionctx.Context, stmt *ast.CreateTableStmt, tableID int64) (*model.TableInfo, error) { +func MockTableInfo(sctx sessionctx.Context, stmt *ast.CreateTableStmt, tableID int64) (*model.TableInfo, error) { + ctx := NewMetaBuildContextWithSctx(sctx) chs, coll := charset.GetDefaultCharsetAndCollate() cols, newConstraints, err := buildColumnsAndConstraints(ctx, stmt.Cols, stmt.Constraints, chs, coll) if err != nil { diff --git a/pkg/ddl/modify_column.go b/pkg/ddl/modify_column.go index 36cbd022a7192..498c7d234e04c 100644 --- a/pkg/ddl/modify_column.go +++ b/pkg/ddl/modify_column.go @@ -32,6 +32,7 @@ import ( "github.com/pingcap/tidb/pkg/kv" "github.com/pingcap/tidb/pkg/meta" "github.com/pingcap/tidb/pkg/meta/autoid" + "github.com/pingcap/tidb/pkg/meta/metabuild" "github.com/pingcap/tidb/pkg/meta/model" "github.com/pingcap/tidb/pkg/metrics" "github.com/pingcap/tidb/pkg/parser" @@ -724,7 +725,7 @@ func GetModifiableColumnJob( Version: col.Version, }) - if err = ProcessColumnCharsetAndCollation(sctx, col, newCol, t.Meta(), specNewColumn, schema); err != nil { + if err = ProcessColumnCharsetAndCollation(NewMetaBuildContextWithSctx(sctx), col, newCol, t.Meta(), specNewColumn, schema); err != nil { return nil, err } @@ -1006,7 +1007,7 @@ func IsElemsChangedToModifyColumn(oldElems, newElems []string) bool { } // ProcessColumnCharsetAndCollation process column charset and collation -func ProcessColumnCharsetAndCollation(sctx sessionctx.Context, col *table.Column, newCol *table.Column, meta *model.TableInfo, specNewColumn *ast.ColumnDef, schema *model.DBInfo) error { +func ProcessColumnCharsetAndCollation(ctx *metabuild.Context, col *table.Column, newCol *table.Column, meta *model.TableInfo, specNewColumn *ast.ColumnDef, schema *model.DBInfo) error { var chs, coll string var err error // TODO: Remove it when all table versions are greater than or equal to TableInfoVersion1. @@ -1016,22 +1017,22 @@ func ProcessColumnCharsetAndCollation(sctx sessionctx.Context, col *table.Column chs = col.FieldType.GetCharset() coll = col.FieldType.GetCollate() } else { - chs, coll, err = getCharsetAndCollateInColumnDef(sctx.GetSessionVars(), specNewColumn) + chs, coll, err = getCharsetAndCollateInColumnDef(specNewColumn, ctx.GetDefaultCollationForUTF8MB4()) if err != nil { return errors.Trace(err) } - chs, coll, err = ResolveCharsetCollation(sctx.GetSessionVars(), - ast.CharsetOpt{Chs: chs, Col: coll}, - ast.CharsetOpt{Chs: meta.Charset, Col: meta.Collate}, - ast.CharsetOpt{Chs: schema.Charset, Col: schema.Collate}, - ) - chs, coll = OverwriteCollationWithBinaryFlag(sctx.GetSessionVars(), specNewColumn, chs, coll) + chs, coll, err = ResolveCharsetCollation([]ast.CharsetOpt{ + {Chs: chs, Col: coll}, + {Chs: meta.Charset, Col: meta.Collate}, + {Chs: schema.Charset, Col: schema.Collate}, + }, ctx.GetDefaultCollationForUTF8MB4()) + chs, coll = OverwriteCollationWithBinaryFlag(specNewColumn, chs, coll, ctx.GetDefaultCollationForUTF8MB4()) if err != nil { return errors.Trace(err) } } - if err = setCharsetCollationFlenDecimal(&newCol.FieldType, newCol.Name.O, chs, coll, sctx.GetSessionVars()); err != nil { + if err = setCharsetCollationFlenDecimal(ctx, &newCol.FieldType, newCol.Name.O, chs, coll); err != nil { return errors.Trace(err) } decodeEnumSetBinaryLiteralToUTF8(&newCol.FieldType, chs) @@ -1113,7 +1114,7 @@ func checkIndexInModifiableColumns(columns []*model.ColumnInfo, idxColumns []*mo // if the type is still prefixable and larger than old prefix length. prefixLength = ic.Length } - if err := checkIndexColumn(nil, col, prefixLength); err != nil { + if err := checkIndexColumn(col, prefixLength, false); err != nil { return err } } @@ -1162,12 +1163,12 @@ func ProcessModifyColumnOptions(ctx sessionctx.Context, col *table.Column, optio for _, opt := range options { switch opt.Tp { case ast.ColumnOptionDefaultValue: - hasDefaultValue, err = SetDefaultValue(ctx, col, opt) + hasDefaultValue, err = SetDefaultValue(ctx.GetExprCtx(), col, opt) if err != nil { return errors.Trace(err) } case ast.ColumnOptionComment: - err := setColumnComment(ctx, col, opt) + err := setColumnComment(ctx.GetExprCtx(), col, opt) if err != nil { return errors.Trace(err) } @@ -1221,7 +1222,7 @@ func ProcessModifyColumnOptions(ctx sessionctx.Context, col *table.Column, optio } } - if err = processAndCheckDefaultValueAndColumn(ctx, col, nil, hasDefaultValue, setOnUpdateNow, hasNullFlag); err != nil { + if err = processAndCheckDefaultValueAndColumn(ctx.GetExprCtx(), col, nil, hasDefaultValue, setOnUpdateNow, hasNullFlag); err != nil { return errors.Trace(err) } diff --git a/pkg/ddl/partition.go b/pkg/ddl/partition.go index 15d4a6cf9233c..badee48f99c9d 100644 --- a/pkg/ddl/partition.go +++ b/pkg/ddl/partition.go @@ -37,6 +37,7 @@ import ( "github.com/pingcap/tidb/pkg/infoschema" "github.com/pingcap/tidb/pkg/kv" "github.com/pingcap/tidb/pkg/meta" + "github.com/pingcap/tidb/pkg/meta/metabuild" "github.com/pingcap/tidb/pkg/meta/model" "github.com/pingcap/tidb/pkg/metrics" "github.com/pingcap/tidb/pkg/parser" @@ -63,7 +64,6 @@ import ( "github.com/pingcap/tidb/pkg/util/mathutil" decoder "github.com/pingcap/tidb/pkg/util/rowDecoder" "github.com/pingcap/tidb/pkg/util/slice" - "github.com/pingcap/tidb/pkg/util/sqlkiller" "github.com/pingcap/tidb/pkg/util/stringutil" "github.com/tikv/client-go/v2/tikv" kvutil "github.com/tikv/client-go/v2/util" @@ -503,7 +503,7 @@ func checkListPartitions(defs []*ast.PartitionDefinition) error { } // buildTablePartitionInfo builds partition info and checks for some errors. -func buildTablePartitionInfo(ctx sessionctx.Context, s *ast.PartitionOptions, tbInfo *model.TableInfo) error { +func buildTablePartitionInfo(ctx *metabuild.Context, s *ast.PartitionOptions, tbInfo *model.TableInfo) error { if s == nil { return nil } @@ -526,7 +526,7 @@ func buildTablePartitionInfo(ctx sessionctx.Context, s *ast.PartitionOptions, tb } // Note that linear hash is simply ignored, and creates non-linear hash/key. if s.Linear { - ctx.GetSessionVars().StmtCtx.AppendWarning(dbterror.ErrUnsupportedCreatePartition.FastGen(fmt.Sprintf("LINEAR %s is not supported, using non-linear %s instead", s.Tp.String(), s.Tp.String()))) + ctx.AppendWarning(dbterror.ErrUnsupportedCreatePartition.FastGen(fmt.Sprintf("LINEAR %s is not supported, using non-linear %s instead", s.Tp.String(), s.Tp.String()))) } if s.Tp == pmodel.PartitionTypeHash || len(s.ColumnNames) != 0 { enable = true @@ -537,11 +537,11 @@ func buildTablePartitionInfo(ctx sessionctx.Context, s *ast.PartitionOptions, tb } if !enable { - ctx.GetSessionVars().StmtCtx.AppendWarning(dbterror.ErrUnsupportedCreatePartition.FastGen(fmt.Sprintf("Unsupported partition type %v, treat as normal table", s.Tp))) + ctx.AppendWarning(dbterror.ErrUnsupportedCreatePartition.FastGen(fmt.Sprintf("Unsupported partition type %v, treat as normal table", s.Tp))) return nil } if s.Sub != nil { - ctx.GetSessionVars().StmtCtx.AppendWarning(dbterror.ErrUnsupportedCreatePartition.FastGen(fmt.Sprintf("Unsupported subpartitioning, only using %v partitioning", s.Tp))) + ctx.AppendWarning(dbterror.ErrUnsupportedCreatePartition.FastGen(fmt.Sprintf("Unsupported subpartitioning, only using %v partitioning", s.Tp))) } pi := &model.PartitionInfo{ @@ -1788,12 +1788,12 @@ func checkResultOK(ok bool) error { } // checkPartitionFuncType checks partition function return type. -func checkPartitionFuncType(ctx sessionctx.Context, anyExpr any, schema string, tblInfo *model.TableInfo) error { +func checkPartitionFuncType(ctx expression.BuildContext, anyExpr any, schema string, tblInfo *model.TableInfo) error { if anyExpr == nil { return nil } if schema == "" { - schema = ctx.GetSessionVars().CurrentDB + schema = ctx.GetEvalCtx().CurrentDB() } var e expression.Expression var err error @@ -1802,16 +1802,16 @@ func checkPartitionFuncType(ctx sessionctx.Context, anyExpr any, schema string, if expr == "" { return nil } - e, err = expression.ParseSimpleExpr(ctx.GetExprCtx(), expr, expression.WithTableInfo(schema, tblInfo)) + e, err = expression.ParseSimpleExpr(ctx, expr, expression.WithTableInfo(schema, tblInfo)) case ast.ExprNode: - e, err = expression.BuildSimpleExpr(ctx.GetExprCtx(), expr, expression.WithTableInfo(schema, tblInfo)) + e, err = expression.BuildSimpleExpr(ctx, expr, expression.WithTableInfo(schema, tblInfo)) default: return errors.Trace(dbterror.ErrPartitionFuncNotAllowed.GenWithStackByArgs("PARTITION")) } if err != nil { return errors.Trace(err) } - if e.GetType(ctx.GetExprCtx().GetEvalCtx()).EvalType() == types.ETInt { + if e.GetType(ctx.GetEvalCtx()).EvalType() == types.ETInt { return nil } if col, ok := e.(*expression.Column); ok { @@ -1825,7 +1825,7 @@ func checkPartitionFuncType(ctx sessionctx.Context, anyExpr any, schema string, // checkRangePartitionValue checks whether `less than value` is strictly increasing for each partition. // Side effect: it may simplify the partition range definition from a constant expression to an integer. -func checkRangePartitionValue(ctx sessionctx.Context, tblInfo *model.TableInfo) error { +func checkRangePartitionValue(ctx expression.BuildContext, tblInfo *model.TableInfo) error { pi := tblInfo.Partition defs := pi.Definitions if len(defs) == 0 { @@ -1835,14 +1835,14 @@ func checkRangePartitionValue(ctx sessionctx.Context, tblInfo *model.TableInfo) if strings.EqualFold(defs[len(defs)-1].LessThan[0], partitionMaxValue) { defs = defs[:len(defs)-1] } - isUnsigned := isPartExprUnsigned(ctx.GetExprCtx().GetEvalCtx(), tblInfo) + isUnsigned := isPartExprUnsigned(ctx.GetEvalCtx(), tblInfo) var prevRangeValue any for i := 0; i < len(defs); i++ { if strings.EqualFold(defs[i].LessThan[0], partitionMaxValue) { return errors.Trace(dbterror.ErrPartitionMaxvalue) } - currentRangeValue, fromExpr, err := getRangeValue(ctx.GetExprCtx(), defs[i].LessThan[0], isUnsigned) + currentRangeValue, fromExpr, err := getRangeValue(ctx, defs[i].LessThan[0], isUnsigned) if err != nil { return errors.Trace(err) } @@ -3115,7 +3115,7 @@ func (w *worker) onReorganizePartition(jobCtx *jobContext, t *meta.Meta, job *mo return ver, err } } else { - if err = checkPartitionFuncType(sctx, partInfo.Expr, job.SchemaName, tblInfo); err != nil { + if err = checkPartitionFuncType(sctx.GetExprCtx(), partInfo.Expr, job.SchemaName, tblInfo); err != nil { job.State = model.JobStateCancelled return ver, err } @@ -4278,7 +4278,7 @@ func checkPartitionColumnsUnique(tbInfo *model.TableInfo) error { return nil } -func checkNoHashPartitions(_ sessionctx.Context, partitionNum uint64) error { +func checkNoHashPartitions(partitionNum uint64) error { if partitionNum == 0 { return ast.ErrNoParts.GenWithStackByArgs("partitions") } @@ -4308,13 +4308,13 @@ func getPartitionRuleIDs(dbName string, table *model.TableInfo) []string { } // checkPartitioningKeysConstraints checks that the range partitioning key is included in the table constraint. -func checkPartitioningKeysConstraints(sctx sessionctx.Context, s *ast.CreateTableStmt, tblInfo *model.TableInfo) error { +func checkPartitioningKeysConstraints(ctx *metabuild.Context, s *ast.CreateTableStmt, tblInfo *model.TableInfo) error { // Returns directly if there are no unique keys in the table. if len(tblInfo.Indices) == 0 && !tblInfo.PKIsHandle { return nil } - partCols, err := getPartitionColSlices(sctx.GetExprCtx(), tblInfo, s.Partition) + partCols, err := getPartitionColSlices(ctx.GetExprCtx(), tblInfo, s.Partition) if err != nil { return errors.Trace(err) } @@ -4876,7 +4876,7 @@ func generatePartValuesWithTp(partVal types.Datum, tp types.FieldType) (string, return "", dbterror.ErrWrongTypeColumnValue.GenWithStackByArgs() } -func checkPartitionDefinitionConstraints(ctx sessionctx.Context, tbInfo *model.TableInfo) error { +func checkPartitionDefinitionConstraints(ctx expression.BuildContext, tbInfo *model.TableInfo) error { var err error if err = checkPartitionNameUnique(tbInfo.Partition); err != nil { return errors.Trace(err) @@ -4893,25 +4893,24 @@ func checkPartitionDefinitionConstraints(ctx sessionctx.Context, tbInfo *model.T switch tbInfo.Partition.Type { case pmodel.PartitionTypeRange: + failpoint.Inject("CheckPartitionByRangeErr", func() { + panic("mockCheckPartitionByRangeErr") + }) err = checkPartitionByRange(ctx, tbInfo) case pmodel.PartitionTypeHash, pmodel.PartitionTypeKey: - err = checkPartitionByHash(ctx, tbInfo) + err = checkPartitionByHash(tbInfo) case pmodel.PartitionTypeList: err = checkPartitionByList(ctx, tbInfo) } return errors.Trace(err) } -func checkPartitionByHash(ctx sessionctx.Context, tbInfo *model.TableInfo) error { - return checkNoHashPartitions(ctx, tbInfo.Partition.Num) +func checkPartitionByHash(tbInfo *model.TableInfo) error { + return checkNoHashPartitions(tbInfo.Partition.Num) } // checkPartitionByRange checks validity of a "BY RANGE" partition. -func checkPartitionByRange(ctx sessionctx.Context, tbInfo *model.TableInfo) error { - failpoint.Inject("CheckPartitionByRangeErr", func() { - ctx.GetSessionVars().SQLKiller.SendKillSignal(sqlkiller.QueryMemoryExceeded) - panic(ctx.GetSessionVars().SQLKiller.HandleSignal()) - }) +func checkPartitionByRange(ctx expression.BuildContext, tbInfo *model.TableInfo) error { pi := tbInfo.Partition if len(pi.Columns) == 0 { @@ -4921,7 +4920,7 @@ func checkPartitionByRange(ctx sessionctx.Context, tbInfo *model.TableInfo) erro return checkRangeColumnsPartitionValue(ctx, tbInfo) } -func checkRangeColumnsPartitionValue(ctx sessionctx.Context, tbInfo *model.TableInfo) error { +func checkRangeColumnsPartitionValue(ctx expression.BuildContext, tbInfo *model.TableInfo) error { // Range columns partition key supports multiple data types with integer、datetime、string. pi := tbInfo.Partition defs := pi.Definitions @@ -4947,7 +4946,7 @@ func checkRangeColumnsPartitionValue(ctx sessionctx.Context, tbInfo *model.Table return nil } -func checkTwoRangeColumns(ctx sessionctx.Context, curr, prev *model.PartitionDefinition, pi *model.PartitionInfo, tbInfo *model.TableInfo) (bool, error) { +func checkTwoRangeColumns(ctx expression.BuildContext, curr, prev *model.PartitionDefinition, pi *model.PartitionInfo, tbInfo *model.TableInfo) (bool, error) { if len(curr.LessThan) != len(pi.Columns) { return false, errors.Trace(ast.ErrPartitionColumnList) } @@ -4967,7 +4966,7 @@ func checkTwoRangeColumns(ctx sessionctx.Context, curr, prev *model.PartitionDef // PARTITION p1 VALUES LESS THAN (10,20,'mmm') // PARTITION p2 VALUES LESS THAN (15,30,'sss') colInfo := findColumnByName(pi.Columns[i].L, tbInfo) - cmp, err := parseAndEvalBoolExpr(ctx.GetExprCtx(), curr.LessThan[i], prev.LessThan[i], colInfo, tbInfo) + cmp, err := parseAndEvalBoolExpr(ctx, curr.LessThan[i], prev.LessThan[i], colInfo, tbInfo) if err != nil { return false, err } @@ -5025,6 +5024,6 @@ func parseAndEvalBoolExpr(ctx expression.BuildContext, l, r string, colInfo *mod } // checkPartitionByList checks validity of a "BY LIST" partition. -func checkPartitionByList(ctx sessionctx.Context, tbInfo *model.TableInfo) error { - return checkListPartitionValue(ctx.GetExprCtx(), tbInfo) +func checkPartitionByList(ctx expression.BuildContext, tbInfo *model.TableInfo) error { + return checkListPartitionValue(ctx, tbInfo) } diff --git a/pkg/ddl/schematracker/BUILD.bazel b/pkg/ddl/schematracker/BUILD.bazel index 209b1af2b649e..d2e5f5fdf8f47 100644 --- a/pkg/ddl/schematracker/BUILD.bazel +++ b/pkg/ddl/schematracker/BUILD.bazel @@ -17,6 +17,7 @@ go_library( "//pkg/infoschema", "//pkg/kv", "//pkg/meta/autoid", + "//pkg/meta/metabuild", "//pkg/meta/model", "//pkg/owner", "//pkg/parser/ast", diff --git a/pkg/ddl/schematracker/dm_tracker.go b/pkg/ddl/schematracker/dm_tracker.go index be9ce8891cfb2..c74b42583fd75 100644 --- a/pkg/ddl/schematracker/dm_tracker.go +++ b/pkg/ddl/schematracker/dm_tracker.go @@ -25,6 +25,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/tidb/pkg/ddl" "github.com/pingcap/tidb/pkg/infoschema" + "github.com/pingcap/tidb/pkg/meta/metabuild" "github.com/pingcap/tidb/pkg/meta/model" "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/parser/charset" @@ -73,11 +74,11 @@ func (d *SchemaTracker) CreateSchema(ctx sessionctx.Context, stmt *ast.CreateDat } } - var sessVars *variable.SessionVars + utf8MB4DefaultColl := "" if ctx != nil { - sessVars = ctx.GetSessionVars() + utf8MB4DefaultColl = ctx.GetSessionVars().DefaultCollationForUTF8MB4 } - chs, coll, err := ddl.ResolveCharsetCollation(sessVars, charsetOpt) + chs, coll, err := ddl.ResolveCharsetCollation([]ast.CharsetOpt{charsetOpt}, utf8MB4DefaultColl) if err != nil { return errors.Trace(err) } @@ -152,7 +153,7 @@ func (d *SchemaTracker) AlterSchema(ctx sessionctx.Context, stmt *ast.AlterDatab } } if toCollate == "" { - if toCollate, err = ddl.GetDefaultCollation(ctx.GetSessionVars(), toCharset); err != nil { + if toCollate, err = ddl.GetDefaultCollation(toCharset, ctx.GetSessionVars().DefaultCollationForUTF8MB4); err != nil { return errors.Trace(err) } } @@ -182,15 +183,6 @@ func (d *SchemaTracker) CreateTable(ctx sessionctx.Context, s *ast.CreateTableSt if schema == nil { return infoschema.ErrDatabaseNotExists.GenWithStackByArgs(ident.Schema) } - // suppress ErrTooLongKey - ctx.SetValue(ddl.SuppressErrorTooLongKeyKey, true) - // support drop PK - enableClusteredIndexBackup := ctx.GetSessionVars().EnableClusteredIndex - ctx.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOff - defer func() { - ctx.ClearValue(ddl.SuppressErrorTooLongKeyKey) - ctx.GetSessionVars().EnableClusteredIndex = enableClusteredIndexBackup - }() var ( referTbl *model.TableInfo @@ -203,14 +195,21 @@ func (d *SchemaTracker) CreateTable(ctx sessionctx.Context, s *ast.CreateTableSt } } + metaBuildCtx := ddl.NewMetaBuildContextWithSctx( + ctx, + // suppress ErrTooLongKey + metabuild.WithSuppressTooLongIndexErr(true), + // support drop PK + metabuild.WithClusteredIndexDefMode(variable.ClusteredIndexDefModeOff), + ) // build tableInfo var ( tbInfo *model.TableInfo ) if s.ReferTable != nil { - tbInfo, err = ddl.BuildTableInfoWithLike(ctx, ident, referTbl, s) + tbInfo, err = ddl.BuildTableInfoWithLike(ident, referTbl, s) } else { - tbInfo, err = ddl.BuildTableInfoWithStmt(ctx, s, schema.Charset, schema.Collate, nil) + tbInfo, err = ddl.BuildTableInfoWithStmt(metaBuildCtx, s, schema.Charset, schema.Collate, nil) } if err != nil { return errors.Trace(err) @@ -218,7 +217,7 @@ func (d *SchemaTracker) CreateTable(ctx sessionctx.Context, s *ast.CreateTableSt // TODO: to reuse the constant fold of expression in partition range definition we use CheckTableInfoValidWithStmt, // but it may also introduce unwanted limit check in DM's use case. Should check it later. - if err = ddl.CheckTableInfoValidWithStmt(ctx, tbInfo, s); err != nil { + if err = ddl.CheckTableInfoValidWithStmt(metaBuildCtx, tbInfo, s); err != nil { return err } @@ -277,7 +276,7 @@ func (d *SchemaTracker) CreateView(ctx sessionctx.Context, s *ast.CreateViewStmt }) } - tbInfo, err := ddl.BuildTableInfo(ctx, s.ViewName.Name, cols, nil, "", "") + tbInfo, err := ddl.BuildTableInfo(ddl.NewMetaBuildContextWithSctx(ctx), s.ViewName.Name, cols, nil, "", "") if err != nil { return err } @@ -413,7 +412,7 @@ func (d *SchemaTracker) createIndex( return dbterror.ErrDupKeyName.GenWithStack("index already exist %s", indexName) } - hiddenCols, err := ddl.BuildHiddenColumnInfo(ctx, indexPartSpecifications, indexName, t.Meta(), t.Cols()) + hiddenCols, err := ddl.BuildHiddenColumnInfo(ddl.NewMetaBuildContextWithSctx(ctx), indexPartSpecifications, indexName, t.Meta(), t.Cols()) if err != nil { return err } @@ -426,7 +425,7 @@ func (d *SchemaTracker) createIndex( } indexInfo, err := ddl.BuildIndexInfo( - ctx, + ddl.NewMetaBuildContextWithSctx(ctx), finalColumns, indexName, false, @@ -653,7 +652,7 @@ func (d *SchemaTracker) alterColumn(ctx sessionctx.Context, ident ast.Ident, spe } oldCol.AddFlag(mysql.NoDefaultValueFlag) } else { - _, err := ddl.SetDefaultValue(ctx, oldCol, specNewColumn.Options[0]) + _, err := ddl.SetDefaultValue(ctx.GetExprCtx(), oldCol, specNewColumn.Options[0]) if err != nil { return errors.Trace(err) } @@ -870,7 +869,7 @@ func (d *SchemaTracker) createPrimaryKey( } indexInfo, err := ddl.BuildIndexInfo( - ctx, + ddl.NewMetaBuildContextWithSctx(ctx), tblInfo.Columns, indexName, true, @@ -979,7 +978,7 @@ func (d *SchemaTracker) AlterTable(ctx context.Context, sctx sessionctx.Context, continue } var toCharset, toCollate string - toCharset, toCollate, err = ddl.GetCharsetAndCollateInTableOption(sctx.GetSessionVars(), i, spec.Options) + toCharset, toCollate, err = ddl.GetCharsetAndCollateInTableOption(i, spec.Options, sctx.GetSessionVars().DefaultCollationForUTF8MB4) if err != nil { return err } diff --git a/pkg/ddl/ttl.go b/pkg/ddl/ttl.go index 73b53feffad45..cc993842c492c 100644 --- a/pkg/ddl/ttl.go +++ b/pkg/ddl/ttl.go @@ -19,14 +19,13 @@ import ( "time" "github.com/pingcap/errors" + infoschemactx "github.com/pingcap/tidb/pkg/infoschema/context" "github.com/pingcap/tidb/pkg/meta" "github.com/pingcap/tidb/pkg/meta/model" "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/parser/format" pmodel "github.com/pingcap/tidb/pkg/parser/model" "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/sessiontxn" "github.com/pingcap/tidb/pkg/ttl/cache" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/dbterror" @@ -97,15 +96,29 @@ func onTTLInfoChange(jobCtx *jobContext, t *meta.Meta, job *model.Job) (ver int6 return ver, nil } -func checkTTLInfoValid(ctx sessionctx.Context, schema pmodel.CIStr, tblInfo *model.TableInfo) error { +// checkTTLInfoValid checks the TTL settings for a table. +// The argument `isForForeignKeyCheck` is used to check the table should not be referenced by foreign key. +// If `isForForeignKeyCheck` is `nil`, it will skip the foreign key check. +func checkTTLInfoValid(schema pmodel.CIStr, tblInfo *model.TableInfo, foreignKeyCheckIs infoschemactx.MetaOnlyInfoSchema) error { + if tblInfo.TempTableType != model.TempTableNone { + return dbterror.ErrTempTableNotAllowedWithTTL + } + if err := checkTTLIntervalExpr(tblInfo.TTLInfo); err != nil { return err } - if err := checkTTLTableSuitable(ctx, schema, tblInfo); err != nil { + if err := checkPrimaryKeyForTTLTable(tblInfo); err != nil { return err } + if foreignKeyCheckIs != nil { + // checks even when the foreign key check is not enabled, to keep safe + if referredFK := checkTableHasForeignKeyReferred(foreignKeyCheckIs, schema.L, tblInfo.Name.L, nil, true); referredFK != nil { + return dbterror.ErrUnsupportedTTLReferencedByFK + } + } + return checkTTLInfoColumnType(tblInfo) } @@ -126,26 +139,6 @@ func checkTTLInfoColumnType(tblInfo *model.TableInfo) error { return nil } -// checkTTLTableSuitable returns whether this table is suitable to be a TTL table -// A temporary table or a parent table referenced by a foreign key cannot be TTL table -func checkTTLTableSuitable(ctx sessionctx.Context, schema pmodel.CIStr, tblInfo *model.TableInfo) error { - if tblInfo.TempTableType != model.TempTableNone { - return dbterror.ErrTempTableNotAllowedWithTTL - } - - if err := checkPrimaryKeyForTTLTable(tblInfo); err != nil { - return err - } - - // checks even when the foreign key check is not enabled, to keep safe - is := sessiontxn.GetTxnManager(ctx).GetTxnInfoSchema() - if referredFK := checkTableHasForeignKeyReferred(is, schema.L, tblInfo.Name.L, nil, true); referredFK != nil { - return dbterror.ErrUnsupportedTTLReferencedByFK - } - - return nil -} - func checkDropColumnWithTTLConfig(tblInfo *model.TableInfo, colName string) error { if tblInfo.TTLInfo != nil { if tblInfo.TTLInfo.ColumnName.L == colName { diff --git a/pkg/distsql/context_test.go b/pkg/distsql/context_test.go index d81568e0373ab..9cb394fccaf3e 100644 --- a/pkg/distsql/context_test.go +++ b/pkg/distsql/context_test.go @@ -24,7 +24,7 @@ import ( // NewDistSQLContextForTest creates a new dist sql context for test func NewDistSQLContextForTest() *distsqlctx.DistSQLContext { return &distsqlctx.DistSQLContext{ - WarnHandler: contextutil.NewFuncWarnAppenderForTest(func(err error) {}), + WarnHandler: contextutil.NewFuncWarnAppenderForTest(func(level string, err error) {}), TiFlashMaxThreads: variable.DefTiFlashMaxThreads, TiFlashMaxBytesBeforeExternalJoin: variable.DefTiFlashMaxBytesBeforeExternalJoin, TiFlashMaxBytesBeforeExternalGroupBy: variable.DefTiFlashMaxBytesBeforeExternalGroupBy, diff --git a/pkg/errctx/context_test.go b/pkg/errctx/context_test.go index 7f18439b29760..2a94d77e57fb5 100644 --- a/pkg/errctx/context_test.go +++ b/pkg/errctx/context_test.go @@ -27,7 +27,8 @@ import ( func TestContext(t *testing.T) { var warn error - ctx := errctx.NewContext(contextutil.NewFuncWarnAppenderForTest(func(err error) { + ctx := errctx.NewContext(contextutil.NewFuncWarnAppenderForTest(func(level string, err error) { + require.Equal(t, contextutil.WarnLevelWarning, level) warn = err })) @@ -75,7 +76,8 @@ func TestContext(t *testing.T) { // test with a level map levels = errctx.LevelMap{} levels[errctx.ErrGroupAutoIncReadFailed] = errctx.LevelWarn - ctx = errctx.NewContextWithLevels(levels, contextutil.NewFuncWarnAppenderForTest(func(err error) { + ctx = errctx.NewContextWithLevels(levels, contextutil.NewFuncWarnAppenderForTest(func(level string, err error) { + require.Equal(t, contextutil.WarnLevelWarning, level) warn = err })) require.Equal(t, levels, ctx.LevelMap()) diff --git a/pkg/executor/ddl.go b/pkg/executor/ddl.go index 2aa576bf09001..24038b07db878 100644 --- a/pkg/executor/ddl.go +++ b/pkg/executor/ddl.go @@ -289,7 +289,7 @@ func (e *DDLExec) createSessionTemporaryTable(s *ast.CreateTableStmt) error { return errors.Trace(err) } - tbInfo, err := ddl.BuildSessionTemporaryTableInfo(e.Ctx(), is, s, dbInfo.Charset, dbInfo.Collate, dbInfo.PlacementPolicyRef) + tbInfo, err := ddl.BuildSessionTemporaryTableInfo(ddl.NewMetaBuildContextWithSctx(e.Ctx()), is, s, dbInfo.Charset, dbInfo.Collate, dbInfo.PlacementPolicyRef) if err != nil { return err } diff --git a/pkg/expression/exprstatic/evalctx.go b/pkg/expression/exprstatic/evalctx.go index fda980f73561b..5086143497a1c 100644 --- a/pkg/expression/exprstatic/evalctx.go +++ b/pkg/expression/exprstatic/evalctx.go @@ -300,6 +300,13 @@ func (ctx *EvalContext) AppendWarning(err error) { } } +// AppendNote appends notes to the context. +func (ctx *EvalContext) AppendNote(err error) { + if h := ctx.warnHandler; h != nil { + h.AppendNote(err) + } +} + // WarningCount gets warning count. func (ctx *EvalContext) WarningCount() int { if h := ctx.warnHandler; h != nil { @@ -489,11 +496,35 @@ func (ctx *EvalContext) currentTimeFuncFromStringVal(val string) func() (time.Ti func newSessionVarsWithSystemVariables(vars map[string]string) (*variable.SessionVars, error) { sessionVars := variable.NewSessionVars(nil) + var cs, col []string for name, val := range vars { - if err := sessionVars.SetSystemVar(name, val); err != nil { + switch strings.ToLower(name) { + // `charset_connection` and `collation_connection` will overwrite each other. + // To make the result more determinate, just set them at last step in order: + // `charset_connection` first, then `collation_connection`. + case variable.CharacterSetConnection: + cs = []string{name, val} + case variable.CollationConnection: + col = []string{name, val} + default: + if err := sessionVars.SetSystemVar(name, val); err != nil { + return nil, err + } + } + } + + if cs != nil { + if err := sessionVars.SetSystemVar(cs[0], cs[1]); err != nil { + return nil, err + } + } + + if col != nil { + if err := sessionVars.SetSystemVar(col[0], col[1]); err != nil { return nil, err } } + return sessionVars, nil } diff --git a/pkg/expression/exprstatic/evalctx_test.go b/pkg/expression/exprstatic/evalctx_test.go index cf4145e8cdab8..10b5fc02a69c7 100644 --- a/pkg/expression/exprstatic/evalctx_test.go +++ b/pkg/expression/exprstatic/evalctx_test.go @@ -310,9 +310,10 @@ func TestStaticEvalCtxWarnings(t *testing.T) { tc, ec := ctx.TypeCtx(), ctx.ErrCtx() h.AppendWarning(errors.NewNoStackError("warn0")) ctx.AppendWarning(errors.NewNoStackError("warn1")) + ctx.AppendNote(errors.NewNoStackError("note1")) tc.AppendWarning(errors.NewNoStackError("warn2")) ec.AppendWarning(errors.NewNoStackError("warn3")) - require.Equal(t, 4, h.WarningCount()) + require.Equal(t, 5, h.WarningCount()) require.Equal(t, h.WarningCount(), ctx.WarningCount()) // ctx.CopyWarnings @@ -320,15 +321,17 @@ func TestStaticEvalCtxWarnings(t *testing.T) { require.Equal(t, []contextutil.SQLWarn{ {Level: contextutil.WarnLevelWarning, Err: errors.NewNoStackError("warn0")}, {Level: contextutil.WarnLevelWarning, Err: errors.NewNoStackError("warn1")}, + {Level: contextutil.WarnLevelNote, Err: errors.NewNoStackError("note1")}, {Level: contextutil.WarnLevelWarning, Err: errors.NewNoStackError("warn2")}, {Level: contextutil.WarnLevelWarning, Err: errors.NewNoStackError("warn3")}, }, warnings) - require.Equal(t, 4, h.WarningCount()) + require.Equal(t, 5, h.WarningCount()) require.Equal(t, h.WarningCount(), ctx.WarningCount()) // ctx.TruncateWarnings warnings = ctx.TruncateWarnings(2) require.Equal(t, []contextutil.SQLWarn{ + {Level: contextutil.WarnLevelNote, Err: errors.NewNoStackError("note1")}, {Level: contextutil.WarnLevelWarning, Err: errors.NewNoStackError("warn2")}, {Level: contextutil.WarnLevelWarning, Err: errors.NewNoStackError("warn3")}, }, warnings) diff --git a/pkg/expression/sessionexpr/sessionctx.go b/pkg/expression/sessionexpr/sessionctx.go index d75f9ff8661c6..af25f9b08e594 100644 --- a/pkg/expression/sessionexpr/sessionctx.go +++ b/pkg/expression/sessionexpr/sessionctx.go @@ -212,6 +212,11 @@ func (ctx *EvalContext) AppendWarning(err error) { ctx.sctx.GetSessionVars().StmtCtx.AppendWarning(err) } +// AppendNote appends notes to the context. +func (ctx *EvalContext) AppendNote(err error) { + ctx.sctx.GetSessionVars().StmtCtx.AppendNote(err) +} + // WarningCount gets warning count. func (ctx *EvalContext) WarningCount() int { return int(ctx.sctx.GetSessionVars().StmtCtx.WarningCount()) diff --git a/pkg/meta/metabuild/BUILD.bazel b/pkg/meta/metabuild/BUILD.bazel new file mode 100644 index 0000000000000..fde57696100c8 --- /dev/null +++ b/pkg/meta/metabuild/BUILD.bazel @@ -0,0 +1,35 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "metabuild", + srcs = ["context.go"], + importpath = "github.com/pingcap/tidb/pkg/meta/metabuild", + visibility = ["//visibility:public"], + deps = [ + "//pkg/expression/exprctx", + "//pkg/expression/exprstatic", + "//pkg/infoschema/context", + "//pkg/parser/mysql", + "//pkg/sessionctx/variable", + "//pkg/util/intest", + ], +) + +go_test( + name = "metabuild_test", + timeout = "short", + srcs = ["context_test.go"], + flaky = True, + deps = [ + ":metabuild", + "//pkg/expression/exprctx", + "//pkg/expression/exprstatic", + "//pkg/infoschema", + "//pkg/infoschema/context", + "//pkg/parser/charset", + "//pkg/parser/mysql", + "//pkg/sessionctx/variable", + "//pkg/util/deeptest", + "@com_github_stretchr_testify//require", + ], +) diff --git a/pkg/meta/metabuild/context.go b/pkg/meta/metabuild/context.go new file mode 100644 index 0000000000000..faa2056cedbd2 --- /dev/null +++ b/pkg/meta/metabuild/context.go @@ -0,0 +1,194 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package metabuild + +import ( + "github.com/pingcap/tidb/pkg/expression/exprctx" + "github.com/pingcap/tidb/pkg/expression/exprstatic" + infoschemactx "github.com/pingcap/tidb/pkg/infoschema/context" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/util/intest" +) + +// Option is used to set context options. +type Option interface { + applyCtx(*Context) +} + +// funcCtxOption implements the Option interface. +type funcCtxOption struct { + f func(*Context) +} + +func (o funcCtxOption) applyCtx(ctx *Context) { + o.f(ctx) +} + +func funcOpt(f func(ctx *Context)) Option { + return funcCtxOption{f: f} +} + +// WithExprCtx sets the expression context. +func WithExprCtx(exprCtx exprctx.ExprContext) Option { + intest.AssertNotNil(exprCtx) + return funcOpt(func(ctx *Context) { + ctx.exprCtx = exprCtx + }) +} + +// WithEnableAutoIncrementInGenerated sets whether enable auto increment in generated column. +func WithEnableAutoIncrementInGenerated(enable bool) Option { + return funcOpt(func(ctx *Context) { + ctx.enableAutoIncrementInGenerated = enable + }) +} + +// WithPrimaryKeyRequired sets whether primary key is required. +func WithPrimaryKeyRequired(required bool) Option { + return funcOpt(func(ctx *Context) { + ctx.primaryKeyRequired = required + }) +} + +// WithClusteredIndexDefMode sets the clustered index mode. +func WithClusteredIndexDefMode(mode variable.ClusteredIndexDefMode) Option { + return funcOpt(func(ctx *Context) { + ctx.clusteredIndexDefMode = mode + }) +} + +// WithShardRowIDBits sets the shard row id bits. +func WithShardRowIDBits(bits uint64) Option { + return funcOpt(func(ctx *Context) { + ctx.shardRowIDBits = bits + }) +} + +// WithPreSplitRegions sets the pre-split regions. +func WithPreSplitRegions(regions uint64) Option { + return funcOpt(func(ctx *Context) { + ctx.preSplitRegions = regions + }) +} + +// WithSuppressTooLongIndexErr sets whether to suppress too long index error. +func WithSuppressTooLongIndexErr(suppress bool) Option { + return funcOpt(func(ctx *Context) { + ctx.suppressTooLongIndexErr = suppress + }) +} + +// WithInfoSchema sets the info schema. +func WithInfoSchema(schema infoschemactx.MetaOnlyInfoSchema) Option { + return funcOpt(func(ctx *Context) { + ctx.is = schema + }) +} + +// Context is used to build meta like `TableInfo`, `IndexInfo`, etc... +type Context struct { + exprCtx exprctx.ExprContext + enableAutoIncrementInGenerated bool + primaryKeyRequired bool + clusteredIndexDefMode variable.ClusteredIndexDefMode + shardRowIDBits uint64 + preSplitRegions uint64 + suppressTooLongIndexErr bool + is infoschemactx.MetaOnlyInfoSchema +} + +// NewContext creates a new context for meta-building. +func NewContext(opts ...Option) *Context { + ctx := &Context{ + enableAutoIncrementInGenerated: variable.DefTiDBEnableAutoIncrementInGenerated, + primaryKeyRequired: false, + clusteredIndexDefMode: variable.DefTiDBEnableClusteredIndex, + shardRowIDBits: variable.DefShardRowIDBits, + preSplitRegions: variable.DefPreSplitRegions, + suppressTooLongIndexErr: false, + } + + for _, opt := range opts { + opt.applyCtx(ctx) + } + + if ctx.exprCtx == nil { + ctx.exprCtx = exprstatic.NewExprContext() + } + + return ctx +} + +// GetExprCtx returns the expression context of the session. +func (ctx *Context) GetExprCtx() exprctx.ExprContext { + return ctx.exprCtx +} + +// GetDefaultCollationForUTF8MB4 returns the default collation for utf8mb4. +func (ctx *Context) GetDefaultCollationForUTF8MB4() string { + return ctx.exprCtx.GetDefaultCollationForUTF8MB4() +} + +// GetSQLMode returns the SQL mode. +func (ctx *Context) GetSQLMode() mysql.SQLMode { + return ctx.exprCtx.GetEvalCtx().SQLMode() +} + +// AppendWarning appends a warning. +func (ctx *Context) AppendWarning(err error) { + ctx.GetExprCtx().GetEvalCtx().AppendWarning(err) +} + +// AppendNote appends a note. +func (ctx *Context) AppendNote(note error) { + ctx.GetExprCtx().GetEvalCtx().AppendNote(note) +} + +// EnableAutoIncrementInGenerated returns whether enable auto increment in generated column. +func (ctx *Context) EnableAutoIncrementInGenerated() bool { + return ctx.enableAutoIncrementInGenerated +} + +// PrimaryKeyRequired returns whether primary key is required. +func (ctx *Context) PrimaryKeyRequired() bool { + return ctx.primaryKeyRequired +} + +// GetClusteredIndexDefMode returns the clustered index mode. +func (ctx *Context) GetClusteredIndexDefMode() variable.ClusteredIndexDefMode { + return ctx.clusteredIndexDefMode +} + +// GetShardRowIDBits returns the shard row id bits. +func (ctx *Context) GetShardRowIDBits() uint64 { + return ctx.shardRowIDBits +} + +// GetPreSplitRegions returns the pre-split regions. +func (ctx *Context) GetPreSplitRegions() uint64 { + return ctx.preSplitRegions +} + +// SuppressTooLongIndexErr returns whether suppress too long index error. +func (ctx *Context) SuppressTooLongIndexErr() bool { + return ctx.suppressTooLongIndexErr +} + +// GetInfoSchema returns the info schema for check some constraints between tables. +// If the second return value is false, it means that we do not need to check the constraints referred to other tables. +func (ctx *Context) GetInfoSchema() (infoschemactx.MetaOnlyInfoSchema, bool) { + return ctx.is, ctx.is != nil +} diff --git a/pkg/meta/metabuild/context_test.go b/pkg/meta/metabuild/context_test.go new file mode 100644 index 0000000000000..99a20961fe3a8 --- /dev/null +++ b/pkg/meta/metabuild/context_test.go @@ -0,0 +1,172 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package metabuild_test + +import ( + "testing" + + "github.com/pingcap/tidb/pkg/expression/exprctx" + "github.com/pingcap/tidb/pkg/expression/exprstatic" + "github.com/pingcap/tidb/pkg/infoschema" + infoschemactx "github.com/pingcap/tidb/pkg/infoschema/context" + "github.com/pingcap/tidb/pkg/meta/metabuild" + "github.com/pingcap/tidb/pkg/parser/charset" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/util/deeptest" + "github.com/stretchr/testify/require" +) + +func TestMetaBuildContext(t *testing.T) { + defVars := variable.NewSessionVars(nil) + fields := []struct { + name string + getter func(ctx *metabuild.Context) any + checkDefault any + option func(val any) metabuild.Option + testVals []any + }{ + { + name: "exprCtx", + getter: func(ctx *metabuild.Context) any { + return ctx.GetExprCtx() + }, + checkDefault: func(ctx *metabuild.Context) { + require.NotNil(t, ctx.GetExprCtx()) + cs, col := ctx.GetExprCtx().GetCharsetInfo() + defCs, defCol := charset.GetDefaultCharsetAndCollate() + require.Equal(t, defCs, cs) + require.Equal(t, defCol, col) + defSQLMode, err := mysql.GetSQLMode(mysql.DefaultSQLMode) + require.NoError(t, err) + require.Equal(t, defSQLMode, ctx.GetSQLMode()) + require.Equal(t, ctx.GetExprCtx().GetEvalCtx().SQLMode(), ctx.GetSQLMode()) + require.Equal(t, defVars.DefaultCollationForUTF8MB4, ctx.GetDefaultCollationForUTF8MB4()) + require.Equal(t, ctx.GetExprCtx().GetDefaultCollationForUTF8MB4(), ctx.GetDefaultCollationForUTF8MB4()) + }, + option: func(val any) metabuild.Option { + return metabuild.WithExprCtx(val.(exprctx.ExprContext)) + }, + testVals: []any{exprstatic.NewExprContext()}, + }, + { + name: "enableAutoIncrementInGenerated", + getter: func(ctx *metabuild.Context) any { + return ctx.EnableAutoIncrementInGenerated() + }, + checkDefault: defVars.EnableAutoIncrementInGenerated, + option: func(val any) metabuild.Option { + return metabuild.WithEnableAutoIncrementInGenerated(val.(bool)) + }, + testVals: []any{true, false}, + }, + { + name: "primaryKeyRequired", + getter: func(ctx *metabuild.Context) any { + return ctx.PrimaryKeyRequired() + }, + checkDefault: defVars.PrimaryKeyRequired, + option: func(val any) metabuild.Option { + return metabuild.WithPrimaryKeyRequired(val.(bool)) + }, + testVals: []any{true, false}, + }, + { + name: "clusteredIndexDefMode", + getter: func(ctx *metabuild.Context) any { + return ctx.GetClusteredIndexDefMode() + }, + checkDefault: defVars.EnableClusteredIndex, + option: func(val any) metabuild.Option { + return metabuild.WithClusteredIndexDefMode(val.(variable.ClusteredIndexDefMode)) + }, + testVals: []any{variable.ClusteredIndexDefModeOn, variable.ClusteredIndexDefModeOff}, + }, + { + name: "shardRowIDBits", + getter: func(ctx *metabuild.Context) any { + return ctx.GetShardRowIDBits() + }, + checkDefault: defVars.ShardRowIDBits, + option: func(val any) metabuild.Option { + return metabuild.WithShardRowIDBits(val.(uint64)) + }, + testVals: []any{uint64(6), uint64(8)}, + }, + { + name: "preSplitRegions", + getter: func(ctx *metabuild.Context) any { + return ctx.GetPreSplitRegions() + }, + checkDefault: defVars.PreSplitRegions, + option: func(val any) metabuild.Option { + return metabuild.WithPreSplitRegions(val.(uint64)) + }, + testVals: []any{uint64(123), uint64(456)}, + }, + { + name: "suppressTooLongIndexErr", + getter: func(ctx *metabuild.Context) any { + return ctx.SuppressTooLongIndexErr() + }, + checkDefault: false, + option: func(val any) metabuild.Option { + return metabuild.WithSuppressTooLongIndexErr(val.(bool)) + }, + testVals: []any{true, false}, + }, + { + name: "is", + getter: func(ctx *metabuild.Context) any { + is, ok := ctx.GetInfoSchema() + require.Equal(t, ok, is != nil) + return is + }, + checkDefault: nil, + option: func(val any) metabuild.Option { + if val == nil { + return metabuild.WithInfoSchema(nil) + } + return metabuild.WithInfoSchema(val.(infoschemactx.MetaOnlyInfoSchema)) + }, + testVals: []any{infoschema.MockInfoSchema(nil), nil}, + }, + } + defCtx := metabuild.NewContext() + allFields := make([]string, 0, len(fields)) + for _, field := range fields { + t.Run("default_of_"+field.name, func(t *testing.T) { + switch val := field.checkDefault.(type) { + case func(*metabuild.Context): + val(defCtx) + default: + require.Equal(t, field.checkDefault, field.getter(defCtx), field.name) + } + }) + allFields = append(allFields, "$."+field.name) + } + + for _, field := range fields { + t.Run("option_of_"+field.name, func(t *testing.T) { + for _, val := range field.testVals { + ctx := metabuild.NewContext(field.option(val)) + require.Equal(t, val, field.getter(ctx), "%s %v", field.name, val) + } + }) + } + + // test allFields are tested + deeptest.AssertRecursivelyNotEqual(t, metabuild.Context{}, metabuild.Context{}, deeptest.WithIgnorePath(allFields)) +} diff --git a/pkg/server/conn_stmt_params_test.go b/pkg/server/conn_stmt_params_test.go index 66fa3754069ba..a1850c7721e5d 100644 --- a/pkg/server/conn_stmt_params_test.go +++ b/pkg/server/conn_stmt_params_test.go @@ -269,7 +269,8 @@ func TestParseExecArgs(t *testing.T) { } for _, tt := range tests { var warn error - typectx := types.NewContext(types.DefaultStmtFlags.WithTruncateAsWarning(true), time.UTC, contextutil.NewFuncWarnAppenderForTest(func(err error) { + typectx := types.NewContext(types.DefaultStmtFlags.WithTruncateAsWarning(true), time.UTC, contextutil.NewFuncWarnAppenderForTest(func(l string, err error) { + require.Equal(t, contextutil.WarnLevelWarning, l) warn = err })) err := decodeAndParse(typectx, tt.args.args, tt.args.boundParams, tt.args.nullBitmap, tt.args.paramTypes, tt.args.paramValues, nil) diff --git a/pkg/types/context_test.go b/pkg/types/context_test.go index 38d1874ab5024..62bfdbd9bb78c 100644 --- a/pkg/types/context_test.go +++ b/pkg/types/context_test.go @@ -120,6 +120,10 @@ func (w *warnStore) AppendWarning(warn error) { w.warnings = append(w.warnings, warn) } +func (w *warnStore) AppendNote(_ error) { + panic("not implemented") +} + func (w *warnStore) Reset() { w.Lock() defer w.Unlock() diff --git a/pkg/types/time_test.go b/pkg/types/time_test.go index a6779f42a8729..d015f5fb209b3 100644 --- a/pkg/types/time_test.go +++ b/pkg/types/time_test.go @@ -61,7 +61,8 @@ func TestTimeEncoding(t *testing.T) { func TestDateTime(t *testing.T) { var warnings []error - typeCtx := types.NewContext(types.StrictFlags.WithIgnoreZeroInDate(true), time.UTC, contextutil.NewFuncWarnAppenderForTest(func(err error) { + typeCtx := types.NewContext(types.StrictFlags.WithIgnoreZeroInDate(true), time.UTC, contextutil.NewFuncWarnAppenderForTest(func(l string, err error) { + require.Equal(t, contextutil.WarnLevelWarning, l) warnings = append(warnings, err) })) table := []struct { @@ -2214,7 +2215,7 @@ func TestDurationConvertToYearFromNow(t *testing.T) { } for _, c := range cases { - ctx := types.NewContext(types.StrictFlags.WithCastTimeToYearThroughConcat(c.throughStr), c.sysTZ, contextutil.NewFuncWarnAppenderForTest(func(_ error) { + ctx := types.NewContext(types.StrictFlags.WithCastTimeToYearThroughConcat(c.throughStr), c.sysTZ, contextutil.NewFuncWarnAppenderForTest(func(_ string, _ error) { require.Fail(t, "shouldn't append warninng") })) now, err := time.Parse(time.RFC3339, c.nowLit) diff --git a/pkg/util/context/warn.go b/pkg/util/context/warn.go index 7ec43b8eea8dc..b07a91d1a3924 100644 --- a/pkg/util/context/warn.go +++ b/pkg/util/context/warn.go @@ -81,6 +81,8 @@ func (warn *SQLWarn) UnmarshalJSON(data []byte) error { type WarnAppender interface { // AppendWarning appends a warning AppendWarning(err error) + // AppendNote appends a warning with level 'Note'. + AppendNote(msg error) } // WarnHandler provides a handler to append and get warnings. @@ -278,6 +280,8 @@ type ignoreWarn struct{} func (*ignoreWarn) AppendWarning(_ error) {} +func (*ignoreWarn) AppendNote(_ error) {} + func (*ignoreWarn) WarningCount() int { return 0 } func (*ignoreWarn) TruncateWarnings(_ int) []SQLWarn { return nil } @@ -288,15 +292,19 @@ func (*ignoreWarn) CopyWarnings(_ []SQLWarn) []SQLWarn { return nil } var IgnoreWarn WarnHandler = &ignoreWarn{} type funcWarnAppender struct { - fn func(err error) + fn func(level string, err error) } func (r *funcWarnAppender) AppendWarning(err error) { - r.fn(err) + r.fn(WarnLevelWarning, err) +} + +func (r *funcWarnAppender) AppendNote(err error) { + r.fn(WarnLevelNote, err) } // NewFuncWarnAppenderForTest creates a `WarnHandler` which will use the function to handle warn // To have a better performance, it's not suggested to use this function in production. -func NewFuncWarnAppenderForTest(fn func(err error)) WarnAppender { +func NewFuncWarnAppenderForTest(fn func(level string, err error)) WarnAppender { return &funcWarnAppender{fn} }