Skip to content

Commit

Permalink
Clean up remoteSites with no active tunnels (#11435)
Browse files Browse the repository at this point in the history
Not closing a remoteSite causes a continuous retry to init the cache
on the remote auth client. Not only does this cause log spam but is
a waste of resources.

We now stop tracking and close a remoteSite when its heartBeat fails
and no active tunnels are established for the site.

Fixes #11197

(cherry picked from commit 3d5928f)

# Conflicts:
#	lib/reversetunnel/fake.go
#	lib/srv/db/access_test.go
  • Loading branch information
rosstimothy committed Apr 4, 2022
1 parent 964df34 commit 37788f7
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 17 deletions.
3 changes: 3 additions & 0 deletions lib/reversetunnel/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package reversetunnel
import (
"context"
"fmt"
"io"
"net"
"time"

Expand Down Expand Up @@ -99,6 +100,8 @@ type RemoteSite interface {
// IsClosed reports whether this RemoteSite has been closed and should no
// longer be used.
IsClosed() bool
// Closer allows the site to be closed
io.Closer
}

// Tunnel provides access to connected local or remote clusters
Expand Down
4 changes: 4 additions & 0 deletions lib/reversetunnel/fake.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,7 @@ func (s *FakeRemoteSite) Dial(params DialParams) (net.Conn, error) {
s.ConnCh <- readerConn
return writerConn, nil
}

func (s *FakeRemoteSite) Close() error {
return nil
}
3 changes: 3 additions & 0 deletions lib/reversetunnel/localsite.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,9 @@ func (s *localSite) DialTCP(params DialParams) (net.Conn, error) {
// IsClosed always returns false because localSite is never closed.
func (s *localSite) IsClosed() bool { return false }

// Close always returns nil because a localSite isn't closed.
func (s *localSite) Close() error { return nil }

func (s *localSite) dialWithAgent(params DialParams) (net.Conn, error) {
if params.GetUserAgent == nil {
return nil, trace.BadParameter("user agent getter missing")
Expand Down
3 changes: 3 additions & 0 deletions lib/reversetunnel/peer.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,9 @@ func (p *clusterPeers) DialTCP(params DialParams) (conn net.Conn, err error) {
// IsClosed always returns false because clusterPeers is never closed.
func (p *clusterPeers) IsClosed() bool { return false }

// Close always returns nil because a clusterPeers isn't closed.
func (p *clusterPeers) Close() error { return nil }

// newClusterPeer returns new cluster peer
func newClusterPeer(srv *server, connInfo types.TunnelConnection, offlineThreshold time.Duration) (*clusterPeer, error) {
clusterPeer := &clusterPeer{
Expand Down
19 changes: 13 additions & 6 deletions lib/reversetunnel/remotesite.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package reversetunnel

import (
"context"
"errors"
"fmt"
"net"
"sync"
Expand Down Expand Up @@ -151,7 +152,7 @@ func (s *remoteSite) connectionCount() int {
return len(s.connections)
}

func (s *remoteSite) hasValidConnections() bool {
func (s *remoteSite) HasValidConnections() bool {
s.RLock()
defer s.RUnlock()

Expand Down Expand Up @@ -314,13 +315,20 @@ func (s *remoteSite) fanOutProxies(proxies []types.Server) {
}
}

// handleHearbeat receives heartbeat messages from the connected agent
// handleHeartbeat receives heartbeat messages from the connected agent
// if the agent has missed several heartbeats in a row, Proxy marks
// the connection as invalid.
func (s *remoteSite) handleHeartbeat(conn *remoteConn, ch ssh.Channel, reqC <-chan *ssh.Request) {
defer func() {
s.Infof("Cluster connection closed.")
conn.Close()

if err := conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
s.WithError(err).Warnf("Failed to close remote connection for remote site %s", s.domainName)
}

if err := s.srv.onSiteTunnelClose(s); err != nil {
s.WithError(err).Warnf("Failed to close remote site %s", s.domainName)
}
}()

firstHeartbeat := true
Expand All @@ -344,7 +352,7 @@ func (s *remoteSite) handleHeartbeat(conn *remoteConn, ch ssh.Channel, reqC <-ch
if req == nil {
s.Infof("Cluster agent disconnected.")
conn.markInvalid(trace.ConnectionProblem(nil, "agent disconnected"))
if !s.hasValidConnections() {
if !s.HasValidConnections() {
s.Debugf("Deleting connection record.")
s.deleteConnectionRecord()
}
Expand Down Expand Up @@ -431,8 +439,7 @@ func (s *remoteSite) updateCertAuthorities(retry utils.Retry) {
s.Debugf("Remote cluster %v does not support cert authorities rotation yet.", s.domainName)
case trace.IsCompareFailed(err):
s.Infof("Remote cluster has updated certificate authorities, going to force reconnect.")
s.srv.removeSite(s.domainName)
s.Close()
s.srv.onSiteTunnelClose(&alwaysClose{RemoteSite: s})
return
case trace.IsConnectionProblem(err):
s.Debugf("Remote cluster %v is offline.", s.domainName)
Expand Down
47 changes: 36 additions & 11 deletions lib/reversetunnel/srv.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"context"
"crypto/tls"
"fmt"
"io"
"net"
"strings"
"sync"
Expand Down Expand Up @@ -84,7 +85,7 @@ type server struct {
srv *sshutils.Server
limiter *limiter.Limiter

// remoteSites is the list of conencted remote clusters
// remoteSites is the list of connected remote clusters
remoteSites []*remoteSite

// localSites is the list of local (our own cluster) tunnel clients,
Expand Down Expand Up @@ -348,7 +349,7 @@ func remoteClustersMap(rc []types.RemoteCluster) map[string]types.RemoteCluster
}

// disconnectClusters disconnects reverse tunnel connections from remote clusters
// that were deleted from the the local cluster side and cleans up in memory objects.
// that were deleted from the local cluster side and cleans up in memory objects.
// In this case all local trust has been deleted, so all the tunnel connections have to be dropped.
func (s *server) disconnectClusters() error {
connectedRemoteClusters := s.getRemoteClusters()
Expand All @@ -363,9 +364,7 @@ func (s *server) disconnectClusters() error {
for _, cluster := range connectedRemoteClusters {
if _, ok := remoteMap[cluster.GetName()]; !ok {
s.log.Infof("Remote cluster %q has been deleted. Disconnecting it from the proxy.", cluster.GetName())
s.removeSite(cluster.GetName())
err := cluster.Close()
if err != nil {
if err := s.onSiteTunnelClose(&alwaysClose{RemoteSite: cluster}); err != nil {
s.log.Debugf("Failure closing cluster %q: %v.", cluster.GetName(), err)
}
}
Expand Down Expand Up @@ -956,22 +955,48 @@ func (s *server) GetSite(name string) (RemoteSite, error) {
return nil, trace.NotFound("cluster %q is not found", name)
}

func (s *server) removeSite(domainName string) error {
// alwaysClose forces onSiteTunnelClose to remove and close
// the site by always returning false from HasValidConnections.
type alwaysClose struct {
RemoteSite
}

func (a *alwaysClose) HasValidConnections() bool {
return false
}

// siteCloser is used by onSiteTunnelClose to determine if a site should be closed
// when a tunnel is closed
type siteCloser interface {
GetName() string
HasValidConnections() bool
io.Closer
}

// onSiteTunnelClose will close and stop tracking the site with the given name
// if it has 0 active tunnels. This is done here to ensure that no new tunnels
// can be established while cleaning up a site.
func (s *server) onSiteTunnelClose(site siteCloser) error {
s.Lock()
defer s.Unlock()

if site.HasValidConnections() {
return nil
}

for i := range s.remoteSites {
if s.remoteSites[i].domainName == domainName {
if s.remoteSites[i].domainName == site.GetName() {
s.remoteSites = append(s.remoteSites[:i], s.remoteSites[i+1:]...)
return nil
return trace.Wrap(site.Close())
}
}
for i := range s.localSites {
if s.localSites[i].domainName == domainName {
if s.localSites[i].domainName == site.GetName() {
s.localSites = append(s.localSites[:i], s.localSites[i+1:]...)
return nil
return trace.Wrap(site.Close())
}
}
return trace.NotFound("cluster %q is not found", domainName)
return trace.NotFound("site %q is not found", site.GetName())
}

// fanOutProxies is a non-blocking call that updated the watches proxies
Expand Down
1 change: 1 addition & 0 deletions lib/srv/db/access_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1033,6 +1033,7 @@ func setupTestContext(ctx context.Context, t *testing.T, withDatabases ...withDa
ConnCh: testCtx.proxyConn,
AccessPoint: proxyAuthClient,
}
t.Cleanup(func() { require.NoError(t, testCtx.fakeRemoteSite.Close()) })
tunnel := &reversetunnel.FakeServer{
Sites: []reversetunnel.RemoteSite{
testCtx.fakeRemoteSite,
Expand Down

0 comments on commit 37788f7

Please sign in to comment.