Skip to content

Commit

Permalink
spanner: make ReadWriteTransaction retry on Session not found error
Browse files Browse the repository at this point in the history
Updates #1527

Ref: googleapis/google-cloud-go#1527
Change-Id: Iea12342ca098c8056abc2206b91edbeda630e718
Reviewed-on: https://code-review.googlesource.com/c/gocloud/+/45910
Reviewed-by: kokoro <[email protected]>
Reviewed-by: Hengfeng Li <[email protected]>
  • Loading branch information
110y authored and olavloite committed Jan 19, 2020
1 parent 9b5d7a6 commit 6d9834c
Show file tree
Hide file tree
Showing 6 changed files with 254 additions and 39 deletions.
2 changes: 1 addition & 1 deletion client.go
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ func (c *Client) ReadWriteTransaction(ctx context.Context, f func(context.Contex
ts time.Time
sh *sessionHandle
)
err = runWithRetryOnAborted(ctx, func(ctx context.Context) error {
err = runWithRetryOnAbortedOrSessionNotFound(ctx, func(ctx context.Context) error {
var (
err error
t *ReadWriteTransaction
Expand Down
212 changes: 211 additions & 1 deletion client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func setupMockedTestServerWithConfigAndClientOptions(t *testing.T, config Client
server, opts, serverTeardown := NewMockedSpannerInMemTestServer(t)
opts = append(opts, clientOptions...)
ctx := context.Background()
var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]")
formattedDatabase := fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]")
client, err := NewClientWithConfig(ctx, formattedDatabase, config, opts...)
if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -609,6 +609,165 @@ func TestClient_ReadWriteTransactionCommitAborted(t *testing.T) {
}
}

func TestClient_ReadWriteTransaction_SessionNotFoundOnCommit(t *testing.T) {
t.Parallel()
if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{
MethodCommitTransaction: {Errors: []error{status.Error(codes.NotFound, "Session not found")}},
}, 2); err != nil {
t.Fatal(err)
}
}

func TestClient_ReadWriteTransaction_SessionNotFoundOnBeginTransaction(t *testing.T) {
t.Parallel()
// We expect only 1 attempt, as the 'Session not found' error is already
//handled in the session pool where the session is prepared.
if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{
MethodBeginTransaction: {Errors: []error{status.Error(codes.NotFound, "Session not found")}},
}, 1); err != nil {
t.Fatal(err)
}
}

func TestClient_ReadWriteTransaction_SessionNotFoundOnBeginTransactionWithEmptySessionPool(t *testing.T) {
t.Parallel()
// There will be no prepared sessions in the pool, so the error will occur
// when the transaction tries to get a session from the pool. This will
// also be handled by the session pool, so the transaction itself does not
// need to retry, hence the expectedAttempts == 1.
if err := testReadWriteTransactionWithConfig(t, ClientConfig{
SessionPoolConfig: SessionPoolConfig{WriteSessions: 0.0},
}, map[string]SimulatedExecutionTime{
MethodBeginTransaction: {Errors: []error{status.Error(codes.NotFound, "Session not found")}},
}, 1); err != nil {
t.Fatal(err)
}
}

func TestClient_ReadWriteTransaction_SessionNotFoundOnExecuteStreamingSql(t *testing.T) {
t.Parallel()
if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{
MethodExecuteStreamingSql: {Errors: []error{status.Error(codes.NotFound, "Session not found")}},
}, 2); err != nil {
t.Fatal(err)
}
}

func TestClient_ReadWriteTransaction_SessionNotFoundOnExecuteUpdate(t *testing.T) {
t.Parallel()

server, client, teardown := setupMockedTestServer(t)
defer teardown()
server.TestSpanner.PutExecutionTime(
MethodExecuteSql,
SimulatedExecutionTime{Errors: []error{status.Error(codes.NotFound, "Session not found")}},
)
ctx := context.Background()
var attempts int
_, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
attempts++
rowCount, err := tx.Update(ctx, NewStatement(UpdateBarSetFoo))
if err != nil {
return err
}
if g, w := rowCount, int64(UpdateBarSetFooRowCount); g != w {
return status.Errorf(codes.FailedPrecondition, "Row count mismatch\nGot: %v\nWant: %v", g, w)
}
return nil
})
if err != nil {
t.Fatal(err)
}
if g, w := attempts, 2; g != w {
t.Fatalf("number of attempts mismatch:\nGot%d\nWant:%d", g, w)
}
}

func TestClient_ReadWriteTransaction_SessionNotFoundOnExecuteBatchUpdate(t *testing.T) {
t.Parallel()

server, client, teardown := setupMockedTestServer(t)
defer teardown()
server.TestSpanner.PutExecutionTime(
MethodExecuteBatchDml,
SimulatedExecutionTime{Errors: []error{status.Error(codes.NotFound, "Session not found")}},
)
ctx := context.Background()
var attempts int
_, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
attempts++
rowCounts, err := tx.BatchUpdate(ctx, []Statement{NewStatement(UpdateBarSetFoo)})
if err != nil {
return err
}
if g, w := len(rowCounts), 1; g != w {
return status.Errorf(codes.FailedPrecondition, "Row counts length mismatch\nGot: %v\nWant: %v", g, w)
}
if g, w := rowCounts[0], int64(UpdateBarSetFooRowCount); g != w {
return status.Errorf(codes.FailedPrecondition, "Row count mismatch\nGot: %v\nWant: %v", g, w)
}
return nil
})
if err != nil {
t.Fatal(err)
}
if g, w := attempts, 2; g != w {
t.Fatalf("number of attempts mismatch:\nGot%d\nWant:%d", g, w)
}
}

