From 464d5a9017d565a28f274ad3af0612965cf5824e Mon Sep 17 00:00:00 2001 From: Chao Wang Date: Thu, 12 Oct 2023 17:57:28 +0800 Subject: [PATCH 1/2] stmtctx: move time zone info in stmtctx.StatementContext into TypeCtx --- bindinfo/session_handle.go | 2 +- br/pkg/lightning/backend/kv/session.go | 2 +- br/pkg/lightning/backend/local/local_test.go | 2 +- .../backend/local/localhelper_test.go | 2 +- br/pkg/restore/merge_test.go | 2 +- br/pkg/task/backup.go | 4 +- ddl/backfilling_scheduler.go | 4 +- ddl/column.go | 2 +- ddl/export_test.go | 2 +- ddl/index.go | 2 +- ddl/partition.go | 2 +- ddl/reorg.go | 2 +- distsql/distsql_test.go | 8 +- distsql/request_builder_test.go | 6 +- distsql/select_result_test.go | 2 +- domain/domain_test.go | 2 +- executor/BUILD.bazel | 1 + executor/aggfuncs/func_count_test.go | 2 +- executor/analyze_col_v2.go | 2 +- executor/brie.go | 4 +- executor/coprocessor.go | 6 +- executor/executor.go | 17 +-- executor/infoschema_reader.go | 2 +- executor/simple.go | 2 +- executor/split_test.go | 2 +- executor/stmtsummary.go | 4 +- executor/test/admintest/admin_test.go | 2 +- executor/test/executor/executor_test.go | 4 +- expression/aggregation/util_test.go | 2 +- expression/bench_test.go | 2 +- expression/builtin_cast.go | 8 +- expression/builtin_cast_test.go | 12 +- expression/builtin_other_test.go | 2 +- expression/builtin_time.go | 2 +- expression/builtin_time_test.go | 8 +- expression/distsql_builtin.go | 2 +- expression/distsql_builtin_test.go | 8 +- expression/expr_to_pb_test.go | 42 +++--- expression/expression_test.go | 2 +- expression/main_test.go | 2 +- expression/scalar_function_test.go | 2 +- expression/util_test.go | 2 +- kv/key_test.go | 8 +- planner/cardinality/selectivity_test.go | 4 +- planner/core/memtable_predicate_extractor.go | 28 ++-- .../core/memtable_predicate_extractor_test.go | 4 +- planner/core/planbuilder.go | 2 +- server/extension.go | 2 +- .../handler/optimizor/statistics_handler.go | 2 +- server/handler/tests/http_handler_test.go | 6 +- server/handler/tikv_handler.go | 7 +- server/internal/column/column_test.go | 2 +- server/internal/dump/dump_test.go | 4 +- server/internal/parse/parse_test.go | 6 +- sessionctx/stmtctx/BUILD.bazel | 6 +- sessionctx/stmtctx/stmtctx.go | 61 ++++++++- sessionctx/stmtctx/stmtctx_test.go | 128 ++++++++++++++++-- sessionctx/variable/session.go | 6 +- statistics/cmsketch_test.go | 4 +- statistics/fmsketch_test.go | 6 +- statistics/handle/bootstrap.go | 4 +- .../handle/globalstats/global_stats_async.go | 2 +- .../handle/globalstats/topn_bench_test.go | 4 +- statistics/handle/handle_hist_test.go | 8 +- statistics/handle/storage/json.go | 4 +- statistics/handle/storage/read.go | 4 +- statistics/histogram.go | 2 +- statistics/main_test.go | 2 +- statistics/row_sampler.go | 2 +- statistics/sample.go | 2 +- statistics/sample_test.go | 2 +- statistics/scalar.go | 4 +- statistics/statistics_test.go | 2 +- store/mockstore/cluster_test.go | 2 +- store/mockstore/mockcopr/analyze.go | 3 +- store/mockstore/mockcopr/cop_handler_dag.go | 9 +- .../mockstore/unistore/cophandler/analyze.go | 8 +- .../unistore/cophandler/closure_exec.go | 4 +- .../unistore/cophandler/cop_handler.go | 9 +- .../unistore/cophandler/cop_handler_test.go | 2 +- store/mockstore/unistore/cophandler/mpp.go | 4 +- .../mockstore/unistore/cophandler/mpp_exec.go | 2 +- store/mockstore/unistore/tikv/mvcc.go | 2 +- table/column_test.go | 4 +- table/tables/mutation_checker_test.go | 4 +- tablecodec/tablecodec.go | 4 +- tablecodec/tablecodec_test.go | 34 ++--- testkit/testutil/handle.go | 2 +- testkit/testutil/require.go | 2 +- types/binary_literal_test.go | 2 +- types/compare_test.go | 4 +- types/context/context.go | 14 +- types/convert_test.go | 53 ++++---- types/datum.go | 8 +- types/datum_test.go | 36 +++-- types/time.go | 22 +-- types/time_test.go | 62 ++++----- util/chunk/chunk_test.go | 2 +- util/chunk/mutrow_test.go | 2 +- util/codec/codec.go | 4 +- util/codec/codec_test.go | 28 ++-- util/codec/collation_test.go | 6 +- util/dbutil/common.go | 2 +- util/keydecoder/keydecoder_test.go | 2 +- .../memoryusagealarm/memoryusagealarm_test.go | 2 +- util/misc_test.go | 6 +- util/mock/context.go | 2 +- util/nocopy/BUILD.bazel | 8 ++ util/nocopy/nocopy.go | 24 ++++ util/rowDecoder/decoder_test.go | 4 +- util/rowcodec/bench_test.go | 2 +- util/rowcodec/encoder.go | 4 +- util/rowcodec/rowcodec_test.go | 60 ++++---- util/stmtsummary/statement_summary_test.go | 28 ++-- util/stmtsummary/v2/record.go | 11 +- util/util_test.go | 2 +- 116 files changed, 602 insertions(+), 411 deletions(-) create mode 100644 util/nocopy/BUILD.bazel create mode 100644 util/nocopy/nocopy.go diff --git a/bindinfo/session_handle.go b/bindinfo/session_handle.go index f5d97ee671cc1..007b654974f42 100644 --- a/bindinfo/session_handle.go +++ b/bindinfo/session_handle.go @@ -62,7 +62,7 @@ func (h *SessionHandle) CreateBindRecord(sctx sessionctx.Context, record *BindRe return err } record.Db = strings.ToLower(record.Db) - now := types.NewTime(types.FromGoTime(time.Now().In(sctx.GetSessionVars().StmtCtx.TimeZone)), mysql.TypeTimestamp, 3) + now := types.NewTime(types.FromGoTime(time.Now().In(sctx.GetSessionVars().StmtCtx.TimeZone())), mysql.TypeTimestamp, 3) for i := range record.Bindings { record.Bindings[i].CreateTime = now record.Bindings[i].UpdateTime = now diff --git a/br/pkg/lightning/backend/kv/session.go b/br/pkg/lightning/backend/kv/session.go index d6d551251b7cd..d3b21b1dec561 100644 --- a/br/pkg/lightning/backend/kv/session.go +++ b/br/pkg/lightning/backend/kv/session.go @@ -311,7 +311,7 @@ func NewSession(options *encode.SessionOptions, logger log.Logger) *Session { } } } - vars.StmtCtx.TimeZone = vars.Location() + vars.StmtCtx.SetTimeZone(vars.Location()) if err := vars.SetSystemVar("timestamp", strconv.FormatInt(options.Timestamp, 10)); err != nil { logger.Warn("new session: failed to set timestamp", log.ShortError(err)) diff --git a/br/pkg/lightning/backend/local/local_test.go b/br/pkg/lightning/backend/local/local_test.go index 82b7740cc06f6..60c34c96c1e69 100644 --- a/br/pkg/lightning/backend/local/local_test.go +++ b/br/pkg/lightning/backend/local/local_test.go @@ -108,7 +108,7 @@ func TestNextKey(t *testing.T) { {types.NewStringDatum("test\255"), types.NewStringDatum("test\255\000")}, } - stmtCtx := new(stmtctx.StatementContext) + stmtCtx := stmtctx.NewStmtCtx() for _, datums := range testDatums { keyBytes, err := codec.EncodeKey(stmtCtx, nil, types.NewIntDatum(123), datums[0]) require.NoError(t, err) diff --git a/br/pkg/lightning/backend/local/localhelper_test.go b/br/pkg/lightning/backend/local/localhelper_test.go index b31482236cf0a..75ab382e2b0b7 100644 --- a/br/pkg/lightning/backend/local/localhelper_test.go +++ b/br/pkg/lightning/backend/local/localhelper_test.go @@ -783,7 +783,7 @@ func doTestBatchSplitByRangesWithClusteredIndex(t *testing.T, hook clientHook) { splitRegionBaseBackOffTime = oldSplitBackoffTime }() - stmtCtx := new(stmtctx.StatementContext) + stmtCtx := stmtctx.NewStmtCtx() tableID := int64(1) tableStartKey := tablecodec.EncodeTablePrefix(tableID) diff --git a/br/pkg/restore/merge_test.go b/br/pkg/restore/merge_test.go index 8bd2121dfe0c7..6b8f38f75cf98 100644 --- a/br/pkg/restore/merge_test.go +++ b/br/pkg/restore/merge_test.go @@ -46,7 +46,7 @@ func (fb *fileBulder) build(tableID, indexID, num, bytes, kv int) (files []*back if indexID != 0 { lowVal := types.NewIntDatum(fb.startKeyOffset - 10) highVal := types.NewIntDatum(fb.startKeyOffset) - sc := &stmtctx.StatementContext{TimeZone: time.UTC} + sc := stmtctx.NewStmtCtxWithTimeZone(time.UTC) lowValue, err := codec.EncodeKey(sc, nil, lowVal) if err != nil { panic(err) diff --git a/br/pkg/task/backup.go b/br/pkg/task/backup.go index 4d8f58e1ceb58..ec4dcc42ae348 100644 --- a/br/pkg/task/backup.go +++ b/br/pkg/task/backup.go @@ -767,9 +767,7 @@ func ParseTSString(ts string, tzCheck bool) (uint64, error) { } loc := time.Local - sc := &stmtctx.StatementContext{ - TimeZone: loc, - } + sc := stmtctx.NewStmtCtxWithTimeZone(loc) if tzCheck { tzIdx, _, _, _, _ := types.GetTimezone(ts) if tzIdx < 0 { diff --git a/ddl/backfilling_scheduler.go b/ddl/backfilling_scheduler.go index 9cbf25afd4b09..7f54f6650110a 100644 --- a/ddl/backfilling_scheduler.go +++ b/ddl/backfilling_scheduler.go @@ -149,9 +149,9 @@ func initSessCtx( tzLocation *model.TimeZoneLocation, ) error { // Unify the TimeZone settings in newContext. - if sessCtx.GetSessionVars().StmtCtx.TimeZone == nil { + if sessCtx.GetSessionVars().StmtCtx.TimeZone() == nil { tz := *time.UTC - sessCtx.GetSessionVars().StmtCtx.TimeZone = &tz + sessCtx.GetSessionVars().StmtCtx.SetTimeZone(&tz) } sessCtx.GetSessionVars().StmtCtx.IsDDLJobInQueue = true // Set the row encode format version. diff --git a/ddl/column.go b/ddl/column.go index e4e533310691a..596d200856f92 100644 --- a/ddl/column.go +++ b/ddl/column.go @@ -1314,7 +1314,7 @@ func (w *updateColumnWorker) fetchRowColVals(txn kv.Transaction, taskRange reorg } func (w *updateColumnWorker) getRowRecord(handle kv.Handle, recordKey []byte, rawRow []byte) error { - sysTZ := w.sessCtx.GetSessionVars().StmtCtx.TimeZone + sysTZ := w.sessCtx.GetSessionVars().StmtCtx.TimeZone() _, err := w.rowDecoder.DecodeTheExistedColumnMap(w.sessCtx, handle, rawRow, sysTZ, w.rowMap) if err != nil { return errors.Trace(dbterror.ErrCantDecodeRecord.GenWithStackByArgs("column", err)) diff --git a/ddl/export_test.go b/ddl/export_test.go index f5bfba95d7ac9..919a4b2c32049 100644 --- a/ddl/export_test.go +++ b/ddl/export_test.go @@ -63,6 +63,6 @@ func ConvertRowToHandleAndIndexDatum( c := copCtx.GetBase() idxData := extractDatumByOffsets(row, copCtx.IndexColumnOutputOffsets(idxID), c.ExprColumnInfos, nil) handleData := extractDatumByOffsets(row, c.HandleOutputOffsets, c.ExprColumnInfos, nil) - handle, err := buildHandle(handleData, c.TableInfo, c.PrimaryKeyInfo, &stmtctx.StatementContext{TimeZone: time.Local}) + handle, err := buildHandle(handleData, c.TableInfo, c.PrimaryKeyInfo, stmtctx.NewStmtCtxWithTimeZone(time.Local)) return handle, idxData, err } diff --git a/ddl/index.go b/ddl/index.go index 659a49a449906..c5f9ac7e131ce 100644 --- a/ddl/index.go +++ b/ddl/index.go @@ -1494,7 +1494,7 @@ func (w *baseIndexWorker) getNextKey(taskRange reorgBackfillTask, taskDone bool) } func (w *baseIndexWorker) updateRowDecoder(handle kv.Handle, rawRecord []byte) error { - sysZone := w.sessCtx.GetSessionVars().StmtCtx.TimeZone + sysZone := w.sessCtx.GetSessionVars().StmtCtx.TimeZone() _, err := w.rowDecoder.DecodeAndEvalRowWithMap(w.sessCtx, handle, rawRecord, sysZone, w.rowMap) return errors.Trace(err) } diff --git a/ddl/partition.go b/ddl/partition.go index d201dba4acda7..9d38a94745713 100644 --- a/ddl/partition.go +++ b/ddl/partition.go @@ -3181,7 +3181,7 @@ func (w *reorgPartitionWorker) fetchRowColVals(txn kv.Transaction, taskRange reo // taskDone means that the added handle is out of taskRange.endHandle. taskDone := false - sysTZ := w.sessCtx.GetSessionVars().StmtCtx.TimeZone + sysTZ := w.sessCtx.GetSessionVars().StmtCtx.TimeZone() tmpRow := make([]types.Datum, w.maxOffset+1) var lastAccessedHandle kv.Key diff --git a/ddl/reorg.go b/ddl/reorg.go index 997dcea0384e6..1978c81c04932 100644 --- a/ddl/reorg.go +++ b/ddl/reorg.go @@ -82,7 +82,7 @@ func newContext(store kv.Storage) sessionctx.Context { tz := *time.UTC c.GetSessionVars().TimeZone = &tz - c.GetSessionVars().StmtCtx.TimeZone = &tz + c.GetSessionVars().StmtCtx.SetTimeZone(&tz) return c } diff --git a/distsql/distsql_test.go b/distsql/distsql_test.go index 49ed70c2b672d..660b7f36f5cc7 100644 --- a/distsql/distsql_test.go +++ b/distsql/distsql_test.go @@ -300,10 +300,10 @@ func (r *mockResultSubset) RespTime() time.Duration { return 0 } func newMockSessionContext() sessionctx.Context { ctx := mock.NewContext() - ctx.GetSessionVars().StmtCtx = &stmtctx.StatementContext{ - MemTracker: memory.NewTracker(-1, -1), - DiskTracker: disk.NewTracker(-1, -1), - } + ctx.GetSessionVars().StmtCtx = stmtctx.NewStmtCtx() + ctx.GetSessionVars().StmtCtx.MemTracker = memory.NewTracker(-1, -1) + ctx.GetSessionVars().StmtCtx.DiskTracker = disk.NewTracker(-1, -1) + ctx.Store = &mock.Store{ Client: &mock.Client{ MockResponse: &mockResponse{ diff --git a/distsql/request_builder_test.go b/distsql/request_builder_test.go index ad51f661e5ee8..8db7e8598df11 100644 --- a/distsql/request_builder_test.go +++ b/distsql/request_builder_test.go @@ -189,7 +189,7 @@ func TestIndexRangesToKVRanges(t *testing.T) { }, } - actual, err := IndexRangesToKVRanges(new(stmtctx.StatementContext), 12, 15, ranges) + actual, err := IndexRangesToKVRanges(stmtctx.NewStmtCtx(), 12, 15, ranges) require.NoError(t, err) for i := range actual.FirstPartitionRange() { require.Equal(t, expect[i], actual.FirstPartitionRange()[i]) @@ -314,7 +314,7 @@ func TestRequestBuilder2(t *testing.T) { }, } - actual, err := (&RequestBuilder{}).SetIndexRanges(new(stmtctx.StatementContext), 12, 15, ranges). + actual, err := (&RequestBuilder{}).SetIndexRanges(stmtctx.NewStmtCtx(), 12, 15, ranges). SetDAGRequest(&tipb.DAGRequest{}). SetDesc(false). SetKeepOrder(false). @@ -659,7 +659,7 @@ func TestIndexRangesToKVRangesWithFbs(t *testing.T) { Collators: collate.GetBinaryCollatorSlice(1), }, } - actual, err := IndexRangesToKVRanges(new(stmtctx.StatementContext), 0, 0, ranges) + actual, err := IndexRangesToKVRanges(stmtctx.NewStmtCtx(), 0, 0, ranges) require.NoError(t, err) expect := []kv.KeyRange{ { diff --git a/distsql/select_result_test.go b/distsql/select_result_test.go index 4ec56a286e5ab..636f0d46467ae 100644 --- a/distsql/select_result_test.go +++ b/distsql/select_result_test.go @@ -29,7 +29,7 @@ import ( func TestUpdateCopRuntimeStats(t *testing.T) { ctx := mock.NewContext() - ctx.GetSessionVars().StmtCtx = new(stmtctx.StatementContext) + ctx.GetSessionVars().StmtCtx = stmtctx.NewStmtCtx() sr := selectResult{ctx: ctx, storeType: kv.TiKV} require.Nil(t, ctx.GetSessionVars().StmtCtx.RuntimeStatsColl) diff --git a/domain/domain_test.go b/domain/domain_test.go index 5ae256a277fb4..259e50804fa11 100644 --- a/domain/domain_test.go +++ b/domain/domain_test.go @@ -196,7 +196,7 @@ func TestStatWorkRecoverFromPanic(t *testing.T) { require.Equal(t, expiredTimeStamp, ts) // set expiredTimeStamp4PC to "2023-08-02 12:15:00" - ts, _ = types.ParseTimestamp(&stmtctx.StatementContext{TimeZone: time.UTC}, "2023-08-02 12:15:00") + ts, _ = types.ParseTimestamp(stmtctx.NewStmtCtxWithTimeZone(time.UTC), "2023-08-02 12:15:00") dom.SetExpiredTimeStamp4PC(ts) expiredTimeStamp = dom.ExpiredTimeStamp4PC() require.Equal(t, expiredTimeStamp, ts) diff --git a/executor/BUILD.bazel b/executor/BUILD.bazel index 1f8b8c4eaefa6..8b6764cddd89c 100644 --- a/executor/BUILD.bazel +++ b/executor/BUILD.bazel @@ -176,6 +176,7 @@ go_library( "//telemetry", "//tidb-binlog/node", "//types", + "//types/context", "//types/parser_driver", "//util", "//util/admin", diff --git a/executor/aggfuncs/func_count_test.go b/executor/aggfuncs/func_count_test.go index 6f8a684650a06..c55c07fdb5f6b 100644 --- a/executor/aggfuncs/func_count_test.go +++ b/executor/aggfuncs/func_count_test.go @@ -158,7 +158,7 @@ func TestMemCount(t *testing.T) { } func TestWriteTime(t *testing.T) { - tt, err := types.ParseDate(&(stmtctx.StatementContext{}), "2020-11-11") + tt, err := types.ParseDate(stmtctx.NewStmtCtx(), "2020-11-11") require.NoError(t, err) buf := make([]byte, 16) diff --git a/executor/analyze_col_v2.go b/executor/analyze_col_v2.go index ed08d5f718e6e..e84c738aac2d9 100644 --- a/executor/analyze_col_v2.go +++ b/executor/analyze_col_v2.go @@ -350,7 +350,7 @@ func (e *AnalyzeColumnsExecV2) buildSamplingStats( // If there's no virtual column or we meet error during eval virtual column, we fallback to normal decode otherwise. for _, sample := range rootRowCollector.Base().Samples { for i := range sample.Columns { - sample.Columns[i], err = tablecodec.DecodeColumnValue(sample.Columns[i].GetBytes(), &e.colsInfo[i].FieldType, sc.TimeZone) + sample.Columns[i], err = tablecodec.DecodeColumnValue(sample.Columns[i].GetBytes(), &e.colsInfo[i].FieldType, sc.TimeZone()) if err != nil { return 0, nil, nil, nil, nil, err } diff --git a/executor/brie.go b/executor/brie.go index 237c8b2466255..1ef1f3abadca9 100644 --- a/executor/brie.go +++ b/executor/brie.go @@ -235,12 +235,12 @@ func (bq *brieQueue) clearTask(sc *stmtctx.StatementContext) { } func (b *executorBuilder) parseTSString(ts string) (uint64, error) { - sc := &stmtctx.StatementContext{TimeZone: b.ctx.GetSessionVars().Location()} + sc := stmtctx.NewStmtCtxWithTimeZone(b.ctx.GetSessionVars().Location()) t, err := types.ParseTime(sc, ts, mysql.TypeTimestamp, types.MaxFsp, nil) if err != nil { return 0, err } - t1, err := t.GoTime(sc.TimeZone) + t1, err := t.GoTime(sc.TimeZone()) if err != nil { return 0, err } diff --git a/executor/coprocessor.go b/executor/coprocessor.go index 820922c6b8053..e5879334d7404 100644 --- a/executor/coprocessor.go +++ b/executor/coprocessor.go @@ -183,11 +183,13 @@ func (h *CoprocessorDAGHandler) buildDAGExecutor(req *coprocessor.Request) (exec stmtCtx := h.sctx.GetSessionVars().StmtCtx stmtCtx.SetFlagsFromPBFlag(dagReq.Flags) - stmtCtx.TimeZone, err = timeutil.ConstructTimeZone(dagReq.TimeZoneName, int(dagReq.TimeZoneOffset)) - h.sctx.GetSessionVars().TimeZone = stmtCtx.TimeZone + tz, err := timeutil.ConstructTimeZone(dagReq.TimeZoneName, int(dagReq.TimeZoneOffset)) if err != nil { return nil, errors.Trace(err) } + + stmtCtx.SetTimeZone(tz) + h.sctx.GetSessionVars().TimeZone = tz h.dagReq = dagReq is := h.sctx.GetInfoSchema().(infoschema.InfoSchema) // Build physical plan. diff --git a/executor/executor.go b/executor/executor.go index 727ecfae2b2c5..568cbc410c6f6 100644 --- a/executor/executor.go +++ b/executor/executor.go @@ -18,6 +18,7 @@ import ( "cmp" "context" "fmt" + typectx "github.com/pingcap/tidb/types/context" "math" "runtime/pprof" "slices" @@ -1938,12 +1939,11 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { if vars.TxnCtx.CouldRetry || mysql.HasCursorExistsFlag(vars.Status) { // Must construct new statement context object, the retry history need context for every statement. // TODO: Maybe one day we can get rid of transaction retry, then this logic can be deleted. - sc = &stmtctx.StatementContext{} + sc = stmtctx.NewStmtCtx() } else { sc = vars.InitStatementContext() } - var typeConvFlags types.Flags - sc.TimeZone = vars.Location() + sc.SetTimeZone(vars.Location()) sc.TaskID = stmtctx.AllocateTaskID() sc.CTEStorageMap = map[int]*CTEStorages{} sc.IsStaleness = false @@ -2139,10 +2139,12 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { sc.AllowInvalidDate = vars.SQLMode.HasAllowInvalidDatesMode() } - typeConvFlags = typeConvFlags. - WithSkipUTF8Check(vars.SkipUTF8Check). - WithSkipSACIICheck(vars.SkipASCIICheck). - WithSkipUTF8MB4Check(!globalConfig.Instance.CheckMb4ValueInUTF8.Load()) + sc.UpdateTypeFlags(func(flags typectx.Flags) typectx.Flags { + return flags. + WithSkipUTF8Check(vars.SkipUTF8Check). + WithSkipSACIICheck(vars.SkipASCIICheck). + WithSkipUTF8MB4Check(!globalConfig.Instance.CheckMb4ValueInUTF8.Load()) + }) vars.PlanCacheParams.Reset() if priority := mysql.PriorityEnum(atomic.LoadInt32(&variable.ForcePriority)); priority != mysql.NoPriority { @@ -2169,7 +2171,6 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { sc.RuntimeStatsColl = execdetails.NewRuntimeStatsColl(reuseObj) } - sc.TypeConvContext = types.NewContext(typeConvFlags, sc.TimeZone, sc.AppendWarning) sc.TblInfo2UnionScan = make(map[*model.TableInfo]bool) errCount, warnCount := vars.StmtCtx.NumErrorWarnings() vars.SysErrorCount = errCount diff --git a/executor/infoschema_reader.go b/executor/infoschema_reader.go index 382ef32f928f2..1b2be9f2b9d63 100644 --- a/executor/infoschema_reader.go +++ b/executor/infoschema_reader.go @@ -1460,7 +1460,7 @@ func (e *memtableRetriever) setDataForProcessList(ctx sessionctx.Context) { continue } - rows := pi.ToRow(ctx.GetSessionVars().StmtCtx.TimeZone) + rows := pi.ToRow(ctx.GetSessionVars().StmtCtx.TimeZone()) record := types.MakeDatums(rows...) records = append(records, record) } diff --git a/executor/simple.go b/executor/simple.go index 13ad68e0a5bde..fa0ee887994f6 100644 --- a/executor/simple.go +++ b/executor/simple.go @@ -2807,7 +2807,7 @@ func (e *SimpleExec) executeAdminFlushPlanCache(s *ast.AdminStmt) error { e.Ctx().GetSessionVars().StmtCtx.AppendWarning(errors.New("The plan cache is disable. So there no need to flush the plan cache")) return nil } - now := types.NewTime(types.FromGoTime(time.Now().In(e.Ctx().GetSessionVars().StmtCtx.TimeZone)), mysql.TypeTimestamp, 3) + now := types.NewTime(types.FromGoTime(time.Now().In(e.Ctx().GetSessionVars().StmtCtx.TimeZone())), mysql.TypeTimestamp, 3) e.Ctx().GetSessionVars().LastUpdateTime4PC = now e.Ctx().GetSessionPlanCache().DeleteAll() if s.StatementScope == ast.StatementScopeInstance { diff --git a/executor/split_test.go b/executor/split_test.go index a557244cc407a..c2c58e8547c78 100644 --- a/executor/split_test.go +++ b/executor/split_test.go @@ -432,7 +432,7 @@ func TestClusterIndexSplitTable(t *testing.T) { }(minRegionStepValue) minRegionStepValue = 3 ctx := mock.NewContext() - sc := &stmtctx.StatementContext{TimeZone: time.Local} + sc := stmtctx.NewStmtCtxWithTimeZone(time.Local) e := &SplitTableRegionExec{ BaseExecutor: exec.NewBaseExecutor(ctx, nil, 0), tableInfo: tbInfo, diff --git a/executor/stmtsummary.go b/executor/stmtsummary.go index 272d863d38de4..e5a13c12c8294 100644 --- a/executor/stmtsummary.go +++ b/executor/stmtsummary.go @@ -142,7 +142,7 @@ func (e *stmtSummaryRetriever) initEvictedRowsReader(sctx sessionctx.Context) (* func (e *stmtSummaryRetriever) initSummaryRowsReader(sctx sessionctx.Context) (*rowsReader, error) { vars := sctx.GetSessionVars() user := vars.User - tz := vars.StmtCtx.TimeZone + tz := vars.StmtCtx.TimeZone() columns := e.columns priv := hasPriv(sctx, mysql.ProcessPriv) instanceAddr, err := clusterTableInstanceAddr(sctx, e.table.Name.O) @@ -241,7 +241,7 @@ func (r *stmtSummaryRetrieverV2) initEvictedRowsReader(sctx sessionctx.Context) func (r *stmtSummaryRetrieverV2) initSummaryRowsReader(ctx context.Context, sctx sessionctx.Context) (*rowsReader, error) { vars := sctx.GetSessionVars() user := vars.User - tz := vars.StmtCtx.TimeZone + tz := vars.StmtCtx.TimeZone() stmtSummary := r.stmtSummary columns := r.columns timeRanges := r.timeRanges diff --git a/executor/test/admintest/admin_test.go b/executor/test/admintest/admin_test.go index 686c9297d869b..1aff97da3f8f1 100644 --- a/executor/test/admintest/admin_test.go +++ b/executor/test/admintest/admin_test.go @@ -1272,7 +1272,7 @@ func TestCheckFailReport(t *testing.T) { txn, err := store.Begin() require.NoError(t, err) - encoded, err := codec.EncodeKey(new(stmtctx.StatementContext), nil, types.NewBytesDatum([]byte{1, 0, 1, 0, 0, 1, 1})) + encoded, err := codec.EncodeKey(stmtctx.NewStmtCtx(), nil, types.NewBytesDatum([]byte{1, 0, 1, 0, 0, 1, 1})) require.NoError(t, err) hd, err := kv.NewCommonHandle(encoded) require.NoError(t, err) diff --git a/executor/test/executor/executor_test.go b/executor/test/executor/executor_test.go index bf2517e89bb6d..3e1f5eaf14354 100644 --- a/executor/test/executor/executor_test.go +++ b/executor/test/executor/executor_test.go @@ -954,7 +954,7 @@ func TestCheckIndex(t *testing.T) { mockCtx := mock.NewContext() idx := tb.Indices()[0] - sc := &stmtctx.StatementContext{TimeZone: time.Local} + sc := stmtctx.NewStmtCtxWithTimeZone(time.Local) _, err = se.Execute(context.Background(), "admin check index t idx_inexistent") require.Error(t, err) @@ -1012,7 +1012,7 @@ func TestCheckIndex(t *testing.T) { func setColValue(t *testing.T, txn kv.Transaction, key kv.Key, v types.Datum) { row := []types.Datum{v, {}} colIDs := []int64{2, 3} - sc := &stmtctx.StatementContext{TimeZone: time.Local} + sc := stmtctx.NewStmtCtxWithTimeZone(time.Local) rd := rowcodec.Encoder{Enable: true} value, err := tablecodec.EncodeRow(sc, row, colIDs, nil, nil, &rd) require.NoError(t, err) diff --git a/expression/aggregation/util_test.go b/expression/aggregation/util_test.go index 3ac3720ffb714..2c476d19cd136 100644 --- a/expression/aggregation/util_test.go +++ b/expression/aggregation/util_test.go @@ -24,7 +24,7 @@ import ( ) func TestDistinct(t *testing.T) { - sc := &stmtctx.StatementContext{TimeZone: time.Local} + sc := stmtctx.NewStmtCtxWithTimeZone(time.Local) dc := createDistinctChecker(sc) testCases := []struct { vals []interface{} diff --git a/expression/bench_test.go b/expression/bench_test.go index cf8c3bef8a390..ecf23f07f604a 100644 --- a/expression/bench_test.go +++ b/expression/bench_test.go @@ -58,7 +58,7 @@ func (h *benchHelper) init() { numRows := 4 * 1024 h.ctx = mock.NewContext() - h.ctx.GetSessionVars().StmtCtx.TimeZone = time.Local + h.ctx.GetSessionVars().StmtCtx.SetTimeZone(time.Local) h.ctx.GetSessionVars().InitChunkSize = 32 h.ctx.GetSessionVars().MaxChunkSize = numRows diff --git a/expression/builtin_cast.go b/expression/builtin_cast.go index a4876714ebc0c..a9a9095250f0c 100644 --- a/expression/builtin_cast.go +++ b/expression/builtin_cast.go @@ -476,7 +476,13 @@ func (b *castJSONAsArrayFunctionSig) Clone() builtinFunc { } // fakeSctx is used to ignore the sql mode, `cast as array` should always return error if any. -var fakeSctx = &stmtctx.StatementContext{InInsertStmt: true} +var fakeSctx = newFakeSctx() + +func newFakeSctx() *stmtctx.StatementContext { + sc := stmtctx.NewStmtCtx() + sc.InInsertStmt = true + return sc +} func (b *castJSONAsArrayFunctionSig) evalJSON(row chunk.Row) (res types.BinaryJSON, isNull bool, err error) { val, isNull, err := b.args[0].EvalJSON(b.ctx, row) diff --git a/expression/builtin_cast_test.go b/expression/builtin_cast_test.go index 9a03ff4a80025..143030dd3435a 100644 --- a/expression/builtin_cast_test.go +++ b/expression/builtin_cast_test.go @@ -297,12 +297,12 @@ func TestCastFuncSig(t *testing.T) { sc := ctx.GetSessionVars().StmtCtx originIgnoreTruncate := sc.IgnoreTruncate.Load() - originTZ := sc.TimeZone + originTZ := sc.TimeZone() sc.IgnoreTruncate.Store(true) - sc.TimeZone = time.UTC + sc.SetTimeZone(time.UTC) defer func() { sc.IgnoreTruncate.Store(originIgnoreTruncate) - sc.TimeZone = originTZ + sc.SetTimeZone(originTZ) }() var sig builtinFunc @@ -1292,10 +1292,10 @@ func TestWrapWithCastAsTypesClasses(t *testing.T) { func TestWrapWithCastAsTime(t *testing.T) { ctx := createContext(t) sc := ctx.GetSessionVars().StmtCtx - save := sc.TimeZone - sc.TimeZone = time.UTC + save := sc.TimeZone() + sc.SetTimeZone(time.UTC) defer func() { - sc.TimeZone = save + sc.SetTimeZone(save) }() cases := []struct { expr Expression diff --git a/expression/builtin_other_test.go b/expression/builtin_other_test.go index dfa32ddf8a15e..66b201ba77e89 100644 --- a/expression/builtin_other_test.go +++ b/expression/builtin_other_test.go @@ -67,7 +67,7 @@ func TestBitCount(t *testing.T) { require.Nil(t, test.count) continue } - sc := new(stmtctx.StatementContext) + sc := stmtctx.NewStmtCtx() sc.IgnoreTruncate.Store(true) res, err := count.ToInt64(sc) require.NoError(t, err) diff --git a/expression/builtin_time.go b/expression/builtin_time.go index b672b7c1e8214..4f9e4d5ac1a6d 100644 --- a/expression/builtin_time.go +++ b/expression/builtin_time.go @@ -1736,7 +1736,7 @@ func evalFromUnixTime(ctx sessionctx.Context, fsp int, unixTimeStamp *types.MyDe } sc := ctx.GetSessionVars().StmtCtx - tmp := time.Unix(integralPart, fractionalPart).In(sc.TimeZone) + tmp := time.Unix(integralPart, fractionalPart).In(sc.TimeZone()) t, err := convertTimeToMysqlTime(tmp, fsp, types.ModeHalfUp) if err != nil { return res, true, err diff --git a/expression/builtin_time_test.go b/expression/builtin_time_test.go index 44121d729514b..034c9d213461c 100644 --- a/expression/builtin_time_test.go +++ b/expression/builtin_time_test.go @@ -1130,7 +1130,7 @@ func TestSubTimeSig(t *testing.T) { func TestSysDate(t *testing.T) { fc := funcs[ast.Sysdate] ctx := mock.NewContext() - ctx.GetSessionVars().StmtCtx.TimeZone = timeutil.SystemLocation() + ctx.GetSessionVars().StmtCtx.SetTimeZone(timeutil.SystemLocation()) timezones := []string{"1234", "0"} for _, timezone := range timezones { // sysdate() result is not affected by "timestamp" session variable. @@ -1254,10 +1254,10 @@ func TestFromUnixTime(t *testing.T) { {true, 32536771199, 32536771199.99999, "", "3001-01-18 23:59:59.999990"}, } sc := ctx.GetSessionVars().StmtCtx - originTZ := sc.TimeZone - sc.TimeZone = time.UTC + originTZ := sc.TimeZone() + sc.SetTimeZone(time.UTC) defer func() { - sc.TimeZone = originTZ + sc.SetTimeZone(originTZ) }() fc := funcs[ast.FromUnixTime] for _, c := range tbl { diff --git a/expression/distsql_builtin.go b/expression/distsql_builtin.go index 924d7841128c0..01a52a9dc5a8b 100644 --- a/expression/distsql_builtin.go +++ b/expression/distsql_builtin.go @@ -1153,7 +1153,7 @@ func PBToExpr(expr *tipb.Expr, tps []*types.FieldType, sc *stmtctx.StatementCont case tipb.ExprType_MysqlDuration: return convertDuration(expr.Val) case tipb.ExprType_MysqlTime: - return convertTime(expr.Val, expr.FieldType, sc.TimeZone) + return convertTime(expr.Val, expr.FieldType, sc.TimeZone()) case tipb.ExprType_MysqlJson: return convertJSON(expr.Val) case tipb.ExprType_MysqlEnum: diff --git a/expression/distsql_builtin_test.go b/expression/distsql_builtin_test.go index 1ed2df7048db3..4483602a0bcb8 100644 --- a/expression/distsql_builtin_test.go +++ b/expression/distsql_builtin_test.go @@ -30,7 +30,7 @@ import ( ) func TestPBToExpr(t *testing.T) { - sc := new(stmtctx.StatementContext) + sc := stmtctx.NewStmtCtx() fieldTps := make([]*types.FieldType, 1) ds := []types.Datum{types.NewIntDatum(1), types.NewUintDatum(1), types.NewFloat64Datum(1), types.NewDecimalDatum(newMyDecimal(t, "1")), types.NewDurationDatum(newDuration(time.Second))} @@ -778,7 +778,7 @@ func TestEval(t *testing.T) { types.NewIntDatum(1), }, } - sc := new(stmtctx.StatementContext) + sc := stmtctx.NewStmtCtx() for _, tt := range tests { expr, err := PBToExpr(tt.expr, fieldTps, sc) require.NoError(t, err) @@ -793,7 +793,7 @@ func TestEval(t *testing.T) { func TestPBToExprWithNewCollation(t *testing.T) { collate.SetNewCollationEnabledForTest(false) - sc := new(stmtctx.StatementContext) + sc := stmtctx.NewStmtCtx() fieldTps := make([]*types.FieldType, 1) cases := []struct { @@ -848,7 +848,7 @@ func TestPBToExprWithNewCollation(t *testing.T) { // Test convert various scalar functions. func TestPBToScalarFuncExpr(t *testing.T) { - sc := new(stmtctx.StatementContext) + sc := stmtctx.NewStmtCtx() fieldTps := make([]*types.FieldType, 1) exprs := []*tipb.Expr{ { diff --git a/expression/expr_to_pb_test.go b/expression/expr_to_pb_test.go index 381143c328e67..877508d962bc9 100644 --- a/expression/expr_to_pb_test.go +++ b/expression/expr_to_pb_test.go @@ -44,7 +44,7 @@ func genColumn(tp byte, id int64) *Column { func TestConstant2Pb(t *testing.T) { t.Skip("constant pb has changed") var constExprs []Expression - sc := new(stmtctx.StatementContext) + sc := stmtctx.NewStmtCtx() client := new(mock.Client) // can be transformed @@ -127,7 +127,7 @@ func TestConstant2Pb(t *testing.T) { func TestColumn2Pb(t *testing.T) { var colExprs []Expression - sc := new(stmtctx.StatementContext) + sc := stmtctx.NewStmtCtx() client := new(mock.Client) colExprs = append(colExprs, genColumn(mysql.TypeSet, 1)) @@ -219,7 +219,7 @@ func TestColumn2Pb(t *testing.T) { func TestCompareFunc2Pb(t *testing.T) { var compareExprs = make([]Expression, 0) - sc := new(stmtctx.StatementContext) + sc := stmtctx.NewStmtCtx() client := new(mock.Client) funcNames := []string{ast.LT, ast.LE, ast.GT, ast.GE, ast.EQ, ast.NE, ast.NullEQ} @@ -255,7 +255,7 @@ func TestCompareFunc2Pb(t *testing.T) { func TestLikeFunc2Pb(t *testing.T) { var likeFuncs []Expression - sc := new(stmtctx.StatementContext) + sc := stmtctx.NewStmtCtx() client := new(mock.Client) retTp := types.NewFieldType(mysql.TypeString) @@ -293,7 +293,7 @@ func TestLikeFunc2Pb(t *testing.T) { func TestArithmeticalFunc2Pb(t *testing.T) { var arithmeticalFuncs = make([]Expression, 0) - sc := new(stmtctx.StatementContext) + sc := stmtctx.NewStmtCtx() client := new(mock.Client) funcNames := []string{ast.Plus, ast.Minus, ast.Mul, ast.Div} @@ -339,7 +339,7 @@ func TestArithmeticalFunc2Pb(t *testing.T) { } func TestDateFunc2Pb(t *testing.T) { - sc := new(stmtctx.StatementContext) + sc := stmtctx.NewStmtCtx() client := new(mock.Client) fc, err := NewFunction( @@ -360,7 +360,7 @@ func TestDateFunc2Pb(t *testing.T) { func TestLogicalFunc2Pb(t *testing.T) { var logicalFuncs = make([]Expression, 0) - sc := new(stmtctx.StatementContext) + sc := stmtctx.NewStmtCtx() client := new(mock.Client) funcNames := []string{ast.LogicAnd, ast.LogicOr, ast.LogicXor, ast.UnaryNot} @@ -396,7 +396,7 @@ func TestLogicalFunc2Pb(t *testing.T) { func TestBitwiseFunc2Pb(t *testing.T) { var bitwiseFuncs = make([]Expression, 0) - sc := new(stmtctx.StatementContext) + sc := stmtctx.NewStmtCtx() client := new(mock.Client) funcNames := []string{ast.And, ast.Or, ast.Xor, ast.LeftShift, ast.RightShift, ast.BitNeg} @@ -434,7 +434,7 @@ func TestBitwiseFunc2Pb(t *testing.T) { func TestControlFunc2Pb(t *testing.T) { var controlFuncs = make([]Expression, 0) - sc := new(stmtctx.StatementContext) + sc := stmtctx.NewStmtCtx() client := new(mock.Client) funcNames := []string{ @@ -475,7 +475,7 @@ func TestControlFunc2Pb(t *testing.T) { func TestOtherFunc2Pb(t *testing.T) { var otherFuncs = make([]Expression, 0) - sc := new(stmtctx.StatementContext) + sc := stmtctx.NewStmtCtx() client := new(mock.Client) funcNames := []string{ast.Coalesce, ast.IsNull} @@ -504,7 +504,7 @@ func TestOtherFunc2Pb(t *testing.T) { } func TestExprPushDownToFlash(t *testing.T) { - sc := new(stmtctx.StatementContext) + sc := stmtctx.NewStmtCtx() client := new(mock.Client) exprs := make([]Expression, 0) @@ -1300,7 +1300,7 @@ func TestExprPushDownToFlash(t *testing.T) { func TestExprOnlyPushDownToFlash(t *testing.T) { t.Skip("Skip this unstable test temporarily and bring it back before 2021-07-26") - sc := new(stmtctx.StatementContext) + sc := stmtctx.NewStmtCtx() client := new(mock.Client) exprs := make([]Expression, 0) @@ -1357,7 +1357,7 @@ func TestExprOnlyPushDownToFlash(t *testing.T) { } func TestExprPushDownToTiKV(t *testing.T) { - sc := new(stmtctx.StatementContext) + sc := stmtctx.NewStmtCtx() client := new(mock.Client) exprs := make([]Expression, 0) @@ -1545,7 +1545,7 @@ func TestExprPushDownToTiKV(t *testing.T) { } func TestExprOnlyPushDownToTiKV(t *testing.T) { - sc := new(stmtctx.StatementContext) + sc := stmtctx.NewStmtCtx() client := new(mock.Client) function, err := NewFunction(mock.NewContext(), "uuid", types.NewFieldType(mysql.TypeLonglong)) @@ -1571,7 +1571,7 @@ func TestExprOnlyPushDownToTiKV(t *testing.T) { } func TestGroupByItem2Pb(t *testing.T) { - sc := new(stmtctx.StatementContext) + sc := stmtctx.NewStmtCtx() client := new(mock.Client) item := genColumn(mysql.TypeDouble, 0) @@ -1588,7 +1588,7 @@ func TestGroupByItem2Pb(t *testing.T) { } func TestSortByItem2Pb(t *testing.T) { - sc := new(stmtctx.StatementContext) + sc := stmtctx.NewStmtCtx() client := new(mock.Client) item := genColumn(mysql.TypeDouble, 0) @@ -1614,7 +1614,7 @@ func TestPushCollationDown(t *testing.T) { fc, err := NewFunction(mock.NewContext(), ast.EQ, types.NewFieldType(mysql.TypeUnspecified), genColumn(mysql.TypeVarchar, 0), genColumn(mysql.TypeVarchar, 1)) require.NoError(t, err) client := new(mock.Client) - sc := new(stmtctx.StatementContext) + sc := stmtctx.NewStmtCtx() tps := []*types.FieldType{types.NewFieldType(mysql.TypeVarchar), types.NewFieldType(mysql.TypeVarchar)} for _, coll := range []string{charset.CollationBin, charset.CollationLatin1, charset.CollationUTF8, charset.CollationUTF8MB4} { @@ -1636,7 +1636,7 @@ func columnCollation(c *Column, chs, coll string) *Column { func TestNewCollationsEnabled(t *testing.T) { var colExprs []Expression - sc := new(stmtctx.StatementContext) + sc := stmtctx.NewStmtCtx() client := new(mock.Client) colExprs = colExprs[:0] @@ -1675,7 +1675,7 @@ func TestNewCollationsEnabled(t *testing.T) { } func TestMetadata(t *testing.T) { - sc := new(stmtctx.StatementContext) + sc := stmtctx.NewStmtCtx() client := new(mock.Client) require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/expression/PushDownTestSwitcher", `return("all")`)) @@ -1717,7 +1717,7 @@ func TestMetadata(t *testing.T) { func TestPushDownSwitcher(t *testing.T) { var funcs = make([]Expression, 0) - sc := new(stmtctx.StatementContext) + sc := stmtctx.NewStmtCtx() client := new(mock.Client) cases := []struct { @@ -1798,6 +1798,6 @@ func TestPanicIfPbCodeUnspecified(t *testing.T) { defer func() { require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/expression/PanicIfPbCodeUnspecified")) }() - pc := PbConverter{client: new(mock.Client), sc: new(stmtctx.StatementContext)} + pc := PbConverter{client: new(mock.Client), sc: stmtctx.NewStmtCtx()} require.PanicsWithError(t, "unspecified PbCode: *expression.builtinBitAndSig", func() { pc.ExprToPB(fn) }) } diff --git a/expression/expression_test.go b/expression/expression_test.go index abc82e7330794..15e28f2e50466 100644 --- a/expression/expression_test.go +++ b/expression/expression_test.go @@ -106,7 +106,7 @@ func TestEvaluateExprWithNullNoChangeRetType(t *testing.T) { func TestConstant(t *testing.T) { ctx := createContext(t) - sc := &stmtctx.StatementContext{TimeZone: time.Local} + sc := stmtctx.NewStmtCtxWithTimeZone(time.Local) require.False(t, NewZero().IsCorrelated()) require.True(t, NewZero().ConstItem(sc)) require.True(t, NewZero().Decorrelate(nil).Equal(ctx, NewZero())) diff --git a/expression/main_test.go b/expression/main_test.go index a2ba67229587e..e161a2a3d3778 100644 --- a/expression/main_test.go +++ b/expression/main_test.go @@ -57,7 +57,7 @@ func TestMain(m *testing.M) { func createContext(t *testing.T) *mock.Context { ctx := mock.NewContext() - ctx.GetSessionVars().StmtCtx.TimeZone = time.Local + ctx.GetSessionVars().StmtCtx.SetTimeZone(time.Local) sc := ctx.GetSessionVars().StmtCtx sc.TruncateAsWarning = true require.NoError(t, ctx.GetSessionVars().SetSystemVar("max_allowed_packet", "67108864")) diff --git a/expression/scalar_function_test.go b/expression/scalar_function_test.go index 5b7952d5ea283..cc8cca60ec95f 100644 --- a/expression/scalar_function_test.go +++ b/expression/scalar_function_test.go @@ -99,7 +99,7 @@ func TestScalarFunction(t *testing.T) { UniqueID: 1, RetType: types.NewFieldType(mysql.TypeDouble), } - sc := &stmtctx.StatementContext{TimeZone: time.Local} + sc := stmtctx.NewStmtCtxWithTimeZone(time.Local) sf := newFunction(ast.LT, a, NewOne()) res, err := sf.MarshalJSON() require.NoError(t, err) diff --git a/expression/util_test.go b/expression/util_test.go index 8340afde7b326..f71f31a02a0cf 100644 --- a/expression/util_test.go +++ b/expression/util_test.go @@ -333,7 +333,7 @@ func TestFilterOutInPlace(t *testing.T) { func TestHashGroupKey(t *testing.T) { ctx := mock.NewContext() - sc := &stmtctx.StatementContext{TimeZone: time.Local} + sc := stmtctx.NewStmtCtxWithTimeZone(time.Local) eTypes := []types.EvalType{types.ETInt, types.ETReal, types.ETDecimal, types.ETString, types.ETTimestamp, types.ETDatetime, types.ETDuration} tNames := []string{"int", "real", "decimal", "string", "timestamp", "datetime", "duration"} for i := 0; i < len(tNames); i++ { diff --git a/kv/key_test.go b/kv/key_test.go index 550b0dd93bd0a..528b5f7ec2f30 100644 --- a/kv/key_test.go +++ b/kv/key_test.go @@ -33,7 +33,7 @@ import ( ) func TestPartialNext(t *testing.T) { - sc := &stmtctx.StatementContext{TimeZone: time.Local} + sc := stmtctx.NewStmtCtxWithTimeZone(time.Local) // keyA represents a multi column index. keyA, err := codec.EncodeValue(sc, nil, types.NewDatum("abc"), types.NewDatum("def")) require.NoError(t, err) @@ -149,7 +149,7 @@ func TestHandle(t *testing.T) { func TestPaddingHandle(t *testing.T) { dec := types.NewDecFromInt(1) - encoded, err := codec.EncodeKey(new(stmtctx.StatementContext), nil, types.NewDecimalDatum(dec)) + encoded, err := codec.EncodeKey(stmtctx.NewStmtCtx(), nil, types.NewDecimalDatum(dec)) assert.Nil(t, err) assert.Less(t, len(encoded), 9) @@ -229,7 +229,7 @@ func TestHandleMapWithPartialHandle(t *testing.T) { m.Set(ih, 3) dec := types.NewDecFromInt(1) - encoded, err := codec.EncodeKey(new(stmtctx.StatementContext), nil, types.NewDecimalDatum(dec)) + encoded, err := codec.EncodeKey(stmtctx.NewStmtCtx(), nil, types.NewDecimalDatum(dec)) assert.Nil(t, err) assert.Less(t, len(encoded), 9) @@ -285,7 +285,7 @@ func TestMemAwareHandleMapWithPartialHandle(t *testing.T) { m.Set(ih, 3) dec := types.NewDecFromInt(1) - encoded, err := codec.EncodeKey(new(stmtctx.StatementContext), nil, types.NewDecimalDatum(dec)) + encoded, err := codec.EncodeKey(stmtctx.NewStmtCtx(), nil, types.NewDecimalDatum(dec)) assert.Nil(t, err) assert.Less(t, len(encoded), 9) diff --git a/planner/cardinality/selectivity_test.go b/planner/cardinality/selectivity_test.go index a19a61c371441..dcbd96054ca95 100644 --- a/planner/cardinality/selectivity_test.go +++ b/planner/cardinality/selectivity_test.go @@ -634,7 +634,7 @@ func generateIntDatum(dimension, num int) ([]types.Datum, error) { ret[i] = types.NewIntDatum(int64(i)) } } else { - sc := &stmtctx.StatementContext{TimeZone: time.Local} + sc := stmtctx.NewStmtCtxWithTimeZone(time.Local) // In this way, we can guarantee the datum is in order. for i := 0; i < length; i++ { data := make([]types.Datum, dimension) @@ -1007,7 +1007,7 @@ func TestOrderingIdxSelectivityThreshold(t *testing.T) { store, dom := testkit.CreateMockStoreAndDomain(t) testKit := testkit.NewTestKit(t, store) h := dom.StatsHandle() - sc := &stmtctx.StatementContext{TimeZone: time.UTC} + sc := stmtctx.NewStmtCtxWithTimeZone(time.UTC) testKit.MustExec("use test") testKit.MustExec("drop table if exists t") diff --git a/planner/core/memtable_predicate_extractor.go b/planner/core/memtable_predicate_extractor.go index 16de60bb67792..470f652400d28 100644 --- a/planner/core/memtable_predicate_extractor.go +++ b/planner/core/memtable_predicate_extractor.go @@ -711,11 +711,11 @@ func (e *ClusterLogTableExtractor) explainInfo(p *PhysicalMemTable) string { st, et := e.StartTime, e.EndTime if st > 0 { st := time.UnixMilli(st) - fmt.Fprintf(r, "start_time:%v, ", st.In(p.SCtx().GetSessionVars().StmtCtx.TimeZone).Format(MetricTableTimeFormat)) + fmt.Fprintf(r, "start_time:%v, ", st.In(p.SCtx().GetSessionVars().StmtCtx.TimeZone()).Format(MetricTableTimeFormat)) } if et > 0 { et := time.UnixMilli(et) - fmt.Fprintf(r, "end_time:%v, ", et.In(p.SCtx().GetSessionVars().StmtCtx.TimeZone).Format(MetricTableTimeFormat)) + fmt.Fprintf(r, "end_time:%v, ", et.In(p.SCtx().GetSessionVars().StmtCtx.TimeZone()).Format(MetricTableTimeFormat)) } if len(e.NodeTypes) > 0 { fmt.Fprintf(r, "node_types:[%s], ", extractStringFromStringSet(e.NodeTypes)) @@ -822,7 +822,7 @@ func (e *HotRegionsHistoryTableExtractor) Extract( e.HotRegionTypes.Insert(HotRegionTypeWrite) } - remained, startTime, endTime := e.extractTimeRange(ctx, schema, names, remained, "update_time", ctx.GetSessionVars().StmtCtx.TimeZone) + remained, startTime, endTime := e.extractTimeRange(ctx, schema, names, remained, "update_time", ctx.GetSessionVars().StmtCtx.TimeZone()) // The time unit for search hot regions is millisecond startTime = startTime / int64(time.Millisecond) endTime = endTime / int64(time.Millisecond) @@ -846,11 +846,11 @@ func (e *HotRegionsHistoryTableExtractor) explainInfo(p *PhysicalMemTable) strin st, et := e.StartTime, e.EndTime if st > 0 { st := time.UnixMilli(st) - fmt.Fprintf(r, "start_time:%v, ", st.In(p.SCtx().GetSessionVars().StmtCtx.TimeZone).Format(time.DateTime)) + fmt.Fprintf(r, "start_time:%v, ", st.In(p.SCtx().GetSessionVars().StmtCtx.TimeZone()).Format(time.DateTime)) } if et > 0 { et := time.UnixMilli(et) - fmt.Fprintf(r, "end_time:%v, ", et.In(p.SCtx().GetSessionVars().StmtCtx.TimeZone).Format(time.DateTime)) + fmt.Fprintf(r, "end_time:%v, ", et.In(p.SCtx().GetSessionVars().StmtCtx.TimeZone()).Format(time.DateTime)) } if len(e.RegionIDs) > 0 { fmt.Fprintf(r, "region_ids:[%s], ", extractStringFromUint64Slice(e.RegionIDs)) @@ -914,7 +914,7 @@ func (e *MetricTableExtractor) Extract( } // Extract the `time` columns - remained, startTime, endTime := e.extractTimeRange(ctx, schema, names, remained, "time", ctx.GetSessionVars().StmtCtx.TimeZone) + remained, startTime, endTime := e.extractTimeRange(ctx, schema, names, remained, "time", ctx.GetSessionVars().StmtCtx.TimeZone()) e.StartTime, e.EndTime = e.getTimeRange(startTime, endTime) e.SkipRequest = e.StartTime.After(e.EndTime) if e.SkipRequest { @@ -963,8 +963,8 @@ func (e *MetricTableExtractor) explainInfo(p *PhysicalMemTable) string { step := time.Second * time.Duration(p.SCtx().GetSessionVars().MetricSchemaStep) return fmt.Sprintf("PromQL:%v, start_time:%v, end_time:%v, step:%v", promQL, - startTime.In(p.SCtx().GetSessionVars().StmtCtx.TimeZone).Format(MetricTableTimeFormat), - endTime.In(p.SCtx().GetSessionVars().StmtCtx.TimeZone).Format(MetricTableTimeFormat), + startTime.In(p.SCtx().GetSessionVars().StmtCtx.TimeZone()).Format(MetricTableTimeFormat), + endTime.In(p.SCtx().GetSessionVars().StmtCtx.TimeZone()).Format(MetricTableTimeFormat), step, ) } @@ -1181,7 +1181,7 @@ func (e *SlowQueryExtractor) Extract( names []*types.FieldName, predicates []expression.Expression, ) []expression.Expression { - remained, startTime, endTime := e.extractTimeRange(ctx, schema, names, predicates, "time", ctx.GetSessionVars().StmtCtx.TimeZone) + remained, startTime, endTime := e.extractTimeRange(ctx, schema, names, predicates, "time", ctx.GetSessionVars().StmtCtx.TimeZone()) e.setTimeRange(startTime, endTime) e.SkipRequest = e.Enable && e.TimeRanges[0].StartTime.After(e.TimeRanges[0].EndTime) if e.SkipRequest { @@ -1323,8 +1323,8 @@ func (e *SlowQueryExtractor) explainInfo(p *PhysicalMemTable) string { if !e.Enable { return fmt.Sprintf("only search in the current '%v' file", p.SCtx().GetSessionVars().SlowQueryFile) } - startTime := e.TimeRanges[0].StartTime.In(p.SCtx().GetSessionVars().StmtCtx.TimeZone) - endTime := e.TimeRanges[0].EndTime.In(p.SCtx().GetSessionVars().StmtCtx.TimeZone) + startTime := e.TimeRanges[0].StartTime.In(p.SCtx().GetSessionVars().StmtCtx.TimeZone()) + endTime := e.TimeRanges[0].EndTime.In(p.SCtx().GetSessionVars().StmtCtx.TimeZone()) return fmt.Sprintf("start_time:%v, end_time:%v", types.NewTime(types.FromGoTime(startTime), mysql.TypeDatetime, types.MaxFsp).String(), types.NewTime(types.FromGoTime(endTime), mysql.TypeDatetime, types.MaxFsp).String()) @@ -1451,8 +1451,8 @@ func (e *StatementsSummaryExtractor) explainInfo(p *PhysicalMemTable) string { } if e.CoarseTimeRange != nil && p.SCtx().GetSessionVars() != nil && p.SCtx().GetSessionVars().StmtCtx != nil { stmtCtx := p.SCtx().GetSessionVars().StmtCtx - startTime := e.CoarseTimeRange.StartTime.In(stmtCtx.TimeZone) - endTime := e.CoarseTimeRange.EndTime.In(stmtCtx.TimeZone) + startTime := e.CoarseTimeRange.StartTime.In(stmtCtx.TimeZone()) + endTime := e.CoarseTimeRange.EndTime.In(stmtCtx.TimeZone()) startTimeStr := types.NewTime(types.FromGoTime(startTime), mysql.TypeDatetime, types.MaxFsp).String() endTimeStr := types.NewTime(types.FromGoTime(endTime), mysql.TypeDatetime, types.MaxFsp).String() fmt.Fprintf(buf, "start_time: %v, end_time: %v, ", startTimeStr, endTimeStr) @@ -1471,7 +1471,7 @@ func (e *StatementsSummaryExtractor) findCoarseTimeRange( names []*types.FieldName, predicates []expression.Expression, ) *TimeRange { - tz := sctx.GetSessionVars().StmtCtx.TimeZone + tz := sctx.GetSessionVars().StmtCtx.TimeZone() _, _, endTime := e.extractTimeRange(sctx, schema, names, predicates, "summary_begin_time", tz) _, startTime, _ := e.extractTimeRange(sctx, schema, names, predicates, "summary_end_time", tz) return e.buildTimeRange(startTime, endTime) diff --git a/planner/core/memtable_predicate_extractor_test.go b/planner/core/memtable_predicate_extractor_test.go index 7ecc0b1143733..9e7d3963cfdc4 100644 --- a/planner/core/memtable_predicate_extractor_test.go +++ b/planner/core/memtable_predicate_extractor_test.go @@ -634,7 +634,7 @@ func TestMetricTableExtractor(t *testing.T) { quantiles: []float64{0}, }, } - se.GetSessionVars().StmtCtx.TimeZone = time.Local + se.GetSessionVars().StmtCtx.SetTimeZone(time.Local) for _, ca := range cases { logicalMemTable := getLogicalMemTable(t, dom, se, parser, ca.sql) require.NotNil(t, logicalMemTable.Extractor) @@ -1048,7 +1048,7 @@ func TestTiDBHotRegionsHistoryTableExtractor(t *testing.T) { se, err := session.CreateSession4Test(store) require.NoError(t, err) - se.GetSessionVars().StmtCtx.TimeZone = time.Local + se.GetSessionVars().StmtCtx.SetTimeZone(time.Local) var cases = []struct { sql string diff --git a/planner/core/planbuilder.go b/planner/core/planbuilder.go index 7de39a5f3929c..61e5840fc7073 100644 --- a/planner/core/planbuilder.go +++ b/planner/core/planbuilder.go @@ -5574,7 +5574,7 @@ func calcTSForPlanReplayer(sctx sessionctx.Context, tsExpr ast.ExprNode) uint64 // a date/time string into a TSO. // To achieve this, we need to set fields like StatementContext.IgnoreTruncate to false, and maybe it's better // not to modify and reuse the original StatementContext, so we use a temporary one here. - tmpStmtCtx := &stmtctx.StatementContext{TimeZone: sctx.GetSessionVars().Location()} + tmpStmtCtx := stmtctx.NewStmtCtxWithTimeZone(sctx.GetSessionVars().Location()) tso, err := tsVal.ConvertTo(tmpStmtCtx, tpLonglong) if err == nil { return tso.GetUint64() diff --git a/server/extension.go b/server/extension.go index 9228b1fd5bc32..367946c6d7bdb 100644 --- a/server/extension.go +++ b/server/extension.go @@ -97,7 +97,7 @@ func (cc *clientConn) onExtensionStmtEnd(node interface{}, stmtCtxValid bool, er if stmtCtxValid { info.sc = sessVars.StmtCtx } else { - info.sc = &stmtctx.StatementContext{} + info.sc = stmtctx.NewStmtCtx() } cc.extensions.OnStmtEvent(tp, info) } diff --git a/server/handler/optimizor/statistics_handler.go b/server/handler/optimizor/statistics_handler.go index d4e4381f1c3ab..f56056f510134 100644 --- a/server/handler/optimizor/statistics_handler.go +++ b/server/handler/optimizor/statistics_handler.go @@ -110,7 +110,7 @@ func (sh StatsHistoryHandler) ServeHTTP(w http.ResponseWriter, req *http.Request return } - se.GetSessionVars().StmtCtx.TimeZone = time.Local + se.GetSessionVars().StmtCtx.SetTimeZone(time.Local) t, err := types.ParseTime(se.GetSessionVars().StmtCtx, params[handler.Snapshot], mysql.TypeTimestamp, 6, nil) if err != nil { handler.WriteError(w, err) diff --git a/server/handler/tests/http_handler_test.go b/server/handler/tests/http_handler_test.go index f7c395a015571..bbc5973b3351b 100644 --- a/server/handler/tests/http_handler_test.go +++ b/server/handler/tests/http_handler_test.go @@ -102,7 +102,7 @@ func TestRegionIndexRange(t *testing.T) { } expectIndexValues = append(expectIndexValues, str) } - encodedValue, err := codec.EncodeKey(&stmtctx.StatementContext{TimeZone: time.Local}, nil, indexValues...) + encodedValue, err := codec.EncodeKey(stmtctx.NewStmtCtxWithTimeZone(time.Local), nil, indexValues...) require.NoError(t, err) startKey := tablecodec.EncodeIndexSeekKey(sTableID, sIndex, encodedValue) @@ -169,7 +169,7 @@ func TestRegionCommonHandleRange(t *testing.T) { } expectIndexValues = append(expectIndexValues, str) } - encodedValue, err := codec.EncodeKey(&stmtctx.StatementContext{TimeZone: time.Local}, nil, indexValues...) + encodedValue, err := codec.EncodeKey(stmtctx.NewStmtCtxWithTimeZone(time.Local), nil, indexValues...) require.NoError(t, err) startKey := tablecodec.EncodeRowKey(sTableID, encodedValue) @@ -701,7 +701,7 @@ func TestDecodeColumnValue(t *testing.T) { colIDs = append(colIDs, col.id) } rd := rowcodec.Encoder{Enable: true} - sc := &stmtctx.StatementContext{TimeZone: time.UTC} + sc := stmtctx.NewStmtCtxWithTimeZone(time.UTC) bs, err := tablecodec.EncodeRow(sc, row, colIDs, nil, nil, &rd) require.NoError(t, err) require.NotNil(t, bs) diff --git a/server/handler/tikv_handler.go b/server/handler/tikv_handler.go index e494e1dd10dca..0839c901d0a9d 100644 --- a/server/handler/tikv_handler.go +++ b/server/handler/tikv_handler.go @@ -90,8 +90,8 @@ func (t *TikvHandlerTool) GetHandle(tb table.PhysicalTable, params map[string]st for _, idxCol := range pkIdx.Columns { pkCols = append(pkCols, cols[idxCol.Offset]) } - sc := new(stmtctx.StatementContext) - sc.TimeZone = time.UTC + sc := stmtctx.NewStmtCtx() + sc.SetTimeZone(time.UTC) pkDts, err := t.formValue2DatumRow(sc, values, pkCols) if err != nil { return nil, errors.Trace(err) @@ -112,10 +112,9 @@ func (t *TikvHandlerTool) GetHandle(tb table.PhysicalTable, params map[string]st // GetMvccByIdxValue gets the mvcc by the index value. func (t *TikvHandlerTool) GetMvccByIdxValue(idx table.Index, values url.Values, idxCols []*model.ColumnInfo, handle kv.Handle) ([]*helper.MvccKV, error) { - sc := new(stmtctx.StatementContext) // HTTP request is not a database session, set timezone to UTC directly here. // See https://github.com/pingcap/tidb/blob/master/docs/tidb_http_api.md for more details. - sc.TimeZone = time.UTC + sc := stmtctx.NewStmtCtxWithTimeZone(time.UTC) idxRow, err := t.formValue2DatumRow(sc, values, idxCols) if err != nil { return nil, errors.Trace(err) diff --git a/server/internal/column/column_test.go b/server/internal/column/column_test.go index edd9355f1207b..b918b88f140ef 100644 --- a/server/internal/column/column_test.go +++ b/server/internal/column/column_test.go @@ -185,7 +185,7 @@ func TestDumpTextValue(t *testing.T) { sc.IgnoreZeroInDate = true losAngelesTz, err := time.LoadLocation("America/Los_Angeles") require.NoError(t, err) - sc.TimeZone = losAngelesTz + sc.SetTimeZone(losAngelesTz) time, err := types.ParseTime(sc, "2017-01-05 23:59:59.575601", mysql.TypeDatetime, 0, nil) require.NoError(t, err) diff --git a/server/internal/dump/dump_test.go b/server/internal/dump/dump_test.go index b75b41ba610ea..37d1e4e6909aa 100644 --- a/server/internal/dump/dump_test.go +++ b/server/internal/dump/dump_test.go @@ -24,13 +24,13 @@ import ( ) func TestDumpBinaryTime(t *testing.T) { - sc := &stmtctx.StatementContext{TimeZone: time.UTC} + sc := stmtctx.NewStmtCtxWithTimeZone(time.UTC) parsedTime, err := types.ParseTimestamp(sc, "0000-00-00 00:00:00.000000") require.NoError(t, err) d := BinaryDateTime(nil, parsedTime) require.Equal(t, []byte{0}, d) - parsedTime, err = types.ParseTimestamp(&stmtctx.StatementContext{TimeZone: time.Local}, "1991-05-01 01:01:01.100001") + parsedTime, err = types.ParseTimestamp(stmtctx.NewStmtCtxWithTimeZone(time.Local), "1991-05-01 01:01:01.100001") require.NoError(t, err) d = BinaryDateTime(nil, parsedTime) // 199 & 7 composed to uint16 1991 (litter-endian) diff --git a/server/internal/parse/parse_test.go b/server/internal/parse/parse_test.go index 3b9a2d12e857a..01a1a7efc48b1 100644 --- a/server/internal/parse/parse_test.go +++ b/server/internal/parse/parse_test.go @@ -221,7 +221,7 @@ func TestParseExecArgs(t *testing.T) { }, } for _, tt := range tests { - err := ExecArgs(&stmtctx.StatementContext{}, tt.args.args, tt.args.boundParams, tt.args.nullBitmap, tt.args.paramTypes, tt.args.paramValues, nil) + err := ExecArgs(stmtctx.NewStmtCtx(), tt.args.args, tt.args.boundParams, tt.args.nullBitmap, tt.args.paramTypes, tt.args.paramValues, nil) require.Truef(t, terror.ErrorEqual(err, tt.err), "err %v", err) if err == nil { require.Equal(t, tt.expect, tt.args.args[0].(*expression.Constant).Value.GetValue()) @@ -231,7 +231,7 @@ func TestParseExecArgs(t *testing.T) { func TestParseExecArgsAndEncode(t *testing.T) { dt := expression.Args2Expressions4Test(1) - err := ExecArgs(&stmtctx.StatementContext{}, + err := ExecArgs(stmtctx.NewStmtCtx(), dt, [][]byte{nil}, []byte{0x0}, @@ -241,7 +241,7 @@ func TestParseExecArgsAndEncode(t *testing.T) { require.NoError(t, err) require.Equal(t, "测试", dt[0].(*expression.Constant).Value.GetValue()) - err = ExecArgs(&stmtctx.StatementContext{}, + err = ExecArgs(stmtctx.NewStmtCtx(), dt, [][]byte{{178, 226, 202, 212}}, []byte{0x0}, diff --git a/sessionctx/stmtctx/BUILD.bazel b/sessionctx/stmtctx/BUILD.bazel index 6134a44ab5f9f..e718ddc260834 100644 --- a/sessionctx/stmtctx/BUILD.bazel +++ b/sessionctx/stmtctx/BUILD.bazel @@ -16,7 +16,9 @@ go_library( "//types/context", "//util/disk", "//util/execdetails", + "//util/intest", "//util/memory", + "//util/nocopy", "//util/resourcegrouptag", "//util/topsql/stmtstats", "//util/tracing", @@ -38,12 +40,14 @@ go_test( ], embed = [":stmtctx"], flaky = True, - shard_count = 6, + shard_count = 10, deps = [ "//kv", "//sessionctx/variable", "//testkit", "//testkit/testsetup", + "//types", + "//types/context", "//util/execdetails", "@com_github_pingcap_errors//:errors", "@com_github_stretchr_testify//require", diff --git a/sessionctx/stmtctx/stmtctx.go b/sessionctx/stmtctx/stmtctx.go index 347657958fa4f..a2fd9df291da8 100644 --- a/sessionctx/stmtctx/stmtctx.go +++ b/sessionctx/stmtctx/stmtctx.go @@ -38,7 +38,9 @@ import ( typectx "github.com/pingcap/tidb/types/context" "github.com/pingcap/tidb/util/disk" "github.com/pingcap/tidb/util/execdetails" + "github.com/pingcap/tidb/util/intest" "github.com/pingcap/tidb/util/memory" + "github.com/pingcap/tidb/util/nocopy" "github.com/pingcap/tidb/util/resourcegrouptag" "github.com/pingcap/tidb/util/topsql/stmtstats" "github.com/pingcap/tidb/util/tracing" @@ -147,12 +149,16 @@ func (rf *ReferenceCount) UnFreeze() { // StatementContext contains variables for a statement. // It should be reset before executing a statement. type StatementContext struct { + // NoCopy indicates that this struct cannot be copied because + // copying this object will make the new TypeCtx referring a wrong `AppendWarnings` func. + _ nocopy.NoCopy + + // TypeCtx is used to indicate how make the type conversation. + TypeCtx typectx.Context + // Set the following variables before execution StmtHints - // TypeConvContext is used to indicate how make the type conversation. - TypeConvContext typectx.Context - // IsDDLJobInQueue is used to mark whether the DDL job is put into the queue. // If IsDDLJobInQueue is true, it means the DDL job is in the queue of storage, and it can be handled by the DDL worker. IsDDLJobInQueue bool @@ -243,7 +249,6 @@ type StatementContext struct { MaxRowID int64 // Copied from SessionVars.TimeZone. - TimeZone *time.Location Priority mysql.PriorityEnum NotFillCache bool MemTracker *memory.Tracker @@ -420,6 +425,54 @@ type StatementContext struct { } } +// NewStmtCtx creates a new statement context +func NewStmtCtx() *StatementContext { + sc := &StatementContext{} + sc.TypeCtx = typectx.NewContext(typectx.StrictFlags, time.UTC, sc.AppendWarning) + return sc +} + +// NewStmtCtxWithTimeZone creates a new StatementContext with the given timezone +func NewStmtCtxWithTimeZone(tz *time.Location) *StatementContext { + intest.Assert(tz) + sc := &StatementContext{} + sc.TypeCtx = typectx.NewContext(typectx.StrictFlags, tz, sc.AppendWarning) + return sc +} + +// Reset resets a statement context +func (sc *StatementContext) Reset() { + *sc = StatementContext{ + TypeCtx: typectx.NewContext(typectx.StrictFlags, time.UTC, sc.AppendWarning), + } +} + +// TimeZone returns the timezone of the type context +func (sc *StatementContext) TimeZone() *time.Location { + return sc.TypeCtx.Location() +} + +// SetTimeZone sets the timezone +func (sc *StatementContext) SetTimeZone(tz *time.Location) { + intest.Assert(tz) + sc.TypeCtx = sc.TypeCtx.WithLocation(tz) +} + +// TypeFlags returns the type flags +func (sc *StatementContext) TypeFlags() typectx.Flags { + return sc.TypeCtx.Flags() +} + +// SetTypeFlags sets the type flags +func (sc *StatementContext) SetTypeFlags(flags typectx.Flags) { + sc.TypeCtx = sc.TypeCtx.WithFlags(flags) +} + +func (sc *StatementContext) UpdateTypeFlags(fn func(typectx.Flags) typectx.Flags) { + flags := fn(sc.TypeCtx.Flags()) + sc.TypeCtx = sc.TypeCtx.WithFlags(flags) +} + // StmtHints are SessionVars related sql hints. type StmtHints struct { // Hint Information diff --git a/sessionctx/stmtctx/stmtctx_test.go b/sessionctx/stmtctx/stmtctx_test.go index 9a3951278befc..dc45f464c27ee 100644 --- a/sessionctx/stmtctx/stmtctx_test.go +++ b/sessionctx/stmtctx/stmtctx_test.go @@ -18,6 +18,7 @@ import ( "context" "encoding/json" "fmt" + typectx "github.com/pingcap/tidb/types/context" "math/rand" "reflect" "sort" @@ -29,6 +30,7 @@ import ( "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/testkit" + "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/execdetails" "github.com/stretchr/testify/require" "github.com/tikv/client-go/v2/util" @@ -36,7 +38,7 @@ import ( ) func TestCopTasksDetails(t *testing.T) { - ctx := new(stmtctx.StatementContext) + ctx := stmtctx.NewStmtCtx() backoffs := []string{"tikvRPC", "pdRPC", "regionMiss"} for i := 0; i < 100; i++ { d := &execdetails.ExecDetails{ @@ -79,23 +81,39 @@ func TestCopTasksDetails(t *testing.T) { } func TestStatementContextPushDownFLags(t *testing.T) { + newStmtCtx := func(fn func(*stmtctx.StatementContext)) *stmtctx.StatementContext { + sc := stmtctx.NewStmtCtx() + fn(sc) + return sc + } + testCases := []struct { in *stmtctx.StatementContext out uint64 }{ - {&stmtctx.StatementContext{InInsertStmt: true}, 8}, - {&stmtctx.StatementContext{InUpdateStmt: true}, 16}, - {&stmtctx.StatementContext{InDeleteStmt: true}, 16}, - {&stmtctx.StatementContext{InSelectStmt: true}, 32}, - {&stmtctx.StatementContext{IgnoreTruncate: *atomic.NewBool(true)}, 1}, - {&stmtctx.StatementContext{TruncateAsWarning: true}, 2}, - {&stmtctx.StatementContext{OverflowAsWarning: true}, 64}, - {&stmtctx.StatementContext{IgnoreZeroInDate: true}, 128}, - {&stmtctx.StatementContext{DividedByZeroAsWarning: true}, 256}, - {&stmtctx.StatementContext{InLoadDataStmt: true}, 1024}, - {&stmtctx.StatementContext{InSelectStmt: true, TruncateAsWarning: true}, 34}, - {&stmtctx.StatementContext{DividedByZeroAsWarning: true, IgnoreTruncate: *atomic.NewBool(true)}, 257}, - {&stmtctx.StatementContext{InUpdateStmt: true, IgnoreZeroInDate: true, InLoadDataStmt: true}, 1168}, + {newStmtCtx(func(sc *stmtctx.StatementContext) { sc.InInsertStmt = true }), 8}, + {newStmtCtx(func(sc *stmtctx.StatementContext) { sc.InUpdateStmt = true }), 16}, + {newStmtCtx(func(sc *stmtctx.StatementContext) { sc.InDeleteStmt = true }), 16}, + {newStmtCtx(func(sc *stmtctx.StatementContext) { sc.InSelectStmt = true }), 32}, + {newStmtCtx(func(sc *stmtctx.StatementContext) { sc.IgnoreTruncate = *atomic.NewBool(true) }), 1}, + {newStmtCtx(func(sc *stmtctx.StatementContext) { sc.TruncateAsWarning = true }), 2}, + {newStmtCtx(func(sc *stmtctx.StatementContext) { sc.OverflowAsWarning = true }), 64}, + {newStmtCtx(func(sc *stmtctx.StatementContext) { sc.IgnoreZeroInDate = true }), 128}, + {newStmtCtx(func(sc *stmtctx.StatementContext) { sc.DividedByZeroAsWarning = true }), 256}, + {newStmtCtx(func(sc *stmtctx.StatementContext) { sc.InLoadDataStmt = true }), 1024}, + {newStmtCtx(func(sc *stmtctx.StatementContext) { + sc.InSelectStmt = true + sc.TruncateAsWarning = true + }), 34}, + {newStmtCtx(func(sc *stmtctx.StatementContext) { + sc.DividedByZeroAsWarning = true + sc.IgnoreTruncate = *atomic.NewBool(true) + }), 257}, + {newStmtCtx(func(sc *stmtctx.StatementContext) { + sc.InUpdateStmt = true + sc.IgnoreZeroInDate = true + sc.InLoadDataStmt = true + }), 1168}, } for _, tt := range testCases { got := tt.in.PushDownFlags() @@ -222,7 +240,7 @@ func TestApproxRuntimeInfo(t *testing.T) { details[rand.Intn(n)].BackoffSleep[backoff] = time.Millisecond * 100 * time.Duration(valRange) } - ctx := new(stmtctx.StatementContext) + ctx := stmtctx.NewStmtCtx() for i := 0; i < n; i++ { ctx.MergeExecDetails(details[i], nil) } @@ -296,3 +314,83 @@ func TestStmtHintsClone(t *testing.T) { } require.Equal(t, hints, *hints.Clone()) } + +func TestNewStmtCtx(t *testing.T) { + sc := stmtctx.NewStmtCtx() + require.Equal(t, types.StrictFlags, sc.TypeCtx.Flags()) + require.Same(t, time.UTC, sc.TypeCtx.Location()) + require.Same(t, time.UTC, sc.TimeZone()) + sc.TypeCtx.AppendWarning(errors.New("err1")) + warnings := sc.GetWarnings() + require.Equal(t, 1, len(warnings)) + require.Equal(t, stmtctx.WarnLevelWarning, warnings[0].Level) + require.Equal(t, "err1", warnings[0].Err.Error()) + + tz := time.FixedZone("UTC+1", 2*60*60) + sc = stmtctx.NewStmtCtxWithTimeZone(tz) + require.Equal(t, types.StrictFlags, sc.TypeCtx.Flags()) + require.Same(t, tz, sc.TypeCtx.Location()) + require.Same(t, tz, sc.TimeZone()) + sc.TypeCtx.AppendWarning(errors.New("err2")) + warnings = sc.GetWarnings() + require.Equal(t, 1, len(warnings)) + require.Equal(t, stmtctx.WarnLevelWarning, warnings[0].Level) + require.Equal(t, "err2", warnings[0].Err.Error()) +} + +func TestSetStmtCtxTimeZone(t *testing.T) { + sc := stmtctx.NewStmtCtx() + require.Same(t, time.UTC, sc.TypeCtx.Location()) + tz := time.FixedZone("UTC+1", 2*60*60) + sc.SetTimeZone(tz) + require.Same(t, tz, sc.TypeCtx.Location()) +} + +func TestSetStmtCtxTypeFlags(t *testing.T) { + sc := stmtctx.NewStmtCtx() + require.Equal(t, types.StrictFlags, sc.TypeCtx.Flags()) + + sc.SetTypeFlags(typectx.FlagClipNegativeToZero | typectx.FlagSkipASCIICheck) + require.Equal(t, typectx.FlagClipNegativeToZero|typectx.FlagSkipASCIICheck, sc.TypeFlags()) + require.Equal(t, sc.TypeFlags(), sc.TypeCtx.Flags()) + + sc.SetTypeFlags(typectx.FlagSkipASCIICheck | typectx.FlagSkipUTF8Check | typectx.FlagInvalidDateAsWarning) + require.Equal(t, typectx.FlagSkipASCIICheck|typectx.FlagSkipUTF8Check|typectx.FlagInvalidDateAsWarning, sc.TypeFlags()) + require.Equal(t, sc.TypeFlags(), sc.TypeCtx.Flags()) + + sc.UpdateTypeFlags(func(flags typectx.Flags) typectx.Flags { + return (flags | typectx.FlagSkipUTF8Check | typectx.FlagClipNegativeToZero) &^ typectx.FlagSkipASCIICheck + }) + require.Equal(t, typectx.FlagSkipUTF8Check|typectx.FlagClipNegativeToZero|typectx.FlagInvalidDateAsWarning, sc.TypeFlags()) + require.Equal(t, sc.TypeFlags(), sc.TypeCtx.Flags()) +} + +func TestResetStmtCtx(t *testing.T) { + sc := stmtctx.NewStmtCtx() + require.Equal(t, types.StrictFlags, sc.TypeCtx.Flags()) + + tz := time.FixedZone("UTC+1", 2*60*60) + sc.SetTimeZone(tz) + sc.SetTypeFlags(typectx.FlagClipNegativeToZero | typectx.FlagSkipASCIICheck) + sc.AppendWarning(errors.New("err1")) + sc.InRestrictedSQL = true + sc.StmtType = "Insert" + + require.Same(t, tz, sc.TimeZone()) + require.Equal(t, typectx.FlagClipNegativeToZero|typectx.FlagSkipASCIICheck, sc.TypeFlags()) + require.Equal(t, 1, len(sc.GetWarnings())) + + sc.Reset() + require.Same(t, time.UTC, sc.TimeZone()) + require.Same(t, time.UTC, sc.TypeCtx.Location()) + require.Equal(t, types.StrictFlags, sc.TypeFlags()) + require.Equal(t, types.StrictFlags, sc.TypeCtx.Flags()) + require.False(t, sc.InRestrictedSQL) + require.Empty(t, sc.StmtType) + require.Equal(t, 0, len(sc.GetWarnings())) + sc.AppendWarning(errors.New("err2")) + warnings := sc.GetWarnings() + require.Equal(t, 1, len(warnings)) + require.Equal(t, stmtctx.WarnLevelWarning, warnings[0].Level) + require.Equal(t, "err2", warnings[0].Err.Error()) +} diff --git a/sessionctx/variable/session.go b/sessionctx/variable/session.go index 398f7dc52c6f2..43f7d1787fc3a 100644 --- a/sessionctx/variable/session.go +++ b/sessionctx/variable/session.go @@ -1677,10 +1677,10 @@ func (s *SessionVars) InitStatementContext() *stmtctx.StatementContext { sc = &s.cachedStmtCtx[1] } if s.RefCountOfStmtCtx.TryFreeze() { - *sc = stmtctx.StatementContext{} + sc.Reset() s.RefCountOfStmtCtx.UnFreeze() } else { - sc = &stmtctx.StatementContext{} + sc = stmtctx.NewStmtCtx() } return sc } @@ -1930,7 +1930,7 @@ func NewSessionVars(hctx HookContext) *SessionVars { AutoIncrementIncrement: DefAutoIncrementIncrement, AutoIncrementOffset: DefAutoIncrementOffset, Status: mysql.ServerStatusAutocommit, - StmtCtx: new(stmtctx.StatementContext), + StmtCtx: stmtctx.NewStmtCtx(), AllowAggPushDown: false, AllowCartesianBCJ: DefOptCartesianBCJ, MPPOuterJoinFixedBuildSide: DefOptMPPOuterJoinFixedBuildSide, diff --git a/statistics/cmsketch_test.go b/statistics/cmsketch_test.go index 84b77e1822dd5..c010c6de6cb44 100644 --- a/statistics/cmsketch_test.go +++ b/statistics/cmsketch_test.go @@ -258,7 +258,7 @@ func TestCMSketchCodingTopN(t *testing.T) { func TestMergePartTopN2GlobalTopNWithoutHists(t *testing.T) { loc := time.UTC - sc := &stmtctx.StatementContext{TimeZone: loc} + sc := stmtctx.NewStmtCtxWithTimeZone(loc) version := 1 isKilled := uint32(0) @@ -303,7 +303,7 @@ func TestSortTopnMeta(t *testing.T) { func TestMergePartTopN2GlobalTopNWithHists(t *testing.T) { loc := time.UTC - sc := &stmtctx.StatementContext{TimeZone: loc} + sc := stmtctx.NewStmtCtxWithTimeZone(loc) version := 1 isKilled := uint32(0) diff --git a/statistics/fmsketch_test.go b/statistics/fmsketch_test.go index c8f9caca7b8cd..d714c25b6cbda 100644 --- a/statistics/fmsketch_test.go +++ b/statistics/fmsketch_test.go @@ -48,7 +48,7 @@ func buildFMSketch(sc *stmtctx.StatementContext, values []types.Datum, maxSize i func SubTestSketch() func(*testing.T) { return func(t *testing.T) { s := createTestStatisticsSamples(t) - sc := &stmtctx.StatementContext{TimeZone: time.Local} + sc := stmtctx.NewStmtCtxWithTimeZone(time.Local) maxSize := 1000 sampleSketch, ndv, err := buildFMSketch(sc, extractSampleItemsDatums(s.samples), maxSize) require.NoError(t, err) @@ -79,7 +79,7 @@ func SubTestSketch() func(*testing.T) { func SubTestSketchProtoConversion() func(*testing.T) { return func(t *testing.T) { s := createTestStatisticsSamples(t) - sc := &stmtctx.StatementContext{TimeZone: time.Local} + sc := stmtctx.NewStmtCtxWithTimeZone(time.Local) maxSize := 1000 sampleSketch, ndv, err := buildFMSketch(sc, extractSampleItemsDatums(s.samples), maxSize) require.NoError(t, err) @@ -98,7 +98,7 @@ func SubTestSketchProtoConversion() func(*testing.T) { func SubTestFMSketchCoding() func(*testing.T) { return func(t *testing.T) { s := createTestStatisticsSamples(t) - sc := &stmtctx.StatementContext{TimeZone: time.Local} + sc := stmtctx.NewStmtCtxWithTimeZone(time.Local) maxSize := 1000 sampleSketch, ndv, err := buildFMSketch(sc, extractSampleItemsDatums(s.samples), maxSize) require.NoError(t, err) diff --git a/statistics/handle/bootstrap.go b/statistics/handle/bootstrap.go index 85120ef2f9bdb..c32d4f97d8cb0 100644 --- a/statistics/handle/bootstrap.go +++ b/statistics/handle/bootstrap.go @@ -405,7 +405,9 @@ func (*Handle) initStatsBuckets4Chunk(cache *cache.StatsCache, iter *chunk.Itera d := types.NewBytesDatum(row.GetBytes(5)) // Setting TimeZone to time.UTC aligns with HistogramFromStorage and can fix #41938. However, #41985 still exist. // TODO: do the correct time zone conversion for timestamp-type columns' upper/lower bounds. - sc := &stmtctx.StatementContext{TimeZone: time.UTC, AllowInvalidDate: true, IgnoreZeroInDate: true} + sc := stmtctx.NewStmtCtxWithTimeZone(time.UTC) + sc.AllowInvalidDate = true + sc.IgnoreZeroInDate = true var err error lower, err = d.ConvertTo(sc, &column.Info.FieldType) if err != nil { diff --git a/statistics/handle/globalstats/global_stats_async.go b/statistics/handle/globalstats/global_stats_async.go index bf4e1cfed57f5..8e616ea444416 100644 --- a/statistics/handle/globalstats/global_stats_async.go +++ b/statistics/handle/globalstats/global_stats_async.go @@ -296,7 +296,7 @@ func (a *AsyncMergePartitionStats2GlobalStats) MergePartitionStats2GlobalStats( isIndex bool, ) error { a.skipMissingPartitionStats = sctx.GetSessionVars().SkipMissingPartitionStats - tz := sctx.GetSessionVars().StmtCtx.TimeZone + tz := sctx.GetSessionVars().StmtCtx.TimeZone() analyzeVersion := sctx.GetSessionVars().AnalyzeVersion stmtCtx := sctx.GetSessionVars().StmtCtx return a.callWithSCtxFunc( diff --git a/statistics/handle/globalstats/topn_bench_test.go b/statistics/handle/globalstats/topn_bench_test.go index 0269d260db680..25c73095a3019 100644 --- a/statistics/handle/globalstats/topn_bench_test.go +++ b/statistics/handle/globalstats/topn_bench_test.go @@ -32,7 +32,7 @@ import ( // cmd: go test -run=^$ -bench=BenchmarkMergePartTopN2GlobalTopNWithHists -benchmem github.com/pingcap/tidb/statistics/handle/globalstats func benchmarkMergePartTopN2GlobalTopNWithHists(partitions int, b *testing.B) { loc := time.UTC - sc := &stmtctx.StatementContext{TimeZone: loc} + sc := stmtctx.NewStmtCtxWithTimeZone(loc) version := 1 isKilled := uint32(0) @@ -83,7 +83,7 @@ func benchmarkMergePartTopN2GlobalTopNWithHists(partitions int, b *testing.B) { // cmd: go test -run=^$ -bench=BenchmarkMergeGlobalStatsTopNByConcurrencyWithHists -benchmem github.com/pingcap/tidb/statistics/handle/globalstats func benchmarkMergeGlobalStatsTopNByConcurrencyWithHists(partitions int, b *testing.B) { loc := time.UTC - sc := &stmtctx.StatementContext{TimeZone: loc} + sc := stmtctx.NewStmtCtxWithTimeZone(loc) version := 1 isKilled := uint32(0) diff --git a/statistics/handle/handle_hist_test.go b/statistics/handle/handle_hist_test.go index 5e7c0a095b869..3f8f65741be68 100644 --- a/statistics/handle/handle_hist_test.go +++ b/statistics/handle/handle_hist_test.go @@ -79,7 +79,7 @@ func TestConcurrentLoadHist(t *testing.T) { hg = stat.Columns[tableInfo.Columns[2].ID].Histogram topn = stat.Columns[tableInfo.Columns[2].ID].TopN require.Equal(t, 0, hg.Len()+topn.Num()) - stmtCtx := &stmtctx.StatementContext{} + stmtCtx := stmtctx.NewStmtCtx() neededColumns := make([]model.TableItemID, 0, len(tableInfo.Columns)) for _, col := range tableInfo.Columns { neededColumns = append(neededColumns, model.TableItemID{TableID: tableInfo.ID, ID: col.ID, IsIndex: false}) @@ -124,7 +124,7 @@ func TestConcurrentLoadHistTimeout(t *testing.T) { hg = stat.Columns[tableInfo.Columns[2].ID].Histogram topn = stat.Columns[tableInfo.Columns[2].ID].TopN require.Equal(t, 0, hg.Len()+topn.Num()) - stmtCtx := &stmtctx.StatementContext{} + stmtCtx := stmtctx.NewStmtCtx() neededColumns := make([]model.TableItemID, 0, len(tableInfo.Columns)) for _, col := range tableInfo.Columns { neededColumns = append(neededColumns, model.TableItemID{TableID: tableInfo.ID, ID: col.ID, IsIndex: false}) @@ -195,9 +195,9 @@ func TestConcurrentLoadHistWithPanicAndFail(t *testing.T) { topn := stat.Columns[tableInfo.Columns[2].ID].TopN require.Equal(t, 0, hg.Len()+topn.Num()) - stmtCtx1 := &stmtctx.StatementContext{} + stmtCtx1 := stmtctx.NewStmtCtx() h.SendLoadRequests(stmtCtx1, neededColumns, timeout) - stmtCtx2 := &stmtctx.StatementContext{} + stmtCtx2 := stmtctx.NewStmtCtx() h.SendLoadRequests(stmtCtx2, neededColumns, timeout) exitCh := make(chan struct{}) diff --git a/statistics/handle/storage/json.go b/statistics/handle/storage/json.go index 1749cafd0db4a..da8bc786f0873 100644 --- a/statistics/handle/storage/json.go +++ b/statistics/handle/storage/json.go @@ -135,7 +135,7 @@ func GenJSONTableFromStats(dbName string, tableInfo *model.TableInfo, tbl *stati Version: tbl.Version, } for _, col := range tbl.Columns { - sc := &stmtctx.StatementContext{TimeZone: time.UTC} + sc := stmtctx.NewStmtCtxWithTimeZone(time.UTC) hist, err := col.ConvertTo(sc, types.NewFieldType(mysql.TypeBlob)) if err != nil { return nil, errors.Trace(err) @@ -198,7 +198,7 @@ func TableStatsFromJSON(tableInfo *model.TableInfo, physicalID int64, jsonTbl *J continue } hist := statistics.HistogramFromProto(jsonCol.Histogram) - sc := &stmtctx.StatementContext{TimeZone: time.UTC} + sc := stmtctx.NewStmtCtxWithTimeZone(time.UTC) tmpFT := colInfo.FieldType // For new collation data, when storing the bounds of the histogram, we store the collate key instead of the // original value. diff --git a/statistics/handle/storage/read.go b/statistics/handle/storage/read.go index b03be31699fc2..7e5979432637f 100644 --- a/statistics/handle/storage/read.go +++ b/statistics/handle/storage/read.go @@ -73,7 +73,9 @@ func HistogramFromStorage(sctx sessionctx.Context, tableID int64, colID int64, t } else { // Invalid date values may be inserted into table under some relaxed sql mode. Those values may exist in statistics. // Hence, when reading statistics, we should skip invalid date check. See #39336. - sc := &stmtctx.StatementContext{TimeZone: time.UTC, AllowInvalidDate: true, IgnoreZeroInDate: true} + sc := stmtctx.NewStmtCtxWithTimeZone(time.UTC) + sc.AllowInvalidDate = true + sc.IgnoreZeroInDate = true d := rows[i].GetDatum(2, &fields[2].Column.FieldType) // For new collation data, when storing the bounds of the histogram, we store the collate key instead of the // original value. diff --git a/statistics/histogram.go b/statistics/histogram.go index 0476bc858d691..5671afaf46004 100644 --- a/statistics/histogram.go +++ b/statistics/histogram.go @@ -1420,7 +1420,7 @@ func MergePartitionHist2GlobalHist(sc *stmtctx.StatementContext, hists []*Histog for _, meta := range popedTopN { totCount += int64(meta.Count) - d, err := topNMetaToDatum(meta, hists[0].Tp.GetType(), isIndex, sc.TimeZone) + d, err := topNMetaToDatum(meta, hists[0].Tp.GetType(), isIndex, sc.TimeZone()) if err != nil { return nil, err } diff --git a/statistics/main_test.go b/statistics/main_test.go index c5eece51b7b3f..8d61b87ff0842 100644 --- a/statistics/main_test.go +++ b/statistics/main_test.go @@ -103,7 +103,7 @@ func createTestStatisticsSamples(t *testing.T) *testStatisticsSamples { for i := start; i < len(samples); i += 5 { samples[i].Value.SetInt64(samples[i].Value.GetInt64() + 2) } - sc := new(stmtctx.StatementContext) + sc := stmtctx.NewStmtCtx() var err error s.samples, err = SortSampleItems(sc, samples) diff --git a/statistics/row_sampler.go b/statistics/row_sampler.go index 58336cfefea96..3764277cb6016 100644 --- a/statistics/row_sampler.go +++ b/statistics/row_sampler.go @@ -190,7 +190,7 @@ func (s *RowSampleBuilder) Collect() (RowSampleCollector, error) { for i, val := range datums { // For string values, we use the collation key instead of the original value. if s.Collators[i] != nil && !val.IsNull() { - decodedVal, err := tablecodec.DecodeColumnValue(val.GetBytes(), s.ColsFieldType[i], s.Sc.TimeZone) + decodedVal, err := tablecodec.DecodeColumnValue(val.GetBytes(), s.ColsFieldType[i], s.Sc.TimeZone()) if err != nil { return nil, err } diff --git a/statistics/sample.go b/statistics/sample.go index 0e37cd7e5de98..8b843614caa3d 100644 --- a/statistics/sample.go +++ b/statistics/sample.go @@ -254,7 +254,7 @@ func (s SampleBuilder) CollectColumnStats() ([]*SampleCollector, *SortedBuilder, } for i, val := range datums { if s.Collators[i] != nil && !val.IsNull() { - decodedVal, err := tablecodec.DecodeColumnValue(val.GetBytes(), s.ColsFieldType[i], s.Sc.TimeZone) + decodedVal, err := tablecodec.DecodeColumnValue(val.GetBytes(), s.ColsFieldType[i], s.Sc.TimeZone()) if err != nil { return nil, nil, err } diff --git a/statistics/sample_test.go b/statistics/sample_test.go index 7e3058edcd537..fb69f0a8e77d1 100644 --- a/statistics/sample_test.go +++ b/statistics/sample_test.go @@ -272,7 +272,7 @@ func SubTestMergeSampleCollector(s *testSampleSuite) func(*testing.T) { ColsFieldType: []*types.FieldType{types.NewFieldType(mysql.TypeLonglong), types.NewFieldType(mysql.TypeLonglong)}, } require.Nil(t, s.rs.Close()) - sc := &stmtctx.StatementContext{TimeZone: time.Local} + sc := stmtctx.NewStmtCtxWithTimeZone(time.Local) collectors, pkBuilder, err := builder.CollectColumnStats() require.NoError(t, err) require.Nil(t, pkBuilder) diff --git a/statistics/scalar.go b/statistics/scalar.go index 505eba014ce5b..323a05ef0a112 100644 --- a/statistics/scalar.go +++ b/statistics/scalar.go @@ -73,7 +73,7 @@ func convertDatumToScalar(value *types.Datum, commonPfxLen int) float64 { case mysql.TypeTimestamp: minTime = types.MinTimestamp } - sc := &stmtctx.StatementContext{TimeZone: types.BoundTimezone} + sc := stmtctx.NewStmtCtxWithTimeZone(types.BoundTimezone) return float64(valueTime.Sub(sc, &minTime).Duration) case types.KindString, types.KindBytes: bytes := value.GetBytes() @@ -276,7 +276,7 @@ func EnumRangeValues(low, high types.Datum, lowExclude, highExclude bool) []type } fsp := mathutil.Max(lowTime.Fsp(), highTime.Fsp()) var stepSize int64 - sc := &stmtctx.StatementContext{TimeZone: time.UTC} + sc := stmtctx.NewStmtCtxWithTimeZone(time.UTC) if lowTime.Type() == mysql.TypeDate { stepSize = 24 * int64(time.Hour) lowTime.SetCoreTime(types.FromDate(lowTime.Year(), lowTime.Month(), lowTime.Day(), 0, 0, 0, 0)) diff --git a/statistics/statistics_test.go b/statistics/statistics_test.go index 5df0bc08dee7e..ae77884315c91 100644 --- a/statistics/statistics_test.go +++ b/statistics/statistics_test.go @@ -455,7 +455,7 @@ func SubTestIndexRanges() func(*testing.T) { } func encodeKey(key types.Datum) types.Datum { - sc := &stmtctx.StatementContext{TimeZone: time.Local} + sc := stmtctx.NewStmtCtxWithTimeZone(time.Local) buf, _ := codec.EncodeKey(sc, nil, key) return types.NewBytesDatum(buf) } diff --git a/store/mockstore/cluster_test.go b/store/mockstore/cluster_test.go index 9384a433fcba7..3f6d5654921cd 100644 --- a/store/mockstore/cluster_test.go +++ b/store/mockstore/cluster_test.go @@ -52,7 +52,7 @@ func TestClusterSplit(t *testing.T) { idxID := int64(2) colID := int64(3) handle := int64(1) - sc := &stmtctx.StatementContext{TimeZone: time.UTC} + sc := stmtctx.NewStmtCtxWithTimeZone(time.UTC) for i := 0; i < 1000; i++ { rowKey := tablecodec.EncodeRowKeyWithHandle(tblID, kv.IntHandle(handle)) colValue := types.NewStringDatum(strconv.Itoa(int(handle))) diff --git a/store/mockstore/mockcopr/analyze.go b/store/mockstore/mockcopr/analyze.go index a5548b25264f4..93e4396aef4c6 100644 --- a/store/mockstore/mockcopr/analyze.go +++ b/store/mockstore/mockcopr/analyze.go @@ -128,10 +128,11 @@ type analyzeColumnsExec struct { func (h coprHandler) handleAnalyzeColumnsReq(req *coprocessor.Request, analyzeReq *tipb.AnalyzeReq) (_ *coprocessor.Response, err error) { sc := flagsToStatementContext(analyzeReq.Flags) - sc.TimeZone, err = constructTimeZone("", int(analyzeReq.TimeZoneOffset)) + tz, err := constructTimeZone("", int(analyzeReq.TimeZoneOffset)) if err != nil { return nil, errors.Trace(err) } + sc.SetTimeZone(tz) evalCtx := &evalContext{sc: sc} columns := analyzeReq.ColReq.ColumnsInfo diff --git a/store/mockstore/mockcopr/cop_handler_dag.go b/store/mockstore/mockcopr/cop_handler_dag.go index d9e8f84646fc7..2a3e749e714a5 100644 --- a/store/mockstore/mockcopr/cop_handler_dag.go +++ b/store/mockstore/mockcopr/cop_handler_dag.go @@ -104,10 +104,11 @@ func (h coprHandler) buildDAGExecutor(req *coprocessor.Request) (*dagContext, ex } sc := flagsToStatementContext(dagReq.Flags) - sc.TimeZone, err = constructTimeZone(dagReq.TimeZoneName, int(dagReq.TimeZoneOffset)) + tz, err := constructTimeZone(dagReq.TimeZoneName, int(dagReq.TimeZoneOffset)) if err != nil { return nil, nil, nil, errors.Trace(err) } + sc.SetTimeZone(tz) ctx := &dagContext{ dagReq: dagReq, @@ -457,7 +458,7 @@ func (e *evalContext) setColumnInfo(cols []*tipb.ColumnInfo) { func (e *evalContext) decodeRelatedColumnVals(relatedColOffsets []int, value [][]byte, row []types.Datum) error { var err error for _, offset := range relatedColOffsets { - row[offset], err = tablecodec.DecodeColumnValue(value[offset], e.fieldTps[offset], e.sc.TimeZone) + row[offset], err = tablecodec.DecodeColumnValue(value[offset], e.fieldTps[offset], e.sc.TimeZone()) if err != nil { return errors.Trace(err) } @@ -467,7 +468,7 @@ func (e *evalContext) decodeRelatedColumnVals(relatedColOffsets []int, value [][ // flagsToStatementContext creates a StatementContext from a `tipb.SelectRequest.Flags`. func flagsToStatementContext(flags uint64) *stmtctx.StatementContext { - sc := new(stmtctx.StatementContext) + sc := stmtctx.NewStmtCtx() sc.IgnoreTruncate.Store((flags & model.FlagIgnoreTruncate) > 0) sc.TruncateAsWarning = (flags & model.FlagTruncateAsWarning) > 0 sc.InInsertStmt = (flags & model.FlagInInsertStmt) > 0 @@ -559,7 +560,7 @@ func (h coprHandler) fillUpData4SelectResponse(selResp *tipb.SelectResponse, dag h.encodeDefault(selResp, rows, dagReq.OutputOffsets) case tipb.EncodeType_TypeChunk: colTypes := h.constructRespSchema(dagCtx) - loc := dagCtx.evalCtx.sc.TimeZone + loc := dagCtx.evalCtx.sc.TimeZone() err := h.encodeChunk(selResp, rows, colTypes, dagReq.OutputOffsets, loc) if err != nil { return err diff --git a/store/mockstore/unistore/cophandler/analyze.go b/store/mockstore/unistore/cophandler/analyze.go index 626147ed16d22..bed3c71db83a3 100644 --- a/store/mockstore/unistore/cophandler/analyze.go +++ b/store/mockstore/unistore/cophandler/analyze.go @@ -267,14 +267,14 @@ type analyzeColumnsExec struct { func buildBaseAnalyzeColumnsExec(dbReader *dbreader.DBReader, rans []kv.KeyRange, analyzeReq *tipb.AnalyzeReq, startTS uint64) (*analyzeColumnsExec, *statistics.SampleBuilder, int64, error) { sc := flagsToStatementContext(analyzeReq.Flags) - sc.TimeZone = time.FixedZone("UTC", int(analyzeReq.TimeZoneOffset)) + sc.SetTimeZone(time.FixedZone("UTC", int(analyzeReq.TimeZoneOffset))) evalCtx := &evalContext{sc: sc} columns := analyzeReq.ColReq.ColumnsInfo evalCtx.setColumnInfo(columns) if len(analyzeReq.ColReq.PrimaryColumnIds) > 0 { evalCtx.primaryCols = analyzeReq.ColReq.PrimaryColumnIds } - decoder, err := newRowDecoder(evalCtx.columnInfos, evalCtx.fieldTps, evalCtx.primaryCols, evalCtx.sc.TimeZone) + decoder, err := newRowDecoder(evalCtx.columnInfos, evalCtx.fieldTps, evalCtx.primaryCols, evalCtx.sc.TimeZone()) if err != nil { return nil, nil, -1, err } @@ -373,14 +373,14 @@ func handleAnalyzeFullSamplingReq( startTS uint64, ) (*coprocessor.Response, error) { sc := flagsToStatementContext(analyzeReq.Flags) - sc.TimeZone = time.FixedZone("UTC", int(analyzeReq.TimeZoneOffset)) + sc.SetTimeZone(time.FixedZone("UTC", int(analyzeReq.TimeZoneOffset))) evalCtx := &evalContext{sc: sc} columns := analyzeReq.ColReq.ColumnsInfo evalCtx.setColumnInfo(columns) if len(analyzeReq.ColReq.PrimaryColumnIds) > 0 { evalCtx.primaryCols = analyzeReq.ColReq.PrimaryColumnIds } - decoder, err := newRowDecoder(evalCtx.columnInfos, evalCtx.fieldTps, evalCtx.primaryCols, evalCtx.sc.TimeZone) + decoder, err := newRowDecoder(evalCtx.columnInfos, evalCtx.fieldTps, evalCtx.primaryCols, evalCtx.sc.TimeZone()) if err != nil { return nil, err } diff --git a/store/mockstore/unistore/cophandler/closure_exec.go b/store/mockstore/unistore/cophandler/closure_exec.go index 515fe6f6e1a32..1e707b9b2955a 100644 --- a/store/mockstore/unistore/cophandler/closure_exec.go +++ b/store/mockstore/unistore/cophandler/closure_exec.go @@ -286,7 +286,7 @@ func newClosureExecutor(dagCtx *dagContext, outputOffsets []uint32, scanExec *ti e.kvRanges = ranges e.scanCtx.chk = chunk.NewChunkWithCapacity(e.fieldTps, 32) if e.scanType == TableScan { - e.scanCtx.decoder, err = newRowDecoder(e.evalContext.columnInfos, e.evalContext.fieldTps, e.evalContext.primaryCols, e.evalContext.sc.TimeZone) + e.scanCtx.decoder, err = newRowDecoder(e.evalContext.columnInfos, e.evalContext.fieldTps, e.evalContext.primaryCols, e.evalContext.sc.TimeZone()) if err != nil { return nil, errors.Trace(err) } @@ -925,7 +925,7 @@ func (e *closureExecutor) indexScanProcessCore(key, value []byte) error { } } chk := e.scanCtx.chk - decoder := codec.NewDecoder(chk, e.sc.TimeZone) + decoder := codec.NewDecoder(chk, e.sc.TimeZone()) for i, colVal := range values { if i < len(e.fieldTps) { _, err = decoder.DecodeOne(colVal, i, e.fieldTps[i]) diff --git a/store/mockstore/unistore/cophandler/cop_handler.go b/store/mockstore/unistore/cophandler/cop_handler.go index 083978c4d7b43..b0fd540672229 100644 --- a/store/mockstore/unistore/cophandler/cop_handler.go +++ b/store/mockstore/unistore/cophandler/cop_handler.go @@ -297,14 +297,15 @@ func buildDAG(reader *dbreader.DBReader, lockStore *lockstore.MemStore, req *cop sc := flagsToStatementContext(dagReq.Flags) switch dagReq.TimeZoneName { case "": - sc.TimeZone = time.FixedZone("UTC", int(dagReq.TimeZoneOffset)) + sc.SetTimeZone(time.FixedZone("UTC", int(dagReq.TimeZoneOffset))) case "System": - sc.TimeZone = time.Local + sc.SetTimeZone(time.Local) default: - sc.TimeZone, err = time.LoadLocation(dagReq.TimeZoneName) + tz, err := time.LoadLocation(dagReq.TimeZoneName) if err != nil { return nil, nil, errors.Trace(err) } + sc.SetTimeZone(tz) } ctx := &dagContext{ evalContext: &evalContext{sc: sc}, @@ -422,7 +423,7 @@ func newRowDecoder(columnInfos []*tipb.ColumnInfo, fieldTps []*types.FieldType, // flagsToStatementContext creates a StatementContext from a `tipb.SelectRequest.Flags`. func flagsToStatementContext(flags uint64) *stmtctx.StatementContext { - sc := new(stmtctx.StatementContext) + sc := stmtctx.NewStmtCtx() sc.IgnoreTruncate.Store((flags & model.FlagIgnoreTruncate) > 0) sc.TruncateAsWarning = (flags & model.FlagTruncateAsWarning) > 0 sc.InInsertStmt = (flags & model.FlagInInsertStmt) > 0 diff --git a/store/mockstore/unistore/cophandler/cop_handler_test.go b/store/mockstore/unistore/cophandler/cop_handler_test.go index 8e40070de2101..ed72642600fd5 100644 --- a/store/mockstore/unistore/cophandler/cop_handler_test.go +++ b/store/mockstore/unistore/cophandler/cop_handler_test.go @@ -94,7 +94,7 @@ func makeATestMutaion(op kvrpcpb.Op, key []byte, value []byte) *kvrpcpb.Mutation } func prepareTestTableData(keyNumber int, tableID int64) (*data, error) { - stmtCtx := new(stmtctx.StatementContext) + stmtCtx := stmtctx.NewStmtCtx() colIds := []int64{1, 2, 3} colTypes := []*types.FieldType{ types.NewFieldType(mysql.TypeLonglong), diff --git a/store/mockstore/unistore/cophandler/mpp.go b/store/mockstore/unistore/cophandler/mpp.go index c31180c947298..37c41e5e850cb 100644 --- a/store/mockstore/unistore/cophandler/mpp.go +++ b/store/mockstore/unistore/cophandler/mpp.go @@ -91,7 +91,7 @@ func (b *mppExecBuilder) buildMPPTableScan(pb *tipb.TableScan) (*tableScanExec, ft := fieldTypeFromPBColumn(col) ts.fieldTypes = append(ts.fieldTypes, ft) } - ts.decoder, err = newRowDecoder(pb.Columns, ts.fieldTypes, pb.PrimaryColumnIds, b.sc.TimeZone) + ts.decoder, err = newRowDecoder(pb.Columns, ts.fieldTypes, pb.PrimaryColumnIds, b.sc.TimeZone()) return ts, err } @@ -114,7 +114,7 @@ func (b *mppExecBuilder) buildMPPPartitionTableScan(pb *tipb.PartitionTableScan) ft := fieldTypeFromPBColumn(col) ts.fieldTypes = append(ts.fieldTypes, ft) } - ts.decoder, err = newRowDecoder(pb.Columns, ts.fieldTypes, pb.PrimaryColumnIds, b.sc.TimeZone) + ts.decoder, err = newRowDecoder(pb.Columns, ts.fieldTypes, pb.PrimaryColumnIds, b.sc.TimeZone()) return ts, err } diff --git a/store/mockstore/unistore/cophandler/mpp_exec.go b/store/mockstore/unistore/cophandler/mpp_exec.go index eea1aa2810268..80e6caf3141f9 100644 --- a/store/mockstore/unistore/cophandler/mpp_exec.go +++ b/store/mockstore/unistore/cophandler/mpp_exec.go @@ -304,7 +304,7 @@ func (e *indexScanExec) Process(key, value []byte) error { e.prevVals[i] = append(e.prevVals[i][:0], values[i]...) } } - decoder := codec.NewDecoder(e.chk, e.sc.TimeZone) + decoder := codec.NewDecoder(e.chk, e.sc.TimeZone()) for i, value := range values { if i < len(e.fieldTypes) { _, err = decoder.DecodeOne(value, i, e.fieldTypes[i]) diff --git a/store/mockstore/unistore/tikv/mvcc.go b/store/mockstore/unistore/tikv/mvcc.go index e7f1e46daa5f4..d8f30ade2dcae 100644 --- a/store/mockstore/unistore/tikv/mvcc.go +++ b/store/mockstore/unistore/tikv/mvcc.go @@ -1000,7 +1000,7 @@ func encodeFromOldRow(oldRow, buf []byte) ([]byte, error) { datums = append(datums, d) } var encoder rowcodec.Encoder - return encoder.Encode(&stmtctx.StatementContext{}, colIDs, datums, buf) + return encoder.Encode(stmtctx.NewStmtCtx(), colIDs, datums, buf) } func (store *MVCCStore) buildPrewriteLock(reqCtx *requestCtx, m *kvrpcpb.Mutation, item *badger.Item, diff --git a/table/column_test.go b/table/column_test.go index a0c14a2360e6e..bb20858aba81d 100644 --- a/table/column_test.go +++ b/table/column_test.go @@ -128,7 +128,7 @@ func TestCheck(t *testing.T) { func TestHandleBadNull(t *testing.T) { col := newCol("a") - sc := new(stmtctx.StatementContext) + sc := stmtctx.NewStmtCtx() d := types.Datum{} err := col.HandleBadNull(&d, sc, 0) require.NoError(t, err) @@ -256,7 +256,7 @@ func TestGetZeroValue(t *testing.T) { types.NewDatum(types.CreateBinaryJSON(nil)), }, } - sc := new(stmtctx.StatementContext) + sc := stmtctx.NewStmtCtx() for _, tt := range tests { t.Run(fmt.Sprintf("%+v", tt.ft), func(t *testing.T) { colInfo := &model.ColumnInfo{FieldType: *tt.ft} diff --git a/table/tables/mutation_checker_test.go b/table/tables/mutation_checker_test.go index 4c44e90a7d244..7c4dd8e140764 100644 --- a/table/tables/mutation_checker_test.go +++ b/table/tables/mutation_checker_test.go @@ -72,7 +72,7 @@ func TestCompareIndexData(t *testing.T) { } for caseID, data := range testData { - sc := &stmtctx.StatementContext{} + sc := stmtctx.NewStmtCtx() cols := make([]*table.Column, 0) indexCols := make([]*model.IndexColumn, 0) for i, ft := range data.fts { @@ -256,7 +256,7 @@ func TestCheckIndexKeysAndCheckHandleConsistency(t *testing.T) { for _, isCommonHandle := range []bool{true, false} { for _, lc := range locations { for _, columnInfos := range columnInfoSets { - sessVars.StmtCtx.TimeZone = lc + sessVars.StmtCtx.SetTimeZone(lc) tableInfo := model.TableInfo{ ID: 1, Name: model.NewCIStr("t"), diff --git a/tablecodec/tablecodec.go b/tablecodec/tablecodec.go index 9ec296256eed5..1f8508a2585f6 100644 --- a/tablecodec/tablecodec.go +++ b/tablecodec/tablecodec.go @@ -377,8 +377,8 @@ func flatten(sc *stmtctx.StatementContext, data types.Datum, ret *types.Datum) e case types.KindMysqlTime: // for mysql datetime, timestamp and date type t := data.GetMysqlTime() - if t.Type() == mysql.TypeTimestamp && sc.TimeZone != time.UTC { - err := t.ConvertTimeZone(sc.TimeZone, time.UTC) + if t.Type() == mysql.TypeTimestamp && sc.TimeZone() != time.UTC { + err := t.ConvertTimeZone(sc.TimeZone(), time.UTC) if err != nil { return errors.Trace(err) } diff --git a/tablecodec/tablecodec_test.go b/tablecodec/tablecodec_test.go index 116971fc36273..a3769481e7f24 100644 --- a/tablecodec/tablecodec_test.go +++ b/tablecodec/tablecodec_test.go @@ -99,7 +99,7 @@ func TestRowCodec(t *testing.T) { colIDs = append(colIDs, col.id) } rd := rowcodec.Encoder{Enable: true} - sc := &stmtctx.StatementContext{TimeZone: time.Local} + sc := stmtctx.NewStmtCtxWithTimeZone(time.Local) bs, err := EncodeRow(sc, row, colIDs, nil, nil, &rd) require.NoError(t, err) require.NotNil(t, bs) @@ -165,7 +165,7 @@ func TestRowCodec(t *testing.T) { } func TestDecodeColumnValue(t *testing.T) { - sc := &stmtctx.StatementContext{TimeZone: time.Local} + sc := stmtctx.NewStmtCtxWithTimeZone(time.Local) // test timestamp d := types.NewTimeDatum(types.NewTime(types.FromGoTime(time.Now()), mysql.TypeTimestamp, types.DefaultFsp)) @@ -175,7 +175,7 @@ func TestDecodeColumnValue(t *testing.T) { _, bs, err = codec.CutOne(bs) // ignore colID require.NoError(t, err) tp := types.NewFieldType(mysql.TypeTimestamp) - d1, err := DecodeColumnValue(bs, tp, sc.TimeZone) + d1, err := DecodeColumnValue(bs, tp, sc.TimeZone()) require.NoError(t, err) cmp, err := d1.Compare(sc, &d, collate.GetBinaryCollator()) require.NoError(t, err) @@ -192,7 +192,7 @@ func TestDecodeColumnValue(t *testing.T) { require.NoError(t, err) tp = types.NewFieldType(mysql.TypeSet) tp.SetElems(elems) - d1, err = DecodeColumnValue(bs, tp, sc.TimeZone) + d1, err = DecodeColumnValue(bs, tp, sc.TimeZone()) require.NoError(t, err) cmp, err = d1.Compare(sc, &d, collate.GetCollator(tp.GetCollate())) require.NoError(t, err) @@ -207,7 +207,7 @@ func TestDecodeColumnValue(t *testing.T) { require.NoError(t, err) tp = types.NewFieldType(mysql.TypeBit) tp.SetFlen(24) - d1, err = DecodeColumnValue(bs, tp, sc.TimeZone) + d1, err = DecodeColumnValue(bs, tp, sc.TimeZone()) require.NoError(t, err) cmp, err = d1.Compare(sc, &d, collate.GetBinaryCollator()) require.NoError(t, err) @@ -221,7 +221,7 @@ func TestDecodeColumnValue(t *testing.T) { _, bs, err = codec.CutOne(bs) // ignore colID require.NoError(t, err) tp = types.NewFieldType(mysql.TypeEnum) - d1, err = DecodeColumnValue(bs, tp, sc.TimeZone) + d1, err = DecodeColumnValue(bs, tp, sc.TimeZone()) require.NoError(t, err) cmp, err = d1.Compare(sc, &d, collate.GetCollator(tp.GetCollate())) require.NoError(t, err) @@ -229,10 +229,10 @@ func TestDecodeColumnValue(t *testing.T) { } func TestUnflattenDatums(t *testing.T) { - sc := &stmtctx.StatementContext{TimeZone: time.UTC} + sc := stmtctx.NewStmtCtxWithTimeZone(time.UTC) input := types.MakeDatums(int64(1)) tps := []*types.FieldType{types.NewFieldType(mysql.TypeLonglong)} - output, err := UnflattenDatums(input, tps, sc.TimeZone) + output, err := UnflattenDatums(input, tps, sc.TimeZone()) require.NoError(t, err) cmp, err := input[0].Compare(sc, &output[0], collate.GetBinaryCollator()) require.NoError(t, err) @@ -241,7 +241,7 @@ func TestUnflattenDatums(t *testing.T) { input = []types.Datum{types.NewCollationStringDatum("aaa", "utf8mb4_unicode_ci")} tps = []*types.FieldType{types.NewFieldType(mysql.TypeBlob)} tps[0].SetCollate("utf8mb4_unicode_ci") - output, err = UnflattenDatums(input, tps, sc.TimeZone) + output, err = UnflattenDatums(input, tps, sc.TimeZone()) require.NoError(t, err) cmp, err = input[0].Compare(sc, &output[0], collate.GetBinaryCollator()) require.NoError(t, err) @@ -260,7 +260,7 @@ func TestTimeCodec(t *testing.T) { row := make([]types.Datum, colLen) row[0] = types.NewIntDatum(100) row[1] = types.NewBytesDatum([]byte("abc")) - ts, err := types.ParseTimestamp(&stmtctx.StatementContext{TimeZone: time.UTC}, + ts, err := types.ParseTimestamp(stmtctx.NewStmtCtxWithTimeZone(time.UTC), "2016-06-23 11:30:45") require.NoError(t, err) row[2] = types.NewDatum(ts) @@ -274,7 +274,7 @@ func TestTimeCodec(t *testing.T) { colIDs = append(colIDs, col.id) } rd := rowcodec.Encoder{Enable: true} - sc := &stmtctx.StatementContext{TimeZone: time.UTC} + sc := stmtctx.NewStmtCtxWithTimeZone(time.UTC) bs, err := EncodeRow(sc, row, colIDs, nil, nil, &rd) require.NoError(t, err) require.NotNil(t, bs) @@ -310,7 +310,7 @@ func TestCutRow(t *testing.T) { row[1] = types.NewBytesDatum([]byte("abc")) row[2] = types.NewDecimalDatum(types.NewDecFromInt(1)) - sc := &stmtctx.StatementContext{TimeZone: time.UTC} + sc := stmtctx.NewStmtCtxWithTimeZone(time.UTC) data := make([][]byte, 3) data[0], err = EncodeValue(sc, nil, row[0]) require.NoError(t, err) @@ -354,7 +354,7 @@ func TestCutKeyNew(t *testing.T) { values := []types.Datum{types.NewIntDatum(1), types.NewBytesDatum([]byte("abc")), types.NewFloat64Datum(5.5)} handle := types.NewIntDatum(100) values = append(values, handle) - sc := &stmtctx.StatementContext{TimeZone: time.UTC} + sc := stmtctx.NewStmtCtxWithTimeZone(time.UTC) encodedValue, err := codec.EncodeKey(sc, nil, values...) require.NoError(t, err) tableID := int64(4) @@ -377,7 +377,7 @@ func TestCutKey(t *testing.T) { values := []types.Datum{types.NewIntDatum(1), types.NewBytesDatum([]byte("abc")), types.NewFloat64Datum(5.5)} handle := types.NewIntDatum(100) values = append(values, handle) - sc := &stmtctx.StatementContext{TimeZone: time.UTC} + sc := stmtctx.NewStmtCtxWithTimeZone(time.UTC) encodedValue, err := codec.EncodeKey(sc, nil, values...) require.NoError(t, err) tableID := int64(4) @@ -493,7 +493,7 @@ func TestDecodeIndexKey(t *testing.T) { } valueStrs = append(valueStrs, str) } - sc := &stmtctx.StatementContext{TimeZone: time.UTC} + sc := stmtctx.NewStmtCtxWithTimeZone(time.UTC) encodedValue, err := codec.EncodeKey(sc, nil, values...) require.NoError(t, err) indexKey := EncodeIndexSeekKey(tableID, indexID, encodedValue) @@ -599,7 +599,7 @@ func TestUntouchedIndexKValue(t *testing.T) { func TestTempIndexKey(t *testing.T) { values := []types.Datum{types.NewIntDatum(1), types.NewBytesDatum([]byte("abc")), types.NewFloat64Datum(5.5)} - encodedValue, err := codec.EncodeKey(&stmtctx.StatementContext{TimeZone: time.UTC}, nil, values...) + encodedValue, err := codec.EncodeKey(stmtctx.NewStmtCtxWithTimeZone(time.UTC), nil, values...) require.NoError(t, err) tableID := int64(4) indexID := int64(5) @@ -626,7 +626,7 @@ func TestTempIndexKey(t *testing.T) { func TestTempIndexValueCodec(t *testing.T) { // Test encode temp index value. - encodedValue, err := codec.EncodeValue(&stmtctx.StatementContext{TimeZone: time.UTC}, nil, types.NewIntDatum(1)) + encodedValue, err := codec.EncodeValue(stmtctx.NewStmtCtxWithTimeZone(time.UTC), nil, types.NewIntDatum(1)) require.NoError(t, err) encodedValueCopy := make([]byte, len(encodedValue)) copy(encodedValueCopy, encodedValue) diff --git a/testkit/testutil/handle.go b/testkit/testutil/handle.go index c0832562f8816..7e96b063e8ea6 100644 --- a/testkit/testutil/handle.go +++ b/testkit/testutil/handle.go @@ -30,7 +30,7 @@ import ( // MustNewCommonHandle create a common handle with given values. func MustNewCommonHandle(t *testing.T, values ...interface{}) kv.Handle { - encoded, err := codec.EncodeKey(new(stmtctx.StatementContext), nil, types.MakeDatums(values...)...) + encoded, err := codec.EncodeKey(stmtctx.NewStmtCtx(), nil, types.MakeDatums(values...)...) require.NoError(t, err) ch, err := kv.NewCommonHandle(encoded) require.NoError(t, err) diff --git a/testkit/testutil/require.go b/testkit/testutil/require.go index 09e8e871312ae..e73123b262a40 100644 --- a/testkit/testutil/require.go +++ b/testkit/testutil/require.go @@ -29,7 +29,7 @@ import ( // DatumEqual verifies that the actual value is equal to the expected value. For string datum, they are compared by the binary collation. func DatumEqual(t testing.TB, expected, actual types.Datum, msgAndArgs ...interface{}) { - sc := new(stmtctx.StatementContext) + sc := stmtctx.NewStmtCtx() res, err := actual.Compare(sc, &expected, collate.GetBinaryCollator()) require.NoError(t, err, msgAndArgs) require.Zero(t, res, msgAndArgs) diff --git a/types/binary_literal_test.go b/types/binary_literal_test.go index a76f4c5a8dcb1..29d3f8ee9f5c3 100644 --- a/types/binary_literal_test.go +++ b/types/binary_literal_test.go @@ -207,7 +207,7 @@ func TestBinaryLiteral(t *testing.T) { {"0x1010ffff8080ff12", 0x1010ffff8080ff12, false}, {"0x1010ffff8080ff12ff", 0xffffffffffffffff, true}, } - sc := new(stmtctx.StatementContext) + sc := stmtctx.NewStmtCtx() for _, item := range tbl { hex, err := ParseHexStr(item.Input) require.NoError(t, err) diff --git a/types/compare_test.go b/types/compare_test.go index ce1168a2bea90..d4f212876ae33 100644 --- a/types/compare_test.go +++ b/types/compare_test.go @@ -146,7 +146,7 @@ func TestCompare(t *testing.T) { } func compareForTest(a, b interface{}) (int, error) { - sc := new(stmtctx.StatementContext) + sc := stmtctx.NewStmtCtx() sc.IgnoreTruncate.Store(true) aDatum := NewDatum(a) bDatum := NewDatum(b) @@ -168,7 +168,7 @@ func TestCompareDatum(t *testing.T) { {Datum{}, MinNotNullDatum(), -1}, {MinNotNullDatum(), MaxValueDatum(), -1}, } - sc := new(stmtctx.StatementContext) + sc := stmtctx.NewStmtCtx() sc.IgnoreTruncate.Store(true) for i, tt := range cmpTbl { ret, err := tt.lhs.Compare(sc, &tt.rhs, collate.GetBinaryCollator()) diff --git a/types/context/context.go b/types/context/context.go index fb3c76f70e54c..ddccdb4a6e445 100644 --- a/types/context/context.go +++ b/types/context/context.go @@ -113,7 +113,6 @@ type Context struct { // NewContext creates a new `Context` func NewContext(flags Flags, loc *time.Location, appendWarningFn func(err error)) Context { - intest.Assert(loc) return Context{ flags: flags, loc: loc, @@ -133,6 +132,14 @@ func (c *Context) WithFlags(f Flags) Context { return ctx } +// WithLocation returns a new context with the given location +func (c *Context) WithLocation(loc *time.Location) Context { + intest.Assert(loc) + ctx := *c + ctx.loc = loc + return ctx +} + // Location returns the location of the context func (c *Context) Location() *time.Location { intest.Assert(c.loc) @@ -149,3 +156,8 @@ func (c *Context) AppendWarning(err error) { fn(err) } } + +// AppendWarningFunc returns the inner `appendWarningFn` +func (c *Context) AppendWarningFunc() func(err error) { + return c.appendWarningFn +} diff --git a/types/convert_test.go b/types/convert_test.go index 67764972849a0..5ca795cb23ab5 100644 --- a/types/convert_test.go +++ b/types/convert_test.go @@ -35,8 +35,7 @@ type invalidMockType struct { // Convert converts the val with type tp. func Convert(val interface{}, target *FieldType) (v interface{}, err error) { d := NewDatum(val) - sc := new(stmtctx.StatementContext) - sc.TimeZone = time.UTC + sc := stmtctx.NewStmtCtxWithTimeZone(time.UTC) ret, err := d.ConvertTo(sc, target) if err != nil { return ret.GetValue(), errors.Trace(err) @@ -149,7 +148,7 @@ func TestConvertType(t *testing.T) { vv, err := Convert(v, ft) require.NoError(t, err) require.Equal(t, "10:11:12.1", vv.(Duration).String()) - sc := &stmtctx.StatementContext{TimeZone: time.UTC} + sc := stmtctx.NewStmtCtxWithTimeZone(time.UTC) vd, err := ParseTime(sc, "2010-10-10 10:11:11.12345", mysql.TypeDatetime, 2, nil) require.Equal(t, "2010-10-10 10:11:11.12", vd.String()) require.NoError(t, err) @@ -346,7 +345,7 @@ func TestConvertToString(t *testing.T) { testToString(t, Enum{Name: "a", Value: 1}, "a") testToString(t, Set{Name: "a", Value: 1}, "a") - t1, err := ParseTime(&stmtctx.StatementContext{TimeZone: time.UTC}, "2011-11-10 11:11:11.999999", mysql.TypeTimestamp, 6, nil) + t1, err := ParseTime(stmtctx.NewStmtCtxWithTimeZone(time.UTC), "2011-11-10 11:11:11.999999", mysql.TypeTimestamp, 6, nil) require.NoError(t, err) testToString(t, t1, "2011-11-10 11:11:11.999999") @@ -384,7 +383,7 @@ func TestConvertToString(t *testing.T) { ft.SetFlen(tt.flen) ft.SetCharset(tt.charset) inputDatum := NewStringDatum(tt.input) - sc := new(stmtctx.StatementContext) + sc := stmtctx.NewStmtCtx() outputDatum, err := inputDatum.ConvertTo(sc, ft) if tt.input != tt.output { require.True(t, ErrDataTooLong.Equal(err), "flen: %d, charset: %s, input: %s, output: %s", tt.flen, tt.charset, tt.input, tt.output) @@ -421,9 +420,9 @@ func TestConvertToStringWithCheck(t *testing.T) { ft.SetFlen(255) ft.SetCharset(tt.outputChs) inputDatum := NewStringDatum(tt.input) - sc := new(stmtctx.StatementContext) - flags := tt.newFlags(sc.TypeConvContext.Flags()) - sc.TypeConvContext = sc.TypeConvContext.WithFlags(flags) + sc := stmtctx.NewStmtCtx() + flags := tt.newFlags(sc.TypeCtx.Flags()) + sc.SetTypeFlags(flags) outputDatum, err := inputDatum.ConvertTo(sc, ft) if len(tt.output) == 0 { require.True(t, charset.ErrInvalidCharacterString.Equal(err), tt) @@ -461,7 +460,7 @@ func TestConvertToBinaryString(t *testing.T) { ft.SetFlen(255) ft.SetCharset(tt.outputCharset) inputDatum := NewCollationStringDatum(tt.input, tt.inputCollate) - sc := new(stmtctx.StatementContext) + sc := stmtctx.NewStmtCtx() outputDatum, err := inputDatum.ConvertTo(sc, ft) if len(tt.output) == 0 { require.True(t, charset.ErrInvalidCharacterString.Equal(err), tt) @@ -473,7 +472,7 @@ func TestConvertToBinaryString(t *testing.T) { } func testStrToInt(t *testing.T, str string, expect int64, truncateAsErr bool, expectErr error) { - sc := new(stmtctx.StatementContext) + sc := stmtctx.NewStmtCtx() sc.IgnoreTruncate.Store(!truncateAsErr) val, err := StrToInt(sc, str, false) if expectErr != nil { @@ -485,7 +484,7 @@ func testStrToInt(t *testing.T, str string, expect int64, truncateAsErr bool, ex } func testStrToUint(t *testing.T, str string, expect uint64, truncateAsErr bool, expectErr error) { - sc := new(stmtctx.StatementContext) + sc := stmtctx.NewStmtCtx() sc.IgnoreTruncate.Store(!truncateAsErr) val, err := StrToUint(sc, str, false) if expectErr != nil { @@ -497,7 +496,7 @@ func testStrToUint(t *testing.T, str string, expect uint64, truncateAsErr bool, } func testStrToFloat(t *testing.T, str string, expect float64, truncateAsErr bool, expectErr error) { - sc := new(stmtctx.StatementContext) + sc := stmtctx.NewStmtCtx() sc.IgnoreTruncate.Store(!truncateAsErr) val, err := StrToFloat(sc, str, false) if expectErr != nil { @@ -566,7 +565,7 @@ func testSelectUpdateDeleteEmptyStringError(t *testing.T) { {true, false}, {false, true}, } - sc := new(stmtctx.StatementContext) + sc := stmtctx.NewStmtCtx() sc.TruncateAsWarning = true for _, tc := range testCases { sc.InSelectStmt = tc.inSelect @@ -604,8 +603,8 @@ func accept(t *testing.T, tp byte, value interface{}, unsigned bool, expected st ft.AddFlag(mysql.UnsignedFlag) } d := NewDatum(value) - sc := new(stmtctx.StatementContext) - sc.TimeZone = time.UTC + sc := stmtctx.NewStmtCtx() + sc.SetTimeZone(time.UTC) sc.IgnoreTruncate.Store(true) casted, err := d.ConvertTo(sc, ft) require.NoErrorf(t, err, "%v", ft) @@ -632,7 +631,7 @@ func deny(t *testing.T, tp byte, value interface{}, unsigned bool, expected stri ft.AddFlag(mysql.UnsignedFlag) } d := NewDatum(value) - sc := new(stmtctx.StatementContext) + sc := stmtctx.NewStmtCtx() casted, err := d.ConvertTo(sc, ft) require.Error(t, err) if casted.IsNull() { @@ -887,7 +886,7 @@ func TestGetValidInt(t *testing.T) { {"123e+", "123", true, true}, {"123de", "123", true, true}, } - sc := new(stmtctx.StatementContext) + sc := stmtctx.NewStmtCtx() sc.TruncateAsWarning = true sc.InSelectStmt = true warningCount := 0 @@ -967,7 +966,7 @@ func TestGetValidFloat(t *testing.T) { {"9-3", "9"}, {"1001001\\u0000\\u0000\\u0000", "1001001"}, } - sc := new(stmtctx.StatementContext) + sc := stmtctx.NewStmtCtx() for _, tt := range tests { prefix, _ := getValidFloatPrefix(sc, tt.origin, false) require.Equal(t, tt.valid, prefix) @@ -1012,9 +1011,7 @@ func TestConvertTime(t *testing.T) { } for _, timezone := range timezones { - sc := &stmtctx.StatementContext{ - TimeZone: timezone, - } + sc := stmtctx.NewStmtCtxWithTimeZone(timezone) testConvertTimeTimeZone(t, sc) } } @@ -1081,7 +1078,7 @@ func TestConvertJSONToInt(t *testing.T) { j, err := ParseBinaryJSONFromString(tt.in) require.NoError(t, err) - casted, err := ConvertJSONToInt64(new(stmtctx.StatementContext), j, false) + casted, err := ConvertJSONToInt64(stmtctx.NewStmtCtx(), j, false) if tt.err { require.Error(t, err, tt) } else { @@ -1114,7 +1111,7 @@ func TestConvertJSONToFloat(t *testing.T) { for _, tt := range tests { j := CreateBinaryJSON(tt.in) require.Equal(t, tt.ty, j.TypeCode) - casted, err := ConvertJSONToFloat(new(stmtctx.StatementContext), j) + casted, err := ConvertJSONToFloat(stmtctx.NewStmtCtx(), j) if tt.err { require.Error(t, err, tt) } else { @@ -1142,7 +1139,7 @@ func TestConvertJSONToDecimal(t *testing.T) { for _, tt := range tests { j, err := ParseBinaryJSONFromString(tt.in) require.NoError(t, err) - casted, err := ConvertJSONToDecimal(new(stmtctx.StatementContext), j) + casted, err := ConvertJSONToDecimal(stmtctx.NewStmtCtx(), j) errMsg := fmt.Sprintf("input: %v, casted: %v, out: %v, json: %#v", tt.in, casted, tt.out, j) if tt.err { require.Error(t, err, errMsg) @@ -1205,7 +1202,7 @@ func TestNumberToDuration(t *testing.T) { } func TestStrToDuration(t *testing.T) { - sc := new(stmtctx.StatementContext) + sc := stmtctx.NewStmtCtx() var tests = []struct { str string fsp int @@ -1283,7 +1280,7 @@ func TestConvertDecimalStrToUint(t *testing.T) { {"-10000000000000000000.0", 0, false}, } for _, ca := range cases { - result, err := convertDecimalStrToUint(&stmtctx.StatementContext{}, ca.input, math.MaxUint64, 0) + result, err := convertDecimalStrToUint(stmtctx.NewStmtCtx(), ca.input, math.MaxUint64, 0) if !ca.succ { require.Error(t, err) } else { @@ -1292,11 +1289,11 @@ func TestConvertDecimalStrToUint(t *testing.T) { require.Equal(t, ca.result, result, "input=%v", ca.input) } - result, err := convertDecimalStrToUint(&stmtctx.StatementContext{}, "-99.0", math.MaxUint8, 0) + result, err := convertDecimalStrToUint(stmtctx.NewStmtCtx(), "-99.0", math.MaxUint8, 0) require.Error(t, err) require.Equal(t, uint64(0), result) - result, err = convertDecimalStrToUint(&stmtctx.StatementContext{}, "-100.0", math.MaxUint8, 0) + result, err = convertDecimalStrToUint(stmtctx.NewStmtCtx(), "-100.0", math.MaxUint8, 0) require.Error(t, err) require.Equal(t, uint64(0), result) } diff --git a/types/datum.go b/types/datum.go index 2836afb75ceae..35c072cc3c03d 100644 --- a/types/datum.go +++ b/types/datum.go @@ -1026,7 +1026,7 @@ func (d *Datum) convertToString(sc *stmtctx.StatementContext, target *FieldType) s string err error ) - ctx := sc.TypeConvContext + ctx := sc.TypeCtx switch d.k { case KindInt64: s = strconv.FormatInt(d.GetInt64(), 10) @@ -1373,13 +1373,13 @@ func (d *Datum) convertToMysqlDuration(sc *stmtctx.StatementContext, target *Fie ret.SetMysqlDuration(dur) return ret, errors.Trace(err) } - dur, err = dur.RoundFrac(fsp, sc.TimeZone) + dur, err = dur.RoundFrac(fsp, sc.TimeZone()) ret.SetMysqlDuration(dur) if err != nil { return ret, errors.Trace(err) } case KindMysqlDuration: - dur, err := d.GetMysqlDuration().RoundFrac(fsp, sc.TimeZone) + dur, err := d.GetMysqlDuration().RoundFrac(fsp, sc.TimeZone()) ret.SetMysqlDuration(dur) if err != nil { return ret, errors.Trace(err) @@ -1880,7 +1880,7 @@ func (d *Datum) toSignedInteger(sc *stmtctx.StatementContext, tp byte) (int64, e case KindMysqlDuration: // 11:11:11.999999 -> 111112 // 11:59:59.999999 -> 120000 - dur, err := d.GetMysqlDuration().RoundFrac(DefaultFsp, sc.TimeZone) + dur, err := d.GetMysqlDuration().RoundFrac(DefaultFsp, sc.TimeZone()) if err != nil { return 0, errors.Trace(err) } diff --git a/types/datum_test.go b/types/datum_test.go index 09107b15c9328..41430d3943bf7 100644 --- a/types/datum_test.go +++ b/types/datum_test.go @@ -56,7 +56,7 @@ func TestDatum(t *testing.T) { func testDatumToBool(t *testing.T, in interface{}, res int) { datum := NewDatum(in) res64 := int64(res) - sc := new(stmtctx.StatementContext) + sc := stmtctx.NewStmtCtx() sc.IgnoreTruncate.Store(true) b, err := datum.ToBool(sc) require.NoError(t, err) @@ -94,7 +94,7 @@ func TestToBool(t *testing.T) { testDatumToBool(t, CreateBinaryJSON(true), 1) testDatumToBool(t, CreateBinaryJSON(false), 1) testDatumToBool(t, CreateBinaryJSON(""), 1) - t1, err := ParseTime(&stmtctx.StatementContext{TimeZone: time.UTC}, "2011-11-10 11:11:11.999999", mysql.TypeTimestamp, 6, nil) + t1, err := ParseTime(stmtctx.NewStmtCtxWithTimeZone(time.UTC), "2011-11-10 11:11:11.999999", mysql.TypeTimestamp, 6, nil) require.NoError(t, err) testDatumToBool(t, t1, 1) @@ -108,7 +108,7 @@ func TestToBool(t *testing.T) { require.NoError(t, err) testDatumToBool(t, v, 1) d := NewDatum(&invalidMockType{}) - sc := new(stmtctx.StatementContext) + sc := stmtctx.NewStmtCtx() sc.IgnoreTruncate.Store(true) _, err = d.ToBool(sc) require.Error(t, err) @@ -116,7 +116,7 @@ func TestToBool(t *testing.T) { func testDatumToInt64(t *testing.T, val interface{}, expect int64) { d := NewDatum(val) - sc := new(stmtctx.StatementContext) + sc := stmtctx.NewStmtCtx() sc.IgnoreTruncate.Store(true) b, err := d.ToInt64(sc) require.NoError(t, err) @@ -135,9 +135,7 @@ func TestToInt64(t *testing.T) { testDatumToInt64(t, Set{Name: "a", Value: 1}, int64(1)) testDatumToInt64(t, CreateBinaryJSON(int64(3)), int64(3)) - t1, err := ParseTime(&stmtctx.StatementContext{ - TimeZone: time.UTC, - }, "2011-11-10 11:11:11.999999", mysql.TypeTimestamp, 0, nil) + t1, err := ParseTime(stmtctx.NewStmtCtxWithTimeZone(time.UTC), "2011-11-10 11:11:11.999999", mysql.TypeTimestamp, 0, nil) require.NoError(t, err) testDatumToInt64(t, t1, int64(20111110111112)) @@ -154,7 +152,7 @@ func TestToInt64(t *testing.T) { func testDatumToUInt32(t *testing.T, val interface{}, expect uint32, hasError bool) { d := NewDatum(val) - sc := new(stmtctx.StatementContext) + sc := stmtctx.NewStmtCtx() sc.IgnoreTruncate.Store(true) ft := NewFieldType(mysql.TypeLong) @@ -206,7 +204,7 @@ func TestConvertToFloat(t *testing.T) { {NewDatum("281.37"), mysql.TypeFloat, "", 281.37, 281.37}, } - sc := new(stmtctx.StatementContext) + sc := stmtctx.NewStmtCtx() sc.IgnoreTruncate.Store(true) for _, testCase := range testCases { converted, err := testCase.d.ConvertTo(sc, NewFieldType(testCase.tp)) @@ -227,7 +225,7 @@ func TestConvertToFloat(t *testing.T) { } func mustParseTime(s string, tp byte, fsp int) Time { - t, err := ParseTime(&stmtctx.StatementContext{TimeZone: time.UTC}, s, tp, fsp, nil) + t, err := ParseTime(stmtctx.NewStmtCtxWithTimeZone(time.UTC), s, tp, fsp, nil) if err != nil { panic("ParseTime fail") } @@ -243,7 +241,7 @@ func mustParseTimeIntoDatum(s string, tp byte, fsp int) (d Datum) { func TestToJSON(t *testing.T) { ft := NewFieldType(mysql.TypeJSON) - sc := new(stmtctx.StatementContext) + sc := stmtctx.NewStmtCtx() tests := []struct { datum Datum expected interface{} @@ -311,7 +309,7 @@ func TestToBytes(t *testing.T) { {NewStringDatum("abc"), []byte("abc")}, {Datum{}, []byte{}}, } - sc := new(stmtctx.StatementContext) + sc := stmtctx.NewStmtCtx() sc.IgnoreTruncate.Store(true) for _, tt := range tests { bin, err := tt.a.ToBytes() @@ -321,7 +319,7 @@ func TestToBytes(t *testing.T) { } func TestComputePlusAndMinus(t *testing.T) { - sc := &stmtctx.StatementContext{TimeZone: time.UTC} + sc := stmtctx.NewStmtCtxWithTimeZone(time.UTC) tests := []struct { a Datum b Datum @@ -360,7 +358,7 @@ func TestCloneDatum(t *testing.T) { raw, } - sc := new(stmtctx.StatementContext) + sc := stmtctx.NewStmtCtx() sc.IgnoreTruncate.Store(true) for _, tt := range tests { tt1 := *tt.Clone() @@ -414,7 +412,7 @@ func TestEstimatedMemUsage(t *testing.T) { } func TestChangeReverseResultByUpperLowerBound(t *testing.T) { - sc := new(stmtctx.StatementContext) + sc := stmtctx.NewStmtCtx() sc.IgnoreTruncate.Store(true) sc.OverflowAsWarning = true // TODO: add more reserve convert tests for each pair of convert type. @@ -539,7 +537,7 @@ func TestStringToMysqlBit(t *testing.T) { {NewStringDatum("b'1'"), []byte{1}}, {NewStringDatum("b'0'"), []byte{0}}, } - sc := new(stmtctx.StatementContext) + sc := stmtctx.NewStmtCtx() sc.IgnoreTruncate.Store(true) tp := NewFieldType(mysql.TypeBit) tp.SetFlen(1) @@ -605,7 +603,7 @@ func TestMarshalDatum(t *testing.T) { func BenchmarkCompareDatum(b *testing.B) { vals, vals1 := prepareCompareDatums() - sc := new(stmtctx.StatementContext) + sc := stmtctx.NewStmtCtx() b.ResetTimer() for i := 0; i < b.N; i++ { for j, v := range vals { @@ -652,7 +650,7 @@ func TestProduceDecWithSpecifiedTp(t *testing.T) { {"99.9999", 6, 3, "100.000", false, true}, {"-99.9999", 6, 3, "-100.000", false, true}, } - sc := new(stmtctx.StatementContext) + sc := stmtctx.NewStmtCtx() for _, tt := range tests { tp := NewFieldTypeBuilder().SetType(mysql.TypeNewDecimal).SetFlen(tt.flen).SetDecimal(tt.frac).BuildP() dec := NewDecFromStringForTest(tt.dec) @@ -698,7 +696,7 @@ func TestNULLNotEqualWithOthers(t *testing.T) { MaxValueDatum(), } nullDatum := NewDatum(nil) - sc := new(stmtctx.StatementContext) + sc := stmtctx.NewStmtCtx() for _, d := range datums { result, err := d.Compare(sc, &nullDatum, collate.GetBinaryCollator()) require.NoError(t, err) diff --git a/types/time.go b/types/time.go index adeb549950a3c..957090306f227 100644 --- a/types/time.go +++ b/types/time.go @@ -524,13 +524,13 @@ func (t Time) RoundFrac(sc *stmtctx.StatementContext, fsp int) (Time, error) { } var nt CoreTime - if t1, err := t.GoTime(sc.TimeZone); err == nil { + if t1, err := t.GoTime(sc.TimeZone()); err == nil { t1 = roundTime(t1, fsp) nt = FromGoTime(t1) } else { // Take the hh:mm:ss part out to avoid handle month or day = 0. hour, minute, second, microsecond := t.Hour(), t.Minute(), t.Second(), t.Microsecond() - t1 := gotime.Date(1, 1, 1, hour, minute, second, microsecond*1000, sc.TimeZone) + t1 := gotime.Date(1, 1, 1, hour, minute, second, microsecond*1000, sc.TimeZone()) t2 := roundTime(t1, fsp) hour, minute, second = t2.Clock() microsecond = t2.Nanosecond() / 1000 @@ -702,9 +702,9 @@ func (t *Time) Check(sc *stmtctx.StatementContext) error { func (t *Time) Sub(sc *stmtctx.StatementContext, t1 *Time) Duration { var duration gotime.Duration if t.Type() == mysql.TypeTimestamp && t1.Type() == mysql.TypeTimestamp { - a, err := t.GoTime(sc.TimeZone) + a, err := t.GoTime(sc.TimeZone()) terror.Log(errors.Trace(err)) - b, err := t1.GoTime(sc.TimeZone) + b, err := t1.GoTime(sc.TimeZone()) terror.Log(errors.Trace(err)) duration = a.Sub(b) } else { @@ -1194,7 +1194,7 @@ func parseDatetime(sc *stmtctx.StatementContext, str string, fsp int, isFloat bo if explicitTz != nil { t1, err = tmp.GoTime(explicitTz) } else { - t1, err = tmp.GoTime(sc.TimeZone) + t1, err = tmp.GoTime(sc.TimeZone()) } if err != nil { return ZeroDatetime, errors.Trace(err) @@ -1231,7 +1231,7 @@ func parseDatetime(sc *stmtctx.StatementContext, str string, fsp int, isFloat bo if explicitTz != nil { t1 = t1.In(explicitTz) } else { - t1 = t1.In(sc.TimeZone) + t1 = t1.In(sc.TimeZone()) } tmp = FromGoTime(t1) } @@ -1501,7 +1501,7 @@ func (d Duration) ToNumber() *MyDecimal { // ConvertToTime converts duration to Time. // Tp is TypeDatetime, TypeTimestamp and TypeDate. func (d Duration) ConvertToTime(sc *stmtctx.StatementContext, tp uint8) (Time, error) { - year, month, day := gotime.Now().In(sc.TimeZone).Date() + year, month, day := gotime.Now().In(sc.TimeZone()).Date() datePart := FromDate(year, int(month), day, 0, 0, 0, 0) mixDateAndDuration(&datePart, d) @@ -1512,7 +1512,7 @@ func (d Duration) ConvertToTime(sc *stmtctx.StatementContext, tp uint8) (Time, e // ConvertToTimeWithTimestamp converts duration to Time by system timestamp. // Tp is TypeDatetime, TypeTimestamp and TypeDate. func (d Duration) ConvertToTimeWithTimestamp(sc *stmtctx.StatementContext, tp uint8, ts gotime.Time) (Time, error) { - year, month, day := ts.In(sc.TimeZone).Date() + year, month, day := ts.In(sc.TimeZone()).Date() datePart := FromDate(year, int(month), day, 0, 0, 0, 0) mixDateAndDuration(&datePart, d) @@ -1833,7 +1833,7 @@ func ParseDuration(sc *stmtctx.StatementContext, str string, fsp int) (Duration, return ZeroDuration, true, ErrTruncatedWrongVal.GenWithStackByArgs("time", str) } - d, err = d.RoundFrac(fsp, sc.TimeZone) + d, err = d.RoundFrac(fsp, sc.TimeZone()) return d, false, err } @@ -2168,7 +2168,7 @@ func checkTimestampType(sc *stmtctx.StatementContext, t CoreTime, explicitTz *go } var checkTime CoreTime - tz := sc.TimeZone + tz := sc.TimeZone() if explicitTz != nil { tz = explicitTz } @@ -3435,7 +3435,7 @@ func DateFSP(date string) (fsp int) { // DateTimeIsOverflow returns if this date is overflow. // See: https://dev.mysql.com/doc/refman/8.0/en/datetime.html func DateTimeIsOverflow(sc *stmtctx.StatementContext, date Time) (bool, error) { - tz := sc.TimeZone + tz := sc.TimeZone() if tz == nil { logutil.BgLogger().Warn("use gotime.local because sc.timezone is nil") tz = gotime.Local diff --git a/types/time_test.go b/types/time_test.go index db74d9642acbe..11390f3d2a29b 100644 --- a/types/time_test.go +++ b/types/time_test.go @@ -192,7 +192,7 @@ func TestTimestamp(t *testing.T) { } for _, test := range table { - v, err := types.ParseTimestamp(&stmtctx.StatementContext{TimeZone: time.UTC}, test.Input) + v, err := types.ParseTimestamp(stmtctx.NewStmtCtxWithTimeZone(time.UTC), test.Input) require.NoError(t, err) require.Equal(t, test.Expect, v.String()) } @@ -203,7 +203,7 @@ func TestTimestamp(t *testing.T) { } for _, test := range errTable { - _, err := types.ParseTimestamp(&stmtctx.StatementContext{TimeZone: time.UTC}, test) + _, err := types.ParseTimestamp(stmtctx.NewStmtCtxWithTimeZone(time.UTC), test) require.Error(t, err) } } @@ -575,7 +575,7 @@ func TestYear(t *testing.T) { } func TestCodec(t *testing.T) { - sc := &stmtctx.StatementContext{TimeZone: time.UTC} + sc := stmtctx.NewStmtCtxWithTimeZone(time.UTC) // MySQL timestamp value doesn't allow month=0 or day=0. _, err := types.ParseTimestamp(sc, "2016-12-00 00:00:00") @@ -680,9 +680,7 @@ func TestParseTimeFromNum(t *testing.T) { require.Equal(t, test.ExpectDateTimeValue, t1.String()) // testtypes.ParseTimestampFromNum - t1, err = types.ParseTimestampFromNum(&stmtctx.StatementContext{ - TimeZone: time.UTC, - }, test.Input) + t1, err = types.ParseTimestampFromNum(stmtctx.NewStmtCtxWithTimeZone(time.UTC), test.Input) if test.ExpectTimeStampError { require.Error(t, err) } else { @@ -709,7 +707,7 @@ func TestToNumber(t *testing.T) { sc.IgnoreZeroInDate = true losAngelesTz, err := time.LoadLocation("America/Los_Angeles") require.NoError(t, err) - sc.TimeZone = losAngelesTz + sc.SetTimeZone(losAngelesTz) tblDateTime := []struct { Input string Fsp int @@ -851,7 +849,7 @@ func TestParseFrac(t *testing.T) { func TestRoundFrac(t *testing.T) { sc := mock.NewContext().GetSessionVars().StmtCtx sc.IgnoreZeroInDate = true - sc.TimeZone = time.UTC + sc.SetTimeZone(time.UTC) tbl := []struct { Input string Fsp int @@ -880,7 +878,7 @@ func TestRoundFrac(t *testing.T) { // test different time zone losAngelesTz, err := time.LoadLocation("America/Los_Angeles") require.NoError(t, err) - sc.TimeZone = losAngelesTz + sc.SetTimeZone(losAngelesTz) tbl = []struct { Input string Fsp int @@ -919,7 +917,7 @@ func TestRoundFrac(t *testing.T) { for _, tt := range tbl { v, _, err := types.ParseDuration(sc, tt.Input, types.MaxFsp) require.NoError(t, err) - nv, err := v.RoundFrac(tt.Fsp, sc.TimeZone) + nv, err := v.RoundFrac(tt.Fsp, sc.TimeZone()) require.NoError(t, err) require.Equal(t, tt.Except, nv.String()) } @@ -944,7 +942,7 @@ func TestConvert(t *testing.T) { sc := mock.NewContext().GetSessionVars().StmtCtx sc.IgnoreZeroInDate = true losAngelesTz, _ := time.LoadLocation("America/Los_Angeles") - sc.TimeZone = losAngelesTz + sc.SetTimeZone(losAngelesTz) tbl := []struct { Input string Fsp int @@ -977,21 +975,21 @@ func TestConvert(t *testing.T) { {"1 11:30:45.999999", 0}, } // test different time zone. - sc.TimeZone = time.UTC + sc.SetTimeZone(time.UTC) for _, tt := range tblDuration { v, _, err := types.ParseDuration(sc, tt.Input, tt.Fsp) require.NoError(t, err) - year, month, day := time.Now().In(sc.TimeZone).Date() - n := time.Date(year, month, day, 0, 0, 0, 0, sc.TimeZone) + year, month, day := time.Now().In(sc.TimeZone()).Date() + n := time.Date(year, month, day, 0, 0, 0, 0, sc.TimeZone()) t1, err := v.ConvertToTime(sc, mysql.TypeDatetime) require.NoError(t, err) - t2, _ := t1.GoTime(sc.TimeZone) + t2, _ := t1.GoTime(sc.TimeZone()) require.Equal(t, v.Duration, t2.Sub(n)) } } func TestCompare(t *testing.T) { - sc := &stmtctx.StatementContext{TimeZone: time.UTC} + sc := stmtctx.NewStmtCtxWithTimeZone(time.UTC) tbl := []struct { Arg1 string Arg2 string @@ -1054,7 +1052,7 @@ func TestDurationClock(t *testing.T) { } for _, tt := range tbl { - d, _, err := types.ParseDuration(&stmtctx.StatementContext{TimeZone: time.UTC}, tt.Input, types.MaxFsp) + d, _, err := types.ParseDuration(stmtctx.NewStmtCtxWithTimeZone(time.UTC), tt.Input, types.MaxFsp) require.NoError(t, err) require.Equal(t, tt.Hour, d.Hour()) require.Equal(t, tt.Minute, d.Minute()) @@ -1165,9 +1163,7 @@ func TestTimeAdd(t *testing.T) { {"2017-08-21", "01:01:01.001", "2017-08-21 01:01:01.001"}, } - sc := &stmtctx.StatementContext{ - TimeZone: time.UTC, - } + sc := stmtctx.NewStmtCtxWithTimeZone(time.UTC) for _, tt := range tbl { v1, err := types.ParseTime(sc, tt.Arg1, mysql.TypeDatetime, types.MaxFsp, nil) require.NoError(t, err) @@ -1256,7 +1252,7 @@ func TestCheckTimestamp(t *testing.T) { } for _, tt := range tests { - validTimestamp := types.CheckTimestampTypeForTest(&stmtctx.StatementContext{TimeZone: tt.tz}, tt.input, nil) + validTimestamp := types.CheckTimestampTypeForTest(stmtctx.NewStmtCtxWithTimeZone(tt.tz), tt.input, nil) if tt.expectRetError { require.Errorf(t, validTimestamp, "For %s %s", tt.input, tt.tz) } else { @@ -1313,7 +1309,7 @@ func TestCheckTimestamp(t *testing.T) { } for _, tt := range tests { - validTimestamp := types.CheckTimestampTypeForTest(&stmtctx.StatementContext{TimeZone: tt.tz}, tt.input, nil) + validTimestamp := types.CheckTimestampTypeForTest(stmtctx.NewStmtCtxWithTimeZone(tt.tz), tt.input, nil) if tt.expectRetError { require.Errorf(t, validTimestamp, "For %s %s", tt.input, tt.tz) } else { @@ -1987,9 +1983,7 @@ func TestTimeSub(t *testing.T) { {"2019-04-12 18:20:00", "2019-04-12 14:00:00", "04:20:00"}, } - sc := &stmtctx.StatementContext{ - TimeZone: time.UTC, - } + sc := stmtctx.NewStmtCtxWithTimeZone(time.UTC) for _, tt := range tbl { v1, err := types.ParseTime(sc, tt.Arg1, mysql.TypeDatetime, types.MaxFsp, nil) require.NoError(t, err) @@ -2022,10 +2016,8 @@ func TestCheckMonthDay(t *testing.T) { {types.FromDate(3200, 2, 29, 0, 0, 0, 0), true}, } - sc := &stmtctx.StatementContext{ - TimeZone: time.UTC, - AllowInvalidDate: false, - } + sc := stmtctx.NewStmtCtxWithTimeZone(time.UTC) + sc.AllowInvalidDate = false for _, tt := range dates { v := types.NewTime(tt.date, mysql.TypeDate, types.DefaultFsp) @@ -2190,7 +2182,7 @@ func TestParseWithTimezone(t *testing.T) { }, } for ith, ca := range cases { - v, err := types.ParseTime(&stmtctx.StatementContext{TimeZone: ca.sysTZ}, ca.lit, mysql.TypeTimestamp, ca.fsp, nil) + v, err := types.ParseTime(stmtctx.NewStmtCtxWithTimeZone(ca.sysTZ), ca.lit, mysql.TypeTimestamp, ca.fsp, nil) require.NoErrorf(t, err, "tidb time parse misbehaved on %d", ith) if err != nil { continue @@ -2223,9 +2215,7 @@ func BenchmarkFormat(b *testing.B) { } func BenchmarkTimeAdd(b *testing.B) { - sc := &stmtctx.StatementContext{ - TimeZone: time.UTC, - } + sc := stmtctx.NewStmtCtxWithTimeZone(time.UTC) arg1, _ := types.ParseTime(sc, "2017-01-18", mysql.TypeDatetime, types.MaxFsp, nil) arg2, _, _ := types.ParseDuration(sc, "12:30:59", types.MaxFsp) for i := 0; i < b.N; i++ { @@ -2237,7 +2227,7 @@ func BenchmarkTimeAdd(b *testing.B) { } func BenchmarkTimeCompare(b *testing.B) { - sc := &stmtctx.StatementContext{TimeZone: time.UTC} + sc := stmtctx.NewStmtCtxWithTimeZone(time.UTC) mustParse := func(str string) types.Time { t, err := types.ParseDatetime(sc, str) if err != nil { @@ -2300,7 +2290,7 @@ func benchmarkDatetimeFormat(b *testing.B, name string, sc *stmtctx.StatementCon } func BenchmarkParseDatetimeFormat(b *testing.B) { - sc := &stmtctx.StatementContext{TimeZone: time.UTC} + sc := stmtctx.NewStmtCtxWithTimeZone(time.UTC) benchmarkDatetimeFormat(b, "datetime without timezone", sc, "2020-10-10T10:10:10") benchmarkDatetimeFormat(b, "datetime with timezone", sc, "2020-10-10T10:10:10Z+08:00") } @@ -2315,7 +2305,7 @@ func benchmarkStrToDate(b *testing.B, name string, sc *stmtctx.StatementContext, } func BenchmarkStrToDate(b *testing.B) { - sc := &stmtctx.StatementContext{TimeZone: time.UTC} + sc := stmtctx.NewStmtCtxWithTimeZone(time.UTC) benchmarkStrToDate(b, "strToDate yyyyMMdd hhmmss ffff", sc, "31/05/2016 12:34:56.1234", "%d/%m/%Y %H:%i:%S.%f") benchmarkStrToDate(b, "strToDate %r ddMMyyyy", sc, "04:13:56 AM 13/05/2019", "%r %d/%c/%Y") benchmarkStrToDate(b, "strToDate %T ddMMyyyy", sc, " 4:13:56 13/05/2019", "%T %d/%c/%Y") diff --git a/util/chunk/chunk_test.go b/util/chunk/chunk_test.go index df7d3728a196d..d6550a8398842 100644 --- a/util/chunk/chunk_test.go +++ b/util/chunk/chunk_test.go @@ -545,7 +545,7 @@ func TestGetDecimalDatum(t *testing.T) { decType := types.NewFieldType(mysql.TypeNewDecimal) decType.SetFlen(4) decType.SetDecimal(2) - sc := new(stmtctx.StatementContext) + sc := stmtctx.NewStmtCtx() decDatum, err := datum.ConvertTo(sc, decType) require.NoError(t, err) diff --git a/util/chunk/mutrow_test.go b/util/chunk/mutrow_test.go index ef9e962f4fb96..a69cfcd75fac3 100644 --- a/util/chunk/mutrow_test.go +++ b/util/chunk/mutrow_test.go @@ -29,7 +29,7 @@ func TestMutRow(t *testing.T) { allTypes := newAllTypes() mutRow := MutRowFromTypes(allTypes) row := mutRow.ToRow() - sc := new(stmtctx.StatementContext) + sc := stmtctx.NewStmtCtx() for i := 0; i < row.Len(); i++ { val := zeroValForType(allTypes[i]) d := row.GetDatum(i, allTypes[i]) diff --git a/util/codec/codec.go b/util/codec/codec.go index 29af9a2f01f80..68d3d46321b0e 100644 --- a/util/codec/codec.go +++ b/util/codec/codec.go @@ -187,8 +187,8 @@ func EncodeMySQLTime(sc *stmtctx.StatementContext, t types.Time, tp byte, b []by if tp == mysql.TypeUnspecified { tp = t.Type() } - if tp == mysql.TypeTimestamp && sc.TimeZone != time.UTC { - err = t.ConvertTimeZone(sc.TimeZone, time.UTC) + if tp == mysql.TypeTimestamp && sc.TimeZone() != time.UTC { + err = t.ConvertTimeZone(sc.TimeZone(), time.UTC) if err != nil { return nil, err } diff --git a/util/codec/codec_test.go b/util/codec/codec_test.go index 7c73ee3910c74..63618af18adee 100644 --- a/util/codec/codec_test.go +++ b/util/codec/codec_test.go @@ -72,7 +72,7 @@ func TestCodecKey(t *testing.T) { types.MakeDatums(uint64(1), uint64(1)), }, } - sc := &stmtctx.StatementContext{TimeZone: time.Local} + sc := stmtctx.NewStmtCtxWithTimeZone(time.Local) for i, datums := range table { comment := fmt.Sprintf("%d %v", i, datums) b, err := EncodeKey(sc, nil, datums.Input...) @@ -214,7 +214,7 @@ func TestCodecKeyCompare(t *testing.T) { -1, }, } - sc := &stmtctx.StatementContext{TimeZone: time.Local} + sc := stmtctx.NewStmtCtxWithTimeZone(time.Local) for _, datums := range table { b1, err := EncodeKey(sc, nil, datums.Left...) require.NoError(t, err) @@ -519,7 +519,7 @@ func TestBytes(t *testing.T) { } func parseTime(t *testing.T, s string) types.Time { - sc := &stmtctx.StatementContext{TimeZone: time.UTC} + sc := stmtctx.NewStmtCtxWithTimeZone(time.UTC) m, err := types.ParseTime(sc, s, mysql.TypeDatetime, types.DefaultFsp, nil) require.NoError(t, err) return m @@ -537,7 +537,7 @@ func TestTime(t *testing.T) { "2011-01-01 00:00:00", "0001-01-01 00:00:00", } - sc := &stmtctx.StatementContext{TimeZone: time.Local} + sc := stmtctx.NewStmtCtxWithTimeZone(time.Local) for _, timeDatum := range tbl { m := types.NewDatum(parseTime(t, timeDatum)) @@ -584,7 +584,7 @@ func TestDuration(t *testing.T) { "00:00:00", "1 11:11:11", } - sc := &stmtctx.StatementContext{TimeZone: time.Local} + sc := stmtctx.NewStmtCtxWithTimeZone(time.Local) for _, duration := range tbl { m := parseDuration(t, duration) @@ -637,7 +637,7 @@ func TestDecimal(t *testing.T) { "-12.340", "-0.1234", } - sc := &stmtctx.StatementContext{TimeZone: time.Local} + sc := stmtctx.NewStmtCtxWithTimeZone(time.Local) for _, decimalNum := range tbl { dec := new(types.MyDecimal) err := dec.FromString([]byte(decimalNum)) @@ -876,7 +876,7 @@ func TestCut(t *testing.T) { types.MakeDatums(types.CreateBinaryJSON(types.Opaque{TypeCode: mysql.TypeString, Buf: []byte("abc")})), }, } - sc := &stmtctx.StatementContext{TimeZone: time.Local} + sc := stmtctx.NewStmtCtxWithTimeZone(time.Local) for i, datums := range table { b, err := EncodeKey(sc, nil, datums.Input...) require.NoErrorf(t, err, "%d %v", i, datums) @@ -933,7 +933,7 @@ func TestCutOneError(t *testing.T) { } func TestSetRawValues(t *testing.T) { - sc := &stmtctx.StatementContext{TimeZone: time.Local} + sc := stmtctx.NewStmtCtxWithTimeZone(time.Local) datums := types.MakeDatums(1, "abc", 1.1, []byte("def")) rowData, err := EncodeValue(sc, nil, datums...) require.NoError(t, err) @@ -951,7 +951,7 @@ func TestSetRawValues(t *testing.T) { } func TestDecodeOneToChunk(t *testing.T) { - sc := &stmtctx.StatementContext{TimeZone: time.Local} + sc := stmtctx.NewStmtCtxWithTimeZone(time.Local) datums, tps := datumsForTest(sc) rowCount := 3 chk := chunkForTest(t, sc, datums, tps, rowCount) @@ -975,7 +975,7 @@ func TestDecodeOneToChunk(t *testing.T) { } func TestHashGroup(t *testing.T) { - sc := &stmtctx.StatementContext{TimeZone: time.Local} + sc := stmtctx.NewStmtCtxWithTimeZone(time.Local) tp := types.NewFieldType(mysql.TypeNewDecimal) tps := []*types.FieldType{tp} chk1 := chunk.New(tps, 3, 3) @@ -1066,7 +1066,7 @@ func datumsForTest(_ *stmtctx.StatementContext) ([]types.Datum, []*types.FieldTy } func chunkForTest(t *testing.T, sc *stmtctx.StatementContext, datums []types.Datum, tps []*types.FieldType, rowCount int) *chunk.Chunk { - decoder := NewDecoder(chunk.New(tps, 32, 32), sc.TimeZone) + decoder := NewDecoder(chunk.New(tps, 32, 32), sc.TimeZone()) for rowIdx := 0; rowIdx < rowCount; rowIdx++ { encoded, err := EncodeValue(sc, nil, datums...) require.NoError(t, err) @@ -1103,7 +1103,7 @@ func TestDecodeRange(t *testing.T) { } func testHashChunkRowEqual(t *testing.T, a, b interface{}, equal bool) { - sc := &stmtctx.StatementContext{TimeZone: time.Local} + sc := stmtctx.NewStmtCtxWithTimeZone(time.Local) buf1 := make([]byte, 1) buf2 := make([]byte, 1) @@ -1146,7 +1146,7 @@ func testHashChunkRowEqual(t *testing.T, a, b interface{}, equal bool) { } func TestHashChunkRow(t *testing.T) { - sc := &stmtctx.StatementContext{TimeZone: time.Local} + sc := stmtctx.NewStmtCtxWithTimeZone(time.Local) buf := make([]byte, 1) datums, tps := datumsForTest(sc) chk := chunkForTest(t, sc, datums, tps, 1) @@ -1234,7 +1234,7 @@ func TestValueSizeOfUnsignedInt(t *testing.T) { } func TestHashChunkColumns(t *testing.T) { - sc := &stmtctx.StatementContext{TimeZone: time.Local} + sc := stmtctx.NewStmtCtxWithTimeZone(time.Local) buf := make([]byte, 1) datums, tps := datumsForTest(sc) chk := chunkForTest(t, sc, datums, tps, 4) diff --git a/util/codec/collation_test.go b/util/codec/collation_test.go index a6358f700b1eb..db73ba78c5dbd 100644 --- a/util/codec/collation_test.go +++ b/util/codec/collation_test.go @@ -45,7 +45,7 @@ func prepareCollationData() (int, *chunk.Chunk, *chunk.Chunk) { } func TestHashGroupKeyCollation(t *testing.T) { - sc := &stmtctx.StatementContext{TimeZone: time.Local} + sc := stmtctx.NewStmtCtxWithTimeZone(time.Local) tp := types.NewFieldType(mysql.TypeString) n, chk1, chk2 := prepareCollationData() @@ -82,7 +82,7 @@ func TestHashGroupKeyCollation(t *testing.T) { } func TestHashChunkRowCollation(t *testing.T) { - sc := &stmtctx.StatementContext{TimeZone: time.Local} + sc := stmtctx.NewStmtCtxWithTimeZone(time.Local) tp := types.NewFieldType(mysql.TypeString) tps := []*types.FieldType{tp} n, chk1, chk2 := prepareCollationData() @@ -124,7 +124,7 @@ func TestHashChunkRowCollation(t *testing.T) { } func TestHashChunkColumnsCollation(t *testing.T) { - sc := &stmtctx.StatementContext{TimeZone: time.Local} + sc := stmtctx.NewStmtCtxWithTimeZone(time.Local) tp := types.NewFieldType(mysql.TypeString) n, chk1, chk2 := prepareCollationData() buf := make([]byte, 1) diff --git a/util/dbutil/common.go b/util/dbutil/common.go index 7a6a2927ac791..44199ac11e6d3 100644 --- a/util/dbutil/common.go +++ b/util/dbutil/common.go @@ -550,7 +550,7 @@ func AnalyzeValuesFromBuckets(valueString string, cols []*model.ColumnInfo) ([]s for i, col := range cols { if IsTimeTypeAndNeedDecode(col.GetType()) { // check if values[i] is already a time string - sc := &stmtctx.StatementContext{TimeZone: time.UTC} + sc := stmtctx.NewStmtCtxWithTimeZone(time.UTC) _, err := types.ParseTime(sc, values[i], col.GetType(), types.MinFsp, nil) if err == nil { continue diff --git a/util/keydecoder/keydecoder_test.go b/util/keydecoder/keydecoder_test.go index 8314949803b4f..3a545a7a9b23c 100644 --- a/util/keydecoder/keydecoder_test.go +++ b/util/keydecoder/keydecoder_test.go @@ -116,7 +116,7 @@ func TestDecodeKey(t *testing.T) { assert.Nil(t, decodedKey.IndexValues) values := types.MakeDatums("abc", 1) - sc := &stmtctx.StatementContext{} + sc := stmtctx.NewStmtCtx() encodedValue, err := codec.EncodeKey(sc, nil, values...) assert.Nil(t, err) key = []byte{ diff --git a/util/memoryusagealarm/memoryusagealarm_test.go b/util/memoryusagealarm/memoryusagealarm_test.go index d5d5257355920..5e6535d3c0c12 100644 --- a/util/memoryusagealarm/memoryusagealarm_test.go +++ b/util/memoryusagealarm/memoryusagealarm_test.go @@ -106,7 +106,7 @@ func genMockProcessInfoList(memConsumeList []int64, startTimeList []time.Time, s tracker.Consume(memConsumeList[i]) var stmtCtxRefCount stmtctx.ReferenceCount = 0 processInfo := util.ProcessInfo{Time: startTimeList[i], - StmtCtx: &stmtctx.StatementContext{}, + StmtCtx: stmtctx.NewStmtCtx(), MemTracker: tracker, StatsInfo: func(interface{}) map[string]uint64 { return map[string]uint64{} diff --git a/util/misc_test.go b/util/misc_test.go index ffdc75c49e0e0..aac12016da84d 100644 --- a/util/misc_test.go +++ b/util/misc_test.go @@ -114,6 +114,8 @@ func TestBasicFuncSyntaxWarn(t *testing.T) { } func TestBasicFuncProcessInfo(t *testing.T) { + sc := stmtctx.NewStmtCtx() + sc.MemTracker = memory.NewTracker(-1, -1) pi := ProcessInfo{ ID: 1, User: "test", @@ -124,9 +126,7 @@ func TestBasicFuncProcessInfo(t *testing.T) { Time: time.Now(), State: 3, Info: "test", - StmtCtx: &stmtctx.StatementContext{ - MemTracker: memory.NewTracker(-1, -1), - }, + StmtCtx: sc, } row := pi.ToRowForShow(false) row2 := pi.ToRowForShow(true) diff --git a/util/mock/context.go b/util/mock/context.go index 3ecf16466b3d6..38e09ec72a6ec 100644 --- a/util/mock/context.go +++ b/util/mock/context.go @@ -465,7 +465,7 @@ func NewContext() *Context { sctx.sessionVars = vars vars.InitChunkSize = 2 vars.MaxChunkSize = 32 - vars.StmtCtx.TimeZone = time.UTC + vars.StmtCtx.SetTimeZone(time.UTC) vars.MemTracker.SetBytesLimit(-1) vars.DiskTracker.SetBytesLimit(-1) vars.StmtCtx.MemTracker, vars.StmtCtx.DiskTracker = memory.NewTracker(-1, -1), disk.NewTracker(-1, -1) diff --git a/util/nocopy/BUILD.bazel b/util/nocopy/BUILD.bazel new file mode 100644 index 0000000000000..fa034f988cd19 --- /dev/null +++ b/util/nocopy/BUILD.bazel @@ -0,0 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +go_library( + name = "nocopy", + srcs = ["nocopy.go"], + importpath = "github.com/pingcap/tidb/util/nocopy", + visibility = ["//visibility:public"], +) diff --git a/util/nocopy/nocopy.go b/util/nocopy/nocopy.go new file mode 100644 index 0000000000000..30ee74104b676 --- /dev/null +++ b/util/nocopy/nocopy.go @@ -0,0 +1,24 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package nocopy + +// NoCopy implements sync.Locker to make an object no copy +type NoCopy struct{} + +// Lock is an empty function to implement sync.Locker interface +func (_ *NoCopy) Lock() {} + +// Unlock is an empty function to implement sync.Locker interface +func (_ *NoCopy) Unlock() {} diff --git a/util/rowDecoder/decoder_test.go b/util/rowDecoder/decoder_test.go index 8af317ab375c2..94eab0a17c0a8 100644 --- a/util/rowDecoder/decoder_test.go +++ b/util/rowDecoder/decoder_test.go @@ -53,7 +53,7 @@ func TestRowDecoder(t *testing.T) { tbl := tables.MockTableFromMeta(tblInfo) ctx := mock.NewContext() - sc := &stmtctx.StatementContext{TimeZone: time.UTC} + sc := stmtctx.NewStmtCtxWithTimeZone(time.UTC) decodeColsMap := make(map[int64]decoder.Column, len(cols)) decodeColsMap2 := make(map[int64]decoder.Column, len(cols)) for _, col := range tbl.Cols() { @@ -164,7 +164,7 @@ func TestClusterIndexRowDecoder(t *testing.T) { tbl := tables.MockTableFromMeta(tblInfo) ctx := mock.NewContext() - sc := &stmtctx.StatementContext{TimeZone: time.UTC} + sc := stmtctx.NewStmtCtxWithTimeZone(time.UTC) decodeColsMap := make(map[int64]decoder.Column, len(cols)) for _, col := range tbl.Cols() { tpExpr := decoder.Column{ diff --git a/util/rowcodec/bench_test.go b/util/rowcodec/bench_test.go index c91916f3d0a40..e8b7150f2045c 100644 --- a/util/rowcodec/bench_test.go +++ b/util/rowcodec/bench_test.go @@ -67,7 +67,7 @@ func BenchmarkEncode(b *testing.B) { func BenchmarkEncodeFromOldRow(b *testing.B) { b.ReportAllocs() oldRow := types.MakeDatums(1, "abc", 1.1) - oldRowData, err := tablecodec.EncodeOldRow(new(stmtctx.StatementContext), oldRow, []int64{1, 2, 3}, nil, nil) + oldRowData, err := tablecodec.EncodeOldRow(stmtctx.NewStmtCtx(), oldRow, []int64{1, 2, 3}, nil, nil) if err != nil { b.Fatal(err) } diff --git a/util/rowcodec/encoder.go b/util/rowcodec/encoder.go index 55fad578d612e..95ff35b13aa14 100644 --- a/util/rowcodec/encoder.go +++ b/util/rowcodec/encoder.go @@ -173,8 +173,8 @@ func encodeValueDatum(sc *stmtctx.StatementContext, d *types.Datum, buffer []byt case types.KindMysqlTime: // for mysql datetime, timestamp and date type t := d.GetMysqlTime() - if t.Type() == mysql.TypeTimestamp && sc != nil && sc.TimeZone != time.UTC { - err = t.ConvertTimeZone(sc.TimeZone, time.UTC) + if t.Type() == mysql.TypeTimestamp && sc != nil && sc.TimeZone() != time.UTC { + err = t.ConvertTimeZone(sc.TimeZone(), time.UTC) if err != nil { return } diff --git a/util/rowcodec/rowcodec_test.go b/util/rowcodec/rowcodec_test.go index dacf576ded5ab..933895dc78ea3 100644 --- a/util/rowcodec/rowcodec_test.go +++ b/util/rowcodec/rowcodec_test.go @@ -51,7 +51,7 @@ func TestEncodeLargeSmallReuseBug(t *testing.T) { colFt := types.NewFieldType(mysql.TypeString) largeColID := int64(300) - b, err := encoder.Encode(&stmtctx.StatementContext{}, []int64{largeColID}, []types.Datum{types.NewBytesDatum([]byte(""))}, nil) + b, err := encoder.Encode(stmtctx.NewStmtCtx(), []int64{largeColID}, []types.Datum{types.NewBytesDatum([]byte(""))}, nil) require.NoError(t, err) bDecoder := rowcodec.NewDatumMapDecoder([]rowcodec.ColInfo{ @@ -66,7 +66,7 @@ func TestEncodeLargeSmallReuseBug(t *testing.T) { colFt = types.NewFieldType(mysql.TypeLonglong) smallColID := int64(1) - b, err = encoder.Encode(&stmtctx.StatementContext{}, []int64{smallColID}, []types.Datum{types.NewIntDatum(2)}, nil) + b, err = encoder.Encode(stmtctx.NewStmtCtx(), []int64{smallColID}, []types.Datum{types.NewIntDatum(2)}, nil) require.NoError(t, err) bDecoder = rowcodec.NewDatumMapDecoder([]rowcodec.ColInfo{ @@ -162,17 +162,16 @@ func TestDecodeRowWithHandle(t *testing.T) { // test encode input. var encoder rowcodec.Encoder - sc := new(stmtctx.StatementContext) - sc.TimeZone = time.UTC + sc := stmtctx.NewStmtCtxWithTimeZone(time.UTC) newRow, err := encoder.Encode(sc, colIDs, dts, nil) require.NoError(t, err) // decode to datum map. - mDecoder := rowcodec.NewDatumMapDecoder(cols, sc.TimeZone) + mDecoder := rowcodec.NewDatumMapDecoder(cols, sc.TimeZone()) dm, err := mDecoder.DecodeToDatumMap(newRow, nil) require.NoError(t, err) - dm, err = tablecodec.DecodeHandleToDatumMap(kv.IntHandle(handleValue), []int64{handleID}, handleColFtMap, sc.TimeZone, dm) + dm, err = tablecodec.DecodeHandleToDatumMap(kv.IntHandle(handleValue), []int64{handleID}, handleColFtMap, sc.TimeZone(), dm) require.NoError(t, err) for _, d := range td { @@ -182,7 +181,7 @@ func TestDecodeRowWithHandle(t *testing.T) { } // decode to chunk. - cDecoder := rowcodec.NewChunkDecoder(cols, []int64{-1}, nil, sc.TimeZone) + cDecoder := rowcodec.NewChunkDecoder(cols, []int64{-1}, nil, sc.TimeZone()) chk := chunk.New(fts, 1, 1) err = cDecoder.DecodeToChunk(newRow, kv.IntHandle(handleValue), chk) require.NoError(t, err) @@ -223,8 +222,7 @@ func TestDecodeRowWithHandle(t *testing.T) { func TestEncodeKindNullDatum(t *testing.T) { var encoder rowcodec.Encoder - sc := new(stmtctx.StatementContext) - sc.TimeZone = time.UTC + sc := stmtctx.NewStmtCtxWithTimeZone(time.UTC) colIDs := []int64{1, 2} var nilDt types.Datum @@ -236,7 +234,7 @@ func TestEncodeKindNullDatum(t *testing.T) { require.NoError(t, err) cols := []rowcodec.ColInfo{{ID: 1, Ft: ft}, {ID: 2, Ft: ft}} - cDecoder := rowcodec.NewChunkDecoder(cols, []int64{-1}, nil, sc.TimeZone) + cDecoder := rowcodec.NewChunkDecoder(cols, []int64{-1}, nil, sc.TimeZone()) chk := chunk.New(fts, 1, 1) err = cDecoder.DecodeToChunk(newRow, kv.IntHandle(-1), chk) require.NoError(t, err) @@ -249,8 +247,7 @@ func TestEncodeKindNullDatum(t *testing.T) { func TestDecodeDecimalFspNotMatch(t *testing.T) { var encoder rowcodec.Encoder - sc := new(stmtctx.StatementContext) - sc.TimeZone = time.UTC + sc := stmtctx.NewStmtCtxWithTimeZone(time.UTC) colIDs := []int64{ 1, } @@ -270,7 +267,7 @@ func TestDecodeDecimalFspNotMatch(t *testing.T) { ID: 1, Ft: ft, }) - cDecoder := rowcodec.NewChunkDecoder(cols, []int64{-1}, nil, sc.TimeZone) + cDecoder := rowcodec.NewChunkDecoder(cols, []int64{-1}, nil, sc.TimeZone()) chk := chunk.New(fts, 1, 1) err = cDecoder.DecodeToChunk(newRow, kv.IntHandle(-1), chk) require.NoError(t, err) @@ -295,7 +292,7 @@ func TestTypesNewRowCodec(t *testing.T) { return d } getTime := func(value string) types.Time { - d, err := types.ParseTime(&stmtctx.StatementContext{TimeZone: time.UTC}, value, mysql.TypeTimestamp, 6, nil) + d, err := types.ParseTime(stmtctx.NewStmtCtxWithTimeZone(time.UTC), value, mysql.TypeTimestamp, 6, nil) require.NoError(t, err) return d } @@ -515,13 +512,12 @@ func TestTypesNewRowCodec(t *testing.T) { } // test encode input. - sc := new(stmtctx.StatementContext) - sc.TimeZone = time.UTC + sc := stmtctx.NewStmtCtxWithTimeZone(time.UTC) newRow, err := encoder.Encode(sc, colIDs, dts, nil) require.NoError(t, err) // decode to datum map. - mDecoder := rowcodec.NewDatumMapDecoder(cols, sc.TimeZone) + mDecoder := rowcodec.NewDatumMapDecoder(cols, sc.TimeZone()) dm, err := mDecoder.DecodeToDatumMap(newRow, nil) require.NoError(t, err) @@ -532,7 +528,7 @@ func TestTypesNewRowCodec(t *testing.T) { } // decode to chunk. - cDecoder := rowcodec.NewChunkDecoder(cols, []int64{-1}, nil, sc.TimeZone) + cDecoder := rowcodec.NewChunkDecoder(cols, []int64{-1}, nil, sc.TimeZone()) chk := chunk.New(fts, 1, 1) err = cDecoder.DecodeToChunk(newRow, kv.IntHandle(-1), chk) require.NoError(t, err) @@ -630,13 +626,12 @@ func TestNilAndDefault(t *testing.T) { // test encode input. var encoder rowcodec.Encoder - sc := new(stmtctx.StatementContext) - sc.TimeZone = time.UTC + sc := stmtctx.NewStmtCtxWithTimeZone(time.UTC) newRow, err := encoder.Encode(sc, colIDs, dts, nil) require.NoError(t, err) // decode to datum map. - mDecoder := rowcodec.NewDatumMapDecoder(cols, sc.TimeZone) + mDecoder := rowcodec.NewDatumMapDecoder(cols, sc.TimeZone()) dm, err := mDecoder.DecodeToDatumMap(newRow, nil) require.NoError(t, err) @@ -653,7 +648,7 @@ func TestNilAndDefault(t *testing.T) { // decode to chunk. chk := chunk.New(fts, 1, 1) - cDecoder := rowcodec.NewChunkDecoder(cols, []int64{-1}, ddf, sc.TimeZone) + cDecoder := rowcodec.NewChunkDecoder(cols, []int64{-1}, ddf, sc.TimeZone()) err = cDecoder.DecodeToChunk(newRow, kv.IntHandle(-1), chk) require.NoError(t, err) @@ -669,7 +664,7 @@ func TestNilAndDefault(t *testing.T) { } chk = chunk.New(fts, 1, 1) - cDecoder = rowcodec.NewChunkDecoder(cols, []int64{-1}, nil, sc.TimeZone) + cDecoder = rowcodec.NewChunkDecoder(cols, []int64{-1}, nil, sc.TimeZone()) err = cDecoder.DecodeToChunk(newRow, kv.IntHandle(-1), chk) require.NoError(t, err) @@ -687,7 +682,7 @@ func TestNilAndDefault(t *testing.T) { for i, t := range td { colOffset[t.id] = i } - bDecoder := rowcodec.NewByteDecoder(cols, []int64{-1}, bdf, sc.TimeZone) + bDecoder := rowcodec.NewByteDecoder(cols, []int64{-1}, bdf, sc.TimeZone()) oldRow, err := bDecoder.DecodeToBytes(colOffset, kv.IntHandle(-1), newRow, nil) require.NoError(t, err) @@ -740,12 +735,11 @@ func TestVarintCompatibility(t *testing.T) { // test encode input. var encoder rowcodec.Encoder - sc := new(stmtctx.StatementContext) - sc.TimeZone = time.UTC + sc := stmtctx.NewStmtCtxWithTimeZone(time.UTC) newRow, err := encoder.Encode(sc, colIDs, dts, nil) require.NoError(t, err) - decoder := rowcodec.NewByteDecoder(cols, []int64{-1}, nil, sc.TimeZone) + decoder := rowcodec.NewByteDecoder(cols, []int64{-1}, nil, sc.TimeZone()) // decode to old row bytes. colOffset := make(map[int64]int) for i, t := range td { @@ -768,7 +762,7 @@ func TestCodecUtil(t *testing.T) { tps[i] = types.NewFieldType(mysql.TypeLonglong) } tps[3] = types.NewFieldType(mysql.TypeNull) - sc := new(stmtctx.StatementContext) + sc := stmtctx.NewStmtCtx() oldRow, err := tablecodec.EncodeOldRow(sc, types.MakeDatums(1, 2, 3, nil), colIDs, nil, nil) require.NoError(t, err) @@ -818,7 +812,7 @@ func TestOldRowCodec(t *testing.T) { tps[i] = types.NewFieldType(mysql.TypeLonglong) } tps[3] = types.NewFieldType(mysql.TypeNull) - sc := new(stmtctx.StatementContext) + sc := stmtctx.NewStmtCtx() oldRow, err := tablecodec.EncodeOldRow(sc, types.MakeDatums(1, 2, 3, nil), colIDs, nil, nil) require.NoError(t, err) @@ -850,7 +844,7 @@ func Test65535Bug(t *testing.T) { colIds := []int64{1} tps := make([]*types.FieldType, 1) tps[0] = types.NewFieldType(mysql.TypeString) - sc := new(stmtctx.StatementContext) + sc := stmtctx.NewStmtCtx() text65535 := strings.Repeat("a", 65535) encode := rowcodec.Encoder{} bd, err := encode.Encode(sc, colIds, []types.Datum{types.NewStringDatum(text65535)}, nil) @@ -1201,7 +1195,7 @@ func TestRowChecksum(t *testing.T) { } func TestEncodeDecodeRowWithChecksum(t *testing.T) { - sc := new(stmtctx.StatementContext) + sc := stmtctx.NewStmtCtx() enc := rowcodec.Encoder{} for _, tt := range []struct { @@ -1216,7 +1210,7 @@ func TestEncodeDecodeRowWithChecksum(t *testing.T) { t.Run(tt.name, func(t *testing.T) { raw, err := enc.Encode(sc, nil, nil, nil, tt.checksums...) require.NoError(t, err) - dec := rowcodec.NewDatumMapDecoder([]rowcodec.ColInfo{}, sc.TimeZone) + dec := rowcodec.NewDatumMapDecoder([]rowcodec.ColInfo{}, sc.TimeZone()) _, err = dec.DecodeToDatumMap(raw, nil) require.NoError(t, err) v1, ok1 := enc.GetChecksum() @@ -1251,7 +1245,7 @@ func TestEncodeDecodeRowWithChecksum(t *testing.T) { } t.Run("ReuseDecoder", func(t *testing.T) { - dec := rowcodec.NewDatumMapDecoder([]rowcodec.ColInfo{}, sc.TimeZone) + dec := rowcodec.NewDatumMapDecoder([]rowcodec.ColInfo{}, sc.TimeZone()) raw1, err := enc.Encode(sc, nil, nil, nil) require.NoError(t, err) diff --git a/util/stmtsummary/statement_summary_test.go b/util/stmtsummary/statement_summary_test.go index ac7e1b06e059b..a208ae92ce036 100644 --- a/util/stmtsummary/statement_summary_test.go +++ b/util/stmtsummary/statement_summary_test.go @@ -65,6 +65,11 @@ func TestAddStatement(t *testing.T) { tables := []stmtctx.TableEntry{{DB: "db1", Table: "tb1"}, {DB: "db2", Table: "tb2"}} indexes := []string{"a", "b"} + sc := stmtctx.NewStmtCtx() + sc.StmtType = "Select" + sc.Tables = tables + sc.IndexNames = indexes + // first statement stmtExecInfo1 := generateAnyExecInfo() stmtExecInfo1.ExecDetail.CommitDetail.Mu.PrewriteBackoffTypes = make([]string, 0) @@ -220,11 +225,7 @@ func TestAddStatement(t *testing.T) { }, CalleeAddress: "202", }, }, - StmtCtx: &stmtctx.StatementContext{ - StmtType: "Select", - Tables: tables, - IndexNames: indexes, - }, + StmtCtx: sc, MemMax: 20000, DiskMax: 20000, StartTime: time.Date(2019, 1, 1, 10, 10, 20, 10, time.UTC), @@ -360,11 +361,7 @@ func TestAddStatement(t *testing.T) { CalleeAddress: "302", }, }, - StmtCtx: &stmtctx.StatementContext{ - StmtType: "Select", - Tables: tables, - IndexNames: indexes, - }, + StmtCtx: sc, MemMax: 200, DiskMax: 200, StartTime: time.Date(2019, 1, 1, 10, 10, 0, 10, time.UTC), @@ -585,6 +582,11 @@ func match(t *testing.T, row []types.Datum, expected ...interface{}) { func generateAnyExecInfo() *StmtExecInfo { tables := []stmtctx.TableEntry{{DB: "db1", Table: "tb1"}, {DB: "db2", Table: "tb2"}} indexes := []string{"a"} + sc := stmtctx.NewStmtCtx() + sc.StmtType = "Select" + sc.Tables = tables + sc.IndexNames = indexes + stmtExecInfo := &StmtExecInfo{ SchemaName: "schema_name", OriginalSQL: "original_sql1", @@ -654,11 +656,7 @@ func generateAnyExecInfo() *StmtExecInfo { CalleeAddress: "129", }, }, - StmtCtx: &stmtctx.StatementContext{ - StmtType: "Select", - Tables: tables, - IndexNames: indexes, - }, + StmtCtx: sc, MemMax: 10000, DiskMax: 10000, StartTime: time.Date(2019, 1, 1, 10, 10, 10, 10, time.UTC), diff --git a/util/stmtsummary/v2/record.go b/util/stmtsummary/v2/record.go index 465192c9d1251..a0155babee45c 100644 --- a/util/stmtsummary/v2/record.go +++ b/util/stmtsummary/v2/record.go @@ -585,6 +585,11 @@ func maxSQLLength() uint32 { func GenerateStmtExecInfo4Test(digest string) *stmtsummary.StmtExecInfo { tables := []stmtctx.TableEntry{{DB: "db1", Table: "tb1"}, {DB: "db2", Table: "tb2"}} indexes := []string{"a"} + sc := stmtctx.NewStmtCtx() + sc.StmtType = "Select" + sc.Tables = tables + sc.IndexNames = indexes + stmtExecInfo := &stmtsummary.StmtExecInfo{ SchemaName: "schema_name", OriginalSQL: "original_sql1", @@ -654,11 +659,7 @@ func GenerateStmtExecInfo4Test(digest string) *stmtsummary.StmtExecInfo { CalleeAddress: "129", }, }, - StmtCtx: &stmtctx.StatementContext{ - StmtType: "Select", - Tables: tables, - IndexNames: indexes, - }, + StmtCtx: sc, MemMax: 10000, DiskMax: 10000, StartTime: time.Date(2019, 1, 1, 10, 10, 10, 10, time.UTC), diff --git a/util/util_test.go b/util/util_test.go index 8ed974ce020fd..858fbea11804d 100644 --- a/util/util_test.go +++ b/util/util_test.go @@ -43,7 +43,7 @@ func TestLogFormat(t *testing.T) { StatsInfo: func(interface{}) map[string]uint64 { return nil }, - StmtCtx: &stmtctx.StatementContext{}, + StmtCtx: stmtctx.NewStmtCtx(), RefCountOfStmtCtx: &refCount, MemTracker: mem, RedactSQL: false, From 15745e6bb42fb9959fc678dc17efcfc60da2caab Mon Sep 17 00:00:00 2001 From: Chao Wang Date: Thu, 12 Oct 2023 18:10:49 +0800 Subject: [PATCH 2/2] update --- executor/BUILD.bazel | 1 - executor/executor.go | 3 +-- sessionctx/stmtctx/stmtctx.go | 3 ++- sessionctx/stmtctx/stmtctx_test.go | 2 +- types/context/context.go | 5 ++++- types/context/context_test.go | 2 +- util/nocopy/nocopy.go | 4 ++-- 7 files changed, 11 insertions(+), 9 deletions(-) diff --git a/executor/BUILD.bazel b/executor/BUILD.bazel index 8b6764cddd89c..1f8b8c4eaefa6 100644 --- a/executor/BUILD.bazel +++ b/executor/BUILD.bazel @@ -176,7 +176,6 @@ go_library( "//telemetry", "//tidb-binlog/node", "//types", - "//types/context", "//types/parser_driver", "//util", "//util/admin", diff --git a/executor/executor.go b/executor/executor.go index 568cbc410c6f6..73d1c4a4a553a 100644 --- a/executor/executor.go +++ b/executor/executor.go @@ -18,7 +18,6 @@ import ( "cmp" "context" "fmt" - typectx "github.com/pingcap/tidb/types/context" "math" "runtime/pprof" "slices" @@ -2139,7 +2138,7 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { sc.AllowInvalidDate = vars.SQLMode.HasAllowInvalidDatesMode() } - sc.UpdateTypeFlags(func(flags typectx.Flags) typectx.Flags { + sc.UpdateTypeFlags(func(flags types.Flags) types.Flags { return flags. WithSkipUTF8Check(vars.SkipUTF8Check). WithSkipSACIICheck(vars.SkipASCIICheck). diff --git a/sessionctx/stmtctx/stmtctx.go b/sessionctx/stmtctx/stmtctx.go index a2fd9df291da8..083fa27db8be3 100644 --- a/sessionctx/stmtctx/stmtctx.go +++ b/sessionctx/stmtctx/stmtctx.go @@ -150,7 +150,7 @@ func (rf *ReferenceCount) UnFreeze() { // It should be reset before executing a statement. type StatementContext struct { // NoCopy indicates that this struct cannot be copied because - // copying this object will make the new TypeCtx referring a wrong `AppendWarnings` func. + // copying this object will make the copied TypeCtx field to refer a wrong `AppendWarnings` func. _ nocopy.NoCopy // TypeCtx is used to indicate how make the type conversation. @@ -468,6 +468,7 @@ func (sc *StatementContext) SetTypeFlags(flags typectx.Flags) { sc.TypeCtx = sc.TypeCtx.WithFlags(flags) } +// UpdateTypeFlags updates the flags of the type context func (sc *StatementContext) UpdateTypeFlags(fn func(typectx.Flags) typectx.Flags) { flags := fn(sc.TypeCtx.Flags()) sc.TypeCtx = sc.TypeCtx.WithFlags(flags) diff --git a/sessionctx/stmtctx/stmtctx_test.go b/sessionctx/stmtctx/stmtctx_test.go index dc45f464c27ee..a9d03bcc42429 100644 --- a/sessionctx/stmtctx/stmtctx_test.go +++ b/sessionctx/stmtctx/stmtctx_test.go @@ -18,7 +18,6 @@ import ( "context" "encoding/json" "fmt" - typectx "github.com/pingcap/tidb/types/context" "math/rand" "reflect" "sort" @@ -31,6 +30,7 @@ import ( "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/testkit" "github.com/pingcap/tidb/types" + typectx "github.com/pingcap/tidb/types/context" "github.com/pingcap/tidb/util/execdetails" "github.com/stretchr/testify/require" "github.com/tikv/client-go/v2/util" diff --git a/types/context/context.go b/types/context/context.go index ddccdb4a6e445..2e57b1690e811 100644 --- a/types/context/context.go +++ b/types/context/context.go @@ -113,6 +113,7 @@ type Context struct { // NewContext creates a new `Context` func NewContext(flags Flags, loc *time.Location, appendWarningFn func(err error)) Context { + intest.Assert(loc != nil && appendWarningFn != nil) return Context{ flags: flags, loc: loc, @@ -144,7 +145,7 @@ func (c *Context) WithLocation(loc *time.Location) Context { func (c *Context) Location() *time.Location { intest.Assert(c.loc) if c.loc == nil { - // this should never happen, just make the code safe here. + // c.loc should always not be nil, just make the code safe here. return time.UTC } return c.loc @@ -152,7 +153,9 @@ func (c *Context) Location() *time.Location { // AppendWarning appends the error to warning. If the inner `appendWarningFn` is nil, do nothing. func (c *Context) AppendWarning(err error) { + intest.Assert(c.appendWarningFn != nil) if fn := c.appendWarningFn; fn != nil { + // appendWarningFn should always not be nil, check fn != nil here to just make code safe. fn(err) } } diff --git a/types/context/context_test.go b/types/context/context_test.go index 4cf057c474893..9a140a2ceb476 100644 --- a/types/context/context_test.go +++ b/types/context/context_test.go @@ -23,7 +23,7 @@ import ( ) func TestWithNewFlags(t *testing.T) { - ctx := NewContext(FlagSkipASCIICheck, time.UTC, nil) + ctx := NewContext(FlagSkipASCIICheck, time.UTC, func(_ error) {}) ctx2 := ctx.WithFlags(FlagSkipUTF8Check) require.Equal(t, FlagSkipASCIICheck, ctx.Flags()) require.Equal(t, FlagSkipUTF8Check, ctx2.Flags()) diff --git a/util/nocopy/nocopy.go b/util/nocopy/nocopy.go index 30ee74104b676..7bfab1988fcb3 100644 --- a/util/nocopy/nocopy.go +++ b/util/nocopy/nocopy.go @@ -18,7 +18,7 @@ package nocopy type NoCopy struct{} // Lock is an empty function to implement sync.Locker interface -func (_ *NoCopy) Lock() {} +func (*NoCopy) Lock() {} // Unlock is an empty function to implement sync.Locker interface -func (_ *NoCopy) Unlock() {} +func (*NoCopy) Unlock() {}