diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go index 90388b46a5..95a87f4efa 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go @@ -574,6 +574,10 @@ func (plugin rayJobResourceHandler) GetTaskPhase(ctx context.Context, pluginCont phaseInfo, err = pluginsCore.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, info), nil case rayv1.JobDeploymentStatusComplete: phaseInfo, err = pluginsCore.PhaseInfoSuccess(info), nil + case rayv1.JobDeploymentStatusSuspended: + phaseInfo, err = pluginsCore.PhaseInfoQueuedWithTaskInfo(time.Now(), pluginsCore.DefaultPhaseVersion, "Suspended", info), nil + case rayv1.JobDeploymentStatusSuspending: + phaseInfo, err = pluginsCore.PhaseInfoQueuedWithTaskInfo(time.Now(), pluginsCore.DefaultPhaseVersion, "Suspending", info), nil case rayv1.JobDeploymentStatusFailed: failInfo := fmt.Sprintf("Failed to run Ray job %s with error: [%s] %s", rayJob.Name, rayJob.Status.Reason, rayJob.Status.Message) phaseInfo, err = pluginsCore.PhaseInfoFailure(flyteerr.TaskFailedWithError, failInfo, info), nil diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go b/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go index 7b555e9f23..38b2f56785 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go @@ -755,7 +755,8 @@ func TestGetTaskPhase(t *testing.T) { {rayv1.JobDeploymentStatusRunning, pluginsCore.PhaseRunning, false}, {rayv1.JobDeploymentStatusComplete, pluginsCore.PhaseSuccess, false}, {rayv1.JobDeploymentStatusFailed, pluginsCore.PhasePermanentFailure, false}, - {rayv1.JobDeploymentStatusSuspended, pluginsCore.PhaseUndefined, true}, + {rayv1.JobDeploymentStatusSuspended, pluginsCore.PhaseQueued, false}, + {rayv1.JobDeploymentStatusSuspending, pluginsCore.PhaseQueued, false}, } for _, tc := range testCases {