diff --git a/bindinfo/bind_test.go b/bindinfo/bind_test.go index a1e53c7f696ac..bd7ab2e848c76 100644 --- a/bindinfo/bind_test.go +++ b/bindinfo/bind_test.go @@ -508,7 +508,7 @@ func (s *testSuite) TestGlobalBinding(c *C) { c.Check(err, IsNil) c.Check(chk.NumRows(), Equals, 0) - _, err = tk.Exec("delete from mysql.bind_info") + _, err = tk.Exec("delete from mysql.bind_info where source != 'builtin'") c.Assert(err, IsNil) } } @@ -1102,6 +1102,7 @@ func (s *testSuite) TestBaselineDBLowerCase(c *C) { // default_db should have lower case. c.Assert(rows[0][2], Equals, "spm") tk.MustQuery("select original_sql, default_db, status from mysql.bind_info where original_sql = 'select * from `spm` . `t`'").Check(testkit.Rows( + "select * from `spm` . `t` SPM deleted", "select * from `spm` . `t` spm using", )) } @@ -1512,9 +1513,9 @@ func (s *testSuite) TestReloadBindings(c *C) { tk.MustExec("create global binding for select * from t using select * from t use index(idx)") rows := tk.MustQuery("show global bindings").Rows() c.Assert(len(rows), Equals, 1) - rows = tk.MustQuery("select * from mysql.bind_info").Rows() + rows = tk.MustQuery("select * from mysql.bind_info where source != 'builtin'").Rows() c.Assert(len(rows), Equals, 1) - tk.MustExec("truncate table mysql.bind_info") + tk.MustExec("delete from mysql.bind_info where source != 'builtin'") c.Assert(s.domain.BindHandle().Update(false), IsNil) rows = tk.MustQuery("show global bindings").Rows() c.Assert(len(rows), Equals, 1) @@ -1595,7 +1596,7 @@ func (s *testSuite) TestOutdatedInfoSchema(c *C) { tk.MustExec("create table t(a int, b int, index idx(a))") tk.MustExec("create global binding for select * from t using select * from t use index(idx)") c.Assert(s.domain.BindHandle().Update(false), IsNil) - tk.MustExec("truncate table mysql.bind_info") + s.cleanBindingEnv(tk) tk.MustExec("create global binding for select * from t using select * from t use index(idx)") } @@ -2002,11 +2003,11 @@ func (s *testSuite) TestReCreateBind(c *C) { tk.MustExec("drop table if exists t") tk.MustExec("create table t(a int, b int, index idx(a))") - tk.MustQuery("select * from mysql.bind_info").Check(testkit.Rows()) + tk.MustQuery("select * from mysql.bind_info where source != 'builtin'").Check(testkit.Rows()) tk.MustQuery("show global bindings").Check(testkit.Rows()) tk.MustExec("create global binding for select * from t using select * from t") - tk.MustQuery("select original_sql, status from mysql.bind_info").Check(testkit.Rows( + tk.MustQuery("select original_sql, status from mysql.bind_info where source != 'builtin';").Check(testkit.Rows( "select * from `test` . `t` using", )) rows := tk.MustQuery("show global bindings").Rows() @@ -2015,13 +2016,15 @@ func (s *testSuite) TestReCreateBind(c *C) { c.Assert(rows[0][3], Equals, "using") tk.MustExec("create global binding for select * from t using select * from t") - tk.MustQuery("select original_sql, status from mysql.bind_info").Check(testkit.Rows( - "select * from `test` . `t` using", - )) rows = tk.MustQuery("show global bindings").Rows() c.Assert(len(rows), Equals, 1) c.Assert(rows[0][0], Equals, "select * from `test` . `t`") c.Assert(rows[0][3], Equals, "using") + + rows = tk.MustQuery("select original_sql, status from mysql.bind_info where source != 'builtin';").Rows() + c.Assert(len(rows), Equals, 2) + c.Assert(rows[0][1], Equals, "deleted") + c.Assert(rows[1][1], Equals, "using") } func (s *testSuite) TestExplainShowBindSQL(c *C) { @@ -2036,10 +2039,9 @@ func (s *testSuite) TestExplainShowBindSQL(c *C) { "select * from `test` . `t` SELECT * FROM `test`.`t` USE INDEX (`a`)", )) - tk.MustExec("explain select * from t") - tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1105 Using the bindSQL: SELECT * FROM `test`.`t` USE INDEX (`a`)")) - tk.MustExec("explain analyze select * from t") - tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1105 Using the bindSQL: SELECT * FROM `test`.`t` USE INDEX (`a`)")) + tk.MustExec("explain format = 'verbose' select * from t") + tk.MustQuery("show warnings").Check(testkit.Rows("Note 1105 Using the bindSQL: SELECT * FROM `test`.`t` USE INDEX (`a`)")) + // explain analyze do not support verbose yet. } func (s *testSuite) TestDMLIndexHintBind(c *C) { @@ -2097,8 +2099,9 @@ func (s *testSuite) TestConcurrentCapture(c *C) { tk.MustExec("select * from t") tk.MustExec("select * from t") tk.MustExec("admin capture bindings") - tk.MustQuery("select original_sql, source from mysql.bind_info where source != 'builtin'").Check(testkit.Rows( - "select * from `test` . `t` capture", + tk.MustQuery("select original_sql, source, status from mysql.bind_info where source != 'builtin'").Check(testkit.Rows( + "select * from `test` . `t` manual deleted", + "select * from `test` . `t` capture using", )) } diff --git a/bindinfo/handle.go b/bindinfo/handle.go index 530952da9038a..1301f6e2556ac 100644 --- a/bindinfo/handle.go +++ b/bindinfo/handle.go @@ -129,7 +129,7 @@ func (h *BindHandle) Update(fullLoad bool) (err error) { exec := h.sctx.Context.(sqlexec.RestrictedSQLExecutor) stmt, err := exec.ParseWithParams(context.TODO(), `SELECT original_sql, bind_sql, default_db, status, create_time, update_time, charset, collation, source - FROM mysql.bind_info WHERE update_time > %? ORDER BY update_time`, updateTime) + FROM mysql.bind_info WHERE update_time > %? ORDER BY update_time, create_time`, updateTime) if err != nil { return err } @@ -218,14 +218,16 @@ func (h *BindHandle) CreateBindRecord(sctx sessionctx.Context, record *BindRecor if err = h.lockBindInfoTable(); err != nil { return err } - // Binding recreation should physically delete previous bindings. - _, err = exec.ExecuteInternal(context.TODO(), `DELETE FROM mysql.bind_info WHERE original_sql = %?`, record.OriginalSQL) + + now := types.NewTime(types.FromGoTime(time.Now()), mysql.TypeTimestamp, 3) + + updateTs := now.String() + _, err = exec.ExecuteInternal(context.TODO(), `UPDATE mysql.bind_info SET status = %?, update_time = %? WHERE original_sql = %? AND update_time < %?`, + deleted, updateTs, record.OriginalSQL, updateTs) if err != nil { return err } - now := types.NewTime(types.FromGoTime(time.Now()), mysql.TypeTimestamp, 3) - for i := range record.Bindings { record.Bindings[i].CreateTime = now record.Bindings[i].UpdateTime = now @@ -697,7 +699,13 @@ func getHintsForSQL(sctx sessionctx.Context, sql string) (string, error) { rs, err := sctx.(sqlexec.SQLExecutor).ExecuteInternal(context.TODO(), fmt.Sprintf("EXPLAIN FORMAT='hint' %s", sql)) sctx.GetSessionVars().UsePlanBaselines = origVals if rs != nil { - defer terror.Call(rs.Close) + defer func() { + // Audit log is collected in Close(), set InRestrictedSQL to avoid 'create sql binding' been recorded as 'explain'. + origin := sctx.GetSessionVars().InRestrictedSQL + sctx.GetSessionVars().InRestrictedSQL = true + terror.Call(rs.Close) + sctx.GetSessionVars().InRestrictedSQL = origin + }() } if err != nil { return "", err diff --git a/config/config.go b/config/config.go index 874f1bb6f7426..eea787f6f4f82 100644 --- a/config/config.go +++ b/config/config.go @@ -723,6 +723,7 @@ var deprecatedConfig = map[string]struct{}{ "tikv-client.copr-cache.enable": {}, "alter-primary-key": {}, // use NONCLUSTERED keyword instead "enable-streaming": {}, + "allow-expression-index": {}, } func isAllDeprecatedConfigItems(items []string) bool { diff --git a/ddl/column_type_change_test.go b/ddl/column_type_change_test.go index 6c7c5595f35ba..57a4e3d167a21 100644 --- a/ddl/column_type_change_test.go +++ b/ddl/column_type_change_test.go @@ -2175,8 +2175,8 @@ func (s *testColumnTypeChangeSuite) TestCastDateToTimestampInReorgAttribute(c *C s.dom.DDL().(ddl.DDLForTest).SetHook(hook) tk.MustExec("alter table t modify column a TIMESTAMP NULL DEFAULT '2021-04-28 03:35:11' FIRST") - c.Assert(checkErr1.Error(), Equals, "[types:1292]Incorrect datetime value: '3977-02-22 00:00:00'") - c.Assert(checkErr2.Error(), Equals, "[types:1292]Incorrect datetime value: '3977-02-22 00:00:00'") + c.Assert(checkErr1.Error(), Equals, "[types:1292]Incorrect timestamp value: '3977-02-22'") + c.Assert(checkErr2.Error(), Equals, "[types:1292]Incorrect timestamp value: '3977-02-22'") tk.MustExec("drop table if exists t") } diff --git a/ddl/db_change_test.go b/ddl/db_change_test.go index 3c9171f6ad91c..ccbd5bd2e6be7 100644 --- a/ddl/db_change_test.go +++ b/ddl/db_change_test.go @@ -1052,8 +1052,9 @@ func (s *testStateChangeSuite) TestParallelAlterModifyColumn(c *C) { f := func(c *C, err1, err2 error) { c.Assert(err1, IsNil) c.Assert(err2, IsNil) - _, err := s.se.Execute(context.Background(), "select * from t") + rs, err := s.se.Execute(context.Background(), "select * from t") c.Assert(err, IsNil) + c.Assert(rs[0].Close(), IsNil) } s.testControlParallelExecSQL(c, sql, sql, f) } @@ -1064,8 +1065,9 @@ func (s *testStateChangeSuite) TestParallelAddGeneratedColumnAndAlterModifyColum f := func(c *C, err1, err2 error) { c.Assert(err1, IsNil) c.Assert(err2.Error(), Equals, "[ddl:8200]Unsupported modify column: oldCol is a dependent column 'a' for generated column") - _, err := s.se.Execute(context.Background(), "select * from t") + rs, err := s.se.Execute(context.Background(), "select * from t") c.Assert(err, IsNil) + c.Assert(rs[0].Close(), IsNil) } s.testControlParallelExecSQL(c, sql1, sql2, f) } @@ -1076,8 +1078,9 @@ func (s *testStateChangeSuite) TestParallelAlterModifyColumnAndAddPK(c *C) { f := func(c *C, err1, err2 error) { c.Assert(err1, IsNil) c.Assert(err2.Error(), Equals, "[ddl:8200]Unsupported modify column: this column has primary key flag") - _, err := s.se.Execute(context.Background(), "select * from t") + rs, err := s.se.Execute(context.Background(), "select * from t") c.Assert(err, IsNil) + c.Assert(rs[0].Close(), IsNil) } s.testControlParallelExecSQL(c, sql1, sql2, f) } @@ -1361,12 +1364,24 @@ func (s *testStateChangeSuiteBase) testControlParallelExecSQL(c *C, sql1, sql2 s wg.Add(2) go func() { defer wg.Done() - _, err1 = se.Execute(context.Background(), sql1) + var rss []sqlexec.RecordSet + rss, err1 = se.Execute(context.Background(), sql1) + if err1 == nil && len(rss) > 0 { + for _, rs := range rss { + c.Assert(rs.Close(), IsNil) + } + } }() go func() { defer wg.Done() <-ch - _, err2 = se1.Execute(context.Background(), sql2) + var rss []sqlexec.RecordSet + rss, err2 = se1.Execute(context.Background(), sql2) + if err2 == nil && len(rss) > 0 { + for _, rs := range rss { + c.Assert(rs.Close(), IsNil) + } + } }() wg.Wait() diff --git a/executor/adapter.go b/executor/adapter.go index 4527230880c23..7dd83d2de71b1 100644 --- a/executor/adapter.go +++ b/executor/adapter.go @@ -567,6 +567,12 @@ func (a *ExecStmt) handleNoDelayExecutor(ctx context.Context, e Executor) (sqlex ctx = opentracing.ContextWithSpan(ctx, span1) } + var err error + defer func() { + terror.Log(e.Close()) + a.logAudit() + }() + // Check if "tidb_snapshot" is set for the write executors. // In history read mode, we can not do write operations. switch e.(type) { @@ -581,12 +587,6 @@ func (a *ExecStmt) handleNoDelayExecutor(ctx context.Context, e Executor) (sqlex } } - var err error - defer func() { - terror.Log(e.Close()) - a.logAudit() - }() - err = Next(ctx, e, newFirstChunk(e)) if err != nil { return nil, err @@ -829,6 +829,7 @@ func (a *ExecStmt) logAudit() { if sessVars.InRestrictedSQL { return } + err := plugin.ForeachPlugin(plugin.Audit, func(p *plugin.Plugin) error { audit := plugin.DeclareAuditManifest(p.Manifest) if audit.OnGeneralEvent != nil { diff --git a/executor/aggregate.go b/executor/aggregate.go index 61daab6407930..cf5bf5687e07f 100644 --- a/executor/aggregate.go +++ b/executor/aggregate.go @@ -26,6 +26,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/failpoint" "github.com/pingcap/parser/mysql" + "github.com/pingcap/parser/terror" "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/executor/aggfuncs" "github.com/pingcap/tidb/expression" @@ -298,14 +299,17 @@ func (e *HashAggExec) Close() error { // Open implements the Executor Open interface. func (e *HashAggExec) Open(ctx context.Context) error { - if err := e.baseExecutor.Open(ctx); err != nil { - return err - } failpoint.Inject("mockHashAggExecBaseExecutorOpenReturnedError", func(val failpoint.Value) { if val.(bool) { failpoint.Return(errors.New("mock HashAggExec.baseExecutor.Open returned error")) } }) + + if err := e.baseExecutor.Open(ctx); err != nil { + return err + } + // If panic here, the children executor should be closed because they are open. + defer closeBaseExecutor(&e.baseExecutor) e.prepared = false e.memTracker = memory.NewTracker(e.id, -1) @@ -344,6 +348,15 @@ func (e *HashAggExec) initForUnparallelExec() { } } +func closeBaseExecutor(b *baseExecutor) { + if r := recover(); r != nil { + // Release the resource, but throw the panic again and let the top level handle it. + terror.Log(b.Close()) + logutil.BgLogger().Warn("panic in Open(), close base executor and throw exception again") + panic(r) + } +} + func (e *HashAggExec) initForParallelExec(ctx sessionctx.Context) { sessionVars := e.ctx.GetSessionVars() finalConcurrency := sessionVars.HashAggFinalConcurrency() @@ -1218,14 +1231,18 @@ type StreamAggExec struct { // Open implements the Executor Open interface. func (e *StreamAggExec) Open(ctx context.Context) error { - if err := e.baseExecutor.Open(ctx); err != nil { - return err - } failpoint.Inject("mockStreamAggExecBaseExecutorOpenReturnedError", func(val failpoint.Value) { if val.(bool) { failpoint.Return(errors.New("mock StreamAggExec.baseExecutor.Open returned error")) } }) + + if err := e.baseExecutor.Open(ctx); err != nil { + return err + } + // If panic in Open, the children executor should be closed because they are open. + defer closeBaseExecutor(&e.baseExecutor) + e.childResult = newFirstChunk(e.children[0]) e.executed = false e.isChildReturnEmpty = true @@ -1886,10 +1903,13 @@ type AggSpillDiskAction struct { // Action set HashAggExec spill mode. func (a *AggSpillDiskAction) Action(t *memory.Tracker) { - if atomic.LoadUint32(&a.e.inSpillMode) == 0 && a.spillTimes < maxSpillTimes { + // Guarantee that processed data is at least 20% of the threshold, to avoid spilling too frequently. + if atomic.LoadUint32(&a.e.inSpillMode) == 0 && a.spillTimes < maxSpillTimes && a.e.memTracker.BytesConsumed() >= t.GetBytesLimit()/5 { a.spillTimes++ logutil.BgLogger().Info("memory exceeds quota, set aggregate mode to spill-mode", - zap.Uint32("spillTimes", a.spillTimes)) + zap.Uint32("spillTimes", a.spillTimes), + zap.Int64("consumed", t.BytesConsumed()), + zap.Int64("quota", t.GetBytesLimit())) atomic.StoreUint32(&a.e.inSpillMode, 1) return } diff --git a/executor/aggregate_test.go b/executor/aggregate_test.go index a0d14cc7e90dc..df3be90097f64 100644 --- a/executor/aggregate_test.go +++ b/executor/aggregate_test.go @@ -1471,7 +1471,7 @@ func (s *testSerialSuite) TestAggInDisk(c *C) { tk.MustExec("drop table if exists t1") tk.MustExec("create table t(a int)") sql := "insert into t values (0)" - for i := 1; i <= 300; i++ { + for i := 1; i <= 200; i++ { sql += fmt.Sprintf(",(%v)", i) } sql += ";" @@ -1488,4 +1488,15 @@ func (s *testSerialSuite) TestAggInDisk(c *C) { strings.Contains(disk, "Bytes"), IsTrue) } } + + // Add code cover + // Test spill chunk. Add a line to avoid tmp spill chunk is always full. + tk.MustExec("insert into t values(0)") + tk.MustQuery("select sum(tt.b) from ( select /*+ HASH_AGG() */ avg(t1.a) as b from t t1 join t t2 group by t1.a, t2.a) as tt").Check( + testkit.Rows("4040100.0000")) + // Test no groupby and no data. + tk.MustExec("drop table t;") + tk.MustExec("create table t(c int, c1 int);") + tk.MustQuery("select /*+ HASH_AGG() */ count(c) from t;").Check(testkit.Rows("0")) + tk.MustQuery("select /*+ HASH_AGG() */ count(c) from t group by c1;").Check(testkit.Rows()) } diff --git a/executor/bind.go b/executor/bind.go index 4b66d46415316..42552f1dcdee1 100644 --- a/executor/bind.go +++ b/executor/bind.go @@ -77,6 +77,13 @@ func (e *SQLBindExec) dropSQLBind() error { } func (e *SQLBindExec) createSQLBind() error { + // For audit log, SQLBindExec execute "explain" statement internally, save and recover stmtctx + // is necessary to avoid 'create binding' been recorded as 'explain'. + saveStmtCtx := e.ctx.GetSessionVars().StmtCtx + defer func() { + e.ctx.GetSessionVars().StmtCtx = saveStmtCtx + }() + bindInfo := bindinfo.Binding{ BindSQL: e.bindSQL, Charset: e.charset, diff --git a/executor/builder.go b/executor/builder.go index 1c6b18dc51d88..10e6b22148a65 100644 --- a/executor/builder.go +++ b/executor/builder.go @@ -1950,7 +1950,8 @@ func (b *executorBuilder) buildUpdate(v *plannercore.Update) Executor { if b.err != nil { return nil } - b.err = plannercore.CheckUpdateList(assignFlag, v) + // should use the new tblID2table, since the update's schema may have been changed in Execstmt. + b.err = plannercore.CheckUpdateList(assignFlag, v, tblID2table) if b.err != nil { return nil } diff --git a/executor/compiler.go b/executor/compiler.go index 511b516a96dfc..8c310b004f310 100644 --- a/executor/compiler.go +++ b/executor/compiler.go @@ -55,7 +55,9 @@ func (c *Compiler) Compile(ctx context.Context, stmtNode ast.StmtNode) (*ExecStm } ret := &plannercore.PreprocessorReturn{} - if err := plannercore.Preprocess(c.Ctx, stmtNode, plannercore.WithPreprocessorReturn(ret)); err != nil { + pe := &plannercore.PreprocessExecuteISUpdate{ExecuteInfoSchemaUpdate: planner.GetExecuteForUpdateReadIS, Node: stmtNode} + err := plannercore.Preprocess(c.Ctx, stmtNode, plannercore.WithPreprocessorReturn(ret), plannercore.WithExecuteInfoSchemaUpdate(pe)) + if err != nil { return nil, err } stmtNode = plannercore.TryAddExtraLimit(c.Ctx, stmtNode) @@ -335,6 +337,9 @@ func GetStmtLabel(stmtNode ast.StmtNode) string { case *ast.DropIndexStmt: return "DropIndex" case *ast.DropTableStmt: + if x.IsView { + return "DropView" + } return "DropTable" case *ast.ExplainStmt: return "Explain" @@ -373,6 +378,12 @@ func GetStmtLabel(stmtNode ast.StmtNode) string { return "CreateBinding" case *ast.IndexAdviseStmt: return "IndexAdvise" + case *ast.DropBindingStmt: + return "DropBinding" + case *ast.TraceStmt: + return "Trace" + case *ast.ShutdownStmt: + return "Shutdown" } return "other" } diff --git a/executor/executor.go b/executor/executor.go index 8a408911915cd..12cab3e4a8866 100644 --- a/executor/executor.go +++ b/executor/executor.go @@ -1701,6 +1701,7 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { if explainStmt, ok := s.(*ast.ExplainStmt); ok { sc.InExplainStmt = true sc.IgnoreExplainIDSuffix = (strings.ToLower(explainStmt.Format) == types.ExplainFormatBrief) + sc.InVerboseExplain = strings.ToLower(explainStmt.Format) == types.ExplainFormatVerbose s = explainStmt.Stmt } if _, ok := s.(*ast.ExplainForStmt); ok { diff --git a/executor/executor_pkg_test.go b/executor/executor_pkg_test.go index b9b01d9678762..c427156ba073b 100644 --- a/executor/executor_pkg_test.go +++ b/executor/executor_pkg_test.go @@ -43,7 +43,7 @@ import ( "github.com/pingcap/tidb/util/tableutil" ) -var _ = Suite(&testExecSuite{}) +var _ = SerialSuites(&testExecSuite{}) var _ = SerialSuites(&testExecSerialSuite{}) // Note: it's a tricky way to export the `inspectionSummaryRules` and `inspectionRules` for unit test but invisible for normal code diff --git a/executor/executor_test.go b/executor/executor_test.go index 6c02d867feb54..ffb281424ef89 100644 --- a/executor/executor_test.go +++ b/executor/executor_test.go @@ -6690,7 +6690,7 @@ func (s *testClusterTableSuite) TestSQLDigestTextRetriever(c *C) { insertNormalized, insertDigest := parser.NormalizeDigest("insert into test_sql_digest_text_retriever values (1, 1)") _, updateDigest := parser.NormalizeDigest("update test_sql_digest_text_retriever set v = v + 1 where id = 1") - r := &executor.SQLDigestTextRetriever{ + r := &expression.SQLDigestTextRetriever{ SQLDigestsMap: map[string]string{ insertDigest.String(): "", updateDigest.String(): "", @@ -6702,6 +6702,89 @@ func (s *testClusterTableSuite) TestSQLDigestTextRetriever(c *C) { c.Assert(r.SQLDigestsMap[updateDigest.String()], Equals, "") } +func (s *testClusterTableSuite) TestFunctionDecodeSQLDigests(c *C) { + tk := testkit.NewTestKitWithInit(c, s.store) + c.Assert(tk.Se.Auth(&auth.UserIdentity{Username: "root", Hostname: "%"}, nil, nil), IsTrue) + tk.MustExec("set global tidb_enable_stmt_summary = 1") + tk.MustQuery("select @@global.tidb_enable_stmt_summary").Check(testkit.Rows("1")) + tk.MustExec("drop table if exists test_func_decode_sql_digests") + tk.MustExec("create table test_func_decode_sql_digests(id int primary key, v int)") + + q1 := "begin" + norm1, digest1 := parser.NormalizeDigest(q1) + q2 := "select @@tidb_current_ts" + norm2, digest2 := parser.NormalizeDigest(q2) + q3 := "select id, v from test_func_decode_sql_digests where id = 1 for update" + norm3, digest3 := parser.NormalizeDigest(q3) + + // TIDB_DECODE_SQL_DIGESTS function doesn't actually do "decoding", instead it queries `statements_summary` and it's + // variations for the corresponding statements. + // Execute the statements so that the queries will be saved into statements_summary table. + tk.MustExec(q1) + // Save the ts to query the transaction from tidb_trx. + ts, err := strconv.ParseUint(tk.MustQuery(q2).Rows()[0][0].(string), 10, 64) + c.Assert(err, IsNil) + c.Assert(ts, Greater, uint64(0)) + tk.MustExec(q3) + tk.MustExec("rollback") + + // Test statements truncating. + decoded := fmt.Sprintf(`["%s","%s","%s"]`, norm1, norm2, norm3) + digests := fmt.Sprintf(`["%s","%s","%s"]`, digest1, digest2, digest3) + tk.MustQuery("select tidb_decode_sql_digests(?, 0)", digests).Check(testkit.Rows(decoded)) + // The three queries are shorter than truncate length, equal to truncate length and longer than truncate length respectively. + tk.MustQuery("select tidb_decode_sql_digests(?, ?)", digests, len(norm2)).Check(testkit.Rows( + "[\"begin\",\"select @@tidb_current_ts\",\"select `id` , `v` from `...\"]")) + + // Empty array. + tk.MustQuery("select tidb_decode_sql_digests('[]')").Check(testkit.Rows("[]")) + + // NULL + tk.MustQuery("select tidb_decode_sql_digests(null)").Check(testkit.Rows("")) + + // Array containing wrong types and not-existing digests (maps to null). + tk.MustQuery("select tidb_decode_sql_digests(?)", fmt.Sprintf(`["%s",1,null,"%s",{"a":1},[2],"%s","","abcde"]`, digest1, digest2, digest3)). + Check(testkit.Rows(fmt.Sprintf(`["%s",null,null,"%s",null,null,"%s",null,null]`, norm1, norm2, norm3))) + + // Not JSON array (throws warnings) + tk.MustQuery(`select tidb_decode_sql_digests('{"a":1}')`).Check(testkit.Rows("")) + tk.MustQuery(`show warnings`).Check(testkit.Rows(`Warning 1210 The argument can't be unmarshalled as JSON array: '{"a":1}'`)) + tk.MustQuery(`select tidb_decode_sql_digests('aabbccdd')`).Check(testkit.Rows("")) + tk.MustQuery(`show warnings`).Check(testkit.Rows(`Warning 1210 The argument can't be unmarshalled as JSON array: 'aabbccdd'`)) + + // Invalid argument count. + tk.MustGetErrCode("select tidb_decode_sql_digests('a', 1, 2)", 1582) + tk.MustGetErrCode("select tidb_decode_sql_digests()", 1582) +} + +func (s *testClusterTableSuite) TestFunctionDecodeSQLDigestsPrivilege(c *C) { + dropUserTk := testkit.NewTestKitWithInit(c, s.store) + c.Assert(dropUserTk.Se.Auth(&auth.UserIdentity{Username: "root", Hostname: "%"}, nil, nil), IsTrue) + + tk := testkit.NewTestKitWithInit(c, s.store) + c.Assert(tk.Se.Auth(&auth.UserIdentity{Username: "root", Hostname: "%"}, nil, nil), IsTrue) + tk.MustExec("create user 'testuser'@'localhost'") + defer dropUserTk.MustExec("drop user 'testuser'@'localhost'") + c.Assert(tk.Se.Auth(&auth.UserIdentity{ + Username: "testuser", + Hostname: "localhost", + }, nil, nil), IsTrue) + err := tk.ExecToErr("select tidb_decode_sql_digests('[\"aa\"]')") + c.Assert(err, NotNil) + c.Assert(err.Error(), Equals, "[expression:1227]Access denied; you need (at least one of) the PROCESS privilege(s) for this operation") + + tk = testkit.NewTestKitWithInit(c, s.store) + c.Assert(tk.Se.Auth(&auth.UserIdentity{Username: "root", Hostname: "%"}, nil, nil), IsTrue) + tk.MustExec("create user 'testuser2'@'localhost'") + defer dropUserTk.MustExec("drop user 'testuser2'@'localhost'") + tk.MustExec("grant process on *.* to 'testuser2'@'localhost'") + c.Assert(tk.Se.Auth(&auth.UserIdentity{ + Username: "testuser2", + Hostname: "localhost", + }, nil, nil), IsTrue) + _ = tk.MustQuery("select tidb_decode_sql_digests('[\"aa\"]')") +} + func prepareLogs(c *C, logData []string, fileNames []string) { writeFile := func(file string, data string) { f, err := os.OpenFile(file, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) @@ -8338,9 +8421,9 @@ func (s *testSerialSuite) TestDeadlocksTable(c *C) { id1 := strconv.FormatUint(rec.ID, 10) id2 := strconv.FormatUint(rec2.ID, 10) - c.Assert(failpoint.Enable("github.com/pingcap/tidb/executor/sqlDigestRetrieverSkipRetrieveGlobal", "return"), IsNil) + c.Assert(failpoint.Enable("github.com/pingcap/tidb/expression/sqlDigestRetrieverSkipRetrieveGlobal", "return"), IsNil) defer func() { - c.Assert(failpoint.Disable("github.com/pingcap/tidb/executor/sqlDigestRetrieverSkipRetrieveGlobal"), IsNil) + c.Assert(failpoint.Disable("github.com/pingcap/tidb/expression/sqlDigestRetrieverSkipRetrieveGlobal"), IsNil) }() tk := testkit.NewTestKit(c, s.store) diff --git a/executor/explain.go b/executor/explain.go index a584028b686ee..4e6116975b978 100644 --- a/executor/explain.go +++ b/executor/explain.go @@ -44,6 +44,10 @@ func (e *ExplainExec) Open(ctx context.Context) error { // Close implements the Executor Close interface. func (e *ExplainExec) Close() error { e.rows = nil + if e.analyzeExec != nil && !e.executed { + // Open(), but Next() is not called. + return e.analyzeExec.Close() + } return nil } diff --git a/executor/index_lookup_join.go b/executor/index_lookup_join.go index 6f62fcff339cb..d617cddf77bd7 100644 --- a/executor/index_lookup_join.go +++ b/executor/index_lookup_join.go @@ -517,6 +517,11 @@ func (iw *innerWorker) constructLookupContent(task *lookUpJoinTask) ([]*indexJoi for rowIdx := 0; rowIdx < numRows; rowIdx++ { dLookUpKey, dHashKey, err := iw.constructDatumLookupKey(task, chkIdx, rowIdx) if err != nil { + if terror.ErrorEqual(err, types.ErrWrongValue) { + // We ignore rows with invalid datetime. + task.encodedLookUpKeys[chkIdx].AppendNull(0) + continue + } return nil, err } if dHashKey == nil { diff --git a/executor/infoschema_reader.go b/executor/infoschema_reader.go index 0096f1cc5a991..d29d523249d70 100644 --- a/executor/infoschema_reader.go +++ b/executor/infoschema_reader.go @@ -39,6 +39,7 @@ import ( "github.com/pingcap/tidb/domain" "github.com/pingcap/tidb/domain/infosync" "github.com/pingcap/tidb/errno" + "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/infoschema" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/meta/autoid" @@ -2137,11 +2138,11 @@ func (e *tidbTrxTableRetriever) retrieve(ctx context.Context, sctx sessionctx.Co var res [][]types.Datum err = e.nextBatch(func(start, end int) error { // Before getting rows, collect the SQL digests that needs to be retrieved first. - var sqlRetriever *SQLDigestTextRetriever + var sqlRetriever *expression.SQLDigestTextRetriever for _, c := range e.columns { if c.Name.O == txninfo.CurrentSQLDigestTextStr { if sqlRetriever == nil { - sqlRetriever = NewSQLDigestTextRetriever() + sqlRetriever = expression.NewSQLDigestTextRetriever() } for i := start; i < end; i++ { @@ -2250,9 +2251,9 @@ func (r *dataLockWaitsTableRetriever) retrieve(ctx context.Context, sctx session } // Fetch the SQL Texts of the digests above if necessary. - var sqlRetriever *SQLDigestTextRetriever + var sqlRetriever *expression.SQLDigestTextRetriever if needSQLText { - sqlRetriever = NewSQLDigestTextRetriever() + sqlRetriever = expression.NewSQLDigestTextRetriever() for _, digest := range digests { if len(digest) > 0 { sqlRetriever.SQLDigestsMap[digest] = "" @@ -2390,11 +2391,11 @@ func (r *deadlocksTableRetriever) retrieve(ctx context.Context, sctx sessionctx. err = r.nextBatch(func(start, end int) error { // Before getting rows, collect the SQL digests that needs to be retrieved first. - var sqlRetriever *SQLDigestTextRetriever + var sqlRetriever *expression.SQLDigestTextRetriever for _, c := range r.columns { if c.Name.O == deadlockhistory.ColCurrentSQLDigestTextStr { if sqlRetriever == nil { - sqlRetriever = NewSQLDigestTextRetriever() + sqlRetriever = expression.NewSQLDigestTextRetriever() } idx, waitChainIdx := r.currentIdx, r.currentWaitChainIdx diff --git a/executor/insert_test.go b/executor/insert_test.go index bde17e6c1218b..7b6978298849d 100644 --- a/executor/insert_test.go +++ b/executor/insert_test.go @@ -336,6 +336,18 @@ func (s *testSuite3) TestInsertWrongValueForField(c *C) { tk.MustExec(`create table t (a year);`) _, err = tk.Exec(`insert into t values(2156);`) c.Assert(err.Error(), Equals, `[types:8033]invalid year`) + + tk.MustExec(`DROP TABLE IF EXISTS ts`) + tk.MustExec(`CREATE TABLE ts (id int DEFAULT NULL, time1 TIMESTAMP NULL DEFAULT NULL)`) + tk.MustExec(`SET @@sql_mode=''`) + tk.MustExec(`INSERT INTO ts (id, time1) VALUES (1, TIMESTAMP '1018-12-23 00:00:00')`) + tk.MustQuery(`SHOW WARNINGS`).Check(testkit.Rows(`Warning 1292 Incorrect timestamp value: '1018-12-23 00:00:00'`)) + tk.MustQuery(`SELECT * FROM ts ORDER BY id`).Check(testkit.Rows(`1 0000-00-00 00:00:00`)) + + tk.MustExec(`SET @@sql_mode='STRICT_TRANS_TABLES'`) + _, err = tk.Exec(`INSERT INTO ts (id, time1) VALUES (2, TIMESTAMP '1018-12-24 00:00:00')`) + c.Assert(err.Error(), Equals, `[table:1292]Incorrect timestamp value: '1018-12-24 00:00:00' for column 'time1' at row 1`) + tk.MustExec(`DROP TABLE ts`) } func (s *testSuite3) TestInsertValueForCastDecimalField(c *C) { diff --git a/executor/prepared.go b/executor/prepared.go index caa8f2dc09272..e6d2e197d00b8 100644 --- a/executor/prepared.go +++ b/executor/prepared.go @@ -87,6 +87,11 @@ type PrepareExec struct { ID uint32 ParamCount int Fields []*ast.ResultField + + // If it's generated from executing "prepare stmt from '...'", the process is parse -> plan -> executor + // If it's generated from the prepare protocol, the process is session.PrepareStmt -> NewPrepareExec + // They both generate a PrepareExec struct, but the second case needs to reset the statement context while the first already do that. + needReset bool } // NewPrepareExec creates a new PrepareExec. @@ -96,6 +101,7 @@ func NewPrepareExec(ctx sessionctx.Context, sqlTxt string) *PrepareExec { return &PrepareExec{ baseExecutor: base, sqlText: sqlTxt, + needReset: true, } } @@ -135,9 +141,11 @@ func (e *PrepareExec) Next(ctx context.Context, req *chunk.Chunk) error { } stmt := stmts[0] - err = ResetContextOfStmt(e.ctx, stmt) - if err != nil { - return err + if e.needReset { + err = ResetContextOfStmt(e.ctx, stmt) + if err != nil { + return err + } } var extractor paramMarkerExtractor diff --git a/executor/seqtest/prepared_test.go b/executor/seqtest/prepared_test.go index 93ba9cb30cace..1b54221a224b5 100644 --- a/executor/seqtest/prepared_test.go +++ b/executor/seqtest/prepared_test.go @@ -57,16 +57,22 @@ func (s *seqTestSuite) TestPrepared(c *C) { tk.MustExec("create table prepare_test (id int PRIMARY KEY AUTO_INCREMENT, c1 int, c2 int, c3 int default 1)") tk.MustExec("insert prepare_test (c1) values (1),(2),(NULL)") - tk.MustExec(`prepare stmt_test_1 from 'select id from prepare_test where id > ?'; set @a = 1; execute stmt_test_1 using @a;`) + tk.MustExec(`prepare stmt_test_1 from 'select id from prepare_test where id > ?';`) + tk.MustExec(`set @a = 1;`) + tk.MustExec(`execute stmt_test_1 using @a;`) tk.MustExec(`prepare stmt_test_2 from 'select 1'`) // Prepare multiple statement is not allowed. _, err = tk.Exec(`prepare stmt_test_3 from 'select id from prepare_test where id > ?;select id from prepare_test where id > ?;'`) c.Assert(executor.ErrPrepareMulti.Equal(err), IsTrue) + // The variable count does not match. - _, err = tk.Exec(`prepare stmt_test_4 from 'select id from prepare_test where id > ? and id < ?'; set @a = 1; execute stmt_test_4 using @a;`) + tk.MustExec(`prepare stmt_test_4 from 'select id from prepare_test where id > ? and id < ?';`) + tk.MustExec(`set @a = 1;`) + _, err = tk.Exec(`execute stmt_test_4 using @a;`) c.Assert(plannercore.ErrWrongParamCount.Equal(err), IsTrue) // Prepare and deallocate prepared statement immediately. - tk.MustExec(`prepare stmt_test_5 from 'select id from prepare_test where id > ?'; deallocate prepare stmt_test_5;`) + tk.MustExec(`prepare stmt_test_5 from 'select id from prepare_test where id > ?';`) + tk.MustExec(`deallocate prepare stmt_test_5;`) // Statement not found. _, err = tk.Exec("deallocate prepare stmt_test_5") @@ -166,8 +172,11 @@ func (s *seqTestSuite) TestPrepared(c *C) { c.Assert(err, IsNil) // Should success as the changed schema do not affect the prepared statement. - _, err = tk.Se.ExecutePreparedStmt(ctx, stmtID, []types.Datum{types.NewDatum(1)}) + rs, err = tk.Se.ExecutePreparedStmt(ctx, stmtID, []types.Datum{types.NewDatum(1)}) c.Assert(err, IsNil) + if rs != nil { + rs.Close() + } // Drop a column so the prepared statement become invalid. query = "select c1, c2 from prepare_test where c1 = ?" diff --git a/executor/show_test.go b/executor/show_test.go index afea20e638680..ff630bc5cae2a 100644 --- a/executor/show_test.go +++ b/executor/show_test.go @@ -1130,7 +1130,7 @@ func (s *testSuite5) TestShowBuiltin(c *C) { res := tk.MustQuery("show builtins;") c.Assert(res, NotNil) rows := res.Rows() - const builtinFuncNum = 272 + const builtinFuncNum = 273 c.Assert(builtinFuncNum, Equals, len(rows)) c.Assert("abs", Equals, rows[0][0].(string)) c.Assert("yearweek", Equals, rows[builtinFuncNum-1][0].(string)) diff --git a/executor/stale_txn_test.go b/executor/stale_txn_test.go index 8c4dbcd40350f..0b021c39ab1ef 100644 --- a/executor/stale_txn_test.go +++ b/executor/stale_txn_test.go @@ -211,12 +211,15 @@ func (s *testStaleTxnSerialSuite) TestSelectAsOf(c *C) { } else if testcase.preSec > 0 { c.Assert(failpoint.Enable("github.com/pingcap/tidb/executor/assertStaleTSOWithTolerance", fmt.Sprintf(`return(%d)`, time.Now().Unix()-testcase.preSec)), IsNil) } - _, err := tk.Exec(testcase.sql) + rs, err := tk.Exec(testcase.sql) if len(testcase.errorStr) != 0 { c.Assert(err, ErrorMatches, testcase.errorStr) continue } c.Assert(err, IsNil, Commentf("sql:%s, error stack %v", testcase.sql, errors.ErrorStack(err))) + if rs != nil { + rs.Close() + } if testcase.expectPhysicalTS > 0 { c.Assert(failpoint.Disable("github.com/pingcap/tidb/executor/assertStaleTSO"), IsNil) } else if testcase.preSec > 0 { @@ -696,8 +699,7 @@ func (s *testStaleTxnSerialSuite) TestValidateReadOnlyInStalenessTransaction(c * c.Log(testcase.name) tk.MustExec(`START TRANSACTION READ ONLY AS OF TIMESTAMP NOW(3);`) if testcase.isValidate { - _, err := tk.Exec(testcase.sql) - c.Assert(err, IsNil) + tk.MustExec(testcase.sql) } else { err := tk.ExecToErr(testcase.sql) c.Assert(err, NotNil) @@ -706,8 +708,7 @@ func (s *testStaleTxnSerialSuite) TestValidateReadOnlyInStalenessTransaction(c * tk.MustExec("commit") tk.MustExec("set transaction read only as of timestamp NOW(3);") if testcase.isValidate { - _, err := tk.Exec(testcase.sql) - c.Assert(err, IsNil) + tk.MustExec(testcase.sql) } else { err := tk.ExecToErr(testcase.sql) c.Assert(err, NotNil) diff --git a/executor/trace.go b/executor/trace.go index fd3ab5ac92223..f0faa25ed504b 100644 --- a/executor/trace.go +++ b/executor/trace.go @@ -64,6 +64,12 @@ func (e *TraceExec) Next(ctx context.Context, req *chunk.Chunk) error { return nil } + // For audit log plugin to set the correct statement. + stmtCtx := e.ctx.GetSessionVars().StmtCtx + defer func() { + e.ctx.GetSessionVars().StmtCtx = stmtCtx + }() + switch e.format { case core.TraceFormatLog: return e.nextTraceLog(ctx, se, req) @@ -130,6 +136,14 @@ func (e *TraceExec) nextRowJSON(ctx context.Context, se sqlexec.SQLExecutor, req } func (e *TraceExec) executeChild(ctx context.Context, se sqlexec.SQLExecutor) { + // For audit log plugin to log the statement correctly. + // Should be logged as 'explain ...', instead of the executed SQL. + vars := e.ctx.GetSessionVars() + origin := vars.InRestrictedSQL + vars.InRestrictedSQL = true + defer func() { + vars.InRestrictedSQL = origin + }() rs, err := se.ExecuteStmt(ctx, e.stmtNode) if err != nil { var errCode uint16 diff --git a/executor/utils.go b/executor/utils.go index cf0eaeb6f0245..dcd2a394331d6 100644 --- a/executor/utils.go +++ b/executor/utils.go @@ -14,13 +14,7 @@ package executor import ( - "context" "strings" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/sessionctx" - "github.com/pingcap/tidb/util/sqlexec" ) // SetFromString constructs a slice of strings from a comma separated string. @@ -60,178 +54,6 @@ func deleteFromSet(set []string, value string) []string { return set } -// SQLDigestTextRetriever is used to find the normalized SQL statement text by SQL digests in statements_summary table. -// It's exported for test purposes. -type SQLDigestTextRetriever struct { - // SQLDigestsMap is the place to put the digests that's requested for getting SQL text and also the place to put - // the query result. - SQLDigestsMap map[string]string - - // Replace querying for test purposes. - mockLocalData map[string]string - mockGlobalData map[string]string - // There are two ways for querying information: 1) query specified digests by WHERE IN query, or 2) query all - // information to avoid the too long WHERE IN clause. If there are more than `fetchAllLimit` digests needs to be - // queried, the second way will be chosen; otherwise, the first way will be chosen. - fetchAllLimit int -} - -// NewSQLDigestTextRetriever creates a new SQLDigestTextRetriever. -func NewSQLDigestTextRetriever() *SQLDigestTextRetriever { - return &SQLDigestTextRetriever{ - SQLDigestsMap: make(map[string]string), - fetchAllLimit: 512, - } -} - -func (r *SQLDigestTextRetriever) runMockQuery(data map[string]string, inValues []interface{}) (map[string]string, error) { - if len(inValues) == 0 { - return data, nil - } - res := make(map[string]string, len(inValues)) - for _, digest := range inValues { - if text, ok := data[digest.(string)]; ok { - res[digest.(string)] = text - } - } - return res, nil -} - -// runFetchDigestQuery runs query to the system tables to fetch the kv mapping of SQL digests and normalized SQL texts -// of the given SQL digests, if `inValues` is given, or all these mappings otherwise. If `queryGlobal` is false, it -// queries information_schema.statements_summary and information_schema.statements_summary_history; otherwise, it -// queries the cluster version of these two tables. -func (r *SQLDigestTextRetriever) runFetchDigestQuery(ctx context.Context, sctx sessionctx.Context, queryGlobal bool, inValues []interface{}) (map[string]string, error) { - // If mock data is set, query the mock data instead of the real statements_summary tables. - if !queryGlobal && r.mockLocalData != nil { - return r.runMockQuery(r.mockLocalData, inValues) - } else if queryGlobal && r.mockGlobalData != nil { - return r.runMockQuery(r.mockGlobalData, inValues) - } - - exec, ok := sctx.(sqlexec.RestrictedSQLExecutor) - if !ok { - return nil, errors.New("restricted sql can't be executed in this context") - } - - // Information in statements_summary will be periodically moved to statements_summary_history. Union them together - // to avoid missing information when statements_summary is just cleared. - stmt := "select digest, digest_text from information_schema.statements_summary union distinct " + - "select digest, digest_text from information_schema.statements_summary_history" - if queryGlobal { - stmt = "select digest, digest_text from information_schema.cluster_statements_summary union distinct " + - "select digest, digest_text from information_schema.cluster_statements_summary_history" - } - // Add the where clause if `inValues` is specified. - if len(inValues) > 0 { - stmt += " where digest in (" + strings.Repeat("%?,", len(inValues)-1) + "%?)" - } - - stmtNode, err := exec.ParseWithParams(ctx, stmt, inValues...) - if err != nil { - return nil, err - } - rows, _, err := exec.ExecRestrictedStmt(ctx, stmtNode) - if err != nil { - return nil, err - } - - res := make(map[string]string, len(rows)) - for _, row := range rows { - res[row.GetString(0)] = row.GetString(1) - } - return res, nil -} - -func (r *SQLDigestTextRetriever) updateDigestInfo(queryResult map[string]string) { - for digest, text := range r.SQLDigestsMap { - if len(text) > 0 { - // The text of this digest is already known - continue - } - sqlText, ok := queryResult[digest] - if ok { - r.SQLDigestsMap[digest] = sqlText - } - } -} - -// RetrieveLocal tries to retrieve the SQL text of the SQL digests from local information. -func (r *SQLDigestTextRetriever) RetrieveLocal(ctx context.Context, sctx sessionctx.Context) error { - if len(r.SQLDigestsMap) == 0 { - return nil - } - - var queryResult map[string]string - if len(r.SQLDigestsMap) <= r.fetchAllLimit { - inValues := make([]interface{}, 0, len(r.SQLDigestsMap)) - for key := range r.SQLDigestsMap { - inValues = append(inValues, key) - } - var err error - queryResult, err = r.runFetchDigestQuery(ctx, sctx, false, inValues) - if err != nil { - return errors.Trace(err) - } - - if len(queryResult) == len(r.SQLDigestsMap) { - r.SQLDigestsMap = queryResult - return nil - } - } else { - var err error - queryResult, err = r.runFetchDigestQuery(ctx, sctx, false, nil) - if err != nil { - return errors.Trace(err) - } - } - - r.updateDigestInfo(queryResult) - return nil -} - -// RetrieveGlobal tries to retrieve the SQL text of the SQL digests from the information of the whole cluster. -func (r *SQLDigestTextRetriever) RetrieveGlobal(ctx context.Context, sctx sessionctx.Context) error { - err := r.RetrieveLocal(ctx, sctx) - if err != nil { - return errors.Trace(err) - } - - // In some unit test environments it's unable to retrieve global info, and this function blocks it for tens of - // seconds, which wastes much time during unit test. In this case, enable this failpoint to bypass retrieving - // globally. - failpoint.Inject("sqlDigestRetrieverSkipRetrieveGlobal", func() { - failpoint.Return(nil) - }) - - var unknownDigests []interface{} - for k, v := range r.SQLDigestsMap { - if len(v) == 0 { - unknownDigests = append(unknownDigests, k) - } - } - - if len(unknownDigests) == 0 { - return nil - } - - var queryResult map[string]string - if len(r.SQLDigestsMap) <= r.fetchAllLimit { - queryResult, err = r.runFetchDigestQuery(ctx, sctx, true, unknownDigests) - if err != nil { - return errors.Trace(err) - } - } else { - queryResult, err = r.runFetchDigestQuery(ctx, sctx, true, nil) - if err != nil { - return errors.Trace(err) - } - } - - r.updateDigestInfo(queryResult) - return nil -} - // batchRetrieverHelper is a helper for batch returning data with known total rows. This helps implementing memtable // retrievers of some information_schema tables. Initialize `batchSize` and `totalRows` fields to use it. type batchRetrieverHelper struct { diff --git a/executor/utils_test.go b/executor/utils_test.go index f22155c7c1c46..4bba62668d572 100644 --- a/executor/utils_test.go +++ b/executor/utils_test.go @@ -14,8 +14,6 @@ package executor import ( - "context" - . "github.com/pingcap/check" "github.com/pingcap/errors" ) @@ -92,65 +90,3 @@ func (s *pkgTestSuite) TestBatchRetrieverHelper(c *C) { c.Assert(rangeStarts, DeepEquals, []int{0}) c.Assert(rangeEnds, DeepEquals, []int{10}) } - -func (s *pkgTestSuite) TestSQLDigestTextRetriever(c *C) { - // Create a fake session as the argument to the retriever, though it's actually not used when mock data is set. - - r := NewSQLDigestTextRetriever() - clearResult := func() { - r.SQLDigestsMap = map[string]string{ - "digest1": "", - "digest2": "", - "digest3": "", - "digest4": "", - "digest5": "", - } - } - clearResult() - r.mockLocalData = map[string]string{ - "digest1": "text1", - "digest2": "text2", - "digest6": "text6", - } - r.mockGlobalData = map[string]string{ - "digest2": "text2", - "digest3": "text3", - "digest4": "text4", - "digest7": "text7", - } - - expectedLocalResult := map[string]string{ - "digest1": "text1", - "digest2": "text2", - "digest3": "", - "digest4": "", - "digest5": "", - } - expectedGlobalResult := map[string]string{ - "digest1": "text1", - "digest2": "text2", - "digest3": "text3", - "digest4": "text4", - "digest5": "", - } - - err := r.RetrieveLocal(context.Background(), nil) - c.Assert(err, IsNil) - c.Assert(r.SQLDigestsMap, DeepEquals, expectedLocalResult) - clearResult() - - err = r.RetrieveGlobal(context.Background(), nil) - c.Assert(err, IsNil) - c.Assert(r.SQLDigestsMap, DeepEquals, expectedGlobalResult) - clearResult() - - r.fetchAllLimit = 1 - err = r.RetrieveLocal(context.Background(), nil) - c.Assert(err, IsNil) - c.Assert(r.SQLDigestsMap, DeepEquals, expectedLocalResult) - clearResult() - - err = r.RetrieveGlobal(context.Background(), nil) - c.Assert(err, IsNil) - c.Assert(r.SQLDigestsMap, DeepEquals, expectedGlobalResult) -} diff --git a/expression/builtin.go b/expression/builtin.go index f5b6caedc03fb..3625602eb8fea 100644 --- a/expression/builtin.go +++ b/expression/builtin.go @@ -896,9 +896,10 @@ var funcs = map[string]functionClass{ // TiDB internal function. ast.TiDBDecodeKey: &tidbDecodeKeyFunctionClass{baseFunctionClass{ast.TiDBDecodeKey, 1, 1}}, // This function is used to show tidb-server version info. - ast.TiDBVersion: &tidbVersionFunctionClass{baseFunctionClass{ast.TiDBVersion, 0, 0}}, - ast.TiDBIsDDLOwner: &tidbIsDDLOwnerFunctionClass{baseFunctionClass{ast.TiDBIsDDLOwner, 0, 0}}, - ast.TiDBDecodePlan: &tidbDecodePlanFunctionClass{baseFunctionClass{ast.TiDBDecodePlan, 1, 1}}, + ast.TiDBVersion: &tidbVersionFunctionClass{baseFunctionClass{ast.TiDBVersion, 0, 0}}, + ast.TiDBIsDDLOwner: &tidbIsDDLOwnerFunctionClass{baseFunctionClass{ast.TiDBIsDDLOwner, 0, 0}}, + ast.TiDBDecodePlan: &tidbDecodePlanFunctionClass{baseFunctionClass{ast.TiDBDecodePlan, 1, 1}}, + ast.TiDBDecodeSQLDigests: &tidbDecodeSQLDigestsFunctionClass{baseFunctionClass{ast.TiDBDecodeSQLDigests, 1, 2}}, // TiDB Sequence function. ast.NextVal: &nextValFunctionClass{baseFunctionClass{ast.NextVal, 1, 1}}, diff --git a/expression/builtin_info.go b/expression/builtin_info.go index 7be431165bf21..381869098b35a 100644 --- a/expression/builtin_info.go +++ b/expression/builtin_info.go @@ -18,8 +18,11 @@ package expression import ( + "context" + "encoding/json" "sort" "strings" + "time" "github.com/pingcap/errors" "github.com/pingcap/parser/model" @@ -52,6 +55,7 @@ var ( _ functionClass = &tidbIsDDLOwnerFunctionClass{} _ functionClass = &tidbDecodePlanFunctionClass{} _ functionClass = &tidbDecodeKeyFunctionClass{} + _ functionClass = &tidbDecodeSQLDigestsFunctionClass{} _ functionClass = &nextValFunctionClass{} _ functionClass = &lastValFunctionClass{} _ functionClass = &setValFunctionClass{} @@ -71,6 +75,7 @@ var ( _ builtinFunc = &builtinTiDBVersionSig{} _ builtinFunc = &builtinRowCountSig{} _ builtinFunc = &builtinTiDBDecodeKeySig{} + _ builtinFunc = &builtinTiDBDecodeSQLDigestsSig{} _ builtinFunc = &builtinNextValSig{} _ builtinFunc = &builtinLastValSig{} _ builtinFunc = &builtinSetValSig{} @@ -771,6 +776,131 @@ func (k TiDBDecodeKeyFunctionKeyType) String() string { // TiDBDecodeKeyFunctionKey is used to identify the decoder function in context. const TiDBDecodeKeyFunctionKey TiDBDecodeKeyFunctionKeyType = 0 +type tidbDecodeSQLDigestsFunctionClass struct { + baseFunctionClass +} + +func (c *tidbDecodeSQLDigestsFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + + pm := privilege.GetPrivilegeManager(ctx) + if pm != nil && !pm.RequestVerification(ctx.GetSessionVars().ActiveRoles, "", "", "", mysql.ProcessPriv) { + return nil, errSpecificAccessDenied.GenWithStackByArgs("PROCESS") + } + + var argTps []types.EvalType + if len(args) > 1 { + argTps = []types.EvalType{types.ETString, types.ETInt} + } else { + argTps = []types.EvalType{types.ETString} + } + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETString, argTps...) + if err != nil { + return nil, err + } + sig := &builtinTiDBDecodeSQLDigestsSig{bf} + return sig, nil +} + +type builtinTiDBDecodeSQLDigestsSig struct { + baseBuiltinFunc +} + +func (b *builtinTiDBDecodeSQLDigestsSig) Clone() builtinFunc { + newSig := &builtinTiDBDecodeSQLDigestsSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (b *builtinTiDBDecodeSQLDigestsSig) evalString(row chunk.Row) (string, bool, error) { + args := b.getArgs() + digestsStr, isNull, err := args[0].EvalString(b.ctx, row) + if err != nil { + return "", true, err + } + if isNull { + return "", true, nil + } + + stmtTruncateLength := int64(0) + if len(args) > 1 { + stmtTruncateLength, isNull, err = args[1].EvalInt(b.ctx, row) + if err != nil { + return "", true, err + } + if isNull { + stmtTruncateLength = 0 + } + } + + var digests []interface{} + err = json.Unmarshal([]byte(digestsStr), &digests) + if err != nil { + const errMsgMaxLength = 32 + if len(digestsStr) > errMsgMaxLength { + digestsStr = digestsStr[:errMsgMaxLength] + "..." + } + b.ctx.GetSessionVars().StmtCtx.AppendWarning(errIncorrectArgs.GenWithStack("The argument can't be unmarshalled as JSON array: '%s'", digestsStr)) + return "", true, nil + } + + // Query the SQL Statements by digests. + retriever := NewSQLDigestTextRetriever() + for _, item := range digests { + if item != nil { + digest, ok := item.(string) + if ok { + retriever.SQLDigestsMap[digest] = "" + } + } + } + + // Querying may take some time and it takes a context.Context as argument, which is not available here. + // We simply create a context with a timeout here. + timeout := time.Duration(b.ctx.GetSessionVars().MaxExecutionTime) * time.Millisecond + if timeout == 0 || timeout > 20*time.Second { + timeout = 20 * time.Second + } + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + err = retriever.RetrieveGlobal(ctx, b.ctx) + if err != nil { + if errors.Cause(err) == context.DeadlineExceeded || errors.Cause(err) == context.Canceled { + return "", true, errUnknown.GenWithStack("Retrieving cancelled internally with error: %v", err) + } + + b.ctx.GetSessionVars().StmtCtx.AppendWarning(errUnknown.GenWithStack("Retrieving statements information failed with error: %v", err)) + return "", true, nil + } + + // Collect the result. + result := make([]interface{}, len(digests)) + for i, item := range digests { + if item == nil { + continue + } + if digest, ok := item.(string); ok { + if stmt, ok := retriever.SQLDigestsMap[digest]; ok && len(stmt) > 0 { + // Truncate too-long statements if necessary. + if stmtTruncateLength > 0 && int64(len(stmt)) > stmtTruncateLength { + stmt = stmt[:stmtTruncateLength] + "..." + } + result[i] = stmt + } + } + } + + resultStr, err := json.Marshal(result) + if err != nil { + b.ctx.GetSessionVars().StmtCtx.AppendWarning(errUnknown.GenWithStack("Marshalling result as JSON failed with error: %v", err)) + return "", true, nil + } + + return string(resultStr), false, nil +} + type tidbDecodePlanFunctionClass struct { baseFunctionClass } diff --git a/expression/builtin_miscellaneous.go b/expression/builtin_miscellaneous.go index a354d18a666ab..c73299211b7a9 100644 --- a/expression/builtin_miscellaneous.go +++ b/expression/builtin_miscellaneous.go @@ -399,6 +399,7 @@ func (c *inetAtonFunctionClass) getFunction(ctx sessionctx.Context, args []Expre bf.tp.Flen = 21 bf.tp.Flag |= mysql.UnsignedFlag sig := &builtinInetAtonSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_InetAton) return sig, nil } @@ -476,6 +477,7 @@ func (c *inetNtoaFunctionClass) getFunction(ctx sessionctx.Context, args []Expre bf.tp.Flen = 93 bf.tp.Decimal = 0 sig := &builtinInetNtoaSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_InetNtoa) return sig, nil } @@ -528,6 +530,7 @@ func (c *inet6AtonFunctionClass) getFunction(ctx sessionctx.Context, args []Expr types.SetBinChsClnFlag(bf.tp) bf.tp.Decimal = 0 sig := &builtinInet6AtonSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_Inet6Aton) return sig, nil } @@ -600,6 +603,7 @@ func (c *inet6NtoaFunctionClass) getFunction(ctx sessionctx.Context, args []Expr bf.tp.Flen = 117 bf.tp.Decimal = 0 sig := &builtinInet6NtoaSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_Inet6Ntoa) return sig, nil } @@ -654,6 +658,7 @@ func (c *isIPv4FunctionClass) getFunction(ctx sessionctx.Context, args []Express } bf.tp.Flen = 1 sig := &builtinIsIPv4Sig{bf} + sig.setPbCode(tipb.ScalarFuncSig_IsIPv4) return sig, nil } @@ -721,6 +726,7 @@ func (c *isIPv4CompatFunctionClass) getFunction(ctx sessionctx.Context, args []E } bf.tp.Flen = 1 sig := &builtinIsIPv4CompatSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_IsIPv4Compat) return sig, nil } @@ -769,6 +775,7 @@ func (c *isIPv4MappedFunctionClass) getFunction(ctx sessionctx.Context, args []E } bf.tp.Flen = 1 sig := &builtinIsIPv4MappedSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_IsIPv4Mapped) return sig, nil } @@ -817,6 +824,7 @@ func (c *isIPv6FunctionClass) getFunction(ctx sessionctx.Context, args []Express } bf.tp.Flen = 1 sig := &builtinIsIPv6Sig{bf} + sig.setPbCode(tipb.ScalarFuncSig_IsIPv6) return sig, nil } diff --git a/expression/errors.go b/expression/errors.go index ad0f49e64653c..b4905b1e9b1f5 100644 --- a/expression/errors.go +++ b/expression/errors.go @@ -48,6 +48,8 @@ var ( errUnknownLocale = dbterror.ClassExpression.NewStd(mysql.ErrUnknownLocale) errNonUniq = dbterror.ClassExpression.NewStd(mysql.ErrNonUniq) errWrongValueForType = dbterror.ClassExpression.NewStd(mysql.ErrWrongValueForType) + errUnknown = dbterror.ClassExpression.NewStd(mysql.ErrUnknown) + errSpecificAccessDenied = dbterror.ClassExpression.NewStd(mysql.ErrSpecificAccessDenied) // Sequence usage privilege check. errSequenceAccessDenied = dbterror.ClassExpression.NewStd(mysql.ErrTableaccessDenied) diff --git a/expression/expr_to_pb_test.go b/expression/expr_to_pb_test.go index 8f7c32811837e..8a79e25ade48f 100644 --- a/expression/expr_to_pb_test.go +++ b/expression/expr_to_pb_test.go @@ -755,7 +755,7 @@ func (s *testEvaluatorSuite) TestExprPushDownToFlash(c *C) { c.Assert(err, IsNil) exprs = append(exprs, function) - // ScalarFuncSig_CeilDecimalToDecimal + // ScalarFuncSig_CeilDecToDec function, err = NewFunction(mock.NewContext(), ast.Ceil, types.NewFieldType(mysql.TypeNewDecimal), decimalColumn) c.Assert(err, IsNil) exprs = append(exprs, function) @@ -770,16 +770,66 @@ func (s *testEvaluatorSuite) TestExprPushDownToFlash(c *C) { c.Assert(err, IsNil) exprs = append(exprs, function) - // ScalarFuncSig_FloorDecimalToInt + // ScalarFuncSig_FloorDecToInt function, err = NewFunction(mock.NewContext(), ast.Floor, types.NewFieldType(mysql.TypeLonglong), decimalColumn) c.Assert(err, IsNil) exprs = append(exprs, function) - // ScalarFuncSig_FloorDecimalToDecimal + // ScalarFuncSig_FloorDecToDec function, err = NewFunction(mock.NewContext(), ast.Floor, types.NewFieldType(mysql.TypeNewDecimal), decimalColumn) c.Assert(err, IsNil) exprs = append(exprs, function) + // ScalarFuncSig_Log1Arg + function, err = NewFunction(mock.NewContext(), ast.Log, types.NewFieldType(mysql.TypeDouble), realColumn) + c.Assert(err, IsNil) + exprs = append(exprs, function) + + // ScalarFuncSig_Log2Args + function, err = NewFunction(mock.NewContext(), ast.Log, types.NewFieldType(mysql.TypeDouble), realColumn, realColumn) + c.Assert(err, IsNil) + exprs = append(exprs, function) + + // ScalarFuncSig_Log2 + function, err = NewFunction(mock.NewContext(), ast.Log2, types.NewFieldType(mysql.TypeDouble), realColumn) + c.Assert(err, IsNil) + exprs = append(exprs, function) + + // ScalarFuncSig_Log10 + function, err = NewFunction(mock.NewContext(), ast.Log10, types.NewFieldType(mysql.TypeDouble), realColumn) + c.Assert(err, IsNil) + exprs = append(exprs, function) + + // ScalarFuncSig_Exp + function, err = NewFunction(mock.NewContext(), ast.Exp, types.NewFieldType(mysql.TypeDouble), realColumn) + c.Assert(err, IsNil) + exprs = append(exprs, function) + + // ScalarFuncSig_Pow + function, err = NewFunction(mock.NewContext(), ast.Pow, types.NewFieldType(mysql.TypeDouble), realColumn, realColumn) + c.Assert(err, IsNil) + exprs = append(exprs, function) + + // ScalarFuncSig_Radians + function, err = NewFunction(mock.NewContext(), ast.Radians, types.NewFieldType(mysql.TypeDouble), realColumn) + c.Assert(err, IsNil) + exprs = append(exprs, function) + + // ScalarFuncSig_Degrees + function, err = NewFunction(mock.NewContext(), ast.Degrees, types.NewFieldType(mysql.TypeDouble), realColumn) + c.Assert(err, IsNil) + exprs = append(exprs, function) + + // ScalarFuncSig_CRC32 + function, err = NewFunction(mock.NewContext(), ast.CRC32, types.NewFieldType(mysql.TypeLonglong), stringColumn) + c.Assert(err, IsNil) + exprs = append(exprs, function) + + // ScalarFuncSig_Conv + function, err = NewFunction(mock.NewContext(), ast.Conv, types.NewFieldType(mysql.TypeDouble), stringColumn, intColumn, intColumn) + c.Assert(err, IsNil) + exprs = append(exprs, function) + // Replace function, err = NewFunction(mock.NewContext(), ast.Replace, types.NewFieldType(mysql.TypeString), stringColumn, stringColumn, stringColumn) c.Assert(err, IsNil) @@ -958,6 +1008,58 @@ func (s *testEvaluatorSuite) TestExprOnlyPushDownToFlash(c *C) { c.Assert(len(remained), Equals, len(exprs)) } +func (s *testEvaluatorSuite) TestExprPushDownToTiKV(c *C) { + sc := new(stmtctx.StatementContext) + client := new(mock.Client) + dg := new(dataGen4Expr2PbTest) + exprs := make([]Expression, 0) + + //jsonColumn := dg.genColumn(mysql.TypeJSON, 1) + //intColumn := dg.genColumn(mysql.TypeLonglong, 2) + //realColumn := dg.genColumn(mysql.TypeDouble, 3) + //decimalColumn := dg.genColumn(mysql.TypeNewDecimal, 4) + stringColumn := dg.genColumn(mysql.TypeString, 5) + //datetimeColumn := dg.genColumn(mysql.TypeDatetime, 6) + binaryStringColumn := dg.genColumn(mysql.TypeString, 7) + binaryStringColumn.RetType.Collate = charset.CollationBin + + function, err := NewFunction(mock.NewContext(), ast.InetAton, types.NewFieldType(mysql.TypeString), stringColumn) + c.Assert(err, IsNil) + exprs = append(exprs, function) + + function, err = NewFunction(mock.NewContext(), ast.InetNtoa, types.NewFieldType(mysql.TypeLonglong), stringColumn) + c.Assert(err, IsNil) + exprs = append(exprs, function) + + function, err = NewFunction(mock.NewContext(), ast.Inet6Aton, types.NewFieldType(mysql.TypeString), stringColumn) + c.Assert(err, IsNil) + exprs = append(exprs, function) + + function, err = NewFunction(mock.NewContext(), ast.Inet6Ntoa, types.NewFieldType(mysql.TypeLonglong), stringColumn) + c.Assert(err, IsNil) + exprs = append(exprs, function) + + function, err = NewFunction(mock.NewContext(), ast.IsIPv4, types.NewFieldType(mysql.TypeString), stringColumn) + c.Assert(err, IsNil) + exprs = append(exprs, function) + + function, err = NewFunction(mock.NewContext(), ast.IsIPv6, types.NewFieldType(mysql.TypeString), stringColumn) + c.Assert(err, IsNil) + exprs = append(exprs, function) + + function, err = NewFunction(mock.NewContext(), ast.IsIPv4Compat, types.NewFieldType(mysql.TypeString), stringColumn) + c.Assert(err, IsNil) + exprs = append(exprs, function) + + function, err = NewFunction(mock.NewContext(), ast.IsIPv4Mapped, types.NewFieldType(mysql.TypeString), stringColumn) + c.Assert(err, IsNil) + exprs = append(exprs, function) + + pushed, remained := PushDownExprs(sc, exprs, client, kv.TiKV) + c.Assert(len(pushed), Equals, 0) + c.Assert(len(remained), Equals, len(exprs)) +} + func (s *testEvaluatorSuite) TestExprOnlyPushDownToTiKV(c *C) { sc := new(stmtctx.StatementContext) client := new(mock.Client) diff --git a/expression/expression.go b/expression/expression.go index a73c6df8ba735..77559112635b3 100644 --- a/expression/expression.go +++ b/expression/expression.go @@ -981,7 +981,9 @@ func scalarExprSupportedByTiKV(sf *ScalarFunction) bool { ast.Cast, // misc functions. - ast.InetNtoa, ast.InetAton, ast.Inet6Ntoa, ast.Inet6Aton, ast.IsIPv4, ast.IsIPv4Compat, ast.IsIPv4Mapped, ast.IsIPv6, ast.UUID: + // TODO(#26942): enable functions below after them are fully tested in TiKV. + /*ast.InetNtoa, ast.InetAton, ast.Inet6Ntoa, ast.Inet6Aton, ast.IsIPv4, ast.IsIPv4Compat, ast.IsIPv4Mapped, ast.IsIPv6,*/ + ast.UUID: return true @@ -1015,9 +1017,10 @@ func scalarExprSupportedByFlash(function *ScalarFunction) bool { ast.Plus, ast.Minus, ast.Div, ast.Mul, ast.Abs, ast.Mod, ast.If, ast.Ifnull, ast.Case, ast.Concat, ast.ConcatWS, - ast.Year, ast.Month, ast.Day, + ast.Date, ast.Year, ast.Month, ast.Day, ast.DateDiff, ast.TimestampDiff, ast.DateFormat, ast.FromUnixTime, - ast.Sqrt, + ast.Sqrt, ast.Log, ast.Log2, ast.Log10, ast.Ln, ast.Exp, ast.Pow, ast.Sign, + ast.Radians, ast.Degrees, ast.Conv, ast.CRC32, ast.JSONLength: return true case ast.Substr, ast.Substring, ast.Left, ast.Right, ast.CharLength: @@ -1044,7 +1047,7 @@ func scalarExprSupportedByFlash(function *ScalarFunction) bool { } case ast.DateAdd, ast.AddDate: switch function.Function.PbCode() { - case tipb.ScalarFuncSig_AddDateDatetimeInt, tipb.ScalarFuncSig_AddDateStringInt: + case tipb.ScalarFuncSig_AddDateDatetimeInt, tipb.ScalarFuncSig_AddDateStringInt, tipb.ScalarFuncSig_AddDateStringReal: return true } case ast.DateSub, ast.SubDate: @@ -1160,15 +1163,14 @@ func init() { func canScalarFuncPushDown(scalarFunc *ScalarFunction, pc PbConverter, storeType kv.StoreType) bool { pbCode := scalarFunc.Function.PbCode() - if pbCode <= tipb.ScalarFuncSig_Unspecified { - failpoint.Inject("PanicIfPbCodeUnspecified", func() { - panic(errors.Errorf("unspecified PbCode: %T", scalarFunc.Function)) - }) - return false - } // Check whether this function can be pushed. - if !canFuncBePushed(scalarFunc, storeType) { + if unspecified := pbCode <= tipb.ScalarFuncSig_Unspecified; unspecified || !canFuncBePushed(scalarFunc, storeType) { + if unspecified { + failpoint.Inject("PanicIfPbCodeUnspecified", func() { + panic(errors.Errorf("unspecified PbCode: %T", scalarFunc.Function)) + }) + } if pc.sc.InExplainStmt { storageName := storeType.Name() if storeType == kv.UnSpecified { diff --git a/expression/util.go b/expression/util.go index 3a9ecfe27b53a..56d526495638f 100644 --- a/expression/util.go +++ b/expression/util.go @@ -14,6 +14,7 @@ package expression import ( + "context" "math" "strconv" "strings" @@ -21,6 +22,7 @@ import ( "unicode" "github.com/pingcap/errors" + "github.com/pingcap/failpoint" "github.com/pingcap/parser/ast" "github.com/pingcap/parser/mysql" "github.com/pingcap/parser/opcode" @@ -31,6 +33,7 @@ import ( "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/collate" "github.com/pingcap/tidb/util/logutil" + "github.com/pingcap/tidb/util/sqlexec" "go.uber.org/zap" "golang.org/x/tools/container/intsets" ) @@ -994,3 +997,176 @@ func GetFormatNanoTime(time float64) string { } return strconv.FormatFloat(value, 'f', 2, 64) + " " + unit } + +// SQLDigestTextRetriever is used to find the normalized SQL statement text by SQL digests in statements_summary table. +// It's exported for test purposes. It's used by the `tidb_decode_sql_digests` builtin function, but also exposed to +// be used in other modules. +type SQLDigestTextRetriever struct { + // SQLDigestsMap is the place to put the digests that's requested for getting SQL text and also the place to put + // the query result. + SQLDigestsMap map[string]string + + // Replace querying for test purposes. + mockLocalData map[string]string + mockGlobalData map[string]string + // There are two ways for querying information: 1) query specified digests by WHERE IN query, or 2) query all + // information to avoid the too long WHERE IN clause. If there are more than `fetchAllLimit` digests needs to be + // queried, the second way will be chosen; otherwise, the first way will be chosen. + fetchAllLimit int +} + +// NewSQLDigestTextRetriever creates a new SQLDigestTextRetriever. +func NewSQLDigestTextRetriever() *SQLDigestTextRetriever { + return &SQLDigestTextRetriever{ + SQLDigestsMap: make(map[string]string), + fetchAllLimit: 512, + } +} + +func (r *SQLDigestTextRetriever) runMockQuery(data map[string]string, inValues []interface{}) (map[string]string, error) { + if len(inValues) == 0 { + return data, nil + } + res := make(map[string]string, len(inValues)) + for _, digest := range inValues { + if text, ok := data[digest.(string)]; ok { + res[digest.(string)] = text + } + } + return res, nil +} + +// runFetchDigestQuery runs query to the system tables to fetch the kv mapping of SQL digests and normalized SQL texts +// of the given SQL digests, if `inValues` is given, or all these mappings otherwise. If `queryGlobal` is false, it +// queries information_schema.statements_summary and information_schema.statements_summary_history; otherwise, it +// queries the cluster version of these two tables. +func (r *SQLDigestTextRetriever) runFetchDigestQuery(ctx context.Context, sctx sessionctx.Context, queryGlobal bool, inValues []interface{}) (map[string]string, error) { + // If mock data is set, query the mock data instead of the real statements_summary tables. + if !queryGlobal && r.mockLocalData != nil { + return r.runMockQuery(r.mockLocalData, inValues) + } else if queryGlobal && r.mockGlobalData != nil { + return r.runMockQuery(r.mockGlobalData, inValues) + } + + exec, ok := sctx.(sqlexec.RestrictedSQLExecutor) + if !ok { + return nil, errors.New("restricted sql can't be executed in this context") + } + + // Information in statements_summary will be periodically moved to statements_summary_history. Union them together + // to avoid missing information when statements_summary is just cleared. + stmt := "select digest, digest_text from information_schema.statements_summary union distinct " + + "select digest, digest_text from information_schema.statements_summary_history" + if queryGlobal { + stmt = "select digest, digest_text from information_schema.cluster_statements_summary union distinct " + + "select digest, digest_text from information_schema.cluster_statements_summary_history" + } + // Add the where clause if `inValues` is specified. + if len(inValues) > 0 { + stmt += " where digest in (" + strings.Repeat("%?,", len(inValues)-1) + "%?)" + } + + stmtNode, err := exec.ParseWithParams(ctx, stmt, inValues...) + if err != nil { + return nil, err + } + rows, _, err := exec.ExecRestrictedStmt(ctx, stmtNode) + if err != nil { + return nil, err + } + + res := make(map[string]string, len(rows)) + for _, row := range rows { + res[row.GetString(0)] = row.GetString(1) + } + return res, nil +} + +func (r *SQLDigestTextRetriever) updateDigestInfo(queryResult map[string]string) { + for digest, text := range r.SQLDigestsMap { + if len(text) > 0 { + // The text of this digest is already known + continue + } + sqlText, ok := queryResult[digest] + if ok { + r.SQLDigestsMap[digest] = sqlText + } + } +} + +// RetrieveLocal tries to retrieve the SQL text of the SQL digests from local information. +func (r *SQLDigestTextRetriever) RetrieveLocal(ctx context.Context, sctx sessionctx.Context) error { + if len(r.SQLDigestsMap) == 0 { + return nil + } + + var queryResult map[string]string + if len(r.SQLDigestsMap) <= r.fetchAllLimit { + inValues := make([]interface{}, 0, len(r.SQLDigestsMap)) + for key := range r.SQLDigestsMap { + inValues = append(inValues, key) + } + var err error + queryResult, err = r.runFetchDigestQuery(ctx, sctx, false, inValues) + if err != nil { + return errors.Trace(err) + } + + if len(queryResult) == len(r.SQLDigestsMap) { + r.SQLDigestsMap = queryResult + return nil + } + } else { + var err error + queryResult, err = r.runFetchDigestQuery(ctx, sctx, false, nil) + if err != nil { + return errors.Trace(err) + } + } + + r.updateDigestInfo(queryResult) + return nil +} + +// RetrieveGlobal tries to retrieve the SQL text of the SQL digests from the information of the whole cluster. +func (r *SQLDigestTextRetriever) RetrieveGlobal(ctx context.Context, sctx sessionctx.Context) error { + err := r.RetrieveLocal(ctx, sctx) + if err != nil { + return errors.Trace(err) + } + + // In some unit test environments it's unable to retrieve global info, and this function blocks it for tens of + // seconds, which wastes much time during unit test. In this case, enable this failpoint to bypass retrieving + // globally. + failpoint.Inject("sqlDigestRetrieverSkipRetrieveGlobal", func() { + failpoint.Return(nil) + }) + + var unknownDigests []interface{} + for k, v := range r.SQLDigestsMap { + if len(v) == 0 { + unknownDigests = append(unknownDigests, k) + } + } + + if len(unknownDigests) == 0 { + return nil + } + + var queryResult map[string]string + if len(r.SQLDigestsMap) <= r.fetchAllLimit { + queryResult, err = r.runFetchDigestQuery(ctx, sctx, true, unknownDigests) + if err != nil { + return errors.Trace(err) + } + } else { + queryResult, err = r.runFetchDigestQuery(ctx, sctx, true, nil) + if err != nil { + return errors.Trace(err) + } + } + + r.updateDigestInfo(queryResult) + return nil +} diff --git a/expression/util_test.go b/expression/util_test.go index 039f399573466..ed9ba8448b116 100644 --- a/expression/util_test.go +++ b/expression/util_test.go @@ -14,6 +14,7 @@ package expression import ( + "context" "reflect" "testing" "time" @@ -413,6 +414,68 @@ func (s *testUtilSuite) TestDisableParseJSONFlag4Expr(c *check.C) { c.Assert(mysql.HasParseToJSONFlag(ft.Flag), check.IsFalse) } +func (s *testUtilSuite) TestSQLDigestTextRetriever(c *check.C) { + // Create a fake session as the argument to the retriever, though it's actually not used when mock data is set. + + r := NewSQLDigestTextRetriever() + clearResult := func() { + r.SQLDigestsMap = map[string]string{ + "digest1": "", + "digest2": "", + "digest3": "", + "digest4": "", + "digest5": "", + } + } + clearResult() + r.mockLocalData = map[string]string{ + "digest1": "text1", + "digest2": "text2", + "digest6": "text6", + } + r.mockGlobalData = map[string]string{ + "digest2": "text2", + "digest3": "text3", + "digest4": "text4", + "digest7": "text7", + } + + expectedLocalResult := map[string]string{ + "digest1": "text1", + "digest2": "text2", + "digest3": "", + "digest4": "", + "digest5": "", + } + expectedGlobalResult := map[string]string{ + "digest1": "text1", + "digest2": "text2", + "digest3": "text3", + "digest4": "text4", + "digest5": "", + } + + err := r.RetrieveLocal(context.Background(), nil) + c.Assert(err, check.IsNil) + c.Assert(r.SQLDigestsMap, check.DeepEquals, expectedLocalResult) + clearResult() + + err = r.RetrieveGlobal(context.Background(), nil) + c.Assert(err, check.IsNil) + c.Assert(r.SQLDigestsMap, check.DeepEquals, expectedGlobalResult) + clearResult() + + r.fetchAllLimit = 1 + err = r.RetrieveLocal(context.Background(), nil) + c.Assert(err, check.IsNil) + c.Assert(r.SQLDigestsMap, check.DeepEquals, expectedLocalResult) + clearResult() + + err = r.RetrieveGlobal(context.Background(), nil) + c.Assert(err, check.IsNil) + c.Assert(r.SQLDigestsMap, check.DeepEquals, expectedGlobalResult) +} + func BenchmarkExtractColumns(b *testing.B) { conditions := []Expression{ newFunction(ast.EQ, newColumn(0), newColumn(1)), diff --git a/go.mod b/go.mod index 26df343063a8b..8b7cba7ed1a0e 100644 --- a/go.mod +++ b/go.mod @@ -51,7 +51,7 @@ require ( github.com/pingcap/fn v0.0.0-20200306044125-d5540d389059 github.com/pingcap/kvproto v0.0.0-20210805052247-76981389e818 github.com/pingcap/log v0.0.0-20210625125904-98ed8e2eb1c7 - github.com/pingcap/parser v0.0.0-20210802034743-dd9b189324ce + github.com/pingcap/parser v0.0.0-20210803205906-cece3020391a github.com/pingcap/sysutil v0.0.0-20210315073920-cc0985d983a3 github.com/pingcap/tidb-tools v5.0.3+incompatible github.com/pingcap/tipb v0.0.0-20210708040514-0f154bb0dc0f diff --git a/go.sum b/go.sum index 9f8e16bb51946..64a60a99c1eed 100644 --- a/go.sum +++ b/go.sum @@ -560,8 +560,8 @@ github.com/pingcap/log v0.0.0-20210317133921-96f4fcab92a4/go.mod h1:4rbK1p9ILyIf github.com/pingcap/log v0.0.0-20210625125904-98ed8e2eb1c7 h1:k2BbABz9+TNpYRwsCCFS8pEEnFVOdbgEjL/kTlLuzZQ= github.com/pingcap/log v0.0.0-20210625125904-98ed8e2eb1c7/go.mod h1:8AanEdAHATuRurdGxZXBz0At+9avep+ub7U1AGYLIMM= github.com/pingcap/parser v0.0.0-20210525032559-c37778aff307/go.mod h1:xZC8I7bug4GJ5KtHhgAikjTfU4kBv1Sbo3Pf1MZ6lVw= -github.com/pingcap/parser v0.0.0-20210802034743-dd9b189324ce h1:3KjHJw5FjUbrLLunmzEdmU/CeXNfLaqnP9AMVfOVOQU= -github.com/pingcap/parser v0.0.0-20210802034743-dd9b189324ce/go.mod h1:Ek0mLKEqUGnQqBw1JnYrJQxsguU433DU68yUbsoeJ7s= +github.com/pingcap/parser v0.0.0-20210803205906-cece3020391a h1:NPO1iSULt7ztYOEifJ73IZA+xF3ywgX0Ik0X6PKy8BI= +github.com/pingcap/parser v0.0.0-20210803205906-cece3020391a/go.mod h1:Ek0mLKEqUGnQqBw1JnYrJQxsguU433DU68yUbsoeJ7s= github.com/pingcap/sysutil v0.0.0-20200206130906-2bfa6dc40bcd/go.mod h1:EB/852NMQ+aRKioCpToQ94Wl7fktV+FNnxf3CX/TTXI= github.com/pingcap/sysutil v0.0.0-20210315073920-cc0985d983a3 h1:A9KL9R+lWSVPH8IqUuH1QSTRJ5FGoY1bT2IcfPKsWD8= github.com/pingcap/sysutil v0.0.0-20210315073920-cc0985d983a3/go.mod h1:tckvA041UWP+NqYzrJ3fMgC/Hw9wnmQ/tUkp/JaHly8= diff --git a/infoschema/tables_test.go b/infoschema/tables_test.go index 6fe1feeed4e5b..51629df4ada75 100644 --- a/infoschema/tables_test.go +++ b/infoschema/tables_test.go @@ -1308,7 +1308,7 @@ func (s *testTableSuite) TestStmtSummaryInternalQuery(c *C) { "where digest_text like \"select `original_sql` , `bind_sql` , `default_db` , status%\"" tk.MustQuery(sql).Check(testkit.Rows( "select `original_sql` , `bind_sql` , `default_db` , status , `create_time` , `update_time` , charset , " + - "collation , source from `mysql` . `bind_info` where `update_time` > ? order by `update_time`")) + "collation , source from `mysql` . `bind_info` where `update_time` > ? order by `update_time` , `create_time`")) // Test for issue #21642. tk.MustQuery(`select tidb_version()`) @@ -1754,9 +1754,14 @@ func (s *testTableSuite) TestInfoschemaClientErrors(c *C) { c.Assert(err.Error(), Equals, "[planner:1227]Access denied; you need (at least one of) the RELOAD privilege(s) for this operation") } -func (s *testTableSuite) TestTrx(c *C) { +func (s *testTableSuite) TestTiDBTrx(c *C) { tk := s.newTestKitWithRoot(c) - _, digest := parser.NormalizeDigest("select * from trx for update;") + tk.MustExec("drop table if exists test_tidb_trx") + tk.MustExec("create table test_tidb_trx(i int)") + // Execute the statement once so that the statement will be collected into statements_summary and able to be found + // by digest. + tk.MustExec("update test_tidb_trx set i = i + 1") + _, digest := parser.NormalizeDigest("update test_tidb_trx set i = i + 1") sm := &mockSessionManager{nil, make([]*txninfo.TxnInfo, 2)} sm.txnInfo[0] = &txninfo.TxnInfo{ StartTS: 424768545227014155, @@ -1772,7 +1777,7 @@ func (s *testTableSuite) TestTrx(c *C) { sm.txnInfo[1] = &txninfo.TxnInfo{ StartTS: 425070846483628033, CurrentSQLDigest: "", - AllSQLDigests: []string{"sql1", "sql2"}, + AllSQLDigests: []string{"sql1", "sql2", digest.String()}, State: txninfo.TxnLockWaiting, ConnectionID: 10, Username: "user1", @@ -1781,9 +1786,19 @@ func (s *testTableSuite) TestTrx(c *C) { sm.txnInfo[1].BlockStartTime.Valid = true sm.txnInfo[1].BlockStartTime.Time = blockTime2 tk.Se.SetSessionManager(sm) + tk.MustQuery("select * from information_schema.TIDB_TRX;").Check(testkit.Rows( - "424768545227014155 2021-05-07 12:56:48.001000 "+digest.String()+" Idle 1 19 2 root test []", - "425070846483628033 2021-05-20 21:16:35.778000 LockWaiting 2021-05-20 13:18:30.123456 0 0 10 user1 db1 [\"sql1\",\"sql2\"]")) + "424768545227014155 2021-05-07 12:56:48.001000 "+digest.String()+" update `test_tidb_trx` set `i` = `i` + ? Idle 1 19 2 root test []", + "425070846483628033 2021-05-20 21:16:35.778000 LockWaiting 2021-05-20 13:18:30.123456 0 0 10 user1 db1 [\"sql1\",\"sql2\",\""+digest.String()+"\"]")) + + // Test the all_sql_digests column can be directly passed to the tidb_decode_sql_digests function. + c.Assert(failpoint.Enable("github.com/pingcap/tidb/expression/sqlDigestRetrieverSkipRetrieveGlobal", "return"), IsNil) + defer func() { + c.Assert(failpoint.Disable("github.com/pingcap/tidb/expression/sqlDigestRetrieverSkipRetrieveGlobal"), IsNil) + }() + tk.MustQuery("select tidb_decode_sql_digests(all_sql_digests) from information_schema.tidb_trx").Check(testkit.Rows( + "[]", + "[null,null,\"update `test_tidb_trx` set `i` = `i` + ?\"]")) } func (s *testTableSuite) TestInfoschemaDeadlockPrivilege(c *C) { diff --git a/planner/core/common_plans.go b/planner/core/common_plans.go index 8f82ce46ddda1..2fa0878d5552e 100644 --- a/planner/core/common_plans.go +++ b/planner/core/common_plans.go @@ -282,6 +282,10 @@ func (e *Execute) OptimizePreparedPlan(ctx context.Context, sctx sessionctx.Cont preparedObj.Executor = nil // If the schema version has changed we need to preprocess it again, // if this time it failed, the real reason for the error is schema changed. + // Example: + // When running update in prepared statement's schema version distinguished from the one of execute statement + // We should reset the tableRefs in the prepared update statements, otherwise, the ast nodes still hold the old + // tableRefs columnInfo which will cause chaos in logic of trying point get plan. (should ban non-public column) ret := &PreprocessorReturn{InfoSchema: is} err := Preprocess(sctx, prepared.Stmt, InPrepare, WithPreprocessorReturn(ret)) if err != nil { diff --git a/planner/core/find_best_task.go b/planner/core/find_best_task.go index 9cd9f7cac3f31..305f884ec3085 100644 --- a/planner/core/find_best_task.go +++ b/planner/core/find_best_task.go @@ -649,6 +649,16 @@ func (ds *DataSource) skylinePruning(prop *property.PhysicalProperty) []*candida return candidates } +func (ds *DataSource) isPointGetConvertableSchema() bool { + for _, col := range ds.Columns { + // Only handle tables that all columns are public. + if col.State != model.StatePublic { + return false + } + } + return true +} + // findBestTask implements the PhysicalPlan interface. // It will enumerate all the available indices and choose a plan with least cost. func (ds *DataSource) findBestTask(prop *property.PhysicalProperty, planCounter *PlanCounterTp) (t task, cntPlan int64, err error) { @@ -745,7 +755,7 @@ func (ds *DataSource) findBestTask(prop *property.PhysicalProperty, planCounter p: dual, }, cntPlan, nil } - canConvertPointGet := len(path.Ranges) > 0 && path.StoreType == kv.TiKV + canConvertPointGet := len(path.Ranges) > 0 && path.StoreType == kv.TiKV && ds.isPointGetConvertableSchema() if canConvertPointGet && !path.IsIntHandlePath { // We simply do not build [batch] point get for prefix indexes. This can be optimized. canConvertPointGet = path.Index.Unique && !path.Index.HasPrefixIndex() diff --git a/planner/core/integration_test.go b/planner/core/integration_test.go index 5a4613e42ec29..c63ac9b6b6a53 100644 --- a/planner/core/integration_test.go +++ b/planner/core/integration_test.go @@ -350,8 +350,7 @@ func (s *testIntegrationSerialSuite) TestNoneAccessPathsFoundByIsolationRead(c * tk.MustExec("drop table if exists t") tk.MustExec("create table t(a int primary key)") - _, err := tk.Exec("select * from t") - c.Assert(err, IsNil) + tk.MustExec("select * from t") tk.MustExec("set @@session.tidb_isolation_read_engines = 'tiflash'") @@ -360,7 +359,7 @@ func (s *testIntegrationSerialSuite) TestNoneAccessPathsFoundByIsolationRead(c * "TableReader 10000.00 root data:TableFullScan", "└─TableFullScan 10000.00 cop[tikv] table:stats_meta keep order:false, stats:pseudo")) - _, err = tk.Exec("select * from t") + _, err := tk.Exec("select * from t") c.Assert(err, NotNil) c.Assert(err.Error(), Equals, "[planner:1815]Internal : Can not find access path matching 'tidb_isolation_read_engines'(value: 'tiflash'). Available values are 'tikv'.") @@ -2212,6 +2211,7 @@ func (s *testIntegrationSuite) TestSelectLimit(c *C) { // normal test tk.MustExec("set @@session.sql_select_limit=1") result := tk.MustQuery("select * from t order by a") + c.Assert(tk.Se.GetSessionVars().StmtCtx.GetWarnings(), HasLen, 0) result.Check(testkit.Rows("1")) result = tk.MustQuery("select * from t order by a limit 2") result.Check(testkit.Rows("1", "1")) diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index 838a97bf0359d..7d2dd60520ab4 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -4655,10 +4655,10 @@ type tblUpdateInfo struct { } // CheckUpdateList checks all related columns in updatable state. -func CheckUpdateList(assignFlags []int, updt *Update) error { +func CheckUpdateList(assignFlags []int, updt *Update, newTblID2Table map[int64]table.Table) error { updateFromOtherAlias := make(map[int64]tblUpdateInfo) for _, content := range updt.TblColPosInfos { - tbl := updt.tblID2Table[content.TblID] + tbl := newTblID2Table[content.TblID] flags := assignFlags[content.Start:content.End] var update, updatePK bool for i, col := range tbl.WritableCols() { diff --git a/planner/core/preprocess.go b/planner/core/preprocess.go index fff5467ead5cd..2efc7dedf5f3d 100644 --- a/planner/core/preprocess.go +++ b/planner/core/preprocess.go @@ -63,6 +63,13 @@ func WithPreprocessorReturn(ret *PreprocessorReturn) PreprocessOpt { } } +// WithExecuteInfoSchemaUpdate return a PreprocessOpt to update the `Execute` infoSchema under some conditions. +func WithExecuteInfoSchemaUpdate(pe *PreprocessExecuteISUpdate) PreprocessOpt { + return func(p *preprocessor) { + p.PreprocessExecuteISUpdate = pe + } +} + // TryAddExtraLimit trys to add an extra limit for SELECT or UNION statement when sql_select_limit is set. func TryAddExtraLimit(ctx sessionctx.Context, node ast.StmtNode) ast.StmtNode { if ctx.GetSessionVars().SelectLimit == math.MaxUint64 || ctx.GetSessionVars().InRestrictedSQL { @@ -143,6 +150,12 @@ type PreprocessorReturn struct { TxnScope string } +// PreprocessExecuteISUpdate is used to update information schema for special Execute statement in the preprocessor. +type PreprocessExecuteISUpdate struct { + ExecuteInfoSchemaUpdate func(node ast.Node, sctx sessionctx.Context) infoschema.InfoSchema + Node ast.Node +} + // preprocessor is an ast.Visitor that preprocess // ast Nodes parsed from parser. type preprocessor struct { @@ -157,6 +170,7 @@ type preprocessor struct { // values that may be returned *PreprocessorReturn + *PreprocessExecuteISUpdate err error } @@ -1596,9 +1610,17 @@ func (p *preprocessor) handleAsOfAndReadTS(node *ast.AsOfClause) { // - session variable // - transaction context func (p *preprocessor) ensureInfoSchema() infoschema.InfoSchema { - if p.InfoSchema == nil { - p.InfoSchema = p.ctx.GetInfoSchema().(infoschema.InfoSchema) + if p.InfoSchema != nil { + return p.InfoSchema + } + // `Execute` under some conditions need to see the latest information schema. + if p.PreprocessExecuteISUpdate != nil { + if newInfoSchema := p.ExecuteInfoSchemaUpdate(p.Node, p.ctx); newInfoSchema != nil { + p.InfoSchema = newInfoSchema + return p.InfoSchema + } } + p.InfoSchema = p.ctx.GetInfoSchema().(infoschema.InfoSchema) return p.InfoSchema } diff --git a/planner/core/task.go b/planner/core/task.go index 49fd3c75efcb9..c651d81b1a1d3 100644 --- a/planner/core/task.go +++ b/planner/core/task.go @@ -1442,7 +1442,6 @@ func CheckAggCanPushCop(sctx sessionctx.Context, aggFuncs []*aggregation.AggFunc } if !ret && sc.InExplainStmt { - sctx.GetSessionVars().RaiseWarningWhenMPPEnforced("MPP mode may be blocked because " + reason) storageName := storeType.Name() if storeType == kv.UnSpecified { storageName = "storage layer" diff --git a/planner/core/testdata/enforce_mpp_suite_out.json b/planner/core/testdata/enforce_mpp_suite_out.json index 5b479befa17c4..6be7eb342aadb 100644 --- a/planner/core/testdata/enforce_mpp_suite_out.json +++ b/planner/core/testdata/enforce_mpp_suite_out.json @@ -295,15 +295,10 @@ " └─TableRowIDScan_41(Probe) 10.00 cop[tikv] table:t keep order:false, stats:pseudo" ], "Warn": [ - "MPP mode may be blocked because expressions of AggFunc `count` contain virtual column or correlated column, which is not supported now", "Aggregation can not be pushed to tiflash because expressions of AggFunc `count` contain virtual column or correlated column, which is not supported now", - "MPP mode may be blocked because expressions of AggFunc `count` contain virtual column or correlated column, which is not supported now", "Aggregation can not be pushed to tiflash because expressions of AggFunc `count` contain virtual column or correlated column, which is not supported now", - "MPP mode may be blocked because expressions of AggFunc `count` contain virtual column or correlated column, which is not supported now", "Aggregation can not be pushed to tikv because expressions of AggFunc `count` contain virtual column or correlated column, which is not supported now", - "MPP mode may be blocked because expressions of AggFunc `count` contain virtual column or correlated column, which is not supported now", "Aggregation can not be pushed to tiflash because expressions of AggFunc `count` contain virtual column or correlated column, which is not supported now", - "MPP mode may be blocked because expressions of AggFunc `count` contain virtual column or correlated column, which is not supported now", "Aggregation can not be pushed to tiflash because expressions of AggFunc `count` contain virtual column or correlated column, which is not supported now" ] }, @@ -316,11 +311,8 @@ " └─TableFullScan_10 10000.00 cop[tiflash] table:t keep order:false, stats:pseudo" ], "Warn": [ - "MPP mode may be blocked because groupByItems contain virtual columns, which is not supported now", "Aggregation can not be pushed to tiflash because groupByItems contain virtual columns, which is not supported now", - "MPP mode may be blocked because groupByItems contain virtual columns, which is not supported now", "Aggregation can not be pushed to tiflash because groupByItems contain virtual columns, which is not supported now", - "MPP mode may be blocked because groupByItems contain virtual columns, which is not supported now", "Aggregation can not be pushed to tiflash because groupByItems contain virtual columns, which is not supported now" ] }, @@ -333,13 +325,9 @@ " └─TableFullScan_12 10000.00 cop[tiflash] table:t keep order:false, stats:pseudo" ], "Warn": [ - "MPP mode may be blocked because AggFunc `group_concat` is not supported now", "Aggregation can not be pushed to tiflash because AggFunc `group_concat` is not supported now", - "MPP mode may be blocked because AggFunc `group_concat` is not supported now", "Aggregation can not be pushed to tiflash because AggFunc `group_concat` is not supported now", - "MPP mode may be blocked because AggFunc `group_concat` is not supported now", "Aggregation can not be pushed to tiflash because AggFunc `group_concat` is not supported now", - "MPP mode may be blocked because AggFunc `group_concat` is not supported now", "Aggregation can not be pushed to tiflash because AggFunc `group_concat` is not supported now" ] }, @@ -353,13 +341,10 @@ ], "Warn": [ "Scalar function 'md5'(signature: MD5) can not be pushed to tiflash", - "MPP mode may be blocked because groupByItems contain unsupported exprs", "Aggregation can not be pushed to tiflash because groupByItems contain unsupported exprs", "Scalar function 'md5'(signature: MD5) can not be pushed to tiflash", - "MPP mode may be blocked because groupByItems contain unsupported exprs", "Aggregation can not be pushed to tiflash because groupByItems contain unsupported exprs", "Scalar function 'md5'(signature: MD5) can not be pushed to tiflash", - "MPP mode may be blocked because groupByItems contain unsupported exprs", "Aggregation can not be pushed to tiflash because groupByItems contain unsupported exprs" ] }, diff --git a/planner/optimize.go b/planner/optimize.go index 80225befe1b42..05ba86f3bbce7 100644 --- a/planner/optimize.go +++ b/planner/optimize.go @@ -77,6 +77,24 @@ func IsReadOnly(node ast.Node, vars *variable.SessionVars) bool { return ast.IsReadOnly(node) } +// GetExecuteForUpdateReadIS is used to check whether the statement is `execute` and target statement has a forUpdateRead flag. +// If so, we will return the latest information schema. +func GetExecuteForUpdateReadIS(node ast.Node, sctx sessionctx.Context) infoschema.InfoSchema { + if execStmt, isExecStmt := node.(*ast.ExecuteStmt); isExecStmt { + vars := sctx.GetSessionVars() + execID := execStmt.ExecID + if execStmt.Name != "" { + execID = vars.PreparedStmtNameToID[execStmt.Name] + } + if preparedPointer, ok := vars.PreparedStmts[execID]; ok { + if preparedObj, ok := preparedPointer.(*core.CachedPrepareStmt); ok && preparedObj.ForUpdateRead { + return domain.GetDomain(sctx).InfoSchema() + } + } + } + return nil +} + // Optimize does optimization and creates a Plan. // The node must be prepared first. func Optimize(ctx context.Context, sctx sessionctx.Context, node ast.Node, is infoschema.InfoSchema) (plannercore.Plan, types.NameSlice, error) { @@ -127,10 +145,6 @@ func Optimize(ctx context.Context, sctx sessionctx.Context, node ast.Node, is in if !ok { useBinding = false } - if useBinding && sessVars.SelectLimit != math.MaxUint64 { - sessVars.StmtCtx.AppendWarning(errors.New("sql_select_limit is set, ignore SQL bindings")) - useBinding = false - } var ( bindRecord *bindinfo.BindRecord scope string @@ -142,6 +156,10 @@ func Optimize(ctx context.Context, sctx sessionctx.Context, node ast.Node, is in useBinding = false } } + if useBinding && sessVars.SelectLimit != math.MaxUint64 { + sessVars.StmtCtx.AppendWarning(errors.New("sql_select_limit is set, ignore SQL bindings")) + useBinding = false + } var names types.NameSlice var bestPlan, bestPlanFromBind plannercore.Plan @@ -186,8 +204,8 @@ func Optimize(ctx context.Context, sctx sessionctx.Context, node ast.Node, is in if err := setFoundInBinding(sctx, true); err != nil { logutil.BgLogger().Warn("set tidb_found_in_binding failed", zap.Error(err)) } - if _, ok := stmtNode.(*ast.ExplainStmt); ok { - sessVars.StmtCtx.AppendWarning(errors.Errorf("Using the bindSQL: %v", chosenBinding.BindSQL)) + if sessVars.StmtCtx.InVerboseExplain { + sessVars.StmtCtx.AppendNote(errors.Errorf("Using the bindSQL: %v", chosenBinding.BindSQL)) } } // Restore the hint to avoid changing the stmt node. @@ -318,18 +336,6 @@ func optimize(ctx context.Context, sctx sessionctx.Context, node ast.Node, is in } sctx.GetSessionVars().RewritePhaseInfo.DurationRewrite = time.Since(beginRewrite) - if execPlan, ok := p.(*plannercore.Execute); ok { - execID := execPlan.ExecID - if execPlan.Name != "" { - execID = sctx.GetSessionVars().PreparedStmtNameToID[execPlan.Name] - } - if preparedPointer, ok := sctx.GetSessionVars().PreparedStmts[execID]; ok { - if preparedObj, ok := preparedPointer.(*core.CachedPrepareStmt); ok && preparedObj.ForUpdateRead { - is = domain.GetDomain(sctx).InfoSchema() - } - } - } - sctx.GetSessionVars().StmtCtx.Tables = builder.GetDBTableInfo() activeRoles := sctx.GetSessionVars().ActiveRoles // Check privilege. Maybe it's better to move this to the Preprocess, but diff --git a/plugin/integration_test.go b/plugin/integration_test.go new file mode 100644 index 0000000000000..9015a9be276e0 --- /dev/null +++ b/plugin/integration_test.go @@ -0,0 +1,162 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package plugin_test + +import ( + "bytes" + "context" + "fmt" + "strconv" + "testing" + + "github.com/pingcap/tidb/config" + "github.com/pingcap/tidb/domain" + "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/plugin" + "github.com/pingcap/tidb/session" + "github.com/pingcap/tidb/sessionctx/variable" + "github.com/pingcap/tidb/store/mockstore" + "github.com/pingcap/tidb/testkit" + "github.com/stretchr/testify/require" + "github.com/tikv/client-go/v2/testutils" +) + +type testAuditLogSuite struct { + cluster testutils.Cluster + store kv.Storage + dom *domain.Domain + + bytes.Buffer +} + +func (s *testAuditLogSuite) setup(t *testing.T) { + pluginName := "test_audit_log" + pluginVersion := uint16(1) + pluginSign := pluginName + "-" + strconv.Itoa(int(pluginVersion)) + + config.UpdateGlobal(func(conf *config.Config) { + conf.Plugin.Load = pluginSign + }) + + // setup load test hook. + loadOne := func(p *plugin.Plugin, dir string, pluginID plugin.ID) (manifest func() *plugin.Manifest, err error) { + return func() *plugin.Manifest { + m := &plugin.AuditManifest{ + Manifest: plugin.Manifest{ + Kind: plugin.Audit, + Name: pluginName, + Version: pluginVersion, + OnInit: OnInit, + OnShutdown: OnShutdown, + Validate: Validate, + }, + OnGeneralEvent: s.OnGeneralEvent, + OnConnectionEvent: OnConnectionEvent, + } + return plugin.ExportManifest(m) + }, nil + } + plugin.SetTestHook(loadOne) + + store, err := mockstore.NewMockStore( + mockstore.WithClusterInspector(func(c testutils.Cluster) { + mockstore.BootstrapWithSingleStore(c) + s.cluster = c + }), + ) + require.NoError(t, err) + s.store = store + session.SetSchemaLease(0) + session.DisableStats4Test() + + d, err := session.BootstrapSession(s.store) + require.NoError(t, err) + d.SetStatsUpdating(true) + s.dom = d +} + +func (s *testAuditLogSuite) teardown() { + s.dom.Close() + s.store.Close() +} + +func TestAuditLog(t *testing.T) { + var s testAuditLogSuite + s.setup(t) + defer s.teardown() + + var buf1 bytes.Buffer + tk := testkit.NewAsyncTestKit(t, s.store) + ctx := tk.OpenSession(context.Background(), "test") + buf1.WriteString("Use use `test`\n") // Workaround for the testing framework. + + tk.MustExec(ctx, "use test") + buf1.WriteString("Use use `test`\n") + + tk.MustExec(ctx, "create table t (id int primary key, a int, b int unique)") + buf1.WriteString("CreateTable create table `t` ( `id` int primary key , `a` int , `b` int unique )\n") + + tk.MustExec(ctx, "create view v1 as select * from t where id > 2") + buf1.WriteString("CreateView create view `v1` as select * from `t` where `id` > ?\n") + + tk.MustExec(ctx, "drop view v1") + buf1.WriteString("DropView drop view `v1`\n") + + tk.MustExec(ctx, "create session binding for select * from t where b = 123 using select * from t ignore index(b) where b = 123") + buf1.WriteString("CreateBinding create session binding for select * from `t` where `b` = ? using select * from `t` where `b` = ?\n") + + tk.MustExec(ctx, "prepare mystmt from 'select ? as num from DUAL'") + buf1.WriteString("Prepare prepare `mystmt` from ?\n") + + tk.MustExec(ctx, "set @number = 5") + buf1.WriteString("Set set @number = ?\n") + + tk.MustExec(ctx, "execute mystmt using @number") + buf1.WriteString("Select select ? as `num` from dual\n") + + tk.MustQuery(ctx, "trace format = 'row' select * from t") + buf1.WriteString("Trace trace format = ? select * from `t`\n") + + tk.MustExec(ctx, "shutdown") + buf1.WriteString("Shutdown shutdown\n") + + require.Equal(t, buf1.String(), s.Buffer.String()) +} + +func Validate(ctx context.Context, m *plugin.Manifest) error { + return nil +} + +// OnInit implements TiDB plugin's OnInit SPI. +func OnInit(ctx context.Context, manifest *plugin.Manifest) error { + return nil +} + +// OnShutdown implements TiDB plugin's OnShutdown SPI. +func OnShutdown(ctx context.Context, manifest *plugin.Manifest) error { + return nil +} + +// OnGeneralEvent implements TiDB Audit plugin's OnGeneralEvent SPI. +func (s *testAuditLogSuite) OnGeneralEvent(ctx context.Context, sctx *variable.SessionVars, event plugin.GeneralEvent, cmd string) { + if sctx != nil { + normalized, _ := sctx.StmtCtx.SQLDigest() + fmt.Fprintln(&s.Buffer, sctx.StmtCtx.StmtType, normalized) + } +} + +// OnConnectionEvent implements TiDB Audit plugin's OnConnectionEvent SPI. +func OnConnectionEvent(ctx context.Context, event plugin.ConnectionEvent, info *variable.ConnectionInfo) error { + return nil +} diff --git a/plugin/main_test.go b/plugin/main_test.go index 108caec196390..25773ec64f3c5 100644 --- a/plugin/main_test.go +++ b/plugin/main_test.go @@ -22,5 +22,12 @@ import ( func TestMain(m *testing.M) { testbridge.WorkaroundGoCheckFlags() - goleak.VerifyTestMain(m) + + opts := []goleak.Option{ + goleak.IgnoreTopFunction("go.etcd.io/etcd/pkg/logutil.(*MergeLogger).outputLoop"), + goleak.IgnoreTopFunction("go.opencensus.io/stats/view.(*worker).start"), + goleak.IgnoreTopFunction("time.Sleep"), + } + + goleak.VerifyTestMain(m, opts...) } diff --git a/session/bootstrap.go b/session/bootstrap.go index 2572189fae957..372badc39932b 100644 --- a/session/bootstrap.go +++ b/session/bootstrap.go @@ -1378,9 +1378,6 @@ func upgradeToVer67(s Session, ver int64) { if err != nil { logutil.BgLogger().Fatal("upgradeToVer67 error", zap.Error(err)) } - if rs != nil { - defer terror.Call(rs.Close) - } req := rs.NewChunk() iter := chunk.NewIterator4Chunk(req) p := parser.New() @@ -1395,6 +1392,7 @@ func upgradeToVer67(s Session, ver int64) { } updateBindInfo(iter, p, bindMap) } + terror.Call(rs.Close) mustExecute(s, "DELETE FROM mysql.bind_info where source != 'builtin'") for original, bind := range bindMap { diff --git a/session/bootstrap_test.go b/session/bootstrap_test.go index 2349a312fe042..26208f6e91aaa 100644 --- a/session/bootstrap_test.go +++ b/session/bootstrap_test.go @@ -60,10 +60,14 @@ func (s *testBootstrapSuite) TestBootstrap(c *C) { c.Assert(se.Auth(&auth.UserIdentity{Username: "root", Hostname: "anyhost"}, []byte(""), []byte("")), IsTrue) mustExecSQL(c, se, "USE test;") // Check privilege tables. - mustExecSQL(c, se, "SELECT * from mysql.global_priv;") - mustExecSQL(c, se, "SELECT * from mysql.db;") - mustExecSQL(c, se, "SELECT * from mysql.tables_priv;") - mustExecSQL(c, se, "SELECT * from mysql.columns_priv;") + rs := mustExecSQL(c, se, "SELECT * from mysql.global_priv;") + c.Assert(rs.Close(), IsNil) + rs = mustExecSQL(c, se, "SELECT * from mysql.db;") + c.Assert(rs.Close(), IsNil) + rs = mustExecSQL(c, se, "SELECT * from mysql.tables_priv;") + c.Assert(rs.Close(), IsNil) + rs = mustExecSQL(c, se, "SELECT * from mysql.columns_priv;") + c.Assert(rs.Close(), IsNil) // Check privilege tables. r = mustExecSQL(c, se, "SELECT COUNT(*) from mysql.global_variables;") c.Assert(r, NotNil) diff --git a/session/session_test.go b/session/session_test.go index 42378b536b8c2..ec44c00c0ae42 100644 --- a/session/session_test.go +++ b/session/session_test.go @@ -1328,8 +1328,9 @@ func (s *testSessionSuite) TestPrepare(c *C) { c.Assert(id, Equals, uint32(1)) c.Assert(ps, Equals, 1) tk.MustExec(`set @a=1`) - _, err = tk.Se.ExecutePreparedStmt(ctx, id, []types.Datum{types.NewDatum("1")}) + rs, err := tk.Se.ExecutePreparedStmt(ctx, id, []types.Datum{types.NewDatum("1")}) c.Assert(err, IsNil) + rs.Close() err = tk.Se.DropPreparedStmt(id) c.Assert(err, IsNil) @@ -1349,7 +1350,7 @@ func (s *testSessionSuite) TestPrepare(c *C) { tk.MustExec("insert multiexec values (1, 1), (2, 2)") id, _, _, err = tk.Se.PrepareStmt("select a from multiexec where b = ? order by b") c.Assert(err, IsNil) - rs, err := tk.Se.ExecutePreparedStmt(ctx, id, []types.Datum{types.NewDatum(1)}) + rs, err = tk.Se.ExecutePreparedStmt(ctx, id, []types.Datum{types.NewDatum(1)}) c.Assert(err, IsNil) rs.Close() rs, err = tk.Se.ExecutePreparedStmt(ctx, id, []types.Datum{types.NewDatum(2)}) @@ -1963,17 +1964,26 @@ func (s *testSessionSuite3) TestCaseInsensitive(c *C) { tk.MustExec("create table T (a text, B int)") tk.MustExec("insert t (A, b) values ('aaa', 1)") - rs, _ := tk.Exec("select * from t") + rs, err := tk.Exec("select * from t") + c.Assert(err, IsNil) fields := rs.Fields() c.Assert(fields[0].ColumnAsName.O, Equals, "a") c.Assert(fields[1].ColumnAsName.O, Equals, "B") - rs, _ = tk.Exec("select A, b from t") + rs.Close() + + rs, err = tk.Exec("select A, b from t") + c.Assert(err, IsNil) fields = rs.Fields() c.Assert(fields[0].ColumnAsName.O, Equals, "A") c.Assert(fields[1].ColumnAsName.O, Equals, "b") - rs, _ = tk.Exec("select a as A from t where A > 0") + rs.Close() + + rs, err = tk.Exec("select a as A from t where A > 0") + c.Assert(err, IsNil) fields = rs.Fields() c.Assert(fields[0].ColumnAsName.O, Equals, "A") + rs.Close() + tk.MustExec("update T set b = B + 1") tk.MustExec("update T set B = b + 1") tk.MustQuery("select b from T").Check(testkit.Rows("3")) @@ -3907,6 +3917,7 @@ func (s *testSessionSerialSuite) TestDoDDLJobQuit(c *C) { func (s *testBackupRestoreSuite) TestBackupAndRestore(c *C) { // only run BR SQL integration test with tikv store. + // TODO move this test to BR integration tests. if *withTiKV { cfg := config.GetGlobalConfig() cfg.Store = "tikv" @@ -3928,7 +3939,7 @@ func (s *testBackupRestoreSuite) TestBackupAndRestore(c *C) { tmpDir := path.Join(os.TempDir(), "bk1") os.RemoveAll(tmpDir) // backup database to tmp dir - tk.MustQuery("backup database * to 'local://" + tmpDir + "'") + tk.MustQuery("backup database br to 'local://" + tmpDir + "'") // remove database for recovery tk.MustExec("drop database br") @@ -3939,7 +3950,6 @@ func (s *testBackupRestoreSuite) TestBackupAndRestore(c *C) { tk.MustExec("use br") tk.MustQuery("select count(*) from t1").Check(testkit.Rows("3")) tk.MustExec("drop database br") - tk.MustExec("drop database br02") } } diff --git a/sessionctx/stmtctx/stmtctx.go b/sessionctx/stmtctx/stmtctx.go index 97565705bedc4..23d4040e5c3be 100644 --- a/sessionctx/stmtctx/stmtctx.go +++ b/sessionctx/stmtctx/stmtctx.go @@ -63,9 +63,7 @@ type StatementContext struct { // IsDDLJobInQueue is used to mark whether the DDL job is put into the queue. // If IsDDLJobInQueue is true, it means the DDL job is in the queue of storage, and it can be handled by the DDL worker. - IsDDLJobInQueue bool - // InReorgAttribute is indicated for cast function that the transition is a kind of reorg process. - InReorgAttribute bool + IsDDLJobInQueue bool InInsertStmt bool InUpdateStmt bool InDeleteStmt bool @@ -182,6 +180,9 @@ type StatementContext struct { DiskTracker disk.Tracker LogOnExceed [2]memory.LogOnExceed } + + // InVerboseExplain indicates the statement is "explain format='verbose' ...". + InVerboseExplain bool } // StmtHints are SessionVars related sql hints. diff --git a/statistics/handle/handle_test.go b/statistics/handle/handle_test.go index 0b6e275a71c1c..ca44f560d6e02 100644 --- a/statistics/handle/handle_test.go +++ b/statistics/handle/handle_test.go @@ -901,11 +901,11 @@ func (s *testStatsSuite) prepareForGlobalStatsWithOpts(c *C, tk *testkit.TestKit } // nolint:unused -func (s *testStatsSuite) checkForGlobalStatsWithOpts(c *C, tk *testkit.TestKit, p string, topn, buckets int) { +func (s *testStatsSuite) checkForGlobalStatsWithOpts(c *C, tk *testkit.TestKit, t string, p string, topn, buckets int) { delta := buckets/2 + 1 for _, isIdx := range []int{0, 1} { - c.Assert(len(tk.MustQuery(fmt.Sprintf("show stats_topn where partition_name='%v' and is_index=%v", p, isIdx)).Rows()), Equals, topn) - numBuckets := len(tk.MustQuery(fmt.Sprintf("show stats_buckets where partition_name='%v' and is_index=%v", p, isIdx)).Rows()) + c.Assert(len(tk.MustQuery(fmt.Sprintf("show stats_topn where table_name='%v' and partition_name='%v' and is_index=%v", t, p, isIdx)).Rows()), Equals, topn) + numBuckets := len(tk.MustQuery(fmt.Sprintf("show stats_buckets where table_name='%v' and partition_name='%v' and is_index=%v", t, p, isIdx)).Rows()) // since the hist-building algorithm doesn't stipulate the final bucket number to be equal to the expected number exactly, // we have to check the results by a range here. c.Assert(numBuckets >= buckets-delta, IsTrue) @@ -942,9 +942,9 @@ func (s *testStatsSuite) TestAnalyzeGlobalStatsWithOpts(c *C) { sql := fmt.Sprintf("analyze table test_gstats_opt with %v topn, %v buckets", ca.topn, ca.buckets) if !ca.err { tk.MustExec(sql) - s.checkForGlobalStatsWithOpts(c, tk, "global", ca.topn, ca.buckets) - s.checkForGlobalStatsWithOpts(c, tk, "p0", ca.topn, ca.buckets) - s.checkForGlobalStatsWithOpts(c, tk, "p1", ca.topn, ca.buckets) + s.checkForGlobalStatsWithOpts(c, tk, "test_gstats_opt", "global", ca.topn, ca.buckets) + s.checkForGlobalStatsWithOpts(c, tk, "test_gstats_opt", "p0", ca.topn, ca.buckets) + s.checkForGlobalStatsWithOpts(c, tk, "test_gstats_opt", "p1", ca.topn, ca.buckets) } else { err := tk.ExecToErr(sql) c.Assert(err, NotNil) @@ -961,25 +961,25 @@ func (s *testStatsSuite) TestAnalyzeGlobalStatsWithOpts2(c *C) { s.prepareForGlobalStatsWithOpts(c, tk, "test_gstats_opt2", "test_gstats_opt2") tk.MustExec("analyze table test_gstats_opt2 with 20 topn, 50 buckets, 1000 samples") - s.checkForGlobalStatsWithOpts(c, tk, "global", 2, 50) - s.checkForGlobalStatsWithOpts(c, tk, "p0", 1, 50) - s.checkForGlobalStatsWithOpts(c, tk, "p1", 1, 50) + s.checkForGlobalStatsWithOpts(c, tk, "test_gstats_opt2", "global", 2, 50) + s.checkForGlobalStatsWithOpts(c, tk, "test_gstats_opt2", "p0", 1, 50) + s.checkForGlobalStatsWithOpts(c, tk, "test_gstats_opt2", "p1", 1, 50) // analyze a partition to let its options be different with others' tk.MustExec("analyze table test_gstats_opt2 partition p0 with 10 topn, 20 buckets") - s.checkForGlobalStatsWithOpts(c, tk, "global", 10, 20) // use new options - s.checkForGlobalStatsWithOpts(c, tk, "p0", 10, 20) - s.checkForGlobalStatsWithOpts(c, tk, "p1", 1, 50) + s.checkForGlobalStatsWithOpts(c, tk, "test_gstats_opt2", "global", 10, 20) // use new options + s.checkForGlobalStatsWithOpts(c, tk, "test_gstats_opt2", "p0", 10, 20) + s.checkForGlobalStatsWithOpts(c, tk, "test_gstats_opt2", "p1", 1, 50) tk.MustExec("analyze table test_gstats_opt2 partition p1 with 100 topn, 200 buckets") - s.checkForGlobalStatsWithOpts(c, tk, "global", 100, 200) - s.checkForGlobalStatsWithOpts(c, tk, "p0", 10, 20) - s.checkForGlobalStatsWithOpts(c, tk, "p1", 100, 200) + s.checkForGlobalStatsWithOpts(c, tk, "test_gstats_opt2", "global", 100, 200) + s.checkForGlobalStatsWithOpts(c, tk, "test_gstats_opt2", "p0", 10, 20) + s.checkForGlobalStatsWithOpts(c, tk, "test_gstats_opt2", "p1", 100, 200) tk.MustExec("analyze table test_gstats_opt2 partition p0 with 20 topn") // change back to 20 topn - s.checkForGlobalStatsWithOpts(c, tk, "global", 20, 256) - s.checkForGlobalStatsWithOpts(c, tk, "p0", 20, 256) - s.checkForGlobalStatsWithOpts(c, tk, "p1", 100, 200) + s.checkForGlobalStatsWithOpts(c, tk, "test_gstats_opt2", "global", 20, 256) + s.checkForGlobalStatsWithOpts(c, tk, "test_gstats_opt2", "p0", 20, 256) + s.checkForGlobalStatsWithOpts(c, tk, "test_gstats_opt2", "p1", 100, 200) } func (s *testStatsSuite) TestGlobalStatsHealthy(c *C) { diff --git a/table/column.go b/table/column.go index 843ab857a1fd2..6b9c2996b7509 100644 --- a/table/column.go +++ b/table/column.go @@ -224,11 +224,13 @@ func handleZeroDatetime(ctx sessionctx.Context, col *model.ColumnInfo, casted ty ignoreErr := sc.DupKeyAsWarning + // Timestamp in MySQL is since EPOCH 1970-01-01 00:00:00 UTC and can by definition not have invalid dates! + // Zero date is special for MySQL timestamp and *NOT* 1970-01-01 00:00:00, but 0000-00-00 00:00:00! // in MySQL 8.0, the Timestamp's case is different to Datetime/Date, as shown below: // // | | NZD | NZD|ST | ELSE | ELSE|ST | // | ------------ | ----------------- | ------- | ----------------- | -------- | - // | `0000-00-01` | Success + Warning | Error | Success + Warning | Error | + // | `0000-00-01` | Truncate + Warning| Error | Truncate + Warning| Error | // | `0000-00-00` | Success + Warning | Error | Success | Success | // // * **NZD**: NO_ZERO_DATE_MODE @@ -273,21 +275,13 @@ func handleZeroDatetime(ctx sessionctx.Context, col *model.ColumnInfo, casted ty // CastValue casts a value based on column type. // If forceIgnoreTruncate is true, truncated errors will be ignored. -// If returnOverflow is true, don't handle overflow errors in this function. +// If returnErr is true, directly return any conversion errors. // It's safe now and it's the same as the behavior of select statement. // Set it to true only in FillVirtualColumnValue and UnionScanExec.Next() // If the handle of err is changed latter, the behavior of forceIgnoreTruncate also need to change. // TODO: change the third arg to TypeField. Not pass ColumnInfo. func CastValue(ctx sessionctx.Context, val types.Datum, col *model.ColumnInfo, returnErr, forceIgnoreTruncate bool) (casted types.Datum, err error) { sc := ctx.GetSessionVars().StmtCtx - // Set the reorg attribute for cast value functionality. - if col.ChangeStateInfo != nil { - origin := ctx.GetSessionVars().StmtCtx.InReorgAttribute - ctx.GetSessionVars().StmtCtx.InReorgAttribute = true - defer func() { - ctx.GetSessionVars().StmtCtx.InReorgAttribute = origin - }() - } casted, err = val.ConvertTo(sc, &col.FieldType) // TODO: make sure all truncate errors are handled by ConvertTo. if returnErr && err != nil { @@ -302,7 +296,12 @@ func CastValue(ctx sessionctx.Context, val types.Datum, col *model.ColumnInfo, r } else if (sc.InInsertStmt || sc.InUpdateStmt) && !casted.IsNull() && (val.Kind() != types.KindMysqlTime || !val.GetMysqlTime().IsZero()) && (col.Tp == mysql.TypeDate || col.Tp == mysql.TypeDatetime || col.Tp == mysql.TypeTimestamp) { - if innCasted, exit, innErr := handleZeroDatetime(ctx, col, casted, val.GetString(), types.ErrWrongValue.Equal(err)); exit { + str, err1 := val.ToString() + if err1 != nil { + logutil.BgLogger().Warn("Datum ToString failed", zap.Stringer("Datum", val), zap.Error(err1)) + str = val.GetString() + } + if innCasted, exit, innErr := handleZeroDatetime(ctx, col, casted, str, types.ErrWrongValue.Equal(err)); exit { return innCasted, innErr } } diff --git a/types/datum.go b/types/datum.go index 93172fc1117a1..dd2d0557099b8 100644 --- a/types/datum.go +++ b/types/datum.go @@ -1128,16 +1128,10 @@ func (d *Datum) convertToMysqlTimestamp(sc *stmtctx.StatementContext, target *Fi } switch d.k { case KindMysqlTime: - // `select timestamp(cast("1000-01-02 23:59:59" as date)); ` casts usage will succeed. - // Alter datetime("1000-01-02 23:59:59") to timestamp will error. - if sc.InReorgAttribute { - t, err = d.GetMysqlTime().Convert(sc, target.Tp) - if err != nil { - ret.SetMysqlTime(t) - return ret, errors.Trace(ErrWrongValue.GenWithStackByArgs(DateTimeStr, t.String())) - } - } else { - t = d.GetMysqlTime() + t, err = d.GetMysqlTime().Convert(sc, target.Tp) + if err != nil { + ret.SetMysqlTime(ZeroTimestamp) + return ret, errors.Trace(ErrWrongValue.GenWithStackByArgs(TimestampStr, t.String())) } t, err = t.RoundFrac(sc, fsp) case KindMysqlDuration: