Skip to content

Commit

Permalink
[v10] Fix an issue DB rotation event get send to older remote cluster (
Browse files Browse the repository at this point in the history
  • Loading branch information
smallinsky authored Jul 1, 2022
1 parent 1b8624d commit 1ffeb51
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 162 deletions.
5 changes: 2 additions & 3 deletions integration/db_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -281,9 +281,8 @@ func (p *phaseWatcher) waitForPhase(phase string, fn func() error) error {
return trace.Wrap(err)
}

sub, err := watcher.Subscribe(ctx, services.CertAuthorityTarget{
ClusterName: p.clusterRootName,
Type: p.certType,
sub, err := watcher.Subscribe(ctx, types.CertAuthorityFilter{
p.certType: p.clusterRootName,
})
if err != nil {
return trace.Wrap(err)
Expand Down
36 changes: 15 additions & 21 deletions lib/reversetunnel/remotesite.go
Original file line number Diff line number Diff line change
Expand Up @@ -488,12 +488,12 @@ func (s *remoteSite) updateCertAuthorities(retry utils.Retry, remoteWatcher *ser
}

func (s *remoteSite) watchCertAuthorities(remoteWatcher *services.CertAuthorityWatcher, remoteVersion string, cas map[types.CertAuthType]types.CertAuthority) error {
targets, err := s.getLocalWatchedCerts(remoteVersion)
filter, err := s.getLocalWatchedCerts(remoteVersion)
if err != nil {
return trace.Wrap(err)
}

localWatch, err := s.srv.CertAuthorityWatcher.Subscribe(s.ctx, targets...)
localWatch, err := s.srv.CertAuthorityWatcher.Subscribe(s.ctx, filter)
if err != nil {
return trace.Wrap(err)
}
Expand All @@ -505,9 +505,8 @@ func (s *remoteSite) watchCertAuthorities(remoteWatcher *services.CertAuthorityW

remoteWatch, err := remoteWatcher.Subscribe(
s.ctx,
services.CertAuthorityTarget{
ClusterName: s.domainName,
Type: types.HostCA,
types.CertAuthorityFilter{
types.HostCA: s.domainName,
},
)
if err != nil {
Expand Down Expand Up @@ -587,31 +586,26 @@ func (s *remoteSite) watchCertAuthorities(remoteWatcher *services.CertAuthorityW
}

// getLocalWatchedCerts returns local certificates types that should be watched by the cert authority watcher.
func (s *remoteSite) getLocalWatchedCerts(remoteClusterVersion string) ([]services.CertAuthorityTarget, error) {
localWatchedTypes := []services.CertAuthorityTarget{
{
Type: types.HostCA,
ClusterName: s.srv.ClusterName,
},
{
Type: types.UserCA,
ClusterName: s.srv.ClusterName,
},
}

func (s *remoteSite) getLocalWatchedCerts(remoteClusterVersion string) (types.CertAuthorityFilter, error) {
// Delete in 11.0.
ver10orAbove, err := utils.MinVerWithoutPreRelease(remoteClusterVersion, constants.DatabaseCAMinVersion)
if err != nil {
return nil, trace.Wrap(err)
}

if ver10orAbove {
localWatchedTypes = append(localWatchedTypes, services.CertAuthorityTarget{ClusterName: s.srv.ClusterName, Type: types.DatabaseCA})
} else {
if !ver10orAbove {
s.Debugf("Connected to remote cluster of version %s. Database CA won't be propagated.", remoteClusterVersion)
return types.CertAuthorityFilter{
types.HostCA: s.srv.ClusterName,
types.UserCA: s.srv.ClusterName,
}, nil
}

return localWatchedTypes, nil
return types.CertAuthorityFilter{
types.HostCA: s.srv.ClusterName,
types.UserCA: s.srv.ClusterName,
types.DatabaseCA: s.srv.ClusterName,
}, nil
}

func (s *remoteSite) updateLocks(retry utils.Retry) {
Expand Down
17 changes: 8 additions & 9 deletions lib/reversetunnel/remotesite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import (
"testing"

"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/utils"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
Expand All @@ -32,25 +31,25 @@ func Test_remoteSite_getLocalWatchedCerts(t *testing.T) {
tests := []struct {
name string
clusterVersion string
want []services.CertAuthorityTarget
want types.CertAuthorityFilter
errorAssertion require.ErrorAssertionFunc
}{
{
name: "pre Database CA, only Host and User CA",
clusterVersion: "9.0.0",
want: []services.CertAuthorityTarget{
{Type: types.HostCA, ClusterName: "test"},
{Type: types.UserCA, ClusterName: "test"},
want: types.CertAuthorityFilter{
types.HostCA: "test",
types.UserCA: "test",
},
errorAssertion: require.NoError,
},
{
name: "all certs should be returned",
clusterVersion: "10.0.0",
want: []services.CertAuthorityTarget{
{Type: types.HostCA, ClusterName: "test"},
{Type: types.UserCA, ClusterName: "test"},
{Type: types.DatabaseCA, ClusterName: "test"},
want: types.CertAuthorityFilter{
types.DatabaseCA: "test",
types.HostCA: "test",
types.UserCA: "test",
},
errorAssertion: require.NoError,
},
Expand Down
59 changes: 10 additions & 49 deletions lib/services/watcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -942,16 +942,6 @@ func (cfg *CertAuthorityWatcherConfig) CheckAndSetDefaults() error {
return nil
}

// IsWatched return true if the given certificate auth type is being observer by the watcher.
func (cfg *CertAuthorityWatcherConfig) IsWatched(certType types.CertAuthType) bool {
for _, observedType := range cfg.Types {
if observedType == certType {
return true
}
}
return false
}

// NewCertAuthorityWatcher returns a new instance of CertAuthorityWatcher.
func NewCertAuthorityWatcher(ctx context.Context, cfg CertAuthorityWatcherConfig) (*CertAuthorityWatcher, error) {
if err := cfg.CheckAndSetDefaults(); err != nil {
Expand Down Expand Up @@ -994,21 +984,17 @@ type caCollector struct {
cas map[types.CertAuthType]map[string]types.CertAuthority
}

// CertAuthorityTarget lists the attributes of interactions to be disabled.
type CertAuthorityTarget struct {
// ClusterName specifies the name of the cluster to watch.
ClusterName string
// Type specifies the ca types to watch for.
Type types.CertAuthType
}

// Subscribe is used to subscribe to the lock updates.
func (c *caCollector) Subscribe(ctx context.Context, targets ...CertAuthorityTarget) (types.Watcher, error) {
watchKinds, err := caTargetToWatchKinds(targets)
if err != nil {
return nil, trace.Wrap(err)
}
sub, err := c.fanout.NewWatcher(ctx, types.Watch{Kinds: watchKinds})
func (c *caCollector) Subscribe(ctx context.Context, filter types.CertAuthorityFilter) (types.Watcher, error) {
watch := types.Watch{
Kinds: []types.WatchKind{
{
Kind: c.resourceKind(),
Filter: filter.IntoMap(),
},
},
}
sub, err := c.fanout.NewWatcher(ctx, watch)
if err != nil {
return nil, trace.Wrap(err)
}
Expand All @@ -1023,31 +1009,6 @@ func (c *caCollector) Subscribe(ctx context.Context, targets ...CertAuthorityTar
return sub, nil
}

func caTargetToWatchKinds(targets []CertAuthorityTarget) ([]types.WatchKind, error) {
watchKinds := make([]types.WatchKind, 0, len(targets))
for _, target := range targets {
kind := types.WatchKind{
Kind: types.KindCertAuthority,
// Note that watching SubKind doesn't work for types.WatchKind - to do so it would
// require a custom filter, which was recently added but - we can't use yet due to
// older clients not supporting the filter.
SubKind: string(target.Type),
}

if target.ClusterName != "" {
kind.Name = target.ClusterName
}

watchKinds = append(watchKinds, kind)
}

if len(watchKinds) == 0 {
watchKinds = []types.WatchKind{{Kind: types.KindCertAuthority}}
}

return watchKinds, nil
}

// resourceKind specifies the resource kind to watch.
func (c *caCollector) resourceKind() string {
return types.KindCertAuthority
Expand Down
142 changes: 62 additions & 80 deletions lib/services/watcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -518,17 +518,6 @@ func resourceDiff(res1, res2 types.Resource) string {
cmpopts.EquateEmpty())
}

func caDiff(ca1, ca2 types.CertAuthority) string {
return cmp.Diff(ca1, ca2,
cmpopts.IgnoreFields(types.Metadata{}, "ID"),
cmpopts.IgnoreFields(types.CertAuthoritySpecV2{}, "CheckingKeys", "TLSKeyPairs", "JWTKeyPairs"),
cmpopts.IgnoreFields(types.SSHKeyPair{}, "PrivateKey"),
cmpopts.IgnoreFields(types.TLSKeyPair{}, "Key"),
cmpopts.IgnoreFields(types.JWTKeyPair{}, "PrivateKey"),
cmpopts.EquateEmpty(),
)
}

// TestDatabaseWatcher tests that database resource watcher properly receives
// and dispatches updates to database resources.
func TestDatabaseWatcher(t *testing.T) {
Expand Down Expand Up @@ -755,82 +744,75 @@ func TestCertAuthorityWatcher(t *testing.T) {
require.NoError(t, err)
t.Cleanup(w.Close)

target := services.CertAuthorityTarget{ClusterName: "test"}
sub, err := w.Subscribe(ctx, target)
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, sub.Close()) })

// create a CA for the cluster and a type we are filtering for
// and ensure we receive the event
ca := newCertAuthority(t, "test", types.HostCA)
require.NoError(t, caService.UpsertCertAuthority(ca))
select {
case event := <-sub.Events():
caFromEvent, ok := event.Resource.(types.CertAuthority)
require.True(t, ok)
require.Empty(t, caDiff(ca, caFromEvent))
case <-time.After(time.Second):
t.Fatal("timed out waiting for event")
waitForEvent := func(t *testing.T, sub types.Watcher, caType types.CertAuthType, clusterName string, op types.OpType) {
select {
case event := <-sub.Events():
require.Equal(t, types.KindCertAuthority, event.Resource.GetKind())
require.Equal(t, string(caType), event.Resource.GetSubKind())
require.Equal(t, clusterName, event.Resource.GetName())
require.Equal(t, op, event.Type)
require.Empty(t, sub.Events()) // no more events.
case <-time.After(time.Second):
t.Fatal("timed out waiting for event")
}
}

// create a CA with a type we are filtering for another cluster that we are NOT filtering for
// and ensure that we DO NOT receive the event
require.NoError(t, caService.UpsertCertAuthority(newCertAuthority(t, "unknown", types.UserCA)))
select {
case event := <-sub.Events():
t.Fatalf("Unexpected event: %v.", event)
case <-sub.Done():
t.Fatal("CA watcher subscription has unexpectedly exited.")
case <-time.After(time.Second):
ensureNoEvents := func(t *testing.T, sub types.Watcher) {
select {
case event := <-sub.Events():
t.Fatalf("Unexpected event: %v.", event)
case <-sub.Done():
t.Fatal("CA watcher subscription has unexpectedly exited.")
case <-time.After(time.Second):
}
}

// create a CA for the cluster and a type we are filtering for
// and ensure we receive the event
ca2 := newCertAuthority(t, "test", types.UserCA)
require.NoError(t, caService.UpsertCertAuthority(ca2))
select {
case event := <-sub.Events():
caFromEvent, ok := event.Resource.(types.CertAuthority)
require.True(t, ok)
require.Empty(t, caDiff(ca2, caFromEvent))
case <-time.After(time.Second):
t.Fatal("timed out waiting for event")
}
t.Run("Subscribe all", func(t *testing.T) {
// Use nil CertAuthorityFilter to subscribe all events from the watcher.
sub, err := w.Subscribe(ctx, nil)
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, sub.Close()) })

// Create a CA and ensure we receive the event.
ca := newCertAuthority(t, "test", types.HostCA)
require.NoError(t, caService.UpsertCertAuthority(ca))
waitForEvent(t, sub, types.HostCA, "test", types.OpPut)

// Delete a CA and ensure we receive the event.
require.NoError(t, caService.DeleteCertAuthority(ca.GetID()))
waitForEvent(t, sub, types.HostCA, "test", types.OpDelete)

// Create a CA with a type that the watcher is NOT receiving and ensure
// we DO NOT receive the event.
signer := newCertAuthority(t, "test", types.JWTSigner)
require.NoError(t, caService.UpsertCertAuthority(signer))
ensureNoEvents(t, sub)
})

// delete a CA with type being watched in the cluster we are filtering for
// and ensure we receive the event
require.NoError(t, caService.DeleteCertAuthority(ca.GetID()))
select {
case event := <-sub.Events():
require.Equal(t, types.KindCertAuthority, event.Resource.GetKind())
require.Equal(t, string(types.HostCA), event.Resource.GetSubKind())
require.Equal(t, "test", event.Resource.GetName())
case <-time.After(time.Second):
t.Fatal("timed out waiting for event")
}
t.Run("Subscribe with filter", func(t *testing.T) {
sub, err := w.Subscribe(ctx,
types.CertAuthorityFilter{
types.HostCA: "test",
types.UserCA: types.Wildcard,
},
)
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, sub.Close()) })

// create a CA with a type we are NOT filtering for but for a cluster we are filtering for
// and ensure we DO NOT receive the event
signer := newCertAuthority(t, "test", types.JWTSigner)
require.NoError(t, caService.UpsertCertAuthority(signer))
select {
case event := <-sub.Events():
t.Fatalf("Unexpected event: %v.", event)
case <-sub.Done():
t.Fatal("CA watcher subscription has unexpectedly exited.")
case <-time.After(time.Second):
}
// Receives one HostCA event, matched by type and specific cluster name.
require.NoError(t, caService.UpsertCertAuthority(newCertAuthority(t, "test", types.HostCA)))
waitForEvent(t, sub, types.HostCA, "test", types.OpPut)

// delete a CA with a name we are filtering for but a type we are NOT filtering for
// and ensure we do NOT receive the event
require.NoError(t, caService.DeleteCertAuthority(signer.GetID()))
select {
case event := <-sub.Events():
t.Fatalf("Unexpected event: %v.", event)
case <-sub.Done():
t.Fatal("CA watcher subscription has unexpectedly exited.")
case <-time.After(time.Second):
}
// Receives one UserCA event, matched by type and wildcard cluster name.
require.NoError(t, caService.UpsertCertAuthority(newCertAuthority(t, "unknown", types.UserCA)))
waitForEvent(t, sub, types.UserCA, "unknown", types.OpPut)

// Should NOT receive any HostCA events from another cluster.
// Should NOT receive any DatabaseCA events.
require.NoError(t, caService.UpsertCertAuthority(newCertAuthority(t, "unknown", types.HostCA)))
require.NoError(t, caService.UpsertCertAuthority(newCertAuthority(t, "test", types.DatabaseCA)))
ensureNoEvents(t, sub)
})
}

func newCertAuthority(t *testing.T, name string, caType types.CertAuthType) types.CertAuthority {
Expand Down

0 comments on commit 1ffeb51

Please sign in to comment.