diff --git a/client.go b/client.go index 4f0ee8e..f1d18f8 100644 --- a/client.go +++ b/client.go @@ -435,7 +435,7 @@ func (c *Client) ReadWriteTransaction(ctx context.Context, f func(context.Contex ts time.Time sh *sessionHandle ) - err = runWithRetryOnAborted(ctx, func(ctx context.Context) error { + err = runWithRetryOnAbortedOrSessionNotFound(ctx, func(ctx context.Context) error { var ( err error t *ReadWriteTransaction diff --git a/client_test.go b/client_test.go index 8b83f61..36cf0b9 100644 --- a/client_test.go +++ b/client_test.go @@ -68,7 +68,7 @@ func setupMockedTestServerWithConfigAndClientOptions(t *testing.T, config Client server, opts, serverTeardown := NewMockedSpannerInMemTestServer(t) opts = append(opts, clientOptions...) ctx := context.Background() - var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]") + formattedDatabase := fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]") client, err := NewClientWithConfig(ctx, formattedDatabase, config, opts...) if err != nil { t.Fatal(err) @@ -609,6 +609,165 @@ func TestClient_ReadWriteTransactionCommitAborted(t *testing.T) { } } +func TestClient_ReadWriteTransaction_SessionNotFoundOnCommit(t *testing.T) { + t.Parallel() + if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{ + MethodCommitTransaction: {Errors: []error{status.Error(codes.NotFound, "Session not found")}}, + }, 2); err != nil { + t.Fatal(err) + } +} + +func TestClient_ReadWriteTransaction_SessionNotFoundOnBeginTransaction(t *testing.T) { + t.Parallel() + // We expect only 1 attempt, as the 'Session not found' error is already + //handled in the session pool where the session is prepared. + if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{ + MethodBeginTransaction: {Errors: []error{status.Error(codes.NotFound, "Session not found")}}, + }, 1); err != nil { + t.Fatal(err) + } +} + +func TestClient_ReadWriteTransaction_SessionNotFoundOnBeginTransactionWithEmptySessionPool(t *testing.T) { + t.Parallel() + // There will be no prepared sessions in the pool, so the error will occur + // when the transaction tries to get a session from the pool. This will + // also be handled by the session pool, so the transaction itself does not + // need to retry, hence the expectedAttempts == 1. + if err := testReadWriteTransactionWithConfig(t, ClientConfig{ + SessionPoolConfig: SessionPoolConfig{WriteSessions: 0.0}, + }, map[string]SimulatedExecutionTime{ + MethodBeginTransaction: {Errors: []error{status.Error(codes.NotFound, "Session not found")}}, + }, 1); err != nil { + t.Fatal(err) + } +} + +func TestClient_ReadWriteTransaction_SessionNotFoundOnExecuteStreamingSql(t *testing.T) { + t.Parallel() + if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{ + MethodExecuteStreamingSql: {Errors: []error{status.Error(codes.NotFound, "Session not found")}}, + }, 2); err != nil { + t.Fatal(err) + } +} + +func TestClient_ReadWriteTransaction_SessionNotFoundOnExecuteUpdate(t *testing.T) { + t.Parallel() + + server, client, teardown := setupMockedTestServer(t) + defer teardown() + server.TestSpanner.PutExecutionTime( + MethodExecuteSql, + SimulatedExecutionTime{Errors: []error{status.Error(codes.NotFound, "Session not found")}}, + ) + ctx := context.Background() + var attempts int + _, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error { + attempts++ + rowCount, err := tx.Update(ctx, NewStatement(UpdateBarSetFoo)) + if err != nil { + return err + } + if g, w := rowCount, int64(UpdateBarSetFooRowCount); g != w { + return status.Errorf(codes.FailedPrecondition, "Row count mismatch\nGot: %v\nWant: %v", g, w) + } + return nil + }) + if err != nil { + t.Fatal(err) + } + if g, w := attempts, 2; g != w { + t.Fatalf("number of attempts mismatch:\nGot%d\nWant:%d", g, w) + } +} + +func TestClient_ReadWriteTransaction_SessionNotFoundOnExecuteBatchUpdate(t *testing.T) { + t.Parallel() + + server, client, teardown := setupMockedTestServer(t) + defer teardown() + server.TestSpanner.PutExecutionTime( + MethodExecuteBatchDml, + SimulatedExecutionTime{Errors: []error{status.Error(codes.NotFound, "Session not found")}}, + ) + ctx := context.Background() + var attempts int + _, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error { + attempts++ + rowCounts, err := tx.BatchUpdate(ctx, []Statement{NewStatement(UpdateBarSetFoo)}) + if err != nil { + return err + } + if g, w := len(rowCounts), 1; g != w { + return status.Errorf(codes.FailedPrecondition, "Row counts length mismatch\nGot: %v\nWant: %v", g, w) + } + if g, w := rowCounts[0], int64(UpdateBarSetFooRowCount); g != w { + return status.Errorf(codes.FailedPrecondition, "Row count mismatch\nGot: %v\nWant: %v", g, w) + } + return nil + }) + if err != nil { + t.Fatal(err) + } + if g, w := attempts, 2; g != w { + t.Fatalf("number of attempts mismatch:\nGot%d\nWant:%d", g, w) + } +} + +func TestClient_SessionNotFound(t *testing.T) { + // Ensure we always have at least one session in the pool. + sc := SessionPoolConfig{ + MinOpened: 1, + } + server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{SessionPoolConfig: sc}) + defer teardown() + ctx := context.Background() + for { + client.idleSessions.mu.Lock() + numSessions := client.idleSessions.idleList.Len() + client.idleSessions.mu.Unlock() + if numSessions > 0 { + break + } + time.After(time.Millisecond) + } + // Remove the session from the server without the pool knowing it. + _, err := server.TestSpanner.DeleteSession(ctx, &sppb.DeleteSessionRequest{Name: client.idleSessions.idleList.Front().Value.(*session).id}) + if err != nil { + t.Fatalf("Failed to delete session unexpectedly: %v", err) + } + + _, err = client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error { + iter := tx.Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums)) + defer iter.Stop() + rowCount := int64(0) + for { + row, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + return err + } + var singerID, albumID int64 + var albumTitle string + if err := row.Columns(&singerID, &albumID, &albumTitle); err != nil { + return err + } + rowCount++ + } + if rowCount != SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount { + return spannerErrorf(codes.FailedPrecondition, "Row count mismatch, got %v, expected %v", rowCount, SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount) + } + return nil + }) + if err != nil { + t.Fatalf("Unexpected error during transaction: %v", err) + } +} + func TestClient_ReadWriteTransactionExecuteStreamingSqlAborted(t *testing.T) { t.Parallel() if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{ @@ -801,6 +960,10 @@ func TestClient_ReadWriteTransactionCommitAlreadyExists(t *testing.T) { } func testReadWriteTransaction(t *testing.T, executionTimes map[string]SimulatedExecutionTime, expectedAttempts int) error { + return testReadWriteTransactionWithConfig(t, ClientConfig{SessionPoolConfig: DefaultSessionPoolConfig}, executionTimes, expectedAttempts) +} + +func testReadWriteTransactionWithConfig(t *testing.T, config ClientConfig, executionTimes map[string]SimulatedExecutionTime, expectedAttempts int) error { server, client, teardown := setupMockedTestServer(t) defer teardown() for method, exec := range executionTimes { @@ -966,3 +1129,50 @@ func TestReadWriteTransaction_WrapError(t *testing.T) { t.Fatalf("Unexpected error\nGot: %v\nWant: %v", err, msg) } } + +func TestReadWriteTransaction_WrapSessionNotFoundError(t *testing.T) { + t.Parallel() + server, client, teardown := setupMockedTestServer(t) + defer teardown() + server.TestSpanner.PutExecutionTime(MethodBeginTransaction, + SimulatedExecutionTime{ + Errors: []error{status.Error(codes.NotFound, "Session not found")}, + }) + server.TestSpanner.PutExecutionTime(MethodExecuteStreamingSql, + SimulatedExecutionTime{ + Errors: []error{status.Error(codes.NotFound, "Session not found")}, + }) + server.TestSpanner.PutExecutionTime(MethodCommitTransaction, + SimulatedExecutionTime{ + Errors: []error{status.Error(codes.NotFound, "Session not found")}, + }) + msg := "query failed" + numAttempts := 0 + ctx := context.Background() + _, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error { + numAttempts++ + iter := tx.Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums)) + defer iter.Stop() + for { + _, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + // Wrap the error in another error that implements the + // (xerrors|errors).Wrapper interface. + return &wrappedTestError{err, msg} + } + } + return nil + }) + if err != nil { + t.Fatalf("Unexpected error\nGot: %v\nWant: nil", err) + } + // We want 3 attempts. The 'Session not found' error on BeginTransaction + // will not retry the entire transaction, which means that we will have two + // failed attempts and then a successful attempt. + if g, w := numAttempts, 3; g != w { + t.Fatalf("Number of transaction attempts mismatch\nGot: %d\nWant: %d", g, w) + } +} diff --git a/internal/testutil/inmem_spanner_server.go b/internal/testutil/inmem_spanner_server.go index f107f18..eb1c870 100644 --- a/internal/testutil/inmem_spanner_server.go +++ b/internal/testutil/inmem_spanner_server.go @@ -431,7 +431,7 @@ func (s *inMemSpannerServer) findSession(name string) (*spannerpb.Session, error defer s.mu.Unlock() session := s.sessions[name] if session == nil { - return nil, gstatus.Error(codes.NotFound, fmt.Sprintf("Session %s not found", name)) + return nil, gstatus.Error(codes.NotFound, fmt.Sprintf("Session not found: %s", name)) } return session, nil } diff --git a/retry.go b/retry.go index cc2d520..75231af 100644 --- a/retry.go +++ b/retry.go @@ -69,11 +69,13 @@ func (r *spannerRetryer) Retry(err error) (time.Duration, bool) { return delay, true } -// runWithRetryOnAborted executes the given function and retries it if it -// returns an Aborted error. The delay between retries is the delay returned -// by Cloud Spanner, and if none is returned, the calculated delay with a -// minimum of 10ms and maximum of 32s. -func runWithRetryOnAborted(ctx context.Context, f func(context.Context) error) error { +// runWithRetryOnAbortedOrSessionNotFound executes the given function and +// retries it if it returns an Aborted or Session not found error. The retry +// is delayed if the error was Aborted. The delay between retries is the delay +// returned by Cloud Spanner, or if none is returned, the calculated delay with +// a minimum of 10ms and maximum of 32s. There is no delay before the retry if +// the error was Session not found. +func runWithRetryOnAbortedOrSessionNotFound(ctx context.Context, f func(context.Context) error) error { retryer := onCodes(DefaultRetryBackoff, codes.Aborted) funcWithRetry := func(ctx context.Context) error { for { @@ -99,6 +101,10 @@ func runWithRetryOnAborted(ctx context.Context, f func(context.Context) error) e } retryErr = err } + if isSessionNotFoundError(retryErr) { + trace.TracePrintf(ctx, nil, "Retrying after Session not found") + continue + } delay, shouldRetry := retryer.Retry(retryErr) if !shouldRetry { return err diff --git a/session.go b/session.go index 6545c8a..9a3dfae 100644 --- a/session.go +++ b/session.go @@ -130,22 +130,23 @@ func (sh *sessionHandle) getTransactionID() transactionID { func (sh *sessionHandle) destroy() { sh.mu.Lock() s := sh.session - p := s.pool tracked := sh.trackedSessionHandle sh.session = nil sh.trackedSessionHandle = nil sh.checkoutTime = time.Time{} sh.stack = nil sh.mu.Unlock() + + if s == nil { + // sessionHandle has already been destroyed.. + return + } if tracked != nil { + p := s.pool p.mu.Lock() p.trackedSessionHandles.Remove(tracked) p.mu.Unlock() } - if s == nil { - // sessionHandle has already been destroyed.. - return - } s.destroy(false) } @@ -764,7 +765,7 @@ func (p *sessionPool) createSession(ctx context.Context) (*session, error) { func (p *sessionPool) isHealthy(s *session) bool { if s.getNextCheck().Add(2 * p.hc.getInterval()).Before(time.Now()) { // TODO: figure out if we need to schedule a new healthcheck worker here. - if err := s.ping(); shouldDropSession(err) { + if err := s.ping(); isSessionNotFoundError(err) { // The session is already bad, continue to fetch/create a new one. s.destroy(false) return false @@ -923,6 +924,13 @@ func (p *sessionPool) takeWriteSession(ctx context.Context) (*sessionHandle, err } if !s.isWritePrepared() { if err = s.prepareForWrite(ctx); err != nil { + if isSessionNotFoundError(err) { + s.destroy(false) + trace.TracePrintf(ctx, map[string]interface{}{"sessionID": s.getID()}, + "Session not found for write") + return nil, toSpannerError(err) + } + s.recycle() trace.TracePrintf(ctx, map[string]interface{}{"sessionID": s.getID()}, "Error preparing session for write") @@ -1230,7 +1238,7 @@ func (hc *healthChecker) healthCheck(s *session) { s.destroy(false) return } - if err := s.ping(); shouldDropSession(err) { + if err := s.ping(); isSessionNotFoundError(err) { // Ping failed, destroy the session. s.destroy(false) } @@ -1497,23 +1505,6 @@ func (hc *healthChecker) shrinkPool(ctx context.Context, shrinkToNumSessions uin } } -// shouldDropSession returns true if a particular error leads to the removal of -// a session -func shouldDropSession(err error) bool { - if err == nil { - return false - } - // If a Cloud Spanner can no longer locate the session (for example, if - // session is garbage collected), then caller should not try to return the - // session back into the session pool. - // - // TODO: once gRPC can return auxiliary error information, stop parsing the error message. - if ErrCode(err) == codes.NotFound && strings.Contains(ErrDesc(err), "Session not found") { - return true - } - return false -} - // maxUint64 returns the maximum of two uint64. func maxUint64(a, b uint64) uint64 { if a > b { @@ -1533,9 +1524,13 @@ func minUint64(a, b uint64) uint64 { // isSessionNotFoundError returns true if the given error is a // `Session not found` error. func isSessionNotFoundError(err error) bool { + if err == nil { + return false + } // We are checking specifically for the error message `Session not found`, // as the error could also be a `Database not found`. The latter should // cause the session pool to stop preparing sessions for read/write // transactions, while the former should not. + // TODO: once gRPC can return auxiliary error information, stop parsing the error message. return ErrCode(err) == codes.NotFound && strings.Contains(err.Error(), "Session not found") } diff --git a/transaction.go b/transaction.go index 39505f2..4b8825e 100644 --- a/transaction.go +++ b/transaction.go @@ -357,7 +357,7 @@ func (t *ReadOnlyTransaction) begin(ctx context.Context) error { if err != nil && sh != nil { // Got a valid session handle, but failed to initialize transaction= // on Cloud Spanner. - if shouldDropSession(err) { + if isSessionNotFoundError(err) { sh.destroy() } // If sh.destroy was already executed, this becomes a noop. @@ -527,7 +527,7 @@ func (t *ReadOnlyTransaction) release(err error) { sh := t.sh t.mu.Unlock() if sh != nil { // sh could be nil if t.acquire() fails. - if shouldDropSession(err) { + if isSessionNotFoundError(err) { sh.destroy() } if t.singleUse { @@ -795,7 +795,7 @@ func (t *ReadWriteTransaction) release(err error) { t.mu.Lock() sh := t.sh t.mu.Unlock() - if sh != nil && shouldDropSession(err) { + if sh != nil && isSessionNotFoundError(err) { sh.destroy() } } @@ -831,7 +831,7 @@ func (t *ReadWriteTransaction) begin(ctx context.Context) error { t.state = txActive return nil } - if shouldDropSession(err) { + if isSessionNotFoundError(err) { t.sh.destroy() } return err @@ -869,7 +869,7 @@ func (t *ReadWriteTransaction) commit(ctx context.Context) (time.Time, error) { if tstamp := res.GetCommitTimestamp(); tstamp != nil { ts = time.Unix(tstamp.Seconds, int64(tstamp.Nanos)) } - if shouldDropSession(err) { + if isSessionNotFoundError(err) { t.sh.destroy() } return ts, err @@ -892,7 +892,7 @@ func (t *ReadWriteTransaction) rollback(ctx context.Context) { Session: sid, TransactionId: t.tx, }) - if shouldDropSession(err) { + if isSessionNotFoundError(err) { t.sh.destroy() } } @@ -914,6 +914,10 @@ func (t *ReadWriteTransaction) runInTransaction(ctx context.Context, f func(cont // one's wound-wait priority. return ts, err } + if isSessionNotFoundError(err) { + t.sh.destroy() + return ts, err + } // Not going to commit, according to API spec, should rollback the // transaction. t.rollback(ctx) @@ -973,7 +977,7 @@ func (t *writeOnlyTransaction) applyAtLeastOnce(ctx context.Context, ms ...*Muta Mutations: mPb, }, gax.WithGRPCOptions(grpc.Trailer(&trailers))) if err != nil && !isAbortErr(err) { - if shouldDropSession(err) { + if isSessionNotFoundError(err) { // Discard the bad session. sh.destroy() }