diff --git a/executor/fktest/foreign_key_test.go b/executor/fktest/foreign_key_test.go index e17285fd3592b..a162bd22b96aa 100644 --- a/executor/fktest/foreign_key_test.go +++ b/executor/fktest/foreign_key_test.go @@ -2521,3 +2521,21 @@ func TestTableLockInForeignKeyCascade(t *testing.T) { tk.MustQuery("select * from t1 order by id").Check(testkit.Rows("2", "3")) tk.MustQuery("select * from t2 order by id").Check(testkit.Rows("2", "3")) } + +func TestForeignKeyIssue39732(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("set @@global.tidb_enable_stmt_summary=1") + tk.MustExec("set @@foreign_key_checks=1") + tk.MustExec("use test") + tk.MustExec("create user 'u1'@'%' identified by '';") + tk.MustExec("GRANT ALL PRIVILEGES ON *.* TO 'u1'@'%'") + err := tk.Session().Auth(&auth.UserIdentity{Username: "u1", Hostname: "localhost", CurrentUser: true, AuthUsername: "u1", AuthHostname: "%"}, nil, []byte("012345678901234567890")) + require.NoError(t, err) + tk.MustExec("create table t1 (id int key, leader int, index(leader), foreign key (leader) references t1(id) ON DELETE CASCADE);") + tk.MustExec("insert into t1 values (1, null), (10, 1), (11, 1), (20, 10)") + tk.MustExec(`prepare stmt1 from 'delete from t1 where id = ?';`) + tk.MustExec(`set @a = 1;`) + tk.MustExec("execute stmt1 using @a;") + tk.MustQuery("select * from t1 order by id").Check(testkit.Rows()) +} diff --git a/planner/core/encode.go b/planner/core/encode.go index 663f743517953..14931d4d1ef0a 100644 --- a/planner/core/encode.go +++ b/planner/core/encode.go @@ -41,7 +41,7 @@ func EncodeFlatPlan(flat *FlatPhysicalPlan) string { return "" } failpoint.Inject("mockPlanRowCount", func(val failpoint.Value) { - selectPlan := flat.Main.GetSelectPlan() + selectPlan, _ := flat.Main.GetSelectPlan() for _, op := range selectPlan { op.Origin.statsInfo().RowCount = float64(val.(int)) } @@ -262,7 +262,7 @@ type planDigester struct { // NormalizeFlatPlan normalizes a FlatPhysicalPlan and generates plan digest. func NormalizeFlatPlan(flat *FlatPhysicalPlan) (normalized string, digest *parser.Digest) { - selectPlan := flat.Main.GetSelectPlan() + selectPlan, selectPlanOffset := flat.Main.GetSelectPlan() if len(selectPlan) == 0 || !selectPlan[0].IsPhysicalPlan { return "", parser.NewDigest(nil) } @@ -274,18 +274,11 @@ func NormalizeFlatPlan(flat *FlatPhysicalPlan) (normalized string, digest *parse }() // assume an operator costs around 30 bytes, preallocate space for them d.buf.Grow(30 * len(selectPlan)) - depthOffset := len(flat.Main) - len(selectPlan) -loop1: for _, op := range selectPlan { - switch op.Origin.(type) { - case *FKCheck, *FKCascade: - // Generate plan digest doesn't need to contain the foreign key check/cascade plan, so just break the loop. - break loop1 - } taskTypeInfo := plancodec.EncodeTaskTypeForNormalize(op.IsRoot, op.StoreType) p := op.Origin.(PhysicalPlan) plancodec.NormalizePlanNode( - int(op.Depth-uint32(depthOffset)), + int(op.Depth-uint32(selectPlanOffset)), op.Origin.TP(), taskTypeInfo, p.ExplainNormalizedInfo(), diff --git a/planner/core/flat_plan.go b/planner/core/flat_plan.go index bbd6ea5f593f6..da200961e821c 100644 --- a/planner/core/flat_plan.go +++ b/planner/core/flat_plan.go @@ -54,23 +54,34 @@ type FlatPhysicalPlan struct { // depth-first traversal plus some special rule for some operators. type FlatPlanTree []*FlatOperator -// GetSelectPlan skips Insert, Delete and Update at the beginning of the FlatPlanTree. +// GetSelectPlan skips Insert, Delete, and Update at the beginning of the FlatPlanTree and the foreign key check/cascade plan at the end of the FlatPlanTree. // Note: // // It returns a reference to the original FlatPlanTree, please avoid modifying the returned value. -// Since you get a part of the original slice, you need to adjust the FlatOperator.Depth and FlatOperator.ChildrenIdx when using them. -func (e FlatPlanTree) GetSelectPlan() FlatPlanTree { +// The second return value is the offset. Because the returned FlatPlanTree is a part of the original slice, you need to minus them by the offset when using the returned FlatOperator.Depth and FlatOperator.ChildrenIdx. +func (e FlatPlanTree) GetSelectPlan() (FlatPlanTree, int) { if len(e) == 0 { - return nil + return nil, 0 } + hasDML := false for i, op := range e { switch op.Origin.(type) { case *Insert, *Delete, *Update: + hasDML = true default: - return e[i:] + if hasDML { + for j := i; j < len(e); j++ { + switch e[j].Origin.(type) { + case *FKCheck, *FKCascade: + // The later plans are belong to foreign key check/cascade plans, doesn't belong to select plan, just skip it. + return e[i:j], i + } + } + } + return e[i:], i } } - return nil + return nil, 0 } // FlatOperator is a simplified operator. diff --git a/planner/core/hints.go b/planner/core/hints.go index f67a66b1df001..baf9f91330d20 100644 --- a/planner/core/hints.go +++ b/planner/core/hints.go @@ -35,7 +35,7 @@ func GenHintsFromFlatPlan(flat *FlatPhysicalPlan) []*ast.TableOptimizerHint { nodeTp = utilhint.TypeDelete } var hints []*ast.TableOptimizerHint - selectPlan := flat.Main.GetSelectPlan() + selectPlan, _ := flat.Main.GetSelectPlan() if len(selectPlan) == 0 || !selectPlan[0].IsPhysicalPlan { return nil } diff --git a/planner/core/plan_test.go b/planner/core/plan_test.go index f50a04a796a7e..f1c65fcd9fb4d 100644 --- a/planner/core/plan_test.go +++ b/planner/core/plan_test.go @@ -139,6 +139,8 @@ func TestNormalizedPlan(t *testing.T) { newNormalized, newDigest := core.NormalizeFlatPlan(flat) require.Equal(t, normalized, newNormalized) require.Equal(t, digest, newDigest) + // Test for GenHintsFromFlatPlan won't panic. + core.GenHintsFromFlatPlan(flat) normalizedPlan, err := plancodec.DecodeNormalizedPlan(normalized) normalizedPlanRows := getPlanRows(normalizedPlan) diff --git a/session/bootstrap_test.go b/session/bootstrap_test.go index dd3e833d36553..4ebb7001461a6 100644 --- a/session/bootstrap_test.go +++ b/session/bootstrap_test.go @@ -46,7 +46,7 @@ func TestBootstrap(t *testing.T) { se := createSessionAndSetID(t, store) mustExec(t, se, "set global tidb_txn_mode=''") mustExec(t, se, "use mysql") - r := mustExec(t, se, "select * from user") + r := mustExecToRecodeSet(t, se, "select * from user") require.NotNil(t, r) ctx := context.Background() @@ -64,19 +64,14 @@ func TestBootstrap(t *testing.T) { mustExec(t, se, "use test") // Check privilege tables. - rs := mustExec(t, se, "SELECT * from mysql.global_priv") - require.NoError(t, rs.Close()) - rs = mustExec(t, se, "SELECT * from mysql.db") - require.NoError(t, rs.Close()) - rs = mustExec(t, se, "SELECT * from mysql.tables_priv") - require.NoError(t, rs.Close()) - rs = mustExec(t, se, "SELECT * from mysql.columns_priv") - require.NoError(t, rs.Close()) - rs = mustExec(t, se, "SELECT * from mysql.global_grants") - require.NoError(t, rs.Close()) + mustExec(t, se, "SELECT * from mysql.global_priv") + mustExec(t, se, "SELECT * from mysql.db") + mustExec(t, se, "SELECT * from mysql.tables_priv") + mustExec(t, se, "SELECT * from mysql.columns_priv") + mustExec(t, se, "SELECT * from mysql.global_grants") // Check privilege tables. - r = mustExec(t, se, "SELECT COUNT(*) from mysql.global_variables") + r = mustExecToRecodeSet(t, se, "SELECT COUNT(*) from mysql.global_variables") require.NotNil(t, r) req = r.NewChunk(nil) @@ -100,7 +95,7 @@ func TestBootstrap(t *testing.T) { se, err = CreateSession4Test(store) require.NoError(t, err) mustExec(t, se, "USE test") - r = mustExec(t, se, "select * from t") + r = mustExecToRecodeSet(t, se, "select * from t") require.NotNil(t, r) req = r.NewChunk(nil) @@ -116,7 +111,7 @@ func TestBootstrap(t *testing.T) { se, err = CreateSession4Test(store) require.NoError(t, err) doDMLWorks(se) - r = mustExec(t, se, "select * from mysql.expr_pushdown_blacklist where name = 'date_add'") + r = mustExecToRecodeSet(t, se, "select * from mysql.expr_pushdown_blacklist where name = 'date_add'") req = r.NewChunk(nil) err = r.Next(ctx, req) require.NoError(t, err) @@ -179,7 +174,7 @@ func TestBootstrapWithError(t *testing.T) { se := createSessionAndSetID(t, store) mustExec(t, se, "USE mysql") - r := mustExec(t, se, `select * from user`) + r := mustExecToRecodeSet(t, se, `select * from user`) req := r.NewChunk(nil) err = r.Next(ctx, req) require.NoError(t, err) @@ -192,15 +187,15 @@ func TestBootstrapWithError(t *testing.T) { mustExec(t, se, "USE test") // Check privilege tables. - mustExec(t, se, "SELECT * from mysql.global_priv").Close() - mustExec(t, se, "SELECT * from mysql.db").Close() - mustExec(t, se, "SELECT * from mysql.tables_priv").Close() - mustExec(t, se, "SELECT * from mysql.columns_priv").Close() + mustExec(t, se, "SELECT * from mysql.global_priv") + mustExec(t, se, "SELECT * from mysql.db") + mustExec(t, se, "SELECT * from mysql.tables_priv") + mustExec(t, se, "SELECT * from mysql.columns_priv") // Check role tables. - mustExec(t, se, "SELECT * from mysql.role_edges").Close() - mustExec(t, se, "SELECT * from mysql.default_roles").Close() + mustExec(t, se, "SELECT * from mysql.role_edges") + mustExec(t, se, "SELECT * from mysql.default_roles") // Check global variables. - r = mustExec(t, se, "SELECT COUNT(*) from mysql.global_variables") + r = mustExecToRecodeSet(t, se, "SELECT COUNT(*) from mysql.global_variables") req = r.NewChunk(nil) err = r.Next(ctx, req) require.NoError(t, err) @@ -208,7 +203,7 @@ func TestBootstrapWithError(t *testing.T) { require.Equal(t, globalVarsCount(), v.GetInt64(0)) require.NoError(t, r.Close()) - r = mustExec(t, se, `SELECT VARIABLE_VALUE from mysql.TiDB where VARIABLE_NAME="bootstrapped"`) + r = mustExecToRecodeSet(t, se, `SELECT VARIABLE_VALUE from mysql.TiDB where VARIABLE_NAME="bootstrapped"`) req = r.NewChunk(nil) err = r.Next(ctx, req) require.NoError(t, err) @@ -219,7 +214,7 @@ func TestBootstrapWithError(t *testing.T) { require.NoError(t, r.Close()) // Check tidb_ttl_table_status table - mustExec(t, se, "SELECT * from mysql.tidb_ttl_table_status").Close() + mustExec(t, se, "SELECT * from mysql.tidb_ttl_table_status") } // TestUpgrade tests upgrading @@ -233,7 +228,7 @@ func TestUpgrade(t *testing.T) { mustExec(t, se, "USE mysql") // bootstrap with currentBootstrapVersion - r := mustExec(t, se, `SELECT VARIABLE_VALUE from mysql.TiDB where VARIABLE_NAME="tidb_server_version"`) + r := mustExecToRecodeSet(t, se, `SELECT VARIABLE_VALUE from mysql.TiDB where VARIABLE_NAME="tidb_server_version"`) req := r.NewChunk(nil) err := r.Next(ctx, req) row := req.GetRow(0) @@ -262,7 +257,7 @@ func TestUpgrade(t *testing.T) { mustExec(t, se1, `commit`) unsetStoreBootstrapped(store.UUID()) // Make sure the version is downgraded. - r = mustExec(t, se1, `SELECT VARIABLE_VALUE from mysql.TiDB where VARIABLE_NAME="tidb_server_version"`) + r = mustExecToRecodeSet(t, se1, `SELECT VARIABLE_VALUE from mysql.TiDB where VARIABLE_NAME="tidb_server_version"`) req = r.NewChunk(nil) err = r.Next(ctx, req) require.NoError(t, err) @@ -278,7 +273,7 @@ func TestUpgrade(t *testing.T) { require.NoError(t, err) se2 := createSessionAndSetID(t, store) - r = mustExec(t, se2, `SELECT VARIABLE_VALUE from mysql.TiDB where VARIABLE_NAME="tidb_server_version"`) + r = mustExecToRecodeSet(t, se2, `SELECT VARIABLE_VALUE from mysql.TiDB where VARIABLE_NAME="tidb_server_version"`) req = r.NewChunk(nil) err = r.Next(ctx, req) require.NoError(t, err) @@ -293,7 +288,7 @@ func TestUpgrade(t *testing.T) { require.Equal(t, currentBootstrapVersion, ver) // Verify that 'new_collation_enabled' is false. - r = mustExec(t, se2, fmt.Sprintf(`SELECT VARIABLE_VALUE from mysql.TiDB where VARIABLE_NAME='%s'`, tidbNewCollationEnabled)) + r = mustExecToRecodeSet(t, se2, fmt.Sprintf(`SELECT VARIABLE_VALUE from mysql.TiDB where VARIABLE_NAME='%s'`, tidbNewCollationEnabled)) req = r.NewChunk(nil) err = r.Next(ctx, req) require.NoError(t, err) @@ -331,7 +326,7 @@ func TestIssue17979_1(t *testing.T) { ver, err = getBootstrapVersion(seV4) require.NoError(t, err) require.Equal(t, currentBootstrapVersion, ver) - r := mustExec(t, seV4, "select variable_value from mysql.tidb where variable_name='default_oom_action'") + r := mustExecToRecodeSet(t, seV4, "select variable_value from mysql.tidb where variable_name='default_oom_action'") req := r.NewChunk(nil) require.NoError(t, r.Next(ctx, req)) require.Equal(t, variable.OOMActionLog, req.GetRow(0).GetString(0)) @@ -368,7 +363,7 @@ func TestIssue17979_2(t *testing.T) { ver, err = getBootstrapVersion(seV4) require.NoError(t, err) require.Equal(t, currentBootstrapVersion, ver) - r := mustExec(t, seV4, "select variable_value from mysql.tidb where variable_name='default_oom_action'") + r := mustExecToRecodeSet(t, seV4, "select variable_value from mysql.tidb where variable_name='default_oom_action'") req := r.NewChunk(nil) require.NoError(t, r.Next(ctx, req)) require.Equal(t, 0, req.NumRows()) @@ -409,12 +404,12 @@ func TestIssue20900_2(t *testing.T) { ver, err = getBootstrapVersion(seV4) require.NoError(t, err) require.Equal(t, currentBootstrapVersion, ver) - r := mustExec(t, seV4, "select @@tidb_mem_quota_query") + r := mustExecToRecodeSet(t, seV4, "select @@tidb_mem_quota_query") req := r.NewChunk(nil) require.NoError(t, r.Next(ctx, req)) require.Equal(t, "1073741824", req.GetRow(0).GetString(0)) require.Equal(t, int64(1073741824), seV4.GetSessionVars().MemQuotaQuery) - r = mustExec(t, seV4, "select variable_value from mysql.tidb where variable_name='default_memory_quota_query'") + r = mustExecToRecodeSet(t, seV4, "select variable_value from mysql.tidb where variable_name='default_memory_quota_query'") req = r.NewChunk(nil) require.NoError(t, r.Next(ctx, req)) require.Equal(t, 0, req.NumRows()) @@ -476,7 +471,7 @@ func TestStmtSummary(t *testing.T) { defer dom.Close() se := createSessionAndSetID(t, store) - r := mustExec(t, se, "select variable_value from mysql.global_variables where variable_name='tidb_enable_stmt_summary'") + r := mustExecToRecodeSet(t, se, "select variable_value from mysql.global_variables where variable_name='tidb_enable_stmt_summary'") req := r.NewChunk(nil) require.NoError(t, r.Next(ctx, req)) row := req.GetRow(0) @@ -538,7 +533,7 @@ func TestUpdateBindInfo(t *testing.T) { mustExec(t, se, sql) upgradeToVer67(se, version66) - r := mustExec(t, se, `select original_sql, bind_sql, default_db, status from mysql.bind_info where source != 'builtin'`) + r := mustExecToRecodeSet(t, se, `select original_sql, bind_sql, default_db, status from mysql.bind_info where source != 'builtin'`) req := r.NewChunk(nil) require.NoError(t, r.Next(ctx, req)) row := req.GetRow(0) @@ -549,7 +544,7 @@ func TestUpdateBindInfo(t *testing.T) { require.NoError(t, r.Close()) sql = fmt.Sprintf("drop global binding for %s", bindCase.deleteText) mustExec(t, se, sql) - r = mustExec(t, se, `select original_sql, bind_sql, status from mysql.bind_info where source != 'builtin'`) + r = mustExecToRecodeSet(t, se, `select original_sql, bind_sql, status from mysql.bind_info where source != 'builtin'`) require.NoError(t, r.Next(ctx, req)) row = req.GetRow(0) require.Equal(t, bindCase.originWithDB, row.GetString(0)) @@ -581,7 +576,7 @@ func TestUpdateDuplicateBindInfo(t *testing.T) { upgradeToVer67(se, version66) - r := mustExec(t, se, `select original_sql, bind_sql, default_db, status, create_time from mysql.bind_info where source != 'builtin' order by create_time`) + r := mustExecToRecodeSet(t, se, `select original_sql, bind_sql, default_db, status, create_time from mysql.bind_info where source != 'builtin' order by create_time`) req := r.NewChunk(nil) require.NoError(t, r.Next(ctx, req)) require.Equal(t, 3, req.NumRows()) @@ -637,7 +632,7 @@ func TestUpgradeClusteredIndexDefaultValue(t *testing.T) { require.NoError(t, err) require.Equal(t, currentBootstrapVersion, ver) - r := mustExec(t, seV68, `select @@global.tidb_enable_clustered_index, @@session.tidb_enable_clustered_index`) + r := mustExecToRecodeSet(t, seV68, `select @@global.tidb_enable_clustered_index, @@session.tidb_enable_clustered_index`) req := r.NewChunk(nil) require.NoError(t, r.Next(context.Background(), req)) require.Equal(t, 1, req.NumRows()) @@ -674,7 +669,7 @@ func TestUpgradeVersion66(t *testing.T) { ver, err = getBootstrapVersion(seV66) require.NoError(t, err) require.Equal(t, currentBootstrapVersion, ver) - r := mustExec(t, seV66, `select @@global.tidb_track_aggregate_memory_usage, @@session.tidb_track_aggregate_memory_usage`) + r := mustExecToRecodeSet(t, seV66, `select @@global.tidb_track_aggregate_memory_usage, @@session.tidb_track_aggregate_memory_usage`) req := r.NewChunk(nil) require.NoError(t, r.Next(ctx, req)) require.Equal(t, 1, req.NumRows()) @@ -724,7 +719,7 @@ func TestUpgradeVersion74(t *testing.T) { ver, err = getBootstrapVersion(seV74) require.NoError(t, err) require.Equal(t, currentBootstrapVersion, ver) - r := mustExec(t, seV74, `SELECT @@global.tidb_stmt_summary_max_stmt_count`) + r := mustExecToRecodeSet(t, seV74, `SELECT @@global.tidb_stmt_summary_max_stmt_count`) req := r.NewChunk(nil) require.NoError(t, r.Next(ctx, req)) require.Equal(t, 1, req.NumRows()) @@ -757,7 +752,7 @@ func TestUpgradeVersion75(t *testing.T) { ver, err := getBootstrapVersion(seV74) require.NoError(t, err) require.Equal(t, int64(74), ver) - r := mustExec(t, seV74, `desc mysql.user`) + r := mustExecToRecodeSet(t, seV74, `desc mysql.user`) req := r.NewChunk(nil) row := req.GetRow(0) require.NoError(t, r.Next(ctx, req)) @@ -771,7 +766,7 @@ func TestUpgradeVersion75(t *testing.T) { ver, err = getBootstrapVersion(seV75) require.NoError(t, err) require.Equal(t, currentBootstrapVersion, ver) - r = mustExec(t, seV75, `desc mysql.user`) + r = mustExecToRecodeSet(t, seV75, `desc mysql.user`) req = r.NewChunk(nil) row = req.GetRow(0) require.NoError(t, r.Next(ctx, req)) @@ -853,7 +848,7 @@ func TestAnalyzeVersionUpgradeFrom300To500(t *testing.T) { require.Equal(t, int64(ver300), ver) // We are now in 3.0.0, check tidb_analyze_version should not exist. - res := mustExec(t, seV3, fmt.Sprintf("select * from mysql.GLOBAL_VARIABLES where variable_name='%s'", variable.TiDBAnalyzeVersion)) + res := mustExecToRecodeSet(t, seV3, fmt.Sprintf("select * from mysql.GLOBAL_VARIABLES where variable_name='%s'", variable.TiDBAnalyzeVersion)) chk := res.NewChunk(nil) err = res.Next(ctx, chk) require.NoError(t, err) @@ -868,7 +863,7 @@ func TestAnalyzeVersionUpgradeFrom300To500(t *testing.T) { require.Equal(t, currentBootstrapVersion, ver) // We are now in version no lower than 5.x, tidb_enable_index_merge should be 1. - res = mustExec(t, seCurVer, "select @@tidb_analyze_version") + res = mustExecToRecodeSet(t, seCurVer, "select @@tidb_analyze_version") chk = res.NewChunk(nil) err = res.Next(ctx, chk) require.NoError(t, err) @@ -891,7 +886,7 @@ func TestIndexMergeInNewCluster(t *testing.T) { // In a new created cluster(above 5.4+), tidb_enable_index_merge is 1 by default. mustExec(t, se, "use test;") - r := mustExec(t, se, "select @@tidb_enable_index_merge;") + r := mustExecToRecodeSet(t, se, "select @@tidb_enable_index_merge;") require.NotNil(t, r) ctx := context.Background() @@ -928,7 +923,7 @@ func TestIndexMergeUpgradeFrom300To540(t *testing.T) { require.Equal(t, int64(ver300), ver) // We are now in 3.0.0, check tidb_enable_index_merge shoudle not exist. - res := mustExec(t, seV3, fmt.Sprintf("select * from mysql.GLOBAL_VARIABLES where variable_name='%s'", variable.TiDBEnableIndexMerge)) + res := mustExecToRecodeSet(t, seV3, fmt.Sprintf("select * from mysql.GLOBAL_VARIABLES where variable_name='%s'", variable.TiDBEnableIndexMerge)) chk := res.NewChunk(nil) err = res.Next(ctx, chk) require.NoError(t, err) @@ -943,7 +938,7 @@ func TestIndexMergeUpgradeFrom300To540(t *testing.T) { require.Equal(t, currentBootstrapVersion, ver) // We are now in 5.x, tidb_enable_index_merge should be off. - res = mustExec(t, seCurVer, "select @@tidb_enable_index_merge") + res = mustExecToRecodeSet(t, seCurVer, "select @@tidb_enable_index_merge") chk = res.NewChunk(nil) err = res.Next(ctx, chk) require.NoError(t, err) @@ -979,7 +974,7 @@ func TestIndexMergeUpgradeFrom400To540(t *testing.T) { require.Equal(t, int64(ver400), ver) // We are now in 4.0.0, tidb_enable_index_merge is off. - res := mustExec(t, seV4, fmt.Sprintf("select * from mysql.GLOBAL_VARIABLES where variable_name='%s'", variable.TiDBEnableIndexMerge)) + res := mustExecToRecodeSet(t, seV4, fmt.Sprintf("select * from mysql.GLOBAL_VARIABLES where variable_name='%s'", variable.TiDBEnableIndexMerge)) chk := res.NewChunk(nil) err = res.Next(ctx, chk) require.NoError(t, err) @@ -1005,7 +1000,7 @@ func TestIndexMergeUpgradeFrom400To540(t *testing.T) { require.Equal(t, currentBootstrapVersion, ver) // We are now in 5.x, tidb_enable_index_merge should be on because we enable it in 4.0.0. - res = mustExec(t, seCurVer, "select @@tidb_enable_index_merge") + res = mustExecToRecodeSet(t, seCurVer, "select @@tidb_enable_index_merge") chk = res.NewChunk(nil) err = res.Next(ctx, chk) require.NoError(t, err) @@ -1037,7 +1032,7 @@ func TestUpgradeToVer85(t *testing.T) { mustExec(t, se, `insert into mysql.bind_info values('select * from t4', 'select /*+ use_index(t4, idx_a)*/ * from t4', 'test', 'invalid', '2021-01-08 14:50:58.257', '2021-01-08 14:50:58.257', 'utf8', 'utf8_general_ci', 'manual')`) upgradeToVer85(se, version84) - r := mustExec(t, se, `select count(*) from mysql.bind_info where status = 'enabled'`) + r := mustExecToRecodeSet(t, se, `select count(*) from mysql.bind_info where status = 'enabled'`) req := r.NewChunk(nil) require.NoError(t, r.Next(ctx, req)) require.Equal(t, 1, req.NumRows()) @@ -1058,7 +1053,7 @@ func TestTiDBEnablePagingVariable(t *testing.T) { "select @@global.tidb_enable_paging", "select @@session.tidb_enable_paging", } { - r := mustExec(t, se, sql) + r := mustExecToRecodeSet(t, se, sql) require.NotNil(t, r) req := r.NewChunk(nil) @@ -1100,7 +1095,7 @@ func TestTiDBOptRangeMaxSizeWhenUpgrading(t *testing.T) { require.Equal(t, int64(ver94), ver) // We are now in 6.3.0, check tidb_opt_range_max_size should not exist. - res := mustExec(t, seV630, fmt.Sprintf("select * from mysql.GLOBAL_VARIABLES where variable_name='%s'", variable.TiDBOptRangeMaxSize)) + res := mustExecToRecodeSet(t, seV630, fmt.Sprintf("select * from mysql.GLOBAL_VARIABLES where variable_name='%s'", variable.TiDBOptRangeMaxSize)) chk := res.NewChunk(nil) err = res.Next(ctx, chk) require.NoError(t, err) @@ -1115,7 +1110,7 @@ func TestTiDBOptRangeMaxSizeWhenUpgrading(t *testing.T) { require.Equal(t, currentBootstrapVersion, ver) // We are now in version no lower than v6.4.0, tidb_opt_range_max_size should be 0. - res = mustExec(t, seCurVer, "select @@session.tidb_opt_range_max_size") + res = mustExecToRecodeSet(t, seCurVer, "select @@session.tidb_opt_range_max_size") chk = res.NewChunk(nil) err = res.Next(ctx, chk) require.NoError(t, err) @@ -1124,7 +1119,7 @@ func TestTiDBOptRangeMaxSizeWhenUpgrading(t *testing.T) { require.Equal(t, 1, row.Len()) require.Equal(t, "0", row.GetString(0)) - res = mustExec(t, seCurVer, "select @@global.tidb_opt_range_max_size") + res = mustExecToRecodeSet(t, seCurVer, "select @@global.tidb_opt_range_max_size") chk = res.NewChunk(nil) err = res.Next(ctx, chk) require.NoError(t, err) @@ -1147,7 +1142,7 @@ func TestTiDBCostModelInNewCluster(t *testing.T) { // In a new created cluster(above 6.5+), tidb_cost_model_version is 2 by default. mustExec(t, se, "use test;") - r := mustExec(t, se, "select @@tidb_cost_model_version;") + r := mustExecToRecodeSet(t, se, "select @@tidb_cost_model_version;") require.NotNil(t, r) ctx := context.Background() @@ -1184,7 +1179,7 @@ func TestTiDBCostModelUpgradeFrom300To650(t *testing.T) { require.Equal(t, int64(ver300), ver) // We are now in 3.0.0, check TiDBCostModelVersion should not exist. - res := mustExec(t, seV3, fmt.Sprintf("select * from mysql.GLOBAL_VARIABLES where variable_name='%s'", variable.TiDBCostModelVersion)) + res := mustExecToRecodeSet(t, seV3, fmt.Sprintf("select * from mysql.GLOBAL_VARIABLES where variable_name='%s'", variable.TiDBCostModelVersion)) chk := res.NewChunk(nil) err = res.Next(ctx, chk) require.NoError(t, err) @@ -1199,7 +1194,7 @@ func TestTiDBCostModelUpgradeFrom300To650(t *testing.T) { require.Equal(t, currentBootstrapVersion, ver) // We are now in 6.5+, TiDBCostModelVersion should be 1. - res = mustExec(t, seCurVer, "select @@tidb_cost_model_version") + res = mustExecToRecodeSet(t, seCurVer, "select @@tidb_cost_model_version") chk = res.NewChunk(nil) err = res.Next(ctx, chk) require.NoError(t, err) @@ -1235,7 +1230,7 @@ func TestTiDBCostModelUpgradeFrom610To650(t *testing.T) { require.Equal(t, int64(ver61), ver) // We are now in 6.1, tidb_cost_model_version is 1. - res := mustExec(t, seV61, fmt.Sprintf("select * from mysql.GLOBAL_VARIABLES where variable_name='%s'", variable.TiDBCostModelVersion)) + res := mustExecToRecodeSet(t, seV61, fmt.Sprintf("select * from mysql.GLOBAL_VARIABLES where variable_name='%s'", variable.TiDBCostModelVersion)) chk := res.NewChunk(nil) err = res.Next(ctx, chk) require.NoError(t, err) @@ -1243,6 +1238,7 @@ func TestTiDBCostModelUpgradeFrom610To650(t *testing.T) { row := chk.GetRow(0) require.Equal(t, 2, row.Len()) require.Equal(t, "1", row.GetString(1)) + res.Close() if i == 0 { // For the first time, We set tidb_cost_model_version to 2. @@ -1261,7 +1257,7 @@ func TestTiDBCostModelUpgradeFrom610To650(t *testing.T) { require.Equal(t, currentBootstrapVersion, ver) // We are now in 6.5. - res = mustExec(t, seCurVer, "select @@tidb_cost_model_version") + res = mustExecToRecodeSet(t, seCurVer, "select @@tidb_cost_model_version") chk = res.NewChunk(nil) err = res.Next(ctx, chk) require.NoError(t, err) @@ -1273,6 +1269,7 @@ func TestTiDBCostModelUpgradeFrom610To650(t *testing.T) { } else { require.Equal(t, "1", row.GetString(0)) } + res.Close() }() } } diff --git a/session/main_test.go b/session/main_test.go index 790903eb68c8a..1841bbbdd3570 100644 --- a/session/main_test.go +++ b/session/main_test.go @@ -99,7 +99,15 @@ func createSessionAndSetID(t *testing.T, store kv.Storage) Session { return se } -func mustExec(t *testing.T, se Session, sql string, args ...interface{}) sqlexec.RecordSet { +func mustExec(t *testing.T, se Session, sql string, args ...interface{}) { + rs, err := exec(se, sql, args...) + require.NoError(t, err) + if rs != nil { + require.NoError(t, rs.Close()) + } +} + +func mustExecToRecodeSet(t *testing.T, se Session, sql string, args ...interface{}) sqlexec.RecordSet { rs, err := exec(se, sql, args...) require.NoError(t, err) return rs diff --git a/types/const_test.go b/types/const_test.go index 3942b6d1c5fd1..3815efde54338 100644 --- a/types/const_test.go +++ b/types/const_test.go @@ -338,8 +338,7 @@ func TestIgnoreSpaceMode(t *testing.T) { tk.MustExec("DROP TABLE BIT_AND;") tk.MustExec("CREATE TABLE `BIT_AND` (a bigint);") tk.MustExec("DROP TABLE BIT_AND;") - _, err = tk.Exec("CREATE TABLE BIT_AND(a bigint);") - require.Error(t, err) + tk.MustExecToErr("CREATE TABLE BIT_AND(a bigint);") tk.MustExec("CREATE TABLE test.BIT_AND(a bigint);") tk.MustExec("DROP TABLE BIT_AND;") @@ -347,36 +346,29 @@ func TestIgnoreSpaceMode(t *testing.T) { tk.MustExec("DROP TABLE NOW;") tk.MustExec("CREATE TABLE `NOW` (a bigint);") tk.MustExec("DROP TABLE NOW;") - _, err = tk.Exec("CREATE TABLE NOW(a bigint);") - require.Error(t, err) + tk.MustExecToErr("CREATE TABLE NOW(a bigint);") tk.MustExec("CREATE TABLE test.NOW(a bigint);") tk.MustExec("DROP TABLE NOW;") tk.MustExec("set sql_mode='IGNORE_SPACE'") - _, err = tk.Exec("CREATE TABLE COUNT (a bigint);") - require.Error(t, err) + tk.MustExecToErr("CREATE TABLE COUNT (a bigint);") tk.MustExec("CREATE TABLE `COUNT` (a bigint);") tk.MustExec("DROP TABLE COUNT;") - _, err = tk.Exec("CREATE TABLE COUNT(a bigint);") - require.Error(t, err) + tk.MustExecToErr("CREATE TABLE COUNT(a bigint);") tk.MustExec("CREATE TABLE test.COUNT(a bigint);") tk.MustExec("DROP TABLE COUNT;") - _, err = tk.Exec("CREATE TABLE BIT_AND (a bigint);") - require.Error(t, err) + tk.MustExecToErr("CREATE TABLE BIT_AND (a bigint);") tk.MustExec("CREATE TABLE `BIT_AND` (a bigint);") tk.MustExec("DROP TABLE BIT_AND;") - _, err = tk.Exec("CREATE TABLE BIT_AND(a bigint);") - require.Error(t, err) + tk.MustExecToErr("CREATE TABLE BIT_AND(a bigint);") tk.MustExec("CREATE TABLE test.BIT_AND(a bigint);") tk.MustExec("DROP TABLE BIT_AND;") - _, err = tk.Exec("CREATE TABLE NOW (a bigint);") - require.Error(t, err) + tk.MustExecToErr("CREATE TABLE NOW (a bigint);") tk.MustExec("CREATE TABLE `NOW` (a bigint);") tk.MustExec("DROP TABLE NOW;") - _, err = tk.Exec("CREATE TABLE NOW(a bigint);") - require.Error(t, err) + tk.MustExecToErr("CREATE TABLE NOW(a bigint);") tk.MustExec("CREATE TABLE test.NOW(a bigint);") tk.MustExec("DROP TABLE NOW;") }