diff --git a/agent/acs/handler/acs_handler.go b/agent/acs/handler/acs_handler.go index ab6df1cd74b..325a5d1c0e6 100644 --- a/agent/acs/handler/acs_handler.go +++ b/agent/acs/handler/acs_handler.go @@ -263,19 +263,6 @@ func (acsSession *session) startACSSession(client wsclient.ClientServer) error { dataClient: acsSession.dataClient, } - // Add handler to ack task ENI attach message - eniAttachHandler := newAttachTaskENIHandler( - acsSession.ctx, - cfg.Cluster, - acsSession.containerInstanceARN, - client, - eniHandler, - ) - eniAttachHandler.start() - defer eniAttachHandler.stop() - - client.AddRequestHandler(eniAttachHandler.handlerFunc()) - // Add handler to ack instance ENI attach message instanceENIAttachHandler := newAttachInstanceENIHandler( acsSession.ctx, @@ -325,9 +312,13 @@ func (acsSession *session) startACSSession(client wsclient.ClientServer) error { responseSender := func(response interface{}) error { return client.MakeRequest(response) } - - heartbeatResponder := acssession.NewHeartbeatResponder(acsSession.doctor, responseSender) - client.AddRequestHandler(heartbeatResponder.HandlerFunc()) + responders := []wsclient.RequestResponder{ + acssession.NewAttachTaskENIResponder(eniHandler, responseSender), + acssession.NewHeartbeatResponder(acsSession.doctor, responseSender), + } + for _, r := range responders { + client.AddRequestHandler(r.HandlerFunc()) + } updater.AddAgentUpdateHandlers(client, cfg, acsSession.state, acsSession.dataClient, acsSession.taskEngine) diff --git a/agent/acs/handler/attach_eni_handler_common_test.go b/agent/acs/handler/attach_eni_handler_common_test.go index ffd6e6468eb..49930633774 100644 --- a/agent/acs/handler/attach_eni_handler_common_test.go +++ b/agent/acs/handler/attach_eni_handler_common_test.go @@ -173,3 +173,43 @@ func testHandleENIAttachment(t *testing.T, attachmentType, taskArn string) { assert.NoError(t, err) assert.Len(t, res, 1) } + +// TestHandleExpiredENIAttachmentTaskENI tests handling an expired task eni +func TestHandleExpiredENIAttachmentTaskENI(t *testing.T) { + testHandleExpiredENIAttachment(t, apieni.ENIAttachmentTypeTaskENI, taskArn) +} + +// TestHandleExpiredENIAttachmentInstanceENI tests handling an expired instance eni +func TestHandleExpiredENIAttachmentInstanceENI(t *testing.T) { + testHandleExpiredENIAttachment(t, apieni.ENIAttachmentTypeInstanceENI, "") +} + +func testHandleExpiredENIAttachment(t *testing.T, attachmentType, taskArn string) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + // Set expiresAt to a value in the past. + expiresAt := time.Unix(time.Now().Unix()-1, 0) + + taskEngineState := dockerstate.NewTaskEngineState() + dataClient := data.NewNoopClient() + + eniAttachment := &apieni.ENIAttachment{ + AttachmentInfo: attachmentinfo.AttachmentInfo{ + TaskARN: taskArn, + AttachmentARN: attachmentArn, + ExpiresAt: expiresAt, + }, + AttachmentType: attachmentType, + MACAddress: randomMAC, + } + eniHandler := &eniHandler{ + state: taskEngineState, + dataClient: dataClient, + } + + // Expect an error starting the timer because of <=0 duration. + err := eniHandler.HandleENIAttachment(eniAttachment) + assert.Error(t, err) + assert.Equal(t, true, eniAttachment.HasExpired()) +} diff --git a/agent/acs/handler/attach_instance_eni_handler_test.go b/agent/acs/handler/attach_instance_eni_handler_test.go index 0ade8f44c0d..9925106c5ad 100644 --- a/agent/acs/handler/attach_instance_eni_handler_test.go +++ b/agent/acs/handler/attach_instance_eni_handler_test.go @@ -26,6 +26,7 @@ import ( "github.com/aws/amazon-ecs-agent/agent/engine/dockerstate" mock_dockerstate "github.com/aws/amazon-ecs-agent/agent/engine/dockerstate/mocks" "github.com/aws/amazon-ecs-agent/ecs-agent/acs/model/ecsacs" + "github.com/aws/amazon-ecs-agent/ecs-agent/acs/session/testconst" "github.com/aws/amazon-ecs-agent/ecs-agent/api/attachmentinfo" apieni "github.com/aws/amazon-ecs-agent/ecs-agent/api/eni" mock_wsclient "github.com/aws/amazon-ecs-agent/ecs-agent/wsclient/mock" @@ -42,8 +43,8 @@ func TestInvalidAttachInstanceENIMessage(t *testing.T) { }{ { message: &ecsacs.AttachInstanceNetworkInterfacesMessage{ - ClusterArn: aws.String(clusterName), - ContainerInstanceArn: aws.String(containerInstanceArn), + ClusterArn: aws.String(testconst.ClusterName), + ContainerInstanceArn: aws.String(testconst.ContainerInstanceARN), ElasticNetworkInterfaces: []*ecsacs.ElasticNetworkInterface{}, WaitTimeoutMs: aws.Int64(waitTimeoutMillis), }, @@ -52,7 +53,7 @@ func TestInvalidAttachInstanceENIMessage(t *testing.T) { { message: &ecsacs.AttachInstanceNetworkInterfacesMessage{ MessageId: aws.String(eniMessageId), - ContainerInstanceArn: aws.String(containerInstanceArn), + ContainerInstanceArn: aws.String(testconst.ContainerInstanceARN), ElasticNetworkInterfaces: []*ecsacs.ElasticNetworkInterface{}, WaitTimeoutMs: aws.Int64(waitTimeoutMillis), }, @@ -61,7 +62,7 @@ func TestInvalidAttachInstanceENIMessage(t *testing.T) { { message: &ecsacs.AttachInstanceNetworkInterfacesMessage{ MessageId: aws.String(eniMessageId), - ClusterArn: aws.String(clusterName), + ClusterArn: aws.String(testconst.ClusterName), ElasticNetworkInterfaces: []*ecsacs.ElasticNetworkInterface{}, WaitTimeoutMs: aws.Int64(waitTimeoutMillis), }, @@ -70,7 +71,7 @@ func TestInvalidAttachInstanceENIMessage(t *testing.T) { { message: &ecsacs.AttachInstanceNetworkInterfacesMessage{ MessageId: aws.String(eniMessageId), - ClusterArn: aws.String(clusterName), + ClusterArn: aws.String(testconst.ClusterName), WaitTimeoutMs: aws.Int64(waitTimeoutMillis), }, description: "Message without network interfaces should be invalid", @@ -78,8 +79,8 @@ func TestInvalidAttachInstanceENIMessage(t *testing.T) { { message: &ecsacs.AttachInstanceNetworkInterfacesMessage{ MessageId: aws.String(eniMessageId), - ClusterArn: aws.String(clusterName), - ContainerInstanceArn: aws.String(containerInstanceArn), + ClusterArn: aws.String(testconst.ClusterName), + ContainerInstanceArn: aws.String(testconst.ContainerInstanceARN), ElasticNetworkInterfaces: []*ecsacs.ElasticNetworkInterface{ { MacAddress: aws.String(randomMAC), @@ -97,8 +98,8 @@ func TestInvalidAttachInstanceENIMessage(t *testing.T) { { message: &ecsacs.AttachInstanceNetworkInterfacesMessage{ MessageId: aws.String(eniMessageId), - ClusterArn: aws.String(clusterName), - ContainerInstanceArn: aws.String(containerInstanceArn), + ClusterArn: aws.String(testconst.ClusterName), + ContainerInstanceArn: aws.String(testconst.ContainerInstanceARN), ElasticNetworkInterfaces: []*ecsacs.ElasticNetworkInterface{ {}, }, @@ -109,8 +110,8 @@ func TestInvalidAttachInstanceENIMessage(t *testing.T) { { message: &ecsacs.AttachInstanceNetworkInterfacesMessage{ MessageId: aws.String(eniMessageId), - ClusterArn: aws.String(clusterName), - ContainerInstanceArn: aws.String(containerInstanceArn), + ClusterArn: aws.String(testconst.ClusterName), + ContainerInstanceArn: aws.String(testconst.ContainerInstanceARN), ElasticNetworkInterfaces: []*ecsacs.ElasticNetworkInterface{ { Ec2Id: aws.String("1"), @@ -123,8 +124,8 @@ func TestInvalidAttachInstanceENIMessage(t *testing.T) { { message: &ecsacs.AttachInstanceNetworkInterfacesMessage{ MessageId: aws.String(eniMessageId), - ClusterArn: aws.String(clusterName), - ContainerInstanceArn: aws.String(containerInstanceArn), + ClusterArn: aws.String(testconst.ClusterName), + ContainerInstanceArn: aws.String(testconst.ContainerInstanceARN), ElasticNetworkInterfaces: []*ecsacs.ElasticNetworkInterface{ { MacAddress: aws.String(randomMAC), @@ -153,7 +154,7 @@ func TestInstanceENIAckSingleMessage(t *testing.T) { ctx := context.TODO() mockWSClient := mock_wsclient.NewMockClientServer(ctrl) - handler := newAttachInstanceENIHandler(ctx, clusterName, containerInstanceArn, mockWSClient, + handler := newAttachInstanceENIHandler(ctx, testconst.ClusterName, testconst.ContainerInstanceARN, mockWSClient, &eniHandler{ state: taskEngineState, dataClient: dataClient, @@ -176,8 +177,8 @@ func TestInstanceENIAckSingleMessage(t *testing.T) { } message := &ecsacs.AttachInstanceNetworkInterfacesMessage{ MessageId: aws.String(eniMessageId), - ClusterArn: aws.String(clusterName), - ContainerInstanceArn: aws.String(containerInstanceArn), + ClusterArn: aws.String(testconst.ClusterName), + ContainerInstanceArn: aws.String(testconst.ContainerInstanceARN), ElasticNetworkInterfaces: []*ecsacs.ElasticNetworkInterface{ &mockNetInterface1, }, @@ -201,7 +202,7 @@ func TestInstanceENIAckSingleMessageDuplicateENIAttachmentMessageStartsTimer(t * ctx := context.TODO() mockWSClient := mock_wsclient.NewMockClientServer(ctrl) - handler := newAttachInstanceENIHandler(ctx, clusterName, containerInstanceArn, mockWSClient, + handler := newAttachInstanceENIHandler(ctx, testconst.ClusterName, testconst.ContainerInstanceARN, mockWSClient, &eniHandler{ state: mockState, dataClient: dataClient, @@ -237,8 +238,8 @@ func TestInstanceENIAckSingleMessageDuplicateENIAttachmentMessageStartsTimer(t * } message := &ecsacs.AttachInstanceNetworkInterfacesMessage{ MessageId: aws.String(eniMessageId), - ClusterArn: aws.String(clusterName), - ContainerInstanceArn: aws.String(containerInstanceArn), + ClusterArn: aws.String(testconst.ClusterName), + ContainerInstanceArn: aws.String(testconst.ContainerInstanceARN), ElasticNetworkInterfaces: []*ecsacs.ElasticNetworkInterface{ &mockNetInterface1, }, @@ -261,7 +262,7 @@ func TestInstanceENIAckHappyPath(t *testing.T) { dataClient := data.NewNoopClient() mockWSClient := mock_wsclient.NewMockClientServer(ctrl) - handler := newAttachInstanceENIHandler(ctx, clusterName, containerInstanceArn, mockWSClient, + handler := newAttachInstanceENIHandler(ctx, testconst.ClusterName, testconst.ContainerInstanceARN, mockWSClient, &eniHandler{ state: taskEngineState, dataClient: dataClient, @@ -284,8 +285,8 @@ func TestInstanceENIAckHappyPath(t *testing.T) { } message := &ecsacs.AttachInstanceNetworkInterfacesMessage{ MessageId: aws.String(eniMessageId), - ClusterArn: aws.String(clusterName), - ContainerInstanceArn: aws.String(containerInstanceArn), + ClusterArn: aws.String(testconst.ClusterName), + ContainerInstanceArn: aws.String(testconst.ContainerInstanceARN), ElasticNetworkInterfaces: []*ecsacs.ElasticNetworkInterface{ &mockNetInterface1, }, diff --git a/agent/acs/handler/attach_task_eni_handler.go b/agent/acs/handler/attach_task_eni_handler.go deleted file mode 100644 index 5c85a32c6fd..00000000000 --- a/agent/acs/handler/attach_task_eni_handler.go +++ /dev/null @@ -1,170 +0,0 @@ -// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"). You may -// not use this file except in compliance with the License. A copy of the -// License is located at -// -// http://aws.amazon.com/apache2.0/ -// -// or in the "license" file accompanying this file. This file is distributed -// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either -// express or implied. See the License for the specific language governing -// permissions and limitations under the License. - -package handler - -import ( - "time" - - "github.com/aws/amazon-ecs-agent/ecs-agent/acs/model/ecsacs" - acssession "github.com/aws/amazon-ecs-agent/ecs-agent/acs/session" - "github.com/aws/amazon-ecs-agent/ecs-agent/api/attachmentinfo" - apieni "github.com/aws/amazon-ecs-agent/ecs-agent/api/eni" - "github.com/aws/amazon-ecs-agent/ecs-agent/api/status" - "github.com/aws/amazon-ecs-agent/ecs-agent/wsclient" - "github.com/aws/aws-sdk-go/aws" - - "github.com/cihub/seelog" - "github.com/pkg/errors" - - "context" -) - -// attachTaskENIHandler handles task ENI attach operation for the ACS client -type attachTaskENIHandler struct { - messageBuffer chan *ecsacs.AttachTaskNetworkInterfacesMessage - ctx context.Context - cancel context.CancelFunc - cluster *string - containerInstance *string - acsClient wsclient.ClientServer - eniHandler acssession.ENIHandler -} - -// newAttachTaskENIHandler returns an instance of the attachENIHandler struct -func newAttachTaskENIHandler(ctx context.Context, - cluster string, - containerInstanceArn string, - acsClient wsclient.ClientServer, - eniHandler acssession.ENIHandler) attachTaskENIHandler { - - // Create a cancelable context from the parent context - derivedContext, cancel := context.WithCancel(ctx) - return attachTaskENIHandler{ - messageBuffer: make(chan *ecsacs.AttachTaskNetworkInterfacesMessage), - ctx: derivedContext, - cancel: cancel, - cluster: aws.String(cluster), - containerInstance: aws.String(containerInstanceArn), - acsClient: acsClient, - eniHandler: eniHandler, - } -} - -// handlerFunc returns a function to enqueue requests onto attachENIHandler buffer -func (attachTaskENIHandler *attachTaskENIHandler) handlerFunc() func(message *ecsacs.AttachTaskNetworkInterfacesMessage) { - return func(message *ecsacs.AttachTaskNetworkInterfacesMessage) { - attachTaskENIHandler.messageBuffer <- message - } -} - -// start invokes handleMessages to ack each enqueued request -func (attachTaskENIHandler *attachTaskENIHandler) start() { - go attachTaskENIHandler.handleMessages() -} - -// stop is used to invoke a cancellation function -func (attachTaskENIHandler *attachTaskENIHandler) stop() { - attachTaskENIHandler.cancel() -} - -// handleMessages handles each message one at a time -func (attachTaskENIHandler *attachTaskENIHandler) handleMessages() { - for { - select { - case <-attachTaskENIHandler.ctx.Done(): - return - case message := <-attachTaskENIHandler.messageBuffer: - if err := attachTaskENIHandler.handleSingleMessage(message); err != nil { - seelog.Warnf("Unable to handle ENI Attachment message [%s]: %v", message.String(), err) - } - } - } -} - -// handleSingleMessage acks the message received -func (attachTaskENIHandler *attachTaskENIHandler) handleSingleMessage(message *ecsacs.AttachTaskNetworkInterfacesMessage) error { - receivedAt := time.Now() - // Validate fields in the message - if err := validateAttachTaskNetworkInterfacesMessage(message); err != nil { - return errors.Wrapf(err, - "attach eni message handler: error validating AttachTaskNetworkInterface message received from ECS") - } - - // Send ACK - go sendAck(attachTaskENIHandler.acsClient, message.ClusterArn, message.ContainerInstanceArn, message.MessageId) - - expiresAt := receivedAt.Add(time.Duration(aws.Int64Value(message.WaitTimeoutMs)) * time.Millisecond) - eniAttachment := &apieni.ENIAttachment{ - AttachmentInfo: attachmentinfo.AttachmentInfo{ - TaskARN: aws.StringValue(message.TaskArn), - AttachmentARN: aws.StringValue(message.ElasticNetworkInterfaces[0].AttachmentArn), - Status: status.AttachmentNone, - ExpiresAt: expiresAt, - AttachStatusSent: false, - ClusterARN: aws.StringValue(message.ClusterArn), - ContainerInstanceARN: aws.StringValue(message.ContainerInstanceArn), - }, - AttachmentType: apieni.ENIAttachmentTypeTaskENI, - MACAddress: aws.StringValue(message.ElasticNetworkInterfaces[0].MacAddress), - } - - // Handle the attachment - return attachTaskENIHandler.eniHandler.HandleENIAttachment(eniAttachment) -} - -// validateAttachTaskNetworkInterfacesMessage performs validation checks on the -// AttachTaskNetworkInterfacesMessage -func validateAttachTaskNetworkInterfacesMessage(message *ecsacs.AttachTaskNetworkInterfacesMessage) error { - if message == nil { - return errors.Errorf("attach eni handler validation: empty AttachTaskNetworkInterface message received from ECS") - } - - messageId := aws.StringValue(message.MessageId) - if messageId == "" { - return errors.Errorf("attach eni handler validation: message id not set in AttachTaskNetworkInterface message received from ECS") - } - - clusterArn := aws.StringValue(message.ClusterArn) - if clusterArn == "" { - return errors.Errorf("attach eni handler validation: clusterArn not set in AttachTaskNetworkInterface message received from ECS") - } - - containerInstanceArn := aws.StringValue(message.ContainerInstanceArn) - if containerInstanceArn == "" { - return errors.Errorf("attach eni handler validation: containerInstanceArn not set in AttachTaskNetworkInterface message received from ECS") - } - - enis := message.ElasticNetworkInterfaces - if len(enis) != 1 { - return errors.Errorf("attach eni handler validation: incorrect number of ENIs in AttachTaskNetworkInterface message received from ECS. Obtained %d", len(enis)) - } - - eni := enis[0] - if aws.StringValue(eni.MacAddress) == "" { - return errors.Errorf("attach eni handler validation: MACAddress not listed in AttachTaskNetworkInterface message received from ECS") - } - - taskArn := aws.StringValue(message.TaskArn) - if taskArn == "" { - return errors.Errorf("attach eni handler validation: taskArn not set in AttachTaskNetworkInterface message received from ECS") - } - - timeout := aws.Int64Value(message.WaitTimeoutMs) - if timeout <= 0 { - return errors.Errorf("attach eni handler validation: invalid timeout listed in AttachTaskNetworkInterface message received from ECS") - - } - - return nil -} diff --git a/agent/acs/handler/attach_task_eni_handler_test.go b/agent/acs/handler/attach_task_eni_handler_test.go deleted file mode 100644 index 075dbc2f801..00000000000 --- a/agent/acs/handler/attach_task_eni_handler_test.go +++ /dev/null @@ -1,369 +0,0 @@ -//go:build unit -// +build unit - -// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"). You may -// not use this file except in compliance with the License. A copy of the -// License is located at -// -// http://aws.amazon.com/apache2.0/ -// -// or in the "license" file accompanying this file. This file is distributed -// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either -// express or implied. See the License for the specific language governing -// permissions and limitations under the License. - -package handler - -import ( - "context" - "sync" - "testing" - "time" - - "github.com/aws/amazon-ecs-agent/agent/data" - "github.com/aws/amazon-ecs-agent/agent/engine/dockerstate" - mock_dockerstate "github.com/aws/amazon-ecs-agent/agent/engine/dockerstate/mocks" - "github.com/aws/amazon-ecs-agent/ecs-agent/acs/model/ecsacs" - "github.com/aws/amazon-ecs-agent/ecs-agent/api/attachmentinfo" - apieni "github.com/aws/amazon-ecs-agent/ecs-agent/api/eni" - mock_wsclient "github.com/aws/amazon-ecs-agent/ecs-agent/wsclient/mock" - - "github.com/aws/aws-sdk-go/aws" - "github.com/golang/mock/gomock" - "github.com/stretchr/testify/assert" -) - -const ( - eniMessageId = "123" - randomMAC = "00:0a:95:9d:68:16" - waitTimeoutMillis = 1000 -) - -// TestAttachENIMessageWithNoMessageId checks the validator against an -// AttachTaskNetworkInterfacesMessage without a messageId -func TestAttachENIMessageWithNoMessageId(t *testing.T) { - message := &ecsacs.AttachTaskNetworkInterfacesMessage{ - ClusterArn: aws.String(clusterName), - ContainerInstanceArn: aws.String(containerInstanceArn), - ElasticNetworkInterfaces: []*ecsacs.ElasticNetworkInterface{}, - TaskArn: aws.String(taskArn), - WaitTimeoutMs: aws.Int64(waitTimeoutMillis), - } - - err := validateAttachTaskNetworkInterfacesMessage(message) - assert.Error(t, err) -} - -// TestAttachENIMessageWithNoClusterArn checks the validator against an -// AttachTaskNetworkInterfacesMessage without a ClusterArn -func TestAttachENIMessageWithNoClusterArn(t *testing.T) { - message := &ecsacs.AttachTaskNetworkInterfacesMessage{ - MessageId: aws.String(eniMessageId), - ContainerInstanceArn: aws.String(containerInstanceArn), - ElasticNetworkInterfaces: []*ecsacs.ElasticNetworkInterface{}, - TaskArn: aws.String(taskArn), - WaitTimeoutMs: aws.Int64(waitTimeoutMillis), - } - - err := validateAttachTaskNetworkInterfacesMessage(message) - assert.Error(t, err) -} - -// TestAttachENIMessageWithNoContainerInstanceArn checks the validator against an -// AttachTaskNetworkInterfacesMessage without a ContainerInstanceArn -func TestAttachENIMessageWithNoContainerInstanceArn(t *testing.T) { - message := &ecsacs.AttachTaskNetworkInterfacesMessage{ - MessageId: aws.String(eniMessageId), - ClusterArn: aws.String(clusterName), - ElasticNetworkInterfaces: []*ecsacs.ElasticNetworkInterface{}, - TaskArn: aws.String(taskArn), - WaitTimeoutMs: aws.Int64(waitTimeoutMillis), - } - - err := validateAttachTaskNetworkInterfacesMessage(message) - assert.Error(t, err) -} - -// TestAttachENIMessageWithNoInterfaces checks the validator against an -// AttachTaskNetworkInterfacesMessage without any interface -func TestAttachENIMessageWithNoInterfaces(t *testing.T) { - message := &ecsacs.AttachTaskNetworkInterfacesMessage{ - MessageId: aws.String(eniMessageId), - ClusterArn: aws.String(clusterName), - TaskArn: aws.String(taskArn), - WaitTimeoutMs: aws.Int64(waitTimeoutMillis), - } - err := validateAttachTaskNetworkInterfacesMessage(message) - assert.Error(t, err) -} - -// TestAttachENIMessageWithMultipleInterfaceschecks checks the validator against an -// AttachTaskNetworkInterfacesMessage with multiple interfaces -func TestAttachENIMessageWithMultipleInterfaces(t *testing.T) { - mockNetInterface1 := ecsacs.ElasticNetworkInterface{ - MacAddress: aws.String(randomMAC), - Ec2Id: aws.String("1"), - } - mockNetInterface2 := ecsacs.ElasticNetworkInterface{ - MacAddress: aws.String(randomMAC), - Ec2Id: aws.String("2"), - } - message := &ecsacs.AttachTaskNetworkInterfacesMessage{ - MessageId: aws.String(eniMessageId), - ClusterArn: aws.String(clusterName), - ContainerInstanceArn: aws.String(containerInstanceArn), - ElasticNetworkInterfaces: []*ecsacs.ElasticNetworkInterface{ - &mockNetInterface1, - &mockNetInterface2, - }, - TaskArn: aws.String(taskArn), - WaitTimeoutMs: aws.Int64(waitTimeoutMillis), - } - - err := validateAttachTaskNetworkInterfacesMessage(message) - assert.Error(t, err) -} - -// TestAttachENIMessageWithMissingNetworkDetails checks the validator against an -// AttachTaskNetworkInterfacesMessage without network details -func TestAttachENIMessageWithMissingNetworkDetails(t *testing.T) { - mockNetInterface1 := ecsacs.ElasticNetworkInterface{} - - message := &ecsacs.AttachTaskNetworkInterfacesMessage{ - MessageId: aws.String(eniMessageId), - ClusterArn: aws.String(clusterName), - ContainerInstanceArn: aws.String(containerInstanceArn), - ElasticNetworkInterfaces: []*ecsacs.ElasticNetworkInterface{ - &mockNetInterface1, - }, - TaskArn: aws.String(taskArn), - WaitTimeoutMs: aws.Int64(waitTimeoutMillis), - } - - err := validateAttachTaskNetworkInterfacesMessage(message) - assert.Error(t, err) -} - -// TestAttachENIMessageWithMissingMACAddress checks the validator against an -// AttachTaskNetworkInterfacesMessage without a MAC address -func TestAttachENIMessageWithMissingMACAddress(t *testing.T) { - mockNetInterface1 := ecsacs.ElasticNetworkInterface{ - Ec2Id: aws.String("1"), - } - message := &ecsacs.AttachTaskNetworkInterfacesMessage{ - MessageId: aws.String(eniMessageId), - ClusterArn: aws.String(clusterName), - ContainerInstanceArn: aws.String(containerInstanceArn), - ElasticNetworkInterfaces: []*ecsacs.ElasticNetworkInterface{ - &mockNetInterface1, - }, - TaskArn: aws.String(taskArn), - WaitTimeoutMs: aws.Int64(waitTimeoutMillis), - } - - err := validateAttachTaskNetworkInterfacesMessage(message) - assert.Error(t, err) -} - -// TODO: -// * Add TaskArn + Timeout Tests - -// TestAttachENIMessageWithMissingTaskArn checks the validator against an -// AttachTaskNetworkInterfacesMessage without a MAC address -func TestAttachENIMessageWithMissingTaskArn(t *testing.T) { - mockNetInterface1 := ecsacs.ElasticNetworkInterface{ - Ec2Id: aws.String("1"), - MacAddress: aws.String(randomMAC), - } - message := &ecsacs.AttachTaskNetworkInterfacesMessage{ - MessageId: aws.String(eniMessageId), - ClusterArn: aws.String(clusterName), - ContainerInstanceArn: aws.String(containerInstanceArn), - ElasticNetworkInterfaces: []*ecsacs.ElasticNetworkInterface{ - &mockNetInterface1, - }, - WaitTimeoutMs: aws.Int64(waitTimeoutMillis), - } - - err := validateAttachTaskNetworkInterfacesMessage(message) - assert.Error(t, err) -} - -// TestAttachENIMessageWithMissingTimeout checks the validator against an -// AttachTaskNetworkInterfacesMessage without a MAC address -func TestAttachENIMessageWithMissingTimeout(t *testing.T) { - mockNetInterface1 := ecsacs.ElasticNetworkInterface{ - Ec2Id: aws.String("1"), - } - message := &ecsacs.AttachTaskNetworkInterfacesMessage{ - MessageId: aws.String(eniMessageId), - ClusterArn: aws.String(clusterName), - ContainerInstanceArn: aws.String(containerInstanceArn), - ElasticNetworkInterfaces: []*ecsacs.ElasticNetworkInterface{ - &mockNetInterface1, - }, - TaskArn: aws.String(taskArn), - } - - err := validateAttachTaskNetworkInterfacesMessage(message) - assert.Error(t, err) -} - -// TestENIAckSingleMessage checks the ack for a single message -func TestENIAckSingleMessage(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - taskEngineState := dockerstate.NewTaskEngineState() - dataClient := data.NewNoopClient() - - ctx := context.TODO() - mockWSClient := mock_wsclient.NewMockClientServer(ctrl) - eniAttachHandler := newAttachTaskENIHandler(ctx, clusterName, containerInstanceArn, mockWSClient, - &eniHandler{ - state: taskEngineState, - dataClient: dataClient, - }, - ) - - var ackSent sync.WaitGroup - ackSent.Add(1) - mockWSClient.EXPECT().MakeRequest(gomock.Any()).Do(func(ackRequest *ecsacs.AckRequest) { - assert.Equal(t, aws.StringValue(ackRequest.MessageId), eniMessageId) - ackSent.Done() - }) - - go eniAttachHandler.start() - - mockNetInterface1 := ecsacs.ElasticNetworkInterface{ - Ec2Id: aws.String("1"), - MacAddress: aws.String(randomMAC), - AttachmentArn: aws.String("attachmentarn"), - } - message := &ecsacs.AttachTaskNetworkInterfacesMessage{ - MessageId: aws.String(eniMessageId), - ClusterArn: aws.String(clusterName), - ContainerInstanceArn: aws.String(containerInstanceArn), - ElasticNetworkInterfaces: []*ecsacs.ElasticNetworkInterface{ - &mockNetInterface1, - }, - TaskArn: aws.String(taskArn), - WaitTimeoutMs: aws.Int64(waitTimeoutMillis), - } - - eniAttachHandler.messageBuffer <- message - ackSent.Wait() - eniAttachHandler.stop() -} - -// TestENIAckSingleMessageDuplicateENIAttachmentMessageStartsTimer checks the ack for a single message -// and ensures that the ENI ack expiration timer is started -func TestENIAckSingleMessageDuplicateENIAttachmentMessageStartsTimer(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockState := mock_dockerstate.NewMockTaskEngineState(ctrl) - dataClient := data.NewNoopClient() - - ctx := context.TODO() - mockWSClient := mock_wsclient.NewMockClientServer(ctrl) - eniAttachHandler := newAttachTaskENIHandler(ctx, clusterName, containerInstanceArn, mockWSClient, - &eniHandler{ - state: mockState, - dataClient: dataClient, - }, - ) - - // Set expiresAt to a value in the past - expiresAt := time.Unix(time.Now().Unix()-1, 0) - var ackSent sync.WaitGroup - ackSent.Add(1) - mockWSClient.EXPECT().MakeRequest(gomock.Any()).Do(func(ackRequest *ecsacs.AckRequest) { - assert.Equal(t, aws.StringValue(ackRequest.MessageId), eniMessageId) - ackSent.Done() - }) - gomock.InOrder( - // Sending an attachment with ExpiresAt set in the past results in an - // error in starting the timer. - // Ensuring that statemanager.Save() is not invoked should be a strong - // enough check to ensure that the timer was started - mockState.EXPECT().ENIByMac(randomMAC).Return(&apieni.ENIAttachment{ - AttachmentInfo: attachmentinfo.AttachmentInfo{ - ExpiresAt: expiresAt, - }, - }, true), - ) - - mockNetInterface1 := ecsacs.ElasticNetworkInterface{ - Ec2Id: aws.String("1"), - MacAddress: aws.String(randomMAC), - AttachmentArn: aws.String("attachmentarn"), - } - message := &ecsacs.AttachTaskNetworkInterfacesMessage{ - MessageId: aws.String(eniMessageId), - ClusterArn: aws.String(clusterName), - ContainerInstanceArn: aws.String(containerInstanceArn), - ElasticNetworkInterfaces: []*ecsacs.ElasticNetworkInterface{ - &mockNetInterface1, - }, - TaskArn: aws.String(taskArn), - WaitTimeoutMs: aws.Int64(waitTimeoutMillis), - } - - // Expect an error starting the timer because of <=0 duration - err := eniAttachHandler.handleSingleMessage(message) - assert.Error(t, err) - ackSent.Wait() -} - -// TestENIAckHappyPath tests the happy path for a typical AttachTaskNetworkInterfacesMessage -func TestENIAckHappyPath(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - ctx := context.TODO() - taskEngineState := dockerstate.NewTaskEngineState() - dataClient := data.NewNoopClient() - - mockWSClient := mock_wsclient.NewMockClientServer(ctrl) - eniAttachHandler := newAttachTaskENIHandler(ctx, clusterName, containerInstanceArn, mockWSClient, - &eniHandler{ - state: taskEngineState, - dataClient: dataClient, - }, - ) - - var ackSent sync.WaitGroup - ackSent.Add(1) - mockWSClient.EXPECT().MakeRequest(gomock.Any()).Do(func(ackRequest *ecsacs.AckRequest) { - assert.Equal(t, aws.StringValue(ackRequest.MessageId), eniMessageId) - ackSent.Done() - eniAttachHandler.stop() - }) - - go eniAttachHandler.start() - - mockNetInterface1 := ecsacs.ElasticNetworkInterface{ - Ec2Id: aws.String("1"), - MacAddress: aws.String(randomMAC), - } - message := &ecsacs.AttachTaskNetworkInterfacesMessage{ - MessageId: aws.String(eniMessageId), - ClusterArn: aws.String(clusterName), - ContainerInstanceArn: aws.String(containerInstanceArn), - ElasticNetworkInterfaces: []*ecsacs.ElasticNetworkInterface{ - &mockNetInterface1, - }, - TaskArn: aws.String(taskArn), - WaitTimeoutMs: aws.Int64(waitTimeoutMillis), - } - - eniAttachHandler.messageBuffer <- message - - ackSent.Wait() - select { - case <-eniAttachHandler.ctx.Done(): - } -} diff --git a/agent/acs/handler/attach_task_eni_responder_test.go b/agent/acs/handler/attach_task_eni_responder_test.go new file mode 100644 index 00000000000..56f897e6435 --- /dev/null +++ b/agent/acs/handler/attach_task_eni_responder_test.go @@ -0,0 +1,156 @@ +//go:build unit +// +build unit + +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package handler + +import ( + "sync" + "testing" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + + "github.com/aws/amazon-ecs-agent/agent/data" + "github.com/aws/amazon-ecs-agent/agent/engine/dockerstate" + mock_dockerstate "github.com/aws/amazon-ecs-agent/agent/engine/dockerstate/mocks" + "github.com/aws/amazon-ecs-agent/ecs-agent/acs/model/ecsacs" + acssession "github.com/aws/amazon-ecs-agent/ecs-agent/acs/session" + "github.com/aws/amazon-ecs-agent/ecs-agent/acs/session/testconst" + "github.com/aws/amazon-ecs-agent/ecs-agent/api/attachmentinfo" + apieni "github.com/aws/amazon-ecs-agent/ecs-agent/api/eni" +) + +const ( + eniMessageId = "123" + randomMAC = "00:0a:95:9d:68:16" + waitTimeoutMillis = 1000 + + interfaceProtocol = "default" + gatewayIpv4 = "192.168.1.1/24" + ipv4Address = "ipv4" +) + +var testAttachTaskENIMessage = &ecsacs.AttachTaskNetworkInterfacesMessage{ + MessageId: aws.String(eniMessageId), + ClusterArn: aws.String(testconst.ClusterName), + ContainerInstanceArn: aws.String(testconst.ContainerInstanceARN), + ElasticNetworkInterfaces: []*ecsacs.ElasticNetworkInterface{ + { + Ec2Id: aws.String("1"), + MacAddress: aws.String(randomMAC), + InterfaceAssociationProtocol: aws.String(interfaceProtocol), + SubnetGatewayIpv4Address: aws.String(gatewayIpv4), + Ipv4Addresses: []*ecsacs.IPv4AddressAssignment{ + { + Primary: aws.Bool(true), + PrivateAddress: aws.String(ipv4Address), + }, + }, + }, + }, + TaskArn: aws.String(testconst.TaskARN), + WaitTimeoutMs: aws.Int64(waitTimeoutMillis), +} + +// TestENIAckHappyPath tests the happy path for a typical AttachTaskNetworkInterfacesMessage and confirms expected +// ACK request is made +func TestENIAckHappyPath(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ackSent := make(chan *ecsacs.AckRequest) + + taskEngineState := dockerstate.NewTaskEngineState() + dataClient := data.NewNoopClient() + + testResponseSender := func(response interface{}) error { + resp := response.(*ecsacs.AckRequest) + ackSent <- resp + return nil + } + testAttachTaskENIResponder := acssession.NewAttachTaskENIResponder( + &eniHandler{ + state: taskEngineState, + dataClient: dataClient, + }, + testResponseSender) + + handleAttachMessage := testAttachTaskENIResponder.HandlerFunc().(func(*ecsacs.AttachTaskNetworkInterfacesMessage)) + go handleAttachMessage(testAttachTaskENIMessage) + + attachTaskEniAckSent := <-ackSent + assert.Equal(t, aws.StringValue(attachTaskEniAckSent.MessageId), eniMessageId) +} + +// TestENIAckSingleMessageWithDuplicateENIAttachment tests the path for an +// AttachTaskNetworkInterfacesMessage with a duplicate expired ENI and confirms: +// 1. attempt is made to start the ack timer that records the expiration of ENI attachment (i.e., ENI is not added to +// task engine state) +// 2. expected ACK request is made +func TestENIAckSingleMessageWithDuplicateENIAttachment(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ackSent := make(chan *ecsacs.AckRequest) + + mockState := mock_dockerstate.NewMockTaskEngineState(ctrl) + dataClient := data.NewNoopClient() + + // Set expiresAt to a value in the past. + expiresAt := time.Unix(time.Now().Unix()-1, 0) + + // WaitGroup is necessary to wait for a function to be called in separate goroutine before exiting the test. + wg := sync.WaitGroup{} + wg.Add(1) + + gomock.InOrder( + // Sending a duplicate expired ENI attachment still results in an attempt to start the timer. We don't really + // care if the timer is actually started or not (i.e., whether or not the ENI attachment is expired); we just + // care that an attempt was made. Attempting to start the timer means that the ENI attachment was not added to + // the task engine state. + mockState.EXPECT(). + ENIByMac(randomMAC). + Return(&apieni.ENIAttachment{ + AttachmentInfo: attachmentinfo.AttachmentInfo{ + ExpiresAt: expiresAt, + }, + }, true). + Do(func(arg0 interface{}) { + defer wg.Done() // we can exit the test now that ENIByMac function has been called + }), + ) + + testResponseSender := func(response interface{}) error { + resp := response.(*ecsacs.AckRequest) + ackSent <- resp + return nil + } + testAttachTaskENIResponder := acssession.NewAttachTaskENIResponder( + &eniHandler{ + state: mockState, + dataClient: dataClient, + }, + testResponseSender) + + handleAttachMessage := testAttachTaskENIResponder.HandlerFunc().(func(*ecsacs.AttachTaskNetworkInterfacesMessage)) + go handleAttachMessage(testAttachTaskENIMessage) + + attachTaskEniAckSent := <-ackSent + wg.Wait() + assert.Equal(t, aws.StringValue(attachTaskEniAckSent.MessageId), eniMessageId) +} diff --git a/agent/acs/handler/payload_handler_test.go b/agent/acs/handler/payload_handler_test.go index 4f3b0ae2a96..ff9d69240a6 100644 --- a/agent/acs/handler/payload_handler_test.go +++ b/agent/acs/handler/payload_handler_test.go @@ -34,6 +34,7 @@ import ( "github.com/aws/amazon-ecs-agent/agent/eventhandler" "github.com/aws/amazon-ecs-agent/agent/taskresource" "github.com/aws/amazon-ecs-agent/ecs-agent/acs/model/ecsacs" + "github.com/aws/amazon-ecs-agent/ecs-agent/acs/session/testconst" "github.com/aws/amazon-ecs-agent/ecs-agent/api/eni" "github.com/aws/amazon-ecs-agent/ecs-agent/credentials" mock_wsclient "github.com/aws/amazon-ecs-agent/ecs-agent/wsclient/mock" @@ -45,10 +46,7 @@ import ( ) const ( - clusterName = "default" - containerInstanceArn = "instance" - payloadMessageId = "123" - testTaskARN = "arn:aws:ecs:us-west-2:1234567890:task/test-cluster/abc" + payloadMessageId = "123" ) // testHelper wraps all the object required for the test @@ -78,8 +76,8 @@ func setup(t *testing.T) *testHelper { ctx, taskEngine, ecsClient, - clusterName, - containerInstanceArn, + testconst.ClusterName, + testconst.ContainerInstanceARN, mockWsClient, data.NewNoopClient(), refreshCredentialsHandler{}, @@ -156,7 +154,7 @@ func TestHandlePayloadMessageSaveData(t *testing.T) { err := tester.payloadHandler.handleSingleMessage(&ecsacs.PayloadMessage{ Tasks: []*ecsacs.Task{ { - Arn: aws.String(testTaskARN), + Arn: aws.String(testconst.TaskARN), DesiredStatus: aws.String(tc.taskDesiredStatus), }, }, @@ -309,7 +307,7 @@ func TestHandlePayloadMessageCredentialsAckedWhenTaskAdded(t *testing.T) { }), ) - refreshCredsHandler := newRefreshCredentialsHandler(tester.ctx, clusterName, containerInstanceArn, tester.mockWsClient, tester.credentialsManager, tester.mockTaskEngine) + refreshCredsHandler := newRefreshCredentialsHandler(tester.ctx, testconst.ClusterName, testconst.ContainerInstanceARN, tester.mockWsClient, tester.credentialsManager, tester.mockTaskEngine) defer refreshCredsHandler.clearAcks() refreshCredsHandler.start() tester.payloadHandler.refreshHandler = refreshCredsHandler @@ -498,7 +496,7 @@ func TestPayloadBufferHandlerWithCredentials(t *testing.T) { }), ) - refreshCredsHandler := newRefreshCredentialsHandler(tester.ctx, clusterName, containerInstanceArn, tester.mockWsClient, tester.credentialsManager, tester.mockTaskEngine) + refreshCredsHandler := newRefreshCredentialsHandler(tester.ctx, testconst.ClusterName, testconst.ContainerInstanceARN, tester.mockWsClient, tester.credentialsManager, tester.mockTaskEngine) defer refreshCredsHandler.clearAcks() refreshCredsHandler.start() tester.payloadHandler.refreshHandler = refreshCredsHandler @@ -620,7 +618,7 @@ func TestAddPayloadTaskAddsExecutionRoles(t *testing.T) { tester.cancel() }), ) - refreshCredsHandler := newRefreshCredentialsHandler(tester.ctx, clusterName, containerInstanceArn, tester.mockWsClient, tester.credentialsManager, tester.mockTaskEngine) + refreshCredsHandler := newRefreshCredentialsHandler(tester.ctx, testconst.ClusterName, testconst.ContainerInstanceARN, tester.mockWsClient, tester.credentialsManager, tester.mockTaskEngine) defer refreshCredsHandler.clearAcks() refreshCredsHandler.start() diff --git a/agent/acs/handler/refresh_credentials_handler_test.go b/agent/acs/handler/refresh_credentials_handler_test.go index 184729ae1f2..375bce36fb0 100644 --- a/agent/acs/handler/refresh_credentials_handler_test.go +++ b/agent/acs/handler/refresh_credentials_handler_test.go @@ -27,6 +27,7 @@ import ( apitask "github.com/aws/amazon-ecs-agent/agent/api/task" mock_engine "github.com/aws/amazon-ecs-agent/agent/engine/mocks" "github.com/aws/amazon-ecs-agent/ecs-agent/acs/model/ecsacs" + "github.com/aws/amazon-ecs-agent/ecs-agent/acs/session/testconst" "github.com/aws/amazon-ecs-agent/ecs-agent/credentials" mock_wsclient "github.com/aws/amazon-ecs-agent/ecs-agent/wsclient/mock" @@ -283,7 +284,7 @@ func TestHandleRefreshMessageAckedWhenCredentialsUpdated(t *testing.T) { checkAndSetDomainlessGMSATaskExecutionRoleCredentialsImpl = checkAndSetDomainlessGMSATaskExecutionRoleCredentials }() - handler := newRefreshCredentialsHandler(ctx, clusterName, containerInstanceArn, mockWsClient, credentialsManager, taskEngine) + handler := newRefreshCredentialsHandler(ctx, testconst.ClusterName, testconst.ContainerInstanceARN, mockWsClient, credentialsManager, taskEngine) go handler.sendAcks() // test adding a credentials message without the MessageId field @@ -388,7 +389,7 @@ func TestRefreshCredentialsHandlerSendPendingAcks(t *testing.T) { mockWSClient := mock_wsclient.NewMockClientServer(ctrl) mockWSClient.EXPECT().MakeRequest(gomock.Any()).Return(nil).Times(1) - handler := newRefreshCredentialsHandler(ctx, clusterName, containerInstanceArn, mockWSClient, + handler := newRefreshCredentialsHandler(ctx, testconst.ClusterName, testconst.ContainerInstanceARN, mockWSClient, credentialsManager, taskEngine) wg := sync.WaitGroup{} @@ -436,7 +437,7 @@ func TestRefreshCredentialsHandler(t *testing.T) { // Return a task from the engine for GetTaskByArn taskEngine.EXPECT().GetTaskByArn(taskArn).Return(&apitask.Task{}, true) - handler := newRefreshCredentialsHandler(ctx, clusterName, containerInstanceArn, mockWsClient, credentialsManager, taskEngine) + handler := newRefreshCredentialsHandler(ctx, testconst.ClusterName, testconst.ContainerInstanceARN, mockWsClient, credentialsManager, taskEngine) go handler.start() handler.messageBuffer <- message diff --git a/agent/acs/handler/task_manifest_handler_test.go b/agent/acs/handler/task_manifest_handler_test.go index 225822e85a4..f1325d51876 100644 --- a/agent/acs/handler/task_manifest_handler_test.go +++ b/agent/acs/handler/task_manifest_handler_test.go @@ -27,6 +27,7 @@ import ( "github.com/aws/amazon-ecs-agent/agent/data" mock_engine "github.com/aws/amazon-ecs-agent/agent/engine/mocks" "github.com/aws/amazon-ecs-agent/ecs-agent/acs/model/ecsacs" + "github.com/aws/amazon-ecs-agent/ecs-agent/acs/session/testconst" mock_wsclient "github.com/aws/amazon-ecs-agent/ecs-agent/wsclient/mock" "github.com/aws/aws-sdk-go/aws" @@ -466,7 +467,7 @@ func TestManifestHandlerSequenceNumbers(t *testing.T) { mockWSClient := mock_wsclient.NewMockClientServer(ctrl) manifestMessageIDAccessor := &manifestMessageIDAccessor{} - newTaskManifest := newTaskManifestHandler(ctx, cluster, containerInstanceArn, mockWSClient, + newTaskManifest := newTaskManifestHandler(ctx, cluster, testconst.ContainerInstanceARN, mockWSClient, data.NewNoopClient(), taskEngine, aws.Int64(tc.inputSequenceNumber), manifestMessageIDAccessor) taskList := []*task.Task{ @@ -481,8 +482,8 @@ func TestManifestHandlerSequenceNumbers(t *testing.T) { message := &ecsacs.TaskManifestMessage{ MessageId: aws.String(eniMessageId), - ClusterArn: aws.String(clusterName), - ContainerInstanceArn: aws.String(containerInstanceArn), + ClusterArn: aws.String(testconst.ClusterName), + ContainerInstanceArn: aws.String(testconst.ContainerInstanceARN), Tasks: []*ecsacs.TaskIdentifier{ { DesiredStatus: aws.String(apitaskstatus.TaskStoppedString), @@ -561,7 +562,7 @@ func TestTaskManifestHandlerSendPendingTaskManifestMessageAck(t *testing.T) { mockWSClient := mock_wsclient.NewMockClientServer(ctrl) mockWSClient.EXPECT().MakeRequest(gomock.Any()).Return(nil).Times(1) manifestMessageIDAccessor := &manifestMessageIDAccessor{} - handler := newTaskManifestHandler(ctx, cluster, containerInstanceArn, mockWSClient, + handler := newTaskManifestHandler(ctx, cluster, testconst.ContainerInstanceARN, mockWSClient, data.NewNoopClient(), taskEngine, aws.Int64(testSeqNum), manifestMessageIDAccessor) wg := sync.WaitGroup{} @@ -598,7 +599,7 @@ func TestTaskManifestHandlerHandlePendingTaskStopVerificationAck(t *testing.T) { taskEngine := mock_engine.NewMockTaskEngine(ctrl) mockWSClient := mock_wsclient.NewMockClientServer(ctrl) manifestMessageIDAccessor := &manifestMessageIDAccessor{} - handler := newTaskManifestHandler(ctx, cluster, containerInstanceArn, mockWSClient, + handler := newTaskManifestHandler(ctx, cluster, testconst.ContainerInstanceARN, mockWSClient, data.NewNoopClient(), taskEngine, aws.Int64(testSeqNum), manifestMessageIDAccessor) wg := sync.WaitGroup{} diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/acs/session/attach_task_eni_responder.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/acs/session/attach_task_eni_responder.go new file mode 100644 index 00000000000..0f840e503da --- /dev/null +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/acs/session/attach_task_eni_responder.go @@ -0,0 +1,158 @@ +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package session + +import ( + "fmt" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/pkg/errors" + + "github.com/aws/amazon-ecs-agent/ecs-agent/acs/model/ecsacs" + "github.com/aws/amazon-ecs-agent/ecs-agent/api/attachmentinfo" + apieni "github.com/aws/amazon-ecs-agent/ecs-agent/api/eni" + "github.com/aws/amazon-ecs-agent/ecs-agent/api/status" + "github.com/aws/amazon-ecs-agent/ecs-agent/logger" + "github.com/aws/amazon-ecs-agent/ecs-agent/logger/field" + "github.com/aws/amazon-ecs-agent/ecs-agent/wsclient" +) + +const ( + AttachTaskENIMessageName = "AttachTaskNetworkInterfacesMessage" +) + +// attachTaskENIResponder implements the wsclient.RequestResponder interface for responding +// to ecsacs.AttachTaskNetworkInterfacesMessage messages sent by ACS. +type attachTaskENIResponder struct { + eniHandler ENIHandler + respond wsclient.RespondFunc +} + +// NewAttachTaskENIResponder returns an instance of the attachENIHandler struct. +func NewAttachTaskENIResponder(eniHandler ENIHandler, responseSender wsclient.RespondFunc) *attachTaskENIResponder { + r := &attachTaskENIResponder{ + eniHandler: eniHandler, + } + r.respond = ResponseToACSSender(r.Name(), responseSender) + return r +} + +func (*attachTaskENIResponder) Name() string { return "attach task ENI responder" } + +func (r *attachTaskENIResponder) HandlerFunc() wsclient.RequestHandler { + return r.handleAttachMessage +} + +func (r *attachTaskENIResponder) handleAttachMessage(message *ecsacs.AttachTaskNetworkInterfacesMessage) { + logger.Debug(fmt.Sprintf("Handling %s", AttachTaskENIMessageName)) + receivedAt := time.Now() + + // Validate fields in the message. + if err := validateAttachTaskNetworkInterfacesMessage(message); err != nil { + logger.Error(fmt.Sprintf("Error validating %s received from ECS", AttachTaskENIMessageName), logger.Fields{ + field.Error: err, + }) + return + } + + // Handle ENIs in the message. + messageID := aws.StringValue(message.MessageId) + for _, mENI := range message.ElasticNetworkInterfaces { + expiresAt := receivedAt.Add(time.Duration(aws.Int64Value(message.WaitTimeoutMs)) * time.Millisecond) + go func(eni *ecsacs.ElasticNetworkInterface) { + err := r.eniHandler.HandleENIAttachment(&apieni.ENIAttachment{ + AttachmentInfo: attachmentinfo.AttachmentInfo{ + TaskARN: aws.StringValue(message.TaskArn), + AttachmentARN: aws.StringValue(eni.AttachmentArn), + Status: status.AttachmentNone, + ExpiresAt: expiresAt, + AttachStatusSent: false, + ClusterARN: aws.StringValue(message.ClusterArn), + ContainerInstanceARN: aws.StringValue(message.ContainerInstanceArn), + }, + AttachmentType: apieni.ENIAttachmentTypeTaskENI, + MACAddress: aws.StringValue(eni.MacAddress), + }) + if err != nil { + logger.Error(fmt.Sprintf("Unable to handle %s", AttachTaskENIMessageName), logger.Fields{ + field.MessageID: messageID, + field.Error: err, + }) + } + }(mENI) + } + + // Send ACK. + go func() { + err := r.respond(&ecsacs.AckRequest{ + Cluster: message.ClusterArn, + ContainerInstance: message.ContainerInstanceArn, + MessageId: message.MessageId, + }) + if err != nil { + logger.Warn(fmt.Sprintf("Error acknowledging %s", AttachTaskENIMessageName), logger.Fields{ + field.MessageID: messageID, + field.Error: err, + }) + } + }() +} + +// validateAttachTaskNetworkInterfacesMessage performs validation checks on the +// AttachTaskNetworkInterfacesMessage. +func validateAttachTaskNetworkInterfacesMessage(message *ecsacs.AttachTaskNetworkInterfacesMessage) error { + if message == nil { + return errors.Errorf("Message is empty") + } + + messageID := aws.StringValue(message.MessageId) + if messageID == "" { + return errors.Errorf("Message ID is not set") + } + + clusterArn := aws.StringValue(message.ClusterArn) + if clusterArn == "" { + return errors.Errorf("clusterArn is not set for message ID %s", messageID) + } + + containerInstanceArn := aws.StringValue(message.ContainerInstanceArn) + if containerInstanceArn == "" { + return errors.Errorf("containerInstanceArn is not set for message ID %s", messageID) + } + + taskArn := aws.StringValue(message.TaskArn) + if taskArn == "" { + return errors.Errorf("taskArn is not set for message ID %s", messageID) + } + + timeout := aws.Int64Value(message.WaitTimeoutMs) + if timeout <= 0 { + return errors.Errorf("Invalid timeout set for message ID %s", messageID) + } + + enis := message.ElasticNetworkInterfaces + if len(enis) < 1 { + return errors.Errorf("No ENIs for message ID %s", messageID) + } + + for _, eni := range enis { + err := apieni.ValidateTaskENI(eni) + if err != nil { + return err + } + } + + return nil +} diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/acs/session/testconst/test_const.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/acs/session/testconst/test_const.go new file mode 100644 index 00000000000..19beef17e1c --- /dev/null +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/acs/session/testconst/test_const.go @@ -0,0 +1,9 @@ +package testconst + +// This file contains constants that are commonly used when testing ACS session and responders. These constants +// should only be called in test files. +const ( + ClusterName = "default" + ContainerInstanceARN = "instance" + TaskARN = "arn:aws:ecs:us-west-2:1234567890:task/test-cluster/abc" +) diff --git a/agent/vendor/modules.txt b/agent/vendor/modules.txt index 5d14d4c7f2f..45fd7a97f35 100644 --- a/agent/vendor/modules.txt +++ b/agent/vendor/modules.txt @@ -10,6 +10,7 @@ github.com/Microsoft/hcsshim/osversion github.com/aws/amazon-ecs-agent/ecs-agent/acs/client github.com/aws/amazon-ecs-agent/ecs-agent/acs/model/ecsacs github.com/aws/amazon-ecs-agent/ecs-agent/acs/session +github.com/aws/amazon-ecs-agent/ecs-agent/acs/session/testconst github.com/aws/amazon-ecs-agent/ecs-agent/api/appnet github.com/aws/amazon-ecs-agent/ecs-agent/api/attachmentinfo github.com/aws/amazon-ecs-agent/ecs-agent/api/eni diff --git a/ecs-agent/acs/session/attach_task_eni_responder.go b/ecs-agent/acs/session/attach_task_eni_responder.go new file mode 100644 index 00000000000..0f840e503da --- /dev/null +++ b/ecs-agent/acs/session/attach_task_eni_responder.go @@ -0,0 +1,158 @@ +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package session + +import ( + "fmt" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/pkg/errors" + + "github.com/aws/amazon-ecs-agent/ecs-agent/acs/model/ecsacs" + "github.com/aws/amazon-ecs-agent/ecs-agent/api/attachmentinfo" + apieni "github.com/aws/amazon-ecs-agent/ecs-agent/api/eni" + "github.com/aws/amazon-ecs-agent/ecs-agent/api/status" + "github.com/aws/amazon-ecs-agent/ecs-agent/logger" + "github.com/aws/amazon-ecs-agent/ecs-agent/logger/field" + "github.com/aws/amazon-ecs-agent/ecs-agent/wsclient" +) + +const ( + AttachTaskENIMessageName = "AttachTaskNetworkInterfacesMessage" +) + +// attachTaskENIResponder implements the wsclient.RequestResponder interface for responding +// to ecsacs.AttachTaskNetworkInterfacesMessage messages sent by ACS. +type attachTaskENIResponder struct { + eniHandler ENIHandler + respond wsclient.RespondFunc +} + +// NewAttachTaskENIResponder returns an instance of the attachENIHandler struct. +func NewAttachTaskENIResponder(eniHandler ENIHandler, responseSender wsclient.RespondFunc) *attachTaskENIResponder { + r := &attachTaskENIResponder{ + eniHandler: eniHandler, + } + r.respond = ResponseToACSSender(r.Name(), responseSender) + return r +} + +func (*attachTaskENIResponder) Name() string { return "attach task ENI responder" } + +func (r *attachTaskENIResponder) HandlerFunc() wsclient.RequestHandler { + return r.handleAttachMessage +} + +func (r *attachTaskENIResponder) handleAttachMessage(message *ecsacs.AttachTaskNetworkInterfacesMessage) { + logger.Debug(fmt.Sprintf("Handling %s", AttachTaskENIMessageName)) + receivedAt := time.Now() + + // Validate fields in the message. + if err := validateAttachTaskNetworkInterfacesMessage(message); err != nil { + logger.Error(fmt.Sprintf("Error validating %s received from ECS", AttachTaskENIMessageName), logger.Fields{ + field.Error: err, + }) + return + } + + // Handle ENIs in the message. + messageID := aws.StringValue(message.MessageId) + for _, mENI := range message.ElasticNetworkInterfaces { + expiresAt := receivedAt.Add(time.Duration(aws.Int64Value(message.WaitTimeoutMs)) * time.Millisecond) + go func(eni *ecsacs.ElasticNetworkInterface) { + err := r.eniHandler.HandleENIAttachment(&apieni.ENIAttachment{ + AttachmentInfo: attachmentinfo.AttachmentInfo{ + TaskARN: aws.StringValue(message.TaskArn), + AttachmentARN: aws.StringValue(eni.AttachmentArn), + Status: status.AttachmentNone, + ExpiresAt: expiresAt, + AttachStatusSent: false, + ClusterARN: aws.StringValue(message.ClusterArn), + ContainerInstanceARN: aws.StringValue(message.ContainerInstanceArn), + }, + AttachmentType: apieni.ENIAttachmentTypeTaskENI, + MACAddress: aws.StringValue(eni.MacAddress), + }) + if err != nil { + logger.Error(fmt.Sprintf("Unable to handle %s", AttachTaskENIMessageName), logger.Fields{ + field.MessageID: messageID, + field.Error: err, + }) + } + }(mENI) + } + + // Send ACK. + go func() { + err := r.respond(&ecsacs.AckRequest{ + Cluster: message.ClusterArn, + ContainerInstance: message.ContainerInstanceArn, + MessageId: message.MessageId, + }) + if err != nil { + logger.Warn(fmt.Sprintf("Error acknowledging %s", AttachTaskENIMessageName), logger.Fields{ + field.MessageID: messageID, + field.Error: err, + }) + } + }() +} + +// validateAttachTaskNetworkInterfacesMessage performs validation checks on the +// AttachTaskNetworkInterfacesMessage. +func validateAttachTaskNetworkInterfacesMessage(message *ecsacs.AttachTaskNetworkInterfacesMessage) error { + if message == nil { + return errors.Errorf("Message is empty") + } + + messageID := aws.StringValue(message.MessageId) + if messageID == "" { + return errors.Errorf("Message ID is not set") + } + + clusterArn := aws.StringValue(message.ClusterArn) + if clusterArn == "" { + return errors.Errorf("clusterArn is not set for message ID %s", messageID) + } + + containerInstanceArn := aws.StringValue(message.ContainerInstanceArn) + if containerInstanceArn == "" { + return errors.Errorf("containerInstanceArn is not set for message ID %s", messageID) + } + + taskArn := aws.StringValue(message.TaskArn) + if taskArn == "" { + return errors.Errorf("taskArn is not set for message ID %s", messageID) + } + + timeout := aws.Int64Value(message.WaitTimeoutMs) + if timeout <= 0 { + return errors.Errorf("Invalid timeout set for message ID %s", messageID) + } + + enis := message.ElasticNetworkInterfaces + if len(enis) < 1 { + return errors.Errorf("No ENIs for message ID %s", messageID) + } + + for _, eni := range enis { + err := apieni.ValidateTaskENI(eni) + if err != nil { + return err + } + } + + return nil +} diff --git a/ecs-agent/acs/session/attach_task_eni_responder_test.go b/ecs-agent/acs/session/attach_task_eni_responder_test.go new file mode 100644 index 00000000000..409960c342b --- /dev/null +++ b/ecs-agent/acs/session/attach_task_eni_responder_test.go @@ -0,0 +1,215 @@ +//go:build unit +// +build unit + +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package session + +import ( + "fmt" + "testing" + + "github.com/aws/aws-sdk-go/aws" + "github.com/stretchr/testify/assert" + + "github.com/aws/amazon-ecs-agent/ecs-agent/acs/model/ecsacs" + "github.com/aws/amazon-ecs-agent/ecs-agent/acs/session/testconst" + apieni "github.com/aws/amazon-ecs-agent/ecs-agent/api/eni" +) + +const ( + eniMessageId = "123" + randomMAC = "00:0a:95:9d:68:16" + waitTimeoutMillis = 1000 + + interfaceProtocol = "default" + gatewayIpv4 = "192.168.1.1/24" + ipv4Address = "ipv4" +) + +var testAttachTaskENIMessage = &ecsacs.AttachTaskNetworkInterfacesMessage{ + MessageId: aws.String(eniMessageId), + ClusterArn: aws.String(testconst.ClusterName), + ContainerInstanceArn: aws.String(testconst.ContainerInstanceARN), + ElasticNetworkInterfaces: []*ecsacs.ElasticNetworkInterface{ + { + Ec2Id: aws.String("1"), + MacAddress: aws.String(randomMAC), + InterfaceAssociationProtocol: aws.String(interfaceProtocol), + SubnetGatewayIpv4Address: aws.String(gatewayIpv4), + Ipv4Addresses: []*ecsacs.IPv4AddressAssignment{ + { + Primary: aws.Bool(true), + PrivateAddress: aws.String(ipv4Address), + }, + }, + }, + }, + TaskArn: aws.String(testconst.TaskARN), + WaitTimeoutMs: aws.Int64(waitTimeoutMillis), +} + +// TestAttachENIEmptyMessage checks the validator against an +// empty AttachTaskNetworkInterfacesMessage +func TestAttachENIEmptyMessage(t *testing.T) { + err := validateAttachTaskNetworkInterfacesMessage(nil) + assert.EqualError(t, err, "Message is empty") +} + +// TestAttachENIMessageWithNoMessageId checks the validator against an +// AttachTaskNetworkInterfacesMessage without a messageId +func TestAttachENIMessageWithNoMessageId(t *testing.T) { + tempMessageId := testAttachTaskENIMessage.MessageId + testAttachTaskENIMessage.MessageId = nil + + err := validateAttachTaskNetworkInterfacesMessage(testAttachTaskENIMessage) + assert.EqualError(t, err, "Message ID is not set") + + testAttachTaskENIMessage.MessageId = tempMessageId +} + +// TestAttachENIMessageWithNoClusterArn checks the validator against an +// AttachTaskNetworkInterfacesMessage without a ClusterArn +func TestAttachENIMessageWithNoClusterArn(t *testing.T) { + tempClusterArn := testAttachTaskENIMessage.ClusterArn + testAttachTaskENIMessage.ClusterArn = nil + + err := validateAttachTaskNetworkInterfacesMessage(testAttachTaskENIMessage) + assert.EqualError(t, err, fmt.Sprintf("clusterArn is not set for message ID %s", + aws.StringValue(testAttachTaskENIMessage.MessageId))) + + testAttachTaskENIMessage.ClusterArn = tempClusterArn +} + +// TestAttachENIMessageWithNoContainerInstanceArn checks the validator against an +// AttachTaskNetworkInterfacesMessage without a ContainerInstanceArn +func TestAttachENIMessageWithNoContainerInstanceArn(t *testing.T) { + tempContainerInstanceArn := testAttachTaskENIMessage.ContainerInstanceArn + testAttachTaskENIMessage.ContainerInstanceArn = nil + + err := validateAttachTaskNetworkInterfacesMessage(testAttachTaskENIMessage) + assert.EqualError(t, err, fmt.Sprintf("containerInstanceArn is not set for message ID %s", + aws.StringValue(testAttachTaskENIMessage.MessageId))) + + testAttachTaskENIMessage.ContainerInstanceArn = tempContainerInstanceArn +} + +// TestAttachENIMessageWithNoInterfaces checks the validator against an +// AttachTaskNetworkInterfacesMessage without any interface +func TestAttachENIMessageWithNoInterfaces(t *testing.T) { + tempENIs := testAttachTaskENIMessage.ElasticNetworkInterfaces + testAttachTaskENIMessage.ElasticNetworkInterfaces = nil + + err := validateAttachTaskNetworkInterfacesMessage(testAttachTaskENIMessage) + assert.EqualError(t, err, fmt.Sprintf("No ENIs for message ID %s", + aws.StringValue(testAttachTaskENIMessage.MessageId))) + + testAttachTaskENIMessage.ElasticNetworkInterfaces = tempENIs +} + +// TestAttachENIMessageWithMultipleInterfaceschecks checks the validator against an +// AttachTaskNetworkInterfacesMessage with multiple interfaces +func TestAttachENIMessageWithMultipleInterfaces(t *testing.T) { + testAttachTaskENIMessage.ElasticNetworkInterfaces = append(testAttachTaskENIMessage.ElasticNetworkInterfaces, + &ecsacs.ElasticNetworkInterface{ + Ec2Id: aws.String("2"), + MacAddress: aws.String(randomMAC), + InterfaceAssociationProtocol: aws.String(interfaceProtocol), + SubnetGatewayIpv4Address: aws.String(gatewayIpv4), + Ipv4Addresses: []*ecsacs.IPv4AddressAssignment{ + { + Primary: aws.Bool(true), + PrivateAddress: aws.String(ipv4Address), + }, + }, + }) + + err := validateAttachTaskNetworkInterfacesMessage(testAttachTaskENIMessage) + assert.NoError(t, err) + + // Remove appended ENI. + testAttachTaskENIMessage.ElasticNetworkInterfaces = + testAttachTaskENIMessage.ElasticNetworkInterfaces[:len(testAttachTaskENIMessage.ElasticNetworkInterfaces)-1] +} + +// TestAttachENIMessageWithInvalidNetworkDetails checks the validator against an +// AttachTaskNetworkInterfacesMessage with invalid network details +func TestAttachENIMessageWithInvalidNetworkDetails(t *testing.T) { + tempIpv4Addresses := testAttachTaskENIMessage.ElasticNetworkInterfaces[0].Ipv4Addresses + testAttachTaskENIMessage.ElasticNetworkInterfaces[0].Ipv4Addresses = nil + err := validateAttachTaskNetworkInterfacesMessage(testAttachTaskENIMessage) + assert.EqualError(t, err, "eni message validation: no ipv4 addresses in the message") + testAttachTaskENIMessage.ElasticNetworkInterfaces[0].Ipv4Addresses = tempIpv4Addresses + + tempSubnetGatewayIpv4Address := testAttachTaskENIMessage.ElasticNetworkInterfaces[0].SubnetGatewayIpv4Address + testAttachTaskENIMessage.ElasticNetworkInterfaces[0].SubnetGatewayIpv4Address = nil + err = validateAttachTaskNetworkInterfacesMessage(testAttachTaskENIMessage) + assert.EqualError(t, err, "eni message validation: no subnet gateway ipv4 address in the message") + invalidSubnetGatewayIpv4Address := aws.String("0.0.0.INVALID") + testAttachTaskENIMessage.ElasticNetworkInterfaces[0].SubnetGatewayIpv4Address = invalidSubnetGatewayIpv4Address + err = validateAttachTaskNetworkInterfacesMessage(testAttachTaskENIMessage) + assert.EqualError(t, err, fmt.Sprintf("eni message validation: invalid subnet gateway ipv4 address %s", + aws.StringValue(invalidSubnetGatewayIpv4Address))) + testAttachTaskENIMessage.ElasticNetworkInterfaces[0].SubnetGatewayIpv4Address = tempSubnetGatewayIpv4Address + + tempMacAddress := testAttachTaskENIMessage.ElasticNetworkInterfaces[0].MacAddress + testAttachTaskENIMessage.ElasticNetworkInterfaces[0].MacAddress = nil + err = validateAttachTaskNetworkInterfacesMessage(testAttachTaskENIMessage) + assert.EqualError(t, err, "eni message validation: empty eni mac address in the message") + testAttachTaskENIMessage.ElasticNetworkInterfaces[0].MacAddress = tempMacAddress + + tempEc2Id := testAttachTaskENIMessage.ElasticNetworkInterfaces[0].Ec2Id + testAttachTaskENIMessage.ElasticNetworkInterfaces[0].Ec2Id = nil + err = validateAttachTaskNetworkInterfacesMessage(testAttachTaskENIMessage) + assert.EqualError(t, err, "eni message validation: empty eni id in the message") + testAttachTaskENIMessage.ElasticNetworkInterfaces[0].Ec2Id = tempEc2Id + + tempInterfaceAssociationProtocol := testAttachTaskENIMessage.ElasticNetworkInterfaces[0].InterfaceAssociationProtocol + unsupportedInterfaceAssociationProtocol := aws.String("unsupported") + testAttachTaskENIMessage.ElasticNetworkInterfaces[0].InterfaceAssociationProtocol = unsupportedInterfaceAssociationProtocol + err = validateAttachTaskNetworkInterfacesMessage(testAttachTaskENIMessage) + assert.EqualError(t, err, fmt.Sprintf("invalid interface association protocol: %s", + aws.StringValue(unsupportedInterfaceAssociationProtocol))) + testAttachTaskENIMessage.ElasticNetworkInterfaces[0].InterfaceAssociationProtocol = + aws.String(apieni.VLANInterfaceAssociationProtocol) + err = validateAttachTaskNetworkInterfacesMessage(testAttachTaskENIMessage) + assert.EqualError(t, err, "vlan interface properties missing") + testAttachTaskENIMessage.ElasticNetworkInterfaces[0].InterfaceAssociationProtocol = tempInterfaceAssociationProtocol +} + +// TestAttachENIMessageWithMissingTaskArn checks the validator against an +// AttachTaskNetworkInterfacesMessage without a task ARN +func TestAttachENIMessageWithMissingTaskArn(t *testing.T) { + tempTaskArn := testAttachTaskENIMessage.TaskArn + testAttachTaskENIMessage.TaskArn = nil + + err := validateAttachTaskNetworkInterfacesMessage(testAttachTaskENIMessage) + assert.EqualError(t, err, fmt.Sprintf("taskArn is not set for message ID %s", + aws.StringValue(testAttachTaskENIMessage.MessageId))) + + testAttachTaskENIMessage.TaskArn = tempTaskArn +} + +// TestAttachENIMessageWithMissingTimeout checks the validator against an +// AttachTaskNetworkInterfacesMessage without a wait timeout +func TestAttachENIMessageWithMissingTimeout(t *testing.T) { + tempWaitTimeoutMs := testAttachTaskENIMessage.WaitTimeoutMs + testAttachTaskENIMessage.WaitTimeoutMs = nil + + err := validateAttachTaskNetworkInterfacesMessage(testAttachTaskENIMessage) + assert.EqualError(t, err, fmt.Sprintf("Invalid timeout set for message ID %s", + aws.StringValue(testAttachTaskENIMessage.MessageId))) + + testAttachTaskENIMessage.WaitTimeoutMs = tempWaitTimeoutMs +} diff --git a/ecs-agent/acs/session/testconst/test_const.go b/ecs-agent/acs/session/testconst/test_const.go new file mode 100644 index 00000000000..19beef17e1c --- /dev/null +++ b/ecs-agent/acs/session/testconst/test_const.go @@ -0,0 +1,9 @@ +package testconst + +// This file contains constants that are commonly used when testing ACS session and responders. These constants +// should only be called in test files. +const ( + ClusterName = "default" + ContainerInstanceARN = "instance" + TaskARN = "arn:aws:ecs:us-west-2:1234567890:task/test-cluster/abc" +)