From 66eeb0d112bf58a2b3be0ec2ae605886d453bcd4 Mon Sep 17 00:00:00 2001 From: Michael Ye Date: Tue, 17 Oct 2023 23:40:49 +0000 Subject: [PATCH] Fix loading CSI driver container from state if it exists --- agent/api/container/container.go | 13 +++ agent/api/container/container_test.go | 42 +++++++- agent/api/container/containertype.go | 1 + agent/api/task/task.go | 30 ++++++ agent/api/task/task_test.go | 102 ++++++++++++++++++++ agent/ebs/watcher.go | 6 +- agent/ebs/watcher_test.go | 133 ++++++++++++++++++++++++-- agent/engine/data.go | 6 ++ agent/engine/data_test.go | 94 ++++++++++++++++++ 9 files changed, 417 insertions(+), 10 deletions(-) diff --git a/agent/api/container/container.go b/agent/api/container/container.go index 8c7af30d4ec..c77dab614e7 100644 --- a/agent/api/container/container.go +++ b/agent/api/container/container.go @@ -1508,3 +1508,16 @@ func (c *Container) GetContainerPortRangeMap() map[string]string { defer c.lock.RUnlock() return c.ContainerPortRangeMap } + +func (c *Container) IsManagedDaemonContainer() bool { + c.lock.RLock() + defer c.lock.RUnlock() + return c.Type == ContainerManagedDaemon +} + +func (c *Container) GetImageName() string { + c.lock.RLock() + defer c.lock.RUnlock() + containerImage := strings.Split(c.Image, ":")[0] + return containerImage +} diff --git a/agent/api/container/container_test.go b/agent/api/container/container_test.go index db34d75ed41..1fd96d6a6ff 100644 --- a/agent/api/container/container_test.go +++ b/agent/api/container/container_test.go @@ -130,13 +130,53 @@ func TestIsInternal(t *testing.T) { } for _, tc := range testCases { - t.Run(fmt.Sprintf("IsInternal shoukd return %t for %s", tc.internal, tc.container.String()), + t.Run(fmt.Sprintf("IsInternal should return %t for %s", tc.internal, tc.container.String()), func(t *testing.T) { assert.Equal(t, tc.internal, tc.container.IsInternal()) }) } } +func TestIsManagedDaemonContainer(t *testing.T) { + testCases := []struct { + container *Container + internal bool + isManagedDaemon bool + }{ + {&Container{}, false, false}, + {&Container{Type: ContainerNormal, Image: "someImage:latest"}, false, false}, + {&Container{Type: ContainerManagedDaemon, Image: "someImage:latest"}, true, true}, + } + + for _, tc := range testCases { + t.Run(fmt.Sprintf("IsManagedDaemonContainer should return %t for %s", tc.isManagedDaemon, tc.container.String()), + func(t *testing.T) { + assert.Equal(t, tc.internal, tc.container.IsInternal()) + ok := tc.container.IsManagedDaemonContainer() + assert.Equal(t, tc.isManagedDaemon, ok) + }) + } +} + +func TestGetImageName(t *testing.T) { + testCases := []struct { + container *Container + imageName string + }{ + {&Container{}, ""}, + {&Container{Image: "someImage:latest"}, "someImage"}, + {&Container{Image: "someImage"}, "someImage"}, + } + + for _, tc := range testCases { + t.Run(fmt.Sprintf("GetImageName should return %s for %s", tc.imageName, tc.container.String()), + func(t *testing.T) { + imageName := tc.container.GetImageName() + assert.Equal(t, tc.imageName, imageName) + }) + } +} + // TestSetupExecutionRoleFlag tests whether or not the container appropriately // sets the flag for using execution roles func TestSetupExecutionRoleFlag(t *testing.T) { diff --git a/agent/api/container/containertype.go b/agent/api/container/containertype.go index 740f504229e..165692e35f8 100644 --- a/agent/api/container/containertype.go +++ b/agent/api/container/containertype.go @@ -53,6 +53,7 @@ var stringToContainerType = map[string]ContainerType{ "EMPTY_HOST_VOLUME": ContainerEmptyHostVolume, "CNI_PAUSE": ContainerCNIPause, "NAMESPACE_PAUSE": ContainerNamespacePause, + "MANAGED_DAEMON": ContainerManagedDaemon, } // String converts the container type enum to a string diff --git a/agent/api/task/task.go b/agent/api/task/task.go index 17cfb13a10d..bdb2a20eca0 100644 --- a/agent/api/task/task.go +++ b/agent/api/task/task.go @@ -3687,3 +3687,33 @@ func (task *Task) HasActiveContainers() bool { } return false } + +// IsManagedDaemonTask will check if a task is a non-stopped managed daemon task +// TODO: We'll probably want to also clean up all of the STOPPED managed daemon task from +func (task *Task) IsManagedDaemonTask() (string, bool) { + task.lock.RLock() + defer task.lock.RUnlock() + + // We'll want to obtain the last known non-stopped managed daemon task to be saved into our task engine. + // There can be an edge case where the task hasn't been progressed to RUNNING yet. + taskStatus := task.KnownStatusUnsafe + if !task.IsInternal || taskStatus >= apitaskstatus.TaskStopped { + return "", false + } + + for _, c := range task.Containers { + if c.IsManagedDaemonContainer() { + imageName := c.GetImageName() + return imageName, true + } + } + return "", false +} + +func (task *Task) IsRunning() bool { + task.lock.RLock() + defer task.lock.RUnlock() + taskStatus := task.KnownStatusUnsafe + + return taskStatus == apitaskstatus.TaskRunning +} diff --git a/agent/api/task/task_test.go b/agent/api/task/task_test.go index 5747f88af82..be439416dd3 100644 --- a/agent/api/task/task_test.go +++ b/agent/api/task/task_test.go @@ -5278,3 +5278,105 @@ func TestRemoveVolumeIndexOutOfBounds(t *testing.T) { task.RemoveVolume(-1) assert.Equal(t, len(task.Volumes), 1) } + +func TestIsManagedDaemonTask(t *testing.T) { + + testTask1 := &Task{ + Containers: []*apicontainer.Container{ + { + Type: apicontainer.ContainerManagedDaemon, + Image: "someImage:latest", + }, + }, + IsInternal: true, + KnownStatusUnsafe: apitaskstatus.TaskRunning, + } + + testTask2 := &Task{ + Containers: []*apicontainer.Container{ + { + Type: apicontainer.ContainerNormal, + Image: "someImage", + }, + { + Type: apicontainer.ContainerNormal, + Image: "someImage:latest", + }, + }, + IsInternal: false, + KnownStatusUnsafe: apitaskstatus.TaskRunning, + } + + testTask3 := &Task{ + Containers: []*apicontainer.Container{ + { + Type: apicontainer.ContainerManagedDaemon, + Image: "someImage:latest", + }, + }, + IsInternal: true, + KnownStatusUnsafe: apitaskstatus.TaskStopped, + } + + testTask4 := &Task{ + Containers: []*apicontainer.Container{ + { + Type: apicontainer.ContainerManagedDaemon, + Image: "someImage:latest", + }, + }, + IsInternal: true, + KnownStatusUnsafe: apitaskstatus.TaskCreated, + } + + testTask5 := &Task{ + Containers: []*apicontainer.Container{ + { + Type: apicontainer.ContainerNormal, + Image: "someImage", + }, + }, + IsInternal: true, + KnownStatusUnsafe: apitaskstatus.TaskStopped, + } + + testCases := []struct { + task *Task + internal bool + isManagedDaemon bool + }{ + { + task: testTask1, + internal: true, + isManagedDaemon: true, + }, + { + task: testTask2, + internal: false, + isManagedDaemon: false, + }, + { + task: testTask3, + internal: true, + isManagedDaemon: false, + }, + { + task: testTask4, + internal: true, + isManagedDaemon: true, + }, + { + task: testTask5, + internal: true, + isManagedDaemon: false, + }, + } + + for _, tc := range testCases { + t.Run(fmt.Sprintf("IsManagedDaemonTask should return %t for %s", tc.isManagedDaemon, tc.task.String()), + func(t *testing.T) { + _, ok := tc.task.IsManagedDaemonTask() + assert.Equal(t, tc.isManagedDaemon, ok) + }) + } +} diff --git a/agent/ebs/watcher.go b/agent/ebs/watcher.go index 69c7b1c6dbb..a565e69b665 100644 --- a/agent/ebs/watcher.go +++ b/agent/ebs/watcher.go @@ -126,7 +126,9 @@ func (w *EBSWatcher) HandleEBSResourceAttachment(ebs *apiebs.ResourceAttachment) } // start EBS CSI Driver Managed Daemon - if runningCsiTask := w.taskEngine.GetDaemonTask(md.EbsCsiDriver); runningCsiTask != nil { + // We want to avoid creating a new CSI driver task if there's already one that's not been stopped. + // TODO: Include unit tests + if runningCsiTask := w.taskEngine.GetDaemonTask(md.EbsCsiDriver); runningCsiTask != nil && !runningCsiTask.GetKnownStatus().Terminal() { log.Debugf("engine ebs CSI driver is running with taskID: %v", runningCsiTask.GetID()) } else { if ebsCsiDaemonManager, ok := w.taskEngine.GetDaemonManagers()[md.EbsCsiDriver]; ok { @@ -191,7 +193,7 @@ func (w *EBSWatcher) stageVolumeEBS(volID, deviceName string) error { } attachmentMountPath := ebsAttachment.GetAttachmentProperties(apiebs.SourceVolumeHostPathKey) hostPath := filepath.Join(hostMountDir, attachmentMountPath) - filesystemType := ebsAttachment.GetAttachmentProperties(apiebs.FileSystemTypeName) + filesystemType := ebsAttachment.GetAttachmentProperties(apiebs.FileSystemKey) // CSI NodeStage stub required fields stubSecrets := make(map[string]string) stubVolumeContext := make(map[string]string) diff --git a/agent/ebs/watcher_test.go b/agent/ebs/watcher_test.go index e454bb9dd2d..5e083220823 100644 --- a/agent/ebs/watcher_test.go +++ b/agent/ebs/watcher_test.go @@ -19,18 +19,24 @@ package ebs import ( "context" "fmt" + "path/filepath" "sync" "testing" "time" + apitask "github.com/aws/amazon-ecs-agent/agent/api/task" "github.com/aws/amazon-ecs-agent/agent/engine" "github.com/aws/amazon-ecs-agent/agent/engine/dockerstate" mock_engine "github.com/aws/amazon-ecs-agent/agent/engine/mocks" + statechange "github.com/aws/amazon-ecs-agent/agent/statechange" taskresourcevolume "github.com/aws/amazon-ecs-agent/agent/taskresource/volume" "github.com/aws/amazon-ecs-agent/ecs-agent/acs/session/testconst" "github.com/aws/amazon-ecs-agent/ecs-agent/api/attachment" apiebs "github.com/aws/amazon-ecs-agent/ecs-agent/api/attachment/resource" mock_ebs_discovery "github.com/aws/amazon-ecs-agent/ecs-agent/api/attachment/resource/mocks" + apitaskstatus "github.com/aws/amazon-ecs-agent/ecs-agent/api/task/status" + csi "github.com/aws/amazon-ecs-agent/ecs-agent/csiclient" + mock_csiclient "github.com/aws/amazon-ecs-agent/ecs-agent/csiclient/mocks" md "github.com/aws/amazon-ecs-agent/ecs-agent/manageddaemon" "github.com/golang/mock/gomock" @@ -47,7 +53,7 @@ const ( // newTestEBSWatcher creates a new EBSWatcher object for testing func newTestEBSWatcher(ctx context.Context, agentState dockerstate.TaskEngineState, - discoveryClient apiebs.EBSDiscovery, taskEngine engine.TaskEngine) *EBSWatcher { + discoveryClient apiebs.EBSDiscovery, taskEngine engine.TaskEngine, csiClient csi.CSIClient) *EBSWatcher { derivedContext, cancel := context.WithCancel(ctx) return &EBSWatcher{ ctx: derivedContext, @@ -55,6 +61,7 @@ func newTestEBSWatcher(ctx context.Context, agentState dockerstate.TaskEngineSta agentState: agentState, discoveryClient: discoveryClient, taskEngine: taskEngine, + csiClient: csiClient, } } @@ -72,6 +79,18 @@ func TestHandleEBSAttachmentHappyCase(t *testing.T) { mockTaskEngine.EXPECT().GetDaemonTask(md.EbsCsiDriver).Return(nil).AnyTimes() mockTaskEngine.EXPECT().GetDaemonManagers().Return(nil).AnyTimes() + mockCsiClient := mock_csiclient.NewMockCSIClient(mockCtrl) + mockCsiClient.EXPECT().NodeStageVolume(gomock.Any(), + taskresourcevolume.TestVolumeId, + gomock.Any(), + filepath.Join(hostMountDir, taskresourcevolume.TestSourceVolumeHostPath), + taskresourcevolume.TestFileSystem, + gomock.Any(), + gomock.Any(), + gomock.Any(), + gomock.Any(), + gomock.Any()).Return(nil).AnyTimes() + testAttachmentProperties := map[string]string{ apiebs.DeviceNameKey: taskresourcevolume.TestDeviceName, apiebs.VolumeIdKey: taskresourcevolume.TestVolumeId, @@ -94,7 +113,7 @@ func TestHandleEBSAttachmentHappyCase(t *testing.T) { AttachmentProperties: testAttachmentProperties, AttachmentType: apiebs.EBSTaskAttach, } - watcher := newTestEBSWatcher(ctx, taskEngineState, mockDiscoveryClient, mockTaskEngine) + watcher := newTestEBSWatcher(ctx, taskEngineState, mockDiscoveryClient, mockTaskEngine, mockCsiClient) var wg sync.WaitGroup wg.Add(1) mockDiscoveryClient.EXPECT().ConfirmEBSVolumeIsAttached(taskresourcevolume.TestDeviceName, taskresourcevolume.TestVolumeId). @@ -141,6 +160,8 @@ func TestHandleExpiredEBSAttachment(t *testing.T) { mockTaskEngine.EXPECT().GetDaemonTask(md.EbsCsiDriver).Return(nil).AnyTimes() mockTaskEngine.EXPECT().GetDaemonManagers().Return(nil).AnyTimes() + mockCsiClient := mock_csiclient.NewMockCSIClient(mockCtrl) + testAttachmentProperties := map[string]string{ apiebs.DeviceNameKey: taskresourcevolume.TestDeviceName, apiebs.VolumeIdKey: taskresourcevolume.TestVolumeId, @@ -163,7 +184,7 @@ func TestHandleExpiredEBSAttachment(t *testing.T) { AttachmentProperties: testAttachmentProperties, AttachmentType: apiebs.EBSTaskAttach, } - watcher := newTestEBSWatcher(ctx, taskEngineState, mockDiscoveryClient, mockTaskEngine) + watcher := newTestEBSWatcher(ctx, taskEngineState, mockDiscoveryClient, mockTaskEngine, mockCsiClient) err := watcher.HandleEBSResourceAttachment(ebsAttachment) assert.Error(t, err) @@ -186,6 +207,18 @@ func TestHandleDuplicateEBSAttachment(t *testing.T) { mockTaskEngine.EXPECT().GetDaemonTask(md.EbsCsiDriver).Return(nil).AnyTimes() mockTaskEngine.EXPECT().GetDaemonManagers().Return(nil).AnyTimes() + mockCsiClient := mock_csiclient.NewMockCSIClient(mockCtrl) + mockCsiClient.EXPECT().NodeStageVolume(gomock.Any(), + taskresourcevolume.TestVolumeId, + gomock.Any(), + filepath.Join(hostMountDir, taskresourcevolume.TestSourceVolumeHostPath), + taskresourcevolume.TestFileSystem, + gomock.Any(), + gomock.Any(), + gomock.Any(), + gomock.Any(), + gomock.Any()).Return(nil).AnyTimes() + expiresAt := time.Now().Add(time.Millisecond * testconst.WaitTimeoutMillis) testAttachmentProperties1 := map[string]string{ @@ -232,7 +265,7 @@ func TestHandleDuplicateEBSAttachment(t *testing.T) { AttachmentType: apiebs.EBSTaskAttach, } - watcher := newTestEBSWatcher(ctx, taskEngineState, mockDiscoveryClient, mockTaskEngine) + watcher := newTestEBSWatcher(ctx, taskEngineState, mockDiscoveryClient, mockTaskEngine, mockCsiClient) var wg sync.WaitGroup wg.Add(1) mockDiscoveryClient.EXPECT().ConfirmEBSVolumeIsAttached(taskresourcevolume.TestDeviceName, taskresourcevolume.TestVolumeId). @@ -277,6 +310,7 @@ func TestHandleInvalidTypeEBSAttachment(t *testing.T) { mockTaskEngine := mock_engine.NewMockTaskEngine(mockCtrl) mockTaskEngine.EXPECT().GetDaemonTask(md.EbsCsiDriver).Return(nil).AnyTimes() mockTaskEngine.EXPECT().GetDaemonManagers().Return(nil).AnyTimes() + mockCsiClient := mock_csiclient.NewMockCSIClient(mockCtrl) testAttachmentProperties := map[string]string{ apiebs.DeviceNameKey: taskresourcevolume.TestDeviceName, @@ -300,7 +334,7 @@ func TestHandleInvalidTypeEBSAttachment(t *testing.T) { AttachmentProperties: testAttachmentProperties, AttachmentType: "InvalidResourceType", } - watcher := newTestEBSWatcher(ctx, taskEngineState, mockDiscoveryClient, mockTaskEngine) + watcher := newTestEBSWatcher(ctx, taskEngineState, mockDiscoveryClient, mockTaskEngine, mockCsiClient) watcher.HandleResourceAttachment(ebsAttachment) @@ -323,6 +357,7 @@ func TestHandleEBSAckTimeout(t *testing.T) { mockTaskEngine := mock_engine.NewMockTaskEngine(mockCtrl) mockTaskEngine.EXPECT().GetDaemonTask(md.EbsCsiDriver).Return(nil).AnyTimes() mockTaskEngine.EXPECT().GetDaemonManagers().Return(nil).AnyTimes() + mockCsiClient := mock_csiclient.NewMockCSIClient(mockCtrl) testAttachmentProperties := map[string]string{ apiebs.DeviceNameKey: taskresourcevolume.TestDeviceName, @@ -345,7 +380,7 @@ func TestHandleEBSAckTimeout(t *testing.T) { }, AttachmentProperties: testAttachmentProperties, } - watcher := newTestEBSWatcher(ctx, taskEngineState, mockDiscoveryClient, mockTaskEngine) + watcher := newTestEBSWatcher(ctx, taskEngineState, mockDiscoveryClient, mockTaskEngine, mockCsiClient) watcher.HandleResourceAttachment(ebsAttachment) time.Sleep(time.Millisecond * testconst.WaitTimeoutMillis * 2) @@ -367,8 +402,9 @@ func TestHandleMismatchEBSAttachment(t *testing.T) { mockTaskEngine := mock_engine.NewMockTaskEngine(mockCtrl) mockTaskEngine.EXPECT().GetDaemonTask(md.EbsCsiDriver).Return(nil).AnyTimes() mockTaskEngine.EXPECT().GetDaemonManagers().Return(nil).AnyTimes() + mockCsiClient := mock_csiclient.NewMockCSIClient(mockCtrl) - watcher := newTestEBSWatcher(ctx, taskEngineState, mockDiscoveryClient, mockTaskEngine) + watcher := newTestEBSWatcher(ctx, taskEngineState, mockDiscoveryClient, mockTaskEngine, mockCsiClient) testAttachmentProperties := map[string]string{ apiebs.DeviceNameKey: taskresourcevolume.TestDeviceName, @@ -414,4 +450,87 @@ func TestHandleMismatchEBSAttachment(t *testing.T) { assert.ErrorIs(t, ebsAttachment.GetError(), apiebs.ErrInvalidVolumeID) } +func TestHandleEBSAttachmentWithExistingCSIDriverTask(t *testing.T) { + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + ctx := context.Background() + taskEngineState := dockerstate.NewTaskEngineState() + mockDiscoveryClient := mock_ebs_discovery.NewMockEBSDiscovery(mockCtrl) + mockTaskEngine := mock_engine.NewMockTaskEngine(mockCtrl) + mockTaskEngine.EXPECT().GetDaemonTask(md.EbsCsiDriver).Return(&apitask.Task{ + Arn: "arn:aws:ecs:us-east-1:012345678910:task/some-task-id", + KnownStatusUnsafe: apitaskstatus.TaskRunning, + }).AnyTimes() + mockTaskEngine.EXPECT().StateChangeEvents().Return(make(chan statechange.Event)).AnyTimes() + + mockCsiClient := mock_csiclient.NewMockCSIClient(mockCtrl) + mockCsiClient.EXPECT().NodeStageVolume(gomock.Any(), + taskresourcevolume.TestVolumeId, + gomock.Any(), + filepath.Join(hostMountDir, taskresourcevolume.TestSourceVolumeHostPath), + taskresourcevolume.TestFileSystem, + gomock.Any(), + gomock.Any(), + gomock.Any(), + gomock.Any(), + gomock.Any()).Return(nil).AnyTimes() + + testAttachmentProperties := map[string]string{ + apiebs.DeviceNameKey: taskresourcevolume.TestDeviceName, + apiebs.VolumeIdKey: taskresourcevolume.TestVolumeId, + apiebs.VolumeNameKey: taskresourcevolume.TestVolumeName, + apiebs.SourceVolumeHostPathKey: taskresourcevolume.TestSourceVolumeHostPath, + apiebs.FileSystemKey: taskresourcevolume.TestFileSystem, + apiebs.VolumeSizeGibKey: taskresourcevolume.TestVolumeSizeGib, + } + + expiresAt := time.Now().Add(time.Millisecond * testconst.WaitTimeoutMillis) + ebsAttachment := &apiebs.ResourceAttachment{ + AttachmentInfo: attachment.AttachmentInfo{ + TaskARN: taskARN, + TaskClusterARN: taskClusterARN, + ContainerInstanceARN: containerInstanceARN, + ExpiresAt: expiresAt, + Status: attachment.AttachmentNone, + AttachmentARN: resourceAttachmentARN, + }, + AttachmentProperties: testAttachmentProperties, + AttachmentType: apiebs.EBSTaskAttach, + } + watcher := newTestEBSWatcher(ctx, taskEngineState, mockDiscoveryClient, mockTaskEngine, mockCsiClient) + var wg sync.WaitGroup + wg.Add(1) + mockDiscoveryClient.EXPECT().ConfirmEBSVolumeIsAttached(taskresourcevolume.TestDeviceName, taskresourcevolume.TestVolumeId). + Do(func(deviceName, volumeID string) { + wg.Done() + }). + Return(taskresourcevolume.TestDeviceName, nil). + MinTimes(1) + + err := watcher.HandleEBSResourceAttachment(ebsAttachment) + assert.NoError(t, err) + + // Instead of starting the EBS watcher, we'll be mocking a tick of the EBS watcher's scan ticker. + // Otherwise, the watcher will continue to run forever and the test will panic. + wg.Add(1) + go func() { + defer wg.Done() + pendingEBS := watcher.agentState.GetAllPendingEBSAttachmentsWithKey() + if len(pendingEBS) > 0 { + foundVolumes := apiebs.ScanEBSVolumes(pendingEBS, watcher.discoveryClient) + watcher.StageAll(foundVolumes) + watcher.NotifyAttached(foundVolumes) + } + }() + + wg.Wait() + + assert.Len(t, taskEngineState.(*dockerstate.DockerTaskEngineState).GetAllEBSAttachments(), 1) + ebsAttachment, ok := taskEngineState.(*dockerstate.DockerTaskEngineState).GetEBSByVolumeId(taskresourcevolume.TestVolumeId) + require.True(t, ok) + assert.True(t, ebsAttachment.IsAttached()) + // assert.True(t, ebsAttachment.IsSent()) +} + // TODO add StageAll test diff --git a/agent/engine/data.go b/agent/engine/data.go index 1b3146a5a21..c994ea219c4 100644 --- a/agent/engine/data.go +++ b/agent/engine/data.go @@ -56,6 +56,12 @@ func (engine *DockerTaskEngine) loadTasks() error { for _, task := range tasks { engine.state.AddTask(task) + // TODO: Will need to clean up all of the STOPPED managed daemon tasks + md, ok := task.IsManagedDaemonTask() + if ok { + engine.SetDaemonTask(md, task) + } + // Populate ip <-> task mapping if task has a local ip. This mapping is needed for serving v2 task metadata. if ip := task.GetLocalIPAddress(); ip != "" { engine.state.AddTaskIPAddress(ip, task.Arn) diff --git a/agent/engine/data_test.go b/agent/engine/data_test.go index 36665545e4b..064cd405f71 100644 --- a/agent/engine/data_test.go +++ b/agent/engine/data_test.go @@ -28,6 +28,7 @@ import ( "github.com/aws/amazon-ecs-agent/agent/engine/image" "github.com/aws/amazon-ecs-agent/ecs-agent/api/attachment" apicontainerstatus "github.com/aws/amazon-ecs-agent/ecs-agent/api/container/status" + apitaskstatus "github.com/aws/amazon-ecs-agent/ecs-agent/api/task/status" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -51,6 +52,13 @@ var ( TaskARNUnsafe: testTaskARN, KnownStatusUnsafe: apicontainerstatus.ContainerPulled, } + testManagedDaemonContainer = &apicontainer.Container{ + Name: "ecs-managed-" + testContainerName, + Image: "ebs-csi-driver", + TaskARNUnsafe: testTaskARN, + Type: apicontainer.ContainerManagedDaemon, + KnownStatusUnsafe: apicontainerstatus.ContainerRunning, + } testDockerContainer = &apicontainer.DockerContainer{ DockerID: testDockerID, Container: testContainer, @@ -59,6 +67,10 @@ var ( DockerID: testDockerID, Container: testPulledContainer, } + testManagedDaemonDockerContainer = &apicontainer.DockerContainer{ + DockerID: testDockerID, + Container: testManagedDaemonContainer, + } testTask = &apitask.Task{ Arn: testTaskARN, Containers: []*apicontainer.Container{testContainer}, @@ -69,6 +81,20 @@ var ( Containers: []*apicontainer.Container{testContainer, testPulledContainer}, LocalIPAddressUnsafe: testTaskIP, } + testTaskWithManagedDaemonContainer = &apitask.Task{ + Arn: testTaskARN, + Containers: []*apicontainer.Container{testManagedDaemonContainer}, + LocalIPAddressUnsafe: testTaskIP, + IsInternal: true, + KnownStatusUnsafe: apitaskstatus.TaskRunning, + } + testStoppedTaskWithManagedDaemonContainer = &apitask.Task{ + Arn: testTaskARN, + Containers: []*apicontainer.Container{testManagedDaemonContainer}, + LocalIPAddressUnsafe: testTaskIP, + IsInternal: true, + KnownStatusUnsafe: apitaskstatus.TaskStopped, + } testImageState = &image.ImageState{ Image: testImage, PullSucceeded: false, @@ -135,6 +161,74 @@ func TestLoadState(t *testing.T) { assert.Equal(t, testTaskARN, arn) } +func TestLoadStateWithManagedDaemon(t *testing.T) { + dataClient := newTestDataClient(t) + + engine := &DockerTaskEngine{ + state: dockerstate.NewTaskEngineState(), + dataClient: dataClient, + daemonTasks: make(map[string]*apitask.Task), + } + + require.NoError(t, dataClient.SaveTask(testTaskWithManagedDaemonContainer)) + require.NoError(t, dataClient.SaveDockerContainer(testManagedDaemonDockerContainer)) + require.NoError(t, dataClient.SaveENIAttachment(testENIAttachment)) + require.NoError(t, dataClient.SaveImageState(testImageState)) + + require.NoError(t, engine.LoadState()) + task, ok := engine.state.TaskByArn(testTaskARN) + assert.True(t, ok) + assert.Equal(t, apicontainerstatus.ContainerRunning, task.Containers[0].GetKnownStatus()) + _, ok = engine.state.ContainerByID(testDockerID) + assert.True(t, ok) + assert.Len(t, engine.state.AllImageStates(), 1) + assert.Len(t, engine.state.AllENIAttachments(), 1) + + // Check ip <-> task arn mapping is loaded in state. + ip, ok := engine.state.GetIPAddressByTaskARN(testTaskARN) + require.True(t, ok) + assert.Equal(t, testTaskIP, ip) + arn, ok := engine.state.GetTaskByIPAddress(testTaskIP) + require.True(t, ok) + assert.Equal(t, testTaskARN, arn) + + assert.NotNil(t, engine.GetDaemonTask("ebs-csi-driver")) +} + +func TestLoadStateWithStoppedManagedDaemon(t *testing.T) { + dataClient := newTestDataClient(t) + + engine := &DockerTaskEngine{ + state: dockerstate.NewTaskEngineState(), + dataClient: dataClient, + daemonTasks: make(map[string]*apitask.Task), + } + + require.NoError(t, dataClient.SaveTask(testStoppedTaskWithManagedDaemonContainer)) + require.NoError(t, dataClient.SaveDockerContainer(testManagedDaemonDockerContainer)) + require.NoError(t, dataClient.SaveENIAttachment(testENIAttachment)) + require.NoError(t, dataClient.SaveImageState(testImageState)) + + require.NoError(t, engine.LoadState()) + task, ok := engine.state.TaskByArn(testTaskARN) + assert.True(t, ok) + assert.Equal(t, apicontainerstatus.ContainerRunning, task.Containers[0].GetKnownStatus()) + _, ok = engine.state.ContainerByID(testDockerID) + assert.True(t, ok) + assert.Len(t, engine.state.AllImageStates(), 1) + assert.Len(t, engine.state.AllENIAttachments(), 1) + + // Check ip <-> task arn mapping is loaded in state. + ip, ok := engine.state.GetIPAddressByTaskARN(testTaskARN) + require.True(t, ok) + assert.Equal(t, testTaskIP, ip) + arn, ok := engine.state.GetTaskByIPAddress(testTaskIP) + require.True(t, ok) + assert.Equal(t, testTaskARN, arn) + + assert.Nil(t, engine.GetDaemonTask("ebs-csi-driver")) +} + func TestSaveState(t *testing.T) { dataClient := newTestDataClient(t)