Skip to content

Commit

Permalink
chore(spanner): add support for multiplexed session with read write t…
Browse files Browse the repository at this point in the history
…ransactions.
  • Loading branch information
rahul2393 committed Dec 18, 2024
1 parent d448fbb commit f692dbf
Show file tree
Hide file tree
Showing 7 changed files with 220 additions and 52 deletions.
69 changes: 41 additions & 28 deletions spanner/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,19 +107,20 @@ func parseDatabaseName(db string) (project, instance, database string, err error
// Client is a client for reading and writing data to a Cloud Spanner database.
// A client is safe to use concurrently, except for its Close method.
type Client struct {
sc *sessionClient
idleSessions *sessionPool
logger *log.Logger
qo QueryOptions
ro ReadOptions
ao []ApplyOption
txo TransactionOptions
bwo BatchWriteOptions
ct *commonTags
disableRouteToLeader bool
dro *sppb.DirectedReadOptions
otConfig *openTelemetryConfig
metricsTracerFactory *builtinMetricsTracerFactory
sc *sessionClient
idleSessions *sessionPool
logger *log.Logger
qo QueryOptions
ro ReadOptions
ao []ApplyOption
txo TransactionOptions
bwo BatchWriteOptions
ct *commonTags
disableRouteToLeader bool
enableMultiplexSessionForRW bool
dro *sppb.DirectedReadOptions
otConfig *openTelemetryConfig
metricsTracerFactory *builtinMetricsTracerFactory
}

// DatabaseName returns the full name of a database, e.g.,
Expand Down Expand Up @@ -478,6 +479,13 @@ func newClientWithConfig(ctx context.Context, database string, config ClientConf
if config.EnableEndToEndTracing || endToEndTracingEnvironmentVariable == "true" {
md.Append(endToEndTracingHeader, "true")
}
//TODO: Uncomment this once the feature is enabled.
//if isMultiplexForRW := os.Getenv("GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS_FOR_RW"); isMultiplexForRW != "" {
// config.enableMultiplexSessionForRW, err = strconv.ParseBool(isMultiplexForRW)
// if err != nil {
// return nil, spannerErrorf(codes.InvalidArgument, "GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS_FOR_RW must be either true or false")
// }
//}

// Create a session client.
sc := newSessionClient(pool, database, config.UserAgent, sessionLabels, config.DatabaseRole, config.DisableRouteToLeader, md, config.BatchTimeout, config.Logger, config.CallOptions)
Expand Down Expand Up @@ -524,19 +532,20 @@ func newClientWithConfig(ctx context.Context, database string, config ClientConf
}

c = &Client{
sc: sc,
idleSessions: sp,
logger: config.Logger,
qo: getQueryOptions(config.QueryOptions),
ro: config.ReadOptions,
ao: config.ApplyOptions,
txo: config.TransactionOptions,
bwo: config.BatchWriteOptions,
ct: getCommonTags(sc),
disableRouteToLeader: config.DisableRouteToLeader,
dro: config.DirectedReadOptions,
otConfig: otConfig,
metricsTracerFactory: metricsTracerFactory,
sc: sc,
idleSessions: sp,
logger: config.Logger,
qo: getQueryOptions(config.QueryOptions),
ro: config.ReadOptions,
ao: config.ApplyOptions,
txo: config.TransactionOptions,
bwo: config.BatchWriteOptions,
ct: getCommonTags(sc),
disableRouteToLeader: config.DisableRouteToLeader,
dro: config.DirectedReadOptions,
otConfig: otConfig,
metricsTracerFactory: metricsTracerFactory,
enableMultiplexSessionForRW: config.enableMultiplexSessionForRW,
}
return c, nil
}
Expand Down Expand Up @@ -1000,8 +1009,12 @@ func (c *Client) rwTransaction(ctx context.Context, f func(context.Context, *Rea
err error
)
if sh == nil || sh.getID() == "" || sh.getClient() == nil {
// Session handle hasn't been allocated or has been destroyed.
sh, err = c.idleSessions.take(ctx)
if c.enableMultiplexSessionForRW {
sh, err = c.idleSessions.takeMultiplexed(ctx)
} else {
// Session handle hasn't been allocated or has been destroyed.
sh, err = c.idleSessions.take(ctx)
}
if err != nil {
// If session retrieval fails, just fail the transaction.
return err
Expand Down
47 changes: 45 additions & 2 deletions spanner/internal/testutil/inmem_spanner_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"sort"
"strings"
"sync"
"sync/atomic"
"time"

"cloud.google.com/go/spanner/apiv1/spannerpb"
Expand Down Expand Up @@ -332,7 +333,8 @@ type inMemSpannerServer struct {
// counters.
transactionCounters map[string]*uint64
// The transactions that have been created on this mock server.
transactions map[string]*spannerpb.Transaction
transactions map[string]*spannerpb.Transaction
multiplexedSessionTransactions map[string]*Transaction
// The transactions that have been (manually) aborted on the server.
abortedTransactions map[string]bool
// The transactions that are marked as PartitionedDMLTransaction
Expand All @@ -358,6 +360,18 @@ type inMemSpannerServer struct {
freezed chan struct{}
}

type Transaction struct {
sequence *atomic.Int32
transaction *spannerpb.Transaction
}

func (t *Transaction) getPreCommitToken(operation string) *spannerpb.MultiplexedSessionPrecommitToken {
return &spannerpb.MultiplexedSessionPrecommitToken{
SeqNum: t.sequence.Add(1),
PrecommitToken: []byte(fmt.Sprintf("precommit-token-%v-%v", operation, t.sequence.Load())),
}
}

// NewInMemSpannerServer creates a new in-mem test server.
func NewInMemSpannerServer() InMemSpannerServer {
res := &inMemSpannerServer{}
Expand Down Expand Up @@ -520,6 +534,7 @@ func (s *inMemSpannerServer) initDefaults() {
s.sessions = make(map[string]*spannerpb.Session)
s.sessionLastUseTime = make(map[string]time.Time)
s.transactions = make(map[string]*spannerpb.Transaction)
s.multiplexedSessionTransactions = make(map[string]*Transaction)
s.abortedTransactions = make(map[string]bool)
s.partitionedDmlTransactions = make(map[string]bool)
s.transactionCounters = make(map[string]*uint64)
Expand Down Expand Up @@ -596,6 +611,9 @@ func (s *inMemSpannerServer) beginTransaction(session *spannerpb.Session, option
ReadTimestamp: getCurrentTimestamp(),
}
s.mu.Lock()
if options.GetReadWrite() != nil && session.Multiplexed {
s.multiplexedSessionTransactions[id] = &Transaction{transaction: res, sequence: new(atomic.Int32)}
}
s.transactions[id] = res
s.partitionedDmlTransactions[id] = options.GetPartitionedDml() != nil
s.mu.Unlock()
Expand Down Expand Up @@ -633,6 +651,7 @@ func (s *inMemSpannerServer) removeTransaction(tx *spannerpb.Transaction) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.transactions, string(tx.Id))
delete(s.multiplexedSessionTransactions, string(tx.Id))
delete(s.partitionedDmlTransactions, string(tx.Id))
}

Expand Down Expand Up @@ -869,9 +888,21 @@ func (s *inMemSpannerServer) ExecuteSql(ctx context.Context, req *spannerpb.Exec
case StatementResultError:
return nil, statementResult.Err
case StatementResultResultSet:
// if request's session is multiplexed and transaction is Read/Write then add Pre-commit Token in Metadata
if statementResult.ResultSet != nil {
statementResult.ResultSet.PrecommitToken = s.multiplexedSessionTransactions[string(id)].getPreCommitToken("ResultSetPrecommitToken")
}
return statementResult.ResultSet, nil
case StatementResultUpdateCount:
return statementResult.convertUpdateCountToResultSet(!isPartitionedDml), nil
res := statementResult.convertUpdateCountToResultSet(!isPartitionedDml)
// if request's session is multiplexed and transaction is Read/Write then add Pre-commit Token in Metadata
s.mu.Lock()
txn, ok := s.multiplexedSessionTransactions[string(id)]
s.mu.Unlock()
if ok {
res.PrecommitToken = txn.getPreCommitToken("ResultSetPrecommitToken")
}
return res, nil
}
return nil, gstatus.Error(codes.Internal, "Unknown result type")
}
Expand Down Expand Up @@ -937,6 +968,12 @@ func (s *inMemSpannerServer) executeStreamingSQL(req *spannerpb.ExecuteSqlReques
return nextPartialResultSetError.Err
}
}
s.mu.Lock()
txn, ok := s.multiplexedSessionTransactions[string(id)]
s.mu.Unlock()
if ok {
part.PrecommitToken = txn.getPreCommitToken("PartialResultSetPrecommitToken")
}
if err := stream.Send(part); err != nil {
return err
}
Expand Down Expand Up @@ -996,6 +1033,12 @@ func (s *inMemSpannerServer) ExecuteBatchDml(ctx context.Context, req *spannerpb
resp.ResultSets[idx] = statementResult.convertUpdateCountToResultSet(!isPartitionedDml)
}
}
s.mu.Lock()
txn, ok := s.multiplexedSessionTransactions[string(id)]
s.mu.Unlock()
if ok {
resp.PrecommitToken = txn.getPreCommitToken("ExecuteBatchDmlResponsePrecommitToken")
}
return resp, nil
}

Expand Down
1 change: 1 addition & 0 deletions spanner/kokoro/presubmit.sh
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ exit_code=0
case $JOB_TYPE in
integration-with-multiplexed-session )
GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS=true
GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS_FOR_RW=true
echo "running presubmit with multiplexed sessions enabled: $GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS"
;;
esac
Expand Down
47 changes: 27 additions & 20 deletions spanner/read.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ func stream(
func(err error) error {
return err
},
nil,
setTimestamp,
release,
)
Expand All @@ -83,20 +84,22 @@ func streamWithReplaceSessionFunc(
replaceSession func(ctx context.Context) error,
setTransactionID func(transactionID),
updateTxState func(err error) error,
updatePrecommitToken func(token *sppb.MultiplexedSessionPrecommitToken),
setTimestamp func(time.Time),
release func(error),
) *RowIterator {
ctx, cancel := context.WithCancel(ctx)
ctx = trace.StartSpan(ctx, "cloud.google.com/go/spanner.RowIterator")
return &RowIterator{
meterTracerFactory: meterTracerFactory,
streamd: newResumableStreamDecoder(ctx, logger, rpc, replaceSession),
rowd: &partialResultSetDecoder{},
setTransactionID: setTransactionID,
updateTxState: updateTxState,
setTimestamp: setTimestamp,
release: release,
cancel: cancel,
meterTracerFactory: meterTracerFactory,
streamd: newResumableStreamDecoder(ctx, logger, rpc, replaceSession),
rowd: &partialResultSetDecoder{},
setTransactionID: setTransactionID,
updatePrecommitToken: updatePrecommitToken,
updateTxState: updateTxState,
setTimestamp: setTimestamp,
release: release,
cancel: cancel,
}
}

