From 1b3cbba667a84c28cd2baf69e3c32e82f4361804 Mon Sep 17 00:00:00 2001 From: rosstimothy <39066650+rosstimothy@users.noreply.github.com> Date: Wed, 27 Apr 2022 09:22:13 -0400 Subject: [PATCH] Create remote site cache based on remote auth version (#12130) (#12249) * Create remote site cache based on remote auth version The cache policy used for a remote site is determined based on the response from a version request. However the version response was only returning the proxy version. If the remote site was not running the same version for both auth and proxy, then the cache policy chosen could be invalid. The reverse tunnel agent now pings its auth server and reports both the auth version in response to a version request. Fixes #12010 (cherry picked from commit 4f2ad1f14a5be4b53bc65eca4c8d859de167ba83) # Conflicts: # lib/reversetunnel/srv.go --- lib/reversetunnel/agent.go | 15 +++++-- lib/reversetunnel/srv.go | 73 ++++++++++++++--------------------- lib/reversetunnel/srv_test.go | 68 ++++++++++++++++++++++++++++++++ lib/utils/ver.go | 29 ++++++++++++++ 4 files changed, 139 insertions(+), 46 deletions(-) diff --git a/lib/reversetunnel/agent.go b/lib/reversetunnel/agent.go index 0d2cba35adec1..8cf7a07e3b511 100644 --- a/lib/reversetunnel/agent.go +++ b/lib/reversetunnel/agent.go @@ -306,16 +306,25 @@ func (a *Agent) handleGlobalRequests(ctx context.Context, requestCh <-chan *ssh. switch r.Type { case versionRequest: - err := r.Reply(true, []byte(teleport.Version)) + // reply with the auth server version + pong, err := a.Client.Ping(ctx) if err != nil { - log.Debugf("Failed to reply to %v request: %v.", r.Type, err) + a.log.WithError(err).Warnf("Failed to ping auth server in response to %v request.", r.Type) + if err := r.Reply(false, []byte("Failed to retrieve auth version")); err != nil { + a.log.Debugf("Failed to reply to %v request: %v.", r.Type, err) + continue + } + } + + if err := r.Reply(true, []byte(pong.ServerVersion)); err != nil { + a.log.Debugf("Failed to reply to %v request: %v.", r.Type, err) continue } default: // This handles keep-alive messages and matches the behaviour of OpenSSH. err := r.Reply(false, nil) if err != nil { - log.Debugf("Failed to reply to %v request: %v.", r.Type, err) + a.log.Debugf("Failed to reply to %v request: %v.", r.Type, err) continue } } diff --git a/lib/reversetunnel/srv.go b/lib/reversetunnel/srv.go index 471e30c88de76..0aba941727cd7 100644 --- a/lib/reversetunnel/srv.go +++ b/lib/reversetunnel/srv.go @@ -39,7 +39,6 @@ import ( "github.com/gravitational/teleport/lib/sshutils" "github.com/gravitational/teleport/lib/utils" - "github.com/coreos/go-semver/semver" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" "github.com/prometheus/client_golang/prometheus" @@ -638,7 +637,7 @@ func (s *server) handleHeartbeat(conn net.Conn, sconn *ssh.ServerConn, nch ssh.N // nodes it's a node dialing back. val, ok := sconn.Permissions.Extensions[extCertRole] if !ok { - log.Errorf("Failed to accept connection, missing %q extension", extCertRole) + s.log.Errorf("Failed to accept connection, missing %q extension", extCertRole) s.rejectRequest(nch, ssh.ConnectionFailed, "unknown role") return } @@ -662,7 +661,7 @@ func (s *server) handleHeartbeat(conn net.Conn, sconn *ssh.ServerConn, nch ssh.N s.handleNewCluster(conn, sconn, nch) // Unknown role. default: - log.Errorf("Unsupported role attempting to connect: %v", val) + s.log.Errorf("Unsupported role attempting to connect: %v", val) s.rejectRequest(nch, ssh.ConnectionFailed, fmt.Sprintf("unsupported role %v", val)) } } @@ -670,14 +669,14 @@ func (s *server) handleHeartbeat(conn net.Conn, sconn *ssh.ServerConn, nch ssh.N func (s *server) handleNewService(role types.SystemRole, conn net.Conn, sconn *ssh.ServerConn, nch ssh.NewChannel, connType types.TunnelType) { cluster, rconn, err := s.upsertServiceConn(conn, sconn, connType) if err != nil { - log.Errorf("Failed to upsert %s: %v.", role, err) + s.log.Errorf("Failed to upsert %s: %v.", role, err) sconn.Close() return } ch, req, err := nch.Accept() if err != nil { - log.Errorf("Failed to accept on channel: %v.", err) + s.log.Errorf("Failed to accept on channel: %v.", err) sconn.Close() return } @@ -689,14 +688,14 @@ func (s *server) handleNewCluster(conn net.Conn, sshConn *ssh.ServerConn, nch ss // add the incoming site (cluster) to the list of active connections: site, remoteConn, err := s.upsertRemoteCluster(conn, sshConn) if err != nil { - log.Error(trace.Wrap(err)) + s.log.Error(trace.Wrap(err)) s.rejectRequest(nch, ssh.ConnectionFailed, "failed to accept incoming cluster connection") return } // accept the request and start the heartbeat on it: ch, req, err := nch.Accept() if err != nil { - log.Error(trace.Wrap(err)) + s.log.Error(trace.Wrap(err)) sshConn.Close() return } @@ -1059,27 +1058,12 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite, } remoteSite.remoteClient = clt - // DELETE IN: 8.0.0 - // - // Check if the cluster that is connecting is a pre-v7 cluster. If it is, - // don't assume the newer organization of cluster configuration resources - // (RFD 28) because older proxy servers will reject that causing the cache - // to go into a re-sync loop. - var accessPointFunc auth.NewCachingAccessPoint - ok, err := isPreV7Cluster(closeContext, sconn) + remoteVersion, err := getRemoteAuthVersion(closeContext, sconn) if err != nil { return nil, trace.Wrap(err) } - if ok { - log.Debugf("Pre-v7 cluster connecting, loading old cache policy.") - accessPointFunc = srv.Config.NewCachingAccessPointOldProxy - } else { - accessPointFunc = srv.newAccessPoint - } - // Configure access to the cached subset of the Auth Server API of the remote - // cluster this remote site provides access to. - accessPoint, err := accessPointFunc(clt, []string{"reverse", domainName}) + accessPoint, err := createRemoteAccessPoint(srv, clt, remoteVersion, domainName) if err != nil { return nil, trace.Wrap(err) } @@ -1124,33 +1108,35 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite, return remoteSite, nil } -// DELETE IN: 8.0.0. -// -// isPreV7Cluster checks if the cluster is older than 7.0.0. -func isPreV7Cluster(ctx context.Context, conn ssh.Conn) (bool, error) { - version, err := sendVersionRequest(ctx, conn) +// createRemoteAccessPoint creates a new access point for the remote cluster. +// Checks if the cluster that is connecting is a pre-v7 cluster. If it is, +// don't assume the newer organization of cluster configuration resources +// (RFD 28) because older proxy servers will reject that causing the cache +// to go into a re-sync loop. +func createRemoteAccessPoint(srv *server, clt auth.ClientI, version, domainName string) (auth.AccessPoint, error) { + ok, err := utils.MinVerWithoutPreRelease(version, utils.VersionBeforeAlpha("7.0.0")) if err != nil { - return false, trace.Wrap(err) + return nil, trace.Wrap(err) } - remoteClusterVersion, err := semver.NewVersion(version) - if err != nil { - return false, trace.Wrap(err) + accessPointFunc := srv.Config.NewCachingAccessPoint + if !ok { + srv.log.Debugf("cluster %q running %q is connecting, loading old cache policy.", domainName, version) + accessPointFunc = srv.Config.NewCachingAccessPointOldProxy } - minClusterVersion, err := semver.NewVersion(utils.VersionBeforeAlpha("7.0.0")) + + // Configure access to the cached subset of the Auth Server API of the remote + // cluster this remote site provides access to. + accessPoint, err := accessPointFunc(clt, []string{"reverse", domainName}) if err != nil { - return false, trace.Wrap(err) - } - // Return true if the version is older than 7.0.0 - if remoteClusterVersion.LessThan(*minClusterVersion) { - return true, nil + return nil, trace.Wrap(err) } - return false, nil + return accessPoint, nil } -// sendVersionRequest sends a request for the version remote Teleport cluster. -func sendVersionRequest(ctx context.Context, sconn ssh.Conn) (string, error) { +// getRemoteAuthVersion sends a version request to the remote agent. +func getRemoteAuthVersion(ctx context.Context, sconn ssh.Conn) (string, error) { errorCh := make(chan error, 1) versionCh := make(chan string, 1) @@ -1164,6 +1150,7 @@ func sendVersionRequest(ctx context.Context, sconn ssh.Conn) (string, error) { errorCh <- trace.BadParameter("no response to %v request", versionRequest) return } + versionCh <- string(payload) }() @@ -1175,7 +1162,7 @@ func sendVersionRequest(ctx context.Context, sconn ssh.Conn) (string, error) { case <-time.After(defaults.WaitCopyTimeout): return "", trace.BadParameter("timeout waiting for version") case <-ctx.Done(): - return "", ctx.Err() + return "", trace.Wrap(ctx.Err()) } } diff --git a/lib/reversetunnel/srv_test.go b/lib/reversetunnel/srv_test.go index 01a04d91fe303..e6747d7f59a08 100644 --- a/lib/reversetunnel/srv_test.go +++ b/lib/reversetunnel/srv_test.go @@ -18,6 +18,7 @@ package reversetunnel import ( "context" + "errors" "net" "testing" "time" @@ -160,3 +161,70 @@ type mockAccessPoint struct { func (ap mockAccessPoint) GetCertAuthority(ctx context.Context, id types.CertAuthID, loadKeys bool, opts ...services.MarshalOption) (types.CertAuthority, error) { return ap.ca, nil } + +func TestCreateRemoteAccessPoint(t *testing.T) { + cases := []struct { + name string + version string + assertion require.ErrorAssertionFunc + oldRemoteProxy bool + }{ + { + name: "invalid version", + assertion: require.Error, + }, + { + name: "remote running 8.0.0", + assertion: require.NoError, + version: "8.0.0", + }, + { + name: "remote running 7.0.0", + assertion: require.NoError, + version: "7.0.0", + }, + { + name: "remote running 6.0.0", + assertion: require.NoError, + version: "6.0.0", + oldRemoteProxy: true, + }, + { + name: "remote running 5.0.0", + assertion: require.NoError, + version: "5.0.0", + oldRemoteProxy: true, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + newProxyFn := func(clt auth.ClientI, cacheName []string) (auth.AccessPoint, error) { + if tt.oldRemoteProxy { + return nil, errors.New("expected to create an old remote proxy") + } + + return nil, nil + } + + oldProxyFn := func(clt auth.ClientI, cacheName []string) (auth.AccessPoint, error) { + if !tt.oldRemoteProxy { + return nil, errors.New("expected to create an new remote proxy") + } + + return nil, nil + } + + clt := &mockAuthClient{} + srv := &server{ + log: utils.NewLoggerForTests(), + Config: Config{ + NewCachingAccessPoint: newProxyFn, + NewCachingAccessPointOldProxy: oldProxyFn, + }, + } + _, err := createRemoteAccessPoint(srv, clt, tt.version, "test") + tt.assertion(t, err) + }) + } +} diff --git a/lib/utils/ver.go b/lib/utils/ver.go index e362c6496deeb..71e967d7fffed 100644 --- a/lib/utils/ver.go +++ b/lib/utils/ver.go @@ -45,3 +45,32 @@ func CheckVersion(currentVersion, minVersion string) error { func VersionBeforeAlpha(version string) string { return version + "-aa" } + +// MinVerWithoutPreRelease compares semver strings, but skips prerelease. This allows to compare +// two versions and ignore dev,alpha,beta, etc. strings. +func MinVerWithoutPreRelease(currentVersion, minVersion string) (bool, error) { + currentSemver, minSemver, err := versionStringToSemver(currentVersion, minVersion) + if err != nil { + return false, trace.Wrap(err) + } + + // Erase pre-release string, so only version is compared. + currentSemver.PreRelease = "" + minSemver.PreRelease = "" + + return !currentSemver.LessThan(*minSemver), nil +} + +func versionStringToSemver(ver1, ver2 string) (*semver.Version, *semver.Version, error) { + v1Semver, err := semver.NewVersion(ver1) + if err != nil { + return nil, nil, trace.Wrap(err, "unsupported version format, need semver format: %q, e.g 1.0.0", v1Semver) + } + + v2Semver, err := semver.NewVersion(ver2) + if err != nil { + return nil, nil, trace.Wrap(err, "unsupported version format, need semver format: %q, e.g 1.0.0", v2Semver) + } + + return v1Semver, v2Semver, nil +}