diff --git a/integration/integration_test.go b/integration/integration_test.go index 0cfb5910117a1..742b03b8fa752 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -3658,44 +3658,27 @@ func testRotateTrustedClusters(t *testing.T, suite *integrationTestSuite) { require.NoError(t, err) // waitForPhase waits until aux cluster detects the rotation - waitForPhase := func(phase string) error { - ctx, cancel := context.WithTimeout(context.Background(), tconf.PollingPeriod*10) - defer cancel() - - watcher, err := services.NewCertAuthorityWatcher(ctx, services.CertAuthorityWatcherConfig{ - ResourceWatcherConfig: services.ResourceWatcherConfig{ - Component: teleport.ComponentProxy, - Clock: tconf.Clock, - Client: aux.GetSiteAPI(clusterAux), - }, - WatchHostCA: true, - }) - if err != nil { - return err - } - defer watcher.Close() + waitForPhase := func(phase string) { + require.Eventually(t, func() bool { + ca, err := aux.Process.GetAuthServer().GetCertAuthority( + ctx, + types.CertAuthID{ + Type: types.HostCA, + DomainName: clusterMain, + }, false) + if err != nil { + return false + } - var lastPhase string - for i := 0; i < 10; i++ { - select { - case <-ctx.Done(): - return trace.CompareFailed("failed to converge to phase %q, last phase %q", phase, lastPhase) - case cas := <-watcher.CertAuthorityC: - for _, ca := range cas { - if ca.GetClusterName() == clusterMain && - ca.GetType() == types.HostCA && - ca.GetRotation().Phase == phase { - return nil - } - lastPhase = ca.GetRotation().Phase - } + if ca.GetRotation().Phase == phase { + return true } - } - return trace.CompareFailed("failed to converge to phase %q, last phase %q", phase, lastPhase) + + return false + }, 30*time.Second, 250*time.Millisecond, "failed to converge to phase %q", phase) } - err = waitForPhase(types.RotationPhaseInit) - require.NoError(t, err) + waitForPhase(types.RotationPhaseInit) // update clients err = svc.GetAuthServer().RotateCertAuthority(ctx, auth.RotateRequest{ @@ -3708,8 +3691,7 @@ func testRotateTrustedClusters(t *testing.T, suite *integrationTestSuite) { svc, err = suite.waitForReload(serviceC, svc) require.NoError(t, err) - err = waitForPhase(types.RotationPhaseUpdateClients) - require.NoError(t, err) + waitForPhase(types.RotationPhaseUpdateClients) // old client should work as is err = runAndMatch(clt, 8, []string{"echo", "hello world"}, ".*hello world.*") @@ -3728,8 +3710,7 @@ func testRotateTrustedClusters(t *testing.T, suite *integrationTestSuite) { svc, err = suite.waitForReload(serviceC, svc) require.NoError(t, err) - err = waitForPhase(types.RotationPhaseUpdateServers) - require.NoError(t, err) + waitForPhase(types.RotationPhaseUpdateServers) // new credentials will work from this phase to others newCreds, err := GenerateUserCreds(UserCredsRequest{Process: svc, Username: suite.me.Username}) @@ -3757,8 +3738,7 @@ func testRotateTrustedClusters(t *testing.T, suite *integrationTestSuite) { require.NoError(t, err) t.Log("Service reload completed, waiting for phase.") - err = waitForPhase(types.RotationPhaseStandby) - require.NoError(t, err) + waitForPhase(types.RotationPhaseStandby) t.Log("Phase completed.") // new client still works diff --git a/lib/reversetunnel/remotesite.go b/lib/reversetunnel/remotesite.go index 65dbe23e85c9f..20366c59f55de 100644 --- a/lib/reversetunnel/remotesite.go +++ b/lib/reversetunnel/remotesite.go @@ -427,20 +427,12 @@ func (s *remoteSite) compareAndSwapCertAuthority(ca types.CertAuthority) error { return trace.CompareFailed("remote certificate authority rotation has been updated") } -func (s *remoteSite) updateCertAuthorities(retry utils.Retry) { - s.Debugf("Watching for cert authority changes.") +func (s *remoteSite) updateCertAuthorities(retry utils.Retry, remoteWatcher *services.CertAuthorityWatcher, remoteVersion string) { + defer remoteWatcher.Close() + cas := make(map[types.CertAuthType]types.CertAuthority) for { - startedWaiting := s.clock.Now() - select { - case t := <-retry.After(): - s.Debugf("Initiating new cert authority watch after waiting %v.", t.Sub(startedWaiting)) - retry.Inc() - case <-s.ctx.Done(): - return - } - - err := s.watchCertAuthorities() + err := s.watchCertAuthorities(remoteWatcher, remoteVersion, cas) if err != nil { switch { case trace.IsNotFound(err): @@ -456,67 +448,92 @@ func (s *remoteSite) updateCertAuthorities(retry utils.Retry) { } } + startedWaiting := s.clock.Now() + select { + case t := <-retry.After(): + s.Debugf("Initiating new cert authority watch after waiting %v.", t.Sub(startedWaiting)) + retry.Inc() + case <-s.ctx.Done(): + return + } } } -func (s *remoteSite) watchCertAuthorities() error { - localWatcher, err := services.NewCertAuthorityWatcher(s.ctx, services.CertAuthorityWatcherConfig{ - ResourceWatcherConfig: services.ResourceWatcherConfig{ - Component: teleport.ComponentProxy, - Log: s, - Clock: s.clock, - Client: s.localAccessPoint, +func (s *remoteSite) watchCertAuthorities(remoteWatcher *services.CertAuthorityWatcher, remoteVersion string, cas map[types.CertAuthType]types.CertAuthority) error { + localWatch, err := s.srv.CertAuthorityWatcher.Subscribe( + s.ctx, + services.CertAuthorityTarget{ + Type: types.HostCA, + ClusterName: s.srv.ClusterName, }, - WatchUserCA: true, - WatchHostCA: true, - }) + services.CertAuthorityTarget{ + Type: types.UserCA, + ClusterName: s.srv.ClusterName, + }) if err != nil { return trace.Wrap(err) } - defer localWatcher.Close() + defer func() { + if err := localWatch.Close(); err != nil { + s.WithError(err).Warn("Failed to close local ca watcher subscription.") + } + }() - remoteWatcher, err := services.NewCertAuthorityWatcher(s.ctx, services.CertAuthorityWatcherConfig{ - ResourceWatcherConfig: services.ResourceWatcherConfig{ - Component: teleport.ComponentProxy, - Log: s, - Clock: s.clock, - Client: s.remoteAccessPoint, + remoteWatch, err := remoteWatcher.Subscribe( + s.ctx, + services.CertAuthorityTarget{ + ClusterName: s.domainName, + Type: types.HostCA, }, - WatchHostCA: true, - }) + ) if err != nil { return trace.Wrap(err) } - defer remoteWatcher.Close() + defer func() { + if err := remoteWatch.Close(); err != nil { + s.WithError(err).Warn("Failed to close remote ca watcher subscription.") + } + }() + s.Debugf("Watching for cert authority changes.") for { select { case <-s.ctx.Done(): s.WithError(s.ctx.Err()).Debug("Context is closing.") return trace.Wrap(s.ctx.Err()) - case <-localWatcher.Done(): + case <-localWatch.Done(): s.Warn("Local CertAuthority watcher subscription has closed") return fmt.Errorf("local ca watcher for cluster %s has closed", s.srv.ClusterName) - case <-remoteWatcher.Done(): + case <-remoteWatch.Done(): s.Warn("Remote CertAuthority watcher subscription has closed") return fmt.Errorf("remote ca watcher for cluster %s has closed", s.domainName) - case cas := <-localWatcher.CertAuthorityC: - for _, localCA := range cas { - if localCA.GetClusterName() != s.srv.ClusterName || - (localCA.GetType() != types.HostCA && - localCA.GetType() != types.UserCA) { + case evt := <-localWatch.Events(): + switch evt.Type { + case types.OpPut: + localCA, ok := evt.Resource.(types.CertAuthority) + if !ok { continue } + ca, ok := cas[localCA.GetType()] + if ok && services.CertAuthoritiesEquivalent(ca, localCA) { + continue + } + + // clone to prevent a race with watcher filtering + localCA = localCA.Clone() if err := s.remoteClient.RotateExternalCertAuthority(s.ctx, localCA); err != nil { - s.WithError(err).Warn("Failed to rotate external ca") + log.WithError(err).Warn("Failed to rotate external ca") return trace.Wrap(err) } + + cas[localCA.GetType()] = localCA } - case cas := <-remoteWatcher.CertAuthorityC: - for _, remoteCA := range cas { - if remoteCA.GetType() != types.HostCA || - remoteCA.GetClusterName() != s.domainName { + case evt := <-remoteWatch.Events(): + switch evt.Type { + case types.OpPut: + remoteCA, ok := evt.Resource.(types.CertAuthority) + if !ok { continue } diff --git a/lib/reversetunnel/srv.go b/lib/reversetunnel/srv.go index af89c620caf0c..39f2329562a9b 100644 --- a/lib/reversetunnel/srv.go +++ b/lib/reversetunnel/srv.go @@ -205,6 +205,9 @@ type Config struct { // NodeWatcher is a node watcher. NodeWatcher *services.NodeWatcher + + // CertAuthorityWatcher is a cert authority watcher. + CertAuthorityWatcher *services.CertAuthorityWatcher } // CheckAndSetDefaults checks parameters and sets default values @@ -259,6 +262,9 @@ func (cfg *Config) CheckAndSetDefaults() error { if cfg.NodeWatcher == nil { return trace.BadParameter("missing parameter NodeWatcher") } + if cfg.CertAuthorityWatcher == nil { + return trace.BadParameter("missing parameter CertAuthorityWatcher") + } return nil } @@ -1037,6 +1043,11 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite, connInfo.SetExpiry(srv.Clock.Now().Add(srv.offlineThreshold)) closeContext, cancel := context.WithCancel(srv.ctx) + defer func() { + if err != nil { + cancel() + } + }() remoteSite := &remoteSite{ srv: srv, domainName: domainName, @@ -1060,20 +1071,17 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite, clt, _, err := remoteSite.getRemoteClient() if err != nil { - cancel() return nil, trace.Wrap(err) } remoteSite.remoteClient = clt remoteVersion, err := getRemoteAuthVersion(closeContext, sconn) if err != nil { - cancel() return nil, trace.Wrap(err) } accessPoint, err := createRemoteAccessPoint(srv, clt, remoteVersion, domainName) if err != nil { - cancel() return nil, trace.Wrap(err) } remoteSite.remoteAccessPoint = accessPoint @@ -1085,7 +1093,6 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite, }, }) if err != nil { - cancel() return nil, trace.Wrap(err) } remoteSite.nodeWatcher = nodeWatcher @@ -1095,7 +1102,6 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite, // is signed by the correct certificate authority. certificateCache, err := newHostCertificateCache(srv.Config.KeyGen, srv.localAuthClient) if err != nil { - cancel() return nil, trace.Wrap(err) } remoteSite.certificateCache = certificateCache @@ -1108,11 +1114,25 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite, Clock: srv.Clock, }) if err != nil { - cancel() return nil, trace.Wrap(err) } - go remoteSite.updateCertAuthorities(caRetry) + remoteWatcher, err := services.NewCertAuthorityWatcher(srv.ctx, services.CertAuthorityWatcherConfig{ + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: teleport.ComponentProxy, + Log: srv.log, + Clock: srv.Clock, + Client: remoteSite.remoteAccessPoint, + }, + Types: []types.CertAuthType{types.HostCA}, + }) + if err != nil { + return nil, trace.Wrap(err) + } + + go func() { + remoteSite.updateCertAuthorities(caRetry, remoteWatcher, remoteVersion) + }() lockRetry, err := utils.NewLinear(utils.LinearConfig{ First: utils.HalfJitter(srv.Config.PollingPeriod), @@ -1122,7 +1142,6 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite, Clock: srv.Clock, }) if err != nil { - cancel() return nil, trace.Wrap(err) } diff --git a/lib/service/service.go b/lib/service/service.go index 4eecb1f7cf7a2..50b5a6d0605d4 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -2505,6 +2505,19 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { return trace.Wrap(err) } + caWatcher, err := services.NewCertAuthorityWatcher(process.ExitContext(), services.CertAuthorityWatcherConfig{ + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: teleport.ComponentProxy, + Log: process.log.WithField(trace.Component, teleport.ComponentProxy), + Client: conn.Client, + }, + AuthorityGetter: accessPoint, + Types: []types.CertAuthType{types.HostCA, types.UserCA}, + }) + if err != nil { + return trace.Wrap(err) + } + // register SSH reverse tunnel server that accepts connections // from remote teleport nodes var tsrv reversetunnel.Server @@ -2528,17 +2541,18 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { Client: conn.Client, }, }, - KeyGen: cfg.Keygen, - Ciphers: cfg.Ciphers, - KEXAlgorithms: cfg.KEXAlgorithms, - MACAlgorithms: cfg.MACAlgorithms, - DataDir: process.Config.DataDir, - PollingPeriod: process.Config.PollingPeriod, - FIPS: cfg.FIPS, - Emitter: streamEmitter, - Log: process.log, - LockWatcher: lockWatcher, - NodeWatcher: nodeWatcher, + KeyGen: cfg.Keygen, + Ciphers: cfg.Ciphers, + KEXAlgorithms: cfg.KEXAlgorithms, + MACAlgorithms: cfg.MACAlgorithms, + DataDir: process.Config.DataDir, + PollingPeriod: process.Config.PollingPeriod, + FIPS: cfg.FIPS, + Emitter: streamEmitter, + Log: process.log, + LockWatcher: lockWatcher, + NodeWatcher: nodeWatcher, + CertAuthorityWatcher: caWatcher, }) if err != nil { return trace.Wrap(err) diff --git a/lib/services/watcher.go b/lib/services/watcher.go index 73e083f2f1360..9d63dbe4994e3 100644 --- a/lib/services/watcher.go +++ b/lib/services/watcher.go @@ -661,12 +661,8 @@ type CertAuthorityWatcherConfig struct { ResourceWatcherConfig // AuthorityGetter is responsible for fetching cert authority resources. AuthorityGetter - // CertAuthorityC receives up-to-date list of all cert authority resources. - CertAuthorityC chan []types.CertAuthority - // WatchHostCA indicates that the watcher should monitor types.HostCA - WatchHostCA bool - // WatchUserCA indicates that the watcher should monitor types.UserCA - WatchUserCA bool + // Types restricts which cert authority types are retrieved via the AuthorityGetter. + Types []types.CertAuthType } // CheckAndSetDefaults checks parameters and sets default values. @@ -681,12 +677,19 @@ func (cfg *CertAuthorityWatcherConfig) CheckAndSetDefaults() error { } cfg.AuthorityGetter = getter } - if cfg.CertAuthorityC == nil { - cfg.CertAuthorityC = make(chan []types.CertAuthority) - } return nil } +// IsWatched return true if the given certificate auth type is being observer by the watcher. +func (cfg *CertAuthorityWatcherConfig) IsWatched(certType types.CertAuthType) bool { + for _, observedType := range cfg.Types { + if observedType == certType { + return true + } + } + return false +} + // NewCertAuthorityWatcher returns a new instance of CertAuthorityWatcher. func NewCertAuthorityWatcher(ctx context.Context, cfg CertAuthorityWatcherConfig) (*CertAuthorityWatcher, error) { if err := cfg.CheckAndSetDefaults(); err != nil { @@ -695,6 +698,12 @@ func NewCertAuthorityWatcher(ctx context.Context, cfg CertAuthorityWatcherConfig collector := &caCollector{ CertAuthorityWatcherConfig: cfg, + fanout: NewFanout(), + cas: make(map[types.CertAuthType]map[string]types.CertAuthority, len(cfg.Types)), + } + + for _, t := range cfg.Types { + collector.cas[t] = make(map[string]types.CertAuthority) } watcher, err := newResourceWatcher(ctx, collector, cfg.ResourceWatcherConfig) @@ -702,6 +711,7 @@ func NewCertAuthorityWatcher(ctx context.Context, cfg CertAuthorityWatcherConfig return nil, trace.Wrap(err) } + collector.fanout.SetInit() return &CertAuthorityWatcher{watcher, collector}, nil } @@ -714,9 +724,66 @@ type CertAuthorityWatcher struct { // caCollector accompanies resourceWatcher when monitoring cert authority resources. type caCollector struct { CertAuthorityWatcherConfig - host map[string]types.CertAuthority - user map[string]types.CertAuthority + fanout *Fanout + + // lock protects concurrent access to cas lock sync.RWMutex + // cas maps ca type -> cluster -> ca + cas map[types.CertAuthType]map[string]types.CertAuthority +} + +// CertAuthorityTarget lists the attributes of interactions to be disabled. +type CertAuthorityTarget struct { + // ClusterName specifies the name of the cluster to watch. + ClusterName string + // Type specifies the ca types to watch for. + Type types.CertAuthType +} + +// Subscribe is used to subscribe to the lock updates. +func (c *caCollector) Subscribe(ctx context.Context, targets ...CertAuthorityTarget) (types.Watcher, error) { + watchKinds, err := caTargetToWatchKinds(targets) + if err != nil { + return nil, trace.Wrap(err) + } + sub, err := c.fanout.NewWatcher(ctx, types.Watch{Kinds: watchKinds}) + if err != nil { + return nil, trace.Wrap(err) + } + select { + case event := <-sub.Events(): + if event.Type != types.OpInit { + return nil, trace.BadParameter("expected init event, got %v instead", event.Type) + } + case <-sub.Done(): + return nil, trace.Wrap(sub.Error()) + } + return sub, nil +} + +func caTargetToWatchKinds(targets []CertAuthorityTarget) ([]types.WatchKind, error) { + watchKinds := make([]types.WatchKind, 0, len(targets)) + for _, target := range targets { + kind := types.WatchKind{ + Kind: types.KindCertAuthority, + // Note that watching SubKind doesn't work for types.WatchKind - to do so it would + // require a custom filter, which was recently added but - we can't use yet due to + // older clients not supporting the filter. + SubKind: string(target.Type), + } + + if target.ClusterName != "" { + kind.Name = target.ClusterName + } + + watchKinds = append(watchKinds, kind) + } + + if len(watchKinds) == 0 { + watchKinds = []types.WatchKind{{Kind: types.KindCertAuthority}} + } + + return watchKinds, nil } // resourceKind specifies the resource kind to watch. @@ -726,42 +793,27 @@ func (c *caCollector) resourceKind() string { // getResourcesAndUpdateCurrent refreshes the list of current resources. func (c *caCollector) getResourcesAndUpdateCurrent(ctx context.Context) error { - var ( - newHost map[string]types.CertAuthority - newUser map[string]types.CertAuthority - ) + var cas []types.CertAuthority - if c.WatchHostCA { - host, err := c.AuthorityGetter.GetCertAuthorities(ctx, types.HostCA, false) + for _, t := range c.Types { + authorities, err := c.AuthorityGetter.GetCertAuthorities(ctx, t, false) if err != nil { return trace.Wrap(err) } - newHost = make(map[string]types.CertAuthority, len(host)) - for _, ca := range host { - newHost[ca.GetName()] = ca - } - } - if c.WatchUserCA { - user, err := c.AuthorityGetter.GetCertAuthorities(ctx, types.UserCA, false) - if err != nil { - return trace.Wrap(err) - } - newUser = make(map[string]types.CertAuthority, len(user)) - for _, ca := range user { - newUser[ca.GetName()] = ca - } + cas = append(cas, authorities...) } c.lock.Lock() - c.host = newHost - c.user = newUser - c.lock.Unlock() + defer c.lock.Unlock() - select { - case <-ctx.Done(): - return trace.Wrap(ctx.Err()) - case c.CertAuthorityC <- casToSlice(newHost, newUser): + for _, ca := range cas { + if !c.watchingType(ca.GetType()) { + continue + } + + c.cas[ca.GetType()][ca.GetName()] = ca + c.fanout.Emit(types.Event{Type: types.OpPut, Resource: ca.Clone()}) } return nil } @@ -776,17 +828,13 @@ func (c *caCollector) processEventAndUpdateCurrent(ctx context.Context, event ty defer c.lock.Unlock() switch event.Type { case types.OpDelete: - if c.WatchHostCA && event.Resource.GetSubKind() == string(types.HostCA) { - delete(c.host, event.Resource.GetName()) - } - if c.WatchUserCA && event.Resource.GetSubKind() == string(types.UserCA) { - delete(c.user, event.Resource.GetName()) + caType := types.CertAuthType(event.Resource.GetSubKind()) + if !c.watchingType(caType) { + return } - select { - case <-ctx.Done(): - case c.CertAuthorityC <- casToSlice(c.host, c.user): - } + delete(c.cas[caType], event.Resource.GetName()) + c.fanout.Emit(event) case types.OpPut: ca, ok := event.Resource.(types.CertAuthority) if !ok { @@ -794,28 +842,31 @@ func (c *caCollector) processEventAndUpdateCurrent(ctx context.Context, event ty return } - if c.WatchHostCA && ca.GetType() == types.HostCA { - c.host[ca.GetName()] = ca - } - if c.WatchUserCA && ca.GetType() == types.UserCA { - c.user[ca.GetName()] = ca + if !c.watchingType(ca.GetType()) { + return } - select { - case <-ctx.Done(): - case c.CertAuthorityC <- casToSlice(c.host, c.user): + authority, ok := c.cas[ca.GetType()][ca.GetName()] + if ok && CertAuthoritiesEquivalent(authority, ca) { + return } + + c.cas[ca.GetType()][ca.GetName()] = ca + c.fanout.Emit(event) default: c.Log.Warnf("Unsupported event type %s.", event.Type) return } } -// GetCurrent returns the currently stored authorities. -func (c *caCollector) GetCurrent() []types.CertAuthority { - c.lock.RLock() - defer c.lock.RUnlock() - return casToSlice(c.host, c.user) +func (c *caCollector) watchingType(t types.CertAuthType) bool { + for _, caType := range c.Types { + if caType == t { + return true + } + } + + return false } func (c *caCollector) notifyStale() {} @@ -874,17 +925,6 @@ type nodeCollector struct { rw sync.RWMutex } -func casToSlice(host map[string]types.CertAuthority, user map[string]types.CertAuthority) []types.CertAuthority { - slice := make([]types.CertAuthority, 0, len(host)+len(user)) - for _, ca := range host { - slice = append(slice, ca) - } - for _, ca := range user { - slice = append(slice, ca) - } - return slice -} - // Node is a readonly subset of the types.Server interface which // users may filter by in GetNodes. type Node interface { diff --git a/lib/services/watcher_test.go b/lib/services/watcher_test.go index 3bfecf88c67e8..ecfca276e342f 100644 --- a/lib/services/watcher_test.go +++ b/lib/services/watcher_test.go @@ -37,6 +37,7 @@ import ( "github.com/gravitational/teleport/lib/auth/testauthority" "github.com/gravitational/teleport/lib/backend/lite" "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/fixtures" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/services/local" "github.com/gravitational/teleport/lib/tlsca" @@ -517,9 +518,10 @@ func resourceDiff(res1, res2 types.Resource) string { func caDiff(ca1, ca2 types.CertAuthority) string { return cmp.Diff(ca1, ca2, cmpopts.IgnoreFields(types.Metadata{}, "ID"), - cmpopts.IgnoreFields(types.CertAuthoritySpecV2{}, "CheckingKeys", "TLSKeyPairs"), + cmpopts.IgnoreFields(types.CertAuthoritySpecV2{}, "CheckingKeys", "TLSKeyPairs", "JWTKeyPairs"), cmpopts.IgnoreFields(types.SSHKeyPair{}, "PrivateKey"), cmpopts.IgnoreFields(types.TLSKeyPair{}, "Key"), + cmpopts.IgnoreFields(types.JWTKeyPair{}, "PrivateKey"), cmpopts.EquateEmpty(), ) } @@ -527,10 +529,12 @@ func caDiff(ca1, ca2 types.CertAuthority) string { func TestCertAuthorityWatcher(t *testing.T) { t.Parallel() ctx := context.Background() + clock := clockwork.NewFakeClock() bk, err := lite.NewWithConfig(ctx, lite.Config{ Path: t.TempDir(), PollStreamPeriod: 200 * time.Millisecond, + Clock: clock, }) require.NoError(t, err) @@ -548,86 +552,88 @@ func TestCertAuthorityWatcher(t *testing.T) { Trust: caService, Events: local.NewEventsService(bk, nil), }, + Clock: clock, }, - CertAuthorityC: make(chan []types.CertAuthority, 10), - WatchUserCA: true, - WatchHostCA: true, + Types: []types.CertAuthType{types.HostCA, types.UserCA}, }) require.NoError(t, err) t.Cleanup(w.Close) - nothingWatcher, err := services.NewCertAuthorityWatcher(ctx, services.CertAuthorityWatcherConfig{ - ResourceWatcherConfig: services.ResourceWatcherConfig{ - Component: "test", - MaxRetryPeriod: 200 * time.Millisecond, - Client: &client{ - Trust: caService, - Events: local.NewEventsService(bk, nil), - }, - }, - CertAuthorityC: make(chan []types.CertAuthority, 10), - }) + target := services.CertAuthorityTarget{ClusterName: "test"} + sub, err := w.Subscribe(ctx, target) require.NoError(t, err) - t.Cleanup(nothingWatcher.Close) - - require.Empty(t, w.GetCurrent()) - require.Empty(t, nothingWatcher.GetCurrent()) + t.Cleanup(func() { require.NoError(t, sub.Close()) }) - // Initially there are no cas so watcher should send an empty list. + // create a CA for the cluster and a type we are filtering for + // and ensure we receive the event + ca := newCertAuthority(t, "test", types.HostCA) + require.NoError(t, caService.UpsertCertAuthority(ca)) select { - case changeset := <-w.CertAuthorityC: - require.Len(t, changeset, 0) - require.Empty(t, nothingWatcher.GetCurrent()) - case <-w.Done(): - t.Fatal("Watcher has unexpectedly exited.") - case <-time.After(2 * time.Second): - t.Fatal("Timeout waiting for the first event.") + case event := <-sub.Events(): + caFromEvent, ok := event.Resource.(types.CertAuthority) + require.True(t, ok) + require.Empty(t, caDiff(ca, caFromEvent)) + case <-time.After(time.Second): + t.Fatal("timed out waiting for event") } - // Add an authority. - ca1 := newCertAuthority(t, "ca1", types.HostCA) - require.NoError(t, caService.CreateCertAuthority(ca1)) - - // The first event is always the current list of apps. + // create a CA with a type we are filtering for another cluster that we are NOT filtering for + // and ensure that we DO NOT receive the event + require.NoError(t, caService.UpsertCertAuthority(newCertAuthority(t, "unknown", types.UserCA))) select { - case changeset := <-w.CertAuthorityC: - require.Len(t, changeset, 1) - require.Empty(t, caDiff(changeset[0], ca1)) - require.Empty(t, nothingWatcher.GetCurrent()) - case <-w.Done(): - t.Fatal("Watcher has unexpectedly exited.") - case <-time.After(2 * time.Second): - t.Fatal("Timeout waiting for the first event.") + case event := <-sub.Events(): + t.Fatalf("Unexpected event: %v.", event) + case <-sub.Done(): + t.Fatal("CA watcher subscription has unexpectedly exited.") + case <-time.After(time.Second): } - // Add a second ca. - ca2 := newCertAuthority(t, "ca2", types.UserCA) - require.NoError(t, caService.CreateCertAuthority(ca2)) + // create a CA for the cluster and a type we are filtering for + // and ensure we receive the event + ca2 := newCertAuthority(t, "test", types.UserCA) + require.NoError(t, caService.UpsertCertAuthority(ca2)) + select { + case event := <-sub.Events(): + caFromEvent, ok := event.Resource.(types.CertAuthority) + require.True(t, ok) + require.Empty(t, caDiff(ca2, caFromEvent)) + case <-time.After(time.Second): + t.Fatal("timed out waiting for event") + } - // Watcher should detect the ca list change. + // delete a CA with type being watched in the cluster we are filtering for + // and ensure we receive the event + require.NoError(t, caService.DeleteCertAuthority(ca.GetID())) select { - case changeset := <-w.CertAuthorityC: - require.Len(t, changeset, 2) - require.Empty(t, nothingWatcher.GetCurrent()) - case <-w.Done(): - t.Fatal("Watcher has unexpectedly exited.") - case <-time.After(2 * time.Second): - t.Fatal("Timeout waiting for the update event.") + case event := <-sub.Events(): + require.Equal(t, types.KindCertAuthority, event.Resource.GetKind()) + require.Equal(t, string(types.HostCA), event.Resource.GetSubKind()) + require.Equal(t, "test", event.Resource.GetName()) + case <-time.After(time.Second): + t.Fatal("timed out waiting for event") } - // Delete the first ca. - require.NoError(t, caService.DeleteCertAuthority(ca1.GetID())) + // create a CA with a type we are NOT filtering for but for a cluster we are filtering for + // and ensure we DO NOT receive the event + signer := newCertAuthority(t, "test", types.JWTSigner) + require.NoError(t, caService.UpsertCertAuthority(signer)) + select { + case event := <-sub.Events(): + t.Fatalf("Unexpected event: %v.", event) + case <-sub.Done(): + t.Fatal("CA watcher subscription has unexpectedly exited.") + case <-time.After(time.Second): + } - // Watcher should detect the ca list change. + // delete a CA with a name we are filtering for but a type we are NOT filtering for + // and ensure we do NOT receive the event + require.NoError(t, caService.DeleteCertAuthority(signer.GetID())) select { - case changeset := <-w.CertAuthorityC: - require.Len(t, changeset, 1) - require.Empty(t, caDiff(changeset[0], ca2)) - require.Empty(t, nothingWatcher.GetCurrent()) - case <-w.Done(): - t.Fatal("Watcher has unexpectedly exited.") - case <-time.After(2 * time.Second): - t.Fatal("Timeout waiting for the update event.") + case event := <-sub.Events(): + t.Fatalf("Unexpected event: %v.", event) + case <-sub.Done(): + t.Fatal("CA watcher subscription has unexpectedly exited.") + case <-time.After(time.Second): } } @@ -644,15 +650,25 @@ func newCertAuthority(t *testing.T, name string, caType types.CertAuthType) type Type: caType, ClusterName: name, ActiveKeys: types.CAKeySet{ - SSH: []*types.SSHKeyPair{{ - PrivateKey: priv, - PrivateKeyType: types.PrivateKeyType_RAW, - PublicKey: pub, - }}, - TLS: []*types.TLSKeyPair{{ - Cert: cert, - Key: key, - }}, + SSH: []*types.SSHKeyPair{ + { + PrivateKey: priv, + PrivateKeyType: types.PrivateKeyType_RAW, + PublicKey: pub, + }, + }, + TLS: []*types.TLSKeyPair{ + { + Cert: cert, + Key: key, + }, + }, + JWT: []*types.JWTKeyPair{ + { + PublicKey: []byte(fixtures.JWTSignerPublicKey), + PrivateKey: []byte(fixtures.JWTSignerPrivateKey), + }, + }, }, Roles: nil, SigningAlg: types.CertAuthoritySpecV2_RSA_SHA2_256, diff --git a/lib/srv/regular/sshserver_test.go b/lib/srv/regular/sshserver_test.go index c96dea7a41753..ff902b78b2541 100644 --- a/lib/srv/regular/sshserver_test.go +++ b/lib/srv/regular/sshserver_test.go @@ -947,7 +947,7 @@ func TestProxyReverseTunnel(t *testing.T) { listener, reverseTunnelAddress := mustListen(t) defer listener.Close() lockWatcher := newLockWatcher(ctx, t, proxyClient) - + caWatcher := newCertAuthorityWatcher(ctx, t, proxyClient) nodeWatcher := newNodeWatcher(ctx, t, proxyClient) reverseTunnelServer, err := reversetunnel.NewServer(reversetunnel.Config{ @@ -966,6 +966,7 @@ func TestProxyReverseTunnel(t *testing.T) { Emitter: proxyClient, Log: logger, LockWatcher: lockWatcher, + CertAuthorityWatcher: caWatcher, NodeWatcher: nodeWatcher, }) require.NoError(t, err) @@ -1141,6 +1142,7 @@ func TestProxyRoundRobin(t *testing.T) { defer listener.Close() lockWatcher := newLockWatcher(ctx, t, proxyClient) nodeWatcher := newNodeWatcher(ctx, t, proxyClient) + caWatcher := newCertAuthorityWatcher(ctx, t, proxyClient) reverseTunnelServer, err := reversetunnel.NewServer(reversetunnel.Config{ ClusterName: f.testSrv.ClusterName(), @@ -1158,6 +1160,7 @@ func TestProxyRoundRobin(t *testing.T) { Log: logger, LockWatcher: lockWatcher, NodeWatcher: nodeWatcher, + CertAuthorityWatcher: caWatcher, }) require.NoError(t, err) logger.WithField("tun-addr", reverseTunnelAddress.String()).Info("Created reverse tunnel server.") @@ -1266,6 +1269,7 @@ func TestProxyDirectAccess(t *testing.T) { proxyClient, _ := newProxyClient(t, f.testSrv) lockWatcher := newLockWatcher(ctx, t, proxyClient) nodeWatcher := newNodeWatcher(ctx, t, proxyClient) + caWatcher := newCertAuthorityWatcher(ctx, t, proxyClient) reverseTunnelServer, err := reversetunnel.NewServer(reversetunnel.Config{ ClientTLS: proxyClient.TLSConfig(), @@ -1283,6 +1287,7 @@ func TestProxyDirectAccess(t *testing.T) { Log: logger, LockWatcher: lockWatcher, NodeWatcher: nodeWatcher, + CertAuthorityWatcher: caWatcher, }) require.NoError(t, err) @@ -1988,6 +1993,19 @@ func newNodeWatcher(ctx context.Context, t *testing.T, client types.Events) *ser return nodeWatcher } +func newCertAuthorityWatcher(ctx context.Context, t *testing.T, client types.Events) *services.CertAuthorityWatcher { + caWatcher, err := services.NewCertAuthorityWatcher(ctx, services.CertAuthorityWatcherConfig{ + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: "test", + Client: client, + }, + Types: []types.CertAuthType{types.HostCA, types.UserCA}, + }) + require.NoError(t, err) + t.Cleanup(caWatcher.Close) + return caWatcher +} + // maxPipeSize is one larger than the maximum pipe size for most operating // systems which appears to be 65536 bytes. // diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index dbc0b5d958b63..693f064b03948 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -258,6 +258,16 @@ func newWebSuite(t *testing.T) *WebSuite { }) require.NoError(t, err) + caWatcher, err := services.NewCertAuthorityWatcher(s.ctx, services.CertAuthorityWatcherConfig{ + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: teleport.ComponentProxy, + Client: s.proxyClient, + }, + Types: []types.CertAuthType{types.HostCA, types.UserCA}, + }) + require.NoError(t, err) + defer caWatcher.Close() + revTunServer, err := reversetunnel.NewServer(reversetunnel.Config{ ID: node.ID(), Listener: revTunListener, @@ -272,6 +282,7 @@ func newWebSuite(t *testing.T) *WebSuite { DataDir: t.TempDir(), LockWatcher: proxyLockWatcher, NodeWatcher: proxyNodeWatcher, + CertAuthorityWatcher: caWatcher, }) require.NoError(t, err) s.proxyTunnel = revTunServer @@ -2874,6 +2885,16 @@ func createProxy(ctx context.Context, t *testing.T, proxyID string, node *regula require.NoError(t, err) t.Cleanup(proxyLockWatcher.Close) + proxyCAWatcher, err := services.NewCertAuthorityWatcher(ctx, services.CertAuthorityWatcherConfig{ + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: teleport.ComponentProxy, + Client: client, + }, + Types: []types.CertAuthType{types.HostCA, types.UserCA}, + }) + require.NoError(t, err) + t.Cleanup(proxyLockWatcher.Close) + proxyNodeWatcher, err := services.NewNodeWatcher(ctx, services.NodeWatcherConfig{ ResourceWatcherConfig: services.ResourceWatcherConfig{ Component: teleport.ComponentProxy, @@ -2897,6 +2918,7 @@ func createProxy(ctx context.Context, t *testing.T, proxyID string, node *regula DataDir: t.TempDir(), LockWatcher: proxyLockWatcher, NodeWatcher: proxyNodeWatcher, + CertAuthorityWatcher: proxyCAWatcher, }) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, revTunServer.Close()) })