diff --git a/ddl/ddl_api_test.go b/ddl/ddl_api_test.go index 4e2e0912b4923..7492d5bd4ca1b 100644 --- a/ddl/ddl_api_test.go +++ b/ddl/ddl_api_test.go @@ -169,7 +169,7 @@ func TestCreateDropCreateTable(t *testing.T) { originHook := dom.DDL().GetHook() onJobUpdated := func(job *model.Job) { if job.Type == model.ActionDropTable && job.SchemaState == model.StateWriteOnly && !createTable { - fpErr = failpoint.Enable("github.com/pingcap/tidb/pkg/ddl/mockOwnerCheckAllVersionSlow", fmt.Sprintf("return(%d)", job.ID)) + fpErr = failpoint.Enable("github.com/pingcap/tidb/ddl/mockOwnerCheckAllVersionSlow", fmt.Sprintf("return(%d)", job.ID)) wg.Add(1) go func() { _, createErr = tk1.Exec("create table t (b int);") @@ -187,7 +187,7 @@ func TestCreateDropCreateTable(t *testing.T) { wg.Wait() require.NoError(t, createErr) require.NoError(t, fpErr) - require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/ddl/mockOwnerCheckAllVersionSlow")) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/ddl/mockOwnerCheckAllVersionSlow")) rs := tk.MustQuery("admin show ddl jobs 3;").Rows() create1JobID := rs[0][0].(string) diff --git a/extension/event_listener_test.go b/extension/event_listener_test.go index 00606fabd5f99..a9947de9d5758 100644 --- a/extension/event_listener_test.go +++ b/extension/event_listener_test.go @@ -329,12 +329,18 @@ func TestExtensionStmtEvents(t *testing.T) { dispatchData: append([]byte{mysql.ComInitDB}, []byte("db1")...), originalText: "use `db1`", redactText: "use `db1`", + tables: []stmtctx.TableEntry{ + {DB: "db1", Table: ""}, + }, }, { dispatchData: append([]byte{mysql.ComInitDB}, []byte("noexistdb")...), originalText: "use `noexistdb`", redactText: "use `noexistdb`", err: "[schema:1049]Unknown database 'noexistdb'", + tables: []stmtctx.TableEntry{ + {DB: "noexistdb", Table: ""}, + }, }, } @@ -424,7 +430,8 @@ func TestExtensionStmtEvents(t *testing.T) { r := subCase.tables[j] return l.DB < r.DB || (l.DB == r.DB && l.Table < r.Table) }) - require.Equal(t, subCase.tables, record.tables) + require.Equal(t, subCase.tables, record.tables, + "sql: %s\noriginalText: %s\n", subCase.sql, subCase.originalText) require.Equal(t, len(subCase.executeParams), len(record.params)) for k, param := range subCase.executeParams { diff --git a/extension/session.go b/extension/session.go index e35f31ec68920..ba430bb3675a8 100644 --- a/extension/session.go +++ b/extension/session.go @@ -85,6 +85,8 @@ type StmtEventInfo interface { // AffectedRows will return the affected rows of the current statement AffectedRows() uint64 // RelatedTables will return the related tables of the current statement + // For statements succeeding to build logical plan, it uses the `visitinfo` to get the related tables + // For statements failing to build logical plan, it traverses the ast node to get the related tables RelatedTables() []stmtctx.TableEntry // GetError will return the error when the current statement is failed GetError() error diff --git a/parser/ast/misc.go b/parser/ast/misc.go index 5b075f8ca1ef1..677578a4e7aa9 100644 --- a/parser/ast/misc.go +++ b/parser/ast/misc.go @@ -950,6 +950,13 @@ func (n *FlushStmt) Accept(v Visitor) (Node, bool) { return v.Leave(newNode) } n = newNode.(*FlushStmt) + for i, t := range n.Tables { + node, ok := t.Accept(v) + if !ok { + return n, false + } + n.Tables[i] = node.(*TableName) + } return v.Leave(n) } diff --git a/planner/core/BUILD.bazel b/planner/core/BUILD.bazel index b82205c7b79e7..35ffd91bf1ad8 100644 --- a/planner/core/BUILD.bazel +++ b/planner/core/BUILD.bazel @@ -218,6 +218,7 @@ go_test( "rule_join_reorder_test.go", "rule_result_reorder_test.go", "stringer_test.go", + "util_test.go", ], data = glob(["testdata/**"]), embed = [":core"], diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index 3d5e85fdbdf0a..3955e9092457d 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -3155,8 +3155,7 @@ func (g *gbyResolver) Leave(inNode ast.Node) (ast.Node, bool) { } func tblInfoFromCol(from ast.ResultSetNode, name *types.FieldName) *model.TableInfo { - var tableList []*ast.TableName - tableList = extractTableList(from, tableList, true) + tableList := ExtractTableList(from, true) for _, field := range tableList { if field.Name.L == name.TblName.L { return field.TableInfo @@ -5718,8 +5717,7 @@ func (b *PlanBuilder) buildUpdate(ctx context.Context, update *ast.UpdateStmt) ( return nil, err } - var tableList []*ast.TableName - tableList = extractTableList(update.TableRefs.TableRefs, tableList, false) + tableList := ExtractTableList(update.TableRefs.TableRefs, false) for _, t := range tableList { dbName := t.Schema.L if dbName == "" { @@ -6247,8 +6245,7 @@ func (b *PlanBuilder) buildDelete(ctx context.Context, ds *ast.DeleteStmt) (Plan } } else { // Delete from a, b, c, d. - var tableList []*ast.TableName - tableList = extractTableList(ds.TableRefs.TableRefs, tableList, false) + tableList := ExtractTableList(ds.TableRefs.TableRefs, false) for _, v := range tableList { if isCTE(v) { return nil, ErrNonUpdatableTable.GenWithStackByArgs(v.Name.O, "DELETE") @@ -7063,17 +7060,6 @@ func buildWindowSpecs(specs []ast.WindowSpec) (map[string]*ast.WindowSpec, error return specsMap, nil } -func unfoldSelectList(list *ast.SetOprSelectList, unfoldList *ast.SetOprSelectList) { - for _, sel := range list.Selects { - switch s := sel.(type) { - case *ast.SelectStmt: - unfoldList.Selects = append(unfoldList.Selects, s) - case *ast.SetOprSelectList: - unfoldSelectList(s, unfoldList) - } - } -} - type updatableTableListResolver struct { updatableTableList []*ast.TableName } @@ -7102,111 +7088,149 @@ func (u *updatableTableListResolver) Leave(inNode ast.Node) (ast.Node, bool) { return inNode, true } -// extractTableList extracts all the TableNames from node. +// ExtractTableList is a wrapper for tableListExtractor and removes duplicate TableName // If asName is true, extract AsName prior to OrigName. // Privilege check should use OrigName, while expression may use AsName. -// TODO: extracting all tables by vistor model maybe a better way -func extractTableList(node ast.Node, input []*ast.TableName, asName bool) []*ast.TableName { - switch x := node.(type) { - case *ast.SelectStmt: - if x.From != nil { - input = extractTableList(x.From.TableRefs, input, asName) - } - if x.Where != nil { - input = extractTableList(x.Where, input, asName) - } - if x.With != nil { - for _, cte := range x.With.CTEs { - input = extractTableList(cte.Query, input, asName) - } - } - for _, f := range x.Fields.Fields { - if s, ok := f.Expr.(*ast.SubqueryExpr); ok { - input = extractTableList(s, input, asName) - } - } - case *ast.DeleteStmt: - input = extractTableList(x.TableRefs.TableRefs, input, asName) - if x.IsMultiTable { - for _, t := range x.Tables.Tables { - input = extractTableList(t, input, asName) - } - } - if x.Where != nil { - input = extractTableList(x.Where, input, asName) - } - if x.With != nil { - for _, cte := range x.With.CTEs { - input = extractTableList(cte.Query, input, asName) - } - } - case *ast.UpdateStmt: - input = extractTableList(x.TableRefs.TableRefs, input, asName) - for _, e := range x.List { - input = extractTableList(e.Expr, input, asName) - } - if x.Where != nil { - input = extractTableList(x.Where, input, asName) - } - if x.With != nil { - for _, cte := range x.With.CTEs { - input = extractTableList(cte.Query, input, asName) +func ExtractTableList(node ast.Node, asName bool) []*ast.TableName { + if node == nil { + return []*ast.TableName{} + } + e := &tableListExtractor{ + asName: asName, + tableNames: []*ast.TableName{}, + } + node.Accept(e) + tableNames := e.tableNames + m := make(map[string]map[string]*ast.TableName) // k1: schemaName, k2: tableName, v: ast.TableName + for _, x := range tableNames { + k1, k2 := x.Schema.L, x.Name.L + // allow empty schema name OR empty table name + if k1 != "" || k2 != "" { + if _, ok := m[k1]; !ok { + m[k1] = make(map[string]*ast.TableName) } + m[k1][k2] = x } - case *ast.InsertStmt: - input = extractTableList(x.Table.TableRefs, input, asName) - input = extractTableList(x.Select, input, asName) - case *ast.SetOprStmt: - l := &ast.SetOprSelectList{} - unfoldSelectList(x.SelectList, l) - for _, s := range l.Selects { - input = extractTableList(s.(ast.ResultSetNode), input, asName) - } - case *ast.PatternInExpr: - if s, ok := x.Sel.(*ast.SubqueryExpr); ok { - input = extractTableList(s, input, asName) + } + tableNames = tableNames[:0] + for _, x := range m { + for _, v := range x { + tableNames = append(tableNames, v) } - case *ast.ExistsSubqueryExpr: - if s, ok := x.Sel.(*ast.SubqueryExpr); ok { - input = extractTableList(s, input, asName) + } + return tableNames +} + +// tableListExtractor extracts all the TableNames from node. +type tableListExtractor struct { + asName bool + tableNames []*ast.TableName +} + +func (e *tableListExtractor) Enter(n ast.Node) (_ ast.Node, skipChildren bool) { + innerExtract := func(inner ast.Node) []*ast.TableName { + if inner == nil { + return nil } - case *ast.BinaryOperationExpr: - if s, ok := x.R.(*ast.SubqueryExpr); ok { - input = extractTableList(s, input, asName) + innerExtractor := &tableListExtractor{ + asName: e.asName, + tableNames: []*ast.TableName{}, } - case *ast.SubqueryExpr: - input = extractTableList(x.Query, input, asName) - case *ast.Join: - input = extractTableList(x.Left, input, asName) - input = extractTableList(x.Right, input, asName) + inner.Accept(innerExtractor) + return innerExtractor.tableNames + } + + switch x := n.(type) { + case *ast.TableName: + e.tableNames = append(e.tableNames, x) case *ast.TableSource: if s, ok := x.Source.(*ast.TableName); ok { - if x.AsName.L != "" && asName { + if x.AsName.L != "" && e.asName { newTableName := *s newTableName.Name = x.AsName newTableName.Schema = model.NewCIStr("") - input = append(input, &newTableName) + e.tableNames = append(e.tableNames, &newTableName) } else { - input = append(input, s) + e.tableNames = append(e.tableNames, s) } } else if s, ok := x.Source.(*ast.SelectStmt); ok { if s.From != nil { - var innerList []*ast.TableName - innerList = extractTableList(s.From.TableRefs, innerList, asName) + innerList := innerExtract(s.From.TableRefs) if len(innerList) > 0 { innerTableName := innerList[0] - if x.AsName.L != "" && asName { + if x.AsName.L != "" && e.asName { newTableName := *innerList[0] newTableName.Name = x.AsName newTableName.Schema = model.NewCIStr("") innerTableName = &newTableName } - input = append(input, innerTableName) + e.tableNames = append(e.tableNames, innerTableName) } } } + return n, true + + case *ast.ShowStmt: + if x.DBName != "" { + e.tableNames = append(e.tableNames, &ast.TableName{Schema: model.NewCIStr(x.DBName)}) + } + case *ast.CreateDatabaseStmt: + e.tableNames = append(e.tableNames, &ast.TableName{Schema: x.Name}) + case *ast.AlterDatabaseStmt: + e.tableNames = append(e.tableNames, &ast.TableName{Schema: x.Name}) + case *ast.DropDatabaseStmt: + e.tableNames = append(e.tableNames, &ast.TableName{Schema: x.Name}) + + case *ast.FlashBackDatabaseStmt: + e.tableNames = append(e.tableNames, &ast.TableName{Schema: x.DBName}) + e.tableNames = append(e.tableNames, &ast.TableName{Schema: model.NewCIStr(x.NewName)}) + case *ast.FlashBackToTimestampStmt: + if x.DBName.L != "" { + e.tableNames = append(e.tableNames, &ast.TableName{Schema: x.DBName}) + } + case *ast.FlashBackTableStmt: + if newName := x.NewName; newName != "" { + e.tableNames = append(e.tableNames, &ast.TableName{ + Schema: x.Table.Schema, + Name: model.NewCIStr(newName)}) + } + + case *ast.GrantStmt: + if x.ObjectType == ast.ObjectTypeTable || x.ObjectType == ast.ObjectTypeNone { + if x.Level.Level == ast.GrantLevelDB || x.Level.Level == ast.GrantLevelTable { + e.tableNames = append(e.tableNames, &ast.TableName{ + Schema: model.NewCIStr(x.Level.DBName), + Name: model.NewCIStr(x.Level.TableName), + }) + } + } + case *ast.RevokeStmt: + if x.ObjectType == ast.ObjectTypeTable || x.ObjectType == ast.ObjectTypeNone { + if x.Level.Level == ast.GrantLevelDB || x.Level.Level == ast.GrantLevelTable { + e.tableNames = append(e.tableNames, &ast.TableName{ + Schema: model.NewCIStr(x.Level.DBName), + Name: model.NewCIStr(x.Level.TableName), + }) + } + } + case *ast.BRIEStmt: + if x.Kind == ast.BRIEKindBackup || x.Kind == ast.BRIEKindRestore { + for _, v := range x.Schemas { + e.tableNames = append(e.tableNames, &ast.TableName{Schema: model.NewCIStr(v)}) + } + } + case *ast.UseStmt: + e.tableNames = append(e.tableNames, &ast.TableName{Schema: model.NewCIStr(x.DBName)}) + case *ast.ExecuteStmt: + if v, ok := x.PrepStmt.(*PlanCacheStmt); ok { + e.tableNames = append(e.tableNames, innerExtract(v.PreparedAst.Stmt)...) + } } - return input + return n, false +} + +func (*tableListExtractor) Leave(n ast.Node) (ast.Node, bool) { + return n, true } func collectTableName(node ast.ResultSetNode, updatableName *map[string]bool, info *map[string]*ast.TableName) { diff --git a/planner/core/point_get_plan.go b/planner/core/point_get_plan.go index a56c20a545481..bab0978a6ec4a 100644 --- a/planner/core/point_get_plan.go +++ b/planner/core/point_get_plan.go @@ -1663,8 +1663,7 @@ func buildPointUpdatePlan(ctx sessionctx.Context, pointPlan PhysicalPlan, dbName } if tbl.GetPartitionInfo() != nil { pt := t.(table.PartitionedTable) - var updateTableList []*ast.TableName - updateTableList = extractTableList(updateStmt.TableRefs.TableRefs, updateTableList, true) + updateTableList := ExtractTableList(updateStmt.TableRefs.TableRefs, true) updatePlan.PartitionedTable = make([]table.PartitionedTable, 0, len(updateTableList)) for _, updateTable := range updateTableList { if len(updateTable.PartitionNames) > 0 { diff --git a/planner/core/preprocess.go b/planner/core/preprocess.go index ae1c4006b4c66..a0699c9fa4c9b 100644 --- a/planner/core/preprocess.go +++ b/planner/core/preprocess.go @@ -522,7 +522,7 @@ func (p *preprocessor) checkBindGrammar(originNode, hintedNode ast.StmtNode, def } // Check the bind operation is not on any temporary table. - tblNames := extractTableList(originNode, nil, false) + tblNames := ExtractTableList(originNode, false) for _, tn := range tblNames { tbl, err := p.tableByName(tn) if err != nil { diff --git a/planner/core/util_test.go b/planner/core/util_test.go new file mode 100644 index 0000000000000..d3720110d8058 --- /dev/null +++ b/planner/core/util_test.go @@ -0,0 +1,318 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package core + +import ( + "fmt" + "sort" + "strings" + "testing" + + "github.com/pingcap/tidb/parser" + "github.com/pingcap/tidb/parser/ast" + "github.com/pingcap/tidb/parser/model" + "github.com/stretchr/testify/require" +) + +func tableNamesAsStr(tableNames []*ast.TableName) string { + names := []string{} + for _, tn := range tableNames { + names = append(names, fmt.Sprintf("[%s.%s]", tn.Schema.L, tn.Name.L)) + } + return strings.Join(names, ",") +} + +func sortTableNames(tableNames []*ast.TableName) { + sort.Slice(tableNames, func(i, j int) bool { + if tableNames[i].Schema.L == tableNames[j].Schema.L { + return tableNames[i].Name.L < tableNames[j].Name.L + } + return tableNames[i].Schema.L < tableNames[j].Schema.L + }) +} + +func TestExtractTableList(t *testing.T) { + cases := []struct { + sql string + asName bool + expect []*ast.TableName + }{ + { + sql: "WITH t AS (SELECT * FROM t2) SELECT * FROM t, t1, mysql.user WHERE t1.a = mysql.user.username", + expect: []*ast.TableName{ + {Name: model.NewCIStr("t")}, + {Name: model.NewCIStr("t1")}, + {Name: model.NewCIStr("t2")}, + {Name: model.NewCIStr("user"), Schema: model.NewCIStr("mysql")}, + }, + }, + { + sql: "SELECT (SELECT a,b,c FROM t1) AS t WHERE t.a = 1", + expect: []*ast.TableName{ + {Name: model.NewCIStr("t1")}, + }, + }, + { + sql: "SELECT * FROM t, v AS w", + expect: []*ast.TableName{ + {Name: model.NewCIStr("t")}, + {Name: model.NewCIStr("v")}, + }, + }, + { + sql: "SELECT * FROM t, v AS w", + asName: true, + expect: []*ast.TableName{ + {Name: model.NewCIStr("t")}, + {Name: model.NewCIStr("w")}, + }, + }, + { + sql: `SELECT + AVG(all_scores.avg_score) AS avg_score, + student_name + FROM + ( + SELECT + student_id, + AVG(score) AS avg_score + FROM + scores + GROUP BY + student_id + ) AS all_scores + JOIN students ON students.student_id = all_scores.student_id + GROUP BY + student_id + ORDER BY + avg_score DESC`, + expect: []*ast.TableName{ + {Name: model.NewCIStr("scores")}, + {Name: model.NewCIStr("students")}, + }, + }, + { + sql: "DELETE FROM x.y z WHERE z.a > 0", + expect: []*ast.TableName{ + {Name: model.NewCIStr("y"), Schema: model.NewCIStr("x")}, + }, + }, + { + sql: "WITH t AS (SELECT * FROM v) DELETE FROM x.y z WHERE z.a > t.c", + expect: []*ast.TableName{ + {Name: model.NewCIStr("y"), Schema: model.NewCIStr("x")}, + {Name: model.NewCIStr("v")}, + }, + }, + { + sql: "DELETE FROM `t1` AS `t2` USE INDEX (`fld1`) WHERE `t2`.`fld`=2", + expect: []*ast.TableName{ + {Name: model.NewCIStr("t1")}, + }, + }, + { + sql: "DELETE FROM `t1` AS `t2` USE INDEX (`fld1`) WHERE `t2`.`fld`=2", + asName: true, + expect: []*ast.TableName{ + {Name: model.NewCIStr("t2")}, + }, + }, + { + sql: "UPDATE t1 USE INDEX(idx_a) JOIN t2 SET t1.price=t2.price WHERE t1.id=t2.id;", + expect: []*ast.TableName{ + {Name: model.NewCIStr("t1")}, + {Name: model.NewCIStr("t2")}, + }, + }, + { + sql: "INSERT INTO t (a,b,c) SELECT x,y,z FROM t1;", + expect: []*ast.TableName{ + {Name: model.NewCIStr("t")}, + {Name: model.NewCIStr("t1")}, + }, + }, + { + sql: "WITH t AS (SELECT * FROM v) SELECT a FROM t UNION SELECT b FROM t1", + expect: []*ast.TableName{ + {Name: model.NewCIStr("v")}, + {Name: model.NewCIStr("t")}, + {Name: model.NewCIStr("t1")}, + }, + }, + { + sql: "LOAD DATA INFILE '/a.csv' FORMAT 'sql file' INTO TABLE `t`", + expect: []*ast.TableName{ + {Name: model.NewCIStr("t")}, + }, + }, + { + sql: "batch on c limit 10 delete from t where t.c = 10", + expect: []*ast.TableName{ + {Name: model.NewCIStr("t")}, + }, + }, + { + sql: "split table t1 between () and () regions 10", + expect: []*ast.TableName{ + {Name: model.NewCIStr("t1")}, + }, + }, + { + sql: "show create table t", + expect: []*ast.TableName{ + {Name: model.NewCIStr("t")}, + }, + }, + { + sql: "show create database test", + expect: []*ast.TableName{ + {Schema: model.NewCIStr("test")}, + }, + }, + { + sql: "create database test", + expect: []*ast.TableName{ + {Schema: model.NewCIStr("test")}, + }, + }, + { + sql: "FLASHBACK DATABASE t1 TO t2", + expect: []*ast.TableName{ + {Schema: model.NewCIStr("t1")}, + {Schema: model.NewCIStr("t2")}, + }, + }, + { + sql: "flashback table t,t1,test.t2 to timestamp '2021-05-26 16:45:26'", + expect: []*ast.TableName{ + {Name: model.NewCIStr("t")}, + {Name: model.NewCIStr("t1")}, + {Name: model.NewCIStr("t2"), Schema: model.NewCIStr("test")}, + }, + }, + { + sql: "flashback database test to timestamp '2021-05-26 16:45:26'", + expect: []*ast.TableName{ + {Schema: model.NewCIStr("test")}, + }, + }, + { + sql: "flashback table t TO t1", + expect: []*ast.TableName{ + {Name: model.NewCIStr("t")}, + {Name: model.NewCIStr("t1")}, + }, + }, + { + sql: "create table t", + expect: []*ast.TableName{ + {Name: model.NewCIStr("t")}, + }, + }, + { + sql: "RENAME TABLE t TO t1, test.t2 TO test.t3", + expect: []*ast.TableName{ + {Name: model.NewCIStr("t")}, + {Name: model.NewCIStr("t1")}, + {Name: model.NewCIStr("t2"), Schema: model.NewCIStr("test")}, + {Name: model.NewCIStr("t3"), Schema: model.NewCIStr("test")}, + }, + }, + { + sql: "drop table test.t, t1", + expect: []*ast.TableName{ + {Name: model.NewCIStr("t1")}, + {Name: model.NewCIStr("t"), Schema: model.NewCIStr("test")}, + }, + }, + { + sql: "create view v as (select * from t)", + expect: []*ast.TableName{ + {Name: model.NewCIStr("v")}, + {Name: model.NewCIStr("t")}, + }, + }, + { + sql: "create sequence if not exists seq no cycle", + expect: []*ast.TableName{ + {Name: model.NewCIStr("seq")}, + }, + }, + { + sql: "CREATE INDEX idx ON t ( a ) VISIBLE INVISIBLE", + expect: []*ast.TableName{ + {Name: model.NewCIStr("t")}, + }, + }, + { + sql: "LOCK TABLE t1 WRITE, t2 READ", + expect: []*ast.TableName{ + {Name: model.NewCIStr("t1")}, + {Name: model.NewCIStr("t2")}, + }, + }, + { + sql: "grant select on test.* to u1", + expect: []*ast.TableName{ + {Schema: model.NewCIStr("test")}, + }, + }, + { + sql: "BACKUP TABLE a.b,c.d,e TO 'noop://'", + expect: []*ast.TableName{ + {Name: model.NewCIStr("b"), Schema: model.NewCIStr("a")}, + {Name: model.NewCIStr("d"), Schema: model.NewCIStr("c")}, + {Name: model.NewCIStr("e")}, + }, + }, + { + sql: "TRACE SELECT (SELECT a,b,c FROM t1) AS t WHERE t.a = 1", + expect: []*ast.TableName{ + {Name: model.NewCIStr("t1")}, + }, + }, + { + sql: "EXPLAIN SELECT (SELECT a,b,c FROM t1) AS t WHERE t.a = 1", + expect: []*ast.TableName{ + {Name: model.NewCIStr("t1")}, + }, + }, + { + sql: "PLAN REPLAYER DUMP EXPLAIN SELECT (SELECT a,b,c FROM t1) AS t WHERE t.a = 1", + expect: []*ast.TableName{ + {Name: model.NewCIStr("t1")}, + }, + }, + { + sql: "ALTER TABLE t COMPACT", + expect: []*ast.TableName{ + {Name: model.NewCIStr("t")}, + }, + }, + } + p := parser.New() + for i, c := range cases { + stmtNode, err := p.ParseOneStmt(c.sql, "", "") + require.NoError(t, err, "case %d sql: %s", i, c.sql) + tableNames := ExtractTableList(stmtNode, c.asName) + require.Len(t, tableNames, len(c.expect), "case %d sql: %s, len: %d, actual: %s", i, c.sql, len(tableNames), tableNamesAsStr(tableNames)) + sortTableNames(tableNames) + sortTableNames(c.expect) + for j, tn := range tableNames { + require.Equal(t, c.expect[j].Schema.L, tn.Schema.L, "case %d sql: %s, j: %d, actual: %s", i, c.sql, j, tableNamesAsStr(tableNames)) + require.Equal(t, c.expect[j].Name.L, tn.Name.L, "case %d sql: %s, j: %d, actual: %s", i, c.sql, j, tableNamesAsStr(tableNames)) + } + } +} diff --git a/server/extension.go b/server/extension.go index 3851f58c09826..8737d92b6a5a9 100644 --- a/server/extension.go +++ b/server/extension.go @@ -198,10 +198,26 @@ func (e *stmtEventInfo) AffectedRows() uint64 { } func (e *stmtEventInfo) RelatedTables() []stmtctx.TableEntry { - if e.sc == nil { - return nil + if useDB, ok := e.stmtNode.(*ast.UseStmt); ok { + return []stmtctx.TableEntry{{DB: useDB.DBName}} + } + if e.sc != nil && e.err == nil { + return e.sc.Tables + } + tableNames := core.ExtractTableList(e.stmtNode, false) + tableEntries := make([]stmtctx.TableEntry, 0, len(tableNames)) + for i, tableName := range tableNames { + if tableName != nil { + tableEntries = append(tableEntries, stmtctx.TableEntry{ + Table: tableName.Name.L, + DB: tableName.Schema.L, + }) + if tableEntries[i].DB == "" { + tableEntries[i].DB = e.sessVars.CurrentDB + } + } } - return e.sc.Tables + return tableEntries } func (e *stmtEventInfo) GetError() error { diff --git a/util/topsql/reporter/pubsub_test.go b/util/topsql/reporter/pubsub_test.go index 8ba35aabeb171..482f4f50f7f4b 100644 --- a/util/topsql/reporter/pubsub_test.go +++ b/util/topsql/reporter/pubsub_test.go @@ -87,7 +87,7 @@ func TestPubSubDataSink(t *testing.T) { _ = ds.run() }() - panicPath := "github.com/pingcap/tidb/pkg/util/topsql/reporter/mockGrpcLogPanic" + panicPath := "github.com/pingcap/tidb/util/topsql/reporter/mockGrpcLogPanic" require.NoError(t, failpoint.Enable(panicPath, "panic")) err := ds.TrySend(&ReportData{ DataRecords: []tipb.TopSQLRecord{{