Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stop loading the entire node set into memory per tsh ssh connection #12014

Merged
merged 10 commits into from
May 2, 2022
3 changes: 3 additions & 0 deletions lib/reversetunnel/api.go
Original file line number Diff line number Diff line change
@@ -25,6 +25,7 @@ import (

"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/auth"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/teleagent"
)

@@ -94,6 +95,8 @@ type RemoteSite interface {
// CachingAccessPoint returns access point that is lightweight
// but is resilient to auth server crashes
CachingAccessPoint() (auth.RemoteProxyAccessPoint, error)
// NodeWatcher returns the node watcher that maintains the node set for the site
NodeWatcher() (*services.NodeWatcher, error)
// GetTunnelsCount returns the amount of active inbound tunnels
// from the remote cluster
GetTunnelsCount() int
30 changes: 12 additions & 18 deletions lib/reversetunnel/localsite.go
Original file line number Diff line number Diff line change
@@ -22,8 +22,6 @@ import (
"sync"
"time"

"golang.org/x/crypto/ssh"

"github.com/gravitational/teleport"
apidefaults "github.com/gravitational/teleport/api/defaults"
"github.com/gravitational/teleport/api/types"
@@ -34,11 +32,12 @@ import (
"github.com/gravitational/teleport/lib/srv/forward"
"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/teleport/lib/utils/proxy"
"github.com/prometheus/client_golang/prometheus"

"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/prometheus/client_golang/prometheus"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"
)

func newlocalSite(srv *server, domainName string, client auth.ClientI) (*localSite, error) {
@@ -129,6 +128,11 @@ func (s *localSite) CachingAccessPoint() (auth.RemoteProxyAccessPoint, error) {
return s.accessPoint, nil
}

// NodeWatcher returns a services.NodeWatcher for this cluster.
func (s *localSite) NodeWatcher() (*services.NodeWatcher, error) {
return s.srv.NodeWatcher, nil
}

// GetClient returns a client to the full Auth Server API.
func (s *localSite) GetClient() (auth.ClientI, error) {
return s.client, nil
@@ -522,14 +526,7 @@ func (s *localSite) periodicFunctions() {

// sshTunnelStats reports SSH tunnel statistics for the cluster.
func (s *localSite) sshTunnelStats() error {
servers, err := s.accessPoint.GetNodes(s.srv.ctx, apidefaults.Namespace)
if err != nil {
return trace.Wrap(err)
}

var missing []string

for _, server := range servers {
missing := s.srv.NodeWatcher.GetNodes(func(server services.Node) bool {
// Skip over any servers that that have a TTL larger than announce TTL (10
// minutes) and are non-IoT SSH servers (they won't have tunnels).
//
@@ -538,23 +535,20 @@ func (s *localSite) sshTunnelStats() error {
// their TTL value.
ttl := s.clock.Now().Add(-1 * apidefaults.ServerAnnounceTTL)
if server.Expiry().Before(ttl) {
continue
return false
}
if !server.GetUseTunnel() {
continue
return false
}

// Check if the tunnel actually exists.
_, err := s.getRemoteConn(&sshutils.DialReq{
ServerID: fmt.Sprintf("%v.%v", server.GetName(), s.domainName),
ConnType: types.NodeTunnel,
})
if err == nil {
continue
}

missing = append(missing, server.GetName())
}
return err != nil
})

// Update Prometheus metrics and also log if any tunnels are missing.
missingSSHTunnels.Set(float64(len(missing)))
12 changes: 12 additions & 0 deletions lib/reversetunnel/peer.go
Original file line number Diff line number Diff line change
@@ -87,6 +87,14 @@ func (p *clusterPeers) CachingAccessPoint() (auth.RemoteProxyAccessPoint, error)
return peer.CachingAccessPoint()
}

func (p *clusterPeers) NodeWatcher() (*services.NodeWatcher, error) {
peer, err := p.pickPeer()
if err != nil {
return nil, trace.Wrap(err)
}
return peer.NodeWatcher()
}

func (p *clusterPeers) GetClient() (auth.ClientI, error) {
peer, err := p.pickPeer()
if err != nil {
@@ -191,6 +199,10 @@ func (s *clusterPeer) CachingAccessPoint() (auth.RemoteProxyAccessPoint, error)
return nil, trace.ConnectionProblem(nil, "unable to fetch access point, this proxy %v has not been discovered yet, try again later", s)
}

func (s *clusterPeer) NodeWatcher() (*services.NodeWatcher, error) {
return nil, trace.ConnectionProblem(nil, "unable to fetch access point, this proxy %v has not been discovered yet, try again later", s)
}

func (s *clusterPeer) GetClient() (auth.ClientI, error) {
return nil, trace.ConnectionProblem(nil, "unable to fetch client, this proxy %v has not been discovered yet, try again later", s)
}
13 changes: 11 additions & 2 deletions lib/reversetunnel/remotesite.go
Original file line number Diff line number Diff line change
@@ -34,13 +34,14 @@ import (
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/srv/forward"
"github.com/gravitational/teleport/lib/utils"

"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"
)

// remoteSite is a remote site that established the inbound connecton to
// remoteSite is a remote site that established the inbound connection to
// the local reverse tunnel server, and now it can provide access to the
// cluster behind it.
type remoteSite struct {
@@ -77,6 +78,9 @@ type remoteSite struct {
// the remote cluster this site belongs to.
remoteAccessPoint auth.RemoteProxyAccessPoint

// nodeWatcher provides access the node set for the remote site
nodeWatcher *services.NodeWatcher

// remoteCA is the last remote certificate authority recorded by the client.
// It is used to detect CA rotation status changes. If the rotation
// state has been changed, the tunnel will reconnect to re-create the client
@@ -138,6 +142,11 @@ func (s *remoteSite) CachingAccessPoint() (auth.RemoteProxyAccessPoint, error) {
return s.remoteAccessPoint, nil
}

// NodeWatcher returns the services.NodeWatcher for the remote cluster.
func (s *remoteSite) NodeWatcher() (*services.NodeWatcher, error) {
return s.nodeWatcher, nil
}

func (s *remoteSite) GetClient() (auth.ClientI, error) {
return s.remoteClient, nil
}
@@ -379,7 +388,7 @@ func (s *remoteSite) handleHeartbeat(conn *remoteConn, ch ssh.Channel, reqC <-ch
} else {
s.WithFields(log.Fields{"nodeID": conn.nodeID}).Debugf("Ping <- %v", conn.conn.RemoteAddr())
}
tm := time.Now().UTC()
tm := s.clock.Now().UTC()
conn.setLastHeartbeat(tm)
go s.registerHeartbeat(tm)
// Note that time.After is re-created everytime a request is processed.
34 changes: 29 additions & 5 deletions lib/reversetunnel/srv.go
Original file line number Diff line number Diff line change
@@ -38,6 +38,7 @@ import (
"github.com/gravitational/teleport/lib/sshca"
"github.com/gravitational/teleport/lib/sshutils"
"github.com/gravitational/teleport/lib/utils"

"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/prometheus/client_golang/prometheus"
@@ -201,6 +202,9 @@ type Config struct {

// LockWatcher is a lock watcher.
LockWatcher *services.LockWatcher

// NodeWatcher is a node watcher.
NodeWatcher *services.NodeWatcher
}

// CheckAndSetDefaults checks parameters and sets default values
@@ -252,6 +256,9 @@ func (cfg *Config) CheckAndSetDefaults() error {
if cfg.LockWatcher == nil {
return trace.BadParameter("missing parameter LockWatcher")
}
if cfg.NodeWatcher == nil {
return trace.BadParameter("missing parameter NodeWatcher")
}
return nil
}

@@ -891,7 +898,7 @@ func (s *server) upsertRemoteCluster(conn net.Conn, sshConn *ssh.ServerConn) (*r
// treat first connection as a registered heartbeat,
// otherwise the connection information will appear after initial
// heartbeat delay
go site.registerHeartbeat(time.Now())
go site.registerHeartbeat(s.Clock.Now())
return site, remoteConn, nil
}

@@ -1024,7 +1031,7 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite,
types.TunnelConnectionSpecV2{
ClusterName: domainName,
ProxyName: srv.ID,
LastHeartbeat: time.Now().UTC(),
LastHeartbeat: srv.Clock.Now().UTC(),
},
)
if err != nil {
@@ -1056,27 +1063,42 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite,

clt, _, err := remoteSite.getRemoteClient()
if err != nil {
cancel()
return nil, trace.Wrap(err)
}
remoteSite.remoteClient = clt

remoteVersion, err := getRemoteAuthVersion(closeContext, sconn)
if err != nil {
cancel()
return nil, trace.Wrap(err)
}

accessPoint, err := createRemoteAccessPoint(srv, clt, remoteVersion, domainName)
if err != nil {
cancel()
return nil, trace.Wrap(err)
}
remoteSite.remoteAccessPoint = accessPoint

nodeWatcher, err := services.NewNodeWatcher(closeContext, services.NodeWatcherConfig{
ResourceWatcherConfig: services.ResourceWatcherConfig{
Component: srv.Component,
Client: accessPoint,
Log: srv.Log,
},
})
rosstimothy marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
cancel()
return nil, trace.Wrap(err)
}
remoteSite.nodeWatcher = nodeWatcher
// instantiate a cache of host certificates for the forwarding server. the
// certificate cache is created in each site (instead of creating it in
// reversetunnel.server and passing it along) so that the host certificate
// is signed by the correct certificate authority.
certificateCache, err := newHostCertificateCache(srv.Config.KeyGen, srv.localAuthClient)
if err != nil {
cancel()
return nil, trace.Wrap(err)
}
remoteSite.certificateCache = certificateCache
@@ -1089,7 +1111,8 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite,
Clock: srv.Clock,
})
if err != nil {
return nil, err
cancel()
return nil, trace.Wrap(err)
}

go remoteSite.updateCertAuthorities(caRetry)
@@ -1102,7 +1125,8 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite,
Clock: srv.Clock,
})
if err != nil {
return nil, err
cancel()
return nil, trace.Wrap(err)
}

go remoteSite.updateLocks(lockRetry)
13 changes: 13 additions & 0 deletions lib/service/service.go
Original file line number Diff line number Diff line change
@@ -2801,6 +2801,17 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
return trace.Wrap(err)
}

nodeWatcher, err := services.NewNodeWatcher(process.ExitContext(), services.NodeWatcherConfig{
ResourceWatcherConfig: services.ResourceWatcherConfig{
Component: teleport.ComponentProxy,
Log: process.log.WithField(trace.Component, teleport.ComponentProxy),
Client: conn.Client,
},
})
if err != nil {
return trace.Wrap(err)
}

serverTLSConfig, err := conn.ServerIdentity.TLSConfig(cfg.CipherSuites)
if err != nil {
return trace.Wrap(err)
@@ -2840,6 +2851,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
Emitter: streamEmitter,
Log: process.log,
LockWatcher: lockWatcher,
NodeWatcher: nodeWatcher,
})
if err != nil {
return trace.Wrap(err)
@@ -2969,6 +2981,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
regular.SetOnHeartbeat(process.onHeartbeat(teleport.ComponentProxy)),
regular.SetEmitter(streamEmitter),
regular.SetLockWatcher(lockWatcher),
regular.SetNodeWatcher(nodeWatcher),
)
if err != nil {
return trace.Wrap(err)
13 changes: 9 additions & 4 deletions lib/services/presence.go
Original file line number Diff line number Diff line change
@@ -29,6 +29,12 @@ type ProxyGetter interface {
GetProxies() ([]types.Server, error)
}

// NodesGetter is a service that gets nodes.
type NodesGetter interface {
// GetNodes returns a list of registered servers.
GetNodes(ctx context.Context, namespace string, opts ...MarshalOption) ([]types.Server, error)
}

// Presence records and reports the presence of all components
// of the cluster - Nodes, Proxies and SSH nodes
type Presence interface {
@@ -43,13 +49,12 @@ type Presence interface {

// GetNode returns a node by name and namespace.
GetNode(ctx context.Context, namespace, name string) (types.Server, error)

// GetNodes returns a list of registered servers.
GetNodes(ctx context.Context, namespace string, opts ...MarshalOption) ([]types.Server, error)

// ListNodes returns a paginated list of registered servers.
ListNodes(ctx context.Context, req proto.ListNodesRequest) (nodes []types.Server, nextKey string, err error)

// NodesGetter gets nodes
NodesGetter

// DeleteAllNodes deletes all nodes in a namespace.
DeleteAllNodes(ctx context.Context, namespace string) error

Loading