Skip to content

Commit

Permalink
Handle nodeexec not found for Task (flyteorg#208)
Browse files Browse the repository at this point in the history
  • Loading branch information
anandswaminathan authored Dec 2, 2020
1 parent 2a60437 commit 7944f8a
Show file tree
Hide file tree
Showing 3 changed files with 166 additions and 12 deletions.
6 changes: 5 additions & 1 deletion flytepropeller/pkg/controller/nodes/task/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -667,9 +667,13 @@ func (t Handler) Abort(ctx context.Context, nCtx handler.NodeExecutionContext, r
}
taskExecID := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetID()
evRecorder := nCtx.EventsRecorder()
nodeExecutionID, err := getParentNodeExecIDForTask(&taskExecID, nCtx.ExecutionContext())
if err != nil {
return err
}
if err := evRecorder.RecordTaskEvent(ctx, &event.TaskExecutionEvent{
TaskId: taskExecID.TaskId,
ParentNodeExecutionId: taskExecID.NodeExecutionId,
ParentNodeExecutionId: nodeExecutionID,
RetryAttempt: nCtx.CurrentAttempt(),
Phase: core.TaskExecution_ABORTED,
OccurredAt: ptypes.TimestampNow(),
Expand Down
142 changes: 142 additions & 0 deletions flytepropeller/pkg/controller/nodes/task/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1349,6 +1349,148 @@ func Test_task_Abort(t *testing.T) {
}
}

func Test_task_Abort_v1(t *testing.T) {
createNodeCtx := func(ev *fakeBufferedTaskEventRecorder) *nodeMocks.NodeExecutionContext {
wfExecID := &core.WorkflowExecutionIdentifier{
Project: "project",
Domain: "domain",
Name: "name",
}

nodeID := "n1"

nm := &nodeMocks.NodeExecutionMetadata{}
nm.OnGetAnnotations().Return(map[string]string{})
nm.OnGetNodeExecutionID().Return(&core.NodeExecutionIdentifier{
NodeId: nodeID,
ExecutionId: wfExecID,
})
nm.OnGetK8sServiceAccount().Return("service-account")
nm.OnGetLabels().Return(map[string]string{})
nm.OnGetNamespace().Return("namespace")
nm.OnGetOwnerID().Return(types.NamespacedName{Namespace: "namespace", Name: "name"})
nm.OnGetOwnerReference().Return(v12.OwnerReference{
Kind: "sample",
Name: "name",
})

taskID := &core.Identifier{}
tr := &nodeMocks.TaskReader{}
tr.OnGetTaskID().Return(taskID)
tr.OnGetTaskType().Return("x")

ns := &flyteMocks.ExecutableNodeStatus{}
ns.OnGetDataDir().Return(storage.DataReference("data-dir"))
ns.OnGetOutputDir().Return(storage.DataReference("output-dir"))

res := &v1.ResourceRequirements{}
n := &flyteMocks.ExecutableNode{}
ma := 5
n.OnGetRetryStrategy().Return(&v1alpha1.RetryStrategy{MinAttempts: &ma})
n.OnGetResources().Return(res)

ir := &ioMocks.InputReader{}
nCtx := &nodeMocks.NodeExecutionContext{}
nCtx.OnNodeExecutionMetadata().Return(nm)
nCtx.OnNode().Return(n)
nCtx.OnInputReader().Return(ir)
ds, err := storage.NewDataStore(
&storage.Config{
Type: storage.TypeMemory,
},
promutils.NewTestScope(),
)
assert.NoError(t, err)
nCtx.OnDataStore().Return(ds)
nCtx.OnCurrentAttempt().Return(uint32(1))
nCtx.OnTaskReader().Return(tr)
nCtx.OnMaxDatasetSizeBytes().Return(int64(1))
nCtx.OnNodeStatus().Return(ns)
nCtx.OnNodeID().Return("n1")
nCtx.OnEnqueueOwnerFunc().Return(nil)
nCtx.OnEventsRecorder().Return(ev)

executionContext := &mocks.ExecutionContext{}
executionContext.OnGetExecutionConfig().Return(v1alpha1.ExecutionConfig{})
executionContext.OnGetParentInfo().Return(nil)
executionContext.OnGetEventVersion().Return(v1alpha1.EventVersion1)
nCtx.OnExecutionContext().Return(executionContext)

nCtx.OnRawOutputPrefix().Return("s3://sandbox/")
nCtx.OnOutputShardSelector().Return(ioutils.NewConstantShardSelector([]string{"x"}))

st := bytes.NewBuffer([]byte{})
a := 45
type test struct {
A int
}
cod := codex.GobStateCodec{}
assert.NoError(t, cod.Encode(test{A: a}, st))
nr := &nodeMocks.NodeStateReader{}
nr.OnGetTaskNodeState().Return(handler.TaskNodeState{
PluginState: st.Bytes(),
})
nCtx.OnNodeStateReader().Return(nr)
return nCtx
}

noopRm := CreateNoopResourceManager(context.TODO(), promutils.NewTestScope())

type fields struct {
defaultPluginCallback func() pluginCore.Plugin
}
type args struct {
ev *fakeBufferedTaskEventRecorder
}
tests := []struct {
name string
fields fields
args args
wantErr bool
abortCalled bool
}{
{"no-plugin", fields{defaultPluginCallback: func() pluginCore.Plugin {
return nil
}}, args{nil}, true, false},

{"abort-fails", fields{defaultPluginCallback: func() pluginCore.Plugin {
p := &pluginCoreMocks.Plugin{}
p.On("GetID").Return("id")
p.On("Abort", mock.Anything, mock.Anything).Return(fmt.Errorf("error"))
return p
}}, args{nil}, true, true},
{"abort-success", fields{defaultPluginCallback: func() pluginCore.Plugin {
p := &pluginCoreMocks.Plugin{}
p.On("GetID").Return("id")
p.On("Abort", mock.Anything, mock.Anything).Return(nil)
return p
}}, args{ev: &fakeBufferedTaskEventRecorder{}}, false, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
m := tt.fields.defaultPluginCallback()
tk := Handler{
defaultPlugin: m,
resourceManager: noopRm,
}
nCtx := createNodeCtx(tt.args.ev)
if err := tk.Abort(context.TODO(), nCtx, "reason"); (err != nil) != tt.wantErr {
t.Errorf("Handler.Abort() error = %v, wantErr %v", err, tt.wantErr)
}
c := 0
if tt.abortCalled {
c = 1
if !tt.wantErr {
assert.Len(t, tt.args.ev.evs, 1)
}
}
if m != nil {
m.(*pluginCoreMocks.Plugin).AssertNumberOfCalls(t, "Abort", c)
}
})
}
}

func Test_task_Finalize(t *testing.T) {

wfExecID := &core.WorkflowExecutionIdentifier{
Expand Down
30 changes: 19 additions & 11 deletions flytepropeller/pkg/controller/nodes/task/transformer.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,22 @@ func trimErrorMessage(original string, maxLength int) string {
return original[0:maxLength/2] + original[len(original)-maxLength/2:]
}

func getParentNodeExecIDForTask(taskExecID *core.TaskExecutionIdentifier, execContext executors.ExecutionContext) (*core.NodeExecutionIdentifier, error) {
nodeExecutionID := &core.NodeExecutionIdentifier{
ExecutionId: taskExecID.NodeExecutionId.ExecutionId,
}
if execContext.GetEventVersion() != v1alpha1.EventVersion0 {
currentNodeUniqueID, err := common.GenerateUniqueID(execContext.GetParentInfo(), taskExecID.NodeExecutionId.NodeId)
if err != nil {
return nil, err
}
nodeExecutionID.NodeId = currentNodeUniqueID
} else {
nodeExecutionID.NodeId = taskExecID.NodeExecutionId.NodeId
}
return nodeExecutionID, nil
}

func ToTaskExecutionEvent(taskExecID *core.TaskExecutionIdentifier, in io.InputFilePaths, out io.OutputFilePaths, info pluginCore.PhaseInfo,
nodeExecutionMetadata handler.NodeExecutionMetadata, execContext executors.ExecutionContext) (*event.TaskExecutionEvent, error) {
// Transitions to a new phase
Expand All @@ -66,17 +82,9 @@ func ToTaskExecutionEvent(taskExecID *core.TaskExecutionIdentifier, in io.InputF
}
}

nodeExecutionID := &core.NodeExecutionIdentifier{
ExecutionId: taskExecID.NodeExecutionId.ExecutionId,
}
if execContext.GetEventVersion() != v1alpha1.EventVersion0 {
currentNodeUniqueID, err := common.GenerateUniqueID(execContext.GetParentInfo(), taskExecID.NodeExecutionId.NodeId)
if err != nil {
return nil, err
}
nodeExecutionID.NodeId = currentNodeUniqueID
} else {
nodeExecutionID.NodeId = taskExecID.NodeExecutionId.NodeId
nodeExecutionID, err := getParentNodeExecIDForTask(taskExecID, execContext)
if err != nil {
return nil, err
}
tev := &event.TaskExecutionEvent{
TaskId: taskExecID.TaskId,
Expand Down

0 comments on commit 7944f8a

Please sign in to comment.