diff --git a/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go b/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go index d9ab78db9e..d97f6e0d37 100755 --- a/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go +++ b/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go @@ -3,6 +3,7 @@ package flytek8s import ( "context" "fmt" + "strings" "time" "github.com/lyft/flytestdlib/logger" @@ -14,6 +15,7 @@ import ( ) const PodKind = "pod" +const OOMKilled = "OOMKilled" func ToK8sPodSpec(ctx context.Context, taskExecutionMetadata pluginsCore.TaskExecutionMetadata, taskReader pluginsCore.TaskReader, inputs io.InputReader, outputPaths io.OutputFilePaths) (*v1.PodSpec, error) { @@ -167,6 +169,22 @@ func DemystifyPending(status v1.PodStatus) (pluginsCore.PhaseInfo, error) { return pluginsCore.PhaseInfoQueued(time.Now(), pluginsCore.DefaultPhaseVersion, "Scheduling"), nil } +func DemystifySuccess(status v1.PodStatus, info pluginsCore.TaskInfo) (pluginsCore.PhaseInfo, error) { + for _, status := range status.ContainerStatuses { + if status.State.Terminated != nil && strings.Contains(status.State.Terminated.Reason, OOMKilled) { + return pluginsCore.PhaseInfoRetryableFailure("OOMKilled", + "Pod reported success despite being OOMKilled", &info), nil + } + } + for _, status := range status.InitContainerStatuses { + if status.State.Terminated != nil && strings.Contains(status.State.Terminated.Reason, OOMKilled) { + return pluginsCore.PhaseInfoRetryableFailure("OOMKilled", + "Pod reported success despite being OOMKilled", &info), nil + } + } + return pluginsCore.PhaseInfoSuccess(&info), nil +} + func ConvertPodFailureToError(status v1.PodStatus) (code, message string) { code = "UnknownError" message = "Container/Pod failed. No message received from kubernetes." diff --git a/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go b/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go index b659bf3d76..b63a8699ea 100755 --- a/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go +++ b/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go @@ -12,7 +12,7 @@ import ( "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" "github.com/stretchr/testify/assert" - "k8s.io/api/core/v1" + v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" metaV1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -338,6 +338,48 @@ func TestDemystifyPending(t *testing.T) { }) } +func TestDemystifySuccess(t *testing.T) { + t.Run("OOMKilled", func(t *testing.T) { + phaseInfo, err := DemystifySuccess(v1.PodStatus{ + ContainerStatuses: []v1.ContainerStatus{ + { + State: v1.ContainerState{ + Terminated: &v1.ContainerStateTerminated{ + Reason: OOMKilled, + }, + }, + }, + }, + }, pluginsCore.TaskInfo{}) + assert.Nil(t, err) + assert.Equal(t, pluginsCore.PhaseRetryableFailure, phaseInfo.Phase()) + assert.Equal(t, "OOMKilled", phaseInfo.Err().Code) + }) + + t.Run("InitContainer OOMKilled", func(t *testing.T) { + phaseInfo, err := DemystifySuccess(v1.PodStatus{ + InitContainerStatuses: []v1.ContainerStatus{ + { + State: v1.ContainerState{ + Terminated: &v1.ContainerStateTerminated{ + Reason: OOMKilled, + }, + }, + }, + }, + }, pluginsCore.TaskInfo{}) + assert.Nil(t, err) + assert.Equal(t, pluginsCore.PhaseRetryableFailure, phaseInfo.Phase()) + assert.Equal(t, "OOMKilled", phaseInfo.Err().Code) + }) + + t.Run("success", func(t *testing.T) { + phaseInfo, err := DemystifySuccess(v1.PodStatus{}, pluginsCore.TaskInfo{}) + assert.Nil(t, err) + assert.Equal(t, pluginsCore.PhaseSuccess, phaseInfo.Phase()) + }) +} + func TestConvertPodFailureToError(t *testing.T) { t.Run("unknown-error", func(t *testing.T) { code, _ := ConvertPodFailureToError(v1.PodStatus{}) diff --git a/flyteplugins/go/tasks/plugins/array/k8s/monitor.go b/flyteplugins/go/tasks/plugins/array/k8s/monitor.go index d6fe6e770c..c9226ffb16 100644 --- a/flyteplugins/go/tasks/plugins/array/k8s/monitor.go +++ b/flyteplugins/go/tasks/plugins/array/k8s/monitor.go @@ -162,7 +162,7 @@ func CheckPodStatus(ctx context.Context, client core.KubeClient, name k8sTypes.N } switch pod.Status.Phase { case v1.PodSucceeded: - return core.PhaseInfoSuccess(&taskInfo), nil + return flytek8s.DemystifySuccess(pod.Status, taskInfo) case v1.PodFailed: code, message := flytek8s.ConvertPodFailureToError(pod.Status) return core.PhaseInfoRetryableFailure(code, message, &taskInfo), nil diff --git a/flyteplugins/go/tasks/plugins/k8s/container/container.go b/flyteplugins/go/tasks/plugins/k8s/container/container.go index 91ecd3f2a9..e487cf1efe 100755 --- a/flyteplugins/go/tasks/plugins/k8s/container/container.go +++ b/flyteplugins/go/tasks/plugins/k8s/container/container.go @@ -6,7 +6,7 @@ import ( "github.com/lyft/flyteplugins/go/tasks/pluginmachinery" "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/flytek8s" - "k8s.io/api/core/v1" + v1 "k8s.io/api/core/v1" "github.com/lyft/flyteplugins/go/tasks/logs" pluginsCore "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core" @@ -37,7 +37,7 @@ func (containerTaskExecutor) GetTaskPhase(ctx context.Context, pluginContext k8s } switch pod.Status.Phase { case v1.PodSucceeded: - return pluginsCore.PhaseInfoSuccess(&info), nil + return flytek8s.DemystifySuccess(pod.Status, info) case v1.PodFailed: code, message := flytek8s.ConvertPodFailureToError(pod.Status) return pluginsCore.PhaseInfoRetryableFailure(code, message, &info), nil diff --git a/flyteplugins/go/tasks/plugins/k8s/sidecar/sidecar.go b/flyteplugins/go/tasks/plugins/k8s/sidecar/sidecar.go index be113c4c59..f86ddc726e 100755 --- a/flyteplugins/go/tasks/plugins/k8s/sidecar/sidecar.go +++ b/flyteplugins/go/tasks/plugins/k8s/sidecar/sidecar.go @@ -143,7 +143,7 @@ func (sidecarResourceHandler) GetTaskPhase(ctx context.Context, pluginContext k8 } switch pod.Status.Phase { case k8sv1.PodSucceeded: - return pluginsCore.PhaseInfoSuccess(&info), nil + return flytek8s.DemystifySuccess(pod.Status, info) case k8sv1.PodFailed: code, message := flytek8s.ConvertPodFailureToError(pod.Status) return pluginsCore.PhaseInfoRetryableFailure(code, message, &info), nil