diff --git a/br/pkg/lightning/backend/kv/BUILD.bazel b/br/pkg/lightning/backend/kv/BUILD.bazel index 207297a01ddc1..a15f31d90f2e8 100644 --- a/br/pkg/lightning/backend/kv/BUILD.bazel +++ b/br/pkg/lightning/backend/kv/BUILD.bazel @@ -22,6 +22,7 @@ go_library( "//br/pkg/logutil", "//br/pkg/redact", "//br/pkg/utils", + "//pkg/errctx", "//pkg/expression", "//pkg/kv", "//pkg/meta/autoid", diff --git a/br/pkg/lightning/backend/kv/session.go b/br/pkg/lightning/backend/kv/session.go index 157e3d642175f..d777a7ade81aa 100644 --- a/br/pkg/lightning/backend/kv/session.go +++ b/br/pkg/lightning/backend/kv/session.go @@ -29,6 +29,7 @@ import ( "github.com/pingcap/tidb/br/pkg/lightning/log" "github.com/pingcap/tidb/br/pkg/lightning/manual" "github.com/pingcap/tidb/br/pkg/utils" + "github.com/pingcap/tidb/pkg/errctx" "github.com/pingcap/tidb/pkg/kv" "github.com/pingcap/tidb/pkg/parser/model" "github.com/pingcap/tidb/pkg/sessionctx" @@ -293,6 +294,12 @@ func NewSession(options *encode.SessionOptions, logger log.Logger) *Session { WithIgnoreInvalidDateErr(sqlMode.HasAllowInvalidDatesMode()). WithIgnoreZeroInDate(!sqlMode.HasStrictMode() || sqlMode.HasAllowInvalidDatesMode()) vars.StmtCtx.SetTypeFlags(typeFlags) + + errLevels := vars.StmtCtx.ErrLevels() + errLevels[errctx.ErrGroupDividedByZero] = + errctx.ResolveErrLevel(!sqlMode.HasErrorForDivisionByZeroMode(), !sqlMode.HasStrictMode()) + vars.StmtCtx.SetErrLevels(errLevels) + if options.SysVars != nil { for k, v := range options.SysVars { // since 6.3(current master) tidb checks whether we can set a system variable diff --git a/pkg/ddl/BUILD.bazel b/pkg/ddl/BUILD.bazel index ffb39636197ff..992f22f58effd 100644 --- a/pkg/ddl/BUILD.bazel +++ b/pkg/ddl/BUILD.bazel @@ -89,6 +89,7 @@ go_library( "//pkg/disttask/operator", "//pkg/domain/infosync", "//pkg/domain/resourcegroup", + "//pkg/errctx", "//pkg/expression", "//pkg/infoschema", "//pkg/kv", diff --git a/pkg/ddl/backfilling_scheduler.go b/pkg/ddl/backfilling_scheduler.go index 3dafeef033c35..91f1779a7204d 100644 --- a/pkg/ddl/backfilling_scheduler.go +++ b/pkg/ddl/backfilling_scheduler.go @@ -24,6 +24,7 @@ import ( "github.com/pingcap/tidb/pkg/ddl/copr" "github.com/pingcap/tidb/pkg/ddl/ingest" sess "github.com/pingcap/tidb/pkg/ddl/internal/session" + "github.com/pingcap/tidb/pkg/errctx" "github.com/pingcap/tidb/pkg/kv" "github.com/pingcap/tidb/pkg/metrics" "github.com/pingcap/tidb/pkg/parser/model" @@ -163,10 +164,13 @@ func initSessCtx( if err := setSessCtxLocation(sessCtx, tzLocation); err != nil { return errors.Trace(err) } - sessCtx.GetSessionVars().StmtCtx.InReorg = true sessCtx.GetSessionVars().StmtCtx.SetTimeZone(sessCtx.GetSessionVars().Location()) sessCtx.GetSessionVars().StmtCtx.BadNullAsWarning = !sqlMode.HasStrictMode() - sessCtx.GetSessionVars().StmtCtx.DividedByZeroAsWarning = !sqlMode.HasStrictMode() + + errLevels := sessCtx.GetSessionVars().StmtCtx.ErrLevels() + errLevels[errctx.ErrGroupDividedByZero] = + errctx.ResolveErrLevel(!sqlMode.HasErrorForDivisionByZeroMode(), !sqlMode.HasStrictMode()) + sessCtx.GetSessionVars().StmtCtx.SetErrLevels(errLevels) typeFlags := types.StrictFlags. WithTruncateAsWarning(!sqlMode.HasStrictMode()). @@ -195,8 +199,8 @@ func restoreSessCtx(sessCtx sessionctx.Context) func(sessCtx sessionctx.Context) timezone = &tz } badNullAsWarn := sv.StmtCtx.BadNullAsWarning - dividedZeroAsWarn := sv.StmtCtx.DividedByZeroAsWarning typeFlags := sv.StmtCtx.TypeFlags() + errLevels := sv.StmtCtx.ErrLevels() resGroupName := sv.StmtCtx.ResourceGroupName return func(usedSessCtx sessionctx.Context) { uv := usedSessCtx.GetSessionVars() @@ -204,10 +208,9 @@ func restoreSessCtx(sessCtx sessionctx.Context) func(sessCtx sessionctx.Context) uv.SQLMode = sqlMode uv.TimeZone = timezone uv.StmtCtx.BadNullAsWarning = badNullAsWarn - uv.StmtCtx.DividedByZeroAsWarning = dividedZeroAsWarn uv.StmtCtx.SetTypeFlags(typeFlags) + uv.StmtCtx.SetErrLevels(errLevels) uv.StmtCtx.ResourceGroupName = resGroupName - uv.StmtCtx.InReorg = false } } diff --git a/pkg/errctx/context.go b/pkg/errctx/context.go index 3f127510f5634..7b7709f2b8f7f 100644 --- a/pkg/errctx/context.go +++ b/pkg/errctx/context.go @@ -209,5 +209,19 @@ func init() { errGroupMap[errCode] = ErrGroupTruncate } + errGroupMap[errno.ErrDivisionByZero] = ErrGroupDividedByZero errGroupMap[errno.ErrAutoincReadFailed] = ErrGroupAutoIncReadFailed } + +// ResolveErrLevel resolves the error level according to the `ignore` and `warn` flags +// if ignore is true, it will return `LevelIgnore` to ignore the error, +// otherwise, it will return `LevelWarn` or `LevelError` according to the `warn` flag +func ResolveErrLevel(ignore bool, warn bool) Level { + if ignore { + return LevelIgnore + } + if warn { + return LevelWarn + } + return LevelError +} diff --git a/pkg/executor/BUILD.bazel b/pkg/executor/BUILD.bazel index 7bb68be626aa4..80e4056ad82ba 100644 --- a/pkg/executor/BUILD.bazel +++ b/pkg/executor/BUILD.bazel @@ -380,6 +380,7 @@ go_test( "//pkg/distsql", "//pkg/domain", "//pkg/domain/infosync", + "//pkg/errctx", "//pkg/errno", "//pkg/executor/aggfuncs", "//pkg/executor/aggregate", diff --git a/pkg/executor/executor.go b/pkg/executor/executor.go index 1bbe2a3e4b94e..a54c241632af9 100644 --- a/pkg/executor/executor.go +++ b/pkg/executor/executor.go @@ -2105,15 +2105,19 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { // pushing them down to TiKV as flags. sc.InRestrictedSQL = vars.InRestrictedSQL + + errLevels := sc.ErrLevels() + errLevels[errctx.ErrGroupDividedByZero] = errctx.LevelWarn switch stmt := s.(type) { // `ResetUpdateStmtCtx` and `ResetDeleteStmtCtx` may modify the flags, so we'll need to store them. case *ast.UpdateStmt: ResetUpdateStmtCtx(sc, stmt, vars) + errLevels = sc.ErrLevels() case *ast.DeleteStmt: ResetDeleteStmtCtx(sc, stmt, vars) + errLevels = sc.ErrLevels() case *ast.InsertStmt: sc.InInsertStmt = true - var errLevels errctx.LevelMap // For insert statement (not for update statement), disabling the StrictSQLMode // should make TruncateAsWarning and DividedByZeroAsWarning, // but should not make DupKeyAsWarning. @@ -2123,9 +2127,11 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { if stmt.IgnoreErr { errLevels[errctx.ErrGroupAutoIncReadFailed] = errctx.LevelWarn } - sc.DividedByZeroAsWarning = !vars.StrictSQLMode || stmt.IgnoreErr + errLevels[errctx.ErrGroupDividedByZero] = errctx.ResolveErrLevel( + !vars.SQLMode.HasErrorForDivisionByZeroMode(), + !vars.StrictSQLMode || stmt.IgnoreErr, + ) sc.Priority = stmt.Priority - sc.SetErrLevels(errLevels) sc.SetTypeFlags(sc.TypeFlags(). WithTruncateAsWarning(!vars.StrictSQLMode || stmt.IgnoreErr). WithIgnoreInvalidDateErr(vars.SQLMode.HasAllowInvalidDatesMode()). @@ -2191,6 +2197,10 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { WithIgnoreInvalidDateErr(vars.SQLMode.HasAllowInvalidDatesMode())) } + if errLevels != sc.ErrLevels() { + sc.SetErrLevels(errLevels) + } + sc.SetTypeFlags(sc.TypeFlags(). WithSkipUTF8Check(vars.SkipUTF8Check). WithSkipSACIICheck(vars.SkipASCIICheck). @@ -2248,9 +2258,14 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { // ResetUpdateStmtCtx resets statement context for UpdateStmt. func ResetUpdateStmtCtx(sc *stmtctx.StatementContext, stmt *ast.UpdateStmt, vars *variable.SessionVars) { sc.InUpdateStmt = true + errLevels := sc.ErrLevels() sc.DupKeyAsWarning = stmt.IgnoreErr sc.BadNullAsWarning = !vars.StrictSQLMode || stmt.IgnoreErr - sc.DividedByZeroAsWarning = !vars.StrictSQLMode || stmt.IgnoreErr + errLevels[errctx.ErrGroupDividedByZero] = errctx.ResolveErrLevel( + !vars.SQLMode.HasErrorForDivisionByZeroMode(), + !vars.StrictSQLMode || stmt.IgnoreErr, + ) + sc.SetErrLevels(errLevels) sc.Priority = stmt.Priority sc.IgnoreNoPartition = stmt.IgnoreErr sc.SetTypeFlags(sc.TypeFlags(). @@ -2263,9 +2278,14 @@ func ResetUpdateStmtCtx(sc *stmtctx.StatementContext, stmt *ast.UpdateStmt, vars // ResetDeleteStmtCtx resets statement context for DeleteStmt. func ResetDeleteStmtCtx(sc *stmtctx.StatementContext, stmt *ast.DeleteStmt, vars *variable.SessionVars) { sc.InDeleteStmt = true + errLevels := sc.ErrLevels() sc.DupKeyAsWarning = stmt.IgnoreErr sc.BadNullAsWarning = !vars.StrictSQLMode || stmt.IgnoreErr - sc.DividedByZeroAsWarning = !vars.StrictSQLMode || stmt.IgnoreErr + errLevels[errctx.ErrGroupDividedByZero] = errctx.ResolveErrLevel( + !vars.SQLMode.HasErrorForDivisionByZeroMode(), + !vars.StrictSQLMode || stmt.IgnoreErr, + ) + sc.SetErrLevels(errLevels) sc.Priority = stmt.Priority sc.SetTypeFlags(sc.TypeFlags(). WithTruncateAsWarning(!vars.StrictSQLMode || stmt.IgnoreErr). diff --git a/pkg/executor/executor_pkg_test.go b/pkg/executor/executor_pkg_test.go index 31c841c92f1f2..e3fe3df4dca40 100644 --- a/pkg/executor/executor_pkg_test.go +++ b/pkg/executor/executor_pkg_test.go @@ -15,14 +15,19 @@ package executor import ( + "fmt" "runtime" "strconv" "testing" "time" "unsafe" + "github.com/pingcap/tidb/pkg/domain" + "github.com/pingcap/tidb/pkg/errctx" "github.com/pingcap/tidb/pkg/executor/aggfuncs" "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/sessionctx/variable" "github.com/pingcap/tidb/pkg/tablecodec" "github.com/pingcap/tidb/pkg/types" @@ -263,3 +268,93 @@ func TestFilterTemporaryTableKeys(t *testing.T) { res := filterTemporaryTableKeys(vars, []kv.Key{tablecodec.EncodeTablePrefix(tableID), tablecodec.EncodeTablePrefix(42)}) require.Len(t, res, 1) } + +func TestErrLevelsForResetStmtContext(t *testing.T) { + ctx := mock.NewContext() + domain.BindDomain(ctx, &domain.Domain{}) + + cases := []struct { + name string + sqlMode mysql.SQLMode + stmt []ast.StmtNode + levels errctx.LevelMap + }{ + { + name: "strict,write", + sqlMode: mysql.ModeStrictAllTables | mysql.ModeErrorForDivisionByZero, + stmt: []ast.StmtNode{&ast.InsertStmt{}, &ast.UpdateStmt{}, &ast.DeleteStmt{}}, + levels: errctx.LevelMap{}, + }, + { + name: "non-strict,write", + sqlMode: mysql.ModeErrorForDivisionByZero, + stmt: []ast.StmtNode{&ast.InsertStmt{}, &ast.UpdateStmt{}, &ast.DeleteStmt{}}, + levels: func() (l errctx.LevelMap) { + l[errctx.ErrGroupTruncate] = errctx.LevelWarn + l[errctx.ErrGroupDividedByZero] = errctx.LevelWarn + return + }(), + }, + { + name: "strict,insert ignore", + sqlMode: mysql.ModeStrictAllTables | mysql.ModeErrorForDivisionByZero, + stmt: []ast.StmtNode{&ast.InsertStmt{IgnoreErr: true}}, + levels: func() (l errctx.LevelMap) { + l[errctx.ErrGroupTruncate] = errctx.LevelWarn + l[errctx.ErrGroupDividedByZero] = errctx.LevelWarn + l[errctx.ErrGroupAutoIncReadFailed] = errctx.LevelWarn + return + }(), + }, + { + name: "strict,update/delete ignore", + sqlMode: mysql.ModeStrictAllTables | mysql.ModeErrorForDivisionByZero, + stmt: []ast.StmtNode{&ast.UpdateStmt{IgnoreErr: true}, &ast.DeleteStmt{IgnoreErr: true}}, + levels: func() (l errctx.LevelMap) { + l[errctx.ErrGroupTruncate] = errctx.LevelWarn + l[errctx.ErrGroupDividedByZero] = errctx.LevelWarn + return + }(), + }, + { + name: "strict without error_for_division_by_zero,write", + sqlMode: mysql.ModeStrictAllTables, + stmt: []ast.StmtNode{&ast.InsertStmt{}, &ast.UpdateStmt{}, &ast.DeleteStmt{}}, + levels: func() (l errctx.LevelMap) { + l[errctx.ErrGroupDividedByZero] = errctx.LevelIgnore + return + }(), + }, + { + name: "strict,select/union", + sqlMode: mysql.ModeStrictAllTables | mysql.ModeErrorForDivisionByZero, + stmt: []ast.StmtNode{&ast.SelectStmt{}, &ast.SetOprStmt{}}, + levels: func() (l errctx.LevelMap) { + l[errctx.ErrGroupTruncate] = errctx.LevelWarn + l[errctx.ErrGroupDividedByZero] = errctx.LevelWarn + return + }(), + }, + { + name: "non-strict,select/union", + sqlMode: mysql.ModeStrictAllTables | mysql.ModeErrorForDivisionByZero, + stmt: []ast.StmtNode{&ast.SelectStmt{}, &ast.SetOprStmt{}}, + levels: func() (l errctx.LevelMap) { + l[errctx.ErrGroupTruncate] = errctx.LevelWarn + l[errctx.ErrGroupDividedByZero] = errctx.LevelWarn + return + }(), + }, + } + + for i, c := range cases { + for _, stmt := range c.stmt { + msg := fmt.Sprintf("%d: %s, stmt: %T", i, c.name, stmt) + ctx.GetSessionVars().SQLMode = c.sqlMode + ctx.GetSessionVars().StrictSQLMode = ctx.GetSessionVars().SQLMode.HasStrictMode() + require.NoError(t, ResetContextOfStmt(ctx, stmt), msg) + ec := ctx.GetSessionVars().StmtCtx.ErrCtx() + require.Equal(t, c.levels, ec.LevelMap(), msg) + } + } +} diff --git a/pkg/expression/BUILD.bazel b/pkg/expression/BUILD.bazel index 8d155c3582bed..429cdf68d1c6d 100644 --- a/pkg/expression/BUILD.bazel +++ b/pkg/expression/BUILD.bazel @@ -194,6 +194,7 @@ go_test( shard_count = 50, deps = [ "//pkg/config", + "//pkg/errctx", "//pkg/errno", "//pkg/kv", "//pkg/parser", diff --git a/pkg/expression/errors.go b/pkg/expression/errors.go index 3207a946be451..3255ceed72466 100644 --- a/pkg/expression/errors.go +++ b/pkg/expression/errors.go @@ -85,17 +85,8 @@ func handleInvalidTimeError(ctx EvalContext, err error) error { // handleDivisionByZeroError reports error or warning depend on the context. func handleDivisionByZeroError(ctx EvalContext) error { - sc := ctx.GetSessionVars().StmtCtx - if sc.InInsertStmt || sc.InUpdateStmt || sc.InDeleteStmt || sc.InReorg { - if !ctx.GetSessionVars().SQLMode.HasErrorForDivisionByZeroMode() { - return nil - } - if ctx.GetSessionVars().StrictSQLMode && !sc.DividedByZeroAsWarning { - return ErrDivisionByZero - } - } - sc.AppendWarning(ErrDivisionByZero) - return nil + ec := ctx.GetSessionVars().StmtCtx.ErrCtx() + return ec.HandleError(ErrDivisionByZero) } // handleAllowedPacketOverflowed reports error or warning depend on the context. diff --git a/pkg/expression/evaluator_test.go b/pkg/expression/evaluator_test.go index 6234c3b8fa780..d751a256ea82a 100644 --- a/pkg/expression/evaluator_test.go +++ b/pkg/expression/evaluator_test.go @@ -18,6 +18,7 @@ import ( "testing" "time" + "github.com/pingcap/tidb/pkg/errctx" "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/parser/charset" "github.com/pingcap/tidb/pkg/parser/mysql" @@ -438,9 +439,9 @@ func TestBinopNumeric(t *testing.T) { {types.NewDecFromInt(10), ast.Mod, 0}, } - ctx.GetSessionVars().StmtCtx.InSelectStmt = false - ctx.GetSessionVars().SQLMode |= mysql.ModeErrorForDivisionByZero - ctx.GetSessionVars().StmtCtx.InInsertStmt = true + levels := ctx.GetSessionVars().StmtCtx.ErrLevels() + levels[errctx.ErrGroupDividedByZero] = errctx.LevelError + ctx.GetSessionVars().StmtCtx.SetErrLevels(levels) for _, tt := range testcases { fc := funcs[tt.op] f, err := fc.getFunction(ctx, datumsToConstants(types.MakeDatums(tt.lhs, tt.rhs))) @@ -449,7 +450,8 @@ func TestBinopNumeric(t *testing.T) { require.Error(t, err) } - ctx.GetSessionVars().StmtCtx.DividedByZeroAsWarning = true + levels[errctx.ErrGroupDividedByZero] = errctx.LevelWarn + ctx.GetSessionVars().StmtCtx.SetErrLevels(levels) for _, tt := range testcases { fc := funcs[tt.op] f, err := fc.getFunction(ctx, datumsToConstants(types.MakeDatums(tt.lhs, tt.rhs))) diff --git a/pkg/sessionctx/stmtctx/stmtctx.go b/pkg/sessionctx/stmtctx/stmtctx.go index 50e5d003278ae..c32e1ad361a19 100644 --- a/pkg/sessionctx/stmtctx/stmtctx.go +++ b/pkg/sessionctx/stmtctx/stmtctx.go @@ -185,10 +185,8 @@ type StatementContext struct { InCreateOrAlterStmt bool InSetSessionStatesStmt bool InPreparedPlanBuilding bool - InReorg bool DupKeyAsWarning bool BadNullAsWarning bool - DividedByZeroAsWarning bool InShowWarning bool UseCache bool ForcePlanCache bool // force the optimizer to use plan cache even if there is risky optimization, see #49736. @@ -432,6 +430,11 @@ type StatementContext struct { } } +var defaultErrLevels = func() (l errctx.LevelMap) { + l[errctx.ErrGroupDividedByZero] = errctx.LevelWarn + return +}() + // NewStmtCtx creates a new statement context func NewStmtCtx() *StatementContext { return NewStmtCtxWithTimeZone(time.UTC) @@ -444,7 +447,7 @@ func NewStmtCtxWithTimeZone(tz *time.Location) *StatementContext { ctxID: stmtCtxIDGenerator.Add(1), } sc.typeCtx = types.NewContext(types.DefaultStmtFlags, tz, sc) - sc.errCtx = newErrCtx(sc.typeCtx, errctx.LevelMap{}, sc) + sc.errCtx = newErrCtx(sc.typeCtx, defaultErrLevels, sc) return sc } @@ -454,7 +457,7 @@ func (sc *StatementContext) Reset() { ctxID: stmtCtxIDGenerator.Add(1), } sc.typeCtx = types.NewContext(types.DefaultStmtFlags, time.UTC, sc) - sc.errCtx = newErrCtx(sc.typeCtx, errctx.LevelMap{}, sc) + sc.errCtx = newErrCtx(sc.typeCtx, defaultErrLevels, sc) } // CtxID returns the context id of the statement @@ -494,6 +497,12 @@ func (sc *StatementContext) SetErrLevels(otherLevels errctx.LevelMap) { sc.errCtx = newErrCtx(sc.typeCtx, otherLevels, sc) } +// ErrLevels returns the current `errctx.LevelMap` +func (sc *StatementContext) ErrLevels() errctx.LevelMap { + ec := sc.ErrCtx() + return ec.LevelMap() +} + // TypeFlags returns the type flags func (sc *StatementContext) TypeFlags() types.Flags { return sc.typeCtx.Flags() @@ -1176,6 +1185,7 @@ func (sc *StatementContext) GetExecDetails() execdetails.ExecDetails { // PushDownFlags converts StatementContext to tipb.SelectRequest.Flags. func (sc *StatementContext) PushDownFlags() uint64 { var flags uint64 + ec := sc.ErrCtx() if sc.InInsertStmt { flags |= model.FlagInInsertStmt } else if sc.InUpdateStmt || sc.InDeleteStmt { @@ -1193,7 +1203,7 @@ func (sc *StatementContext) PushDownFlags() uint64 { if sc.TypeFlags().IgnoreZeroInDate() { flags |= model.FlagIgnoreZeroInDate } - if sc.DividedByZeroAsWarning { + if ec.LevelForGroup(errctx.ErrGroupDividedByZero) != errctx.LevelError { flags |= model.FlagDividedByZeroAsWarning } if sc.InLoadDataStmt { @@ -1254,7 +1264,11 @@ func (sc *StatementContext) InitFromPBFlagAndTz(flags uint64, tz *time.Location) sc.InInsertStmt = (flags & model.FlagInInsertStmt) > 0 sc.InSelectStmt = (flags & model.FlagInSelectStmt) > 0 sc.InDeleteStmt = (flags & model.FlagInUpdateOrDeleteStmt) > 0 - sc.DividedByZeroAsWarning = (flags & model.FlagDividedByZeroAsWarning) > 0 + levels := sc.ErrLevels() + levels[errctx.ErrGroupDividedByZero] = errctx.ResolveErrLevel(false, + (flags&model.FlagDividedByZeroAsWarning) > 0, + ) + sc.SetErrLevels(levels) sc.SetTimeZone(tz) sc.SetTypeFlags(types.DefaultStmtFlags. WithIgnoreTruncateErr((flags & model.FlagIgnoreTruncate) > 0). diff --git a/pkg/sessionctx/stmtctx/stmtctx_test.go b/pkg/sessionctx/stmtctx/stmtctx_test.go index 4e0fe05a6944d..b6f733a8c7625 100644 --- a/pkg/sessionctx/stmtctx/stmtctx_test.go +++ b/pkg/sessionctx/stmtctx/stmtctx_test.go @@ -82,6 +82,7 @@ func TestCopTasksDetails(t *testing.T) { func TestStatementContextPushDownFLags(t *testing.T) { newStmtCtx := func(fn func(*stmtctx.StatementContext)) *stmtctx.StatementContext { sc := stmtctx.NewStmtCtx() + sc.SetErrLevels(errctx.LevelMap{}) fn(sc) return sc } @@ -97,14 +98,20 @@ func TestStatementContextPushDownFLags(t *testing.T) { {newStmtCtx(func(sc *stmtctx.StatementContext) { sc.SetTypeFlags(sc.TypeFlags().WithIgnoreTruncateErr(true)) }), 1}, {newStmtCtx(func(sc *stmtctx.StatementContext) { sc.SetTypeFlags(sc.TypeFlags().WithTruncateAsWarning(true)) }), 66}, {newStmtCtx(func(sc *stmtctx.StatementContext) { sc.SetTypeFlags(sc.TypeFlags().WithIgnoreZeroInDate(true)) }), 128}, - {newStmtCtx(func(sc *stmtctx.StatementContext) { sc.DividedByZeroAsWarning = true }), 256}, + {newStmtCtx(func(sc *stmtctx.StatementContext) { + var levels errctx.LevelMap + levels[errctx.ErrGroupDividedByZero] = errctx.LevelWarn + sc.SetErrLevels(levels) + }), 256}, {newStmtCtx(func(sc *stmtctx.StatementContext) { sc.InLoadDataStmt = true }), 1024}, {newStmtCtx(func(sc *stmtctx.StatementContext) { sc.InSelectStmt = true sc.SetTypeFlags(sc.TypeFlags().WithTruncateAsWarning(true)) }), 98}, {newStmtCtx(func(sc *stmtctx.StatementContext) { - sc.DividedByZeroAsWarning = true + var levels errctx.LevelMap + levels[errctx.ErrGroupDividedByZero] = errctx.LevelWarn + sc.SetErrLevels(levels) sc.SetTypeFlags(sc.TypeFlags().WithIgnoreTruncateErr(true)) }), 257}, {newStmtCtx(func(sc *stmtctx.StatementContext) { @@ -318,7 +325,9 @@ func TestNewStmtCtx(t *testing.T) { require.Equal(t, types.DefaultStmtFlags, sc.TypeFlags()) require.Same(t, time.UTC, sc.TimeZone()) require.Same(t, time.UTC, sc.TimeZone()) - require.Equal(t, errctx.NewContextWithLevels(errctx.LevelMap{}, sc), sc.ErrCtx()) + var levels errctx.LevelMap + levels[errctx.ErrGroupDividedByZero] = errctx.LevelWarn + require.Equal(t, errctx.NewContextWithLevels(levels, sc), sc.ErrCtx()) sc.AppendWarning(errors.NewNoStackError("err1")) warnings := sc.GetWarnings() require.Equal(t, 1, len(warnings)) @@ -330,7 +339,7 @@ func TestNewStmtCtx(t *testing.T) { require.Equal(t, types.DefaultStmtFlags, sc.TypeFlags()) require.Same(t, tz, sc.TimeZone()) require.Same(t, tz, sc.TimeZone()) - require.Equal(t, errctx.NewContextWithLevels(errctx.LevelMap{}, sc), sc.ErrCtx()) + require.Equal(t, errctx.NewContextWithLevels(levels, sc), sc.ErrCtx()) sc.AppendWarning(errors.NewNoStackError("err2")) warnings = sc.GetWarnings() require.Equal(t, 1, len(warnings)) @@ -351,6 +360,7 @@ func TestSetStmtCtxTypeFlags(t *testing.T) { require.Equal(t, types.DefaultStmtFlags, sc.TypeFlags()) levels := errctx.LevelMap{} + sc.SetErrLevels(levels) sc.SetTypeFlags(types.FlagAllowNegativeToUnsigned | types.FlagSkipASCIICheck) require.Equal(t, types.FlagAllowNegativeToUnsigned|types.FlagSkipASCIICheck, sc.TypeFlags()) require.Equal(t, sc.TypeFlags(), sc.TypeFlags()) @@ -379,6 +389,7 @@ func TestResetStmtCtx(t *testing.T) { require.Equal(t, 1, len(sc.GetWarnings())) levels := errctx.LevelMap{} levels[errctx.ErrGroupTruncate] = errctx.LevelIgnore + levels[errctx.ErrGroupDividedByZero] = errctx.LevelWarn require.Equal(t, errctx.NewContextWithLevels(levels, sc), sc.ErrCtx()) sc.Reset() @@ -395,6 +406,7 @@ func TestResetStmtCtx(t *testing.T) { require.Equal(t, stmtctx.WarnLevelWarning, warnings[0].Level) require.Equal(t, "err2", warnings[0].Err.Error()) levels = errctx.LevelMap{} + levels[errctx.ErrGroupDividedByZero] = errctx.LevelWarn require.Equal(t, errctx.NewContextWithLevels(levels, sc), sc.ErrCtx()) } @@ -426,7 +438,9 @@ func TestErrCtx(t *testing.T) { err := types.ErrTruncated require.Error(t, sc.HandleError(err)) levels := errctx.LevelMap{} + levels[errctx.ErrGroupDividedByZero] = errctx.LevelWarn require.Equal(t, errctx.NewContextWithLevels(levels, sc), sc.ErrCtx()) + levels[errctx.ErrGroupDividedByZero] = errctx.LevelError // set error levels levels[errctx.ErrGroupAutoIncReadFailed] = errctx.LevelIgnore