diff --git a/bindinfo/bind_test.go b/bindinfo/bind_test.go index beca8ed520928..a0acb7ba89cf4 100644 --- a/bindinfo/bind_test.go +++ b/bindinfo/bind_test.go @@ -151,7 +151,7 @@ func normalizeWithDefaultDB(c *C, sql, db string) (string, string) { testParser := parser.New() stmt, err := testParser.ParseOneStmt(sql, "", "") c.Assert(err, IsNil) - return parser.NormalizeDigest(utilparser.RestoreWithDefaultDB(stmt, "test")) + return parser.NormalizeDigest(utilparser.RestoreWithDefaultDB(stmt, "test", "")) } func (s *testSuite) TestBindParse(c *C) { diff --git a/bindinfo/handle.go b/bindinfo/handle.go index 727b6dcc106cd..61148d42712b8 100644 --- a/bindinfo/handle.go +++ b/bindinfo/handle.go @@ -624,7 +624,7 @@ func (h *BindHandle) CaptureBaselines() { continue } dbName := utilparser.GetDefaultDB(stmt, bindableStmt.Schema) - normalizedSQL, digest := parser.NormalizeDigest(utilparser.RestoreWithDefaultDB(stmt, dbName)) + normalizedSQL, digest := parser.NormalizeDigest(utilparser.RestoreWithDefaultDB(stmt, dbName, bindableStmt.Query)) if r := h.GetBindRecord(digest, normalizedSQL, dbName); r != nil && r.HasUsingBinding() { continue } diff --git a/cmd/explaintest/r/explain.result b/cmd/explaintest/r/explain.result index 4a279d7dc5058..259a7c49be9ea 100644 --- a/cmd/explaintest/r/explain.result +++ b/cmd/explaintest/r/explain.result @@ -43,4 +43,4 @@ drop view if exists v; create view v as select cast(replace(substring_index(substring_index("",',',1),':',-1),'"','') as CHAR(32)) as event_id; desc v; Field Type Null Key Default Extra -event_id varchar(32) YES NULL +event_id varchar(32) NO NULL diff --git a/cmd/explaintest/r/explain_easy.result b/cmd/explaintest/r/explain_easy.result index 5700618e49ca6..0d7bc00eccddb 100644 --- a/cmd/explaintest/r/explain_easy.result +++ b/cmd/explaintest/r/explain_easy.result @@ -613,14 +613,15 @@ HashJoin 8002.00 root right outer join, equal:[eq(test.t.nb, test.t.nb)] explain format = 'brief' select ifnull(t.a, 1) in (select count(*) from t s , t t1 where s.a = t.a and s.a = t1.a) from t; id estRows task access object operator info Projection 10000.00 root Column#14 -└─Apply 10000.00 root CARTESIAN left outer semi join, other cond:eq(ifnull(test.t.a, 1), Column#13) - ├─TableReader(Build) 10000.00 root data:TableFullScan - │ └─TableFullScan 10000.00 cop[tikv] table:t keep order:false, stats:pseudo - └─HashAgg(Probe) 1.00 root funcs:count(Column#15)->Column#13 +└─Apply 10000.00 root left outer semi join, equal:[eq(Column#15, Column#13)] + ├─Projection(Build) 10000.00 root test.t.a, ifnull(test.t.a, 1)->Column#15 + │ └─TableReader 10000.00 root data:TableFullScan + │ └─TableFullScan 10000.00 cop[tikv] table:t keep order:false, stats:pseudo + └─HashAgg(Probe) 1.00 root funcs:count(Column#17)->Column#13 └─HashJoin 9.99 root inner join, equal:[eq(test.t.a, test.t.a)] - ├─HashAgg(Build) 7.99 root group by:test.t.a, funcs:count(Column#16)->Column#15, funcs:firstrow(test.t.a)->test.t.a + ├─HashAgg(Build) 7.99 root group by:test.t.a, funcs:count(Column#18)->Column#17, funcs:firstrow(test.t.a)->test.t.a │ └─TableReader 7.99 root data:HashAgg - │ └─HashAgg 7.99 cop[tikv] group by:test.t.a, funcs:count(1)->Column#16 + │ └─HashAgg 7.99 cop[tikv] group by:test.t.a, funcs:count(1)->Column#18 │ └─Selection 9.99 cop[tikv] eq(test.t.a, test.t.a), not(isnull(test.t.a)) │ └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo └─TableReader(Probe) 9.99 root data:Selection diff --git a/distsql/distsql.go b/distsql/distsql.go index 9c8d8e4e63384..83236ce4dc5f4 100644 --- a/distsql/distsql.go +++ b/distsql/distsql.go @@ -30,7 +30,7 @@ import ( // DispatchMPPTasks dispathes all tasks and returns an iterator. func DispatchMPPTasks(ctx context.Context, sctx sessionctx.Context, tasks []*kv.MPPDispatchRequest, fieldTypes []*types.FieldType, planIDs []int, rootID int) (SelectResult, error) { - resp := sctx.GetMPPClient().DispatchMPPTasks(ctx, tasks) + resp := sctx.GetMPPClient().DispatchMPPTasks(ctx, sctx.GetSessionVars().KVVars, tasks) if resp == nil { err := errors.New("client returns nil response") return nil, err diff --git a/executor/analyze_test.go b/executor/analyze_test.go index daf12ba5f2ab5..d535707738173 100644 --- a/executor/analyze_test.go +++ b/executor/analyze_test.go @@ -506,14 +506,20 @@ func (s *testFastAnalyze) TestFastAnalyze(c *C) { c.Assert(result.Rows()[1][5], Equals, "2") c.Assert(result.Rows()[2][5], Equals, "3") */ +} +func (s *testSerialSuite2) TestFastAnalyze4GlobalStats(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("set @@session.tidb_enable_fast_analyze=1") + tk.MustExec("set @@session.tidb_build_stats_concurrency=1") // test fast analyze in dynamic mode tk.MustExec("set @@session.tidb_analyze_version = 2;") tk.MustExec("set @@session.tidb_partition_prune_mode = 'dynamic';") tk.MustExec("drop table if exists t4;") tk.MustExec("create table t4(a int, b int) PARTITION BY HASH(a) PARTITIONS 2;") tk.MustExec("insert into t4 values(1,1),(3,3),(4,4),(2,2),(5,5);") - err = tk.ExecToErr("analyze table t4;") + err := tk.ExecToErr("analyze table t4;") c.Assert(err.Error(), Equals, "Fast analyze hasn't reached General Availability and only support analyze version 1 currently.") } diff --git a/executor/executor.go b/executor/executor.go index 3303436c39a44..201ef2e248f56 100644 --- a/executor/executor.go +++ b/executor/executor.go @@ -28,7 +28,6 @@ import ( "github.com/cznic/mathutil" "github.com/opentracing/opentracing-go" "github.com/pingcap/errors" - "github.com/pingcap/failpoint" "github.com/pingcap/parser/ast" "github.com/pingcap/parser/auth" "github.com/pingcap/parser/model" @@ -1442,12 +1441,6 @@ type UnionExec struct { results []*chunk.Chunk wg sync.WaitGroup initialized bool - mu struct { - *sync.Mutex - maxOpenedChildID int - } - - childInFlightForTest int32 } // unionWorkerResult stores the result for a union worker. @@ -1467,11 +1460,12 @@ func (e *UnionExec) waitAllFinished() { // Open implements the Executor Open interface. func (e *UnionExec) Open(ctx context.Context) error { + if err := e.baseExecutor.Open(ctx); err != nil { + return err + } e.stopFetchData.Store(false) e.initialized = false e.finished = make(chan struct{}) - e.mu.Mutex = &sync.Mutex{} - e.mu.maxOpenedChildID = -1 return nil } @@ -1517,19 +1511,6 @@ func (e *UnionExec) resultPuller(ctx context.Context, workerID int) { e.wg.Done() }() for childID := range e.childIDChan { - e.mu.Lock() - if childID > e.mu.maxOpenedChildID { - e.mu.maxOpenedChildID = childID - } - e.mu.Unlock() - if err := e.children[childID].Open(ctx); err != nil { - result.err = err - e.stopFetchData.Store(true) - e.resultPool <- result - } - failpoint.Inject("issue21441", func() { - atomic.AddInt32(&e.childInFlightForTest, 1) - }) for { if e.stopFetchData.Load().(bool) { return @@ -1544,20 +1525,12 @@ func (e *UnionExec) resultPuller(ctx context.Context, workerID int) { e.resourcePools[workerID] <- result.chk break } - failpoint.Inject("issue21441", func() { - if int(atomic.LoadInt32(&e.childInFlightForTest)) > e.concurrency { - panic("the count of child in flight is larger than e.concurrency unexpectedly") - } - }) e.resultPool <- result if result.err != nil { e.stopFetchData.Store(true) return } } - failpoint.Inject("issue21441", func() { - atomic.AddInt32(&e.childInFlightForTest, -1) - }) } } @@ -1596,15 +1569,7 @@ func (e *UnionExec) Close() error { for range e.childIDChan { } } - // We do not need to acquire the e.mu.Lock since all the resultPuller can be - // promised to exit when reaching here (e.childIDChan been closed). - var firstErr error - for i := 0; i <= e.mu.maxOpenedChildID; i++ { - if err := e.children[i].Close(); err != nil && firstErr == nil { - firstErr = err - } - } - return firstErr + return e.baseExecutor.Close() } // ResetContextOfStmt resets the StmtContext and session variables. diff --git a/executor/executor_test.go b/executor/executor_test.go index 1b90e1f79375a..aff376228b938 100644 --- a/executor/executor_test.go +++ b/executor/executor_test.go @@ -7399,40 +7399,6 @@ func (s *testSuite) TestOOMActionPriority(c *C) { c.Assert(action.GetPriority(), Equals, int64(memory.DefLogPriority)) } -func (s *testSerialSuite) TestIssue21441(c *C) { - err := failpoint.Enable("github.com/pingcap/tidb/executor/issue21441", `return`) - c.Assert(err, IsNil) - defer func() { - err := failpoint.Disable("github.com/pingcap/tidb/executor/issue21441") - c.Assert(err, IsNil) - }() - - tk := testkit.NewTestKit(c, s.store) - tk.MustExec("use test") - tk.MustExec("drop table if exists t") - tk.MustExec("create table t(a int)") - tk.MustExec(`insert into t values(1),(2),(3)`) - tk.Se.GetSessionVars().InitChunkSize = 1 - tk.Se.GetSessionVars().MaxChunkSize = 1 - sql := ` -select a from t union all -select a from t union all -select a from t union all -select a from t union all -select a from t union all -select a from t union all -select a from t union all -select a from t` - tk.MustQuery(sql).Sort().Check(testkit.Rows( - "1", "1", "1", "1", "1", "1", "1", "1", - "2", "2", "2", "2", "2", "2", "2", "2", - "3", "3", "3", "3", "3", "3", "3", "3", - )) - - tk.MustQuery("select a from (" + sql + ") t order by a limit 4").Check(testkit.Rows("1", "1", "1", "1")) - tk.MustQuery("select a from (" + sql + ") t order by a limit 7, 4").Check(testkit.Rows("1", "2", "2", "2")) -} - func (s *testSuite) Test17780(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") diff --git a/executor/mpp_gather.go b/executor/mpp_gather.go index fd6e97d47b840..107d3501c6d5f 100644 --- a/executor/mpp_gather.go +++ b/executor/mpp_gather.go @@ -83,6 +83,7 @@ func (e *MPPGather) appendMPPDispatchReq(pf *plannercore.Fragment, tasks []*kv.M Timeout: 10, SchemaVar: e.is.SchemaMetaVersion(), StartTs: e.startTS, + State: kv.MppTaskReady, } e.mppReqs = append(e.mppReqs, req) } diff --git a/executor/tiflash_test.go b/executor/tiflash_test.go index 1d628375a9c82..20679644dabb6 100644 --- a/executor/tiflash_test.go +++ b/executor/tiflash_test.go @@ -15,18 +15,25 @@ package executor_test import ( "fmt" + "sync" + "sync/atomic" + "time" . "github.com/pingcap/check" + "github.com/pingcap/errors" "github.com/pingcap/failpoint" "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/parser" + "github.com/pingcap/parser/terror" "github.com/pingcap/tidb/domain" + "github.com/pingcap/tidb/executor" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/session" "github.com/pingcap/tidb/store/mockstore" "github.com/pingcap/tidb/store/mockstore/unistore" "github.com/pingcap/tidb/store/tikv/mockstore/cluster" "github.com/pingcap/tidb/util/testkit" + "github.com/pingcap/tidb/util/testleak" ) type tiflashTestSuite struct { @@ -36,6 +43,7 @@ type tiflashTestSuite struct { } func (s *tiflashTestSuite) SetUpSuite(c *C) { + testleak.BeforeTest() var err error s.store, err = mockstore.NewMockStore( mockstore.WithClusterInspector(func(c cluster.Cluster) { @@ -271,3 +279,74 @@ func (s *tiflashTestSuite) TestPartitionTable(c *C) { failpoint.Disable("github.com/pingcap/tidb/executor/checkTotalMPPTasks") failpoint.Disable("github.com/pingcap/tidb/executor/checkUseMPP") } + +func (s *tiflashTestSuite) TestCancelMppTasks(c *C) { + var hang = "github.com/pingcap/tidb/store/mockstore/unistore/mppRecvHang" + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a int not null primary key, b int not null)") + tk.MustExec("alter table t set tiflash replica 1") + tk.MustExec("insert into t values(1,0)") + tk.MustExec("insert into t values(2,0)") + tk.MustExec("insert into t values(3,0)") + tk.MustExec("insert into t values(4,0)") + tb := testGetTableByName(c, tk.Se, "test", "t") + err := domain.GetDomain(tk.Se).DDL().UpdateTableReplicaInfo(tk.Se, tb.Meta().ID, true) + c.Assert(err, IsNil) + tk.MustExec("set @@session.tidb_isolation_read_engines=\"tiflash\"") + tk.MustExec("set @@session.tidb_allow_mpp=ON") + atomic.StoreUint32(&tk.Se.GetSessionVars().Killed, 0) + c.Assert(failpoint.Enable(hang, `return(true)`), IsNil) + wg := &sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + err := tk.QueryToErr("select count(*) from t as t1 , t where t1.a = t.a") + c.Assert(err, NotNil) + c.Assert(int(terror.ToSQLError(errors.Cause(err).(*terror.Error)).Code), Equals, int(executor.ErrQueryInterrupted.Code())) + }() + time.Sleep(1 * time.Second) + atomic.StoreUint32(&tk.Se.GetSessionVars().Killed, 1) + wg.Wait() + c.Assert(failpoint.Disable(hang), IsNil) +} + +// all goroutines exit if one goroutine hangs but another return errors +func (s *tiflashTestSuite) TestMppGoroutinesExitFromErrors(c *C) { + defer testleak.AfterTest(c)() + // mock non-root tasks return error + var mppNonRootTaskError = "github.com/pingcap/tidb/store/copr/mppNonRootTaskError" + // mock root tasks hang + var hang = "github.com/pingcap/tidb/store/mockstore/unistore/mppRecvHang" + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a int not null primary key, b int not null)") + tk.MustExec("alter table t set tiflash replica 1") + tb := testGetTableByName(c, tk.Se, "test", "t") + err := domain.GetDomain(tk.Se).DDL().UpdateTableReplicaInfo(tk.Se, tb.Meta().ID, true) + c.Assert(err, IsNil) + tk.MustExec("insert into t values(1,0)") + tk.MustExec("insert into t values(2,0)") + tk.MustExec("insert into t values(3,0)") + tk.MustExec("drop table if exists t1") + tk.MustExec("create table t1(a int not null primary key, b int not null)") + tk.MustExec("alter table t1 set tiflash replica 1") + tb = testGetTableByName(c, tk.Se, "test", "t1") + err = domain.GetDomain(tk.Se).DDL().UpdateTableReplicaInfo(tk.Se, tb.Meta().ID, true) + c.Assert(err, IsNil) + tk.MustExec("insert into t1 values(1,0)") + tk.MustExec("insert into t1 values(2,0)") + tk.MustExec("insert into t1 values(3,0)") + tk.MustExec("set @@session.tidb_isolation_read_engines=\"tiflash\"") + tk.MustExec("set @@session.tidb_allow_mpp=ON") + c.Assert(failpoint.Enable(mppNonRootTaskError, `return(true)`), IsNil) + c.Assert(failpoint.Enable(hang, `return(true)`), IsNil) + + // generate 2 root tasks, one will hang and another will return errors + err = tk.QueryToErr("select count(*) from t as t1 , t where t1.a = t.a") + c.Assert(err, NotNil) + c.Assert(failpoint.Disable(mppNonRootTaskError), IsNil) + c.Assert(failpoint.Disable(hang), IsNil) +} diff --git a/expression/constant_fold.go b/expression/constant_fold.go index 6f84eefeaefb9..c698d4a0a9536 100644 --- a/expression/constant_fold.go +++ b/expression/constant_fold.go @@ -16,6 +16,7 @@ package expression import ( "github.com/pingcap/parser/ast" "github.com/pingcap/parser/mysql" + "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/logutil" "go.uber.org/zap" @@ -207,14 +208,24 @@ func foldConstant(expr Expression) (Expression, bool) { return expr, isDeferredConst } value, err := x.Eval(chunk.Row{}) + retType := x.RetType.Clone() + if !hasNullArg { + // set right not null flag for constant value + switch value.Kind() { + case types.KindNull: + retType.Flag &= ^mysql.NotNullFlag + default: + retType.Flag |= mysql.NotNullFlag + } + } if err != nil { logutil.BgLogger().Debug("fold expression to constant", zap.String("expression", x.ExplainInfo()), zap.Error(err)) return expr, isDeferredConst } if isDeferredConst { - return &Constant{Value: value, RetType: x.RetType, DeferredExpr: x}, true + return &Constant{Value: value, RetType: retType, DeferredExpr: x}, true } - return &Constant{Value: value, RetType: x.RetType}, false + return &Constant{Value: value, RetType: retType}, false case *Constant: if x.ParamMarker != nil { return &Constant{ diff --git a/expression/expr_to_pb.go b/expression/expr_to_pb.go index 223f71df311e8..5dcefc47aab16 100644 --- a/expression/expr_to_pb.go +++ b/expression/expr_to_pb.go @@ -64,9 +64,6 @@ func (pc PbConverter) ExprToPB(expr Expression) *tipb.Expr { if pbExpr == nil { return nil } - if !x.Value.IsNull() { - pbExpr.FieldType.Flag |= uint32(mysql.NotNullFlag) - } return pbExpr case *CorrelatedColumn: return pc.conOrCorColToPBExpr(expr) diff --git a/expression/expr_to_pb_test.go b/expression/expr_to_pb_test.go index de7dc5858b171..4299e637b6df1 100644 --- a/expression/expr_to_pb_test.go +++ b/expression/expr_to_pb_test.go @@ -270,6 +270,7 @@ func (s *testEvaluatorSuite) TestLikeFunc2Pb(c *C) { client := new(mock.Client) retTp := types.NewFieldType(mysql.TypeString) + retTp.Flag |= mysql.NotNullFlag retTp.Charset = charset.CharsetUTF8 retTp.Collate = charset.CollationUTF8 args := []Expression{ diff --git a/expression/typeinfer_test.go b/expression/typeinfer_test.go index f5abbd9f4657f..aac2eabee5632 100644 --- a/expression/typeinfer_test.go +++ b/expression/typeinfer_test.go @@ -156,26 +156,26 @@ func (s *testInferTypeSuite) TestInferType(c *C) { func (s *testInferTypeSuite) createTestCase4Constants() []typeInferTestCase { return []typeInferTestCase{ - {"1", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0}, - {"-1", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 2, 0}, - {"1.23", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, 4, 2}, - {"-1.23", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, 5, 2}, - {"123e5", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, 8, types.UnspecifiedLength}, - {"-123e5", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, 9, types.UnspecifiedLength}, - {"123e-5", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, 7, types.UnspecifiedLength}, - {"-123e-5", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, 8, types.UnspecifiedLength}, + {"1", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag | mysql.NotNullFlag, 1, 0}, + {"-1", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag | mysql.NotNullFlag, 2, 0}, + {"1.23", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag | mysql.NotNullFlag, 4, 2}, + {"-1.23", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag | mysql.NotNullFlag, 5, 2}, + {"123e5", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag | mysql.NotNullFlag, 8, types.UnspecifiedLength}, + {"-123e5", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag | mysql.NotNullFlag, 9, types.UnspecifiedLength}, + {"123e-5", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag | mysql.NotNullFlag, 7, types.UnspecifiedLength}, + {"-123e-5", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag | mysql.NotNullFlag, 8, types.UnspecifiedLength}, {"NULL", mysql.TypeNull, charset.CharsetBin, mysql.BinaryFlag, 0, 0}, - {"TRUE", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag | mysql.IsBooleanFlag, 1, 0}, - {"FALSE", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag | mysql.IsBooleanFlag, 1, 0}, - {"'1234'", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 4, types.UnspecifiedLength}, - {"_utf8'1234'", mysql.TypeVarString, charset.CharsetUTF8, 0, 4, types.UnspecifiedLength}, - {"_binary'1234'", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 4, types.UnspecifiedLength}, - {"b'0001'", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 1, 0}, - {"b'000100001'", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 2, 0}, - {"b'0000000000010000'", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 2, 0}, - {"x'10'", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag | mysql.UnsignedFlag, 3, 0}, - {"x'ff10'", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag | mysql.UnsignedFlag, 6, 0}, - {"x'0000000000000000ff10'", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag | mysql.UnsignedFlag, 30, 0}, + {"TRUE", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag | mysql.IsBooleanFlag | mysql.NotNullFlag, 1, 0}, + {"FALSE", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag | mysql.IsBooleanFlag | mysql.NotNullFlag, 1, 0}, + {"'1234'", mysql.TypeVarString, charset.CharsetUTF8MB4, 0 | mysql.NotNullFlag, 4, types.UnspecifiedLength}, + {"_utf8'1234'", mysql.TypeVarString, charset.CharsetUTF8, 0 | mysql.NotNullFlag, 4, types.UnspecifiedLength}, + {"_binary'1234'", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag | mysql.NotNullFlag, 4, types.UnspecifiedLength}, + {"b'0001'", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag | mysql.NotNullFlag, 1, 0}, + {"b'000100001'", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag | mysql.NotNullFlag, 2, 0}, + {"b'0000000000010000'", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag | mysql.NotNullFlag, 2, 0}, + {"x'10'", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag | mysql.UnsignedFlag | mysql.NotNullFlag, 3, 0}, + {"x'ff10'", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag | mysql.UnsignedFlag | mysql.NotNullFlag, 6, 0}, + {"x'0000000000000000ff10'", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag | mysql.UnsignedFlag | mysql.NotNullFlag, 30, 0}, } } @@ -241,9 +241,9 @@ func (s *testInferTypeSuite) createTestCase4StrFuncs() []typeInferTestCase { {"space(c_int_d)", mysql.TypeLongBlob, mysql.DefaultCharset, 0, mysql.MaxBlobWidth, types.UnspecifiedLength}, {"CONCAT(c_binary, c_int_d)", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 40, types.UnspecifiedLength}, {"CONCAT(c_bchar, c_int_d)", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 40, types.UnspecifiedLength}, - {"CONCAT('T', 'i', 'DB')", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 4, types.UnspecifiedLength}, + {"CONCAT('T', 'i', 'DB')", mysql.TypeVarString, charset.CharsetUTF8MB4, 0 | mysql.NotNullFlag, 4, types.UnspecifiedLength}, {"CONCAT('T', 'i', 'DB', c_binary)", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 24, types.UnspecifiedLength}, - {"CONCAT_WS('-', 'T', 'i', 'DB')", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 6, types.UnspecifiedLength}, + {"CONCAT_WS('-', 'T', 'i', 'DB')", mysql.TypeVarString, charset.CharsetUTF8MB4, 0 | mysql.NotNullFlag, 6, types.UnspecifiedLength}, {"CONCAT_WS(',', 'TiDB', c_binary)", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 25, types.UnspecifiedLength}, {"left(c_int_d, c_int_d)", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 20, types.UnspecifiedLength}, {"right(c_int_d, c_int_d)", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 20, types.UnspecifiedLength}, @@ -251,7 +251,7 @@ func (s *testInferTypeSuite) createTestCase4StrFuncs() []typeInferTestCase { {"lower(c_binary)", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 20, types.UnspecifiedLength}, {"upper(c_int_d)", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 20, types.UnspecifiedLength}, {"upper(c_binary)", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 20, types.UnspecifiedLength}, - {"replace(1234, 2, 55)", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 20, types.UnspecifiedLength}, + {"replace(1234, 2, 55)", mysql.TypeVarString, charset.CharsetUTF8MB4, 0 | mysql.NotNullFlag, 20, types.UnspecifiedLength}, {"replace(c_binary, 1, 2)", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 20, types.UnspecifiedLength}, {"to_base64(c_binary)", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 28, types.UnspecifiedLength}, {"substr(c_int_d, c_int_d)", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 20, types.UnspecifiedLength}, @@ -273,8 +273,8 @@ func (s *testInferTypeSuite) createTestCase4StrFuncs() []typeInferTestCase { {"ascii(c_char)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 3, 0}, {"ord(c_char)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 10, 0}, {`c_int_d like 'abc%'`, mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag | mysql.IsBooleanFlag, 1, 0}, - {"tidb_version()", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, len(printer.GetTiDBInfo()), types.UnspecifiedLength}, - {"tidb_is_ddl_owner()", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxIntWidth, 0}, + {"tidb_version()", mysql.TypeVarString, charset.CharsetUTF8MB4, 0 | mysql.NotNullFlag, len(printer.GetTiDBInfo()), types.UnspecifiedLength}, + {"tidb_is_ddl_owner()", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag | mysql.NotNullFlag, mysql.MaxIntWidth, 0}, {"password(c_char)", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, mysql.PWDHashLen + 1, types.UnspecifiedLength}, {"elt(c_int_d, c_char, c_char, c_char)", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 20, types.UnspecifiedLength}, {"elt(c_int_d, c_char, c_char, c_binary)", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 20, types.UnspecifiedLength}, @@ -291,11 +291,11 @@ func (s *testInferTypeSuite) createTestCase4StrFuncs() []typeInferTestCase { {"locate(c_binary, c_char, c_int_d)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxIntWidth, 0}, {"locate(c_binary, c_binary, c_int_d)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxIntWidth, 0}, - {"lpad('TiDB', 12, 'go' )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 48, types.UnspecifiedLength}, + {"lpad('TiDB', 12, 'go' )", mysql.TypeVarString, charset.CharsetUTF8MB4, mysql.NotNullFlag, 48, types.UnspecifiedLength}, {"lpad(c_binary, 12, 'go' )", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 12, types.UnspecifiedLength}, {"lpad(c_char, c_int_d, c_binary)", mysql.TypeLongBlob, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxBlobWidth, types.UnspecifiedLength}, {"lpad(c_char, c_int_d, c_char )", mysql.TypeLongBlob, charset.CharsetUTF8MB4, 0, mysql.MaxBlobWidth, types.UnspecifiedLength}, - {"rpad('TiDB', 12, 'go' )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 48, types.UnspecifiedLength}, + {"rpad('TiDB', 12, 'go' )", mysql.TypeVarString, charset.CharsetUTF8MB4, mysql.NotNullFlag, 48, types.UnspecifiedLength}, {"rpad(c_binary, 12, 'go' )", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 12, types.UnspecifiedLength}, {"rpad(c_char, c_int_d, c_binary)", mysql.TypeLongBlob, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxBlobWidth, types.UnspecifiedLength}, {"rpad(c_char, c_int_d, c_char )", mysql.TypeLongBlob, charset.CharsetUTF8MB4, 0, mysql.MaxBlobWidth, types.UnspecifiedLength}, @@ -493,7 +493,7 @@ func (s *testInferTypeSuite) createTestCase4MathFuncs() []typeInferTestCase { {"exp(c_time_d)", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxRealWidth, types.UnspecifiedLength}, {"exp(c_timestamp_d)", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxRealWidth, types.UnspecifiedLength}, {"exp(c_binary)", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxRealWidth, types.UnspecifiedLength}, - {"pi()", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, 8, 6}, + {"pi()", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag | mysql.NotNullFlag, 8, 6}, {"~c_int_d", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag | mysql.UnsignedFlag, mysql.MaxIntWidth, 0}, {"!c_int_d", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag | mysql.IsBooleanFlag, 1, 0}, {"c_int_d & c_int_d", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag | mysql.UnsignedFlag, mysql.MaxIntWidth, 0}, @@ -535,8 +535,8 @@ func (s *testInferTypeSuite) createTestCase4MathFuncs() []typeInferTestCase { {"floor(c_time_d)", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxRealWidth, 0}, {"floor(c_enum)", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxRealWidth, 0}, {"floor(c_text_d)", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxRealWidth, 0}, - {"floor(18446744073709551615)", mysql.TypeLonglong, charset.CharsetBin, mysql.UnsignedFlag | mysql.BinaryFlag, 20, 0}, - {"floor(18446744073709551615.1)", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, 22, 0}, + {"floor(18446744073709551615)", mysql.TypeLonglong, charset.CharsetBin, mysql.UnsignedFlag | mysql.BinaryFlag | mysql.NotNullFlag, 20, 0}, + {"floor(18446744073709551615.1)", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag | mysql.NotNullFlag, 22, 0}, {"ceil(c_int_d)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 11, 0}, {"ceil(c_uint_d)", mysql.TypeLonglong, charset.CharsetBin, mysql.UnsignedFlag | mysql.BinaryFlag, 10, 0}, @@ -553,8 +553,8 @@ func (s *testInferTypeSuite) createTestCase4MathFuncs() []typeInferTestCase { {"ceil(c_time_d)", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxRealWidth, 0}, {"ceil(c_enum)", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxRealWidth, 0}, {"ceil(c_text_d)", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxRealWidth, 0}, - {"ceil(18446744073709551615)", mysql.TypeLonglong, charset.CharsetBin, mysql.UnsignedFlag | mysql.BinaryFlag, 20, 0}, - {"ceil(18446744073709551615.1)", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, 22, 0}, + {"ceil(18446744073709551615)", mysql.TypeLonglong, charset.CharsetBin, mysql.UnsignedFlag | mysql.BinaryFlag | mysql.NotNullFlag, 20, 0}, + {"ceil(18446744073709551615.1)", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag | mysql.NotNullFlag, 22, 0}, {"ceiling(c_int_d)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 11, 0}, {"ceiling(c_decimal)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 6, 0}, @@ -564,8 +564,8 @@ func (s *testInferTypeSuite) createTestCase4MathFuncs() []typeInferTestCase { {"ceiling(c_time_d)", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxRealWidth, 0}, {"ceiling(c_enum)", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxRealWidth, 0}, {"ceiling(c_text_d)", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxRealWidth, 0}, - {"ceiling(18446744073709551615)", mysql.TypeLonglong, charset.CharsetBin, mysql.UnsignedFlag | mysql.BinaryFlag, 20, 0}, - {"ceiling(18446744073709551615.1)", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, 22, 0}, + {"ceiling(18446744073709551615)", mysql.TypeLonglong, charset.CharsetBin, mysql.UnsignedFlag | mysql.BinaryFlag | mysql.NotNullFlag, 20, 0}, + {"ceiling(18446744073709551615.1)", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag | mysql.NotNullFlag, 22, 0}, {"conv(c_char, c_int_d, c_int_d)", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 64, types.UnspecifiedLength}, {"conv(c_int_d, c_int_d, c_int_d)", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 64, types.UnspecifiedLength}, @@ -843,15 +843,15 @@ func (s *testInferTypeSuite) createTestCase4Aggregations() []typeInferTestCase { func (s *testInferTypeSuite) createTestCase4InfoFunc() []typeInferTestCase { return []typeInferTestCase{ - {"last_insert_id( )", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag | mysql.UnsignedFlag, mysql.MaxIntWidth, 0}, + {"last_insert_id( )", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag | mysql.UnsignedFlag | mysql.NotNullFlag, mysql.MaxIntWidth, 0}, {"last_insert_id(c_int_d)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag | mysql.UnsignedFlag, mysql.MaxIntWidth, 0}, {"found_rows()", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag | mysql.UnsignedFlag, mysql.MaxIntWidth, 0}, - {"database()", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 64, types.UnspecifiedLength}, + {"database()", mysql.TypeVarString, charset.CharsetUTF8MB4, mysql.NotNullFlag, 64, types.UnspecifiedLength}, {"current_user()", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 64, types.UnspecifiedLength}, - {"current_role()", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 64, types.UnspecifiedLength}, + {"current_role()", mysql.TypeVarString, charset.CharsetUTF8MB4, mysql.NotNullFlag, 64, types.UnspecifiedLength}, {"user()", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 64, types.UnspecifiedLength}, - {"connection_id()", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag | mysql.UnsignedFlag, mysql.MaxIntWidth, 0}, - {"version()", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 64, types.UnspecifiedLength}, + {"connection_id()", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag | mysql.UnsignedFlag | mysql.NotNullFlag, mysql.MaxIntWidth, 0}, + {"version()", mysql.TypeVarString, charset.CharsetUTF8MB4, mysql.NotNullFlag, 64, types.UnspecifiedLength}, } } @@ -873,8 +873,8 @@ func (s *testInferTypeSuite) createTestCase4EncryptionFuncs() []typeInferTestCas {"md5(c_blob_d )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 32, types.UnspecifiedLength}, {"md5(c_set )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 32, types.UnspecifiedLength}, {"md5(c_enum )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 32, types.UnspecifiedLength}, - {"md5('1234' )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 32, types.UnspecifiedLength}, - {"md5(1234 )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 32, types.UnspecifiedLength}, + {"md5('1234' )", mysql.TypeVarString, charset.CharsetUTF8MB4, mysql.NotNullFlag, 32, types.UnspecifiedLength}, + {"md5(1234 )", mysql.TypeVarString, charset.CharsetUTF8MB4, mysql.NotNullFlag, 32, types.UnspecifiedLength}, {"sha(c_int_d )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 40, types.UnspecifiedLength}, {"sha(c_bigint_d )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 40, types.UnspecifiedLength}, @@ -892,8 +892,8 @@ func (s *testInferTypeSuite) createTestCase4EncryptionFuncs() []typeInferTestCas {"sha(c_blob_d )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 40, types.UnspecifiedLength}, {"sha(c_set )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 40, types.UnspecifiedLength}, {"sha(c_enum )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 40, types.UnspecifiedLength}, - {"sha('1234' )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 40, types.UnspecifiedLength}, - {"sha(1234 )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 40, types.UnspecifiedLength}, + {"sha('1234' )", mysql.TypeVarString, charset.CharsetUTF8MB4, mysql.NotNullFlag, 40, types.UnspecifiedLength}, + {"sha(1234 )", mysql.TypeVarString, charset.CharsetUTF8MB4, mysql.NotNullFlag, 40, types.UnspecifiedLength}, {"sha1(c_int_d )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 40, types.UnspecifiedLength}, {"sha1(c_bigint_d )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 40, types.UnspecifiedLength}, @@ -911,8 +911,8 @@ func (s *testInferTypeSuite) createTestCase4EncryptionFuncs() []typeInferTestCas {"sha1(c_blob_d )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 40, types.UnspecifiedLength}, {"sha1(c_set )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 40, types.UnspecifiedLength}, {"sha1(c_enum )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 40, types.UnspecifiedLength}, - {"sha1('1234' )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 40, types.UnspecifiedLength}, - {"sha1(1234 )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 40, types.UnspecifiedLength}, + {"sha1('1234' )", mysql.TypeVarString, charset.CharsetUTF8MB4, mysql.NotNullFlag, 40, types.UnspecifiedLength}, + {"sha1(1234 )", mysql.TypeVarString, charset.CharsetUTF8MB4, mysql.NotNullFlag, 40, types.UnspecifiedLength}, {"sha2(c_int_d , 0)", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 128, types.UnspecifiedLength}, {"sha2(c_bigint_d , 0)", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 128, types.UnspecifiedLength}, @@ -930,8 +930,8 @@ func (s *testInferTypeSuite) createTestCase4EncryptionFuncs() []typeInferTestCas {"sha2(c_blob_d , 0)", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 128, types.UnspecifiedLength}, {"sha2(c_set , 0)", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 128, types.UnspecifiedLength}, {"sha2(c_enum , 0)", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 128, types.UnspecifiedLength}, - {"sha2('1234' , 0)", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 128, types.UnspecifiedLength}, - {"sha2(1234 , 0)", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 128, types.UnspecifiedLength}, + {"sha2('1234' , 0)", mysql.TypeVarString, charset.CharsetUTF8MB4, mysql.NotNullFlag, 128, types.UnspecifiedLength}, + {"sha2(1234 , 0)", mysql.TypeVarString, charset.CharsetUTF8MB4, mysql.NotNullFlag, 128, types.UnspecifiedLength}, {"sha2(c_int_d , '256')", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 128, types.UnspecifiedLength}, {"sha2(c_bigint_d , '256')", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 128, types.UnspecifiedLength}, @@ -949,20 +949,19 @@ func (s *testInferTypeSuite) createTestCase4EncryptionFuncs() []typeInferTestCas {"sha2(c_blob_d , '256')", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 128, types.UnspecifiedLength}, {"sha2(c_set , '256')", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 128, types.UnspecifiedLength}, {"sha2(c_enum , '256')", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 128, types.UnspecifiedLength}, - {"sha2('1234' , '256')", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 128, types.UnspecifiedLength}, - {"sha2(1234 , '256')", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 128, types.UnspecifiedLength}, + {"sha2('1234' , '256')", mysql.TypeVarString, charset.CharsetUTF8MB4, mysql.NotNullFlag, 128, types.UnspecifiedLength}, + {"sha2(1234 , '256')", mysql.TypeVarString, charset.CharsetUTF8MB4, mysql.NotNullFlag, 128, types.UnspecifiedLength}, {"AES_ENCRYPT(c_int_d, 'key')", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 32, types.UnspecifiedLength}, {"AES_ENCRYPT(c_char, 'key')", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 32, types.UnspecifiedLength}, {"AES_ENCRYPT(c_varchar, 'key')", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 32, types.UnspecifiedLength}, {"AES_ENCRYPT(c_binary, 'key')", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 32, types.UnspecifiedLength}, {"AES_ENCRYPT(c_varbinary, 'key')", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 32, types.UnspecifiedLength}, - {"AES_ENCRYPT('', 'key')", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 16, types.UnspecifiedLength}, - {"AES_ENCRYPT('111111', 'key')", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 16, types.UnspecifiedLength}, - {"AES_ENCRYPT('111111111111111', 'key')", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 16, types.UnspecifiedLength}, - {"AES_ENCRYPT('1111111111111111', 'key')", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 32, types.UnspecifiedLength}, - {"AES_ENCRYPT('11111111111111111', 'key')", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 32, types.UnspecifiedLength}, - + {"AES_ENCRYPT('', 'key')", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag | mysql.NotNullFlag, 16, types.UnspecifiedLength}, + {"AES_ENCRYPT('111111', 'key')", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag | mysql.NotNullFlag, 16, types.UnspecifiedLength}, + {"AES_ENCRYPT('111111111111111', 'key')", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag | mysql.NotNullFlag, 16, types.UnspecifiedLength}, + {"AES_ENCRYPT('1111111111111111', 'key')", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag | mysql.NotNullFlag, 32, types.UnspecifiedLength}, + {"AES_ENCRYPT('11111111111111111', 'key')", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag | mysql.NotNullFlag, 32, types.UnspecifiedLength}, {"AES_DECRYPT('1111111111111111', 'key')", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 16, types.UnspecifiedLength}, {"AES_DECRYPT('11111111111111112222222222222222', 'key')", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 32, types.UnspecifiedLength}, @@ -972,8 +971,8 @@ func (s *testInferTypeSuite) createTestCase4EncryptionFuncs() []typeInferTestCas {"COMPRESS(c_varchar)", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 33, types.UnspecifiedLength}, {"COMPRESS(c_binary)", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 33, types.UnspecifiedLength}, {"COMPRESS(c_varbinary)", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 33, types.UnspecifiedLength}, - {"COMPRESS('')", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 13, types.UnspecifiedLength}, - {"COMPRESS('abcde')", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 18, types.UnspecifiedLength}, + {"COMPRESS('')", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag | mysql.NotNullFlag, 13, types.UnspecifiedLength}, + {"COMPRESS('abcde')", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag | mysql.NotNullFlag, 18, types.UnspecifiedLength}, {"UNCOMPRESS(c_int_d)", mysql.TypeLongBlob, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxBlobWidth, types.UnspecifiedLength}, {"UNCOMPRESS(c_char)", mysql.TypeLongBlob, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxBlobWidth, types.UnspecifiedLength}, @@ -982,8 +981,8 @@ func (s *testInferTypeSuite) createTestCase4EncryptionFuncs() []typeInferTestCas {"UNCOMPRESSED_LENGTH(c_varchar)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 10, 0}, {"UNCOMPRESSED_LENGTH(c_int_d)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 10, 0}, - {"RANDOM_BYTES(5)", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 1024, types.UnspecifiedLength}, - {"RANDOM_BYTES('123')", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 1024, types.UnspecifiedLength}, + {"RANDOM_BYTES(5)", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag | mysql.NotNullFlag, 1024, types.UnspecifiedLength}, + {"RANDOM_BYTES('123')", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag | mysql.NotNullFlag, 1024, types.UnspecifiedLength}, {"RANDOM_BYTES('abc')", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 1024, types.UnspecifiedLength}, } } @@ -1145,8 +1144,8 @@ func (s *testInferTypeSuite) createTestCase4OpFuncs() []typeInferTestCase { {"c_time_d is true", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag | mysql.IsBooleanFlag, 1, 0}, {"c_enum is true", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag | mysql.IsBooleanFlag, 1, 0}, {"c_text_d is true", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag | mysql.IsBooleanFlag, 1, 0}, - {"18446 is true", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag | mysql.IsBooleanFlag, 1, 0}, - {"1844674.1 is true", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag | mysql.IsBooleanFlag, 1, 0}, + {"18446 is true", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag | mysql.IsBooleanFlag | mysql.NotNullFlag, 1, 0}, + {"1844674.1 is true", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag | mysql.IsBooleanFlag | mysql.NotNullFlag, 1, 0}, {"c_int_d is false", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag | mysql.IsBooleanFlag, 1, 0}, {"c_decimal is false", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag | mysql.IsBooleanFlag, 1, 0}, @@ -1156,8 +1155,8 @@ func (s *testInferTypeSuite) createTestCase4OpFuncs() []typeInferTestCase { {"c_time_d is false", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag | mysql.IsBooleanFlag, 1, 0}, {"c_enum is false", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag | mysql.IsBooleanFlag, 1, 0}, {"c_text_d is false", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag | mysql.IsBooleanFlag, 1, 0}, - {"18446 is false", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag | mysql.IsBooleanFlag, 1, 0}, - {"1844674.1 is false", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag | mysql.IsBooleanFlag, 1, 0}, + {"18446 is false", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag | mysql.IsBooleanFlag | mysql.NotNullFlag, 1, 0}, + {"1844674.1 is false", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag | mysql.IsBooleanFlag | mysql.NotNullFlag, 1, 0}, } } @@ -1195,14 +1194,14 @@ func (s *testInferTypeSuite) createTestCase4OtherFuncs() []typeInferTestCase { func (s *testInferTypeSuite) createTestCase4TimeFuncs() []typeInferTestCase { return []typeInferTestCase{ - {`time_format('150:02:28', '%r%r%r%r')`, mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 44, types.UnspecifiedLength}, - {`time_format(123456, '%r%r%r%r')`, mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 44, types.UnspecifiedLength}, + {`time_format('150:02:28', '%r%r%r%r')`, mysql.TypeVarString, charset.CharsetUTF8MB4, 0 | mysql.NotNullFlag, 44, types.UnspecifiedLength}, + {`time_format(123456, '%r%r%r%r')`, mysql.TypeVarString, charset.CharsetUTF8MB4, 0 | mysql.NotNullFlag, 44, types.UnspecifiedLength}, {`time_format('bad string', '%r%r%r%r')`, mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 44, types.UnspecifiedLength}, {`time_format(null, '%r%r%r%r')`, mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 44, types.UnspecifiedLength}, {`date_format(null, '%r%r%r%r')`, mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 44, types.UnspecifiedLength}, - {`date_format('2017-06-15', '%r%r%r%r')`, mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 44, types.UnspecifiedLength}, - {`date_format(151113102019.12, '%r%r%r%r')`, mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 44, types.UnspecifiedLength}, + {`date_format('2017-06-15', '%r%r%r%r')`, mysql.TypeVarString, charset.CharsetUTF8MB4, 0 | mysql.NotNullFlag, 44, types.UnspecifiedLength}, + {`date_format(151113102019.12, '%r%r%r%r')`, mysql.TypeVarString, charset.CharsetUTF8MB4, 0 | mysql.NotNullFlag, 44, types.UnspecifiedLength}, {"timestampadd(HOUR, c_int_d, c_timestamp_d)", mysql.TypeString, charset.CharsetUTF8MB4, 0, 19, types.UnspecifiedLength}, {"timestampadd(minute, c_double_d, c_timestamp_d)", mysql.TypeString, charset.CharsetUTF8MB4, 0, 19, types.UnspecifiedLength}, @@ -1666,38 +1665,38 @@ func (s *testInferTypeSuite) createTestCase4TimeFuncs() []typeInferTestCase { {"dayName(c_set )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 10, types.UnspecifiedLength}, {"dayName(c_enum )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 10, types.UnspecifiedLength}, - {"now() ", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 19, 0}, - {"now(0) ", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 19, 0}, - {"now(1) ", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 21, 1}, - {"now(2) ", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 22, 2}, - {"now(3) ", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 23, 3}, - {"now(4) ", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 24, 4}, - {"now(5) ", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 25, 5}, - {"now(6) ", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 26, 6}, + {"now() ", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag | mysql.NotNullFlag, 19, 0}, + {"now(0) ", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag | mysql.NotNullFlag, 19, 0}, + {"now(1) ", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag | mysql.NotNullFlag, 21, 1}, + {"now(2) ", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag | mysql.NotNullFlag, 22, 2}, + {"now(3) ", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag | mysql.NotNullFlag, 23, 3}, + {"now(4) ", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag | mysql.NotNullFlag, 24, 4}, + {"now(5) ", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag | mysql.NotNullFlag, 25, 5}, + {"now(6) ", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag | mysql.NotNullFlag, 26, 6}, {"now(7) ", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 26, 6}, - {"utc_timestamp() ", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 19, 0}, - {"utc_timestamp(0) ", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 19, 0}, - {"utc_timestamp(1) ", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 21, 1}, - {"utc_timestamp(2) ", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 22, 2}, - {"utc_timestamp(3) ", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 23, 3}, - {"utc_timestamp(4) ", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 24, 4}, - {"utc_timestamp(5) ", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 25, 5}, - {"utc_timestamp(6) ", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 26, 6}, + {"utc_timestamp() ", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag | mysql.NotNullFlag, 19, 0}, + {"utc_timestamp(0) ", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag | mysql.NotNullFlag, 19, 0}, + {"utc_timestamp(1) ", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag | mysql.NotNullFlag, 21, 1}, + {"utc_timestamp(2) ", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag | mysql.NotNullFlag, 22, 2}, + {"utc_timestamp(3) ", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag | mysql.NotNullFlag, 23, 3}, + {"utc_timestamp(4) ", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag | mysql.NotNullFlag, 24, 4}, + {"utc_timestamp(5) ", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag | mysql.NotNullFlag, 25, 5}, + {"utc_timestamp(6) ", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag | mysql.NotNullFlag, 26, 6}, {"utc_timestamp(7) ", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 26, 6}, - {"utc_time() ", mysql.TypeDuration, charset.CharsetBin, mysql.BinaryFlag, 8, 0}, - {"utc_time(0) ", mysql.TypeDuration, charset.CharsetBin, mysql.BinaryFlag, 8, 0}, - {"utc_time(1) ", mysql.TypeDuration, charset.CharsetBin, mysql.BinaryFlag, 10, 1}, - {"utc_time(2) ", mysql.TypeDuration, charset.CharsetBin, mysql.BinaryFlag, 11, 2}, - {"utc_time(3) ", mysql.TypeDuration, charset.CharsetBin, mysql.BinaryFlag, 12, 3}, - {"utc_time(4) ", mysql.TypeDuration, charset.CharsetBin, mysql.BinaryFlag, 13, 4}, - {"utc_time(5) ", mysql.TypeDuration, charset.CharsetBin, mysql.BinaryFlag, 14, 5}, - {"utc_time(6) ", mysql.TypeDuration, charset.CharsetBin, mysql.BinaryFlag, 15, 6}, + {"utc_time() ", mysql.TypeDuration, charset.CharsetBin, mysql.BinaryFlag | mysql.NotNullFlag, 8, 0}, + {"utc_time(0) ", mysql.TypeDuration, charset.CharsetBin, mysql.BinaryFlag | mysql.NotNullFlag, 8, 0}, + {"utc_time(1) ", mysql.TypeDuration, charset.CharsetBin, mysql.BinaryFlag | mysql.NotNullFlag, 10, 1}, + {"utc_time(2) ", mysql.TypeDuration, charset.CharsetBin, mysql.BinaryFlag | mysql.NotNullFlag, 11, 2}, + {"utc_time(3) ", mysql.TypeDuration, charset.CharsetBin, mysql.BinaryFlag | mysql.NotNullFlag, 12, 3}, + {"utc_time(4) ", mysql.TypeDuration, charset.CharsetBin, mysql.BinaryFlag | mysql.NotNullFlag, 13, 4}, + {"utc_time(5) ", mysql.TypeDuration, charset.CharsetBin, mysql.BinaryFlag | mysql.NotNullFlag, 14, 5}, + {"utc_time(6) ", mysql.TypeDuration, charset.CharsetBin, mysql.BinaryFlag | mysql.NotNullFlag, 15, 6}, {"utc_time(7) ", mysql.TypeDuration, charset.CharsetBin, mysql.BinaryFlag, 15, 6}, - {"utc_date() ", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 10, 0}, - {"curdate()", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 10, 0}, + {"utc_date() ", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag | mysql.NotNullFlag, 10, 0}, + {"curdate()", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag | mysql.NotNullFlag, 10, 0}, {"sysdate(4)", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 19, 0}, {"date(c_int_d )", mysql.TypeDate, charset.CharsetBin, mysql.BinaryFlag, 10, 0}, @@ -1768,9 +1767,9 @@ func (s *testInferTypeSuite) createTestCase4TimeFuncs() []typeInferTestCase { {"quarter(c_set )", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0}, {"quarter(c_enum )", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0}, - {"current_time()", mysql.TypeDuration, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxDurationWidthNoFsp, int(types.MinFsp)}, - {"current_time(0)", mysql.TypeDuration, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxDurationWidthWithFsp, int(types.MinFsp)}, - {"current_time(6)", mysql.TypeDuration, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxDurationWidthWithFsp, int(types.MaxFsp)}, + {"current_time()", mysql.TypeDuration, charset.CharsetBin, mysql.BinaryFlag | mysql.NotNullFlag, mysql.MaxDurationWidthNoFsp, int(types.MinFsp)}, + {"current_time(0)", mysql.TypeDuration, charset.CharsetBin, mysql.BinaryFlag | mysql.NotNullFlag, mysql.MaxDurationWidthWithFsp, int(types.MinFsp)}, + {"current_time(6)", mysql.TypeDuration, charset.CharsetBin, mysql.BinaryFlag | mysql.NotNullFlag, mysql.MaxDurationWidthWithFsp, int(types.MaxFsp)}, {"sec_to_time(c_int_d )", mysql.TypeDuration, charset.CharsetBin, mysql.BinaryFlag, 10, 0}, {"sec_to_time(c_bigint_d )", mysql.TypeDuration, charset.CharsetBin, mysql.BinaryFlag, 10, 0}, @@ -1859,14 +1858,14 @@ func (s *testInferTypeSuite) createTestCase4TimeFuncs() []typeInferTestCase { {"maketime(c_int_d, c_int_d, c_varchar)", mysql.TypeDuration, charset.CharsetBin, mysql.BinaryFlag, 17, 6}, {"maketime(c_int_d, c_int_d, 1.2345)", mysql.TypeDuration, charset.CharsetBin, mysql.BinaryFlag, 15, 4}, - {"get_format(DATE, 'USA')", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 17, types.UnspecifiedLength}, + {"get_format(DATE, 'USA')", mysql.TypeVarString, charset.CharsetUTF8MB4, mysql.NotNullFlag, 17, types.UnspecifiedLength}, {"convert_tz(c_time_d, c_text_d, c_text_d)", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxDatetimeWidthWithFsp, int(types.MaxFsp)}, - {"from_unixtime(20170101.999)", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxDatetimeWidthWithFsp, 3}, - {"from_unixtime(20170101.1234567)", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxDatetimeWidthWithFsp, int(types.MaxFsp)}, - {"from_unixtime('20170101.999')", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxDatetimeWidthWithFsp, int(types.MaxFsp)}, - {"from_unixtime(20170101.123, '%H')", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, -1, types.UnspecifiedLength}, + {"from_unixtime(20170101.999)", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag | mysql.NotNullFlag, mysql.MaxDatetimeWidthWithFsp, 3}, + {"from_unixtime(20170101.1234567)", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag | mysql.NotNullFlag, mysql.MaxDatetimeWidthWithFsp, int(types.MaxFsp)}, + {"from_unixtime('20170101.999')", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag | mysql.NotNullFlag, mysql.MaxDatetimeWidthWithFsp, int(types.MaxFsp)}, + {"from_unixtime(20170101.123, '%H')", mysql.TypeVarString, charset.CharsetUTF8MB4, 0 | mysql.NotNullFlag, -1, types.UnspecifiedLength}, {"extract(day from c_char)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxIntWidth, 0}, {"extract(hour from c_char)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxIntWidth, 0}, @@ -1913,13 +1912,13 @@ func (s *testInferTypeSuite) createTestCase4LikeFuncs() []typeInferTestCase { func (s *testInferTypeSuite) createTestCase4Literals() []typeInferTestCase { return []typeInferTestCase{ - {"time '00:00:00'", mysql.TypeDuration, charset.CharsetBin, mysql.BinaryFlag, 10, 0}, - {"time '00'", mysql.TypeDuration, charset.CharsetBin, mysql.BinaryFlag, 10, 0}, - {"time '3 00:00:00'", mysql.TypeDuration, charset.CharsetBin, mysql.BinaryFlag, 10, 0}, - {"time '3 00:00:00.1234'", mysql.TypeDuration, charset.CharsetBin, mysql.BinaryFlag, 15, 4}, - {"timestamp '2017-01-01 01:01:01'", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxDatetimeWidthNoFsp, 0}, - {"timestamp '2017-01-00000000001 01:01:01.001'", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 23, 3}, - {"date '2017-01-01'", mysql.TypeDate, charset.CharsetBin, mysql.BinaryFlag, 10, 0}, + {"time '00:00:00'", mysql.TypeDuration, charset.CharsetBin, mysql.BinaryFlag | mysql.NotNullFlag, 10, 0}, + {"time '00'", mysql.TypeDuration, charset.CharsetBin, mysql.BinaryFlag | mysql.NotNullFlag, 10, 0}, + {"time '3 00:00:00'", mysql.TypeDuration, charset.CharsetBin, mysql.BinaryFlag | mysql.NotNullFlag, 10, 0}, + {"time '3 00:00:00.1234'", mysql.TypeDuration, charset.CharsetBin, mysql.BinaryFlag | mysql.NotNullFlag, 15, 4}, + {"timestamp '2017-01-01 01:01:01'", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag | mysql.NotNullFlag, mysql.MaxDatetimeWidthNoFsp, 0}, + {"timestamp '2017-01-00000000001 01:01:01.001'", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag | mysql.NotNullFlag, 23, 3}, + {"date '2017-01-01'", mysql.TypeDate, charset.CharsetBin, mysql.BinaryFlag | mysql.NotNullFlag, 10, 0}, } } diff --git a/go.mod b/go.mod index b42ff2ea6b384..5a53c7b55cb36 100644 --- a/go.mod +++ b/go.mod @@ -31,7 +31,7 @@ require ( github.com/kr/text v0.2.0 // indirect github.com/ngaut/pools v0.0.0-20180318154953-b7bc8c42aac7 github.com/ngaut/sync2 v0.0.0-20141008032647-7a24ed77b2ef - github.com/ngaut/unistore v0.0.0-20210304095907-0ebafaf44efb + github.com/ngaut/unistore v0.0.0-20210310131351-7ad6a204de87 github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e // indirect github.com/opentracing/basictracer-go v1.0.0 github.com/opentracing/opentracing-go v1.1.0 @@ -80,7 +80,7 @@ require ( gopkg.in/check.v1 v1.0.0-20200902074654-038fdea0a05b // indirect gopkg.in/natefinch/lumberjack.v2 v2.0.0 gopkg.in/yaml.v2 v2.3.0 // indirect - honnef.co/go/tools v0.1.2 // indirect + honnef.co/go/tools v0.1.3 // indirect sourcegraph.com/sourcegraph/appdash v0.0.0-20190731080439-ebfcffb1b5c0 sourcegraph.com/sourcegraph/appdash-data v0.0.0-20151005221446-73f23eafcf67 ) diff --git a/go.sum b/go.sum index 7c152071e31fd..0636203d83bd1 100644 --- a/go.sum +++ b/go.sum @@ -353,8 +353,8 @@ github.com/ngaut/pools v0.0.0-20180318154953-b7bc8c42aac7 h1:7KAv7KMGTTqSmYZtNdc github.com/ngaut/pools v0.0.0-20180318154953-b7bc8c42aac7/go.mod h1:iWMfgwqYW+e8n5lC/jjNEhwcjbRDpl5NT7n2h+4UNcI= github.com/ngaut/sync2 v0.0.0-20141008032647-7a24ed77b2ef h1:K0Fn+DoFqNqktdZtdV3bPQ/0cuYh2H4rkg0tytX/07k= github.com/ngaut/sync2 v0.0.0-20141008032647-7a24ed77b2ef/go.mod h1:7WjlapSfwQyo6LNmIvEWzsW1hbBQfpUO4JWnuQRmva8= -github.com/ngaut/unistore v0.0.0-20210304095907-0ebafaf44efb h1:2rGvEhflp/uK1l1rNUmoHA4CiHpbddHGxg52H71Fke8= -github.com/ngaut/unistore v0.0.0-20210304095907-0ebafaf44efb/go.mod h1:ZR3NH+HzqfiYetwdoAivApnIy8iefPZHTMLfrFNm8g4= +github.com/ngaut/unistore v0.0.0-20210310131351-7ad6a204de87 h1:lVRrhmqIT2zMbmoalrgxQLwWzFd3VtFaaWy0fnMwPro= +github.com/ngaut/unistore v0.0.0-20210310131351-7ad6a204de87/go.mod h1:ZR3NH+HzqfiYetwdoAivApnIy8iefPZHTMLfrFNm8g4= github.com/nicksnyder/go-i18n v1.10.0/go.mod h1:HrK7VCrbOvQoUAQ7Vpy7i87N7JZZZ7R2xBGjv0j365Q= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= @@ -466,7 +466,6 @@ github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/sasha-s/go-deadlock v0.2.0/go.mod h1:StQn567HiB1fF2yJ44N9au7wOhrPS3iZqiDbRupzT10= github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= -github.com/sergi/go-diff v1.0.1-0.20180205163309-da645544ed44 h1:tB9NOR21++IjLyVx3/PCPhWMwqGNCMQEH96A6dMZ/gc= github.com/sergi/go-diff v1.0.1-0.20180205163309-da645544ed44/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= github.com/shirou/gopsutil v2.19.10+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA= github.com/shirou/gopsutil v3.20.12+incompatible h1:6VEGkOXP/eP4o2Ilk8cSsX0PhOEfX6leqAnD+urrp9M= @@ -845,9 +844,10 @@ honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWh honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.1-2019.2.3 h1:3JgtbtFHMiCmsznwGVTUWbgGov+pVqnlf1dEJTNAXeM= honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= -honnef.co/go/tools v0.1.2 h1:SMdYLJl312RXuxXziCCHhRsp/tvct9cGKey0yv95tZM= -honnef.co/go/tools v0.1.2/go.mod h1:NgwopIslSNH47DimFoV78dnkksY2EFtX0ajyb3K/las= +honnef.co/go/tools v0.1.3 h1:qTakTkI6ni6LFD5sBwwsdSO+AQqbSIxOauHTTQKZ/7o= +honnef.co/go/tools v0.1.3/go.mod h1:NgwopIslSNH47DimFoV78dnkksY2EFtX0ajyb3K/las= k8s.io/klog v1.0.0/go.mod h1:4Bi6QPql/J/LkTDqv7R/cd3hPo4k2DG6Ptcz060Ez5I= rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8= sigs.k8s.io/yaml v1.1.0 h1:4A07+ZFc2wgJwo8YNlQpr1rVlgUDlxXHhPJciaPY5gs= diff --git a/kv/mpp.go b/kv/mpp.go index f275346e2d477..df38894d0c5d1 100644 --- a/kv/mpp.go +++ b/kv/mpp.go @@ -45,6 +45,20 @@ func (t *MPPTask) ToPB() *mpp.TaskMeta { return meta } +//MppTaskStates denotes the state of mpp tasks +type MppTaskStates uint8 + +const ( + // MppTaskReady means the task is ready + MppTaskReady MppTaskStates = iota + // MppTaskRunning means the task is running + MppTaskRunning + // MppTaskCancelled means the task is cancelled + MppTaskCancelled + // MppTaskDone means the task is done + MppTaskDone +) + // MPPDispatchRequest stands for a dispatching task. type MPPDispatchRequest struct { Data []byte // data encodes the dag coprocessor request. @@ -55,6 +69,7 @@ type MPPDispatchRequest struct { SchemaVar int64 StartTs uint64 ID int64 // identify a single task + State MppTaskStates } // MPPClient accepts and processes mpp requests. @@ -64,7 +79,7 @@ type MPPClient interface { ConstructMPPTasks(context.Context, *MPPBuildTasksRequest) ([]MPPTaskMeta, error) // DispatchMPPTasks dispatches ALL mpp requests at once, and returns an iterator that transfers the data. - DispatchMPPTasks(context.Context, []*MPPDispatchRequest) Response + DispatchMPPTasks(context.Context, *Variables, []*MPPDispatchRequest) Response } // MPPBuildTasksRequest request the stores allocation for a mpp plan fragment. diff --git a/planner/core/expression_rewriter.go b/planner/core/expression_rewriter.go index e6a6372e80a29..f10134bb849c2 100644 --- a/planner/core/expression_rewriter.go +++ b/planner/core/expression_rewriter.go @@ -1016,8 +1016,16 @@ func (er *expressionRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok case *ast.AggregateFuncExpr, *ast.ColumnNameExpr, *ast.ParenthesesExpr, *ast.WhenClause, *ast.SubqueryExpr, *ast.ExistsSubqueryExpr, *ast.CompareSubqueryExpr, *ast.ValuesExpr, *ast.WindowFuncExpr, *ast.TableNameExpr: case *driver.ValueExpr: - v.Datum.SetValue(v.Datum.GetValue(), &v.Type) - value := &expression.Constant{Value: v.Datum, RetType: &v.Type} + // set right not null flag for constant value + retType := v.Type.Clone() + switch v.Datum.Kind() { + case types.KindNull: + retType.Flag &= ^mysql.NotNullFlag + default: + retType.Flag |= mysql.NotNullFlag + } + v.Datum.SetValue(v.Datum.GetValue(), retType) + value := &expression.Constant{Value: v.Datum, RetType: retType} er.ctxStackAppend(value, types.EmptyName) case *driver.ParamMarkerExpr: var value expression.Expression diff --git a/planner/core/planbuilder.go b/planner/core/planbuilder.go index 5b9d61d0a64ea..7929df815721a 100644 --- a/planner/core/planbuilder.go +++ b/planner/core/planbuilder.go @@ -756,12 +756,12 @@ func (b *PlanBuilder) buildSet(ctx context.Context, v *ast.SetStmt) (Plan, error func (b *PlanBuilder) buildDropBindPlan(v *ast.DropBindingStmt) (Plan, error) { p := &SQLBindPlan{ SQLBindOp: OpSQLBindDrop, - NormdOrigSQL: parser.Normalize(utilparser.RestoreWithDefaultDB(v.OriginNode, b.ctx.GetSessionVars().CurrentDB)), + NormdOrigSQL: parser.Normalize(utilparser.RestoreWithDefaultDB(v.OriginNode, b.ctx.GetSessionVars().CurrentDB, v.OriginNode.Text())), IsGlobal: v.GlobalScope, Db: utilparser.GetDefaultDB(v.OriginNode, b.ctx.GetSessionVars().CurrentDB), } if v.HintedNode != nil { - p.BindSQL = utilparser.RestoreWithDefaultDB(v.HintedNode, b.ctx.GetSessionVars().CurrentDB) + p.BindSQL = utilparser.RestoreWithDefaultDB(v.HintedNode, b.ctx.GetSessionVars().CurrentDB, v.HintedNode.Text()) } b.visitInfo = appendVisitInfo(b.visitInfo, mysql.SuperPriv, "", "", "", nil) return p, nil @@ -800,8 +800,8 @@ func (b *PlanBuilder) buildCreateBindPlan(v *ast.CreateBindingStmt) (Plan, error p := &SQLBindPlan{ SQLBindOp: OpSQLBindCreate, - NormdOrigSQL: parser.Normalize(utilparser.RestoreWithDefaultDB(v.OriginNode, b.ctx.GetSessionVars().CurrentDB)), - BindSQL: utilparser.RestoreWithDefaultDB(v.HintedNode, b.ctx.GetSessionVars().CurrentDB), + NormdOrigSQL: parser.Normalize(utilparser.RestoreWithDefaultDB(v.OriginNode, b.ctx.GetSessionVars().CurrentDB, v.OriginNode.Text())), + BindSQL: utilparser.RestoreWithDefaultDB(v.HintedNode, b.ctx.GetSessionVars().CurrentDB, v.HintedNode.Text()), IsGlobal: v.GlobalScope, BindStmt: v.HintedNode, Db: utilparser.GetDefaultDB(v.OriginNode, b.ctx.GetSessionVars().CurrentDB), diff --git a/planner/core/preprocess.go b/planner/core/preprocess.go index 628658515d413..7c00903139efc 100644 --- a/planner/core/preprocess.go +++ b/planner/core/preprocess.go @@ -315,8 +315,8 @@ func (p *preprocessor) checkBindGrammar(originNode, hintedNode ast.StmtNode, def return } } - originSQL := parser.Normalize(utilparser.RestoreWithDefaultDB(originNode, defaultDB)) - hintedSQL := parser.Normalize(utilparser.RestoreWithDefaultDB(hintedNode, defaultDB)) + originSQL := parser.Normalize(utilparser.RestoreWithDefaultDB(originNode, defaultDB, originNode.Text())) + hintedSQL := parser.Normalize(utilparser.RestoreWithDefaultDB(hintedNode, defaultDB, hintedNode.Text())) if originSQL != hintedSQL { p.err = errors.Errorf("hinted sql and origin sql don't match when hinted sql erase the hint info, after erase hint info, originSQL:%s, hintedSQL:%s", originSQL, hintedSQL) } diff --git a/planner/optimize.go b/planner/optimize.go index 3985b0ef7d83a..b8bac7a8c2cd8 100644 --- a/planner/optimize.go +++ b/planner/optimize.go @@ -290,7 +290,7 @@ func extractSelectAndNormalizeDigest(stmtNode ast.StmtNode, specifiledDB string) } switch x.Stmt.(type) { case *ast.SelectStmt, *ast.DeleteStmt, *ast.UpdateStmt, *ast.InsertStmt: - normalizeSQL := parser.Normalize(utilparser.RestoreWithDefaultDB(x.Stmt, specifiledDB)) + normalizeSQL := parser.Normalize(utilparser.RestoreWithDefaultDB(x.Stmt, specifiledDB, x.Text())) normalizeSQL = plannercore.EraseLastSemicolonInSQL(normalizeSQL) hash := parser.DigestNormalized(normalizeSQL) return x.Stmt, normalizeSQL, hash, nil @@ -298,7 +298,7 @@ func extractSelectAndNormalizeDigest(stmtNode ast.StmtNode, specifiledDB string) plannercore.EraseLastSemicolon(x) var normalizeExplainSQL string if specifiledDB != "" { - normalizeExplainSQL = parser.Normalize(utilparser.RestoreWithDefaultDB(x, specifiledDB)) + normalizeExplainSQL = parser.Normalize(utilparser.RestoreWithDefaultDB(x, specifiledDB, x.Text())) } else { normalizeExplainSQL = parser.Normalize(x.Text()) } @@ -321,7 +321,7 @@ func extractSelectAndNormalizeDigest(stmtNode ast.StmtNode, specifiledDB string) if len(x.Text()) == 0 { return x, "", "", nil } - normalizedSQL, hash := parser.NormalizeDigest(utilparser.RestoreWithDefaultDB(x, specifiledDB)) + normalizedSQL, hash := parser.NormalizeDigest(utilparser.RestoreWithDefaultDB(x, specifiledDB, x.Text())) return x, normalizedSQL, hash, nil } return nil, "", "", nil diff --git a/server/conn_test.go b/server/conn_test.go index ee9d79e10d61c..2aab4e0d3733e 100644 --- a/server/conn_test.go +++ b/server/conn_test.go @@ -778,7 +778,11 @@ func (ts *ConnTestSuite) TestTiFlashFallback(c *C) { c.Assert(cc.handleQuery(ctx, "select * from t t1 join t t2 on t1.a = t2.a"), NotNil) c.Assert(failpoint.Disable("github.com/pingcap/tidb/server/secondNextErr"), IsNil) - // TODO: simple TiFlash query (unary + non-streaming) + // simple TiFlash query (unary + non-streaming) + tk.MustExec("set @@tidb_allow_batch_cop=0; set @@tidb_allow_mpp=0;") + c.Assert(failpoint.Enable("github.com/pingcap/tidb/store/tikv/tikvStoreSendReqResult", "return(\"requestTiFlashError\")"), IsNil) + testFallbackWork(c, tk, cc, "select sum(a) from t") + c.Assert(failpoint.Disable("github.com/pingcap/tidb/store/tikv/tikvStoreSendReqResult"), IsNil) // TiFlash query based on batch cop (batch + streaming) tk.MustExec("set @@tidb_allow_batch_cop=1; set @@tidb_allow_mpp=0;") diff --git a/session/bootstrap.go b/session/bootstrap.go index 7672d3d6c8c10..ab43ce0a0e98b 100644 --- a/session/bootstrap.go +++ b/session/bootstrap.go @@ -1389,7 +1389,7 @@ func updateBindInfo(iter *chunk.Iterator4Chunk, p *parser.Parser, bindMap map[st if err != nil { logutil.BgLogger().Fatal("updateBindInfo error", zap.Error(err)) } - originWithDB := parser.Normalize(utilparser.RestoreWithDefaultDB(stmt, db)) + originWithDB := parser.Normalize(utilparser.RestoreWithDefaultDB(stmt, db, bind)) if _, ok := bindMap[originWithDB]; ok { // The results are sorted in descending order of time. // And in the following cases, duplicate originWithDB may occur @@ -1400,7 +1400,7 @@ func updateBindInfo(iter *chunk.Iterator4Chunk, p *parser.Parser, bindMap map[st continue } bindMap[originWithDB] = bindInfo{ - bindSQL: utilparser.RestoreWithDefaultDB(stmt, db), + bindSQL: utilparser.RestoreWithDefaultDB(stmt, db, bind), status: row.GetString(2), createTime: row.GetTime(3), charset: charset, diff --git a/statistics/cmsketch_test.go b/statistics/cmsketch_test.go index b3320bb7bf5b3..38fe81c79c60f 100644 --- a/statistics/cmsketch_test.go +++ b/statistics/cmsketch_test.go @@ -17,7 +17,6 @@ import ( "fmt" "math" "math/rand" - "strconv" "time" . "github.com/pingcap/check" @@ -304,89 +303,3 @@ func (s *testStatisticsSuite) TestCMSketchCodingTopN(c *C) { // do not panic DecodeCMSketchAndTopN([]byte{}, rows) } - -func (s *testStatisticsSuite) TestMergeTopN(c *C) { - tests := []struct { - topnNum int - n int - maxTopNVal int - maxTopNCnt int - }{ - { - topnNum: 10, - n: 5, - maxTopNVal: 50, - maxTopNCnt: 100, - }, - { - topnNum: 1, - n: 5, - maxTopNVal: 50, - maxTopNCnt: 100, - }, - { - topnNum: 5, - n: 5, - maxTopNVal: 5, - maxTopNCnt: 100, - }, - { - topnNum: 5, - n: 5, - maxTopNVal: 10, - maxTopNCnt: 100, - }, - } - for _, t := range tests { - topnNum, n := t.topnNum, t.n - maxTopNVal, maxTopNCnt := t.maxTopNVal, t.maxTopNCnt - - // the number of maxTopNVal should be bigger than n. - ok := maxTopNVal >= n - c.Assert(ok, Equals, true) - - topNs := make([]*TopN, 0, topnNum) - res := make(map[int]uint64) - rand.Seed(time.Now().Unix()) - for i := 0; i < topnNum; i++ { - topN := NewTopN(n) - occur := make(map[int]bool) - for j := 0; j < n; j++ { - // The range of numbers in the topn structure is in [0, maxTopNVal) - // But there cannot be repeated occurrences of value in a topN structure. - randNum := rand.Intn(maxTopNVal) - for occur[randNum] { - randNum = rand.Intn(maxTopNVal) - } - occur[randNum] = true - tString := []byte(fmt.Sprintf("%d", randNum)) - // The range of the number of occurrences in the topn structure is in [0, maxTopNCnt) - randCnt := uint64(rand.Intn(maxTopNCnt)) - res[randNum] += randCnt - topNMeta := TopNMeta{tString, randCnt} - topN.TopN = append(topN.TopN, topNMeta) - } - topNs = append(topNs, topN) - } - topN, remainTopN := MergeTopN(topNs, uint32(n)) - cnt := len(topN.TopN) - var minTopNCnt uint64 - for _, topNMeta := range topN.TopN { - val, err := strconv.Atoi(string(topNMeta.Encoded)) - c.Assert(err, IsNil) - c.Assert(topNMeta.Count, Equals, res[val]) - minTopNCnt = topNMeta.Count - } - if remainTopN != nil { - cnt += len(remainTopN) - for _, remainTopNMeta := range remainTopN { - val, err := strconv.Atoi(string(remainTopNMeta.Encoded)) - c.Assert(err, IsNil) - c.Assert(remainTopNMeta.Count, Equals, res[val]) - ok = minTopNCnt > remainTopNMeta.Count - c.Assert(ok, Equals, true) - } - } - c.Assert(cnt, Equals, len(res)) - } -} diff --git a/statistics/handle/handle.go b/statistics/handle/handle.go index 048ea4fc8e318..6e8b3b3d3f607 100644 --- a/statistics/handle/handle.go +++ b/statistics/handle/handle.go @@ -327,9 +327,9 @@ func (h *Handle) MergePartitionStats2GlobalStats(sc sessionctx.Context, opts map return } globalTableInfo := globalTable.Meta() - partitionNum := globalTableInfo.Partition.Num + partitionNum := len(globalTableInfo.Partition.Definitions) partitionIDs := make([]int64, 0, partitionNum) - for i := uint64(0); i < partitionNum; i++ { + for i := 0; i < partitionNum; i++ { partitionIDs = append(partitionIDs, globalTableInfo.Partition.Definitions[i].ID) } @@ -420,7 +420,7 @@ func (h *Handle) MergePartitionStats2GlobalStats(sc sessionctx.Context, opts map for i := 0; i < globalStats.Num; i++ { // Merge CMSketch globalStats.Cms[i] = allCms[i][0].Copy() - for j := uint64(1); j < partitionNum; j++ { + for j := 1; j < partitionNum; j++ { err = globalStats.Cms[i].MergeCMSketch(allCms[i][j]) if err != nil { return @@ -444,7 +444,7 @@ func (h *Handle) MergePartitionStats2GlobalStats(sc sessionctx.Context, opts map // For the column stats, we should merge the FMSketch first. And use the FMSketch to calculate the new NDV. // merge FMSketch globalStats.Fms[i] = allFms[i][0].Copy() - for j := uint64(1); j < partitionNum; j++ { + for j := 1; j < partitionNum; j++ { globalStats.Fms[i].MergeFMSketch(allFms[i][j]) } diff --git a/statistics/handle/handle_test.go b/statistics/handle/handle_test.go index 47fdc81edc931..9d5c2eccf57d3 100644 --- a/statistics/handle/handle_test.go +++ b/statistics/handle/handle_test.go @@ -1504,18 +1504,18 @@ partition by range (a) ( c.Assert(s.do.StatsHandle().DumpStatsDeltaToKV(handle.DumpAll), IsNil) tk.MustExec("set @@tidb_partition_prune_mode='static'") - tk.MustExec("set @@tidb_analyze_version=1") + tk.MustExec("set @@session.tidb_analyze_version=1") tk.MustExec("analyze table t") // both p0 and p1 are in ver1 c.Assert(len(tk.MustQuery("show stats_meta").Rows()), Equals, 2) tk.MustExec("set @@tidb_partition_prune_mode='dynamic'") - tk.MustExec("set @@tidb_analyze_version=1") + tk.MustExec("set @@session.tidb_analyze_version=1") err := tk.ExecToErr("analyze table t") // try to build global-stats on ver1 c.Assert(err, NotNil) c.Assert(err.Error(), Equals, "[stats]: some partition level statistics are not in statistics version 2, please set tidb_analyze_version to 2 and analyze the this table") tk.MustExec("set @@tidb_partition_prune_mode='dynamic'") - tk.MustExec("set @@tidb_analyze_version=2") + tk.MustExec("set @@session.tidb_analyze_version=2") err = tk.ExecToErr("analyze table t partition p1") // only analyze p1 to let it in ver2 while p0 is in ver1 c.Assert(err, NotNil) c.Assert(err.Error(), Equals, "[stats]: some partition level statistics are not in statistics version 2, please set tidb_analyze_version to 2 and analyze the this table") @@ -1523,11 +1523,43 @@ partition by range (a) ( tk.MustExec("analyze table t") // both p0 and p1 are in ver2 c.Assert(len(tk.MustQuery("show stats_meta").Rows()), Equals, 3) + // If we already have global-stats, we can get the latest global-stats by analyzing the newly added partition. tk.MustExec("alter table t add partition (partition p2 values less than (30))") - tk.MustExec("insert t values (13), (14)") + tk.MustExec("insert t values (13), (14), (22), (23)") + c.Assert(s.do.StatsHandle().DumpStatsDeltaToKV(handle.DumpAll), IsNil) + tk.MustExec("analyze table t partition p2") // it will success since p0 and p1 are both in ver2 + c.Assert(s.do.StatsHandle().DumpStatsDeltaToKV(handle.DumpAll), IsNil) + do := s.do + is := do.InfoSchema() + h := do.StatsHandle() + c.Assert(h.Update(is), IsNil) + tbl, err := is.TableByName(model.NewCIStr("test"), model.NewCIStr("t")) + c.Assert(err, IsNil) + tableInfo := tbl.Meta() + globalStats := h.GetTableStats(tableInfo) + // global.count = p0.count(3) + p1.count(2) + p2.count(2) + // We did not analyze partition p1, so the value here has not changed + c.Assert(globalStats.Count, Equals, int64(7)) + + tk.MustExec("analyze table t partition p1;") + globalStats = h.GetTableStats(tableInfo) + // global.count = p0.count(3) + p1.count(4) + p2.count(4) + // The value of p1.Count is correct now. + c.Assert(globalStats.Count, Equals, int64(9)) + c.Assert(globalStats.ModifyCount, Equals, int64(0)) + + tk.MustExec("alter table t drop partition p2;") c.Assert(s.do.StatsHandle().DumpStatsDeltaToKV(handle.DumpAll), IsNil) - tk.MustExec("analyze table t partition p2") // it will success since p0 and p1 are both in ver2 - c.Assert(len(tk.MustQuery("show stats_meta").Rows()), Equals, 4) // p0, p1, p2 and global + globalStats = h.GetTableStats(tableInfo) + // The value of global.count will be updated the next time analyze. + c.Assert(globalStats.Count, Equals, int64(9)) + c.Assert(globalStats.ModifyCount, Equals, int64(0)) + + tk.MustExec("analyze table t;") + globalStats = h.GetTableStats(tableInfo) + // global.count = p0.count(3) + p1.count(4) + // The value of global.Count is correct now. + c.Assert(globalStats.Count, Equals, int64(7)) } func (s *testStatsSuite) TestExtendedStatsDefaultSwitch(c *C) { diff --git a/statistics/handle/update_test.go b/statistics/handle/update_test.go index d08be9e65b332..9c8c7ea39d600 100644 --- a/statistics/handle/update_test.go +++ b/statistics/handle/update_test.go @@ -16,6 +16,7 @@ package handle_test import ( "fmt" "math" + "math/rand" "os" "strconv" "strings" @@ -2082,6 +2083,93 @@ func (s *testStatsSuite) TestFeedbackCounter(c *C) { c.Assert(subtraction(newNum, oldNum), Equals, 20) } +func (s *testSerialStatsSuite) TestMergeTopN(c *C) { + // Move this test to here to avoid race test. + tests := []struct { + topnNum int + n int + maxTopNVal int + maxTopNCnt int + }{ + { + topnNum: 10, + n: 5, + maxTopNVal: 50, + maxTopNCnt: 100, + }, + { + topnNum: 1, + n: 5, + maxTopNVal: 50, + maxTopNCnt: 100, + }, + { + topnNum: 5, + n: 5, + maxTopNVal: 5, + maxTopNCnt: 100, + }, + { + topnNum: 5, + n: 5, + maxTopNVal: 10, + maxTopNCnt: 100, + }, + } + for _, t := range tests { + topnNum, n := t.topnNum, t.n + maxTopNVal, maxTopNCnt := t.maxTopNVal, t.maxTopNCnt + + // the number of maxTopNVal should be bigger than n. + ok := maxTopNVal >= n + c.Assert(ok, Equals, true) + + topNs := make([]*statistics.TopN, 0, topnNum) + res := make(map[int]uint64) + rand.Seed(time.Now().Unix()) + for i := 0; i < topnNum; i++ { + topN := statistics.NewTopN(n) + occur := make(map[int]bool) + for j := 0; j < n; j++ { + // The range of numbers in the topn structure is in [0, maxTopNVal) + // But there cannot be repeated occurrences of value in a topN structure. + randNum := rand.Intn(maxTopNVal) + for occur[randNum] { + randNum = rand.Intn(maxTopNVal) + } + occur[randNum] = true + tString := []byte(fmt.Sprintf("%d", randNum)) + // The range of the number of occurrences in the topn structure is in [0, maxTopNCnt) + randCnt := uint64(rand.Intn(maxTopNCnt)) + res[randNum] += randCnt + topNMeta := statistics.TopNMeta{Encoded: tString, Count: randCnt} + topN.TopN = append(topN.TopN, topNMeta) + } + topNs = append(topNs, topN) + } + topN, remainTopN := statistics.MergeTopN(topNs, uint32(n)) + cnt := len(topN.TopN) + var minTopNCnt uint64 + for _, topNMeta := range topN.TopN { + val, err := strconv.Atoi(string(topNMeta.Encoded)) + c.Assert(err, IsNil) + c.Assert(topNMeta.Count, Equals, res[val]) + minTopNCnt = topNMeta.Count + } + if remainTopN != nil { + cnt += len(remainTopN) + for _, remainTopNMeta := range remainTopN { + val, err := strconv.Atoi(string(remainTopNMeta.Encoded)) + c.Assert(err, IsNil) + c.Assert(remainTopNMeta.Count, Equals, res[val]) + ok = minTopNCnt > remainTopNMeta.Count + c.Assert(ok, Equals, true) + } + } + c.Assert(cnt, Equals, len(res)) + } +} + func (s *testSerialStatsSuite) TestAutoUpdatePartitionInDynamicOnlyMode(c *C) { defer cleanEnv(c, s.store, s.do) testKit := testkit.NewTestKit(c, s.store) diff --git a/store/copr/mpp.go b/store/copr/mpp.go index da008937f65cf..8a6bd3c108b38 100644 --- a/store/copr/mpp.go +++ b/store/copr/mpp.go @@ -21,6 +21,7 @@ import ( "time" "github.com/pingcap/errors" + "github.com/pingcap/failpoint" "github.com/pingcap/kvproto/pkg/coprocessor" "github.com/pingcap/kvproto/pkg/kvrpcpb" "github.com/pingcap/kvproto/pkg/metapb" @@ -123,15 +124,31 @@ type mppIterator struct { respChan chan *mppResponse - rpcCancel *tikv.RPCCanceller + cancelFunc context.CancelFunc wg sync.WaitGroup closed uint32 + + vars *kv.Variables + + mu sync.Mutex } func (m *mppIterator) run(ctx context.Context) { for _, task := range m.tasks { + if atomic.LoadUint32(&m.closed) == 1 { + break + } + m.mu.Lock() + switch task.State { + case kv.MppTaskReady: + task.State = kv.MppTaskRunning + m.mu.Unlock() + default: + m.mu.Unlock() + break + } m.wg.Add(1) bo := tikv.NewBackoffer(ctx, copNextMaxBackoff) go m.handleDispatchReq(ctx, bo, task) @@ -142,6 +159,7 @@ func (m *mppIterator) run(ctx context.Context) { func (m *mppIterator) sendError(err error) { m.sendToRespCh(&mppResponse{err: err}) + m.cancelMppTasks() } func (m *mppIterator) sendToRespCh(resp *mppResponse) (exit bool) { @@ -223,7 +241,13 @@ func (m *mppIterator) handleDispatchReq(ctx context.Context, bo *tikv.Backoffer, m.sendError(errors.New(realResp.Error.Msg)) return } - + failpoint.Inject("mppNonRootTaskError", func(val failpoint.Value) { + if val.(bool) && !req.IsRoot { + time.Sleep(1 * time.Second) + m.sendError(tikv.ErrTiFlashServerTimeout) + return + } + }) if !req.IsRoot { return } @@ -231,6 +255,39 @@ func (m *mppIterator) handleDispatchReq(ctx context.Context, bo *tikv.Backoffer, m.establishMPPConns(bo, req, taskMeta) } +// NOTE: We do not retry here, because retry is helpless when errors result from TiFlash or Network. If errors occur, the execution on TiFlash will finally stop after some minutes. +// This function is exclusively called, and only the first call succeeds sending tasks and setting all tasks as cancelled, while others will not work. +func (m *mppIterator) cancelMppTasks() { + m.mu.Lock() + defer m.mu.Unlock() + killReq := &mpp.CancelTaskRequest{ + Meta: &mpp.TaskMeta{StartTs: m.startTs}, + } + + wrappedReq := tikvrpc.NewRequest(tikvrpc.CmdMPPCancel, killReq, kvrpcpb.Context{}) + wrappedReq.StoreTp = kv.TiFlash + + usedStoreAddrs := make(map[string]bool) + for _, task := range m.tasks { + // get the store address of running tasks + if task.State == kv.MppTaskRunning && !usedStoreAddrs[task.Meta.GetAddress()] { + usedStoreAddrs[task.Meta.GetAddress()] = true + } else if task.State == kv.MppTaskCancelled { + return + } + task.State = kv.MppTaskCancelled + } + + // send cancel cmd to all stores where tasks run + for addr := range usedStoreAddrs { + _, err := m.store.GetTiKVClient().SendRequest(context.Background(), addr, wrappedReq, tikv.ReadTimeoutUltraLong) + logutil.BgLogger().Debug("cancel task ", zap.Uint64("query id ", m.startTs), zap.String(" on addr ", addr)) + if err != nil { + logutil.BgLogger().Error("cancel task error: ", zap.Error(err), zap.Uint64(" for query id ", m.startTs), zap.String(" on addr ", addr)) + } + } +} + func (m *mppIterator) establishMPPConns(bo *tikv.Backoffer, req *kv.MPPDispatchRequest, taskMeta *mpp.TaskMeta) { connReq := &mpp.EstablishMPPConnectionRequest{ SenderMeta: taskMeta, @@ -260,13 +317,13 @@ func (m *mppIterator) establishMPPConns(bo *tikv.Backoffer, req *kv.MPPDispatchR return } - // TODO: cancel the whole process when some error happens for { err := m.handleMPPStreamResponse(bo, resp, req) if err != nil { m.sendError(err) return } + resp, err = stream.Recv() if err != nil { if errors.Cause(err) == io.EOF { @@ -280,9 +337,7 @@ func (m *mppIterator) establishMPPConns(bo *tikv.Backoffer, req *kv.MPPDispatchR logutil.BgLogger().Info("stream unknown error", zap.Error(err)) } } - m.sendToRespCh(&mppResponse{ - err: tikv.ErrTiFlashServerTimeout, - }) + m.sendError(tikv.ErrTiFlashServerTimeout) return } } @@ -293,7 +348,7 @@ func (m *mppIterator) Close() error { if atomic.CompareAndSwapUint32(&m.closed, 0, 1) { close(m.finishCh) } - m.rpcCancel.CancelAll() + m.cancelFunc() m.wg.Wait() return nil } @@ -336,7 +391,11 @@ func (m *mppIterator) nextImpl(ctx context.Context) (resp *mppResponse, ok bool, case resp, ok = <-m.respChan: return case <-ticker.C: - //TODO: kill query + if m.vars != nil && m.vars.Killed != nil && atomic.LoadUint32(m.vars.Killed) == 1 { + err = tikv.ErrQueryInterrupted + exit = true + return + } case <-m.finishCh: exit = true return @@ -370,19 +429,18 @@ func (m *mppIterator) Next(ctx context.Context) (kv.ResultSubset, error) { return resp, nil } -// DispatchMPPTasks dispatches all the mpp task and waits for the reponses. -func (c *MPPClient) DispatchMPPTasks(ctx context.Context, dispatchReqs []*kv.MPPDispatchRequest) kv.Response { +// DispatchMPPTasks dispatches all the mpp task and waits for the responses. +func (c *MPPClient) DispatchMPPTasks(ctx context.Context, vars *kv.Variables, dispatchReqs []*kv.MPPDispatchRequest) kv.Response { + ctxChild, cancelFunc := context.WithCancel(ctx) iter := &mppIterator{ - store: c.store, - tasks: dispatchReqs, - finishCh: make(chan struct{}), - rpcCancel: tikv.NewRPCanceller(), - respChan: make(chan *mppResponse, 4096), - startTs: dispatchReqs[0].StartTs, + store: c.store, + tasks: dispatchReqs, + finishCh: make(chan struct{}), + cancelFunc: cancelFunc, + respChan: make(chan *mppResponse, 4096), + startTs: dispatchReqs[0].StartTs, + vars: vars, } - ctx = context.WithValue(ctx, tikv.RPCCancellerCtxKey{}, iter.rpcCancel) - - // TODO: Process the case of query cancellation. - go iter.run(ctx) + go iter.run(ctxChild) return iter } diff --git a/store/mockstore/unistore/cophandler/cop_handler.go b/store/mockstore/unistore/cophandler/cop_handler.go index 685804f12314e..c119a49db16e2 100644 --- a/store/mockstore/unistore/cophandler/cop_handler.go +++ b/store/mockstore/unistore/cophandler/cop_handler.go @@ -15,6 +15,7 @@ package cophandler import ( "bytes" + "context" "fmt" "time" @@ -46,6 +47,7 @@ type MPPCtx struct { RPCClient client.Client StoreAddr string TaskHandler *MPPTaskHandler + Ctx context.Context } // HandleCopRequest handles coprocessor request. diff --git a/store/mockstore/unistore/cophandler/mpp_exec.go b/store/mockstore/unistore/cophandler/mpp_exec.go index f586230a4ea53..8da2ed1085f1d 100644 --- a/store/mockstore/unistore/cophandler/mpp_exec.go +++ b/store/mockstore/unistore/cophandler/mpp_exec.go @@ -14,7 +14,6 @@ package cophandler import ( - "context" "io" "math" "sync" @@ -258,7 +257,7 @@ func (e *exchRecvExec) next() (*chunk.Chunk, error) { func (e *exchRecvExec) EstablishConnAndReceiveData(h *MPPTaskHandler, meta *mpp.TaskMeta) ([]*mpp.MPPDataPacket, error) { req := &mpp.EstablishMPPConnectionRequest{ReceiverMeta: h.Meta, SenderMeta: meta} rpcReq := tikvrpc.NewRequest(tikvrpc.CmdMPPConn, req, kvrpcpb.Context{}) - rpcResp, err := h.RPCClient.SendRequest(context.Background(), meta.Address, rpcReq, 3600*time.Second) + rpcResp, err := h.RPCClient.SendRequest(e.mppCtx.Ctx, meta.Address, rpcReq, 3600*time.Second) if err != nil { return nil, errors.Trace(err) } diff --git a/store/mockstore/unistore/rpc.go b/store/mockstore/unistore/rpc.go index 69d9a5639fc39..72f36eb239abe 100644 --- a/store/mockstore/unistore/rpc.go +++ b/store/mockstore/unistore/rpc.go @@ -223,11 +223,6 @@ func (c *RPCClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.R case tikvrpc.CmdRawScan: resp.Resp, err = c.rawHandler.RawScan(ctx, req.RawScan()) case tikvrpc.CmdCop: - failpoint.Inject("copRpcErr"+addr, func(value failpoint.Value) { - if value.(string) == addr { - failpoint.Return(nil, errors.New("cop rpc error")) - } - }) resp.Resp, err = c.usSvr.Coprocessor(ctx, req.Cop()) case tikvrpc.CmdCopStream: resp.Resp, err = c.handleCopStream(ctx, req.Cop()) @@ -253,6 +248,7 @@ func (c *RPCClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.R } }) resp.Resp, err = c.handleDispatchMPPTask(ctx, req.DispatchMPPTask(), storeID) + case tikvrpc.CmdMPPCancel: case tikvrpc.CmdMvccGetByKey: resp.Resp, err = c.usSvr.MvccGetByKey(ctx, req.MvccGetByKey()) case tikvrpc.CmdMvccGetByStartTs: @@ -302,7 +298,7 @@ func (c *RPCClient) handleEstablishMPPConnection(ctx context.Context, r *mpp.Est if err != nil { return nil, err } - var mockClient = mockMPPConnectionClient{mppResponses: mockServer.mppResponses, idx: 0, targetTask: r.ReceiverMeta} + var mockClient = mockMPPConnectionClient{mppResponses: mockServer.mppResponses, idx: 0, ctx: ctx, targetTask: r.ReceiverMeta} streamResp := &tikvrpc.MPPStreamResponse{Tikv_EstablishMPPConnectionClient: &mockClient} _, cancel := context.WithCancel(ctx) streamResp.Lease.Cancel = cancel @@ -477,8 +473,8 @@ type mockMPPConnectionClient struct { mockClientStream mppResponses []*mpp.MPPDataPacket idx int - - targetTask *mpp.TaskMeta + ctx context.Context + targetTask *mpp.TaskMeta } func (mock *mockMPPConnectionClient) Recv() (*mpp.MPPDataPacket, error) { @@ -492,6 +488,18 @@ func (mock *mockMPPConnectionClient) Recv() (*mpp.MPPDataPacket, error) { failpoint.Return(nil, context.Canceled) } }) + failpoint.Inject("mppRecvHang", func(val failpoint.Value) { + for val.(bool) { + select { + case <-mock.ctx.Done(): + { + failpoint.Return(nil, context.Canceled) + } + default: + time.Sleep(1 * time.Second) + } + } + }) return nil, io.EOF } diff --git a/store/tikv/region_request.go b/store/tikv/region_request.go index 5116c36964c09..7182bf23d9e91 100644 --- a/store/tikv/region_request.go +++ b/store/tikv/region_request.go @@ -254,6 +254,10 @@ func (s *RegionRequestSender) SendReqCtx( if sType == kv.TiDB { failpoint.Return(nil, nil, ErrTiKVServerTimeout) } + case "requestTiFlashError": + if sType == kv.TiFlash { + failpoint.Return(nil, nil, ErrTiFlashServerTimeout) + } } }) diff --git a/store/tikv/tikvrpc/tikvrpc.go b/store/tikv/tikvrpc/tikvrpc.go index c695f8ecb3827..ed4da00e146a9 100644 --- a/store/tikv/tikvrpc/tikvrpc.go +++ b/store/tikv/tikvrpc/tikvrpc.go @@ -73,6 +73,7 @@ const ( CmdBatchCop CmdMPPTask CmdMPPConn + CmdMPPCancel CmdMvccGetByKey CmdType = 1024 + iota CmdMvccGetByStartTs @@ -147,6 +148,8 @@ func (t CmdType) String() string { return "DispatchMPPTask" case CmdMPPConn: return "EstablishMPPConnection" + case CmdMPPCancel: + return "CancelMPPTask" case CmdMvccGetByKey: return "MvccGetByKey" case CmdMvccGetByStartTs: @@ -170,7 +173,7 @@ type Request struct { Type CmdType Req interface{} kvrpcpb.Context - ReplicaReadType kv.ReplicaReadType // dirrerent from `kvrpcpb.Context.ReplicaRead` + ReplicaReadType kv.ReplicaReadType // different from `kvrpcpb.Context.ReplicaRead` ReplicaReadSeed *uint32 // pointer to follower read seed in snapshot/coprocessor StoreTp kv.StoreType } @@ -334,11 +337,16 @@ func (req *Request) DispatchMPPTask() *mpp.DispatchTaskRequest { return req.Req.(*mpp.DispatchTaskRequest) } -// EstablishMPPConn returns stablishMPPConnectionRequest in request. +// EstablishMPPConn returns EstablishMPPConnectionRequest in request. func (req *Request) EstablishMPPConn() *mpp.EstablishMPPConnectionRequest { return req.Req.(*mpp.EstablishMPPConnectionRequest) } +// CancelMPPTask returns canceling task in request +func (req *Request) CancelMPPTask() *mpp.CancelTaskRequest { + return req.Req.(*mpp.CancelTaskRequest) +} + // MvccGetByKey returns MvccGetByKeyRequest in request. func (req *Request) MvccGetByKey() *kvrpcpb.MvccGetByKeyRequest { return req.Req.(*kvrpcpb.MvccGetByKeyRequest) @@ -531,7 +539,7 @@ func FromBatchCommandsResponse(res *tikvpb.BatchCommandsResponse_Response) (*Res panic("unreachable") } -// CopStreamResponse combinates tikvpb.Tikv_CoprocessorStreamClient and the first Recv() result together. +// CopStreamResponse combines tikvpb.Tikv_CoprocessorStreamClient and the first Recv() result together. // In streaming API, get grpc stream client may not involve any network packet, then region error have // to be handled in Recv() function. This struct facilitates the error handling. type CopStreamResponse struct { @@ -803,7 +811,7 @@ func (resp *Response) GetRegionError() (*errorpb.Error, error) { } // CallRPC launches a rpc call. -// ch is needed to implement timeout for coprocessor streaing, the stream object's +// ch is needed to implement timeout for coprocessor streaming, the stream object's // cancel function will be sent to the channel, together with a lease checked by a background goroutine. func CallRPC(ctx context.Context, client tikvpb.TikvClient, req *Request) (*Response, error) { resp := &Response{} @@ -871,6 +879,9 @@ func CallRPC(ctx context.Context, client tikvpb.TikvClient, req *Request) (*Resp resp.Resp = &MPPStreamResponse{ Tikv_EstablishMPPConnectionClient: streamClient, } + case CmdMPPCancel: + // it cannot use the ctx with cancel(), otherwise this cmd will fail. + resp.Resp, err = client.CancelMPPTask(ctx, req.CancelMPPTask()) case CmdCopStream: var streamClient tikvpb.Tikv_CoprocessorStreamClient streamClient, err = client.CoprocessorStream(ctx, req.Cop()) @@ -988,7 +999,7 @@ func (resp *MPPStreamResponse) Close() { } } -// CheckStreamTimeoutLoop runs periodically to check is there any stream request timeouted. +// CheckStreamTimeoutLoop runs periodically to check is there any stream request timed out. // Lease is an object to track stream requests, call this function with "go CheckStreamTimeoutLoop()" // It is not guaranteed to call every Lease.Cancel() putting into channel when exits. // If grpc-go supports SetDeadline(https://github.com/grpc/grpc-go/issues/2917), we can stop using this method. diff --git a/util/parser/ast.go b/util/parser/ast.go index 7a1c1ebf4c31c..ded22800022cd 100644 --- a/util/parser/ast.go +++ b/util/parser/ast.go @@ -49,8 +49,76 @@ func (i *implicitDatabase) Leave(in ast.Node) (out ast.Node, ok bool) { return in, true } +func findTablePos(s, t string) int { + l := 0 + for i := range s { + if s[i] == ' ' || s[i] == ',' { + if len(t) == i-l && strings.Compare(s[l:i], t) == 0 { + return l + } + l = i + 1 + } + } + if len(t) == len(s)-l && strings.Compare(s[l:], t) == 0 { + return l + } + return -1 +} + +// SimpleCases captures simple SQL statements and uses string replacement instead of `restore` to improve performance. +// See https://github.com/pingcap/tidb/issues/22398. +func SimpleCases(node ast.StmtNode, defaultDB, origin string) (s string, ok bool) { + if len(origin) == 0 { + return "", false + } + insert, ok := node.(*ast.InsertStmt) + if !ok { + return "", false + } + if insert.Select != nil || insert.Setlist != nil || insert.OnDuplicate != nil || (insert.TableHints != nil && len(insert.TableHints) != 0) { + return "", false + } + join := insert.Table.TableRefs + if join.Tp != 0 || join.Right != nil { + return "", false + } + ts, ok := join.Left.(*ast.TableSource) + if !ok { + return "", false + } + tn, ok := ts.Source.(*ast.TableName) + if !ok { + return "", false + } + parenPos := strings.Index(origin, "(") + if parenPos == -1 { + return "", false + } + if strings.Contains(origin[:parenPos], ".") { + return origin, true + } + lower := strings.ToLower(origin[:parenPos]) + pos := findTablePos(lower, tn.Name.L) + if pos == -1 { + return "", false + } + var builder strings.Builder + builder.WriteString(origin[:pos]) + if tn.Schema.String() != "" { + builder.WriteString(tn.Schema.String()) + } else { + builder.WriteString(defaultDB) + } + builder.WriteString(".") + builder.WriteString(origin[pos:]) + return builder.String(), true +} + // RestoreWithDefaultDB returns restore strings for StmtNode with defaultDB -func RestoreWithDefaultDB(node ast.StmtNode, defaultDB string) string { +func RestoreWithDefaultDB(node ast.StmtNode, defaultDB, origin string) string { + if s, ok := SimpleCases(node, defaultDB, origin); ok { + return s + } var sb strings.Builder // Three flags for restore with default DB: // 1. RestoreStringSingleQuotes specifies to use single quotes to surround the string; diff --git a/util/parser/ast_test.go b/util/parser/ast_test.go new file mode 100644 index 0000000000000..177caf16f1978 --- /dev/null +++ b/util/parser/ast_test.go @@ -0,0 +1,70 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package parser_test + +import ( + "testing" + + . "github.com/pingcap/check" + "github.com/pingcap/parser" + _ "github.com/pingcap/tidb/types/parser_driver" + utilparser "github.com/pingcap/tidb/util/parser" +) + +var _ = Suite(&testASTSuite{}) + +type testASTSuite struct { +} + +func TestT(t *testing.T) { + TestingT(t) +} + +func (s *testASTSuite) TestSimpleCases(c *C) { + tests := []struct { + sql string + db string + ans string + }{ + { + sql: "insert into t values(1, 2)", + db: "test", + ans: "insert into test.t values(1, 2)", + }, + { + sql: "insert into mydb.t values(1, 2)", + db: "test", + ans: "insert into mydb.t values(1, 2)", + }, + { + sql: "insert into t(a, b) values(1, 2)", + db: "test", + ans: "insert into test.t(a, b) values(1, 2)", + }, + { + sql: "insert into value value(2, 3)", + db: "test", + ans: "insert into test.value value(2, 3)", + }, + } + + for _, t := range tests { + p := parser.New() + stmt, err := p.ParseOneStmt(t.sql, "", "") + c.Assert(err, IsNil) + ans, ok := utilparser.SimpleCases(stmt, t.db, t.sql) + c.Assert(ok, IsTrue) + c.Assert(t.ans, Equals, ans) + } +}