Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

planner: introduce new functions for the general statement in session interface #37024

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions server/conn_stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,9 +199,13 @@ func (cc *clientConn) handleStmtExecute(ctx context.Context, data []byte) (err e
return errors.Annotate(err, cc.preparedStmt2String(stmtID))
}
}
return cc.executePlanCacheStmt(ctx, stmt, args, useCursor)
}

func (cc *clientConn) executePlanCacheStmt(ctx context.Context, stmt interface{}, args []expression.Expression, useCursor bool) (err error) {
ctx = context.WithValue(ctx, execdetails.StmtExecDetailKey, &execdetails.StmtExecDetails{})
ctx = context.WithValue(ctx, util.ExecDetailsKey, &util.ExecDetails{})
retryable, err := cc.executePreparedStmtAndWriteResult(ctx, stmt, args, useCursor)
retryable, err := cc.executePreparedStmtAndWriteResult(ctx, stmt.(PreparedStatement), args, useCursor)
if err != nil {
action, txnErr := sessiontxn.GetTxnManager(&cc.ctx).OnStmtErrorForNextAction(sessiontxn.StmtErrAfterQuery, err)
if txnErr != nil {
Expand All @@ -210,7 +214,7 @@ func (cc *clientConn) handleStmtExecute(ctx context.Context, data []byte) (err e

if retryable && action == sessiontxn.StmtActionRetryReady {
cc.ctx.GetSessionVars().RetryInfo.Retrying = true
_, err = cc.executePreparedStmtAndWriteResult(ctx, stmt, args, useCursor)
_, err = cc.executePreparedStmtAndWriteResult(ctx, stmt.(PreparedStatement), args, useCursor)
cc.ctx.GetSessionVars().RetryInfo.Retrying = false
return err
}
Expand All @@ -224,7 +228,7 @@ func (cc *clientConn) handleStmtExecute(ctx context.Context, data []byte) (err e
defer func() {
cc.ctx.GetSessionVars().IsolationReadEngines[kv.TiFlash] = struct{}{}
}()
_, err = cc.executePreparedStmtAndWriteResult(ctx, stmt, args, useCursor)
_, err = cc.executePreparedStmtAndWriteResult(ctx, stmt.(PreparedStatement), args, useCursor)
// We append warning after the retry because `ResetContextOfStmt` may be called during the retry, which clears warnings.
cc.ctx.GetSessionVars().StmtCtx.AppendError(prevErr)
}
Expand Down
50 changes: 35 additions & 15 deletions session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@ type Session interface {
CacheGeneralStmt(sql string) error
// ExecutePreparedStmt executes a prepared statement.
ExecutePreparedStmt(ctx context.Context, stmtID uint32, param []expression.Expression) (sqlexec.RecordSet, error)
// ExecuteGeneralStmt executes a general statement.
ExecuteGeneralStmt(ctx context.Context, sql string, param []expression.Expression) (sqlexec.RecordSet, error)
DropPreparedStmt(stmtID uint32) error
// SetSessionStatesHandler sets SessionStatesHandler for type stateType.
SetSessionStatesHandler(stateType sessionstates.SessionStateType, handler sessionctx.SessionStatesHandler)
Expand Down Expand Up @@ -2328,31 +2330,49 @@ func (s *session) preparedStmtExec(ctx context.Context, execStmt *ast.ExecuteStm
}

// ExecutePreparedStmt executes a prepared statement.
func (s *session) ExecutePreparedStmt(ctx context.Context, stmtID uint32, args []expression.Expression) (sqlexec.RecordSet, error) {
var err error
if err = s.PrepareTxnCtx(ctx); err != nil {
return nil, err
}

s.sessionVars.StartTime = time.Now()
func (s *session) ExecutePreparedStmt(ctx context.Context, stmtID uint32, params []expression.Expression) (sqlexec.RecordSet, error) {
prepStmt, err := s.sessionVars.GetPreparedStmtByID(stmtID)
if err != nil {
err = plannercore.ErrStmtNotFound
logutil.Logger(ctx).Error("prepared statement not found", zap.Uint32("stmtID", stmtID))
return nil, err
}
preparedStmt, ok := prepStmt.(*plannercore.PlanCacheStmt)
stmt, ok := prepStmt.(*plannercore.PlanCacheStmt)
if !ok {
return nil, errors.Errorf("invalid PlanCacheStmt type")
}
return s.executePlanCacheStmt(ctx, stmt, params)
}

// ExecuteGeneralStmt executes a general statement.
func (s *session) ExecuteGeneralStmt(ctx context.Context, sql string, params []expression.Expression) (sqlexec.RecordSet, error) {
generalStmt := s.sessionVars.GetGeneralPlanCacheStmt(sql)
if generalStmt == nil {
err := plannercore.ErrStmtNotFound
logutil.Logger(ctx).Error("general statement not found", zap.String("sql", sql))
return nil, err
}
stmt, ok := generalStmt.(*plannercore.PlanCacheStmt)
if !ok {
return nil, errors.Errorf("invalid PlanCacheStmt type")
}
return s.executePlanCacheStmt(ctx, stmt, params)
}

func (s *session) executePlanCacheStmt(ctx context.Context, stmt *plannercore.PlanCacheStmt, params []expression.Expression) (sqlexec.RecordSet, error) {
var err error
if err = s.PrepareTxnCtx(ctx); err != nil {
return nil, err
}

execStmt := &ast.ExecuteStmt{PrepStmt: prepStmt, BinaryArgs: args}
s.sessionVars.StartTime = time.Now()
execStmt := &ast.ExecuteStmt{PrepStmt: stmt, BinaryArgs: params}
if err := executor.ResetContextOfStmt(s, execStmt); err != nil {
return nil, err
}

staleReadProcessor := staleread.NewStaleReadProcessor(s)
if err = staleReadProcessor.OnExecutePreparedStmt(preparedStmt.SnapshotTSEvaluator); err != nil {
if err = staleReadProcessor.OnExecutePreparedStmt(stmt.SnapshotTSEvaluator); err != nil {
return nil, err
}

Expand All @@ -2372,17 +2392,17 @@ func (s *session) ExecutePreparedStmt(ctx context.Context, stmtID uint32, args [
}
}

executor.CountStmtNode(preparedStmt.PreparedAst.Stmt, s.sessionVars.InRestrictedSQL)
s.txn.onStmtStart(preparedStmt.SQLDigest.String())
executor.CountStmtNode(stmt.PreparedAst.Stmt, s.sessionVars.InRestrictedSQL)
s.txn.onStmtStart(stmt.SQLDigest.String())
defer s.txn.onStmtEnd()

if err = s.onTxnManagerStmtStartOrRetry(ctx, execStmt); err != nil {
return nil, err
}
s.setRequestSource(ctx, preparedStmt.PreparedAst.StmtType, preparedStmt.PreparedAst.Stmt)
s.setRequestSource(ctx, stmt.PreparedAst.StmtType, stmt.PreparedAst.Stmt)
// even the txn is valid, still need to set session variable for coprocessor usage.
s.sessionVars.RequestSourceType = preparedStmt.PreparedAst.StmtType
return s.preparedStmtExec(ctx, execStmt, preparedStmt)
s.sessionVars.RequestSourceType = stmt.PreparedAst.StmtType
return s.preparedStmtExec(ctx, execStmt, stmt)
}

func (s *session) DropPreparedStmt(stmtID uint32) error {
Expand Down