diff --git a/pkg/controller/executors/mocks/node_lookup.go b/pkg/controller/executors/mocks/node_lookup.go index 036a0400d9..eac909a110 100644 --- a/pkg/controller/executors/mocks/node_lookup.go +++ b/pkg/controller/executors/mocks/node_lookup.go @@ -15,6 +15,47 @@ type NodeLookup struct { mock.Mock } +type NodeLookup_FromNode struct { + *mock.Call +} + +func (_m NodeLookup_FromNode) Return(_a0 []string, _a1 error) *NodeLookup_FromNode { + return &NodeLookup_FromNode{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *NodeLookup) OnFromNode(id string) *NodeLookup_FromNode { + c_call := _m.On("FromNode", id) + return &NodeLookup_FromNode{Call: c_call} +} + +func (_m *NodeLookup) OnFromNodeMatch(matchers ...interface{}) *NodeLookup_FromNode { + c_call := _m.On("FromNode", matchers...) + return &NodeLookup_FromNode{Call: c_call} +} + +// FromNode provides a mock function with given fields: id +func (_m *NodeLookup) FromNode(id string) ([]string, error) { + ret := _m.Called(id) + + var r0 []string + if rf, ok := ret.Get(0).(func(string) []string); ok { + r0 = rf(id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + type NodeLookup_GetNode struct { *mock.Call } @@ -89,3 +130,44 @@ func (_m *NodeLookup) GetNodeExecutionStatus(ctx context.Context, id string) v1a return r0 } + +type NodeLookup_ToNode struct { + *mock.Call +} + +func (_m NodeLookup_ToNode) Return(_a0 []string, _a1 error) *NodeLookup_ToNode { + return &NodeLookup_ToNode{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *NodeLookup) OnToNode(id string) *NodeLookup_ToNode { + c_call := _m.On("ToNode", id) + return &NodeLookup_ToNode{Call: c_call} +} + +func (_m *NodeLookup) OnToNodeMatch(matchers ...interface{}) *NodeLookup_ToNode { + c_call := _m.On("ToNode", matchers...) + return &NodeLookup_ToNode{Call: c_call} +} + +// ToNode provides a mock function with given fields: id +func (_m *NodeLookup) ToNode(id string) ([]string, error) { + ret := _m.Called(id) + + var r0 []string + if rf, ok := ret.Get(0).(func(string) []string); ok { + r0 = rf(id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} diff --git a/pkg/controller/executors/node_lookup.go b/pkg/controller/executors/node_lookup.go index 9b49dc4ff5..381b832c0e 100644 --- a/pkg/controller/executors/node_lookup.go +++ b/pkg/controller/executors/node_lookup.go @@ -12,21 +12,27 @@ import ( type NodeLookup interface { GetNode(nodeID v1alpha1.NodeID) (v1alpha1.ExecutableNode, bool) GetNodeExecutionStatus(ctx context.Context, id v1alpha1.NodeID) v1alpha1.ExecutableNodeStatus + // Lookup for upstream edges, find all node ids from which this node can be reached. + ToNode(id v1alpha1.NodeID) ([]v1alpha1.NodeID, error) + // Lookup for downstream edges, find all node ids that can be reached from the given node id. + FromNode(id v1alpha1.NodeID) ([]v1alpha1.NodeID, error) } // Implements a contextual NodeLookup that can be composed of a disparate NodeGetter and a NodeStatusGetter type contextualNodeLookup struct { v1alpha1.NodeGetter v1alpha1.NodeStatusGetter + DAGStructure } // Returns a Contextual NodeLookup using the given NodeGetter and a separate NodeStatusGetter. // Very useful in Subworkflows where the Subworkflow is the reservoir of the nodes, but the status for these nodes // maybe stored int he Top-level workflow node itself. -func NewNodeLookup(n v1alpha1.NodeGetter, s v1alpha1.NodeStatusGetter) NodeLookup { +func NewNodeLookup(n v1alpha1.NodeGetter, s v1alpha1.NodeStatusGetter, d DAGStructure) NodeLookup { return contextualNodeLookup{ NodeGetter: n, NodeStatusGetter: s, + DAGStructure: d, } } @@ -45,6 +51,14 @@ func (s staticNodeLookup) GetNodeExecutionStatus(_ context.Context, id v1alpha1. return s.status[id] } +func (s staticNodeLookup) ToNode(id v1alpha1.NodeID) ([]v1alpha1.NodeID, error) { + return nil, nil +} + +func (s staticNodeLookup) FromNode(id v1alpha1.NodeID) ([]v1alpha1.NodeID, error) { + return nil, nil +} + // Returns a new NodeLookup useful in Testing. Not recommended to be used in production func NewTestNodeLookup(nodes map[v1alpha1.NodeID]v1alpha1.ExecutableNode, status map[v1alpha1.NodeID]v1alpha1.ExecutableNodeStatus) NodeLookup { return staticNodeLookup{ diff --git a/pkg/controller/executors/node_lookup_test.go b/pkg/controller/executors/node_lookup_test.go index a86b00b081..4bce76138c 100644 --- a/pkg/controller/executors/node_lookup_test.go +++ b/pkg/controller/executors/node_lookup_test.go @@ -18,14 +18,20 @@ type nsg struct { v1alpha1.NodeStatusGetter } +type dag struct { + DAGStructure +} + func TestNewNodeLookup(t *testing.T) { n := ng{} ns := nsg{} - nl := NewNodeLookup(n, ns) + d := dag{} + nl := NewNodeLookup(n, ns, d) assert.NotNil(t, nl) typed := nl.(contextualNodeLookup) assert.Equal(t, n, typed.NodeGetter) assert.Equal(t, ns, typed.NodeStatusGetter) + assert.Equal(t, d, typed.DAGStructure) } func TestNewTestNodeLookup(t *testing.T) { diff --git a/pkg/controller/nodes/branch/handler.go b/pkg/controller/nodes/branch/handler.go index 9b0cd7f596..109290b908 100644 --- a/pkg/controller/nodes/branch/handler.go +++ b/pkg/controller/nodes/branch/handler.go @@ -136,7 +136,11 @@ func (b *branchHandler) recurseDownstream(ctx context.Context, nCtx handler.Node childNodeStatus := nl.GetNodeExecutionStatus(ctx, branchTakenNode.GetID()) childNodeStatus.SetDataDir(nodeStatus.GetDataDir()) childNodeStatus.SetOutputDir(nodeStatus.GetOutputDir()) - dag := executors.NewLeafNodeDAGStructure(branchTakenNode.GetID(), nCtx.NodeID()) + upstreamNodeIds, err := nCtx.ContextualNodeLookup().ToNode(branchTakenNode.GetID()) + if err != nil { + return handler.UnknownTransition, err + } + dag := executors.NewLeafNodeDAGStructure(branchTakenNode.GetID(), append(upstreamNodeIds, nCtx.NodeID())...) execContext, err := b.getExecutionContextForDownstream(nCtx) if err != nil { return handler.UnknownTransition, err @@ -196,7 +200,11 @@ func (b *branchHandler) Abort(ctx context.Context, nCtx handler.NodeExecutionCon // TODO we should replace the call to RecursiveNodeHandler with a call to SingleNode Handler. The inputs are also already known ahead of time // There is no DAGStructure for the branch nodes, the branch taken node is the leaf node. The node itself may be arbitrarily complex, but in that case the node should reference a subworkflow etc // The parent of the BranchTaken Node is the actual Branch Node and all the data is just forwarded from the Branch to the executed node. - dag := executors.NewLeafNodeDAGStructure(branchTakenNode.GetID(), nCtx.NodeID()) + upstreamNodeIds, err := nCtx.ContextualNodeLookup().ToNode(branchTakenNode.GetID()) + if err != nil { + return err + } + dag := executors.NewLeafNodeDAGStructure(branchTakenNode.GetID(), append(upstreamNodeIds, nCtx.NodeID())...) execContext, err := b.getExecutionContextForDownstream(nCtx) if err != nil { return err @@ -236,7 +244,11 @@ func (b *branchHandler) Finalize(ctx context.Context, nCtx handler.NodeExecution // TODO we should replace the call to RecursiveNodeHandler with a call to SingleNode Handler. The inputs are also already known ahead of time // There is no DAGStructure for the branch nodes, the branch taken node is the leaf node. The node itself may be arbitrarily complex, but in that case the node should reference a subworkflow etc // The parent of the BranchTaken Node is the actual Branch Node and all the data is just forwarded from the Branch to the executed node. - dag := executors.NewLeafNodeDAGStructure(branchTakenNode.GetID(), nCtx.NodeID()) + upstreamNodeIds, err := nCtx.ContextualNodeLookup().ToNode(branchTakenNode.GetID()) + if err != nil { + return err + } + dag := executors.NewLeafNodeDAGStructure(branchTakenNode.GetID(), append(upstreamNodeIds, nCtx.NodeID())...) execContext, err := b.getExecutionContextForDownstream(nCtx) if err != nil { return err diff --git a/pkg/controller/nodes/branch/handler_test.go b/pkg/controller/nodes/branch/handler_test.go index bc39c1b243..5711de5d42 100644 --- a/pkg/controller/nodes/branch/handler_test.go +++ b/pkg/controller/nodes/branch/handler_test.go @@ -158,24 +158,34 @@ func TestBranchHandler_RecurseDownstream(t *testing.T) { isErr bool expectedPhase handler.EPhase childPhase v1alpha1.NodePhase - nl *execMocks.NodeLookup + upstreamNodeID string }{ + {"upstreamNodeExists", executors.NodeStatusPending, nil, + &mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseRunning, v1alpha1.NodePhaseQueued, "n2"}, {"childNodeError", executors.NodeStatusUndefined, fmt.Errorf("err"), - &mocks2.ExecutableNodeStatus{}, bn, true, handler.EPhaseUndefined, v1alpha1.NodePhaseFailed, &execMocks.NodeLookup{}}, + &mocks2.ExecutableNodeStatus{}, bn, true, handler.EPhaseUndefined, v1alpha1.NodePhaseFailed, ""}, {"childPending", executors.NodeStatusPending, nil, - &mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseRunning, v1alpha1.NodePhaseQueued, &execMocks.NodeLookup{}}, + &mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseRunning, v1alpha1.NodePhaseQueued, ""}, {"childStillRunning", executors.NodeStatusRunning, nil, - &mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseRunning, v1alpha1.NodePhaseRunning, &execMocks.NodeLookup{}}, + &mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseRunning, v1alpha1.NodePhaseRunning, ""}, {"childFailure", executors.NodeStatusFailed(expectedError), nil, - &mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseFailed, v1alpha1.NodePhaseFailed, &execMocks.NodeLookup{}}, + &mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseFailed, v1alpha1.NodePhaseFailed, ""}, {"childComplete", executors.NodeStatusComplete, nil, - &mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseSuccess, v1alpha1.NodePhaseSucceeded, &execMocks.NodeLookup{}}, + &mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseSuccess, v1alpha1.NodePhaseSucceeded, ""}, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { eCtx := &execMocks.ExecutionContext{} eCtx.OnGetParentInfo().Return(parentInfo{}) - nCtx, _ := createNodeContext(v1alpha1.BranchNodeNotYetEvaluated, &childNodeID, n, nil, test.nl, eCtx) + + mockNodeLookup := &execMocks.NodeLookup{} + if len(test.upstreamNodeID) > 0 { + mockNodeLookup.OnToNodeMatch(childNodeID).Return([]string{test.upstreamNodeID}, nil) + } else { + mockNodeLookup.OnToNodeMatch(childNodeID).Return(nil, nil) + } + + nCtx, _ := createNodeContext(v1alpha1.BranchNodeNotYetEvaluated, &childNodeID, n, nil, mockNodeLookup, eCtx) newParentInfo, _ := common.CreateParentInfo(parentInfo{}, nCtx.NodeID(), nCtx.CurrentAttempt()) expectedExecContext := executors.NewExecutionContextWithParentInfo(nCtx.ExecutionContext(), newParentInfo) mockNodeExecutor := &execMocks.Node{} @@ -187,23 +197,27 @@ func TestBranchHandler_RecurseDownstream(t *testing.T) { fList, err1 := d.FromNode("x") dList, err2 := d.ToNode(childNodeID) b := assert.NoError(t, err1) - b = b && assert.Equal(t, fList, []v1alpha1.NodeID{}) + b = b && assert.Equal(t, []v1alpha1.NodeID{}, fList) b = b && assert.NoError(t, err2) - b = b && assert.Equal(t, dList, []v1alpha1.NodeID{nodeID}) + dListExpected := []v1alpha1.NodeID{nodeID} + if len(test.upstreamNodeID) > 0 { + dListExpected = append([]string{test.upstreamNodeID}, dListExpected...) + } + b = b && assert.Equal(t, dListExpected, dList) return b } return false }), - mock.MatchedBy(func(lookup executors.NodeLookup) bool { return assert.Equal(t, lookup, test.nl) }), + mock.MatchedBy(func(lookup executors.NodeLookup) bool { return assert.Equal(t, lookup, mockNodeLookup) }), mock.MatchedBy(func(n v1alpha1.ExecutableNode) bool { return assert.Equal(t, n.GetID(), childNodeID) }), ).Return(test.ns, test.err) childNodeStatus := &mocks2.ExecutableNodeStatus{} - if test.nl != nil { + if mockNodeLookup != nil { childNodeStatus.OnGetOutputDir().Return("parent-output-dir") test.nodeStatus.OnGetDataDir().Return("parent-data-dir") test.nodeStatus.OnGetOutputDir().Return("parent-output-dir") - test.nl.OnGetNodeExecutionStatus(ctx, childNodeID).Return(childNodeStatus) + mockNodeLookup.OnGetNodeExecutionStatus(ctx, childNodeID).Return(childNodeStatus) childNodeStatus.On("SetDataDir", storage.DataReference("parent-data-dir")).Once() childNodeStatus.On("SetOutputDir", storage.DataReference("parent-output-dir")).Once() } @@ -295,17 +309,18 @@ func TestBranchHandler_AbortNode(t *testing.T) { t.Run("BranchNodeSuccess", func(t *testing.T) { mockNodeExecutor := &execMocks.Node{} - nl := &execMocks.NodeLookup{} + mockNodeLookup := &execMocks.NodeLookup{} + mockNodeLookup.OnToNodeMatch(mock.Anything).Return(nil, nil) eCtx := &execMocks.ExecutionContext{} eCtx.OnGetParentInfo().Return(parentInfo{}) - nCtx, s := createNodeContext(v1alpha1.BranchNodeSuccess, &n1, n, nil, nl, eCtx) + nCtx, s := createNodeContext(v1alpha1.BranchNodeSuccess, &n1, n, nil, mockNodeLookup, eCtx) newParentInfo, _ := common.CreateParentInfo(parentInfo{}, nCtx.NodeID(), nCtx.CurrentAttempt()) expectedExecContext := executors.NewExecutionContextWithParentInfo(nCtx.ExecutionContext(), newParentInfo) mockNodeExecutor.OnAbortHandlerMatch(mock.Anything, mock.MatchedBy(func(e executors.ExecutionContext) bool { return assert.Equal(t, e, expectedExecContext) }), mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) - nl.OnGetNode(*s.s.FinalizedNodeID).Return(n, true) + mockNodeLookup.OnGetNode(*s.s.FinalizedNodeID).Return(n, true) branch := New(mockNodeExecutor, eventConfig, promutils.NewTestScope()) err := branch.Abort(ctx, nCtx, "") assert.NoError(t, err) diff --git a/pkg/controller/nodes/dynamic/dynamic_workflow.go b/pkg/controller/nodes/dynamic/dynamic_workflow.go index ef1fc51c34..eb891aa279 100644 --- a/pkg/controller/nodes/dynamic/dynamic_workflow.go +++ b/pkg/controller/nodes/dynamic/dynamic_workflow.go @@ -183,7 +183,7 @@ func (d dynamicNodeTaskNodeHandler) buildContextualDynamicWorkflow(ctx context.C subWorkflow: compiledWf, subWorkflowClosure: workflowCacheContents.CompiledWorkflow, execContext: executors.NewExecutionContext(nCtx.ExecutionContext(), compiledWf, compiledWf, newParentInfo, nCtx.ExecutionContext()), - nodeLookup: executors.NewNodeLookup(compiledWf, dynamicNodeStatus), + nodeLookup: executors.NewNodeLookup(compiledWf, dynamicNodeStatus, compiledWf), dynamicJobSpecURI: string(f.GetLoc()), }, nil } @@ -216,7 +216,7 @@ func (d dynamicNodeTaskNodeHandler) buildContextualDynamicWorkflow(ctx context.C subWorkflow: dynamicWf, subWorkflowClosure: closure, execContext: executors.NewExecutionContext(nCtx.ExecutionContext(), dynamicWf, dynamicWf, newParentInfo, nCtx.ExecutionContext()), - nodeLookup: executors.NewNodeLookup(dynamicWf, dynamicNodeStatus), + nodeLookup: executors.NewNodeLookup(dynamicWf, dynamicNodeStatus, dynamicWf), dynamicJobSpecURI: string(f.GetLoc()), }, nil } diff --git a/pkg/controller/nodes/subworkflow/subworkflow.go b/pkg/controller/nodes/subworkflow/subworkflow.go index 24d74473f0..74beeaf792 100644 --- a/pkg/controller/nodes/subworkflow/subworkflow.go +++ b/pkg/controller/nodes/subworkflow/subworkflow.go @@ -209,7 +209,7 @@ func (s *subworkflowHandler) HandleFailingSubWorkflow(ctx context.Context, nCtx } status := nCtx.NodeStatus() - nodeLookup := executors.NewNodeLookup(subWorkflow, status) + nodeLookup := executors.NewNodeLookup(subWorkflow, status, subWorkflow) return s.HandleFailureNodeOfSubWorkflow(ctx, nCtx, subWorkflow, nodeLookup) } @@ -220,7 +220,7 @@ func (s *subworkflowHandler) StartSubWorkflow(ctx context.Context, nCtx handler. } status := nCtx.NodeStatus() - nodeLookup := executors.NewNodeLookup(subWorkflow, status) + nodeLookup := executors.NewNodeLookup(subWorkflow, status, subWorkflow) // assert startStatus.IsComplete() == true return s.startAndHandleSubWorkflow(ctx, nCtx, subWorkflow, nodeLookup) @@ -233,7 +233,7 @@ func (s *subworkflowHandler) CheckSubWorkflowStatus(ctx context.Context, nCtx ha } status := nCtx.NodeStatus() - nodeLookup := executors.NewNodeLookup(subWorkflow, status) + nodeLookup := executors.NewNodeLookup(subWorkflow, status, subWorkflow) return s.handleSubWorkflow(ctx, nCtx, subWorkflow, nodeLookup) } @@ -243,7 +243,7 @@ func (s *subworkflowHandler) HandleAbort(ctx context.Context, nCtx handler.NodeE return err } status := nCtx.NodeStatus() - nodeLookup := executors.NewNodeLookup(subWorkflow, status) + nodeLookup := executors.NewNodeLookup(subWorkflow, status, subWorkflow) execContext, err := s.getExecutionContextForDownstream(nCtx) if err != nil { return err diff --git a/pkg/controller/workflow/executor.go b/pkg/controller/workflow/executor.go index 92a870c1c2..e3eac1e37c 100644 --- a/pkg/controller/workflow/executor.go +++ b/pkg/controller/workflow/executor.go @@ -124,7 +124,7 @@ func (c *workflowExecutor) handleReadyWorkflow(ctx context.Context, w *v1alpha1. nodeStatus.SetDataDir(dataDir) nodeStatus.SetOutputDir(outputDir) execcontext := executors.NewExecutionContext(w, w, w, nil, executors.InitializeControlFlow()) - s, err := c.nodeExecutor.SetInputsForStartNode(ctx, execcontext, w, executors.NewNodeLookup(w, w.GetExecutionStatus()), inputs) + s, err := c.nodeExecutor.SetInputsForStartNode(ctx, execcontext, w, executors.NewNodeLookup(w, w.GetExecutionStatus(), w), inputs) if err != nil { return StatusReady, err }