Skip to content

Commit

Permalink
Improve CertAuthorityWatcher (#10403) (#12723)
Browse files Browse the repository at this point in the history
* Improve CertAuthorityWatcher

CertAuthorityWatcher and its usage are refactored to allow for
all the following:
 - eliminate retransmission of the same CAs
 - reduce memory usage by having one local watcher per proxy
 - adds the ability to filter only the CAs that are desired
 - reduce the time required to send the first CAs

watchCertAuthorities now compares all CAs it receives from the
watcher with the previous CA of the same type and only sends to
the remote site if they are not identical. This is to reduce
unnecessary network traffic which can be problematic for a
root cluster with a larger number of leafs.

The CertAuthorityWatcher is refactored to leverage a fanout
to emit events to any number of watchers, each subscription
can be for a subset of the configured CA types. The proxy
now has only one CertAuthorityWatcher that is passed around
similarly to the LockWatcher. This reduces the memory usage
for proxies, which prior to this has one local CAWatcher per
remote site.

updateCertAuthorities no longer waits on the utils.Retry it
is provided with before starting to watch CAs. By doing this
the proxy no longer has to wait ~8 minutes before it even
starts to watch CAs.

(cherry picked from commit 1ac0957)
  • Loading branch information
rosstimothy authored May 24, 2022
1 parent bf30d9c commit 73df3a3
Show file tree
Hide file tree
Showing 8 changed files with 373 additions and 246 deletions.
62 changes: 22 additions & 40 deletions integration/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4053,44 +4053,29 @@ func testRotateTrustedClusters(t *testing.T, suite *integrationTestSuite) {
require.NoError(t, err)

// waitForPhase waits until aux cluster detects the rotation
waitForPhase := func(phase string) error {
ctx, cancel := context.WithTimeout(context.Background(), tconf.PollingPeriod*10)
defer cancel()

watcher, err := services.NewCertAuthorityWatcher(ctx, services.CertAuthorityWatcherConfig{
ResourceWatcherConfig: services.ResourceWatcherConfig{
Component: teleport.ComponentProxy,
Clock: tconf.Clock,
Client: aux.GetSiteAPI(clusterAux),
},
WatchHostCA: true,
})
if err != nil {
return err
}
defer watcher.Close()
waitForPhase := func(phase string) {
require.Eventually(t, func() bool {
ca, err := aux.Process.GetAuthServer().GetCertAuthority(
ctx,
types.CertAuthID{
Type: types.HostCA,
DomainName: clusterMain,
},
false,
)
if err != nil {
return false
}

var lastPhase string
for i := 0; i < 10; i++ {
select {
case <-ctx.Done():
return trace.CompareFailed("failed to converge to phase %q, last phase %q", phase, lastPhase)
case cas := <-watcher.CertAuthorityC:
for _, ca := range cas {
if ca.GetClusterName() == clusterMain &&
ca.GetType() == types.HostCA &&
ca.GetRotation().Phase == phase {
return nil
}
lastPhase = ca.GetRotation().Phase
}
if ca.GetRotation().Phase == phase {
return true
}
}
return trace.CompareFailed("failed to converge to phase %q, last phase %q", phase, lastPhase)

return false
}, 30*time.Second, 250*time.Millisecond, "failed to converge to phase %q", phase)
}

err = waitForPhase(types.RotationPhaseInit)
require.NoError(t, err)
waitForPhase(types.RotationPhaseInit)

// update clients
err = svc.GetAuthServer().RotateCertAuthority(ctx, auth.RotateRequest{
Expand All @@ -4103,8 +4088,7 @@ func testRotateTrustedClusters(t *testing.T, suite *integrationTestSuite) {
svc, err = suite.waitForReload(serviceC, svc)
require.NoError(t, err)

err = waitForPhase(types.RotationPhaseUpdateClients)
require.NoError(t, err)
waitForPhase(types.RotationPhaseUpdateClients)

// old client should work as is
err = runAndMatch(clt, 8, []string{"echo", "hello world"}, ".*hello world.*")
Expand All @@ -4123,8 +4107,7 @@ func testRotateTrustedClusters(t *testing.T, suite *integrationTestSuite) {
svc, err = suite.waitForReload(serviceC, svc)
require.NoError(t, err)

err = waitForPhase(types.RotationPhaseUpdateServers)
require.NoError(t, err)
waitForPhase(types.RotationPhaseUpdateServers)

// new credentials will work from this phase to others
newCreds, err := GenerateUserCreds(UserCredsRequest{Process: svc, Username: suite.me.Username})
Expand Down Expand Up @@ -4152,8 +4135,7 @@ func testRotateTrustedClusters(t *testing.T, suite *integrationTestSuite) {
require.NoError(t, err)
t.Log("Service reload completed, waiting for phase.")

err = waitForPhase(types.RotationPhaseStandby)
require.NoError(t, err)
waitForPhase(types.RotationPhaseStandby)
t.Log("Phase completed.")

// new client still works
Expand Down
107 changes: 62 additions & 45 deletions lib/reversetunnel/remotesite.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ func (s *remoteSite) nextConn() (*remoteConn, error) {
for i := 0; i < len(s.connections); i++ {
s.lastUsed = (s.lastUsed + 1) % len(s.connections)
remoteConn := s.connections[s.lastUsed]
// connection could have been initated, but agent
// connection could have been initiated, but agent
// on the other side is not ready yet.
// Proxy assumes that connection is ready to serve when
// it has received a first heartbeat, otherwise
Expand Down Expand Up @@ -427,20 +427,12 @@ func (s *remoteSite) compareAndSwapCertAuthority(ca types.CertAuthority) error {
return trace.CompareFailed("remote certificate authority rotation has been updated")
}

func (s *remoteSite) updateCertAuthorities(retry utils.Retry) {
s.Debugf("Watching for cert authority changes.")
func (s *remoteSite) updateCertAuthorities(retry utils.Retry, remoteWatcher *services.CertAuthorityWatcher, remoteVersion string) {
defer remoteWatcher.Close()

cas := make(map[types.CertAuthType]types.CertAuthority)
for {
startedWaiting := s.clock.Now()
select {
case t := <-retry.After():
s.Debugf("Initiating new cert authority watch after waiting %v.", t.Sub(startedWaiting))
retry.Inc()
case <-s.ctx.Done():
return
}

err := s.watchCertAuthorities()
err := s.watchCertAuthorities(remoteWatcher, remoteVersion, cas)
if err != nil {
switch {
case trace.IsNotFound(err):
Expand All @@ -456,67 +448,92 @@ func (s *remoteSite) updateCertAuthorities(retry utils.Retry) {
}
}

startedWaiting := s.clock.Now()
select {
case t := <-retry.After():
s.Debugf("Initiating new cert authority watch after waiting %v.", t.Sub(startedWaiting))
retry.Inc()
case <-s.ctx.Done():
return
}
}
}

func (s *remoteSite) watchCertAuthorities() error {
localWatcher, err := services.NewCertAuthorityWatcher(s.ctx, services.CertAuthorityWatcherConfig{
ResourceWatcherConfig: services.ResourceWatcherConfig{
Component: teleport.ComponentProxy,
Log: s,
Clock: s.clock,
Client: s.localAccessPoint,
func (s *remoteSite) watchCertAuthorities(remoteWatcher *services.CertAuthorityWatcher, remoteVersion string, cas map[types.CertAuthType]types.CertAuthority) error {
localWatch, err := s.srv.CertAuthorityWatcher.Subscribe(
s.ctx,
services.CertAuthorityTarget{
Type: types.HostCA,
ClusterName: s.srv.ClusterName,
},
WatchUserCA: true,
WatchHostCA: true,
})
services.CertAuthorityTarget{
Type: types.UserCA,
ClusterName: s.srv.ClusterName,
})
if err != nil {
return trace.Wrap(err)
}
defer localWatcher.Close()
defer func() {
if err := localWatch.Close(); err != nil {
s.WithError(err).Warn("Failed to close local ca watcher subscription.")
}
}()

remoteWatcher, err := services.NewCertAuthorityWatcher(s.ctx, services.CertAuthorityWatcherConfig{
ResourceWatcherConfig: services.ResourceWatcherConfig{
Component: teleport.ComponentProxy,
Log: s,
Clock: s.clock,
Client: s.remoteAccessPoint,
remoteWatch, err := remoteWatcher.Subscribe(
s.ctx,
services.CertAuthorityTarget{
ClusterName: s.domainName,
Type: types.HostCA,
},
WatchHostCA: true,
})
)
if err != nil {
return trace.Wrap(err)
}
defer remoteWatcher.Close()
defer func() {
if err := remoteWatch.Close(); err != nil {
s.WithError(err).Warn("Failed to close remote ca watcher subscription.")
}
}()

s.Debugf("Watching for cert authority changes.")
for {
select {
case <-s.ctx.Done():
s.WithError(s.ctx.Err()).Debug("Context is closing.")
return trace.Wrap(s.ctx.Err())
case <-localWatcher.Done():
case <-localWatch.Done():
s.Warn("Local CertAuthority watcher subscription has closed")
return fmt.Errorf("local ca watcher for cluster %s has closed", s.srv.ClusterName)
case <-remoteWatcher.Done():
case <-remoteWatch.Done():
s.Warn("Remote CertAuthority watcher subscription has closed")
return fmt.Errorf("remote ca watcher for cluster %s has closed", s.domainName)
case cas := <-localWatcher.CertAuthorityC:
for _, localCA := range cas {
if localCA.GetClusterName() != s.srv.ClusterName ||
(localCA.GetType() != types.HostCA &&
localCA.GetType() != types.UserCA) {
case evt := <-localWatch.Events():
switch evt.Type {
case types.OpPut:
localCA, ok := evt.Resource.(types.CertAuthority)
if !ok {
continue
}

ca, ok := cas[localCA.GetType()]
if ok && services.CertAuthoritiesEquivalent(ca, localCA) {
continue
}

// clone to prevent a race with watcher filtering
localCA = localCA.Clone()
if err := s.remoteClient.RotateExternalCertAuthority(s.ctx, localCA); err != nil {
s.WithError(err).Warn("Failed to rotate external ca")
log.WithError(err).Warn("Failed to rotate external ca")
return trace.Wrap(err)
}

cas[localCA.GetType()] = localCA
}
case cas := <-remoteWatcher.CertAuthorityC:
for _, remoteCA := range cas {
if remoteCA.GetType() != types.HostCA ||
remoteCA.GetClusterName() != s.domainName {
case evt := <-remoteWatch.Events():
switch evt.Type {
case types.OpPut:
remoteCA, ok := evt.Resource.(types.CertAuthority)
if !ok {
continue
}

Expand Down
35 changes: 27 additions & 8 deletions lib/reversetunnel/srv.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,9 @@ type Config struct {

// NodeWatcher is a node watcher.
NodeWatcher *services.NodeWatcher

// CertAuthorityWatcher is a cert authority watcher.
CertAuthorityWatcher *services.CertAuthorityWatcher
}

// CheckAndSetDefaults checks parameters and sets default values
Expand Down Expand Up @@ -259,6 +262,9 @@ func (cfg *Config) CheckAndSetDefaults() error {
if cfg.NodeWatcher == nil {
return trace.BadParameter("missing parameter NodeWatcher")
}
if cfg.CertAuthorityWatcher == nil {
return trace.BadParameter("missing parameter CertAuthorityWatcher")
}
return nil
}

Expand Down Expand Up @@ -1039,6 +1045,11 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite,
connInfo.SetExpiry(srv.Clock.Now().Add(srv.offlineThreshold))

closeContext, cancel := context.WithCancel(srv.ctx)
defer func() {
if err != nil {
cancel()
}
}()
remoteSite := &remoteSite{
srv: srv,
domainName: domainName,
Expand All @@ -1062,20 +1073,17 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite,

clt, _, err := remoteSite.getRemoteClient()
if err != nil {
cancel()
return nil, trace.Wrap(err)
}
remoteSite.remoteClient = clt

remoteVersion, err := getRemoteAuthVersion(closeContext, sconn)
if err != nil {
cancel()
return nil, trace.Wrap(err)
}

accessPoint, err := createRemoteAccessPoint(srv, clt, remoteVersion, domainName)
if err != nil {
cancel()
return nil, trace.Wrap(err)
}
remoteSite.remoteAccessPoint = accessPoint
Expand All @@ -1087,7 +1095,6 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite,
},
})
if err != nil {
cancel()
return nil, trace.Wrap(err)
}
remoteSite.nodeWatcher = nodeWatcher
Expand All @@ -1097,7 +1104,6 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite,
// is signed by the correct certificate authority.
certificateCache, err := newHostCertificateCache(srv.Config.KeyGen, srv.localAuthClient)
if err != nil {
cancel()
return nil, trace.Wrap(err)
}
remoteSite.certificateCache = certificateCache
Expand All @@ -1110,11 +1116,25 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite,
Clock: srv.Clock,
})
if err != nil {
cancel()
return nil, trace.Wrap(err)
}

go remoteSite.updateCertAuthorities(caRetry)
remoteWatcher, err := services.NewCertAuthorityWatcher(srv.ctx, services.CertAuthorityWatcherConfig{
ResourceWatcherConfig: services.ResourceWatcherConfig{
Component: teleport.ComponentProxy,
Log: srv.log,
Clock: srv.Clock,
Client: remoteSite.remoteAccessPoint,
},
Types: []types.CertAuthType{types.HostCA},
})
if err != nil {
return nil, trace.Wrap(err)
}

go func() {
remoteSite.updateCertAuthorities(caRetry, remoteWatcher, remoteVersion)
}()

lockRetry, err := utils.NewLinear(utils.LinearConfig{
First: utils.HalfJitter(srv.Config.PollingPeriod),
Expand All @@ -1124,7 +1144,6 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite,
Clock: srv.Clock,
})
if err != nil {
cancel()
return nil, trace.Wrap(err)
}

Expand Down
Loading

0 comments on commit 73df3a3

Please sign in to comment.