diff --git a/go/tasks/plugins/array/awsbatch/monitor_test.go b/go/tasks/plugins/array/awsbatch/monitor_test.go index e9cedb90b..a9d708377 100644 --- a/go/tasks/plugins/array/awsbatch/monitor_test.go +++ b/go/tasks/plugins/array/awsbatch/monitor_test.go @@ -3,6 +3,8 @@ package awsbatch import ( "testing" + "github.com/stretchr/testify/mock" + "github.com/flyteorg/flytestdlib/contextutils" "github.com/flyteorg/flytestdlib/promutils/labeled" @@ -11,6 +13,7 @@ import ( "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/arraystatus" + flyteIdl "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" "github.com/aws/aws-sdk-go/aws/request" @@ -19,6 +22,7 @@ import ( arrayCore "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/core" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks" + ioMocks "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io/mocks" "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/awsbatch/config" batchMocks "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/awsbatch/mocks" "github.com/flyteorg/flytestdlib/utils" @@ -35,15 +39,39 @@ func init() { func TestCheckSubTasksState(t *testing.T) { ctx := context.Background() + tCtx := &mocks.TaskExecutionContext{} tID := &mocks.TaskExecutionID{} tID.OnGetGeneratedName().Return("generated-name") - tMeta := &mocks.TaskExecutionMetadata{} tMeta.OnGetOwnerID().Return(types.NamespacedName{ Namespace: "domain", Name: "name", }) tMeta.OnGetTaskExecutionID().Return(tID) + inMemDatastore, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) + assert.NoError(t, err) + + outputWriter := &ioMocks.OutputWriter{} + outputWriter.OnGetOutputPrefixPath().Return("") + outputWriter.OnGetRawOutputPrefix().Return("") + + taskReader := &mocks.TaskReader{} + task := &flyteIdl.TaskTemplate{ + Type: "test", + Target: &flyteIdl.TaskTemplate_Container{ + Container: &flyteIdl.Container{ + Command: []string{"command"}, + Args: []string{"{{.Input}}"}, + }, + }, + Metadata: &flyteIdl.TaskMetadata{Retries: &flyteIdl.RetryStrategy{Retries: 3}}, + } + taskReader.On("Read", mock.Anything).Return(task, nil) + + tCtx.OnOutputWriter().Return(outputWriter) + tCtx.OnTaskReader().Return(taskReader) + tCtx.OnDataStore().Return(inMemDatastore) + tCtx.OnTaskExecutionMetadata().Return(tMeta) t.Run("Not in cache", func(t *testing.T) { mBatchClient := batchMocks.NewMockAwsBatchClient() @@ -52,7 +80,7 @@ func TestCheckSubTasksState(t *testing.T) { utils.NewRateLimiter("", 10, 20)) jobStore := newJobsStore(t, batchClient) - newState, err := CheckSubTasksState(ctx, tMeta, "", "", jobStore, nil, &config.Config{}, &State{ + newState, err := CheckSubTasksState(ctx, tCtx, jobStore, &config.Config{}, &State{ State: &arrayCore.State{ CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions, ExecutionArraySize: 5, @@ -61,7 +89,7 @@ func TestCheckSubTasksState(t *testing.T) { }, ExternalJobID: refStr("job-id"), JobDefinitionArn: "", - }, getAwsBatchExecutorMetrics(promutils.NewTestScope()), 3) + }, getAwsBatchExecutorMetrics(promutils.NewTestScope())) assert.NoError(t, err) p, _ := newState.GetPhase() @@ -98,7 +126,7 @@ func TestCheckSubTasksState(t *testing.T) { assert.NoError(t, err) - newState, err := CheckSubTasksState(ctx, tMeta, "", "", jobStore, nil, &config.Config{}, &State{ + newState, err := CheckSubTasksState(ctx, tCtx, jobStore, &config.Config{}, &State{ State: &arrayCore.State{ CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions, ExecutionArraySize: 5, @@ -107,7 +135,7 @@ func TestCheckSubTasksState(t *testing.T) { }, ExternalJobID: refStr("job-id"), JobDefinitionArn: "", - }, getAwsBatchExecutorMetrics(promutils.NewTestScope()), 3) + }, getAwsBatchExecutorMetrics(promutils.NewTestScope())) assert.NoError(t, err) p, _ := newState.GetPhase() @@ -133,13 +161,10 @@ func TestCheckSubTasksState(t *testing.T) { assert.NoError(t, err) - inMemDatastore, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) - assert.NoError(t, err) - retryAttemptsArray, err := bitarray.NewCompactArray(1, bitarray.Item(1)) assert.NoError(t, err) - newState, err := CheckSubTasksState(ctx, tMeta, "", "", jobStore, inMemDatastore, &config.Config{}, &State{ + newState, err := CheckSubTasksState(ctx, tCtx, jobStore, &config.Config{}, &State{ State: &arrayCore.State{ CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions, ExecutionArraySize: 1, @@ -153,7 +178,7 @@ func TestCheckSubTasksState(t *testing.T) { }, ExternalJobID: refStr("job-id"), JobDefinitionArn: "", - }, getAwsBatchExecutorMetrics(promutils.NewTestScope()), 3) + }, getAwsBatchExecutorMetrics(promutils.NewTestScope())) assert.NoError(t, err) p, _ := newState.GetPhase() @@ -181,13 +206,10 @@ func TestCheckSubTasksState(t *testing.T) { assert.NoError(t, err) - inMemDatastore, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) - assert.NoError(t, err) - retryAttemptsArray, err := bitarray.NewCompactArray(2, bitarray.Item(1)) assert.NoError(t, err) - newState, err := CheckSubTasksState(ctx, tMeta, "", "", jobStore, inMemDatastore, &config.Config{}, &State{ + newState, err := CheckSubTasksState(ctx, tCtx, jobStore, &config.Config{}, &State{ State: &arrayCore.State{ CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions, ExecutionArraySize: 2, @@ -201,7 +223,7 @@ func TestCheckSubTasksState(t *testing.T) { }, ExternalJobID: refStr("job-id"), JobDefinitionArn: "", - }, getAwsBatchExecutorMetrics(promutils.NewTestScope()), 3) + }, getAwsBatchExecutorMetrics(promutils.NewTestScope())) assert.NoError(t, err) p, _ := newState.GetPhase() @@ -228,13 +250,10 @@ func TestCheckSubTasksState(t *testing.T) { assert.NoError(t, err) - inMemDatastore, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) - assert.NoError(t, err) - retryAttemptsArray, err := bitarray.NewCompactArray(2, bitarray.Item(1)) assert.NoError(t, err) - newState, err := CheckSubTasksState(ctx, tMeta, "", "", jobStore, inMemDatastore, &config.Config{}, &State{ + newState, err := CheckSubTasksState(ctx, tCtx, jobStore, &config.Config{}, &State{ State: &arrayCore.State{ CurrentPhase: arrayCore.PhaseWriteToDiscoveryThenFail, ExecutionArraySize: 2, @@ -248,7 +267,7 @@ func TestCheckSubTasksState(t *testing.T) { }, ExternalJobID: refStr("job-id"), JobDefinitionArn: "", - }, getAwsBatchExecutorMetrics(promutils.NewTestScope()), 1) + }, getAwsBatchExecutorMetrics(promutils.NewTestScope())) assert.NoError(t, err) p, _ := newState.GetPhase()