Skip to content

Commit

Permalink
Remove calls to depreated pool.Subjects() method
Browse files Browse the repository at this point in the history
This deprecation was kind of a pain, because x509.CertPool becomes
a black box - there is no public API to determine how many certs
have been added to the pool. To account for this, some of our method
signatures needed to be updated to report the number of certs that
were added.
  • Loading branch information
zmb3 committed Apr 14, 2022
1 parent f6bb323 commit 663e3d0
Show file tree
Hide file tree
Showing 14 changed files with 52 additions and 50 deletions.
1 change: 0 additions & 1 deletion integration/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1563,7 +1563,6 @@ func twoClustersTunnel(t *testing.T, suite *integrationTestSuite, now time.Time,
require.NoError(t, werr)
ok := roots.AppendCertsFromPEM(buffer)
require.True(t, ok)
require.Len(t, roots.Subjects(), 2)

// wait for active tunnel connections to be established
waitForActiveTunnelConnections(t, b.Tunnel, a.Secrets.SiteName, 1)
Expand Down
12 changes: 3 additions & 9 deletions lib/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -3760,7 +3760,7 @@ func WithClusterCAs(tlsConfig *tls.Config, ap AccessCache, currentClusterName st
}
}
}
pool, err := DefaultClientCertPool(ap, clusterName)
pool, totalSubjectsLen, err := DefaultClientCertPool(ap, clusterName)
if err != nil {
log.WithError(err).Errorf("Failed to retrieve client pool for %q.", clusterName)
// this falls back to the default config
Expand All @@ -3779,16 +3779,10 @@ func WithClusterCAs(tlsConfig *tls.Config, ap AccessCache, currentClusterName st
// If the number of CAs turns out too large for the handshake, drop all but
// the current cluster CA. In the unlikely case where it's wrong, the
// client will be rejected.
var totalSubjectsLen int64
for _, s := range pool.Subjects() {
// Each subject in the list gets a separate 2-byte length prefix.
totalSubjectsLen += 2
totalSubjectsLen += int64(len(s))
}
if totalSubjectsLen >= int64(math.MaxUint16) {
log.Debugf("Number of CAs in client cert pool is too large (%d) and cannot be encoded in a TLS handshake; this is due to a large number of trusted clusters; will use only the CA of the current cluster to validate.", len(pool.Subjects()))
log.Debugf("Number of CAs in client cert pool is too large and cannot be encoded in a TLS handshake; this is due to a large number of trusted clusters; will use only the CA of the current cluster to validate.")

pool, err = DefaultClientCertPool(ap, currentClusterName)
pool, _, err = DefaultClientCertPool(ap, currentClusterName)
if err != nil {
log.WithError(err).Errorf("Failed to retrieve client pool for %q.", currentClusterName)
// this falls back to the default config
Expand Down
32 changes: 17 additions & 15 deletions lib/auth/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ func (t *TLSServer) GetConfigForClient(info *tls.ClientHelloInfo) (*tls.Config,
// certificate authorities.
// TODO(klizhentas) drop connections of the TLS cert authorities
// that are not trusted
pool, err := DefaultClientCertPool(t.cfg.AccessPoint, clusterName)
pool, totalSubjectsLen, err := DefaultClientCertPool(t.cfg.AccessPoint, clusterName)
if err != nil {
var ourClusterName string
if clusterName, err := t.cfg.AccessPoint.GetClusterName(); err == nil {
Expand All @@ -298,14 +298,8 @@ func (t *TLSServer) GetConfigForClient(info *tls.ClientHelloInfo) (*tls.Config,
// This may happen with a very large (>500) number of trusted clusters, if
// the client doesn't send the correct ServerName in its ClientHelloInfo
// (see the switch at the top of this func).
var totalSubjectsLen int64
for _, s := range pool.Subjects() {
// Each subject in the list gets a separate 2-byte length prefix.
totalSubjectsLen += 2
totalSubjectsLen += int64(len(s))
}
if totalSubjectsLen >= int64(math.MaxUint16) {
return nil, trace.BadParameter("number of CAs in client cert pool is too large (%d) and cannot be encoded in a TLS handshake; this is due to a large number of trusted clusters; try updating tsh to the latest version; if that doesn't help, remove some trusted clusters", len(pool.Subjects()))
return nil, trace.BadParameter("number of CAs in client cert pool is too large and cannot be encoded in a TLS handshake; this is due to a large number of trusted clusters; try updating tsh to the latest version; if that doesn't help, remove some trusted clusters")
}

tlsCopy := t.cfg.TLS.Clone()
Expand Down Expand Up @@ -617,9 +611,12 @@ func (a *Middleware) WrapContextWithUser(ctx context.Context, conn *tls.Conn) (c
}

// ClientCertPool returns trusted x509 certificate authority pool with CAs provided as caTypes.
func ClientCertPool(client AccessCache, clusterName string, caTypes ...types.CertAuthType) (*x509.CertPool, error) {
// In addition, it returns the total length of all subjects added to the cert pool, allowing
// the caller to validate that the pool doesn't exceed the maximum 2-byte length prefix before
// using it.
func ClientCertPool(client AccessCache, clusterName string, caTypes ...types.CertAuthType) (*x509.CertPool, int64, error) {
if len(caTypes) == 0 {
return nil, trace.BadParameter("at least one CA type is required")
return nil, 0, trace.BadParameter("at least one CA type is required")
}

ctx := context.TODO()
Expand All @@ -629,7 +626,7 @@ func ClientCertPool(client AccessCache, clusterName string, caTypes ...types.Cer
for _, caType := range caTypes {
cas, err := client.GetCertAuthorities(ctx, caType, false)
if err != nil {
return nil, trace.Wrap(err)
return nil, 0, trace.Wrap(err)
}
authorities = append(authorities, cas...)
}
Expand All @@ -640,27 +637,32 @@ func ClientCertPool(client AccessCache, clusterName string, caTypes ...types.Cer
types.CertAuthID{Type: caType, DomainName: clusterName},
false)
if err != nil {
return nil, trace.Wrap(err)
return nil, 0, trace.Wrap(err)
}

authorities = append(authorities, ca)
}
}

var totalSubjectsLen int64
for _, auth := range authorities {
for _, keyPair := range auth.GetTrustedTLSKeyPairs() {
cert, err := tlsca.ParseCertificatePEM(keyPair.Cert)
if err != nil {
return nil, trace.Wrap(err)
return nil, 0, trace.Wrap(err)
}
log.Debugf("ClientCertPool -> %v", CertInfo(cert))
pool.AddCert(cert)

// Each subject in the list gets a separate 2-byte length prefix.
totalSubjectsLen += 2
totalSubjectsLen += int64(len(cert.RawSubject))
}
}
return pool, nil
return pool, totalSubjectsLen, nil
}

// DefaultClientCertPool returns default trusted x509 certificate authority pool.
func DefaultClientCertPool(client AccessCache, clusterName string) (*x509.CertPool, error) {
func DefaultClientCertPool(client AccessCache, clusterName string) (*x509.CertPool, int64, error) {
return ClientCertPool(client, clusterName, types.HostCA, types.UserCA)
}
3 changes: 2 additions & 1 deletion lib/kube/proxy/forwarder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ func (s ForwarderSuite) TestRequestCertificate(c *check.C) {
c.Assert(err, check.IsNil)
// All fields except b.key are predictable.
c.Assert(b.Certificates[0].Certificate[0], check.DeepEquals, cl.lastCert.Raw)
c.Assert(len(b.RootCAs.Subjects()), check.Equals, 1)

// Check the KubeCSR fields.
c.Assert(cl.gotCSR.Username, check.DeepEquals, ctx.User.GetName())
Expand Down Expand Up @@ -708,6 +707,7 @@ func TestNewClusterSessionRemote(t *testing.T) {

// Make sure newClusterSession obtained a new client cert instead of using f.creds.
require.Equal(t, f.cfg.AuthClient.(*mockCSRClient).lastCert.Raw, sess.tlsConfig.Certificates[0].Certificate[0])
//lint:ignore SA1019 there's no non-deprecated public API for testing the contents of the RootCAs pool
require.Equal(t, [][]byte{f.cfg.AuthClient.(*mockCSRClient).ca.Cert.RawSubject}, sess.tlsConfig.RootCAs.Subjects())
require.Equal(t, 1, f.clientCredentials.Len())
}
Expand Down Expand Up @@ -755,6 +755,7 @@ func TestNewClusterSessionDirect(t *testing.T) {

// Make sure newClusterSession obtained a new client cert instead of using f.creds.
require.Equal(t, f.cfg.AuthClient.(*mockCSRClient).lastCert.Raw, sess.tlsConfig.Certificates[0].Certificate[0])
//lint:ignore SA1019 there's no non-deprecated public API for testing the contents of the RootCAs pool
require.Equal(t, [][]byte{f.cfg.AuthClient.(*mockCSRClient).ca.Cert.RawSubject}, sess.tlsConfig.RootCAs.Subjects())
require.Equal(t, 1, f.clientCredentials.Len())
}
Expand Down
2 changes: 1 addition & 1 deletion lib/service/desktop.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ func (process *TeleportProcess) initWindowsDesktopServiceRegistered(log *logrus.
log.Debugf("Ignoring unsupported cluster name %q.", info.ServerName)
}
}
pool, err := auth.DefaultClientCertPool(accessPoint, clusterName)
pool, _, err := auth.DefaultClientCertPool(accessPoint, clusterName)
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down
6 changes: 4 additions & 2 deletions lib/service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -2148,6 +2148,7 @@ func (process *TeleportProcess) initMetricsService() error {
return trace.BadParameter("no keypairs were provided for the metrics service with mtls enabled")
}

addedCerts := false
pool := x509.NewCertPool()
for _, caCertPath := range process.Config.Metrics.CACerts {
caCert, err := os.ReadFile(caCertPath)
Expand All @@ -2158,9 +2159,10 @@ func (process *TeleportProcess) initMetricsService() error {
if !pool.AppendCertsFromPEM(caCert) {
return trace.BadParameter("failed to parse prometheus CA certificate: %+v", caCertPath)
}
addedCerts = true
}

if len(pool.Subjects()) == 0 {
if !addedCerts {
return trace.BadParameter("no prometheus ca certs were provided for the metrics service with mtls enabled")
}

Expand Down Expand Up @@ -3358,7 +3360,7 @@ func (process *TeleportProcess) setupProxyTLSConfig(conn *Connector, tsrv revers
// order to be able to validate certificates provided by app
// access CLI clients.
var err error
tlsClone.ClientCAs, err = auth.DefaultClientCertPool(accessPoint, clusterName)
tlsClone.ClientCAs, _, err = auth.DefaultClientCertPool(accessPoint, clusterName)
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down
13 changes: 8 additions & 5 deletions lib/services/authority.go
Original file line number Diff line number Diff line change
Expand Up @@ -314,10 +314,12 @@ func (c *UserCertParams) CheckAndSetDefaults() error {
return nil
}

// CertPoolFromCertAuthorities returns certificate pools from TLS certificates
// set up in the certificate authorities list
func CertPoolFromCertAuthorities(cas []types.CertAuthority) (*x509.CertPool, error) {
// CertPoolFromCertAuthorities returns a certificate pool from the TLS certificates
// set up in the certificate authorities list, as well as the number of certificates
// that were added to the pool.
func CertPoolFromCertAuthorities(cas []types.CertAuthority) (*x509.CertPool, int, error) {
certPool := x509.NewCertPool()
count := 0
for _, ca := range cas {
keyPairs := ca.GetTrustedTLSKeyPairs()
if len(keyPairs) == 0 {
Expand All @@ -326,12 +328,13 @@ func CertPoolFromCertAuthorities(cas []types.CertAuthority) (*x509.CertPool, err
for _, keyPair := range keyPairs {
cert, err := tlsca.ParseCertificatePEM(keyPair.Cert)
if err != nil {
return nil, trace.Wrap(err)
return nil, 0, trace.Wrap(err)
}
certPool.AddCert(cert)
count++
}
}
return certPool, nil
return certPool, count, nil
}

// CertPool returns certificate pools from TLS certificates
Expand Down
20 changes: 12 additions & 8 deletions lib/services/authority_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,25 +80,29 @@ func TestCertPoolFromCertAuthorities(t *testing.T) {
require.NoError(t, err)

t.Run("ca1 with 1 cert", func(t *testing.T) {
pool, err := CertPoolFromCertAuthorities([]types.CertAuthority{ca1})
pool, count, err := CertPoolFromCertAuthorities([]types.CertAuthority{ca1})
require.NotNil(t, pool)
require.NoError(t, err)
require.Len(t, pool.Subjects(), 1)
require.Equal(t, 1, count)
})
t.Run("ca2 with 2 certs", func(t *testing.T) {
pool, err := CertPoolFromCertAuthorities([]types.CertAuthority{ca2})
pool, count, err := CertPoolFromCertAuthorities([]types.CertAuthority{ca2})
require.NotNil(t, pool)
require.NoError(t, err)
require.Len(t, pool.Subjects(), 2)
require.Equal(t, 2, count)
})
t.Run("ca3 with 1 cert", func(t *testing.T) {
pool, err := CertPoolFromCertAuthorities([]types.CertAuthority{ca3})
pool, count, err := CertPoolFromCertAuthorities([]types.CertAuthority{ca3})
require.NotNil(t, pool)
require.NoError(t, err)
require.Len(t, pool.Subjects(), 1)
require.Equal(t, 1, count)
})

t.Run("ca1 + ca2 + ca3 with 4 certs total", func(t *testing.T) {
pool, err := CertPoolFromCertAuthorities([]types.CertAuthority{ca1, ca2, ca3})
pool, count, err := CertPoolFromCertAuthorities([]types.CertAuthority{ca1, ca2, ca3})
require.NotNil(t, pool)
require.NoError(t, err)
require.Len(t, pool.Subjects(), 4)
require.Equal(t, 4, count)
})
}

Expand Down
2 changes: 1 addition & 1 deletion lib/srv/app/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -775,7 +775,7 @@ func (s *Server) getConfigForClient(info *tls.ClientHelloInfo) (*tls.Config, err

// Fetch list of CAs that could have signed this certificate. If clusterName
// is empty, all CAs that this cluster knows about are returned.
pool, err := auth.DefaultClientCertPool(s.c.AccessPoint, clusterName)
pool, _, err := auth.DefaultClientCertPool(s.c.AccessPoint, clusterName)
if err != nil {
// If this request fails, return nil and fallback to the default ClientCAs.
s.log.Debugf("Failed to retrieve client pool: %v.", trace.DebugReport(err))
Expand Down
2 changes: 1 addition & 1 deletion lib/srv/db/access_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1636,7 +1636,7 @@ func (c *testContext) makeTLSConfig(t *testing.T) *tls.Config {
conf := utils.TLSConfig(nil)
conf.Certificates = append(conf.Certificates, cert)
conf.ClientAuth = tls.VerifyClientCertIfGiven
conf.ClientCAs, err = auth.DefaultClientCertPool(c.authServer, c.clusterName)
conf.ClientCAs, _, err = auth.DefaultClientCertPool(c.authServer, c.clusterName)
require.NoError(t, err)
return conf
}
Expand Down
2 changes: 1 addition & 1 deletion lib/srv/db/proxyserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -688,7 +688,7 @@ func getConfigForClient(conf *tls.Config, ap auth.ReadDatabaseAccessPoint, log l
log.Debugf("Ignoring unsupported cluster name %q.", info.ServerName)
}
}
pool, err := auth.ClientCertPool(ap, clusterName, caTypes...)
pool, _, err := auth.ClientCertPool(ap, clusterName, caTypes...)
if err != nil {
log.WithError(err).Error("Failed to retrieve client CA pool.")
return nil, nil // Fall back to the default config.
Expand Down
2 changes: 1 addition & 1 deletion lib/utils/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ type TLSCredentials struct {
Cert []byte
}

// macMaxTLSCertValidityPeriod is the maximum validitiy period
// macMaxTLSCertValidityPeriod is the maximum validity period
// for a TLS certificate enforced by macOS.
// As of Go 1.18, certificates are validated via the system
// verifier and not in Go.
Expand Down
2 changes: 1 addition & 1 deletion lib/web/sessions.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ func (c *SessionContext) ClientTLSConfig(clusterName ...string) (*tls.Config, er
if err != nil {
return nil, trace.Wrap(err)
}
certPool, err = services.CertPoolFromCertAuthorities(certAuthorities)
certPool, _, err = services.CertPoolFromCertAuthorities(certAuthorities)
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down
3 changes: 0 additions & 3 deletions tool/tsh/tsh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -704,9 +704,6 @@ func TestIdentityRead(t *testing.T) {
conf, err := k.TeleportClientTLSConfig(nil, []string{"one"})
require.NoError(t, err)
require.NotNil(t, conf)

// ensure that at least root CA was successfully loaded
require.Greater(t, len(conf.RootCAs.Subjects()), 0)
}

func TestFormatConnectCommand(t *testing.T) {
Expand Down

0 comments on commit 663e3d0

Please sign in to comment.