Expand Down Expand Up @@ -127,18 +130,19 @@ type RowIterator struct {
// RowIterator.Next() returned an error that is not equal to iterator.Done.
Metadata *sppb.ResultSetMetadata

ctx context.Context
meterTracerFactory *builtinMetricsTracerFactory
streamd *resumableStreamDecoder
rowd *partialResultSetDecoder
setTransactionID func(transactionID)
updateTxState func(err error) error
setTimestamp func(time.Time)
release func(error)
cancel func()
err error
rows []*Row
sawStats bool
ctx context.Context
meterTracerFactory *builtinMetricsTracerFactory
streamd *resumableStreamDecoder
rowd *partialResultSetDecoder
setTransactionID func(transactionID)
updateTxState func(err error) error
updatePrecommitToken func(token *sppb.MultiplexedSessionPrecommitToken)
setTimestamp func(time.Time)
release func(error)
cancel func()
err error
rows []*Row
sawStats bool
}

// this is for safety from future changes to RowIterator making sure that it implements rowIterator interface.
Expand Down Expand Up @@ -189,6 +193,9 @@ func (r *RowIterator) Next() (*Row, error) {
}
r.setTransactionID = nil
}
if r.updatePrecommitToken != nil {
r.updatePrecommitToken(prs.GetPrecommitToken())
}
if prs.Stats != nil {
r.sawStats = true
r.QueryPlan = prs.Stats.QueryPlan
Expand Down
7 changes: 6 additions & 1 deletion spanner/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,9 @@ type SessionPoolConfig struct {
// Defaults to false.
TrackSessionHandles bool

// enableMultiplexSessionForRW is a flag to enable multiplexed session for read/write transactions.
enableMultiplexSessionForRW bool

// healthCheckSampleInterval is how often the health checker samples live
// session (for use in maintaining session pool size).
//
Expand Down Expand Up @@ -703,6 +706,7 @@ func newSessionPool(sc *sessionClient, config SessionPoolConfig) (*sessionPool,
if isMultiplexed != "" && isMultiplexed != "true" && isMultiplexed != "false" {
return nil, spannerErrorf(codes.InvalidArgument, "GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS must be either true or false")
}

pool := &sessionPool{
sc: sc,
valid: true,
Expand All @@ -713,7 +717,7 @@ func newSessionPool(sc *sessionClient, config SessionPoolConfig) (*sessionPool,
mw: newMaintenanceWindow(config.MaxOpened),
rand: rand.New(rand.NewSource(time.Now().UnixNano())),
otConfig: sc.otConfig,
enableMultiplexSession: isMultiplexed == "true",
enableMultiplexSession: isMultiplexed == "true" || config.enableMultiplexSessionForRW,
}

_, instance, database, err := parseDatabaseName(sc.database)
Expand Down Expand Up @@ -1291,6 +1295,7 @@ func (p *sessionPool) takeMultiplexed(ctx context.Context) (*sessionHandle, erro
if isUnimplementedError(err) {
logf(p.sc.logger, "Multiplexed session is not enabled on this project, continuing with regular sessions")
p.enableMultiplexSession = false
p.enableMultiplexSessionForRW = false
} else {
p.mu.Unlock()
// If the error is a timeout, there is a chance that the session was
Expand Down
Loading

0 comments on commit f692dbf

Please sign in to comment.