Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pgwire: add server.max_connections_per_gateway public cluster setting #76401

Merged
merged 1 commit into from
Feb 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/generated/settings/settings-for-tenants.txt
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ server.eventlog.enabled boolean true if set, logged notable events are also stor
server.eventlog.ttl duration 2160h0m0s if nonzero, entries in system.eventlog older than this duration are deleted every 10m0s. Should not be lowered below 24 hours.
server.host_based_authentication.configuration string host-based authentication configuration to use during connection authentication
server.identity_map.configuration string system-identity to database-username mappings
server.max_connections_per_gateway integer -1 the maximum number of non-superuser SQL connections per gateway allowed at a given time (note: this will only limit future connection attempts and will not affect already established connections). Negative values result in unlimited number of connections. Superusers are not affected by this limit.
server.oidc_authentication.autologin boolean false if true, logged-out visitors to the DB Console will be automatically redirected to the OIDC login endpoint
server.oidc_authentication.button_text string Login with your OIDC provider text to show on button on DB Console login page to login with your OIDC provider (only shown if OIDC is enabled)
server.oidc_authentication.claim_json_key string sets JSON key of principal to extract from payload after OIDC authentication completes (usually email or sid)
Expand Down
1 change: 1 addition & 0 deletions docs/generated/settings/settings.html
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
<tr><td><code>server.eventlog.ttl</code></td><td>duration</td><td><code>2160h0m0s</code></td><td>if nonzero, entries in system.eventlog older than this duration are deleted every 10m0s. Should not be lowered below 24 hours.</td></tr>
<tr><td><code>server.host_based_authentication.configuration</code></td><td>string</td><td><code></code></td><td>host-based authentication configuration to use during connection authentication</td></tr>
<tr><td><code>server.identity_map.configuration</code></td><td>string</td><td><code></code></td><td>system-identity to database-username mappings</td></tr>
<tr><td><code>server.max_connections_per_gateway</code></td><td>integer</td><td><code>-1</code></td><td>the maximum number of non-superuser SQL connections per gateway allowed at a given time (note: this will only limit future connection attempts and will not affect already established connections). Negative values result in unlimited number of connections. Superusers are not affected by this limit.</td></tr>
<tr><td><code>server.oidc_authentication.autologin</code></td><td>boolean</td><td><code>false</code></td><td>if true, logged-out visitors to the DB Console will be automatically redirected to the OIDC login endpoint</td></tr>
<tr><td><code>server.oidc_authentication.button_text</code></td><td>string</td><td><code>Login with your OIDC provider</code></td><td>text to show on button on DB Console login page to login with your OIDC provider (only shown if OIDC is enabled)</td></tr>
<tr><td><code>server.oidc_authentication.claim_json_key</code></td><td>string</td><td><code></code></td><td>sets JSON key of principal to extract from payload after OIDC authentication completes (usually email or sid)</td></tr>
Expand Down
38 changes: 38 additions & 0 deletions pkg/sql/conn_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,11 @@ type Server struct {

// TelemetryLoggingMetrics is used to track metrics for logging to the telemetry channel.
TelemetryLoggingMetrics *TelemetryLoggingMetrics

mu struct {
syncutil.Mutex
connectionCount int64
}
}

// Metrics collects timeseries data about SQL activity.
Expand Down Expand Up @@ -662,6 +667,39 @@ func (s *Server) SetupConn(
return ConnectionHandler{ex}, nil
}

// IncrementConnectionCount increases connectionCount by 1.
func (s *Server) IncrementConnectionCount() {
s.mu.Lock()
defer s.mu.Unlock()
s.mu.connectionCount++
}

// DecrementConnectionCount decreases connectionCount by 1.
func (s *Server) DecrementConnectionCount() {
s.mu.Lock()
defer s.mu.Unlock()
s.mu.connectionCount--
}

// IncrementConnectionCountIfLessThan increases connectionCount by and returns true if allowedConnectionCount < max,
// otherwise it does nothing and returns false.
func (s *Server) IncrementConnectionCountIfLessThan(max int64) bool {
s.mu.Lock()
defer s.mu.Unlock()
lt := s.mu.connectionCount < max
if lt {
s.mu.connectionCount++
}
return lt
}

// GetConnectionCount returns the current number of connections.
func (s *Server) GetConnectionCount() int64 {
s.mu.Lock()
defer s.mu.Unlock()
return s.mu.connectionCount
}

