Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ddl,expression: introduce CtxWithHandleTruncateErrLevel to wrap a expression context to handle truncate error #53441

Merged
merged 1 commit into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions pkg/ddl/column.go
Original file line number Diff line number Diff line change
Expand Up @@ -1820,8 +1820,7 @@ func updateColumnDefaultValue(d *ddlCtx, t *meta.Meta, job *model.Job, newCol *m
oldCol.AddFlag(mysql.NoDefaultValueFlag)
} else {
oldCol.DelFlag(mysql.NoDefaultValueFlag)
sctx := newReorgSessCtx(d.store)
lcwangchao marked this conversation as resolved.
Show resolved Hide resolved
err = checkDefaultValue(sctx, table.ToColumn(oldCol), true)
err = checkDefaultValue(newReorgExprCtx(), table.ToColumn(oldCol), true)
if err != nil {
job.State = model.JobStateCancelled
return ver, err
Expand Down
49 changes: 21 additions & 28 deletions pkg/ddl/ddl_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -1433,7 +1433,7 @@ func getFuncCallDefaultValue(col *table.Column, option *ast.ColumnOption, expr *
// getDefaultValue will get the default value for column.
// 1: get the expr restored string for the column which uses sequence next value as default value.
// 2: get specific default value for the other column.
func getDefaultValue(ctx sessionctx.Context, col *table.Column, option *ast.ColumnOption) (any, bool, error) {
func getDefaultValue(ctx exprctx.BuildContext, col *table.Column, option *ast.ColumnOption) (any, bool, error) {
// handle default value with function call
tp, fsp := col.FieldType.GetType(), col.FieldType.GetDecimal()
if x, ok := option.Expr.(*ast.FuncCallExpr); ok {
Expand All @@ -1445,7 +1445,7 @@ func getDefaultValue(ctx sessionctx.Context, col *table.Column, option *ast.Colu
}

if tp == mysql.TypeTimestamp || tp == mysql.TypeDatetime || tp == mysql.TypeDate {
vd, err := expression.GetTimeValue(ctx.GetExprCtx(), option.Expr, tp, fsp, nil)
vd, err := expression.GetTimeValue(ctx, option.Expr, tp, fsp, nil)
value := vd.GetValue()
if err != nil {
return nil, false, dbterror.ErrInvalidDefaultValue.GenWithStackByArgs(col.Name.O)
Expand All @@ -1465,7 +1465,7 @@ func getDefaultValue(ctx sessionctx.Context, col *table.Column, option *ast.Colu
}

// evaluate the non-function-call expr to a certain value.
v, err := expression.EvalSimpleAst(ctx.GetExprCtx(), option.Expr)
v, err := expression.EvalSimpleAst(ctx, option.Expr)
if err != nil {
return nil, false, errors.Trace(err)
}
Expand All @@ -1491,7 +1491,7 @@ func getDefaultValue(ctx sessionctx.Context, col *table.Column, option *ast.Colu
return str, false, err
}
// For other kind of fields (e.g. INT), we supply its integer as string value.
value, err := v.GetBinaryLiteral().ToInt(ctx.GetSessionVars().StmtCtx.TypeCtx())
value, err := v.GetBinaryLiteral().ToInt(ctx.GetEvalCtx().TypeCtx())
lcwangchao marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
return nil, false, err
}
Expand All @@ -1506,7 +1506,7 @@ func getDefaultValue(ctx sessionctx.Context, col *table.Column, option *ast.Colu
val, err := getEnumDefaultValue(v, col)
return val, false, err
case mysql.TypeDuration, mysql.TypeDate:
if v, err = v.ConvertTo(ctx.GetSessionVars().StmtCtx.TypeCtx(), &col.FieldType); err != nil {
if v, err = v.ConvertTo(ctx.GetEvalCtx().TypeCtx(), &col.FieldType); err != nil {
return "", false, errors.Trace(err)
}
case mysql.TypeBit:
Expand All @@ -1518,7 +1518,7 @@ func getDefaultValue(ctx sessionctx.Context, col *table.Column, option *ast.Colu
// For these types, convert it to standard format firstly.
// like integer fields, convert it into integer string literals. like convert "1.25" into "1" and "2.8" into "3".
// if raise a error, we will use original expression. We will handle it in check phase
if temp, err := v.ConvertTo(ctx.GetSessionVars().StmtCtx.TypeCtx(), &col.FieldType); err == nil {
if temp, err := v.ConvertTo(ctx.GetEvalCtx().TypeCtx(), &col.FieldType); err == nil {
v = temp
}
}
Expand Down Expand Up @@ -1665,7 +1665,7 @@ func setNoDefaultValueFlag(c *table.Column, hasDefaultValue bool) {
}
}

func checkDefaultValue(ctx sessionctx.Context, c *table.Column, hasDefaultValue bool) (err error) {
func checkDefaultValue(ctx exprctx.BuildContext, c *table.Column, hasDefaultValue bool) (err error) {
if !hasDefaultValue {
return nil
}
Expand All @@ -1677,9 +1677,10 @@ func checkDefaultValue(ctx sessionctx.Context, c *table.Column, hasDefaultValue
}
return nil
}
handleWithTruncateErr(ctx, func() {
_, err = table.GetColDefaultValue(ctx.GetExprCtx(), c.ToInfo())
})
_, err = table.GetColDefaultValue(
exprctx.CtxWithHandleTruncateErrLevel(ctx, errctx.LevelError),
c.ToInfo(),
)
if err != nil {
return types.ErrInvalidDefault.GenWithStackByArgs(c.Name)
}
Expand Down Expand Up @@ -5518,23 +5519,14 @@ func checkModifyTypes(origin *types.FieldType, to *types.FieldType, needRewriteC
return errors.Trace(err)
}

// handleWithTruncateErr handles the doFunc with FlagTruncateAsWarning and FlagIgnoreTruncateErr flags, both of which are false.
func handleWithTruncateErr(ctx sessionctx.Context, doFunc func()) {
sv := ctx.GetSessionVars().StmtCtx
oldTypeFlags := sv.TypeFlags()
newTypeFlags := oldTypeFlags.WithTruncateAsWarning(false).WithIgnoreTruncateErr(false)
sv.SetTypeFlags(newTypeFlags)
doFunc()
sv.SetTypeFlags(oldTypeFlags)
}

// SetDefaultValue sets the default value of the column.
func SetDefaultValue(ctx sessionctx.Context, col *table.Column, option *ast.ColumnOption) (hasDefaultValue bool, err error) {
var value any
var isSeqExpr bool
handleWithTruncateErr(ctx, func() {
value, isSeqExpr, err = getDefaultValue(ctx, col, option)
})
value, isSeqExpr, err = getDefaultValue(
exprctx.CtxWithHandleTruncateErrLevel(ctx.GetExprCtx(), errctx.LevelError),
col, option,
)
if err != nil {
return false, errors.Trace(err)
}
Expand Down Expand Up @@ -5685,7 +5677,7 @@ func processAndCheckDefaultValueAndColumn(ctx sessionctx.Context, col *table.Col
if err = checkColumnValueConstraint(col, col.GetCollate()); err != nil {
return errors.Trace(err)
}
if err = checkDefaultValue(ctx, col, hasDefaultValue); err != nil {
if err = checkDefaultValue(ctx.GetExprCtx(), col, hasDefaultValue); err != nil {
return errors.Trace(err)
}
if err = checkColumnFieldLength(col); err != nil {
Expand Down Expand Up @@ -5941,9 +5933,10 @@ func GetModifiableColumnJob(
return nil, dbterror.ErrUnsupportedModifyColumn.GenWithStack("cannot parse generated PartitionInfo")
}
pAst := at.Specs[0].Partition
handleWithTruncateErr(sctx, func() {
_, err = buildPartitionDefinitionsInfo(sctx.GetExprCtx(), pAst.Definitions, &newTblInfo, uint64(len(newTblInfo.Partition.Definitions)))
})
_, err = buildPartitionDefinitionsInfo(
exprctx.CtxWithHandleTruncateErrLevel(sctx.GetExprCtx(), errctx.LevelError),
pAst.Definitions, &newTblInfo, uint64(len(newTblInfo.Partition.Definitions)),
)
if err != nil {
return nil, dbterror.ErrUnsupportedModifyColumn.GenWithStack("New column does not match partition definitions: %s", err.Error())
}
Expand Down Expand Up @@ -6361,7 +6354,7 @@ func (d *ddl) AlterColumn(ctx sessionctx.Context, ident ast.Ident, spec *ast.Alt
if err != nil {
return errors.Trace(err)
}
if err = checkDefaultValue(ctx, col, hasDefaultValue); err != nil {
if err = checkDefaultValue(ctx.GetExprCtx(), col, hasDefaultValue); err != nil {
return errors.Trace(err)
}
}
Expand Down
14 changes: 11 additions & 3 deletions pkg/expression/context/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,17 @@ go_library(
go_test(
name = "context_test",
timeout = "short",
srcs = ["optional_test.go"],
srcs = [
"context_override_test.go",
"optional_test.go",
],
embed = [":context"],
flaky = True,
shard_count = 3,
deps = ["@com_github_stretchr_testify//require"],
shard_count = 4,
deps = [
"//pkg/errctx",
"//pkg/expression/contextstatic",
"//pkg/types",
"@com_github_stretchr_testify//require",
],
)
59 changes: 59 additions & 0 deletions pkg/expression/context/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,65 @@ func (ctx *NullRejectCheckExprContext) IsInNullRejectCheck() bool {
return true
}

type innerOverrideEvalContext struct {
EvalContext
typeCtx types.Context
errCtx errctx.Context
}

// TypeCtx implements EvalContext.TypeCtx
func (ctx *innerOverrideEvalContext) TypeCtx() types.Context {
return ctx.typeCtx
}

// ErrCtx implements EvalContext.GetEvalCtx
func (ctx *innerOverrideEvalContext) ErrCtx() errctx.Context {
return ctx.errCtx
}

type innerOverrideBuildContext struct {
BuildContext
evalCtx EvalContext
}

// GetEvalCtx implements BuildContext.GetEvalCtx
func (ctx *innerOverrideBuildContext) GetEvalCtx() EvalContext {
return ctx.evalCtx
}

// CtxWithHandleTruncateErrLevel returns a new BuildContext with the specified level for handling truncate error.
func CtxWithHandleTruncateErrLevel(ctx BuildContext, level errctx.Level) BuildContext {
truncateAsWarnings, ignoreTruncate := false, false
switch level {
case errctx.LevelWarn:
truncateAsWarnings = true
case errctx.LevelIgnore:
ignoreTruncate = true
default:
}

evalCtx := ctx.GetEvalCtx()
tc, ec := evalCtx.TypeCtx(), evalCtx.ErrCtx()

flags := tc.Flags().
WithTruncateAsWarning(truncateAsWarnings).
WithIgnoreTruncateErr(ignoreTruncate)

if tc.Flags() == flags && ec.LevelForGroup(errctx.ErrGroupTruncate) == level {
// We do not need to create a new context if the flags and level are the same.
return ctx
}

return &innerOverrideBuildContext{
BuildContext: ctx,
evalCtx: &innerOverrideEvalContext{
EvalContext: evalCtx,
typeCtx: tc.WithFlags(flags),
errCtx: ec.WithErrGroupLevel(errctx.ErrGroupTruncate, level),
},
}
}

// AssertLocationWithSessionVars asserts the location in the context and session variables are the same.
// It is only used for testing.
func AssertLocationWithSessionVars(ctxLoc *time.Location, vars *variable.SessionVars) {
Expand Down
85 changes: 85 additions & 0 deletions pkg/expression/context/context_override_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
// Copyright 2024 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package context_test

import (
"testing"
"time"

"github.com/pingcap/tidb/pkg/errctx"
"github.com/pingcap/tidb/pkg/expression/context"
"github.com/pingcap/tidb/pkg/expression/contextstatic"
"github.com/pingcap/tidb/pkg/types"
"github.com/stretchr/testify/require"
)

func TestCtxWithHandleTruncateErrLevel(t *testing.T) {
for _, level := range []errctx.Level{errctx.LevelWarn, errctx.LevelIgnore, errctx.LevelError} {
originalLevelMap := errctx.LevelMap{errctx.ErrGroupDividedByZero: errctx.LevelError}
expectedLevelMap := originalLevelMap
expectedLevelMap[errctx.ErrGroupTruncate] = level

originalFlags := types.DefaultStmtFlags
originalLoc := time.FixedZone("tz1", 3600*2)
var expectedFlags types.Flags
switch level {
case errctx.LevelError:
originalFlags = originalFlags.WithTruncateAsWarning(true)
originalLevelMap[errctx.ErrGroupTruncate] = errctx.LevelWarn
expectedFlags = originalFlags.WithTruncateAsWarning(false)
case errctx.LevelWarn:
expectedFlags = originalFlags.WithTruncateAsWarning(true)
case errctx.LevelIgnore:
expectedFlags = originalFlags.WithIgnoreTruncateErr(true)
default:
require.FailNow(t, "unexpected level")
}

evalCtx := contextstatic.NewStaticEvalContext(
contextstatic.WithTypeFlags(originalFlags),
contextstatic.WithLocation(originalLoc),
contextstatic.WithErrLevelMap(errctx.LevelMap{errctx.ErrGroupTruncate: level}),
)

tc, ec := evalCtx.TypeCtx(), evalCtx.ErrCtx()
ctx := contextstatic.NewStaticExprContext(
contextstatic.WithEvalCtx(evalCtx),
contextstatic.WithConnectionID(1234),
)

// override should take effect
newCtx := context.CtxWithHandleTruncateErrLevel(ctx, level)
newEvalCtx := newCtx.GetEvalCtx()
newTypeCtx, newErrCtx := newEvalCtx.TypeCtx(), newEvalCtx.ErrCtx()
require.Equal(t, expectedFlags, newTypeCtx.Flags())
require.Equal(t, expectedLevelMap, newErrCtx.LevelMap())

// other fields should not change
require.Equal(t, originalLoc, newTypeCtx.Location())
require.Equal(t, originalLoc, newEvalCtx.Location())
require.Equal(t, uint64(1234), newCtx.ConnectionID())

// old ctx should not change
require.Same(t, evalCtx, ctx.GetEvalCtx())
require.Equal(t, tc, evalCtx.TypeCtx())
require.Equal(t, ec, evalCtx.ErrCtx())
require.Same(t, originalLoc, evalCtx.Location())
require.Equal(t, uint64(1234), ctx.ConnectionID())

// not create new ctx case
newCtx2 := context.CtxWithHandleTruncateErrLevel(newCtx, level)
require.Same(t, newCtx, newCtx2)
}
}
44 changes: 2 additions & 42 deletions pkg/expression/expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (

"github.com/pingcap/errors"
"github.com/pingcap/tidb/pkg/errctx"
exprctx "github.com/pingcap/tidb/pkg/expression/context"
"github.com/pingcap/tidb/pkg/parser/ast"
"github.com/pingcap/tidb/pkg/parser/model"
"github.com/pingcap/tidb/pkg/parser/mysql"
Expand Down Expand Up @@ -1007,47 +1008,6 @@ func TableInfo2SchemaAndNames(ctx BuildContext, dbName model.CIStr, tbl *model.T
return schema, names, nil
}

type ignoreTruncateExprCtx struct {
BuildContext
EvalContext
tc types.Context
ec errctx.Context
}

// ignoreTruncate returns a new BuildContext that ignores the truncate error.
func ignoreTruncate(ctx BuildContext) BuildContext {
evalCtx := ctx.GetEvalCtx()
tc, ec := evalCtx.TypeCtx(), evalCtx.ErrCtx()
if tc.Flags().IgnoreTruncateErr() && ec.LevelForGroup(errctx.ErrGroupTruncate) == errctx.LevelIgnore {
return ctx
}

tc = tc.WithFlags(tc.Flags().WithIgnoreTruncateErr(true))
ec = ec.WithErrGroupLevel(errctx.ErrGroupTruncate, errctx.LevelIgnore)

return &ignoreTruncateExprCtx{
BuildContext: ctx,
EvalContext: evalCtx,
tc: tc,
ec: ec,
}
}

// GetEvalCtx implements the BuildContext.EvalCtx().
func (ctx *ignoreTruncateExprCtx) GetEvalCtx() EvalContext {
return ctx
}

// TypeCtx implements the EvalContext.TypeCtx().
func (ctx *ignoreTruncateExprCtx) TypeCtx() types.Context {
return ctx.tc
}

// ErrCtx implements the EvalContext.ErrCtx().
func (ctx *ignoreTruncateExprCtx) ErrCtx() errctx.Context {
return ctx.ec
}

// ColumnInfos2ColumnsAndNames converts the ColumnInfo to the *Column and NameSlice.
func ColumnInfos2ColumnsAndNames(ctx BuildContext, dbName, tblName model.CIStr, colInfos []*model.ColumnInfo, tblInfo *model.TableInfo) ([]*Column, types.NameSlice, error) {
columns := make([]*Column, 0, len(colInfos))
Expand Down Expand Up @@ -1078,7 +1038,7 @@ func ColumnInfos2ColumnsAndNames(ctx BuildContext, dbName, tblName model.CIStr,
if col.IsVirtualGenerated() {
if !truncateIgnored {
// Ignore redundant warning here.
ctx = ignoreTruncate(ctx)
ctx = exprctx.CtxWithHandleTruncateErrLevel(ctx, errctx.LevelIgnore)
truncateIgnored = true
}

Expand Down
Loading