Skip to content

Commit

Permalink
Create remote site cache based on remote auth version (#12130) (#12250)
Browse files Browse the repository at this point in the history
* Create remote site cache based on remote auth version

The cache policy used for a remote site is determined based on
the response from a version request. However the version response
was only returning the proxy version. If the remote site was not
running the same version for both auth and proxy, then the cache
policy chosen could be invalid.

The reverse tunnel agent now pings its auth server and reports
both the auth version in response to a version request.

Fixes #12010

(cherry picked from commit 4f2ad1f)
  • Loading branch information
rosstimothy authored Apr 27, 2022
1 parent f78faab commit c58ccee
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 43 deletions.
15 changes: 12 additions & 3 deletions lib/reversetunnel/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -354,16 +354,25 @@ func (a *Agent) handleGlobalRequests(ctx context.Context, requestCh <-chan *ssh.

switch r.Type {
case versionRequest:
err := r.Reply(true, []byte(teleport.Version))
// reply with the auth server version
pong, err := a.Client.Ping(ctx)
if err != nil {
log.Debugf("Failed to reply to %v request: %v.", r.Type, err)
a.log.WithError(err).Warnf("Failed to ping auth server in response to %v request.", r.Type)
if err := r.Reply(false, []byte("Failed to retrieve auth version")); err != nil {
a.log.Debugf("Failed to reply to %v request: %v.", r.Type, err)
continue
}
}

if err := r.Reply(true, []byte(pong.ServerVersion)); err != nil {
a.log.Debugf("Failed to reply to %v request: %v.", r.Type, err)
continue
}
default:
// This handles keep-alive messages and matches the behaviour of OpenSSH.
err := r.Reply(false, nil)
if err != nil {
log.Debugf("Failed to reply to %v request: %v.", r.Type, err)
a.log.Debugf("Failed to reply to %v request: %v.", r.Type, err)
continue
}
}
Expand Down
70 changes: 30 additions & 40 deletions lib/reversetunnel/srv.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@ import (
"github.com/gravitational/teleport/lib/sshca"
"github.com/gravitational/teleport/lib/sshutils"
"github.com/gravitational/teleport/lib/utils"

"github.com/coreos/go-semver/semver"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/prometheus/client_golang/prometheus"
Expand Down Expand Up @@ -638,7 +636,7 @@ func (s *server) handleHeartbeat(conn net.Conn, sconn *ssh.ServerConn, nch ssh.N
// nodes it's a node dialing back.
val, ok := sconn.Permissions.Extensions[extCertRole]
if !ok {
log.Errorf("Failed to accept connection, missing %q extension", extCertRole)
s.log.Errorf("Failed to accept connection, missing %q extension", extCertRole)
s.rejectRequest(nch, ssh.ConnectionFailed, "unknown role")
return
}
Expand All @@ -664,22 +662,22 @@ func (s *server) handleHeartbeat(conn net.Conn, sconn *ssh.ServerConn, nch ssh.N
s.handleNewService(role, conn, sconn, nch, types.WindowsDesktopTunnel)
// Unknown role.
default:
log.Errorf("Unsupported role attempting to connect: %v", val)
s.log.Errorf("Unsupported role attempting to connect: %v", val)
s.rejectRequest(nch, ssh.ConnectionFailed, fmt.Sprintf("unsupported role %v", val))
}
}

func (s *server) handleNewService(role types.SystemRole, conn net.Conn, sconn *ssh.ServerConn, nch ssh.NewChannel, connType types.TunnelType) {
cluster, rconn, err := s.upsertServiceConn(conn, sconn, connType)
if err != nil {
log.Errorf("Failed to upsert %s: %v.", role, err)
s.log.Errorf("Failed to upsert %s: %v.", role, err)
sconn.Close()
return
}

ch, req, err := nch.Accept()
if err != nil {
log.Errorf("Failed to accept on channel: %v.", err)
s.log.Errorf("Failed to accept on channel: %v.", err)
sconn.Close()
return
}
Expand All @@ -691,14 +689,14 @@ func (s *server) handleNewCluster(conn net.Conn, sshConn *ssh.ServerConn, nch ss
// add the incoming site (cluster) to the list of active connections:
site, remoteConn, err := s.upsertRemoteCluster(conn, sshConn)
if err != nil {
log.Error(trace.Wrap(err))
s.log.Error(trace.Wrap(err))
s.rejectRequest(nch, ssh.ConnectionFailed, "failed to accept incoming cluster connection")
return
}
// accept the request and start the heartbeat on it:
ch, req, err := nch.Accept()
if err != nil {
log.Error(trace.Wrap(err))
s.log.Error(trace.Wrap(err))
sshConn.Close()
return
}
Expand Down Expand Up @@ -1061,25 +1059,12 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite,
}
remoteSite.remoteClient = clt

// Check if the cluster that is connecting is a pre-v8 cluster. If it is,
// don't assume the newer organization of cluster configuration resources
// (RFD 28) because older proxy servers will reject that causing the cache
// to go into a re-sync loop.
var accessPointFunc auth.NewRemoteProxyCachingAccessPoint
ok, err := isPreV8Cluster(closeContext, sconn)
remoteVersion, err := getRemoteAuthVersion(closeContext, sconn)
if err != nil {
return nil, trace.Wrap(err)
}
if ok {
log.Debugf("Pre-v8 cluster connecting, loading old cache policy.")
accessPointFunc = srv.Config.NewCachingAccessPointOldProxy
} else {
accessPointFunc = srv.newAccessPoint
}

// Configure access to the cached subset of the Auth Server API of the remote
// cluster this remote site provides access to.
accessPoint, err := accessPointFunc(clt, []string{"reverse", domainName})
accessPoint, err := createRemoteAccessPoint(srv, clt, remoteVersion, domainName)
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down Expand Up @@ -1124,31 +1109,35 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite,
return remoteSite, nil
}

// isPreV8Cluster checks if the cluster is older than 8.0.0.
func isPreV8Cluster(ctx context.Context, conn ssh.Conn) (bool, error) {
version, err := sendVersionRequest(ctx, conn)
// createRemoteAccessPoint creates a new access point for the remote cluster.
// Checks if the cluster that is connecting is a pre-v8 cluster. If it is,
// don't assume the newer organization of cluster configuration resources
// (RFD 28) because older proxy servers will reject that causing the cache
// to go into a re-sync loop.
func createRemoteAccessPoint(srv *server, clt auth.ClientI, version, domainName string) (auth.RemoteProxyAccessPoint, error) {
ok, err := utils.MinVerWithoutPreRelease(version, utils.VersionBeforeAlpha("8.0.0"))
if err != nil {
return false, trace.Wrap(err)
return nil, trace.Wrap(err)
}

remoteClusterVersion, err := semver.NewVersion(version)
if err != nil {
return false, trace.Wrap(err)
accessPointFunc := srv.Config.NewCachingAccessPoint
if !ok {
srv.log.Debugf("cluster %q running %q is connecting, loading old cache policy.", domainName, version)
accessPointFunc = srv.Config.NewCachingAccessPointOldProxy
}
minClusterVersion, err := semver.NewVersion(utils.VersionBeforeAlpha("8.0.0"))

// Configure access to the cached subset of the Auth Server API of the remote
// cluster this remote site provides access to.
accessPoint, err := accessPointFunc(clt, []string{"reverse", domainName})
if err != nil {
return false, trace.Wrap(err)
}
// Return true if the version is older than 8.0.0
if remoteClusterVersion.LessThan(*minClusterVersion) {
return true, nil
return nil, trace.Wrap(err)
}

return false, nil
return accessPoint, nil
}

// sendVersionRequest sends a request for the version remote Teleport cluster.
func sendVersionRequest(ctx context.Context, sconn ssh.Conn) (string, error) {
// getRemoteAuthVersion sends a version request to the remote agent.
func getRemoteAuthVersion(ctx context.Context, sconn ssh.Conn) (string, error) {
errorCh := make(chan error, 1)
versionCh := make(chan string, 1)

Expand All @@ -1162,6 +1151,7 @@ func sendVersionRequest(ctx context.Context, sconn ssh.Conn) (string, error) {
errorCh <- trace.BadParameter("no response to %v request", versionRequest)
return
}

versionCh <- string(payload)
}()

Expand All @@ -1173,7 +1163,7 @@ func sendVersionRequest(ctx context.Context, sconn ssh.Conn) (string, error) {
case <-time.After(defaults.WaitCopyTimeout):
return "", trace.BadParameter("timeout waiting for version")
case <-ctx.Done():
return "", ctx.Err()
return "", trace.Wrap(ctx.Err())
}
}

Expand Down
68 changes: 68 additions & 0 deletions lib/reversetunnel/srv_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package reversetunnel

import (
"context"
"errors"
"net"
"testing"
"time"
Expand Down Expand Up @@ -160,3 +161,70 @@ type mockAccessPoint struct {
func (ap mockAccessPoint) GetCertAuthority(ctx context.Context, id types.CertAuthID, loadKeys bool, opts ...services.MarshalOption) (types.CertAuthority, error) {
return ap.ca, nil
}

func TestCreateRemoteAccessPoint(t *testing.T) {
cases := []struct {
name string
version string
assertion require.ErrorAssertionFunc
oldRemoteProxy bool
}{
{
name: "invalid version",
assertion: require.Error,
},
{
name: "remote running 9.0.0",
assertion: require.NoError,
version: "9.0.0",
},
{
name: "remote running 8.0.0",
assertion: require.NoError,
version: "8.0.0",
},
{
name: "remote running 7.0.0",
assertion: require.NoError,
version: "7.0.0",
oldRemoteProxy: true,
},
{
name: "remote running 6.0.0",
assertion: require.NoError,
version: "6.0.0",
oldRemoteProxy: true,
},
}

for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
newProxyFn := func(clt auth.ClientI, cacheName []string) (auth.RemoteProxyAccessPoint, error) {
if tt.oldRemoteProxy {
return nil, errors.New("expected to create an old remote proxy")
}

return nil, nil
}

oldProxyFn := func(clt auth.ClientI, cacheName []string) (auth.RemoteProxyAccessPoint, error) {
if !tt.oldRemoteProxy {
return nil, errors.New("expected to create an new remote proxy")
}

return nil, nil
}

clt := &mockAuthClient{}
srv := &server{
log: utils.NewLoggerForTests(),
Config: Config{
NewCachingAccessPoint: newProxyFn,
NewCachingAccessPointOldProxy: oldProxyFn,
},
}
_, err := createRemoteAccessPoint(srv, clt, tt.version, "test")
tt.assertion(t, err)
})
}
}
29 changes: 29 additions & 0 deletions lib/utils/ver.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,32 @@ func CheckVersion(currentVersion, minVersion string) error {
func VersionBeforeAlpha(version string) string {
return version + "-aa"
}

// MinVerWithoutPreRelease compares semver strings, but skips prerelease. This allows to compare
// two versions and ignore dev,alpha,beta, etc. strings.
func MinVerWithoutPreRelease(currentVersion, minVersion string) (bool, error) {
currentSemver, minSemver, err := versionStringToSemver(currentVersion, minVersion)
if err != nil {
return false, trace.Wrap(err)
}

// Erase pre-release string, so only version is compared.
currentSemver.PreRelease = ""
minSemver.PreRelease = ""

return !currentSemver.LessThan(*minSemver), nil
}

func versionStringToSemver(ver1, ver2 string) (*semver.Version, *semver.Version, error) {
v1Semver, err := semver.NewVersion(ver1)
if err != nil {
return nil, nil, trace.Wrap(err, "unsupported version format, need semver format: %q, e.g 1.0.0", v1Semver)
}

v2Semver, err := semver.NewVersion(ver2)
if err != nil {
return nil, nil, trace.Wrap(err, "unsupported version format, need semver format: %q, e.g 1.0.0", v2Semver)
}

return v1Semver, v2Semver, nil
}

0 comments on commit c58ccee

Please sign in to comment.