Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

stmtctx: move time zone info in stmtctx.StatementContext into TypeCtx #47592

Merged
merged 3 commits into from
Oct 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion bindinfo/session_handle.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func (h *SessionHandle) CreateBindRecord(sctx sessionctx.Context, record *BindRe
return err
}
record.Db = strings.ToLower(record.Db)
now := types.NewTime(types.FromGoTime(time.Now().In(sctx.GetSessionVars().StmtCtx.TimeZone)), mysql.TypeTimestamp, 3)
now := types.NewTime(types.FromGoTime(time.Now().In(sctx.GetSessionVars().StmtCtx.TimeZone())), mysql.TypeTimestamp, 3)
for i := range record.Bindings {
record.Bindings[i].CreateTime = now
record.Bindings[i].UpdateTime = now
Expand Down
2 changes: 1 addition & 1 deletion br/pkg/lightning/backend/kv/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ func NewSession(options *encode.SessionOptions, logger log.Logger) *Session {
}
}
}
vars.StmtCtx.TimeZone = vars.Location()
vars.StmtCtx.SetTimeZone(vars.Location())
if err := vars.SetSystemVar("timestamp", strconv.FormatInt(options.Timestamp, 10)); err != nil {
logger.Warn("new session: failed to set timestamp",
log.ShortError(err))
Expand Down
2 changes: 1 addition & 1 deletion br/pkg/lightning/backend/local/local_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ func TestNextKey(t *testing.T) {
{types.NewStringDatum("test\255"), types.NewStringDatum("test\255\000")},
}

stmtCtx := new(stmtctx.StatementContext)
stmtCtx := stmtctx.NewStmtCtx()
for _, datums := range testDatums {
keyBytes, err := codec.EncodeKey(stmtCtx, nil, types.NewIntDatum(123), datums[0])
require.NoError(t, err)
Expand Down
2 changes: 1 addition & 1 deletion br/pkg/lightning/backend/local/localhelper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -783,7 +783,7 @@ func doTestBatchSplitByRangesWithClusteredIndex(t *testing.T, hook clientHook) {
splitRegionBaseBackOffTime = oldSplitBackoffTime
}()

stmtCtx := new(stmtctx.StatementContext)
stmtCtx := stmtctx.NewStmtCtx()

tableID := int64(1)
tableStartKey := tablecodec.EncodeTablePrefix(tableID)
Expand Down
2 changes: 1 addition & 1 deletion br/pkg/restore/merge_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func (fb *fileBulder) build(tableID, indexID, num, bytes, kv int) (files []*back
if indexID != 0 {
lowVal := types.NewIntDatum(fb.startKeyOffset - 10)
highVal := types.NewIntDatum(fb.startKeyOffset)
sc := &stmtctx.StatementContext{TimeZone: time.UTC}
sc := stmtctx.NewStmtCtxWithTimeZone(time.UTC)
lowValue, err := codec.EncodeKey(sc, nil, lowVal)
if err != nil {
panic(err)
Expand Down
4 changes: 1 addition & 3 deletions br/pkg/task/backup.go
Original file line number Diff line number Diff line change
Expand Up @@ -767,9 +767,7 @@ func ParseTSString(ts string, tzCheck bool) (uint64, error) {
}

loc := time.Local
sc := &stmtctx.StatementContext{
TimeZone: loc,
}
sc := stmtctx.NewStmtCtxWithTimeZone(loc)
if tzCheck {
tzIdx, _, _, _, _ := types.GetTimezone(ts)
if tzIdx < 0 {
Expand Down
4 changes: 2 additions & 2 deletions ddl/backfilling_scheduler.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,9 @@ func initSessCtx(
tzLocation *model.TimeZoneLocation,
) error {
// Unify the TimeZone settings in newContext.
if sessCtx.GetSessionVars().StmtCtx.TimeZone == nil {
if sessCtx.GetSessionVars().StmtCtx.TimeZone() == nil {
tz := *time.UTC
sessCtx.GetSessionVars().StmtCtx.TimeZone = &tz
sessCtx.GetSessionVars().StmtCtx.SetTimeZone(&tz)
}
sessCtx.GetSessionVars().StmtCtx.IsDDLJobInQueue = true
// Set the row encode format version.
Expand Down
2 changes: 1 addition & 1 deletion ddl/column.go
Original file line number Diff line number Diff line change
Expand Up @@ -1314,7 +1314,7 @@ func (w *updateColumnWorker) fetchRowColVals(txn kv.Transaction, taskRange reorg
}

func (w *updateColumnWorker) getRowRecord(handle kv.Handle, recordKey []byte, rawRow []byte) error {
sysTZ := w.sessCtx.GetSessionVars().StmtCtx.TimeZone
sysTZ := w.sessCtx.GetSessionVars().StmtCtx.TimeZone()
_, err := w.rowDecoder.DecodeTheExistedColumnMap(w.sessCtx, handle, rawRow, sysTZ, w.rowMap)
if err != nil {
return errors.Trace(dbterror.ErrCantDecodeRecord.GenWithStackByArgs("column", err))
Expand Down
2 changes: 1 addition & 1 deletion ddl/export_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,6 @@ func ConvertRowToHandleAndIndexDatum(
c := copCtx.GetBase()
idxData := extractDatumByOffsets(row, copCtx.IndexColumnOutputOffsets(idxID), c.ExprColumnInfos, nil)
handleData := extractDatumByOffsets(row, c.HandleOutputOffsets, c.ExprColumnInfos, nil)
handle, err := buildHandle(handleData, c.TableInfo, c.PrimaryKeyInfo, &stmtctx.StatementContext{TimeZone: time.Local})
handle, err := buildHandle(handleData, c.TableInfo, c.PrimaryKeyInfo, stmtctx.NewStmtCtxWithTimeZone(time.Local))
return handle, idxData, err
}
2 changes: 1 addition & 1 deletion ddl/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -1494,7 +1494,7 @@ func (w *baseIndexWorker) getNextKey(taskRange reorgBackfillTask, taskDone bool)
}

func (w *baseIndexWorker) updateRowDecoder(handle kv.Handle, rawRecord []byte) error {
sysZone := w.sessCtx.GetSessionVars().StmtCtx.TimeZone
sysZone := w.sessCtx.GetSessionVars().StmtCtx.TimeZone()
_, err := w.rowDecoder.DecodeAndEvalRowWithMap(w.sessCtx, handle, rawRecord, sysZone, w.rowMap)
return errors.Trace(err)
}
Expand Down
2 changes: 1 addition & 1 deletion ddl/partition.go
Original file line number Diff line number Diff line change
Expand Up @@ -3181,7 +3181,7 @@ func (w *reorgPartitionWorker) fetchRowColVals(txn kv.Transaction, taskRange reo

// taskDone means that the added handle is out of taskRange.endHandle.
taskDone := false
sysTZ := w.sessCtx.GetSessionVars().StmtCtx.TimeZone
sysTZ := w.sessCtx.GetSessionVars().StmtCtx.TimeZone()

tmpRow := make([]types.Datum, w.maxOffset+1)
var lastAccessedHandle kv.Key
Expand Down
2 changes: 1 addition & 1 deletion ddl/reorg.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func newContext(store kv.Storage) sessionctx.Context {

tz := *time.UTC
c.GetSessionVars().TimeZone = &tz
c.GetSessionVars().StmtCtx.TimeZone = &tz
c.GetSessionVars().StmtCtx.SetTimeZone(&tz)
return c
}

Expand Down
8 changes: 4 additions & 4 deletions distsql/distsql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -300,10 +300,10 @@ func (r *mockResultSubset) RespTime() time.Duration { return 0 }

func newMockSessionContext() sessionctx.Context {
ctx := mock.NewContext()
ctx.GetSessionVars().StmtCtx = &stmtctx.StatementContext{
MemTracker: memory.NewTracker(-1, -1),
DiskTracker: disk.NewTracker(-1, -1),
}
ctx.GetSessionVars().StmtCtx = stmtctx.NewStmtCtx()
ctx.GetSessionVars().StmtCtx.MemTracker = memory.NewTracker(-1, -1)
ctx.GetSessionVars().StmtCtx.DiskTracker = disk.NewTracker(-1, -1)

ctx.Store = &mock.Store{
Client: &mock.Client{
MockResponse: &mockResponse{
Expand Down
6 changes: 3 additions & 3 deletions distsql/request_builder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ func TestIndexRangesToKVRanges(t *testing.T) {
},
}

actual, err := IndexRangesToKVRanges(new(stmtctx.StatementContext), 12, 15, ranges)
actual, err := IndexRangesToKVRanges(stmtctx.NewStmtCtx(), 12, 15, ranges)
require.NoError(t, err)
for i := range actual.FirstPartitionRange() {
require.Equal(t, expect[i], actual.FirstPartitionRange()[i])
Expand Down Expand Up @@ -314,7 +314,7 @@ func TestRequestBuilder2(t *testing.T) {
},
}

actual, err := (&RequestBuilder{}).SetIndexRanges(new(stmtctx.StatementContext), 12, 15, ranges).
actual, err := (&RequestBuilder{}).SetIndexRanges(stmtctx.NewStmtCtx(), 12, 15, ranges).
SetDAGRequest(&tipb.DAGRequest{}).
SetDesc(false).
SetKeepOrder(false).
Expand Down Expand Up @@ -659,7 +659,7 @@ func TestIndexRangesToKVRangesWithFbs(t *testing.T) {
Collators: collate.GetBinaryCollatorSlice(1),
},
}
actual, err := IndexRangesToKVRanges(new(stmtctx.StatementContext), 0, 0, ranges)
actual, err := IndexRangesToKVRanges(stmtctx.NewStmtCtx(), 0, 0, ranges)
require.NoError(t, err)
expect := []kv.KeyRange{
{
Expand Down
2 changes: 1 addition & 1 deletion distsql/select_result_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import (

func TestUpdateCopRuntimeStats(t *testing.T) {
ctx := mock.NewContext()
ctx.GetSessionVars().StmtCtx = new(stmtctx.StatementContext)
ctx.GetSessionVars().StmtCtx = stmtctx.NewStmtCtx()
sr := selectResult{ctx: ctx, storeType: kv.TiKV}
require.Nil(t, ctx.GetSessionVars().StmtCtx.RuntimeStatsColl)

Expand Down
2 changes: 1 addition & 1 deletion domain/domain_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ func TestStatWorkRecoverFromPanic(t *testing.T) {
require.Equal(t, expiredTimeStamp, ts)

// set expiredTimeStamp4PC to "2023-08-02 12:15:00"
ts, _ = types.ParseTimestamp(&stmtctx.StatementContext{TimeZone: time.UTC}, "2023-08-02 12:15:00")
ts, _ = types.ParseTimestamp(stmtctx.NewStmtCtxWithTimeZone(time.UTC), "2023-08-02 12:15:00")
dom.SetExpiredTimeStamp4PC(ts)
expiredTimeStamp = dom.ExpiredTimeStamp4PC()
require.Equal(t, expiredTimeStamp, ts)
Expand Down
2 changes: 1 addition & 1 deletion executor/aggfuncs/func_count_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ func TestMemCount(t *testing.T) {
}

func TestWriteTime(t *testing.T) {
tt, err := types.ParseDate(&(stmtctx.StatementContext{}), "2020-11-11")
tt, err := types.ParseDate(stmtctx.NewStmtCtx(), "2020-11-11")
require.NoError(t, err)

buf := make([]byte, 16)
Expand Down
2 changes: 1 addition & 1 deletion executor/analyze_col_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ func (e *AnalyzeColumnsExecV2) buildSamplingStats(
// If there's no virtual column or we meet error during eval virtual column, we fallback to normal decode otherwise.
for _, sample := range rootRowCollector.Base().Samples {
for i := range sample.Columns {
sample.Columns[i], err = tablecodec.DecodeColumnValue(sample.Columns[i].GetBytes(), &e.colsInfo[i].FieldType, sc.TimeZone)
sample.Columns[i], err = tablecodec.DecodeColumnValue(sample.Columns[i].GetBytes(), &e.colsInfo[i].FieldType, sc.TimeZone())
if err != nil {
return 0, nil, nil, nil, nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions executor/brie.go
Original file line number Diff line number Diff line change
Expand Up @@ -235,12 +235,12 @@ func (bq *brieQueue) clearTask(sc *stmtctx.StatementContext) {
}

func (b *executorBuilder) parseTSString(ts string) (uint64, error) {
sc := &stmtctx.StatementContext{TimeZone: b.ctx.GetSessionVars().Location()}
sc := stmtctx.NewStmtCtxWithTimeZone(b.ctx.GetSessionVars().Location())
t, err := types.ParseTime(sc, ts, mysql.TypeTimestamp, types.MaxFsp, nil)
if err != nil {
return 0, err
}
t1, err := t.GoTime(sc.TimeZone)
t1, err := t.GoTime(sc.TimeZone())
if err != nil {
return 0, err
}
Expand Down
6 changes: 4 additions & 2 deletions executor/coprocessor.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,11 +183,13 @@ func (h *CoprocessorDAGHandler) buildDAGExecutor(req *coprocessor.Request) (exec

stmtCtx := h.sctx.GetSessionVars().StmtCtx
stmtCtx.SetFlagsFromPBFlag(dagReq.Flags)
stmtCtx.TimeZone, err = timeutil.ConstructTimeZone(dagReq.TimeZoneName, int(dagReq.TimeZoneOffset))
h.sctx.GetSessionVars().TimeZone = stmtCtx.TimeZone
tz, err := timeutil.ConstructTimeZone(dagReq.TimeZoneName, int(dagReq.TimeZoneOffset))
if err != nil {
return nil, errors.Trace(err)
}

stmtCtx.SetTimeZone(tz)
h.sctx.GetSessionVars().TimeZone = tz
h.dagReq = dagReq
is := h.sctx.GetInfoSchema().(infoschema.InfoSchema)
// Build physical plan.
Expand Down
16 changes: 8 additions & 8 deletions executor/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -1938,12 +1938,11 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) {
if vars.TxnCtx.CouldRetry || mysql.HasCursorExistsFlag(vars.Status) {
// Must construct new statement context object, the retry history need context for every statement.
// TODO: Maybe one day we can get rid of transaction retry, then this logic can be deleted.
sc = &stmtctx.StatementContext{}
sc = stmtctx.NewStmtCtx()
} else {
sc = vars.InitStatementContext()
}
var typeConvFlags types.Flags
sc.TimeZone = vars.Location()
sc.SetTimeZone(vars.Location())
sc.TaskID = stmtctx.AllocateTaskID()
sc.CTEStorageMap = map[int]*CTEStorages{}
sc.IsStaleness = false
Expand Down Expand Up @@ -2139,10 +2138,12 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) {
sc.AllowInvalidDate = vars.SQLMode.HasAllowInvalidDatesMode()
}

typeConvFlags = typeConvFlags.
WithSkipUTF8Check(vars.SkipUTF8Check).
WithSkipSACIICheck(vars.SkipASCIICheck).
WithSkipUTF8MB4Check(!globalConfig.Instance.CheckMb4ValueInUTF8.Load())
sc.UpdateTypeFlags(func(flags types.Flags) types.Flags {
return flags.
WithSkipUTF8Check(vars.SkipUTF8Check).
WithSkipSACIICheck(vars.SkipASCIICheck).
WithSkipUTF8MB4Check(!globalConfig.Instance.CheckMb4ValueInUTF8.Load())
})

vars.PlanCacheParams.Reset()
if priority := mysql.PriorityEnum(atomic.LoadInt32(&variable.ForcePriority)); priority != mysql.NoPriority {
Expand All @@ -2169,7 +2170,6 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) {
sc.RuntimeStatsColl = execdetails.NewRuntimeStatsColl(reuseObj)
}

sc.TypeConvContext = types.NewContext(typeConvFlags, sc.TimeZone, sc.AppendWarning)
sc.TblInfo2UnionScan = make(map[*model.TableInfo]bool)
errCount, warnCount := vars.StmtCtx.NumErrorWarnings()
vars.SysErrorCount = errCount
Expand Down
2 changes: 1 addition & 1 deletion executor/infoschema_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -1460,7 +1460,7 @@ func (e *memtableRetriever) setDataForProcessList(ctx sessionctx.Context) {
continue
}

rows := pi.ToRow(ctx.GetSessionVars().StmtCtx.TimeZone)
rows := pi.ToRow(ctx.GetSessionVars().StmtCtx.TimeZone())
record := types.MakeDatums(rows...)
records = append(records, record)
}
Expand Down
2 changes: 1 addition & 1 deletion executor/simple.go
Original file line number Diff line number Diff line change
Expand Up @@ -2807,7 +2807,7 @@ func (e *SimpleExec) executeAdminFlushPlanCache(s *ast.AdminStmt) error {
e.Ctx().GetSessionVars().StmtCtx.AppendWarning(errors.New("The plan cache is disable. So there no need to flush the plan cache"))
return nil
}
now := types.NewTime(types.FromGoTime(time.Now().In(e.Ctx().GetSessionVars().StmtCtx.TimeZone)), mysql.TypeTimestamp, 3)
now := types.NewTime(types.FromGoTime(time.Now().In(e.Ctx().GetSessionVars().StmtCtx.TimeZone())), mysql.TypeTimestamp, 3)
e.Ctx().GetSessionVars().LastUpdateTime4PC = now
e.Ctx().GetSessionPlanCache().DeleteAll()
if s.StatementScope == ast.StatementScopeInstance {
Expand Down
2 changes: 1 addition & 1 deletion executor/split_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ func TestClusterIndexSplitTable(t *testing.T) {
}(minRegionStepValue)
minRegionStepValue = 3
ctx := mock.NewContext()
sc := &stmtctx.StatementContext{TimeZone: time.Local}
sc := stmtctx.NewStmtCtxWithTimeZone(time.Local)
e := &SplitTableRegionExec{
BaseExecutor: exec.NewBaseExecutor(ctx, nil, 0),
tableInfo: tbInfo,
Expand Down
4 changes: 2 additions & 2 deletions executor/stmtsummary.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ func (e *stmtSummaryRetriever) initEvictedRowsReader(sctx sessionctx.Context) (*
func (e *stmtSummaryRetriever) initSummaryRowsReader(sctx sessionctx.Context) (*rowsReader, error) {
vars := sctx.GetSessionVars()
user := vars.User
tz := vars.StmtCtx.TimeZone
tz := vars.StmtCtx.TimeZone()
columns := e.columns
priv := hasPriv(sctx, mysql.ProcessPriv)
instanceAddr, err := clusterTableInstanceAddr(sctx, e.table.Name.O)
Expand Down Expand Up @@ -241,7 +241,7 @@ func (r *stmtSummaryRetrieverV2) initEvictedRowsReader(sctx sessionctx.Context)
func (r *stmtSummaryRetrieverV2) initSummaryRowsReader(ctx context.Context, sctx sessionctx.Context) (*rowsReader, error) {
vars := sctx.GetSessionVars()
user := vars.User
tz := vars.StmtCtx.TimeZone
tz := vars.StmtCtx.TimeZone()
stmtSummary := r.stmtSummary
columns := r.columns
timeRanges := r.timeRanges
Expand Down
2 changes: 1 addition & 1 deletion executor/test/admintest/admin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1272,7 +1272,7 @@ func TestCheckFailReport(t *testing.T) {

txn, err := store.Begin()
require.NoError(t, err)
encoded, err := codec.EncodeKey(new(stmtctx.StatementContext), nil, types.NewBytesDatum([]byte{1, 0, 1, 0, 0, 1, 1}))
encoded, err := codec.EncodeKey(stmtctx.NewStmtCtx(), nil, types.NewBytesDatum([]byte{1, 0, 1, 0, 0, 1, 1}))
require.NoError(t, err)
hd, err := kv.NewCommonHandle(encoded)
require.NoError(t, err)
Expand Down
4 changes: 2 additions & 2 deletions executor/test/executor/executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -954,7 +954,7 @@ func TestCheckIndex(t *testing.T) {

mockCtx := mock.NewContext()
idx := tb.Indices()[0]
sc := &stmtctx.StatementContext{TimeZone: time.Local}
sc := stmtctx.NewStmtCtxWithTimeZone(time.Local)

_, err = se.Execute(context.Background(), "admin check index t idx_inexistent")
require.Error(t, err)
Expand Down Expand Up @@ -1012,7 +1012,7 @@ func TestCheckIndex(t *testing.T) {
func setColValue(t *testing.T, txn kv.Transaction, key kv.Key, v types.Datum) {
row := []types.Datum{v, {}}
colIDs := []int64{2, 3}
sc := &stmtctx.StatementContext{TimeZone: time.Local}
sc := stmtctx.NewStmtCtxWithTimeZone(time.Local)
rd := rowcodec.Encoder{Enable: true}
value, err := tablecodec.EncodeRow(sc, row, colIDs, nil, nil, &rd)
require.NoError(t, err)
Expand Down
2 changes: 1 addition & 1 deletion expression/aggregation/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import (
)

func TestDistinct(t *testing.T) {
sc := &stmtctx.StatementContext{TimeZone: time.Local}
sc := stmtctx.NewStmtCtxWithTimeZone(time.Local)
dc := createDistinctChecker(sc)
testCases := []struct {
vals []interface{}
Expand Down
2 changes: 1 addition & 1 deletion expression/bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func (h *benchHelper) init() {
numRows := 4 * 1024

h.ctx = mock.NewContext()
h.ctx.GetSessionVars().StmtCtx.TimeZone = time.Local
h.ctx.GetSessionVars().StmtCtx.SetTimeZone(time.Local)
h.ctx.GetSessionVars().InitChunkSize = 32
h.ctx.GetSessionVars().MaxChunkSize = numRows

Expand Down
8 changes: 7 additions & 1 deletion expression/builtin_cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,13 @@ func (b *castJSONAsArrayFunctionSig) Clone() builtinFunc {
}

// fakeSctx is used to ignore the sql mode, `cast as array` should always return error if any.
var fakeSctx = &stmtctx.StatementContext{InInsertStmt: true}
var fakeSctx = newFakeSctx()

func newFakeSctx() *stmtctx.StatementContext {
sc := stmtctx.NewStmtCtx()
sc.InInsertStmt = true
return sc
}

func (b *castJSONAsArrayFunctionSig) evalJSON(row chunk.Row) (res types.BinaryJSON, isNull bool, err error) {
val, isNull, err := b.args[0].EvalJSON(b.ctx, row)
Expand Down
12 changes: 6 additions & 6 deletions expression/builtin_cast_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -297,12 +297,12 @@ func TestCastFuncSig(t *testing.T) {

sc := ctx.GetSessionVars().StmtCtx
originIgnoreTruncate := sc.IgnoreTruncate.Load()
originTZ := sc.TimeZone
originTZ := sc.TimeZone()
sc.IgnoreTruncate.Store(true)
sc.TimeZone = time.UTC
sc.SetTimeZone(time.UTC)
defer func() {
sc.IgnoreTruncate.Store(originIgnoreTruncate)
sc.TimeZone = originTZ
sc.SetTimeZone(originTZ)
}()
var sig builtinFunc

Expand Down Expand Up @@ -1292,10 +1292,10 @@ func TestWrapWithCastAsTypesClasses(t *testing.T) {
func TestWrapWithCastAsTime(t *testing.T) {
ctx := createContext(t)
sc := ctx.GetSessionVars().StmtCtx
save := sc.TimeZone
sc.TimeZone = time.UTC
save := sc.TimeZone()
sc.SetTimeZone(time.UTC)
defer func() {
sc.TimeZone = save
sc.SetTimeZone(save)
}()
cases := []struct {
expr Expression
Expand Down
Loading
Loading