diff --git a/flyteplugins/go/tasks/plugins/array/awsbatch/transformer.go b/flyteplugins/go/tasks/plugins/array/awsbatch/transformer.go index 3177935417..597f6d9888 100644 --- a/flyteplugins/go/tasks/plugins/array/awsbatch/transformer.go +++ b/flyteplugins/go/tasks/plugins/array/awsbatch/transformer.go @@ -5,25 +5,19 @@ import ( "sort" "time" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/template" - - "github.com/golang/protobuf/ptypes/duration" - - "k8s.io/apimachinery/pkg/api/resource" - - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io" - "github.com/flyteorg/flytestdlib/storage" - - config2 "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/awsbatch/config" - - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s" + "github.com/flyteorg/flyteplugins/go/tasks/plugins/array" "github.com/aws/aws-sdk-go/service/batch" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" idlCore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyteplugins/go/tasks/errors" pluginCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/template" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s" + config2 "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/awsbatch/config" + "github.com/golang/protobuf/ptypes/duration" v1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/resource" ) const ( @@ -31,16 +25,6 @@ const ( arrayJobIDFormatter = "%v:%v" ) -// A proxy inputreader that overrides the inputpath to be the inputpathprefix for array jobs -type arrayJobInputReader struct { - io.InputReader -} - -// We override the inputpath to return the prefix path for array jobs -func (i arrayJobInputReader) GetInputPath() storage.DataReference { - return i.GetInputPrefixPath() -} - // Note that Name is not set on the result object. // It's up to the caller to set the Name before creating the object in K8s. func FlyteTaskToBatchInput(ctx context.Context, tCtx pluginCore.TaskExecutionContext, jobDefinition string, cfg *config2.Config) ( @@ -70,9 +54,9 @@ func FlyteTaskToBatchInput(ctx context.Context, tCtx pluginCore.TaskExecutionCon if err != nil { return nil, err } - + inputReader := array.GetInputReader(tCtx, taskTemplate) args, err := template.ReplaceTemplateCommandArgs(ctx, tCtx.TaskExecutionMetadata(), taskTemplate.GetContainer().GetArgs(), - arrayJobInputReader{tCtx.InputReader()}, tCtx.OutputWriter()) + inputReader, tCtx.OutputWriter()) taskTemplate.GetContainer().GetEnv() if err != nil { return nil, err diff --git a/flyteplugins/go/tasks/plugins/array/catalog.go b/flyteplugins/go/tasks/plugins/array/catalog.go index 932086ded5..6dc10f4d07 100644 --- a/flyteplugins/go/tasks/plugins/array/catalog.go +++ b/flyteplugins/go/tasks/plugins/array/catalog.go @@ -3,6 +3,7 @@ package array import ( "context" "fmt" + "math" "strconv" arrayCore "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/core" @@ -37,34 +38,66 @@ func DetermineDiscoverability(ctx context.Context, tCtx core.TaskExecutionContex } // Extract the custom plugin pb - arrayJob, err := arrayCore.ToArrayJob(taskTemplate.GetCustom()) + arrayJob, err := arrayCore.ToArrayJob(taskTemplate.GetCustom(), taskTemplate.TaskTypeVersion) if err != nil { return state, err } + var arrayJobSize int64 + // Save this in the state - state = state.SetOriginalArraySize(arrayJob.Size) - state = state.SetOriginalMinSuccesses(arrayJob.GetMinSuccesses()) + if taskTemplate.TaskTypeVersion == 0 { + state = state.SetOriginalArraySize(arrayJob.Size) + arrayJobSize = arrayJob.Size + state = state.SetOriginalMinSuccesses(arrayJob.GetMinSuccesses()) + } else { + inputs, err := tCtx.InputReader().Get(ctx) + if err != nil { + return state, errors.Errorf(errors.MetadataAccessFailed, "Could not read inputs and therefore failed to determine array job size") + } + size := 0 + for _, literal := range inputs.Literals { + if literal.GetCollection() != nil { + size = len(literal.GetCollection().Literals) + break + } + } + if size == 0 { + // Something is wrong, we should have inferred the array size when it is not specified by the size of the + // input collection (for any input value). Non-collection type inputs are not currently supported for + // taskTypeVersion > 0. + return state, errors.Errorf(errors.BadTaskSpecification, "Unable to determine array size from inputs") + } + minSuccesses := math.Ceil(float64(arrayJob.GetMinSuccessRatio()) * float64(size)) + + logger.Debugf(ctx, "Computed state: size [%d] and minSuccesses [%d]", int64(size), int64(minSuccesses)) + state = state.SetOriginalArraySize(int64(size)) + // We can cast the min successes because we already computed the ceiling value from the ratio + state = state.SetOriginalMinSuccesses(int64(minSuccesses)) + + arrayJobSize = int64(size) + } // If the task is not discoverable, then skip data catalog work and move directly to launch if taskTemplate.Metadata == nil || !taskTemplate.Metadata.Discoverable { logger.Infof(ctx, "Task is not discoverable, moving to launch phase...") // Set an all set indexes to cache. This task won't try to write to catalog anyway. - state = state.SetIndexesToCache(arrayCore.InvertBitSet(bitarray.NewBitSet(uint(arrayJob.Size)), uint(arrayJob.Size))) - state = state.SetExecutionArraySize(int(arrayJob.Size)) + state = state.SetIndexesToCache(arrayCore.InvertBitSet(bitarray.NewBitSet(uint(arrayJobSize)), uint(arrayJobSize))) state = state.SetPhase(arrayCore.PhasePreLaunch, core.DefaultPhaseVersion).SetReason("Task is not discoverable.") + + state.SetExecutionArraySize(int(arrayJobSize)) return state, nil } // Otherwise, run the data catalog steps - create and submit work items to the catalog processor, // build input readers - inputReaders, err := ConstructInputReaders(ctx, tCtx.DataStore(), tCtx.InputReader().GetInputPrefixPath(), int(arrayJob.Size)) + inputReaders, err := ConstructInputReaders(ctx, tCtx.DataStore(), tCtx.InputReader().GetInputPrefixPath(), int(arrayJobSize)) if err != nil { return state, err } // build output writers - outputWriters, err := ConstructOutputWriters(ctx, tCtx.DataStore(), tCtx.OutputWriter().GetOutputPrefixPath(), tCtx.OutputWriter().GetRawOutputPrefix(), int(arrayJob.Size)) + outputWriters, err := ConstructOutputWriters(ctx, tCtx.DataStore(), tCtx.OutputWriter().GetOutputPrefixPath(), tCtx.OutputWriter().GetRawOutputPrefix(), int(arrayJobSize)) if err != nil { return state, err } @@ -87,8 +120,8 @@ func DetermineDiscoverability(ctx context.Context, tCtx core.TaskExecutionContex // TODO: maybe add a config option to decide the behavior on catalog failure. logger.Warnf(ctx, "Failing to lookup catalog. Will move on to launching the task. Error: %v", err) - state = state.SetIndexesToCache(arrayCore.InvertBitSet(bitarray.NewBitSet(uint(arrayJob.Size)), uint(arrayJob.Size))) - state = state.SetExecutionArraySize(int(arrayJob.Size)) + state = state.SetIndexesToCache(arrayCore.InvertBitSet(bitarray.NewBitSet(uint(arrayJobSize)), uint(arrayJobSize))) + state = state.SetExecutionArraySize(int(arrayJobSize)) state = state.SetPhase(arrayCore.PhasePreLaunch, core.DefaultPhaseVersion).SetReason(fmt.Sprintf("Skipping cache check due to err [%v]", err)) return state, nil } @@ -100,11 +133,11 @@ func DetermineDiscoverability(ctx context.Context, tCtx core.TaskExecutionContex } cachedResults := resp.GetCachedResults() - state = state.SetIndexesToCache(arrayCore.InvertBitSet(cachedResults, uint(arrayJob.Size))) - state = state.SetExecutionArraySize(int(arrayJob.Size) - resp.GetCachedCount()) + state = state.SetIndexesToCache(arrayCore.InvertBitSet(cachedResults, uint(arrayJobSize))) + state = state.SetExecutionArraySize(int(arrayJobSize) - resp.GetCachedCount()) // If all the sub-tasks are actually done, then we can just move on. - if resp.GetCachedCount() == int(arrayJob.Size) { + if resp.GetCachedCount() == int(arrayJobSize) { state.SetPhase(arrayCore.PhaseAssembleFinalOutput, core.DefaultPhaseVersion).SetReason("All subtasks are cached. assembling final outputs.") return state, nil } @@ -117,7 +150,7 @@ func DetermineDiscoverability(ctx context.Context, tCtx core.TaskExecutionContex } logger.Infof(ctx, "Writing indexlookup file to [%s], cached count [%d/%d], ", - indexLookupPath, resp.GetCachedCount(), arrayJob.Size) + indexLookupPath, resp.GetCachedCount(), arrayJobSize) err = tCtx.DataStore().WriteProtobuf(ctx, indexLookupPath, storage.Options{}, indexLookup) if err != nil { return state, err @@ -150,7 +183,7 @@ func WriteToDiscovery(ctx context.Context, tCtx core.TaskExecutionContext, state } // Extract the custom plugin pb - arrayJob, err := arrayCore.ToArrayJob(taskTemplate.GetCustom()) + arrayJob, err := arrayCore.ToArrayJob(taskTemplate.GetCustom(), taskTemplate.TaskTypeVersion) if err != nil { return state, err } else if arrayJob == nil { diff --git a/flyteplugins/go/tasks/plugins/array/catalog_test.go b/flyteplugins/go/tasks/plugins/array/catalog_test.go index 04d4cbdadb..aad365dd04 100644 --- a/flyteplugins/go/tasks/plugins/array/catalog_test.go +++ b/flyteplugins/go/tasks/plugins/array/catalog_test.go @@ -5,6 +5,10 @@ import ( "errors" "testing" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/utils" + structpb "github.com/golang/protobuf/ptypes/struct" + stdErrors "github.com/flyteorg/flytestdlib/errors" pluginErrors "github.com/flyteorg/flyteplugins/go/tasks/errors" @@ -66,6 +70,32 @@ func runDetermineDiscoverabilityTest(t testing.TB, taskTemplate *core.TaskTempla ir := &ioMocks.InputReader{} ir.OnGetInputPrefixPath().Return("/prefix/") + dummyInputLiteral := &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Primitive{ + Primitive: &core.Primitive{ + Value: &core.Primitive_Integer{ + Integer: 3, + }, + }, + }, + }, + }, + } + ir.On("Get", mock.Anything).Return(&core.LiteralMap{ + Literals: map[string]*core.Literal{ + "foo": { + Value: &core.Literal_Collection{ + Collection: &core.LiteralCollection{ + Literals: []*core.Literal{ + dummyInputLiteral, dummyInputLiteral, dummyInputLiteral, + }, + }, + }, + }, + }, + }, nil) ow := &ioMocks.OutputWriter{} ow.OnGetOutputPrefixPath().Return("/prefix/") @@ -198,3 +228,65 @@ func TestDetermineDiscoverability(t *testing.T) { }, nil) }) } + +func TestDiscoverabilityTaskType1(t *testing.T) { + + download := &catalogMocks.DownloadResponse{} + download.OnGetCachedCount().Return(0) + download.OnGetResultsSize().Return(1) + + f := &catalogMocks.DownloadFuture{} + f.OnGetResponseStatus().Return(catalog.ResponseStatusReady) + f.OnGetResponseError().Return(nil) + f.OnGetResponse().Return(download, nil) + + t.Run("Not discoverable", func(t *testing.T) { + download.OnGetCachedResults().Return(bitarray.NewBitSet(1)).Once() + toCache := arrayCore.InvertBitSet(bitarray.NewBitSet(uint(3)), uint(3)) + + arrayJob := &plugins.ArrayJob{ + SuccessCriteria: &plugins.ArrayJob_MinSuccessRatio{ + MinSuccessRatio: 0.5, + }, + } + var arrayJobCustom structpb.Struct + err := utils.MarshalStruct(arrayJob, &arrayJobCustom) + assert.NoError(t, err) + templateType1 := &core.TaskTemplate{ + Id: &core.Identifier{ + ResourceType: core.ResourceType_TASK, + Project: "p", + Domain: "d", + Name: "n", + Version: "1", + }, + Interface: &core.TypedInterface{ + Inputs: &core.VariableMap{Variables: map[string]*core.Variable{ + "foo": { + Description: "foo", + }, + }}, + Outputs: &core.VariableMap{Variables: map[string]*core.Variable{}}, + }, + Target: &core.TaskTemplate_Container{ + Container: &core.Container{ + Command: []string{"cmd"}, + Args: []string{"{{$inputPrefix}}"}, + Image: "img1", + }, + }, + TaskTypeVersion: 1, + Custom: &arrayJobCustom, + } + + runDetermineDiscoverabilityTest(t, templateType1, f, &arrayCore.State{ + CurrentPhase: arrayCore.PhasePreLaunch, + PhaseVersion: core2.DefaultPhaseVersion, + ExecutionArraySize: 3, + OriginalArraySize: 3, + OriginalMinSuccesses: 2, + IndexesToCache: toCache, + Reason: "Task is not discoverable.", + }, nil) + }) +} diff --git a/flyteplugins/go/tasks/plugins/array/core/state.go b/flyteplugins/go/tasks/plugins/array/core/state.go index d3a5e2d4dc..44fd901a06 100644 --- a/flyteplugins/go/tasks/plugins/array/core/state.go +++ b/flyteplugins/go/tasks/plugins/array/core/state.go @@ -131,13 +131,23 @@ const ( ErrorK8sArrayGeneric errors.ErrorCode = "ARRAY_JOB_GENERIC_FAILURE" ) -func ToArrayJob(structObj *structpb.Struct) (*idlPlugins.ArrayJob, error) { +func ToArrayJob(structObj *structpb.Struct, taskTypeVersion int32) (*idlPlugins.ArrayJob, error) { if structObj == nil { + if taskTypeVersion == 0 { + + return &idlPlugins.ArrayJob{ + Parallelism: 1, + Size: 1, + SuccessCriteria: &idlPlugins.ArrayJob_MinSuccesses{ + MinSuccesses: 1, + }, + }, nil + } return &idlPlugins.ArrayJob{ Parallelism: 1, Size: 1, - SuccessCriteria: &idlPlugins.ArrayJob_MinSuccesses{ - MinSuccesses: 1, + SuccessCriteria: &idlPlugins.ArrayJob_MinSuccessRatio{ + MinSuccessRatio: 1.0, }, }, nil } diff --git a/flyteplugins/go/tasks/plugins/array/core/state_test.go b/flyteplugins/go/tasks/plugins/array/core/state_test.go index 15e0c214c2..d7d2210501 100644 --- a/flyteplugins/go/tasks/plugins/array/core/state_test.go +++ b/flyteplugins/go/tasks/plugins/array/core/state_test.go @@ -4,6 +4,9 @@ import ( "context" "testing" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins" + "github.com/golang/protobuf/proto" + "github.com/flyteorg/flytestdlib/bitarray" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" @@ -213,3 +216,29 @@ func Test_calculateOriginalIndex(t *testing.T) { } }) } + +func TestToArrayJob(t *testing.T) { + t.Run("task_type_version == 0", func(t *testing.T) { + arrayJob, err := ToArrayJob(nil, 0) + assert.NoError(t, err) + assert.True(t, proto.Equal(arrayJob, &plugins.ArrayJob{ + Parallelism: 1, + Size: 1, + SuccessCriteria: &plugins.ArrayJob_MinSuccesses{ + MinSuccesses: 1, + }, + })) + }) + + t.Run("task_type_version == 1", func(t *testing.T) { + arrayJob, err := ToArrayJob(nil, 1) + assert.NoError(t, err) + assert.True(t, proto.Equal(arrayJob, &plugins.ArrayJob{ + Parallelism: 1, + Size: 1, + SuccessCriteria: &plugins.ArrayJob_MinSuccessRatio{ + MinSuccessRatio: 1.0, + }, + })) + }) +} diff --git a/flyteplugins/go/tasks/plugins/array/inputs.go b/flyteplugins/go/tasks/plugins/array/inputs.go new file mode 100644 index 0000000000..539b200c71 --- /dev/null +++ b/flyteplugins/go/tasks/plugins/array/inputs.go @@ -0,0 +1,31 @@ +package array + +import ( + idlCore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io" + "github.com/flyteorg/flytestdlib/storage" +) + +// A proxy inputreader that overrides the inputpath to be the inputpathprefix for array jobs +type arrayJobInputReader struct { + io.InputReader +} + +// We override the inputpath to return the prefix path for array jobs +func (i arrayJobInputReader) GetInputPath() storage.DataReference { + return i.GetInputPrefixPath() +} + +func GetInputReader(tCtx core.TaskExecutionContext, taskTemplate *idlCore.TaskTemplate) io.InputReader { + var inputReader io.InputReader + if taskTemplate.GetTaskTypeVersion() == 0 { + // Prior to task type version == 1, dynamic type tasks (including array tasks) would write input files for each + // individual array task instance. In this case we use a modified input reader to only pass in the parent input + // directory. + inputReader = arrayJobInputReader{tCtx.InputReader()} + } else { + inputReader = tCtx.InputReader() + } + return inputReader +} diff --git a/flyteplugins/go/tasks/plugins/array/inputs_test.go b/flyteplugins/go/tasks/plugins/array/inputs_test.go new file mode 100644 index 0000000000..42b0f8703e --- /dev/null +++ b/flyteplugins/go/tasks/plugins/array/inputs_test.go @@ -0,0 +1,41 @@ +package array + +import ( + "testing" + + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + pluginsCoreMock "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks" + pluginsIOMock "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io/mocks" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + + "github.com/flyteorg/flytestdlib/storage" +) + +func TestGetInputReader(t *testing.T) { + + inputReader := &pluginsIOMock.InputReader{} + inputReader.On("GetInputPrefixPath").Return(storage.DataReference("test-data-prefix")) + inputReader.On("GetInputPath").Return(storage.DataReference("test-data-reference")) + inputReader.On("Get", mock.Anything).Return(&core.LiteralMap{}, nil) + + t.Run("task_type_version == 0", func(t *testing.T) { + taskCtx := &pluginsCoreMock.TaskExecutionContext{} + taskCtx.On("InputReader").Return(inputReader) + + inputReader := GetInputReader(taskCtx, &core.TaskTemplate{ + TaskTypeVersion: 0, + }) + assert.Equal(t, inputReader.GetInputPath().String(), "test-data-prefix") + }) + + t.Run("task_type_version == 1", func(t *testing.T) { + taskCtx := &pluginsCoreMock.TaskExecutionContext{} + taskCtx.On("InputReader").Return(inputReader) + + inputReader := GetInputReader(taskCtx, &core.TaskTemplate{ + TaskTypeVersion: 1, + }) + assert.Equal(t, inputReader.GetInputPath().String(), "test-data-reference") + }) +} diff --git a/flyteplugins/go/tasks/plugins/array/k8s/executor.go b/flyteplugins/go/tasks/plugins/array/k8s/executor.go index 74a9456c36..2790ee5f12 100644 --- a/flyteplugins/go/tasks/plugins/array/k8s/executor.go +++ b/flyteplugins/go/tasks/plugins/array/k8s/executor.go @@ -106,6 +106,7 @@ func (e Executor) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (c err = nil case arrayCore.PhaseCheckingSubTaskExecutions: + nextState, logLinks, err = LaunchAndCheckSubTasksState(ctx, tCtx, e.kubeClient, pluginConfig, tCtx.DataStore(), tCtx.OutputWriter().GetOutputPrefixPath(), tCtx.OutputWriter().GetRawOutputPrefix(), pluginState) diff --git a/flyteplugins/go/tasks/plugins/array/k8s/launcher.go b/flyteplugins/go/tasks/plugins/array/k8s/launcher.go index 7aac0e91d5..d9db715250 100644 --- a/flyteplugins/go/tasks/plugins/array/k8s/launcher.go +++ b/flyteplugins/go/tasks/plugins/array/k8s/launcher.go @@ -19,6 +19,7 @@ const ( ErrBuildPodTemplate errors2.ErrorCode = "POD_TEMPLATE_FAILED" ErrReplaceCmdTemplate errors2.ErrorCode = "CMD_TEMPLATE_FAILED" ErrSubmitJob errors2.ErrorCode = "SUBMIT_JOB_FAILED" + ErrGetTaskTypeVersion errors2.ErrorCode = "GET_TASK_TYPE_VERSION_FAILED" JobIndexVarName string = "BATCH_JOB_ARRAY_INDEX_VAR_NAME" FlyteK8sArrayIndexVarName string = "FLYTE_K8S_ARRAY_INDEX" ) diff --git a/flyteplugins/go/tasks/plugins/array/k8s/monitor.go b/flyteplugins/go/tasks/plugins/array/k8s/monitor.go index 26bb61cc55..a2afde9d7d 100644 --- a/flyteplugins/go/tasks/plugins/array/k8s/monitor.go +++ b/flyteplugins/go/tasks/plugins/array/k8s/monitor.go @@ -112,6 +112,7 @@ func LaunchAndCheckSubTasksState(ctx context.Context, tCtx core.TaskExecutionCon var monitorResult MonitorResult monitorResult, err = task.Monitor(ctx, tCtx, kubeClient, dataStore, outputPrefix, baseOutputDataSandbox) + logLinks = task.LogLinks if monitorResult != MonitorSuccess { if err != nil { diff --git a/flyteplugins/go/tasks/plugins/array/k8s/task.go b/flyteplugins/go/tasks/plugins/array/k8s/task.go index a0f0964bff..40c7747cdb 100644 --- a/flyteplugins/go/tasks/plugins/array/k8s/task.go +++ b/flyteplugins/go/tasks/plugins/array/k8s/task.go @@ -75,8 +75,15 @@ func (t Task) Launch(ctx context.Context, tCtx core.TaskExecutionContext, kubeCl }) pod.Spec.Containers[0].Env = append(pod.Spec.Containers[0].Env, arrayJobEnvVars...) + taskTemplate, err := tCtx.TaskReader().Read(ctx) + if err != nil { + return LaunchError, errors2.Wrapf(ErrGetTaskTypeVersion, err, "Unable to read task template") + } else if taskTemplate == nil { + return LaunchError, errors2.Wrapf(ErrGetTaskTypeVersion, err, "Missing task template") + } + inputReader := array.GetInputReader(tCtx, taskTemplate) pod.Spec.Containers[0].Args, err = template.ReplaceTemplateCommandArgs(ctx, tCtx.TaskExecutionMetadata(), args, - arrayJobInputReader{tCtx.InputReader()}, tCtx.OutputWriter()) + inputReader, tCtx.OutputWriter()) if err != nil { return LaunchError, errors2.Wrapf(ErrReplaceCmdTemplate, err, "Failed to replace cmd args") } @@ -116,7 +123,7 @@ func (t Task) Launch(ctx context.Context, tCtx core.TaskExecutionContext, kubeCl return LaunchSuccess, nil } -func (t Task) Monitor(ctx context.Context, tCtx core.TaskExecutionContext, kubeClient core.KubeClient, dataStore *storage.DataStore, outputPrefix, baseOutputDataSandbox storage.DataReference) (MonitorResult, error) { +func (t *Task) Monitor(ctx context.Context, tCtx core.TaskExecutionContext, kubeClient core.KubeClient, dataStore *storage.DataStore, outputPrefix, baseOutputDataSandbox storage.DataReference) (MonitorResult, error) { indexStr := strconv.Itoa(t.ChildIdx) podName := formatSubTaskName(ctx, tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), indexStr) phaseInfo, err := CheckPodStatus(ctx, kubeClient, diff --git a/flyteplugins/go/tasks/plugins/array/k8s/transformer.go b/flyteplugins/go/tasks/plugins/array/k8s/transformer.go index 95ae13ab4a..70061249c7 100644 --- a/flyteplugins/go/tasks/plugins/array/k8s/transformer.go +++ b/flyteplugins/go/tasks/plugins/array/k8s/transformer.go @@ -3,34 +3,20 @@ package k8s import ( "context" - core2 "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/core" - - "github.com/flyteorg/flytestdlib/storage" - - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io" + "github.com/flyteorg/flyteplugins/go/tasks/plugins/array" idlPlugins "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "github.com/flyteorg/flyteplugins/go/tasks/errors" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s" + core2 "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/core" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" v1 "k8s.io/api/core/v1" ) const PodKind = "pod" -// A proxy inputreader that overrides the inputpath to be the inputpathprefix for array jobs -type arrayJobInputReader struct { - io.InputReader -} - -// We override the inputpath to return the prefix path for array jobs -func (i arrayJobInputReader) GetInputPath() storage.DataReference { - return i.GetInputPrefixPath() -} - // Note that Name is not set on the result object. // It's up to the caller to set the Name before creating the object in K8s. func FlyteArrayJobToK8sPodTemplate(ctx context.Context, tCtx core.TaskExecutionContext) ( @@ -49,15 +35,16 @@ func FlyteArrayJobToK8sPodTemplate(ctx context.Context, tCtx core.TaskExecutionC "Required value not set, taskTemplate Container") } + inputReader := array.GetInputReader(tCtx, taskTemplate) var arrayJob *idlPlugins.ArrayJob if taskTemplate.GetCustom() != nil { - arrayJob, err = core2.ToArrayJob(taskTemplate.GetCustom()) + arrayJob, err = core2.ToArrayJob(taskTemplate.GetCustom(), taskTemplate.TaskTypeVersion) if err != nil { return v1.Pod{}, nil, err } } - podSpec, err := flytek8s.ToK8sPodSpec(ctx, tCtx.TaskExecutionMetadata(), tCtx.TaskReader(), arrayJobInputReader{tCtx.InputReader()}, + podSpec, err := flytek8s.ToK8sPodSpec(ctx, tCtx.TaskExecutionMetadata(), tCtx.TaskReader(), inputReader, tCtx.OutputWriter()) if err != nil { return v1.Pod{}, nil, err