func TestClient_SessionNotFound(t *testing.T) {
// Ensure we always have at least one session in the pool.
sc := SessionPoolConfig{
MinOpened: 1,
}
server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{SessionPoolConfig: sc})
defer teardown()
ctx := context.Background()
for {
client.idleSessions.mu.Lock()
numSessions := client.idleSessions.idleList.Len()
client.idleSessions.mu.Unlock()
if numSessions > 0 {
break
}
time.After(time.Millisecond)
}
// Remove the session from the server without the pool knowing it.
_, err := server.TestSpanner.DeleteSession(ctx, &sppb.DeleteSessionRequest{Name: client.idleSessions.idleList.Front().Value.(*session).id})
if err != nil {
t.Fatalf("Failed to delete session unexpectedly: %v", err)
}

_, err = client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
iter := tx.Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums))
defer iter.Stop()
rowCount := int64(0)
for {
row, err := iter.Next()
if err == iterator.Done {
break
}
if err != nil {
return err
}
var singerID, albumID int64
var albumTitle string
if err := row.Columns(&singerID, &albumID, &albumTitle); err != nil {
return err
}
rowCount++
}
if rowCount != SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount {
return spannerErrorf(codes.FailedPrecondition, "Row count mismatch, got %v, expected %v", rowCount, SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount)
}
return nil
})
if err != nil {
t.Fatalf("Unexpected error during transaction: %v", err)
}
}

func TestClient_ReadWriteTransactionExecuteStreamingSqlAborted(t *testing.T) {
t.Parallel()
if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{
Expand Down Expand Up @@ -801,6 +960,10 @@ func TestClient_ReadWriteTransactionCommitAlreadyExists(t *testing.T) {
}

func testReadWriteTransaction(t *testing.T, executionTimes map[string]SimulatedExecutionTime, expectedAttempts int) error {
return testReadWriteTransactionWithConfig(t, ClientConfig{SessionPoolConfig: DefaultSessionPoolConfig}, executionTimes, expectedAttempts)
}

func testReadWriteTransactionWithConfig(t *testing.T, config ClientConfig, executionTimes map[string]SimulatedExecutionTime, expectedAttempts int) error {
server, client, teardown := setupMockedTestServer(t)
defer teardown()
for method, exec := range executionTimes {
Expand Down Expand Up @@ -966,3 +1129,50 @@ func TestReadWriteTransaction_WrapError(t *testing.T) {
t.Fatalf("Unexpected error\nGot: %v\nWant: %v", err, msg)
}
}

func TestReadWriteTransaction_WrapSessionNotFoundError(t *testing.T) {
t.Parallel()
server, client, teardown := setupMockedTestServer(t)
defer teardown()
server.TestSpanner.PutExecutionTime(MethodBeginTransaction,
SimulatedExecutionTime{
Errors: []error{status.Error(codes.NotFound, "Session not found")},
})
server.TestSpanner.PutExecutionTime(MethodExecuteStreamingSql,
SimulatedExecutionTime{
Errors: []error{status.Error(codes.NotFound, "Session not found")},
})
server.TestSpanner.PutExecutionTime(MethodCommitTransaction,
SimulatedExecutionTime{
Errors: []error{status.Error(codes.NotFound, "Session not found")},
})
msg := "query failed"
numAttempts := 0
ctx := context.Background()
_, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
numAttempts++
iter := tx.Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums))
defer iter.Stop()
for {
_, err := iter.Next()
if err == iterator.Done {
break
}
if err != nil {
// Wrap the error in another error that implements the
// (xerrors|errors).Wrapper interface.
return &wrappedTestError{err, msg}
}
}
return nil
})
if err != nil {
t.Fatalf("Unexpected error\nGot: %v\nWant: nil", err)
}
// We want 3 attempts. The 'Session not found' error on BeginTransaction
// will not retry the entire transaction, which means that we will have two
// failed attempts and then a successful attempt.
if g, w := numAttempts, 3; g != w {
t.Fatalf("Number of transaction attempts mismatch\nGot: %d\nWant: %d", g, w)
}
}
2 changes: 1 addition & 1 deletion internal/testutil/inmem_spanner_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ func (s *inMemSpannerServer) findSession(name string) (*spannerpb.Session, error
defer s.mu.Unlock()
session := s.sessions[name]
if session == nil {
return nil, gstatus.Error(codes.NotFound, fmt.Sprintf("Session %s not found", name))
return nil, gstatus.Error(codes.NotFound, fmt.Sprintf("Session not found: %s", name))
}
return session, nil
}
Expand Down
16 changes: 11 additions & 5 deletions retry.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,13 @@ func (r *spannerRetryer) Retry(err error) (time.Duration, bool) {
return delay, true
}

