Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

The status of the AWS batch job should become failed once the retry limit exceeded #291

Merged
merged 15 commits into from
Dec 1, 2022
4 changes: 1 addition & 3 deletions go/tasks/plugins/array/awsbatch/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,7 @@ func (e Executor) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (c
pluginState, err = LaunchSubTasks(ctx, tCtx, e.jobStore, pluginConfig, pluginState, e.metrics)

case arrayCore.PhaseCheckingSubTaskExecutions:
pluginState, err = CheckSubTasksState(ctx, tCtx.TaskExecutionMetadata(),
tCtx.OutputWriter().GetOutputPrefixPath(), tCtx.OutputWriter().GetRawOutputPrefix(),
e.jobStore, tCtx.DataStore(), pluginConfig, pluginState, e.metrics)
pluginState, err = CheckSubTasksState(ctx, tCtx, e.jobStore, pluginConfig, pluginState, e.metrics)

case arrayCore.PhaseAssembleFinalOutput:
pluginState.State, err = array.AssembleFinalOutputs(ctx, e.outputAssembler, tCtx, arrayCore.PhaseSuccess, version, pluginState.State)
Expand Down
27 changes: 22 additions & 5 deletions go/tasks/plugins/array/awsbatch/monitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ import (
"context"

core2 "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flytestdlib/storage"

"github.com/flyteorg/flyteplugins/go/tasks/errors"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/ioutils"
"github.com/flyteorg/flyteplugins/go/tasks/plugins/array"
Expand Down Expand Up @@ -34,19 +34,32 @@ func createSubJobList(count int) []*Job {
return res
}

func CheckSubTasksState(ctx context.Context, taskMeta core.TaskExecutionMetadata, outputPrefix, baseOutputSandbox storage.DataReference, jobStore *JobStore,
dataStore *storage.DataStore, cfg *config.Config, currentState *State, metrics ExecutorMetrics) (newState *State, err error) {
func CheckSubTasksState(ctx context.Context, tCtx core.TaskExecutionContext, jobStore *JobStore,
cfg *config.Config, currentState *State, metrics ExecutorMetrics) (newState *State, err error) {
newState = currentState
parentState := currentState.State
jobName := taskMeta.GetTaskExecutionID().GetGeneratedName()
jobName := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName()
job := jobStore.Get(jobName)
outputPrefix := tCtx.OutputWriter().GetOutputPrefixPath()
baseOutputSandbox := tCtx.OutputWriter().GetRawOutputPrefix()
dataStore := tCtx.DataStore()
// Check that the taskTemplate is valid
var taskTemplate *core2.TaskTemplate
taskTemplate, err = tCtx.TaskReader().Read(ctx)
if err != nil {
return nil, errors.Wrapf(errors.CorruptedPluginState, err, "Failed to read task template")
} else if taskTemplate == nil {
return nil, errors.Errorf(errors.BadTaskSpecification, "Required value not set, taskTemplate is nil")
}
retry := toRetryStrategy(ctx, toBackoffLimit(taskTemplate.Metadata), cfg.MinRetries, cfg.MaxRetries)

// If job isn't currently being monitored (recovering from a restart?), add it to the sync-cache and return
if job == nil {
logger.Info(ctx, "Job not found in cache, adding it. [%v]", jobName)

_, err = jobStore.GetOrCreate(jobName, &Job{
ID: *currentState.ExternalJobID,
OwnerReference: taskMeta.GetOwnerID(),
OwnerReference: tCtx.TaskExecutionMetadata().GetOwnerID(),
SubJobs: createSubJobList(currentState.GetExecutionArraySize()),
})

Expand Down Expand Up @@ -108,6 +121,10 @@ func CheckSubTasksState(ctx context.Context, taskMeta core.TaskExecutionMetadata
} else {
msg.Collect(childIdx, "Job failed")
}

if subJob.Status.Phase == core.PhaseRetryableFailure && *retry.Attempts == int64(len(subJob.Attempts)) {
actualPhase = core.PhasePermanentFailure
}
} else if subJob.Status.Phase.IsSuccess() {
actualPhase, err = array.CheckTaskOutput(ctx, dataStore, outputPrefix, baseOutputSandbox, childIdx, originalIdx)
if err != nil {
Expand Down
87 changes: 76 additions & 11 deletions go/tasks/plugins/array/awsbatch/monitor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package awsbatch
import (
"testing"

"github.com/stretchr/testify/mock"

"github.com/flyteorg/flytestdlib/contextutils"
"github.com/flyteorg/flytestdlib/promutils/labeled"

Expand All @@ -11,6 +13,7 @@ import (

"github.com/flyteorg/flyteplugins/go/tasks/plugins/array/arraystatus"

flyteIdl "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"

"github.com/aws/aws-sdk-go/aws/request"
Expand All @@ -19,6 +22,7 @@ import (
arrayCore "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/core"

"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks"
ioMocks "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io/mocks"
"github.com/flyteorg/flyteplugins/go/tasks/plugins/array/awsbatch/config"
batchMocks "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/awsbatch/mocks"
"github.com/flyteorg/flytestdlib/utils"
Expand All @@ -35,15 +39,39 @@ func init() {

func TestCheckSubTasksState(t *testing.T) {
ctx := context.Background()
tCtx := &mocks.TaskExecutionContext{}
tID := &mocks.TaskExecutionID{}
tID.OnGetGeneratedName().Return("generated-name")

tMeta := &mocks.TaskExecutionMetadata{}
tMeta.OnGetOwnerID().Return(types.NamespacedName{
Namespace: "domain",
Name: "name",
})
tMeta.OnGetTaskExecutionID().Return(tID)
inMemDatastore, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope())
assert.NoError(t, err)

outputWriter := &ioMocks.OutputWriter{}
outputWriter.OnGetOutputPrefixPath().Return("")
outputWriter.OnGetRawOutputPrefix().Return("")

taskReader := &mocks.TaskReader{}
task := &flyteIdl.TaskTemplate{
Type: "test",
Target: &flyteIdl.TaskTemplate_Container{
Container: &flyteIdl.Container{
Command: []string{"command"},
Args: []string{"{{.Input}}"},
},
},
Metadata: &flyteIdl.TaskMetadata{Retries: &flyteIdl.RetryStrategy{Retries: 3}},
}
taskReader.On("Read", mock.Anything).Return(task, nil)

tCtx.OnOutputWriter().Return(outputWriter)
tCtx.OnTaskReader().Return(taskReader)
tCtx.OnDataStore().Return(inMemDatastore)
tCtx.OnTaskExecutionMetadata().Return(tMeta)

t.Run("Not in cache", func(t *testing.T) {
mBatchClient := batchMocks.NewMockAwsBatchClient()
Expand All @@ -52,7 +80,7 @@ func TestCheckSubTasksState(t *testing.T) {
utils.NewRateLimiter("", 10, 20))

jobStore := newJobsStore(t, batchClient)
newState, err := CheckSubTasksState(ctx, tMeta, "", "", jobStore, nil, &config.Config{}, &State{
newState, err := CheckSubTasksState(ctx, tCtx, jobStore, &config.Config{}, &State{
State: &arrayCore.State{
CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions,
ExecutionArraySize: 5,
Expand Down Expand Up @@ -98,7 +126,7 @@ func TestCheckSubTasksState(t *testing.T) {

assert.NoError(t, err)

newState, err := CheckSubTasksState(ctx, tMeta, "", "", jobStore, nil, &config.Config{}, &State{
newState, err := CheckSubTasksState(ctx, tCtx, jobStore, &config.Config{}, &State{
State: &arrayCore.State{
CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions,
ExecutionArraySize: 5,
Expand Down Expand Up @@ -133,13 +161,10 @@ func TestCheckSubTasksState(t *testing.T) {

assert.NoError(t, err)

inMemDatastore, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope())
assert.NoError(t, err)

retryAttemptsArray, err := bitarray.NewCompactArray(1, bitarray.Item(1))
assert.NoError(t, err)

newState, err := CheckSubTasksState(ctx, tMeta, "", "", jobStore, inMemDatastore, &config.Config{}, &State{
newState, err := CheckSubTasksState(ctx, tCtx, jobStore, &config.Config{}, &State{
State: &arrayCore.State{
CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions,
ExecutionArraySize: 1,
Expand Down Expand Up @@ -181,13 +206,10 @@ func TestCheckSubTasksState(t *testing.T) {

assert.NoError(t, err)

inMemDatastore, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope())
assert.NoError(t, err)

retryAttemptsArray, err := bitarray.NewCompactArray(2, bitarray.Item(1))
assert.NoError(t, err)

newState, err := CheckSubTasksState(ctx, tMeta, "", "", jobStore, inMemDatastore, &config.Config{}, &State{
newState, err := CheckSubTasksState(ctx, tCtx, jobStore, &config.Config{}, &State{
State: &arrayCore.State{
CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions,
ExecutionArraySize: 2,
Expand All @@ -206,6 +228,49 @@ func TestCheckSubTasksState(t *testing.T) {
assert.NoError(t, err)
p, _ := newState.GetPhase()
assert.Equal(t, arrayCore.PhaseCheckingSubTaskExecutions.String(), p.String())
})

t.Run("retry limit exceeded", func(t *testing.T) {
mBatchClient := batchMocks.NewMockAwsBatchClient()
batchClient := NewCustomBatchClient(mBatchClient, "", "",
utils.NewRateLimiter("", 10, 20),
utils.NewRateLimiter("", 10, 20))

jobStore := newJobsStore(t, batchClient)
_, err := jobStore.GetOrCreate(tID.GetGeneratedName(), &Job{
ID: "job-id",
Status: JobStatus{
Phase: core.PhaseRunning,
},
SubJobs: []*Job{
{Status: JobStatus{Phase: core.PhaseRetryableFailure}, Attempts: []Attempt{{LogStream: "failed"}}},
{Status: JobStatus{Phase: core.PhaseSuccess}},
},
})

assert.NoError(t, err)

retryAttemptsArray, err := bitarray.NewCompactArray(2, bitarray.Item(1))
assert.NoError(t, err)

newState, err := CheckSubTasksState(ctx, tCtx, jobStore, &config.Config{}, &State{
State: &arrayCore.State{
CurrentPhase: arrayCore.PhaseWriteToDiscoveryThenFail,
ExecutionArraySize: 2,
OriginalArraySize: 2,
OriginalMinSuccesses: 2,
ArrayStatus: arraystatus.ArrayStatus{
Detailed: arrayCore.NewPhasesCompactArray(2),
},
IndexesToCache: bitarray.NewBitSet(2),
RetryAttempts: retryAttemptsArray,
},
ExternalJobID: refStr("job-id"),
JobDefinitionArn: "",
}, getAwsBatchExecutorMetrics(promutils.NewTestScope()))

assert.NoError(t, err)
p, _ := newState.GetPhase()
assert.Equal(t, arrayCore.PhaseWriteToDiscoveryThenFail, p)
})
}
16 changes: 16 additions & 0 deletions go/tasks/plugins/array/core/state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,22 @@ func TestSummaryToPhase(t *testing.T) {
core.PhaseSuccess: 10,
},
},
{
"FailedToRetry",
PhaseWriteToDiscoveryThenFail,
map[core.Phase]int64{
core.PhaseSuccess: 5,
core.PhasePermanentFailure: 5,
},
},
{
"Retrying",
PhaseCheckingSubTaskExecutions,
map[core.Phase]int64{
core.PhaseSuccess: 5,
core.PhaseRetryableFailure: 5,
},
},
}

for _, tt := range tests {
Expand Down