Skip to content

Commit

Permalink
*: use errctx to handle divide zero error
Browse files Browse the repository at this point in the history
  • Loading branch information
lcwangchao committed Jan 4, 2024
1 parent d85a8bd commit 05cff48
Show file tree
Hide file tree
Showing 13 changed files with 199 additions and 35 deletions.
1 change: 1 addition & 0 deletions br/pkg/lightning/backend/kv/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ go_library(
"//br/pkg/logutil",
"//br/pkg/redact",
"//br/pkg/utils",
"//pkg/errctx",
"//pkg/expression",
"//pkg/kv",
"//pkg/meta/autoid",
Expand Down
7 changes: 7 additions & 0 deletions br/pkg/lightning/backend/kv/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"github.com/pingcap/tidb/br/pkg/lightning/log"
"github.com/pingcap/tidb/br/pkg/lightning/manual"
"github.com/pingcap/tidb/br/pkg/utils"
"github.com/pingcap/tidb/pkg/errctx"
"github.com/pingcap/tidb/pkg/kv"
"github.com/pingcap/tidb/pkg/parser/model"
"github.com/pingcap/tidb/pkg/sessionctx"
Expand Down Expand Up @@ -293,6 +294,12 @@ func NewSession(options *encode.SessionOptions, logger log.Logger) *Session {
WithIgnoreInvalidDateErr(sqlMode.HasAllowInvalidDatesMode()).
WithIgnoreZeroInDate(!sqlMode.HasStrictMode() || sqlMode.HasAllowInvalidDatesMode())
vars.StmtCtx.SetTypeFlags(typeFlags)

errLevels := vars.StmtCtx.ErrLevels()
errLevels[errctx.ErrGroupDividedByZero] =
errctx.ResolveErrLevel(!sqlMode.HasErrorForDivisionByZeroMode(), !sqlMode.HasStrictMode())
vars.StmtCtx.SetErrLevels(errLevels)

if options.SysVars != nil {
for k, v := range options.SysVars {
// since 6.3(current master) tidb checks whether we can set a system variable
Expand Down
1 change: 1 addition & 0 deletions pkg/ddl/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ go_library(
"//pkg/disttask/operator",
"//pkg/domain/infosync",
"//pkg/domain/resourcegroup",
"//pkg/errctx",
"//pkg/expression",
"//pkg/infoschema",
"//pkg/kv",
Expand Down
13 changes: 8 additions & 5 deletions pkg/ddl/backfilling_scheduler.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"github.com/pingcap/tidb/pkg/ddl/copr"
"github.com/pingcap/tidb/pkg/ddl/ingest"
sess "github.com/pingcap/tidb/pkg/ddl/internal/session"
"github.com/pingcap/tidb/pkg/errctx"
"github.com/pingcap/tidb/pkg/kv"
"github.com/pingcap/tidb/pkg/metrics"
"github.com/pingcap/tidb/pkg/parser/model"
Expand Down Expand Up @@ -163,10 +164,13 @@ func initSessCtx(
if err := setSessCtxLocation(sessCtx, tzLocation); err != nil {
return errors.Trace(err)
}
sessCtx.GetSessionVars().StmtCtx.InReorg = true
sessCtx.GetSessionVars().StmtCtx.SetTimeZone(sessCtx.GetSessionVars().Location())
sessCtx.GetSessionVars().StmtCtx.BadNullAsWarning = !sqlMode.HasStrictMode()
sessCtx.GetSessionVars().StmtCtx.DividedByZeroAsWarning = !sqlMode.HasStrictMode()

errLevels := sessCtx.GetSessionVars().StmtCtx.ErrLevels()
errLevels[errctx.ErrGroupDividedByZero] =
errctx.ResolveErrLevel(!sqlMode.HasErrorForDivisionByZeroMode(), !sqlMode.HasStrictMode())
sessCtx.GetSessionVars().StmtCtx.SetErrLevels(errLevels)

typeFlags := types.StrictFlags.
WithTruncateAsWarning(!sqlMode.HasStrictMode()).
Expand Down Expand Up @@ -195,19 +199,18 @@ func restoreSessCtx(sessCtx sessionctx.Context) func(sessCtx sessionctx.Context)
timezone = &tz
}
badNullAsWarn := sv.StmtCtx.BadNullAsWarning
dividedZeroAsWarn := sv.StmtCtx.DividedByZeroAsWarning
typeFlags := sv.StmtCtx.TypeFlags()
errLevels := sv.StmtCtx.ErrLevels()
resGroupName := sv.StmtCtx.ResourceGroupName
return func(usedSessCtx sessionctx.Context) {
uv := usedSessCtx.GetSessionVars()
uv.RowEncoder.Enable = rowEncoder
uv.SQLMode = sqlMode
uv.TimeZone = timezone
uv.StmtCtx.BadNullAsWarning = badNullAsWarn
uv.StmtCtx.DividedByZeroAsWarning = dividedZeroAsWarn
uv.StmtCtx.SetTypeFlags(typeFlags)
uv.StmtCtx.SetErrLevels(errLevels)
uv.StmtCtx.ResourceGroupName = resGroupName
uv.StmtCtx.InReorg = false
}
}

Expand Down
14 changes: 14 additions & 0 deletions pkg/errctx/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,5 +209,19 @@ func init() {
errGroupMap[errCode] = ErrGroupTruncate
}

errGroupMap[errno.ErrDivisionByZero] = ErrGroupDividedByZero
errGroupMap[errno.ErrAutoincReadFailed] = ErrGroupAutoIncReadFailed
}

// ResolveErrLevel resolves the error level according to the `ignore` and `warn` flags
// if ignore is true, it will return `LevelIgnore` to ignore the error,
// otherwise, it will return `LevelWarn` or `LevelError` according to the `warn` flag
func ResolveErrLevel(ignore bool, warn bool) Level {
if ignore {
return LevelIgnore
}
if warn {
return LevelWarn
}
return LevelError
}
1 change: 1 addition & 0 deletions pkg/executor/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,7 @@ go_test(
"//pkg/distsql",
"//pkg/domain",
"//pkg/domain/infosync",
"//pkg/errctx",
"//pkg/errno",
"//pkg/executor/aggfuncs",
"//pkg/executor/aggregate",
Expand Down
30 changes: 25 additions & 5 deletions pkg/executor/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -2105,15 +2105,19 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) {
// pushing them down to TiKV as flags.

sc.InRestrictedSQL = vars.InRestrictedSQL

errLevels := sc.ErrLevels()
errLevels[errctx.ErrGroupDividedByZero] = errctx.LevelWarn
switch stmt := s.(type) {
// `ResetUpdateStmtCtx` and `ResetDeleteStmtCtx` may modify the flags, so we'll need to store them.
case *ast.UpdateStmt:
ResetUpdateStmtCtx(sc, stmt, vars)
errLevels = sc.ErrLevels()
case *ast.DeleteStmt:
ResetDeleteStmtCtx(sc, stmt, vars)
errLevels = sc.ErrLevels()
case *ast.InsertStmt:
sc.InInsertStmt = true
var errLevels errctx.LevelMap
// For insert statement (not for update statement), disabling the StrictSQLMode
// should make TruncateAsWarning and DividedByZeroAsWarning,
// but should not make DupKeyAsWarning.
Expand All @@ -2123,9 +2127,11 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) {
if stmt.IgnoreErr {
errLevels[errctx.ErrGroupAutoIncReadFailed] = errctx.LevelWarn
}
sc.DividedByZeroAsWarning = !vars.StrictSQLMode || stmt.IgnoreErr
errLevels[errctx.ErrGroupDividedByZero] = errctx.ResolveErrLevel(
!vars.SQLMode.HasErrorForDivisionByZeroMode(),
!vars.StrictSQLMode || stmt.IgnoreErr,
)
sc.Priority = stmt.Priority
sc.SetErrLevels(errLevels)
sc.SetTypeFlags(sc.TypeFlags().
WithTruncateAsWarning(!vars.StrictSQLMode || stmt.IgnoreErr).
WithIgnoreInvalidDateErr(vars.SQLMode.HasAllowInvalidDatesMode()).
Expand Down Expand Up @@ -2191,6 +2197,10 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) {
WithIgnoreInvalidDateErr(vars.SQLMode.HasAllowInvalidDatesMode()))
}

if errLevels != sc.ErrLevels() {
sc.SetErrLevels(errLevels)
}

sc.SetTypeFlags(sc.TypeFlags().
WithSkipUTF8Check(vars.SkipUTF8Check).
WithSkipSACIICheck(vars.SkipASCIICheck).
Expand Down Expand Up @@ -2248,9 +2258,14 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) {
// ResetUpdateStmtCtx resets statement context for UpdateStmt.
func ResetUpdateStmtCtx(sc *stmtctx.StatementContext, stmt *ast.UpdateStmt, vars *variable.SessionVars) {
sc.InUpdateStmt = true
errLevels := sc.ErrLevels()
sc.DupKeyAsWarning = stmt.IgnoreErr
sc.BadNullAsWarning = !vars.StrictSQLMode || stmt.IgnoreErr
sc.DividedByZeroAsWarning = !vars.StrictSQLMode || stmt.IgnoreErr
errLevels[errctx.ErrGroupDividedByZero] = errctx.ResolveErrLevel(
!vars.SQLMode.HasErrorForDivisionByZeroMode(),
!vars.StrictSQLMode || stmt.IgnoreErr,
)
sc.SetErrLevels(errLevels)
sc.Priority = stmt.Priority
sc.IgnoreNoPartition = stmt.IgnoreErr
sc.SetTypeFlags(sc.TypeFlags().
Expand All @@ -2263,9 +2278,14 @@ func ResetUpdateStmtCtx(sc *stmtctx.StatementContext, stmt *ast.UpdateStmt, vars
// ResetDeleteStmtCtx resets statement context for DeleteStmt.
func ResetDeleteStmtCtx(sc *stmtctx.StatementContext, stmt *ast.DeleteStmt, vars *variable.SessionVars) {
sc.InDeleteStmt = true
errLevels := sc.ErrLevels()
sc.DupKeyAsWarning = stmt.IgnoreErr
sc.BadNullAsWarning = !vars.StrictSQLMode || stmt.IgnoreErr
sc.DividedByZeroAsWarning = !vars.StrictSQLMode || stmt.IgnoreErr
errLevels[errctx.ErrGroupDividedByZero] = errctx.ResolveErrLevel(
!vars.SQLMode.HasErrorForDivisionByZeroMode(),
!vars.StrictSQLMode || stmt.IgnoreErr,
)
sc.SetErrLevels(errLevels)
sc.Priority = stmt.Priority
sc.SetTypeFlags(sc.TypeFlags().
WithTruncateAsWarning(!vars.StrictSQLMode || stmt.IgnoreErr).
Expand Down
95 changes: 95 additions & 0 deletions pkg/executor/executor_pkg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,19 @@
package executor

import (
"fmt"
"runtime"
"strconv"
"testing"
"time"
"unsafe"

"github.com/pingcap/tidb/pkg/domain"
"github.com/pingcap/tidb/pkg/errctx"
"github.com/pingcap/tidb/pkg/executor/aggfuncs"
"github.com/pingcap/tidb/pkg/kv"
"github.com/pingcap/tidb/pkg/parser/ast"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/sessionctx/variable"
"github.com/pingcap/tidb/pkg/tablecodec"
"github.com/pingcap/tidb/pkg/types"
Expand Down Expand Up @@ -263,3 +268,93 @@ func TestFilterTemporaryTableKeys(t *testing.T) {
res := filterTemporaryTableKeys(vars, []kv.Key{tablecodec.EncodeTablePrefix(tableID), tablecodec.EncodeTablePrefix(42)})
require.Len(t, res, 1)
}

func TestErrLevelsForResetStmtContext(t *testing.T) {
ctx := mock.NewContext()
domain.BindDomain(ctx, &domain.Domain{})

cases := []struct {
name string
sqlMode mysql.SQLMode
stmt []ast.StmtNode
levels errctx.LevelMap
}{
{
name: "strict,write",
sqlMode: mysql.ModeStrictAllTables | mysql.ModeErrorForDivisionByZero,
stmt: []ast.StmtNode{&ast.InsertStmt{}, &ast.UpdateStmt{}, &ast.DeleteStmt{}},
levels: errctx.LevelMap{},
},
{
name: "non-strict,write",
sqlMode: mysql.ModeErrorForDivisionByZero,
stmt: []ast.StmtNode{&ast.InsertStmt{}, &ast.UpdateStmt{}, &ast.DeleteStmt{}},
levels: func() (l errctx.LevelMap) {
l[errctx.ErrGroupTruncate] = errctx.LevelWarn
l[errctx.ErrGroupDividedByZero] = errctx.LevelWarn
return
}(),
},
{
name: "strict,insert ignore",
sqlMode: mysql.ModeStrictAllTables | mysql.ModeErrorForDivisionByZero,
stmt: []ast.StmtNode{&ast.InsertStmt{IgnoreErr: true}},
levels: func() (l errctx.LevelMap) {
l[errctx.ErrGroupTruncate] = errctx.LevelWarn
l[errctx.ErrGroupDividedByZero] = errctx.LevelWarn
l[errctx.ErrGroupAutoIncReadFailed] = errctx.LevelWarn
return
}(),
},
{
name: "strict,update/delete ignore",
sqlMode: mysql.ModeStrictAllTables | mysql.ModeErrorForDivisionByZero,
stmt: []ast.StmtNode{&ast.UpdateStmt{IgnoreErr: true}, &ast.DeleteStmt{IgnoreErr: true}},
levels: func() (l errctx.LevelMap) {
l[errctx.ErrGroupTruncate] = errctx.LevelWarn
l[errctx.ErrGroupDividedByZero] = errctx.LevelWarn
return
}(),
},
{
name: "strict without error_for_division_by_zero,write",
sqlMode: mysql.ModeStrictAllTables,
stmt: []ast.StmtNode{&ast.InsertStmt{}, &ast.UpdateStmt{}, &ast.DeleteStmt{}},
levels: func() (l errctx.LevelMap) {
l[errctx.ErrGroupDividedByZero] = errctx.LevelIgnore
return
}(),
},
{
name: "strict,select/union",
sqlMode: mysql.ModeStrictAllTables | mysql.ModeErrorForDivisionByZero,
stmt: []ast.StmtNode{&ast.SelectStmt{}, &ast.SetOprStmt{}},
levels: func() (l errctx.LevelMap) {
l[errctx.ErrGroupTruncate] = errctx.LevelWarn
l[errctx.ErrGroupDividedByZero] = errctx.LevelWarn
return
}(),
},
{
name: "non-strict,select/union",
sqlMode: mysql.ModeStrictAllTables | mysql.ModeErrorForDivisionByZero,
stmt: []ast.StmtNode{&ast.SelectStmt{}, &ast.SetOprStmt{}},
levels: func() (l errctx.LevelMap) {
l[errctx.ErrGroupTruncate] = errctx.LevelWarn
l[errctx.ErrGroupDividedByZero] = errctx.LevelWarn
return
}(),
},
}

for i, c := range cases {
for _, stmt := range c.stmt {
msg := fmt.Sprintf("%d: %s, stmt: %T", i, c.name, stmt)
ctx.GetSessionVars().SQLMode = c.sqlMode
ctx.GetSessionVars().StrictSQLMode = ctx.GetSessionVars().SQLMode.HasStrictMode()
require.NoError(t, ResetContextOfStmt(ctx, stmt), msg)
ec := ctx.GetSessionVars().StmtCtx.ErrCtx()
require.Equal(t, c.levels, ec.LevelMap(), msg)
}
}
}
1 change: 1 addition & 0 deletions pkg/expression/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ go_test(
shard_count = 50,
deps = [
"//pkg/config",
"//pkg/errctx",
"//pkg/errno",
"//pkg/kv",
"//pkg/parser",
Expand Down
13 changes: 2 additions & 11 deletions pkg/expression/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,17 +85,8 @@ func handleInvalidTimeError(ctx EvalContext, err error) error {

// handleDivisionByZeroError reports error or warning depend on the context.
func handleDivisionByZeroError(ctx EvalContext) error {
sc := ctx.GetSessionVars().StmtCtx
if sc.InInsertStmt || sc.InUpdateStmt || sc.InDeleteStmt || sc.InReorg {
if !ctx.GetSessionVars().SQLMode.HasErrorForDivisionByZeroMode() {
return nil
}
if ctx.GetSessionVars().StrictSQLMode && !sc.DividedByZeroAsWarning {
return ErrDivisionByZero
}
}
sc.AppendWarning(ErrDivisionByZero)
return nil
ec := ctx.GetSessionVars().StmtCtx.ErrCtx()
return ec.HandleError(ErrDivisionByZero)
}

// handleAllowedPacketOverflowed reports error or warning depend on the context.
Expand Down
10 changes: 6 additions & 4 deletions pkg/expression/evaluator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"testing"
"time"

"github.com/pingcap/tidb/pkg/errctx"
"github.com/pingcap/tidb/pkg/parser/ast"
"github.com/pingcap/tidb/pkg/parser/charset"
"github.com/pingcap/tidb/pkg/parser/mysql"
Expand Down Expand Up @@ -438,9 +439,9 @@ func TestBinopNumeric(t *testing.T) {
{types.NewDecFromInt(10), ast.Mod, 0},
}

ctx.GetSessionVars().StmtCtx.InSelectStmt = false
ctx.GetSessionVars().SQLMode |= mysql.ModeErrorForDivisionByZero
ctx.GetSessionVars().StmtCtx.InInsertStmt = true
levels := ctx.GetSessionVars().StmtCtx.ErrLevels()
levels[errctx.ErrGroupDividedByZero] = errctx.LevelError
ctx.GetSessionVars().StmtCtx.SetErrLevels(levels)
for _, tt := range testcases {
fc := funcs[tt.op]
f, err := fc.getFunction(ctx, datumsToConstants(types.MakeDatums(tt.lhs, tt.rhs)))
Expand All @@ -449,7 +450,8 @@ func TestBinopNumeric(t *testing.T) {
require.Error(t, err)
}

ctx.GetSessionVars().StmtCtx.DividedByZeroAsWarning = true
levels[errctx.ErrGroupDividedByZero] = errctx.LevelWarn
ctx.GetSessionVars().StmtCtx.SetErrLevels(levels)
for _, tt := range testcases {
fc := funcs[tt.op]
f, err := fc.getFunction(ctx, datumsToConstants(types.MakeDatums(tt.lhs, tt.rhs)))
Expand Down
Loading

0 comments on commit 05cff48

Please sign in to comment.