// ConnectionHandler is the interface between the result of SetupConn
// and the ServeConn below. It encapsulates the connExecutor and hides
// it away from other packages.
Expand Down
1 change: 1 addition & 0 deletions pkg/sql/pgwire/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ go_test(
"@com_github_cockroachdb_errors//:errors",
"@com_github_cockroachdb_errors//stdstrings",
"@com_github_cockroachdb_redact//:redact",
"@com_github_jackc_pgconn//:pgconn",
"@com_github_jackc_pgproto3_v2//:pgproto3",
"@com_github_jackc_pgx_v4//:pgx",
"@com_github_lib_pq//:pq",
Expand Down
32 changes: 12 additions & 20 deletions pkg/sql/pgwire/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,16 +96,11 @@ func (c *conn) handleAuthentication(
return nil, authOpt.testingAuthHook(ctx)
}

sendError := func(err error) error {
_ /* err */ = writeErr(ctx, &execCfg.Settings.SV, err, &c.msgBuilder, c.conn)
return err
}

// Retrieve the authentication method.
tlsState, hbaEntry, authMethod, err := c.findAuthenticationMethod(authOpt)
if err != nil {
ac.LogAuthFailed(ctx, eventpb.AuthFailReason_METHOD_NOT_FOUND, err)
return nil, sendError(pgerror.WithCandidateCode(err, pgcode.InvalidAuthorizationSpecification))
return nil, c.sendError(ctx, execCfg, pgerror.WithCandidateCode(err, pgcode.InvalidAuthorizationSpecification))
}

ac.SetAuthMethod(hbaEntry.Method.String())
Expand All @@ -118,7 +113,7 @@ func (c *conn) handleAuthentication(
connClose = behaviors.ConnClose
if err != nil {
ac.LogAuthFailed(ctx, eventpb.AuthFailReason_UNKNOWN, err)
return connClose, sendError(pgerror.WithCandidateCode(err, pgcode.InvalidAuthorizationSpecification))
return connClose, c.sendError(ctx, execCfg, pgerror.WithCandidateCode(err, pgcode.InvalidAuthorizationSpecification))
}

// Choose the system identity that we'll use below for mapping
Expand All @@ -136,7 +131,7 @@ func (c *conn) handleAuthentication(
if err := c.chooseDbRole(ctx, ac, behaviors.MapRole, systemIdentity); err != nil {
log.Warningf(ctx, "unable to map incoming identity %q to any database user: %+v", systemIdentity, err)
ac.LogAuthFailed(ctx, eventpb.AuthFailReason_USER_NOT_FOUND, err)
return connClose, sendError(pgerror.WithCandidateCode(err, pgcode.InvalidAuthorizationSpecification))
return connClose, c.sendError(ctx, execCfg, pgerror.WithCandidateCode(err, pgcode.InvalidAuthorizationSpecification))
}

// Once chooseDbRole() returns, we know that the actual DB username
Expand All @@ -156,33 +151,26 @@ func (c *conn) handleAuthentication(
if err != nil {
log.Warningf(ctx, "user retrieval failed for user=%q: %+v", dbUser, err)
ac.LogAuthFailed(ctx, eventpb.AuthFailReason_USER_RETRIEVAL_ERROR, err)
return connClose, sendError(pgerror.WithCandidateCode(err, pgcode.InvalidAuthorizationSpecification))
return connClose, c.sendError(ctx, execCfg, pgerror.WithCandidateCode(err, pgcode.InvalidAuthorizationSpecification))
}
c.sessionArgs.IsSuperuser = isSuperuser

if !exists {
ac.LogAuthFailed(ctx, eventpb.AuthFailReason_USER_NOT_FOUND, nil)
return connClose, sendError(pgerror.WithCandidateCode(
security.NewErrPasswordUserAuthFailed(dbUser),
pgcode.InvalidAuthorizationSpecification,
))
return connClose, c.sendError(ctx, execCfg, pgerror.WithCandidateCode(security.NewErrPasswordUserAuthFailed(dbUser), pgcode.InvalidAuthorizationSpecification))
}

if !canLoginSQL {
ac.LogAuthFailed(ctx, eventpb.AuthFailReason_LOGIN_DISABLED, nil)
return connClose, sendError(pgerror.Newf(
pgcode.InvalidAuthorizationSpecification,
"%s does not have login privilege",
dbUser,
))
return connClose, c.sendError(ctx, execCfg, pgerror.Newf(pgcode.InvalidAuthorizationSpecification, "%s does not have login privilege", dbUser))
}

// At this point, we know that the requested user exists and is
// allowed to log in. Now we can delegate to the selected AuthMethod
// implementation to complete the authentication.
if err := behaviors.Authenticate(ctx, systemIdentity, true /* public */, pwRetrievalFn); err != nil {
ac.LogAuthFailed(ctx, eventpb.AuthFailReason_CREDENTIALS_INVALID, err)
return connClose, sendError(pgerror.WithCandidateCode(err, pgcode.InvalidAuthorizationSpecification))
return connClose, c.sendError(ctx, execCfg, pgerror.WithCandidateCode(err, pgcode.InvalidAuthorizationSpecification))
}

// Add all the defaults to this session's defaults. If there is an
Expand All @@ -209,9 +197,13 @@ func (c *conn) handleAuthentication(
}

ac.LogAuthOK(ctx)
return connClose, nil
}

func (c *conn) authOKMessage() error {
c.msgBuilder.initMsg(pgwirebase.ServerMsgAuth)
c.msgBuilder.putInt32(authOK)
return connClose, c.msgBuilder.finishMsg(c.conn)
return c.msgBuilder.finishMsg(c.conn)
}

// chooseDbRole uses the provided RoleMapper to map an incoming
Expand Down
42 changes: 42 additions & 0 deletions pkg/sql/pgwire/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,39 @@ func (c *conn) GetErr() error {
return nil
}

func (c *conn) sendError(ctx context.Context, execCfg *sql.ExecutorConfig, err error) error {
// We could, but do not, report server-side network errors while
// trying to send the client error. This is because clients that
// receive error payload are highly correlated with clients
// disconnecting abruptly.
_ /* err */ = writeErr(ctx, &execCfg.Settings.SV, err, &c.msgBuilder, c.conn)
return err
}

func (c *conn) checkMaxConnections(ctx context.Context, sqlServer *sql.Server) error {
if c.sessionArgs.IsSuperuser {
// This user is a super user and is therefore not affected by connection limits.
sqlServer.IncrementConnectionCount()
return nil
}

maxNumConnectionsValue := maxNumConnections.Get(&sqlServer.GetExecutorConfig().Settings.SV)
if maxNumConnectionsValue < 0 {
// Unlimited connections are allowed.
sqlServer.IncrementConnectionCount()
return nil
}
if !sqlServer.IncrementConnectionCountIfLessThan(maxNumConnectionsValue) {
return c.sendError(ctx, sqlServer.GetExecutorConfig(), errors.WithHintf(
pgerror.New(pgcode.TooManyConnections, "sorry, too many clients already"),
"the maximum number of allowed connections is %d and can be modified using the %s config key",
maxNumConnectionsValue,
maxNumConnections.Key(),
))
}
return nil
}

func (c *conn) authLogEnabled() bool {
return c.alwaysLogAuthActivity || logSessionAuth.Get(c.sv)
}
Expand Down Expand Up @@ -664,6 +697,15 @@ func (c *conn) processCommandsAsync(
return
}

if retErr = c.checkMaxConnections(ctx, sqlServer); retErr != nil {
return
}
defer sqlServer.DecrementConnectionCount()

if retErr = c.authOKMessage(); retErr != nil {
return
}

// Inform the client of the default session settings.
connHandler, retErr = c.sendInitialConnData(ctx, sqlServer, onDefaultIntSizeChange)
if retErr != nil {
Expand Down
148 changes: 148 additions & 0 deletions pkg/sql/pgwire/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ import (
"github.com/cockroachdb/cockroach/pkg/util/mon"
"github.com/cockroachdb/cockroach/pkg/util/timeutil"
"github.com/cockroachdb/errors"
"github.com/jackc/pgconn"
"github.com/jackc/pgproto3/v2"
"github.com/jackc/pgx/v4"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -1783,3 +1784,150 @@ func TestRoleDefaultSettings(t *testing.T) {
})
}
}

func TestPGWireRejectsNewConnIfTooManyConns(t *testing.T) {
defer leaktest.AfterTest(t)()
defer log.Scope(t).Close(t)

ctx := context.Background()
testServer, _, _ := serverutils.StartServer(t, base.TestServerArgs{})
defer testServer.Stopper().Stop(ctx)

// Users.
admin := security.RootUser
nonAdmin := security.TestUser

// openConnWithUser opens a connection to the testServer for the given user
// and always returns an associated cleanup function, even in case of error,
// which should be called. The returned cleanup function is idempotent.
openConnWithUser := func(user string) (*pgx.Conn, func(), error) {
pgURL, cleanup := sqlutils.PGUrlWithOptionalClientCerts(
t,
testServer.ServingSQLAddr(),
t.Name(),
url.UserPassword(user, user),
user == admin,
)
defer cleanup()
conn, err := pgx.Connect(ctx, pgURL.String())
if err != nil {
return nil, func() {}, err
}
return conn, func() {
require.NoError(t, conn.Close(ctx))
}, nil
}

openConnWithUserSuccess := func(user string) (*pgx.Conn, func()) {
conn, cleanup, err := openConnWithUser(user)
if err != nil {
defer cleanup()
t.FailNow()
}
return conn, cleanup
}

openConnWithUserError := func(user string) {
_, cleanup, err := openConnWithUser(user)
defer cleanup()
require.Error(t, err)
var pgErr *pgconn.PgError
require.ErrorAs(t, err, &pgErr)
require.Equal(t, pgcode.TooManyConnections.String(), pgErr.Code)
}

getConnectionCount := func() int {
return int(testServer.SQLServer().(*sql.Server).GetConnectionCount())
}

getMaxConnections := func() int {
conn, cleanup := openConnWithUserSuccess(admin)
defer cleanup()
var maxConnections int
err := conn.QueryRow(ctx, "SHOW CLUSTER SETTING server.max_connections_per_gateway").Scan(&maxConnections)
require.NoError(t, err)
return maxConnections
}

setMaxConnections := func(maxConnections int) {
conn, cleanup := openConnWithUserSuccess(admin)
defer cleanup()
_, err := conn.Exec(ctx, "SET CLUSTER SETTING server.max_connections_per_gateway = $1", maxConnections)
require.NoError(t, err)
}

createUser := func(user string) {
conn, cleanup := openConnWithUserSuccess(admin)
defer cleanup()
_, err := conn.Exec(ctx, fmt.Sprintf("CREATE USER %[1]s WITH PASSWORD '%[1]s'", user))
require.NoError(t, err)
}

// create nonAdmin
createUser(nonAdmin)
require.Equal(t, 0, getConnectionCount())

// assert default value
require.Equal(t, -1, getMaxConnections())
require.Equal(t, 0, getConnectionCount())

t.Run("0 max_connections", func(t *testing.T) {
setMaxConnections(0)
require.Equal(t, 0, getConnectionCount())
// can't connect with nonAdmin
openConnWithUserError(nonAdmin)
require.Equal(t, 0, getConnectionCount())
// can connect with admin
_, adminCleanup := openConnWithUserSuccess(admin)
require.Equal(t, 1, getConnectionCount())
adminCleanup()
require.Equal(t, 0, getConnectionCount())
})

t.Run("1 max_connections nonAdmin -> admin", func(t *testing.T) {
setMaxConnections(1)
require.Equal(t, 0, getConnectionCount())
// can connect with nonAdmin
_, nonAdminCleanup := openConnWithUserSuccess(nonAdmin)
require.Equal(t, 1, getConnectionCount())
// can connect with admin
_, adminCleanup := openConnWithUserSuccess(admin)
require.Equal(t, 2, getConnectionCount())
adminCleanup()
require.Equal(t, 1, getConnectionCount())
nonAdminCleanup()
require.Equal(t, 0, getConnectionCount())
})

t.Run("1 max_connections admin -> nonAdmin", func(t *testing.T) {
setMaxConnections(1)
require.Equal(t, 0, getConnectionCount())
// can connect with admin
_, adminCleanup := openConnWithUserSuccess(admin)
require.Equal(t, 1, getConnectionCount())
// can't connect with nonAdmin
openConnWithUserError(nonAdmin)
require.Equal(t, 1, getConnectionCount())
adminCleanup()
require.Equal(t, 0, getConnectionCount())
})

t.Run("-1 max_connections", func(t *testing.T) {
setMaxConnections(-1)
require.Equal(t, 0, getConnectionCount())
// can connect with multiple nonAdmin
_, nonAdminCleanup1 := openConnWithUserSuccess(nonAdmin)
require.Equal(t, 1, getConnectionCount())
_, nonAdminCleanup2 := openConnWithUserSuccess(nonAdmin)
require.Equal(t, 2, getConnectionCount())
// can connect with admin
_, adminCleanup := openConnWithUserSuccess(admin)
require.Equal(t, 3, getConnectionCount())
adminCleanup()
require.Equal(t, 2, getConnectionCount())
nonAdminCleanup1()
require.Equal(t, 1, getConnectionCount())
nonAdminCleanup2()
require.Equal(t, 0, getConnectionCount())
})
}
7 changes: 5 additions & 2 deletions pkg/sql/pgwire/pgwire_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1278,8 +1278,11 @@ func TestPGPreparedExec(t *testing.T) {
defer s.Stopper().Stop(context.Background())

runTests := func(
t *testing.T, query string, tests []preparedExecTest, execFunc func(...interface{},
) (gosql.Result, error)) {
t *testing.T,
query string,
tests []preparedExecTest,
execFunc func(...interface{}) (gosql.Result, error),
) {
for idx, test := range tests {
t.Run(fmt.Sprintf("%d", idx), func(t *testing.T) {
if testing.Verbose() || log.V(1) {
Expand Down
Loading