From e9a8e3ad04b655dcf202c989f0295892b62e2b37 Mon Sep 17 00:00:00 2001 From: rahul2393 Date: Mon, 23 Dec 2024 11:09:57 +0530 Subject: [PATCH] chore(spanner): track precommit token for R/W multiplexed session (#11229) * chore(spanner): add support for multiplexed session with read write transactions. * fix tests * incorporate changes * disable multiplxed session for ReadWrite only when unimplemented error is because of multiplex from server * re-trigger --- spanner/client.go | 80 +++++--- .../internal/testutil/inmem_spanner_server.go | 35 +++- spanner/kokoro/presubmit.sh | 1 + spanner/read.go | 47 +++-- spanner/session.go | 25 ++- spanner/transaction.go | 27 ++- spanner/transaction_test.go | 179 ++++++++++++++++++ 7 files changed, 333 insertions(+), 61 deletions(-) diff --git a/spanner/client.go b/spanner/client.go index ef53bafdb42c..1ce356266108 100644 --- a/spanner/client.go +++ b/spanner/client.go @@ -107,19 +107,20 @@ func parseDatabaseName(db string) (project, instance, database string, err error // Client is a client for reading and writing data to a Cloud Spanner database. // A client is safe to use concurrently, except for its Close method. type Client struct { - sc *sessionClient - idleSessions *sessionPool - logger *log.Logger - qo QueryOptions - ro ReadOptions - ao []ApplyOption - txo TransactionOptions - bwo BatchWriteOptions - ct *commonTags - disableRouteToLeader bool - dro *sppb.DirectedReadOptions - otConfig *openTelemetryConfig - metricsTracerFactory *builtinMetricsTracerFactory + sc *sessionClient + idleSessions *sessionPool + logger *log.Logger + qo QueryOptions + ro ReadOptions + ao []ApplyOption + txo TransactionOptions + bwo BatchWriteOptions + ct *commonTags + disableRouteToLeader bool + enableMultiplexedSessionForRW bool + dro *sppb.DirectedReadOptions + otConfig *openTelemetryConfig + metricsTracerFactory *builtinMetricsTracerFactory } // DatabaseName returns the full name of a database, e.g., @@ -487,6 +488,21 @@ func newClientWithConfig(ctx context.Context, database string, config ClientConf md.Append(endToEndTracingHeader, "true") } + if isMultiplexed := strings.ToLower(os.Getenv("GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS")); isMultiplexed != "" { + config.SessionPoolConfig.enableMultiplexSession, err = strconv.ParseBool(isMultiplexed) + if err != nil { + return nil, spannerErrorf(codes.InvalidArgument, "GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS must be either true or false") + } + } + //TODO: Uncomment this once the feature is enabled. + //if isMultiplexForRW := os.Getenv("GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS_FOR_RW"); isMultiplexForRW != "" { + // config.enableMultiplexedSessionForRW, err = strconv.ParseBool(isMultiplexForRW) + // if err != nil { + // return nil, spannerErrorf(codes.InvalidArgument, "GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS_FOR_RW must be either true or false") + // } + // config.enableMultiplexedSessionForRW = config.enableMultiplexedSessionForRW && config.SessionPoolConfig.enableMultiplexSession + //} + // Create a session client. sc := newSessionClient(pool, database, config.UserAgent, sessionLabels, config.DatabaseRole, config.DisableRouteToLeader, md, config.BatchTimeout, config.Logger, config.CallOptions) @@ -532,19 +548,20 @@ func newClientWithConfig(ctx context.Context, database string, config ClientConf } c = &Client{ - sc: sc, - idleSessions: sp, - logger: config.Logger, - qo: getQueryOptions(config.QueryOptions), - ro: config.ReadOptions, - ao: config.ApplyOptions, - txo: config.TransactionOptions, - bwo: config.BatchWriteOptions, - ct: getCommonTags(sc), - disableRouteToLeader: config.DisableRouteToLeader, - dro: config.DirectedReadOptions, - otConfig: otConfig, - metricsTracerFactory: metricsTracerFactory, + sc: sc, + idleSessions: sp, + logger: config.Logger, + qo: getQueryOptions(config.QueryOptions), + ro: config.ReadOptions, + ao: config.ApplyOptions, + txo: config.TransactionOptions, + bwo: config.BatchWriteOptions, + ct: getCommonTags(sc), + disableRouteToLeader: config.DisableRouteToLeader, + dro: config.DirectedReadOptions, + otConfig: otConfig, + metricsTracerFactory: metricsTracerFactory, + enableMultiplexedSessionForRW: config.enableMultiplexedSessionForRW, } return c, nil } @@ -1008,8 +1025,12 @@ func (c *Client) rwTransaction(ctx context.Context, f func(context.Context, *Rea err error ) if sh == nil || sh.getID() == "" || sh.getClient() == nil { - // Session handle hasn't been allocated or has been destroyed. - sh, err = c.idleSessions.take(ctx) + if c.enableMultiplexedSessionForRW { + sh, err = c.idleSessions.takeMultiplexed(ctx) + } else { + // Session handle hasn't been allocated or has been destroyed. + sh, err = c.idleSessions.take(ctx) + } if err != nil { // If session retrieval fails, just fail the transaction. return err @@ -1050,6 +1071,9 @@ func (c *Client) rwTransaction(ctx context.Context, f func(context.Context, *Rea resp, err = t.runInTransaction(ctx, f) return err }) + if isUnimplementedErrorForMultiplexedRW(err) { + c.enableMultiplexedSessionForRW = false + } return resp, err } diff --git a/spanner/internal/testutil/inmem_spanner_server.go b/spanner/internal/testutil/inmem_spanner_server.go index 08be3b21742c..86770e4d2948 100644 --- a/spanner/internal/testutil/inmem_spanner_server.go +++ b/spanner/internal/testutil/inmem_spanner_server.go @@ -25,6 +25,7 @@ import ( "sort" "strings" "sync" + "sync/atomic" "time" "cloud.google.com/go/spanner/apiv1/spannerpb" @@ -333,7 +334,8 @@ type inMemSpannerServer struct { // counters. transactionCounters map[string]*uint64 // The transactions that have been created on this mock server. - transactions map[string]*spannerpb.Transaction + transactions map[string]*spannerpb.Transaction + multiplexedSessionTransactionsToSeqNo map[string]*atomic.Int32 // The transactions that have been (manually) aborted on the server. abortedTransactions map[string]bool // The transactions that are marked as PartitionedDMLTransaction @@ -521,11 +523,25 @@ func (s *inMemSpannerServer) initDefaults() { s.sessions = make(map[string]*spannerpb.Session) s.sessionLastUseTime = make(map[string]time.Time) s.transactions = make(map[string]*spannerpb.Transaction) + s.multiplexedSessionTransactionsToSeqNo = make(map[string]*atomic.Int32) s.abortedTransactions = make(map[string]bool) s.partitionedDmlTransactions = make(map[string]bool) s.transactionCounters = make(map[string]*uint64) } +func (s *inMemSpannerServer) getPreCommitToken(transactionID, operation string) *spannerpb.MultiplexedSessionPrecommitToken { + s.mu.Lock() + defer s.mu.Unlock() + sequence, ok := s.multiplexedSessionTransactionsToSeqNo[transactionID] + if !ok { + return nil + } + return &spannerpb.MultiplexedSessionPrecommitToken{ + SeqNum: sequence.Add(1), + PrecommitToken: []byte(fmt.Sprintf("precommit-token-%v-%v", operation, sequence.Load())), + } +} + func (s *inMemSpannerServer) generateSessionNameLocked(database string, isMultiplexed bool) string { s.sessionCounter++ if isMultiplexed { @@ -597,6 +613,9 @@ func (s *inMemSpannerServer) beginTransaction(session *spannerpb.Session, option ReadTimestamp: getCurrentTimestamp(), } s.mu.Lock() + if options.GetReadWrite() != nil && session.Multiplexed { + s.multiplexedSessionTransactionsToSeqNo[id] = new(atomic.Int32) + } s.transactions[id] = res s.partitionedDmlTransactions[id] = options.GetPartitionedDml() != nil s.mu.Unlock() @@ -634,6 +653,7 @@ func (s *inMemSpannerServer) removeTransaction(tx *spannerpb.Transaction) { s.mu.Lock() defer s.mu.Unlock() delete(s.transactions, string(tx.Id)) + delete(s.multiplexedSessionTransactionsToSeqNo, string(tx.Id)) delete(s.partitionedDmlTransactions, string(tx.Id)) } @@ -870,9 +890,16 @@ func (s *inMemSpannerServer) ExecuteSql(ctx context.Context, req *spannerpb.Exec case StatementResultError: return nil, statementResult.Err case StatementResultResultSet: + + // if request's session is multiplexed and transaction is Read/Write then add Pre-commit Token in Metadata + if statementResult.ResultSet != nil { + statementResult.ResultSet.PrecommitToken = s.getPreCommitToken(string(id), "ResultSetPrecommitToken") + } return statementResult.ResultSet, nil case StatementResultUpdateCount: - return statementResult.convertUpdateCountToResultSet(!isPartitionedDml), nil + res := statementResult.convertUpdateCountToResultSet(!isPartitionedDml) + res.PrecommitToken = s.getPreCommitToken(string(id), "ResultSetPrecommitToken") + return res, nil } return nil, gstatus.Error(codes.Internal, "Unknown result type") } @@ -938,6 +965,9 @@ func (s *inMemSpannerServer) executeStreamingSQL(req *spannerpb.ExecuteSqlReques return nextPartialResultSetError.Err } } + // For every PartialResultSet, if request's session is multiplexed and transaction is Read/Write then add Pre-commit Token in Metadata + // and increment the sequence number + part.PrecommitToken = s.getPreCommitToken(string(id), "PartialResultSetPrecommitToken") if err := stream.Send(part); err != nil { return err } @@ -997,6 +1027,7 @@ func (s *inMemSpannerServer) ExecuteBatchDml(ctx context.Context, req *spannerpb resp.ResultSets[idx] = statementResult.convertUpdateCountToResultSet(!isPartitionedDml) } } + resp.PrecommitToken = s.getPreCommitToken(string(id), "ExecuteBatchDmlResponsePrecommitToken") return resp, nil } diff --git a/spanner/kokoro/presubmit.sh b/spanner/kokoro/presubmit.sh index f92440377a06..f5ac7c1b389a 100755 --- a/spanner/kokoro/presubmit.sh +++ b/spanner/kokoro/presubmit.sh @@ -46,6 +46,7 @@ exit_code=0 case $JOB_TYPE in integration-with-multiplexed-session ) GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS=true + GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS_FOR_RW=true echo "running presubmit with multiplexed sessions enabled: $GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS" ;; esac diff --git a/spanner/read.go b/spanner/read.go index 32c3f488050a..3e1ffaaccbc5 100644 --- a/spanner/read.go +++ b/spanner/read.go @@ -68,6 +68,7 @@ func stream( func(err error) error { return err }, + nil, setTimestamp, release, gsc, @@ -85,6 +86,7 @@ func streamWithReplaceSessionFunc( replaceSession func(ctx context.Context) error, setTransactionID func(transactionID), updateTxState func(err error) error, + updatePrecommitToken func(token *sppb.MultiplexedSessionPrecommitToken), setTimestamp func(time.Time), release func(error), gsc *grpcSpannerClient, @@ -92,14 +94,15 @@ func streamWithReplaceSessionFunc( ctx, cancel := context.WithCancel(ctx) ctx = trace.StartSpan(ctx, "cloud.google.com/go/spanner.RowIterator") return &RowIterator{ - meterTracerFactory: meterTracerFactory, - streamd: newResumableStreamDecoder(ctx, logger, rpc, replaceSession, gsc), - rowd: &partialResultSetDecoder{}, - setTransactionID: setTransactionID, - updateTxState: updateTxState, - setTimestamp: setTimestamp, - release: release, - cancel: cancel, + meterTracerFactory: meterTracerFactory, + streamd: newResumableStreamDecoder(ctx, logger, rpc, replaceSession, gsc), + rowd: &partialResultSetDecoder{}, + setTransactionID: setTransactionID, + updatePrecommitToken: updatePrecommitToken, + updateTxState: updateTxState, + setTimestamp: setTimestamp, + release: release, + cancel: cancel, } } @@ -130,18 +133,19 @@ type RowIterator struct { // RowIterator.Next() returned an error that is not equal to iterator.Done. Metadata *sppb.ResultSetMetadata - ctx context.Context - meterTracerFactory *builtinMetricsTracerFactory - streamd *resumableStreamDecoder - rowd *partialResultSetDecoder - setTransactionID func(transactionID) - updateTxState func(err error) error - setTimestamp func(time.Time) - release func(error) - cancel func() - err error - rows []*Row - sawStats bool + ctx context.Context + meterTracerFactory *builtinMetricsTracerFactory + streamd *resumableStreamDecoder + rowd *partialResultSetDecoder + setTransactionID func(transactionID) + updateTxState func(err error) error + updatePrecommitToken func(token *sppb.MultiplexedSessionPrecommitToken) + setTimestamp func(time.Time) + release func(error) + cancel func() + err error + rows []*Row + sawStats bool } // this is for safety from future changes to RowIterator making sure that it implements rowIterator interface. @@ -192,6 +196,9 @@ func (r *RowIterator) Next() (*Row, error) { } r.setTransactionID = nil } + if r.updatePrecommitToken != nil { + r.updatePrecommitToken(prs.GetPrecommitToken()) + } if prs.Stats != nil { r.sawStats = true r.QueryPlan = prs.Stats.QueryPlan diff --git a/spanner/session.go b/spanner/session.go index e07f67dedf3e..2165fdee04d8 100644 --- a/spanner/session.go +++ b/spanner/session.go @@ -24,7 +24,6 @@ import ( "log" "math" "math/rand" - "os" "runtime/debug" "strings" "sync" @@ -507,6 +506,11 @@ type SessionPoolConfig struct { // Defaults to false. TrackSessionHandles bool + enableMultiplexSession bool + + // enableMultiplexedSessionForRW is a flag to enable multiplexed session for read/write transactions, is used in testing + enableMultiplexedSessionForRW bool + // healthCheckSampleInterval is how often the health checker samples live // session (for use in maintaining session pool size). // @@ -699,10 +703,7 @@ func newSessionPool(sc *sessionClient, config SessionPoolConfig) (*sessionPool, if config.MultiplexSessionCheckInterval == 0 { config.MultiplexSessionCheckInterval = 10 * time.Minute } - isMultiplexed := strings.ToLower(os.Getenv("GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS")) - if isMultiplexed != "" && isMultiplexed != "true" && isMultiplexed != "false" { - return nil, spannerErrorf(codes.InvalidArgument, "GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS must be either true or false") - } + pool := &sessionPool{ sc: sc, valid: true, @@ -713,7 +714,7 @@ func newSessionPool(sc *sessionClient, config SessionPoolConfig) (*sessionPool, mw: newMaintenanceWindow(config.MaxOpened), rand: rand.New(rand.NewSource(time.Now().UnixNano())), otConfig: sc.otConfig, - enableMultiplexSession: isMultiplexed == "true", + enableMultiplexSession: config.enableMultiplexSession, } _, instance, database, err := parseDatabaseName(sc.database) @@ -1944,15 +1945,19 @@ func isSessionNotFoundError(err error) bool { return strings.Contains(err.Error(), "Session not found") } -// isUnimplementedError returns true if the gRPC error code is Unimplemented. func isUnimplementedError(err error) bool { if err == nil { return false } - if ErrCode(err) == codes.Unimplemented { - return true + return ErrCode(err) == codes.Unimplemented +} + +// isUnimplementedErrorForMultiplexedRW returns true if the gRPC error code is Unimplemented and related to use of multiplexed session with ReadWrite txn. +func isUnimplementedErrorForMultiplexedRW(err error) bool { + if err == nil { + return false } - return false + return ErrCode(err) == codes.Unimplemented && strings.Contains(err.Error(), "Transaction type read_write not supported with multiplexed sessions") } func isFailedInlineBeginTransaction(err error) bool { diff --git a/spanner/transaction.go b/spanner/transaction.go index f70b29050bb9..dbc1c5e969ef 100644 --- a/spanner/transaction.go +++ b/spanner/transaction.go @@ -50,6 +50,8 @@ type txReadEnv interface { getTransactionSelector() *sppb.TransactionSelector // sets the transactionID setTransactionID(id transactionID) + // updatePrecommitToken updates the precommit token for the transaction + updatePrecommitToken(token *sppb.MultiplexedSessionPrecommitToken) // sets the transaction's read timestamp setTimestamp(time.Time) // release should be called at the end of every transactional read to deal @@ -355,6 +357,7 @@ func (t *txReadOnly) ReadWithOptions(ctx context.Context, table string, keys Key func(err error) error { return t.updateTxState(err) }, + t.updatePrecommitToken, t.setTimestamp, t.release, client.(*grpcSpannerClient), @@ -643,6 +646,7 @@ func (t *txReadOnly) query(ctx context.Context, statement Statement, options Que func(err error) error { return t.updateTxState(err) }, + t.updatePrecommitToken, t.setTimestamp, t.release, client.(*grpcSpannerClient)) @@ -871,6 +875,11 @@ func (t *ReadOnlyTransaction) begin(ctx context.Context) error { return err } +// no-op for ReadOnlyTransaction. +func (t *ReadOnlyTransaction) updatePrecommitToken(token *sppb.MultiplexedSessionPrecommitToken) { + return +} + // acquire implements txReadEnv.acquire. func (t *ReadOnlyTransaction) acquire(ctx context.Context) (*sessionHandle, *sppb.TransactionSelector, error) { if err := checkNestedTxn(ctx); err != nil { @@ -1155,7 +1164,9 @@ type ReadWriteTransaction struct { // tx is the transaction ID in Cloud Spanner that uniquely identifies the // ReadWriteTransaction. It is set only once in ReadWriteTransaction.begin() // during the initialization of ReadWriteTransaction. - tx transactionID + tx transactionID + precommitToken *sppb.MultiplexedSessionPrecommitToken + // txReadyOrClosed is for broadcasting that transaction ID has been returned // by Cloud Spanner or that transaction is closed. txReadyOrClosed chan struct{} @@ -1252,6 +1263,7 @@ func (t *ReadWriteTransaction) update(ctx context.Context, stmt Statement, opts return 0, errInlineBeginTransactionFailed() } } + t.updatePrecommitToken(resultSet.GetPrecommitToken()) if resultSet.Stats == nil { return 0, spannerErrorf(codes.InvalidArgument, "query passed to Update: %q", stmt.SQL) } @@ -1371,6 +1383,7 @@ func (t *ReadWriteTransaction) batchUpdateWithOptions(ctx context.Context, stmts t.setTransactionID(nil) return counts, errInlineBeginTransactionFailed() } + t.updatePrecommitToken(resp.PrecommitToken) if resp.Status != nil && resp.Status.Code != 0 { return counts, t.txReadOnly.updateTxState(spannerError(codes.Code(uint32(resp.Status.Code)), resp.Status.Message)) } @@ -1486,6 +1499,17 @@ func (t *ReadWriteTransaction) setTransactionID(tx transactionID) { t.txReadyOrClosed = make(chan struct{}) } +func (t *ReadWriteTransaction) updatePrecommitToken(token *sppb.MultiplexedSessionPrecommitToken) { + if token == nil { + return + } + t.mu.Lock() + defer t.mu.Unlock() + if t.precommitToken == nil || token.SeqNum > t.precommitToken.SeqNum { + t.precommitToken = token + } +} + // release implements txReadEnv.release. func (t *ReadWriteTransaction) release(err error) { t.mu.Lock() @@ -1676,6 +1700,7 @@ func (t *ReadWriteTransaction) commit(ctx context.Context, options CommitOptions Transaction: &sppb.CommitRequest_TransactionId{ TransactionId: t.tx, }, + PrecommitToken: t.precommitToken, RequestOptions: createRequestOptions(t.txOpts.CommitPriority, "", t.txOpts.TransactionTag), Mutations: mPb, ReturnCommitStats: options.ReturnCommitStats, diff --git a/spanner/transaction_test.go b/spanner/transaction_test.go index 3c12ac7eca3c..1acd7f72f390 100644 --- a/spanner/transaction_test.go +++ b/spanner/transaction_test.go @@ -302,6 +302,185 @@ func TestReadWriteTransaction_ErrorReturned(t *testing.T) { } } +func TestClient_ReadWriteTransaction_UnimplementedErrorWithMultiplexedSessionSwitchesToRegular(t *testing.T) { + ctx := context.Background() + server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ + DisableNativeMetrics: true, + SessionPoolConfig: SessionPoolConfig{ + MinOpened: 1, + MaxOpened: 1, + enableMultiplexSession: true, + enableMultiplexedSessionForRW: true, + }, + }) + defer teardown() + + for _, sumulatdError := range []error{ + status.Error(codes.Unimplemented, "other Unimplemented error which should not turn off multiplexed session"), + status.Error(codes.Unimplemented, "Transaction type read_write not supported with multiplexed sessions")} { + server.TestSpanner.PutExecutionTime( + MethodExecuteStreamingSql, + SimulatedExecutionTime{Errors: []error{sumulatdError}}, + ) + _, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error { + iter := tx.Query(ctx, NewStatement(SelectFooFromBar)) + defer iter.Stop() + for { + _, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + return err + } + } + return nil + }) + requests := drainRequestsFromServer(server.TestSpanner) + foundMultiplexedSession := false + for _, req := range requests { + if sqlReq, ok := req.(*sppb.ExecuteSqlRequest); ok { + if strings.Contains(sqlReq.Session, "multiplexed") { + foundMultiplexedSession = true + break + } + } + } + + // Assert that the error is an Unimplemented error. + if status.Code(err) != codes.Unimplemented { + t.Fatalf("Expected Unimplemented error, got: %v", err) + } + if !foundMultiplexedSession { + t.Fatalf("Expected first transaction to use a multiplexed session, but it did not") + } + server.TestSpanner.Reset() + } + + // Attempt a second read-write transaction. + _, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error { + iter := tx.Query(ctx, NewStatement(SelectFooFromBar)) + defer iter.Stop() + for { + _, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + return err + } + } + return nil + }) + + if err != nil { + t.Fatalf("Unexpected error in second transaction: %v", err) + } + + // Check that the second transaction used a regular session. + requests := drainRequestsFromServer(server.TestSpanner) + foundMultiplexedSession := false + for _, req := range requests { + if sqlReq, ok := req.(*sppb.CommitRequest); ok { + if strings.Contains(sqlReq.Session, "multiplexed") { + foundMultiplexedSession = true + break + } + } + } + + if foundMultiplexedSession { + t.Fatalf("Expected second transaction to use a regular session, but it did not") + } +} + +func TestReadWriteTransaction_PrecommitToken(t *testing.T) { + t.Parallel() + ctx := context.Background() + server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ + DisableNativeMetrics: true, + SessionPoolConfig: SessionPoolConfig{ + MinOpened: 1, + MaxOpened: 1, + enableMultiplexSession: true, + enableMultiplexedSessionForRW: true, + }, + }) + defer teardown() + + type testCase struct { + name string + query bool + update bool + batchUpdate bool + expectedPrecommitToken string + expectedSequenceNumber int32 + } + + testCases := []testCase{ + {"Only Query", true, false, false, "PartialResultSetPrecommitToken", 3}, //since mock server is returning 3 rows + {"Query and Update", true, true, false, "ResultSetPrecommitToken", 4}, + {"Query, Update, and Batch Update", true, true, true, "ExecuteBatchDmlResponsePrecommitToken", 5}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + _, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error { + if tc.query { + iter := tx.Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums)) + defer iter.Stop() + for { + _, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + return err + } + } + } + + if tc.update { + if _, err := tx.Update(ctx, Statement{SQL: UpdateBarSetFoo}); err != nil { + return err + } + } + + if tc.batchUpdate { + if _, err := tx.BatchUpdate(ctx, []Statement{{SQL: UpdateBarSetFoo}}); err != nil { + return err + } + } + + return nil + }) + if err != nil { + t.Fatalf("%s failed: %v", tc.name, err) + } + + requests := drainRequestsFromServer(server.TestSpanner) + var commitReq *sppb.CommitRequest + for _, req := range requests { + if c, ok := req.(*sppb.CommitRequest); ok { + commitReq = c + break + } + } + if commitReq.PrecommitToken == nil || len(commitReq.PrecommitToken.GetPrecommitToken()) == 0 { + t.Fatalf("Expected commit request to contain a valid precommitToken, got: %v", commitReq.PrecommitToken) + } + // Validate that the precommit token contains the test argument. + if !strings.Contains(string(commitReq.PrecommitToken.GetPrecommitToken()), tc.expectedPrecommitToken) { + t.Fatalf("Precommit token does not contain the expected test argument") + } + // Validate that the sequence number is as expected. + if got, want := commitReq.PrecommitToken.GetSeqNum(), tc.expectedSequenceNumber; got != want { + t.Fatalf("Precommit token sequence number mismatch: got %d, want %d", got, want) + } + }) + } +} + func TestBatchDML_WithMultipleDML(t *testing.T) { t.Parallel() ctx := context.Background()