diff --git a/pkg/workload/BUILD.bazel b/pkg/workload/BUILD.bazel index 356fbd763d7d..4153cc874112 100644 --- a/pkg/workload/BUILD.bazel +++ b/pkg/workload/BUILD.bazel @@ -21,6 +21,7 @@ go_library( "//pkg/util/bufalloc", "//pkg/util/encoding/csv", "//pkg/util/log", + "//pkg/util/syncutil", "//pkg/util/timeutil", "//pkg/workload/histogram", "@com_github_cockroachdb_errors//:errors", diff --git a/pkg/workload/pgx_helpers.go b/pkg/workload/pgx_helpers.go index 928819bf68b4..a2a23e47c3cd 100644 --- a/pkg/workload/pgx_helpers.go +++ b/pkg/workload/pgx_helpers.go @@ -15,6 +15,7 @@ import ( "sync/atomic" "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/cockroach/pkg/util/syncutil" "github.com/jackc/pgx/v4" "github.com/jackc/pgx/v4/pgxpool" "golang.org/x/sync/errgroup" @@ -25,9 +26,13 @@ type MultiConnPool struct { Pools []*pgxpool.Pool // Atomic counter used by Get(). counter uint32 - // preparedStatements is a map from name to SQL. The statements in the map - // are prepared whenever a new connection is acquired from the pool. - preparedStatements map[string]string + + mu struct { + syncutil.RWMutex + // preparedStatements is a map from name to SQL. The statements in the map + // are prepared whenever a new connection is acquired from the pool. + preparedStatements map[string]string + } } // MultiConnPoolCfg encapsulates the knobs passed to NewMultiConnPool. @@ -66,9 +71,9 @@ func (p pgxLogger) Log( func NewMultiConnPool( ctx context.Context, cfg MultiConnPoolCfg, urls ...string, ) (*MultiConnPool, error) { - m := &MultiConnPool{ - preparedStatements: map[string]string{}, - } + m := &MultiConnPool{} + m.mu.preparedStatements = map[string]string{} + connsPerURL := distribute(cfg.MaxTotalConnections, len(urls)) maxConnsPerPool := cfg.MaxConnsPerPool if maxConnsPerPool == 0 { @@ -90,7 +95,9 @@ func NewMultiConnPool( connCfg.ConnConfig.Logger = pgxLogger{} connCfg.MaxConns = int32(numConns) connCfg.BeforeAcquire = func(ctx context.Context, conn *pgx.Conn) bool { - for name, sql := range m.preparedStatements { + m.mu.RLock() + defer m.mu.RUnlock() + for name, sql := range m.mu.preparedStatements { // Note that calling `Prepare` with a name that has already been // prepared is idempotent and short-circuits before doing any // communication to the server. @@ -150,6 +157,15 @@ func NewMultiConnPool( return m, nil } +// AddPreparedStatement adds the given sql statement to the map of +// statements that will be prepared when a new connection is retrieved +// from the pool. +func (m *MultiConnPool) AddPreparedStatement(name string, statement string) { + m.mu.Lock() + defer m.mu.Unlock() + m.mu.preparedStatements[name] = statement +} + // Get returns one of the pools, in round-robin manner. func (m *MultiConnPool) Get() *pgxpool.Pool { if len(m.Pools) == 1 { diff --git a/pkg/workload/sql_runner.go b/pkg/workload/sql_runner.go index 074a5e88b9f0..c3dd76ce4dc3 100644 --- a/pkg/workload/sql_runner.go +++ b/pkg/workload/sql_runner.go @@ -119,10 +119,9 @@ func (sr *SQLRunner) Init( for i, s := range sr.stmts { stmtName := fmt.Sprintf("%s-%d", name, i+1) s.preparedName = stmtName - mcp.preparedStatements[stmtName] = s.sql + mcp.AddPreparedStatement(stmtName, s.sql) } } - sr.mcp = mcp sr.initialized = true return nil