Skip to content

Commit

Permalink
Stop loading the enitre node set into memory per tsh ssh connection (#…
Browse files Browse the repository at this point in the history
…12014) (#12573)

* Prevent proxy from loading entire node set into memory more than once

When establishing a new session to a node, the proxy would load the
entire node set into memory in an attempt to find the matching host. For
smaller clusters this may not be that problematic. But on larger clusters,
loading >40k nodes into memory from the cache can be quite expensive.
This problem is compounded by the fact that it happened**per** session,
which could potentially cause the proxy to consume all available memory
and be OOM killed.

A new `NodeWatcher` is introduced which will maintain an in memory list
of all nodes per process. The watcher leverages the existing resource
watcher system and stores all nodes as types.Server, to eliminate the
cost incurred by unmarshalling the nodes from the cache. The `NodeWatcher`
provides a way to retrieve a filtered list of nodes in order to reduce the number
of copies made to only the matches.

(cherry picked from commit fa12352)
  • Loading branch information
rosstimothy authored May 12, 2022
1 parent 2781abc commit b911a8e
Show file tree
Hide file tree
Showing 16 changed files with 983 additions and 560 deletions.
3 changes: 3 additions & 0 deletions lib/reversetunnel/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (

"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/auth"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/teleagent"
)

Expand Down Expand Up @@ -94,6 +95,8 @@ type RemoteSite interface {
// CachingAccessPoint returns access point that is lightweight
// but is resilient to auth server crashes
CachingAccessPoint() (auth.RemoteProxyAccessPoint, error)
// NodeWatcher returns the node watcher that maintains the node set for the site
NodeWatcher() (*services.NodeWatcher, error)
// GetTunnelsCount returns the amount of active inbound tunnels
// from the remote cluster
GetTunnelsCount() int
Expand Down
30 changes: 12 additions & 18 deletions lib/reversetunnel/localsite.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ import (
"sync"
"time"

"golang.org/x/crypto/ssh"

"github.com/gravitational/teleport"
apidefaults "github.com/gravitational/teleport/api/defaults"
"github.com/gravitational/teleport/api/types"
Expand All @@ -34,11 +32,12 @@ import (
"github.com/gravitational/teleport/lib/srv/forward"
"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/teleport/lib/utils/proxy"
"github.com/prometheus/client_golang/prometheus"

"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/prometheus/client_golang/prometheus"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"
)

func newlocalSite(srv *server, domainName string, client auth.ClientI) (*localSite, error) {
Expand Down Expand Up @@ -129,6 +128,11 @@ func (s *localSite) CachingAccessPoint() (auth.RemoteProxyAccessPoint, error) {
return s.accessPoint, nil
}

// NodeWatcher returns a services.NodeWatcher for this cluster.
func (s *localSite) NodeWatcher() (*services.NodeWatcher, error) {
return s.srv.NodeWatcher, nil
}

// GetClient returns a client to the full Auth Server API.
func (s *localSite) GetClient() (auth.ClientI, error) {
return s.client, nil
Expand Down Expand Up @@ -522,14 +526,7 @@ func (s *localSite) periodicFunctions() {

// sshTunnelStats reports SSH tunnel statistics for the cluster.
func (s *localSite) sshTunnelStats() error {
servers, err := s.accessPoint.GetNodes(s.srv.ctx, apidefaults.Namespace)
if err != nil {
return trace.Wrap(err)
}

var missing []string

for _, server := range servers {
missing := s.srv.NodeWatcher.GetNodes(func(server services.Node) bool {
// Skip over any servers that that have a TTL larger than announce TTL (10
// minutes) and are non-IoT SSH servers (they won't have tunnels).
//
Expand All @@ -538,23 +535,20 @@ func (s *localSite) sshTunnelStats() error {
// their TTL value.
ttl := s.clock.Now().Add(-1 * apidefaults.ServerAnnounceTTL)
if server.Expiry().Before(ttl) {
continue
return false
}
if !server.GetUseTunnel() {
continue
return false
}

// Check if the tunnel actually exists.
_, err := s.getRemoteConn(&sshutils.DialReq{
ServerID: fmt.Sprintf("%v.%v", server.GetName(), s.domainName),
ConnType: types.NodeTunnel,
})
if err == nil {
continue
}

missing = append(missing, server.GetName())
}
return err != nil
})

// Update Prometheus metrics and also log if any tunnels are missing.
missingSSHTunnels.Set(float64(len(missing)))
Expand Down
12 changes: 12 additions & 0 deletions lib/reversetunnel/peer.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,14 @@ func (p *clusterPeers) CachingAccessPoint() (auth.RemoteProxyAccessPoint, error)
return peer.CachingAccessPoint()
}

func (p *clusterPeers) NodeWatcher() (*services.NodeWatcher, error) {
peer, err := p.pickPeer()
if err != nil {
return nil, trace.Wrap(err)
}
return peer.NodeWatcher()
}

func (p *clusterPeers) GetClient() (auth.ClientI, error) {
peer, err := p.pickPeer()
if err != nil {
Expand Down Expand Up @@ -191,6 +199,10 @@ func (s *clusterPeer) CachingAccessPoint() (auth.RemoteProxyAccessPoint, error)
return nil, trace.ConnectionProblem(nil, "unable to fetch access point, this proxy %v has not been discovered yet, try again later", s)
}

func (s *clusterPeer) NodeWatcher() (*services.NodeWatcher, error) {
return nil, trace.ConnectionProblem(nil, "unable to fetch access point, this proxy %v has not been discovered yet, try again later", s)
}

func (s *clusterPeer) GetClient() (auth.ClientI, error) {
return nil, trace.ConnectionProblem(nil, "unable to fetch client, this proxy %v has not been discovered yet, try again later", s)
}
Expand Down
13 changes: 11 additions & 2 deletions lib/reversetunnel/remotesite.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,14 @@ import (
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/srv/forward"
"github.com/gravitational/teleport/lib/utils"

"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"
)

// remoteSite is a remote site that established the inbound connecton to
// remoteSite is a remote site that established the inbound connection to
// the local reverse tunnel server, and now it can provide access to the
// cluster behind it.
type remoteSite struct {
Expand Down Expand Up @@ -77,6 +78,9 @@ type remoteSite struct {
// the remote cluster this site belongs to.
remoteAccessPoint auth.RemoteProxyAccessPoint

// nodeWatcher provides access the node set for the remote site
nodeWatcher *services.NodeWatcher

// remoteCA is the last remote certificate authority recorded by the client.
// It is used to detect CA rotation status changes. If the rotation
// state has been changed, the tunnel will reconnect to re-create the client
Expand Down Expand Up @@ -138,6 +142,11 @@ func (s *remoteSite) CachingAccessPoint() (auth.RemoteProxyAccessPoint, error) {
return s.remoteAccessPoint, nil
}

// NodeWatcher returns the services.NodeWatcher for the remote cluster.
func (s *remoteSite) NodeWatcher() (*services.NodeWatcher, error) {
return s.nodeWatcher, nil
}

func (s *remoteSite) GetClient() (auth.ClientI, error) {
return s.remoteClient, nil
}
Expand Down Expand Up @@ -379,7 +388,7 @@ func (s *remoteSite) handleHeartbeat(conn *remoteConn, ch ssh.Channel, reqC <-ch
} else {
s.WithFields(log.Fields{"nodeID": conn.nodeID}).Debugf("Ping <- %v", conn.conn.RemoteAddr())
}
tm := time.Now().UTC()
tm := s.clock.Now().UTC()
conn.setLastHeartbeat(tm)
go s.registerHeartbeat(tm)
// Note that time.After is re-created everytime a request is processed.
Expand Down
34 changes: 29 additions & 5 deletions lib/reversetunnel/srv.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import (
"github.com/gravitational/teleport/lib/sshca"
"github.com/gravitational/teleport/lib/sshutils"
"github.com/gravitational/teleport/lib/utils"

"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/prometheus/client_golang/prometheus"
Expand Down Expand Up @@ -201,6 +202,9 @@ type Config struct {

// LockWatcher is a lock watcher.
LockWatcher *services.LockWatcher

// NodeWatcher is a node watcher.
NodeWatcher *services.NodeWatcher
}

// CheckAndSetDefaults checks parameters and sets default values
Expand Down Expand Up @@ -252,6 +256,9 @@ func (cfg *Config) CheckAndSetDefaults() error {
if cfg.LockWatcher == nil {
return trace.BadParameter("missing parameter LockWatcher")
}
if cfg.NodeWatcher == nil {
return trace.BadParameter("missing parameter NodeWatcher")
}
return nil
}

Expand Down Expand Up @@ -891,7 +898,7 @@ func (s *server) upsertRemoteCluster(conn net.Conn, sshConn *ssh.ServerConn) (*r
// treat first connection as a registered heartbeat,
// otherwise the connection information will appear after initial
// heartbeat delay
go site.registerHeartbeat(time.Now())
go site.registerHeartbeat(s.Clock.Now())
return site, remoteConn, nil
}

Expand Down Expand Up @@ -1024,7 +1031,7 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite,
types.TunnelConnectionSpecV2{
ClusterName: domainName,
ProxyName: srv.ID,
LastHeartbeat: time.Now().UTC(),
LastHeartbeat: srv.Clock.Now().UTC(),
},
)
if err != nil {
Expand Down Expand Up @@ -1056,27 +1063,42 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite,

clt, _, err := remoteSite.getRemoteClient()
if err != nil {
cancel()
return nil, trace.Wrap(err)
}
remoteSite.remoteClient = clt

remoteVersion, err := getRemoteAuthVersion(closeContext, sconn)
if err != nil {
cancel()
return nil, trace.Wrap(err)
}

accessPoint, err := createRemoteAccessPoint(srv, clt, remoteVersion, domainName)
if err != nil {
cancel()
return nil, trace.Wrap(err)
}
remoteSite.remoteAccessPoint = accessPoint

nodeWatcher, err := services.NewNodeWatcher(closeContext, services.NodeWatcherConfig{
ResourceWatcherConfig: services.ResourceWatcherConfig{
Component: srv.Component,
Client: accessPoint,
Log: srv.Log,
},
})
if err != nil {
cancel()
return nil, trace.Wrap(err)
}
remoteSite.nodeWatcher = nodeWatcher
// instantiate a cache of host certificates for the forwarding server. the
// certificate cache is created in each site (instead of creating it in
// reversetunnel.server and passing it along) so that the host certificate
// is signed by the correct certificate authority.
certificateCache, err := newHostCertificateCache(srv.Config.KeyGen, srv.localAuthClient)
if err != nil {
cancel()
return nil, trace.Wrap(err)
}
remoteSite.certificateCache = certificateCache
Expand All @@ -1089,7 +1111,8 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite,
Clock: srv.Clock,
})
if err != nil {
return nil, err
cancel()
return nil, trace.Wrap(err)
}

go remoteSite.updateCertAuthorities(caRetry)
Expand All @@ -1102,7 +1125,8 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite,
Clock: srv.Clock,
})
if err != nil {
return nil, err
cancel()
return nil, trace.Wrap(err)
}

go remoteSite.updateLocks(lockRetry)
Expand Down
13 changes: 13 additions & 0 deletions lib/service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -2856,6 +2856,17 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
return trace.Wrap(err)
}

nodeWatcher, err := services.NewNodeWatcher(process.ExitContext(), services.NodeWatcherConfig{
ResourceWatcherConfig: services.ResourceWatcherConfig{
Component: teleport.ComponentProxy,
Log: process.log.WithField(trace.Component, teleport.ComponentProxy),
Client: conn.Client,
},
})
if err != nil {
return trace.Wrap(err)
}

serverTLSConfig, err := conn.ServerIdentity.TLSConfig(cfg.CipherSuites)
if err != nil {
return trace.Wrap(err)
Expand Down Expand Up @@ -2895,6 +2906,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
Emitter: streamEmitter,
Log: process.log,
LockWatcher: lockWatcher,
NodeWatcher: nodeWatcher,
})
if err != nil {
return trace.Wrap(err)
Expand Down Expand Up @@ -3024,6 +3036,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
regular.SetOnHeartbeat(process.onHeartbeat(teleport.ComponentProxy)),
regular.SetEmitter(streamEmitter),
regular.SetLockWatcher(lockWatcher),
regular.SetNodeWatcher(nodeWatcher),
)
if err != nil {
return trace.Wrap(err)
Expand Down
13 changes: 9 additions & 4 deletions lib/services/presence.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ type ProxyGetter interface {
GetProxies() ([]types.Server, error)
}

// NodesGetter is a service that gets nodes.
type NodesGetter interface {
// GetNodes returns a list of registered servers.
GetNodes(ctx context.Context, namespace string, opts ...MarshalOption) ([]types.Server, error)
}

// Presence records and reports the presence of all components
// of the cluster - Nodes, Proxies and SSH nodes
type Presence interface {
Expand All @@ -43,13 +49,12 @@ type Presence interface {

// GetNode returns a node by name and namespace.
GetNode(ctx context.Context, namespace, name string) (types.Server, error)

// GetNodes returns a list of registered servers.
GetNodes(ctx context.Context, namespace string, opts ...MarshalOption) ([]types.Server, error)

// ListNodes returns a paginated list of registered servers.
ListNodes(ctx context.Context, req proto.ListNodesRequest) (nodes []types.Server, nextKey string, err error)

// NodesGetter gets nodes
NodesGetter

// DeleteAllNodes deletes all nodes in a namespace.
DeleteAllNodes(ctx context.Context, namespace string) error

Expand Down
Loading

0 comments on commit b911a8e

Please sign in to comment.