Skip to content

Commit

Permalink
Allow checkpoint resume when recovering a workflow (flyteorg#486)
Browse files Browse the repository at this point in the history
* Update flyteidl version

Signed-off-by: Flyte-Bot <[email protected]>

* Update flyteidl version

Signed-off-by: Flyte-Bot <[email protected]>

* Fix build break

Signed-off-by: Haytham Abuelfutuh <[email protected]>

* Update flyteidl version

Signed-off-by: Flyte-Bot <[email protected]>

* Save/restore CheckpointUri from NodeExecution

Signed-off-by: Andrew Dye <[email protected]>

* Lints, generate

Signed-off-by: Andrew Dye <[email protected]>

* Fix log line

Signed-off-by: Andrew Dye <[email protected]>

Signed-off-by: Flyte-Bot <[email protected]>
Signed-off-by: Haytham Abuelfutuh <[email protected]>
Signed-off-by: Andrew Dye <[email protected]>
Co-authored-by: flyte-bot <[email protected]>
Co-authored-by: Haytham Abuelfutuh <[email protected]>
Co-authored-by: Dan Rammer <[email protected]>
  • Loading branch information
4 people authored Oct 6, 2022
1 parent 66eaf12 commit a9b831b
Show file tree
Hide file tree
Showing 13 changed files with 232 additions and 43 deletions.
2 changes: 2 additions & 0 deletions pkg/apis/flyteworkflow/v1alpha1/iface.go
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ type ExecutableTaskNodeStatus interface {
GetPluginStateVersion() uint32
GetBarrierClockTick() uint32
GetLastPhaseUpdatedAt() time.Time
GetPreviousNodeExecutionCheckpointPath() DataReference
}

type MutableTaskNodeStatus interface {
Expand All @@ -347,6 +348,7 @@ type MutableTaskNodeStatus interface {
SetPluginState([]byte)
SetPluginStateVersion(uint32)
SetBarrierClockTick(tick uint32)
SetPreviousNodeExecutionCheckpointPath(DataReference)
}

// ExecutableWorkflowNode is an interface for a Child Workflow Node
Expand Down
37 changes: 35 additions & 2 deletions pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableTaskNodeStatus.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

42 changes: 40 additions & 2 deletions pkg/apis/flyteworkflow/v1alpha1/mocks/MutableTaskNodeStatus.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

22 changes: 16 additions & 6 deletions pkg/apis/flyteworkflow/v1alpha1/node_status.go
Original file line number Diff line number Diff line change
Expand Up @@ -711,12 +711,13 @@ func (in *CustomState) DeepCopy() *CustomState {

type TaskNodeStatus struct {
MutableStruct
Phase int `json:"phase,omitempty"`
PhaseVersion uint32 `json:"phaseVersion,omitempty"`
PluginState []byte `json:"pState,omitempty"`
PluginStateVersion uint32 `json:"psv,omitempty"`
BarrierClockTick uint32 `json:"tick,omitempty"`
LastPhaseUpdatedAt time.Time `json:"updAt,omitempty"`
Phase int `json:"phase,omitempty"`
PhaseVersion uint32 `json:"phaseVersion,omitempty"`
PluginState []byte `json:"pState,omitempty"`
PluginStateVersion uint32 `json:"psv,omitempty"`
BarrierClockTick uint32 `json:"tick,omitempty"`
LastPhaseUpdatedAt time.Time `json:"updAt,omitempty"`
PreviousNodeExecutionCheckpointPath DataReference `json:"checkpointPath,omitempty"`
}

func (in *TaskNodeStatus) GetBarrierClockTick() uint32 {
Expand All @@ -728,6 +729,11 @@ func (in *TaskNodeStatus) SetBarrierClockTick(tick uint32) {
in.SetDirty()
}

func (in *TaskNodeStatus) SetPreviousNodeExecutionCheckpointPath(path DataReference) {
in.PreviousNodeExecutionCheckpointPath = path
in.SetDirty()
}

func (in *TaskNodeStatus) SetPluginState(s []byte) {
in.PluginState = s
in.SetDirty()
Expand Down Expand Up @@ -768,6 +774,10 @@ func (in TaskNodeStatus) GetLastPhaseUpdatedAt() time.Time {
return in.LastPhaseUpdatedAt
}

func (in TaskNodeStatus) GetPreviousNodeExecutionCheckpointPath() DataReference {
return in.PreviousNodeExecutionCheckpointPath
}

func (in TaskNodeStatus) GetPhaseVersion() uint32 {
return in.PhaseVersion
}
Expand Down
13 changes: 13 additions & 0 deletions pkg/controller/nodes/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/ioutils"
errors2 "github.com/flyteorg/flytestdlib/errors"

"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin"
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event"
"github.com/flyteorg/flytepropeller/events"
Expand Down Expand Up @@ -190,6 +191,17 @@ func (c *nodeExecutor) attemptRecovery(ctx context.Context, nCtx handler.NodeExe
case core.NodeExecution_SUCCEEDED:
logger.Debugf(ctx, "Node [%+v] can be recovered. Proceeding to copy inputs and outputs", nCtx.NodeExecutionMetadata().GetNodeExecutionID())
default:
// The node execution may be partially recoverable through intra task checkpointing. Save the checkpoint
// uri in the task node state to pass to the task handler later on.
if metadata, ok := recovered.Closure.TargetMetadata.(*admin.NodeExecutionClosure_TaskNodeMetadata); ok {
state := nCtx.NodeStateReader().GetTaskNodeState()
state.PreviousNodeExecutionCheckpointURI = storage.DataReference(metadata.TaskNodeMetadata.CheckpointUri)
err = nCtx.NodeStateWriter().PutTaskNodeState(state)
if err != nil {
logger.Warn(ctx, "failed to save recovered checkpoint uri for [%+v]: [%+v]",
nCtx.NodeExecutionMetadata().GetNodeExecutionID(), err)
}
}
logger.Debugf(ctx, "Node [%+v] phase [%v] is not recoverable", nCtx.NodeExecutionMetadata().GetNodeExecutionID(), recovered.Closure.Phase)
return handler.PhaseInfoUndefined, nil
}
Expand Down Expand Up @@ -594,6 +606,7 @@ func (c *nodeExecutor) handleQueuedOrRunningNode(ctx context.Context, nCtx *node
if np != nodeStatus.GetPhase() && np != v1alpha1.NodePhaseRetryableFailure {
// assert np == skipped, succeeding, failing or recovered
logger.Infof(ctx, "Change in node state detected from [%s] -> [%s], (handler phase [%s])", nodeStatus.GetPhase().String(), np.String(), p.GetPhase().String())

nev, err := ToNodeExecutionEvent(nCtx.NodeExecutionMetadata().GetNodeExecutionID(),
p, nCtx.InputReader().GetInputPath().String(), nCtx.NodeStatus(), nCtx.ExecutionContext().GetEventVersion(),
nCtx.ExecutionContext().GetParentInfo(), nCtx.node, c.clusterID, nCtx.NodeStateReader().GetDynamicNodeState().Phase)
Expand Down
15 changes: 15 additions & 0 deletions pkg/controller/nodes/executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2282,13 +2282,28 @@ func TestRecover(t *testing.T) {
&admin.NodeExecution{
Closure: &admin.NodeExecutionClosure{
Phase: core.NodeExecution_FAILED,
TargetMetadata: &admin.NodeExecutionClosure_TaskNodeMetadata{
TaskNodeMetadata: &admin.TaskNodeMetadata{
CheckpointUri: "prev path",
},
},
},
}, nil)

executor := nodeExecutor{
recoveryClient: recoveryClient,
}

reader := &nodeHandlerMocks.NodeStateReader{}
reader.OnGetTaskNodeState().Return(handler.TaskNodeState{})
nCtx.OnNodeStateReader().Return(reader)
writer := &nodeHandlerMocks.NodeStateWriter{}
writer.OnPutTaskNodeStateMatch(mock.Anything).Run(func(args mock.Arguments) {
state := args.Get(0).(handler.TaskNodeState)
assert.Equal(t, state.PreviousNodeExecutionCheckpointURI.String(), "prev path")
}).Return(nil)
nCtx.OnNodeStateWriter().Return(writer)

phaseInfo, err := executor.attemptRecovery(context.TODO(), nCtx)
assert.NoError(t, err)
assert.Equal(t, phaseInfo.GetPhase(), handler.EPhaseUndefined)
Expand Down
14 changes: 8 additions & 6 deletions pkg/controller/nodes/handler/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (

"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
pluginCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"
"github.com/flyteorg/flytestdlib/storage"

"github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1"
)
Expand All @@ -13,12 +14,13 @@ import (
// TODO eventually we could just convert this to be binary node state encoded into the node status

type TaskNodeState struct {
PluginPhase pluginCore.Phase
PluginPhaseVersion uint32
PluginState []byte
PluginStateVersion uint32
BarrierClockTick uint32
LastPhaseUpdatedAt time.Time
PluginPhase pluginCore.Phase
PluginPhaseVersion uint32
PluginState []byte
PluginStateVersion uint32
BarrierClockTick uint32
LastPhaseUpdatedAt time.Time
PreviousNodeExecutionCheckpointURI storage.DataReference
}

type BranchNodeState struct {
Expand Down
13 changes: 7 additions & 6 deletions pkg/controller/nodes/node_state_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,13 @@ func (n nodeStateManager) GetTaskNodeState() handler.TaskNodeState {
tn := n.nodeStatus.GetTaskNodeStatus()
if tn != nil {
return handler.TaskNodeState{
PluginPhase: pluginCore.Phase(tn.GetPhase()),
PluginPhaseVersion: tn.GetPhaseVersion(),
PluginStateVersion: tn.GetPluginStateVersion(),
PluginState: tn.GetPluginState(),
BarrierClockTick: tn.GetBarrierClockTick(),
LastPhaseUpdatedAt: tn.GetLastPhaseUpdatedAt(),
PluginPhase: pluginCore.Phase(tn.GetPhase()),
PluginPhaseVersion: tn.GetPhaseVersion(),
PluginStateVersion: tn.GetPluginStateVersion(),
PluginState: tn.GetPluginState(),
BarrierClockTick: tn.GetBarrierClockTick(),
LastPhaseUpdatedAt: tn.GetLastPhaseUpdatedAt(),
PreviousNodeExecutionCheckpointURI: tn.GetPreviousNodeExecutionCheckpointPath(),
}
}
return handler.TaskNodeState{}
Expand Down
56 changes: 42 additions & 14 deletions pkg/controller/nodes/task/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,26 @@ func (p *pluginRequestedTransition) ObservedTransitionAndState(trns pluginCore.T
p.pluginStateVersion = pluginStateVersion
}

func (p *pluginRequestedTransition) ObservedExecutionError(executionError *io.ExecutionError) {
func (p *pluginRequestedTransition) ObservedExecutionError(executionError *io.ExecutionError, taskMetadata *event.TaskNodeMetadata) {
if executionError.IsRecoverable {
p.pInfo = pluginCore.PhaseInfoFailed(pluginCore.PhaseRetryableFailure, executionError.ExecutionError, p.pInfo.Info())
} else {
p.pInfo = pluginCore.PhaseInfoFailed(pluginCore.PhasePermanentFailure, executionError.ExecutionError, p.pInfo.Info())
}

if taskMetadata != nil {
p.execInfo.TaskNodeInfo = &handler.TaskNodeInfo{
TaskNodeMetadata: taskMetadata,
}
}
}

func (p *pluginRequestedTransition) ObservedFailure(taskMetadata *event.TaskNodeMetadata) {
if taskMetadata != nil {
p.execInfo.TaskNodeInfo = &handler.TaskNodeInfo{
TaskNodeMetadata: taskMetadata,
}
}
}

func (p *pluginRequestedTransition) IsPreviouslyObserved() bool {
Expand Down Expand Up @@ -159,10 +173,10 @@ func (p *pluginRequestedTransition) FinalTransition(ctx context.Context) (handle
return handler.DoTransition(p.ttype, handler.PhaseInfoSuccess(&p.execInfo)), nil
case pluginCore.PhaseRetryableFailure:
logger.Debugf(ctx, "Transitioning to RetryableFailure")
return handler.DoTransition(p.ttype, handler.PhaseInfoRetryableFailureErr(p.pInfo.Err(), nil)), nil
return handler.DoTransition(p.ttype, handler.PhaseInfoRetryableFailureErr(p.pInfo.Err(), &p.execInfo)), nil
case pluginCore.PhasePermanentFailure:
logger.Debugf(ctx, "Transitioning to Failure")
return handler.DoTransition(p.ttype, handler.PhaseInfoFailureErr(p.pInfo.Err(), nil)), nil
return handler.DoTransition(p.ttype, handler.PhaseInfoFailureErr(p.pInfo.Err(), &p.execInfo)), nil
case pluginCore.PhaseUndefined:
return handler.UnknownTransition, fmt.Errorf("error converting plugin phase, received [Undefined]")
}
Expand Down Expand Up @@ -429,7 +443,7 @@ func (t Handler) invokePlugin(ctx context.Context, p pluginCore.Plugin, tCtx *ta
pluginTrns.pInfo.Phase().String(), p.GetID(), pluginTrns.pInfo.Version(), t.cfg.MaxPluginPhaseVersions),
},
IsRecoverable: false,
})
}, nil)
return pluginTrns, nil
}
}
Expand All @@ -448,7 +462,8 @@ func (t Handler) invokePlugin(ctx context.Context, p pluginCore.Plugin, tCtx *ta
}
}

if pluginTrns.pInfo.Phase() == pluginCore.PhaseSuccess {
switch pluginTrns.pInfo.Phase() {
case pluginCore.PhaseSuccess:
// -------------------------------------
// TODO: @kumare create Issue# Remove the code after we use closures to handle dynamic nodes
// This code only exists to support Dynamic tasks. Eventually dynamic tasks will use closure nodes to execute
Expand Down Expand Up @@ -480,8 +495,12 @@ func (t Handler) invokePlugin(ctx context.Context, p pluginCore.Plugin, tCtx *ta
if err != nil {
return nil, err
}

if ee != nil {
pluginTrns.ObservedExecutionError(ee)
pluginTrns.ObservedExecutionError(ee,
&event.TaskNodeMetadata{
CheckpointUri: tCtx.ow.GetCheckpointPrefix().String(),
})
} else {
var deckURI *storage.DataReference
if tCtx.ow.GetReader() != nil {
Expand All @@ -496,10 +515,18 @@ func (t Handler) invokePlugin(ctx context.Context, p pluginCore.Plugin, tCtx *ta
}
pluginTrns.ObserveSuccess(tCtx.ow.GetOutputPath(), deckURI,
&event.TaskNodeMetadata{
CacheStatus: cacheStatus.GetCacheStatus(),
CatalogKey: cacheStatus.GetMetadata(),
CacheStatus: cacheStatus.GetCacheStatus(),
CatalogKey: cacheStatus.GetMetadata(),
CheckpointUri: tCtx.ow.GetCheckpointPrefix().String(),
})
}
case pluginCore.PhaseRetryableFailure:
fallthrough
case pluginCore.PhasePermanentFailure:
pluginTrns.ObservedFailure(
&event.TaskNodeMetadata{
CheckpointUri: tCtx.ow.GetCheckpointPrefix().String(),
})
}

return pluginTrns, nil
Expand Down Expand Up @@ -710,12 +737,13 @@ func (t Handler) Handle(ctx context.Context, nCtx handler.NodeExecutionContext)

// STEP 6: Persist the plugin state
err = nCtx.NodeStateWriter().PutTaskNodeState(handler.TaskNodeState{
PluginState: pluginTrns.pluginState,
PluginStateVersion: pluginTrns.pluginStateVersion,
PluginPhase: pluginTrns.pInfo.Phase(),
PluginPhaseVersion: pluginTrns.pInfo.Version(),
BarrierClockTick: barrierTick,
LastPhaseUpdatedAt: time.Now(),
PluginState: pluginTrns.pluginState,
PluginStateVersion: pluginTrns.pluginStateVersion,
PluginPhase: pluginTrns.pInfo.Phase(),
PluginPhaseVersion: pluginTrns.pInfo.Version(),
BarrierClockTick: barrierTick,
LastPhaseUpdatedAt: time.Now(),
PreviousNodeExecutionCheckpointURI: ts.PreviousNodeExecutionCheckpointURI,
})
if err != nil {
logger.Errorf(ctx, "Failed to store TaskNode state, err :%s", err.Error())
Expand Down
Loading

0 comments on commit a9b831b

Please sign in to comment.