From c02a2b1fe69aa024d1c1f49d98c46553ea8cf654 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E8=B6=85?= Date: Tue, 28 May 2024 12:02:49 +0800 Subject: [PATCH] ddl,expression: introduce `CtxWithHandleTruncateErrLevel` to wrap a expression context to handle truncate error (#53441) ref pingcap/tidb#53388 --- pkg/ddl/column.go | 3 +- pkg/ddl/ddl_api.go | 49 +++++------ pkg/expression/context/BUILD.bazel | 14 ++- pkg/expression/context/context.go | 59 +++++++++++++ .../context/context_override_test.go | 85 +++++++++++++++++++ pkg/expression/expression.go | 44 +--------- pkg/expression/expression_test.go | 32 ------- pkg/planner/core/BUILD.bazel | 2 + pkg/planner/core/logical_plan_builder.go | 10 +-- 9 files changed, 185 insertions(+), 113 deletions(-) create mode 100644 pkg/expression/context/context_override_test.go diff --git a/pkg/ddl/column.go b/pkg/ddl/column.go index 2c49fa0477fae..e7f8efa4744a1 100644 --- a/pkg/ddl/column.go +++ b/pkg/ddl/column.go @@ -1820,8 +1820,7 @@ func updateColumnDefaultValue(d *ddlCtx, t *meta.Meta, job *model.Job, newCol *m oldCol.AddFlag(mysql.NoDefaultValueFlag) } else { oldCol.DelFlag(mysql.NoDefaultValueFlag) - sctx := newReorgSessCtx(d.store) - err = checkDefaultValue(sctx, table.ToColumn(oldCol), true) + err = checkDefaultValue(newReorgExprCtx(), table.ToColumn(oldCol), true) if err != nil { job.State = model.JobStateCancelled return ver, err diff --git a/pkg/ddl/ddl_api.go b/pkg/ddl/ddl_api.go index e50c24a0cd96c..839c19487eaf9 100644 --- a/pkg/ddl/ddl_api.go +++ b/pkg/ddl/ddl_api.go @@ -1433,7 +1433,7 @@ func getFuncCallDefaultValue(col *table.Column, option *ast.ColumnOption, expr * // getDefaultValue will get the default value for column. // 1: get the expr restored string for the column which uses sequence next value as default value. // 2: get specific default value for the other column. -func getDefaultValue(ctx sessionctx.Context, col *table.Column, option *ast.ColumnOption) (any, bool, error) { +func getDefaultValue(ctx exprctx.BuildContext, col *table.Column, option *ast.ColumnOption) (any, bool, error) { // handle default value with function call tp, fsp := col.FieldType.GetType(), col.FieldType.GetDecimal() if x, ok := option.Expr.(*ast.FuncCallExpr); ok { @@ -1445,7 +1445,7 @@ func getDefaultValue(ctx sessionctx.Context, col *table.Column, option *ast.Colu } if tp == mysql.TypeTimestamp || tp == mysql.TypeDatetime || tp == mysql.TypeDate { - vd, err := expression.GetTimeValue(ctx.GetExprCtx(), option.Expr, tp, fsp, nil) + vd, err := expression.GetTimeValue(ctx, option.Expr, tp, fsp, nil) value := vd.GetValue() if err != nil { return nil, false, dbterror.ErrInvalidDefaultValue.GenWithStackByArgs(col.Name.O) @@ -1465,7 +1465,7 @@ func getDefaultValue(ctx sessionctx.Context, col *table.Column, option *ast.Colu } // evaluate the non-function-call expr to a certain value. - v, err := expression.EvalSimpleAst(ctx.GetExprCtx(), option.Expr) + v, err := expression.EvalSimpleAst(ctx, option.Expr) if err != nil { return nil, false, errors.Trace(err) } @@ -1491,7 +1491,7 @@ func getDefaultValue(ctx sessionctx.Context, col *table.Column, option *ast.Colu return str, false, err } // For other kind of fields (e.g. INT), we supply its integer as string value. - value, err := v.GetBinaryLiteral().ToInt(ctx.GetSessionVars().StmtCtx.TypeCtx()) + value, err := v.GetBinaryLiteral().ToInt(ctx.GetEvalCtx().TypeCtx()) if err != nil { return nil, false, err } @@ -1506,7 +1506,7 @@ func getDefaultValue(ctx sessionctx.Context, col *table.Column, option *ast.Colu val, err := getEnumDefaultValue(v, col) return val, false, err case mysql.TypeDuration, mysql.TypeDate: - if v, err = v.ConvertTo(ctx.GetSessionVars().StmtCtx.TypeCtx(), &col.FieldType); err != nil { + if v, err = v.ConvertTo(ctx.GetEvalCtx().TypeCtx(), &col.FieldType); err != nil { return "", false, errors.Trace(err) } case mysql.TypeBit: @@ -1518,7 +1518,7 @@ func getDefaultValue(ctx sessionctx.Context, col *table.Column, option *ast.Colu // For these types, convert it to standard format firstly. // like integer fields, convert it into integer string literals. like convert "1.25" into "1" and "2.8" into "3". // if raise a error, we will use original expression. We will handle it in check phase - if temp, err := v.ConvertTo(ctx.GetSessionVars().StmtCtx.TypeCtx(), &col.FieldType); err == nil { + if temp, err := v.ConvertTo(ctx.GetEvalCtx().TypeCtx(), &col.FieldType); err == nil { v = temp } } @@ -1665,7 +1665,7 @@ func setNoDefaultValueFlag(c *table.Column, hasDefaultValue bool) { } } -func checkDefaultValue(ctx sessionctx.Context, c *table.Column, hasDefaultValue bool) (err error) { +func checkDefaultValue(ctx exprctx.BuildContext, c *table.Column, hasDefaultValue bool) (err error) { if !hasDefaultValue { return nil } @@ -1677,9 +1677,10 @@ func checkDefaultValue(ctx sessionctx.Context, c *table.Column, hasDefaultValue } return nil } - handleWithTruncateErr(ctx, func() { - _, err = table.GetColDefaultValue(ctx.GetExprCtx(), c.ToInfo()) - }) + _, err = table.GetColDefaultValue( + exprctx.CtxWithHandleTruncateErrLevel(ctx, errctx.LevelError), + c.ToInfo(), + ) if err != nil { return types.ErrInvalidDefault.GenWithStackByArgs(c.Name) } @@ -5518,23 +5519,14 @@ func checkModifyTypes(origin *types.FieldType, to *types.FieldType, needRewriteC return errors.Trace(err) } -// handleWithTruncateErr handles the doFunc with FlagTruncateAsWarning and FlagIgnoreTruncateErr flags, both of which are false. -func handleWithTruncateErr(ctx sessionctx.Context, doFunc func()) { - sv := ctx.GetSessionVars().StmtCtx - oldTypeFlags := sv.TypeFlags() - newTypeFlags := oldTypeFlags.WithTruncateAsWarning(false).WithIgnoreTruncateErr(false) - sv.SetTypeFlags(newTypeFlags) - doFunc() - sv.SetTypeFlags(oldTypeFlags) -} - // SetDefaultValue sets the default value of the column. func SetDefaultValue(ctx sessionctx.Context, col *table.Column, option *ast.ColumnOption) (hasDefaultValue bool, err error) { var value any var isSeqExpr bool - handleWithTruncateErr(ctx, func() { - value, isSeqExpr, err = getDefaultValue(ctx, col, option) - }) + value, isSeqExpr, err = getDefaultValue( + exprctx.CtxWithHandleTruncateErrLevel(ctx.GetExprCtx(), errctx.LevelError), + col, option, + ) if err != nil { return false, errors.Trace(err) } @@ -5685,7 +5677,7 @@ func processAndCheckDefaultValueAndColumn(ctx sessionctx.Context, col *table.Col if err = checkColumnValueConstraint(col, col.GetCollate()); err != nil { return errors.Trace(err) } - if err = checkDefaultValue(ctx, col, hasDefaultValue); err != nil { + if err = checkDefaultValue(ctx.GetExprCtx(), col, hasDefaultValue); err != nil { return errors.Trace(err) } if err = checkColumnFieldLength(col); err != nil { @@ -5941,9 +5933,10 @@ func GetModifiableColumnJob( return nil, dbterror.ErrUnsupportedModifyColumn.GenWithStack("cannot parse generated PartitionInfo") } pAst := at.Specs[0].Partition - handleWithTruncateErr(sctx, func() { - _, err = buildPartitionDefinitionsInfo(sctx.GetExprCtx(), pAst.Definitions, &newTblInfo, uint64(len(newTblInfo.Partition.Definitions))) - }) + _, err = buildPartitionDefinitionsInfo( + exprctx.CtxWithHandleTruncateErrLevel(sctx.GetExprCtx(), errctx.LevelError), + pAst.Definitions, &newTblInfo, uint64(len(newTblInfo.Partition.Definitions)), + ) if err != nil { return nil, dbterror.ErrUnsupportedModifyColumn.GenWithStack("New column does not match partition definitions: %s", err.Error()) } @@ -6361,7 +6354,7 @@ func (d *ddl) AlterColumn(ctx sessionctx.Context, ident ast.Ident, spec *ast.Alt if err != nil { return errors.Trace(err) } - if err = checkDefaultValue(ctx, col, hasDefaultValue); err != nil { + if err = checkDefaultValue(ctx.GetExprCtx(), col, hasDefaultValue); err != nil { return errors.Trace(err) } } diff --git a/pkg/expression/context/BUILD.bazel b/pkg/expression/context/BUILD.bazel index a5f120a606841..862fa7ad121e1 100644 --- a/pkg/expression/context/BUILD.bazel +++ b/pkg/expression/context/BUILD.bazel @@ -22,9 +22,17 @@ go_library( go_test( name = "context_test", timeout = "short", - srcs = ["optional_test.go"], + srcs = [ + "context_override_test.go", + "optional_test.go", + ], embed = [":context"], flaky = True, - shard_count = 3, - deps = ["@com_github_stretchr_testify//require"], + shard_count = 4, + deps = [ + "//pkg/errctx", + "//pkg/expression/contextstatic", + "//pkg/types", + "@com_github_stretchr_testify//require", + ], ) diff --git a/pkg/expression/context/context.go b/pkg/expression/context/context.go index 736376c21ac17..d9084c0746bf9 100644 --- a/pkg/expression/context/context.go +++ b/pkg/expression/context/context.go @@ -143,6 +143,65 @@ func (ctx *NullRejectCheckExprContext) IsInNullRejectCheck() bool { return true } +type innerOverrideEvalContext struct { + EvalContext + typeCtx types.Context + errCtx errctx.Context +} + +// TypeCtx implements EvalContext.TypeCtx +func (ctx *innerOverrideEvalContext) TypeCtx() types.Context { + return ctx.typeCtx +} + +// ErrCtx implements EvalContext.GetEvalCtx +func (ctx *innerOverrideEvalContext) ErrCtx() errctx.Context { + return ctx.errCtx +} + +type innerOverrideBuildContext struct { + BuildContext + evalCtx EvalContext +} + +// GetEvalCtx implements BuildContext.GetEvalCtx +func (ctx *innerOverrideBuildContext) GetEvalCtx() EvalContext { + return ctx.evalCtx +} + +// CtxWithHandleTruncateErrLevel returns a new BuildContext with the specified level for handling truncate error. +func CtxWithHandleTruncateErrLevel(ctx BuildContext, level errctx.Level) BuildContext { + truncateAsWarnings, ignoreTruncate := false, false + switch level { + case errctx.LevelWarn: + truncateAsWarnings = true + case errctx.LevelIgnore: + ignoreTruncate = true + default: + } + + evalCtx := ctx.GetEvalCtx() + tc, ec := evalCtx.TypeCtx(), evalCtx.ErrCtx() + + flags := tc.Flags(). + WithTruncateAsWarning(truncateAsWarnings). + WithIgnoreTruncateErr(ignoreTruncate) + + if tc.Flags() == flags && ec.LevelForGroup(errctx.ErrGroupTruncate) == level { + // We do not need to create a new context if the flags and level are the same. + return ctx + } + + return &innerOverrideBuildContext{ + BuildContext: ctx, + evalCtx: &innerOverrideEvalContext{ + EvalContext: evalCtx, + typeCtx: tc.WithFlags(flags), + errCtx: ec.WithErrGroupLevel(errctx.ErrGroupTruncate, level), + }, + } +} + // AssertLocationWithSessionVars asserts the location in the context and session variables are the same. // It is only used for testing. func AssertLocationWithSessionVars(ctxLoc *time.Location, vars *variable.SessionVars) { diff --git a/pkg/expression/context/context_override_test.go b/pkg/expression/context/context_override_test.go new file mode 100644 index 0000000000000..dc608c08a0eb0 --- /dev/null +++ b/pkg/expression/context/context_override_test.go @@ -0,0 +1,85 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package context_test + +import ( + "testing" + "time" + + "github.com/pingcap/tidb/pkg/errctx" + "github.com/pingcap/tidb/pkg/expression/context" + "github.com/pingcap/tidb/pkg/expression/contextstatic" + "github.com/pingcap/tidb/pkg/types" + "github.com/stretchr/testify/require" +) + +func TestCtxWithHandleTruncateErrLevel(t *testing.T) { + for _, level := range []errctx.Level{errctx.LevelWarn, errctx.LevelIgnore, errctx.LevelError} { + originalLevelMap := errctx.LevelMap{errctx.ErrGroupDividedByZero: errctx.LevelError} + expectedLevelMap := originalLevelMap + expectedLevelMap[errctx.ErrGroupTruncate] = level + + originalFlags := types.DefaultStmtFlags + originalLoc := time.FixedZone("tz1", 3600*2) + var expectedFlags types.Flags + switch level { + case errctx.LevelError: + originalFlags = originalFlags.WithTruncateAsWarning(true) + originalLevelMap[errctx.ErrGroupTruncate] = errctx.LevelWarn + expectedFlags = originalFlags.WithTruncateAsWarning(false) + case errctx.LevelWarn: + expectedFlags = originalFlags.WithTruncateAsWarning(true) + case errctx.LevelIgnore: + expectedFlags = originalFlags.WithIgnoreTruncateErr(true) + default: + require.FailNow(t, "unexpected level") + } + + evalCtx := contextstatic.NewStaticEvalContext( + contextstatic.WithTypeFlags(originalFlags), + contextstatic.WithLocation(originalLoc), + contextstatic.WithErrLevelMap(errctx.LevelMap{errctx.ErrGroupTruncate: level}), + ) + + tc, ec := evalCtx.TypeCtx(), evalCtx.ErrCtx() + ctx := contextstatic.NewStaticExprContext( + contextstatic.WithEvalCtx(evalCtx), + contextstatic.WithConnectionID(1234), + ) + + // override should take effect + newCtx := context.CtxWithHandleTruncateErrLevel(ctx, level) + newEvalCtx := newCtx.GetEvalCtx() + newTypeCtx, newErrCtx := newEvalCtx.TypeCtx(), newEvalCtx.ErrCtx() + require.Equal(t, expectedFlags, newTypeCtx.Flags()) + require.Equal(t, expectedLevelMap, newErrCtx.LevelMap()) + + // other fields should not change + require.Equal(t, originalLoc, newTypeCtx.Location()) + require.Equal(t, originalLoc, newEvalCtx.Location()) + require.Equal(t, uint64(1234), newCtx.ConnectionID()) + + // old ctx should not change + require.Same(t, evalCtx, ctx.GetEvalCtx()) + require.Equal(t, tc, evalCtx.TypeCtx()) + require.Equal(t, ec, evalCtx.ErrCtx()) + require.Same(t, originalLoc, evalCtx.Location()) + require.Equal(t, uint64(1234), ctx.ConnectionID()) + + // not create new ctx case + newCtx2 := context.CtxWithHandleTruncateErrLevel(newCtx, level) + require.Same(t, newCtx, newCtx2) + } +} diff --git a/pkg/expression/expression.go b/pkg/expression/expression.go index eeba4d6ca88dc..bdb4b32b7788e 100644 --- a/pkg/expression/expression.go +++ b/pkg/expression/expression.go @@ -20,6 +20,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/tidb/pkg/errctx" + exprctx "github.com/pingcap/tidb/pkg/expression/context" "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/parser/model" "github.com/pingcap/tidb/pkg/parser/mysql" @@ -1007,47 +1008,6 @@ func TableInfo2SchemaAndNames(ctx BuildContext, dbName model.CIStr, tbl *model.T return schema, names, nil } -type ignoreTruncateExprCtx struct { - BuildContext - EvalContext - tc types.Context - ec errctx.Context -} - -// ignoreTruncate returns a new BuildContext that ignores the truncate error. -func ignoreTruncate(ctx BuildContext) BuildContext { - evalCtx := ctx.GetEvalCtx() - tc, ec := evalCtx.TypeCtx(), evalCtx.ErrCtx() - if tc.Flags().IgnoreTruncateErr() && ec.LevelForGroup(errctx.ErrGroupTruncate) == errctx.LevelIgnore { - return ctx - } - - tc = tc.WithFlags(tc.Flags().WithIgnoreTruncateErr(true)) - ec = ec.WithErrGroupLevel(errctx.ErrGroupTruncate, errctx.LevelIgnore) - - return &ignoreTruncateExprCtx{ - BuildContext: ctx, - EvalContext: evalCtx, - tc: tc, - ec: ec, - } -} - -// GetEvalCtx implements the BuildContext.EvalCtx(). -func (ctx *ignoreTruncateExprCtx) GetEvalCtx() EvalContext { - return ctx -} - -// TypeCtx implements the EvalContext.TypeCtx(). -func (ctx *ignoreTruncateExprCtx) TypeCtx() types.Context { - return ctx.tc -} - -// ErrCtx implements the EvalContext.ErrCtx(). -func (ctx *ignoreTruncateExprCtx) ErrCtx() errctx.Context { - return ctx.ec -} - // ColumnInfos2ColumnsAndNames converts the ColumnInfo to the *Column and NameSlice. func ColumnInfos2ColumnsAndNames(ctx BuildContext, dbName, tblName model.CIStr, colInfos []*model.ColumnInfo, tblInfo *model.TableInfo) ([]*Column, types.NameSlice, error) { columns := make([]*Column, 0, len(colInfos)) @@ -1078,7 +1038,7 @@ func ColumnInfos2ColumnsAndNames(ctx BuildContext, dbName, tblName model.CIStr, if col.IsVirtualGenerated() { if !truncateIgnored { // Ignore redundant warning here. - ctx = ignoreTruncate(ctx) + ctx = exprctx.CtxWithHandleTruncateErrLevel(ctx, errctx.LevelIgnore) truncateIgnored = true } diff --git a/pkg/expression/expression_test.go b/pkg/expression/expression_test.go index 361d9e88e86fa..e30c9960a28c4 100644 --- a/pkg/expression/expression_test.go +++ b/pkg/expression/expression_test.go @@ -17,7 +17,6 @@ package expression import ( "testing" - "github.com/pingcap/tidb/pkg/errctx" "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/parser/model" "github.com/pingcap/tidb/pkg/parser/mysql" @@ -296,34 +295,3 @@ func TestExpressionMemeoryUsage(t *testing.T) { c4 := Constant{Value: types.NewStringDatum("11")} require.Greater(t, c4.MemoryUsage(), c3.MemoryUsage()) } - -func TestIgnoreTruncateExprCtx(t *testing.T) { - ctx := createContext(t) - ctx.GetSessionVars().StmtCtx.SetTypeFlags(types.StrictFlags) - evalCtx := ctx.GetEvalCtx() - tc, ec := evalCtx.TypeCtx(), evalCtx.ErrCtx() - require.True(t, !tc.Flags().IgnoreTruncateErr() && !tc.Flags().TruncateAsWarning()) - require.Equal(t, errctx.LevelError, ec.LevelForGroup(errctx.ErrGroupTruncate)) - - // new ctx will ignore truncate error - newEvalCtx := ignoreTruncate(ctx).GetEvalCtx() - tc, ec = newEvalCtx.TypeCtx(), newEvalCtx.ErrCtx() - require.True(t, tc.Flags().IgnoreTruncateErr() && !tc.Flags().TruncateAsWarning()) - require.Equal(t, errctx.LevelIgnore, ec.LevelForGroup(errctx.ErrGroupTruncate)) - - // old eval ctx will not change - tc, ec = evalCtx.TypeCtx(), evalCtx.ErrCtx() - require.True(t, !tc.Flags().IgnoreTruncateErr() && !tc.Flags().TruncateAsWarning()) - require.Equal(t, errctx.LevelError, ec.LevelForGroup(errctx.ErrGroupTruncate)) - - // old build ctx will not change - evalCtx = ctx.GetEvalCtx() - tc, ec = evalCtx.TypeCtx(), evalCtx.ErrCtx() - require.True(t, !tc.Flags().IgnoreTruncateErr() && !tc.Flags().TruncateAsWarning()) - require.Equal(t, errctx.LevelError, ec.LevelForGroup(errctx.ErrGroupTruncate)) - - // truncate ignored ctx will not create new ctx - ctx.GetSessionVars().StmtCtx.SetTypeFlags(types.StrictFlags.WithIgnoreTruncateErr(true)) - newCtx := ignoreTruncate(ctx) - require.Same(t, ctx, newCtx) -} diff --git a/pkg/planner/core/BUILD.bazel b/pkg/planner/core/BUILD.bazel index 2d521e5e70246..84160875c791b 100644 --- a/pkg/planner/core/BUILD.bazel +++ b/pkg/planner/core/BUILD.bazel @@ -93,8 +93,10 @@ go_library( "//pkg/config", "//pkg/distsql", "//pkg/domain", + "//pkg/errctx", "//pkg/expression", "//pkg/expression/aggregation", + "//pkg/expression/context", "//pkg/infoschema", "//pkg/infoschema/context", "//pkg/kv", diff --git a/pkg/planner/core/logical_plan_builder.go b/pkg/planner/core/logical_plan_builder.go index 5f462ac9a6002..dec80f8d3e71f 100644 --- a/pkg/planner/core/logical_plan_builder.go +++ b/pkg/planner/core/logical_plan_builder.go @@ -28,8 +28,10 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/failpoint" "github.com/pingcap/tidb/pkg/domain" + "github.com/pingcap/tidb/pkg/errctx" "github.com/pingcap/tidb/pkg/expression" "github.com/pingcap/tidb/pkg/expression/aggregation" + exprctx "github.com/pingcap/tidb/pkg/expression/context" "github.com/pingcap/tidb/pkg/infoschema" "github.com/pingcap/tidb/pkg/kv" "github.com/pingcap/tidb/pkg/parser" @@ -6555,12 +6557,8 @@ func (b *PlanBuilder) buildWindowFunctionFrameBound(_ context.Context, spec *ast // If it has paramMarker and is in prepare stmt. We don't need to eval it since its value is not decided yet. if !checker.InPrepareStmt { // Do not raise warnings for truncate. - sc := b.ctx.GetSessionVars().StmtCtx - oldTypeFlags := sc.TypeFlags() - newTypeFlags := oldTypeFlags.WithIgnoreTruncateErr(true) - sc.SetTypeFlags(newTypeFlags) - uVal, isNull, err := expr.EvalInt(b.ctx.GetExprCtx().GetEvalCtx(), chunk.Row{}) - sc.SetTypeFlags(oldTypeFlags) + exprCtx := exprctx.CtxWithHandleTruncateErrLevel(b.ctx.GetExprCtx(), errctx.LevelIgnore) + uVal, isNull, err := expr.EvalInt(exprCtx.GetEvalCtx(), chunk.Row{}) if uVal < 0 || isNull || err != nil { return nil, plannererrors.ErrWindowFrameIllegal.GenWithStackByArgs(getWindowName(spec.Name.O)) }