Skip to content

Commit

Permalink
Add Raw AWS Batch Task (flyteorg#228)
Browse files Browse the repository at this point in the history
* Add Raw AWS Batch Task

Signed-off-by: Kevin Su <[email protected]>

* Fix test

Signed-off-by: Kevin Su <[email protected]>

* Fix lint

Signed-off-by: Kevin Su <[email protected]>

* Fix lint

Signed-off-by: Kevin Su <[email protected]>

* Add tests

Signed-off-by: Kevin Su <[email protected]>

* Fix lint

Signed-off-by: Kevin Su <[email protected]>

* Remove log

Signed-off-by: Kevin Su <[email protected]>

* Updated tests and added comment

Signed-off-by: Kevin Su <[email protected]>

* address comment

Signed-off-by: Kevin Su <[email protected]>

* Fix tests

Signed-off-by: Kevin Su <[email protected]>

* Convert pb to string

Signed-off-by: Kevin Su <[email protected]>

* Use job definition as cache key

Signed-off-by: Kevin Su <[email protected]>

* Fixed tests

Signed-off-by: Kevin Su <[email protected]>

* Fixed tests

Signed-off-by: Kevin Su <[email protected]>

* One more test

Signed-off-by: Kevin Su <[email protected]>

* Hash job definition

Signed-off-by: Kevin Su <[email protected]>

* address comments

Signed-off-by: Kevin Su <[email protected]>

* Updated dependency

Signed-off-by: Kevin Su <[email protected]>

* Fixed test

Signed-off-by: Kevin Su <[email protected]>

* lint fixed

Signed-off-by: Kevin Su <[email protected]>

* Reorder

Signed-off-by: Kevin Su <[email protected]>
  • Loading branch information
pingsutw authored Feb 7, 2022
1 parent e1345cd commit e9688b8
Show file tree
Hide file tree
Showing 14 changed files with 143 additions and 67 deletions.
23 changes: 10 additions & 13 deletions flyteplugins/go/tasks/plugins/array/awsbatch/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,13 @@ import (
"context"
"fmt"

definition2 "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/awsbatch/definition"

"github.com/flyteorg/flyteplugins/go/tasks/aws"
"github.com/flyteorg/flytestdlib/utils"

"github.com/flyteorg/flytestdlib/logger"

a "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/batch"
"github.com/flyteorg/flyteplugins/go/tasks/aws"
definition2 "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/awsbatch/definition"
"github.com/flyteorg/flytestdlib/logger"
"github.com/flyteorg/flytestdlib/utils"
)

//go:generate mockery -all -case=underscore
Expand All @@ -35,7 +32,7 @@ type Client interface {
GetJobDetailsBatch(ctx context.Context, ids []JobID) ([]*batch.JobDetail, error)

// Registers a new Job Definition with AWS Batch provided a name, image and role.
RegisterJobDefinition(ctx context.Context, name, image, role string) (arn string, err error)
RegisterJobDefinition(ctx context.Context, name, image, role string, platformCapabilities string) (arn string, err error)

// Gets the single region this client interacts with.
GetRegion() string
Expand Down Expand Up @@ -68,12 +65,13 @@ func (b client) GetAccountID() string {
}

// Registers a new job definition. There is no deduping on AWS side (even for the same name).
func (b *client) RegisterJobDefinition(ctx context.Context, name, image, role string) (arn definition2.JobDefinitionArn, err error) {
logger.Infof(ctx, "Registering job definition with name [%v], image [%v], role [%v]", name, image, role)
func (b *client) RegisterJobDefinition(ctx context.Context, name, image, role string, platformCapabilities string) (arn definition2.JobDefinitionArn, err error) {
logger.Infof(ctx, "Registering job definition with name [%v], image [%v], role [%v], platformCapabilities [%v]", name, image, role, platformCapabilities)

res, err := b.Batch.RegisterJobDefinitionWithContext(ctx, &batch.RegisterJobDefinitionInput{
Type: refStr(batch.JobDefinitionTypeContainer),
JobDefinitionName: refStr(name),
Type: refStr(batch.JobDefinitionTypeContainer),
JobDefinitionName: refStr(name),
PlatformCapabilities: refStrSlice([]string{platformCapabilities}),
ContainerProperties: &batch.ContainerProperties{
Image: refStr(image),
JobRoleArn: refStr(role),
Expand All @@ -83,7 +81,6 @@ func (b *client) RegisterJobDefinition(ctx context.Context, name, image, role st
Memory: refInt(100),
},
})

if err != nil {
return "", err
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ func TestClient_GetJobDetailsBatch(t *testing.T) {

func TestClient_RegisterJobDefinition(t *testing.T) {
c := newClientWithMockBatch()
j, err := c.RegisterJobDefinition(context.TODO(), "name-abc", "img", "admin-role")
j, err := c.RegisterJobDefinition(context.TODO(), "name-abc", "img", "admin-role", defaultComputeEngine)
assert.NoError(t, err)
assert.NotNil(t, j)
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,13 @@ type CacheKey interface {
}

type cacheKey struct {
role string
image string
role string
image string
platformCapabilities string
}

func (k cacheKey) String() string {
return fmt.Sprintf("%v-%v", k.image, k.role)
return fmt.Sprintf("%v-%v-%v", k.image, k.role, k.platformCapabilities)
}

type cache struct {
Expand All @@ -52,10 +53,11 @@ func (c cache) Put(key CacheKey, definition JobDefinitionArn) error {
}

// Creates a new deterministic cache key.
func NewCacheKey(role, image string) CacheKey {
func NewCacheKey(role, image, platformCapabilities string) CacheKey {
return cacheKey{
role: role,
image: image,
role: role,
image: image,
platformCapabilities: platformCapabilities,
}
}

Expand Down
4 changes: 2 additions & 2 deletions flyteplugins/go/tasks/plugins/array/awsbatch/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,10 +189,10 @@ func init() {
pluginmachinery.PluginRegistry().RegisterCorePlugin(
core.PluginEntry{
ID: executorName,
RegisteredTaskTypes: []core.TaskType{arrayTaskType},
RegisteredTaskTypes: []core.TaskType{arrayTaskType, array.AwsBatchTaskType},
LoadPlugin: createNewExecutorPlugin,
IsDefault: false,
DefaultForTaskTypes: []core.TaskType{arrayTaskType},
DefaultForTaskTypes: []core.TaskType{arrayTaskType, array.AwsBatchTaskType},
})
}

Expand Down
15 changes: 11 additions & 4 deletions flyteplugins/go/tasks/plugins/array/awsbatch/job_definition.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ import (
"github.com/flyteorg/flyteplugins/go/tasks/plugins/array/awsbatch/definition"
)

const defaultComputeEngine = "EC2"
const platformCapabilitiesConfigKey = "platformCapabilities"

func getContainerImage(_ context.Context, task *core.TaskTemplate) string {
if task.GetContainer() != nil && len(task.GetContainer().Image) > 0 {
return task.GetContainer().Image
Expand Down Expand Up @@ -51,11 +54,15 @@ func EnsureJobDefinition(ctx context.Context, tCtx pluginCore.TaskExecutionConte
}

role := awsUtils.GetRoleFromSecurityContext(cfg.RoleAnnotationKey, tCtx.TaskExecutionMetadata())
platformCapabilities := taskTemplate.GetConfig()[platformCapabilitiesConfigKey]
if len(platformCapabilities) == 0 {
platformCapabilities = defaultComputeEngine
}

cacheKey := definition.NewCacheKey(role, containerImage)
cacheKey := definition.NewCacheKey(role, containerImage, platformCapabilities)
if existingArn, found := definitionCache.Get(cacheKey); found {
logger.Infof(ctx, "Found an existing job definition for Image [%v] and Role [%v]. Arn [%v]",
containerImage, role, existingArn)
logger.Infof(ctx, "Found an existing job definition for Image [%v], Role [%v], JobDefinitionInput [%v]. Arn [%v]",
containerImage, role, platformCapabilities, existingArn)

nextState = currentState.SetJobDefinitionArn(existingArn)
nextState.State = nextState.SetPhase(arrayCore.PhaseLaunch, 0).SetReason("AWS job definition already exist.")
Expand All @@ -64,7 +71,7 @@ func EnsureJobDefinition(ctx context.Context, tCtx pluginCore.TaskExecutionConte

name := definition.GetJobDefinitionSafeName(containerImageRepository(containerImage))

arn, err := client.RegisterJobDefinition(ctx, name, containerImage, role)
arn, err := client.RegisterJobDefinition(ctx, name, containerImage, role, platformCapabilities)
if err != nil {
return currentState, err
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ func TestEnsureJobDefinition(t *testing.T) {

t.Run("Found", func(t *testing.T) {
dCache := definition.NewCache(10)
assert.NoError(t, dCache.Put(definition.NewCacheKey("", "img1"), "their-arn"))
assert.NoError(t, dCache.Put(definition.NewCacheKey("", "img1", defaultComputeEngine), "their-arn"))

nextState, err := EnsureJobDefinition(ctx, tCtx, cfg, batchClient, dCache, &State{
State: &arrayCore.State{},
Expand All @@ -100,6 +100,11 @@ func TestEnsureJobDefinition(t *testing.T) {
assert.NotNil(t, nextState)
assert.Equal(t, "their-arn", nextState.JobDefinitionArn)
})

t.Run("Test New Cache Key", func(t *testing.T) {
cacheKey := definition.NewCacheKey("default", "img1", defaultComputeEngine)
assert.Equal(t, cacheKey.String(), "img1-default-EC2")
})
}

func TestEnsureJobDefinitionWithSecurityContext(t *testing.T) {
Expand All @@ -115,6 +120,7 @@ func TestEnsureJobDefinitionWithSecurityContext(t *testing.T) {
Target: &core.TaskTemplate_Container{
Container: createSampleContainerTask(),
},
Config: map[string]string{platformCapabilitiesConfigKey: defaultComputeEngine},
}, nil)

overrides := &mocks.TaskOverrides{}
Expand Down Expand Up @@ -158,7 +164,7 @@ func TestEnsureJobDefinitionWithSecurityContext(t *testing.T) {

t.Run("Found", func(t *testing.T) {
dCache := definition.NewCache(10)
assert.NoError(t, dCache.Put(definition.NewCacheKey("new-role", "img1"), "their-arn"))
assert.NoError(t, dCache.Put(definition.NewCacheKey("new-role", "img1", defaultComputeEngine), "their-arn"))

nextState, err := EnsureJobDefinition(ctx, tCtx, cfg, batchClient, dCache, &State{
State: &arrayCore.State{},
Expand Down
18 changes: 9 additions & 9 deletions flyteplugins/go/tasks/plugins/array/awsbatch/mocks/client.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

25 changes: 17 additions & 8 deletions flyteplugins/go/tasks/plugins/array/awsbatch/transformer.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"sort"
"time"

"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/utils"

"github.com/flyteorg/flyteplugins/go/tasks/plugins/array"

"github.com/aws/aws-sdk-go/service/batch"
Expand Down Expand Up @@ -85,14 +87,21 @@ func FlyteTaskToBatchInput(ctx context.Context, tCtx pluginCore.TaskExecutionCon
}
resources := flytek8s.ApplyResourceOverrides(*res, *platformResources, assignResources)

return &batch.SubmitJobInput{
JobName: refStr(tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName()),
JobDefinition: refStr(jobDefinition),
JobQueue: refStr(jobConfig.DynamicTaskQueue),
RetryStrategy: toRetryStrategy(ctx, toBackoffLimit(taskTemplate.Metadata), cfg.MinRetries, cfg.MaxRetries),
ContainerOverrides: toContainerOverrides(ctx, append(cmd, args...), &resources, envVars),
Timeout: toTimeout(taskTemplate.Metadata.GetTimeout(), cfg.DefaultTimeOut.Duration),
}, nil
submitJobInput := &batch.SubmitJobInput{}
if taskTemplate.GetCustom() != nil {
err = utils.UnmarshalStructToObj(taskTemplate.GetCustom(), &submitJobInput)
if err != nil {
return nil, errors.Errorf(errors.BadTaskSpecification,
"invalid TaskSpecification [%v], Err: [%v]", taskTemplate.GetCustom(), err.Error())
}
}
submitJobInput.SetJobName(tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName()).
SetJobDefinition(jobDefinition).SetJobQueue(jobConfig.DynamicTaskQueue).
SetRetryStrategy(toRetryStrategy(ctx, toBackoffLimit(taskTemplate.Metadata), cfg.MinRetries, cfg.MaxRetries)).
SetContainerOverrides(toContainerOverrides(ctx, append(cmd, args...), &resources, envVars)).
SetTimeout(toTimeout(taskTemplate.Metadata.GetTimeout(), cfg.DefaultTimeOut.Duration))

return submitJobInput, nil
}

func UpdateBatchInputForArray(_ context.Context, batchInput *batch.SubmitJobInput, arraySize int64) *batch.SubmitJobInput {
Expand Down
11 changes: 10 additions & 1 deletion flyteplugins/go/tasks/plugins/array/awsbatch/transformer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (

"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/utils"
"github.com/flyteorg/flyteplugins/go/tasks/plugins/array"
"github.com/flyteorg/flyteplugins/go/tasks/plugins/array/awsbatch/config"

v12 "k8s.io/api/core/v1"
Expand Down Expand Up @@ -191,11 +192,11 @@ func TestArrayJobToBatchInput(t *testing.T) {
Target: &core.TaskTemplate_Container{
Container: createSampleContainerTask(),
},
Type: arrayTaskType,
}

tr := &mocks.TaskReader{}
tr.OnReadMatch(mock.Anything).Return(taskTemplate, nil)

taskCtx.OnTaskReader().Return(tr)

ctx := context.Background()
Expand All @@ -205,6 +206,14 @@ func TestArrayJobToBatchInput(t *testing.T) {
batchInput = UpdateBatchInputForArray(ctx, batchInput, input.Size)
assert.NotNil(t, batchInput)
assert.Equal(t, *expectedBatchInput, *batchInput)

taskTemplate.Type = array.AwsBatchTaskType
tr.OnReadMatch(mock.Anything).Return(taskTemplate, nil)
taskCtx.OnTaskReader().Return(tr)

ctx = context.Background()
_, err = FlyteTaskToBatchInput(ctx, taskCtx, "", &config.Config{})
assert.NoError(t, err)
}

func Test_getEnvVarsForTask(t *testing.T) {
Expand Down
17 changes: 16 additions & 1 deletion flyteplugins/go/tasks/plugins/array/catalog.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"math"
"strconv"

idlPlugins "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins"

arrayCore "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/core"

"github.com/flyteorg/flytestdlib/bitarray"
Expand All @@ -21,6 +23,8 @@ import (
idlCore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
)

const AwsBatchTaskType = "aws-batch"

// DetermineDiscoverability checks if there are any previously cached tasks. If there are we will only submit an
// ArrayJob for the non-cached tasks. The ArrayJob is now a different size, and each task will get a new index location
// which is different than their original location. To find the original index we construct an indexLookup array.
Expand All @@ -38,7 +42,18 @@ func DetermineDiscoverability(ctx context.Context, tCtx core.TaskExecutionContex
}

// Extract the custom plugin pb
arrayJob, err := arrayCore.ToArrayJob(taskTemplate.GetCustom(), taskTemplate.TaskTypeVersion)
var arrayJob *idlPlugins.ArrayJob
if taskTemplate.Type == AwsBatchTaskType {
arrayJob = &idlPlugins.ArrayJob{
Parallelism: 1,
Size: 1,
SuccessCriteria: &idlPlugins.ArrayJob_MinSuccesses{
MinSuccesses: 1,
},
}
} else {
arrayJob, err = arrayCore.ToArrayJob(taskTemplate.GetCustom(), taskTemplate.TaskTypeVersion)
}
if err != nil {
return state, err
}
Expand Down
14 changes: 14 additions & 0 deletions flyteplugins/go/tasks/plugins/array/catalog_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,20 @@ func TestDetermineDiscoverability(t *testing.T) {
},
}

t.Run("Run AWS Batch single job", func(t *testing.T) {
toCache := arrayCore.InvertBitSet(bitarray.NewBitSet(1), 1)
template.Type = AwsBatchTaskType
runDetermineDiscoverabilityTest(t, template, f, &arrayCore.State{
CurrentPhase: arrayCore.PhasePreLaunch,
PhaseVersion: core2.DefaultPhaseVersion,
ExecutionArraySize: 1,
OriginalArraySize: 1,
OriginalMinSuccesses: 1,
IndexesToCache: toCache,
Reason: "Task is not discoverable.",
}, nil)
})

t.Run("Not discoverable", func(t *testing.T) {
toCache := arrayCore.InvertBitSet(bitarray.NewBitSet(1), 1)

Expand Down
2 changes: 1 addition & 1 deletion flyteplugins/go/tasks/plugins/array/inputs.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func (i arrayJobInputReader) GetInputPath() storage.DataReference {
}

func GetInputReader(tCtx core.TaskExecutionContext, taskTemplate *idlCore.TaskTemplate) io.InputReader {
if taskTemplate.GetTaskTypeVersion() == 0 {
if taskTemplate.GetTaskTypeVersion() == 0 && taskTemplate.Type != AwsBatchTaskType {
// Prior to task type version == 1, dynamic type tasks (including array tasks) would write input files for each
// individual array task instance. In this case we use a modified input reader to only pass in the parent input
// directory.
Expand Down
Loading

0 comments on commit e9688b8

Please sign in to comment.