Skip to content

Commit

Permalink
ddl,expression: introduce CtxWithHandleTruncateErrLevel to wrap a e…
Browse files Browse the repository at this point in the history
…xpression context to handle truncate error (#53441)

ref #53388
  • Loading branch information
lcwangchao authored May 28, 2024
1 parent 57d0b40 commit c02a2b1
Show file tree
Hide file tree
Showing 9 changed files with 185 additions and 113 deletions.
3 changes: 1 addition & 2 deletions pkg/ddl/column.go
Original file line number Diff line number Diff line change
Expand Up @@ -1820,8 +1820,7 @@ func updateColumnDefaultValue(d *ddlCtx, t *meta.Meta, job *model.Job, newCol *m
oldCol.AddFlag(mysql.NoDefaultValueFlag)
} else {
oldCol.DelFlag(mysql.NoDefaultValueFlag)
sctx := newReorgSessCtx(d.store)
err = checkDefaultValue(sctx, table.ToColumn(oldCol), true)
err = checkDefaultValue(newReorgExprCtx(), table.ToColumn(oldCol), true)
if err != nil {
job.State = model.JobStateCancelled
return ver, err
Expand Down
49 changes: 21 additions & 28 deletions pkg/ddl/ddl_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -1433,7 +1433,7 @@ func getFuncCallDefaultValue(col *table.Column, option *ast.ColumnOption, expr *
// getDefaultValue will get the default value for column.
// 1: get the expr restored string for the column which uses sequence next value as default value.
// 2: get specific default value for the other column.
func getDefaultValue(ctx sessionctx.Context, col *table.Column, option *ast.ColumnOption) (any, bool, error) {
func getDefaultValue(ctx exprctx.BuildContext, col *table.Column, option *ast.ColumnOption) (any, bool, error) {
// handle default value with function call
tp, fsp := col.FieldType.GetType(), col.FieldType.GetDecimal()
if x, ok := option.Expr.(*ast.FuncCallExpr); ok {
Expand All @@ -1445,7 +1445,7 @@ func getDefaultValue(ctx sessionctx.Context, col *table.Column, option *ast.Colu
}

if tp == mysql.TypeTimestamp || tp == mysql.TypeDatetime || tp == mysql.TypeDate {
vd, err := expression.GetTimeValue(ctx.GetExprCtx(), option.Expr, tp, fsp, nil)
vd, err := expression.GetTimeValue(ctx, option.Expr, tp, fsp, nil)
value := vd.GetValue()
if err != nil {
return nil, false, dbterror.ErrInvalidDefaultValue.GenWithStackByArgs(col.Name.O)
Expand All @@ -1465,7 +1465,7 @@ func getDefaultValue(ctx sessionctx.Context, col *table.Column, option *ast.Colu
}

// evaluate the non-function-call expr to a certain value.
v, err := expression.EvalSimpleAst(ctx.GetExprCtx(), option.Expr)
v, err := expression.EvalSimpleAst(ctx, option.Expr)
if err != nil {
return nil, false, errors.Trace(err)
}
Expand All @@ -1491,7 +1491,7 @@ func getDefaultValue(ctx sessionctx.Context, col *table.Column, option *ast.Colu
return str, false, err
}
// For other kind of fields (e.g. INT), we supply its integer as string value.
value, err := v.GetBinaryLiteral().ToInt(ctx.GetSessionVars().StmtCtx.TypeCtx())
value, err := v.GetBinaryLiteral().ToInt(ctx.GetEvalCtx().TypeCtx())
if err != nil {
return nil, false, err
}
Expand All @@ -1506,7 +1506,7 @@ func getDefaultValue(ctx sessionctx.Context, col *table.Column, option *ast.Colu
val, err := getEnumDefaultValue(v, col)
return val, false, err
case mysql.TypeDuration, mysql.TypeDate:
if v, err = v.ConvertTo(ctx.GetSessionVars().StmtCtx.TypeCtx(), &col.FieldType); err != nil {
if v, err = v.ConvertTo(ctx.GetEvalCtx().TypeCtx(), &col.FieldType); err != nil {
return "", false, errors.Trace(err)
}
case mysql.TypeBit:
Expand All @@ -1518,7 +1518,7 @@ func getDefaultValue(ctx sessionctx.Context, col *table.Column, option *ast.Colu
// For these types, convert it to standard format firstly.
// like integer fields, convert it into integer string literals. like convert "1.25" into "1" and "2.8" into "3".
// if raise a error, we will use original expression. We will handle it in check phase
if temp, err := v.ConvertTo(ctx.GetSessionVars().StmtCtx.TypeCtx(), &col.FieldType); err == nil {
if temp, err := v.ConvertTo(ctx.GetEvalCtx().TypeCtx(), &col.FieldType); err == nil {
v = temp
}
}
Expand Down Expand Up @@ -1665,7 +1665,7 @@ func setNoDefaultValueFlag(c *table.Column, hasDefaultValue bool) {
}
}

func checkDefaultValue(ctx sessionctx.Context, c *table.Column, hasDefaultValue bool) (err error) {
func checkDefaultValue(ctx exprctx.BuildContext, c *table.Column, hasDefaultValue bool) (err error) {
if !hasDefaultValue {
return nil
}
Expand All @@ -1677,9 +1677,10 @@ func checkDefaultValue(ctx sessionctx.Context, c *table.Column, hasDefaultValue
}
return nil
}
handleWithTruncateErr(ctx, func() {
_, err = table.GetColDefaultValue(ctx.GetExprCtx(), c.ToInfo())
})
_, err = table.GetColDefaultValue(
exprctx.CtxWithHandleTruncateErrLevel(ctx, errctx.LevelError),
c.ToInfo(),
)
if err != nil {
return types.ErrInvalidDefault.GenWithStackByArgs(c.Name)
}
Expand Down Expand Up @@ -5518,23 +5519,14 @@ func checkModifyTypes(origin *types.FieldType, to *types.FieldType, needRewriteC
return errors.Trace(err)
}

// handleWithTruncateErr handles the doFunc with FlagTruncateAsWarning and FlagIgnoreTruncateErr flags, both of which are false.
func handleWithTruncateErr(ctx sessionctx.Context, doFunc func()) {
sv := ctx.GetSessionVars().StmtCtx
oldTypeFlags := sv.TypeFlags()
newTypeFlags := oldTypeFlags.WithTruncateAsWarning(false).WithIgnoreTruncateErr(false)
sv.SetTypeFlags(newTypeFlags)
doFunc()
sv.SetTypeFlags(oldTypeFlags)
}

// SetDefaultValue sets the default value of the column.
func SetDefaultValue(ctx sessionctx.Context, col *table.Column, option *ast.ColumnOption) (hasDefaultValue bool, err error) {
var value any
var isSeqExpr bool
handleWithTruncateErr(ctx, func() {
value, isSeqExpr, err = getDefaultValue(ctx, col, option)
})
value, isSeqExpr, err = getDefaultValue(
exprctx.CtxWithHandleTruncateErrLevel(ctx.GetExprCtx(), errctx.LevelError),
col, option,
)
if err != nil {
return false, errors.Trace(err)
}
Expand Down Expand Up @@ -5685,7 +5677,7 @@ func processAndCheckDefaultValueAndColumn(ctx sessionctx.Context, col *table.Col
if err = checkColumnValueConstraint(col, col.GetCollate()); err != nil {
return errors.Trace(err)
}
if err = checkDefaultValue(ctx, col, hasDefaultValue); err != nil {
if err = checkDefaultValue(ctx.GetExprCtx(), col, hasDefaultValue); err != nil {
return errors.Trace(err)
}
if err = checkColumnFieldLength(col); err != nil {
Expand Down Expand Up @@ -5941,9 +5933,10 @@ func GetModifiableColumnJob(
return nil, dbterror.ErrUnsupportedModifyColumn.GenWithStack("cannot parse generated PartitionInfo")
}
pAst := at.Specs[0].Partition
handleWithTruncateErr(sctx, func() {
_, err = buildPartitionDefinitionsInfo(sctx.GetExprCtx(), pAst.Definitions, &newTblInfo, uint64(len(newTblInfo.Partition.Definitions)))
})
_, err = buildPartitionDefinitionsInfo(
exprctx.CtxWithHandleTruncateErrLevel(sctx.GetExprCtx(), errctx.LevelError),
pAst.Definitions, &newTblInfo, uint64(len(newTblInfo.Partition.Definitions)),
)
if err != nil {
return nil, dbterror.ErrUnsupportedModifyColumn.GenWithStack("New column does not match partition definitions: %s", err.Error())
}
Expand Down Expand Up @@ -6361,7 +6354,7 @@ func (d *ddl) AlterColumn(ctx sessionctx.Context, ident ast.Ident, spec *ast.Alt
if err != nil {
return errors.Trace(err)
}
if err = checkDefaultValue(ctx, col, hasDefaultValue); err != nil {
if err = checkDefaultValue(ctx.GetExprCtx(), col, hasDefaultValue); err != nil {
return errors.Trace(err)
}
}
Expand Down
14 changes: 11 additions & 3 deletions pkg/expression/context/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,17 @@ go_library(
go_test(
name = "context_test",
timeout = "short",
srcs = ["optional_test.go"],
srcs = [
"context_override_test.go",
"optional_test.go",
],
embed = [":context"],
flaky = True,
shard_count = 3,
deps = ["@com_github_stretchr_testify//require"],
shard_count = 4,
deps = [
"//pkg/errctx",
"//pkg/expression/contextstatic",
"//pkg/types",
"@com_github_stretchr_testify//require",
],
)
59 changes: 59 additions & 0 deletions pkg/expression/context/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,65 @@ func (ctx *NullRejectCheckExprContext) IsInNullRejectCheck() bool {
return true
}

type innerOverrideEvalContext struct {
EvalContext
typeCtx types.Context
errCtx errctx.Context
}

// TypeCtx implements EvalContext.TypeCtx
func (ctx *innerOverrideEvalContext) TypeCtx() types.Context {
return ctx.typeCtx
}

// ErrCtx implements EvalContext.GetEvalCtx
func (ctx *innerOverrideEvalContext) ErrCtx() errctx.Context {
return ctx.errCtx
}

type innerOverrideBuildContext struct {
BuildContext
evalCtx EvalContext
}

// GetEvalCtx implements BuildContext.GetEvalCtx
func (ctx *innerOverrideBuildContext) GetEvalCtx() EvalContext {
return ctx.evalCtx
}

// CtxWithHandleTruncateErrLevel returns a new BuildContext with the specified level for handling truncate error.
func CtxWithHandleTruncateErrLevel(ctx BuildContext, level errctx.Level) BuildContext {
truncateAsWarnings, ignoreTruncate := false, false
switch level {
case errctx.LevelWarn:
truncateAsWarnings = true
case errctx.LevelIgnore:
ignoreTruncate = true
default:
}

evalCtx := ctx.GetEvalCtx()
tc, ec := evalCtx.TypeCtx(), evalCtx.ErrCtx()

flags := tc.Flags().
WithTruncateAsWarning(truncateAsWarnings).
WithIgnoreTruncateErr(ignoreTruncate)

if tc.Flags() == flags && ec.LevelForGroup(errctx.ErrGroupTruncate) == level {
// We do not need to create a new context if the flags and level are the same.
return ctx
}

return &innerOverrideBuildContext{
BuildContext: ctx,
evalCtx: &innerOverrideEvalContext{
EvalContext: evalCtx,
typeCtx: tc.WithFlags(flags),
errCtx: ec.WithErrGroupLevel(errctx.ErrGroupTruncate, level),
},
}
}

// AssertLocationWithSessionVars asserts the location in the context and session variables are the same.
// It is only used for testing.
func AssertLocationWithSessionVars(ctxLoc *time.Location, vars *variable.SessionVars) {
Expand Down
85 changes: 85 additions & 0 deletions pkg/expression/context/context_override_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
// Copyright 2024 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package context_test

import (
"testing"
"time"

"github.com/pingcap/tidb/pkg/errctx"
"github.com/pingcap/tidb/pkg/expression/context"
"github.com/pingcap/tidb/pkg/expression/contextstatic"
"github.com/pingcap/tidb/pkg/types"
"github.com/stretchr/testify/require"
)

func TestCtxWithHandleTruncateErrLevel(t *testing.T) {
for _, level := range []errctx.Level{errctx.LevelWarn, errctx.LevelIgnore, errctx.LevelError} {
originalLevelMap := errctx.LevelMap{errctx.ErrGroupDividedByZero: errctx.LevelError}
expectedLevelMap := originalLevelMap
expectedLevelMap[errctx.ErrGroupTruncate] = level

originalFlags := types.DefaultStmtFlags
originalLoc := time.FixedZone("tz1", 3600*2)
var expectedFlags types.Flags
switch level {
case errctx.LevelError:
originalFlags = originalFlags.WithTruncateAsWarning(true)
originalLevelMap[errctx.ErrGroupTruncate] = errctx.LevelWarn
expectedFlags = originalFlags.WithTruncateAsWarning(false)
case errctx.LevelWarn:
expectedFlags = originalFlags.WithTruncateAsWarning(true)
case errctx.LevelIgnore:
expectedFlags = originalFlags.WithIgnoreTruncateErr(true)
default:
require.FailNow(t, "unexpected level")
}

evalCtx := contextstatic.NewStaticEvalContext(
contextstatic.WithTypeFlags(originalFlags),
contextstatic.WithLocation(originalLoc),
contextstatic.WithErrLevelMap(errctx.LevelMap{errctx.ErrGroupTruncate: level}),
)

tc, ec := evalCtx.TypeCtx(), evalCtx.ErrCtx()
ctx := contextstatic.NewStaticExprContext(
contextstatic.WithEvalCtx(evalCtx),
contextstatic.WithConnectionID(1234),
)

// override should take effect
newCtx := context.CtxWithHandleTruncateErrLevel(ctx, level)
newEvalCtx := newCtx.GetEvalCtx()
newTypeCtx, newErrCtx := newEvalCtx.TypeCtx(), newEvalCtx.ErrCtx()
require.Equal(t, expectedFlags, newTypeCtx.Flags())
require.Equal(t, expectedLevelMap, newErrCtx.LevelMap())

// other fields should not change
require.Equal(t, originalLoc, newTypeCtx.Location())
require.Equal(t, originalLoc, newEvalCtx.Location())
require.Equal(t, uint64(1234), newCtx.ConnectionID())

// old ctx should not change
require.Same(t, evalCtx, ctx.GetEvalCtx())
require.Equal(t, tc, evalCtx.TypeCtx())
require.Equal(t, ec, evalCtx.ErrCtx())
require.Same(t, originalLoc, evalCtx.Location())
require.Equal(t, uint64(1234), ctx.ConnectionID())

// not create new ctx case
newCtx2 := context.CtxWithHandleTruncateErrLevel(newCtx, level)
require.Same(t, newCtx, newCtx2)
}
}
44 changes: 2 additions & 42 deletions pkg/expression/expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (

"github.com/pingcap/errors"
"github.com/pingcap/tidb/pkg/errctx"
exprctx "github.com/pingcap/tidb/pkg/expression/context"
"github.com/pingcap/tidb/pkg/parser/ast"
"github.com/pingcap/tidb/pkg/parser/model"
"github.com/pingcap/tidb/pkg/parser/mysql"
Expand Down Expand Up @@ -1007,47 +1008,6 @@ func TableInfo2SchemaAndNames(ctx BuildContext, dbName model.CIStr, tbl *model.T
return schema, names, nil
}

type ignoreTruncateExprCtx struct {
BuildContext
EvalContext
tc types.Context
ec errctx.Context
}

// ignoreTruncate returns a new BuildContext that ignores the truncate error.
func ignoreTruncate(ctx BuildContext) BuildContext {
evalCtx := ctx.GetEvalCtx()
tc, ec := evalCtx.TypeCtx(), evalCtx.ErrCtx()
if tc.Flags().IgnoreTruncateErr() && ec.LevelForGroup(errctx.ErrGroupTruncate) == errctx.LevelIgnore {
return ctx
}

tc = tc.WithFlags(tc.Flags().WithIgnoreTruncateErr(true))
ec = ec.WithErrGroupLevel(errctx.ErrGroupTruncate, errctx.LevelIgnore)

return &ignoreTruncateExprCtx{
BuildContext: ctx,
EvalContext: evalCtx,
tc: tc,
ec: ec,
}
}

// GetEvalCtx implements the BuildContext.EvalCtx().
func (ctx *ignoreTruncateExprCtx) GetEvalCtx() EvalContext {
return ctx
}

// TypeCtx implements the EvalContext.TypeCtx().
func (ctx *ignoreTruncateExprCtx) TypeCtx() types.Context {
return ctx.tc
}

// ErrCtx implements the EvalContext.ErrCtx().
func (ctx *ignoreTruncateExprCtx) ErrCtx() errctx.Context {
return ctx.ec
}

// ColumnInfos2ColumnsAndNames converts the ColumnInfo to the *Column and NameSlice.
func ColumnInfos2ColumnsAndNames(ctx BuildContext, dbName, tblName model.CIStr, colInfos []*model.ColumnInfo, tblInfo *model.TableInfo) ([]*Column, types.NameSlice, error) {
columns := make([]*Column, 0, len(colInfos))
Expand Down Expand Up @@ -1078,7 +1038,7 @@ func ColumnInfos2ColumnsAndNames(ctx BuildContext, dbName, tblName model.CIStr,
if col.IsVirtualGenerated() {
if !truncateIgnored {
// Ignore redundant warning here.
ctx = ignoreTruncate(ctx)
ctx = exprctx.CtxWithHandleTruncateErrLevel(ctx, errctx.LevelIgnore)
truncateIgnored = true
}

Expand Down
Loading

0 comments on commit c02a2b1

Please sign in to comment.