diff --git a/spanner/client_test.go b/spanner/client_test.go index 73693f14449e..36233f1f3091 100644 --- a/spanner/client_test.go +++ b/spanner/client_test.go @@ -112,14 +112,20 @@ func TestClient_Single_InvalidArgument(t *testing.T) { } func testSingleQuery(t *testing.T, serverError error) error { - ctx := context.Background() server, client, teardown := setupMockedTestServer(t) defer teardown() if serverError != nil { server.TestSpanner.SetError(serverError) } - iter := client.Single().Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums)) + tx := client.Single() + return executeTestQuery(t, tx) +} + +func executeTestQuery(t *testing.T, tx *ReadOnlyTransaction) error { + ctx := context.Background() + iter := tx.Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums)) defer iter.Stop() + var rowCount int64 for { row, err := iter.Next() if err == iterator.Done { @@ -133,6 +139,10 @@ func testSingleQuery(t *testing.T, serverError error) error { if err := row.Columns(&singerID, &albumID, &albumTitle); err != nil { return err } + rowCount++ + } + if rowCount != SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount { + t.Fatalf("Row count mismatch\ngot: %v\nwant: %v", rowCount, SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount) } return nil } @@ -200,6 +210,16 @@ func TestClient_ReadOnlyTransaction_UnavailableOnCreateSessionAndInvalidArgument } } +func TestClient_ReadOnlyTransaction_SessionNotFoundOnBeginTransaction(t *testing.T) { + t.Parallel() + exec := map[string]SimulatedExecutionTime{ + MethodBeginTransaction: {Errors: []error{gstatus.Error(codes.NotFound, "Session not found")}}, + } + if err := testReadOnlyTransaction(t, exec); err != nil { + t.Fatal(err) + } +} + func testReadOnlyTransaction(t *testing.T, executionTimes map[string]SimulatedExecutionTime) error { server, client, teardown := setupMockedTestServer(t) defer teardown() @@ -208,24 +228,7 @@ func testReadOnlyTransaction(t *testing.T, executionTimes map[string]SimulatedEx } tx := client.ReadOnlyTransaction() defer tx.Close() - ctx := context.Background() - iter := tx.Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums)) - defer iter.Stop() - for { - row, err := iter.Next() - if err == iterator.Done { - break - } - if err != nil { - return err - } - var singerID, albumID int64 - var albumTitle string - if err := row.Columns(&singerID, &albumID, &albumTitle); err != nil { - return err - } - } - return nil + return executeTestQuery(t, tx) } func TestClient_ReadWriteTransaction(t *testing.T) { diff --git a/spanner/internal/testutil/inmem_spanner_server.go b/spanner/internal/testutil/inmem_spanner_server.go index 5c4f54643307..022135b8f451 100644 --- a/spanner/internal/testutil/inmem_spanner_server.go +++ b/spanner/internal/testutil/inmem_spanner_server.go @@ -59,6 +59,7 @@ const ( MethodGetSession string = "GET_SESSION" MethodExecuteSql string = "EXECUTE_SQL" MethodExecuteStreamingSql string = "EXECUTE_STREAMING_SQL" + MethodStreamingRead string = "EXECUTE_STREAMING_READ" ) // StatementResult represents a mocked result on the test server. Th result can @@ -703,13 +704,9 @@ func (s *inMemSpannerServer) Read(ctx context.Context, req *spannerpb.ReadReques } func (s *inMemSpannerServer) StreamingRead(req *spannerpb.ReadRequest, stream spannerpb.Spanner_StreamingReadServer) error { - s.mu.Lock() - if s.stopped { - s.mu.Unlock() - return gstatus.Error(codes.Unavailable, "server has been stopped") + if err := s.simulateExecutionTime(MethodStreamingRead, req); err != nil { + return err } - s.receivedRequests <- req - s.mu.Unlock() return gstatus.Error(codes.Unimplemented, "Method not yet implemented") } diff --git a/spanner/transaction.go b/spanner/transaction.go index a53486b674f3..9f1414b2ab23 100644 --- a/spanner/transaction.go +++ b/spanner/transaction.go @@ -363,18 +363,41 @@ func (t *ReadOnlyTransaction) begin(ctx context.Context) error { sh.recycle() } }() - sh, err = t.sp.take(ctx) - if err != nil { - return err - } - res, err := sh.getClient().BeginTransaction(contextWithOutgoingMetadata(ctx, sh.getMetadata()), &sppb.BeginTransactionRequest{ - Session: sh.getID(), - Options: &sppb.TransactionOptions{ - Mode: &sppb.TransactionOptions_ReadOnly_{ - ReadOnly: buildTransactionOptionsReadOnly(t.getTimestampBound(), true), - }, + + // Create transaction options. + readOnlyOptions := buildTransactionOptionsReadOnly(t.getTimestampBound(), true) + transactionOptions := &sppb.TransactionOptions{ + Mode: &sppb.TransactionOptions_ReadOnly_{ + ReadOnly: readOnlyOptions, }, - }) + } + // Retry TakeSession and BeginTransaction on Session not found. + retryOnNotFound := gax.OnCodes([]codes.Code{codes.NotFound}, gax.Backoff{}) + beginTxWithRetry := func(ctx context.Context) (*sppb.Transaction, error) { + for { + sh, err = t.sp.take(ctx) + if err != nil { + return nil, err + } + client := sh.getClient() + ctx := contextWithOutgoingMetadata(ctx, sh.getMetadata()) + res, err := client.BeginTransaction(ctx, &sppb.BeginTransactionRequest{ + Session: sh.getID(), + Options: transactionOptions, + }) + if err == nil { + return res, nil + } + // We should not wait before retrying. + if _, shouldRetry := retryOnNotFound.Retry(err); !shouldRetry { + return nil, err + } + // Delete session and then retry with a new one. + sh.destroy() + } + } + + res, err := beginTxWithRetry(ctx) if err == nil { tx = res.Id if res.ReadTimestamp != nil { diff --git a/spanner/transaction_test.go b/spanner/transaction_test.go index 2976c96b595d..677b1b43113c 100644 --- a/spanner/transaction_test.go +++ b/spanner/transaction_test.go @@ -195,46 +195,64 @@ func TestApply_RetryOnAbort(t *testing.T) { } } -// Tests that NotFound errors cause failures, and aren't retried. +// Tests that NotFound errors cause failures, and aren't retried, except for +// BeginTransaction. func TestTransaction_NotFound(t *testing.T) { t.Parallel() ctx := context.Background() server, client, teardown := setupMockedTestServer(t) defer teardown() - wantErr := spannerErrorf(codes.NotFound, "Session not found") + errSessionNotFound := spannerErrorf(codes.NotFound, "Session not found") + // BeginTransaction should retry automatically. server.TestSpanner.PutExecutionTime(MethodBeginTransaction, SimulatedExecutionTime{ - Errors: []error{wantErr, wantErr, wantErr}, + Errors: []error{errSessionNotFound}, }) - server.TestSpanner.PutExecutionTime(MethodCommitTransaction, - SimulatedExecutionTime{ - Errors: []error{wantErr, wantErr, wantErr}, - }) - txn := client.ReadOnlyTransaction() - defer txn.Close() - - if _, _, got := txn.acquire(ctx); !testEqual(wantErr, got) { - t.Fatalf("Expect acquire to fail, got %v, want %v.", got, wantErr) + if _, _, got := txn.acquire(ctx); got != nil { + t.Fatalf("Expect acquire to succeed, got %v, want nil.", got) } + txn.Close() - // The failure should recycle the session, we expect it to be used in - // following requests. - if got := txn.Query(ctx, NewStatement("SELECT 1")); !testEqual(wantErr, got.err) { - t.Fatalf("Expect Query to fail, got %v, want %v.", got.err, wantErr) + // Query should fail with Session not found. + server.TestSpanner.PutExecutionTime(MethodExecuteStreamingSql, + SimulatedExecutionTime{ + Errors: []error{errSessionNotFound}, + }) + txn = client.ReadOnlyTransaction() + iter := txn.Query(ctx, NewStatement("SELECT 1")) + _, got := iter.Next() + if !testEqual(errSessionNotFound, got) { + t.Fatalf("Expect Query to fail\ngot: %v\nwant: %v", got, errSessionNotFound) } + iter.Stop() - if got := txn.Read(ctx, "Users", KeySets(Key{"alice"}, Key{"bob"}), []string{"name", "email"}); !testEqual(wantErr, got.err) { - t.Fatalf("Expect Read to fail, got %v, want %v.", got.err, wantErr) + // Read should fail with Session not found. + server.TestSpanner.PutExecutionTime(MethodStreamingRead, + SimulatedExecutionTime{ + Errors: []error{errSessionNotFound}, + }) + txn = client.ReadOnlyTransaction() + iter = txn.Read(ctx, "Users", KeySets(Key{"alice"}, Key{"bob"}), []string{"name", "email"}) + _, got = iter.Next() + if !testEqual(errSessionNotFound, got) { + t.Fatalf("Expect Read to fail\ngot: %v\nwant: %v", got, errSessionNotFound) } + iter.Stop() + + // Commit should fail with Session not found. + server.TestSpanner.PutExecutionTime(MethodCommitTransaction, + SimulatedExecutionTime{ + Errors: []error{errSessionNotFound}, + }) ms := []*Mutation{ Insert("Accounts", []string{"AccountId", "Nickname", "Balance"}, []interface{}{int64(1), "Foo", int64(50)}), Insert("Accounts", []string{"AccountId", "Nickname", "Balance"}, []interface{}{int64(2), "Bar", int64(1)}), } - if _, got := client.Apply(ctx, ms, ApplyAtLeastOnce()); !testEqual(wantErr, got) { - t.Fatalf("Expect Apply to fail, got %v, want %v.", got, wantErr) + if _, got := client.Apply(ctx, ms, ApplyAtLeastOnce()); !testEqual(errSessionNotFound, got) { + t.Fatalf("Expect Apply to fail\ngot: %v\nwant: %v", got, errSessionNotFound) } }