Skip to content

Commit

Permalink
Record and emit the timestamp that the last connection was establishe…
Browse files Browse the repository at this point in the history
…d to ACS
  • Loading branch information
Tianze Shan authored and danehlim committed Nov 16, 2023
1 parent 4cc80f3 commit ea9a929
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 0 deletions.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 11 additions & 0 deletions ecs-agent/acs/session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ const (
// The Session.Start() method can be used to start processing messages from ACS.
type Session interface {
Start(context.Context) error
GetLastConnectedTime() time.Time
}

// session encapsulates all arguments needed to connect to ACS and to handle messages received by ACS.
Expand Down Expand Up @@ -97,6 +98,7 @@ type session struct {
disconnectTimeout time.Duration
disconnectJitter time.Duration
inactiveInstanceReconnectDelay time.Duration
lastConnectedTime time.Time
}

// NewSession creates a new Session.
Expand Down Expand Up @@ -155,6 +157,7 @@ func NewSession(containerInstanceARN string,
disconnectTimeout: wsclient.DisconnectTimeout,
disconnectJitter: wsclient.DisconnectJitterMax,
inactiveInstanceReconnectDelay: inactiveInstanceReconnectDelay,
lastConnectedTime: time.Time{},
}
}

Expand Down Expand Up @@ -245,6 +248,9 @@ func (s *session) startSessionOnce(ctx context.Context) error {
}
defer disconnectTimer.Stop()

// Record the timestamp of the last connection to ACS.
s.lastConnectedTime = time.Now()

// Connection to ACS was successful. Moving forward, rely on ACS to send credentials to Agent at its own cadence
// and make sure Agent does not force ACS to send credentials for any subsequent reconnects to ACS.
logger.Info("Connected to ACS endpoint")
Expand Down Expand Up @@ -425,3 +431,8 @@ func formatDockerVersion(dockerVersionValue string) string {
}
return dockerVersionValue
}

// GetLastConnectedTime returns the timestamp that the last connection was established to ACS.
func (s *session) GetLastConnectedTime() time.Time {
return s.lastConnectedTime
}
92 changes: 92 additions & 0 deletions ecs-agent/acs/session/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1286,6 +1286,98 @@ func TestSessionCallsAddUpdateRequestHandlers(t *testing.T) {
assert.True(t, addUpdateRequestHandlersCalled)
}

// TestGetLastConnectedTime tests that the Session's 'lastConnectedTime' field is updated correctly for successive
// invocations of startSessionOnce. Also tests that the Session's GetLastConnectedTime() API call works as expected.
func TestGetLastConnectedTime(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

const numInvocations = 10
discoverEndpointClient := mock_api.NewMockECSDiscoverEndpointSDK(ctrl)
discoverEndpointClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(acsURL, nil).AnyTimes()
ctx, cancel := context.WithCancel(context.Background())

mockWsClient := mock_wsclient.NewMockClientServer(ctrl)
mockClientFactory := mock_wsclient.NewMockClientFactory(ctrl)
mockClientFactory.EXPECT().
New(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
Return(mockWsClient).AnyTimes()
mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).AnyTimes()
mockWsClient.EXPECT().AddRequestHandler(gomock.Any()).AnyTimes()
mockWsClient.EXPECT().WriteCloseMessage().AnyTimes()
mockWsClient.EXPECT().Close().Return(nil).AnyTimes()
mockWsClient.EXPECT().Serve(gomock.Any()).Return(io.EOF).AnyTimes()

acsSession := NewSession(testconst.ContainerInstanceARN,
testconst.ClusterARN,
discoverEndpointClient,
nil,
noopFunc,
mockClientFactory,
metricsfactory.NewNopEntryFactory(),
agentVersion,
agentGitShortHash,
dockerVersion,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
)
acsSession.(*session).heartbeatTimeout = 20 * time.Millisecond
acsSession.(*session).heartbeatJitter = 10 * time.Millisecond
acsSession.(*session).disconnectTimeout = 30 * time.Millisecond
acsSession.(*session).disconnectJitter = 10 * time.Millisecond
gomock.InOrder(
// When the websocket client connects to ACS for the first time, 'sendCredentials' should be set to true.
mockWsClient.EXPECT().Connect(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(interface{},
interface{}, interface{}) {
assert.Equal(t, true, acsSession.(*session).sendCredentials)
}).Return(time.NewTimer(wsclient.DisconnectTimeout), nil),
// For all subsequent connections to ACS, 'sendCredentials' should be set to false.
mockWsClient.EXPECT().Connect(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(interface{},
interface{}, interface{}) {
assert.Equal(t, false, acsSession.(*session).sendCredentials)
}).Return(time.NewTimer(wsclient.DisconnectTimeout), nil).Times(numInvocations-1),
)

// The Session's lastConnectedTime field was initialized with time.Time{}, which is the default zero value for time.Time.
// At this point, since the Session has not connected to ACS yet, the Session's lastConnectedTime should still be zero.
assert.True(t, acsSession.GetLastConnectedTime().IsZero())

go func() {
for i := 0; i < numInvocations; i++ {
// Record the current time.
currentTime := time.Now()
// Invoke startSessionOnce() to connect to ACS.
acsSession.(*session).startSessionOnce(ctx)
// Get the timestamp recorded in Session's lastConnectedTime field.
acsSessionActualConnectedTime := acsSession.GetLastConnectedTime()
// Compare the two timestamps.
// Since the connection was started right after the first timestamp was recorded, the two timestamps should
// be very close. Allow an 1 ms to account for jitters.
assert.WithinDuration(t, currentTime, acsSessionActualConnectedTime, 1*time.Millisecond)
// Sleep for 2 ms before proceeding to the next test iteration, so that if the Session's lastConnectedTime
// field is not correctly updated, it would be caught since the allowed delta is 1 ms.
time.Sleep(2 * time.Millisecond)
}
cancel()
}()

// Wait for context to be canceled.
select {
case <-ctx.Done():
cancel()
}
}

func startFakeACSServer(closeWS <-chan bool) (*httptest.Server, chan<- string, <-chan string, <-chan error, error) {
serverChan := make(chan string, 1)
requestsChan := make(chan string, 1)
Expand Down

0 comments on commit ea9a929

Please sign in to comment.