From 26bd42119031c52d900fd7f35eb7e59c08edd9d9 Mon Sep 17 00:00:00 2001 From: Cameron Sparr Date: Wed, 22 May 2024 12:08:32 -0700 Subject: [PATCH] Add restart count to TMDE v4 response (#4166) * Add restart count to TMDE v4 response * Add warning log message if container is not found in internal state --- agent/handlers/task_server_setup_test.go | 36 ++++++++++++++------ agent/handlers/v4/response.go | 43 ++++++++++++++++++++---- agent/handlers/v4/response_test.go | 15 ++++----- 3 files changed, 67 insertions(+), 27 deletions(-) diff --git a/agent/handlers/task_server_setup_test.go b/agent/handlers/task_server_setup_test.go index f01bd9140b0..099840d96c4 100644 --- a/agent/handlers/task_server_setup_test.go +++ b/agent/handlers/task_server_setup_test.go @@ -1381,7 +1381,7 @@ func testTMDSRequest[R TMDSResponse](t *testing.T, tc TMDSTestCase[R]) { // Parse the response body var actualResponseBody R err = json.Unmarshal(recorder.Body.Bytes(), &actualResponseBody) - require.NoError(t, err) + require.NoError(t, err, recorder.Body.String()) // Assert status code and body assert.Equal(t, tc.expectedStatusCode, recorder.Code) @@ -1801,8 +1801,9 @@ func TestV4ContainerMetadata(t *testing.T) { setStateExpectations: func(state *mock_dockerstate.MockTaskEngineState) { gomock.InOrder( state.EXPECT().DockerIDByV3EndpointID(v3EndpointID).Return(containerID, true), - state.EXPECT().ContainerByID(containerID).Return(dockerContainer, true), + state.EXPECT().ContainerByID(containerID).Return(dockerContainer, true).AnyTimes(), state.EXPECT().TaskByID(containerID).Return(task, true).Times(2), + state.EXPECT().ContainerByID(containerID).Return(dockerContainer, true).AnyTimes(), ) }, expectedStatusCode: http.StatusOK, @@ -1817,7 +1818,7 @@ func TestV4ContainerMetadata(t *testing.T) { state.EXPECT().DockerIDByV3EndpointID(v3EndpointID).Return(containerID, true), state.EXPECT().ContainerByID(containerID).Return(bridgeContainer, true), state.EXPECT().TaskByID(containerID).Return(bridgeTask, true), - state.EXPECT().ContainerByID(containerID).Return(nil, false), + state.EXPECT().ContainerByID(containerID).Return(nil, false).AnyTimes(), ) }, expectedStatusCode: http.StatusInternalServerError, @@ -1832,7 +1833,7 @@ func TestV4ContainerMetadata(t *testing.T) { state.EXPECT().DockerIDByV3EndpointID(v3EndpointID).Return(containerID, true), state.EXPECT().ContainerByID(containerID).Return(bridgeContainerNoNetwork, true), state.EXPECT().TaskByID(containerID).Return(bridgeTask, true), - state.EXPECT().ContainerByID(containerID).Return(bridgeContainerNoNetwork, true), + state.EXPECT().ContainerByID(containerID).Return(bridgeContainerNoNetwork, true).AnyTimes(), ) }, expectedStatusCode: http.StatusInternalServerError, @@ -1846,9 +1847,9 @@ func TestV4ContainerMetadata(t *testing.T) { setStateExpectations: func(state *mock_dockerstate.MockTaskEngineState) { gomock.InOrder( state.EXPECT().DockerIDByV3EndpointID(v3EndpointID).Return(containerID, true), - state.EXPECT().ContainerByID(containerID).Return(bridgeContainer, true), + state.EXPECT().ContainerByID(containerID).Return(bridgeContainer, true).AnyTimes(), state.EXPECT().TaskByID(containerID).Return(bridgeTask, true), - state.EXPECT().ContainerByID(containerID).Return(bridgeContainer, true), + state.EXPECT().ContainerByID(containerID).Return(bridgeContainer, true).AnyTimes(), ) }, expectedStatusCode: http.StatusOK, @@ -1939,8 +1940,10 @@ func TestV4TaskMetadata(t *testing.T) { gomock.InOrder( state.EXPECT().TaskARNByV3EndpointID(v3EndpointID).Return(taskARN, true), state.EXPECT().TaskByArn(taskARN).Return(task, true).Times(2), + state.EXPECT().ContainerByID(containerID).Return(dockerContainer, true).AnyTimes(), state.EXPECT().ContainerMapByArn(taskARN).Return(containerNameToDockerContainer, true), state.EXPECT().TaskByArn(taskARN).Return(task, true), + state.EXPECT().ContainerByID(containerID).Return(dockerContainer, true).AnyTimes(), state.EXPECT().PulledContainerMapByArn(taskARN).Return(nil, true), ) }, @@ -1955,8 +1958,10 @@ func TestV4TaskMetadata(t *testing.T) { gomock.InOrder( state.EXPECT().TaskARNByV3EndpointID(v3EndpointID).Return(taskARN, true), state.EXPECT().TaskByArn(taskARN).Return(pulledTask, true).Times(2), + state.EXPECT().ContainerByID(containerID).Return(dockerContainer, true).AnyTimes(), state.EXPECT().ContainerMapByArn(taskARN).Return(containerNameToDockerContainer, true), state.EXPECT().TaskByArn(taskARN).Return(pulledTask, true), + state.EXPECT().ContainerByID(containerID).Return(dockerContainer, true).AnyTimes(), state.EXPECT().PulledContainerMapByArn(taskARN).Return(pulledContainerNameToDockerContainer, true), ) }, @@ -1972,8 +1977,9 @@ func TestV4TaskMetadata(t *testing.T) { state.EXPECT().TaskARNByV3EndpointID(v3EndpointID).Return(taskARN, true), state.EXPECT().TaskByArn(taskARN).Return(bridgeTask, true).Times(2), state.EXPECT().ContainerMapByArn(taskARN).Return(containerNameToBridgeContainer, true), - state.EXPECT().ContainerByID(containerID).Return(nil, false), + state.EXPECT().ContainerByID(containerID).Return(nil, false).AnyTimes(), state.EXPECT().PulledContainerMapByArn(taskARN).Return(nil, true), + state.EXPECT().ContainerByID(containerID).Return(nil, false).AnyTimes(), ) }, expectedStatusCode: http.StatusOK, @@ -1988,8 +1994,9 @@ func TestV4TaskMetadata(t *testing.T) { state.EXPECT().TaskARNByV3EndpointID(v3EndpointID).Return(taskARN, true), state.EXPECT().TaskByArn(taskARN).Return(bridgeTask, true).Times(2), state.EXPECT().ContainerMapByArn(taskARN).Return(containerNameToBridgeContainer, true), - state.EXPECT().ContainerByID(containerID).Return(bridgeContainerNoNetwork, true), + state.EXPECT().ContainerByID(containerID).Return(bridgeContainerNoNetwork, true).AnyTimes(), state.EXPECT().PulledContainerMapByArn(taskARN).Return(nil, true), + state.EXPECT().ContainerByID(containerID).Return(bridgeContainerNoNetwork, true).AnyTimes(), ) }, expectedStatusCode: http.StatusOK, @@ -2004,8 +2011,9 @@ func TestV4TaskMetadata(t *testing.T) { state.EXPECT().TaskARNByV3EndpointID(v3EndpointID).Return(taskARN, true), state.EXPECT().TaskByArn(taskARN).Return(bridgeTask, true).Times(2), state.EXPECT().ContainerMapByArn(taskARN).Return(containerNameToBridgeContainer, true), - state.EXPECT().ContainerByID(containerID).Return(bridgeContainer, true), + state.EXPECT().ContainerByID(containerID).Return(bridgeContainer, true).AnyTimes(), state.EXPECT().PulledContainerMapByArn(taskARN).Return(nil, true), + state.EXPECT().ContainerByID(containerID).Return(bridgeContainer, true).AnyTimes(), ) }, expectedStatusCode: http.StatusOK, @@ -2340,8 +2348,10 @@ func TestV4TaskMetadataWithTags(t *testing.T) { gomock.InOrder( state.EXPECT().TaskARNByV3EndpointID(v3EndpointID).Return(taskARN, true), state.EXPECT().TaskByArn(taskARN).Return(task, true).AnyTimes(), + state.EXPECT().ContainerByID(containerID).Return(bridgeContainer, true).AnyTimes(), state.EXPECT().ContainerMapByArn(taskARN).Return(containerNameToDockerContainer, true), state.EXPECT().TaskByArn(taskARN).Return(task, true).AnyTimes(), + state.EXPECT().ContainerByID(containerID).Return(bridgeContainer, true).AnyTimes(), state.EXPECT().PulledContainerMapByArn(taskARN).Return(nil, true), ) } @@ -2495,8 +2505,9 @@ func TestV4TaskMetadataWithTags(t *testing.T) { state.EXPECT().TaskARNByV3EndpointID(v3EndpointID).Return(taskARN, true), state.EXPECT().TaskByArn(taskARN).Return(bridgeTask, true).Times(2), state.EXPECT().ContainerMapByArn(taskARN).Return(containerNameToBridgeContainer, true), - state.EXPECT().ContainerByID(containerID).Return(nil, false), + state.EXPECT().ContainerByID(containerID).Return(nil, false).AnyTimes(), state.EXPECT().PulledContainerMapByArn(taskARN).Return(nil, true), + state.EXPECT().ContainerByID(containerID).Return(nil, false).AnyTimes(), ) }, setECSClientExpectations: happyECSClientExpectations, @@ -2515,8 +2526,9 @@ func TestV4TaskMetadataWithTags(t *testing.T) { state.EXPECT().TaskARNByV3EndpointID(v3EndpointID).Return(taskARN, true), state.EXPECT().TaskByArn(taskARN).Return(bridgeTask, true).Times(2), state.EXPECT().ContainerMapByArn(taskARN).Return(containerNameToBridgeContainer, true), - state.EXPECT().ContainerByID(containerID).Return(bridgeContainerNoNetwork, true), + state.EXPECT().ContainerByID(containerID).Return(bridgeContainerNoNetwork, true).AnyTimes(), state.EXPECT().PulledContainerMapByArn(taskARN).Return(nil, true), + state.EXPECT().ContainerByID(containerID).Return(bridgeContainerNoNetwork, true).AnyTimes(), ) }, setECSClientExpectations: happyECSClientExpectations, @@ -3002,6 +3014,7 @@ func TestGetTaskProtection(t *testing.T) { state.EXPECT().TaskByArn(taskARN).Return(task, true).Times(2), state.EXPECT().ContainerMapByArn(taskARN).Return(containerNameToDockerContainer, true), state.EXPECT().TaskByArn(taskARN).Return(task, true), + state.EXPECT().ContainerByID(containerID).Return(dockerContainer, true).AnyTimes(), state.EXPECT().PulledContainerMapByArn(taskARN).Return(nil, true), ) } @@ -3273,6 +3286,7 @@ func TestUpdateTaskProtection(t *testing.T) { state.EXPECT().TaskByArn(taskARN).Return(task, true).Times(2), state.EXPECT().ContainerMapByArn(taskARN).Return(containerNameToDockerContainer, true), state.EXPECT().TaskByArn(taskARN).Return(task, true), + state.EXPECT().ContainerByID(containerID).Return(dockerContainer, true).AnyTimes(), state.EXPECT().PulledContainerMapByArn(taskARN).Return(nil, true), ) } diff --git a/agent/handlers/v4/response.go b/agent/handlers/v4/response.go index e01eabf03de..ece020953ea 100644 --- a/agent/handlers/v4/response.go +++ b/agent/handlers/v4/response.go @@ -19,6 +19,8 @@ import ( "github.com/aws/amazon-ecs-agent/agent/engine/dockerstate" v2 "github.com/aws/amazon-ecs-agent/agent/handlers/v2" "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs" + "github.com/aws/amazon-ecs-agent/ecs-agent/logger" + "github.com/aws/amazon-ecs-agent/ecs-agent/logger/field" ni "github.com/aws/amazon-ecs-agent/ecs-agent/netlib/model/networkinterface" tmdsresponse "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/response" "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/utils" @@ -28,7 +30,7 @@ import ( ) // NewTaskResponse creates a new v4 response object for the task. It augments v2 task response -// with additional network interface fields. +// with additional fields for the v4 response. func NewTaskResponse( taskARN string, state dockerstate.TaskEngineState, @@ -55,10 +57,12 @@ func NewTaskResponse( if err != nil { return nil, err } - containers = append(containers, tmdsv4.ContainerResponse{ + v4Response := &tmdsv4.ContainerResponse{ ContainerResponse: &v2Resp.Containers[i], Networks: networks, - }) + } + v4Response = augmentContainerResponse(container.ID, state, v4Response) + containers = append(containers, *v4Response) } return &tmdsv4.TaskResponse{ @@ -69,8 +73,8 @@ func NewTaskResponse( }, nil } -// NewContainerResponse creates a new v4 container response based on container id. It augments -// v4 container response with additional network interface fields. +// NewContainerResponse creates a new v4 container response based on container id. It augments +// v2 container response with additional fields for v4. func NewContainerResponse( containerID string, state dockerstate.TaskEngineState, @@ -87,10 +91,11 @@ func NewContainerResponse( if err != nil { return nil, err } - return &tmdsv4.ContainerResponse{ + v4Response := tmdsv4.ContainerResponse{ ContainerResponse: container, Networks: networks, - }, nil + } + return augmentContainerResponse(containerID, state, &v4Response), nil } // toV4NetworkResponse converts v2 network response to v4. Additional fields are only @@ -121,6 +126,30 @@ func toV4NetworkResponse( return resp, nil } +// augmentContainerResponse augments the container response with additional fields. +func augmentContainerResponse( + containerID string, + state dockerstate.TaskEngineState, + v4Response *tmdsv4.ContainerResponse, +) *tmdsv4.ContainerResponse { + dockerContainer, ok := state.ContainerByID(containerID) + if !ok { + // did not find container, continue on and try next container(s) + // we don't return error here to avoid disrupting all of a TMDS response + // on a single missing container. + logger.Warn("V4 container response: unable to find container in internal state", + logger.Fields{ + field.RuntimeID: containerID, + }) + return v4Response + } + if dockerContainer.Container.RestartPolicyEnabled() { + restartCount := dockerContainer.Container.RestartTracker.GetRestartCount() + v4Response.RestartCount = &restartCount + } + return v4Response +} + // newNetworkInterfaceProperties creates the NetworkInterfaceProperties object for a given // task. func newNetworkInterfaceProperties(task *apitask.Task) (tmdsv4.NetworkInterfaceProperties, error) { diff --git a/agent/handlers/v4/response_test.go b/agent/handlers/v4/response_test.go index 32e6b57c144..30afe1a2bf7 100644 --- a/agent/handlers/v4/response_test.go +++ b/agent/handlers/v4/response_test.go @@ -132,11 +132,10 @@ func TestNewTaskContainerResponses(t *testing.T) { containerNameToDockerContainer := map[string]*apicontainer.DockerContainer{ taskARN: dockerContainer, } - gomock.InOrder( - state.EXPECT().TaskByArn(taskARN).Return(task, true), - state.EXPECT().ContainerMapByArn(taskARN).Return(containerNameToDockerContainer, true), - state.EXPECT().TaskByArn(taskARN).Return(task, true), - ) + state.EXPECT().TaskByArn(taskARN).Return(task, true) + state.EXPECT().ContainerByID(containerID).Return(dockerContainer, true).AnyTimes() + state.EXPECT().ContainerMapByArn(taskARN).Return(containerNameToDockerContainer, true) + state.EXPECT().TaskByArn(taskARN).Return(task, true) taskResponse, err := NewTaskResponse(taskARN, state, ecsClient, cluster, availabilityZone, vpcID, containerInstanceArn, task.ServiceName, false) @@ -150,10 +149,8 @@ func TestNewTaskContainerResponses(t *testing.T) { assert.Equal(t, subnetGatewayIPV4Address, taskResponse.Containers[0].Networks[0].SubnetGatewayIPV4Address) assert.Equal(t, serviceName, taskResponse.ServiceName) - gomock.InOrder( - state.EXPECT().ContainerByID(containerID).Return(dockerContainer, true), - state.EXPECT().TaskByID(containerID).Return(task, true).Times(2), - ) + state.EXPECT().ContainerByID(containerID).Return(dockerContainer, true).AnyTimes() + state.EXPECT().TaskByID(containerID).Return(task, true).Times(2) containerResponse, err := NewContainerResponse(containerID, state) require.NoError(t, err) _, err = json.Marshal(containerResponse)