From e9688b82f13708750f25ff042ec71160209c9728 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Mon, 7 Feb 2022 08:10:32 +0800 Subject: [PATCH] Add Raw AWS Batch Task (#228) * Add Raw AWS Batch Task Signed-off-by: Kevin Su * Fix test Signed-off-by: Kevin Su * Fix lint Signed-off-by: Kevin Su * Fix lint Signed-off-by: Kevin Su * Add tests Signed-off-by: Kevin Su * Fix lint Signed-off-by: Kevin Su * Remove log Signed-off-by: Kevin Su * Updated tests and added comment Signed-off-by: Kevin Su * address comment Signed-off-by: Kevin Su * Fix tests Signed-off-by: Kevin Su * Convert pb to string Signed-off-by: Kevin Su * Use job definition as cache key Signed-off-by: Kevin Su * Fixed tests Signed-off-by: Kevin Su * Fixed tests Signed-off-by: Kevin Su * One more test Signed-off-by: Kevin Su * Hash job definition Signed-off-by: Kevin Su * address comments Signed-off-by: Kevin Su * Updated dependency Signed-off-by: Kevin Su * Fixed test Signed-off-by: Kevin Su * lint fixed Signed-off-by: Kevin Su * Reorder Signed-off-by: Kevin Su --- .../go/tasks/plugins/array/awsbatch/client.go | 23 ++++++------- .../plugins/array/awsbatch/client_test.go | 2 +- .../awsbatch/definition/job_def_cache.go | 14 ++++---- .../tasks/plugins/array/awsbatch/executor.go | 4 +-- .../plugins/array/awsbatch/job_definition.go | 15 ++++++--- .../array/awsbatch/job_definition_test.go | 10 ++++-- .../plugins/array/awsbatch/mocks/client.go | 18 +++++------ .../plugins/array/awsbatch/transformer.go | 25 ++++++++++----- .../array/awsbatch/transformer_test.go | 11 ++++++- .../go/tasks/plugins/array/catalog.go | 17 +++++++++- .../go/tasks/plugins/array/catalog_test.go | 14 ++++++++ flyteplugins/go/tasks/plugins/array/inputs.go | 2 +- .../go/tasks/plugins/array/outputs.go | 32 ++++++++++++------- .../go/tasks/plugins/array/outputs_test.go | 23 ++++++++----- 14 files changed, 143 insertions(+), 67 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/array/awsbatch/client.go b/flyteplugins/go/tasks/plugins/array/awsbatch/client.go index ccedda6f35..b76d4bb2b7 100644 --- a/flyteplugins/go/tasks/plugins/array/awsbatch/client.go +++ b/flyteplugins/go/tasks/plugins/array/awsbatch/client.go @@ -9,16 +9,13 @@ import ( "context" "fmt" - definition2 "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/awsbatch/definition" - - "github.com/flyteorg/flyteplugins/go/tasks/aws" - "github.com/flyteorg/flytestdlib/utils" - - "github.com/flyteorg/flytestdlib/logger" - a "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/service/batch" + "github.com/flyteorg/flyteplugins/go/tasks/aws" + definition2 "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/awsbatch/definition" + "github.com/flyteorg/flytestdlib/logger" + "github.com/flyteorg/flytestdlib/utils" ) //go:generate mockery -all -case=underscore @@ -35,7 +32,7 @@ type Client interface { GetJobDetailsBatch(ctx context.Context, ids []JobID) ([]*batch.JobDetail, error) // Registers a new Job Definition with AWS Batch provided a name, image and role. - RegisterJobDefinition(ctx context.Context, name, image, role string) (arn string, err error) + RegisterJobDefinition(ctx context.Context, name, image, role string, platformCapabilities string) (arn string, err error) // Gets the single region this client interacts with. GetRegion() string @@ -68,12 +65,13 @@ func (b client) GetAccountID() string { } // Registers a new job definition. There is no deduping on AWS side (even for the same name). -func (b *client) RegisterJobDefinition(ctx context.Context, name, image, role string) (arn definition2.JobDefinitionArn, err error) { - logger.Infof(ctx, "Registering job definition with name [%v], image [%v], role [%v]", name, image, role) +func (b *client) RegisterJobDefinition(ctx context.Context, name, image, role string, platformCapabilities string) (arn definition2.JobDefinitionArn, err error) { + logger.Infof(ctx, "Registering job definition with name [%v], image [%v], role [%v], platformCapabilities [%v]", name, image, role, platformCapabilities) res, err := b.Batch.RegisterJobDefinitionWithContext(ctx, &batch.RegisterJobDefinitionInput{ - Type: refStr(batch.JobDefinitionTypeContainer), - JobDefinitionName: refStr(name), + Type: refStr(batch.JobDefinitionTypeContainer), + JobDefinitionName: refStr(name), + PlatformCapabilities: refStrSlice([]string{platformCapabilities}), ContainerProperties: &batch.ContainerProperties{ Image: refStr(image), JobRoleArn: refStr(role), @@ -83,7 +81,6 @@ func (b *client) RegisterJobDefinition(ctx context.Context, name, image, role st Memory: refInt(100), }, }) - if err != nil { return "", err } diff --git a/flyteplugins/go/tasks/plugins/array/awsbatch/client_test.go b/flyteplugins/go/tasks/plugins/array/awsbatch/client_test.go index 5c494e07a5..f4dd72fe6c 100644 --- a/flyteplugins/go/tasks/plugins/array/awsbatch/client_test.go +++ b/flyteplugins/go/tasks/plugins/array/awsbatch/client_test.go @@ -81,7 +81,7 @@ func TestClient_GetJobDetailsBatch(t *testing.T) { func TestClient_RegisterJobDefinition(t *testing.T) { c := newClientWithMockBatch() - j, err := c.RegisterJobDefinition(context.TODO(), "name-abc", "img", "admin-role") + j, err := c.RegisterJobDefinition(context.TODO(), "name-abc", "img", "admin-role", defaultComputeEngine) assert.NoError(t, err) assert.NotNil(t, j) } diff --git a/flyteplugins/go/tasks/plugins/array/awsbatch/definition/job_def_cache.go b/flyteplugins/go/tasks/plugins/array/awsbatch/definition/job_def_cache.go index 6456b5b342..ece102ebe0 100644 --- a/flyteplugins/go/tasks/plugins/array/awsbatch/definition/job_def_cache.go +++ b/flyteplugins/go/tasks/plugins/array/awsbatch/definition/job_def_cache.go @@ -27,12 +27,13 @@ type CacheKey interface { } type cacheKey struct { - role string - image string + role string + image string + platformCapabilities string } func (k cacheKey) String() string { - return fmt.Sprintf("%v-%v", k.image, k.role) + return fmt.Sprintf("%v-%v-%v", k.image, k.role, k.platformCapabilities) } type cache struct { @@ -52,10 +53,11 @@ func (c cache) Put(key CacheKey, definition JobDefinitionArn) error { } // Creates a new deterministic cache key. -func NewCacheKey(role, image string) CacheKey { +func NewCacheKey(role, image, platformCapabilities string) CacheKey { return cacheKey{ - role: role, - image: image, + role: role, + image: image, + platformCapabilities: platformCapabilities, } } diff --git a/flyteplugins/go/tasks/plugins/array/awsbatch/executor.go b/flyteplugins/go/tasks/plugins/array/awsbatch/executor.go index 104149f296..ec4b5bc3c7 100644 --- a/flyteplugins/go/tasks/plugins/array/awsbatch/executor.go +++ b/flyteplugins/go/tasks/plugins/array/awsbatch/executor.go @@ -189,10 +189,10 @@ func init() { pluginmachinery.PluginRegistry().RegisterCorePlugin( core.PluginEntry{ ID: executorName, - RegisteredTaskTypes: []core.TaskType{arrayTaskType}, + RegisteredTaskTypes: []core.TaskType{arrayTaskType, array.AwsBatchTaskType}, LoadPlugin: createNewExecutorPlugin, IsDefault: false, - DefaultForTaskTypes: []core.TaskType{arrayTaskType}, + DefaultForTaskTypes: []core.TaskType{arrayTaskType, array.AwsBatchTaskType}, }) } diff --git a/flyteplugins/go/tasks/plugins/array/awsbatch/job_definition.go b/flyteplugins/go/tasks/plugins/array/awsbatch/job_definition.go index 9809073370..401be2148b 100644 --- a/flyteplugins/go/tasks/plugins/array/awsbatch/job_definition.go +++ b/flyteplugins/go/tasks/plugins/array/awsbatch/job_definition.go @@ -16,6 +16,9 @@ import ( "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/awsbatch/definition" ) +const defaultComputeEngine = "EC2" +const platformCapabilitiesConfigKey = "platformCapabilities" + func getContainerImage(_ context.Context, task *core.TaskTemplate) string { if task.GetContainer() != nil && len(task.GetContainer().Image) > 0 { return task.GetContainer().Image @@ -51,11 +54,15 @@ func EnsureJobDefinition(ctx context.Context, tCtx pluginCore.TaskExecutionConte } role := awsUtils.GetRoleFromSecurityContext(cfg.RoleAnnotationKey, tCtx.TaskExecutionMetadata()) + platformCapabilities := taskTemplate.GetConfig()[platformCapabilitiesConfigKey] + if len(platformCapabilities) == 0 { + platformCapabilities = defaultComputeEngine + } - cacheKey := definition.NewCacheKey(role, containerImage) + cacheKey := definition.NewCacheKey(role, containerImage, platformCapabilities) if existingArn, found := definitionCache.Get(cacheKey); found { - logger.Infof(ctx, "Found an existing job definition for Image [%v] and Role [%v]. Arn [%v]", - containerImage, role, existingArn) + logger.Infof(ctx, "Found an existing job definition for Image [%v], Role [%v], JobDefinitionInput [%v]. Arn [%v]", + containerImage, role, platformCapabilities, existingArn) nextState = currentState.SetJobDefinitionArn(existingArn) nextState.State = nextState.SetPhase(arrayCore.PhaseLaunch, 0).SetReason("AWS job definition already exist.") @@ -64,7 +71,7 @@ func EnsureJobDefinition(ctx context.Context, tCtx pluginCore.TaskExecutionConte name := definition.GetJobDefinitionSafeName(containerImageRepository(containerImage)) - arn, err := client.RegisterJobDefinition(ctx, name, containerImage, role) + arn, err := client.RegisterJobDefinition(ctx, name, containerImage, role, platformCapabilities) if err != nil { return currentState, err } diff --git a/flyteplugins/go/tasks/plugins/array/awsbatch/job_definition_test.go b/flyteplugins/go/tasks/plugins/array/awsbatch/job_definition_test.go index a7dbd97754..6936baa275 100644 --- a/flyteplugins/go/tasks/plugins/array/awsbatch/job_definition_test.go +++ b/flyteplugins/go/tasks/plugins/array/awsbatch/job_definition_test.go @@ -91,7 +91,7 @@ func TestEnsureJobDefinition(t *testing.T) { t.Run("Found", func(t *testing.T) { dCache := definition.NewCache(10) - assert.NoError(t, dCache.Put(definition.NewCacheKey("", "img1"), "their-arn")) + assert.NoError(t, dCache.Put(definition.NewCacheKey("", "img1", defaultComputeEngine), "their-arn")) nextState, err := EnsureJobDefinition(ctx, tCtx, cfg, batchClient, dCache, &State{ State: &arrayCore.State{}, @@ -100,6 +100,11 @@ func TestEnsureJobDefinition(t *testing.T) { assert.NotNil(t, nextState) assert.Equal(t, "their-arn", nextState.JobDefinitionArn) }) + + t.Run("Test New Cache Key", func(t *testing.T) { + cacheKey := definition.NewCacheKey("default", "img1", defaultComputeEngine) + assert.Equal(t, cacheKey.String(), "img1-default-EC2") + }) } func TestEnsureJobDefinitionWithSecurityContext(t *testing.T) { @@ -115,6 +120,7 @@ func TestEnsureJobDefinitionWithSecurityContext(t *testing.T) { Target: &core.TaskTemplate_Container{ Container: createSampleContainerTask(), }, + Config: map[string]string{platformCapabilitiesConfigKey: defaultComputeEngine}, }, nil) overrides := &mocks.TaskOverrides{} @@ -158,7 +164,7 @@ func TestEnsureJobDefinitionWithSecurityContext(t *testing.T) { t.Run("Found", func(t *testing.T) { dCache := definition.NewCache(10) - assert.NoError(t, dCache.Put(definition.NewCacheKey("new-role", "img1"), "their-arn")) + assert.NoError(t, dCache.Put(definition.NewCacheKey("new-role", "img1", defaultComputeEngine), "their-arn")) nextState, err := EnsureJobDefinition(ctx, tCtx, cfg, batchClient, dCache, &State{ State: &arrayCore.State{}, diff --git a/flyteplugins/go/tasks/plugins/array/awsbatch/mocks/client.go b/flyteplugins/go/tasks/plugins/array/awsbatch/mocks/client.go index 70faae703c..8e65d5e596 100644 --- a/flyteplugins/go/tasks/plugins/array/awsbatch/mocks/client.go +++ b/flyteplugins/go/tasks/plugins/array/awsbatch/mocks/client.go @@ -128,8 +128,8 @@ func (_m Client_RegisterJobDefinition) Return(arn string, err error) *Client_Reg return &Client_RegisterJobDefinition{Call: _m.Call.Return(arn, err)} } -func (_m *Client) OnRegisterJobDefinition(ctx context.Context, name string, image string, role string) *Client_RegisterJobDefinition { - c := _m.On("RegisterJobDefinition", ctx, name, image, role) +func (_m *Client) OnRegisterJobDefinition(ctx context.Context, name string, image string, role string, platformCapabilities string) *Client_RegisterJobDefinition { + c := _m.On("RegisterJobDefinition", ctx, name, image, role, platformCapabilities) return &Client_RegisterJobDefinition{Call: c} } @@ -138,20 +138,20 @@ func (_m *Client) OnRegisterJobDefinitionMatch(matchers ...interface{}) *Client_ return &Client_RegisterJobDefinition{Call: c} } -// RegisterJobDefinition provides a mock function with given fields: ctx, name, image, role -func (_m *Client) RegisterJobDefinition(ctx context.Context, name string, image string, role string) (string, error) { - ret := _m.Called(ctx, name, image, role) +// RegisterJobDefinition provides a mock function with given fields: ctx, name, image, role, structObj +func (_m *Client) RegisterJobDefinition(ctx context.Context, name string, image string, role string, platformCapabilities string) (string, error) { + ret := _m.Called(ctx, name, image, role, platformCapabilities) var r0 string - if rf, ok := ret.Get(0).(func(context.Context, string, string, string) string); ok { - r0 = rf(ctx, name, image, role) + if rf, ok := ret.Get(0).(func(context.Context, string, string, string, string) string); ok { + r0 = rf(ctx, name, image, role, platformCapabilities) } else { r0 = ret.Get(0).(string) } var r1 error - if rf, ok := ret.Get(1).(func(context.Context, string, string, string) error); ok { - r1 = rf(ctx, name, image, role) + if rf, ok := ret.Get(1).(func(context.Context, string, string, string, string) error); ok { + r1 = rf(ctx, name, image, role, platformCapabilities) } else { r1 = ret.Error(1) } diff --git a/flyteplugins/go/tasks/plugins/array/awsbatch/transformer.go b/flyteplugins/go/tasks/plugins/array/awsbatch/transformer.go index e68ada5434..f978d92ef5 100644 --- a/flyteplugins/go/tasks/plugins/array/awsbatch/transformer.go +++ b/flyteplugins/go/tasks/plugins/array/awsbatch/transformer.go @@ -5,6 +5,8 @@ import ( "sort" "time" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/utils" + "github.com/flyteorg/flyteplugins/go/tasks/plugins/array" "github.com/aws/aws-sdk-go/service/batch" @@ -85,14 +87,21 @@ func FlyteTaskToBatchInput(ctx context.Context, tCtx pluginCore.TaskExecutionCon } resources := flytek8s.ApplyResourceOverrides(*res, *platformResources, assignResources) - return &batch.SubmitJobInput{ - JobName: refStr(tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName()), - JobDefinition: refStr(jobDefinition), - JobQueue: refStr(jobConfig.DynamicTaskQueue), - RetryStrategy: toRetryStrategy(ctx, toBackoffLimit(taskTemplate.Metadata), cfg.MinRetries, cfg.MaxRetries), - ContainerOverrides: toContainerOverrides(ctx, append(cmd, args...), &resources, envVars), - Timeout: toTimeout(taskTemplate.Metadata.GetTimeout(), cfg.DefaultTimeOut.Duration), - }, nil + submitJobInput := &batch.SubmitJobInput{} + if taskTemplate.GetCustom() != nil { + err = utils.UnmarshalStructToObj(taskTemplate.GetCustom(), &submitJobInput) + if err != nil { + return nil, errors.Errorf(errors.BadTaskSpecification, + "invalid TaskSpecification [%v], Err: [%v]", taskTemplate.GetCustom(), err.Error()) + } + } + submitJobInput.SetJobName(tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName()). + SetJobDefinition(jobDefinition).SetJobQueue(jobConfig.DynamicTaskQueue). + SetRetryStrategy(toRetryStrategy(ctx, toBackoffLimit(taskTemplate.Metadata), cfg.MinRetries, cfg.MaxRetries)). + SetContainerOverrides(toContainerOverrides(ctx, append(cmd, args...), &resources, envVars)). + SetTimeout(toTimeout(taskTemplate.Metadata.GetTimeout(), cfg.DefaultTimeOut.Duration)) + + return submitJobInput, nil } func UpdateBatchInputForArray(_ context.Context, batchInput *batch.SubmitJobInput, arraySize int64) *batch.SubmitJobInput { diff --git a/flyteplugins/go/tasks/plugins/array/awsbatch/transformer_test.go b/flyteplugins/go/tasks/plugins/array/awsbatch/transformer_test.go index f78fecf65f..29fc8022cc 100644 --- a/flyteplugins/go/tasks/plugins/array/awsbatch/transformer_test.go +++ b/flyteplugins/go/tasks/plugins/array/awsbatch/transformer_test.go @@ -22,6 +22,7 @@ import ( "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/utils" + "github.com/flyteorg/flyteplugins/go/tasks/plugins/array" "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/awsbatch/config" v12 "k8s.io/api/core/v1" @@ -191,11 +192,11 @@ func TestArrayJobToBatchInput(t *testing.T) { Target: &core.TaskTemplate_Container{ Container: createSampleContainerTask(), }, + Type: arrayTaskType, } tr := &mocks.TaskReader{} tr.OnReadMatch(mock.Anything).Return(taskTemplate, nil) - taskCtx.OnTaskReader().Return(tr) ctx := context.Background() @@ -205,6 +206,14 @@ func TestArrayJobToBatchInput(t *testing.T) { batchInput = UpdateBatchInputForArray(ctx, batchInput, input.Size) assert.NotNil(t, batchInput) assert.Equal(t, *expectedBatchInput, *batchInput) + + taskTemplate.Type = array.AwsBatchTaskType + tr.OnReadMatch(mock.Anything).Return(taskTemplate, nil) + taskCtx.OnTaskReader().Return(tr) + + ctx = context.Background() + _, err = FlyteTaskToBatchInput(ctx, taskCtx, "", &config.Config{}) + assert.NoError(t, err) } func Test_getEnvVarsForTask(t *testing.T) { diff --git a/flyteplugins/go/tasks/plugins/array/catalog.go b/flyteplugins/go/tasks/plugins/array/catalog.go index 781f87227b..3695304b3c 100644 --- a/flyteplugins/go/tasks/plugins/array/catalog.go +++ b/flyteplugins/go/tasks/plugins/array/catalog.go @@ -6,6 +6,8 @@ import ( "math" "strconv" + idlPlugins "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins" + arrayCore "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/core" "github.com/flyteorg/flytestdlib/bitarray" @@ -21,6 +23,8 @@ import ( idlCore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" ) +const AwsBatchTaskType = "aws-batch" + // DetermineDiscoverability checks if there are any previously cached tasks. If there are we will only submit an // ArrayJob for the non-cached tasks. The ArrayJob is now a different size, and each task will get a new index location // which is different than their original location. To find the original index we construct an indexLookup array. @@ -38,7 +42,18 @@ func DetermineDiscoverability(ctx context.Context, tCtx core.TaskExecutionContex } // Extract the custom plugin pb - arrayJob, err := arrayCore.ToArrayJob(taskTemplate.GetCustom(), taskTemplate.TaskTypeVersion) + var arrayJob *idlPlugins.ArrayJob + if taskTemplate.Type == AwsBatchTaskType { + arrayJob = &idlPlugins.ArrayJob{ + Parallelism: 1, + Size: 1, + SuccessCriteria: &idlPlugins.ArrayJob_MinSuccesses{ + MinSuccesses: 1, + }, + } + } else { + arrayJob, err = arrayCore.ToArrayJob(taskTemplate.GetCustom(), taskTemplate.TaskTypeVersion) + } if err != nil { return state, err } diff --git a/flyteplugins/go/tasks/plugins/array/catalog_test.go b/flyteplugins/go/tasks/plugins/array/catalog_test.go index aad365dd04..3edb82056a 100644 --- a/flyteplugins/go/tasks/plugins/array/catalog_test.go +++ b/flyteplugins/go/tasks/plugins/array/catalog_test.go @@ -175,6 +175,20 @@ func TestDetermineDiscoverability(t *testing.T) { }, } + t.Run("Run AWS Batch single job", func(t *testing.T) { + toCache := arrayCore.InvertBitSet(bitarray.NewBitSet(1), 1) + template.Type = AwsBatchTaskType + runDetermineDiscoverabilityTest(t, template, f, &arrayCore.State{ + CurrentPhase: arrayCore.PhasePreLaunch, + PhaseVersion: core2.DefaultPhaseVersion, + ExecutionArraySize: 1, + OriginalArraySize: 1, + OriginalMinSuccesses: 1, + IndexesToCache: toCache, + Reason: "Task is not discoverable.", + }, nil) + }) + t.Run("Not discoverable", func(t *testing.T) { toCache := arrayCore.InvertBitSet(bitarray.NewBitSet(1), 1) diff --git a/flyteplugins/go/tasks/plugins/array/inputs.go b/flyteplugins/go/tasks/plugins/array/inputs.go index 248398b42c..ac0c2acb5f 100644 --- a/flyteplugins/go/tasks/plugins/array/inputs.go +++ b/flyteplugins/go/tasks/plugins/array/inputs.go @@ -20,7 +20,7 @@ func (i arrayJobInputReader) GetInputPath() storage.DataReference { } func GetInputReader(tCtx core.TaskExecutionContext, taskTemplate *idlCore.TaskTemplate) io.InputReader { - if taskTemplate.GetTaskTypeVersion() == 0 { + if taskTemplate.GetTaskTypeVersion() == 0 && taskTemplate.Type != AwsBatchTaskType { // Prior to task type version == 1, dynamic type tasks (including array tasks) would write input files for each // individual array task instance. In this case we use a modified input reader to only pass in the parent input // directory. diff --git a/flyteplugins/go/tasks/plugins/array/outputs.go b/flyteplugins/go/tasks/plugins/array/outputs.go index ab31021952..4177fcb616 100644 --- a/flyteplugins/go/tasks/plugins/array/outputs.go +++ b/flyteplugins/go/tasks/plugins/array/outputs.go @@ -44,10 +44,11 @@ func (o OutputAssembler) Queue(ctx context.Context, id workqueue.WorkItemID, ite } type outputAssembleItem struct { - outputPaths io.OutputFilePaths - varNames []string - finalPhases bitarray.CompactArray - dataStore *storage.DataStore + outputPaths io.OutputFilePaths + varNames []string + finalPhases bitarray.CompactArray + dataStore *storage.DataStore + isAwsSingleJob bool } type assembleOutputsWorker struct { @@ -75,13 +76,21 @@ func (w assembleOutputsWorker) Process(ctx context.Context, workItem workqueue.W } if executionError == nil && output != nil { - appendSubTaskOutput(finalOutputs, output, int64(i.finalPhases.ItemsCount)) - continue + if i.isAwsSingleJob { + // We will only have one output.pb when running aws single job, so we don't need + // to aggregate outputs here + finalOutputs.Literals = output.GetLiterals() + } else { + appendSubTaskOutput(finalOutputs, output, int64(i.finalPhases.ItemsCount)) + continue + } } } // TODO: Do we need the names of the outputs in the literalMap here? - appendEmptyOutputs(finalOutputs, i.varNames) + if !i.isAwsSingleJob { + appendEmptyOutputs(finalOutputs, i.varNames) + } } ow := ioutils.NewRemoteFileOutputWriter(ctx, i.dataStore, i.outputPaths) @@ -185,10 +194,11 @@ func AssembleFinalOutputs(ctx context.Context, assemblyQueue OutputAssembler, tC state.GetIndexesToCache(), uint(state.GetOriginalArraySize())) err = assemblyQueue.Queue(ctx, workItemID, &outputAssembleItem{ - varNames: varNames, - finalPhases: finalPhases, - outputPaths: tCtx.OutputWriter(), - dataStore: tCtx.DataStore(), + varNames: varNames, + finalPhases: finalPhases, + outputPaths: tCtx.OutputWriter(), + dataStore: tCtx.DataStore(), + isAwsSingleJob: taskTemplate.Type == AwsBatchTaskType, }) if err != nil { diff --git a/flyteplugins/go/tasks/plugins/array/outputs_test.go b/flyteplugins/go/tasks/plugins/array/outputs_test.go index 743a3366b2..bda6379723 100644 --- a/flyteplugins/go/tasks/plugins/array/outputs_test.go +++ b/flyteplugins/go/tasks/plugins/array/outputs_test.go @@ -98,10 +98,11 @@ func Test_assembleOutputsWorker_Process(t *testing.T) { phases.SetItem(3, bitarray.Item(pluginCore.PhasePermanentFailure)) item := &outputAssembleItem{ - outputPaths: ow, - varNames: []string{"var1", "var2"}, - finalPhases: phases, - dataStore: memStore, + outputPaths: ow, + varNames: []string{"var1", "var2"}, + finalPhases: phases, + dataStore: memStore, + isAwsSingleJob: false, } w := assembleOutputsWorker{} @@ -386,10 +387,11 @@ func Test_assembleErrorsWorker_Process(t *testing.T) { phases.SetItem(3, bitarray.Item(pluginCore.PhasePermanentFailure)) item := &outputAssembleItem{ - varNames: []string{"var1", "var2"}, - finalPhases: phases, - outputPaths: ow, - dataStore: memStore, + varNames: []string{"var1", "var2"}, + finalPhases: phases, + outputPaths: ow, + dataStore: memStore, + isAwsSingleJob: false, } w := assembleErrorsWorker{ @@ -398,6 +400,11 @@ func Test_assembleErrorsWorker_Process(t *testing.T) { actual, err := w.Process(ctx, item) assert.NoError(t, err) assert.Equal(t, workqueue.WorkStatusSucceeded, actual) + + item.isAwsSingleJob = true + actual, err = w.Process(ctx, item) + assert.NoError(t, err) + assert.Equal(t, workqueue.WorkStatusSucceeded, actual) } func TestNewOutputAssembler(t *testing.T) {