Skip to content

Commit

Permalink
Improve CertAuthorityWatcher
Browse files Browse the repository at this point in the history
CertAuthorityWatcher and its usage are refactored to allow for
all the following:
 - eliminate retransmission of the same CAs
 - reduce memory usage by having one local watcher per proxy
 - adds the ability to filter only the CAs that are desired
 - reduce the time required to send the first CAs

watchCertAuthorities now compares all CAs it receives from the
watcher with the previous CA of the same type and only sends to
the remote site if they are not identical. This is to reduce
unnecessary network traffic which can be problematic for a
root cluster with a larger number of leafs.

The CertAuthorityWatcher is refactored to leverage a fanout
to emit events to any number of watchers, each subscription
can be for a subset of the configured CA types. The proxy
now has only one CertAuthorityWatcher that is passed around
similarly to the LockWatcher. This reduces the memory usage
for proxies, which prior to this has one local CAWatcher per
remote site.

updateCertAuthorities no longer waits on the utils.Retry it
is provided with before starting to watch CAs. By doing this
the proxy no longer has to wait ~8 minutes before it even
starts to watch CAs.
  • Loading branch information
rosstimothy committed May 11, 2022
1 parent 3fd2277 commit 5b12c90
Show file tree
Hide file tree
Showing 10 changed files with 405 additions and 228 deletions.
25 changes: 19 additions & 6 deletions integration/db_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ func (p *phaseWatcher) waitForPhase(phase string, fn func() error) error {
Clock: p.clock,
Client: p.siteAPI,
},
WatchCertTypes: []types.CertAuthType{p.certType},
Types: []types.CertAuthType{p.certType},
})
if err != nil {
return err
Expand All @@ -280,16 +280,29 @@ func (p *phaseWatcher) waitForPhase(phase string, fn func() error) error {
return trace.Wrap(err)
}

sub, err := watcher.Subscribe(ctx, services.CertAuthorityTarget{
ClusterName: p.clusterRootName,
Type: p.certType,
})
if err != nil {
return trace.Wrap(err)
}
defer sub.Close()

var lastPhase string
for i := 0; i < 10; i++ {
select {
case <-ctx.Done():
case <-sub.Done():
return trace.CompareFailed("failed to converge to phase %q, last phase %q certType: %v err: %v", phase, lastPhase, p.certType, ctx.Err())
case cas := <-watcher.CertAuthorityC:
for _, ca := range cas {
if ca.GetClusterName() == p.clusterRootName &&
ca.GetType() == p.certType &&
ca.GetRotation().Phase == phase {
case evt := <-sub.Events():
switch evt.Type {
case types.OpPut:
ca, ok := evt.Resource.(types.CertAuthority)
if !ok {
return trace.BadParameter("expected a ca got type %T", evt.Resource)
}
if ca.GetRotation().Phase == phase {
return nil
}
lastPhase = ca.GetRotation().Phase
Expand Down
71 changes: 31 additions & 40 deletions integration/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4097,45 +4097,39 @@ func testRotateTrustedClusters(t *testing.T, suite *integrationTestSuite) {
err = waitForProcessEvent(svc, service.TeleportPhaseChangeEvent, 10*time.Second)
require.NoError(t, err)

// waitForPhase waits until aux cluster detects the rotation
waitForPhase := func(phase string) error {
ctx, cancel := context.WithTimeout(context.Background(), tconf.PollingPeriod*10)
defer cancel()
watcher, err := services.NewCertAuthorityWatcher(ctx, services.CertAuthorityWatcherConfig{
ResourceWatcherConfig: services.ResourceWatcherConfig{
Component: teleport.ComponentProxy,
Clock: tconf.Clock,
Client: aux.GetSiteAPI(clusterAux),
},
Types: []types.CertAuthType{types.HostCA},
})
require.NoError(t, err)
t.Cleanup(watcher.Close)

watcher, err := services.NewCertAuthorityWatcher(ctx, services.CertAuthorityWatcherConfig{
ResourceWatcherConfig: services.ResourceWatcherConfig{
Component: teleport.ComponentProxy,
Clock: tconf.Clock,
Client: aux.GetSiteAPI(clusterAux),
},
WatchCertTypes: []types.CertAuthType{types.HostCA},
})
if err != nil {
return err
}
defer watcher.Close()
// waitForPhase waits until aux cluster detects the rotation
waitForPhase := func(phase string) {
require.Eventually(t, func() bool {
ca, err := aux.Process.GetAuthServer().GetCertAuthority(
ctx,
types.CertAuthID{
Type: types.HostCA,
DomainName: clusterMain,
}, false)
if err != nil {
return false
}

var lastPhase string
for i := 0; i < 10; i++ {
select {
case <-ctx.Done():
return trace.CompareFailed("failed to converge to phase %q, last phase %q", phase, lastPhase)
case cas := <-watcher.CertAuthorityC:
for _, ca := range cas {
if ca.GetClusterName() == clusterMain &&
ca.GetType() == types.HostCA &&
ca.GetRotation().Phase == phase {
return nil
}
lastPhase = ca.GetRotation().Phase
}
if ca.GetRotation().Phase == phase {
return true
}
}
return trace.CompareFailed("failed to converge to phase %q, last phase %q", phase, lastPhase)

return false
}, tconf.PollingPeriod*10, tconf.PollingPeriod/2, "failed to converge to phase %q", phase)
}

err = waitForPhase(types.RotationPhaseInit)
require.NoError(t, err)
waitForPhase(types.RotationPhaseInit)

// update clients
err = svc.GetAuthServer().RotateCertAuthority(ctx, auth.RotateRequest{
Expand All @@ -4148,8 +4142,7 @@ func testRotateTrustedClusters(t *testing.T, suite *integrationTestSuite) {
svc, err = waitForReload(serviceC, svc)
require.NoError(t, err)

err = waitForPhase(types.RotationPhaseUpdateClients)
require.NoError(t, err)
waitForPhase(types.RotationPhaseUpdateClients)

// old client should work as is
err = runAndMatch(clt, 8, []string{"echo", "hello world"}, ".*hello world.*")
Expand All @@ -4168,8 +4161,7 @@ func testRotateTrustedClusters(t *testing.T, suite *integrationTestSuite) {
svc, err = waitForReload(serviceC, svc)
require.NoError(t, err)

err = waitForPhase(types.RotationPhaseUpdateServers)
require.NoError(t, err)
waitForPhase(types.RotationPhaseUpdateServers)

// new credentials will work from this phase to others
newCreds, err := GenerateUserCreds(UserCredsRequest{Process: svc, Username: suite.me.Username})
Expand Down Expand Up @@ -4197,8 +4189,7 @@ func testRotateTrustedClusters(t *testing.T, suite *integrationTestSuite) {
require.NoError(t, err)
t.Log("Service reload completed, waiting for phase.")

err = waitForPhase(types.RotationPhaseStandby)
require.NoError(t, err)
waitForPhase(types.RotationPhaseStandby)
t.Log("Phase completed.")

// new client still works
Expand Down
113 changes: 65 additions & 48 deletions lib/reversetunnel/remotesite.go
Original file line number Diff line number Diff line change
Expand Up @@ -427,20 +427,10 @@ func (s *remoteSite) compareAndSwapCertAuthority(ca types.CertAuthority) error {
return trace.CompareFailed("remote certificate authority rotation has been updated")
}

func (s *remoteSite) updateCertAuthorities(retry utils.Retry, remoteClusterVersion string) {
s.Debugf("Watching for cert authority changes.")

func (s *remoteSite) updateCertAuthorities(retry utils.Retry, remoteWatcher *services.CertAuthorityWatcher, remoteVersion string) {
cas := make(map[types.CertAuthType]types.CertAuthority)
for {
startedWaiting := s.clock.Now()
select {
case t := <-retry.After():
s.Debugf("Initiating new cert authority watch after waiting %v.", t.Sub(startedWaiting))
retry.Inc()
case <-s.ctx.Done():
return
}

err := s.watchCertAuthorities(remoteClusterVersion)
err := s.watchCertAuthorities(remoteWatcher, remoteVersion, cas)
if err != nil {
switch {
case trace.IsNotFound(err):
Expand All @@ -456,70 +446,88 @@ func (s *remoteSite) updateCertAuthorities(retry utils.Retry, remoteClusterVersi
}
}

startedWaiting := s.clock.Now()
select {
case t := <-retry.After():
s.Debugf("Initiating new cert authority watch after waiting %v.", t.Sub(startedWaiting))
retry.Inc()
case <-s.ctx.Done():
return
}
}
}

func (s *remoteSite) watchCertAuthorities(remoteClusterVersion string) error {
localWatchedTypes, err := s.getLocalWatchedCerts(remoteClusterVersion)
func (s *remoteSite) watchCertAuthorities(remoteWatcher *services.CertAuthorityWatcher, remoteVersion string, cas map[types.CertAuthType]types.CertAuthority) error {
targets, err := s.getLocalWatchedCerts(remoteVersion)
if err != nil {
return trace.Wrap(err)
}

localWatcher, err := services.NewCertAuthorityWatcher(s.ctx, services.CertAuthorityWatcherConfig{
ResourceWatcherConfig: services.ResourceWatcherConfig{
Component: teleport.ComponentProxy,
Log: s,
Clock: s.clock,
Client: s.localAccessPoint,
},
WatchCertTypes: localWatchedTypes,
})
localWatch, err := s.srv.CertAuthorityWatcher.Subscribe(s.ctx, targets...)
if err != nil {
return trace.Wrap(err)
}
defer localWatcher.Close()
defer func() {
if err := localWatch.Close(); err != nil {
s.WithError(err).Warn("Failed to close local ca watcher subscription.")
}
}()

remoteWatcher, err := services.NewCertAuthorityWatcher(s.ctx, services.CertAuthorityWatcherConfig{
ResourceWatcherConfig: services.ResourceWatcherConfig{
Component: teleport.ComponentProxy,
Log: s,
Clock: s.clock,
Client: s.remoteAccessPoint,
remoteWatch, err := remoteWatcher.Subscribe(
s.ctx,
services.CertAuthorityTarget{
ClusterName: s.domainName,
Type: types.HostCA,
},
WatchCertTypes: []types.CertAuthType{types.HostCA},
})
)
if err != nil {
return trace.Wrap(err)
}
defer remoteWatcher.Close()
defer func() {
if err := remoteWatch.Close(); err != nil {
s.WithError(err).Warn("Failed to close remote ca watcher subscription.")
}
}()

s.Debugf("Watching for cert authority changes.")
for {
select {
case <-s.ctx.Done():
s.WithError(s.ctx.Err()).Debug("Context is closing.")
return trace.Wrap(s.ctx.Err())
case <-localWatcher.Done():
case <-localWatch.Done():
s.Warn("Local CertAuthority watcher subscription has closed")
return fmt.Errorf("local ca watcher for cluster %s has closed", s.srv.ClusterName)
case <-remoteWatcher.Done():
case <-remoteWatch.Done():
s.Warn("Remote CertAuthority watcher subscription has closed")
return fmt.Errorf("remote ca watcher for cluster %s has closed", s.domainName)
case cas := <-localWatcher.CertAuthorityC:
for _, localCA := range cas {
if localCA.GetClusterName() != s.srv.ClusterName ||
!localWatcher.IsWatched(localCA.GetType()) {
case evt := <-localWatch.Events():
switch evt.Type {
case types.OpPut:
localCA, ok := evt.Resource.(types.CertAuthority)
if !ok {
continue
}

ca, ok := cas[localCA.GetType()]
if ok && services.CertAuthoritiesEquivalent(ca, localCA) {
continue
}

// clone to prevent a race with watcher filtering
localCA = localCA.Clone()
if err := s.remoteClient.RotateExternalCertAuthority(s.ctx, localCA); err != nil {
s.WithError(err).Warn("Failed to rotate external ca")
log.WithError(err).Warn("Failed to rotate external ca")
return trace.Wrap(err)
}

cas[localCA.GetType()] = localCA
}
case cas := <-remoteWatcher.CertAuthorityC:
for _, remoteCA := range cas {
if remoteCA.GetType() != types.HostCA ||
remoteCA.GetClusterName() != s.domainName {
case evt := <-remoteWatch.Events():
switch evt.Type {
case types.OpPut:
remoteCA, ok := evt.Resource.(types.CertAuthority)
if !ok {
continue
}

Expand Down Expand Up @@ -549,8 +557,17 @@ func (s *remoteSite) watchCertAuthorities(remoteClusterVersion string) error {
}

// getLocalWatchedCerts returns local certificates types that should be watched by the cert authority watcher.
func (s *remoteSite) getLocalWatchedCerts(remoteClusterVersion string) ([]types.CertAuthType, error) {
localWatchedTypes := []types.CertAuthType{types.HostCA, types.UserCA}
func (s *remoteSite) getLocalWatchedCerts(remoteClusterVersion string) ([]services.CertAuthorityTarget, error) {
localWatchedTypes := []services.CertAuthorityTarget{
{
Type: types.HostCA,
ClusterName: s.srv.ClusterName,
},
{
Type: types.UserCA,
ClusterName: s.srv.ClusterName,
},
}

// Delete in 11.0.
ver10orAbove, err := utils.MinVerWithoutPreRelease(remoteClusterVersion, constants.DatabaseCAMinVersion)
Expand All @@ -559,7 +576,7 @@ func (s *remoteSite) getLocalWatchedCerts(remoteClusterVersion string) ([]types.
}

if ver10orAbove {
localWatchedTypes = append(localWatchedTypes, types.DatabaseCA)
localWatchedTypes = append(localWatchedTypes, services.CertAuthorityTarget{ClusterName: s.srv.ClusterName, Type: types.DatabaseCA})
} else {
s.Debugf("Connected to remote cluster of version %s. Database CA won't be propagated.", remoteClusterVersion)
}
Expand Down
29 changes: 22 additions & 7 deletions lib/reversetunnel/remotesite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"testing"

"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/utils"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
Expand All @@ -31,34 +32,48 @@ func Test_remoteSite_getLocalWatchedCerts(t *testing.T) {
tests := []struct {
name string
clusterVersion string
want []types.CertAuthType
wantErr bool
want []services.CertAuthorityTarget
errorAssertion require.ErrorAssertionFunc
}{
{
name: "pre Database CA, only Host and User CA",
clusterVersion: "9.0.0",
want: []types.CertAuthType{types.HostCA, types.UserCA},
want: []services.CertAuthorityTarget{
{Type: types.HostCA, ClusterName: "test"},
{Type: types.UserCA, ClusterName: "test"},
},
errorAssertion: require.NoError,
},
{
name: "all certs should be returned",
clusterVersion: "10.0.0",
want: []types.CertAuthType{types.HostCA, types.UserCA, types.DatabaseCA},
want: []services.CertAuthorityTarget{
{Type: types.HostCA, ClusterName: "test"},
{Type: types.UserCA, ClusterName: "test"},
{Type: types.DatabaseCA, ClusterName: "test"},
},
errorAssertion: require.NoError,
},
{
name: "invalid version",
clusterVersion: "foo",
wantErr: true,
errorAssertion: require.Error,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &remoteSite{
srv: &server{
Config: Config{
ClusterName: "test",
},
},
Entry: log.NewEntry(utils.NewLoggerForTests()),
}
got, err := s.getLocalWatchedCerts(tt.clusterVersion)
if (err != nil) != tt.wantErr {
t.Errorf("getLocalWatchedCerts() error = %v, wantErr %v", err, tt.wantErr)
tt.errorAssertion(t, err)
if err != nil {
return
}

Expand Down
Loading

0 comments on commit 5b12c90

Please sign in to comment.