Skip to content

Commit

Permalink
Allow auth to reject outdated clients (#38182)
Browse files Browse the repository at this point in the history
Setting `TELEPORT_UNSTABLE_REJECT_OLD_CLIENTS=yes` on the Auth process now
enforces that any clients connected are running a supported version.
Clients connecting with an unsupported major version are terminated
by Auth.
  • Loading branch information
rosstimothy authored Feb 14, 2024
1 parent 21f8139 commit bf36d6a
Show file tree
Hide file tree
Showing 4 changed files with 197 additions and 4 deletions.
59 changes: 55 additions & 4 deletions lib/auth/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@ import (
"math"
"net"
"net/http"
"os"
"slices"
"time"

"github.com/coreos/go-semver/semver"
"github.com/gravitational/oxy/ratelimit"
"github.com/gravitational/trace"
om "github.com/grpc-ecosystem/go-grpc-middleware/providers/openmetrics/v2"
Expand All @@ -42,6 +44,7 @@ import (

"github.com/gravitational/teleport"
apidefaults "github.com/gravitational/teleport/api/defaults"
"github.com/gravitational/teleport/api/metadata"
"github.com/gravitational/teleport/api/types"
apiutils "github.com/gravitational/teleport/api/utils"
"github.com/gravitational/teleport/api/utils/grpc/interceptors"
Expand Down Expand Up @@ -163,14 +166,20 @@ func NewTLSServer(ctx context.Context, cfg TLSServerConfig) (*TLSServer, error)
return nil, trace.Wrap(err)
}

var oldestSupportedVersion *semver.Version
if os.Getenv("TELEPORT_UNSTABLE_REJECT_OLD_CLIENTS") == "yes" {
oldestSupportedVersion = &teleport.MinClientSemVersion
}

// authMiddleware authenticates request assuming TLS client authentication
// adds authentication information to the context
// and passes it to the API server
authMiddleware := &Middleware{
ClusterName: localClusterName.GetClusterName(),
AcceptedUsage: cfg.AcceptedUsage,
Limiter: limiter,
GRPCMetrics: grpcMetrics,
ClusterName: localClusterName.GetClusterName(),
AcceptedUsage: cfg.AcceptedUsage,
Limiter: limiter,
GRPCMetrics: grpcMetrics,
OldestSupportedVersion: oldestSupportedVersion,
}

apiServer, err := NewAPIServer(&cfg.APIConfig)
Expand Down Expand Up @@ -366,6 +375,10 @@ type Middleware struct {
// This is used by the proxy to forward the identity of the user who
// connected to the proxy to the next hop.
EnableCredentialsForwarding bool
// OldestSupportedVersion optionally allows the middleware to reject any connections
// originated from a client that is using an unsupported version. If not set, then no
// rejection occurs.
OldestSupportedVersion *semver.Version
}

// Wrap sets next handler in chain
Expand Down Expand Up @@ -404,6 +417,40 @@ func getCustomRate(endpoint string) *ratelimit.RateSet {
return nil
}

// ValidateClientVersion inspects the client version for the connection and terminates
// the [IdentityInfo.Conn] if the client is unsupported. Requires the [Middleware.OldestSupportedVersion]
// to be configured before any connection rejection occurs.
func (a *Middleware) ValidateClientVersion(ctx context.Context, info IdentityInfo) error {
if a.OldestSupportedVersion == nil {
return nil
}

clientVersionString, versionExists := metadata.ClientVersionFromContext(ctx)
if !versionExists {
return nil
}

logger := log.WithFields(logrus.Fields{"identity": info.IdentityGetter.GetIdentity().Username, "version": clientVersionString})
clientVersion, err := semver.NewVersion(clientVersionString)
if err != nil {
logger.WithError(err).Warn("Failed to determine client version")
if err := info.Conn.Close(); err != nil {
logger.WithError(err).Warn("Failed to close client connection")
}
return trace.AccessDenied("client version is unsupported")
}

if clientVersion.LessThan(*a.OldestSupportedVersion) {
logger.Info("Terminating connection of client using unsupported version")
if err := info.Conn.Close(); err != nil {
logger.WithError(err).Warn("Failed to close client connection")
}
return trace.AccessDenied("client version is unsupported")
}

return nil
}

// withAuthenticatedUser returns a new context with the ContextUser field set to
// the caller's user identity as authenticated by their client TLS certificate.
func (a *Middleware) withAuthenticatedUser(ctx context.Context) (context.Context, error) {
Expand All @@ -423,6 +470,10 @@ func (a *Middleware) withAuthenticatedUser(ctx context.Context) (context.Context
case IdentityInfo:
connState = &info.TLSInfo.State
identityGetter = info.IdentityGetter

if err := a.ValidateClientVersion(ctx, info); err != nil {
return nil, trace.Wrap(err)
}
// credentials.TLSInfo is provided if the grpc server is configured with
// credentials.NewTLS.
case credentials.TLSInfo:
Expand Down
78 changes: 78 additions & 0 deletions lib/auth/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,14 @@ import (
"testing"
"time"

"github.com/coreos/go-semver/semver"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/gravitational/trace"
"github.com/stretchr/testify/require"
"google.golang.org/grpc/metadata"

"github.com/gravitational/teleport"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/authz"
"github.com/gravitational/teleport/lib/services"
Expand Down Expand Up @@ -655,3 +659,77 @@ func (h *fakeHTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
require.Empty(h.t, r.Header.Get(TeleportImpersonateUserHeader))
require.Empty(h.t, r.Header.Get(TeleportImpersonateIPHeader))
}

type fakeConn struct {
net.Conn
}

func (f fakeConn) Close() error {
return nil
}

func TestValidateClientVersion(t *testing.T) {
cases := []struct {
name string
middleware Middleware
clientVersion string
errAssertion func(t *testing.T, err error)
}{
{
name: "rejection disabled",
errAssertion: func(t *testing.T, err error) {
require.NoError(t, err)
},
},
{
name: "rejection enabled and client version not specified",
middleware: Middleware{OldestSupportedVersion: &teleport.MinClientSemVersion},
errAssertion: func(t *testing.T, err error) {
require.NoError(t, err)
},
},
{
name: "client rejected",
middleware: Middleware{OldestSupportedVersion: &teleport.MinClientSemVersion},
clientVersion: semver.Version{Major: teleport.SemVersion.Major - 2}.String(),
errAssertion: func(t *testing.T, err error) {
require.True(t, trace.IsAccessDenied(err), "got %T, expected access denied error", err)
},
},
{
name: "valid client v-1",
middleware: Middleware{OldestSupportedVersion: &teleport.MinClientSemVersion},
clientVersion: semver.Version{Major: teleport.SemVersion.Major - 1}.String(),
errAssertion: func(t *testing.T, err error) {
require.NoError(t, err)
},
},
{
name: "valid client v-0",
middleware: Middleware{OldestSupportedVersion: &teleport.MinClientSemVersion},
clientVersion: semver.Version{Major: teleport.SemVersion.Major}.String(),
errAssertion: func(t *testing.T, err error) {
require.NoError(t, err)
},
},
{
name: "invalid client version",
middleware: Middleware{OldestSupportedVersion: &teleport.MinClientSemVersion},
clientVersion: "abc123",
errAssertion: func(t *testing.T, err error) {
require.True(t, trace.IsAccessDenied(err), "got %T, expected access denied error", err)
},
},
}

for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
if tt.clientVersion != "" {
ctx = metadata.NewIncomingContext(ctx, metadata.New(map[string]string{"version": tt.clientVersion}))
}

tt.errAssertion(t, tt.middleware.ValidateClientVersion(ctx, IdentityInfo{Conn: fakeConn{}, IdentityGetter: TestBuiltin(types.RoleNode).I}))
})
}
}
61 changes: 61 additions & 0 deletions lib/auth/tls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ import (
"github.com/gravitational/teleport/api/client/proto"
"github.com/gravitational/teleport/api/constants"
apidefaults "github.com/gravitational/teleport/api/defaults"
"github.com/gravitational/teleport/api/metadata"
"github.com/gravitational/teleport/api/types"
eventtypes "github.com/gravitational/teleport/api/types/events"
"github.com/gravitational/teleport/api/types/wrappers"
Expand All @@ -66,6 +67,66 @@ import (
"github.com/gravitational/teleport/lib/utils"
)

func TestRejectedClients(t *testing.T) {
t.Setenv("TELEPORT_UNSTABLE_REJECT_OLD_CLIENTS", "yes")

server, err := NewTestAuthServer(TestAuthServerConfig{
Dir: t.TempDir(),
ClusterName: "cluster",
Clock: clockwork.NewFakeClock(),
})
require.NoError(t, err)

user, _, err := CreateUserAndRole(server.AuthServer, "user", []string{"role"}, nil)
require.NoError(t, err)

tlsServer, err := server.NewTestTLSServer()
require.NoError(t, err)
defer tlsServer.Close()

tlsConfig, err := tlsServer.ClientTLSConfig(TestUser(user.GetName()))
require.NoError(t, err)

clt, err := NewClient(client.Config{
DialInBackground: true,
Addrs: []string{tlsServer.Addr().String()},
Credentials: []client.Credentials{
client.LoadTLS(tlsConfig),
},
CircuitBreakerConfig: breaker.NoopBreakerConfig(),
})
require.NoError(t, err)
defer clt.Close()

t.Run("reject old version", func(t *testing.T) {
version := teleport.MinClientSemVersion
version.Major--
ctx := context.WithValue(context.Background(), metadata.DisableInterceptors{}, struct{}{})
ctx = metadata.AddMetadataToContext(ctx, map[string]string{
metadata.VersionKey: version.String(),
})
resp, err := clt.Ping(ctx)
require.True(t, trace.IsConnectionProblem(err))
require.Equal(t, proto.PingResponse{}, resp)
})

t.Run("allow valid versions", func(t *testing.T) {
version := teleport.MinClientSemVersion
version.Major--
for i := 0; i < 5; i++ {
version.Major++

ctx := context.WithValue(context.Background(), metadata.DisableInterceptors{}, struct{}{})
ctx = metadata.AddMetadataToContext(ctx, map[string]string{
metadata.VersionKey: version.String(),
})
resp, err := clt.Ping(ctx)
require.NoError(t, err)
require.NotNil(t, resp)
}
})
}

// TestRemoteBuiltinRole tests remote builtin role
// that gets mapped to remote proxy readonly role
func TestRemoteBuiltinRole(t *testing.T) {
Expand Down
3 changes: 3 additions & 0 deletions lib/auth/transport_credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ type IdentityInfo struct {
// [TransportCredentialsConfig.Authorizer] provided to [NewTransportCredentials]
// was nil.
AuthContext *authz.Context
// Conn is the underlying [net.Conn] of the gRPC connection.
Conn net.Conn
}

// ServerHandshake does the authentication handshake for servers. It returns
Expand Down Expand Up @@ -179,6 +181,7 @@ func (c *TransportCredentials) ServerHandshake(rawConn net.Conn) (_ net.Conn, _
TLSInfo: tlsInfo,
IdentityGetter: identityGetter,
AuthContext: authCtx,
Conn: conn,
}, nil
}

Expand Down

0 comments on commit bf36d6a

Please sign in to comment.