diff --git a/ddl/ddl_worker_test.go b/ddl/ddl_worker_test.go index cc48c01bb6a34..c07786b82ccdb 100644 --- a/ddl/ddl_worker_test.go +++ b/ddl/ddl_worker_test.go @@ -267,7 +267,7 @@ func testCheckJobCancelled(c *C, d *ddl, job *model.Job, state *model.SchemaStat t := meta.NewMeta(txn) historyJob, err := t.GetHistoryDDLJob(job.ID) c.Assert(err, IsNil) - c.Assert(historyJob.IsCancelled() || historyJob.IsRollbackDone(), IsTrue, Commentf("histroy job %s", historyJob)) + c.Assert(historyJob.IsCancelled() || historyJob.IsRollbackDone(), IsTrue, Commentf("history job %s", historyJob)) if state != nil { c.Assert(historyJob.SchemaState, Equals, *state) } diff --git a/executor/executor.go b/executor/executor.go index 17384f070c6bc..bcfacf04d31ae 100644 --- a/executor/executor.go +++ b/executor/executor.go @@ -1288,11 +1288,17 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { sc.Priority = priority } } - if vars.LastInsertID > 0 { - vars.PrevLastInsertID = vars.LastInsertID - vars.LastInsertID = 0 + if vars.StmtCtx.LastInsertID > 0 { + sc.PrevLastInsertID = vars.StmtCtx.LastInsertID + } else { + sc.PrevLastInsertID = vars.StmtCtx.PrevLastInsertID + } + sc.PrevAffectedRows = 0 + if vars.StmtCtx.InUpdateOrDeleteStmt || vars.StmtCtx.InInsertStmt { + sc.PrevAffectedRows = int64(vars.StmtCtx.AffectedRows()) + } else if vars.StmtCtx.InSelectStmt { + sc.PrevAffectedRows = -1 } - vars.ResetPrevAffectedRows() err = vars.SetSystemVar("warning_count", fmt.Sprintf("%d", vars.StmtCtx.NumWarnings(false))) if err != nil { return errors.Trace(err) @@ -1301,7 +1307,6 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { if err != nil { return errors.Trace(err) } - vars.InsertID = 0 vars.StmtCtx = sc return } diff --git a/executor/insert_common.go b/executor/insert_common.go index 224df2b1c3bc9..7cea33e154144 100644 --- a/executor/insert_common.go +++ b/executor/insert_common.go @@ -460,7 +460,7 @@ func (e *InsertValues) adjustAutoIncrementDatum(d types.Datum, hasValue bool, c if err != nil { return types.Datum{}, errors.Trace(err) } - e.ctx.GetSessionVars().InsertID = uint64(recordID) + e.ctx.GetSessionVars().StmtCtx.InsertID = uint64(recordID) retryInfo.AddAutoIncrementID(recordID) d.SetAutoID(recordID, c.Flag) return d, nil diff --git a/expression/builtin_info.go b/expression/builtin_info.go index 0bde0f778e6ec..4d7dc78c90499 100644 --- a/expression/builtin_info.go +++ b/expression/builtin_info.go @@ -262,7 +262,7 @@ func (b *builtinLastInsertIDSig) Clone() builtinFunc { // evalInt evals LAST_INSERT_ID(). // See https://dev.mysql.com/doc/refman/5.7/en/information-functions.html#function_last-insert-id. func (b *builtinLastInsertIDSig) evalInt(row chunk.Row) (res int64, isNull bool, err error) { - res = int64(b.ctx.GetSessionVars().PrevLastInsertID) + res = int64(b.ctx.GetSessionVars().StmtCtx.PrevLastInsertID) return res, false, nil } @@ -439,6 +439,6 @@ func (b *builtinRowCountSig) Clone() builtinFunc { // evalInt evals ROW_COUNT(). // See https://dev.mysql.com/doc/refman/5.7/en/information-functions.html#function_row-count. func (b *builtinRowCountSig) evalInt(_ chunk.Row) (res int64, isNull bool, err error) { - res = int64(b.ctx.GetSessionVars().PrevAffectedRows) + res = int64(b.ctx.GetSessionVars().StmtCtx.PrevAffectedRows) return res, false, nil } diff --git a/expression/builtin_info_test.go b/expression/builtin_info_test.go index b28b6b80defe0..926361bba106a 100644 --- a/expression/builtin_info_test.go +++ b/expression/builtin_info_test.go @@ -154,7 +154,7 @@ func (s *testEvaluatorSuite) TestRowCount(c *C) { defer testleak.AfterTest(c)() ctx := mock.NewContext() sessionVars := ctx.GetSessionVars() - sessionVars.PrevAffectedRows = 10 + sessionVars.StmtCtx.PrevAffectedRows = 10 f, err := funcs[ast.RowCount].getFunction(ctx, nil) c.Assert(err, IsNil) @@ -203,7 +203,7 @@ func (s *testEvaluatorSuite) TestLastInsertID(c *C) { err error ) if t.insertID > 0 { - s.ctx.GetSessionVars().PrevLastInsertID = t.insertID + s.ctx.GetSessionVars().StmtCtx.PrevLastInsertID = t.insertID } if t.args != nil { diff --git a/session/session.go b/session/session.go index 3d5a003af6b2d..7f8463e2a497b 100644 --- a/session/session.go +++ b/session/session.go @@ -174,10 +174,10 @@ func (s *session) Status() uint16 { } func (s *session) LastInsertID() uint64 { - if s.sessionVars.LastInsertID > 0 { - return s.sessionVars.LastInsertID + if s.sessionVars.StmtCtx.LastInsertID > 0 { + return s.sessionVars.StmtCtx.LastInsertID } - return s.sessionVars.InsertID + return s.sessionVars.StmtCtx.InsertID } func (s *session) AffectedRows() uint64 { @@ -427,8 +427,8 @@ func (s *session) String() string { if sessVars.SnapshotTS != 0 { data["snapshotTS"] = sessVars.SnapshotTS } - if sessVars.LastInsertID > 0 { - data["lastInsertID"] = sessVars.LastInsertID + if sessVars.StmtCtx.LastInsertID > 0 { + data["lastInsertID"] = sessVars.StmtCtx.LastInsertID } if len(sessVars.PreparedStmts) > 0 { data["preparedStmtCount"] = len(sessVars.PreparedStmts) @@ -486,6 +486,9 @@ func (s *session) retry(ctx context.Context, maxCnt uint) error { if st.IsReadOnly() { continue } + s.sessionVars.StmtCtx = sr.stmtCtx + s.sessionVars.StmtCtx.ResetForRetry() + s.sessionVars.PreparedParams = s.sessionVars.PreparedParams[:0] schemaVersion, err = st.RebuildPlan() if err != nil { return errors.Trace(err) @@ -499,8 +502,6 @@ func (s *session) retry(ctx context.Context, maxCnt uint) error { } else { log.Warnf("con:%d schema_ver:%d retry_cnt:%d query_num:%d", connID, schemaVersion, retryCnt, i) } - s.sessionVars.StmtCtx = sr.stmtCtx - s.sessionVars.StmtCtx.ResetForRetry() _, err = st.Exec(ctx) if err != nil { s.StmtRollback() diff --git a/session/session_test.go b/session/session_test.go index 993b1ee3bd730..3997ae7d7cc32 100644 --- a/session/session_test.go +++ b/session/session_test.go @@ -677,7 +677,7 @@ func (s *testSessionSuite) TestLastInsertID(c *C) { tk.MustExec("execute stmt1 using @v1") tk.MustExec("execute stmt1 using @v2") tk.MustExec("deallocate prepare stmt1") - currLastInsertID := tk.Se.GetSessionVars().PrevLastInsertID + currLastInsertID := tk.Se.GetSessionVars().StmtCtx.PrevLastInsertID tk.MustQuery("select c1 from t where c2 = 20").Check(testkit.Rows(fmt.Sprint(currLastInsertID))) c.Assert(lastInsertID+2, Equals, currLastInsertID) } @@ -778,7 +778,7 @@ func (s *testSessionSuite) TestAutoIncrementWithRetry(c *C) { tk.MustExec("commit") tk.MustQuery("select c1 from t where c2 = 11").Check(testkit.Rows("6")) - currLastInsertID := tk.Se.GetSessionVars().PrevLastInsertID + currLastInsertID := tk.Se.GetSessionVars().StmtCtx.PrevLastInsertID c.Assert(lastInsertID+5, Equals, currLastInsertID) // insert set @@ -793,7 +793,7 @@ func (s *testSessionSuite) TestAutoIncrementWithRetry(c *C) { tk.MustExec("commit") tk.MustQuery("select c1 from t where c2 = 31").Check(testkit.Rows("9")) - currLastInsertID = tk.Se.GetSessionVars().PrevLastInsertID + currLastInsertID = tk.Se.GetSessionVars().StmtCtx.PrevLastInsertID c.Assert(lastInsertID+3, Equals, currLastInsertID) // replace @@ -808,7 +808,7 @@ func (s *testSessionSuite) TestAutoIncrementWithRetry(c *C) { tk.MustExec("commit") tk.MustQuery("select c1 from t where c2 = 21").Check(testkit.Rows("10")) - currLastInsertID = tk.Se.GetSessionVars().PrevLastInsertID + currLastInsertID = tk.Se.GetSessionVars().StmtCtx.PrevLastInsertID c.Assert(lastInsertID+1, Equals, currLastInsertID) // update @@ -824,7 +824,7 @@ func (s *testSessionSuite) TestAutoIncrementWithRetry(c *C) { tk.MustExec("commit") tk.MustQuery("select c1 from t where c2 = 41").Check(testkit.Rows("0")) - currLastInsertID = tk.Se.GetSessionVars().PrevLastInsertID + currLastInsertID = tk.Se.GetSessionVars().StmtCtx.PrevLastInsertID c.Assert(lastInsertID+3, Equals, currLastInsertID) // prepare @@ -846,7 +846,7 @@ func (s *testSessionSuite) TestAutoIncrementWithRetry(c *C) { tk.MustExec("commit") tk.MustQuery("select c1 from t where c2 = 12").Check(testkit.Rows("7")) - currLastInsertID = tk.Se.GetSessionVars().PrevLastInsertID + currLastInsertID = tk.Se.GetSessionVars().StmtCtx.PrevLastInsertID c.Assert(lastInsertID+3, Equals, currLastInsertID) } @@ -1306,6 +1306,36 @@ func (s *testSessionSuite) TestDelete(c *C) { tk1.MustQuery("select * from t;").Check(testkit.Rows("1")) } +func (s *testSessionSuite) TestResetCtx(c *C) { + tk := testkit.NewTestKitWithInit(c, s.store) + tk1 := testkit.NewTestKitWithInit(c, s.store) + + tk.MustExec("create table t (i int auto_increment not null key);") + tk.MustExec("insert into t values (1);") + tk.MustExec("begin;") + tk.MustExec("insert into t values (10);") + tk.MustExec("update t set i = i + row_count();") + tk.MustQuery("select * from t;").Check(testkit.Rows("2", "11")) + + tk1.MustExec("update t set i = 0 where i = 1;") + tk1.MustQuery("select * from t;").Check(testkit.Rows("0")) + + tk.MustExec("commit;") + tk.MustQuery("select * from t;").Check(testkit.Rows("1", "11")) + + tk.MustExec("delete from t where i = 11;") + tk.MustExec("begin;") + tk.MustExec("insert into t values ();") + tk.MustExec("update t set i = i + last_insert_id() + 1;") + tk.MustQuery("select * from t;").Check(testkit.Rows("14", "25")) + + tk1.MustExec("update t set i = 0 where i = 1;") + tk1.MustQuery("select * from t;").Check(testkit.Rows("0")) + + tk.MustExec("commit;") + tk.MustQuery("select * from t;").Check(testkit.Rows("13", "25")) +} + func (s *testSessionSuite) TestUnique(c *C) { // test for https://github.com/pingcap/tidb/pull/461 diff --git a/sessionctx/stmtctx/stmtctx.go b/sessionctx/stmtctx/stmtctx.go index 66db63081c5b8..7f521dd1d28bd 100644 --- a/sessionctx/stmtctx/stmtctx.go +++ b/sessionctx/stmtctx/stmtctx.go @@ -71,6 +71,14 @@ type StatementContext struct { histogramsNotLoad bool execDetails execdetails.ExecDetails } + // PrevAffectedRows is the affected-rows value(DDL is 0, DML is the number of affected rows). + PrevAffectedRows int64 + // PrevLastInsertID is the last insert ID of previous statement. + PrevLastInsertID uint64 + // LastInsertID is the auto-generated ID in the current statement. + LastInsertID uint64 + // InsertID is the given insert ID of an auto_increment column. + InsertID uint64 // Copied from SessionVars.TimeZone. TimeZone *time.Location @@ -239,6 +247,8 @@ func (sc *StatementContext) ResetForRetry() { sc.mu.foundRows = 0 sc.mu.warnings = nil sc.mu.Unlock() + sc.TableIDs = sc.TableIDs[:0] + sc.IndexIDs = sc.IndexIDs[:0] } // MergeExecDetails merges a single region execution details into self, used to print diff --git a/sessionctx/variable/session.go b/sessionctx/variable/session.go index 3ecb985959ff5..497a6249fc085 100644 --- a/sessionctx/variable/session.go +++ b/sessionctx/variable/session.go @@ -201,14 +201,8 @@ type SessionVars struct { Value string } - // Following variables are special for current session. - - Status uint16 - PrevLastInsertID uint64 // PrevLastInsertID is the last insert ID of previous statement. - LastInsertID uint64 // LastInsertID is the auto-generated ID in the current statement. - InsertID uint64 // InsertID is the given insert ID of an auto_increment column. - // PrevAffectedRows is the affected-rows value(DDL is 0, DML is the number of affected rows). - PrevAffectedRows int64 + // Status stands for the session status. e.g. in transaction or not, auto commit is on or off, and so on. + Status uint16 // ClientCapability is client's capability. ClientCapability uint32 @@ -405,7 +399,7 @@ func (s *SessionVars) GetCharsetInfo() (charset, collation string) { // SetLastInsertID saves the last insert id to the session context. // TODO: we may store the result for last_insert_id sys var later. func (s *SessionVars) SetLastInsertID(insertID uint64) { - s.LastInsertID = insertID + s.StmtCtx.LastInsertID = insertID } // SetStatusFlag sets the session server status variable. @@ -449,18 +443,6 @@ func (s *SessionVars) Location() *time.Location { return loc } -// ResetPrevAffectedRows reset the prev-affected-rows variable. -func (s *SessionVars) ResetPrevAffectedRows() { - s.PrevAffectedRows = 0 - if s.StmtCtx != nil { - if s.StmtCtx.InUpdateOrDeleteStmt || s.StmtCtx.InInsertStmt { - s.PrevAffectedRows = int64(s.StmtCtx.AffectedRows()) - } else if s.StmtCtx.InSelectStmt { - s.PrevAffectedRows = -1 - } - } -} - // GetExecuteArgumentsInfo gets the argument list as a string of execute statement. func (s *SessionVars) GetExecuteArgumentsInfo() string { if len(s.PreparedParams) == 0 { diff --git a/sessionctx/variable/session_test.go b/sessionctx/variable/session_test.go index 340c1759780df..b7cd5f213479d 100644 --- a/sessionctx/variable/session_test.go +++ b/sessionctx/variable/session_test.go @@ -43,7 +43,7 @@ func (*testSessionSuite) TestSession(c *C) { // For last insert id ctx.GetSessionVars().SetLastInsertID(1) - c.Assert(ctx.GetSessionVars().LastInsertID, Equals, uint64(1)) + c.Assert(ctx.GetSessionVars().StmtCtx.LastInsertID, Equals, uint64(1)) ss.ResetForRetry() c.Assert(ss.AffectedRows(), Equals, uint64(0))