diff --git a/integration/integration_test.go b/integration/integration_test.go index bb46ae378726c..04e7999a00db7 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -3711,10 +3711,10 @@ func testRotateRollback(t *testing.T, s *integrationTestSuite) { // TestRotateTrustedClusters tests CA rotation support for trusted clusters func testRotateTrustedClusters(t *testing.T, suite *integrationTestSuite) { tr := utils.NewTracer(utils.ThisFunction()).Start() - defer tr.Stop() + t.Cleanup(func() { tr.Stop() }) ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + t.Cleanup(cancel) clusterMain := "rotate-main" clusterAux := "rotate-aux" @@ -3773,7 +3773,7 @@ func testRotateTrustedClusters(t *testing.T, suite *integrationTestSuite) { require.NoError(t, err) err = aux.Process.GetAuthServer().UpsertRole(ctx, role) require.NoError(t, err) - trustedClusterToken := "trusted-clsuter-token" + trustedClusterToken := "trusted-cluster-token" err = svc.GetAuthServer().UpsertToken(ctx, services.MustCreateProvisionToken(trustedClusterToken, []types.SystemRole{types.RoleTrustedCluster}, time.Time{})) require.NoError(t, err) @@ -3789,7 +3789,7 @@ func testRotateTrustedClusters(t *testing.T, suite *integrationTestSuite) { tryCreateTrustedCluster(t, aux.Process.GetAuthServer(), trustedCluster) waitForTunnelConnections(t, svc.GetAuthServer(), aux.Secrets.SiteName, 1) - // capture credentials before has reload started to simulate old client + // capture credentials before reload has started to simulate old client initialCreds, err := GenerateUserCreds(UserCredsRequest{ Process: svc, Username: suite.me.Username, @@ -3818,24 +3818,43 @@ func testRotateTrustedClusters(t *testing.T, suite *integrationTestSuite) { }) require.NoError(t, err) - // wait until service phase update to be broadcasted (init phase does not trigger reload) + // wait until service phase update to be broadcast (init phase does not trigger reload) err = waitForProcessEvent(svc, service.TeleportPhaseChangeEvent, 10*time.Second) require.NoError(t, err) // waitForPhase waits until aux cluster detects the rotation waitForPhase := func(phase string) error { + ctx, cancel := context.WithTimeout(context.Background(), tconf.PollingPeriod*10) + defer cancel() + + watcher, err := services.NewCertAuthorityWatcher(ctx, services.CertAuthorityWatcherConfig{ + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: teleport.ComponentProxy, + Clock: tconf.Clock, + Client: aux.GetSiteAPI(clusterAux), + }, + WatchHostCA: true, + }) + if err != nil { + return err + } + defer watcher.Close() + var lastPhase string for i := 0; i < 10; i++ { - ca, err := aux.Process.GetAuthServer().GetCertAuthority(types.CertAuthID{ - Type: types.HostCA, - DomainName: clusterMain, - }, false) - require.NoError(t, err) - if ca.GetRotation().Phase == phase { - return nil + select { + case <-ctx.Done(): + return trace.CompareFailed("failed to converge to phase %q, last phase %q", phase, lastPhase) + case cas := <-watcher.CertAuthorityC: + for _, ca := range cas { + if ca.GetClusterName() == clusterMain && + ca.GetType() == types.HostCA && + ca.GetRotation().Phase == phase { + return nil + } + lastPhase = ca.GetRotation().Phase + } } - lastPhase = ca.GetRotation().Phase - time.Sleep(tconf.PollingPeriod / 2) } return trace.CompareFailed("failed to converge to phase %q, last phase %q", phase, lastPhase) } @@ -3916,7 +3935,7 @@ func testRotateTrustedClusters(t *testing.T, suite *integrationTestSuite) { // shut down the service cancel() // close the service without waiting for the connections to drain - svc.Close() + require.NoError(t, svc.Close()) select { case err := <-runErrCh: @@ -4082,6 +4101,7 @@ func (s *integrationTestSuite) rotationConfig(disableWebService bool) *service.C tconf.PollingPeriod = 500 * time.Millisecond tconf.ClientTimeout = time.Second tconf.ShutdownTimeout = 2 * tconf.ClientTimeout + tconf.MaxRetryPeriod = time.Second return tconf } diff --git a/lib/reversetunnel/remotesite.go b/lib/reversetunnel/remotesite.go index 3346c32525889..fa2212c4f6eba 100644 --- a/lib/reversetunnel/remotesite.go +++ b/lib/reversetunnel/remotesite.go @@ -23,8 +23,6 @@ import ( "sync" "time" - "golang.org/x/crypto/ssh" - "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/client" "github.com/gravitational/teleport/api/constants" @@ -35,13 +33,10 @@ import ( "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/srv/forward" "github.com/gravitational/teleport/lib/utils" - "github.com/gravitational/teleport/lib/utils/interval" - "github.com/gravitational/trace" - "github.com/jonboulle/clockwork" - "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus" + "golang.org/x/crypto/ssh" ) // remoteSite is a remote site that established the inbound connecton to @@ -416,129 +411,177 @@ func (s *remoteSite) compareAndSwapCertAuthority(ca types.CertAuthority) error { return trace.CompareFailed("remote certificate authority rotation has been updated") } -// updateCertAuthorities updates local and remote cert authorities -func (s *remoteSite) updateCertAuthorities() error { - // update main cluster cert authorities on the remote side - // remote side makes sure that only relevant fields - // are updated - hostCA, err := s.localClient.GetCertAuthority(types.CertAuthID{ - Type: types.HostCA, - DomainName: s.srv.ClusterName, - }, false) - if err != nil { - return trace.Wrap(err) - } - err = s.remoteClient.RotateExternalCertAuthority(hostCA) - if err != nil { - return trace.Wrap(err) - } +func (s *remoteSite) updateCertAuthorities(retry utils.Retry) { + s.Debugf("Watching for cert authority changes.") + + for { + startedWaiting := s.clock.Now() + select { + case t := <-retry.After(): + s.Debugf("Initiating new cert authority watch after waiting %v.", t.Sub(startedWaiting)) + retry.Inc() + case <-s.ctx.Done(): + return + } + + err := s.watchCertAuthorities() + if err != nil { + switch { + case trace.IsNotFound(err): + s.Debugf("Remote cluster %v does not support cert authorities rotation yet.", s.domainName) + case trace.IsCompareFailed(err): + s.Infof("Remote cluster has updated certificate authorities, going to force reconnect.") + s.srv.removeSite(s.domainName) + s.Close() + return + case trace.IsConnectionProblem(err): + s.Debugf("Remote cluster %v is offline.", s.domainName) + default: + s.Warningf("Could not perform cert authorities update: %v.", trace.DebugReport(err)) + } + } - userCA, err := s.localClient.GetCertAuthority(types.CertAuthID{ - Type: types.UserCA, - DomainName: s.srv.ClusterName, - }, false) - if err != nil { - return trace.Wrap(err) } - err = s.remoteClient.RotateExternalCertAuthority(userCA) +} + +func (s *remoteSite) watchCertAuthorities() error { + localWatcher, err := services.NewCertAuthorityWatcher(s.ctx, services.CertAuthorityWatcherConfig{ + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: teleport.ComponentProxy, + Log: s, + Clock: s.clock, + Client: s.localAccessPoint, + }, + WatchUserCA: true, + WatchHostCA: true, + }) if err != nil { return trace.Wrap(err) } + defer localWatcher.Close() - // update remote cluster's host cert authoritiy on a local cluster - // local proxy is authorized to perform this operation only for - // host authorities of remote clusters. - remoteCA, err := s.remoteClient.GetCertAuthority(types.CertAuthID{ - Type: types.HostCA, - DomainName: s.domainName, - }, false) + remoteWatcher, err := services.NewCertAuthorityWatcher(s.ctx, services.CertAuthorityWatcherConfig{ + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: teleport.ComponentProxy, + Log: s, + Clock: s.clock, + Client: s.remoteAccessPoint, + }, + WatchHostCA: true, + }) if err != nil { return trace.Wrap(err) } + defer remoteWatcher.Close() - if remoteCA.GetClusterName() != s.domainName { - return trace.BadParameter( - "remote cluster sent different cluster name %v instead of expected one %v", - remoteCA.GetClusterName(), s.domainName) - } + for { + select { + case <-s.ctx.Done(): + s.WithError(s.ctx.Err()).Debug("Context is closing.") + return trace.Wrap(s.ctx.Err()) + case <-localWatcher.Done(): + s.Warn("Local CertAuthority watcher subscription has closed") + return fmt.Errorf("local ca watcher for cluster %s has closed", s.srv.ClusterName) + case <-remoteWatcher.Done(): + s.Warn("Remote CertAuthority watcher subscription has closed") + return fmt.Errorf("remote ca watcher for cluster %s has closed", s.domainName) + case cas := <-localWatcher.CertAuthorityC: + for _, localCA := range cas { + if localCA.GetClusterName() != s.srv.ClusterName || + (localCA.GetType() != types.HostCA && + localCA.GetType() != types.UserCA) { + continue + } - oldRemoteCA, err := s.localClient.GetCertAuthority(types.CertAuthID{ - Type: types.HostCA, - DomainName: remoteCA.GetClusterName(), - }, false) + if err := s.remoteClient.RotateExternalCertAuthority(localCA); err != nil { + s.WithError(err).Warn("Failed to rotate external ca") + return trace.Wrap(err) + } + } + case cas := <-remoteWatcher.CertAuthorityC: + for _, remoteCA := range cas { + if remoteCA.GetType() != types.HostCA || + remoteCA.GetClusterName() != s.domainName { + continue + } - if err != nil && !trace.IsNotFound(err) { - return trace.Wrap(err) - } + oldRemoteCA, err := s.localClient.GetCertAuthority(types.CertAuthID{ + Type: types.HostCA, + DomainName: remoteCA.GetClusterName(), + }, false) + + if err != nil && !trace.IsNotFound(err) { + return trace.Wrap(err) + } - // if CA is changed or does not exist, update backend - if err != nil || !services.CertAuthoritiesEquivalent(oldRemoteCA, remoteCA) { - if err := s.localClient.UpsertCertAuthority(remoteCA); err != nil { - return trace.Wrap(err) + // if CA is changed or does not exist, update backend + if err != nil || !services.CertAuthoritiesEquivalent(oldRemoteCA, remoteCA) { + if err := s.localClient.UpsertCertAuthority(remoteCA); err != nil { + return trace.Wrap(err) + } + } + + // always update our local reference to the cert authority + if err := s.compareAndSwapCertAuthority(remoteCA); err != nil { + return trace.Wrap(err) + } + } } } - - // always update our local reference to the cert authority - return s.compareAndSwapCertAuthority(remoteCA) } -func (s *remoteSite) periodicUpdateCertAuthorities() { - s.Debugf("Updating remote CAs with period %v.", s.srv.PollingPeriod) - periodic := interval.New(interval.Config{ - Duration: s.srv.PollingPeriod, - FirstDuration: utils.HalfJitter(s.srv.PollingPeriod), - Jitter: utils.NewSeventhJitter(), - }) - defer periodic.Stop() +func (s *remoteSite) updateLocks(retry utils.Retry) { + s.Debugf("Watching for remote lock changes.") + for { + startedWaiting := s.clock.Now() select { + case t := <-retry.After(): + s.Debugf("Initiating new lock watch after waiting %v.", t.Sub(startedWaiting)) + retry.Inc() case <-s.ctx.Done(): - s.Debugf("Context is closing.") return - case <-periodic.Next(): - err := s.updateCertAuthorities() - if err != nil { - switch { - case trace.IsNotFound(err): - s.Debugf("Remote cluster %v does not support cert authorities rotation yet.", s.domainName) - case trace.IsCompareFailed(err): - s.Infof("Remote cluster has updated certificate authorities, going to force reconnect.") - s.srv.removeSite(s.domainName) - s.Close() - return - case trace.IsConnectionProblem(err): - s.Debugf("Remote cluster %v is offline.", s.domainName) - default: - s.Warningf("Could not perform cert authorities updated: %v.", trace.DebugReport(err)) - } + } + + if err := s.watchLocks(); err != nil { + switch { + case trace.IsNotImplemented(err): + s.Debugf("Remote cluster %v does not support locks yet.", s.domainName) + case trace.IsConnectionProblem(err): + s.Debugf("Remote cluster %v is offline.", s.domainName) + default: + s.WithError(err).Warning("Could not update remote locks.") } } } } -func (s *remoteSite) periodicUpdateLocks() { - s.Debugf("Updating remote locks with period %v.", s.srv.PollingPeriod) - periodic := interval.New(interval.Config{ - Duration: s.srv.PollingPeriod, - FirstDuration: utils.HalfJitter(s.srv.PollingPeriod), - Jitter: utils.NewSeventhJitter(), - }) - defer periodic.Stop() +func (s *remoteSite) watchLocks() error { + watcher, err := s.srv.LockWatcher.Subscribe(s.ctx) + if err != nil { + s.WithError(err).Error("Failed to subscribe to LockWatcher") + return err + } + defer func() { + if err := watcher.Close(); err != nil { + s.WithError(err).Warn("Failed to close lock watcher subscription.") + } + }() + for { select { + case <-watcher.Done(): + s.WithError(watcher.Error()).Warn("Lock watcher subscription has closed") + return trace.Wrap(watcher.Error()) case <-s.ctx.Done(): - s.Debugf("Context is closing.") - return - case <-periodic.Next(): - locks := s.srv.LockWatcher.GetCurrent() - if err := s.remoteClient.ReplaceRemoteLocks(s.ctx, s.srv.ClusterName, locks); err != nil { - switch { - case trace.IsNotImplemented(err): - s.Debugf("Remote cluster %v does not support locks yet.", s.domainName) - case trace.IsConnectionProblem(err): - s.Debugf("Remote cluster %v is offline.", s.domainName) - default: - s.WithError(err).Warning("Could not update remote locks.") + s.WithError(s.ctx.Err()).Debug("Context is closing.") + return trace.Wrap(s.ctx.Err()) + case evt := <-watcher.Events(): + switch evt.Type { + case types.OpPut, types.OpDelete: + locks := s.srv.LockWatcher.GetCurrent() + if err := s.remoteClient.ReplaceRemoteLocks(s.ctx, s.srv.ClusterName, locks); err != nil { + return trace.Wrap(err) } } } @@ -632,7 +675,7 @@ func (s *remoteSite) dialWithAgent(params DialParams) (net.Conn, error) { MACAlgorithms: s.srv.Config.MACAlgorithms, DataDir: s.srv.Config.DataDir, Address: params.Address, - UseTunnel: UseTunnel(targetConn), + UseTunnel: UseTunnel(s.Logger, targetConn), FIPS: s.srv.FIPS, HostUUID: s.srv.ID, Emitter: s.srv.Config.Emitter, @@ -657,7 +700,7 @@ func (s *remoteSite) dialWithAgent(params DialParams) (net.Conn, error) { // UseTunnel makes a channel request asking for the type of connection. If // the other side does not respond (older cluster) or takes to long to // respond, be on the safe side and assume it's not a tunnel connection. -func UseTunnel(c *sshutils.ChConn) bool { +func UseTunnel(logger *log.Logger, c *sshutils.ChConn) bool { responseCh := make(chan bool, 1) go func() { @@ -673,8 +716,7 @@ func UseTunnel(c *sshutils.ChConn) bool { case response := <-responseCh: return response case <-time.After(1 * time.Second): - // TODO: remove logrus import - logrus.Debugf("Timed out waiting for response: returning false.") + logger.Debugf("Timed out waiting for response: returning false.") return false } } diff --git a/lib/reversetunnel/srv.go b/lib/reversetunnel/srv.go index e6b490bab284c..67e16a01ee959 100644 --- a/lib/reversetunnel/srv.go +++ b/lib/reversetunnel/srv.go @@ -1071,8 +1071,31 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite, } remoteSite.certificateCache = certificateCache - go remoteSite.periodicUpdateCertAuthorities() - go remoteSite.periodicUpdateLocks() + caRetry, err := utils.NewLinear(utils.LinearConfig{ + First: utils.HalfJitter(srv.Config.PollingPeriod), + Step: srv.Config.PollingPeriod / 5, + Max: srv.Config.PollingPeriod, + Jitter: utils.NewHalfJitter(), + Clock: srv.Clock, + }) + if err != nil { + return nil, err + } + + go remoteSite.updateCertAuthorities(caRetry) + + lockRetry, err := utils.NewLinear(utils.LinearConfig{ + First: utils.HalfJitter(srv.Config.PollingPeriod), + Step: srv.Config.PollingPeriod / 5, + Max: srv.Config.PollingPeriod, + Jitter: utils.NewHalfJitter(), + Clock: srv.Clock, + }) + if err != nil { + return nil, err + } + + go remoteSite.updateLocks(lockRetry) return remoteSite, nil } diff --git a/lib/services/authority.go b/lib/services/authority.go index d575a21ada5e1..684895da2a770 100644 --- a/lib/services/authority.go +++ b/lib/services/authority.go @@ -258,7 +258,7 @@ type UserCertParams struct { ClientIP string } -// Check checks the user certificate parameters +// CheckAndSetDefaults checks the user certificate parameters func (c *UserCertParams) CheckAndSetDefaults() error { if c.CASigner == nil || c.CASigningAlg == "" { return trace.BadParameter("CASigner and CASigningAlg are required") @@ -387,7 +387,7 @@ func MarshalCertAuthority(certAuthority types.CertAuthority, opts ...MarshalOpti } } -// CertAuthorityNeedsMigrations returns true if the given CertAuthority needs to be migrated +// CertAuthorityNeedsMigration returns true if the given CertAuthority needs to be migrated func CertAuthorityNeedsMigration(cai types.CertAuthority) (bool, error) { ca, ok := cai.(*types.CertAuthorityV2) if !ok { diff --git a/lib/services/trust.go b/lib/services/trust.go index 3302865b9a0c7..84d4c5d4aeb5b 100644 --- a/lib/services/trust.go +++ b/lib/services/trust.go @@ -18,6 +18,15 @@ package services import "github.com/gravitational/teleport/api/types" +// AuthorityGetter defines interface for fetching cert authority resources. +type AuthorityGetter interface { + // GetCertAuthority returns cert authority by id + GetCertAuthority(id types.CertAuthID, loadKeys bool, opts ...MarshalOption) (types.CertAuthority, error) + + // GetCertAuthorities returns a list of cert authorities + GetCertAuthorities(caType types.CertAuthType, loadKeys bool, opts ...MarshalOption) ([]types.CertAuthority, error) +} + // Trust is responsible for managing certificate authorities // Each authority is managing some domain, e.g. example.com // @@ -28,6 +37,9 @@ import "github.com/gravitational/teleport/api/types" // Remote authorities have only public keys available, so they can // be only used to validate type Trust interface { + // AuthorityGetter retrieves certificate authorities + AuthorityGetter + // CreateCertAuthority inserts a new certificate authority CreateCertAuthority(ca types.CertAuthority) error @@ -45,14 +57,6 @@ type Trust interface { // DeleteAllCertAuthorities deletes cert authorities of a certain type DeleteAllCertAuthorities(caType types.CertAuthType) error - // GetCertAuthority returns certificate authority by given id. Parameter loadSigningKeys - // controls if signing keys are loaded - GetCertAuthority(id types.CertAuthID, loadSigningKeys bool, opts ...MarshalOption) (types.CertAuthority, error) - - // GetCertAuthorities returns a list of authorities of a given type - // loadSigningKeys controls whether signing keys should be loaded or not - GetCertAuthorities(caType types.CertAuthType, loadSigningKeys bool, opts ...MarshalOption) ([]types.CertAuthority, error) - // ActivateCertAuthority moves a CertAuthority from the deactivated list to // the normal list. ActivateCertAuthority(id types.CertAuthID) error diff --git a/lib/services/watcher.go b/lib/services/watcher.go index 04371c9df280c..c1d932323248f 100644 --- a/lib/services/watcher.go +++ b/lib/services/watcher.go @@ -883,3 +883,169 @@ func appsToSlice(apps map[string]types.Application) (slice []types.Application) } return slice } + +// CertAuthorityWatcherConfig is a CertAuthorityWatcher configuration. +type CertAuthorityWatcherConfig struct { + // ResourceWatcherConfig is the resource watcher configuration. + ResourceWatcherConfig + // AuthorityGetter is responsible for fetching cert authority resources. + AuthorityGetter + // CertAuthorityC receives up-to-date list of all cert authority resources. + CertAuthorityC chan []types.CertAuthority + // WatchHostCA indicates that the watcher should monitor types.HostCA + WatchHostCA bool + // WatchUserCA indicates that the watcher should monitor types.UserCA + WatchUserCA bool +} + +// CheckAndSetDefaults checks parameters and sets default values. +func (cfg *CertAuthorityWatcherConfig) CheckAndSetDefaults() error { + if err := cfg.ResourceWatcherConfig.CheckAndSetDefaults(); err != nil { + return trace.Wrap(err) + } + if cfg.AuthorityGetter == nil { + getter, ok := cfg.Client.(AuthorityGetter) + if !ok { + return trace.BadParameter("missing parameter AuthorityGetter and Client not usable as AuthorityGetter") + } + cfg.AuthorityGetter = getter + } + if cfg.CertAuthorityC == nil { + cfg.CertAuthorityC = make(chan []types.CertAuthority) + } + return nil +} + +// NewCertAuthorityWatcher returns a new instance of CertAuthorityWatcher. +func NewCertAuthorityWatcher(ctx context.Context, cfg CertAuthorityWatcherConfig) (*CertAuthorityWatcher, error) { + if err := cfg.CheckAndSetDefaults(); err != nil { + return nil, trace.Wrap(err) + } + + collector := &caCollector{ + CertAuthorityWatcherConfig: cfg, + } + + watcher, err := newResourceWatcher(ctx, collector, cfg.ResourceWatcherConfig) + if err != nil { + return nil, trace.Wrap(err) + } + + return &CertAuthorityWatcher{watcher, collector}, nil +} + +// CertAuthorityWatcher is built on top of resourceWatcher to monitor cert authority resources. +type CertAuthorityWatcher struct { + *resourceWatcher + *caCollector +} + +// caCollector accompanies resourceWatcher when monitoring cert authority resources. +type caCollector struct { + CertAuthorityWatcherConfig + host map[string]types.CertAuthority + user map[string]types.CertAuthority + lock sync.RWMutex +} + +// resourceKind specifies the resource kind to watch. +func (c *caCollector) resourceKind() string { + return types.KindCertAuthority +} + +// getResourcesAndUpdateCurrent refreshes the list of current resources. +func (c *caCollector) getResourcesAndUpdateCurrent(ctx context.Context) error { + var ( + newHost map[string]types.CertAuthority + newUser map[string]types.CertAuthority + ) + + if c.WatchHostCA { + host, err := c.AuthorityGetter.GetCertAuthorities(types.HostCA, false) + if err != nil { + return trace.Wrap(err) + } + newHost = make(map[string]types.CertAuthority, len(host)) + for _, ca := range host { + newHost[ca.GetName()] = ca + } + } + + if c.WatchUserCA { + user, err := c.AuthorityGetter.GetCertAuthorities(types.UserCA, false) + if err != nil { + return trace.Wrap(err) + } + newUser = make(map[string]types.CertAuthority, len(user)) + for _, ca := range user { + newUser[ca.GetName()] = ca + } + } + + c.lock.Lock() + c.host = newHost + c.user = newUser + c.lock.Unlock() + + c.CertAuthorityC <- casToSlice(newHost, newUser) + return nil +} + +// processEventAndUpdateCurrent is called when a watcher event is received. +func (c *caCollector) processEventAndUpdateCurrent(ctx context.Context, event types.Event) { + if event.Resource == nil || event.Resource.GetKind() != types.KindCertAuthority { + c.Log.Warnf("Unexpected event: %v.", event) + return + } + c.lock.Lock() + defer c.lock.Unlock() + switch event.Type { + case types.OpDelete: + if c.WatchHostCA && event.Resource.GetSubKind() == string(types.HostCA) { + delete(c.host, event.Resource.GetName()) + } + if c.WatchUserCA && event.Resource.GetSubKind() == string(types.UserCA) { + delete(c.user, event.Resource.GetName()) + } + + c.CertAuthorityC <- casToSlice(c.host, c.user) + case types.OpPut: + ca, ok := event.Resource.(types.CertAuthority) + if !ok { + c.Log.Warnf("Unexpected resource type %T.", event.Resource) + return + } + + if c.WatchHostCA && ca.GetType() == types.HostCA { + c.host[ca.GetName()] = ca + } + if c.WatchUserCA && ca.GetType() == types.UserCA { + c.user[ca.GetName()] = ca + } + + c.CertAuthorityC <- casToSlice(c.host, c.user) + default: + c.Log.Warnf("Unsupported event type %s.", event.Type) + return + } +} + +// GetCurrent returns the currently stored authorities. +func (c *caCollector) GetCurrent() []types.CertAuthority { + c.lock.RLock() + defer c.lock.RUnlock() + return casToSlice(c.host, c.user) +} + +func (c *caCollector) notifyStale() {} + +func casToSlice(host map[string]types.CertAuthority, user map[string]types.CertAuthority) []types.CertAuthority { + slice := make([]types.CertAuthority, 0, len(host)+len(user)) + for _, ca := range host { + slice = append(slice, ca) + } + for _, ca := range user { + slice = append(slice, ca) + } + return slice +} diff --git a/lib/services/watcher_test.go b/lib/services/watcher_test.go index 99f676e683c51..4d7b3cfbcacc9 100644 --- a/lib/services/watcher_test.go +++ b/lib/services/watcher_test.go @@ -18,6 +18,7 @@ package services_test import ( "context" + "crypto/x509/pkix" "errors" "sync" "testing" @@ -25,16 +26,17 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "github.com/gravitational/trace" - "github.com/jonboulle/clockwork" - "github.com/stretchr/testify/require" - "github.com/gravitational/teleport/api/constants" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/auth/testauthority" "github.com/gravitational/teleport/lib/backend/lite" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/services/local" + "github.com/gravitational/teleport/lib/tlsca" + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/require" ) var _ types.Events = (*errorWatcher)(nil) @@ -512,6 +514,16 @@ func resourceDiff(res1, res2 types.Resource) string { cmpopts.EquateEmpty()) } +func caDiff(ca1, ca2 types.CertAuthority) string { + return cmp.Diff(ca1, ca2, + cmpopts.IgnoreFields(types.Metadata{}, "ID"), + cmpopts.IgnoreFields(types.CertAuthoritySpecV2{}, "CheckingKeys", "TLSKeyPairs"), + cmpopts.IgnoreFields(types.SSHKeyPair{}, "PrivateKey"), + cmpopts.IgnoreFields(types.TLSKeyPair{}, "Key"), + cmpopts.EquateEmpty(), + ) +} + // TestDatabaseWatcher tests that database resource watcher properly receives // and dispatches updates to database resources. func TestDatabaseWatcher(t *testing.T) { @@ -704,3 +716,140 @@ func newApp(t *testing.T, name string) types.Application { require.NoError(t, err) return app } + +func TestCertAuthorityWatcher(t *testing.T) { + t.Parallel() + ctx := context.Background() + + bk, err := lite.NewWithConfig(ctx, lite.Config{ + Path: t.TempDir(), + PollStreamPeriod: 200 * time.Millisecond, + }) + require.NoError(t, err) + + type client struct { + services.Trust + types.Events + } + + caService := local.NewCAService(bk) + w, err := services.NewCertAuthorityWatcher(ctx, services.CertAuthorityWatcherConfig{ + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: "test", + MaxRetryPeriod: 200 * time.Millisecond, + Client: &client{ + Trust: caService, + Events: local.NewEventsService(bk), + }, + }, + CertAuthorityC: make(chan []types.CertAuthority, 10), + WatchUserCA: true, + WatchHostCA: true, + }) + require.NoError(t, err) + t.Cleanup(w.Close) + + nothingWatcher, err := services.NewCertAuthorityWatcher(ctx, services.CertAuthorityWatcherConfig{ + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: "test", + MaxRetryPeriod: 200 * time.Millisecond, + Client: &client{ + Trust: caService, + Events: local.NewEventsService(bk), + }, + }, + CertAuthorityC: make(chan []types.CertAuthority, 10), + }) + require.NoError(t, err) + t.Cleanup(nothingWatcher.Close) + + require.Empty(t, w.GetCurrent()) + require.Empty(t, nothingWatcher.GetCurrent()) + + // Initially there are no cas so watcher should send an empty list. + select { + case changeset := <-w.CertAuthorityC: + require.Len(t, changeset, 0) + require.Empty(t, nothingWatcher.GetCurrent()) + case <-w.Done(): + t.Fatal("Watcher has unexpectedly exited.") + case <-time.After(2 * time.Second): + t.Fatal("Timeout waiting for the first event.") + } + + // Add an authority. + ca1 := newCertAuthority(t, "ca1", types.HostCA) + require.NoError(t, caService.CreateCertAuthority(ca1)) + + // The first event is always the current list of apps. + select { + case changeset := <-w.CertAuthorityC: + require.Len(t, changeset, 1) + require.Empty(t, caDiff(changeset[0], ca1)) + require.Empty(t, nothingWatcher.GetCurrent()) + case <-w.Done(): + t.Fatal("Watcher has unexpectedly exited.") + case <-time.After(2 * time.Second): + t.Fatal("Timeout waiting for the first event.") + } + + // Add a second ca. + ca2 := newCertAuthority(t, "ca2", types.UserCA) + require.NoError(t, caService.CreateCertAuthority(ca2)) + + // Watcher should detect the ca list change. + select { + case changeset := <-w.CertAuthorityC: + require.Len(t, changeset, 2) + require.Empty(t, nothingWatcher.GetCurrent()) + case <-w.Done(): + t.Fatal("Watcher has unexpectedly exited.") + case <-time.After(2 * time.Second): + t.Fatal("Timeout waiting for the update event.") + } + + // Delete the first ca. + require.NoError(t, caService.DeleteCertAuthority(ca1.GetID())) + + // Watcher should detect the ca list change. + select { + case changeset := <-w.CertAuthorityC: + require.Len(t, changeset, 1) + require.Empty(t, caDiff(changeset[0], ca2)) + require.Empty(t, nothingWatcher.GetCurrent()) + case <-w.Done(): + t.Fatal("Watcher has unexpectedly exited.") + case <-time.After(2 * time.Second): + t.Fatal("Timeout waiting for the update event.") + } +} + +func newCertAuthority(t *testing.T, name string, caType types.CertAuthType) types.CertAuthority { + ta := testauthority.New() + priv, pub, err := ta.GenerateKeyPair("") + require.NoError(t, err) + + // CA for cluster1 with 1 key pair. + key, cert, err := tlsca.GenerateSelfSignedCA(pkix.Name{CommonName: name}, nil, time.Minute) + require.NoError(t, err) + + ca, err := types.NewCertAuthority(types.CertAuthoritySpecV2{ + Type: caType, + ClusterName: name, + ActiveKeys: types.CAKeySet{ + SSH: []*types.SSHKeyPair{{ + PrivateKey: priv, + PrivateKeyType: types.PrivateKeyType_RAW, + PublicKey: pub, + }}, + TLS: []*types.TLSKeyPair{{ + Cert: cert, + Key: key, + }}, + }, + Roles: nil, + SigningAlg: types.CertAuthoritySpecV2_RSA_SHA2_256, + }) + require.NoError(t, err) + return ca +}