diff --git a/go/tasks/plugins/array/awsbatch/executor.go b/go/tasks/plugins/array/awsbatch/executor.go index 00614da9e..e5172d31f 100644 --- a/go/tasks/plugins/array/awsbatch/executor.go +++ b/go/tasks/plugins/array/awsbatch/executor.go @@ -80,9 +80,7 @@ func (e Executor) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (c pluginState, err = LaunchSubTasks(ctx, tCtx, e.jobStore, pluginConfig, pluginState, e.metrics) case arrayCore.PhaseCheckingSubTaskExecutions: - pluginState, err = CheckSubTasksState(ctx, tCtx.TaskExecutionMetadata(), - tCtx.OutputWriter().GetOutputPrefixPath(), tCtx.OutputWriter().GetRawOutputPrefix(), - e.jobStore, tCtx.DataStore(), pluginConfig, pluginState, e.metrics) + pluginState, err = CheckSubTasksState(ctx, tCtx, e.jobStore, pluginConfig, pluginState, e.metrics) case arrayCore.PhaseAssembleFinalOutput: pluginState.State, err = array.AssembleFinalOutputs(ctx, e.outputAssembler, tCtx, arrayCore.PhaseSuccess, version, pluginState.State) diff --git a/go/tasks/plugins/array/awsbatch/monitor.go b/go/tasks/plugins/array/awsbatch/monitor.go index 846aec420..4f09b911f 100644 --- a/go/tasks/plugins/array/awsbatch/monitor.go +++ b/go/tasks/plugins/array/awsbatch/monitor.go @@ -4,8 +4,8 @@ import ( "context" core2 "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" - "github.com/flyteorg/flytestdlib/storage" + "github.com/flyteorg/flyteplugins/go/tasks/errors" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/ioutils" "github.com/flyteorg/flyteplugins/go/tasks/plugins/array" @@ -34,19 +34,32 @@ func createSubJobList(count int) []*Job { return res } -func CheckSubTasksState(ctx context.Context, taskMeta core.TaskExecutionMetadata, outputPrefix, baseOutputSandbox storage.DataReference, jobStore *JobStore, - dataStore *storage.DataStore, cfg *config.Config, currentState *State, metrics ExecutorMetrics) (newState *State, err error) { +func CheckSubTasksState(ctx context.Context, tCtx core.TaskExecutionContext, jobStore *JobStore, + cfg *config.Config, currentState *State, metrics ExecutorMetrics) (newState *State, err error) { newState = currentState parentState := currentState.State - jobName := taskMeta.GetTaskExecutionID().GetGeneratedName() + jobName := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName() job := jobStore.Get(jobName) + outputPrefix := tCtx.OutputWriter().GetOutputPrefixPath() + baseOutputSandbox := tCtx.OutputWriter().GetRawOutputPrefix() + dataStore := tCtx.DataStore() + // Check that the taskTemplate is valid + var taskTemplate *core2.TaskTemplate + taskTemplate, err = tCtx.TaskReader().Read(ctx) + if err != nil { + return nil, errors.Wrapf(errors.CorruptedPluginState, err, "Failed to read task template") + } else if taskTemplate == nil { + return nil, errors.Errorf(errors.BadTaskSpecification, "Required value not set, taskTemplate is nil") + } + retry := toRetryStrategy(ctx, toBackoffLimit(taskTemplate.Metadata), cfg.MinRetries, cfg.MaxRetries) + // If job isn't currently being monitored (recovering from a restart?), add it to the sync-cache and return if job == nil { logger.Info(ctx, "Job not found in cache, adding it. [%v]", jobName) _, err = jobStore.GetOrCreate(jobName, &Job{ ID: *currentState.ExternalJobID, - OwnerReference: taskMeta.GetOwnerID(), + OwnerReference: tCtx.TaskExecutionMetadata().GetOwnerID(), SubJobs: createSubJobList(currentState.GetExecutionArraySize()), }) @@ -108,6 +121,10 @@ func CheckSubTasksState(ctx context.Context, taskMeta core.TaskExecutionMetadata } else { msg.Collect(childIdx, "Job failed") } + + if subJob.Status.Phase == core.PhaseRetryableFailure && *retry.Attempts == int64(len(subJob.Attempts)) { + actualPhase = core.PhasePermanentFailure + } } else if subJob.Status.Phase.IsSuccess() { actualPhase, err = array.CheckTaskOutput(ctx, dataStore, outputPrefix, baseOutputSandbox, childIdx, originalIdx) if err != nil { diff --git a/go/tasks/plugins/array/awsbatch/monitor_test.go b/go/tasks/plugins/array/awsbatch/monitor_test.go index 621512b4a..a9d708377 100644 --- a/go/tasks/plugins/array/awsbatch/monitor_test.go +++ b/go/tasks/plugins/array/awsbatch/monitor_test.go @@ -3,6 +3,8 @@ package awsbatch import ( "testing" + "github.com/stretchr/testify/mock" + "github.com/flyteorg/flytestdlib/contextutils" "github.com/flyteorg/flytestdlib/promutils/labeled" @@ -11,6 +13,7 @@ import ( "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/arraystatus" + flyteIdl "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" "github.com/aws/aws-sdk-go/aws/request" @@ -19,6 +22,7 @@ import ( arrayCore "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/core" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks" + ioMocks "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io/mocks" "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/awsbatch/config" batchMocks "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/awsbatch/mocks" "github.com/flyteorg/flytestdlib/utils" @@ -35,15 +39,39 @@ func init() { func TestCheckSubTasksState(t *testing.T) { ctx := context.Background() + tCtx := &mocks.TaskExecutionContext{} tID := &mocks.TaskExecutionID{} tID.OnGetGeneratedName().Return("generated-name") - tMeta := &mocks.TaskExecutionMetadata{} tMeta.OnGetOwnerID().Return(types.NamespacedName{ Namespace: "domain", Name: "name", }) tMeta.OnGetTaskExecutionID().Return(tID) + inMemDatastore, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) + assert.NoError(t, err) + + outputWriter := &ioMocks.OutputWriter{} + outputWriter.OnGetOutputPrefixPath().Return("") + outputWriter.OnGetRawOutputPrefix().Return("") + + taskReader := &mocks.TaskReader{} + task := &flyteIdl.TaskTemplate{ + Type: "test", + Target: &flyteIdl.TaskTemplate_Container{ + Container: &flyteIdl.Container{ + Command: []string{"command"}, + Args: []string{"{{.Input}}"}, + }, + }, + Metadata: &flyteIdl.TaskMetadata{Retries: &flyteIdl.RetryStrategy{Retries: 3}}, + } + taskReader.On("Read", mock.Anything).Return(task, nil) + + tCtx.OnOutputWriter().Return(outputWriter) + tCtx.OnTaskReader().Return(taskReader) + tCtx.OnDataStore().Return(inMemDatastore) + tCtx.OnTaskExecutionMetadata().Return(tMeta) t.Run("Not in cache", func(t *testing.T) { mBatchClient := batchMocks.NewMockAwsBatchClient() @@ -52,7 +80,7 @@ func TestCheckSubTasksState(t *testing.T) { utils.NewRateLimiter("", 10, 20)) jobStore := newJobsStore(t, batchClient) - newState, err := CheckSubTasksState(ctx, tMeta, "", "", jobStore, nil, &config.Config{}, &State{ + newState, err := CheckSubTasksState(ctx, tCtx, jobStore, &config.Config{}, &State{ State: &arrayCore.State{ CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions, ExecutionArraySize: 5, @@ -98,7 +126,7 @@ func TestCheckSubTasksState(t *testing.T) { assert.NoError(t, err) - newState, err := CheckSubTasksState(ctx, tMeta, "", "", jobStore, nil, &config.Config{}, &State{ + newState, err := CheckSubTasksState(ctx, tCtx, jobStore, &config.Config{}, &State{ State: &arrayCore.State{ CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions, ExecutionArraySize: 5, @@ -133,13 +161,10 @@ func TestCheckSubTasksState(t *testing.T) { assert.NoError(t, err) - inMemDatastore, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) - assert.NoError(t, err) - retryAttemptsArray, err := bitarray.NewCompactArray(1, bitarray.Item(1)) assert.NoError(t, err) - newState, err := CheckSubTasksState(ctx, tMeta, "", "", jobStore, inMemDatastore, &config.Config{}, &State{ + newState, err := CheckSubTasksState(ctx, tCtx, jobStore, &config.Config{}, &State{ State: &arrayCore.State{ CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions, ExecutionArraySize: 1, @@ -181,13 +206,10 @@ func TestCheckSubTasksState(t *testing.T) { assert.NoError(t, err) - inMemDatastore, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) - assert.NoError(t, err) - retryAttemptsArray, err := bitarray.NewCompactArray(2, bitarray.Item(1)) assert.NoError(t, err) - newState, err := CheckSubTasksState(ctx, tMeta, "", "", jobStore, inMemDatastore, &config.Config{}, &State{ + newState, err := CheckSubTasksState(ctx, tCtx, jobStore, &config.Config{}, &State{ State: &arrayCore.State{ CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions, ExecutionArraySize: 2, @@ -206,6 +228,49 @@ func TestCheckSubTasksState(t *testing.T) { assert.NoError(t, err) p, _ := newState.GetPhase() assert.Equal(t, arrayCore.PhaseCheckingSubTaskExecutions.String(), p.String()) + }) + + t.Run("retry limit exceeded", func(t *testing.T) { + mBatchClient := batchMocks.NewMockAwsBatchClient() + batchClient := NewCustomBatchClient(mBatchClient, "", "", + utils.NewRateLimiter("", 10, 20), + utils.NewRateLimiter("", 10, 20)) + + jobStore := newJobsStore(t, batchClient) + _, err := jobStore.GetOrCreate(tID.GetGeneratedName(), &Job{ + ID: "job-id", + Status: JobStatus{ + Phase: core.PhaseRunning, + }, + SubJobs: []*Job{ + {Status: JobStatus{Phase: core.PhaseRetryableFailure}, Attempts: []Attempt{{LogStream: "failed"}}}, + {Status: JobStatus{Phase: core.PhaseSuccess}}, + }, + }) + + assert.NoError(t, err) + + retryAttemptsArray, err := bitarray.NewCompactArray(2, bitarray.Item(1)) + assert.NoError(t, err) + newState, err := CheckSubTasksState(ctx, tCtx, jobStore, &config.Config{}, &State{ + State: &arrayCore.State{ + CurrentPhase: arrayCore.PhaseWriteToDiscoveryThenFail, + ExecutionArraySize: 2, + OriginalArraySize: 2, + OriginalMinSuccesses: 2, + ArrayStatus: arraystatus.ArrayStatus{ + Detailed: arrayCore.NewPhasesCompactArray(2), + }, + IndexesToCache: bitarray.NewBitSet(2), + RetryAttempts: retryAttemptsArray, + }, + ExternalJobID: refStr("job-id"), + JobDefinitionArn: "", + }, getAwsBatchExecutorMetrics(promutils.NewTestScope())) + + assert.NoError(t, err) + p, _ := newState.GetPhase() + assert.Equal(t, arrayCore.PhaseWriteToDiscoveryThenFail, p) }) } diff --git a/go/tasks/plugins/array/core/state_test.go b/go/tasks/plugins/array/core/state_test.go index 654979431..d7e9dfeed 100644 --- a/go/tasks/plugins/array/core/state_test.go +++ b/go/tasks/plugins/array/core/state_test.go @@ -334,6 +334,22 @@ func TestSummaryToPhase(t *testing.T) { core.PhaseSuccess: 10, }, }, + { + "FailedToRetry", + PhaseWriteToDiscoveryThenFail, + map[core.Phase]int64{ + core.PhaseSuccess: 5, + core.PhasePermanentFailure: 5, + }, + }, + { + "Retrying", + PhaseCheckingSubTaskExecutions, + map[core.Phase]int64{ + core.PhaseSuccess: 5, + core.PhaseRetryableFailure: 5, + }, + }, } for _, tt := range tests {