diff --git a/pkg/ccl/cliccl/context.go b/pkg/ccl/cliccl/context.go index a00834413fe5..5ad040b9a511 100644 --- a/pkg/ccl/cliccl/context.go +++ b/pkg/ccl/cliccl/context.go @@ -37,6 +37,7 @@ func setProxyContextDefaults() { proxyContext.PollConfigInterval = 30 * time.Second proxyContext.ThrottleBaseDelay = time.Second proxyContext.DisableConnectionRebalancing = false + proxyContext.EnableLimitNonRootConns = false } var testDirectorySvrContext struct { diff --git a/pkg/ccl/cliccl/flags.go b/pkg/ccl/cliccl/flags.go index 8cfed934e5be..94942a3f2998 100644 --- a/pkg/ccl/cliccl/flags.go +++ b/pkg/ccl/cliccl/flags.go @@ -27,6 +27,7 @@ func init() { cliflagcfg.StringFlag(f, &proxyContext.DirectoryAddr, cliflags.DirectoryAddr) cliflagcfg.BoolFlag(f, &proxyContext.SkipVerify, cliflags.SkipVerify) cliflagcfg.BoolFlag(f, &proxyContext.Insecure, cliflags.InsecureBackend) + cliflagcfg.BoolFlag(f, &proxyContext.EnableLimitNonRootConns, cliflags.LimitNonRootConns) cliflagcfg.DurationFlag(f, &proxyContext.ValidateAccessInterval, cliflags.ValidateAccessInterval) cliflagcfg.DurationFlag(f, &proxyContext.PollConfigInterval, cliflags.PollConfigInterval) cliflagcfg.DurationFlag(f, &proxyContext.ThrottleBaseDelay, cliflags.ThrottleBaseDelay) diff --git a/pkg/ccl/sqlproxyccl/authentication.go b/pkg/ccl/sqlproxyccl/authentication.go index 6894603bdb84..cbd4e4f85b5e 100644 --- a/pkg/ccl/sqlproxyccl/authentication.go +++ b/pkg/ccl/sqlproxyccl/authentication.go @@ -24,6 +24,7 @@ var authenticate = func( clientConn, crdbConn net.Conn, proxyBackendKeyData *pgproto3.BackendKeyData, throttleHook func(throttler.AttemptStatus) error, + limitHook func(crdbConn net.Conn) error, ) (crdbBackendKeyData *pgproto3.BackendKeyData, _ error) { fe := pgproto3.NewBackend(pgproto3.NewChunkReader(clientConn), clientConn) be := pgproto3.NewFrontend(pgproto3.NewChunkReader(crdbConn), crdbConn) @@ -144,6 +145,16 @@ var authenticate = func( // Server has authenticated the connection successfully and is ready to // serve queries. case *pgproto3.ReadyForQuery: + // We invoke the limit hook before sending ReadyForQuery back to the client. + // This is because once we send ReadyForQuery, the client may start sending + // messages and if we send an error the client may not read it correctly. + limitError := limitHook(crdbConn) + if limitError != nil { + if err = feSend(toPgError(limitError)); err != nil { + return nil, err + } + return nil, limitError + } if err = feSend(backendMsg); err != nil { return nil, err } diff --git a/pkg/ccl/sqlproxyccl/connector.go b/pkg/ccl/sqlproxyccl/connector.go index 6a891a7ab53f..2a0e2548503d 100644 --- a/pkg/ccl/sqlproxyccl/connector.go +++ b/pkg/ccl/sqlproxyccl/connector.go @@ -144,6 +144,7 @@ func (c *connector) OpenTenantConnWithAuth( requester balancer.ConnectionHandle, clientConn net.Conn, throttleHook func(throttler.AttemptStatus) error, + limitHook func(crdbConn net.Conn) error, ) (retServerConnection net.Conn, sentToClient bool, retErr error) { // Just a safety check, but this shouldn't happen since we will block the // startup param in the frontend admitter. The only case where we actually @@ -163,7 +164,8 @@ func (c *connector) OpenTenantConnWithAuth( // Perform user authentication for non-token-based auth methods. This will // block until the server has authenticated the client. - crdbBackendKeyData, err := authenticate(clientConn, serverConn, c.CancelInfo.proxyBackendKeyData, throttleHook) + crdbBackendKeyData, err := authenticate(clientConn, serverConn, c.CancelInfo.proxyBackendKeyData, + throttleHook, limitHook) if err != nil { return nil, true, err } diff --git a/pkg/ccl/sqlproxyccl/proxy_handler.go b/pkg/ccl/sqlproxyccl/proxy_handler.go index 103361a897f7..5a673074ffef 100644 --- a/pkg/ccl/sqlproxyccl/proxy_handler.go +++ b/pkg/ccl/sqlproxyccl/proxy_handler.go @@ -12,6 +12,8 @@ import ( "bytes" "context" "crypto/tls" + "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/interceptor" + "io" "net" "net/http" "regexp" @@ -108,6 +110,9 @@ type ProxyOptions struct { ThrottleBaseDelay time.Duration // DisableConnectionRebalancing disables connection rebalancing for tenants. DisableConnectionRebalancing bool + // EnableLimitNonRootConns enables the limiting of non root connections for + // tenants who have that limit set on them. + EnableLimitNonRootConns bool // testingKnobs are knobs used for testing. testingKnobs struct { @@ -164,11 +169,21 @@ const throttledErrorHint string = `Connection throttling is triggered by repeate sure the username and password are correct. ` +// TODO: Make this generic - let tenant dir provide error message +const resourceLimitedErrorHint string = `Connection limiting is triggered by insufficient resource limits. Make +sure the cluster has not hit its request unit and storage limits. +` + var throttledError = errors.WithHint( withCode(errors.New( "connection attempt throttled"), codeProxyRefusedConnection), throttledErrorHint) +var connectionsLimitedError = errors.WithHint( + withCode(errors.New( + "connection attempt limited"), codeProxyRefusedConnection), + resourceLimitedErrorHint) + // newProxyHandler will create a new proxy handler with configuration based on // the provided options. func newProxyHandler( @@ -273,6 +288,9 @@ func newProxyHandler( if handler.testingKnobs.dirOpts != nil { dirOpts = append(dirOpts, handler.testingKnobs.dirOpts...) } + if options.EnableLimitNonRootConns { + dirOpts = append(dirOpts, tenant.WithWatchTenantsOption()) + } client := tenant.NewDirectoryClient(conn) handler.directoryCache, err = tenant.NewDirectoryCache(ctx, stopper, client, dirOpts...) @@ -428,6 +446,9 @@ func (handler *proxyHandler) handle(ctx context.Context, incomingConn net.Conn) return throttledError } return nil + }, func(crdbConn net.Conn) error { + user := fe.Msg.Parameters["user"] + return handler.limitNonRootConnections(ctx, user, tenID, clusterName, crdbConn) }, ) if err != nil { @@ -465,6 +486,7 @@ func (handler *proxyHandler) handle(ctx context.Context, incomingConn net.Conn) writeErrMarker: errClientWrite, } + // Pass ownership of conn and crdbConn to the forwarder. if err := f.run(clientConn, crdbConn); err != nil { // Don't send to the client here for the same reason below. @@ -503,6 +525,83 @@ func (handler *proxyHandler) handle(ctx context.Context, incomingConn net.Conn) } } +func (handler *proxyHandler) limitNonRootConnections (ctx context.Context, user string, + tenID roachpb.TenantID, clusterName string, crdbConn net.Conn) error { + if !handler.EnableLimitNonRootConns || user == "root" { + return nil + } + + // If this is a non-root connection and the tenant has a specified limit on + // their non-root connections, then ensure that limit is respected. + tenantResp, err := handler.directoryCache.LookupTenant(ctx, tenID, clusterName) + if err != nil { + log.Errorf(ctx, "error looking up tenant %v: %v", tenID, err.Error()) + return err + } + + // If the tenant has a limit specified, count the non-root connections. + if tenantResp.MaxNonRootConns != -1 { + count, err := countNonRootConnections(ctx, interceptor.NewPGConn(crdbConn)) + if err != nil { + log.Errorf(ctx, "error counting non root connections for tenant %v: %v", tenID, err.Error()) + return err + } + log.Infof(ctx, "non-root connections open for tenant %v: %v, maxNonRootConnections %v", + tenID, count, tenantResp.MaxNonRootConns) + // The current connection is included in the count, so if the count is greater + // than the limit then this current connection should be closed. + if int32(count) > tenantResp.MaxNonRootConns { + log.Error(ctx, "throttler refused connection due to limiting non root connections") + return connectionsLimitedError + } + } + + return nil +} + +func countNonRootConnections(ctx context.Context, pgServerConn *interceptor.PGConn) (int, error) { + if err := writeQuery(pgServerConn, "select count(*) from [show cluster sessions] where user_name != 'root'"); err != nil { + return -1, err + } + + pgServerFeConn := pgServerConn.ToFrontendConn() + if err := waitForSmallRowDescription( + ctx, + pgServerFeConn, + io.Discard, + func(msg *pgproto3.RowDescription) bool { + return true + }, + ); err != nil { + return -1, err + } + + numNonRootSessions := 0 + if err := expectDataRow(ctx, pgServerFeConn, func(msg *pgproto3.DataRow, _ int) bool { + if len(msg.Values) != 1 { + return false + } + num, err := strconv.Atoi(string(msg.Values[0])) + if err != nil { + return false + } + numNonRootSessions = num + return true + }); err != nil { + return -1, errors.Wrap(err, "expecting DataRow") + } + + if err := expectCommandComplete(ctx, pgServerFeConn, "SELECT 1"); err != nil { + return -1, errors.Wrap(err, "expecting CommandComplete") + } + + if err := expectReadyForQuery(ctx, pgServerFeConn); err != nil { + return -1, errors.Wrap(err, "expecting ReadyForQuery") + } + + return numNonRootSessions, nil +} + // handleCancelRequest handles a pgwire query cancel request by either // forwarding it to a SQL node or to another proxy. func (handler *proxyHandler) handleCancelRequest( diff --git a/pkg/ccl/sqlproxyccl/tenant/directory.proto b/pkg/ccl/sqlproxyccl/tenant/directory.proto index 4ee4695c0641..bceb03ce9efa 100644 --- a/pkg/ccl/sqlproxyccl/tenant/directory.proto +++ b/pkg/ccl/sqlproxyccl/tenant/directory.proto @@ -85,6 +85,27 @@ message WatchPodsResponse { Pod pod = 1; } +// WatchTenantsRequest is empty as we want to get all notifications. +message WatchTenantsRequest {} + +// WatchTenantsResponse represents the notifications that the server sends to +// its clients when clients want to monitor the directory server activity. +message WatchTenantsResponse { + Tenant tenant = 1; +} + +// Tenant contains information about a tenant, such as its name, tenantID, +// and connection limit. +message Tenant { + // ClusterName is the name of the tenant's cluster. + string cluster_name = 1; + // MaxNonRootConns is the limit on non-root sql connections to this tenant. + // A value of -1 indicates no limit. + int32 max_non_root_conns = 2; + // TenantID identifies the tenant. + uint64 tenant_id = 3 [(gogoproto.customname) = "TenantID"]; +} + // EnsurePodRequest is used to ensure that at least one tenant pod is in the // RUNNING state. message EnsurePodRequest { @@ -107,7 +128,10 @@ message GetTenantRequest { // GetTenantResponse is sent back when a client requests metadata for a tenant. message GetTenantResponse { // ClusterName is the name of the tenant's cluster. + // TODO(alyshan): Deprecate cluster_name field in favour of tenant field. string cluster_name = 1; // add more metadata if needed + // Tenant contains the information for the tenant specified in the request. + Tenant tenant = 2; } // Directory specifies a service that keeps track and manages tenant backends, @@ -120,6 +144,10 @@ service Directory { // are sent when a tenant pod is created, destroyed, or modified. When // WatchPods is first called, it returns notifications for all existing pods. rpc WatchPods(WatchPodsRequest) returns (stream WatchPodsResponse); + // WatchTenants gets a stream of tenant change notifications. Notifications + // are sent when a tenant is created, destroyed, or modified. When + // WatchTenants is first called, it returns notifications for all existing tenants. + rpc WatchTenants(WatchTenantsRequest) returns (stream WatchTenantsResponse); // EnsurePod ensures that at least one of the given tenant's pod is in the // RUNNING state. If there is already a RUNNING pod, then the server doesn't // have to do anything. If there isn't a RUNNING pod, then the server must diff --git a/pkg/ccl/sqlproxyccl/tenant/directory_cache.go b/pkg/ccl/sqlproxyccl/tenant/directory_cache.go index 0d3acfdeaf83..de1fccfae2a6 100644 --- a/pkg/ccl/sqlproxyccl/tenant/directory_cache.go +++ b/pkg/ccl/sqlproxyccl/tenant/directory_cache.go @@ -35,6 +35,12 @@ type DirectoryCache interface { // deleted), this will return a GRPC NotFound error. LookupTenantPods(ctx context.Context, tenantID roachpb.TenantID, clusterName string) ([]*Pod, error) + // LookupTenant returns the specified tenant. + // + // If no matching tenant is found (e.g. cluster name mismatch, or tenant was + // deleted), this will return a GRPC NotFound error. + LookupTenant(ctx context.Context, tenantID roachpb.TenantID, clusterName string) (Tenant, error) + // TryLookupTenantPods returns the IP addresses for all available SQL // processes for the given tenant. It returns a GRPC NotFound error if the // tenant does not exist. @@ -54,6 +60,10 @@ type dirOptions struct { deterministic bool refreshDelay time.Duration podWatcher chan *Pod + // watchTenants is a flag that instructs the directory to + // set up a watch on tenants and update tenant entries to + // keep track of tenants MaxNonRootConnections. + watchTenants bool } // DirOption defines an option that can be passed to directoryCache in order @@ -84,6 +94,13 @@ func PodWatcher(podWatcher chan *Pod) func(opts *dirOptions) { } } +// WithWatchTenantsOption is used to set the watchTenants option to true. +func WithWatchTenantsOption() func(opts *dirOptions){ + return func(opts *dirOptions) { + opts.watchTenants = true + } +} + // directoryCache tracks the network locations of SQL tenant processes. It is // used by the sqlproxy to route incoming traffic to the correct backend process. // Process information is populated and kept relatively up-to-date using a @@ -150,6 +167,12 @@ func NewDirectoryCache( if err := dir.watchPods(ctx, stopper); err != nil { return nil, err } + if dir.options.watchTenants { + // Start the tenant watcher on a background goroutine. + if err := dir.watchTenants(ctx, stopper); err != nil { + return nil, err + } + } return dir, nil } @@ -176,20 +199,11 @@ func (d *directoryCache) LookupTenantPods( ctx context.Context, tenantID roachpb.TenantID, clusterName string, ) ([]*Pod, error) { // Ensure that a directory entry has been created for this tenant. - entry, err := d.getEntry(ctx, tenantID, true /* allowCreate */) + entry, err := d.ensureDirectoryEntry(ctx, tenantID, clusterName) if err != nil { return nil, err } - // Check if the cluster name matches. This can be skipped if clusterName - // is empty, or the ClusterName returned by the directory server is empty. - if clusterName != "" && entry.ClusterName != "" && clusterName != entry.ClusterName { - // Return a GRPC NotFound error. - log.Errorf(ctx, "cluster name %s doesn't match expected %s", clusterName, entry.ClusterName) - return nil, status.Errorf(codes.NotFound, - "cluster name %s doesn't match expected %s", clusterName, entry.ClusterName) - } - ctx, _ = d.stopper.WithCancelOnQuiesce(ctx) tenantPods := entry.GetPods() @@ -209,6 +223,42 @@ func (d *directoryCache) LookupTenantPods( return tenantPods, nil } +// ensureDirectoryEntry ensures that an entry has been created for the tenant. +func (d *directoryCache) ensureDirectoryEntry(ctx context.Context, + tenantID roachpb.TenantID, clusterName string) (*tenantEntry, error) { + entry, err := d.getEntry(ctx, tenantID, true /* allowCreate */) + if err != nil { + return nil, err + } + + // Check if the cluster name matches. This can be skipped if clusterName + // is empty, or the ClusterName returned by the directory server is empty. + if clusterName != "" && entry.ClusterName != "" && clusterName != entry.ClusterName { + // Return a GRPC NotFound error. + log.Errorf(ctx, "cluster name %s doesn't match expected %s", clusterName, entry.ClusterName) + return nil, status.Errorf(codes.NotFound, + "cluster name %s doesn't match expected %s", clusterName, entry.ClusterName) + } + + return entry, nil +} + + +// LookupTenant returns the specified tenant. +// +// If no matching tenant is found (e.g. tenant was deleted), this will +// return a GRPC NotFound error. +func (d *directoryCache) LookupTenant(ctx context.Context, + tenantID roachpb.TenantID, clusterName string) (Tenant, error) { + // Ensure that a directory entry has been created for this tenant. + entry, err := d.ensureDirectoryEntry(ctx, tenantID, clusterName) + if err != nil { + return Tenant{}, err + } + + return entry.ToTenantProto(), nil +} + // TryLookupTenantPods returns a list of SQL pods in the RUNNING and DRAINING // states for thegiven tenant. It returns a GRPC NotFound error if the tenant // does not exist (e.g. it has not yet been created) or if it has not yet been @@ -405,7 +455,7 @@ func (d *directoryCache) watchPods(ctx context.Context, stopper *stop.Stopper) e // Update the directory entry for the tenant with the latest // information about this pod. - d.updateTenantEntry(ctx, resp.Pod) + d.updateTenantEntryWithPod(ctx, resp.Pod) // If caller is watching pods, send to its channel now. Only do this // after updating the tenant entry in the directory. @@ -427,10 +477,67 @@ func (d *directoryCache) watchPods(ctx context.Context, stopper *stop.Stopper) e return err } -// updateTenantEntry keeps tenant directory entries up-to-date by handling pod +// watchTenants establishes a watcher that looks for changes to tenants. +// Whenever a tenant is created, modified or deleted the watcher will get a notification +// and update the directory to reflect that change. +func (d *directoryCache) watchTenants(ctx context.Context, stopper *stop.Stopper) error { + req := WatchTenantsRequest{} + + err := stopper.RunAsyncTask(ctx, "watch-tenants-client", func(ctx context.Context) { + var client Directory_WatchTenantsClient + var err error + ctx, _ = stopper.WithCancelOnQuiesce(ctx) + + watchTenantsErr := log.Every(10 * time.Second) + recvErr := log.Every(10 * time.Second) + + for ctx.Err() == nil { + if client == nil { + client, err = d.client.WatchTenants(ctx, &req) + if err != nil { + if watchTenantsErr.ShouldLog() { + log.Errorf(ctx, "err creating new watch tenants client: %s", err) + } + sleepContext(ctx, time.Second) + continue + } else { + log.Info(ctx, "established watch on tenants") + } + } + + // Read the next watcher event. + resp, err := client.Recv() + if err != nil { + if recvErr.ShouldLog() { + log.Errorf(ctx, "err receiving stream events: %s", err) + } + // If stream ends, immediately try to establish a new one. Otherwise, + // wait for a second to avoid slamming server. + if err != io.EOF { + time.Sleep(time.Second) + } + client = nil + continue + } + + // Update the directory entry for the tenant with the latest + // information about this tenant. Note that if an entry does not exist for the + // tenant, then we will not create one. This is because there is no value in keeping + // entries for tenants that aren't being connected to through the proxy. + d.updateTenantEntryWithTenant(ctx, resp.Tenant) + } + }) + if err != nil { + return err + } + + return err +} + +// updateTenantEntryWithPod keeps tenant directory entries up-to-date by handling pod // watcher events. When a pod is created, destroyed, or modified, it updates the // tenant's entry to reflect that change. -func (d *directoryCache) updateTenantEntry(ctx context.Context, pod *Pod) { +func (d *directoryCache) updateTenantEntryWithPod(ctx context.Context, pod *Pod) { if pod.Addr == "" || pod.TenantID == 0 { // Nothing needs to be done if there is no IP address specified. // We also check on TenantID here because roachpb.MustMakeTenantID will @@ -468,6 +575,32 @@ func (d *directoryCache) updateTenantEntry(ctx context.Context, pod *Pod) { } } +// updateTenantEntryWithTenant keeps tenant directory entries up-to-date by handling tenant +// watcher events. When a tenant is modified, it updates the tenant's entry to reflect that change. +// Note that if an entry does not exist, this function will not create one. The watcher is only +// concerned with tenants that are being connected to through the proxy. +func (d *directoryCache) updateTenantEntryWithTenant(ctx context.Context, tenant *Tenant) { + // Nothing to do. + if tenant == nil { + return + } + + entry, err := d.getEntry(ctx, roachpb.MustMakeTenantID(tenant.TenantID), false) + if err != nil { + if !grpcutil.IsContextCanceled(err) { + // This should only happen in case of a deleted tenant or a transient + // error during fetch of tenant metadata (i.e. very rarely). + log.Errorf(ctx, "ignoring error getting entry for tenant %d: %v", tenant.TenantID, err) + } + return + } + if entry == nil { + return + } + + entry.SetMaxNonRootConnections(tenant.MaxNonRootConns) +} + // sleepContext sleeps for the given duration or until the given context is // canceled, whichever comes first. func sleepContext(ctx context.Context, delay time.Duration) { diff --git a/pkg/ccl/sqlproxyccl/tenant/entry.go b/pkg/ccl/sqlproxyccl/tenant/entry.go index c30861e701cc..05a25df11403 100644 --- a/pkg/ccl/sqlproxyccl/tenant/entry.go +++ b/pkg/ccl/sqlproxyccl/tenant/entry.go @@ -34,6 +34,19 @@ type tenantEntry struct { // Full name of the tenant's cluster i.e. dim-dog. ClusterName string + // mu contains fields in the tenantEntry that can be updated over time, + // so a lock must be obtained before accessing them. + mu struct { + syncutil.Mutex + // numConns indicates whether the tenant has the respective + // limit set on it. A value of -1 indicates no limit. A value of 0 would + // indicate that only root connections can be made to the tenant. + numConns int32 + // pods synchronizes access to information about the tenant's SQL pods. + pods []*Pod + } + + // RefreshDelay is the minimum amount of time that must elapse between // attempts to refresh pods for this tenant after ReportFailure is // called. @@ -47,14 +60,6 @@ type tenantEntry struct { // error occurred). initError error - // pods synchronizes access to information about the tenant's SQL pods. - // These fields can be updated over time, so a lock must be obtained before - // accessing them. - pods struct { - syncutil.Mutex - pods []*Pod - } - // calls synchronizes calls to the Directory service for this tenant (e.g. // calls to GetTenant or ListPods). Synchronization is needed to ensure that // only one thread at a time is calling on behalf of a tenant, and that @@ -80,7 +85,12 @@ func (e *tenantEntry) Initialize(ctx context.Context, client DirectoryClient) er return } + // TODO(alyshan): Once tenant directories are updated to send the Tenant field + // we can take ClusterName from that. e.ClusterName = tenantResp.ClusterName + if tenantResp.Tenant != nil { + e.SetMaxNonRootConnections(tenantResp.Tenant.MaxNonRootConns) + } }) // If Initialize has already been called, return any error that occurred. @@ -110,36 +120,59 @@ func (e *tenantEntry) RefreshPods(ctx context.Context, client DirectoryClient) e // AddPod inserts the given pod into the tenant's list of pods. If it is // already present, then AddPod updates the pod entry and returns false. func (e *tenantEntry) AddPod(pod *Pod) bool { - e.pods.Lock() - defer e.pods.Unlock() + e.mu.Lock() + defer e.mu.Unlock() - for i, existing := range e.pods.pods { + for i, existing := range e.mu.pods { if existing.Addr == pod.Addr { // e.pods.pods is copy on write. Whenever modifications are made, // we must make a copy to avoid accidentally mutating the slice // retrieved by GetPods. - pods := e.pods.pods - e.pods.pods = make([]*Pod, len(pods)) - copy(e.pods.pods, pods) - e.pods.pods[i] = pod + pods := e.mu.pods + e.mu.pods = make([]*Pod, len(pods)) + copy(e.mu.pods, pods) + e.mu.pods[i] = pod return false } } - e.pods.pods = append(e.pods.pods, pod) + e.mu.pods = append(e.mu.pods, pod) return true } +// SetMaxNonRootConnections sets numConns in the lock protected mu struct. +func (e *tenantEntry) SetMaxNonRootConnections(numNonRoot int32){ + e.mu.Lock() + defer e.mu.Unlock() + e.mu.numConns = numNonRoot +} + +// GetMaxNonRootConnections gets numConns from the lock protected mu struct. +func (e *tenantEntry) GetMaxNonRootConnections() int32 { + e.mu.Lock() + defer e.mu.Unlock() + return e.mu.numConns +} + +// ToTenantProto returns a Tenant constructed from the fields of the tenantEntry. +func (e *tenantEntry) ToTenantProto() Tenant { + return Tenant{ + ClusterName: e.ClusterName, + MaxNonRootConns: e.GetMaxNonRootConnections(), + TenantID: e.TenantID.ToUint64(), + } +} + // RemovePodByAddr removes the pod with the given IP address from the tenant's // list of pod addresses. If it was not present, RemovePodByAddr returns false. func (e *tenantEntry) RemovePodByAddr(addr string) bool { - e.pods.Lock() - defer e.pods.Unlock() + e.mu.Lock() + defer e.mu.Unlock() - for i, existing := range e.pods.pods { + for i, existing := range e.mu.pods { if existing.Addr == addr { - copy(e.pods.pods[i:], e.pods.pods[i+1:]) - e.pods.pods = e.pods.pods[:len(e.pods.pods)-1] + copy(e.mu.pods[i:], e.mu.pods[i+1:]) + e.mu.pods = e.mu.pods[:len(e.mu.pods)-1] return true } } @@ -148,9 +181,9 @@ func (e *tenantEntry) RemovePodByAddr(addr string) bool { // GetPods gets the current list of pods within scope of lock and returns them. func (e *tenantEntry) GetPods() []*Pod { - e.pods.Lock() - defer e.pods.Unlock() - return e.pods.pods + e.mu.Lock() + defer e.mu.Unlock() + return e.mu.pods } // EnsureTenantPod ensures that at least one RUNNING SQL process exists for this @@ -228,15 +261,15 @@ func (e *tenantEntry) fetchPodsLocked( // Need to lock in case another thread is reading the IP addresses (e.g. in // ChoosePodAddr). - e.pods.Lock() - defer e.pods.Unlock() - e.pods.pods = list.Pods + e.mu.Lock() + defer e.mu.Unlock() + e.mu.pods = list.Pods - if len(e.pods.pods) != 0 { - log.Infof(ctx, "fetched IP addresses: %v", e.pods.pods) + if len(e.mu.pods) != 0 { + log.Infof(ctx, "fetched IP addresses: %v", e.mu.pods) } - return e.pods.pods, nil + return e.mu.pods, nil } // canRefreshLocked returns true if it's been at least X milliseconds since the diff --git a/pkg/ccl/sqlproxyccl/tenantdirsvr/test_directory_svr.go b/pkg/ccl/sqlproxyccl/tenantdirsvr/test_directory_svr.go index 0e8826c634aa..b8ec8c00c527 100644 --- a/pkg/ccl/sqlproxyccl/tenantdirsvr/test_directory_svr.go +++ b/pkg/ccl/sqlproxyccl/tenantdirsvr/test_directory_svr.go @@ -228,6 +228,14 @@ func (s *TestDirectoryServer) WatchPods( return err } +// WatchTenants returns a new stream, that can be used to monitor server +// activity. +func (s *TestDirectoryServer) WatchTenants( + _ *tenant.WatchTenantsRequest, server tenant.Directory_WatchTenantsServer, +) error { + return nil +} + // Drain sends out DRAINING pod notifications for each process managed by the // test directory. This causes the proxy to start enforcing short idle // connection timeouts in order to drain the connections to the pod. diff --git a/pkg/ccl/sqlproxyccl/tenantdirsvr/test_simple_directory_svr.go b/pkg/ccl/sqlproxyccl/tenantdirsvr/test_simple_directory_svr.go index 610b352730f5..508c212bd139 100644 --- a/pkg/ccl/sqlproxyccl/tenantdirsvr/test_simple_directory_svr.go +++ b/pkg/ccl/sqlproxyccl/tenantdirsvr/test_simple_directory_svr.go @@ -77,6 +77,15 @@ func (d *TestSimpleDirectoryServer) ListPods( }, nil } +// WatchTenants is a no-op for the simple directory. +// +// WatchTenants implements the tenant.DirectoryServer interface. +func (d *TestSimpleDirectoryServer) WatchTenants( + req *tenant.WatchTenantsRequest, server tenant.Directory_WatchTenantsServer, +) error { + return nil +} + // WatchPods is a no-op for the simple directory. // // WatchPods implements the tenant.DirectoryServer interface. diff --git a/pkg/ccl/sqlproxyccl/tenantdirsvr/test_static_directory_svr.go b/pkg/ccl/sqlproxyccl/tenantdirsvr/test_static_directory_svr.go index aac71d067b97..0b10e3f38a5a 100644 --- a/pkg/ccl/sqlproxyccl/tenantdirsvr/test_static_directory_svr.go +++ b/pkg/ccl/sqlproxyccl/tenantdirsvr/test_static_directory_svr.go @@ -176,6 +176,67 @@ func (d *TestStaticDirectoryServer) WatchPods( ) } +// WatchTenants allows callers to monitor for pod update events. +// +// WatchTenants implements the tenant.DirectoryServer interface. +func (d *TestStaticDirectoryServer) WatchTenants( + req *tenant.WatchTenantsRequest, server tenant.Directory_WatchTenantsServer, +) error { + d.process.Lock() + stopper := d.process.stopper + d.process.Unlock() + + // This cannot happen unless WatchTenants was called directly, which we + // shouldn't since it is meant to be called through a GRPC client. + if stopper == nil { + return status.Errorf(codes.FailedPrecondition, "directory server has not been started") + } + + addListener := func(ch chan *tenant.WatchTenantsResponse) *list.Element { + d.mu.Lock() + defer d.mu.Unlock() + return d.mu.eventListeners.PushBack(ch) + } + removeListener := func(e *list.Element) chan *tenant.WatchTenantsResponse { + d.mu.Lock() + defer d.mu.Unlock() + return d.mu.eventListeners.Remove(e).(chan *tenant.WatchTenantsResponse) + } + + // Construct the channel with a small buffer to allow for a burst of + // notifications, and a slow receiver. + c := make(chan *tenant.WatchTenantsResponse, 10) + chElement := addListener(c) + + return stopper.RunTask( + context.Background(), + "watch-tenants-server", + func(ctx context.Context) { + defer func() { + if ch := removeListener(chElement); ch != nil { + close(ch) + } + }() + + for watch := true; watch; { + select { + case e, ok := <-c: + // Channel was closed. + if !ok { + watch = false + break + } + if err := server.Send(e); err != nil { + watch = false + } + case <-stopper.ShouldQuiesce(): + watch = false + } + } + }, + ) +} + // EnsurePod returns an empty response if a tenant with the given tenant ID // exists, and there is at least one SQL pod. If there are no SQL pods, a GRPC // FailedPrecondition error will be returned. Similarly, if the tenant does not diff --git a/pkg/cli/cliflags/flags_mt.go b/pkg/cli/cliflags/flags_mt.go index 5ded2caf3490..5235a68d239b 100644 --- a/pkg/cli/cliflags/flags_mt.go +++ b/pkg/cli/cliflags/flags_mt.go @@ -64,6 +64,11 @@ var ( Description: "Directory address of the service doing resolution of tenants to their IP addresses.", } + LimitNonRootConns = FlagInfo{ + Name: "limit-non-root-conns", + Description: "Enable the limiting of non-root connections for tenants.", + } + // TODO(chrisseto): Remove skip-verify as a CLI option. It should only be // set internally for testing, rather than being exposed to consumers. SkipVerify = FlagInfo{