diff --git a/go/tasks/plugins/array/catalog.go b/go/tasks/plugins/array/catalog.go index 874a267bb1..38bac1aa15 100644 --- a/go/tasks/plugins/array/catalog.go +++ b/go/tasks/plugins/array/catalog.go @@ -488,7 +488,8 @@ func ConstructOutputWriter(ctx context.Context, dataStore *storage.DataStore, ou return nil, err } - // TODO when we fix https://github.com/flyteorg/flyte/issues/1276 we should make sure that the checkpoint paths are computed correctly + // checkpoint paths are not computed here because this function is only called when writing + // existing cached outputs. if this functionality changes this will need to be revisited. p := ioutils.NewCheckpointRemoteFilePaths(ctx, dataStore, dataReference, ioutils.NewRawOutputPaths(ctx, outputSandbox), "") return ioutils.NewRemoteFileOutputWriter(ctx, dataStore, p), nil } @@ -523,7 +524,8 @@ func ConstructOutputReader(ctx context.Context, dataStore *storage.DataStore, ou return nil, err } - // TODO when we fix https://github.com/flyteorg/flyte/issues/1276 we should make so that the checkpoint paths are computed correctly + // checkpoint paths are not computed here because this function is only called when writing + // existing cached outputs. if this functionality changes this will need to be revisited. outputPath := ioutils.NewCheckpointRemoteFilePaths(ctx, dataStore, dataReference, ioutils.NewRawOutputPaths(ctx, outputSandbox), "") return ioutils.NewRemoteFileOutputReader(ctx, dataStore, outputPath, int64(999999999)), nil } diff --git a/go/tasks/plugins/array/k8s/management.go b/go/tasks/plugins/array/k8s/management.go index 1d9b98e506..480db31c47 100644 --- a/go/tasks/plugins/array/k8s/management.go +++ b/go/tasks/plugins/array/k8s/management.go @@ -167,7 +167,7 @@ func LaunchAndCheckSubTasksState(ctx context.Context, tCtx core.TaskExecutionCon originalIdx := arrayCore.CalculateOriginalIndex(childIdx, newState.GetIndexesToCache()) systemFailures := currentState.SystemFailures.GetItem(childIdx) - stCtx, err := NewSubTaskExecutionContext(tCtx, taskTemplate, childIdx, originalIdx, retryAttempt, systemFailures) + stCtx, err := NewSubTaskExecutionContext(ctx, tCtx, taskTemplate, childIdx, originalIdx, retryAttempt, systemFailures) if err != nil { return currentState, externalResources, err } @@ -337,7 +337,7 @@ func TerminateSubTasks(ctx context.Context, tCtx core.TaskExecutionContext, kube } originalIdx := arrayCore.CalculateOriginalIndex(childIdx, currentState.GetIndexesToCache()) - stCtx, err := NewSubTaskExecutionContext(tCtx, taskTemplate, childIdx, originalIdx, retryAttempt, 0) + stCtx, err := NewSubTaskExecutionContext(ctx, tCtx, taskTemplate, childIdx, originalIdx, retryAttempt, 0) if err != nil { return err } diff --git a/go/tasks/plugins/array/k8s/management_test.go b/go/tasks/plugins/array/k8s/management_test.go index 2aaf077fa6..48a0a47518 100644 --- a/go/tasks/plugins/array/k8s/management_test.go +++ b/go/tasks/plugins/array/k8s/management_test.go @@ -16,6 +16,8 @@ import ( arrayCore "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/core" "github.com/flyteorg/flytestdlib/bitarray" + "github.com/flyteorg/flytestdlib/storage" + stdmocks "github.com/flyteorg/flytestdlib/storage/mocks" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" @@ -104,11 +106,17 @@ func getMockTaskExecutionContext(ctx context.Context, parallelism int) *mocks.Ta ir.OnGetInputPath().Return("/prefix/inputs.pb") ir.OnGetMatch(mock.Anything).Return(&core2.LiteralMap{}, nil) + dataStore := &storage.DataStore{ + ComposedProtobufStore: &stdmocks.ComposedProtobufStore{}, + ReferenceConstructor: &storage.URLPathConstructor{}, + } + tCtx := &mocks.TaskExecutionContext{} tCtx.OnTaskReader().Return(tr) tCtx.OnTaskExecutionMetadata().Return(tMeta) tCtx.OnOutputWriter().Return(ow) tCtx.OnInputReader().Return(ir) + tCtx.OnDataStore().Return(dataStore) return tCtx } diff --git a/go/tasks/plugins/array/k8s/subtask_exec_context.go b/go/tasks/plugins/array/k8s/subtask_exec_context.go index d5d4393101..37e2f34609 100644 --- a/go/tasks/plugins/array/k8s/subtask_exec_context.go +++ b/go/tasks/plugins/array/k8s/subtask_exec_context.go @@ -9,10 +9,13 @@ import ( pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/ioutils" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/utils" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/utils/secrets" "github.com/flyteorg/flyteplugins/go/tasks/plugins/array" podPlugin "github.com/flyteorg/flyteplugins/go/tasks/plugins/k8s/pod" + + "github.com/flyteorg/flytestdlib/storage" ) // SubTaskExecutionContext wraps the core TaskExecutionContext so that the k8s array task context @@ -22,6 +25,7 @@ type SubTaskExecutionContext struct { arrayInputReader io.InputReader metadataOverride pluginsCore.TaskExecutionMetadata originalIndex int + outputWriter io.OutputWriter subtaskReader SubTaskReader } @@ -30,6 +34,11 @@ func (s SubTaskExecutionContext) InputReader() io.InputReader { return s.arrayInputReader } +// OutputWriter overrides the base TaskExecutionContext to return a custom OutputWriter +func (s SubTaskExecutionContext) OutputWriter() io.OutputWriter { + return s.outputWriter +} + // TaskExecutionMetadata overrides the base TaskExecutionContext to return custom // TaskExecutionMetadata func (s SubTaskExecutionContext) TaskExecutionMetadata() pluginsCore.TaskExecutionMetadata { @@ -42,7 +51,7 @@ func (s SubTaskExecutionContext) TaskReader() pluginsCore.TaskReader { } // NewSubtaskExecutionContext constructs a SubTaskExecutionContext using the provided parameters -func NewSubTaskExecutionContext(tCtx pluginsCore.TaskExecutionContext, taskTemplate *core.TaskTemplate, +func NewSubTaskExecutionContext(ctx context.Context, tCtx pluginsCore.TaskExecutionContext, taskTemplate *core.TaskTemplate, executionIndex, originalIndex int, retryAttempt uint64, systemFailures uint64) (SubTaskExecutionContext, error) { subTaskExecutionMetadata, err := NewSubTaskExecutionMetadata(tCtx.TaskExecutionMetadata(), taskTemplate, executionIndex, retryAttempt, systemFailures) @@ -50,6 +59,7 @@ func NewSubTaskExecutionContext(tCtx pluginsCore.TaskExecutionContext, taskTempl return SubTaskExecutionContext{}, err } + // construct TaskTemplate subtaskTemplate := &core.TaskTemplate{} *subtaskTemplate = *taskTemplate @@ -64,11 +74,42 @@ func NewSubTaskExecutionContext(tCtx pluginsCore.TaskExecutionContext, taskTempl arrayInputReader := array.GetInputReader(tCtx, taskTemplate) subtaskReader := SubTaskReader{tCtx.TaskReader(), subtaskTemplate} + + // construct OutputWriter + dataStore := tCtx.DataStore() + checkpointPrefix, err := dataStore.ConstructReference(ctx, tCtx.OutputWriter().GetRawOutputPrefix(), strconv.Itoa(originalIndex)) + if err != nil { + return SubTaskExecutionContext{}, err + } + + checkpoint, err := dataStore.ConstructReference(ctx, checkpointPrefix, strconv.FormatUint(retryAttempt, 10)) + if err != nil { + return SubTaskExecutionContext{}, err + } + checkpointPath := ioutils.NewRawOutputPaths(ctx, checkpoint) + + var prevCheckpoint storage.DataReference + if retryAttempt == 0 { + prevCheckpoint = "" + } else { + prevCheckpoint, err = dataStore.ConstructReference(ctx, checkpointPrefix, strconv.FormatUint(retryAttempt-1, 10)) + if err != nil { + return SubTaskExecutionContext{}, err + } + } + prevCheckpointPath := ioutils.ConstructCheckpointPath(dataStore, prevCheckpoint) + + // note that we must not append the originalIndex to the original OutputPrefixPath because + // flytekit is already doing this + p := ioutils.NewCheckpointRemoteFilePaths(ctx, dataStore, tCtx.OutputWriter().GetOutputPrefixPath(), checkpointPath, prevCheckpointPath) + outputWriter := ioutils.NewRemoteFileOutputWriter(ctx, dataStore, p) + return SubTaskExecutionContext{ TaskExecutionContext: tCtx, arrayInputReader: arrayInputReader, metadataOverride: subTaskExecutionMetadata, originalIndex: originalIndex, + outputWriter: outputWriter, subtaskReader: subtaskReader, }, nil } diff --git a/go/tasks/plugins/array/k8s/subtask_exec_context_test.go b/go/tasks/plugins/array/k8s/subtask_exec_context_test.go index 079ab82915..77389514ec 100644 --- a/go/tasks/plugins/array/k8s/subtask_exec_context_test.go +++ b/go/tasks/plugins/array/k8s/subtask_exec_context_test.go @@ -7,6 +7,8 @@ import ( podPlugin "github.com/flyteorg/flyteplugins/go/tasks/plugins/k8s/pod" + "github.com/flyteorg/flytestdlib/storage" + "github.com/stretchr/testify/assert" ) @@ -22,7 +24,7 @@ func TestSubTaskExecutionContext(t *testing.T) { retryAttempt := uint64(1) systemFailures := uint64(0) - stCtx, err := NewSubTaskExecutionContext(tCtx, taskTemplate, executionIndex, originalIndex, retryAttempt, systemFailures) + stCtx, err := NewSubTaskExecutionContext(ctx, tCtx, taskTemplate, executionIndex, originalIndex, retryAttempt, systemFailures) assert.Nil(t, err) assert.Equal(t, fmt.Sprintf("notfound-%d-%d", executionIndex, retryAttempt), stCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName()) @@ -31,4 +33,6 @@ func TestSubTaskExecutionContext(t *testing.T) { assert.Nil(t, err) assert.Equal(t, int32(2), subtaskTemplate.TaskTypeVersion) assert.Equal(t, podPlugin.ContainerTaskType, subtaskTemplate.Type) + assert.Equal(t, storage.DataReference("/prefix/"), stCtx.OutputWriter().GetOutputPrefixPath()) + assert.Equal(t, storage.DataReference("/raw_prefix/5/1"), stCtx.OutputWriter().GetRawOutputPrefix()) }