diff --git a/flyteplugins/go/tasks/plugins/array/k8s/executor.go b/flyteplugins/go/tasks/plugins/array/k8s/executor.go index f664392fd9..3232584674 100644 --- a/flyteplugins/go/tasks/plugins/array/k8s/executor.go +++ b/flyteplugins/go/tasks/plugins/array/k8s/executor.go @@ -103,9 +103,6 @@ func (e Executor) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (c case arrayCore.PhasePreLaunch: nextState = pluginState.SetPhase(arrayCore.PhaseLaunch, version+1).SetReason("Nothing to do in PreLaunch phase.") - case arrayCore.PhaseWaitingForResources: - fallthrough - case arrayCore.PhaseLaunch: // In order to maintain backwards compatibility with the state transitions // in the aws batch plugin. Forward to PhaseCheckingSubTasksExecutions where the launching @@ -113,6 +110,9 @@ func (e Executor) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (c nextState = pluginState.SetPhase(arrayCore.PhaseCheckingSubTaskExecutions, version+1).SetReason("Nothing to do in Launch phase.") err = nil + case arrayCore.PhaseWaitingForResources: + fallthrough + case arrayCore.PhaseCheckingSubTaskExecutions: nextState, externalResources, err = LaunchAndCheckSubTasksState(ctx, tCtx, e.kubeClient, pluginConfig, tCtx.DataStore(), tCtx.OutputWriter().GetOutputPrefixPath(), tCtx.OutputWriter().GetRawOutputPrefix(), pluginState) diff --git a/flyteplugins/go/tasks/plugins/array/k8s/management.go b/flyteplugins/go/tasks/plugins/array/k8s/management.go index d6abaaf74b..12eea118cc 100644 --- a/flyteplugins/go/tasks/plugins/array/k8s/management.go +++ b/flyteplugins/go/tasks/plugins/array/k8s/management.go @@ -306,7 +306,7 @@ func LaunchAndCheckSubTasksState(ctx context.Context, tCtx core.TaskExecutionCon } _, version := currentState.GetPhase() - if phase == arrayCore.PhaseCheckingSubTaskExecutions { + if phase == arrayCore.PhaseCheckingSubTaskExecutions || phase == arrayCore.PhaseWaitingForResources { newSubTaskPhaseHash, err := newState.GetArrayStatus().HashCode() if err != nil { return currentState, externalResources, err @@ -316,7 +316,7 @@ func LaunchAndCheckSubTasksState(ctx context.Context, tCtx core.TaskExecutionCon version++ } - newState = newState.SetPhase(phase, version).SetReason("Task is still running") + newState = newState.SetPhase(arrayCore.PhaseCheckingSubTaskExecutions, version).SetReason("Task is still running") } else { newState = newState.SetPhase(phase, version+1) } diff --git a/flyteplugins/go/tasks/plugins/array/k8s/management_test.go b/flyteplugins/go/tasks/plugins/array/k8s/management_test.go index 9404bdfb72..2bd1d5eefe 100644 --- a/flyteplugins/go/tasks/plugins/array/k8s/management_test.go +++ b/flyteplugins/go/tasks/plugins/array/k8s/management_test.go @@ -300,7 +300,7 @@ func TestCheckSubTasksState(t *testing.T) { // validate results assert.Nil(t, err) p, _ := newState.GetPhase() - assert.Equal(t, arrayCore.PhaseWaitingForResources.String(), p.String()) + assert.Equal(t, arrayCore.PhaseCheckingSubTaskExecutions.String(), p.String()) resourceManager.AssertNumberOfCalls(t, "AllocateResource", subtaskCount) for _, subtaskPhaseIndex := range newState.GetArrayStatus().Detailed.GetItems() { assert.Equal(t, core.PhaseWaitingForResources, core.Phases[subtaskPhaseIndex])