Skip to content

Commit

Permalink
Stop loading the enitre node set into memory per tsh ssh connection
Browse files Browse the repository at this point in the history
When doing a `tsh ssh node` the proxySubsys would load the entire
set of nodes into memory to determine which server to route the
request to. Under heavy load, like an automated bot that periodically
spawns numerous session, a proxy could easily consume all
available memory.

To do prevent this, proxies now utilize a NodeWatcher, that maintains
a single node set in memory. This prevents loading the nodes into
memroy more than once, and also eliminates the need to unmarshal the
types.Server on each retrieval of the nodes. The NodeWatcher only
provides a GetNodes function that require a filter function to make
it intentionally challenging to retrieve a copy of the entire node set.
  • Loading branch information
rosstimothy committed Apr 19, 2022
1 parent 57cc2ed commit 7f09c3f
Show file tree
Hide file tree
Showing 16 changed files with 450 additions and 79 deletions.
3 changes: 3 additions & 0 deletions lib/reversetunnel/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (

"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/auth"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/teleagent"
)

Expand Down Expand Up @@ -94,6 +95,8 @@ type RemoteSite interface {
// CachingAccessPoint returns access point that is lightweight
// but is resilient to auth server crashes
CachingAccessPoint() (auth.RemoteProxyAccessPoint, error)
// NodeWatcher returns the node watcher that maintains the node set for the site
NodeWatcher() (*services.NodeWatcher, error)
// GetTunnelsCount returns the amount of active inbound tunnels
// from the remote cluster
GetTunnelsCount() int
Expand Down
27 changes: 11 additions & 16 deletions lib/reversetunnel/localsite.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ import (

"golang.org/x/crypto/ssh"

"github.com/prometheus/client_golang/prometheus"

"github.com/gravitational/teleport"
apidefaults "github.com/gravitational/teleport/api/defaults"
"github.com/gravitational/teleport/api/types"
Expand All @@ -34,7 +36,6 @@ import (
"github.com/gravitational/teleport/lib/srv/forward"
"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/teleport/lib/utils/proxy"
"github.com/prometheus/client_golang/prometheus"

"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
Expand Down Expand Up @@ -129,6 +130,10 @@ func (s *localSite) CachingAccessPoint() (auth.RemoteProxyAccessPoint, error) {
return s.accessPoint, nil
}

func (s *localSite) NodeWatcher() (*services.NodeWatcher, error) {
return s.srv.NodeWatcher, nil
}

// GetClient returns a client to the full Auth Server API.
func (s *localSite) GetClient() (auth.ClientI, error) {
return s.client, nil
Expand Down Expand Up @@ -522,14 +527,7 @@ func (s *localSite) periodicFunctions() {

// sshTunnelStats reports SSH tunnel statistics for the cluster.
func (s *localSite) sshTunnelStats() error {
servers, err := s.accessPoint.GetNodes(s.srv.ctx, apidefaults.Namespace)
if err != nil {
return trace.Wrap(err)
}

var missing []string

for _, server := range servers {
missing := s.srv.NodeWatcher.GetNodes(func(server services.Node) bool {
// Skip over any servers that that have a TTL larger than announce TTL (10
// minutes) and are non-IoT SSH servers (they won't have tunnels).
//
Expand All @@ -538,23 +536,20 @@ func (s *localSite) sshTunnelStats() error {
// their TTL value.
ttl := s.clock.Now().Add(-1 * apidefaults.ServerAnnounceTTL)
if server.Expiry().Before(ttl) {
continue
return false
}
if !server.GetUseTunnel() {
continue
return false
}

// Check if the tunnel actually exists.
_, err := s.getRemoteConn(&sshutils.DialReq{
ServerID: fmt.Sprintf("%v.%v", server.GetName(), s.domainName),
ConnType: types.NodeTunnel,
})
if err == nil {
continue
}

missing = append(missing, server.GetName())
}
return err != nil
})

// Update Prometheus metrics and also log if any tunnels are missing.
missingSSHTunnels.Set(float64(len(missing)))
Expand Down
12 changes: 12 additions & 0 deletions lib/reversetunnel/peer.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,14 @@ func (p *clusterPeers) CachingAccessPoint() (auth.RemoteProxyAccessPoint, error)
return peer.CachingAccessPoint()
}

func (p *clusterPeers) NodeWatcher() (*services.NodeWatcher, error) {
peer, err := p.pickPeer()
if err != nil {
return nil, trace.Wrap(err)
}
return peer.NodeWatcher()
}

func (p *clusterPeers) GetClient() (auth.ClientI, error) {
peer, err := p.pickPeer()
if err != nil {
Expand Down Expand Up @@ -191,6 +199,10 @@ func (s *clusterPeer) CachingAccessPoint() (auth.RemoteProxyAccessPoint, error)
return nil, trace.ConnectionProblem(nil, "unable to fetch access point, this proxy %v has not been discovered yet, try again later", s)
}

func (s *clusterPeer) NodeWatcher() (*services.NodeWatcher, error) {
return nil, trace.ConnectionProblem(nil, "unable to fetch access point, this proxy %v has not been discovered yet, try again later", s)
}

func (s *clusterPeer) GetClient() (auth.ClientI, error) {
return nil, trace.ConnectionProblem(nil, "unable to fetch client, this proxy %v has not been discovered yet, try again later", s)
}
Expand Down
18 changes: 13 additions & 5 deletions lib/reversetunnel/remotesite.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ import (
"sync"
"time"

"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"

"github.com/gravitational/teleport"
"github.com/gravitational/teleport/api/client"
"github.com/gravitational/teleport/api/constants"
Expand All @@ -34,10 +39,6 @@ import (
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/srv/forward"
"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"
)

// remoteSite is a remote site that established the inbound connecton to
Expand Down Expand Up @@ -77,6 +78,9 @@ type remoteSite struct {
// the remote cluster this site belongs to.
remoteAccessPoint auth.RemoteProxyAccessPoint

// nodeWatcher provides access the node set for the remote site
nodeWatcher *services.NodeWatcher

// remoteCA is the last remote certificate authority recorded by the client.
// It is used to detect CA rotation status changes. If the rotation
// state has been changed, the tunnel will reconnect to re-create the client
Expand Down Expand Up @@ -138,6 +142,10 @@ func (s *remoteSite) CachingAccessPoint() (auth.RemoteProxyAccessPoint, error) {
return s.remoteAccessPoint, nil
}

func (s *remoteSite) NodeWatcher() (*services.NodeWatcher, error) {
return s.nodeWatcher, nil
}

func (s *remoteSite) GetClient() (auth.ClientI, error) {
return s.remoteClient, nil
}
Expand Down Expand Up @@ -379,7 +387,7 @@ func (s *remoteSite) handleHeartbeat(conn *remoteConn, ch ssh.Channel, reqC <-ch
} else {
s.WithFields(log.Fields{"nodeID": conn.nodeID}).Debugf("Ping <- %v", conn.conn.RemoteAddr())
}
tm := time.Now().UTC()
tm := s.clock.Now().UTC()
conn.setLastHeartbeat(tm)
go s.registerHeartbeat(tm)
// Note that time.After is re-created everytime a request is processed.
Expand Down
43 changes: 33 additions & 10 deletions lib/reversetunnel/srv.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,9 @@ type Config struct {

// LockWatcher is a lock watcher.
LockWatcher *services.LockWatcher

// NodeWatcher is a node watcher.
NodeWatcher *services.NodeWatcher
}

// CheckAndSetDefaults checks parameters and sets default values
Expand Down Expand Up @@ -254,6 +257,9 @@ func (cfg *Config) CheckAndSetDefaults() error {
if cfg.LockWatcher == nil {
return trace.BadParameter("missing parameter LockWatcher")
}
if cfg.NodeWatcher == nil {
return trace.BadParameter("missing parameter NodeWatcher")
}
return nil
}

Expand Down Expand Up @@ -639,7 +645,7 @@ func (s *server) handleHeartbeat(conn net.Conn, sconn *ssh.ServerConn, nch ssh.N
// nodes it's a node dialing back.
val, ok := sconn.Permissions.Extensions[extCertRole]
if !ok {
log.Errorf("Failed to accept connection, missing %q extension", extCertRole)
s.log.Errorf("Failed to accept connection, missing %q extension", extCertRole)
s.rejectRequest(nch, ssh.ConnectionFailed, "unknown role")
return
}
Expand All @@ -665,22 +671,22 @@ func (s *server) handleHeartbeat(conn net.Conn, sconn *ssh.ServerConn, nch ssh.N
s.handleNewService(role, conn, sconn, nch, types.WindowsDesktopTunnel)
// Unknown role.
default:
log.Errorf("Unsupported role attempting to connect: %v", val)
s.log.Errorf("Unsupported role attempting to connect: %v", val)
s.rejectRequest(nch, ssh.ConnectionFailed, fmt.Sprintf("unsupported role %v", val))
}
}

func (s *server) handleNewService(role types.SystemRole, conn net.Conn, sconn *ssh.ServerConn, nch ssh.NewChannel, connType types.TunnelType) {
cluster, rconn, err := s.upsertServiceConn(conn, sconn, connType)
if err != nil {
log.Errorf("Failed to upsert %s: %v.", role, err)
s.log.Errorf("Failed to upsert %s: %v.", role, err)
sconn.Close()
return
}

ch, req, err := nch.Accept()
if err != nil {
log.Errorf("Failed to accept on channel: %v.", err)
s.log.Errorf("Failed to accept on channel: %v.", err)
sconn.Close()
return
}
Expand All @@ -692,14 +698,14 @@ func (s *server) handleNewCluster(conn net.Conn, sshConn *ssh.ServerConn, nch ss
// add the incoming site (cluster) to the list of active connections:
site, remoteConn, err := s.upsertRemoteCluster(conn, sshConn)
if err != nil {
log.Error(trace.Wrap(err))
s.log.Error(trace.Wrap(err))
s.rejectRequest(nch, ssh.ConnectionFailed, "failed to accept incoming cluster connection")
return
}
// accept the request and start the heartbeat on it:
ch, req, err := nch.Accept()
if err != nil {
log.Error(trace.Wrap(err))
s.log.Error(trace.Wrap(err))
sshConn.Close()
return
}
Expand Down Expand Up @@ -893,7 +899,7 @@ func (s *server) upsertRemoteCluster(conn net.Conn, sshConn *ssh.ServerConn) (*r
// treat first connection as a registered heartbeat,
// otherwise the connection information will appear after initial
// heartbeat delay
go site.registerHeartbeat(time.Now())
go site.registerHeartbeat(s.Clock.Now())
return site, remoteConn, nil
}

Expand Down Expand Up @@ -1026,7 +1032,7 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite,
types.TunnelConnectionSpecV2{
ClusterName: domainName,
ProxyName: srv.ID,
LastHeartbeat: time.Now().UTC(),
LastHeartbeat: srv.Clock.Now().UTC(),
},
)
if err != nil {
Expand Down Expand Up @@ -1058,6 +1064,7 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite,

clt, _, err := remoteSite.getRemoteClient()
if err != nil {
cancel()
return nil, trace.Wrap(err)
}
remoteSite.remoteClient = clt
Expand All @@ -1069,10 +1076,11 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite,
var accessPointFunc auth.NewRemoteProxyCachingAccessPoint
ok, err := isPreV8Cluster(closeContext, sconn)
if err != nil {
cancel()
return nil, trace.Wrap(err)
}
if ok {
log.Debugf("Pre-v8 cluster connecting, loading old cache policy.")
remoteSite.Debugf("Pre-v8 cluster connecting, loading old cache policy.")
accessPointFunc = srv.Config.NewCachingAccessPointOldProxy
} else {
accessPointFunc = srv.newAccessPoint
Expand All @@ -1082,16 +1090,29 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite,
// cluster this remote site provides access to.
accessPoint, err := accessPointFunc(clt, []string{"reverse", domainName})
if err != nil {
cancel()
return nil, trace.Wrap(err)
}
remoteSite.remoteAccessPoint = accessPoint

nodeWatcher, err := services.NewNodeWatcher(closeContext, services.NodeWatcherConfig{
ResourceWatcherConfig: services.ResourceWatcherConfig{
Component: srv.Component,
Client: accessPoint,
Log: srv.Log,
},
})
if err != nil {
cancel()
return nil, trace.Wrap(err)
}
remoteSite.nodeWatcher = nodeWatcher
// instantiate a cache of host certificates for the forwarding server. the
// certificate cache is created in each site (instead of creating it in
// reversetunnel.server and passing it along) so that the host certificate
// is signed by the correct certificate authority.
certificateCache, err := newHostCertificateCache(srv.Config.KeyGen, srv.localAuthClient)
if err != nil {
cancel()
return nil, trace.Wrap(err)
}
remoteSite.certificateCache = certificateCache
Expand All @@ -1104,6 +1125,7 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite,
Clock: srv.Clock,
})
if err != nil {
cancel()
return nil, err
}

Expand All @@ -1117,6 +1139,7 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite,
Clock: srv.Clock,
})
if err != nil {
cancel()
return nil, err
}

Expand Down
13 changes: 13 additions & 0 deletions lib/service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -2785,6 +2785,17 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
return trace.Wrap(err)
}

nodeWatcher, err := services.NewNodeWatcher(process.ExitContext(), services.NodeWatcherConfig{
ResourceWatcherConfig: services.ResourceWatcherConfig{
Component: teleport.ComponentProxy,
Log: process.log.WithField(trace.Component, teleport.ComponentProxy),
Client: conn.Client,
},
})
if err != nil {
return trace.Wrap(err)
}

serverTLSConfig, err := conn.ServerIdentity.TLSConfig(cfg.CipherSuites)
if err != nil {
return trace.Wrap(err)
Expand Down Expand Up @@ -2824,6 +2835,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
Emitter: streamEmitter,
Log: process.log,
LockWatcher: lockWatcher,
NodeWatcher: nodeWatcher,
})
if err != nil {
return trace.Wrap(err)
Expand Down Expand Up @@ -2953,6 +2965,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
regular.SetOnHeartbeat(process.onHeartbeat(teleport.ComponentProxy)),
regular.SetEmitter(streamEmitter),
regular.SetLockWatcher(lockWatcher),
regular.SetNodeWatcher(nodeWatcher),
)
if err != nil {
return trace.Wrap(err)
Expand Down
13 changes: 9 additions & 4 deletions lib/services/presence.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ type ProxyGetter interface {
GetProxies() ([]types.Server, error)
}

// NodeGetter is a service that gets nodes.
type NodeGetter interface {
// GetNodes returns a list of registered servers.
GetNodes(ctx context.Context, namespace string, opts ...MarshalOption) ([]types.Server, error)
}

// Presence records and reports the presence of all components
// of the cluster - Nodes, Proxies and SSH nodes
type Presence interface {
Expand All @@ -43,13 +49,12 @@ type Presence interface {

// GetNode returns a node by name and namespace.
GetNode(ctx context.Context, namespace, name string) (types.Server, error)

// GetNodes returns a list of registered servers.
GetNodes(ctx context.Context, namespace string, opts ...MarshalOption) ([]types.Server, error)

// ListNodes returns a paginated list of registered servers.
ListNodes(ctx context.Context, req proto.ListNodesRequest) (nodes []types.Server, nextKey string, err error)

// NodeGetter gets nodes
NodeGetter

// DeleteAllNodes deletes all nodes in a namespace.
DeleteAllNodes(ctx context.Context, namespace string) error

Expand Down
Loading

0 comments on commit 7f09c3f

Please sign in to comment.