Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sqlproxyccl: Add support for limiting non root connections #99056

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pkg/ccl/cliccl/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ func setProxyContextDefaults() {
proxyContext.PollConfigInterval = 30 * time.Second
proxyContext.ThrottleBaseDelay = time.Second
proxyContext.DisableConnectionRebalancing = false
proxyContext.EnableLimitNonRootConns = false
}

var testDirectorySvrContext struct {
Expand Down
1 change: 1 addition & 0 deletions pkg/ccl/cliccl/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ func init() {
cliflagcfg.StringFlag(f, &proxyContext.DirectoryAddr, cliflags.DirectoryAddr)
cliflagcfg.BoolFlag(f, &proxyContext.SkipVerify, cliflags.SkipVerify)
cliflagcfg.BoolFlag(f, &proxyContext.Insecure, cliflags.InsecureBackend)
cliflagcfg.BoolFlag(f, &proxyContext.EnableLimitNonRootConns, cliflags.LimitNonRootConns)
cliflagcfg.DurationFlag(f, &proxyContext.ValidateAccessInterval, cliflags.ValidateAccessInterval)
cliflagcfg.DurationFlag(f, &proxyContext.PollConfigInterval, cliflags.PollConfigInterval)
cliflagcfg.DurationFlag(f, &proxyContext.ThrottleBaseDelay, cliflags.ThrottleBaseDelay)
Expand Down
11 changes: 11 additions & 0 deletions pkg/ccl/sqlproxyccl/authentication.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ var authenticate = func(
clientConn, crdbConn net.Conn,
proxyBackendKeyData *pgproto3.BackendKeyData,
throttleHook func(throttler.AttemptStatus) error,
limitHook func(crdbConn net.Conn) error,
) (crdbBackendKeyData *pgproto3.BackendKeyData, _ error) {
fe := pgproto3.NewBackend(pgproto3.NewChunkReader(clientConn), clientConn)
be := pgproto3.NewFrontend(pgproto3.NewChunkReader(crdbConn), crdbConn)
Expand Down Expand Up @@ -144,6 +145,16 @@ var authenticate = func(
// Server has authenticated the connection successfully and is ready to
// serve queries.
case *pgproto3.ReadyForQuery:
// We invoke the limit hook before sending ReadyForQuery back to the client.
// This is because once we send ReadyForQuery, the client may start sending
// messages and if we send an error the client may not read it correctly.
limitError := limitHook(crdbConn)
if limitError != nil {
if err = feSend(toPgError(limitError)); err != nil {
return nil, err
}
return nil, limitError
}
if err = feSend(backendMsg); err != nil {
return nil, err
}
Expand Down
4 changes: 3 additions & 1 deletion pkg/ccl/sqlproxyccl/connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ func (c *connector) OpenTenantConnWithAuth(
requester balancer.ConnectionHandle,
clientConn net.Conn,
throttleHook func(throttler.AttemptStatus) error,
limitHook func(crdbConn net.Conn) error,
) (retServerConnection net.Conn, sentToClient bool, retErr error) {
// Just a safety check, but this shouldn't happen since we will block the
// startup param in the frontend admitter. The only case where we actually
Expand All @@ -163,7 +164,8 @@ func (c *connector) OpenTenantConnWithAuth(

// Perform user authentication for non-token-based auth methods. This will
// block until the server has authenticated the client.
crdbBackendKeyData, err := authenticate(clientConn, serverConn, c.CancelInfo.proxyBackendKeyData, throttleHook)
crdbBackendKeyData, err := authenticate(clientConn, serverConn, c.CancelInfo.proxyBackendKeyData,
throttleHook, limitHook)
if err != nil {
return nil, true, err
}
Expand Down
99 changes: 99 additions & 0 deletions pkg/ccl/sqlproxyccl/proxy_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import (
"bytes"
"context"
"crypto/tls"
"github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/interceptor"
"io"
"net"
"net/http"
"regexp"
Expand Down Expand Up @@ -108,6 +110,9 @@ type ProxyOptions struct {
ThrottleBaseDelay time.Duration
// DisableConnectionRebalancing disables connection rebalancing for tenants.
DisableConnectionRebalancing bool
// EnableLimitNonRootConns enables the limiting of non root connections for
// tenants who have that limit set on them.
EnableLimitNonRootConns bool
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: consider generalizing the name of this feature to something like enable-active-connection-limit. We currently exclude the root user, but we may want to exclude more users in the future. For example: we may automatically create a cockroach-cloud-ui user and we would also want to exclude that from the connection limits.


// testingKnobs are knobs used for testing.
testingKnobs struct {
Expand Down Expand Up @@ -164,11 +169,21 @@ const throttledErrorHint string = `Connection throttling is triggered by repeate
sure the username and password are correct.
`

// TODO: Make this generic - let tenant dir provide error message
const resourceLimitedErrorHint string = `Connection limiting is triggered by insufficient resource limits. Make
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should ensure the error message is as useful as possible. Ideally this would say something like: cluster is throttled to one connection: <cluster_name> is out of storage.

If we want to keep the sqlproxy part generic, which makes sense to me, then we should let the tenant directory pick the error message.

sure the cluster has not hit its request unit and storage limits.
`

var throttledError = errors.WithHint(
withCode(errors.New(
"connection attempt throttled"), codeProxyRefusedConnection),
throttledErrorHint)

var connectionsLimitedError = errors.WithHint(
withCode(errors.New(
"connection attempt limited"), codeProxyRefusedConnection),
resourceLimitedErrorHint)

// newProxyHandler will create a new proxy handler with configuration based on
// the provided options.
func newProxyHandler(
Expand Down Expand Up @@ -273,6 +288,9 @@ func newProxyHandler(
if handler.testingKnobs.dirOpts != nil {
dirOpts = append(dirOpts, handler.testingKnobs.dirOpts...)
}
if options.EnableLimitNonRootConns {
dirOpts = append(dirOpts, tenant.WithWatchTenantsOption())
}

client := tenant.NewDirectoryClient(conn)
handler.directoryCache, err = tenant.NewDirectoryCache(ctx, stopper, client, dirOpts...)
Expand Down Expand Up @@ -428,6 +446,9 @@ func (handler *proxyHandler) handle(ctx context.Context, incomingConn net.Conn)
return throttledError
}
return nil
}, func(crdbConn net.Conn) error {
user := fe.Msg.Parameters["user"]
return handler.limitNonRootConnections(ctx, user, tenID, clusterName, crdbConn)
},
)
if err != nil {
Expand Down Expand Up @@ -465,6 +486,7 @@ func (handler *proxyHandler) handle(ctx context.Context, incomingConn net.Conn)
writeErrMarker: errClientWrite,
}


// Pass ownership of conn and crdbConn to the forwarder.
if err := f.run(clientConn, crdbConn); err != nil {
// Don't send to the client here for the same reason below.
Expand Down Expand Up @@ -503,6 +525,83 @@ func (handler *proxyHandler) handle(ctx context.Context, incomingConn net.Conn)
}
}

func (handler *proxyHandler) limitNonRootConnections (ctx context.Context, user string,
tenID roachpb.TenantID, clusterName string, crdbConn net.Conn) error {
if !handler.EnableLimitNonRootConns || user == "root" {
return nil
}

// If this is a non-root connection and the tenant has a specified limit on
// their non-root connections, then ensure that limit is respected.
tenantResp, err := handler.directoryCache.LookupTenant(ctx, tenID, clusterName)
if err != nil {
log.Errorf(ctx, "error looking up tenant %v: %v", tenID, err.Error())
return err
}

// If the tenant has a limit specified, count the non-root connections.
if tenantResp.MaxNonRootConns != -1 {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think of making MaxNonRootConns nullable instead of using -1 as a sentinel value?

count, err := countNonRootConnections(ctx, interceptor.NewPGConn(crdbConn))
if err != nil {
log.Errorf(ctx, "error counting non root connections for tenant %v: %v", tenID, err.Error())
return err
}
log.Infof(ctx, "non-root connections open for tenant %v: %v, maxNonRootConnections %v",
tenID, count, tenantResp.MaxNonRootConns)
// The current connection is included in the count, so if the count is greater
// than the limit then this current connection should be closed.
if int32(count) > tenantResp.MaxNonRootConns {
log.Error(ctx, "throttler refused connection due to limiting non root connections")
return connectionsLimitedError
}
}

return nil
}

func countNonRootConnections(ctx context.Context, pgServerConn *interceptor.PGConn) (int, error) {
if err := writeQuery(pgServerConn, "select count(*) from [show cluster sessions] where user_name != 'root'"); err != nil {
return -1, err
}

pgServerFeConn := pgServerConn.ToFrontendConn()
if err := waitForSmallRowDescription(
ctx,
pgServerFeConn,
io.Discard,
func(msg *pgproto3.RowDescription) bool {
return true
},
); err != nil {
return -1, err
}

numNonRootSessions := 0
if err := expectDataRow(ctx, pgServerFeConn, func(msg *pgproto3.DataRow, _ int) bool {
if len(msg.Values) != 1 {
return false
}
num, err := strconv.Atoi(string(msg.Values[0]))
if err != nil {
return false
}
numNonRootSessions = num
return true
}); err != nil {
return -1, errors.Wrap(err, "expecting DataRow")
}

if err := expectCommandComplete(ctx, pgServerFeConn, "SELECT 1"); err != nil {
return -1, errors.Wrap(err, "expecting CommandComplete")
}

if err := expectReadyForQuery(ctx, pgServerFeConn); err != nil {
return -1, errors.Wrap(err, "expecting ReadyForQuery")
}

return numNonRootSessions, nil
}

// handleCancelRequest handles a pgwire query cancel request by either
// forwarding it to a SQL node or to another proxy.
func (handler *proxyHandler) handleCancelRequest(
Expand Down
28 changes: 28 additions & 0 deletions pkg/ccl/sqlproxyccl/tenant/directory.proto
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,27 @@ message WatchPodsResponse {
Pod pod = 1;
}

// WatchTenantsRequest is empty as we want to get all notifications.
message WatchTenantsRequest {}

// WatchTenantsResponse represents the notifications that the server sends to
// its clients when clients want to monitor the directory server activity.
message WatchTenantsResponse {
Tenant tenant = 1;
}

// Tenant contains information about a tenant, such as its name, tenantID,
// and connection limit.
message Tenant {
// ClusterName is the name of the tenant's cluster.
string cluster_name = 1;
// MaxNonRootConns is the limit on non-root sql connections to this tenant.
// A value of -1 indicates no limit.
int32 max_non_root_conns = 2;
// TenantID identifies the tenant.
uint64 tenant_id = 3 [(gogoproto.customname) = "TenantID"];
}

// EnsurePodRequest is used to ensure that at least one tenant pod is in the
// RUNNING state.
message EnsurePodRequest {
Expand All @@ -107,7 +128,10 @@ message GetTenantRequest {
// GetTenantResponse is sent back when a client requests metadata for a tenant.
message GetTenantResponse {
// ClusterName is the name of the tenant's cluster.
// TODO(alyshan): Deprecate cluster_name field in favour of tenant field.
string cluster_name = 1; // add more metadata if needed
// Tenant contains the information for the tenant specified in the request.
Tenant tenant = 2;
}

// Directory specifies a service that keeps track and manages tenant backends,
Expand All @@ -120,6 +144,10 @@ service Directory {
// are sent when a tenant pod is created, destroyed, or modified. When
// WatchPods is first called, it returns notifications for all existing pods.
rpc WatchPods(WatchPodsRequest) returns (stream WatchPodsResponse);
// WatchTenants gets a stream of tenant change notifications. Notifications
// are sent when a tenant is created, destroyed, or modified. When
// WatchTenants is first called, it returns notifications for all existing tenants.
rpc WatchTenants(WatchTenantsRequest) returns (stream WatchTenantsResponse);
// EnsurePod ensures that at least one of the given tenant's pod is in the
// RUNNING state. If there is already a RUNNING pod, then the server doesn't
// have to do anything. If there isn't a RUNNING pod, then the server must
Expand Down
Loading