Skip to content

Commit

Permalink
Add restart count to TMDE v4 response (#4166)
Browse files Browse the repository at this point in the history
* Add restart count to TMDE v4 response

* Add warning log message if container is not found in internal state
  • Loading branch information
sparrc committed Jul 29, 2024
1 parent c2fb1c4 commit 26bd421
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 27 deletions.
36 changes: 25 additions & 11 deletions agent/handlers/task_server_setup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1381,7 +1381,7 @@ func testTMDSRequest[R TMDSResponse](t *testing.T, tc TMDSTestCase[R]) {
// Parse the response body
var actualResponseBody R
err = json.Unmarshal(recorder.Body.Bytes(), &actualResponseBody)
require.NoError(t, err)
require.NoError(t, err, recorder.Body.String())

// Assert status code and body
assert.Equal(t, tc.expectedStatusCode, recorder.Code)
Expand Down Expand Up @@ -1801,8 +1801,9 @@ func TestV4ContainerMetadata(t *testing.T) {
setStateExpectations: func(state *mock_dockerstate.MockTaskEngineState) {
gomock.InOrder(
state.EXPECT().DockerIDByV3EndpointID(v3EndpointID).Return(containerID, true),
state.EXPECT().ContainerByID(containerID).Return(dockerContainer, true),
state.EXPECT().ContainerByID(containerID).Return(dockerContainer, true).AnyTimes(),
state.EXPECT().TaskByID(containerID).Return(task, true).Times(2),
state.EXPECT().ContainerByID(containerID).Return(dockerContainer, true).AnyTimes(),
)
},
expectedStatusCode: http.StatusOK,
Expand All @@ -1817,7 +1818,7 @@ func TestV4ContainerMetadata(t *testing.T) {
state.EXPECT().DockerIDByV3EndpointID(v3EndpointID).Return(containerID, true),
state.EXPECT().ContainerByID(containerID).Return(bridgeContainer, true),
state.EXPECT().TaskByID(containerID).Return(bridgeTask, true),
state.EXPECT().ContainerByID(containerID).Return(nil, false),
state.EXPECT().ContainerByID(containerID).Return(nil, false).AnyTimes(),
)
},
expectedStatusCode: http.StatusInternalServerError,
Expand All @@ -1832,7 +1833,7 @@ func TestV4ContainerMetadata(t *testing.T) {
state.EXPECT().DockerIDByV3EndpointID(v3EndpointID).Return(containerID, true),
state.EXPECT().ContainerByID(containerID).Return(bridgeContainerNoNetwork, true),
state.EXPECT().TaskByID(containerID).Return(bridgeTask, true),
state.EXPECT().ContainerByID(containerID).Return(bridgeContainerNoNetwork, true),
state.EXPECT().ContainerByID(containerID).Return(bridgeContainerNoNetwork, true).AnyTimes(),
)
},
expectedStatusCode: http.StatusInternalServerError,
Expand All @@ -1846,9 +1847,9 @@ func TestV4ContainerMetadata(t *testing.T) {
setStateExpectations: func(state *mock_dockerstate.MockTaskEngineState) {
gomock.InOrder(
state.EXPECT().DockerIDByV3EndpointID(v3EndpointID).Return(containerID, true),
state.EXPECT().ContainerByID(containerID).Return(bridgeContainer, true),
state.EXPECT().ContainerByID(containerID).Return(bridgeContainer, true).AnyTimes(),
state.EXPECT().TaskByID(containerID).Return(bridgeTask, true),
state.EXPECT().ContainerByID(containerID).Return(bridgeContainer, true),
state.EXPECT().ContainerByID(containerID).Return(bridgeContainer, true).AnyTimes(),
)
},
expectedStatusCode: http.StatusOK,
Expand Down Expand Up @@ -1939,8 +1940,10 @@ func TestV4TaskMetadata(t *testing.T) {
gomock.InOrder(
state.EXPECT().TaskARNByV3EndpointID(v3EndpointID).Return(taskARN, true),
state.EXPECT().TaskByArn(taskARN).Return(task, true).Times(2),
state.EXPECT().ContainerByID(containerID).Return(dockerContainer, true).AnyTimes(),
state.EXPECT().ContainerMapByArn(taskARN).Return(containerNameToDockerContainer, true),
state.EXPECT().TaskByArn(taskARN).Return(task, true),
state.EXPECT().ContainerByID(containerID).Return(dockerContainer, true).AnyTimes(),
state.EXPECT().PulledContainerMapByArn(taskARN).Return(nil, true),
)
},
Expand All @@ -1955,8 +1958,10 @@ func TestV4TaskMetadata(t *testing.T) {
gomock.InOrder(
state.EXPECT().TaskARNByV3EndpointID(v3EndpointID).Return(taskARN, true),
state.EXPECT().TaskByArn(taskARN).Return(pulledTask, true).Times(2),
state.EXPECT().ContainerByID(containerID).Return(dockerContainer, true).AnyTimes(),
state.EXPECT().ContainerMapByArn(taskARN).Return(containerNameToDockerContainer, true),
state.EXPECT().TaskByArn(taskARN).Return(pulledTask, true),
state.EXPECT().ContainerByID(containerID).Return(dockerContainer, true).AnyTimes(),
state.EXPECT().PulledContainerMapByArn(taskARN).Return(pulledContainerNameToDockerContainer, true),
)
},
Expand All @@ -1972,8 +1977,9 @@ func TestV4TaskMetadata(t *testing.T) {
state.EXPECT().TaskARNByV3EndpointID(v3EndpointID).Return(taskARN, true),
state.EXPECT().TaskByArn(taskARN).Return(bridgeTask, true).Times(2),
state.EXPECT().ContainerMapByArn(taskARN).Return(containerNameToBridgeContainer, true),
state.EXPECT().ContainerByID(containerID).Return(nil, false),
state.EXPECT().ContainerByID(containerID).Return(nil, false).AnyTimes(),
state.EXPECT().PulledContainerMapByArn(taskARN).Return(nil, true),
state.EXPECT().ContainerByID(containerID).Return(nil, false).AnyTimes(),
)
},
expectedStatusCode: http.StatusOK,
Expand All @@ -1988,8 +1994,9 @@ func TestV4TaskMetadata(t *testing.T) {
state.EXPECT().TaskARNByV3EndpointID(v3EndpointID).Return(taskARN, true),
state.EXPECT().TaskByArn(taskARN).Return(bridgeTask, true).Times(2),
state.EXPECT().ContainerMapByArn(taskARN).Return(containerNameToBridgeContainer, true),
state.EXPECT().ContainerByID(containerID).Return(bridgeContainerNoNetwork, true),
state.EXPECT().ContainerByID(containerID).Return(bridgeContainerNoNetwork, true).AnyTimes(),
state.EXPECT().PulledContainerMapByArn(taskARN).Return(nil, true),
state.EXPECT().ContainerByID(containerID).Return(bridgeContainerNoNetwork, true).AnyTimes(),
)
},
expectedStatusCode: http.StatusOK,
Expand All @@ -2004,8 +2011,9 @@ func TestV4TaskMetadata(t *testing.T) {
state.EXPECT().TaskARNByV3EndpointID(v3EndpointID).Return(taskARN, true),
state.EXPECT().TaskByArn(taskARN).Return(bridgeTask, true).Times(2),
state.EXPECT().ContainerMapByArn(taskARN).Return(containerNameToBridgeContainer, true),
state.EXPECT().ContainerByID(containerID).Return(bridgeContainer, true),
state.EXPECT().ContainerByID(containerID).Return(bridgeContainer, true).AnyTimes(),
state.EXPECT().PulledContainerMapByArn(taskARN).Return(nil, true),
state.EXPECT().ContainerByID(containerID).Return(bridgeContainer, true).AnyTimes(),
)
},
expectedStatusCode: http.StatusOK,
Expand Down Expand Up @@ -2340,8 +2348,10 @@ func TestV4TaskMetadataWithTags(t *testing.T) {
gomock.InOrder(
state.EXPECT().TaskARNByV3EndpointID(v3EndpointID).Return(taskARN, true),
state.EXPECT().TaskByArn(taskARN).Return(task, true).AnyTimes(),
state.EXPECT().ContainerByID(containerID).Return(bridgeContainer, true).AnyTimes(),
state.EXPECT().ContainerMapByArn(taskARN).Return(containerNameToDockerContainer, true),
state.EXPECT().TaskByArn(taskARN).Return(task, true).AnyTimes(),
state.EXPECT().ContainerByID(containerID).Return(bridgeContainer, true).AnyTimes(),
state.EXPECT().PulledContainerMapByArn(taskARN).Return(nil, true),
)
}
Expand Down Expand Up @@ -2495,8 +2505,9 @@ func TestV4TaskMetadataWithTags(t *testing.T) {
state.EXPECT().TaskARNByV3EndpointID(v3EndpointID).Return(taskARN, true),
state.EXPECT().TaskByArn(taskARN).Return(bridgeTask, true).Times(2),
state.EXPECT().ContainerMapByArn(taskARN).Return(containerNameToBridgeContainer, true),
state.EXPECT().ContainerByID(containerID).Return(nil, false),
state.EXPECT().ContainerByID(containerID).Return(nil, false).AnyTimes(),
state.EXPECT().PulledContainerMapByArn(taskARN).Return(nil, true),
state.EXPECT().ContainerByID(containerID).Return(nil, false).AnyTimes(),
)
},
setECSClientExpectations: happyECSClientExpectations,
Expand All @@ -2515,8 +2526,9 @@ func TestV4TaskMetadataWithTags(t *testing.T) {
state.EXPECT().TaskARNByV3EndpointID(v3EndpointID).Return(taskARN, true),
state.EXPECT().TaskByArn(taskARN).Return(bridgeTask, true).Times(2),
state.EXPECT().ContainerMapByArn(taskARN).Return(containerNameToBridgeContainer, true),
state.EXPECT().ContainerByID(containerID).Return(bridgeContainerNoNetwork, true),
state.EXPECT().ContainerByID(containerID).Return(bridgeContainerNoNetwork, true).AnyTimes(),
state.EXPECT().PulledContainerMapByArn(taskARN).Return(nil, true),
state.EXPECT().ContainerByID(containerID).Return(bridgeContainerNoNetwork, true).AnyTimes(),
)
},
setECSClientExpectations: happyECSClientExpectations,
Expand Down Expand Up @@ -3002,6 +3014,7 @@ func TestGetTaskProtection(t *testing.T) {
state.EXPECT().TaskByArn(taskARN).Return(task, true).Times(2),
state.EXPECT().ContainerMapByArn(taskARN).Return(containerNameToDockerContainer, true),
state.EXPECT().TaskByArn(taskARN).Return(task, true),
state.EXPECT().ContainerByID(containerID).Return(dockerContainer, true).AnyTimes(),
state.EXPECT().PulledContainerMapByArn(taskARN).Return(nil, true),
)
}
Expand Down Expand Up @@ -3273,6 +3286,7 @@ func TestUpdateTaskProtection(t *testing.T) {
state.EXPECT().TaskByArn(taskARN).Return(task, true).Times(2),
state.EXPECT().ContainerMapByArn(taskARN).Return(containerNameToDockerContainer, true),
state.EXPECT().TaskByArn(taskARN).Return(task, true),
state.EXPECT().ContainerByID(containerID).Return(dockerContainer, true).AnyTimes(),
state.EXPECT().PulledContainerMapByArn(taskARN).Return(nil, true),
)
}
Expand Down
43 changes: 36 additions & 7 deletions agent/handlers/v4/response.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ import (
"github.com/aws/amazon-ecs-agent/agent/engine/dockerstate"
v2 "github.com/aws/amazon-ecs-agent/agent/handlers/v2"
"github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs"
"github.com/aws/amazon-ecs-agent/ecs-agent/logger"
"github.com/aws/amazon-ecs-agent/ecs-agent/logger/field"
ni "github.com/aws/amazon-ecs-agent/ecs-agent/netlib/model/networkinterface"
tmdsresponse "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/response"
"github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/utils"
Expand All @@ -28,7 +30,7 @@ import (
)

// NewTaskResponse creates a new v4 response object for the task. It augments v2 task response
// with additional network interface fields.
// with additional fields for the v4 response.
func NewTaskResponse(
taskARN string,
state dockerstate.TaskEngineState,
Expand All @@ -55,10 +57,12 @@ func NewTaskResponse(
if err != nil {
return nil, err
}
containers = append(containers, tmdsv4.ContainerResponse{
v4Response := &tmdsv4.ContainerResponse{
ContainerResponse: &v2Resp.Containers[i],
Networks: networks,
})
}
v4Response = augmentContainerResponse(container.ID, state, v4Response)
containers = append(containers, *v4Response)
}

return &tmdsv4.TaskResponse{
Expand All @@ -69,8 +73,8 @@ func NewTaskResponse(
}, nil
}

// NewContainerResponse creates a new v4 container response based on container id. It augments
// v4 container response with additional network interface fields.
// NewContainerResponse creates a new v4 container response based on container id. It augments
// v2 container response with additional fields for v4.
func NewContainerResponse(
containerID string,
state dockerstate.TaskEngineState,
Expand All @@ -87,10 +91,11 @@ func NewContainerResponse(
if err != nil {
return nil, err
}
return &tmdsv4.ContainerResponse{
v4Response := tmdsv4.ContainerResponse{
ContainerResponse: container,
Networks: networks,
}, nil
}
return augmentContainerResponse(containerID, state, &v4Response), nil
}

// toV4NetworkResponse converts v2 network response to v4. Additional fields are only
Expand Down Expand Up @@ -121,6 +126,30 @@ func toV4NetworkResponse(
return resp, nil
}

// augmentContainerResponse augments the container response with additional fields.
func augmentContainerResponse(
containerID string,
state dockerstate.TaskEngineState,
v4Response *tmdsv4.ContainerResponse,
) *tmdsv4.ContainerResponse {
dockerContainer, ok := state.ContainerByID(containerID)
if !ok {
// did not find container, continue on and try next container(s)
// we don't return error here to avoid disrupting all of a TMDS response
// on a single missing container.
logger.Warn("V4 container response: unable to find container in internal state",
logger.Fields{
field.RuntimeID: containerID,
})
return v4Response
}
if dockerContainer.Container.RestartPolicyEnabled() {
restartCount := dockerContainer.Container.RestartTracker.GetRestartCount()
v4Response.RestartCount = &restartCount
}
return v4Response
}

// newNetworkInterfaceProperties creates the NetworkInterfaceProperties object for a given
// task.
func newNetworkInterfaceProperties(task *apitask.Task) (tmdsv4.NetworkInterfaceProperties, error) {
Expand Down
15 changes: 6 additions & 9 deletions agent/handlers/v4/response_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,11 +132,10 @@ func TestNewTaskContainerResponses(t *testing.T) {
containerNameToDockerContainer := map[string]*apicontainer.DockerContainer{
taskARN: dockerContainer,
}
gomock.InOrder(
state.EXPECT().TaskByArn(taskARN).Return(task, true),
state.EXPECT().ContainerMapByArn(taskARN).Return(containerNameToDockerContainer, true),
state.EXPECT().TaskByArn(taskARN).Return(task, true),
)
state.EXPECT().TaskByArn(taskARN).Return(task, true)
state.EXPECT().ContainerByID(containerID).Return(dockerContainer, true).AnyTimes()
state.EXPECT().ContainerMapByArn(taskARN).Return(containerNameToDockerContainer, true)
state.EXPECT().TaskByArn(taskARN).Return(task, true)

taskResponse, err := NewTaskResponse(taskARN, state, ecsClient, cluster,
availabilityZone, vpcID, containerInstanceArn, task.ServiceName, false)
Expand All @@ -150,10 +149,8 @@ func TestNewTaskContainerResponses(t *testing.T) {
assert.Equal(t, subnetGatewayIPV4Address, taskResponse.Containers[0].Networks[0].SubnetGatewayIPV4Address)
assert.Equal(t, serviceName, taskResponse.ServiceName)

gomock.InOrder(
state.EXPECT().ContainerByID(containerID).Return(dockerContainer, true),
state.EXPECT().TaskByID(containerID).Return(task, true).Times(2),
)
state.EXPECT().ContainerByID(containerID).Return(dockerContainer, true).AnyTimes()
state.EXPECT().TaskByID(containerID).Return(task, true).Times(2)
containerResponse, err := NewContainerResponse(containerID, state)
require.NoError(t, err)
_, err = json.Marshal(containerResponse)
Expand Down

0 comments on commit 26bd421

Please sign in to comment.