From e4461a6fea963f6967fde63498eb6be3543e35cd Mon Sep 17 00:00:00 2001 From: Jack Yu Date: Mon, 8 Apr 2019 17:39:01 +0800 Subject: [PATCH] *: fix the read-only check for the prepare/execute statement (#9723) (#10048) --- executor/adapter.go | 21 ++++++++++++--------- executor/executor.go | 3 +++ session/session.go | 3 --- session/session_test.go | 19 +++++++++++++++++++ session/tidb.go | 2 +- util/sqlexec/restricted_sql_executor.go | 3 ++- 6 files changed, 37 insertions(+), 14 deletions(-) diff --git a/executor/adapter.go b/executor/adapter.go index e301f409539e9..cb54c7f54a390 100644 --- a/executor/adapter.go +++ b/executor/adapter.go @@ -163,15 +163,18 @@ func (a *ExecStmt) IsPrepared() bool { } // IsReadOnly returns true if a statement is read only. -// It will update readOnlyCheckStmt if current ExecStmt can be conveted to -// a plannercore.Execute. Last step is using ast.IsReadOnly function to determine -// a statement is read only or not. -func (a *ExecStmt) IsReadOnly() bool { - readOnlyCheckStmt := a.StmtNode - if checkPlan, ok := a.Plan.(*plannercore.Execute); ok { - readOnlyCheckStmt = checkPlan.Stmt - } - return ast.IsReadOnly(readOnlyCheckStmt) +// If current StmtNode is an ExecuteStmt, we can get its prepared stmt, +// then using ast.IsReadOnly function to determine a statement is read only or not. +func (a *ExecStmt) IsReadOnly(vars *variable.SessionVars) bool { + if execStmt, ok := a.StmtNode.(*ast.ExecuteStmt); ok { + s, err := getPreparedStmt(execStmt, vars) + if err != nil { + logutil.Logger(context.Background()).Error("getPreparedStmt failed", zap.Error(err)) + return false + } + return ast.IsReadOnly(s) + } + return ast.IsReadOnly(a.StmtNode) } // RebuildPlan rebuilds current execute statement plan. diff --git a/executor/executor.go b/executor/executor.go index c1096513aa36e..96ea85726e68f 100644 --- a/executor/executor.go +++ b/executor/executor.go @@ -1260,6 +1260,9 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { if execStmt, ok := s.(*ast.ExecuteStmt); ok { s, err = getPreparedStmt(execStmt, vars) + if err != nil { + return + } } // TODO: Many same bool variables here. // We should set only two variables ( diff --git a/session/session.go b/session/session.go index fd5970602c553..48f62ae770528 100644 --- a/session/session.go +++ b/session/session.go @@ -523,9 +523,6 @@ func (s *session) retry(ctx context.Context, maxCnt uint) (err error) { s.sessionVars.RetryInfo.ResetOffset() for i, sr := range nh.history { st := sr.st - if st.IsReadOnly() { - continue - } s.sessionVars.StmtCtx = sr.stmtCtx s.sessionVars.StmtCtx.ResetForRetry() s.sessionVars.PreparedParams = s.sessionVars.PreparedParams[:0] diff --git a/session/session_test.go b/session/session_test.go index c4ec53d7912ca..74c890b246541 100644 --- a/session/session_test.go +++ b/session/session_test.go @@ -923,6 +923,25 @@ func (s *testSessionSuite) TestAutoIncrementWithRetry(c *C) { c.Assert(lastInsertID+3, Equals, currLastInsertID) } +func (s *testSessionSuite) TestBinaryReadOnly(c *C) { + tk := testkit.NewTestKitWithInit(c, s.store) + tk.MustExec("create table t (i int key)") + id, _, _, err := tk.Se.PrepareStmt("select i from t where i = ?") + c.Assert(err, IsNil) + id2, _, _, err := tk.Se.PrepareStmt("insert into t values (?)") + c.Assert(err, IsNil) + tk.MustExec("set autocommit = 0") + _, err = tk.Se.ExecutePreparedStmt(context.Background(), id, 1) + c.Assert(err, IsNil) + c.Assert(session.GetHistory(tk.Se).Count(), Equals, 0) + tk.MustExec("insert into t values (1)") + c.Assert(session.GetHistory(tk.Se).Count(), Equals, 1) + _, err = tk.Se.ExecutePreparedStmt(context.Background(), id2, 2) + c.Assert(err, IsNil) + c.Assert(session.GetHistory(tk.Se).Count(), Equals, 2) + tk.MustExec("commit") +} + func (s *testSessionSuite) TestPrepare(c *C) { tk := testkit.NewTestKitWithInit(c, s.store) tk.MustExec("create table t(id TEXT)") diff --git a/session/tidb.go b/session/tidb.go index 8f14ce5004ac7..347eccaa314f8 100644 --- a/session/tidb.go +++ b/session/tidb.go @@ -193,7 +193,7 @@ func runStmt(ctx context.Context, sctx sessionctx.Context, s sqlexec.Statement) sessVars := se.GetSessionVars() // All the history should be added here. sessVars.TxnCtx.StatementCount++ - if !s.IsReadOnly() { + if !s.IsReadOnly(sessVars) { if err == nil { GetHistory(sctx).Add(0, s, se.sessionVars.StmtCtx) } diff --git a/util/sqlexec/restricted_sql_executor.go b/util/sqlexec/restricted_sql_executor.go index 672845de616bc..0d665f1567a73 100644 --- a/util/sqlexec/restricted_sql_executor.go +++ b/util/sqlexec/restricted_sql_executor.go @@ -16,6 +16,7 @@ package sqlexec import ( "github.com/pingcap/parser/ast" "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/util/chunk" "golang.org/x/net/context" ) @@ -69,7 +70,7 @@ type Statement interface { IsPrepared() bool // IsReadOnly returns if the statement is read only. For example: SelectStmt without lock. - IsReadOnly() bool + IsReadOnly(vars *variable.SessionVars) bool // RebuildPlan rebuilds the plan of the statement. RebuildPlan() (schemaVersion int64, err error)