From 5b12c9043588e69fd6258e7c45cba8e70863180f Mon Sep 17 00:00:00 2001 From: Tim Ross Date: Wed, 11 May 2022 11:57:30 -0400 Subject: [PATCH] Improve CertAuthorityWatcher CertAuthorityWatcher and its usage are refactored to allow for all the following: - eliminate retransmission of the same CAs - reduce memory usage by having one local watcher per proxy - adds the ability to filter only the CAs that are desired - reduce the time required to send the first CAs watchCertAuthorities now compares all CAs it receives from the watcher with the previous CA of the same type and only sends to the remote site if they are not identical. This is to reduce unnecessary network traffic which can be problematic for a root cluster with a larger number of leafs. The CertAuthorityWatcher is refactored to leverage a fanout to emit events to any number of watchers, each subscription can be for a subset of the configured CA types. The proxy now has only one CertAuthorityWatcher that is passed around similarly to the LockWatcher. This reduces the memory usage for proxies, which prior to this has one local CAWatcher per remote site. updateCertAuthorities no longer waits on the utils.Retry it is provided with before starting to watch CAs. By doing this the proxy no longer has to wait ~8 minutes before it even starts to watch CAs. --- integration/db_integration_test.go | 25 ++++- integration/integration_test.go | 71 ++++++------ lib/reversetunnel/remotesite.go | 113 +++++++++++-------- lib/reversetunnel/remotesite_test.go | 29 +++-- lib/reversetunnel/srv.go | 24 +++- lib/service/service.go | 36 ++++-- lib/services/watcher.go | 139 ++++++++++++++++-------- lib/services/watcher_test.go | 157 +++++++++++++++------------ lib/srv/regular/sshserver_test.go | 17 +++ lib/web/apiserver_test.go | 22 ++++ 10 files changed, 405 insertions(+), 228 deletions(-) diff --git a/integration/db_integration_test.go b/integration/db_integration_test.go index 36778c9f980d2..28187738d1d5b 100644 --- a/integration/db_integration_test.go +++ b/integration/db_integration_test.go @@ -269,7 +269,7 @@ func (p *phaseWatcher) waitForPhase(phase string, fn func() error) error { Clock: p.clock, Client: p.siteAPI, }, - WatchCertTypes: []types.CertAuthType{p.certType}, + Types: []types.CertAuthType{p.certType}, }) if err != nil { return err @@ -280,16 +280,29 @@ func (p *phaseWatcher) waitForPhase(phase string, fn func() error) error { return trace.Wrap(err) } + sub, err := watcher.Subscribe(ctx, services.CertAuthorityTarget{ + ClusterName: p.clusterRootName, + Type: p.certType, + }) + if err != nil { + return trace.Wrap(err) + } + defer sub.Close() + var lastPhase string for i := 0; i < 10; i++ { select { case <-ctx.Done(): + case <-sub.Done(): return trace.CompareFailed("failed to converge to phase %q, last phase %q certType: %v err: %v", phase, lastPhase, p.certType, ctx.Err()) - case cas := <-watcher.CertAuthorityC: - for _, ca := range cas { - if ca.GetClusterName() == p.clusterRootName && - ca.GetType() == p.certType && - ca.GetRotation().Phase == phase { + case evt := <-sub.Events(): + switch evt.Type { + case types.OpPut: + ca, ok := evt.Resource.(types.CertAuthority) + if !ok { + return trace.BadParameter("expected a ca got type %T", evt.Resource) + } + if ca.GetRotation().Phase == phase { return nil } lastPhase = ca.GetRotation().Phase diff --git a/integration/integration_test.go b/integration/integration_test.go index 0c6a192f77e53..5c67f51128041 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -4097,45 +4097,39 @@ func testRotateTrustedClusters(t *testing.T, suite *integrationTestSuite) { err = waitForProcessEvent(svc, service.TeleportPhaseChangeEvent, 10*time.Second) require.NoError(t, err) - // waitForPhase waits until aux cluster detects the rotation - waitForPhase := func(phase string) error { - ctx, cancel := context.WithTimeout(context.Background(), tconf.PollingPeriod*10) - defer cancel() + watcher, err := services.NewCertAuthorityWatcher(ctx, services.CertAuthorityWatcherConfig{ + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: teleport.ComponentProxy, + Clock: tconf.Clock, + Client: aux.GetSiteAPI(clusterAux), + }, + Types: []types.CertAuthType{types.HostCA}, + }) + require.NoError(t, err) + t.Cleanup(watcher.Close) - watcher, err := services.NewCertAuthorityWatcher(ctx, services.CertAuthorityWatcherConfig{ - ResourceWatcherConfig: services.ResourceWatcherConfig{ - Component: teleport.ComponentProxy, - Clock: tconf.Clock, - Client: aux.GetSiteAPI(clusterAux), - }, - WatchCertTypes: []types.CertAuthType{types.HostCA}, - }) - if err != nil { - return err - } - defer watcher.Close() + // waitForPhase waits until aux cluster detects the rotation + waitForPhase := func(phase string) { + require.Eventually(t, func() bool { + ca, err := aux.Process.GetAuthServer().GetCertAuthority( + ctx, + types.CertAuthID{ + Type: types.HostCA, + DomainName: clusterMain, + }, false) + if err != nil { + return false + } - var lastPhase string - for i := 0; i < 10; i++ { - select { - case <-ctx.Done(): - return trace.CompareFailed("failed to converge to phase %q, last phase %q", phase, lastPhase) - case cas := <-watcher.CertAuthorityC: - for _, ca := range cas { - if ca.GetClusterName() == clusterMain && - ca.GetType() == types.HostCA && - ca.GetRotation().Phase == phase { - return nil - } - lastPhase = ca.GetRotation().Phase - } + if ca.GetRotation().Phase == phase { + return true } - } - return trace.CompareFailed("failed to converge to phase %q, last phase %q", phase, lastPhase) + + return false + }, tconf.PollingPeriod*10, tconf.PollingPeriod/2, "failed to converge to phase %q", phase) } - err = waitForPhase(types.RotationPhaseInit) - require.NoError(t, err) + waitForPhase(types.RotationPhaseInit) // update clients err = svc.GetAuthServer().RotateCertAuthority(ctx, auth.RotateRequest{ @@ -4148,8 +4142,7 @@ func testRotateTrustedClusters(t *testing.T, suite *integrationTestSuite) { svc, err = waitForReload(serviceC, svc) require.NoError(t, err) - err = waitForPhase(types.RotationPhaseUpdateClients) - require.NoError(t, err) + waitForPhase(types.RotationPhaseUpdateClients) // old client should work as is err = runAndMatch(clt, 8, []string{"echo", "hello world"}, ".*hello world.*") @@ -4168,8 +4161,7 @@ func testRotateTrustedClusters(t *testing.T, suite *integrationTestSuite) { svc, err = waitForReload(serviceC, svc) require.NoError(t, err) - err = waitForPhase(types.RotationPhaseUpdateServers) - require.NoError(t, err) + waitForPhase(types.RotationPhaseUpdateServers) // new credentials will work from this phase to others newCreds, err := GenerateUserCreds(UserCredsRequest{Process: svc, Username: suite.me.Username}) @@ -4197,8 +4189,7 @@ func testRotateTrustedClusters(t *testing.T, suite *integrationTestSuite) { require.NoError(t, err) t.Log("Service reload completed, waiting for phase.") - err = waitForPhase(types.RotationPhaseStandby) - require.NoError(t, err) + waitForPhase(types.RotationPhaseStandby) t.Log("Phase completed.") // new client still works diff --git a/lib/reversetunnel/remotesite.go b/lib/reversetunnel/remotesite.go index 7cac1eb2c8ce2..1254fb1fb40b3 100644 --- a/lib/reversetunnel/remotesite.go +++ b/lib/reversetunnel/remotesite.go @@ -427,20 +427,10 @@ func (s *remoteSite) compareAndSwapCertAuthority(ca types.CertAuthority) error { return trace.CompareFailed("remote certificate authority rotation has been updated") } -func (s *remoteSite) updateCertAuthorities(retry utils.Retry, remoteClusterVersion string) { - s.Debugf("Watching for cert authority changes.") - +func (s *remoteSite) updateCertAuthorities(retry utils.Retry, remoteWatcher *services.CertAuthorityWatcher, remoteVersion string) { + cas := make(map[types.CertAuthType]types.CertAuthority) for { - startedWaiting := s.clock.Now() - select { - case t := <-retry.After(): - s.Debugf("Initiating new cert authority watch after waiting %v.", t.Sub(startedWaiting)) - retry.Inc() - case <-s.ctx.Done(): - return - } - - err := s.watchCertAuthorities(remoteClusterVersion) + err := s.watchCertAuthorities(remoteWatcher, remoteVersion, cas) if err != nil { switch { case trace.IsNotFound(err): @@ -456,70 +446,88 @@ func (s *remoteSite) updateCertAuthorities(retry utils.Retry, remoteClusterVersi } } + startedWaiting := s.clock.Now() + select { + case t := <-retry.After(): + s.Debugf("Initiating new cert authority watch after waiting %v.", t.Sub(startedWaiting)) + retry.Inc() + case <-s.ctx.Done(): + return + } } } -func (s *remoteSite) watchCertAuthorities(remoteClusterVersion string) error { - localWatchedTypes, err := s.getLocalWatchedCerts(remoteClusterVersion) +func (s *remoteSite) watchCertAuthorities(remoteWatcher *services.CertAuthorityWatcher, remoteVersion string, cas map[types.CertAuthType]types.CertAuthority) error { + targets, err := s.getLocalWatchedCerts(remoteVersion) if err != nil { return trace.Wrap(err) } - localWatcher, err := services.NewCertAuthorityWatcher(s.ctx, services.CertAuthorityWatcherConfig{ - ResourceWatcherConfig: services.ResourceWatcherConfig{ - Component: teleport.ComponentProxy, - Log: s, - Clock: s.clock, - Client: s.localAccessPoint, - }, - WatchCertTypes: localWatchedTypes, - }) + localWatch, err := s.srv.CertAuthorityWatcher.Subscribe(s.ctx, targets...) if err != nil { return trace.Wrap(err) } - defer localWatcher.Close() + defer func() { + if err := localWatch.Close(); err != nil { + s.WithError(err).Warn("Failed to close local ca watcher subscription.") + } + }() - remoteWatcher, err := services.NewCertAuthorityWatcher(s.ctx, services.CertAuthorityWatcherConfig{ - ResourceWatcherConfig: services.ResourceWatcherConfig{ - Component: teleport.ComponentProxy, - Log: s, - Clock: s.clock, - Client: s.remoteAccessPoint, + remoteWatch, err := remoteWatcher.Subscribe( + s.ctx, + services.CertAuthorityTarget{ + ClusterName: s.domainName, + Type: types.HostCA, }, - WatchCertTypes: []types.CertAuthType{types.HostCA}, - }) + ) if err != nil { return trace.Wrap(err) } - defer remoteWatcher.Close() + defer func() { + if err := remoteWatch.Close(); err != nil { + s.WithError(err).Warn("Failed to close remote ca watcher subscription.") + } + }() + s.Debugf("Watching for cert authority changes.") for { select { case <-s.ctx.Done(): s.WithError(s.ctx.Err()).Debug("Context is closing.") return trace.Wrap(s.ctx.Err()) - case <-localWatcher.Done(): + case <-localWatch.Done(): s.Warn("Local CertAuthority watcher subscription has closed") return fmt.Errorf("local ca watcher for cluster %s has closed", s.srv.ClusterName) - case <-remoteWatcher.Done(): + case <-remoteWatch.Done(): s.Warn("Remote CertAuthority watcher subscription has closed") return fmt.Errorf("remote ca watcher for cluster %s has closed", s.domainName) - case cas := <-localWatcher.CertAuthorityC: - for _, localCA := range cas { - if localCA.GetClusterName() != s.srv.ClusterName || - !localWatcher.IsWatched(localCA.GetType()) { + case evt := <-localWatch.Events(): + switch evt.Type { + case types.OpPut: + localCA, ok := evt.Resource.(types.CertAuthority) + if !ok { continue } + ca, ok := cas[localCA.GetType()] + if ok && services.CertAuthoritiesEquivalent(ca, localCA) { + continue + } + + // clone to prevent a race with watcher filtering + localCA = localCA.Clone() if err := s.remoteClient.RotateExternalCertAuthority(s.ctx, localCA); err != nil { - s.WithError(err).Warn("Failed to rotate external ca") + log.WithError(err).Warn("Failed to rotate external ca") return trace.Wrap(err) } + + cas[localCA.GetType()] = localCA } - case cas := <-remoteWatcher.CertAuthorityC: - for _, remoteCA := range cas { - if remoteCA.GetType() != types.HostCA || - remoteCA.GetClusterName() != s.domainName { + case evt := <-remoteWatch.Events(): + switch evt.Type { + case types.OpPut: + remoteCA, ok := evt.Resource.(types.CertAuthority) + if !ok { continue } @@ -549,8 +557,17 @@ func (s *remoteSite) watchCertAuthorities(remoteClusterVersion string) error { } // getLocalWatchedCerts returns local certificates types that should be watched by the cert authority watcher. -func (s *remoteSite) getLocalWatchedCerts(remoteClusterVersion string) ([]types.CertAuthType, error) { - localWatchedTypes := []types.CertAuthType{types.HostCA, types.UserCA} +func (s *remoteSite) getLocalWatchedCerts(remoteClusterVersion string) ([]services.CertAuthorityTarget, error) { + localWatchedTypes := []services.CertAuthorityTarget{ + { + Type: types.HostCA, + ClusterName: s.srv.ClusterName, + }, + { + Type: types.UserCA, + ClusterName: s.srv.ClusterName, + }, + } // Delete in 11.0. ver10orAbove, err := utils.MinVerWithoutPreRelease(remoteClusterVersion, constants.DatabaseCAMinVersion) @@ -559,7 +576,7 @@ func (s *remoteSite) getLocalWatchedCerts(remoteClusterVersion string) ([]types. } if ver10orAbove { - localWatchedTypes = append(localWatchedTypes, types.DatabaseCA) + localWatchedTypes = append(localWatchedTypes, services.CertAuthorityTarget{ClusterName: s.srv.ClusterName, Type: types.DatabaseCA}) } else { s.Debugf("Connected to remote cluster of version %s. Database CA won't be propagated.", remoteClusterVersion) } diff --git a/lib/reversetunnel/remotesite_test.go b/lib/reversetunnel/remotesite_test.go index 734f1a311ebc0..77f17dfb23276 100644 --- a/lib/reversetunnel/remotesite_test.go +++ b/lib/reversetunnel/remotesite_test.go @@ -22,6 +22,7 @@ import ( "testing" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/utils" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/require" @@ -31,34 +32,48 @@ func Test_remoteSite_getLocalWatchedCerts(t *testing.T) { tests := []struct { name string clusterVersion string - want []types.CertAuthType - wantErr bool + want []services.CertAuthorityTarget + errorAssertion require.ErrorAssertionFunc }{ { name: "pre Database CA, only Host and User CA", clusterVersion: "9.0.0", - want: []types.CertAuthType{types.HostCA, types.UserCA}, + want: []services.CertAuthorityTarget{ + {Type: types.HostCA, ClusterName: "test"}, + {Type: types.UserCA, ClusterName: "test"}, + }, + errorAssertion: require.NoError, }, { name: "all certs should be returned", clusterVersion: "10.0.0", - want: []types.CertAuthType{types.HostCA, types.UserCA, types.DatabaseCA}, + want: []services.CertAuthorityTarget{ + {Type: types.HostCA, ClusterName: "test"}, + {Type: types.UserCA, ClusterName: "test"}, + {Type: types.DatabaseCA, ClusterName: "test"}, + }, + errorAssertion: require.NoError, }, { name: "invalid version", clusterVersion: "foo", - wantErr: true, + errorAssertion: require.Error, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { s := &remoteSite{ + srv: &server{ + Config: Config{ + ClusterName: "test", + }, + }, Entry: log.NewEntry(utils.NewLoggerForTests()), } got, err := s.getLocalWatchedCerts(tt.clusterVersion) - if (err != nil) != tt.wantErr { - t.Errorf("getLocalWatchedCerts() error = %v, wantErr %v", err, tt.wantErr) + tt.errorAssertion(t, err) + if err != nil { return } diff --git a/lib/reversetunnel/srv.go b/lib/reversetunnel/srv.go index 2380b79cf3ee5..a7226fd59ddaf 100644 --- a/lib/reversetunnel/srv.go +++ b/lib/reversetunnel/srv.go @@ -205,6 +205,9 @@ type Config struct { // NodeWatcher is a node watcher. NodeWatcher *services.NodeWatcher + + // CertAuthorityWatcher is a cert authority watcher. + CertAuthorityWatcher *services.CertAuthorityWatcher } // CheckAndSetDefaults checks parameters and sets default values @@ -259,6 +262,9 @@ func (cfg *Config) CheckAndSetDefaults() error { if cfg.NodeWatcher == nil { return trace.BadParameter("missing parameter NodeWatcher") } + if cfg.CertAuthorityWatcher == nil { + return trace.BadParameter("missing parameter CertAuthorityWatcher") + } return nil } @@ -1115,7 +1121,23 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite, return nil, trace.Wrap(err) } - go remoteSite.updateCertAuthorities(caRetry, remoteVersion) + remoteWatcher, err := services.NewCertAuthorityWatcher(srv.ctx, services.CertAuthorityWatcherConfig{ + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: teleport.ComponentProxy, + Log: srv.log, + Clock: srv.Clock, + Client: remoteSite.remoteAccessPoint, + }, + Types: []types.CertAuthType{types.HostCA}, + }) + if err != nil { + return nil, trace.Wrap(err) + } + + go func() { + defer remoteWatcher.Close() + remoteSite.updateCertAuthorities(caRetry, remoteWatcher, remoteVersion) + }() lockRetry, err := utils.NewLinear(utils.LinearConfig{ First: utils.HalfJitter(srv.Config.PollingPeriod), diff --git a/lib/service/service.go b/lib/service/service.go index dff09093260a2..a4b340aa642c3 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -2849,6 +2849,19 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { return trace.Wrap(err) } + caWatcher, err := services.NewCertAuthorityWatcher(process.ExitContext(), services.CertAuthorityWatcherConfig{ + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: teleport.ComponentProxy, + Log: process.log.WithField(trace.Component, teleport.ComponentProxy), + Client: conn.Client, + }, + AuthorityGetter: accessPoint, + Types: []types.CertAuthType{types.HostCA, types.UserCA, types.DatabaseCA}, + }) + if err != nil { + return trace.Wrap(err) + } + serverTLSConfig, err := conn.ServerIdentity.TLSConfig(cfg.CipherSuites) if err != nil { return trace.Wrap(err) @@ -2878,17 +2891,18 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { Client: conn.Client, }, }, - KeyGen: cfg.Keygen, - Ciphers: cfg.Ciphers, - KEXAlgorithms: cfg.KEXAlgorithms, - MACAlgorithms: cfg.MACAlgorithms, - DataDir: process.Config.DataDir, - PollingPeriod: process.Config.PollingPeriod, - FIPS: cfg.FIPS, - Emitter: streamEmitter, - Log: process.log, - LockWatcher: lockWatcher, - NodeWatcher: nodeWatcher, + KeyGen: cfg.Keygen, + Ciphers: cfg.Ciphers, + KEXAlgorithms: cfg.KEXAlgorithms, + MACAlgorithms: cfg.MACAlgorithms, + DataDir: process.Config.DataDir, + PollingPeriod: process.Config.PollingPeriod, + FIPS: cfg.FIPS, + Emitter: streamEmitter, + Log: process.log, + LockWatcher: lockWatcher, + NodeWatcher: nodeWatcher, + CertAuthorityWatcher: caWatcher, }) if err != nil { return trace.Wrap(err) diff --git a/lib/services/watcher.go b/lib/services/watcher.go index d694868781630..afcdb575c8bd6 100644 --- a/lib/services/watcher.go +++ b/lib/services/watcher.go @@ -919,10 +919,8 @@ type CertAuthorityWatcherConfig struct { ResourceWatcherConfig // AuthorityGetter is responsible for fetching cert authority resources. AuthorityGetter - // CertAuthorityC receives up-to-date list of all cert authority resources. - CertAuthorityC chan []types.CertAuthority - // WatchCertTypes stores all certificate types that should be monitored. - WatchCertTypes []types.CertAuthType + // Types restricts which cert authority types are retrieved via the AuthorityGetter. + Types []types.CertAuthType } // CheckAndSetDefaults checks parameters and sets default values. @@ -937,15 +935,12 @@ func (cfg *CertAuthorityWatcherConfig) CheckAndSetDefaults() error { } cfg.AuthorityGetter = getter } - if cfg.CertAuthorityC == nil { - cfg.CertAuthorityC = make(chan []types.CertAuthority) - } return nil } // IsWatched return true if the given certificate auth type is being observer by the watcher. func (cfg *CertAuthorityWatcherConfig) IsWatched(certType types.CertAuthType) bool { - for _, observedType := range cfg.WatchCertTypes { + for _, observedType := range cfg.Types { if observedType == certType { return true } @@ -961,6 +956,12 @@ func NewCertAuthorityWatcher(ctx context.Context, cfg CertAuthorityWatcherConfig collector := &caCollector{ CertAuthorityWatcherConfig: cfg, + fanout: NewFanout(), + cas: make(map[types.CertAuthType]map[string]types.CertAuthority, len(cfg.Types)), + } + + for _, t := range cfg.Types { + collector.cas[t] = make(map[string]types.CertAuthority) } watcher, err := newResourceWatcher(ctx, collector, cfg.ResourceWatcherConfig) @@ -968,6 +969,7 @@ func NewCertAuthorityWatcher(ctx context.Context, cfg CertAuthorityWatcherConfig return nil, trace.Wrap(err) } + collector.fanout.SetInit() return &CertAuthorityWatcher{watcher, collector}, nil } @@ -980,9 +982,63 @@ type CertAuthorityWatcher struct { // caCollector accompanies resourceWatcher when monitoring cert authority resources. type caCollector struct { CertAuthorityWatcherConfig + lock sync.RWMutex + cas map[types.CertAuthType]map[string]types.CertAuthority + fanout *Fanout +} + +// CertAuthorityTarget lists the attributes of interactions to be disabled. +type CertAuthorityTarget struct { + // ClusterName specifies the name of the cluster to watch. + ClusterName string + // Type specifies the ca types to watch for. + Type types.CertAuthType +} + +// Subscribe is used to subscribe to the lock updates. +func (c *caCollector) Subscribe(ctx context.Context, targets ...CertAuthorityTarget) (types.Watcher, error) { + watchKinds, err := caTargetToWatchKinds(targets) + if err != nil { + return nil, trace.Wrap(err) + } + sub, err := c.fanout.NewWatcher(ctx, types.Watch{Kinds: watchKinds}) + if err != nil { + return nil, trace.Wrap(err) + } + select { + case event := <-sub.Events(): + if event.Type != types.OpInit { + return nil, trace.BadParameter("expected init event, got %v instead", event.Type) + } + case <-sub.Done(): + return nil, trace.Wrap(sub.Error()) + } + return sub, nil +} + +func caTargetToWatchKinds(targets []CertAuthorityTarget) ([]types.WatchKind, error) { + watchKinds := make([]types.WatchKind, 0, len(targets)) + for _, target := range targets { + kind := types.WatchKind{ + Kind: types.KindCertAuthority, + // Note that watching SubKind doesn't work for types.WatchKind - to do so it would + // require a custom filter, which was recently added but - we can't use yet due to + // older clients not supporting the filter. + SubKind: string(target.Type), + } + + if target.ClusterName != "" { + kind.Name = target.ClusterName + } + + watchKinds = append(watchKinds, kind) + } + + if len(watchKinds) == 0 { + watchKinds = []types.WatchKind{{Kind: types.KindCertAuthority}} + } - collectedCAs CertAuthorityTypeMap - lock sync.RWMutex + return watchKinds, nil } // CertAuthorityMap maps clusterName -> types.CertAuthority @@ -1009,28 +1065,23 @@ func (c *caCollector) resourceKind() string { // getResourcesAndUpdateCurrent refreshes the list of current resources. func (c *caCollector) getResourcesAndUpdateCurrent(ctx context.Context) error { - updatedCerts := make(CertAuthorityTypeMap) + var cas []types.CertAuthority - for _, caType := range c.WatchCertTypes { - cas, err := c.AuthorityGetter.GetCertAuthorities(ctx, caType, false) + for _, t := range c.Types { + authorities, err := c.AuthorityGetter.GetCertAuthorities(ctx, t, false) if err != nil { return trace.Wrap(err) } - updatedCerts[caType] = make(CertAuthorityMap, len(cas)) - for _, ca := range cas { - updatedCerts[caType][ca.GetName()] = ca - } + cas = append(cas, authorities...) } c.lock.Lock() - c.collectedCAs = updatedCerts - c.lock.Unlock() + defer c.lock.Unlock() - select { - case <-ctx.Done(): - return trace.Wrap(ctx.Err()) - case c.CertAuthorityC <- updatedCerts.ToSlice(): + for _, ca := range cas { + c.cas[ca.GetType()][ca.GetName()] = ca + c.fanout.Emit(types.Event{Type: types.OpPut, Resource: ca.Clone()}) } return nil } @@ -1046,17 +1097,12 @@ func (c *caCollector) processEventAndUpdateCurrent(ctx context.Context, event ty switch event.Type { case types.OpDelete: caType := types.CertAuthType(event.Resource.GetSubKind()) - - // Check if the certificate should be processed. - _, found := c.collectedCAs[caType] - if found { - delete(c.collectedCAs[caType], event.Resource.GetName()) + if !c.watchingType(caType) { + return } - select { - case <-ctx.Done(): - case c.CertAuthorityC <- c.collectedCAs.ToSlice(): - } + delete(c.cas[caType], event.Resource.GetName()) + c.fanout.Emit(event) case types.OpPut: ca, ok := event.Resource.(types.CertAuthority) if !ok { @@ -1064,28 +1110,31 @@ func (c *caCollector) processEventAndUpdateCurrent(ctx context.Context, event ty return } - caType := ca.GetType() - _, found := c.collectedCAs[caType] - // Check if the certificate should be processed. - if found { - c.collectedCAs[caType][ca.GetName()] = ca + if !c.watchingType(ca.GetType()) { + return } - select { - case <-ctx.Done(): - case c.CertAuthorityC <- c.collectedCAs.ToSlice(): + authority, ok := c.cas[ca.GetType()][ca.GetName()] + if ok && CertAuthoritiesEquivalent(authority, ca) { + return } + + c.cas[ca.GetType()][ca.GetName()] = ca + c.fanout.Emit(event) default: c.Log.Warnf("Unsupported event type %s.", event.Type) return } } -// GetCurrent returns the currently stored authorities. -func (c *caCollector) GetCurrent() []types.CertAuthority { - c.lock.RLock() - defer c.lock.RUnlock() - return c.collectedCAs.ToSlice() +func (c *caCollector) watchingType(t types.CertAuthType) bool { + for _, caType := range c.Types { + if caType == t { + return true + } + } + + return false } func (c *caCollector) notifyStale() {} diff --git a/lib/services/watcher_test.go b/lib/services/watcher_test.go index 1338998a3c964..8864cc8dcf9ed 100644 --- a/lib/services/watcher_test.go +++ b/lib/services/watcher_test.go @@ -37,6 +37,7 @@ import ( "github.com/gravitational/teleport/lib/auth/testauthority" "github.com/gravitational/teleport/lib/backend/lite" "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/fixtures" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/services/local" "github.com/gravitational/teleport/lib/tlsca" @@ -520,9 +521,10 @@ func resourceDiff(res1, res2 types.Resource) string { func caDiff(ca1, ca2 types.CertAuthority) string { return cmp.Diff(ca1, ca2, cmpopts.IgnoreFields(types.Metadata{}, "ID"), - cmpopts.IgnoreFields(types.CertAuthoritySpecV2{}, "CheckingKeys", "TLSKeyPairs"), + cmpopts.IgnoreFields(types.CertAuthoritySpecV2{}, "CheckingKeys", "TLSKeyPairs", "JWTKeyPairs"), cmpopts.IgnoreFields(types.SSHKeyPair{}, "PrivateKey"), cmpopts.IgnoreFields(types.TLSKeyPair{}, "Key"), + cmpopts.IgnoreFields(types.JWTKeyPair{}, "PrivateKey"), cmpopts.EquateEmpty(), ) } @@ -723,10 +725,12 @@ func newApp(t *testing.T, name string) types.Application { func TestCertAuthorityWatcher(t *testing.T) { t.Parallel() ctx := context.Background() + clock := clockwork.NewFakeClock() bk, err := lite.NewWithConfig(ctx, lite.Config{ Path: t.TempDir(), PollStreamPeriod: 200 * time.Millisecond, + Clock: clock, }) require.NoError(t, err) @@ -744,85 +748,88 @@ func TestCertAuthorityWatcher(t *testing.T) { Trust: caService, Events: local.NewEventsService(bk), }, + Clock: clock, }, - CertAuthorityC: make(chan []types.CertAuthority, 10), - WatchCertTypes: []types.CertAuthType{types.HostCA, types.UserCA, types.DatabaseCA}, + Types: []types.CertAuthType{types.HostCA, types.UserCA, types.DatabaseCA}, }) require.NoError(t, err) t.Cleanup(w.Close) - nothingWatcher, err := services.NewCertAuthorityWatcher(ctx, services.CertAuthorityWatcherConfig{ - ResourceWatcherConfig: services.ResourceWatcherConfig{ - Component: "test", - MaxRetryPeriod: 200 * time.Millisecond, - Client: &client{ - Trust: caService, - Events: local.NewEventsService(bk), - }, - }, - CertAuthorityC: make(chan []types.CertAuthority, 10), - }) + target := services.CertAuthorityTarget{ClusterName: "test"} + sub, err := w.Subscribe(ctx, target) require.NoError(t, err) - t.Cleanup(nothingWatcher.Close) - - require.Empty(t, w.GetCurrent()) - require.Empty(t, nothingWatcher.GetCurrent()) + t.Cleanup(func() { require.NoError(t, sub.Close()) }) - // Initially there are no cas so watcher should send an empty list. + // create a CA for the cluster and a type we are filtering for + // and ensure we receive the event + ca := newCertAuthority(t, "test", types.HostCA) + require.NoError(t, caService.UpsertCertAuthority(ca)) select { - case changeset := <-w.CertAuthorityC: - require.Len(t, changeset, 0) - require.Empty(t, nothingWatcher.GetCurrent()) - case <-w.Done(): - t.Fatal("Watcher has unexpectedly exited.") - case <-time.After(2 * time.Second): - t.Fatal("Timeout waiting for the first event.") + case event := <-sub.Events(): + caFromEvent, ok := event.Resource.(types.CertAuthority) + require.True(t, ok) + require.Empty(t, caDiff(ca, caFromEvent)) + case <-time.After(time.Second): + t.Fatal("timed out waiting for event") } - // Add an authority. - ca1 := newCertAuthority(t, "ca1", types.HostCA) - require.NoError(t, caService.CreateCertAuthority(ca1)) - - // The first event is always the current list of apps. + // create a CA with a type we are filtering for another cluster that we are NOT filtering for + // and ensure that we DO NOT receive the event + require.NoError(t, caService.UpsertCertAuthority(newCertAuthority(t, "unknown", types.UserCA))) select { - case changeset := <-w.CertAuthorityC: - require.Len(t, changeset, 1) - require.Empty(t, caDiff(changeset[0], ca1)) - require.Empty(t, nothingWatcher.GetCurrent()) - case <-w.Done(): - t.Fatal("Watcher has unexpectedly exited.") - case <-time.After(2 * time.Second): - t.Fatal("Timeout waiting for the first event.") + case event := <-sub.Events(): + t.Fatalf("Unexpected event: %v.", event) + case <-sub.Done(): + t.Fatal("CA watcher subscription has unexpectedly exited.") + case <-time.After(time.Second): } - // Add a second ca. - ca2 := newCertAuthority(t, "ca2", types.UserCA) - require.NoError(t, caService.CreateCertAuthority(ca2)) + // create a CA for the cluster and a type we are filtering for + // and ensure we receive the event + ca2 := newCertAuthority(t, "test", types.UserCA) + require.NoError(t, caService.UpsertCertAuthority(ca2)) + select { + case event := <-sub.Events(): + caFromEvent, ok := event.Resource.(types.CertAuthority) + require.True(t, ok) + require.Empty(t, caDiff(ca2, caFromEvent)) + case <-time.After(time.Second): + t.Fatal("timed out waiting for event") + } - // Watcher should detect the ca list change. + // delete a CA with type being watched in the cluster we are filtering for + // and ensure we receive the event + require.NoError(t, caService.DeleteCertAuthority(ca.GetID())) select { - case changeset := <-w.CertAuthorityC: - require.Len(t, changeset, 2) - require.Empty(t, nothingWatcher.GetCurrent()) - case <-w.Done(): - t.Fatal("Watcher has unexpectedly exited.") - case <-time.After(2 * time.Second): - t.Fatal("Timeout waiting for the update event.") + case event := <-sub.Events(): + require.Equal(t, types.KindCertAuthority, event.Resource.GetKind()) + require.Equal(t, string(types.HostCA), event.Resource.GetSubKind()) + require.Equal(t, "test", event.Resource.GetName()) + case <-time.After(time.Second): + t.Fatal("timed out waiting for event") } - // Delete the first ca. - require.NoError(t, caService.DeleteCertAuthority(ca1.GetID())) + // create a CA with a type we are NOT filtering for but for a cluster we are filtering for + // and ensure we DO NOT receive the event + signer := newCertAuthority(t, "test", types.JWTSigner) + require.NoError(t, caService.UpsertCertAuthority(signer)) + select { + case event := <-sub.Events(): + t.Fatalf("Unexpected event: %v.", event) + case <-sub.Done(): + t.Fatal("CA watcher subscription has unexpectedly exited.") + case <-time.After(time.Second): + } - // Watcher should detect the ca list change. + // delete a CA with a name we are filtering for but a type we are NOT filtering for + // and ensure we do NOT receive the event + require.NoError(t, caService.DeleteCertAuthority(signer.GetID())) select { - case changeset := <-w.CertAuthorityC: - require.Len(t, changeset, 1) - require.Empty(t, caDiff(changeset[0], ca2)) - require.Empty(t, nothingWatcher.GetCurrent()) - case <-w.Done(): - t.Fatal("Watcher has unexpectedly exited.") - case <-time.After(2 * time.Second): - t.Fatal("Timeout waiting for the update event.") + case event := <-sub.Events(): + t.Fatalf("Unexpected event: %v.", event) + case <-sub.Done(): + t.Fatal("CA watcher subscription has unexpectedly exited.") + case <-time.After(time.Second): } } @@ -839,15 +846,25 @@ func newCertAuthority(t *testing.T, name string, caType types.CertAuthType) type Type: caType, ClusterName: name, ActiveKeys: types.CAKeySet{ - SSH: []*types.SSHKeyPair{{ - PrivateKey: priv, - PrivateKeyType: types.PrivateKeyType_RAW, - PublicKey: pub, - }}, - TLS: []*types.TLSKeyPair{{ - Cert: cert, - Key: key, - }}, + SSH: []*types.SSHKeyPair{ + { + PrivateKey: priv, + PrivateKeyType: types.PrivateKeyType_RAW, + PublicKey: pub, + }, + }, + TLS: []*types.TLSKeyPair{ + { + Cert: cert, + Key: key, + }, + }, + JWT: []*types.JWTKeyPair{ + { + PublicKey: []byte(fixtures.JWTSignerPublicKey), + PrivateKey: []byte(fixtures.JWTSignerPrivateKey), + }, + }, }, Roles: nil, SigningAlg: types.CertAuthoritySpecV2_RSA_SHA2_256, diff --git a/lib/srv/regular/sshserver_test.go b/lib/srv/regular/sshserver_test.go index 511c44fd30e17..39617b523a81b 100644 --- a/lib/srv/regular/sshserver_test.go +++ b/lib/srv/regular/sshserver_test.go @@ -1126,6 +1126,7 @@ func TestProxyRoundRobin(t *testing.T) { defer listener.Close() lockWatcher := newLockWatcher(ctx, t, proxyClient) nodeWatcher := newNodeWatcher(ctx, t, proxyClient) + caWatcher := newCertAuthorityWatcher(ctx, t, proxyClient) reverseTunnelServer, err := reversetunnel.NewServer(reversetunnel.Config{ ClusterName: f.testSrv.ClusterName(), @@ -1143,6 +1144,7 @@ func TestProxyRoundRobin(t *testing.T) { Log: logger, LockWatcher: lockWatcher, NodeWatcher: nodeWatcher, + CertAuthorityWatcher: caWatcher, }) require.NoError(t, err) logger.WithField("tun-addr", reverseTunnelAddress.String()).Info("Created reverse tunnel server.") @@ -1252,6 +1254,7 @@ func TestProxyDirectAccess(t *testing.T) { proxyClient, _ := newProxyClient(t, f.testSrv) lockWatcher := newLockWatcher(ctx, t, proxyClient) nodeWatcher := newNodeWatcher(ctx, t, proxyClient) + caWatcher := newCertAuthorityWatcher(ctx, t, proxyClient) reverseTunnelServer, err := reversetunnel.NewServer(reversetunnel.Config{ ClientTLS: proxyClient.TLSConfig(), @@ -1269,6 +1272,7 @@ func TestProxyDirectAccess(t *testing.T) { Log: logger, LockWatcher: lockWatcher, NodeWatcher: nodeWatcher, + CertAuthorityWatcher: caWatcher, }) require.NoError(t, err) @@ -1963,6 +1967,19 @@ func newNodeWatcher(ctx context.Context, t *testing.T, client types.Events) *ser return nodeWatcher } +func newCertAuthorityWatcher(ctx context.Context, t *testing.T, client types.Events) *services.CertAuthorityWatcher { + caWatcher, err := services.NewCertAuthorityWatcher(ctx, services.CertAuthorityWatcherConfig{ + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: "test", + Client: client, + }, + Types: []types.CertAuthType{types.HostCA, types.UserCA}, + }) + require.NoError(t, err) + t.Cleanup(caWatcher.Close) + return caWatcher +} + // maxPipeSize is one larger than the maximum pipe size for most operating // systems which appears to be 65536 bytes. // diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index 302510831d18c..9c92af07c4201 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -275,6 +275,16 @@ func newWebSuite(t *testing.T) *WebSuite { }) require.NoError(t, err) + caWatcher, err := services.NewCertAuthorityWatcher(s.ctx, services.CertAuthorityWatcherConfig{ + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: teleport.ComponentProxy, + Client: s.proxyClient, + }, + Types: []types.CertAuthType{types.HostCA, types.UserCA}, + }) + require.NoError(t, err) + defer caWatcher.Close() + revTunServer, err := reversetunnel.NewServer(reversetunnel.Config{ ID: node.ID(), Listener: revTunListener, @@ -289,6 +299,7 @@ func newWebSuite(t *testing.T) *WebSuite { DataDir: t.TempDir(), LockWatcher: proxyLockWatcher, NodeWatcher: proxyNodeWatcher, + CertAuthorityWatcher: caWatcher, }) require.NoError(t, err) s.proxyTunnel = revTunServer @@ -3738,6 +3749,16 @@ func createProxy(ctx context.Context, t *testing.T, proxyID string, node *regula require.NoError(t, err) t.Cleanup(proxyLockWatcher.Close) + proxyCAWatcher, err := services.NewCertAuthorityWatcher(ctx, services.CertAuthorityWatcherConfig{ + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: teleport.ComponentProxy, + Client: client, + }, + Types: []types.CertAuthType{types.HostCA, types.UserCA}, + }) + require.NoError(t, err) + t.Cleanup(proxyLockWatcher.Close) + proxyNodeWatcher, err := services.NewNodeWatcher(ctx, services.NodeWatcherConfig{ ResourceWatcherConfig: services.ResourceWatcherConfig{ Component: teleport.ComponentProxy, @@ -3761,6 +3782,7 @@ func createProxy(ctx context.Context, t *testing.T, proxyID string, node *regula DataDir: t.TempDir(), LockWatcher: proxyLockWatcher, NodeWatcher: proxyNodeWatcher, + CertAuthorityWatcher: proxyCAWatcher, }) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, revTunServer.Close()) })