diff --git a/pkg/sql/conn_executor.go b/pkg/sql/conn_executor.go index ef8a0a0721df..966faab1e8ed 100644 --- a/pkg/sql/conn_executor.go +++ b/pkg/sql/conn_executor.go @@ -325,6 +325,7 @@ type Server struct { mu struct { syncutil.Mutex connectionCount int64 + rootConnectionCount int64 } } @@ -784,20 +785,38 @@ func (s *Server) SetupConn( } // IncrementConnectionCount increases connectionCount by 1. +// If isRoot is true then it also increases rootConnectionCount by 1. func (s *Server) IncrementConnectionCount() { s.mu.Lock() defer s.mu.Unlock() s.mu.connectionCount++ } +// IncrementRootConnectionCount increases both connectionCount and rootConnectionCount by 1. +func (s *Server) IncrementRootConnectionCount() { + s.mu.Lock() + defer s.mu.Unlock() + s.mu.connectionCount++ + s.mu.rootConnectionCount++ +} + // DecrementConnectionCount decreases connectionCount by 1. +// If isRoot is true then it also decreases rootConnectionCount by 1. func (s *Server) DecrementConnectionCount() { s.mu.Lock() defer s.mu.Unlock() s.mu.connectionCount-- } -// IncrementConnectionCountIfLessThan increases connectionCount by and returns true if allowedConnectionCount < max, +// DecrementRootConnectionCount decreases both connectionCount and rootConnectionCount by 1. +func (s *Server) DecrementRootConnectionCount() { + s.mu.Lock() + defer s.mu.Unlock() + s.mu.connectionCount-- + s.mu.rootConnectionCount-- +} + +// IncrementConnectionCountIfLessThan increases connectionCount by 1 and returns true if connectionCount < max, // otherwise it does nothing and returns false. func (s *Server) IncrementConnectionCountIfLessThan(max int64) bool { s.mu.Lock() @@ -816,6 +835,13 @@ func (s *Server) GetConnectionCount() int64 { return s.mu.connectionCount } +// GetNonRootConnectionCount returns the current number of non root connections. +func (s *Server) GetNonRootConnectionCount() int64 { + s.mu.Lock() + defer s.mu.Unlock() + return s.mu.connectionCount - s.mu.rootConnectionCount +} + // ConnectionHandler is the interface between the result of SetupConn // and the ServeConn below. It encapsulates the connExecutor and hides // it away from other packages. diff --git a/pkg/sql/pgwire/conn.go b/pkg/sql/pgwire/conn.go index 96b0550e0c84..c27577162695 100644 --- a/pkg/sql/pgwire/conn.go +++ b/pkg/sql/pgwire/conn.go @@ -232,13 +232,32 @@ func (c *conn) sendError(ctx context.Context, execCfg *sql.ExecutorConfig, err e } func (c *conn) checkMaxConnections(ctx context.Context, sqlServer *sql.Server) error { + // Root user is not affected by connection limits. + if c.sessionArgs.User.IsRootUser() { + sqlServer.IncrementRootConnectionCount() + return nil + } + + // First check maxNumExternalConnections. + maxExternalConnectionsValue := maxNumExternalConnections.Get(&sqlServer.GetExecutorConfig().Settings.SV) + if maxExternalConnectionsValue >= 0 && sqlServer.GetNonRootConnectionCount() >= maxExternalConnectionsValue { + // TODO(alyshan): Add another cluster setting that supplements the error message, so that + // clients can know *why* this ReadOnly setting is set and limiting their connections. + return c.sendError(ctx, sqlServer.GetExecutorConfig(), errors.WithHintf( + pgerror.New(pgcode.TooManyConnections, "cluster connections are limited"), + "the maximum number of allowed connections is %d", + maxExternalConnectionsValue, + )) + } + + // Then check maxNumNonAdminConnections. if c.sessionArgs.IsSuperuser { - // This user is a super user and is therefore not affected by connection limits. + // This user is a super user and is therefore not affected by maxNumNonAdminConnections. sqlServer.IncrementConnectionCount() return nil } - maxNumConnectionsValue := maxNumConnections.Get(&sqlServer.GetExecutorConfig().Settings.SV) + maxNumConnectionsValue := maxNumNonAdminConnections.Get(&sqlServer.GetExecutorConfig().Settings.SV) if maxNumConnectionsValue < 0 { // Unlimited connections are allowed. sqlServer.IncrementConnectionCount() @@ -249,7 +268,7 @@ func (c *conn) checkMaxConnections(ctx context.Context, sqlServer *sql.Server) e pgerror.New(pgcode.TooManyConnections, "sorry, too many clients already"), "the maximum number of allowed connections is %d and can be modified using the %s config key", maxNumConnectionsValue, - maxNumConnections.Key(), + maxNumNonAdminConnections.Key(), )) } return nil @@ -702,7 +721,12 @@ func (c *conn) processCommandsAsync( if retErr = c.checkMaxConnections(ctx, sqlServer); retErr != nil { return } - defer sqlServer.DecrementConnectionCount() + if c.sessionArgs.User.IsRootUser() { + defer sqlServer.DecrementRootConnectionCount() + } else { + defer sqlServer.DecrementConnectionCount() + } + if retErr = c.authOKMessage(); retErr != nil { return diff --git a/pkg/sql/pgwire/conn_test.go b/pkg/sql/pgwire/conn_test.go index 8207d562cb7d..e9ac8fb59b1c 100644 --- a/pkg/sql/pgwire/conn_test.go +++ b/pkg/sql/pgwire/conn_test.go @@ -1867,8 +1867,9 @@ func TestPGWireRejectsNewConnIfTooManyConns(t *testing.T) { defer testServer.Stopper().Stop(ctx) // Users. - admin := username.RootUser + rootUser := username.RootUser nonAdmin := username.TestUser + admin := "testadmin" // openConnWithUser opens a connection to the testServer for the given user // and always returns an associated cleanup function, even in case of error, @@ -1879,7 +1880,7 @@ func TestPGWireRejectsNewConnIfTooManyConns(t *testing.T) { testServer.ServingSQLAddr(), t.Name(), url.UserPassword(user, user), - user == admin, + user == rootUser, ) defer cleanup() conn, err := pgx.Connect(ctx, pgURL.String()) @@ -1923,7 +1924,7 @@ func TestPGWireRejectsNewConnIfTooManyConns(t *testing.T) { } getMaxConnections := func() int { - conn, cleanup := openConnWithUserSuccess(admin) + conn, cleanup := openConnWithUserSuccess(rootUser) defer cleanup() var maxConnections int err := conn.QueryRow(ctx, "SHOW CLUSTER SETTING server.max_connections_per_gateway").Scan(&maxConnections) @@ -1932,21 +1933,34 @@ func TestPGWireRejectsNewConnIfTooManyConns(t *testing.T) { } setMaxConnections := func(maxConnections int) { - conn, cleanup := openConnWithUserSuccess(admin) + conn, cleanup := openConnWithUserSuccess(rootUser) defer cleanup() _, err := conn.Exec(ctx, "SET CLUSTER SETTING server.max_connections_per_gateway = $1", maxConnections) require.NoError(t, err) } - createUser := func(user string) { - conn, cleanup := openConnWithUserSuccess(admin) + setMaxExternalConnections := func(maxConnections int) { + conn, cleanup := openConnWithUserSuccess(rootUser) + defer cleanup() + _, err := conn.Exec(ctx, "SET CLUSTER SETTING server.max_external_connections_per_gateway = $1", maxConnections) + require.NoError(t, err) + } + + createUser := func(user string, isAdmin bool) { + conn, cleanup := openConnWithUserSuccess(rootUser) defer cleanup() _, err := conn.Exec(ctx, fmt.Sprintf("CREATE USER %[1]s WITH PASSWORD '%[1]s'", user)) require.NoError(t, err) + + if isAdmin { + _, err := conn.Exec(ctx, fmt.Sprintf("grant admin to %[1]s", user)) + require.NoError(t, err) + } } // create nonAdmin - createUser(nonAdmin) + createUser(nonAdmin, false) + createUser(admin, true) requireConnectionCount(t, 0) // assert default value @@ -2012,6 +2026,33 @@ func TestPGWireRejectsNewConnIfTooManyConns(t *testing.T) { nonAdminCleanup2() requireConnectionCount(t, 0) }) + + t.Run("0 max_external_connections", func(t *testing.T) { + setMaxExternalConnections(0) + requireConnectionCount(t, 0) + // can connect with root + _, rootCleanup := openConnWithUserSuccess(rootUser) + requireConnectionCount(t, 1) + // can't connect with non root + openConnWithUserError(admin) + requireConnectionCount(t, 1) + rootCleanup() + requireConnectionCount(t, 0) + }) + + t.Run("1 max_external_connections", func(t *testing.T) { + setMaxExternalConnections(1) + requireConnectionCount(t, 0) + // can connect with root + _, rootCleanup := openConnWithUserSuccess(rootUser) + requireConnectionCount(t, 1) + // can connect with non root + _, nonAdminCleanup := openConnWithUserSuccess(nonAdmin) + requireConnectionCount(t, 2) + rootCleanup() + nonAdminCleanup() + requireConnectionCount(t, 0) + }) } func TestConnCloseReleasesReservedMem(t *testing.T) { diff --git a/pkg/sql/pgwire/server.go b/pkg/sql/pgwire/server.go index 485aba21bceb..4e09649e4fa1 100644 --- a/pkg/sql/pgwire/server.go +++ b/pkg/sql/pgwire/server.go @@ -82,7 +82,11 @@ var logSessionAuth = settings.RegisterBoolSetting( "if set, log SQL session login/disconnection events (note: may hinder performance on loaded nodes)", false).WithPublic() -var maxNumConnections = settings.RegisterIntSetting( +// TODO(alyshan): This setting is enforcing max number of connections with superusers not being affected by +// the limit. However, admin users connections are counted towards the max count. So we should either update the +// description to say "the maximum number of connections per gateway ... Superusers are not affected by this limit" +// or stop counting superuser connections towards the max count. +var maxNumNonAdminConnections = settings.RegisterIntSetting( settings.TenantWritable, "server.max_connections_per_gateway", "the maximum number of non-superuser SQL connections per gateway allowed at a given time "+ @@ -91,6 +95,20 @@ var maxNumConnections = settings.RegisterIntSetting( -1, // Postgres defaults to 100, but we default to -1 to match our previous behavior of unlimited. ).WithPublic() +// Note(alyshan): One might suggest this cluster setting be named server.max_non_root_connections_per_gateway +// as that reflects its current behaviour of excluding root connections from being limited. However, I chose +// to use the term "external" so that this setting can be extended to exclude connections from an arbitrary +// list of users. +// Note(alyshan): This setting is not public. +var maxNumExternalConnections = settings.RegisterIntSetting( + settings.TenantWritable, + "server.max_external_connections_per_gateway", + "the maximum number of external SQL connections per gateway allowed at a given time" + + "(note: this will only limit future connection attempts and will not affect already established connections). "+ + "Negative values result in unlimited number of connections. The root user is not affected by this limit.", + -1, + ) + const ( // ErrSSLRequired is returned when a client attempts to connect to a // secure server in cleartext.