From d5e4a166d845ecae5476f399585eb95a82b09056 Mon Sep 17 00:00:00 2001 From: John Murret Date: Mon, 21 Aug 2023 10:35:25 -0600 Subject: [PATCH] NET-4943 - Implement ProxyTracker --- agent/proxy-tracker/mock_Logger.go | 31 ++ agent/proxy-tracker/mock_SessionLimiter.go | 53 ++++ agent/proxy-tracker/proxy_tracker.go | 278 ++++++++++++++++ agent/proxy-tracker/proxy_tracker_test.go | 351 +++++++++++++++++++++ 4 files changed, 713 insertions(+) create mode 100644 agent/proxy-tracker/mock_Logger.go create mode 100644 agent/proxy-tracker/mock_SessionLimiter.go create mode 100644 agent/proxy-tracker/proxy_tracker.go create mode 100644 agent/proxy-tracker/proxy_tracker_test.go diff --git a/agent/proxy-tracker/mock_Logger.go b/agent/proxy-tracker/mock_Logger.go new file mode 100644 index 000000000000..b4d28b096e86 --- /dev/null +++ b/agent/proxy-tracker/mock_Logger.go @@ -0,0 +1,31 @@ +// Code generated by mockery v2.32.4. DO NOT EDIT. + +package proxytracker + +import mock "github.com/stretchr/testify/mock" + +// MockLogger is an autogenerated mock type for the Logger type +type MockLogger struct { + mock.Mock +} + +// Error provides a mock function with given fields: args +func (_m *MockLogger) Error(args ...interface{}) { + var _ca []interface{} + _ca = append(_ca, args...) + _m.Called(_ca...) +} + +// NewMockLogger creates a new instance of MockLogger. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockLogger(t interface { + mock.TestingT + Cleanup(func()) +}) *MockLogger { + mock := &MockLogger{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/agent/proxy-tracker/mock_SessionLimiter.go b/agent/proxy-tracker/mock_SessionLimiter.go new file mode 100644 index 000000000000..4a2c5f324a1b --- /dev/null +++ b/agent/proxy-tracker/mock_SessionLimiter.go @@ -0,0 +1,53 @@ +// Code generated by mockery v2.32.4. DO NOT EDIT. + +package proxytracker + +import ( + limiter "github.com/hashicorp/consul/agent/grpc-external/limiter" + mock "github.com/stretchr/testify/mock" +) + +// MockSessionLimiter is an autogenerated mock type for the SessionLimiter type +type MockSessionLimiter struct { + mock.Mock +} + +// BeginSession provides a mock function with given fields: +func (_m *MockSessionLimiter) BeginSession() (limiter.Session, error) { + ret := _m.Called() + + var r0 limiter.Session + var r1 error + if rf, ok := ret.Get(0).(func() (limiter.Session, error)); ok { + return rf() + } + if rf, ok := ret.Get(0).(func() limiter.Session); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(limiter.Session) + } + } + + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// NewMockSessionLimiter creates a new instance of MockSessionLimiter. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockSessionLimiter(t interface { + mock.TestingT + Cleanup(func()) +}) *MockSessionLimiter { + mock := &MockSessionLimiter{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/agent/proxy-tracker/proxy_tracker.go b/agent/proxy-tracker/proxy_tracker.go new file mode 100644 index 000000000000..00c0cf23f990 --- /dev/null +++ b/agent/proxy-tracker/proxy_tracker.go @@ -0,0 +1,278 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package proxytracker + +import ( + "errors" + "fmt" + "sync" + + "github.com/hashicorp/consul/internal/controller" + "github.com/hashicorp/consul/internal/mesh" + "github.com/hashicorp/consul/internal/resource" + + "github.com/hashicorp/consul/agent/grpc-external/limiter" + "github.com/hashicorp/consul/agent/proxycfg" + pbmesh "github.com/hashicorp/consul/proto-public/pbmesh/v1alpha1" + "github.com/hashicorp/consul/proto-public/pbresource" +) + +// Proxy implements the queue.ItemType interface so that it can be used in a controller.Event. +// It is sent on the newProxyConnectionCh channel. +// TODO(ProxyState): needs to support tenancy in the future. +// Key() is current resourceID.Name. +type ProxyConnection struct { + ProxyID *pbresource.ID +} + +func (e *ProxyConnection) Key() string { + return e.ProxyID.GetName() +} + +// proxyWatchData is a handle on all of the relevant bits that is created by calling Watch(). +// It is meant to be stored in the proxies cache by proxyID so that watches can be notified +// when the ProxyState for that proxyID has changed. +type proxyWatchData struct { + // notifyCh is the channel that the watcher receives updates from ProxyTracker. + notifyCh chan *pbmesh.ProxyState + // state is the current/last updated ProxyState for a given proxy. + state *pbmesh.ProxyState + // token is the ACL token provided by the watcher. + token string + // nodeName is the node where the given proxy resides. + nodeName string +} + +type ProxyTrackerConfig struct { + // logger will be used to write log messages. + Logger Logger + + // sessionLimiter is used to enforce xDS concurrency limits. + SessionLimiter SessionLimiter +} + +// ProxyTracker implements the Watcher and Updater interfaces. The Watcher is used by the xds server to add a new proxy +// to this server, and get back a channel for updates. The Updater is used by the ProxyState controller running on the +// server to push ProxyState updates to the notify channel. +type ProxyTracker struct { + config ProxyTrackerConfig + // proxies is a cache of the proxies connected to this server and configuration information for each one. + proxies map[resource.ReferenceKey]*proxyWatchData + // newProxyConnectionCh is the channel that the "updater" retains to receive messages from ProxyTracker that a new + // proxy has connected to ProxyTracker and a signal the "updater" should call PushChanges with a new state. + newProxyConnectionCh chan controller.Event + // shutdownCh is a channel that closes when ProxyTracker is shutdown. ShutdownChannel is never written to, only closed to + // indicate a shutdown has been initiated. + shutdownCh chan struct{} + // mu is a mutex that is used internally for locking when reading and modifying ProxyTracker state, namely the proxies map. + mu sync.Mutex +} + +// NewProxyTracker returns a ProxyTracker instance given a configuration. +func NewProxyTracker(cfg ProxyTrackerConfig) *ProxyTracker { + return &ProxyTracker{ + config: cfg, + proxies: make(map[resource.ReferenceKey]*proxyWatchData), + // buffering this channel since ProxyTracker will be registering watches for all proxies. + // using the buffer will limit errors related to controller and the proxy are both running + // but the controllers listening function is not blocking on the particular receive line. + // This channel is meant to error when the controller is "not ready" which means up and alive. + // This buffer will try to reduce false negatives and limit unnecessary erroring. + newProxyConnectionCh: make(chan controller.Event, 1000), + shutdownCh: make(chan struct{}), + } +} + +// Watch connects a proxy with ProxyTracker and returns the consumer a channel to receive updates, +// a channel to notify of xDS terminated session, and a cancel function to cancel the watch. +func (pt *ProxyTracker) Watch(proxyID *pbresource.ID, + nodeName string, token string) (<-chan *pbmesh.ProxyState, + limiter.SessionTerminatedChan, proxycfg.CancelFunc, error) { + + if err := validateArgs(proxyID, nodeName, token); err != nil { + pt.config.Logger.Error("args failed validation", err) + return nil, nil, nil, err + } + // Begin a session with the xDS session concurrency limiter. + // + // See: https://github.com/hashicorp/consul/issues/15753 + session, err := pt.config.SessionLimiter.BeginSession() + if err != nil { + pt.config.Logger.Error("failed to begin session with xDS session concurrency limiter", err) + return nil, nil, nil, err + } + + // This buffering is crucial otherwise we'd block immediately trying to + // deliver the current snapshot below if we already have one. + proxyStateChan := make(chan *pbmesh.ProxyState, 1) + watchData := &proxyWatchData{ + notifyCh: proxyStateChan, + state: nil, + token: token, + nodeName: nodeName, + } + + proxyReferenceKey := resource.NewReferenceKey(proxyID) + cancel := func() { + pt.mu.Lock() + defer pt.mu.Unlock() + pt.cancelWatchLocked(proxyReferenceKey, proxyStateChan, session) + } + + pt.mu.Lock() + defer pt.mu.Unlock() + + pt.proxies[proxyReferenceKey] = watchData + + //Send an event to the controller + err = pt.notifyNewProxyChannel(proxyID) + if err != nil { + pt.cancelWatchLocked(proxyReferenceKey, watchData.notifyCh, session) + return nil, nil, nil, err + } + + return proxyStateChan, session.Terminated(), cancel, nil +} + +// notifyNewProxyChannel attempts to send a message to newProxyConnectionCh and will return an error if there's no receiver. +// This will handle conditions where a proxy is connected but there's no controller for some reason to receive the event. +// This will error back to the proxy's Watch call and will cause the proxy call Watch again to retry connection until the controller +// is available. +func (pt *ProxyTracker) notifyNewProxyChannel(proxyID *pbresource.ID) error { + controllerEvent := controller.Event{ + Obj: &ProxyConnection{ + ProxyID: proxyID, + }, + } + select { + case pt.newProxyConnectionCh <- controllerEvent: + return nil + // using default here to return errors is only safe when we have a large buffer. + // the receiver is on a loop to read from the channel. If the sequence of + // sender blocks on the channel and then the receiver blocks on the channel is not + // aligned, then extraneous errors could be returned to the proxy that are just + // false negatives and the controller could be up and healthy. + default: + return fmt.Errorf("failed to notify the controller of the proxy connecting") + } +} + +// cancelWatchLocked does the following: +// - deletes the key from the proxies array. +// - ends the session with xDS session limiter. +// - closes the proxy state channel assigned to the proxy. +// This function assumes the state lock is already held. +func (pt *ProxyTracker) cancelWatchLocked(proxyReferenceKey resource.ReferenceKey, proxyStateChan chan *pbmesh.ProxyState, session limiter.Session) { + delete(pt.proxies, proxyReferenceKey) + session.End() + close(proxyStateChan) +} + +func validateArgs(proxyID *pbresource.ID, + nodeName string, token string) error { + if proxyID == nil { + return errors.New("proxyID is required") + } else if proxyID.Type.Kind != mesh.ProxyStateTemplateConfigurationType.Kind { + return fmt.Errorf("proxyID must be a %s", mesh.ProxyStateTemplateConfigurationType.GetKind()) + } else if nodeName == "" { + return errors.New("nodeName is required") + } else if token == "" { + return errors.New("token is required") + } + + return nil +} + +// PushChange allows pushing a computed ProxyState to xds for xds resource generation to send to a proxy. +func (pt *ProxyTracker) PushChange(proxyID *pbresource.ID, proxyState *pbmesh.ProxyState) error { + proxyReferenceKey := resource.NewReferenceKey(proxyID) + pt.mu.Lock() + defer pt.mu.Unlock() + if data, ok := pt.proxies[proxyReferenceKey]; ok { + data.state = proxyState + pt.deliverLatest(proxyID, proxyState, data.notifyCh) + } else { + return errors.New("proxyState change could not be sent because proxy is not connected") + } + + return nil +} + +func (pt *ProxyTracker) deliverLatest(proxyID *pbresource.ID, proxyState *pbmesh.ProxyState, ch chan *pbmesh.ProxyState) { + // Send if chan is empty + select { + case ch <- proxyState: + return + default: + } + + // Not empty, drain the chan of older snapshots and redeliver. For now we only + // use 1-buffered chans but this will still work if we change that later. +OUTER: + for { + select { + case <-ch: + continue + default: + break OUTER + } + } + + // Now send again + select { + case ch <- proxyState: + return + default: + // This should not be possible since we should be the only sender, enforced + // by m.mu but error and drop the update rather than panic. + pt.config.Logger.Error("failed to deliver proxyState to proxy", + "proxy", proxyID.String(), + ) + } +} + +// EventChannel returns an event channel that sends controller events when a proxy connects to a server. +func (pt *ProxyTracker) EventChannel() chan controller.Event { + return pt.newProxyConnectionCh +} + +// ShutdownChannel returns a channel that closes when ProxyTracker is shutdown. ShutdownChannel is never written to, only closed to +// indicate a shutdown has been initiated. +func (pt *ProxyTracker) ShutdownChannel() chan struct{} { + return pt.shutdownCh +} + +// ProxyConnectedToServer returns whether this id is connected to this server. +func (pt *ProxyTracker) ProxyConnectedToServer(proxyID *pbresource.ID) bool { + pt.mu.Lock() + defer pt.mu.Unlock() + proxyReferenceKey := resource.NewReferenceKey(proxyID) + _, ok := pt.proxies[proxyReferenceKey] + return ok +} + +// Shutdown removes all state and close all channels. +func (pt *ProxyTracker) Shutdown() { + pt.mu.Lock() + defer pt.mu.Unlock() + + // Close all current watchers first + for proxyID, watchData := range pt.proxies { + close(watchData.notifyCh) + delete(pt.proxies, proxyID) + } + + close(pt.newProxyConnectionCh) + close(pt.shutdownCh) +} + +//go:generate mockery --name SessionLimiter --inpackage +type SessionLimiter interface { + BeginSession() (limiter.Session, error) +} + +//go:generate mockery --name Logger --inpackage +type Logger interface { + Error(args ...any) +} diff --git a/agent/proxy-tracker/proxy_tracker_test.go b/agent/proxy-tracker/proxy_tracker_test.go new file mode 100644 index 000000000000..3913f95b213f --- /dev/null +++ b/agent/proxy-tracker/proxy_tracker_test.go @@ -0,0 +1,351 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package proxytracker + +import ( + "errors" + "fmt" + "github.com/hashicorp/consul/agent/grpc-external/limiter" + "github.com/hashicorp/consul/internal/controller" + "github.com/hashicorp/consul/internal/mesh" + "github.com/hashicorp/consul/internal/resource" + "github.com/hashicorp/consul/internal/resource/resourcetest" + pbmesh "github.com/hashicorp/consul/proto-public/pbmesh/v1alpha1" + "github.com/hashicorp/consul/proto-public/pbresource" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "testing" +) + +func TestProxyTracker_Watch(t *testing.T) { + resourceID := resourcetest.Resource(mesh.ProxyStateTemplateConfigurationType, "test").ID() + proxyReferenceKey := resource.NewReferenceKey(resourceID) + lim := NewMockSessionLimiter(t) + session1 := newMockSession(t) + session1TermCh := make(limiter.SessionTerminatedChan) + session1.On("Terminated").Return(session1TermCh) + session1.On("End").Return() + lim.On("BeginSession").Return(session1, nil) + logger := NewMockLogger(t) + + pt := NewProxyTracker(ProxyTrackerConfig{ + Logger: logger, + SessionLimiter: lim, + }) + + // Watch() + proxyStateChan, _, cancelFunc, err := pt.Watch(resourceID, "node 1", "token") + require.NoError(t, err) + + // ensure New Proxy Connection message is sent + newProxyMsg := <-pt.EventChannel() + require.Equal(t, resourceID.Name, newProxyMsg.Obj.Key()) + + // watchData is stored in the proxies array with a nil state + watchData, ok := pt.proxies[proxyReferenceKey] + require.True(t, ok) + require.NotNil(t, watchData) + require.Nil(t, watchData.state) + + // calling cancelFunc does the following: + // - closes the proxy state channel + // - and removes the map entry for the proxy + // - session is ended + cancelFunc() + + // read channel to see if there is data and it is open. + receivedState, channelOpen := <-proxyStateChan + require.Nil(t, receivedState) + require.False(t, channelOpen) + + // key is removed from proxies array + _, ok = pt.proxies[proxyReferenceKey] + require.False(t, ok) + + // session ended + session1.AssertCalled(t, "Terminated") + session1.AssertCalled(t, "End") +} + +func TestProxyTracker_Watch_ErrorConsumerNotReady(t *testing.T) { + resourceID := resourcetest.Resource(mesh.ProxyStateTemplateConfigurationType, "test").ID() + proxyReferenceKey := resource.NewReferenceKey(resourceID) + lim := NewMockSessionLimiter(t) + session1 := newMockSession(t) + session1.On("End").Return() + lim.On("BeginSession").Return(session1, nil) + logger := NewMockLogger(t) + + pt := NewProxyTracker(ProxyTrackerConfig{ + Logger: logger, + SessionLimiter: lim, + }) + + //fill up buffered channel while the consumer is not ready to simulate the error + for i := 0; i < 1000; i++ { + event := controller.Event{Obj: &ProxyConnection{ProxyID: resourcetest.Resource(mesh.ProxyStateTemplateConfigurationType, fmt.Sprintf("test%d", i)).ID()}} + pt.newProxyConnectionCh <- event + } + + // Watch() + proxyStateChan, sessionTerminatedCh, cancelFunc, err := pt.Watch(resourceID, "node 1", "token") + require.Nil(t, cancelFunc) + require.Nil(t, proxyStateChan) + require.Nil(t, sessionTerminatedCh) + require.Error(t, err) + require.Equal(t, "failed to notify the controller of the proxy connecting", err.Error()) + + // it is not stored in the proxies array + watchData, ok := pt.proxies[proxyReferenceKey] + require.False(t, ok) + require.Nil(t, watchData) +} + +func TestProxyTracker_Watch_ArgValidationErrors(t *testing.T) { + type testcase struct { + description string + proxyID *pbresource.ID + nodeName string + token string + expectedError error + } + testcases := []*testcase{ + { + description: "Empty proxyID", + proxyID: nil, + nodeName: "something", + token: "something", + expectedError: errors.New("proxyID is required"), + }, + { + description: "Empty nodeName", + proxyID: resourcetest.Resource(mesh.ProxyStateTemplateConfigurationType, "test").ID(), + nodeName: "", + token: "something", + expectedError: errors.New("nodeName is required"), + }, + { + description: "Empty token", + proxyID: resourcetest.Resource(mesh.ProxyStateTemplateConfigurationType, "test").ID(), + nodeName: "something", + token: "", + expectedError: errors.New("token is required"), + }, + { + description: "resource is not ProxyStateTemplate", + proxyID: resourcetest.Resource(mesh.ProxyConfigurationType, "test").ID(), + nodeName: "something", + token: "something else", + expectedError: errors.New("proxyID must be a ProxyStateTemplate"), + }, + } + for _, tc := range testcases { + lim := NewMockSessionLimiter(t) + lim.On("BeginSession").Return(nil, nil).Maybe() + logger := NewMockLogger(t) + logger.On("Error", mock.Anything, mock.Anything).Return(nil) + + pt := NewProxyTracker(ProxyTrackerConfig{ + Logger: logger, + SessionLimiter: lim, + }) + + // Watch() + proxyStateChan, sessionTerminateCh, cancelFunc, err := pt.Watch(tc.proxyID, tc.nodeName, tc.token) + require.Error(t, err) + require.Equal(t, tc.expectedError, err) + require.Nil(t, proxyStateChan) + require.Nil(t, sessionTerminateCh) + require.Nil(t, cancelFunc) + } +} + +func TestProxyTracker_Watch_SessionLimiterError(t *testing.T) { + resourceID := resourcetest.Resource(mesh.ProxyStateTemplateConfigurationType, "test").ID() + lim := NewMockSessionLimiter(t) + lim.On("BeginSession").Return(nil, errors.New("kaboom")) + logger := NewMockLogger(t) + logger.On("Error", mock.Anything, mock.Anything).Return(nil) + + pt := NewProxyTracker(ProxyTrackerConfig{ + Logger: logger, + SessionLimiter: lim, + }) + + // Watch() + proxyStateChan, sessionTerminateCh, cancelFunc, err := pt.Watch(resourceID, "node 1", "token") + require.Error(t, err) + require.Equal(t, "kaboom", err.Error()) + require.Nil(t, proxyStateChan) + require.Nil(t, sessionTerminateCh) + require.Nil(t, cancelFunc) +} + +func TestProxyTracker_PushChange(t *testing.T) { + resourceID := resourcetest.Resource(mesh.ProxyStateTemplateConfigurationType, "test").ID() + proxyReferenceKey := resource.NewReferenceKey(resourceID) + lim := NewMockSessionLimiter(t) + session1 := newMockSession(t) + session1TermCh := make(limiter.SessionTerminatedChan) + session1.On("Terminated").Return(session1TermCh) + lim.On("BeginSession").Return(session1, nil) + logger := NewMockLogger(t) + + pt := NewProxyTracker(ProxyTrackerConfig{ + Logger: logger, + SessionLimiter: lim, + }) + + // Watch() + proxyStateChan, _, _, err := pt.Watch(resourceID, "node 1", "token") + require.NoError(t, err) + + // PushChange + proxyState := &pbmesh.ProxyState{ + IntentionDefaultAllow: true, + } + + // using a goroutine so that the channel and main test thread do not cause + // blocking issues with each other + go func() { + err = pt.PushChange(resourceID, proxyState) + require.NoError(t, err) + }() + + // channel receives a copy + receivedState, channelOpen := <-proxyStateChan + require.True(t, channelOpen) + require.Equal(t, proxyState, receivedState) + + // it is stored in the proxies array + watchData, ok := pt.proxies[proxyReferenceKey] + require.True(t, ok) + require.Equal(t, proxyState, watchData.state) +} + +func TestProxyTracker_PushChanges_ErrorProxyNotConnected(t *testing.T) { + resourceID := resourcetest.Resource(mesh.ProxyStateTemplateConfigurationType, "test").ID() + lim := NewMockSessionLimiter(t) + logger := NewMockLogger(t) + + pt := NewProxyTracker(ProxyTrackerConfig{ + Logger: logger, + SessionLimiter: lim, + }) + + // PushChange + proxyState := &pbmesh.ProxyState{ + IntentionDefaultAllow: true, + } + + err := pt.PushChange(resourceID, proxyState) + require.Error(t, err) + require.Equal(t, "proxyState change could not be sent because proxy is not connected", err.Error()) +} + +func TestProxyTracker_ProxyConnectedToServer(t *testing.T) { + type testcase struct { + name string + shouldExist bool + preProcessingFunc func(pt *ProxyTracker, resourceID *pbresource.ID, limiter *MockSessionLimiter, session *mockSession, channel limiter.SessionTerminatedChan) + } + testsCases := []*testcase{ + { + name: "Resource that has not been sent through Watch() should return false", + shouldExist: false, + preProcessingFunc: func(pt *ProxyTracker, resourceID *pbresource.ID, limiter *MockSessionLimiter, session *mockSession, channel limiter.SessionTerminatedChan) { + session.On("Terminated").Return(channel).Maybe() + session.On("End").Return().Maybe() + limiter.On("BeginSession").Return(session, nil).Maybe() + }, + }, + { + name: "Resource used that is already passed in through Watch() should return true", + shouldExist: true, + preProcessingFunc: func(pt *ProxyTracker, resourceID *pbresource.ID, limiter *MockSessionLimiter, session *mockSession, channel limiter.SessionTerminatedChan) { + session.On("Terminated").Return(channel).Maybe() + session.On("End").Return().Maybe() + limiter.On("BeginSession").Return(session, nil) + _, _, _, _ = pt.Watch(resourceID, "node 1", "token") + }, + }, + } + + for _, tc := range testsCases { + lim := NewMockSessionLimiter(t) + session1 := newMockSession(t) + session1TermCh := make(limiter.SessionTerminatedChan) + logger := NewMockLogger(t) + + pt := NewProxyTracker(ProxyTrackerConfig{ + Logger: logger, + SessionLimiter: lim, + }) + resourceID := resourcetest.Resource(mesh.ProxyStateTemplateConfigurationType, "test").ID() + tc.preProcessingFunc(pt, resourceID, lim, session1, session1TermCh) + require.Equal(t, tc.shouldExist, pt.ProxyConnectedToServer(resourceID)) + } +} + +func TestProxyTracker_Shutdown(t *testing.T) { + resourceID := resourcetest.Resource(mesh.ProxyStateTemplateConfigurationType, "test").ID() + proxyReferenceKey := resource.NewReferenceKey(resourceID) + lim := NewMockSessionLimiter(t) + session1 := newMockSession(t) + session1TermCh := make(limiter.SessionTerminatedChan) + session1.On("Terminated").Return(session1TermCh) + session1.On("End").Return().Maybe() + lim.On("BeginSession").Return(session1, nil) + logger := NewMockLogger(t) + + pt := NewProxyTracker(ProxyTrackerConfig{ + Logger: logger, + SessionLimiter: lim, + }) + + // Watch() + proxyStateChan, _, _, err := pt.Watch(resourceID, "node 1", "token") + require.NoError(t, err) + + pt.Shutdown() + + // proxy channels are all disconnected and proxy is removed from proxies map + receivedState, channelOpen := <-proxyStateChan + require.Nil(t, receivedState) + require.False(t, channelOpen) + _, ok := pt.proxies[proxyReferenceKey] + require.False(t, ok) + + // shutdownCh is closed + select { + case <-pt.ShutdownChannel(): + default: + t.Fatalf("shutdown channel should be closed") + } + // newProxyConnectionCh is closed + select { + case <-pt.EventChannel(): + default: + t.Fatalf("shutdown channel should be closed") + } +} + +type mockSession struct { + mock.Mock +} + +func newMockSession(t *testing.T) *mockSession { + m := &mockSession{} + m.Mock.Test(t) + + t.Cleanup(func() { m.AssertExpectations(t) }) + + return m +} + +func (m *mockSession) End() { m.Called() } + +func (m *mockSession) Terminated() limiter.SessionTerminatedChan { + return m.Called().Get(0).(limiter.SessionTerminatedChan) +}