Skip to content

Commit

Permalink
chore(spanner): track precommit token for R/W multiplexed session (#1…
Browse files Browse the repository at this point in the history
…1229)

* chore(spanner): add support for multiplexed session with read write transactions.

* fix tests

* incorporate changes

* disable multiplxed session for ReadWrite only when unimplemented error is because of multiplex from server

* re-trigger
  • Loading branch information
rahul2393 authored Dec 23, 2024
1 parent 7cbffad commit e9a8e3a
Show file tree
Hide file tree
Showing 7 changed files with 333 additions and 61 deletions.
80 changes: 52 additions & 28 deletions spanner/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,19 +107,20 @@ func parseDatabaseName(db string) (project, instance, database string, err error
// Client is a client for reading and writing data to a Cloud Spanner database.
// A client is safe to use concurrently, except for its Close method.
type Client struct {
sc *sessionClient
idleSessions *sessionPool
logger *log.Logger
qo QueryOptions
ro ReadOptions
ao []ApplyOption
txo TransactionOptions
bwo BatchWriteOptions
ct *commonTags
disableRouteToLeader bool
dro *sppb.DirectedReadOptions
otConfig *openTelemetryConfig
metricsTracerFactory *builtinMetricsTracerFactory
sc *sessionClient
idleSessions *sessionPool
logger *log.Logger
qo QueryOptions
ro ReadOptions
ao []ApplyOption
txo TransactionOptions
bwo BatchWriteOptions
ct *commonTags
disableRouteToLeader bool
enableMultiplexedSessionForRW bool
dro *sppb.DirectedReadOptions
otConfig *openTelemetryConfig
metricsTracerFactory *builtinMetricsTracerFactory
}

// DatabaseName returns the full name of a database, e.g.,
Expand Down Expand Up @@ -487,6 +488,21 @@ func newClientWithConfig(ctx context.Context, database string, config ClientConf
md.Append(endToEndTracingHeader, "true")
}

if isMultiplexed := strings.ToLower(os.Getenv("GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS")); isMultiplexed != "" {
config.SessionPoolConfig.enableMultiplexSession, err = strconv.ParseBool(isMultiplexed)
if err != nil {
return nil, spannerErrorf(codes.InvalidArgument, "GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS must be either true or false")
}
}
//TODO: Uncomment this once the feature is enabled.
//if isMultiplexForRW := os.Getenv("GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS_FOR_RW"); isMultiplexForRW != "" {
// config.enableMultiplexedSessionForRW, err = strconv.ParseBool(isMultiplexForRW)
// if err != nil {
// return nil, spannerErrorf(codes.InvalidArgument, "GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS_FOR_RW must be either true or false")
// }
// config.enableMultiplexedSessionForRW = config.enableMultiplexedSessionForRW && config.SessionPoolConfig.enableMultiplexSession
//}

// Create a session client.
sc := newSessionClient(pool, database, config.UserAgent, sessionLabels, config.DatabaseRole, config.DisableRouteToLeader, md, config.BatchTimeout, config.Logger, config.CallOptions)

Expand Down Expand Up @@ -532,19 +548,20 @@ func newClientWithConfig(ctx context.Context, database string, config ClientConf
}

c = &Client{
sc: sc,
idleSessions: sp,
logger: config.Logger,
qo: getQueryOptions(config.QueryOptions),
ro: config.ReadOptions,
ao: config.ApplyOptions,
txo: config.TransactionOptions,
bwo: config.BatchWriteOptions,
ct: getCommonTags(sc),
disableRouteToLeader: config.DisableRouteToLeader,
dro: config.DirectedReadOptions,
otConfig: otConfig,
metricsTracerFactory: metricsTracerFactory,
sc: sc,
idleSessions: sp,
logger: config.Logger,
qo: getQueryOptions(config.QueryOptions),
ro: config.ReadOptions,
ao: config.ApplyOptions,
txo: config.TransactionOptions,
bwo: config.BatchWriteOptions,
ct: getCommonTags(sc),
disableRouteToLeader: config.DisableRouteToLeader,
dro: config.DirectedReadOptions,
otConfig: otConfig,
metricsTracerFactory: metricsTracerFactory,
enableMultiplexedSessionForRW: config.enableMultiplexedSessionForRW,
}
return c, nil
}
Expand Down Expand Up @@ -1008,8 +1025,12 @@ func (c *Client) rwTransaction(ctx context.Context, f func(context.Context, *Rea
err error
)
if sh == nil || sh.getID() == "" || sh.getClient() == nil {
// Session handle hasn't been allocated or has been destroyed.
sh, err = c.idleSessions.take(ctx)
if c.enableMultiplexedSessionForRW {
sh, err = c.idleSessions.takeMultiplexed(ctx)
} else {
// Session handle hasn't been allocated or has been destroyed.
sh, err = c.idleSessions.take(ctx)
}
if err != nil {
// If session retrieval fails, just fail the transaction.
return err
Expand Down Expand Up @@ -1050,6 +1071,9 @@ func (c *Client) rwTransaction(ctx context.Context, f func(context.Context, *Rea
resp, err = t.runInTransaction(ctx, f)
return err
})
if isUnimplementedErrorForMultiplexedRW(err) {
c.enableMultiplexedSessionForRW = false
}
return resp, err
}

Expand Down
35 changes: 33 additions & 2 deletions spanner/internal/testutil/inmem_spanner_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"sort"
"strings"
"sync"
"sync/atomic"
"time"

"cloud.google.com/go/spanner/apiv1/spannerpb"
Expand Down Expand Up @@ -333,7 +334,8 @@ type inMemSpannerServer struct {
// counters.
transactionCounters map[string]*uint64
// The transactions that have been created on this mock server.
transactions map[string]*spannerpb.Transaction
transactions map[string]*spannerpb.Transaction
multiplexedSessionTransactionsToSeqNo map[string]*atomic.Int32
// The transactions that have been (manually) aborted on the server.
abortedTransactions map[string]bool
// The transactions that are marked as PartitionedDMLTransaction
Expand Down Expand Up @@ -521,11 +523,25 @@ func (s *inMemSpannerServer) initDefaults() {
s.sessions = make(map[string]*spannerpb.Session)
s.sessionLastUseTime = make(map[string]time.Time)
s.transactions = make(map[string]*spannerpb.Transaction)
s.multiplexedSessionTransactionsToSeqNo = make(map[string]*atomic.Int32)
s.abortedTransactions = make(map[string]bool)
s.partitionedDmlTransactions = make(map[string]bool)
s.transactionCounters = make(map[string]*uint64)
}

func (s *inMemSpannerServer) getPreCommitToken(transactionID, operation string) *spannerpb.MultiplexedSessionPrecommitToken {
s.mu.Lock()
defer s.mu.Unlock()
sequence, ok := s.multiplexedSessionTransactionsToSeqNo[transactionID]
if !ok {
return nil
}
return &spannerpb.MultiplexedSessionPrecommitToken{
SeqNum: sequence.Add(1),
PrecommitToken: []byte(fmt.Sprintf("precommit-token-%v-%v", operation, sequence.Load())),
}
}

func (s *inMemSpannerServer) generateSessionNameLocked(database string, isMultiplexed bool) string {
s.sessionCounter++
if isMultiplexed {
Expand Down Expand Up @@ -597,6 +613,9 @@ func (s *inMemSpannerServer) beginTransaction(session *spannerpb.Session, option
ReadTimestamp: getCurrentTimestamp(),
}
s.mu.Lock()
if options.GetReadWrite() != nil && session.Multiplexed {
s.multiplexedSessionTransactionsToSeqNo[id] = new(atomic.Int32)
}
s.transactions[id] = res
s.partitionedDmlTransactions[id] = options.GetPartitionedDml() != nil
s.mu.Unlock()
Expand Down Expand Up @@ -634,6 +653,7 @@ func (s *inMemSpannerServer) removeTransaction(tx *spannerpb.Transaction) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.transactions, string(tx.Id))
delete(s.multiplexedSessionTransactionsToSeqNo, string(tx.Id))
delete(s.partitionedDmlTransactions, string(tx.Id))
}

Expand Down Expand Up @@ -870,9 +890,16 @@ func (s *inMemSpannerServer) ExecuteSql(ctx context.Context, req *spannerpb.Exec
case StatementResultError:
return nil, statementResult.Err
case StatementResultResultSet:

// if request's session is multiplexed and transaction is Read/Write then add Pre-commit Token in Metadata
if statementResult.ResultSet != nil {
statementResult.ResultSet.PrecommitToken = s.getPreCommitToken(string(id), "ResultSetPrecommitToken")
}
return statementResult.ResultSet, nil
case StatementResultUpdateCount:
return statementResult.convertUpdateCountToResultSet(!isPartitionedDml), nil
res := statementResult.convertUpdateCountToResultSet(!isPartitionedDml)
res.PrecommitToken = s.getPreCommitToken(string(id), "ResultSetPrecommitToken")
return res, nil
}
return nil, gstatus.Error(codes.Internal, "Unknown result type")
}
Expand Down Expand Up @@ -938,6 +965,9 @@ func (s *inMemSpannerServer) executeStreamingSQL(req *spannerpb.ExecuteSqlReques
return nextPartialResultSetError.Err
}
}
// For every PartialResultSet, if request's session is multiplexed and transaction is Read/Write then add Pre-commit Token in Metadata
// and increment the sequence number
part.PrecommitToken = s.getPreCommitToken(string(id), "PartialResultSetPrecommitToken")
if err := stream.Send(part); err != nil {
return err
}
Expand Down Expand Up @@ -997,6 +1027,7 @@ func (s *inMemSpannerServer) ExecuteBatchDml(ctx context.Context, req *spannerpb
resp.ResultSets[idx] = statementResult.convertUpdateCountToResultSet(!isPartitionedDml)
}
}
resp.PrecommitToken = s.getPreCommitToken(string(id), "ExecuteBatchDmlResponsePrecommitToken")
return resp, nil
}

