diff --git a/spanner/client.go b/spanner/client.go index a95b02431ad4..bfe00c2dcec0 100644 --- a/spanner/client.go +++ b/spanner/client.go @@ -563,6 +563,10 @@ func (c *Client) rwTransaction(ctx context.Context, f func(context.Context, *Rea } } if t.shouldExplicitBegin(attempt) { + // Make sure we set the current session handle before calling BeginTransaction. + // Note that the t.begin(ctx) call could change the session that is being used by the transaction, as the + // BeginTransaction RPC invocation will be retried on a new session if it returns SessionNotFound. + t.txReadOnly.sh = sh if err = t.begin(ctx); err != nil { trace.TracePrintf(ctx, nil, "Error while BeginTransaction during retrying a ReadWrite transaction: %v", ToSpannerError(err)) return ToSpannerError(err) @@ -571,9 +575,9 @@ func (c *Client) rwTransaction(ctx context.Context, f func(context.Context, *Rea t = &ReadWriteTransaction{ txReadyOrClosed: make(chan struct{}), } + t.txReadOnly.sh = sh } attempt++ - t.txReadOnly.sh = sh t.txReadOnly.sp = c.idleSessions t.txReadOnly.txReadEnv = t t.txReadOnly.qo = c.qo diff --git a/spanner/client_test.go b/spanner/client_test.go index 6f963102805e..1141a6ebb5bc 100644 --- a/spanner/client_test.go +++ b/spanner/client_test.go @@ -727,6 +727,202 @@ func TestClient_ReadOnlyTransaction_SessionNotFoundOnBeginTransaction_WithMaxOne } } +func TestClient_ReadWriteTransaction_SessionNotFoundForFirstStatement(t *testing.T) { + ctx := context.Background() + server, client, teardown := setupMockedTestServer(t) + defer teardown() + server.TestSpanner.PutExecutionTime( + MethodExecuteStreamingSql, + SimulatedExecutionTime{Errors: []error{newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")}}, + ) + + expectedAttempts := 2 + var attempts int + _, err := client.ReadWriteTransaction( + ctx, func(ctx context.Context, tx *ReadWriteTransaction) error { + attempts++ + iter := tx.Query(ctx, NewStatement(SelectFooFromBar)) + defer iter.Stop() + for { + _, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + return err + } + } + return nil + }) + if err != nil { + t.Fatal(err) + } + if expectedAttempts != attempts { + t.Fatalf("unexpected number of attempts: %d, expected %d", attempts, expectedAttempts) + } + requests := drainRequestsFromServer(server.TestSpanner) + if err := compareRequests([]interface{}{ + &sppb.BatchCreateSessionsRequest{}, + &sppb.ExecuteSqlRequest{}, + &sppb.BeginTransactionRequest{}, + &sppb.ExecuteSqlRequest{}, + &sppb.CommitRequest{}, + }, requests); err != nil { + t.Fatal(err) + } +} + +func TestClient_ReadWriteTransaction_SessionNotFoundForFirstStatement_AndThenSessionNotFoundForBeginTransaction(t *testing.T) { + ctx := context.Background() + server, client, teardown := setupMockedTestServer(t) + defer teardown() + server.TestSpanner.PutExecutionTime( + MethodExecuteStreamingSql, + SimulatedExecutionTime{Errors: []error{newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")}}, + ) + server.TestSpanner.PutExecutionTime( + MethodBeginTransaction, + SimulatedExecutionTime{Errors: []error{newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")}}, + ) + + expectedAttempts := 2 + var attempts int + _, err := client.ReadWriteTransaction( + ctx, func(ctx context.Context, tx *ReadWriteTransaction) error { + attempts++ + iter := tx.Query(ctx, NewStatement(SelectFooFromBar)) + defer iter.Stop() + for { + _, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + return err + } + } + return nil + }) + if err != nil { + t.Fatal(err) + } + if expectedAttempts != attempts { + t.Fatalf("unexpected number of attempts: %d, expected %d", attempts, expectedAttempts) + } + requests := drainRequestsFromServer(server.TestSpanner) + if err := compareRequests([]interface{}{ + &sppb.BatchCreateSessionsRequest{}, + &sppb.ExecuteSqlRequest{}, + &sppb.BeginTransactionRequest{}, + &sppb.BeginTransactionRequest{}, + &sppb.ExecuteSqlRequest{}, + &sppb.CommitRequest{}, + }, requests); err != nil { + t.Fatal(err) + } +} + +func TestClient_ReadWriteTransaction_AbortedForFirstStatement_AndThenSessionNotFoundForBeginTransaction(t *testing.T) { + ctx := context.Background() + server, client, teardown := setupMockedTestServer(t) + defer teardown() + server.TestSpanner.PutExecutionTime( + MethodExecuteStreamingSql, + SimulatedExecutionTime{Errors: []error{status.Error(codes.Aborted, "Transaction aborted")}}, + ) + server.TestSpanner.PutExecutionTime( + MethodBeginTransaction, + SimulatedExecutionTime{Errors: []error{newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")}}, + ) + + expectedAttempts := 2 + var attempts int + _, err := client.ReadWriteTransaction( + ctx, func(ctx context.Context, tx *ReadWriteTransaction) error { + attempts++ + iter := tx.Query(ctx, NewStatement(SelectFooFromBar)) + defer iter.Stop() + for { + _, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + return err + } + } + return nil + }) + if err != nil { + t.Fatal(err) + } + if expectedAttempts != attempts { + t.Fatalf("unexpected number of attempts: %d, expected %d", attempts, expectedAttempts) + } + requests := drainRequestsFromServer(server.TestSpanner) + if err := compareRequests([]interface{}{ + &sppb.BatchCreateSessionsRequest{}, + &sppb.ExecuteSqlRequest{}, + &sppb.BeginTransactionRequest{}, + &sppb.BeginTransactionRequest{}, + &sppb.ExecuteSqlRequest{}, + &sppb.CommitRequest{}, + }, requests); err != nil { + t.Fatal(err) + } +} + +func TestClient_ReadWriteTransaction_SessionNotFoundForFirstStatement_DoesNotLeakSession(t *testing.T) { + ctx := context.Background() + server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + MinOpened: 1, + MaxOpened: 1, + }, + }) + defer teardown() + server.TestSpanner.PutExecutionTime( + MethodExecuteStreamingSql, + SimulatedExecutionTime{Errors: []error{newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")}}, + ) + + expectedAttempts := 2 + var attempts int + _, err := client.ReadWriteTransaction( + ctx, func(ctx context.Context, tx *ReadWriteTransaction) error { + attempts++ + iter := tx.Query(ctx, NewStatement(SelectFooFromBar)) + defer iter.Stop() + for { + _, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + return err + } + } + return nil + }) + if err != nil { + t.Fatal(err) + } + if expectedAttempts != attempts { + t.Fatalf("unexpected number of attempts: %d, expected %d", attempts, expectedAttempts) + } + requests := drainRequestsFromServer(server.TestSpanner) + if err := compareRequests([]interface{}{ + &sppb.BatchCreateSessionsRequest{}, + &sppb.ExecuteSqlRequest{}, + &sppb.BatchCreateSessionsRequest{}, // We need to create more sessions, as the one used first was destroyed. + &sppb.BeginTransactionRequest{}, + &sppb.ExecuteSqlRequest{}, + &sppb.CommitRequest{}, + }, requests); err != nil { + t.Fatal(err) + } +} + func TestClient_ReadOnlyTransaction_QueryOptions(t *testing.T) { for _, tt := range queryOptionsTestCases() { t.Run(tt.name, func(t *testing.T) { diff --git a/spanner/integration_test.go b/spanner/integration_test.go index 7aa7d622df66..5d016ede945d 100644 --- a/spanner/integration_test.go +++ b/spanner/integration_test.go @@ -42,12 +42,14 @@ import ( adminpb "cloud.google.com/go/spanner/admin/database/apiv1/databasepb" instance "cloud.google.com/go/spanner/admin/instance/apiv1" "cloud.google.com/go/spanner/admin/instance/apiv1/instancepb" + v1 "cloud.google.com/go/spanner/apiv1" sppb "cloud.google.com/go/spanner/apiv1/spannerpb" "cloud.google.com/go/spanner/internal" "go.opencensus.io/stats/view" "go.opencensus.io/tag" "google.golang.org/api/iterator" "google.golang.org/api/option" + "google.golang.org/api/option/internaloption" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/peer" @@ -846,6 +848,55 @@ func TestIntegration_SingleUse_WithQueryOptions(t *testing.T) { } } +func TestIntegration_TransactionWasStartedInDifferentSession(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + defer cancel() + // Set up testing environment. + client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, statements[testDialect][singerDDLStatements]) + defer cleanup() + + attempts := 0 + _, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, transaction *ReadWriteTransaction) error { + attempts++ + if attempts == 1 { + deleteTestSession(ctx, t, transaction.sh.getID()) + } + if _, err := readAll(transaction.Query(ctx, NewStatement("select * from singers"))); err != nil { + return err + } + return nil + }) + if err != nil { + t.Fatal(err) + } + if g, w := attempts, 2; g != w { + t.Fatalf("attempts mismatch\nGot: %v\nWant: %v", g, w) + } +} + +func deleteTestSession(ctx context.Context, t *testing.T, sessionName string) { + var opts []option.ClientOption + if emulatorAddr := os.Getenv("SPANNER_EMULATOR_HOST"); emulatorAddr != "" { + emulatorOpts := []option.ClientOption{ + option.WithEndpoint(emulatorAddr), + option.WithGRPCDialOption(grpc.WithInsecure()), + option.WithoutAuthentication(), + internaloption.SkipDialSettingsValidation(), + } + opts = append(emulatorOpts, opts...) + } + gapic, err := v1.NewClient(ctx, opts...) + if err != nil { + t.Fatalf("could not create gapic client: %v", err) + } + defer gapic.Close() + if err := gapic.DeleteSession(ctx, &sppb.DeleteSessionRequest{Name: sessionName}); err != nil { + t.Fatal(err) + } +} + func TestIntegration_SingleUse_ReadingWithLimit(t *testing.T) { t.Parallel() diff --git a/spanner/internal/testutil/inmem_spanner_server.go b/spanner/internal/testutil/inmem_spanner_server.go index 922ae6ad1328..b1adf02f2182 100644 --- a/spanner/internal/testutil/inmem_spanner_server.go +++ b/spanner/internal/testutil/inmem_spanner_server.go @@ -581,13 +581,17 @@ func (s *inMemSpannerServer) beginTransaction(session *spannerpb.Session, option return res } -func (s *inMemSpannerServer) getTransactionByID(id []byte) (*spannerpb.Transaction, error) { +func (s *inMemSpannerServer) getTransactionByID(session *spannerpb.Session, id []byte) (*spannerpb.Transaction, error) { s.mu.Lock() defer s.mu.Unlock() tx, ok := s.transactions[string(id)] if !ok { return nil, gstatus.Error(codes.NotFound, "Transaction not found") } + if !strings.HasPrefix(string(id), session.Name) { + return nil, gstatus.Error(codes.InvalidArgument, "Transaction was started in a different session.") + } + aborted, ok := s.abortedTransactions[string(id)] if ok && aborted { return nil, newAbortedErrorWithMinimalRetryDelay() @@ -813,7 +817,7 @@ func (s *inMemSpannerServer) ExecuteSql(ctx context.Context, req *spannerpb.Exec var id []byte s.updateSessionLastUseTime(session.Name) if id = s.getTransactionID(session, req.Transaction); id != nil { - _, err = s.getTransactionByID(id) + _, err = s.getTransactionByID(session, id) if err != nil { return nil, err } @@ -860,7 +864,7 @@ func (s *inMemSpannerServer) executeStreamingSQL(req *spannerpb.ExecuteSqlReques s.updateSessionLastUseTime(session.Name) var id []byte if id = s.getTransactionID(session, req.Transaction); id != nil { - _, err = s.getTransactionByID(id) + _, err = s.getTransactionByID(session, id) if err != nil { return err } @@ -932,7 +936,7 @@ func (s *inMemSpannerServer) ExecuteBatchDml(ctx context.Context, req *spannerpb s.updateSessionLastUseTime(session.Name) var id []byte if id = s.getTransactionID(session, req.Transaction); id != nil { - _, err = s.getTransactionByID(id) + _, err = s.getTransactionByID(session, id) if err != nil { return nil, err } @@ -1031,7 +1035,7 @@ func (s *inMemSpannerServer) Commit(ctx context.Context, req *spannerpb.CommitRe if req.GetSingleUseTransaction() != nil { tx = s.beginTransaction(session, req.GetSingleUseTransaction()) } else if req.GetTransactionId() != nil { - tx, err = s.getTransactionByID(req.GetTransactionId()) + tx, err = s.getTransactionByID(session, req.GetTransactionId()) if err != nil { return nil, err } @@ -1064,7 +1068,7 @@ func (s *inMemSpannerServer) Rollback(ctx context.Context, req *spannerpb.Rollba return nil, err } s.updateSessionLastUseTime(session.Name) - tx, err := s.getTransactionByID(req.TransactionId) + tx, err := s.getTransactionByID(session, req.TransactionId) if err != nil { return nil, err } @@ -1091,7 +1095,7 @@ func (s *inMemSpannerServer) PartitionQuery(ctx context.Context, req *spannerpb. var tx *spannerpb.Transaction s.updateSessionLastUseTime(session.Name) if id = s.getTransactionID(session, req.Transaction); id != nil { - tx, err = s.getTransactionByID(id) + tx, err = s.getTransactionByID(session, id) if err != nil { return nil, err } diff --git a/spanner/transaction.go b/spanner/transaction.go index 85de18327f89..81d7e036521c 100644 --- a/spanner/transaction.go +++ b/spanner/transaction.go @@ -1380,15 +1380,13 @@ func (t *ReadWriteTransaction) begin(ctx context.Context) error { }() // Retry the BeginTransaction call if a 'Session not found' is returned. for { - if sh == nil || sh.getID() == "" || sh.getClient() == nil { + tx, err = beginTransaction(contextWithOutgoingMetadata(ctx, sh.getMetadata(), t.disableRouteToLeader), sh.getID(), sh.getClient(), t.txOpts) + if isSessionNotFoundError(err) { + sh.destroy() sh, err = t.sp.take(ctx) if err != nil { return err } - } - tx, err = beginTransaction(contextWithOutgoingMetadata(ctx, sh.getMetadata(), t.disableRouteToLeader), sh.getID(), sh.getClient(), t.txOpts) - if isSessionNotFoundError(err) { - sh.destroy() continue } else { err = ToSpannerError(err) @@ -1399,7 +1397,7 @@ func (t *ReadWriteTransaction) begin(ctx context.Context) error { t.mu.Lock() t.tx = tx t.sh = sh - // State transite to txActive. + // Transition state to txActive. t.state = txActive t.mu.Unlock() }