Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sessionctx: support encoding and decoding prepared statements #35808

Merged
merged 7 commits into from
Jun 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion plugin/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func TestAuditLogNormal(t *testing.T) {
defer clean()
sv := server.CreateMockServer(t, store)
defer sv.Close()
conn := server.CreateMockConn(t, store, sv)
conn := server.CreateMockConn(t, sv)
defer conn.Close()
session.DisableStats4Test()
session.SetSchemaLease(0)
Expand Down
92 changes: 82 additions & 10 deletions server/driver_tidb.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ package server
import (
"context"
"crypto/tls"
"fmt"
"strings"
"sync/atomic"

"github.com/pingcap/errors"
Expand All @@ -27,6 +29,8 @@ import (
"github.com/pingcap/tidb/parser/terror"
"github.com/pingcap/tidb/planner/core"
"github.com/pingcap/tidb/session"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/sessionstates"
"github.com/pingcap/tidb/sessionctx/stmtctx"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/chunk"
Expand All @@ -50,8 +54,7 @@ func NewTiDBDriver(store kv.Storage) *TiDBDriver {
// TiDBContext implements QueryCtx.
type TiDBContext struct {
session.Session
currentDB string
stmts map[int]*TiDBStatement
stmts map[int]*TiDBStatement
}

// TiDBStatement implements PreparedStatement.
Expand Down Expand Up @@ -199,10 +202,10 @@ func (qd *TiDBDriver) OpenCtx(connID uint64, capability uint32, collation uint8,
se.SetClientCapability(capability)
se.SetConnectionID(connID)
tc := &TiDBContext{
Session: se,
currentDB: dbname,
stmts: make(map[int]*TiDBStatement),
Session: se,
stmts: make(map[int]*TiDBStatement),
}
se.SetSessionStatesHandler(sessionstates.StatePrepareStmt, tc)
return tc, nil
}

Expand All @@ -211,11 +214,6 @@ func (tc *TiDBContext) GetWarnings() []stmtctx.SQLWarn {
return tc.GetSessionVars().StmtCtx.GetWarnings()
}

// CurrentDB implements QueryCtx CurrentDB method.
func (tc *TiDBContext) CurrentDB() string {
return tc.currentDB
}

// WarningCount implements QueryCtx WarningCount method.
func (tc *TiDBContext) WarningCount() uint16 {
return tc.GetSessionVars().StmtCtx.WarningCount()
Expand Down Expand Up @@ -308,6 +306,80 @@ func (tc *TiDBContext) GetStmtStats() *stmtstats.StatementStats {
return tc.Session.GetStmtStats()
}

// EncodeSessionStates implements SessionStatesHandler.EncodeSessionStates interface.
func (tc *TiDBContext) EncodeSessionStates(ctx context.Context, sctx sessionctx.Context, sessionStates *sessionstates.SessionStates) error {
sessionVars := tc.Session.GetSessionVars()
sessionStates.PreparedStmts = make(map[uint32]*sessionstates.PreparedStmtInfo, len(sessionVars.PreparedStmts))
for preparedID, preparedObj := range sessionVars.PreparedStmts {
preparedStmt, ok := preparedObj.(*core.CachedPrepareStmt)
if !ok {
return errors.Errorf("invalid CachedPreparedStmt type")
bb7133 marked this conversation as resolved.
Show resolved Hide resolved
}
sessionStates.PreparedStmts[preparedID] = &sessionstates.PreparedStmtInfo{
StmtText: preparedStmt.StmtText,
StmtDB: preparedStmt.StmtDB,
}
}
for name, id := range sessionVars.PreparedStmtNameToID {
// Only text protocol statements have names.
if preparedStmtInfo, ok := sessionStates.PreparedStmts[id]; ok {
preparedStmtInfo.Name = name
}
}
for id, stmt := range tc.stmts {
// Only binary protocol statements have paramTypes.
preparedStmtInfo, ok := sessionStates.PreparedStmts[uint32(id)]
if !ok {
return errors.Errorf("prepared statement %d not found", id)
}
preparedStmtInfo.ParamTypes = stmt.GetParamsType()
}
return nil
}

// DecodeSessionStates implements SessionStatesHandler.DecodeSessionStates interface.
func (tc *TiDBContext) DecodeSessionStates(ctx context.Context, sctx sessionctx.Context, sessionStates *sessionstates.SessionStates) error {
if len(sessionStates.PreparedStmts) == 0 {
return nil
}
sessionVars := tc.Session.GetSessionVars()
savedPreparedStmtID := sessionVars.GetNextPreparedStmtID()
savedCurrentDB := sessionVars.CurrentDB
defer func() {
sessionVars.SetNextPreparedStmtID(savedPreparedStmtID - 1)
sessionVars.CurrentDB = savedCurrentDB
}()

for id, preparedStmtInfo := range sessionStates.PreparedStmts {
// Set the next id and currentDB manually.
sessionVars.SetNextPreparedStmtID(id - 1)
sessionVars.CurrentDB = preparedStmtInfo.StmtDB
if preparedStmtInfo.Name == "" {
// Binary protocol: add to sessionVars.PreparedStmts and TiDBContext.stmts.
stmt, _, _, err := tc.Prepare(preparedStmtInfo.StmtText)
if err != nil {
return err
}
// Only binary protocol uses paramsType, which is passed from the first COM_STMT_EXECUTE.
stmt.SetParamsType(preparedStmtInfo.ParamTypes)
} else {
// Text protocol: add to sessionVars.PreparedStmts and sessionVars.PreparedStmtNameToID.
stmtText := strings.ReplaceAll(preparedStmtInfo.StmtText, "\\", "\\\\")
stmtText = strings.ReplaceAll(stmtText, "'", "\\'")
// Add single quotes because the sql_mode might contain ANSI_QUOTES.
sql := fmt.Sprintf("PREPARE `%s` FROM '%s'", preparedStmtInfo.Name, stmtText)
stmts, err := tc.Parse(ctx, sql)
if err != nil {
return err
}
if _, err = tc.ExecuteStmt(ctx, stmts[0]); err != nil {
return err
}
}
}
return nil
}

type tidbResultSet struct {
recordSet sqlexec.RecordSet
columns []*ColumnInfo
Expand Down
13 changes: 4 additions & 9 deletions server/mock_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ import (

"github.com/pingcap/tidb/config"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/session"
tmysql "github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/util/arena"
"github.com/pingcap/tidb/util/chunk"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -85,18 +84,14 @@ func CreateMockServer(t *testing.T, store kv.Storage) *Server {
}

// CreateMockConn creates a mock connection together with a session.
func CreateMockConn(t *testing.T, store kv.Storage, server *Server) MockConn {
se, err := session.CreateSession4Test(store)
func CreateMockConn(t *testing.T, server *Server) MockConn {
tc, err := server.driver.OpenCtx(uint64(0), 0, uint8(tmysql.DefaultCollationID), "", nil)
require.NoError(t, err)
tc := &TiDBContext{
Session: se,
stmts: make(map[int]*TiDBStatement),
}

cc := &clientConn{
server: server,
salt: []byte{},
collation: mysql.DefaultCollationID,
collation: tmysql.DefaultCollationID,
alloc: arena.NewAllocator(1024),
chunkAlloc: chunk.NewAllocator(),
pkt: &packetIO{
Expand Down
2 changes: 1 addition & 1 deletion server/mock_conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func TestMockConn(t *testing.T) {
defer clean()
server := CreateMockServer(t, store)
defer server.Close()
conn := CreateMockConn(t, store, server)
conn := CreateMockConn(t, server)
defer conn.Close()

require.NoError(t, conn.HandleQuery(context.Background(), "select 1"))
Expand Down
68 changes: 46 additions & 22 deletions session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ type Session interface {
// ExecutePreparedStmt executes a prepared statement.
ExecutePreparedStmt(ctx context.Context, stmtID uint32, param []types.Datum) (sqlexec.RecordSet, error)
DropPreparedStmt(stmtID uint32) error
// SetSessionStatesHandler sets SessionStatesHandler for type stateType.
SetSessionStatesHandler(stateType sessionstates.SessionStateType, handler sessionctx.SessionStatesHandler)
SetClientCapability(uint32) // Set client capability flags.
SetConnectionID(uint64)
SetCommandValue(byte)
Expand Down Expand Up @@ -248,6 +250,9 @@ type session struct {
// regularly.
stmtStats *stmtstats.StatementStats

// Used to encode and decode each type of session states.
sessionStatesHandlers map[sessionstates.SessionStateType]sessionctx.SessionStatesHandler

// Contains a list of sessions used to collect advisory locks.
advisoryLocks map[string]*advisoryLock
}
Expand Down Expand Up @@ -2760,6 +2765,11 @@ func (s *session) RefreshVars(ctx context.Context) error {
return nil
}

// SetSessionStatesHandler implements the Session.SetSessionStatesHandler interface.
func (s *session) SetSessionStatesHandler(stateType sessionstates.SessionStateType, handler sessionctx.SessionStatesHandler) {
s.sessionStatesHandlers[stateType] = handler
}

// CreateSession4Test creates a new session environment for test.
func CreateSession4Test(store kv.Storage) (Session, error) {
se, err := CreateSession4TestWithOpt(store, nil)
Expand Down Expand Up @@ -3006,12 +3016,13 @@ func createSessionWithOpt(store kv.Storage, opt *Opt) (*session, error) {
return nil, err
}
s := &session{
store: store,
sessionVars: variable.NewSessionVars(),
ddlOwnerManager: dom.DDL().OwnerManager(),
client: store.GetClient(),
mppClient: store.GetMPPClient(),
stmtStats: stmtstats.CreateStatementStats(),
store: store,
sessionVars: variable.NewSessionVars(),
ddlOwnerManager: dom.DDL().OwnerManager(),
client: store.GetClient(),
mppClient: store.GetMPPClient(),
stmtStats: stmtstats.CreateStatementStats(),
sessionStatesHandlers: make(map[sessionstates.SessionStateType]sessionctx.SessionStatesHandler),
}
s.functionUsageMu.builtinFunctionUsage = make(telemetry.BuiltinFunctionsUsage)
if plannercore.PreparedPlanCacheEnabled() {
Expand Down Expand Up @@ -3043,11 +3054,12 @@ func createSessionWithOpt(store kv.Storage, opt *Opt) (*session, error) {
// a lock context, which cause we can't call createSession directly.
func CreateSessionWithDomain(store kv.Storage, dom *domain.Domain) (*session, error) {
s := &session{
store: store,
sessionVars: variable.NewSessionVars(),
client: store.GetClient(),
mppClient: store.GetMPPClient(),
stmtStats: stmtstats.CreateStatementStats(),
store: store,
sessionVars: variable.NewSessionVars(),
client: store.GetClient(),
mppClient: store.GetMPPClient(),
stmtStats: stmtstats.CreateStatementStats(),
sessionStatesHandlers: make(map[sessionstates.SessionStateType]sessionctx.SessionStatesHandler),
}
s.functionUsageMu.builtinFunctionUsage = make(telemetry.BuiltinFunctionsUsage)
if plannercore.PreparedPlanCacheEnabled() {
Expand Down Expand Up @@ -3505,8 +3517,8 @@ func (s *session) GetStmtStats() *stmtstats.StatementStats {
}

// EncodeSessionStates implements SessionStatesHandler.EncodeSessionStates interface.
func (s *session) EncodeSessionStates(ctx context.Context, sctx sessionctx.Context, sessionStates *sessionstates.SessionStates) (err error) {
if err = s.sessionVars.EncodeSessionStates(ctx, sessionStates); err != nil {
func (s *session) EncodeSessionStates(ctx context.Context, sctx sessionctx.Context, sessionStates *sessionstates.SessionStates) error {
if err := s.sessionVars.EncodeSessionStates(ctx, sessionStates); err != nil {
return err
}

Expand All @@ -3527,25 +3539,37 @@ func (s *session) EncodeSessionStates(ctx context.Context, sctx sessionctx.Conte
continue
}
// Get all session variables because the default values may change between versions.
if val, keep, err := variable.GetSessionStatesSystemVar(s.sessionVars, sv.Name); err == nil && keep {
if val, keep, err := variable.GetSessionStatesSystemVar(s.sessionVars, sv.Name); err != nil {
return err
} else if keep {
sessionStates.SystemVars[sv.Name] = val
}
}
return

if handler, ok := s.sessionStatesHandlers[sessionstates.StatePrepareStmt]; ok {
if err := handler.EncodeSessionStates(ctx, s, sessionStates); err != nil {
return err
}
}
return nil
}

// DecodeSessionStates implements SessionStatesHandler.DecodeSessionStates interface.
func (s *session) DecodeSessionStates(ctx context.Context, sctx sessionctx.Context, sessionStates *sessionstates.SessionStates) (err error) {
func (s *session) DecodeSessionStates(ctx context.Context, sctx sessionctx.Context, sessionStates *sessionstates.SessionStates) error {
if handler, ok := s.sessionStatesHandlers[sessionstates.StatePrepareStmt]; ok {
if err := handler.DecodeSessionStates(ctx, s, sessionStates); err != nil {
return err
}
}

// Decode session variables.
for name, val := range sessionStates.SystemVars {
if err = variable.SetSessionSystemVar(s.sessionVars, name, val); err != nil {
if err := variable.SetSessionSystemVar(s.sessionVars, name, val); err != nil {
return err
}
}

// Decode stmt ctx after session vars because setting session vars may override stmt ctx, such as warnings.
if err = s.sessionVars.DecodeSessionStates(ctx, sessionStates); err != nil {
return err
}
return err
// Decoding session vars / prepared statements may override stmt ctx, such as warnings,
// so we decode stmt ctx at last.
return s.sessionVars.DecodeSessionStates(ctx, sessionStates)
}
18 changes: 18 additions & 0 deletions sessionctx/sessionstates/session_states.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,23 @@ import (
"github.com/pingcap/tidb/types"
)

// SessionStateType is the type of session states.
type SessionStateType int

// These enums represents the types of session state handlers.
const (
// StatePrepareStmt represents prepared statements.
StatePrepareStmt SessionStateType = iota
)

// PreparedStmtInfo contains the information about prepared statements, both text and binary protocols.
type PreparedStmtInfo struct {
Name string `json:"name,omitempty"`
StmtText string `json:"text"`
StmtDB string `json:"db,omitempty"`
ParamTypes []byte `json:"types,omitempty"`
}

// QueryInfo represents the information of last executed query. It's used to expose information for test purpose.
type QueryInfo struct {
TxnScope string `json:"txn_scope"`
Expand All @@ -42,6 +59,7 @@ type SessionStates struct {
UserVars map[string]*types.Datum `json:"user-var-values,omitempty"`
UserVarTypes map[string]*ptypes.FieldType `json:"user-var-types,omitempty"`
SystemVars map[string]string `json:"sys-vars,omitempty"`
PreparedStmts map[uint32]*PreparedStmtInfo `json:"prepared-stmts,omitempty"`
PreparedStmtID uint32 `json:"prepared-stmt-id,omitempty"`
Status uint16 `json:"status,omitempty"`
CurrentDB string `json:"current-db,omitempty"`
Expand Down
Loading