Expand Down
1 change: 1 addition & 0 deletions spanner/kokoro/presubmit.sh
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ exit_code=0
case $JOB_TYPE in
integration-with-multiplexed-session )
GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS=true
GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS_FOR_RW=true
echo "running presubmit with multiplexed sessions enabled: $GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS"
;;
esac
Expand Down
47 changes: 27 additions & 20 deletions spanner/read.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ func stream(
func(err error) error {
return err
},
nil,
setTimestamp,
release,
gsc,
Expand All @@ -85,21 +86,23 @@ func streamWithReplaceSessionFunc(
replaceSession func(ctx context.Context) error,
setTransactionID func(transactionID),
updateTxState func(err error) error,
updatePrecommitToken func(token *sppb.MultiplexedSessionPrecommitToken),
setTimestamp func(time.Time),
release func(error),
gsc *grpcSpannerClient,
) *RowIterator {
ctx, cancel := context.WithCancel(ctx)
ctx = trace.StartSpan(ctx, "cloud.google.com/go/spanner.RowIterator")
return &RowIterator{
meterTracerFactory: meterTracerFactory,
streamd: newResumableStreamDecoder(ctx, logger, rpc, replaceSession, gsc),
rowd: &partialResultSetDecoder{},
setTransactionID: setTransactionID,
updateTxState: updateTxState,
setTimestamp: setTimestamp,
release: release,
cancel: cancel,
meterTracerFactory: meterTracerFactory,
streamd: newResumableStreamDecoder(ctx, logger, rpc, replaceSession, gsc),
rowd: &partialResultSetDecoder{},
setTransactionID: setTransactionID,
updatePrecommitToken: updatePrecommitToken,
updateTxState: updateTxState,
setTimestamp: setTimestamp,
release: release,
cancel: cancel,
}
}

Expand Down Expand Up @@ -130,18 +133,19 @@ type RowIterator struct {
// RowIterator.Next() returned an error that is not equal to iterator.Done.
Metadata *sppb.ResultSetMetadata

ctx context.Context
meterTracerFactory *builtinMetricsTracerFactory
streamd *resumableStreamDecoder
rowd *partialResultSetDecoder
setTransactionID func(transactionID)
updateTxState func(err error) error
setTimestamp func(time.Time)
release func(error)
cancel func()
err error
rows []*Row
sawStats bool
ctx context.Context
meterTracerFactory *builtinMetricsTracerFactory
streamd *resumableStreamDecoder
rowd *partialResultSetDecoder
setTransactionID func(transactionID)
updateTxState func(err error) error
updatePrecommitToken func(token *sppb.MultiplexedSessionPrecommitToken)
setTimestamp func(time.Time)
release func(error)
cancel func()
err error
rows []*Row
sawStats bool
}

// this is for safety from future changes to RowIterator making sure that it implements rowIterator interface.
Expand Down Expand Up @@ -192,6 +196,9 @@ func (r *RowIterator) Next() (*Row, error) {
}
r.setTransactionID = nil
}
if r.updatePrecommitToken != nil {
r.updatePrecommitToken(prs.GetPrecommitToken())
}
if prs.Stats != nil {
r.sawStats = true
r.QueryPlan = prs.Stats.QueryPlan
Expand Down
25 changes: 15 additions & 10 deletions spanner/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import (
"log"
"math"
"math/rand"
"os"
"runtime/debug"
"strings"
"sync"
Expand Down Expand Up @@ -507,6 +506,11 @@ type SessionPoolConfig struct {
// Defaults to false.
TrackSessionHandles bool

enableMultiplexSession bool

// enableMultiplexedSessionForRW is a flag to enable multiplexed session for read/write transactions, is used in testing
enableMultiplexedSessionForRW bool

// healthCheckSampleInterval is how often the health checker samples live
// session (for use in maintaining session pool size).
//
Expand Down Expand Up @@ -699,10 +703,7 @@ func newSessionPool(sc *sessionClient, config SessionPoolConfig) (*sessionPool,
if config.MultiplexSessionCheckInterval == 0 {
config.MultiplexSessionCheckInterval = 10 * time.Minute
}
isMultiplexed := strings.ToLower(os.Getenv("GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS"))
if isMultiplexed != "" && isMultiplexed != "true" && isMultiplexed != "false" {
return nil, spannerErrorf(codes.InvalidArgument, "GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS must be either true or false")
}

pool := &sessionPool{
sc: sc,
valid: true,
Expand All @@ -713,7 +714,7 @@ func newSessionPool(sc *sessionClient, config SessionPoolConfig) (*sessionPool,
mw: newMaintenanceWindow(config.MaxOpened),
rand: rand.New(rand.NewSource(time.Now().UnixNano())),
otConfig: sc.otConfig,
enableMultiplexSession: isMultiplexed == "true",
enableMultiplexSession: config.enableMultiplexSession,
}

_, instance, database, err := parseDatabaseName(sc.database)
Expand Down Expand Up @@ -1944,15 +1945,19 @@ func isSessionNotFoundError(err error) bool {
return strings.Contains(err.Error(), "Session not found")
}

// isUnimplementedError returns true if the gRPC error code is Unimplemented.
func isUnimplementedError(err error) bool {
if err == nil {
return false
}
if ErrCode(err) == codes.Unimplemented {
return true
return ErrCode(err) == codes.Unimplemented
}

// isUnimplementedErrorForMultiplexedRW returns true if the gRPC error code is Unimplemented and related to use of multiplexed session with ReadWrite txn.
func isUnimplementedErrorForMultiplexedRW(err error) bool {
if err == nil {
return false
}
return false
return ErrCode(err) == codes.Unimplemented && strings.Contains(err.Error(), "Transaction type read_write not supported with multiplexed sessions")
}

func isFailedInlineBeginTransaction(err error) bool {
Expand Down
Loading

0 comments on commit e9a8e3a

Please sign in to comment.