Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Retry to get session when session is invalid in SessionPool (#252) #261

Merged
merged 2 commits into from
Mar 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 6 additions & 14 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1202,15 +1202,11 @@ func TestReconnect(t *testing.T) {
defer pool.Close()

// Create session
var sessionList []*Session

for i := 0; i < 3; i++ {
session, err := pool.GetSession(username, password)
if err != nil {
t.Errorf("fail to create a new session from connection pool, %s", err.Error())
}
sessionList = append(sessionList, session)
session, err := pool.GetSession(username, password)
if err != nil {
t.Errorf("fail to create a new session from connection pool, %s", err.Error())
}
defer session.Release()

// Send query to server periodically
for i := 0; i < timeoutConfig.MaxConnPoolSize; i++ {
Expand All @@ -1221,7 +1217,7 @@ func TestReconnect(t *testing.T) {
if i == 7 {
stopContainer(t, "nebula-docker-compose_graphd1_1")
}
_, err := sessionList[0].Execute("SHOW HOSTS;")
_, err := session.Execute("SHOW HOSTS;")
fmt.Println("Sending query...")

if err != nil {
Expand All @@ -1230,7 +1226,7 @@ func TestReconnect(t *testing.T) {
}
}

resp, err := sessionList[0].Execute("SHOW HOSTS;")
resp, err := session.Execute("SHOW HOSTS;")
if err != nil {
t.Fatalf(err.Error())
return
Expand All @@ -1240,10 +1236,6 @@ func TestReconnect(t *testing.T) {
startContainer(t, "nebula-docker-compose_graphd_1")
startContainer(t, "nebula-docker-compose_graphd1_1")

for i := 0; i < len(sessionList); i++ {
sessionList[i].Release()
}

// Wait for graphd to be up
time.Sleep(5 * time.Second)
}
Expand Down
32 changes: 17 additions & 15 deletions configs.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,13 @@ func openAndReadFile(path string) ([]byte, error) {
// SessionPoolConf is the configs of a session pool
// Note that the space name is bound to the session pool for its lifetime
type SessionPoolConf struct {
username string // username for authentication
password string // password for authentication
serviceAddrs []HostAddress // service addresses for session pool
hostIndex int // index of the host in ServiceAddrs that the next new session will connect to
spaceName string // The space name that all sessions in the pool are bound to
sslConfig *tls.Config // Optional SSL config for the connection
username string // username for authentication
password string // password for authentication
serviceAddrs []HostAddress // service addresses for session pool
hostIndex int // index of the host in ServiceAddrs that the next new session will connect to
spaceName string // The space name that all sessions in the pool are bound to
sslConfig *tls.Config // Optional SSL config for the connection
retryGetSessionTimes int // The max times to retry get new session when executing a query

// Basic pool configs
// Socket timeout and Socket connection timeout, unit: seconds
Expand All @@ -138,15 +139,16 @@ func NewSessionPoolConf(
spaceName string, opts ...SessionPoolConfOption) (*SessionPoolConf, error) {
// Set default values for basic pool configs
newPoolConf := SessionPoolConf{
username: username,
password: password,
serviceAddrs: serviceAddrs,
spaceName: spaceName,
timeOut: 0 * time.Millisecond,
idleTime: 0 * time.Millisecond,
maxSize: 30,
minSize: 1,
hostIndex: 0,
username: username,
password: password,
serviceAddrs: serviceAddrs,
spaceName: spaceName,
retryGetSessionTimes: 1,
timeOut: 0 * time.Millisecond,
idleTime: 0 * time.Millisecond,
maxSize: 30,
minSize: 1,
hostIndex: 0,
}

// Iterate the given options and apply them to the config.
Expand Down
4 changes: 2 additions & 2 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,15 +198,15 @@ func (session *Session) ExecuteJsonWithParameter(stmt string, params map[string]
}

func (session *Session) reConnect() error {
newconnection, err := session.connPool.getIdleConn()
newConnection, err := session.connPool.getIdleConn()
if err != nil {
err = fmt.Errorf(err.Error())
return err
}

// Release connection to pool
session.connPool.release(session.connection)
session.connection = newconnection
session.connection = newConnection
return nil
}

Expand Down
56 changes: 55 additions & 1 deletion session_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@ import (
"container/list"
"crypto/tls"
"fmt"
"strconv"
"sync"
"time"

"github.com/vesoft-inc/nebula-go/v3/nebula"
"github.com/vesoft-inc/nebula-go/v3/nebula/graph"
)

// SessionPool is a pool that manages sessions internally.
Expand Down Expand Up @@ -119,7 +121,15 @@ func (pool *SessionPool) ExecuteWithParameter(stmt string, params map[string]int
}

// Execute the query
resp, err := session.connection.executeWithParameter(session.sessionID, stmt, paramsMap)
execFunc := func(s *Session) (*graph.ExecutionResponse, error) {
resp, err := s.connection.executeWithParameter(s.sessionID, stmt, paramsMap)
if err != nil {
return nil, err
}
return resp, nil
}

resp, err := pool.executeWithRetry(session, execFunc, pool.conf.retryGetSessionTimes)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -385,6 +395,50 @@ func (pool *SessionPool) getIdleSession() (*Session, error) {
" session pool and the total session count has reached the limit")
}

// retryGetSession tries to create a new session when the current session is invalid.
func (pool *SessionPool) executeWithRetry(
session *Session,
f func(*Session) (*graph.ExecutionResponse, error),
retry int) (*graph.ExecutionResponse, error) {
pool.rwLock.Lock()
defer pool.rwLock.Unlock()

resp, err := f(session)
if err != nil {
pool.removeSessionFromList(&pool.activeSessions, session)
return nil, err
}

if resp.ErrorCode == nebula.ErrorCode_SUCCEEDED {
return resp, nil
} else if ErrorCode(resp.ErrorCode) != ErrorCode_E_SESSION_INVALID { // only retry when the session is invalid
return resp, err
}

// remove invalid session regardless of the retry is successful or not
defer pool.removeSessionFromList(&pool.activeSessions, session)
// If the session is invalid, close it and get a new session
for i := 0; i < retry; i++ {
pool.log.Info("retry to get sessions")
newSession, err := pool.newSession()
if err != nil {
return nil, err
}

pingErr := newSession.Ping()
if pingErr != nil {
pool.log.Error("failed to ping the session, error: " + pingErr.Error())
continue
}
pool.log.Info("retry to get sessions successfully")
pool.addSessionToList(&pool.activeSessions, newSession)

return f(newSession)
}
pool.log.Error(fmt.Sprintf("failed to get session after " + strconv.Itoa(retry) + " retries"))
return nil, fmt.Errorf("failed to get session after %d retries", retry)
}

// startCleaner starts sessionCleaner if idleTime > 0.
func (pool *SessionPool) startCleaner() {
if pool.conf.idleTime > 0 && pool.cleanerChan == nil {
Expand Down
44 changes: 44 additions & 0 deletions session_pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,50 @@ func TestIdleSessionCleaner(t *testing.T) {
sessionPool.conf.minSize, sessionPool.GetTotalSessionCount())
}

func TestRetryGetSession(t *testing.T) {
err := prepareSpace("client_test")
if err != nil {
t.Fatal(err)
}
defer dropSpace("client_test")

hostAddress := HostAddress{Host: address, Port: port}
config, err := NewSessionPoolConf(
"root",
"nebula",
[]HostAddress{hostAddress},
"client_test")
if err != nil {
t.Errorf("failed to create session pool config, %s", err.Error())
}
config.minSize = 2
config.maxSize = 2
config.retryGetSessionTimes = 1

// create session pool
sessionPool, err := NewSessionPool(*config, DefaultLogger{})
if err != nil {
t.Fatal(err)
}
defer sessionPool.Close()

// kill all sessions in the cluster
resultSet, err := sessionPool.Execute("SHOW SESSIONS | KILL SESSIONS $-.SessionId")
if err != nil {
t.Fatal(err)
}
assert.True(t, resultSet.IsSucceed(), fmt.Errorf("error code: %d, error msg: %s",
resultSet.GetErrorCode(), resultSet.GetErrorMsg()))

// execute query, it should retry to get session
resultSet, err = sessionPool.Execute("SHOW HOSTS;")
if err != nil {
t.Fatal(err)
}
assert.True(t, resultSet.IsSucceed(), fmt.Errorf("error code: %d, error msg: %s",
resultSet.GetErrorCode(), resultSet.GetErrorMsg()))
}

func BenchmarkConcurrency(b *testing.B) {
err := prepareSpace("client_test")
if err != nil {
Expand Down