From 6aa1d8b2720c32a7a7ba5108cabdcb9c0e2881b2 Mon Sep 17 00:00:00 2001 From: tiancaiamao Date: Thu, 6 Jun 2019 20:05:20 +0800 Subject: [PATCH] executor: execute some statement (`create user` `grant` etc) would commit current transaction automically --- executor/grant.go | 11 ++++++++++- executor/revoke.go | 9 ++++++++- executor/simple.go | 22 ++++++++++++++++++++-- executor/simple_test.go | 30 ++++++++++++++++++++++++++++++ 4 files changed, 68 insertions(+), 4 deletions(-) diff --git a/executor/grant.go b/executor/grant.go index d23315006d871..aaa2d285e0279 100644 --- a/executor/grant.go +++ b/executor/grant.go @@ -64,7 +64,7 @@ func (e *GrantExec) Next(ctx context.Context, chk *chunk.Chunk) error { dbName = e.ctx.GetSessionVars().CurrentDB } // Grant for each user - for _, user := range e.Users { + for idx, user := range e.Users { // Check if user exists. exists, err := userExists(e.ctx, user.User.Username, user.User.Hostname) if err != nil { @@ -103,6 +103,15 @@ func (e *GrantExec) Next(ctx context.Context, chk *chunk.Chunk) error { if e.WithGrant { privs = append(privs, &ast.PrivElem{Priv: mysql.GrantPriv}) } + + if idx == 0 { + // Commit the old transaction, like DDL. + if err := e.ctx.NewTxn(); err != nil { + return err + } + defer func() { e.ctx.GetSessionVars().SetStatusFlag(mysql.ServerStatusInTrans, false) }() + } + // Grant each priv to the user. for _, priv := range privs { if len(priv.Cols) > 0 { diff --git a/executor/revoke.go b/executor/revoke.go index 6856c816551b5..d75f99a30f1e8 100644 --- a/executor/revoke.go +++ b/executor/revoke.go @@ -58,7 +58,7 @@ func (e *RevokeExec) Next(ctx context.Context, chk *chunk.Chunk) error { e.done = true // Revoke for each user. - for _, user := range e.Users { + for idx, user := range e.Users { // Check if user exists. exists, err := userExists(e.ctx, user.User.Username, user.User.Hostname) if err != nil { @@ -68,6 +68,13 @@ func (e *RevokeExec) Next(ctx context.Context, chk *chunk.Chunk) error { return errors.Errorf("Unknown user: %s", user.User) } + if idx == 0 { + // Commit the old transaction, like DDL. + if err := e.ctx.NewTxn(); err != nil { + return err + } + defer func() { e.ctx.GetSessionVars().SetStatusFlag(mysql.ServerStatusInTrans, false) }() + } err = e.revokeOneUser(user.User.Username, user.User.Hostname) if err != nil { return errors.Trace(err) diff --git a/executor/simple.go b/executor/simple.go index 3fa6054515bd3..cfd19af0cce49 100644 --- a/executor/simple.go +++ b/executor/simple.go @@ -54,6 +54,15 @@ func (e *SimpleExec) Next(ctx context.Context, chk *chunk.Chunk) (err error) { if e.done { return nil } + + if e.autoNewTxn() { + // Commit the old transaction, like DDL. + if err := e.ctx.NewTxn(); err != nil { + return err + } + defer func() { e.ctx.GetSessionVars().SetStatusFlag(mysql.ServerStatusInTrans, false) }() + } + switch x := e.Statement.(type) { case *ast.UseStmt: err = e.executeUse(x) @@ -66,7 +75,7 @@ func (e *SimpleExec) Next(ctx context.Context, chk *chunk.Chunk) (err error) { case *ast.RollbackStmt: err = e.executeRollback(x) case *ast.CreateUserStmt: - err = e.executeCreateUser(x) + err = e.executeCreateUser(ctx, x) case *ast.AlterUserStmt: err = e.executeAlterUser(x) case *ast.DropUserStmt: @@ -141,7 +150,7 @@ func (e *SimpleExec) executeRollback(s *ast.RollbackStmt) error { return nil } -func (e *SimpleExec) executeCreateUser(s *ast.CreateUserStmt) error { +func (e *SimpleExec) executeCreateUser(ctx context.Context, s *ast.CreateUserStmt) error { users := make([]string, 0, len(s.Specs)) for _, spec := range s.Specs { exists, err1 := userExists(e.ctx, spec.User.Username, spec.User.Hostname) @@ -164,6 +173,7 @@ func (e *SimpleExec) executeCreateUser(s *ast.CreateUserStmt) error { if len(users) == 0 { return nil } + sql := fmt.Sprintf(`INSERT INTO %s.%s (Host, User, Password) VALUES %s;`, mysql.SystemDB, mysql.UserTable, strings.Join(users, ", ")) _, err := e.ctx.(sqlexec.SQLExecutor).Execute(context.Background(), sql) if err != nil { @@ -370,3 +380,11 @@ func (e *SimpleExec) executeDropStats(s *ast.DropStatsStmt) error { } return errors.Trace(h.Update(GetInfoSchema(e.ctx))) } + +func (e *SimpleExec) autoNewTxn() bool { + switch e.Statement.(type) { + case *ast.CreateUserStmt, *ast.AlterUserStmt, *ast.DropUserStmt: + return true + } + return false +} diff --git a/executor/simple_test.go b/executor/simple_test.go index e97d21b142902..cd1810863e8fa 100644 --- a/executor/simple_test.go +++ b/executor/simple_test.go @@ -296,3 +296,33 @@ func (s *testSuite) TestDropStats(c *C) { c.Assert(statsTbl.Pseudo, IsTrue) h.Lease = 0 } + +func (s *testSuite) TestStmtAutoNewTxn(c *C) { + // Some statements are like DDL, they commit the previous txn automically. + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + + // Fix issue https://github.com/pingcap/tidb/issues/10705 + tk.MustExec("begin") + tk.MustExec("create user 'xxx'@'%';") + tk.MustExec("grant all privileges on *.* to 'xxx'@'%';") + + tk.MustExec("create table auto_new (id int)") + tk.MustExec("begin") + tk.MustExec("insert into auto_new values (1)") + tk.MustExec("revoke all privileges on *.* from 'xxx'@'%'") + tk.MustExec("rollback") // insert statement has already committed + tk.MustQuery("select * from auto_new").Check(testkit.Rows("1")) + + // Test the behavior when autocommit is false. + tk.MustExec("set autocommit = 0") + tk.MustExec("insert into auto_new values (2)") + tk.MustExec("create user 'yyy'@'%'") + tk.MustExec("rollback") + tk.MustQuery("select * from auto_new").Check(testkit.Rows("1", "2")) + + tk.MustExec("drop user 'yyy'@'%'") + tk.MustExec("insert into auto_new values (3)") + tk.MustExec("rollback") + tk.MustQuery("select * from auto_new").Check(testkit.Rows("1", "2")) +}