Skip to content

Commit

Permalink
Fix regression in the Ray plugin (flyteorg#4239)
Browse files Browse the repository at this point in the history
---------

Signed-off-by: Kevin Su <[email protected]>
Signed-off-by: squiishyy <[email protected]>
  • Loading branch information
pingsutw authored and squiishyy committed Oct 18, 2023
1 parent ec38454 commit 2f3fe6c
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 14 deletions.
35 changes: 26 additions & 9 deletions flyteplugins/go/tasks/plugins/k8s/ray/ray.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"strconv"
"strings"
"time"

rayv1alpha1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1alpha1"
v1 "k8s.io/api/core/v1"
Expand Down Expand Up @@ -398,19 +399,35 @@ func (plugin rayJobResourceHandler) GetTaskPhase(ctx context.Context, pluginCont
return pluginsCore.PhaseInfoUndefined, err
}

switch rayJob.Status.JobStatus {
case rayv1alpha1.JobStatusPending:
return pluginsCore.PhaseInfoInitializing(rayJob.Status.StartTime.Time, pluginsCore.DefaultPhaseVersion, "job is pending", info), nil
case rayv1alpha1.JobStatusFailed:
reason := fmt.Sprintf("Failed to create Ray job: %s", rayJob.Name)
if len(rayJob.Status.JobDeploymentStatus) == 0 {
return pluginsCore.PhaseInfoQueued(time.Now(), pluginsCore.DefaultPhaseVersion, "Scheduling"), nil
}

// Kuberay creates a Ray cluster first, and then submits a Ray job to the cluster
switch rayJob.Status.JobDeploymentStatus {
case rayv1alpha1.JobDeploymentStatusInitializing:
return pluginsCore.PhaseInfoInitializing(rayJob.CreationTimestamp.Time, pluginsCore.DefaultPhaseVersion, "cluster is creating", info), nil
case rayv1alpha1.JobDeploymentStatusFailedToGetOrCreateRayCluster:
reason := fmt.Sprintf("Failed to create Ray cluster %s with error: %s", rayJob.Name, rayJob.Status.Message)
return pluginsCore.PhaseInfoFailure(flyteerr.TaskFailedWithError, reason, info), nil
case rayv1alpha1.JobStatusSucceeded:
return pluginsCore.PhaseInfoSuccess(info), nil
case rayv1alpha1.JobStatusRunning:
case rayv1alpha1.JobDeploymentStatusFailedJobDeploy:
reason := fmt.Sprintf("Failed to submit Ray job %s with error: %s", rayJob.Name, rayJob.Status.Message)
return pluginsCore.PhaseInfoFailure(flyteerr.TaskFailedWithError, reason, info), nil
case rayv1alpha1.JobDeploymentStatusWaitForDashboard:
return pluginsCore.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, info), nil
case rayv1alpha1.JobDeploymentStatusRunning, rayv1alpha1.JobDeploymentStatusComplete:
switch rayJob.Status.JobStatus {
case rayv1alpha1.JobStatusFailed:
reason := fmt.Sprintf("Failed to run Ray job %s with error: %s", rayJob.Name, rayJob.Status.Message)
return pluginsCore.PhaseInfoFailure(flyteerr.TaskFailedWithError, reason, info), nil
case rayv1alpha1.JobStatusSucceeded:
return pluginsCore.PhaseInfoSuccess(info), nil
case rayv1alpha1.JobStatusPending, rayv1alpha1.JobStatusRunning:
return pluginsCore.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, info), nil
}
}

return pluginsCore.PhaseInfoQueued(rayJob.CreationTimestamp.Time, pluginsCore.DefaultPhaseVersion, "JobCreated"), nil
return pluginsCore.PhaseInfoUndefined, nil
}

func init() {
Expand Down
17 changes: 12 additions & 5 deletions flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -420,19 +420,26 @@ func TestGetTaskPhase(t *testing.T) {

testCases := []struct {
rayJobPhase rayv1alpha1.JobStatus
rayClusterPhase rayv1alpha1.JobDeploymentStatus
expectedCorePhase pluginsCore.Phase
}{
{"", pluginsCore.PhaseQueued},
{rayv1alpha1.JobStatusPending, pluginsCore.PhaseInitializing},
{rayv1alpha1.JobStatusRunning, pluginsCore.PhaseRunning},
{rayv1alpha1.JobStatusSucceeded, pluginsCore.PhaseSuccess},
{rayv1alpha1.JobStatusFailed, pluginsCore.PhasePermanentFailure},
{"", rayv1alpha1.JobDeploymentStatusInitializing, pluginsCore.PhaseInitializing},
{rayv1alpha1.JobStatusPending, rayv1alpha1.JobDeploymentStatusFailedToGetOrCreateRayCluster, pluginsCore.PhasePermanentFailure},
{rayv1alpha1.JobStatusPending, rayv1alpha1.JobDeploymentStatusWaitForDashboard, pluginsCore.PhaseRunning},
{rayv1alpha1.JobStatusPending, rayv1alpha1.JobDeploymentStatusFailedJobDeploy, pluginsCore.PhasePermanentFailure},
{rayv1alpha1.JobStatusPending, rayv1alpha1.JobDeploymentStatusRunning, pluginsCore.PhaseRunning},
{rayv1alpha1.JobStatusPending, rayv1alpha1.JobDeploymentStatusFailedToGetJobStatus, pluginsCore.PhaseUndefined},
{rayv1alpha1.JobStatusRunning, rayv1alpha1.JobDeploymentStatusRunning, pluginsCore.PhaseRunning},
{rayv1alpha1.JobStatusFailed, rayv1alpha1.JobDeploymentStatusRunning, pluginsCore.PhasePermanentFailure},
{rayv1alpha1.JobStatusSucceeded, rayv1alpha1.JobDeploymentStatusRunning, pluginsCore.PhaseSuccess},
{rayv1alpha1.JobStatusSucceeded, rayv1alpha1.JobDeploymentStatusComplete, pluginsCore.PhaseSuccess},
}

for _, tc := range testCases {
t.Run("TestGetTaskPhase_"+string(tc.rayJobPhase), func(t *testing.T) {
rayObject := &rayv1alpha1.RayJob{}
rayObject.Status.JobStatus = tc.rayJobPhase
rayObject.Status.JobDeploymentStatus = tc.rayClusterPhase
startTime := metav1.NewTime(time.Now())
rayObject.Status.StartTime = &startTime
phaseInfo, err := rayJobResourceHandler.GetTaskPhase(ctx, pluginCtx, rayObject)
Expand Down

0 comments on commit 2f3fe6c

Please sign in to comment.