From a6d009c0f1f9d44b03c3376c8c11588185ae712d Mon Sep 17 00:00:00 2001 From: Chao Wang Date: Mon, 20 May 2024 15:24:55 +0800 Subject: [PATCH] ddl: introduce `newReorgExprCtx` to replace `mock.Context` usage --- pkg/ddl/BUILD.bazel | 2 + pkg/ddl/backfilling.go | 7 ++-- pkg/ddl/backfilling_scheduler.go | 7 ++-- pkg/ddl/backfilling_test.go | 60 +++++++++++++++++++++++++++++ pkg/ddl/column.go | 9 +++-- pkg/ddl/ddl_api.go | 16 ++++---- pkg/ddl/reorg.go | 23 ++++++++--- pkg/ddl/schematracker/dm_tracker.go | 2 +- pkg/sessionctx/stmtctx/stmtctx.go | 7 ++-- 9 files changed, 104 insertions(+), 29 deletions(-) diff --git a/pkg/ddl/BUILD.bazel b/pkg/ddl/BUILD.bazel index 85cd87e13e279..50fde4c6d1055 100644 --- a/pkg/ddl/BUILD.bazel +++ b/pkg/ddl/BUILD.bazel @@ -89,6 +89,8 @@ go_library( "//pkg/domain/resourcegroup", "//pkg/errctx", "//pkg/expression", + "//pkg/expression/context", + "//pkg/expression/contextstatic", "//pkg/infoschema", "//pkg/kv", "//pkg/lightning/backend", diff --git a/pkg/ddl/backfilling.go b/pkg/ddl/backfilling.go index 2968c5efe0fcb..36493c7fa61cd 100644 --- a/pkg/ddl/backfilling.go +++ b/pkg/ddl/backfilling.go @@ -527,12 +527,12 @@ func loadDDLReorgVars(ctx context.Context, sessPool *sess.Pool) error { return ddlutil.LoadDDLReorgVars(ctx, sCtx) } -func makeupDecodeColMap(sessCtx sessionctx.Context, dbName model.CIStr, t table.Table) (map[int64]decoder.Column, error) { +func makeupDecodeColMap(dbName model.CIStr, t table.Table) (map[int64]decoder.Column, error) { writableColInfos := make([]*model.ColumnInfo, 0, len(t.WritableCols())) for _, col := range t.WritableCols() { writableColInfos = append(writableColInfos, col.ColumnInfo) } - exprCols, _, err := expression.ColumnInfos2ColumnsAndNames(sessCtx.GetExprCtx(), dbName, t.Meta().Name, writableColInfos, t.Meta()) + exprCols, _, err := expression.ColumnInfos2ColumnsAndNames(newReorgExprCtx(), dbName, t.Meta().Name, writableColInfos, t.Meta()) if err != nil { return nil, err } @@ -603,11 +603,10 @@ func (dc *ddlCtx) writePhysicalTableRecord( }) jc := reorgInfo.NewJobContext() - sessCtx := newReorgSessCtx(reorgInfo.d.store) eg, egCtx := util.NewErrorGroupWithRecoverWithCtx(dc.ctx) - scheduler, err := newBackfillScheduler(egCtx, reorgInfo, sessPool, bfWorkerType, t, sessCtx, jc) + scheduler, err := newBackfillScheduler(egCtx, reorgInfo, sessPool, bfWorkerType, t, jc) if err != nil { return errors.Trace(err) } diff --git a/pkg/ddl/backfilling_scheduler.go b/pkg/ddl/backfilling_scheduler.go index 9b173170234e3..69ab527f0cedc 100644 --- a/pkg/ddl/backfilling_scheduler.go +++ b/pkg/ddl/backfilling_scheduler.go @@ -88,20 +88,19 @@ func newBackfillScheduler( sessPool *sess.Pool, tp backfillerType, tbl table.PhysicalTable, - sessCtx sessionctx.Context, jobCtx *JobContext, ) (backfillScheduler, error) { if tp == typeAddIndexWorker && info.ReorgMeta.ReorgTp == model.ReorgTypeLitMerge { ctx = logutil.WithCategory(ctx, "ddl-ingest") return newIngestBackfillScheduler(ctx, info, sessPool, tbl) } - return newTxnBackfillScheduler(ctx, info, sessPool, tp, tbl, sessCtx, jobCtx) + return newTxnBackfillScheduler(ctx, info, sessPool, tp, tbl, jobCtx) } func newTxnBackfillScheduler(ctx context.Context, info *reorgInfo, sessPool *sess.Pool, - tp backfillerType, tbl table.PhysicalTable, sessCtx sessionctx.Context, + tp backfillerType, tbl table.PhysicalTable, jobCtx *JobContext) (backfillScheduler, error) { - decColMap, err := makeupDecodeColMap(sessCtx, info.dbInfo.Name, tbl) + decColMap, err := makeupDecodeColMap(info.dbInfo.Name, tbl) if err != nil { return nil, err } diff --git a/pkg/ddl/backfilling_test.go b/pkg/ddl/backfilling_test.go index 736c08eb0bde8..4702031c67e86 100644 --- a/pkg/ddl/backfilling_test.go +++ b/pkg/ddl/backfilling_test.go @@ -18,6 +18,7 @@ import ( "bytes" "context" "testing" + "time" "github.com/pingcap/tidb/pkg/ddl/ingest" "github.com/pingcap/tidb/pkg/kv" @@ -84,3 +85,62 @@ func TestPickBackfillType(t *testing.T) { require.NoError(t, err) require.Equal(t, tp, model.ReorgTypeLitMerge) } + +// TestReorgExprContext is used in refactor stage to make sure the newReorgExprCtx() is +// compatible with newReorgSessCtx(nil).GetExprCtx() to make it safe to replace `mock.Context` usage. +// After refactor, the TestReorgExprContext can be removed. +func TestReorgExprContext(t *testing.T) { + sctx := newReorgSessCtx(nil) + sessCtx := sctx.GetExprCtx() + exprCtx := newReorgExprCtx() + cs1, col1 := sessCtx.GetCharsetInfo() + cs2, col2 := exprCtx.GetCharsetInfo() + require.Equal(t, cs1, cs2) + require.Equal(t, col1, col2) + require.Equal(t, sessCtx.GetDefaultCollationForUTF8MB4(), exprCtx.GetDefaultCollationForUTF8MB4()) + if sessCtx.GetBlockEncryptionMode() == "" { + // The newReorgSessCtx returns a block encryption mode as an empty string. + // Though it is not a valid value, it does not matter because `GetBlockEncryptionMode` is never used in DDL. + // So we do not want to modify the behavior of `newReorgSessCtx` or `newReorgExprCtx`, and just to + // place the test code here to check: + // If `GetBlockEncryptionMode` still returns empty string in `newReorgSessCtx`, that means the behavior is + // not changed, and we just need to return a default value for `newReorgExprCtx`. + // If `GetBlockEncryptionMode` returns some other values, that means `GetBlockEncryptionMode` may have been + // used in somewhere and two return values should be the same. + require.Equal(t, "aes-128-ecb", exprCtx.GetBlockEncryptionMode()) + } else { + require.Equal(t, sessCtx.GetBlockEncryptionMode(), exprCtx.GetBlockEncryptionMode()) + } + require.Equal(t, sessCtx.GetSysdateIsNow(), exprCtx.GetSysdateIsNow()) + require.Equal(t, sessCtx.GetNoopFuncsMode(), exprCtx.GetNoopFuncsMode()) + require.Equal(t, sessCtx.IsUseCache(), exprCtx.IsUseCache()) + require.Equal(t, sessCtx.IsInNullRejectCheck(), exprCtx.IsInNullRejectCheck()) + require.Equal(t, sessCtx.ConnectionID(), exprCtx.ConnectionID()) + require.Equal(t, sessCtx.AllocPlanColumnID(), exprCtx.AllocPlanColumnID()) + require.Equal(t, sessCtx.GetWindowingUseHighPrecision(), exprCtx.GetWindowingUseHighPrecision()) + require.Equal(t, sessCtx.GetGroupConcatMaxLen(), exprCtx.GetGroupConcatMaxLen()) + + evalCtx1 := sessCtx.GetEvalCtx() + evalCtx := exprCtx.GetEvalCtx() + require.Equal(t, evalCtx1.SQLMode(), evalCtx.SQLMode()) + tc1 := evalCtx1.TypeCtx() + tc2 := evalCtx.TypeCtx() + require.Equal(t, tc1.Flags(), tc2.Flags()) + require.Equal(t, tc1.Location().String(), tc2.Location().String()) + ec1 := evalCtx1.ErrCtx() + ec2 := evalCtx.ErrCtx() + require.Equal(t, ec1.LevelMap(), ec2.LevelMap()) + require.Equal(t, time.UTC, sctx.GetSessionVars().Location()) + require.Equal(t, time.UTC, sctx.GetSessionVars().StmtCtx.TimeZone()) + require.Equal(t, time.UTC, evalCtx1.Location()) + require.Equal(t, time.UTC, evalCtx.Location()) + require.Equal(t, evalCtx1.CurrentDB(), evalCtx.CurrentDB()) + tm1, err := evalCtx1.CurrentTime() + require.NoError(t, err) + tm2, err := evalCtx.CurrentTime() + require.NoError(t, err) + require.InDelta(t, tm1.Unix(), tm2.Unix(), 2) + require.Equal(t, evalCtx1.GetMaxAllowedPacket(), evalCtx.GetMaxAllowedPacket()) + require.Equal(t, evalCtx1.GetDefaultWeekFormatMode(), evalCtx.GetDefaultWeekFormatMode()) + require.Equal(t, evalCtx1.GetDivPrecisionIncrement(), evalCtx.GetDivPrecisionIncrement()) +} diff --git a/pkg/ddl/column.go b/pkg/ddl/column.go index 16c4a1664b81b..2c49fa0477fae 100644 --- a/pkg/ddl/column.go +++ b/pkg/ddl/column.go @@ -31,6 +31,7 @@ import ( sess "github.com/pingcap/tidb/pkg/ddl/internal/session" "github.com/pingcap/tidb/pkg/ddl/logutil" "github.com/pingcap/tidb/pkg/expression" + exprctx "github.com/pingcap/tidb/pkg/expression/context" "github.com/pingcap/tidb/pkg/infoschema" "github.com/pingcap/tidb/pkg/kv" "github.com/pingcap/tidb/pkg/meta" @@ -491,11 +492,11 @@ func getModifyColumnInfo(t *meta.Meta, job *model.Job) (*model.DBInfo, *model.Ta // Otherwise we set the zero value as original default value. // Besides, in insert & update records, we have already implement using the casted value of relative column to insert // rather than the original default value. -func GetOriginDefaultValueForModifyColumn(sessCtx sessionctx.Context, changingCol, oldCol *model.ColumnInfo) (any, error) { +func GetOriginDefaultValueForModifyColumn(ctx exprctx.BuildContext, changingCol, oldCol *model.ColumnInfo) (any, error) { var err error originDefVal := oldCol.GetOriginDefaultValue() if originDefVal != nil { - odv, err := table.CastValue(sessCtx, types.NewDatum(originDefVal), changingCol, false, false) + odv, err := table.CastColumnValue(ctx, types.NewDatum(originDefVal), changingCol, false, false) if err != nil { logutil.DDLLogger().Info("cast origin default value failed", zap.Error(err)) } @@ -580,7 +581,7 @@ func (w *worker) onModifyColumn(d *ddlCtx, t *meta.Meta, job *model.Job) (ver in changingCol.Name = newColName changingCol.ChangeStateInfo = &model.ChangeStateInfo{DependencyColumnOffset: oldCol.Offset} - originDefVal, err := GetOriginDefaultValueForModifyColumn(newReorgSessCtx(d.store), changingCol, oldCol) + originDefVal, err := GetOriginDefaultValueForModifyColumn(newReorgExprCtx(), changingCol, oldCol) if err != nil { return ver, errors.Trace(err) } @@ -1803,7 +1804,7 @@ func updateColumnDefaultValue(d *ddlCtx, t *meta.Meta, job *model.Job, newCol *m return ver, infoschema.ErrColumnNotExists.GenWithStackByArgs(newCol.Name, tblInfo.Name) } - if hasDefaultValue, _, err := checkColumnDefaultValue(newReorgSessCtx(d.store), table.ToColumn(oldCol.Clone()), newCol.DefaultValue); err != nil { + if hasDefaultValue, _, err := checkColumnDefaultValue(newReorgExprCtx(), table.ToColumn(oldCol.Clone()), newCol.DefaultValue); err != nil { job.State = model.JobStateCancelled return ver, errors.Trace(err) } else if !hasDefaultValue { diff --git a/pkg/ddl/ddl_api.go b/pkg/ddl/ddl_api.go index 0c22f0bcb9193..e50c24a0cd96c 100644 --- a/pkg/ddl/ddl_api.go +++ b/pkg/ddl/ddl_api.go @@ -40,6 +40,7 @@ import ( rg "github.com/pingcap/tidb/pkg/domain/resourcegroup" "github.com/pingcap/tidb/pkg/errctx" "github.com/pingcap/tidb/pkg/expression" + exprctx "github.com/pingcap/tidb/pkg/expression/context" "github.com/pingcap/tidb/pkg/infoschema" "github.com/pingcap/tidb/pkg/kv" "github.com/pingcap/tidb/pkg/meta" @@ -1023,13 +1024,13 @@ func buildColumnAndConstraint( // In non-strict SQL mode, if the default value of the column is an empty string, the default value can be ignored. // In strict SQL mode, TEXT/BLOB/JSON can't have not null default values. // In NO_ZERO_DATE SQL mode, TIMESTAMP/DATE/DATETIME type can't have zero date like '0000-00-00' or '0000-00-00 00:00:00'. -func checkColumnDefaultValue(ctx sessionctx.Context, col *table.Column, value any) (bool, any, error) { +func checkColumnDefaultValue(ctx exprctx.BuildContext, col *table.Column, value any) (bool, any, error) { hasDefaultValue := true if value != nil && (col.GetType() == mysql.TypeJSON || col.GetType() == mysql.TypeTinyBlob || col.GetType() == mysql.TypeMediumBlob || col.GetType() == mysql.TypeLongBlob || col.GetType() == mysql.TypeBlob) { // In non-strict SQL mode. - if !ctx.GetSessionVars().SQLMode.HasStrictMode() && value == "" { + if !ctx.GetEvalCtx().SQLMode().HasStrictMode() && value == "" { if col.GetType() == mysql.TypeBlob || col.GetType() == mysql.TypeLongBlob { // The TEXT/BLOB default value can be ignored. hasDefaultValue = false @@ -1038,17 +1039,16 @@ func checkColumnDefaultValue(ctx sessionctx.Context, col *table.Column, value an if col.GetType() == mysql.TypeJSON { value = `null` } - sc := ctx.GetSessionVars().StmtCtx - sc.AppendWarning(dbterror.ErrBlobCantHaveDefault.FastGenByArgs(col.Name.O)) + ctx.GetEvalCtx().AppendWarning(dbterror.ErrBlobCantHaveDefault.FastGenByArgs(col.Name.O)) return hasDefaultValue, value, nil } // In strict SQL mode or default value is not an empty string. return hasDefaultValue, value, dbterror.ErrBlobCantHaveDefault.GenWithStackByArgs(col.Name.O) } - if value != nil && ctx.GetSessionVars().SQLMode.HasNoZeroDateMode() && - ctx.GetSessionVars().SQLMode.HasStrictMode() && types.IsTypeTime(col.GetType()) { + if value != nil && ctx.GetEvalCtx().SQLMode().HasNoZeroDateMode() && + ctx.GetEvalCtx().SQLMode().HasStrictMode() && types.IsTypeTime(col.GetType()) { if vv, ok := value.(string); ok { - timeValue, err := expression.GetTimeValue(ctx.GetExprCtx(), vv, col.GetType(), col.GetDecimal(), nil) + timeValue, err := expression.GetTimeValue(ctx, vv, col.GetType(), col.GetDecimal(), nil) if err != nil { return hasDefaultValue, value, errors.Trace(err) } @@ -5547,7 +5547,7 @@ func SetDefaultValue(ctx sessionctx.Context, col *table.Column, option *ast.Colu // When the default value is expression, we skip check and convert. if !col.DefaultIsExpr { - if hasDefaultValue, value, err = checkColumnDefaultValue(ctx, col, value); err != nil { + if hasDefaultValue, value, err = checkColumnDefaultValue(ctx.GetExprCtx(), col, value); err != nil { return hasDefaultValue, errors.Trace(err) } value, err = convertTimestampDefaultValToUTC(ctx, value, col) diff --git a/pkg/ddl/reorg.go b/pkg/ddl/reorg.go index 66bd682fcfa47..7cf13f60fbb01 100644 --- a/pkg/ddl/reorg.go +++ b/pkg/ddl/reorg.go @@ -29,6 +29,8 @@ import ( sess "github.com/pingcap/tidb/pkg/ddl/internal/session" "github.com/pingcap/tidb/pkg/ddl/logutil" "github.com/pingcap/tidb/pkg/distsql" + exprctx "github.com/pingcap/tidb/pkg/expression/context" + "github.com/pingcap/tidb/pkg/expression/contextstatic" "github.com/pingcap/tidb/pkg/kv" "github.com/pingcap/tidb/pkg/meta" "github.com/pingcap/tidb/pkg/metrics" @@ -73,6 +75,19 @@ type reorgCtx struct { references atomicutil.Int32 } +func newReorgExprCtx() exprctx.ExprContext { + evalCtx := contextstatic.NewStaticEvalContext( + contextstatic.WithSQLMode(mysql.ModeNone), + contextstatic.WithTypeFlags(types.DefaultStmtFlags), + contextstatic.WithErrLevelMap(stmtctx.DefaultStmtErrLevels), + ) + + return contextstatic.NewStaticExprContext( + contextstatic.WithEvalCtx(evalCtx), + contextstatic.WithUseCache(false), + ) +} + func newReorgSessCtx(store kv.Storage) sessionctx.Context { c := mock.NewContext() c.Store = store @@ -601,16 +616,15 @@ func (dc *ddlCtx) GetTableMaxHandle(ctx *JobContext, startTS uint64, tbl table.P // empty table return nil, true, nil } - sessCtx := newReorgSessCtx(dc.store) row := chk.GetRow(0) if tblInfo.IsCommonHandle { - maxHandle, err = buildCommonHandleFromChunkRow(sessCtx.GetSessionVars().StmtCtx, tblInfo, pkIdx, handleCols, row) + maxHandle, err = buildCommonHandleFromChunkRow(time.UTC, tblInfo, pkIdx, handleCols, row) return maxHandle, false, err } return kv.IntHandle(row.GetInt64(0)), false, nil } -func buildCommonHandleFromChunkRow(sctx *stmtctx.StatementContext, tblInfo *model.TableInfo, idxInfo *model.IndexInfo, +func buildCommonHandleFromChunkRow(loc *time.Location, tblInfo *model.TableInfo, idxInfo *model.IndexInfo, cols []*model.ColumnInfo, row chunk.Row) (kv.Handle, error) { fieldTypes := make([]*types.FieldType, 0, len(cols)) for _, col := range cols { @@ -620,8 +634,7 @@ func buildCommonHandleFromChunkRow(sctx *stmtctx.StatementContext, tblInfo *mode tablecodec.TruncateIndexValues(tblInfo, idxInfo, datumRow) var handleBytes []byte - handleBytes, err := codec.EncodeKey(sctx.TimeZone(), nil, datumRow...) - err = sctx.HandleError(err) + handleBytes, err := codec.EncodeKey(loc, nil, datumRow...) if err != nil { return nil, err } diff --git a/pkg/ddl/schematracker/dm_tracker.go b/pkg/ddl/schematracker/dm_tracker.go index 4d80800060e83..e4079b29fb5d5 100644 --- a/pkg/ddl/schematracker/dm_tracker.go +++ b/pkg/ddl/schematracker/dm_tracker.go @@ -725,7 +725,7 @@ func (d SchemaTracker) handleModifyColumn( tblInfo.AutoRandomBits = updatedAutoRandomBits oldCol := table.FindCol(t.Cols(), originalColName.L).ColumnInfo - originDefVal, err := ddl.GetOriginDefaultValueForModifyColumn(sctx, newColInfo, oldCol) + originDefVal, err := ddl.GetOriginDefaultValueForModifyColumn(sctx.GetExprCtx(), newColInfo, oldCol) if err != nil { return errors.Trace(err) } diff --git a/pkg/sessionctx/stmtctx/stmtctx.go b/pkg/sessionctx/stmtctx/stmtctx.go index 5ca6908bbc1cc..5a72c3b6ceb32 100644 --- a/pkg/sessionctx/stmtctx/stmtctx.go +++ b/pkg/sessionctx/stmtctx/stmtctx.go @@ -404,7 +404,8 @@ type StatementContext struct { MDLRelatedTableIDs map[int64]struct{} } -var defaultErrLevels = func() (l errctx.LevelMap) { +// DefaultStmtErrLevels is the default error levels for statement +var DefaultStmtErrLevels = func() (l errctx.LevelMap) { l[errctx.ErrGroupDividedByZero] = errctx.LevelWarn return }() @@ -421,7 +422,7 @@ func NewStmtCtxWithTimeZone(tz *time.Location) *StatementContext { ctxID: contextutil.GenContextID(), } sc.typeCtx = types.NewContext(types.DefaultStmtFlags, tz, sc) - sc.errCtx = newErrCtx(sc.typeCtx, defaultErrLevels, sc) + sc.errCtx = newErrCtx(sc.typeCtx, DefaultStmtErrLevels, sc) sc.PlanCacheTracker = contextutil.NewPlanCacheTracker(sc) sc.RangeFallbackHandler = contextutil.NewRangeFallbackHandler(&sc.PlanCacheTracker, sc) sc.WarnHandler = contextutil.NewStaticWarnHandler(0) @@ -435,7 +436,7 @@ func (sc *StatementContext) Reset() { ctxID: contextutil.GenContextID(), } sc.typeCtx = types.NewContext(types.DefaultStmtFlags, time.UTC, sc) - sc.errCtx = newErrCtx(sc.typeCtx, defaultErrLevels, sc) + sc.errCtx = newErrCtx(sc.typeCtx, DefaultStmtErrLevels, sc) sc.PlanCacheTracker = contextutil.NewPlanCacheTracker(sc) sc.RangeFallbackHandler = contextutil.NewRangeFallbackHandler(&sc.PlanCacheTracker, sc) sc.WarnHandler = contextutil.NewStaticWarnHandler(0)