Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve CertAuthorityWatcher #10403

Merged
merged 4 commits into from
May 17, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions integration/db_integration_test.go
Original file line number Diff line number Diff line change
@@ -269,7 +269,7 @@ func (p *phaseWatcher) waitForPhase(phase string, fn func() error) error {
Clock: p.clock,
Client: p.siteAPI,
},
WatchCertTypes: []types.CertAuthType{p.certType},
Types: []types.CertAuthType{p.certType},
})
if err != nil {
return err
@@ -280,16 +280,30 @@ func (p *phaseWatcher) waitForPhase(phase string, fn func() error) error {
return trace.Wrap(err)
}

sub, err := watcher.Subscribe(ctx, services.CertAuthorityTarget{
ClusterName: p.clusterRootName,
Type: p.certType,
})
if err != nil {
return trace.Wrap(err)
}
defer sub.Close()

var lastPhase string
for i := 0; i < 10; i++ {
select {
case <-ctx.Done():
return trace.CompareFailed("failed to converge to phase %q, last phase %q certType: %v err: %v", phase, lastPhase, p.certType, ctx.Err())
case cas := <-watcher.CertAuthorityC:
for _, ca := range cas {
if ca.GetClusterName() == p.clusterRootName &&
ca.GetType() == p.certType &&
ca.GetRotation().Phase == phase {
case <-sub.Done():
return trace.CompareFailed("failed to converge to phase %q, last phase %q certType: %v err: %v", phase, lastPhase, p.certType, sub.Error())
case evt := <-sub.Events():
switch evt.Type {
case types.OpPut:
ca, ok := evt.Resource.(types.CertAuthority)
if !ok {
return trace.BadParameter("expected a ca got type %T", evt.Resource)
}
if ca.GetRotation().Phase == phase {
return nil
}
lastPhase = ca.GetRotation().Phase
71 changes: 31 additions & 40 deletions integration/integration_test.go
Original file line number Diff line number Diff line change
@@ -4097,45 +4097,39 @@ func testRotateTrustedClusters(t *testing.T, suite *integrationTestSuite) {
err = waitForProcessEvent(svc, service.TeleportPhaseChangeEvent, 10*time.Second)
require.NoError(t, err)

// waitForPhase waits until aux cluster detects the rotation
waitForPhase := func(phase string) error {
ctx, cancel := context.WithTimeout(context.Background(), tconf.PollingPeriod*10)
defer cancel()
watcher, err := services.NewCertAuthorityWatcher(ctx, services.CertAuthorityWatcherConfig{
ResourceWatcherConfig: services.ResourceWatcherConfig{
Component: teleport.ComponentProxy,
Clock: tconf.Clock,
Client: aux.GetSiteAPI(clusterAux),
},
Types: []types.CertAuthType{types.HostCA},
})
require.NoError(t, err)
t.Cleanup(watcher.Close)

watcher, err := services.NewCertAuthorityWatcher(ctx, services.CertAuthorityWatcherConfig{
ResourceWatcherConfig: services.ResourceWatcherConfig{
Component: teleport.ComponentProxy,
Clock: tconf.Clock,
Client: aux.GetSiteAPI(clusterAux),
},
WatchCertTypes: []types.CertAuthType{types.HostCA},
})
if err != nil {
return err
}
defer watcher.Close()
// waitForPhase waits until aux cluster detects the rotation
waitForPhase := func(phase string) {
require.Eventually(t, func() bool {
ca, err := aux.Process.GetAuthServer().GetCertAuthority(
ctx,
types.CertAuthID{
Type: types.HostCA,
DomainName: clusterMain,
}, false)
if err != nil {
return false
}

var lastPhase string
for i := 0; i < 10; i++ {
select {
case <-ctx.Done():
return trace.CompareFailed("failed to converge to phase %q, last phase %q", phase, lastPhase)
case cas := <-watcher.CertAuthorityC:
for _, ca := range cas {
if ca.GetClusterName() == clusterMain &&
ca.GetType() == types.HostCA &&
ca.GetRotation().Phase == phase {
return nil
}
lastPhase = ca.GetRotation().Phase
}
if ca.GetRotation().Phase == phase {
return true
}
}
return trace.CompareFailed("failed to converge to phase %q, last phase %q", phase, lastPhase)

return false
}, tconf.PollingPeriod*10, tconf.PollingPeriod/2, "failed to converge to phase %q", phase)
}

err = waitForPhase(types.RotationPhaseInit)
require.NoError(t, err)
waitForPhase(types.RotationPhaseInit)

// update clients
err = svc.GetAuthServer().RotateCertAuthority(ctx, auth.RotateRequest{
@@ -4148,8 +4142,7 @@ func testRotateTrustedClusters(t *testing.T, suite *integrationTestSuite) {
svc, err = waitForReload(serviceC, svc)
require.NoError(t, err)

err = waitForPhase(types.RotationPhaseUpdateClients)
require.NoError(t, err)
waitForPhase(types.RotationPhaseUpdateClients)

// old client should work as is
err = runAndMatch(clt, 8, []string{"echo", "hello world"}, ".*hello world.*")
@@ -4168,8 +4161,7 @@ func testRotateTrustedClusters(t *testing.T, suite *integrationTestSuite) {
svc, err = waitForReload(serviceC, svc)
require.NoError(t, err)

err = waitForPhase(types.RotationPhaseUpdateServers)
require.NoError(t, err)
waitForPhase(types.RotationPhaseUpdateServers)

// new credentials will work from this phase to others
newCreds, err := GenerateUserCreds(UserCredsRequest{Process: svc, Username: suite.me.Username})
@@ -4197,8 +4189,7 @@ func testRotateTrustedClusters(t *testing.T, suite *integrationTestSuite) {
require.NoError(t, err)
t.Log("Service reload completed, waiting for phase.")

err = waitForPhase(types.RotationPhaseStandby)
require.NoError(t, err)
waitForPhase(types.RotationPhaseStandby)
t.Log("Phase completed.")

// new client still works
113 changes: 66 additions & 47 deletions lib/reversetunnel/remotesite.go
Original file line number Diff line number Diff line change
@@ -427,20 +427,12 @@ func (s *remoteSite) compareAndSwapCertAuthority(ca types.CertAuthority) error {
return trace.CompareFailed("remote certificate authority rotation has been updated")
}

func (s *remoteSite) updateCertAuthorities(retry utils.Retry, remoteClusterVersion string) {
s.Debugf("Watching for cert authority changes.")
func (s *remoteSite) updateCertAuthorities(retry utils.Retry, remoteWatcher *services.CertAuthorityWatcher, remoteVersion string) {
defer remoteWatcher.Close()

cas := make(map[types.CertAuthType]types.CertAuthority)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
cas := make(map[types.CertAuthType]types.CertAuthority)
cas := make(CertAuthorityMap)

for {
startedWaiting := s.clock.Now()
select {
case t := <-retry.After():
s.Debugf("Initiating new cert authority watch after waiting %v.", t.Sub(startedWaiting))
retry.Inc()
case <-s.ctx.Done():
return
}

err := s.watchCertAuthorities(remoteClusterVersion)
err := s.watchCertAuthorities(remoteWatcher, remoteVersion, cas)
if err != nil {
switch {
case trace.IsNotFound(err):
@@ -456,70 +448,88 @@ func (s *remoteSite) updateCertAuthorities(retry utils.Retry, remoteClusterVersi
}
}

startedWaiting := s.clock.Now()
select {
case t := <-retry.After():
s.Debugf("Initiating new cert authority watch after waiting %v.", t.Sub(startedWaiting))
retry.Inc()
case <-s.ctx.Done():
return
}
}
}

func (s *remoteSite) watchCertAuthorities(remoteClusterVersion string) error {
localWatchedTypes, err := s.getLocalWatchedCerts(remoteClusterVersion)
func (s *remoteSite) watchCertAuthorities(remoteWatcher *services.CertAuthorityWatcher, remoteVersion string, cas map[types.CertAuthType]types.CertAuthority) error {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
func (s *remoteSite) watchCertAuthorities(remoteWatcher *services.CertAuthorityWatcher, remoteVersion string, cas map[types.CertAuthType]types.CertAuthority) error {
func (s *remoteSite) watchCertAuthorities(remoteWatcher *services.CertAuthorityWatcher, remoteVersion string, cas CertAuthorityMap) error {

targets, err := s.getLocalWatchedCerts(remoteVersion)
if err != nil {
return trace.Wrap(err)
}

localWatcher, err := services.NewCertAuthorityWatcher(s.ctx, services.CertAuthorityWatcherConfig{
ResourceWatcherConfig: services.ResourceWatcherConfig{
Component: teleport.ComponentProxy,
Log: s,
Clock: s.clock,
Client: s.localAccessPoint,
},
WatchCertTypes: localWatchedTypes,
})
localWatch, err := s.srv.CertAuthorityWatcher.Subscribe(s.ctx, targets...)
if err != nil {
return trace.Wrap(err)
}
defer localWatcher.Close()
defer func() {
if err := localWatch.Close(); err != nil {
s.WithError(err).Warn("Failed to close local ca watcher subscription.")
}
}()

remoteWatcher, err := services.NewCertAuthorityWatcher(s.ctx, services.CertAuthorityWatcherConfig{
ResourceWatcherConfig: services.ResourceWatcherConfig{
Component: teleport.ComponentProxy,
Log: s,
Clock: s.clock,
Client: s.remoteAccessPoint,
remoteWatch, err := remoteWatcher.Subscribe(
s.ctx,
services.CertAuthorityTarget{
ClusterName: s.domainName,
Type: types.HostCA,
},
WatchCertTypes: []types.CertAuthType{types.HostCA},
})
)
if err != nil {
return trace.Wrap(err)
}
defer remoteWatcher.Close()
defer func() {
if err := remoteWatch.Close(); err != nil {
s.WithError(err).Warn("Failed to close remote ca watcher subscription.")
}
}()

s.Debugf("Watching for cert authority changes.")
for {
select {
case <-s.ctx.Done():
s.WithError(s.ctx.Err()).Debug("Context is closing.")
return trace.Wrap(s.ctx.Err())
case <-localWatcher.Done():
case <-localWatch.Done():
s.Warn("Local CertAuthority watcher subscription has closed")
return fmt.Errorf("local ca watcher for cluster %s has closed", s.srv.ClusterName)
case <-remoteWatcher.Done():
case <-remoteWatch.Done():
s.Warn("Remote CertAuthority watcher subscription has closed")
return fmt.Errorf("remote ca watcher for cluster %s has closed", s.domainName)
case cas := <-localWatcher.CertAuthorityC:
for _, localCA := range cas {
if localCA.GetClusterName() != s.srv.ClusterName ||
!localWatcher.IsWatched(localCA.GetType()) {
case evt := <-localWatch.Events():
switch evt.Type {
case types.OpPut:
localCA, ok := evt.Resource.(types.CertAuthority)
if !ok {
continue
}

ca, ok := cas[localCA.GetType()]
if ok && services.CertAuthoritiesEquivalent(ca, localCA) {
continue
}

// clone to prevent a race with watcher filtering
localCA = localCA.Clone()
if err := s.remoteClient.RotateExternalCertAuthority(s.ctx, localCA); err != nil {
s.WithError(err).Warn("Failed to rotate external ca")
log.WithError(err).Warn("Failed to rotate external ca")
return trace.Wrap(err)
}

cas[localCA.GetType()] = localCA
}
case cas := <-remoteWatcher.CertAuthorityC:
for _, remoteCA := range cas {
if remoteCA.GetType() != types.HostCA ||
remoteCA.GetClusterName() != s.domainName {
case evt := <-remoteWatch.Events():
switch evt.Type {
case types.OpPut:
remoteCA, ok := evt.Resource.(types.CertAuthority)
if !ok {
continue
}

@@ -549,8 +559,17 @@ func (s *remoteSite) watchCertAuthorities(remoteClusterVersion string) error {
}

// getLocalWatchedCerts returns local certificates types that should be watched by the cert authority watcher.
func (s *remoteSite) getLocalWatchedCerts(remoteClusterVersion string) ([]types.CertAuthType, error) {
localWatchedTypes := []types.CertAuthType{types.HostCA, types.UserCA}
func (s *remoteSite) getLocalWatchedCerts(remoteClusterVersion string) ([]services.CertAuthorityTarget, error) {
localWatchedTypes := []services.CertAuthorityTarget{
{
Type: types.HostCA,
ClusterName: s.srv.ClusterName,
},
{
Type: types.UserCA,
ClusterName: s.srv.ClusterName,
},
}

// Delete in 11.0.
ver10orAbove, err := utils.MinVerWithoutPreRelease(remoteClusterVersion, constants.DatabaseCAMinVersion)
@@ -559,7 +578,7 @@ func (s *remoteSite) getLocalWatchedCerts(remoteClusterVersion string) ([]types.
}

if ver10orAbove {
localWatchedTypes = append(localWatchedTypes, types.DatabaseCA)
localWatchedTypes = append(localWatchedTypes, services.CertAuthorityTarget{ClusterName: s.srv.ClusterName, Type: types.DatabaseCA})
} else {
s.Debugf("Connected to remote cluster of version %s. Database CA won't be propagated.", remoteClusterVersion)
}
29 changes: 22 additions & 7 deletions lib/reversetunnel/remotesite_test.go
Original file line number Diff line number Diff line change
@@ -22,6 +22,7 @@ import (
"testing"

"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/utils"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
@@ -31,34 +32,48 @@ func Test_remoteSite_getLocalWatchedCerts(t *testing.T) {
tests := []struct {
name string
clusterVersion string
want []types.CertAuthType
wantErr bool
want []services.CertAuthorityTarget
errorAssertion require.ErrorAssertionFunc
}{
{
name: "pre Database CA, only Host and User CA",
clusterVersion: "9.0.0",
want: []types.CertAuthType{types.HostCA, types.UserCA},
want: []services.CertAuthorityTarget{
{Type: types.HostCA, ClusterName: "test"},
{Type: types.UserCA, ClusterName: "test"},
},
errorAssertion: require.NoError,
},
{
name: "all certs should be returned",
clusterVersion: "10.0.0",
want: []types.CertAuthType{types.HostCA, types.UserCA, types.DatabaseCA},
want: []services.CertAuthorityTarget{
{Type: types.HostCA, ClusterName: "test"},
{Type: types.UserCA, ClusterName: "test"},
{Type: types.DatabaseCA, ClusterName: "test"},
},
errorAssertion: require.NoError,
},
{
name: "invalid version",
clusterVersion: "foo",
wantErr: true,
errorAssertion: require.Error,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &remoteSite{
srv: &server{
Config: Config{
ClusterName: "test",
},
},
Entry: log.NewEntry(utils.NewLoggerForTests()),
}
got, err := s.getLocalWatchedCerts(tt.clusterVersion)
if (err != nil) != tt.wantErr {
t.Errorf("getLocalWatchedCerts() error = %v, wantErr %v", err, tt.wantErr)
tt.errorAssertion(t, err)
if err != nil {
return
}

35 changes: 27 additions & 8 deletions lib/reversetunnel/srv.go
Original file line number Diff line number Diff line change
@@ -205,6 +205,9 @@ type Config struct {

// NodeWatcher is a node watcher.
NodeWatcher *services.NodeWatcher

// CertAuthorityWatcher is a cert authority watcher.
CertAuthorityWatcher *services.CertAuthorityWatcher
}

// CheckAndSetDefaults checks parameters and sets default values
@@ -259,6 +262,9 @@ func (cfg *Config) CheckAndSetDefaults() error {
if cfg.NodeWatcher == nil {
return trace.BadParameter("missing parameter NodeWatcher")
}
if cfg.CertAuthorityWatcher == nil {
return trace.BadParameter("missing parameter CertAuthorityWatcher")
}
return nil
}

@@ -1040,6 +1046,11 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite,
connInfo.SetExpiry(srv.Clock.Now().Add(srv.offlineThreshold))

closeContext, cancel := context.WithCancel(srv.ctx)
defer func() {
if err != nil {
cancel()
}
}()
remoteSite := &remoteSite{
srv: srv,
domainName: domainName,
@@ -1063,20 +1074,17 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite,

clt, _, err := remoteSite.getRemoteClient()
if err != nil {
cancel()
return nil, trace.Wrap(err)
}
remoteSite.remoteClient = clt

remoteVersion, err := getRemoteAuthVersion(closeContext, sconn)
if err != nil {
cancel()
return nil, trace.Wrap(err)
}

accessPoint, err := createRemoteAccessPoint(srv, clt, remoteVersion, domainName)
if err != nil {
cancel()
return nil, trace.Wrap(err)
}
remoteSite.remoteAccessPoint = accessPoint
@@ -1088,7 +1096,6 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite,
},
})
if err != nil {
cancel()
return nil, trace.Wrap(err)
}
remoteSite.nodeWatcher = nodeWatcher
@@ -1098,7 +1105,6 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite,
// is signed by the correct certificate authority.
certificateCache, err := newHostCertificateCache(srv.Config.KeyGen, srv.localAuthClient)
if err != nil {
cancel()
return nil, trace.Wrap(err)
}
remoteSite.certificateCache = certificateCache
@@ -1111,11 +1117,25 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite,
Clock: srv.Clock,
})
if err != nil {
cancel()
return nil, trace.Wrap(err)
}

go remoteSite.updateCertAuthorities(caRetry, remoteVersion)
remoteWatcher, err := services.NewCertAuthorityWatcher(srv.ctx, services.CertAuthorityWatcherConfig{
ResourceWatcherConfig: services.ResourceWatcherConfig{
Component: teleport.ComponentProxy,
Log: srv.log,
Clock: srv.Clock,
Client: remoteSite.remoteAccessPoint,
},
Types: []types.CertAuthType{types.HostCA},
})
if err != nil {
return nil, trace.Wrap(err)
rosstimothy marked this conversation as resolved.
Show resolved Hide resolved
}

go func() {
remoteSite.updateCertAuthorities(caRetry, remoteWatcher, remoteVersion)
}()

lockRetry, err := utils.NewLinear(utils.LinearConfig{
First: utils.HalfJitter(srv.Config.PollingPeriod),
@@ -1125,7 +1145,6 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite,
Clock: srv.Clock,
})
if err != nil {
cancel()
return nil, trace.Wrap(err)
}

36 changes: 25 additions & 11 deletions lib/service/service.go
Original file line number Diff line number Diff line change
@@ -2853,6 +2853,19 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
return trace.Wrap(err)
}

caWatcher, err := services.NewCertAuthorityWatcher(process.ExitContext(), services.CertAuthorityWatcherConfig{
ResourceWatcherConfig: services.ResourceWatcherConfig{
Component: teleport.ComponentProxy,
Log: process.log.WithField(trace.Component, teleport.ComponentProxy),
Client: conn.Client,
},
AuthorityGetter: accessPoint,
Types: []types.CertAuthType{types.HostCA, types.UserCA, types.DatabaseCA},
})
if err != nil {
return trace.Wrap(err)
}

serverTLSConfig, err := conn.ServerIdentity.TLSConfig(cfg.CipherSuites)
if err != nil {
return trace.Wrap(err)
@@ -2882,17 +2895,18 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
Client: conn.Client,
},
},
KeyGen: cfg.Keygen,
Ciphers: cfg.Ciphers,
KEXAlgorithms: cfg.KEXAlgorithms,
MACAlgorithms: cfg.MACAlgorithms,
DataDir: process.Config.DataDir,
PollingPeriod: process.Config.PollingPeriod,
FIPS: cfg.FIPS,
Emitter: streamEmitter,
Log: process.log,
LockWatcher: lockWatcher,
NodeWatcher: nodeWatcher,
KeyGen: cfg.Keygen,
Ciphers: cfg.Ciphers,
KEXAlgorithms: cfg.KEXAlgorithms,
MACAlgorithms: cfg.MACAlgorithms,
DataDir: process.Config.DataDir,
PollingPeriod: process.Config.PollingPeriod,
FIPS: cfg.FIPS,
Emitter: streamEmitter,
Log: process.log,
LockWatcher: lockWatcher,
NodeWatcher: nodeWatcher,
CertAuthorityWatcher: caWatcher,
})
if err != nil {
return trace.Wrap(err)
151 changes: 95 additions & 56 deletions lib/services/watcher.go
Original file line number Diff line number Diff line change
@@ -919,10 +919,8 @@ type CertAuthorityWatcherConfig struct {
ResourceWatcherConfig
// AuthorityGetter is responsible for fetching cert authority resources.
AuthorityGetter
// CertAuthorityC receives up-to-date list of all cert authority resources.
CertAuthorityC chan []types.CertAuthority
// WatchCertTypes stores all certificate types that should be monitored.
WatchCertTypes []types.CertAuthType
// Types restricts which cert authority types are retrieved via the AuthorityGetter.
Types []types.CertAuthType
}

// CheckAndSetDefaults checks parameters and sets default values.
@@ -937,15 +935,12 @@ func (cfg *CertAuthorityWatcherConfig) CheckAndSetDefaults() error {
}
cfg.AuthorityGetter = getter
}
if cfg.CertAuthorityC == nil {
cfg.CertAuthorityC = make(chan []types.CertAuthority)
}
return nil
}

// IsWatched return true if the given certificate auth type is being observer by the watcher.
func (cfg *CertAuthorityWatcherConfig) IsWatched(certType types.CertAuthType) bool {
for _, observedType := range cfg.WatchCertTypes {
for _, observedType := range cfg.Types {
if observedType == certType {
return true
}
@@ -961,13 +956,20 @@ func NewCertAuthorityWatcher(ctx context.Context, cfg CertAuthorityWatcherConfig

collector := &caCollector{
CertAuthorityWatcherConfig: cfg,
fanout: NewFanout(),
cas: make(map[types.CertAuthType]map[string]types.CertAuthority, len(cfg.Types)),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
cas: make(map[types.CertAuthType]map[string]types.CertAuthority, len(cfg.Types)),
cas: make(types.CertAuthorityTypeMap, len(cfg.Types)),

I also don't mind renaming the CertAuthorityTypeMap if you have any better idea.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I didn't know these were added, but looking at them, they seem really specific to the CAWatcher internals and I'm not entirely convinced that they should even be exported.

// CertAuthorityMap maps clusterName -> types.CertAuthority
type CertAuthorityMap map[string]types.CertAuthority

// CertAuthorityTypeMap maps types.CertAuthType -> map(clusterName -> types.CertAuthority)
type CertAuthorityTypeMap map[types.CertAuthType]CertAuthorityMap

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not even entirely sure they are needed any more after migrating to use the fanout. Thoughts on just removing them?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't mind removing them, just map[types.CertAuthType]map[string]types.CertAuthority looks a bit "scary" in my opinion.

}

for _, t := range cfg.Types {
collector.cas[t] = make(map[string]types.CertAuthority)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
collector.cas[t] = make(map[string]types.CertAuthority)
collector.cas[t] = make(CertAuthorityMap)

}

watcher, err := newResourceWatcher(ctx, collector, cfg.ResourceWatcherConfig)
if err != nil {
return nil, trace.Wrap(err)
}

collector.fanout.SetInit()
return &CertAuthorityWatcher{watcher, collector}, nil
}

@@ -980,26 +982,66 @@ type CertAuthorityWatcher struct {
// caCollector accompanies resourceWatcher when monitoring cert authority resources.
type caCollector struct {
CertAuthorityWatcherConfig
fanout *Fanout

collectedCAs CertAuthorityTypeMap
lock sync.RWMutex
// lock protects concurrent access to cas
lock sync.RWMutex
// cas maps ca type -> cluster -> ca
cas map[types.CertAuthType]map[string]types.CertAuthority
}

// CertAuthorityMap maps clusterName -> types.CertAuthority
type CertAuthorityMap map[string]types.CertAuthority
// CertAuthorityTarget lists the attributes of interactions to be disabled.
type CertAuthorityTarget struct {
// ClusterName specifies the name of the cluster to watch.
ClusterName string
// Type specifies the ca types to watch for.
Type types.CertAuthType
}

// CertAuthorityTypeMap maps types.CertAuthType -> map(clusterName -> types.CertAuthority)
type CertAuthorityTypeMap map[types.CertAuthType]CertAuthorityMap
// Subscribe is used to subscribe to the lock updates.
func (c *caCollector) Subscribe(ctx context.Context, targets ...CertAuthorityTarget) (types.Watcher, error) {
watchKinds, err := caTargetToWatchKinds(targets)
if err != nil {
return nil, trace.Wrap(err)
}
sub, err := c.fanout.NewWatcher(ctx, types.Watch{Kinds: watchKinds})
if err != nil {
return nil, trace.Wrap(err)
}
select {
case event := <-sub.Events():
if event.Type != types.OpInit {
return nil, trace.BadParameter("expected init event, got %v instead", event.Type)
}
case <-sub.Done():
return nil, trace.Wrap(sub.Error())
}
return sub, nil
}

// ToSlice converts CertAuthorityTypeMap to a slice.
func (cat *CertAuthorityTypeMap) ToSlice() []types.CertAuthority {
slice := make([]types.CertAuthority, 0)
for _, cert := range *cat {
for _, ca := range cert {
slice = append(slice, ca)
func caTargetToWatchKinds(targets []CertAuthorityTarget) ([]types.WatchKind, error) {
watchKinds := make([]types.WatchKind, 0, len(targets))
for _, target := range targets {
kind := types.WatchKind{
Kind: types.KindCertAuthority,
// Note that watching SubKind doesn't work for types.WatchKind - to do so it would
// require a custom filter, which was recently added but - we can't use yet due to
// older clients not supporting the filter.
SubKind: string(target.Type),
}

if target.ClusterName != "" {
kind.Name = target.ClusterName
}

watchKinds = append(watchKinds, kind)
}
return slice

if len(watchKinds) == 0 {
watchKinds = []types.WatchKind{{Kind: types.KindCertAuthority}}
}

return watchKinds, nil
}

// resourceKind specifies the resource kind to watch.
@@ -1009,28 +1051,27 @@ func (c *caCollector) resourceKind() string {

// getResourcesAndUpdateCurrent refreshes the list of current resources.
func (c *caCollector) getResourcesAndUpdateCurrent(ctx context.Context) error {
updatedCerts := make(CertAuthorityTypeMap)
var cas []types.CertAuthority

for _, caType := range c.WatchCertTypes {
cas, err := c.AuthorityGetter.GetCertAuthorities(ctx, caType, false)
for _, t := range c.Types {
authorities, err := c.AuthorityGetter.GetCertAuthorities(ctx, t, false)
if err != nil {
return trace.Wrap(err)
}

updatedCerts[caType] = make(CertAuthorityMap, len(cas))
for _, ca := range cas {
updatedCerts[caType][ca.GetName()] = ca
}
cas = append(cas, authorities...)
}

c.lock.Lock()
c.collectedCAs = updatedCerts
c.lock.Unlock()
defer c.lock.Unlock()

select {
case <-ctx.Done():
return trace.Wrap(ctx.Err())
case c.CertAuthorityC <- updatedCerts.ToSlice():
for _, ca := range cas {
if !c.watchingType(ca.GetType()) {
continue
}

c.cas[ca.GetType()][ca.GetName()] = ca
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know if GetCertAuthorities is guaranteed to doublecheck the type of the CAs? If not we should probably doublecheck it here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. Updated to only add them if configured to watch for that ca type now

c.fanout.Emit(types.Event{Type: types.OpPut, Resource: ca.Clone()})
}
return nil
}
@@ -1046,46 +1087,44 @@ func (c *caCollector) processEventAndUpdateCurrent(ctx context.Context, event ty
switch event.Type {
case types.OpDelete:
caType := types.CertAuthType(event.Resource.GetSubKind())

// Check if the certificate should be processed.
_, found := c.collectedCAs[caType]
if found {
delete(c.collectedCAs[caType], event.Resource.GetName())
if !c.watchingType(caType) {
return
}

select {
case <-ctx.Done():
case c.CertAuthorityC <- c.collectedCAs.ToSlice():
}
delete(c.cas[caType], event.Resource.GetName())
rosstimothy marked this conversation as resolved.
Show resolved Hide resolved
c.fanout.Emit(event)
case types.OpPut:
ca, ok := event.Resource.(types.CertAuthority)
if !ok {
c.Log.Warnf("Unexpected resource type %T.", event.Resource)
return
}

caType := ca.GetType()
_, found := c.collectedCAs[caType]
// Check if the certificate should be processed.
if found {
c.collectedCAs[caType][ca.GetName()] = ca
if !c.watchingType(ca.GetType()) {
return
}

select {
case <-ctx.Done():
case c.CertAuthorityC <- c.collectedCAs.ToSlice():
authority, ok := c.cas[ca.GetType()][ca.GetName()]
if ok && CertAuthoritiesEquivalent(authority, ca) {
return
}

c.cas[ca.GetType()][ca.GetName()] = ca
c.fanout.Emit(event)
rosstimothy marked this conversation as resolved.
Show resolved Hide resolved
default:
c.Log.Warnf("Unsupported event type %s.", event.Type)
return
}
}

// GetCurrent returns the currently stored authorities.
func (c *caCollector) GetCurrent() []types.CertAuthority {
c.lock.RLock()
defer c.lock.RUnlock()
return c.collectedCAs.ToSlice()
func (c *caCollector) watchingType(t types.CertAuthType) bool {
for _, caType := range c.Types {
if caType == t {
return true
}
}

return false
}

func (c *caCollector) notifyStale() {}
157 changes: 87 additions & 70 deletions lib/services/watcher_test.go
Original file line number Diff line number Diff line change
@@ -37,6 +37,7 @@ import (
"github.com/gravitational/teleport/lib/auth/testauthority"
"github.com/gravitational/teleport/lib/backend/lite"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/fixtures"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/services/local"
"github.com/gravitational/teleport/lib/tlsca"
@@ -520,9 +521,10 @@ func resourceDiff(res1, res2 types.Resource) string {
func caDiff(ca1, ca2 types.CertAuthority) string {
return cmp.Diff(ca1, ca2,
cmpopts.IgnoreFields(types.Metadata{}, "ID"),
cmpopts.IgnoreFields(types.CertAuthoritySpecV2{}, "CheckingKeys", "TLSKeyPairs"),
cmpopts.IgnoreFields(types.CertAuthoritySpecV2{}, "CheckingKeys", "TLSKeyPairs", "JWTKeyPairs"),
cmpopts.IgnoreFields(types.SSHKeyPair{}, "PrivateKey"),
cmpopts.IgnoreFields(types.TLSKeyPair{}, "Key"),
cmpopts.IgnoreFields(types.JWTKeyPair{}, "PrivateKey"),
cmpopts.EquateEmpty(),
)
}
@@ -723,10 +725,12 @@ func newApp(t *testing.T, name string) types.Application {
func TestCertAuthorityWatcher(t *testing.T) {
t.Parallel()
ctx := context.Background()
clock := clockwork.NewFakeClock()

bk, err := lite.NewWithConfig(ctx, lite.Config{
Path: t.TempDir(),
PollStreamPeriod: 200 * time.Millisecond,
Clock: clock,
})
require.NoError(t, err)

@@ -744,85 +748,88 @@ func TestCertAuthorityWatcher(t *testing.T) {
Trust: caService,
Events: local.NewEventsService(bk),
},
Clock: clock,
},
CertAuthorityC: make(chan []types.CertAuthority, 10),
WatchCertTypes: []types.CertAuthType{types.HostCA, types.UserCA, types.DatabaseCA},
Types: []types.CertAuthType{types.HostCA, types.UserCA, types.DatabaseCA},
})
require.NoError(t, err)
t.Cleanup(w.Close)

nothingWatcher, err := services.NewCertAuthorityWatcher(ctx, services.CertAuthorityWatcherConfig{
ResourceWatcherConfig: services.ResourceWatcherConfig{
Component: "test",
MaxRetryPeriod: 200 * time.Millisecond,
Client: &client{
Trust: caService,
Events: local.NewEventsService(bk),
},
},
CertAuthorityC: make(chan []types.CertAuthority, 10),
})
target := services.CertAuthorityTarget{ClusterName: "test"}
sub, err := w.Subscribe(ctx, target)
require.NoError(t, err)
t.Cleanup(nothingWatcher.Close)

require.Empty(t, w.GetCurrent())
require.Empty(t, nothingWatcher.GetCurrent())
t.Cleanup(func() { require.NoError(t, sub.Close()) })

// Initially there are no cas so watcher should send an empty list.
// create a CA for the cluster and a type we are filtering for
// and ensure we receive the event
ca := newCertAuthority(t, "test", types.HostCA)
require.NoError(t, caService.UpsertCertAuthority(ca))
select {
case changeset := <-w.CertAuthorityC:
require.Len(t, changeset, 0)
require.Empty(t, nothingWatcher.GetCurrent())
case <-w.Done():
t.Fatal("Watcher has unexpectedly exited.")
case <-time.After(2 * time.Second):
t.Fatal("Timeout waiting for the first event.")
case event := <-sub.Events():
caFromEvent, ok := event.Resource.(types.CertAuthority)
require.True(t, ok)
require.Empty(t, caDiff(ca, caFromEvent))
case <-time.After(time.Second):
t.Fatal("timed out waiting for event")
}

// Add an authority.
ca1 := newCertAuthority(t, "ca1", types.HostCA)
require.NoError(t, caService.CreateCertAuthority(ca1))

// The first event is always the current list of apps.
// create a CA with a type we are filtering for another cluster that we are NOT filtering for
// and ensure that we DO NOT receive the event
require.NoError(t, caService.UpsertCertAuthority(newCertAuthority(t, "unknown", types.UserCA)))
select {
case changeset := <-w.CertAuthorityC:
require.Len(t, changeset, 1)
require.Empty(t, caDiff(changeset[0], ca1))
require.Empty(t, nothingWatcher.GetCurrent())
case <-w.Done():
t.Fatal("Watcher has unexpectedly exited.")
case <-time.After(2 * time.Second):
t.Fatal("Timeout waiting for the first event.")
case event := <-sub.Events():
t.Fatalf("Unexpected event: %v.", event)
case <-sub.Done():
t.Fatal("CA watcher subscription has unexpectedly exited.")
case <-time.After(time.Second):
}

// Add a second ca.
ca2 := newCertAuthority(t, "ca2", types.UserCA)
require.NoError(t, caService.CreateCertAuthority(ca2))
// create a CA for the cluster and a type we are filtering for
// and ensure we receive the event
ca2 := newCertAuthority(t, "test", types.UserCA)
require.NoError(t, caService.UpsertCertAuthority(ca2))
select {
case event := <-sub.Events():
caFromEvent, ok := event.Resource.(types.CertAuthority)
require.True(t, ok)
require.Empty(t, caDiff(ca2, caFromEvent))
case <-time.After(time.Second):
t.Fatal("timed out waiting for event")
}

// Watcher should detect the ca list change.
// delete a CA with type being watched in the cluster we are filtering for
// and ensure we receive the event
require.NoError(t, caService.DeleteCertAuthority(ca.GetID()))
select {
case changeset := <-w.CertAuthorityC:
require.Len(t, changeset, 2)
require.Empty(t, nothingWatcher.GetCurrent())
case <-w.Done():
t.Fatal("Watcher has unexpectedly exited.")
case <-time.After(2 * time.Second):
t.Fatal("Timeout waiting for the update event.")
case event := <-sub.Events():
require.Equal(t, types.KindCertAuthority, event.Resource.GetKind())
require.Equal(t, string(types.HostCA), event.Resource.GetSubKind())
require.Equal(t, "test", event.Resource.GetName())
case <-time.After(time.Second):
t.Fatal("timed out waiting for event")
}

// Delete the first ca.
require.NoError(t, caService.DeleteCertAuthority(ca1.GetID()))
// create a CA with a type we are NOT filtering for but for a cluster we are filtering for
// and ensure we DO NOT receive the event
signer := newCertAuthority(t, "test", types.JWTSigner)
require.NoError(t, caService.UpsertCertAuthority(signer))
select {
case event := <-sub.Events():
t.Fatalf("Unexpected event: %v.", event)
case <-sub.Done():
t.Fatal("CA watcher subscription has unexpectedly exited.")
case <-time.After(time.Second):
}

// Watcher should detect the ca list change.
// delete a CA with a name we are filtering for but a type we are NOT filtering for
// and ensure we do NOT receive the event
require.NoError(t, caService.DeleteCertAuthority(signer.GetID()))
select {
case changeset := <-w.CertAuthorityC:
require.Len(t, changeset, 1)
require.Empty(t, caDiff(changeset[0], ca2))
require.Empty(t, nothingWatcher.GetCurrent())
case <-w.Done():
t.Fatal("Watcher has unexpectedly exited.")
case <-time.After(2 * time.Second):
t.Fatal("Timeout waiting for the update event.")
case event := <-sub.Events():
t.Fatalf("Unexpected event: %v.", event)
case <-sub.Done():
t.Fatal("CA watcher subscription has unexpectedly exited.")
case <-time.After(time.Second):
}
}

@@ -839,15 +846,25 @@ func newCertAuthority(t *testing.T, name string, caType types.CertAuthType) type
Type: caType,
ClusterName: name,
ActiveKeys: types.CAKeySet{
SSH: []*types.SSHKeyPair{{
PrivateKey: priv,
PrivateKeyType: types.PrivateKeyType_RAW,
PublicKey: pub,
}},
TLS: []*types.TLSKeyPair{{
Cert: cert,
Key: key,
}},
SSH: []*types.SSHKeyPair{
{
PrivateKey: priv,
PrivateKeyType: types.PrivateKeyType_RAW,
PublicKey: pub,
},
},
TLS: []*types.TLSKeyPair{
{
Cert: cert,
Key: key,
},
},
JWT: []*types.JWTKeyPair{
{
PublicKey: []byte(fixtures.JWTSignerPublicKey),
PrivateKey: []byte(fixtures.JWTSignerPrivateKey),
},
},
},
Roles: nil,
SigningAlg: types.CertAuthoritySpecV2_RSA_SHA2_256,
19 changes: 19 additions & 0 deletions lib/srv/regular/sshserver_test.go
Original file line number Diff line number Diff line change
@@ -1126,6 +1126,7 @@ func TestProxyRoundRobin(t *testing.T) {
defer listener.Close()
lockWatcher := newLockWatcher(ctx, t, proxyClient)
nodeWatcher := newNodeWatcher(ctx, t, proxyClient)
caWatcher := newCertAuthorityWatcher(ctx, t, proxyClient)

reverseTunnelServer, err := reversetunnel.NewServer(reversetunnel.Config{
ClusterName: f.testSrv.ClusterName(),
@@ -1143,6 +1144,7 @@ func TestProxyRoundRobin(t *testing.T) {
Log: logger,
LockWatcher: lockWatcher,
NodeWatcher: nodeWatcher,
CertAuthorityWatcher: caWatcher,
})
require.NoError(t, err)
logger.WithField("tun-addr", reverseTunnelAddress.String()).Info("Created reverse tunnel server.")
@@ -1252,6 +1254,7 @@ func TestProxyDirectAccess(t *testing.T) {
proxyClient, _ := newProxyClient(t, f.testSrv)
lockWatcher := newLockWatcher(ctx, t, proxyClient)
nodeWatcher := newNodeWatcher(ctx, t, proxyClient)
caWatcher := newCertAuthorityWatcher(ctx, t, proxyClient)

reverseTunnelServer, err := reversetunnel.NewServer(reversetunnel.Config{
ClientTLS: proxyClient.TLSConfig(),
@@ -1269,6 +1272,7 @@ func TestProxyDirectAccess(t *testing.T) {
Log: logger,
LockWatcher: lockWatcher,
NodeWatcher: nodeWatcher,
CertAuthorityWatcher: caWatcher,
})
require.NoError(t, err)

@@ -1863,6 +1867,7 @@ func TestIgnorePuTTYSimpleChannel(t *testing.T) {
proxyClient, _ := newProxyClient(t, f.testSrv)
lockWatcher := newLockWatcher(ctx, t, proxyClient)
nodeWatcher := newNodeWatcher(ctx, t, proxyClient)
caWatcher := newCertAuthorityWatcher(ctx, t, proxyClient)

reverseTunnelServer, err := reversetunnel.NewServer(reversetunnel.Config{
ClientTLS: proxyClient.TLSConfig(),
@@ -1880,6 +1885,7 @@ func TestIgnorePuTTYSimpleChannel(t *testing.T) {
Log: logger,
LockWatcher: lockWatcher,
NodeWatcher: nodeWatcher,
CertAuthorityWatcher: caWatcher,
})
require.NoError(t, err)

@@ -2098,6 +2104,19 @@ func newNodeWatcher(ctx context.Context, t *testing.T, client types.Events) *ser
return nodeWatcher
}

func newCertAuthorityWatcher(ctx context.Context, t *testing.T, client types.Events) *services.CertAuthorityWatcher {
caWatcher, err := services.NewCertAuthorityWatcher(ctx, services.CertAuthorityWatcherConfig{
ResourceWatcherConfig: services.ResourceWatcherConfig{
Component: "test",
Client: client,
},
Types: []types.CertAuthType{types.HostCA, types.UserCA},
})
require.NoError(t, err)
t.Cleanup(caWatcher.Close)
return caWatcher
}

// maxPipeSize is one larger than the maximum pipe size for most operating
// systems which appears to be 65536 bytes.
//
22 changes: 22 additions & 0 deletions lib/web/apiserver_test.go
Original file line number Diff line number Diff line change
@@ -275,6 +275,16 @@ func newWebSuite(t *testing.T) *WebSuite {
})
require.NoError(t, err)

caWatcher, err := services.NewCertAuthorityWatcher(s.ctx, services.CertAuthorityWatcherConfig{
ResourceWatcherConfig: services.ResourceWatcherConfig{
Component: teleport.ComponentProxy,
Client: s.proxyClient,
},
Types: []types.CertAuthType{types.HostCA, types.UserCA},
})
require.NoError(t, err)
defer caWatcher.Close()

revTunServer, err := reversetunnel.NewServer(reversetunnel.Config{
ID: node.ID(),
Listener: revTunListener,
@@ -289,6 +299,7 @@ func newWebSuite(t *testing.T) *WebSuite {
DataDir: t.TempDir(),
LockWatcher: proxyLockWatcher,
NodeWatcher: proxyNodeWatcher,
CertAuthorityWatcher: caWatcher,
})
require.NoError(t, err)
s.proxyTunnel = revTunServer
@@ -3746,6 +3757,16 @@ func createProxy(ctx context.Context, t *testing.T, proxyID string, node *regula
require.NoError(t, err)
t.Cleanup(proxyLockWatcher.Close)

proxyCAWatcher, err := services.NewCertAuthorityWatcher(ctx, services.CertAuthorityWatcherConfig{
ResourceWatcherConfig: services.ResourceWatcherConfig{
Component: teleport.ComponentProxy,
Client: client,
},
Types: []types.CertAuthType{types.HostCA, types.UserCA},
})
require.NoError(t, err)
t.Cleanup(proxyLockWatcher.Close)

proxyNodeWatcher, err := services.NewNodeWatcher(ctx, services.NodeWatcherConfig{
ResourceWatcherConfig: services.ResourceWatcherConfig{
Component: teleport.ComponentProxy,
@@ -3769,6 +3790,7 @@ func createProxy(ctx context.Context, t *testing.T, proxyID string, node *regula
DataDir: t.TempDir(),
LockWatcher: proxyLockWatcher,
NodeWatcher: proxyNodeWatcher,
CertAuthorityWatcher: proxyCAWatcher,
})
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, revTunServer.Close()) })