From 0025cbea7084089597691d88fb698bd8d7ae45f0 Mon Sep 17 00:00:00 2001 From: Dan Rammer Date: Fri, 28 Apr 2023 14:37:10 -0500 Subject: [PATCH] Added support for aborting task nodes reported as failures (#541) * added CleanupOnFailure support for TaskNodeStatus to support aborting failed task nodes Signed-off-by: Daniel Rammer * updated flyteplugins and generated Signed-off-by: Daniel Rammer * updated flyteplugins Signed-off-by: Daniel Rammer --------- Signed-off-by: Daniel Rammer --- go.mod | 2 +- go.sum | 4 +- pkg/apis/flyteworkflow/v1alpha1/iface.go | 2 + .../mocks/ExecutableTaskNodeStatus.go | 32 ++++++++++++++++ .../v1alpha1/mocks/MutableTaskNodeStatus.go | 37 +++++++++++++++++++ .../flyteworkflow/v1alpha1/node_status.go | 10 +++++ pkg/controller/nodes/handler/state.go | 1 + pkg/controller/nodes/node_state_manager.go | 1 + pkg/controller/nodes/task/handler.go | 6 ++- pkg/controller/nodes/transformers.go | 1 + 10 files changed, 91 insertions(+), 5 deletions(-) diff --git a/go.mod b/go.mod index a9fa8bc964..d0d6abb601 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/benlaurie/objecthash v0.0.0-20180202135721-d1e3d6079fc1 github.com/fatih/color v1.13.0 github.com/flyteorg/flyteidl v1.3.14 - github.com/flyteorg/flyteplugins v1.0.49 + github.com/flyteorg/flyteplugins v1.0.52 github.com/flyteorg/flytestdlib v1.0.15 github.com/ghodss/yaml v1.0.0 github.com/go-redis/redis v6.15.7+incompatible diff --git a/go.sum b/go.sum index 05eafa4e92..97d3b363c1 100644 --- a/go.sum +++ b/go.sum @@ -262,8 +262,8 @@ github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYF github.com/felixge/httpsnoop v1.0.1/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/flyteorg/flyteidl v1.3.14 h1:o5M0g/r6pXTPu5PEurbYxbQmuOu3hqqsaI2M6uvK0N8= github.com/flyteorg/flyteidl v1.3.14/go.mod h1:Pkt2skI1LiHs/2ZoekBnyPhuGOFMiuul6HHcKGZBsbM= -github.com/flyteorg/flyteplugins v1.0.49 h1:lUmT4kqYamkJY2tO6nCWRCnVv2M2QNLIap5bFYAol7s= -github.com/flyteorg/flyteplugins v1.0.49/go.mod h1:ztsonku5fKwyxcIg1k69PTiBVjRI6d3nK5DnC+iwx08= +github.com/flyteorg/flyteplugins v1.0.52 h1:AWNrRYgm0bCzOws+bIfJDfPBZqBmTdABxW78r8q3kP4= +github.com/flyteorg/flyteplugins v1.0.52/go.mod h1:ztsonku5fKwyxcIg1k69PTiBVjRI6d3nK5DnC+iwx08= github.com/flyteorg/flytestdlib v1.0.15 h1:kv9jDQmytbE84caY+pkZN8trJU2ouSAmESzpTEhfTt0= github.com/flyteorg/flytestdlib v1.0.15/go.mod h1:ghw/cjY0sEWIIbyCtcJnL/Gt7ZS7gf9SUi0CCPhbz3s= github.com/flyteorg/stow v0.3.6 h1:jt50ciM14qhKBaIrB+ppXXY+SXB59FNREFgTJqCyqIk= diff --git a/pkg/apis/flyteworkflow/v1alpha1/iface.go b/pkg/apis/flyteworkflow/v1alpha1/iface.go index c52361b239..8b6c02318f 100644 --- a/pkg/apis/flyteworkflow/v1alpha1/iface.go +++ b/pkg/apis/flyteworkflow/v1alpha1/iface.go @@ -359,6 +359,7 @@ type ExecutableTaskNodeStatus interface { GetBarrierClockTick() uint32 GetLastPhaseUpdatedAt() time.Time GetPreviousNodeExecutionCheckpointPath() DataReference + GetCleanupOnFailure() bool } type MutableTaskNodeStatus interface { @@ -371,6 +372,7 @@ type MutableTaskNodeStatus interface { SetPluginStateVersion(uint32) SetBarrierClockTick(tick uint32) SetPreviousNodeExecutionCheckpointPath(DataReference) + SetCleanupOnFailure(bool) } // ExecutableWorkflowNode is an interface for a Child Workflow Node diff --git a/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableTaskNodeStatus.go b/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableTaskNodeStatus.go index 3d52a25ea9..899cab1b44 100644 --- a/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableTaskNodeStatus.go +++ b/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableTaskNodeStatus.go @@ -46,6 +46,38 @@ func (_m *ExecutableTaskNodeStatus) GetBarrierClockTick() uint32 { return r0 } +type ExecutableTaskNodeStatus_GetCleanupOnFailure struct { + *mock.Call +} + +func (_m ExecutableTaskNodeStatus_GetCleanupOnFailure) Return(_a0 bool) *ExecutableTaskNodeStatus_GetCleanupOnFailure { + return &ExecutableTaskNodeStatus_GetCleanupOnFailure{Call: _m.Call.Return(_a0)} +} + +func (_m *ExecutableTaskNodeStatus) OnGetCleanupOnFailure() *ExecutableTaskNodeStatus_GetCleanupOnFailure { + c_call := _m.On("GetCleanupOnFailure") + return &ExecutableTaskNodeStatus_GetCleanupOnFailure{Call: c_call} +} + +func (_m *ExecutableTaskNodeStatus) OnGetCleanupOnFailureMatch(matchers ...interface{}) *ExecutableTaskNodeStatus_GetCleanupOnFailure { + c_call := _m.On("GetCleanupOnFailure", matchers...) + return &ExecutableTaskNodeStatus_GetCleanupOnFailure{Call: c_call} +} + +// GetCleanupOnFailure provides a mock function with given fields: +func (_m *ExecutableTaskNodeStatus) GetCleanupOnFailure() bool { + ret := _m.Called() + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + type ExecutableTaskNodeStatus_GetLastPhaseUpdatedAt struct { *mock.Call } diff --git a/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableTaskNodeStatus.go b/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableTaskNodeStatus.go index 83c0fca19c..4eea4400d4 100644 --- a/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableTaskNodeStatus.go +++ b/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableTaskNodeStatus.go @@ -46,6 +46,38 @@ func (_m *MutableTaskNodeStatus) GetBarrierClockTick() uint32 { return r0 } +type MutableTaskNodeStatus_GetCleanupOnFailure struct { + *mock.Call +} + +func (_m MutableTaskNodeStatus_GetCleanupOnFailure) Return(_a0 bool) *MutableTaskNodeStatus_GetCleanupOnFailure { + return &MutableTaskNodeStatus_GetCleanupOnFailure{Call: _m.Call.Return(_a0)} +} + +func (_m *MutableTaskNodeStatus) OnGetCleanupOnFailure() *MutableTaskNodeStatus_GetCleanupOnFailure { + c_call := _m.On("GetCleanupOnFailure") + return &MutableTaskNodeStatus_GetCleanupOnFailure{Call: c_call} +} + +func (_m *MutableTaskNodeStatus) OnGetCleanupOnFailureMatch(matchers ...interface{}) *MutableTaskNodeStatus_GetCleanupOnFailure { + c_call := _m.On("GetCleanupOnFailure", matchers...) + return &MutableTaskNodeStatus_GetCleanupOnFailure{Call: c_call} +} + +// GetCleanupOnFailure provides a mock function with given fields: +func (_m *MutableTaskNodeStatus) GetCleanupOnFailure() bool { + ret := _m.Called() + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + type MutableTaskNodeStatus_GetLastPhaseUpdatedAt struct { *mock.Call } @@ -277,6 +309,11 @@ func (_m *MutableTaskNodeStatus) SetBarrierClockTick(tick uint32) { _m.Called(tick) } +// SetCleanupOnFailure provides a mock function with given fields: _a0 +func (_m *MutableTaskNodeStatus) SetCleanupOnFailure(_a0 bool) { + _m.Called(_a0) +} + // SetLastPhaseUpdatedAt provides a mock function with given fields: updatedAt func (_m *MutableTaskNodeStatus) SetLastPhaseUpdatedAt(updatedAt time.Time) { _m.Called(updatedAt) diff --git a/pkg/apis/flyteworkflow/v1alpha1/node_status.go b/pkg/apis/flyteworkflow/v1alpha1/node_status.go index 7aea3f2b8e..282add2389 100644 --- a/pkg/apis/flyteworkflow/v1alpha1/node_status.go +++ b/pkg/apis/flyteworkflow/v1alpha1/node_status.go @@ -765,6 +765,7 @@ type TaskNodeStatus struct { BarrierClockTick uint32 `json:"tick,omitempty"` LastPhaseUpdatedAt time.Time `json:"updAt,omitempty"` PreviousNodeExecutionCheckpointPath DataReference `json:"checkpointPath,omitempty"` + CleanupOnFailure bool `json:"clean,omitempty"` } func (in *TaskNodeStatus) GetBarrierClockTick() uint32 { @@ -795,6 +796,11 @@ func (in *TaskNodeStatus) SetPluginStateVersion(v uint32) { in.SetDirty() } +func (in *TaskNodeStatus) SetCleanupOnFailure(cleanupOnFailure bool) { + in.CleanupOnFailure = cleanupOnFailure + in.SetDirty() +} + func (in *TaskNodeStatus) GetPluginState() []byte { return in.PluginState } @@ -829,6 +835,10 @@ func (in TaskNodeStatus) GetPhaseVersion() uint32 { return in.PhaseVersion } +func (in TaskNodeStatus) GetCleanupOnFailure() bool { + return in.CleanupOnFailure +} + func (in *TaskNodeStatus) UpdatePhase(phase int, phaseVersion uint32) { if in.Phase != phase || in.PhaseVersion != phaseVersion { in.SetDirty() diff --git a/pkg/controller/nodes/handler/state.go b/pkg/controller/nodes/handler/state.go index 697b4959d6..2ca4fb015d 100644 --- a/pkg/controller/nodes/handler/state.go +++ b/pkg/controller/nodes/handler/state.go @@ -20,6 +20,7 @@ type TaskNodeState struct { PluginStateVersion uint32 LastPhaseUpdatedAt time.Time PreviousNodeExecutionCheckpointURI storage.DataReference + CleanupOnFailure bool } type BranchNodeState struct { diff --git a/pkg/controller/nodes/node_state_manager.go b/pkg/controller/nodes/node_state_manager.go index 84bd289b59..89347f79ba 100644 --- a/pkg/controller/nodes/node_state_manager.go +++ b/pkg/controller/nodes/node_state_manager.go @@ -53,6 +53,7 @@ func (n nodeStateManager) GetTaskNodeState() handler.TaskNodeState { PluginState: tn.GetPluginState(), LastPhaseUpdatedAt: tn.GetLastPhaseUpdatedAt(), PreviousNodeExecutionCheckpointURI: tn.GetPreviousNodeExecutionCheckpointPath(), + CleanupOnFailure: tn.GetCleanupOnFailure(), } } return handler.TaskNodeState{} diff --git a/pkg/controller/nodes/task/handler.go b/pkg/controller/nodes/task/handler.go index 5fe48de1c9..d7300818b4 100644 --- a/pkg/controller/nodes/task/handler.go +++ b/pkg/controller/nodes/task/handler.go @@ -747,6 +747,7 @@ func (t Handler) Handle(ctx context.Context, nCtx handler.NodeExecutionContext) PluginPhaseVersion: pluginTrns.pInfo.Version(), LastPhaseUpdatedAt: time.Now(), PreviousNodeExecutionCheckpointURI: ts.PreviousNodeExecutionCheckpointURI, + CleanupOnFailure: ts.CleanupOnFailure || pluginTrns.pInfo.CleanupOnFailure(), }) if err != nil { logger.Errorf(ctx, "Failed to store TaskNode state, err :%s", err.Error()) @@ -761,10 +762,11 @@ func (t Handler) Handle(ctx context.Context, nCtx handler.NodeExecutionContext) } func (t Handler) Abort(ctx context.Context, nCtx handler.NodeExecutionContext, reason string) error { - currentPhase := nCtx.NodeStateReader().GetTaskNodeState().PluginPhase + taskNodeState := nCtx.NodeStateReader().GetTaskNodeState() + currentPhase := taskNodeState.PluginPhase logger.Debugf(ctx, "Abort invoked with phase [%v]", currentPhase) - if currentPhase.IsTerminal() { + if currentPhase.IsTerminal() && !(currentPhase.IsFailure() && taskNodeState.CleanupOnFailure) { logger.Debugf(ctx, "Returning immediately from Abort since task is already in terminal phase.", currentPhase) return nil } diff --git a/pkg/controller/nodes/transformers.go b/pkg/controller/nodes/transformers.go index 51742c2a38..ae615a9f30 100644 --- a/pkg/controller/nodes/transformers.go +++ b/pkg/controller/nodes/transformers.go @@ -240,6 +240,7 @@ func UpdateNodeStatus(np v1alpha1.NodePhase, p handler.PhaseInfo, n *nodeStateMa t.SetPluginState(n.t.PluginState) t.SetPluginStateVersion(n.t.PluginStateVersion) t.SetPreviousNodeExecutionCheckpointPath(n.t.PreviousNodeExecutionCheckpointURI) + t.SetCleanupOnFailure(n.t.CleanupOnFailure) } // Update dynamic node status