Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(spanner): track precommit token for R/W multiplexed session #11229

Merged
merged 9 commits into from
Dec 23, 2024
69 changes: 41 additions & 28 deletions spanner/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,19 +107,20 @@ func parseDatabaseName(db string) (project, instance, database string, err error
// Client is a client for reading and writing data to a Cloud Spanner database.
// A client is safe to use concurrently, except for its Close method.
type Client struct {
sc *sessionClient
idleSessions *sessionPool
logger *log.Logger
qo QueryOptions
ro ReadOptions
ao []ApplyOption
txo TransactionOptions
bwo BatchWriteOptions
ct *commonTags
disableRouteToLeader bool
dro *sppb.DirectedReadOptions
otConfig *openTelemetryConfig
metricsTracerFactory *builtinMetricsTracerFactory
sc *sessionClient
idleSessions *sessionPool
logger *log.Logger
qo QueryOptions
ro ReadOptions
ao []ApplyOption
txo TransactionOptions
bwo BatchWriteOptions
ct *commonTags
disableRouteToLeader bool
enableMultiplexSessionForRW bool
harshachinta marked this conversation as resolved.
Show resolved Hide resolved
dro *sppb.DirectedReadOptions
otConfig *openTelemetryConfig
metricsTracerFactory *builtinMetricsTracerFactory
}

// DatabaseName returns the full name of a database, e.g.,
Expand Down Expand Up @@ -486,6 +487,13 @@ func newClientWithConfig(ctx context.Context, database string, config ClientConf
if config.EnableEndToEndTracing || endToEndTracingEnvironmentVariable == "true" {
md.Append(endToEndTracingHeader, "true")
}
//TODO: Uncomment this once the feature is enabled.
//if isMultiplexForRW := os.Getenv("GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS_FOR_RW"); isMultiplexForRW != "" {
// config.enableMultiplexSessionForRW, err = strconv.ParseBool(isMultiplexForRW)
// if err != nil {
// return nil, spannerErrorf(codes.InvalidArgument, "GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS_FOR_RW must be either true or false")
// }
//}

// Create a session client.
sc := newSessionClient(pool, database, config.UserAgent, sessionLabels, config.DatabaseRole, config.DisableRouteToLeader, md, config.BatchTimeout, config.Logger, config.CallOptions)
Expand Down Expand Up @@ -532,19 +540,20 @@ func newClientWithConfig(ctx context.Context, database string, config ClientConf
}

c = &Client{
sc: sc,
idleSessions: sp,
logger: config.Logger,
qo: getQueryOptions(config.QueryOptions),
ro: config.ReadOptions,
ao: config.ApplyOptions,
txo: config.TransactionOptions,
bwo: config.BatchWriteOptions,
ct: getCommonTags(sc),
disableRouteToLeader: config.DisableRouteToLeader,
dro: config.DirectedReadOptions,
otConfig: otConfig,
metricsTracerFactory: metricsTracerFactory,
sc: sc,
idleSessions: sp,
logger: config.Logger,
qo: getQueryOptions(config.QueryOptions),
ro: config.ReadOptions,
ao: config.ApplyOptions,
txo: config.TransactionOptions,
bwo: config.BatchWriteOptions,
ct: getCommonTags(sc),
disableRouteToLeader: config.DisableRouteToLeader,
dro: config.DirectedReadOptions,
otConfig: otConfig,
metricsTracerFactory: metricsTracerFactory,
enableMultiplexSessionForRW: config.enableMultiplexSessionForRW,
harshachinta marked this conversation as resolved.
Show resolved Hide resolved
}
return c, nil
}
Expand Down Expand Up @@ -1008,8 +1017,12 @@ func (c *Client) rwTransaction(ctx context.Context, f func(context.Context, *Rea
err error
)
if sh == nil || sh.getID() == "" || sh.getClient() == nil {
// Session handle hasn't been allocated or has been destroyed.
sh, err = c.idleSessions.take(ctx)
if c.enableMultiplexSessionForRW {
harshachinta marked this conversation as resolved.
Show resolved Hide resolved
sh, err = c.idleSessions.takeMultiplexed(ctx)
} else {
// Session handle hasn't been allocated or has been destroyed.
sh, err = c.idleSessions.take(ctx)
}
if err != nil {
// If session retrieval fails, just fail the transaction.
return err
Expand Down
55 changes: 53 additions & 2 deletions spanner/internal/testutil/inmem_spanner_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"sort"
"strings"
"sync"
"sync/atomic"
"time"

"cloud.google.com/go/spanner/apiv1/spannerpb"
Expand Down Expand Up @@ -333,7 +334,8 @@ type inMemSpannerServer struct {
// counters.
transactionCounters map[string]*uint64
// The transactions that have been created on this mock server.
transactions map[string]*spannerpb.Transaction
transactions map[string]*spannerpb.Transaction
multiplexedSessionTransactions map[string]*Transaction
harshachinta marked this conversation as resolved.
Show resolved Hide resolved
// The transactions that have been (manually) aborted on the server.
abortedTransactions map[string]bool
// The transactions that are marked as PartitionedDMLTransaction
Expand All @@ -359,6 +361,20 @@ type inMemSpannerServer struct {
freezed chan struct{}
}

// Transaction is a wrapper around a spannerpb.Transaction that also contains
// a sequence number that is used to generate precommit tokens.
type Transaction struct {
sequence *atomic.Int32
transaction *spannerpb.Transaction
}

func (t *Transaction) getPreCommitToken(operation string) *spannerpb.MultiplexedSessionPrecommitToken {
return &spannerpb.MultiplexedSessionPrecommitToken{
SeqNum: t.sequence.Add(1),
PrecommitToken: []byte(fmt.Sprintf("precommit-token-%v-%v", operation, t.sequence.Load())),
}
}

harshachinta marked this conversation as resolved.
Show resolved Hide resolved
// NewInMemSpannerServer creates a new in-mem test server.
func NewInMemSpannerServer() InMemSpannerServer {
res := &inMemSpannerServer{}
Expand Down Expand Up @@ -521,6 +537,7 @@ func (s *inMemSpannerServer) initDefaults() {
s.sessions = make(map[string]*spannerpb.Session)
s.sessionLastUseTime = make(map[string]time.Time)
s.transactions = make(map[string]*spannerpb.Transaction)
s.multiplexedSessionTransactions = make(map[string]*Transaction)
s.abortedTransactions = make(map[string]bool)
s.partitionedDmlTransactions = make(map[string]bool)
s.transactionCounters = make(map[string]*uint64)
Expand Down Expand Up @@ -597,6 +614,9 @@ func (s *inMemSpannerServer) beginTransaction(session *spannerpb.Session, option
ReadTimestamp: getCurrentTimestamp(),
}
s.mu.Lock()
if options.GetReadWrite() != nil && session.Multiplexed {
s.multiplexedSessionTransactions[id] = &Transaction{transaction: res, sequence: new(atomic.Int32)}
}
s.transactions[id] = res
s.partitionedDmlTransactions[id] = options.GetPartitionedDml() != nil
s.mu.Unlock()
Expand Down Expand Up @@ -634,6 +654,7 @@ func (s *inMemSpannerServer) removeTransaction(tx *spannerpb.Transaction) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.transactions, string(tx.Id))
delete(s.multiplexedSessionTransactions, string(tx.Id))
delete(s.partitionedDmlTransactions, string(tx.Id))
}

Expand Down Expand Up @@ -870,9 +891,27 @@ func (s *inMemSpannerServer) ExecuteSql(ctx context.Context, req *spannerpb.Exec
case StatementResultError:
return nil, statementResult.Err
case StatementResultResultSet:

// if request's session is multiplexed and transaction is Read/Write then add Pre-commit Token in Metadata
if statementResult.ResultSet != nil {
s.mu.Lock()
txn, ok := s.multiplexedSessionTransactions[string(id)]
s.mu.Unlock()
if ok {
statementResult.ResultSet.PrecommitToken = txn.getPreCommitToken("ResultSetPrecommitToken")
}
}
return statementResult.ResultSet, nil
case StatementResultUpdateCount:
return statementResult.convertUpdateCountToResultSet(!isPartitionedDml), nil
res := statementResult.convertUpdateCountToResultSet(!isPartitionedDml)
// if request's session is multiplexed and transaction is Read/Write then add Pre-commit Token in Metadata
s.mu.Lock()
txn, ok := s.multiplexedSessionTransactions[string(id)]
s.mu.Unlock()
if ok {
res.PrecommitToken = txn.getPreCommitToken("ResultSetPrecommitToken")
}
return res, nil
}
return nil, gstatus.Error(codes.Internal, "Unknown result type")
}
Expand Down Expand Up @@ -938,6 +977,12 @@ func (s *inMemSpannerServer) executeStreamingSQL(req *spannerpb.ExecuteSqlReques
return nextPartialResultSetError.Err
}
}
s.mu.Lock()
txn, ok := s.multiplexedSessionTransactions[string(id)]
s.mu.Unlock()
if ok {
part.PrecommitToken = txn.getPreCommitToken("PartialResultSetPrecommitToken")
}
if err := stream.Send(part); err != nil {
return err
}
Expand Down Expand Up @@ -997,6 +1042,12 @@ func (s *inMemSpannerServer) ExecuteBatchDml(ctx context.Context, req *spannerpb
resp.ResultSets[idx] = statementResult.convertUpdateCountToResultSet(!isPartitionedDml)
}
}
s.mu.Lock()
txn, ok := s.multiplexedSessionTransactions[string(id)]
s.mu.Unlock()
if ok {
resp.PrecommitToken = txn.getPreCommitToken("ExecuteBatchDmlResponsePrecommitToken")
}
return resp, nil
}

Expand Down
1 change: 1 addition & 0 deletions spanner/kokoro/presubmit.sh
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ exit_code=0
case $JOB_TYPE in
integration-with-multiplexed-session )
GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS=true
GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS_FOR_RW=true
echo "running presubmit with multiplexed sessions enabled: $GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS"
;;
esac
Expand Down
47 changes: 27 additions & 20 deletions spanner/read.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ func stream(
func(err error) error {
return err
},
nil,
setTimestamp,
release,
gsc,
Expand All @@ -85,21 +86,23 @@ func streamWithReplaceSessionFunc(
replaceSession func(ctx context.Context) error,
setTransactionID func(transactionID),
updateTxState func(err error) error,
updatePrecommitToken func(token *sppb.MultiplexedSessionPrecommitToken),
setTimestamp func(time.Time),
release func(error),
gsc *grpcSpannerClient,
) *RowIterator {
ctx, cancel := context.WithCancel(ctx)
ctx = trace.StartSpan(ctx, "cloud.google.com/go/spanner.RowIterator")
return &RowIterator{
meterTracerFactory: meterTracerFactory,
streamd: newResumableStreamDecoder(ctx, logger, rpc, replaceSession, gsc),
rowd: &partialResultSetDecoder{},
setTransactionID: setTransactionID,
updateTxState: updateTxState,
setTimestamp: setTimestamp,
release: release,
cancel: cancel,
meterTracerFactory: meterTracerFactory,
streamd: newResumableStreamDecoder(ctx, logger, rpc, replaceSession, gsc),
rowd: &partialResultSetDecoder{},
setTransactionID: setTransactionID,
updatePrecommitToken: updatePrecommitToken,
updateTxState: updateTxState,
setTimestamp: setTimestamp,
release: release,
cancel: cancel,
}
}

Expand Down Expand Up @@ -130,18 +133,19 @@ type RowIterator struct {
// RowIterator.Next() returned an error that is not equal to iterator.Done.
Metadata *sppb.ResultSetMetadata

ctx context.Context
meterTracerFactory *builtinMetricsTracerFactory
streamd *resumableStreamDecoder
rowd *partialResultSetDecoder
setTransactionID func(transactionID)
updateTxState func(err error) error
setTimestamp func(time.Time)
release func(error)
cancel func()
err error
rows []*Row
sawStats bool
ctx context.Context
meterTracerFactory *builtinMetricsTracerFactory
streamd *resumableStreamDecoder
rowd *partialResultSetDecoder
setTransactionID func(transactionID)
updateTxState func(err error) error
updatePrecommitToken func(token *sppb.MultiplexedSessionPrecommitToken)
setTimestamp func(time.Time)
release func(error)
cancel func()
err error
rows []*Row
sawStats bool
}

// this is for safety from future changes to RowIterator making sure that it implements rowIterator interface.
Expand Down Expand Up @@ -192,6 +196,9 @@ func (r *RowIterator) Next() (*Row, error) {
}
r.setTransactionID = nil
}
if r.updatePrecommitToken != nil {
r.updatePrecommitToken(prs.GetPrecommitToken())
}
if prs.Stats != nil {
r.sawStats = true
r.QueryPlan = prs.Stats.QueryPlan
Expand Down
7 changes: 6 additions & 1 deletion spanner/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,9 @@ type SessionPoolConfig struct {
// Defaults to false.
TrackSessionHandles bool

// enableMultiplexSessionForRW is a flag to enable multiplexed session for read/write transactions.
enableMultiplexSessionForRW bool
harshachinta marked this conversation as resolved.
Show resolved Hide resolved

// healthCheckSampleInterval is how often the health checker samples live
// session (for use in maintaining session pool size).
//
Expand Down Expand Up @@ -703,6 +706,7 @@ func newSessionPool(sc *sessionClient, config SessionPoolConfig) (*sessionPool,
if isMultiplexed != "" && isMultiplexed != "true" && isMultiplexed != "false" {
return nil, spannerErrorf(codes.InvalidArgument, "GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS must be either true or false")
}
harshachinta marked this conversation as resolved.
Show resolved Hide resolved

pool := &sessionPool{
sc: sc,
valid: true,
Expand All @@ -713,7 +717,7 @@ func newSessionPool(sc *sessionClient, config SessionPoolConfig) (*sessionPool,
mw: newMaintenanceWindow(config.MaxOpened),
rand: rand.New(rand.NewSource(time.Now().UnixNano())),
otConfig: sc.otConfig,
enableMultiplexSession: isMultiplexed == "true",
enableMultiplexSession: isMultiplexed == "true" || config.enableMultiplexSessionForRW,
harshachinta marked this conversation as resolved.
Show resolved Hide resolved
}

_, instance, database, err := parseDatabaseName(sc.database)
Expand Down Expand Up @@ -1291,6 +1295,7 @@ func (p *sessionPool) takeMultiplexed(ctx context.Context) (*sessionHandle, erro
if isUnimplementedError(err) {
logf(p.sc.logger, "Multiplexed session is not enabled on this project, continuing with regular sessions")
p.enableMultiplexSession = false
p.enableMultiplexSessionForRW = false
harshachinta marked this conversation as resolved.
Show resolved Hide resolved
} else {
p.mu.Unlock()
// If the error is a timeout, there is a chance that the session was
Expand Down
Loading
Loading