diff --git a/lib/reversetunnel/agent.go b/lib/reversetunnel/agent.go index 0d2cba35adec1..8cf7a07e3b511 100644 --- a/lib/reversetunnel/agent.go +++ b/lib/reversetunnel/agent.go @@ -306,16 +306,25 @@ func (a *Agent) handleGlobalRequests(ctx context.Context, requestCh <-chan *ssh. switch r.Type { case versionRequest: - err := r.Reply(true, []byte(teleport.Version)) + // reply with the auth server version + pong, err := a.Client.Ping(ctx) if err != nil { - log.Debugf("Failed to reply to %v request: %v.", r.Type, err) + a.log.WithError(err).Warnf("Failed to ping auth server in response to %v request.", r.Type) + if err := r.Reply(false, []byte("Failed to retrieve auth version")); err != nil { + a.log.Debugf("Failed to reply to %v request: %v.", r.Type, err) + continue + } + } + + if err := r.Reply(true, []byte(pong.ServerVersion)); err != nil { + a.log.Debugf("Failed to reply to %v request: %v.", r.Type, err) continue } default: // This handles keep-alive messages and matches the behaviour of OpenSSH. err := r.Reply(false, nil) if err != nil { - log.Debugf("Failed to reply to %v request: %v.", r.Type, err) + a.log.Debugf("Failed to reply to %v request: %v.", r.Type, err) continue } } diff --git a/lib/reversetunnel/srv.go b/lib/reversetunnel/srv.go index 471e30c88de76..0aba941727cd7 100644 --- a/lib/reversetunnel/srv.go +++ b/lib/reversetunnel/srv.go @@ -39,7 +39,6 @@ import ( "github.com/gravitational/teleport/lib/sshutils" "github.com/gravitational/teleport/lib/utils" - "github.com/coreos/go-semver/semver" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" "github.com/prometheus/client_golang/prometheus" @@ -638,7 +637,7 @@ func (s *server) handleHeartbeat(conn net.Conn, sconn *ssh.ServerConn, nch ssh.N // nodes it's a node dialing back. val, ok := sconn.Permissions.Extensions[extCertRole] if !ok { - log.Errorf("Failed to accept connection, missing %q extension", extCertRole) + s.log.Errorf("Failed to accept connection, missing %q extension", extCertRole) s.rejectRequest(nch, ssh.ConnectionFailed, "unknown role") return } @@ -662,7 +661,7 @@ func (s *server) handleHeartbeat(conn net.Conn, sconn *ssh.ServerConn, nch ssh.N s.handleNewCluster(conn, sconn, nch) // Unknown role. default: - log.Errorf("Unsupported role attempting to connect: %v", val) + s.log.Errorf("Unsupported role attempting to connect: %v", val) s.rejectRequest(nch, ssh.ConnectionFailed, fmt.Sprintf("unsupported role %v", val)) } } @@ -670,14 +669,14 @@ func (s *server) handleHeartbeat(conn net.Conn, sconn *ssh.ServerConn, nch ssh.N func (s *server) handleNewService(role types.SystemRole, conn net.Conn, sconn *ssh.ServerConn, nch ssh.NewChannel, connType types.TunnelType) { cluster, rconn, err := s.upsertServiceConn(conn, sconn, connType) if err != nil { - log.Errorf("Failed to upsert %s: %v.", role, err) + s.log.Errorf("Failed to upsert %s: %v.", role, err) sconn.Close() return } ch, req, err := nch.Accept() if err != nil { - log.Errorf("Failed to accept on channel: %v.", err) + s.log.Errorf("Failed to accept on channel: %v.", err) sconn.Close() return } @@ -689,14 +688,14 @@ func (s *server) handleNewCluster(conn net.Conn, sshConn *ssh.ServerConn, nch ss // add the incoming site (cluster) to the list of active connections: site, remoteConn, err := s.upsertRemoteCluster(conn, sshConn) if err != nil { - log.Error(trace.Wrap(err)) + s.log.Error(trace.Wrap(err)) s.rejectRequest(nch, ssh.ConnectionFailed, "failed to accept incoming cluster connection") return } // accept the request and start the heartbeat on it: ch, req, err := nch.Accept() if err != nil { - log.Error(trace.Wrap(err)) + s.log.Error(trace.Wrap(err)) sshConn.Close() return } @@ -1059,27 +1058,12 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite, } remoteSite.remoteClient = clt - // DELETE IN: 8.0.0 - // - // Check if the cluster that is connecting is a pre-v7 cluster. If it is, - // don't assume the newer organization of cluster configuration resources - // (RFD 28) because older proxy servers will reject that causing the cache - // to go into a re-sync loop. - var accessPointFunc auth.NewCachingAccessPoint - ok, err := isPreV7Cluster(closeContext, sconn) + remoteVersion, err := getRemoteAuthVersion(closeContext, sconn) if err != nil { return nil, trace.Wrap(err) } - if ok { - log.Debugf("Pre-v7 cluster connecting, loading old cache policy.") - accessPointFunc = srv.Config.NewCachingAccessPointOldProxy - } else { - accessPointFunc = srv.newAccessPoint - } - // Configure access to the cached subset of the Auth Server API of the remote - // cluster this remote site provides access to. - accessPoint, err := accessPointFunc(clt, []string{"reverse", domainName}) + accessPoint, err := createRemoteAccessPoint(srv, clt, remoteVersion, domainName) if err != nil { return nil, trace.Wrap(err) } @@ -1124,33 +1108,35 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite, return remoteSite, nil } -// DELETE IN: 8.0.0. -// -// isPreV7Cluster checks if the cluster is older than 7.0.0. -func isPreV7Cluster(ctx context.Context, conn ssh.Conn) (bool, error) { - version, err := sendVersionRequest(ctx, conn) +// createRemoteAccessPoint creates a new access point for the remote cluster. +// Checks if the cluster that is connecting is a pre-v7 cluster. If it is, +// don't assume the newer organization of cluster configuration resources +// (RFD 28) because older proxy servers will reject that causing the cache +// to go into a re-sync loop. +func createRemoteAccessPoint(srv *server, clt auth.ClientI, version, domainName string) (auth.AccessPoint, error) { + ok, err := utils.MinVerWithoutPreRelease(version, utils.VersionBeforeAlpha("7.0.0")) if err != nil { - return false, trace.Wrap(err) + return nil, trace.Wrap(err) } - remoteClusterVersion, err := semver.NewVersion(version) - if err != nil { - return false, trace.Wrap(err) + accessPointFunc := srv.Config.NewCachingAccessPoint + if !ok { + srv.log.Debugf("cluster %q running %q is connecting, loading old cache policy.", domainName, version) + accessPointFunc = srv.Config.NewCachingAccessPointOldProxy } - minClusterVersion, err := semver.NewVersion(utils.VersionBeforeAlpha("7.0.0")) + + // Configure access to the cached subset of the Auth Server API of the remote + // cluster this remote site provides access to. + accessPoint, err := accessPointFunc(clt, []string{"reverse", domainName}) if err != nil { - return false, trace.Wrap(err) - } - // Return true if the version is older than 7.0.0 - if remoteClusterVersion.LessThan(*minClusterVersion) { - return true, nil + return nil, trace.Wrap(err) } - return false, nil + return accessPoint, nil } -// sendVersionRequest sends a request for the version remote Teleport cluster. -func sendVersionRequest(ctx context.Context, sconn ssh.Conn) (string, error) { +// getRemoteAuthVersion sends a version request to the remote agent. +func getRemoteAuthVersion(ctx context.Context, sconn ssh.Conn) (string, error) { errorCh := make(chan error, 1) versionCh := make(chan string, 1) @@ -1164,6 +1150,7 @@ func sendVersionRequest(ctx context.Context, sconn ssh.Conn) (string, error) { errorCh <- trace.BadParameter("no response to %v request", versionRequest) return } + versionCh <- string(payload) }() @@ -1175,7 +1162,7 @@ func sendVersionRequest(ctx context.Context, sconn ssh.Conn) (string, error) { case <-time.After(defaults.WaitCopyTimeout): return "", trace.BadParameter("timeout waiting for version") case <-ctx.Done(): - return "", ctx.Err() + return "", trace.Wrap(ctx.Err()) } } diff --git a/lib/reversetunnel/srv_test.go b/lib/reversetunnel/srv_test.go index 01a04d91fe303..e6747d7f59a08 100644 --- a/lib/reversetunnel/srv_test.go +++ b/lib/reversetunnel/srv_test.go @@ -18,6 +18,7 @@ package reversetunnel import ( "context" + "errors" "net" "testing" "time" @@ -160,3 +161,70 @@ type mockAccessPoint struct { func (ap mockAccessPoint) GetCertAuthority(ctx context.Context, id types.CertAuthID, loadKeys bool, opts ...services.MarshalOption) (types.CertAuthority, error) { return ap.ca, nil } + +func TestCreateRemoteAccessPoint(t *testing.T) { + cases := []struct { + name string + version string + assertion require.ErrorAssertionFunc + oldRemoteProxy bool + }{ + { + name: "invalid version", + assertion: require.Error, + }, + { + name: "remote running 8.0.0", + assertion: require.NoError, + version: "8.0.0", + }, + { + name: "remote running 7.0.0", + assertion: require.NoError, + version: "7.0.0", + }, + { + name: "remote running 6.0.0", + assertion: require.NoError, + version: "6.0.0", + oldRemoteProxy: true, + }, + { + name: "remote running 5.0.0", + assertion: require.NoError, + version: "5.0.0", + oldRemoteProxy: true, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + newProxyFn := func(clt auth.ClientI, cacheName []string) (auth.AccessPoint, error) { + if tt.oldRemoteProxy { + return nil, errors.New("expected to create an old remote proxy") + } + + return nil, nil + } + + oldProxyFn := func(clt auth.ClientI, cacheName []string) (auth.AccessPoint, error) { + if !tt.oldRemoteProxy { + return nil, errors.New("expected to create an new remote proxy") + } + + return nil, nil + } + + clt := &mockAuthClient{} + srv := &server{ + log: utils.NewLoggerForTests(), + Config: Config{ + NewCachingAccessPoint: newProxyFn, + NewCachingAccessPointOldProxy: oldProxyFn, + }, + } + _, err := createRemoteAccessPoint(srv, clt, tt.version, "test") + tt.assertion(t, err) + }) + } +} diff --git a/lib/utils/ver.go b/lib/utils/ver.go index e362c6496deeb..71e967d7fffed 100644 --- a/lib/utils/ver.go +++ b/lib/utils/ver.go @@ -45,3 +45,32 @@ func CheckVersion(currentVersion, minVersion string) error { func VersionBeforeAlpha(version string) string { return version + "-aa" } + +// MinVerWithoutPreRelease compares semver strings, but skips prerelease. This allows to compare +// two versions and ignore dev,alpha,beta, etc. strings. +func MinVerWithoutPreRelease(currentVersion, minVersion string) (bool, error) { + currentSemver, minSemver, err := versionStringToSemver(currentVersion, minVersion) + if err != nil { + return false, trace.Wrap(err) + } + + // Erase pre-release string, so only version is compared. + currentSemver.PreRelease = "" + minSemver.PreRelease = "" + + return !currentSemver.LessThan(*minSemver), nil +} + +func versionStringToSemver(ver1, ver2 string) (*semver.Version, *semver.Version, error) { + v1Semver, err := semver.NewVersion(ver1) + if err != nil { + return nil, nil, trace.Wrap(err, "unsupported version format, need semver format: %q, e.g 1.0.0", v1Semver) + } + + v2Semver, err := semver.NewVersion(ver2) + if err != nil { + return nil, nil, trace.Wrap(err, "unsupported version format, need semver format: %q, e.g 1.0.0", v2Semver) + } + + return v1Semver, v2Semver, nil +}