Skip to content

Commit

Permalink
Support intra-task checkpointing in map tasks. (flyteorg#257)
Browse files Browse the repository at this point in the history
* overriding OutputWriter

Signed-off-by: Daniel Rammer <[email protected]>

* updated comments

Signed-off-by: Daniel Rammer <[email protected]>

* added unit test

Signed-off-by: Daniel Rammer <[email protected]>

* fixed lint issues

Signed-off-by: Daniel Rammer <[email protected]>
  • Loading branch information
hamersaw authored Apr 11, 2022
1 parent cc2b8c6 commit 7a01f94
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 6 deletions.
6 changes: 4 additions & 2 deletions go/tasks/plugins/array/catalog.go
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,8 @@ func ConstructOutputWriter(ctx context.Context, dataStore *storage.DataStore, ou
return nil, err
}

// TODO when we fix https://github.com/flyteorg/flyte/issues/1276 we should make sure that the checkpoint paths are computed correctly
// checkpoint paths are not computed here because this function is only called when writing
// existing cached outputs. if this functionality changes this will need to be revisited.
p := ioutils.NewCheckpointRemoteFilePaths(ctx, dataStore, dataReference, ioutils.NewRawOutputPaths(ctx, outputSandbox), "")
return ioutils.NewRemoteFileOutputWriter(ctx, dataStore, p), nil
}
Expand Down Expand Up @@ -523,7 +524,8 @@ func ConstructOutputReader(ctx context.Context, dataStore *storage.DataStore, ou
return nil, err
}

