Skip to content

Commit

Permalink
ddl: introduce newReorgExprCtx to replace mock.Context usage
Browse files Browse the repository at this point in the history
  • Loading branch information
lcwangchao committed May 21, 2024
1 parent 044f113 commit a6d009c
Show file tree
Hide file tree
Showing 9 changed files with 104 additions and 29 deletions.
2 changes: 2 additions & 0 deletions pkg/ddl/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ go_library(
"//pkg/domain/resourcegroup",
"//pkg/errctx",
"//pkg/expression",
"//pkg/expression/context",
"//pkg/expression/contextstatic",
"//pkg/infoschema",
"//pkg/kv",
"//pkg/lightning/backend",
Expand Down
7 changes: 3 additions & 4 deletions pkg/ddl/backfilling.go
Original file line number Diff line number Diff line change
Expand Up @@ -527,12 +527,12 @@ func loadDDLReorgVars(ctx context.Context, sessPool *sess.Pool) error {
return ddlutil.LoadDDLReorgVars(ctx, sCtx)
}

func makeupDecodeColMap(sessCtx sessionctx.Context, dbName model.CIStr, t table.Table) (map[int64]decoder.Column, error) {
func makeupDecodeColMap(dbName model.CIStr, t table.Table) (map[int64]decoder.Column, error) {
writableColInfos := make([]*model.ColumnInfo, 0, len(t.WritableCols()))
for _, col := range t.WritableCols() {
writableColInfos = append(writableColInfos, col.ColumnInfo)
}
exprCols, _, err := expression.ColumnInfos2ColumnsAndNames(sessCtx.GetExprCtx(), dbName, t.Meta().Name, writableColInfos, t.Meta())
exprCols, _, err := expression.ColumnInfos2ColumnsAndNames(newReorgExprCtx(), dbName, t.Meta().Name, writableColInfos, t.Meta())
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -603,11 +603,10 @@ func (dc *ddlCtx) writePhysicalTableRecord(
})

jc := reorgInfo.NewJobContext()
sessCtx := newReorgSessCtx(reorgInfo.d.store)

eg, egCtx := util.NewErrorGroupWithRecoverWithCtx(dc.ctx)

scheduler, err := newBackfillScheduler(egCtx, reorgInfo, sessPool, bfWorkerType, t, sessCtx, jc)
scheduler, err := newBackfillScheduler(egCtx, reorgInfo, sessPool, bfWorkerType, t, jc)
if err != nil {
return errors.Trace(err)
}
Expand Down
7 changes: 3 additions & 4 deletions pkg/ddl/backfilling_scheduler.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,20 +88,19 @@ func newBackfillScheduler(
sessPool *sess.Pool,
tp backfillerType,
tbl table.PhysicalTable,
sessCtx sessionctx.Context,
jobCtx *JobContext,
) (backfillScheduler, error) {
if tp == typeAddIndexWorker && info.ReorgMeta.ReorgTp == model.ReorgTypeLitMerge {
ctx = logutil.WithCategory(ctx, "ddl-ingest")
return newIngestBackfillScheduler(ctx, info, sessPool, tbl)
}
return newTxnBackfillScheduler(ctx, info, sessPool, tp, tbl, sessCtx, jobCtx)
return newTxnBackfillScheduler(ctx, info, sessPool, tp, tbl, jobCtx)
}

func newTxnBackfillScheduler(ctx context.Context, info *reorgInfo, sessPool *sess.Pool,
tp backfillerType, tbl table.PhysicalTable, sessCtx sessionctx.Context,
tp backfillerType, tbl table.PhysicalTable,
jobCtx *JobContext) (backfillScheduler, error) {
decColMap, err := makeupDecodeColMap(sessCtx, info.dbInfo.Name, tbl)
decColMap, err := makeupDecodeColMap(info.dbInfo.Name, tbl)
if err != nil {
return nil, err
}
Expand Down
60 changes: 60 additions & 0 deletions pkg/ddl/backfilling_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"bytes"
"context"
"testing"
"time"

"github.com/pingcap/tidb/pkg/ddl/ingest"
"github.com/pingcap/tidb/pkg/kv"
Expand Down Expand Up @@ -84,3 +85,62 @@ func TestPickBackfillType(t *testing.T) {
require.NoError(t, err)
require.Equal(t, tp, model.ReorgTypeLitMerge)
}

// TestReorgExprContext is used in refactor stage to make sure the newReorgExprCtx() is
// compatible with newReorgSessCtx(nil).GetExprCtx() to make it safe to replace `mock.Context` usage.
// After refactor, the TestReorgExprContext can be removed.
func TestReorgExprContext(t *testing.T) {
sctx := newReorgSessCtx(nil)
sessCtx := sctx.GetExprCtx()
exprCtx := newReorgExprCtx()
cs1, col1 := sessCtx.GetCharsetInfo()
cs2, col2 := exprCtx.GetCharsetInfo()
require.Equal(t, cs1, cs2)
require.Equal(t, col1, col2)
require.Equal(t, sessCtx.GetDefaultCollationForUTF8MB4(), exprCtx.GetDefaultCollationForUTF8MB4())
if sessCtx.GetBlockEncryptionMode() == "" {
// The newReorgSessCtx returns a block encryption mode as an empty string.
// Though it is not a valid value, it does not matter because `GetBlockEncryptionMode` is never used in DDL.
// So we do not want to modify the behavior of `newReorgSessCtx` or `newReorgExprCtx`, and just to
// place the test code here to check:
// If `GetBlockEncryptionMode` still returns empty string in `newReorgSessCtx`, that means the behavior is
// not changed, and we just need to return a default value for `newReorgExprCtx`.
// If `GetBlockEncryptionMode` returns some other values, that means `GetBlockEncryptionMode` may have been
// used in somewhere and two return values should be the same.
require.Equal(t, "aes-128-ecb", exprCtx.GetBlockEncryptionMode())
} else {
require.Equal(t, sessCtx.GetBlockEncryptionMode(), exprCtx.GetBlockEncryptionMode())
}
require.Equal(t, sessCtx.GetSysdateIsNow(), exprCtx.GetSysdateIsNow())
require.Equal(t, sessCtx.GetNoopFuncsMode(), exprCtx.GetNoopFuncsMode())
require.Equal(t, sessCtx.IsUseCache(), exprCtx.IsUseCache())
require.Equal(t, sessCtx.IsInNullRejectCheck(), exprCtx.IsInNullRejectCheck())
require.Equal(t, sessCtx.ConnectionID(), exprCtx.ConnectionID())
require.Equal(t, sessCtx.AllocPlanColumnID(), exprCtx.AllocPlanColumnID())
require.Equal(t, sessCtx.GetWindowingUseHighPrecision(), exprCtx.GetWindowingUseHighPrecision())
require.Equal(t, sessCtx.GetGroupConcatMaxLen(), exprCtx.GetGroupConcatMaxLen())

evalCtx1 := sessCtx.GetEvalCtx()
evalCtx := exprCtx.GetEvalCtx()
require.Equal(t, evalCtx1.SQLMode(), evalCtx.SQLMode())
tc1 := evalCtx1.TypeCtx()
tc2 := evalCtx.TypeCtx()
require.Equal(t, tc1.Flags(), tc2.Flags())
require.Equal(t, tc1.Location().String(), tc2.Location().String())
ec1 := evalCtx1.ErrCtx()
ec2 := evalCtx.ErrCtx()
require.Equal(t, ec1.LevelMap(), ec2.LevelMap())
require.Equal(t, time.UTC, sctx.GetSessionVars().Location())
require.Equal(t, time.UTC, sctx.GetSessionVars().StmtCtx.TimeZone())
require.Equal(t, time.UTC, evalCtx1.Location())
require.Equal(t, time.UTC, evalCtx.Location())
require.Equal(t, evalCtx1.CurrentDB(), evalCtx.CurrentDB())
tm1, err := evalCtx1.CurrentTime()
require.NoError(t, err)
tm2, err := evalCtx.CurrentTime()
require.NoError(t, err)
require.InDelta(t, tm1.Unix(), tm2.Unix(), 2)
require.Equal(t, evalCtx1.GetMaxAllowedPacket(), evalCtx.GetMaxAllowedPacket())
require.Equal(t, evalCtx1.GetDefaultWeekFormatMode(), evalCtx.GetDefaultWeekFormatMode())
require.Equal(t, evalCtx1.GetDivPrecisionIncrement(), evalCtx.GetDivPrecisionIncrement())
}
9 changes: 5 additions & 4 deletions pkg/ddl/column.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
sess "github.com/pingcap/tidb/pkg/ddl/internal/session"
"github.com/pingcap/tidb/pkg/ddl/logutil"
"github.com/pingcap/tidb/pkg/expression"
exprctx "github.com/pingcap/tidb/pkg/expression/context"
"github.com/pingcap/tidb/pkg/infoschema"
"github.com/pingcap/tidb/pkg/kv"
"github.com/pingcap/tidb/pkg/meta"
Expand Down Expand Up @@ -491,11 +492,11 @@ func getModifyColumnInfo(t *meta.Meta, job *model.Job) (*model.DBInfo, *model.Ta
// Otherwise we set the zero value as original default value.
// Besides, in insert & update records, we have already implement using the casted value of relative column to insert
// rather than the original default value.
func GetOriginDefaultValueForModifyColumn(sessCtx sessionctx.Context, changingCol, oldCol *model.ColumnInfo) (any, error) {
func GetOriginDefaultValueForModifyColumn(ctx exprctx.BuildContext, changingCol, oldCol *model.ColumnInfo) (any, error) {
var err error
originDefVal := oldCol.GetOriginDefaultValue()
if originDefVal != nil {
odv, err := table.CastValue(sessCtx, types.NewDatum(originDefVal), changingCol, false, false)
odv, err := table.CastColumnValue(ctx, types.NewDatum(originDefVal), changingCol, false, false)
if err != nil {
logutil.DDLLogger().Info("cast origin default value failed", zap.Error(err))
}
Expand Down Expand Up @@ -580,7 +581,7 @@ func (w *worker) onModifyColumn(d *ddlCtx, t *meta.Meta, job *model.Job) (ver in
changingCol.Name = newColName
changingCol.ChangeStateInfo = &model.ChangeStateInfo{DependencyColumnOffset: oldCol.Offset}

originDefVal, err := GetOriginDefaultValueForModifyColumn(newReorgSessCtx(d.store), changingCol, oldCol)
originDefVal, err := GetOriginDefaultValueForModifyColumn(newReorgExprCtx(), changingCol, oldCol)
if err != nil {
return ver, errors.Trace(err)
}
Expand Down Expand Up @@ -1803,7 +1804,7 @@ func updateColumnDefaultValue(d *ddlCtx, t *meta.Meta, job *model.Job, newCol *m
return ver, infoschema.ErrColumnNotExists.GenWithStackByArgs(newCol.Name, tblInfo.Name)
}

if hasDefaultValue, _, err := checkColumnDefaultValue(newReorgSessCtx(d.store), table.ToColumn(oldCol.Clone()), newCol.DefaultValue); err != nil {
if hasDefaultValue, _, err := checkColumnDefaultValue(newReorgExprCtx(), table.ToColumn(oldCol.Clone()), newCol.DefaultValue); err != nil {
job.State = model.JobStateCancelled
return ver, errors.Trace(err)
} else if !hasDefaultValue {
Expand Down
16 changes: 8 additions & 8 deletions pkg/ddl/ddl_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ import (
rg "github.com/pingcap/tidb/pkg/domain/resourcegroup"
"github.com/pingcap/tidb/pkg/errctx"
"github.com/pingcap/tidb/pkg/expression"
exprctx "github.com/pingcap/tidb/pkg/expression/context"
"github.com/pingcap/tidb/pkg/infoschema"
"github.com/pingcap/tidb/pkg/kv"
"github.com/pingcap/tidb/pkg/meta"
Expand Down Expand Up @@ -1023,13 +1024,13 @@ func buildColumnAndConstraint(
// In non-strict SQL mode, if the default value of the column is an empty string, the default value can be ignored.
// In strict SQL mode, TEXT/BLOB/JSON can't have not null default values.
// In NO_ZERO_DATE SQL mode, TIMESTAMP/DATE/DATETIME type can't have zero date like '0000-00-00' or '0000-00-00 00:00:00'.
func checkColumnDefaultValue(ctx sessionctx.Context, col *table.Column, value any) (bool, any, error) {
func checkColumnDefaultValue(ctx exprctx.BuildContext, col *table.Column, value any) (bool, any, error) {
hasDefaultValue := true
if value != nil && (col.GetType() == mysql.TypeJSON ||
col.GetType() == mysql.TypeTinyBlob || col.GetType() == mysql.TypeMediumBlob ||
col.GetType() == mysql.TypeLongBlob || col.GetType() == mysql.TypeBlob) {
// In non-strict SQL mode.
if !ctx.GetSessionVars().SQLMode.HasStrictMode() && value == "" {
if !ctx.GetEvalCtx().SQLMode().HasStrictMode() && value == "" {
if col.GetType() == mysql.TypeBlob || col.GetType() == mysql.TypeLongBlob {
// The TEXT/BLOB default value can be ignored.
hasDefaultValue = false
Expand All @@ -1038,17 +1039,16 @@ func checkColumnDefaultValue(ctx sessionctx.Context, col *table.Column, value an
if col.GetType() == mysql.TypeJSON {
value = `null`
}
sc := ctx.GetSessionVars().StmtCtx
sc.AppendWarning(dbterror.ErrBlobCantHaveDefault.FastGenByArgs(col.Name.O))
ctx.GetEvalCtx().AppendWarning(dbterror.ErrBlobCantHaveDefault.FastGenByArgs(col.Name.O))
return hasDefaultValue, value, nil
}
// In strict SQL mode or default value is not an empty string.
return hasDefaultValue, value, dbterror.ErrBlobCantHaveDefault.GenWithStackByArgs(col.Name.O)
}
if value != nil && ctx.GetSessionVars().SQLMode.HasNoZeroDateMode() &&
ctx.GetSessionVars().SQLMode.HasStrictMode() && types.IsTypeTime(col.GetType()) {
if value != nil && ctx.GetEvalCtx().SQLMode().HasNoZeroDateMode() &&
ctx.GetEvalCtx().SQLMode().HasStrictMode() && types.IsTypeTime(col.GetType()) {
if vv, ok := value.(string); ok {
timeValue, err := expression.GetTimeValue(ctx.GetExprCtx(), vv, col.GetType(), col.GetDecimal(), nil)
timeValue, err := expression.GetTimeValue(ctx, vv, col.GetType(), col.GetDecimal(), nil)
if err != nil {
return hasDefaultValue, value, errors.Trace(err)
}
Expand Down Expand Up @@ -5547,7 +5547,7 @@ func SetDefaultValue(ctx sessionctx.Context, col *table.Column, option *ast.Colu

// When the default value is expression, we skip check and convert.
if !col.DefaultIsExpr {
if hasDefaultValue, value, err = checkColumnDefaultValue(ctx, col, value); err != nil {
if hasDefaultValue, value, err = checkColumnDefaultValue(ctx.GetExprCtx(), col, value); err != nil {
return hasDefaultValue, errors.Trace(err)
}
value, err = convertTimestampDefaultValToUTC(ctx, value, col)
Expand Down
23 changes: 18 additions & 5 deletions pkg/ddl/reorg.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ import (
sess "github.com/pingcap/tidb/pkg/ddl/internal/session"
"github.com/pingcap/tidb/pkg/ddl/logutil"
"github.com/pingcap/tidb/pkg/distsql"
exprctx "github.com/pingcap/tidb/pkg/expression/context"
"github.com/pingcap/tidb/pkg/expression/contextstatic"
"github.com/pingcap/tidb/pkg/kv"
"github.com/pingcap/tidb/pkg/meta"
"github.com/pingcap/tidb/pkg/metrics"
Expand Down Expand Up @@ -73,6 +75,19 @@ type reorgCtx struct {
references atomicutil.Int32
}

func newReorgExprCtx() exprctx.ExprContext {
evalCtx := contextstatic.NewStaticEvalContext(
contextstatic.WithSQLMode(mysql.ModeNone),
contextstatic.WithTypeFlags(types.DefaultStmtFlags),
contextstatic.WithErrLevelMap(stmtctx.DefaultStmtErrLevels),
)

return contextstatic.NewStaticExprContext(
contextstatic.WithEvalCtx(evalCtx),
contextstatic.WithUseCache(false),
)
}

func newReorgSessCtx(store kv.Storage) sessionctx.Context {
c := mock.NewContext()
c.Store = store
Expand Down Expand Up @@ -601,16 +616,15 @@ func (dc *ddlCtx) GetTableMaxHandle(ctx *JobContext, startTS uint64, tbl table.P
// empty table
return nil, true, nil
}
sessCtx := newReorgSessCtx(dc.store)
row := chk.GetRow(0)
if tblInfo.IsCommonHandle {
maxHandle, err = buildCommonHandleFromChunkRow(sessCtx.GetSessionVars().StmtCtx, tblInfo, pkIdx, handleCols, row)
maxHandle, err = buildCommonHandleFromChunkRow(time.UTC, tblInfo, pkIdx, handleCols, row)
return maxHandle, false, err
}
return kv.IntHandle(row.GetInt64(0)), false, nil
}

func buildCommonHandleFromChunkRow(sctx *stmtctx.StatementContext, tblInfo *model.TableInfo, idxInfo *model.IndexInfo,
func buildCommonHandleFromChunkRow(loc *time.Location, tblInfo *model.TableInfo, idxInfo *model.IndexInfo,
cols []*model.ColumnInfo, row chunk.Row) (kv.Handle, error) {
fieldTypes := make([]*types.FieldType, 0, len(cols))
for _, col := range cols {
Expand All @@ -620,8 +634,7 @@ func buildCommonHandleFromChunkRow(sctx *stmtctx.StatementContext, tblInfo *mode
tablecodec.TruncateIndexValues(tblInfo, idxInfo, datumRow)

var handleBytes []byte
handleBytes, err := codec.EncodeKey(sctx.TimeZone(), nil, datumRow...)
err = sctx.HandleError(err)
handleBytes, err := codec.EncodeKey(loc, nil, datumRow...)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/ddl/schematracker/dm_tracker.go
Original file line number Diff line number Diff line change
Expand Up @@ -725,7 +725,7 @@ func (d SchemaTracker) handleModifyColumn(
tblInfo.AutoRandomBits = updatedAutoRandomBits
oldCol := table.FindCol(t.Cols(), originalColName.L).ColumnInfo

originDefVal, err := ddl.GetOriginDefaultValueForModifyColumn(sctx, newColInfo, oldCol)
originDefVal, err := ddl.GetOriginDefaultValueForModifyColumn(sctx.GetExprCtx(), newColInfo, oldCol)
if err != nil {
return errors.Trace(err)
}
Expand Down
7 changes: 4 additions & 3 deletions pkg/sessionctx/stmtctx/stmtctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,8 @@ type StatementContext struct {
MDLRelatedTableIDs map[int64]struct{}
}

var defaultErrLevels = func() (l errctx.LevelMap) {
// DefaultStmtErrLevels is the default error levels for statement
var DefaultStmtErrLevels = func() (l errctx.LevelMap) {
l[errctx.ErrGroupDividedByZero] = errctx.LevelWarn
return
}()
Expand All @@ -421,7 +422,7 @@ func NewStmtCtxWithTimeZone(tz *time.Location) *StatementContext {
ctxID: contextutil.GenContextID(),
}
sc.typeCtx = types.NewContext(types.DefaultStmtFlags, tz, sc)
sc.errCtx = newErrCtx(sc.typeCtx, defaultErrLevels, sc)
sc.errCtx = newErrCtx(sc.typeCtx, DefaultStmtErrLevels, sc)
sc.PlanCacheTracker = contextutil.NewPlanCacheTracker(sc)
sc.RangeFallbackHandler = contextutil.NewRangeFallbackHandler(&sc.PlanCacheTracker, sc)
sc.WarnHandler = contextutil.NewStaticWarnHandler(0)
Expand All @@ -435,7 +436,7 @@ func (sc *StatementContext) Reset() {
ctxID: contextutil.GenContextID(),
}
sc.typeCtx = types.NewContext(types.DefaultStmtFlags, time.UTC, sc)
sc.errCtx = newErrCtx(sc.typeCtx, defaultErrLevels, sc)
sc.errCtx = newErrCtx(sc.typeCtx, DefaultStmtErrLevels, sc)
sc.PlanCacheTracker = contextutil.NewPlanCacheTracker(sc)
sc.RangeFallbackHandler = contextutil.NewRangeFallbackHandler(&sc.PlanCacheTracker, sc)
sc.WarnHandler = contextutil.NewStaticWarnHandler(0)
Expand Down

0 comments on commit a6d009c

Please sign in to comment.