Skip to content

Commit

Permalink
Modify array task behavior based on task type version (flyteorg#152)
Browse files Browse the repository at this point in the history
  • Loading branch information
Katrina Rogan authored Mar 9, 2021
1 parent 58b0327 commit eddc8a0
Show file tree
Hide file tree
Showing 12 changed files with 279 additions and 62 deletions.
32 changes: 8 additions & 24 deletions flyteplugins/go/tasks/plugins/array/awsbatch/transformer.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,42 +5,26 @@ import (
"sort"
"time"

"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/template"

"github.com/golang/protobuf/ptypes/duration"

"k8s.io/apimachinery/pkg/api/resource"

"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io"
"github.com/flyteorg/flytestdlib/storage"

config2 "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/awsbatch/config"

"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s"
"github.com/flyteorg/flyteplugins/go/tasks/plugins/array"

"github.com/aws/aws-sdk-go/service/batch"
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
idlCore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flyteplugins/go/tasks/errors"
pluginCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/template"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s"
config2 "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/awsbatch/config"
"github.com/golang/protobuf/ptypes/duration"
v1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/resource"
)

const (
ArrayJobIndex = "BATCH_JOB_ARRAY_INDEX_VAR_NAME"
arrayJobIDFormatter = "%v:%v"
)

// A proxy inputreader that overrides the inputpath to be the inputpathprefix for array jobs
type arrayJobInputReader struct {
io.InputReader
}

// We override the inputpath to return the prefix path for array jobs
func (i arrayJobInputReader) GetInputPath() storage.DataReference {
return i.GetInputPrefixPath()
}

// Note that Name is not set on the result object.
// It's up to the caller to set the Name before creating the object in K8s.
func FlyteTaskToBatchInput(ctx context.Context, tCtx pluginCore.TaskExecutionContext, jobDefinition string, cfg *config2.Config) (
Expand Down Expand Up @@ -70,9 +54,9 @@ func FlyteTaskToBatchInput(ctx context.Context, tCtx pluginCore.TaskExecutionCon
if err != nil {
return nil, err
}

inputReader := array.GetInputReader(tCtx, taskTemplate)
args, err := template.ReplaceTemplateCommandArgs(ctx, tCtx.TaskExecutionMetadata(), taskTemplate.GetContainer().GetArgs(),
arrayJobInputReader{tCtx.InputReader()}, tCtx.OutputWriter())
inputReader, tCtx.OutputWriter())
taskTemplate.GetContainer().GetEnv()
if err != nil {
return nil, err
Expand Down
61 changes: 47 additions & 14 deletions flyteplugins/go/tasks/plugins/array/catalog.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package array
import (
"context"
"fmt"
"math"
"strconv"

arrayCore "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/core"
Expand Down Expand Up @@ -37,34 +38,66 @@ func DetermineDiscoverability(ctx context.Context, tCtx core.TaskExecutionContex
}

// Extract the custom plugin pb
arrayJob, err := arrayCore.ToArrayJob(taskTemplate.GetCustom())
arrayJob, err := arrayCore.ToArrayJob(taskTemplate.GetCustom(), taskTemplate.TaskTypeVersion)
if err != nil {
return state, err
}

var arrayJobSize int64

// Save this in the state
state = state.SetOriginalArraySize(arrayJob.Size)
state = state.SetOriginalMinSuccesses(arrayJob.GetMinSuccesses())
if taskTemplate.TaskTypeVersion == 0 {
state = state.SetOriginalArraySize(arrayJob.Size)
arrayJobSize = arrayJob.Size
state = state.SetOriginalMinSuccesses(arrayJob.GetMinSuccesses())
} else {
inputs, err := tCtx.InputReader().Get(ctx)
if err != nil {
return state, errors.Errorf(errors.MetadataAccessFailed, "Could not read inputs and therefore failed to determine array job size")
}
size := 0
for _, literal := range inputs.Literals {
if literal.GetCollection() != nil {
size = len(literal.GetCollection().Literals)
break
}
}
if size == 0 {
// Something is wrong, we should have inferred the array size when it is not specified by the size of the
// input collection (for any input value). Non-collection type inputs are not currently supported for
// taskTypeVersion > 0.
return state, errors.Errorf(errors.BadTaskSpecification, "Unable to determine array size from inputs")
}
minSuccesses := math.Ceil(float64(arrayJob.GetMinSuccessRatio()) * float64(size))

logger.Debugf(ctx, "Computed state: size [%d] and minSuccesses [%d]", int64(size), int64(minSuccesses))
state = state.SetOriginalArraySize(int64(size))
// We can cast the min successes because we already computed the ceiling value from the ratio
state = state.SetOriginalMinSuccesses(int64(minSuccesses))

arrayJobSize = int64(size)
}

// If the task is not discoverable, then skip data catalog work and move directly to launch
if taskTemplate.Metadata == nil || !taskTemplate.Metadata.Discoverable {
logger.Infof(ctx, "Task is not discoverable, moving to launch phase...")
// Set an all set indexes to cache. This task won't try to write to catalog anyway.
state = state.SetIndexesToCache(arrayCore.InvertBitSet(bitarray.NewBitSet(uint(arrayJob.Size)), uint(arrayJob.Size)))
state = state.SetExecutionArraySize(int(arrayJob.Size))
state = state.SetIndexesToCache(arrayCore.InvertBitSet(bitarray.NewBitSet(uint(arrayJobSize)), uint(arrayJobSize)))
state = state.SetPhase(arrayCore.PhasePreLaunch, core.DefaultPhaseVersion).SetReason("Task is not discoverable.")

state.SetExecutionArraySize(int(arrayJobSize))
return state, nil
}

// Otherwise, run the data catalog steps - create and submit work items to the catalog processor,
// build input readers
inputReaders, err := ConstructInputReaders(ctx, tCtx.DataStore(), tCtx.InputReader().GetInputPrefixPath(), int(arrayJob.Size))
inputReaders, err := ConstructInputReaders(ctx, tCtx.DataStore(), tCtx.InputReader().GetInputPrefixPath(), int(arrayJobSize))
if err != nil {
return state, err
}

// build output writers
outputWriters, err := ConstructOutputWriters(ctx, tCtx.DataStore(), tCtx.OutputWriter().GetOutputPrefixPath(), tCtx.OutputWriter().GetRawOutputPrefix(), int(arrayJob.Size))
outputWriters, err := ConstructOutputWriters(ctx, tCtx.DataStore(), tCtx.OutputWriter().GetOutputPrefixPath(), tCtx.OutputWriter().GetRawOutputPrefix(), int(arrayJobSize))
if err != nil {
return state, err
}
Expand All @@ -87,8 +120,8 @@ func DetermineDiscoverability(ctx context.Context, tCtx core.TaskExecutionContex
// TODO: maybe add a config option to decide the behavior on catalog failure.
logger.Warnf(ctx, "Failing to lookup catalog. Will move on to launching the task. Error: %v", err)

state = state.SetIndexesToCache(arrayCore.InvertBitSet(bitarray.NewBitSet(uint(arrayJob.Size)), uint(arrayJob.Size)))
state = state.SetExecutionArraySize(int(arrayJob.Size))
state = state.SetIndexesToCache(arrayCore.InvertBitSet(bitarray.NewBitSet(uint(arrayJobSize)), uint(arrayJobSize)))
state = state.SetExecutionArraySize(int(arrayJobSize))
state = state.SetPhase(arrayCore.PhasePreLaunch, core.DefaultPhaseVersion).SetReason(fmt.Sprintf("Skipping cache check due to err [%v]", err))
return state, nil
}
Expand All @@ -100,11 +133,11 @@ func DetermineDiscoverability(ctx context.Context, tCtx core.TaskExecutionContex
}

cachedResults := resp.GetCachedResults()
state = state.SetIndexesToCache(arrayCore.InvertBitSet(cachedResults, uint(arrayJob.Size)))
state = state.SetExecutionArraySize(int(arrayJob.Size) - resp.GetCachedCount())
state = state.SetIndexesToCache(arrayCore.InvertBitSet(cachedResults, uint(arrayJobSize)))
state = state.SetExecutionArraySize(int(arrayJobSize) - resp.GetCachedCount())

// If all the sub-tasks are actually done, then we can just move on.
if resp.GetCachedCount() == int(arrayJob.Size) {
if resp.GetCachedCount() == int(arrayJobSize) {
state.SetPhase(arrayCore.PhaseAssembleFinalOutput, core.DefaultPhaseVersion).SetReason("All subtasks are cached. assembling final outputs.")
return state, nil
}
Expand All @@ -117,7 +150,7 @@ func DetermineDiscoverability(ctx context.Context, tCtx core.TaskExecutionContex
}

logger.Infof(ctx, "Writing indexlookup file to [%s], cached count [%d/%d], ",
indexLookupPath, resp.GetCachedCount(), arrayJob.Size)
indexLookupPath, resp.GetCachedCount(), arrayJobSize)
err = tCtx.DataStore().WriteProtobuf(ctx, indexLookupPath, storage.Options{}, indexLookup)
if err != nil {
return state, err
Expand Down Expand Up @@ -150,7 +183,7 @@ func WriteToDiscovery(ctx context.Context, tCtx core.TaskExecutionContext, state
}

// Extract the custom plugin pb
arrayJob, err := arrayCore.ToArrayJob(taskTemplate.GetCustom())
arrayJob, err := arrayCore.ToArrayJob(taskTemplate.GetCustom(), taskTemplate.TaskTypeVersion)
if err != nil {
return state, err
} else if arrayJob == nil {
Expand Down
92 changes: 92 additions & 0 deletions flyteplugins/go/tasks/plugins/array/catalog_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ import (
"errors"
"testing"

"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/utils"
structpb "github.com/golang/protobuf/ptypes/struct"

stdErrors "github.com/flyteorg/flytestdlib/errors"

pluginErrors "github.com/flyteorg/flyteplugins/go/tasks/errors"
Expand Down Expand Up @@ -66,6 +70,32 @@ func runDetermineDiscoverabilityTest(t testing.TB, taskTemplate *core.TaskTempla

ir := &ioMocks.InputReader{}
ir.OnGetInputPrefixPath().Return("/prefix/")
dummyInputLiteral := &core.Literal{
Value: &core.Literal_Scalar{
Scalar: &core.Scalar{
Value: &core.Scalar_Primitive{
Primitive: &core.Primitive{
Value: &core.Primitive_Integer{
Integer: 3,
},
},
},
},
},
}
ir.On("Get", mock.Anything).Return(&core.LiteralMap{
Literals: map[string]*core.Literal{
"foo": {
Value: &core.Literal_Collection{
Collection: &core.LiteralCollection{
Literals: []*core.Literal{
dummyInputLiteral, dummyInputLiteral, dummyInputLiteral,
},
},
},
},
},
}, nil)

ow := &ioMocks.OutputWriter{}
ow.OnGetOutputPrefixPath().Return("/prefix/")
Expand Down Expand Up @@ -198,3 +228,65 @@ func TestDetermineDiscoverability(t *testing.T) {
}, nil)
})
}

func TestDiscoverabilityTaskType1(t *testing.T) {

download := &catalogMocks.DownloadResponse{}
download.OnGetCachedCount().Return(0)
download.OnGetResultsSize().Return(1)

f := &catalogMocks.DownloadFuture{}
f.OnGetResponseStatus().Return(catalog.ResponseStatusReady)
f.OnGetResponseError().Return(nil)
f.OnGetResponse().Return(download, nil)

t.Run("Not discoverable", func(t *testing.T) {
download.OnGetCachedResults().Return(bitarray.NewBitSet(1)).Once()
toCache := arrayCore.InvertBitSet(bitarray.NewBitSet(uint(3)), uint(3))

arrayJob := &plugins.ArrayJob{
SuccessCriteria: &plugins.ArrayJob_MinSuccessRatio{
MinSuccessRatio: 0.5,
},
}
var arrayJobCustom structpb.Struct
err := utils.MarshalStruct(arrayJob, &arrayJobCustom)
assert.NoError(t, err)
templateType1 := &core.TaskTemplate{
Id: &core.Identifier{
ResourceType: core.ResourceType_TASK,
Project: "p",
Domain: "d",
Name: "n",
Version: "1",
},
Interface: &core.TypedInterface{
Inputs: &core.VariableMap{Variables: map[string]*core.Variable{
"foo": {
Description: "foo",
},
}},
Outputs: &core.VariableMap{Variables: map[string]*core.Variable{}},
},
Target: &core.TaskTemplate_Container{
Container: &core.Container{
Command: []string{"cmd"},
Args: []string{"{{$inputPrefix}}"},
Image: "img1",
},
},
TaskTypeVersion: 1,
Custom: &arrayJobCustom,
}

runDetermineDiscoverabilityTest(t, templateType1, f, &arrayCore.State{
CurrentPhase: arrayCore.PhasePreLaunch,
PhaseVersion: core2.DefaultPhaseVersion,
ExecutionArraySize: 3,
OriginalArraySize: 3,
OriginalMinSuccesses: 2,
IndexesToCache: toCache,
Reason: "Task is not discoverable.",
}, nil)
})
}
16 changes: 13 additions & 3 deletions flyteplugins/go/tasks/plugins/array/core/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,23 @@ const (
ErrorK8sArrayGeneric errors.ErrorCode = "ARRAY_JOB_GENERIC_FAILURE"
)

func ToArrayJob(structObj *structpb.Struct) (*idlPlugins.ArrayJob, error) {
func ToArrayJob(structObj *structpb.Struct, taskTypeVersion int32) (*idlPlugins.ArrayJob, error) {
if structObj == nil {
if taskTypeVersion == 0 {

return &idlPlugins.ArrayJob{
Parallelism: 1,
Size: 1,
SuccessCriteria: &idlPlugins.ArrayJob_MinSuccesses{
MinSuccesses: 1,
},
}, nil
}
return &idlPlugins.ArrayJob{
Parallelism: 1,
Size: 1,
SuccessCriteria: &idlPlugins.ArrayJob_MinSuccesses{
MinSuccesses: 1,
SuccessCriteria: &idlPlugins.ArrayJob_MinSuccessRatio{
MinSuccessRatio: 1.0,
},
}, nil
}
Expand Down
29 changes: 29 additions & 0 deletions flyteplugins/go/tasks/plugins/array/core/state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ import (
"context"
"testing"

"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins"
"github.com/golang/protobuf/proto"

"github.com/flyteorg/flytestdlib/bitarray"

"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"
Expand Down Expand Up @@ -213,3 +216,29 @@ func Test_calculateOriginalIndex(t *testing.T) {
}
})
}

func TestToArrayJob(t *testing.T) {
t.Run("task_type_version == 0", func(t *testing.T) {
arrayJob, err := ToArrayJob(nil, 0)
assert.NoError(t, err)
assert.True(t, proto.Equal(arrayJob, &plugins.ArrayJob{
Parallelism: 1,
Size: 1,
SuccessCriteria: &plugins.ArrayJob_MinSuccesses{
MinSuccesses: 1,
},
}))
})

t.Run("task_type_version == 1", func(t *testing.T) {
arrayJob, err := ToArrayJob(nil, 1)
assert.NoError(t, err)
assert.True(t, proto.Equal(arrayJob, &plugins.ArrayJob{
Parallelism: 1,
Size: 1,
SuccessCriteria: &plugins.ArrayJob_MinSuccessRatio{
MinSuccessRatio: 1.0,
},
}))
})
}
31 changes: 31 additions & 0 deletions flyteplugins/go/tasks/plugins/array/inputs.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package array

import (
idlCore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io"
"github.com/flyteorg/flytestdlib/storage"
)

// A proxy inputreader that overrides the inputpath to be the inputpathprefix for array jobs
type arrayJobInputReader struct {
io.InputReader
}

// We override the inputpath to return the prefix path for array jobs
func (i arrayJobInputReader) GetInputPath() storage.DataReference {
return i.GetInputPrefixPath()
}

func GetInputReader(tCtx core.TaskExecutionContext, taskTemplate *idlCore.TaskTemplate) io.InputReader {
var inputReader io.InputReader
if taskTemplate.GetTaskTypeVersion() == 0 {
// Prior to task type version == 1, dynamic type tasks (including array tasks) would write input files for each
// individual array task instance. In this case we use a modified input reader to only pass in the parent input
// directory.
inputReader = arrayJobInputReader{tCtx.InputReader()}
} else {
inputReader = tCtx.InputReader()
}
return inputReader
}
Loading

0 comments on commit eddc8a0

Please sign in to comment.