// runWithRetryOnAborted executes the given function and retries it if it
// returns an Aborted error. The delay between retries is the delay returned
// by Cloud Spanner, and if none is returned, the calculated delay with a
// minimum of 10ms and maximum of 32s.
func runWithRetryOnAborted(ctx context.Context, f func(context.Context) error) error {
// runWithRetryOnAbortedOrSessionNotFound executes the given function and
// retries it if it returns an Aborted or Session not found error. The retry
// is delayed if the error was Aborted. The delay between retries is the delay
// returned by Cloud Spanner, or if none is returned, the calculated delay with
// a minimum of 10ms and maximum of 32s. There is no delay before the retry if
// the error was Session not found.
func runWithRetryOnAbortedOrSessionNotFound(ctx context.Context, f func(context.Context) error) error {
retryer := onCodes(DefaultRetryBackoff, codes.Aborted)
funcWithRetry := func(ctx context.Context) error {
for {
Expand All @@ -99,6 +101,10 @@ func runWithRetryOnAborted(ctx context.Context, f func(context.Context) error) e
}
retryErr = err
}
if isSessionNotFoundError(retryErr) {
trace.TracePrintf(ctx, nil, "Retrying after Session not found")
continue
}
delay, shouldRetry := retryer.Retry(retryErr)
if !shouldRetry {
return err
Expand Down
43 changes: 19 additions & 24 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,22 +130,23 @@ func (sh *sessionHandle) getTransactionID() transactionID {
func (sh *sessionHandle) destroy() {
sh.mu.Lock()
s := sh.session
p := s.pool
tracked := sh.trackedSessionHandle
sh.session = nil
sh.trackedSessionHandle = nil
sh.checkoutTime = time.Time{}
sh.stack = nil
sh.mu.Unlock()

if s == nil {
// sessionHandle has already been destroyed..
return
}
if tracked != nil {
p := s.pool
p.mu.Lock()
p.trackedSessionHandles.Remove(tracked)
p.mu.Unlock()
}
if s == nil {
// sessionHandle has already been destroyed..
return
}
s.destroy(false)
}

Expand Down Expand Up @@ -764,7 +765,7 @@ func (p *sessionPool) createSession(ctx context.Context) (*session, error) {
func (p *sessionPool) isHealthy(s *session) bool {
if s.getNextCheck().Add(2 * p.hc.getInterval()).Before(time.Now()) {
// TODO: figure out if we need to schedule a new healthcheck worker here.
if err := s.ping(); shouldDropSession(err) {
if err := s.ping(); isSessionNotFoundError(err) {
// The session is already bad, continue to fetch/create a new one.
s.destroy(false)
return false
Expand Down Expand Up @@ -923,6 +924,13 @@ func (p *sessionPool) takeWriteSession(ctx context.Context) (*sessionHandle, err
}
if !s.isWritePrepared() {
if err = s.prepareForWrite(ctx); err != nil {
if isSessionNotFoundError(err) {
s.destroy(false)
trace.TracePrintf(ctx, map[string]interface{}{"sessionID": s.getID()},
"Session not found for write")
return nil, toSpannerError(err)
}

s.recycle()
trace.TracePrintf(ctx, map[string]interface{}{"sessionID": s.getID()},
"Error preparing session for write")
Expand Down Expand Up @@ -1230,7 +1238,7 @@ func (hc *healthChecker) healthCheck(s *session) {
s.destroy(false)
return
}
if err := s.ping(); shouldDropSession(err) {
if err := s.ping(); isSessionNotFoundError(err) {
// Ping failed, destroy the session.
s.destroy(false)
}
Expand Down Expand Up @@ -1497,23 +1505,6 @@ func (hc *healthChecker) shrinkPool(ctx context.Context, shrinkToNumSessions uin
}
}

// shouldDropSession returns true if a particular error leads to the removal of
// a session
func shouldDropSession(err error) bool {
if err == nil {
return false
}
// If a Cloud Spanner can no longer locate the session (for example, if
// session is garbage collected), then caller should not try to return the
// session back into the session pool.
//
// TODO: once gRPC can return auxiliary error information, stop parsing the error message.
if ErrCode(err) == codes.NotFound && strings.Contains(ErrDesc(err), "Session not found") {
return true
}
return false
}

// maxUint64 returns the maximum of two uint64.
func maxUint64(a, b uint64) uint64 {
if a > b {
Expand All @@ -1533,9 +1524,13 @@ func minUint64(a, b uint64) uint64 {
// isSessionNotFoundError returns true if the given error is a
// `Session not found` error.
func isSessionNotFoundError(err error) bool {
if err == nil {
return false
}
// We are checking specifically for the error message `Session not found`,
// as the error could also be a `Database not found`. The latter should
// cause the session pool to stop preparing sessions for read/write
// transactions, while the former should not.
// TODO: once gRPC can return auxiliary error information, stop parsing the error message.
return ErrCode(err) == codes.NotFound && strings.Contains(err.Error(), "Session not found")
}
Loading

0 comments on commit 6d9834c

Please sign in to comment.