Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[branch/v8] Clean up remoteSites with no active tunnels (#11435) #11706

Merged
merged 2 commits into from
Apr 5, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions lib/reversetunnel/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package reversetunnel
import (
"context"
"fmt"
"io"
"net"
"time"

Expand Down Expand Up @@ -99,6 +100,8 @@ type RemoteSite interface {
// IsClosed reports whether this RemoteSite has been closed and should no
// longer be used.
IsClosed() bool
// Closer allows the site to be closed
io.Closer
}

// Tunnel provides access to connected local or remote clusters
Expand Down
4 changes: 4 additions & 0 deletions lib/reversetunnel/fake.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,7 @@ func (s *FakeRemoteSite) Dial(params DialParams) (net.Conn, error) {
s.ConnCh <- readerConn
return writerConn, nil
}

func (s *FakeRemoteSite) Close() error {
return nil
}
3 changes: 3 additions & 0 deletions lib/reversetunnel/localsite.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,9 @@ func (s *localSite) DialTCP(params DialParams) (net.Conn, error) {
// IsClosed always returns false because localSite is never closed.
func (s *localSite) IsClosed() bool { return false }

// Close always returns nil because a localSite isn't closed.
func (s *localSite) Close() error { return nil }

func (s *localSite) dialWithAgent(params DialParams) (net.Conn, error) {
if params.GetUserAgent == nil {
return nil, trace.BadParameter("user agent getter missing")
Expand Down
3 changes: 3 additions & 0 deletions lib/reversetunnel/peer.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,9 @@ func (p *clusterPeers) DialTCP(params DialParams) (conn net.Conn, err error) {
// IsClosed always returns false because clusterPeers is never closed.
func (p *clusterPeers) IsClosed() bool { return false }

// Close always returns nil because a clusterPeers isn't closed.
func (p *clusterPeers) Close() error { return nil }

// newClusterPeer returns new cluster peer
func newClusterPeer(srv *server, connInfo types.TunnelConnection, offlineThreshold time.Duration) (*clusterPeer, error) {
clusterPeer := &clusterPeer{
Expand Down
19 changes: 13 additions & 6 deletions lib/reversetunnel/remotesite.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package reversetunnel

import (
"context"
"errors"
"fmt"
"net"
"sync"
Expand Down Expand Up @@ -151,7 +152,7 @@ func (s *remoteSite) connectionCount() int {
return len(s.connections)
}

func (s *remoteSite) hasValidConnections() bool {
func (s *remoteSite) HasValidConnections() bool {
s.RLock()
defer s.RUnlock()

Expand Down Expand Up @@ -314,13 +315,20 @@ func (s *remoteSite) fanOutProxies(proxies []types.Server) {
}
}

// handleHearbeat receives heartbeat messages from the connected agent
// handleHeartbeat receives heartbeat messages from the connected agent
// if the agent has missed several heartbeats in a row, Proxy marks
// the connection as invalid.
func (s *remoteSite) handleHeartbeat(conn *remoteConn, ch ssh.Channel, reqC <-chan *ssh.Request) {
defer func() {
s.Infof("Cluster connection closed.")
conn.Close()

if err := conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
s.WithError(err).Warnf("Failed to close remote connection for remote site %s", s.domainName)
}

if err := s.srv.onSiteTunnelClose(s); err != nil {
s.WithError(err).Warnf("Failed to close remote site %s", s.domainName)
}
}()

firstHeartbeat := true
Expand All @@ -344,7 +352,7 @@ func (s *remoteSite) handleHeartbeat(conn *remoteConn, ch ssh.Channel, reqC <-ch
if req == nil {
s.Infof("Cluster agent disconnected.")
conn.markInvalid(trace.ConnectionProblem(nil, "agent disconnected"))
if !s.hasValidConnections() {
if !s.HasValidConnections() {
s.Debugf("Deleting connection record.")
s.deleteConnectionRecord()
}
Expand Down Expand Up @@ -431,8 +439,7 @@ func (s *remoteSite) updateCertAuthorities(retry utils.Retry) {
s.Debugf("Remote cluster %v does not support cert authorities rotation yet.", s.domainName)
case trace.IsCompareFailed(err):
s.Infof("Remote cluster has updated certificate authorities, going to force reconnect.")
s.srv.removeSite(s.domainName)
s.Close()
s.srv.onSiteTunnelClose(&alwaysClose{RemoteSite: s})
return
case trace.IsConnectionProblem(err):
s.Debugf("Remote cluster %v is offline.", s.domainName)
Expand Down
47 changes: 36 additions & 11 deletions lib/reversetunnel/srv.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"context"
"crypto/tls"
"fmt"
"io"
"net"
"strings"
"sync"
Expand Down Expand Up @@ -84,7 +85,7 @@ type server struct {
srv *sshutils.Server
limiter *limiter.Limiter

// remoteSites is the list of conencted remote clusters
// remoteSites is the list of connected remote clusters
remoteSites []*remoteSite

// localSites is the list of local (our own cluster) tunnel clients,
Expand Down Expand Up @@ -348,7 +349,7 @@ func remoteClustersMap(rc []types.RemoteCluster) map[string]types.RemoteCluster
}

// disconnectClusters disconnects reverse tunnel connections from remote clusters
// that were deleted from the the local cluster side and cleans up in memory objects.
// that were deleted from the local cluster side and cleans up in memory objects.
// In this case all local trust has been deleted, so all the tunnel connections have to be dropped.
func (s *server) disconnectClusters() error {
connectedRemoteClusters := s.getRemoteClusters()
Expand All @@ -363,9 +364,7 @@ func (s *server) disconnectClusters() error {
for _, cluster := range connectedRemoteClusters {
if _, ok := remoteMap[cluster.GetName()]; !ok {
s.log.Infof("Remote cluster %q has been deleted. Disconnecting it from the proxy.", cluster.GetName())
s.removeSite(cluster.GetName())
err := cluster.Close()
if err != nil {
if err := s.onSiteTunnelClose(&alwaysClose{RemoteSite: cluster}); err != nil {
s.log.Debugf("Failure closing cluster %q: %v.", cluster.GetName(), err)
}
}
Expand Down Expand Up @@ -956,22 +955,48 @@ func (s *server) GetSite(name string) (RemoteSite, error) {
return nil, trace.NotFound("cluster %q is not found", name)
}

func (s *server) removeSite(domainName string) error {
// alwaysClose forces onSiteTunnelClose to remove and close
// the site by always returning false from HasValidConnections.
type alwaysClose struct {
RemoteSite
}

func (a *alwaysClose) HasValidConnections() bool {
return false
}

// siteCloser is used by onSiteTunnelClose to determine if a site should be closed
// when a tunnel is closed
type siteCloser interface {
GetName() string
HasValidConnections() bool
io.Closer
}

// onSiteTunnelClose will close and stop tracking the site with the given name
// if it has 0 active tunnels. This is done here to ensure that no new tunnels
// can be established while cleaning up a site.
func (s *server) onSiteTunnelClose(site siteCloser) error {
s.Lock()
defer s.Unlock()

if site.HasValidConnections() {
return nil
}

for i := range s.remoteSites {
if s.remoteSites[i].domainName == domainName {
if s.remoteSites[i].domainName == site.GetName() {
s.remoteSites = append(s.remoteSites[:i], s.remoteSites[i+1:]...)
return nil
return trace.Wrap(site.Close())
}
}
for i := range s.localSites {
if s.localSites[i].domainName == domainName {
if s.localSites[i].domainName == site.GetName() {
s.localSites = append(s.localSites[:i], s.localSites[i+1:]...)
return nil
return trace.Wrap(site.Close())
}
}
return trace.NotFound("cluster %q is not found", domainName)
return trace.NotFound("site %q is not found", site.GetName())
}

// fanOutProxies is a non-blocking call that updated the watches proxies
Expand Down
1 change: 1 addition & 0 deletions lib/srv/db/access_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1033,6 +1033,7 @@ func setupTestContext(ctx context.Context, t *testing.T, withDatabases ...withDa
ConnCh: testCtx.proxyConn,
AccessPoint: proxyAuthClient,
}
t.Cleanup(func() { require.NoError(t, testCtx.fakeRemoteSite.Close()) })
tunnel := &reversetunnel.FakeServer{
Sites: []reversetunnel.RemoteSite{
testCtx.fakeRemoteSite,
Expand Down