Skip to content

Commit

Permalink
sqlproxyccl: Add support for limiting non root connections
Browse files Browse the repository at this point in the history
This commit adds a cli flag --limit-non-root-conns to the sqlproxy.
When set to true, this enables the sqlproxy to limit the number of non root
connections to tenants. This is achieved by modifying the Directory interface
to supply an additional field MaxNonRootConnections when getting tenants. And
adding a watcher for tenants. The sqlproxy is then able to check if the current
connection exceeds the MaxNonRootConnections limit for the tenant and can close
it appropriately with a user friendly error.

This functionality is required for Serverless.
Part of: https://cockroachlabs.atlassian.net/browse/CC-9288

Release note: None
  • Loading branch information
alyshanjahani-crl committed Mar 20, 2023
1 parent 87d1fed commit 57c49ba
Show file tree
Hide file tree
Showing 12 changed files with 377 additions and 14 deletions.
1 change: 1 addition & 0 deletions pkg/ccl/cliccl/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ func setProxyContextDefaults() {
proxyContext.PollConfigInterval = 30 * time.Second
proxyContext.ThrottleBaseDelay = time.Second
proxyContext.DisableConnectionRebalancing = false
proxyContext.EnableLimitNonRootConns = false
}

var testDirectorySvrContext struct {
Expand Down
1 change: 1 addition & 0 deletions pkg/ccl/cliccl/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ func init() {
cliflagcfg.StringFlag(f, &proxyContext.DirectoryAddr, cliflags.DirectoryAddr)
cliflagcfg.BoolFlag(f, &proxyContext.SkipVerify, cliflags.SkipVerify)
cliflagcfg.BoolFlag(f, &proxyContext.Insecure, cliflags.InsecureBackend)
cliflagcfg.BoolFlag(f, &proxyContext.EnableLimitNonRootConns, cliflags.LimitNonRootConns)
cliflagcfg.DurationFlag(f, &proxyContext.ValidateAccessInterval, cliflags.ValidateAccessInterval)
cliflagcfg.DurationFlag(f, &proxyContext.PollConfigInterval, cliflags.PollConfigInterval)
cliflagcfg.DurationFlag(f, &proxyContext.ThrottleBaseDelay, cliflags.ThrottleBaseDelay)
Expand Down
8 changes: 8 additions & 0 deletions pkg/ccl/sqlproxyccl/authentication.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ var authenticate = func(
clientConn, crdbConn net.Conn,
proxyBackendKeyData *pgproto3.BackendKeyData,
throttleHook func(throttler.AttemptStatus) error,
limitHook func(crdbConn net.Conn) error,
) (crdbBackendKeyData *pgproto3.BackendKeyData, _ error) {
fe := pgproto3.NewBackend(pgproto3.NewChunkReader(clientConn), clientConn)
be := pgproto3.NewFrontend(pgproto3.NewChunkReader(crdbConn), crdbConn)
Expand Down Expand Up @@ -144,6 +145,13 @@ var authenticate = func(
// Server has authenticated the connection successfully and is ready to
// serve queries.
case *pgproto3.ReadyForQuery:
limitError := limitHook(crdbConn)
if limitError != nil {
if err = feSend(toPgError(limitError)); err != nil {
return nil, err
}
return nil, limitError
}
if err = feSend(backendMsg); err != nil {
return nil, err
}
Expand Down
4 changes: 3 additions & 1 deletion pkg/ccl/sqlproxyccl/connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ func (c *connector) OpenTenantConnWithAuth(
requester balancer.ConnectionHandle,
clientConn net.Conn,
throttleHook func(throttler.AttemptStatus) error,
limitHook func(crdbConn net.Conn) error,
) (retServerConnection net.Conn, sentToClient bool, retErr error) {
// Just a safety check, but this shouldn't happen since we will block the
// startup param in the frontend admitter. The only case where we actually
Expand All @@ -163,7 +164,8 @@ func (c *connector) OpenTenantConnWithAuth(

// Perform user authentication for non-token-based auth methods. This will
// block until the server has authenticated the client.
crdbBackendKeyData, err := authenticate(clientConn, serverConn, c.CancelInfo.proxyBackendKeyData, throttleHook)
crdbBackendKeyData, err := authenticate(clientConn, serverConn, c.CancelInfo.proxyBackendKeyData,
throttleHook, limitHook)
if err != nil {
return nil, true, err
}
Expand Down
98 changes: 98 additions & 0 deletions pkg/ccl/sqlproxyccl/proxy_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import (
"bytes"
"context"
"crypto/tls"
"github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/interceptor"
"io"
"net"
"net/http"
"regexp"
Expand Down Expand Up @@ -106,6 +108,9 @@ type ProxyOptions struct {
ThrottleBaseDelay time.Duration
// DisableConnectionRebalancing disables connection rebalancing for tenants.
DisableConnectionRebalancing bool
// EnableLimitNonRootConns enables the limiting of non root connections for
// tenants who have that limit set on them.
EnableLimitNonRootConns bool

// testingKnobs are knobs used for testing.
testingKnobs struct {
Expand Down Expand Up @@ -162,11 +167,20 @@ const throttledErrorHint string = `Connection throttling is triggered by repeate
sure the username and password are correct.
`

const resourceLimitedErrorHint string = `Connection limiting is triggered by insufficient resource limits. Make
sure the cluster has not hit its request unit and storage limits.
`

var throttledError = errors.WithHint(
withCode(errors.New(
"connection attempt throttled"), codeProxyRefusedConnection),
throttledErrorHint)

var connectionsLimitedError = errors.WithHint(
withCode(errors.New(
"connection attempt limited"), codeProxyRefusedConnection),
resourceLimitedErrorHint)

// newProxyHandler will create a new proxy handler with configuration based on
// the provided options.
func newProxyHandler(
Expand Down Expand Up @@ -268,6 +282,9 @@ func newProxyHandler(
if handler.testingKnobs.dirOpts != nil {
dirOpts = append(dirOpts, handler.testingKnobs.dirOpts...)
}
if options.EnableLimitNonRootConns {
dirOpts = append(dirOpts, tenant.WithWatchTenantsOption())
}

client := tenant.NewDirectoryClient(conn)
handler.directoryCache, err = tenant.NewDirectoryCache(ctx, stopper, client, dirOpts...)
Expand Down Expand Up @@ -419,6 +436,9 @@ func (handler *proxyHandler) handle(ctx context.Context, incomingConn net.Conn)
return throttledError
}
return nil
}, func(crdbConn net.Conn) error {
user := fe.Msg.Parameters["user"]
return handler.limitNonRootConnections(ctx, user, tenID, clusterName, crdbConn)
},
)
if err != nil {
Expand Down Expand Up @@ -456,6 +476,7 @@ func (handler *proxyHandler) handle(ctx context.Context, incomingConn net.Conn)
writeErrMarker: errClientWrite,
}


// Pass ownership of conn and crdbConn to the forwarder.
if err := f.run(clientConn, crdbConn); err != nil {
// Don't send to the client here for the same reason below.
Expand Down Expand Up @@ -494,6 +515,83 @@ func (handler *proxyHandler) handle(ctx context.Context, incomingConn net.Conn)
}
}

func (handler *proxyHandler) limitNonRootConnections (ctx context.Context, user string,
tenID roachpb.TenantID, clusterName string, crdbConn net.Conn) error {
if !handler.EnableLimitNonRootConns || user == "root" {
return nil
}

// If this is a non-root connection and the tenant has a specified limit on
// their non-root connections, then ensure that limit is respected.
tenantResp, err := handler.directoryCache.LookupTenant(ctx, tenID, clusterName)
if err != nil {
log.Errorf(ctx, "error looking up tenant %v: %v", tenID, err.Error())
return err
}

// If the tenant has a limit specified, count the non-root connections.
if tenantResp.MaxNonRootConns != -1 {
count, err := countNonRootConnections(ctx, interceptor.NewPGConn(crdbConn))
if err != nil {
log.Errorf(ctx, "error counting non root connections for tenant %v: %v", tenID, err.Error())
return err
}
log.Infof(ctx, "non-root connections open for tenant %v: %v, maxNonRootConnections %v",
tenID, count, tenantResp.MaxNonRootConns)
// The current connection is included in the count, so if the count is greater
// than the limit then this current connection should be closed.
if int32(count) > tenantResp.MaxNonRootConns {
log.Error(ctx, "throttler refused connection due to limiting non root connections")
return connectionsLimitedError
}
}

return nil
}

func countNonRootConnections(ctx context.Context, pgServerConn *interceptor.PGConn) (int, error) {
if err := writeQuery(pgServerConn, "select count(*) from [show cluster sessions] where user_name != 'root'"); err != nil {
return -1, err
}

pgServerFeConn := pgServerConn.ToFrontendConn()
if err := waitForSmallRowDescription(
ctx,
pgServerFeConn,
io.Discard,
func(msg *pgproto3.RowDescription) bool {
return true
},
); err != nil {
return -1, err
}

numNonRootSessions := 0
if err := expectDataRow(ctx, pgServerFeConn, func(msg *pgproto3.DataRow, _ int) bool {
if len(msg.Values) != 1 {
return false
}
num, err := strconv.Atoi(string(msg.Values[0]))
if err != nil {
return false
}
numNonRootSessions = num
return true
}); err != nil {
return -1, errors.Wrap(err, "expecting DataRow")
}

if err := expectCommandComplete(ctx, pgServerFeConn, "SELECT 1"); err != nil {
return -1, errors.Wrap(err, "expecting CommandComplete")
}

if err := expectReadyForQuery(ctx, pgServerFeConn); err != nil {
return -1, errors.Wrap(err, "expecting ReadyForQuery")
}

return numNonRootSessions, nil
}

// handleCancelRequest handles a pgwire query cancel request by either
// forwarding it to a SQL node or to another proxy.
func (handler *proxyHandler) handleCancelRequest(
Expand Down
21 changes: 21 additions & 0 deletions pkg/ccl/sqlproxyccl/tenant/directory.proto
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,24 @@ message WatchPodsResponse {
Pod pod = 1;
}

// WatchTenantsRequest is empty as we want to get all notifications.
message WatchTenantsRequest {}

// WatchTenantsResponse represents the notifications that the server sends to
// its clients when clients want to monitor the directory server activity.
message WatchTenantsResponse {
Tenant tenant = 1;
}

message Tenant {
// ClusterName is the name of the tenant's cluster.
string cluster_name = 1;
// MaxNonRootConns is the limit on non-root sql connections to this tenant.
// A value of -1 indicates no limit.
int32 max_non_root_conns = 2;
uint64 tenant_id = 3 [(gogoproto.customname) = "TenantID"];
}

// EnsurePodRequest is used to ensure that at least one tenant pod is in the
// RUNNING state.
message EnsurePodRequest {
Expand All @@ -107,7 +125,9 @@ message GetTenantRequest {
// GetTenantResponse is sent back when a client requests metadata for a tenant.
message GetTenantResponse {
// ClusterName is the name of the tenant's cluster.
// TODO(alyshan): Deprecate cluster_name field in favour of tenant field.
string cluster_name = 1; // add more metadata if needed
Tenant tenant = 2;
}

// Directory specifies a service that keeps track and manages tenant backends,
Expand All @@ -120,6 +140,7 @@ service Directory {
// are sent when a tenant pod is created, destroyed, or modified. When
// WatchPods is first called, it returns notifications for all existing pods.
rpc WatchPods(WatchPodsRequest) returns (stream WatchPodsResponse);
rpc WatchTenants(WatchTenantsRequest) returns (stream WatchTenantsResponse);
// EnsurePod ensures that at least one of the given tenant's pod is in the
// RUNNING state. If there is already a RUNNING pod, then the server doesn't
// have to do anything. If there isn't a RUNNING pod, then the server must
Expand Down
Loading

0 comments on commit 57c49ba

Please sign in to comment.