From 283c08bf3626a465c41bf9b4f35cf0f326752160 Mon Sep 17 00:00:00 2001 From: rosstimothy <39066650+rosstimothy@users.noreply.github.com> Date: Wed, 19 Jan 2022 16:53:45 -0500 Subject: [PATCH] Replace cluster periodics with watchers (#9609) * Replace cluster periodics with watchers Remove periodically sending locks and certificate authorities to leaf clusters. Instead we can rely on the watcher system to only deliver resources to leaf clusters when changes occur. Fixes #8817 (cherry picked from commit 8932ed4e03c808dc127fbb24fc26fc55941b1c03) --- integration/integration_test.go | 50 +++++-- lib/reversetunnel/remotesite.go | 250 +++++++++++++++++++------------- lib/reversetunnel/srv.go | 27 +++- lib/services/authority.go | 4 +- lib/services/trust.go | 20 ++- lib/services/watcher.go | 166 +++++++++++++++++++++ lib/services/watcher_test.go | 157 +++++++++++++++++++- 7 files changed, 539 insertions(+), 135 deletions(-) diff --git a/integration/integration_test.go b/integration/integration_test.go index bb46ae378726c..04e7999a00db7 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -3711,10 +3711,10 @@ func testRotateRollback(t *testing.T, s *integrationTestSuite) { // TestRotateTrustedClusters tests CA rotation support for trusted clusters func testRotateTrustedClusters(t *testing.T, suite *integrationTestSuite) { tr := utils.NewTracer(utils.ThisFunction()).Start() - defer tr.Stop() + t.Cleanup(func() { tr.Stop() }) ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + t.Cleanup(cancel) clusterMain := "rotate-main" clusterAux := "rotate-aux" @@ -3773,7 +3773,7 @@ func testRotateTrustedClusters(t *testing.T, suite *integrationTestSuite) { require.NoError(t, err) err = aux.Process.GetAuthServer().UpsertRole(ctx, role) require.NoError(t, err) - trustedClusterToken := "trusted-clsuter-token" + trustedClusterToken := "trusted-cluster-token" err = svc.GetAuthServer().UpsertToken(ctx, services.MustCreateProvisionToken(trustedClusterToken, []types.SystemRole{types.RoleTrustedCluster}, time.Time{})) require.NoError(t, err) @@ -3789,7 +3789,7 @@ func testRotateTrustedClusters(t *testing.T, suite *integrationTestSuite) { tryCreateTrustedCluster(t, aux.Process.GetAuthServer(), trustedCluster) waitForTunnelConnections(t, svc.GetAuthServer(), aux.Secrets.SiteName, 1) - // capture credentials before has reload started to simulate old client + // capture credentials before reload has started to simulate old client initialCreds, err := GenerateUserCreds(UserCredsRequest{ Process: svc, Username: suite.me.Username, @@ -3818,24 +3818,43 @@ func testRotateTrustedClusters(t *testing.T, suite *integrationTestSuite) { }) require.NoError(t, err) - // wait until service phase update to be broadcasted (init phase does not trigger reload) + // wait until service phase update to be broadcast (init phase does not trigger reload) err = waitForProcessEvent(svc, service.TeleportPhaseChangeEvent, 10*time.Second) require.NoError(t, err) // waitForPhase waits until aux cluster detects the rotation waitForPhase := func(phase string) error { + ctx, cancel := context.WithTimeout(context.Background(), tconf.PollingPeriod*10) + defer cancel() + + watcher, err := services.NewCertAuthorityWatcher(ctx, services.CertAuthorityWatcherConfig{ + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: teleport.ComponentProxy, + Clock: tconf.Clock, + Client: aux.GetSiteAPI(clusterAux), + }, + WatchHostCA: true, + }) + if err != nil { + return err + } + defer watcher.Close() + var lastPhase string for i := 0; i < 10; i++ { - ca, err := aux.Process.GetAuthServer().GetCertAuthority(types.CertAuthID{ - Type: types.HostCA, - DomainName: clusterMain, - }, false) - require.NoError(t, err) - if ca.GetRotation().Phase == phase { - return nil + select { + case <-ctx.Done(): + return trace.CompareFailed("failed to converge to phase %q, last phase %q", phase, lastPhase) + case cas := <-watcher.CertAuthorityC: + for _, ca := range cas { + if ca.GetClusterName() == clusterMain && + ca.GetType() == types.HostCA && + ca.GetRotation().Phase == phase { + return nil + } + lastPhase = ca.GetRotation().Phase + } } - lastPhase = ca.GetRotation().Phase - time.Sleep(tconf.PollingPeriod / 2) } return trace.CompareFailed("failed to converge to phase %q, last phase %q", phase, lastPhase) } @@ -3916,7 +3935,7 @@ func testRotateTrustedClusters(t *testing.T, suite *integrationTestSuite) { // shut down the service cancel() // close the service without waiting for the connections to drain - svc.Close() + require.NoError(t, svc.Close()) select { case err := <-runErrCh: @@ -4082,6 +4101,7 @@ func (s *integrationTestSuite) rotationConfig(disableWebService bool) *service.C tconf.PollingPeriod = 500 * time.Millisecond tconf.ClientTimeout = time.Second tconf.ShutdownTimeout = 2 * tconf.ClientTimeout + tconf.MaxRetryPeriod = time.Second return tconf } diff --git a/lib/reversetunnel/remotesite.go b/lib/reversetunnel/remotesite.go index 3346c32525889..fa2212c4f6eba 100644 --- a/lib/reversetunnel/remotesite.go +++ b/lib/reversetunnel/remotesite.go @@ -23,8 +23,6 @@ import ( "sync" "time" - "golang.org/x/crypto/ssh" - "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/client" "github.com/gravitational/teleport/api/constants" @@ -35,13 +33,10 @@ import ( "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/srv/forward" "github.com/gravitational/teleport/lib/utils" - "github.com/gravitational/teleport/lib/utils/interval" - "github.com/gravitational/trace" - "github.com/jonboulle/clockwork" - "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus" + "golang.org/x/crypto/ssh" ) // remoteSite is a remote site that established the inbound connecton to @@ -416,129 +411,177 @@ func (s *remoteSite) compareAndSwapCertAuthority(ca types.CertAuthority) error { return trace.CompareFailed("remote certificate authority rotation has been updated") } -// updateCertAuthorities updates local and remote cert authorities -func (s *remoteSite) updateCertAuthorities() error { - // update main cluster cert authorities on the remote side - // remote side makes sure that only relevant fields - // are updated - hostCA, err := s.localClient.GetCertAuthority(types.CertAuthID{ - Type: types.HostCA, - DomainName: s.srv.ClusterName, - }, false) - if err != nil { - return trace.Wrap(err) - } - err = s.remoteClient.RotateExternalCertAuthority(hostCA) - if err != nil { - return trace.Wrap(err) - } +func (s *remoteSite) updateCertAuthorities(retry utils.Retry) { + s.Debugf("Watching for cert authority changes.") + + for { + startedWaiting := s.clock.Now() + select { + case t := <-retry.After(): + s.Debugf("Initiating new cert authority watch after waiting %v.", t.Sub(startedWaiting)) + retry.Inc() + case <-s.ctx.Done(): + return + } + + err := s.watchCertAuthorities() + if err != nil { + switch { + case trace.IsNotFound(err): + s.Debugf("Remote cluster %v does not support cert authorities rotation yet.", s.domainName) + case trace.IsCompareFailed(err): + s.Infof("Remote cluster has updated certificate authorities, going to force reconnect.") + s.srv.removeSite(s.domainName) + s.Close() + return + case trace.IsConnectionProblem(err): + s.Debugf("Remote cluster %v is offline.", s.domainName) + default: + s.Warningf("Could not perform cert authorities update: %v.", trace.DebugReport(err)) + } + } - userCA, err := s.localClient.GetCertAuthority(types.CertAuthID{ - Type: types.UserCA, - DomainName: s.srv.ClusterName, - }, false) - if err != nil { - return trace.Wrap(err) } - err = s.remoteClient.RotateExternalCertAuthority(userCA) +} + +func (s *remoteSite) watchCertAuthorities() error { + localWatcher, err := services.NewCertAuthorityWatcher(s.ctx, services.CertAuthorityWatcherConfig{ + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: teleport.ComponentProxy, + Log: s, + Clock: s.clock, + Client: s.localAccessPoint, + }, + WatchUserCA: true, + WatchHostCA: true, + }) if err != nil { return trace.Wrap(err) } + defer localWatcher.Close() - // update remote cluster's host cert authoritiy on a local cluster - // local proxy is authorized to perform this operation only for - // host authorities of remote clusters. - remoteCA, err := s.remoteClient.GetCertAuthority(types.CertAuthID{ - Type: types.HostCA, - DomainName: s.domainName, - }, false) + remoteWatcher, err := services.NewCertAuthorityWatcher(s.ctx, services.CertAuthorityWatcherConfig{ + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: teleport.ComponentProxy, + Log: s, + Clock: s.clock, + Client: s.remoteAccessPoint, + }, + WatchHostCA: true, + }) if err != nil { return trace.Wrap(err) } + defer remoteWatcher.Close() - if remoteCA.GetClusterName() != s.domainName { - return trace.BadParameter( - "remote cluster sent different cluster name %v instead of expected one %v", - remoteCA.GetClusterName(), s.domainName) - } + for { + select { + case <-s.ctx.Done(): + s.WithError(s.ctx.Err()).Debug("Context is closing.") + return trace.Wrap(s.ctx.Err()) + case <-localWatcher.Done(): + s.Warn("Local CertAuthority watcher subscription has closed") + return fmt.Errorf("local ca watcher for cluster %s has closed", s.srv.ClusterName) + case <-remoteWatcher.Done(): + s.Warn("Remote CertAuthority watcher subscription has closed") + return fmt.Errorf("remote ca watcher for cluster %s has closed", s.domainName) + case cas := <-localWatcher.CertAuthorityC: + for _, localCA := range cas { + if localCA.GetClusterName() != s.srv.ClusterName || + (localCA.GetType() != types.HostCA && + localCA.GetType() != types.UserCA) { + continue + } - oldRemoteCA, err := s.localClient.GetCertAuthority(types.CertAuthID{ - Type: types.HostCA, - DomainName: remoteCA.GetClusterName(), - }, false) + if err := s.remoteClient.RotateExternalCertAuthority(localCA); err != nil { + s.WithError(err).Warn("Failed to rotate external ca") + return trace.Wrap(err) + } + } + case cas := <-remoteWatcher.CertAuthorityC: + for _, remoteCA := range cas { + if remoteCA.GetType() != types.HostCA || + remoteCA.GetClusterName() != s.domainName { + continue + } - if err != nil && !trace.IsNotFound(err) { - return trace.Wrap(err) - } + oldRemoteCA, err := s.localClient.GetCertAuthority(types.CertAuthID{ + Type: types.HostCA, + DomainName: remoteCA.GetClusterName(), + }, false) + + if err != nil && !trace.IsNotFound(err) { + return trace.Wrap(err) + } - // if CA is changed or does not exist, update backend - if err != nil || !services.CertAuthoritiesEquivalent(oldRemoteCA, remoteCA) { - if err := s.localClient.UpsertCertAuthority(remoteCA); err != nil { - return trace.Wrap(err) + // if CA is changed or does not exist, update backend + if err != nil || !services.CertAuthoritiesEquivalent(oldRemoteCA, remoteCA) { + if err := s.localClient.UpsertCertAuthority(remoteCA); err != nil { + return trace.Wrap(err) + } + } + + // always update our local reference to the cert authority + if err := s.compareAndSwapCertAuthority(remoteCA); err != nil { + return trace.Wrap(err) + } + } } } - - // always update our local reference to the cert authority - return s.compareAndSwapCertAuthority(remoteCA) } -func (s *remoteSite) periodicUpdateCertAuthorities() { - s.Debugf("Updating remote CAs with period %v.", s.srv.PollingPeriod) - periodic := interval.New(interval.Config{ - Duration: s.srv.PollingPeriod, - FirstDuration: utils.HalfJitter(s.srv.PollingPeriod), - Jitter: utils.NewSeventhJitter(), - }) - defer periodic.Stop() +func (s *remoteSite) updateLocks(retry utils.Retry) { + s.Debugf("Watching for remote lock changes.") + for { + startedWaiting := s.clock.Now() select { + case t := <-retry.After(): + s.Debugf("Initiating new lock watch after waiting %v.", t.Sub(startedWaiting)) + retry.Inc() case <-s.ctx.Done(): - s.Debugf("Context is closing.") return - case <-periodic.Next(): - err := s.updateCertAuthorities() - if err != nil { - switch { - case trace.IsNotFound(err): - s.Debugf("Remote cluster %v does not support cert authorities rotation yet.", s.domainName) - case trace.IsCompareFailed(err): - s.Infof("Remote cluster has updated certificate authorities, going to force reconnect.") - s.srv.removeSite(s.domainName) - s.Close() - return - case trace.IsConnectionProblem(err): - s.Debugf("Remote cluster %v is offline.", s.domainName) - default: - s.Warningf("Could not perform cert authorities updated: %v.", trace.DebugReport(err)) - } + } + + if err := s.watchLocks(); err != nil { + switch { + case trace.IsNotImplemented(err): + s.Debugf("Remote cluster %v does not support locks yet.", s.domainName) + case trace.IsConnectionProblem(err): + s.Debugf("Remote cluster %v is offline.", s.domainName) + default: + s.WithError(err).Warning("Could not update remote locks.") } } } } -func (s *remoteSite) periodicUpdateLocks() { - s.Debugf("Updating remote locks with period %v.", s.srv.PollingPeriod) - periodic := interval.New(interval.Config{ - Duration: s.srv.PollingPeriod, - FirstDuration: utils.HalfJitter(s.srv.PollingPeriod), - Jitter: utils.NewSeventhJitter(), - }) - defer periodic.Stop() +func (s *remoteSite) watchLocks() error { + watcher, err := s.srv.LockWatcher.Subscribe(s.ctx) + if err != nil { + s.WithError(err).Error("Failed to subscribe to LockWatcher") + return err + } + defer func() { + if err := watcher.Close(); err != nil { + s.WithError(err).Warn("Failed to close lock watcher subscription.") + } + }() + for { select { + case <-watcher.Done(): + s.WithError(watcher.Error()).Warn("Lock watcher subscription has closed") + return trace.Wrap(watcher.Error()) case <-s.ctx.Done(): - s.Debugf("Context is closing.") - return - case <-periodic.Next(): - locks := s.srv.LockWatcher.GetCurrent() - if err := s.remoteClient.ReplaceRemoteLocks(s.ctx, s.srv.ClusterName, locks); err != nil { - switch { - case trace.IsNotImplemented(err): - s.Debugf("Remote cluster %v does not support locks yet.", s.domainName) - case trace.IsConnectionProblem(err): - s.Debugf("Remote cluster %v is offline.", s.domainName) - default: - s.WithError(err).Warning("Could not update remote locks.") + s.WithError(s.ctx.Err()).Debug("Context is closing.") + return trace.Wrap(s.ctx.Err()) + case evt := <-watcher.Events(): + switch evt.Type { + case types.OpPut, types.OpDelete: + locks := s.srv.LockWatcher.GetCurrent() + if err := s.remoteClient.ReplaceRemoteLocks(s.ctx, s.srv.ClusterName, locks); err != nil { + return trace.Wrap(err) } } } @@ -632,7 +675,7 @@ func (s *remoteSite) dialWithAgent(params DialParams) (net.Conn, error) { MACAlgorithms: s.srv.Config.MACAlgorithms, DataDir: s.srv.Config.DataDir, Address: params.Address, - UseTunnel: UseTunnel(targetConn), + UseTunnel: UseTunnel(s.Logger, targetConn), FIPS: s.srv.FIPS, HostUUID: s.srv.ID, Emitter: s.srv.Config.Emitter, @@ -657,7 +700,7 @@ func (s *remoteSite) dialWithAgent(params DialParams) (net.Conn, error) { // UseTunnel makes a channel request asking for the type of connection. If // the other side does not respond (older cluster) or takes to long to // respond, be on the safe side and assume it's not a tunnel connection. -func UseTunnel(c *sshutils.ChConn) bool { +func UseTunnel(logger *log.Logger, c *sshutils.ChConn) bool { responseCh := make(chan bool, 1) go func() { @@ -673,8 +716,7 @@ func UseTunnel(c *sshutils.ChConn) bool { case response := <-responseCh: return response case <-time.After(1 * time.Second): - // TODO: remove logrus import - logrus.Debugf("Timed out waiting for response: returning false.") + logger.Debugf("Timed out waiting for response: returning false.") return false } } diff --git a/lib/reversetunnel/srv.go b/lib/reversetunnel/srv.go index bc4e17128aa9e..e09f500702c33 100644 --- a/lib/reversetunnel/srv.go +++ b/lib/reversetunnel/srv.go @@ -1070,8 +1070,31 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite, } remoteSite.certificateCache = certificateCache - go remoteSite.periodicUpdateCertAuthorities() - go remoteSite.periodicUpdateLocks() + caRetry, err := utils.NewLinear(utils.LinearConfig{ + First: utils.HalfJitter(srv.Config.PollingPeriod), + Step: srv.Config.PollingPeriod / 5, + Max: srv.Config.PollingPeriod, + Jitter: utils.NewHalfJitter(), + Clock: srv.Clock, + }) + if err != nil { + return nil, err + } + + go remoteSite.updateCertAuthorities(caRetry) + + lockRetry, err := utils.NewLinear(utils.LinearConfig{ + First: utils.HalfJitter(srv.Config.PollingPeriod), + Step: srv.Config.PollingPeriod / 5, + Max: srv.Config.PollingPeriod, + Jitter: utils.NewHalfJitter(), + Clock: srv.Clock, + }) + if err != nil { + return nil, err + } + + go remoteSite.updateLocks(lockRetry) return remoteSite, nil } diff --git a/lib/services/authority.go b/lib/services/authority.go index d575a21ada5e1..684895da2a770 100644 --- a/lib/services/authority.go +++ b/lib/services/authority.go @@ -258,7 +258,7 @@ type UserCertParams struct { ClientIP string } -// Check checks the user certificate parameters +// CheckAndSetDefaults checks the user certificate parameters func (c *UserCertParams) CheckAndSetDefaults() error { if c.CASigner == nil || c.CASigningAlg == "" { return trace.BadParameter("CASigner and CASigningAlg are required") @@ -387,7 +387,7 @@ func MarshalCertAuthority(certAuthority types.CertAuthority, opts ...MarshalOpti } } -// CertAuthorityNeedsMigrations returns true if the given CertAuthority needs to be migrated +// CertAuthorityNeedsMigration returns true if the given CertAuthority needs to be migrated func CertAuthorityNeedsMigration(cai types.CertAuthority) (bool, error) { ca, ok := cai.(*types.CertAuthorityV2) if !ok { diff --git a/lib/services/trust.go b/lib/services/trust.go index 3302865b9a0c7..84d4c5d4aeb5b 100644 --- a/lib/services/trust.go +++ b/lib/services/trust.go @@ -18,6 +18,15 @@ package services import "github.com/gravitational/teleport/api/types" +// AuthorityGetter defines interface for fetching cert authority resources. +type AuthorityGetter interface { + // GetCertAuthority returns cert authority by id + GetCertAuthority(id types.CertAuthID, loadKeys bool, opts ...MarshalOption) (types.CertAuthority, error) + + // GetCertAuthorities returns a list of cert authorities + GetCertAuthorities(caType types.CertAuthType, loadKeys bool, opts ...MarshalOption) ([]types.CertAuthority, error) +} + // Trust is responsible for managing certificate authorities // Each authority is managing some domain, e.g. example.com // @@ -28,6 +37,9 @@ import "github.com/gravitational/teleport/api/types" // Remote authorities have only public keys available, so they can // be only used to validate type Trust interface { + // AuthorityGetter retrieves certificate authorities + AuthorityGetter + // CreateCertAuthority inserts a new certificate authority CreateCertAuthority(ca types.CertAuthority) error @@ -45,14 +57,6 @@ type Trust interface { // DeleteAllCertAuthorities deletes cert authorities of a certain type DeleteAllCertAuthorities(caType types.CertAuthType) error - // GetCertAuthority returns certificate authority by given id. Parameter loadSigningKeys - // controls if signing keys are loaded - GetCertAuthority(id types.CertAuthID, loadSigningKeys bool, opts ...MarshalOption) (types.CertAuthority, error) - - // GetCertAuthorities returns a list of authorities of a given type - // loadSigningKeys controls whether signing keys should be loaded or not - GetCertAuthorities(caType types.CertAuthType, loadSigningKeys bool, opts ...MarshalOption) ([]types.CertAuthority, error) - // ActivateCertAuthority moves a CertAuthority from the deactivated list to // the normal list. ActivateCertAuthority(id types.CertAuthID) error diff --git a/lib/services/watcher.go b/lib/services/watcher.go index 04371c9df280c..c1d932323248f 100644 --- a/lib/services/watcher.go +++ b/lib/services/watcher.go @@ -883,3 +883,169 @@ func appsToSlice(apps map[string]types.Application) (slice []types.Application) } return slice } + +// CertAuthorityWatcherConfig is a CertAuthorityWatcher configuration. +type CertAuthorityWatcherConfig struct { + // ResourceWatcherConfig is the resource watcher configuration. + ResourceWatcherConfig + // AuthorityGetter is responsible for fetching cert authority resources. + AuthorityGetter + // CertAuthorityC receives up-to-date list of all cert authority resources. + CertAuthorityC chan []types.CertAuthority + // WatchHostCA indicates that the watcher should monitor types.HostCA + WatchHostCA bool + // WatchUserCA indicates that the watcher should monitor types.UserCA + WatchUserCA bool +} + +// CheckAndSetDefaults checks parameters and sets default values. +func (cfg *CertAuthorityWatcherConfig) CheckAndSetDefaults() error { + if err := cfg.ResourceWatcherConfig.CheckAndSetDefaults(); err != nil { + return trace.Wrap(err) + } + if cfg.AuthorityGetter == nil { + getter, ok := cfg.Client.(AuthorityGetter) + if !ok { + return trace.BadParameter("missing parameter AuthorityGetter and Client not usable as AuthorityGetter") + } + cfg.AuthorityGetter = getter + } + if cfg.CertAuthorityC == nil { + cfg.CertAuthorityC = make(chan []types.CertAuthority) + } + return nil +} + +// NewCertAuthorityWatcher returns a new instance of CertAuthorityWatcher. +func NewCertAuthorityWatcher(ctx context.Context, cfg CertAuthorityWatcherConfig) (*CertAuthorityWatcher, error) { + if err := cfg.CheckAndSetDefaults(); err != nil { + return nil, trace.Wrap(err) + } + + collector := &caCollector{ + CertAuthorityWatcherConfig: cfg, + } + + watcher, err := newResourceWatcher(ctx, collector, cfg.ResourceWatcherConfig) + if err != nil { + return nil, trace.Wrap(err) + } + + return &CertAuthorityWatcher{watcher, collector}, nil +} + +// CertAuthorityWatcher is built on top of resourceWatcher to monitor cert authority resources. +type CertAuthorityWatcher struct { + *resourceWatcher + *caCollector +} + +// caCollector accompanies resourceWatcher when monitoring cert authority resources. +type caCollector struct { + CertAuthorityWatcherConfig + host map[string]types.CertAuthority + user map[string]types.CertAuthority + lock sync.RWMutex +} + +// resourceKind specifies the resource kind to watch. +func (c *caCollector) resourceKind() string { + return types.KindCertAuthority +} + +// getResourcesAndUpdateCurrent refreshes the list of current resources. +func (c *caCollector) getResourcesAndUpdateCurrent(ctx context.Context) error { + var ( + newHost map[string]types.CertAuthority + newUser map[string]types.CertAuthority + ) + + if c.WatchHostCA { + host, err := c.AuthorityGetter.GetCertAuthorities(types.HostCA, false) + if err != nil { + return trace.Wrap(err) + } + newHost = make(map[string]types.CertAuthority, len(host)) + for _, ca := range host { + newHost[ca.GetName()] = ca + } + } + + if c.WatchUserCA { + user, err := c.AuthorityGetter.GetCertAuthorities(types.UserCA, false) + if err != nil { + return trace.Wrap(err) + } + newUser = make(map[string]types.CertAuthority, len(user)) + for _, ca := range user { + newUser[ca.GetName()] = ca + } + } + + c.lock.Lock() + c.host = newHost + c.user = newUser + c.lock.Unlock() + + c.CertAuthorityC <- casToSlice(newHost, newUser) + return nil +} + +// processEventAndUpdateCurrent is called when a watcher event is received. +func (c *caCollector) processEventAndUpdateCurrent(ctx context.Context, event types.Event) { + if event.Resource == nil || event.Resource.GetKind() != types.KindCertAuthority { + c.Log.Warnf("Unexpected event: %v.", event) + return + } + c.lock.Lock() + defer c.lock.Unlock() + switch event.Type { + case types.OpDelete: + if c.WatchHostCA && event.Resource.GetSubKind() == string(types.HostCA) { + delete(c.host, event.Resource.GetName()) + } + if c.WatchUserCA && event.Resource.GetSubKind() == string(types.UserCA) { + delete(c.user, event.Resource.GetName()) + } + + c.CertAuthorityC <- casToSlice(c.host, c.user) + case types.OpPut: + ca, ok := event.Resource.(types.CertAuthority) + if !ok { + c.Log.Warnf("Unexpected resource type %T.", event.Resource) + return + } + + if c.WatchHostCA && ca.GetType() == types.HostCA { + c.host[ca.GetName()] = ca + } + if c.WatchUserCA && ca.GetType() == types.UserCA { + c.user[ca.GetName()] = ca + } + + c.CertAuthorityC <- casToSlice(c.host, c.user) + default: + c.Log.Warnf("Unsupported event type %s.", event.Type) + return + } +} + +// GetCurrent returns the currently stored authorities. +func (c *caCollector) GetCurrent() []types.CertAuthority { + c.lock.RLock() + defer c.lock.RUnlock() + return casToSlice(c.host, c.user) +} + +func (c *caCollector) notifyStale() {} + +func casToSlice(host map[string]types.CertAuthority, user map[string]types.CertAuthority) []types.CertAuthority { + slice := make([]types.CertAuthority, 0, len(host)+len(user)) + for _, ca := range host { + slice = append(slice, ca) + } + for _, ca := range user { + slice = append(slice, ca) + } + return slice +} diff --git a/lib/services/watcher_test.go b/lib/services/watcher_test.go index 99f676e683c51..4d7b3cfbcacc9 100644 --- a/lib/services/watcher_test.go +++ b/lib/services/watcher_test.go @@ -18,6 +18,7 @@ package services_test import ( "context" + "crypto/x509/pkix" "errors" "sync" "testing" @@ -25,16 +26,17 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "github.com/gravitational/trace" - "github.com/jonboulle/clockwork" - "github.com/stretchr/testify/require" - "github.com/gravitational/teleport/api/constants" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/auth/testauthority" "github.com/gravitational/teleport/lib/backend/lite" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/services/local" + "github.com/gravitational/teleport/lib/tlsca" + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/require" ) var _ types.Events = (*errorWatcher)(nil) @@ -512,6 +514,16 @@ func resourceDiff(res1, res2 types.Resource) string { cmpopts.EquateEmpty()) } +func caDiff(ca1, ca2 types.CertAuthority) string { + return cmp.Diff(ca1, ca2, + cmpopts.IgnoreFields(types.Metadata{}, "ID"), + cmpopts.IgnoreFields(types.CertAuthoritySpecV2{}, "CheckingKeys", "TLSKeyPairs"), + cmpopts.IgnoreFields(types.SSHKeyPair{}, "PrivateKey"), + cmpopts.IgnoreFields(types.TLSKeyPair{}, "Key"), + cmpopts.EquateEmpty(), + ) +} + // TestDatabaseWatcher tests that database resource watcher properly receives // and dispatches updates to database resources. func TestDatabaseWatcher(t *testing.T) { @@ -704,3 +716,140 @@ func newApp(t *testing.T, name string) types.Application { require.NoError(t, err) return app } + +func TestCertAuthorityWatcher(t *testing.T) { + t.Parallel() + ctx := context.Background() + + bk, err := lite.NewWithConfig(ctx, lite.Config{ + Path: t.TempDir(), + PollStreamPeriod: 200 * time.Millisecond, + }) + require.NoError(t, err) + + type client struct { + services.Trust + types.Events + } + + caService := local.NewCAService(bk) + w, err := services.NewCertAuthorityWatcher(ctx, services.CertAuthorityWatcherConfig{ + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: "test", + MaxRetryPeriod: 200 * time.Millisecond, + Client: &client{ + Trust: caService, + Events: local.NewEventsService(bk), + }, + }, + CertAuthorityC: make(chan []types.CertAuthority, 10), + WatchUserCA: true, + WatchHostCA: true, + }) + require.NoError(t, err) + t.Cleanup(w.Close) + + nothingWatcher, err := services.NewCertAuthorityWatcher(ctx, services.CertAuthorityWatcherConfig{ + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: "test", + MaxRetryPeriod: 200 * time.Millisecond, + Client: &client{ + Trust: caService, + Events: local.NewEventsService(bk), + }, + }, + CertAuthorityC: make(chan []types.CertAuthority, 10), + }) + require.NoError(t, err) + t.Cleanup(nothingWatcher.Close) + + require.Empty(t, w.GetCurrent()) + require.Empty(t, nothingWatcher.GetCurrent()) + + // Initially there are no cas so watcher should send an empty list. + select { + case changeset := <-w.CertAuthorityC: + require.Len(t, changeset, 0) + require.Empty(t, nothingWatcher.GetCurrent()) + case <-w.Done(): + t.Fatal("Watcher has unexpectedly exited.") + case <-time.After(2 * time.Second): + t.Fatal("Timeout waiting for the first event.") + } + + // Add an authority. + ca1 := newCertAuthority(t, "ca1", types.HostCA) + require.NoError(t, caService.CreateCertAuthority(ca1)) + + // The first event is always the current list of apps. + select { + case changeset := <-w.CertAuthorityC: + require.Len(t, changeset, 1) + require.Empty(t, caDiff(changeset[0], ca1)) + require.Empty(t, nothingWatcher.GetCurrent()) + case <-w.Done(): + t.Fatal("Watcher has unexpectedly exited.") + case <-time.After(2 * time.Second): + t.Fatal("Timeout waiting for the first event.") + } + + // Add a second ca. + ca2 := newCertAuthority(t, "ca2", types.UserCA) + require.NoError(t, caService.CreateCertAuthority(ca2)) + + // Watcher should detect the ca list change. + select { + case changeset := <-w.CertAuthorityC: + require.Len(t, changeset, 2) + require.Empty(t, nothingWatcher.GetCurrent()) + case <-w.Done(): + t.Fatal("Watcher has unexpectedly exited.") + case <-time.After(2 * time.Second): + t.Fatal("Timeout waiting for the update event.") + } + + // Delete the first ca. + require.NoError(t, caService.DeleteCertAuthority(ca1.GetID())) + + // Watcher should detect the ca list change. + select { + case changeset := <-w.CertAuthorityC: + require.Len(t, changeset, 1) + require.Empty(t, caDiff(changeset[0], ca2)) + require.Empty(t, nothingWatcher.GetCurrent()) + case <-w.Done(): + t.Fatal("Watcher has unexpectedly exited.") + case <-time.After(2 * time.Second): + t.Fatal("Timeout waiting for the update event.") + } +} + +func newCertAuthority(t *testing.T, name string, caType types.CertAuthType) types.CertAuthority { + ta := testauthority.New() + priv, pub, err := ta.GenerateKeyPair("") + require.NoError(t, err) + + // CA for cluster1 with 1 key pair. + key, cert, err := tlsca.GenerateSelfSignedCA(pkix.Name{CommonName: name}, nil, time.Minute) + require.NoError(t, err) + + ca, err := types.NewCertAuthority(types.CertAuthoritySpecV2{ + Type: caType, + ClusterName: name, + ActiveKeys: types.CAKeySet{ + SSH: []*types.SSHKeyPair{{ + PrivateKey: priv, + PrivateKeyType: types.PrivateKeyType_RAW, + PublicKey: pub, + }}, + TLS: []*types.TLSKeyPair{{ + Cert: cert, + Key: key, + }}, + }, + Roles: nil, + SigningAlg: types.CertAuthoritySpecV2_RSA_SHA2_256, + }) + require.NoError(t, err) + return ca +}