// TODO when we fix https://github.com/flyteorg/flyte/issues/1276 we should make so that the checkpoint paths are computed correctly
// checkpoint paths are not computed here because this function is only called when writing
// existing cached outputs. if this functionality changes this will need to be revisited.
outputPath := ioutils.NewCheckpointRemoteFilePaths(ctx, dataStore, dataReference, ioutils.NewRawOutputPaths(ctx, outputSandbox), "")
return ioutils.NewRemoteFileOutputReader(ctx, dataStore, outputPath, int64(999999999)), nil
}
4 changes: 2 additions & 2 deletions go/tasks/plugins/array/k8s/management.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ func LaunchAndCheckSubTasksState(ctx context.Context, tCtx core.TaskExecutionCon

originalIdx := arrayCore.CalculateOriginalIndex(childIdx, newState.GetIndexesToCache())
systemFailures := currentState.SystemFailures.GetItem(childIdx)
stCtx, err := NewSubTaskExecutionContext(tCtx, taskTemplate, childIdx, originalIdx, retryAttempt, systemFailures)
stCtx, err := NewSubTaskExecutionContext(ctx, tCtx, taskTemplate, childIdx, originalIdx, retryAttempt, systemFailures)
if err != nil {
return currentState, externalResources, err
}
Expand Down Expand Up @@ -337,7 +337,7 @@ func TerminateSubTasks(ctx context.Context, tCtx core.TaskExecutionContext, kube
}

originalIdx := arrayCore.CalculateOriginalIndex(childIdx, currentState.GetIndexesToCache())
stCtx, err := NewSubTaskExecutionContext(tCtx, taskTemplate, childIdx, originalIdx, retryAttempt, 0)
stCtx, err := NewSubTaskExecutionContext(ctx, tCtx, taskTemplate, childIdx, originalIdx, retryAttempt, 0)
if err != nil {
return err
}
Expand Down
8 changes: 8 additions & 0 deletions go/tasks/plugins/array/k8s/management_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ import (
arrayCore "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/core"

"github.com/flyteorg/flytestdlib/bitarray"
"github.com/flyteorg/flytestdlib/storage"
stdmocks "github.com/flyteorg/flytestdlib/storage/mocks"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
Expand Down Expand Up @@ -104,11 +106,17 @@ func getMockTaskExecutionContext(ctx context.Context, parallelism int) *mocks.Ta
ir.OnGetInputPath().Return("/prefix/inputs.pb")
ir.OnGetMatch(mock.Anything).Return(&core2.LiteralMap{}, nil)

dataStore := &storage.DataStore{
ComposedProtobufStore: &stdmocks.ComposedProtobufStore{},
ReferenceConstructor: &storage.URLPathConstructor{},
}

tCtx := &mocks.TaskExecutionContext{}
tCtx.OnTaskReader().Return(tr)
tCtx.OnTaskExecutionMetadata().Return(tMeta)
tCtx.OnOutputWriter().Return(ow)
tCtx.OnInputReader().Return(ir)
tCtx.OnDataStore().Return(dataStore)
return tCtx
}

Expand Down
43 changes: 42 additions & 1 deletion go/tasks/plugins/array/k8s/subtask_exec_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,13 @@ import (

pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/ioutils"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/utils"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/utils/secrets"
"github.com/flyteorg/flyteplugins/go/tasks/plugins/array"
podPlugin "github.com/flyteorg/flyteplugins/go/tasks/plugins/k8s/pod"

"github.com/flyteorg/flytestdlib/storage"
)

// SubTaskExecutionContext wraps the core TaskExecutionContext so that the k8s array task context
Expand All @@ -22,6 +25,7 @@ type SubTaskExecutionContext struct {
arrayInputReader io.InputReader
metadataOverride pluginsCore.TaskExecutionMetadata
originalIndex int
outputWriter io.OutputWriter
subtaskReader SubTaskReader
}

Expand All @@ -30,6 +34,11 @@ func (s SubTaskExecutionContext) InputReader() io.InputReader {
return s.arrayInputReader
}

// OutputWriter overrides the base TaskExecutionContext to return a custom OutputWriter
func (s SubTaskExecutionContext) OutputWriter() io.OutputWriter {
return s.outputWriter
}

// TaskExecutionMetadata overrides the base TaskExecutionContext to return custom
// TaskExecutionMetadata
func (s SubTaskExecutionContext) TaskExecutionMetadata() pluginsCore.TaskExecutionMetadata {
Expand All @@ -42,14 +51,15 @@ func (s SubTaskExecutionContext) TaskReader() pluginsCore.TaskReader {
}

// NewSubtaskExecutionContext constructs a SubTaskExecutionContext using the provided parameters
func NewSubTaskExecutionContext(tCtx pluginsCore.TaskExecutionContext, taskTemplate *core.TaskTemplate,
func NewSubTaskExecutionContext(ctx context.Context, tCtx pluginsCore.TaskExecutionContext, taskTemplate *core.TaskTemplate,
executionIndex, originalIndex int, retryAttempt uint64, systemFailures uint64) (SubTaskExecutionContext, error) {

subTaskExecutionMetadata, err := NewSubTaskExecutionMetadata(tCtx.TaskExecutionMetadata(), taskTemplate, executionIndex, retryAttempt, systemFailures)
if err != nil {
return SubTaskExecutionContext{}, err
}

// construct TaskTemplate
subtaskTemplate := &core.TaskTemplate{}
*subtaskTemplate = *taskTemplate

Expand All @@ -64,11 +74,42 @@ func NewSubTaskExecutionContext(tCtx pluginsCore.TaskExecutionContext, taskTempl

arrayInputReader := array.GetInputReader(tCtx, taskTemplate)
subtaskReader := SubTaskReader{tCtx.TaskReader(), subtaskTemplate}

// construct OutputWriter
dataStore := tCtx.DataStore()
checkpointPrefix, err := dataStore.ConstructReference(ctx, tCtx.OutputWriter().GetRawOutputPrefix(), strconv.Itoa(originalIndex))
if err != nil {
return SubTaskExecutionContext{}, err
}

checkpoint, err := dataStore.ConstructReference(ctx, checkpointPrefix, strconv.FormatUint(retryAttempt, 10))
if err != nil {
return SubTaskExecutionContext{}, err
}
checkpointPath := ioutils.NewRawOutputPaths(ctx, checkpoint)

var prevCheckpoint storage.DataReference
if retryAttempt == 0 {
prevCheckpoint = ""
} else {
prevCheckpoint, err = dataStore.ConstructReference(ctx, checkpointPrefix, strconv.FormatUint(retryAttempt-1, 10))
if err != nil {
return SubTaskExecutionContext{}, err
}
}
prevCheckpointPath := ioutils.ConstructCheckpointPath(dataStore, prevCheckpoint)

// note that we must not append the originalIndex to the original OutputPrefixPath because
// flytekit is already doing this
p := ioutils.NewCheckpointRemoteFilePaths(ctx, dataStore, tCtx.OutputWriter().GetOutputPrefixPath(), checkpointPath, prevCheckpointPath)
outputWriter := ioutils.NewRemoteFileOutputWriter(ctx, dataStore, p)

return SubTaskExecutionContext{
TaskExecutionContext: tCtx,
arrayInputReader: arrayInputReader,
metadataOverride: subTaskExecutionMetadata,
originalIndex: originalIndex,
outputWriter: outputWriter,
subtaskReader: subtaskReader,
}, nil
}
Expand Down
6 changes: 5 additions & 1 deletion go/tasks/plugins/array/k8s/subtask_exec_context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (

podPlugin "github.com/flyteorg/flyteplugins/go/tasks/plugins/k8s/pod"

"github.com/flyteorg/flytestdlib/storage"

"github.com/stretchr/testify/assert"
)

Expand All @@ -22,7 +24,7 @@ func TestSubTaskExecutionContext(t *testing.T) {
retryAttempt := uint64(1)
systemFailures := uint64(0)

stCtx, err := NewSubTaskExecutionContext(tCtx, taskTemplate, executionIndex, originalIndex, retryAttempt, systemFailures)
stCtx, err := NewSubTaskExecutionContext(ctx, tCtx, taskTemplate, executionIndex, originalIndex, retryAttempt, systemFailures)
assert.Nil(t, err)

assert.Equal(t, fmt.Sprintf("notfound-%d-%d", executionIndex, retryAttempt), stCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName())
Expand All @@ -31,4 +33,6 @@ func TestSubTaskExecutionContext(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, int32(2), subtaskTemplate.TaskTypeVersion)
assert.Equal(t, podPlugin.ContainerTaskType, subtaskTemplate.Type)
assert.Equal(t, storage.DataReference("/prefix/"), stCtx.OutputWriter().GetOutputPrefixPath())
assert.Equal(t, storage.DataReference("/raw_prefix/5/1"), stCtx.OutputWriter().GetRawOutputPrefix())
}

0 comments on commit 7a01f94

Please sign in to comment.