From c143eaa4986e891ab06c9da43ff22b05458c26c5 Mon Sep 17 00:00:00 2001 From: Dan Rammer Date: Fri, 14 Apr 2023 10:04:32 -0500 Subject: [PATCH] Set PrimaryContainerKey annotation by default (#337) * added pod plugin optimizations Signed-off-by: Daniel Rammer * fixed unit tests Signed-off-by: Daniel Rammer --------- Signed-off-by: Daniel Rammer --- .../plugins/array/k8s/integration_test.go | 16 +++++++++++- .../plugins/array/k8s/management_test.go | 7 ++++++ .../tasks/plugins/k8s/pod/container_test.go | 25 +++++++++++++++++++ .../go/tasks/plugins/k8s/pod/plugin.go | 21 +++++++++++++--- 4 files changed, 65 insertions(+), 4 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/array/k8s/integration_test.go b/flyteplugins/go/tasks/plugins/array/k8s/integration_test.go index 17edf2077f..1c1f8a9778 100644 --- a/flyteplugins/go/tasks/plugins/array/k8s/integration_test.go +++ b/flyteplugins/go/tasks/plugins/array/k8s/integration_test.go @@ -77,8 +77,18 @@ func advancePodPhases(ctx context.Context, store *storage.DataStore, outputWrite for _, pod := range podList.Items { newPhase := nextHappyPodPhase(pod.Status.Phase) + primaryContainerName := pod.Annotations["primary_container_name"] + if len(primaryContainerName) <= 0 { + primaryContainerName = "foo" + } pod.Status.ContainerStatuses = []v1.ContainerStatus{ - {ContainerID: "cont_123"}, + v1.ContainerStatus{ + Name: primaryContainerName, + ContainerID: primaryContainerName, + State: v1.ContainerState{ + Running: &v1.ContainerStateRunning{}, + }, + }, } if pod.Status.Phase != newPhase && newPhase == v1.PodSucceeded { @@ -95,6 +105,10 @@ func advancePodPhases(ctx context.Context, store *storage.DataStore, outputWrite } } + pod.Status.ContainerStatuses[0].State = v1.ContainerState{ + Terminated: &v1.ContainerStateTerminated{}, + } + ref := outputWriter.GetOutputPath() if idx > -1 { ref, err = store.ConstructReference(ctx, outputWriter.GetOutputPrefixPath(), strconv.Itoa(idx), "outputs.pb") diff --git a/flyteplugins/go/tasks/plugins/array/k8s/management_test.go b/flyteplugins/go/tasks/plugins/array/k8s/management_test.go index fc422132b8..a2dd715faf 100644 --- a/flyteplugins/go/tasks/plugins/array/k8s/management_test.go +++ b/flyteplugins/go/tasks/plugins/array/k8s/management_test.go @@ -164,6 +164,13 @@ func TestCheckSubTasksState(t *testing.T) { pod.Spec.Containers = append(pod.Spec.Containers, v1.Container{Name: "foo"}) pod.Status.Phase = v1.PodRunning + pod.Status.ContainerStatuses = []v1.ContainerStatus{ + v1.ContainerStatus{ + State: v1.ContainerState{ + Running: &v1.ContainerStateRunning{}, + }, + }, + } _ = fakeKubeClient.Create(ctx, pod) _ = fakeKubeCache.Create(ctx, pod) } diff --git a/flyteplugins/go/tasks/plugins/k8s/pod/container_test.go b/flyteplugins/go/tasks/plugins/k8s/pod/container_test.go index 96872f115c..0624c817c4 100644 --- a/flyteplugins/go/tasks/plugins/k8s/pod/container_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/pod/container_test.go @@ -156,6 +156,14 @@ func TestContainerTaskExecutor_GetTaskStatus(t *testing.T) { ctx := context.TODO() t.Run("running", func(t *testing.T) { j.Status.Phase = v1.PodRunning + j.Status.ContainerStatuses = []v1.ContainerStatus{ + { + State: v1.ContainerState{ + Running: &v1.ContainerStateRunning{}, + }, + }, + } + phaseInfo, err := DefaultPodPlugin.GetTaskPhase(ctx, taskCtx, j) assert.NoError(t, err) assert.Equal(t, pluginsCore.PhaseRunning, phaseInfo.Phase()) @@ -193,6 +201,23 @@ func TestContainerTaskExecutor_GetTaskStatus(t *testing.T) { assert.Equal(t, "Unschedulable", ec) }) + t.Run("successOptimized", func(t *testing.T) { + j.Status.Phase = v1.PodRunning + j.Status.ContainerStatuses = []v1.ContainerStatus{ + { + State: v1.ContainerState{ + Terminated: &v1.ContainerStateTerminated{ + ExitCode: 0, + }, + }, + }, + } + + phaseInfo, err := DefaultPodPlugin.GetTaskPhase(ctx, taskCtx, j) + assert.NoError(t, err) + assert.Equal(t, pluginsCore.PhaseSuccess, phaseInfo.Phase()) + }) + t.Run("success", func(t *testing.T) { j.Status.Phase = v1.PodSucceeded phaseInfo, err := DefaultPodPlugin.GetTaskPhase(ctx, taskCtx, j) diff --git a/flyteplugins/go/tasks/plugins/k8s/pod/plugin.go b/flyteplugins/go/tasks/plugins/k8s/pod/plugin.go index d1c509fb7c..ac0d8c24cf 100644 --- a/flyteplugins/go/tasks/plugins/k8s/pod/plugin.go +++ b/flyteplugins/go/tasks/plugins/k8s/pod/plugin.go @@ -121,8 +121,10 @@ func (p plugin) BuildResource(ctx context.Context, taskCtx pluginsCore.TaskExecu return nil, err } - // set primary container name if this is executed as a sidecar - if taskTemplate.Type == SidecarTaskType { + // set primaryContainerKey annotation if this is a Sidecar task or, as an optimization, if there is only a single + // container. this plugin marks the task complete if the primary Container is complete, so if there is only one + // container we can mark the task as complete before the Pod has been marked complete. + if taskTemplate.Type == SidecarTaskType || len(podSpec.Containers) == 1 { objectMeta.Annotations[flytek8s.PrimaryContainerKey] = primaryContainerName } @@ -187,7 +189,20 @@ func (plugin) GetTaskPhaseWithLogs(ctx context.Context, pluginContext k8s.Plugin default: primaryContainerName, exists := r.GetAnnotations()[flytek8s.PrimaryContainerKey] if !exists { - // if the primary container annotation dos not exist, then the task requires all containers + // if all of the containers in the Pod are complete, as an optimization, we can declare the task as + // succeeded rather than waiting for the Pod to be marked completed. + allSuccessfullyTerminated := len(pod.Status.ContainerStatuses) > 0 + for _, s := range pod.Status.ContainerStatuses { + if s.State.Waiting != nil || s.State.Running != nil || (s.State.Terminated != nil && s.State.Terminated.ExitCode != 0) { + allSuccessfullyTerminated = false + } + } + + if allSuccessfullyTerminated { + return flytek8s.DemystifySuccess(pod.Status, info) + } + + // if the primary container annotation does not exist, then the task requires all containers // to succeed to declare success. therefore, if the pod is not in one of the above states we // fallback to declaring the task as 'running'. phaseInfo = pluginsCore.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, &info)