From c8cfb78d392361342a7fe62e38428699638e5ca4 Mon Sep 17 00:00:00 2001
From: Daniel Rammer <daniel@union.ai>
Date: Wed, 22 Mar 2023 18:01:34 -0500
Subject: [PATCH 1/4] waiting for upstream nodes on branch subnode evaluation

Signed-off-by: Daniel Rammer <daniel@union.ai>
---
 pkg/controller/executors/mocks/node_lookup.go | 82 +++++++++++++++++++
 pkg/controller/executors/node_lookup.go       | 18 +++-
 pkg/controller/executors/node_lookup_test.go  |  8 +-
 pkg/controller/nodes/branch/handler.go        | 18 +++-
 pkg/controller/nodes/branch/handler_test.go   | 30 ++++---
 .../nodes/dynamic/dynamic_workflow.go         |  4 +-
 .../nodes/subworkflow/subworkflow.go          |  8 +-
 pkg/controller/workflow/executor.go           |  2 +-
 8 files changed, 145 insertions(+), 25 deletions(-)

diff --git a/pkg/controller/executors/mocks/node_lookup.go b/pkg/controller/executors/mocks/node_lookup.go
index 036a0400d..eac909a11 100644
--- a/pkg/controller/executors/mocks/node_lookup.go
+++ b/pkg/controller/executors/mocks/node_lookup.go
@@ -15,6 +15,47 @@ type NodeLookup struct {
 	mock.Mock
 }
 
+type NodeLookup_FromNode struct {
+	*mock.Call
+}
+
+func (_m NodeLookup_FromNode) Return(_a0 []string, _a1 error) *NodeLookup_FromNode {
+	return &NodeLookup_FromNode{Call: _m.Call.Return(_a0, _a1)}
+}
+
+func (_m *NodeLookup) OnFromNode(id string) *NodeLookup_FromNode {
+	c_call := _m.On("FromNode", id)
+	return &NodeLookup_FromNode{Call: c_call}
+}
+
+func (_m *NodeLookup) OnFromNodeMatch(matchers ...interface{}) *NodeLookup_FromNode {
+	c_call := _m.On("FromNode", matchers...)
+	return &NodeLookup_FromNode{Call: c_call}
+}
+
+// FromNode provides a mock function with given fields: id
+func (_m *NodeLookup) FromNode(id string) ([]string, error) {
+	ret := _m.Called(id)
+
+	var r0 []string
+	if rf, ok := ret.Get(0).(func(string) []string); ok {
+		r0 = rf(id)
+	} else {
+		if ret.Get(0) != nil {
+			r0 = ret.Get(0).([]string)
+		}
+	}
+
+	var r1 error
+	if rf, ok := ret.Get(1).(func(string) error); ok {
+		r1 = rf(id)
+	} else {
+		r1 = ret.Error(1)
+	}
+
+	return r0, r1
+}
+
 type NodeLookup_GetNode struct {
 	*mock.Call
 }
@@ -89,3 +130,44 @@ func (_m *NodeLookup) GetNodeExecutionStatus(ctx context.Context, id string) v1a
 
 	return r0
 }
+
+type NodeLookup_ToNode struct {
+	*mock.Call
+}
+
+func (_m NodeLookup_ToNode) Return(_a0 []string, _a1 error) *NodeLookup_ToNode {
+	return &NodeLookup_ToNode{Call: _m.Call.Return(_a0, _a1)}
+}
+
+func (_m *NodeLookup) OnToNode(id string) *NodeLookup_ToNode {
+	c_call := _m.On("ToNode", id)
+	return &NodeLookup_ToNode{Call: c_call}
+}
+
+func (_m *NodeLookup) OnToNodeMatch(matchers ...interface{}) *NodeLookup_ToNode {
+	c_call := _m.On("ToNode", matchers...)
+	return &NodeLookup_ToNode{Call: c_call}
+}
+
+// ToNode provides a mock function with given fields: id
+func (_m *NodeLookup) ToNode(id string) ([]string, error) {
+	ret := _m.Called(id)
+
+	var r0 []string
+	if rf, ok := ret.Get(0).(func(string) []string); ok {
+		r0 = rf(id)
+	} else {
+		if ret.Get(0) != nil {
+			r0 = ret.Get(0).([]string)
+		}
+	}
+
+	var r1 error
+	if rf, ok := ret.Get(1).(func(string) error); ok {
+		r1 = rf(id)
+	} else {
+		r1 = ret.Error(1)
+	}
+
+	return r0, r1
+}
diff --git a/pkg/controller/executors/node_lookup.go b/pkg/controller/executors/node_lookup.go
index 9b49dc4ff..0348b8b18 100644
--- a/pkg/controller/executors/node_lookup.go
+++ b/pkg/controller/executors/node_lookup.go
@@ -12,21 +12,27 @@ import (
 type NodeLookup interface {
 	GetNode(nodeID v1alpha1.NodeID) (v1alpha1.ExecutableNode, bool)
 	GetNodeExecutionStatus(ctx context.Context, id v1alpha1.NodeID) v1alpha1.ExecutableNodeStatus
+	// Lookup for upstream edges, find all node ids from which this node can be reached.
+	ToNode(id v1alpha1.NodeID) ([]v1alpha1.NodeID, error)
+	// Lookup for downstream edges, find all node ids that can be reached from the given node id.
+	FromNode(id v1alpha1.NodeID) ([]v1alpha1.NodeID, error)
 }
 
 // Implements a contextual NodeLookup that can be composed of a disparate NodeGetter and a NodeStatusGetter
 type contextualNodeLookup struct {
 	v1alpha1.NodeGetter
 	v1alpha1.NodeStatusGetter
+	DAGStructure
 }
 
 // Returns a Contextual NodeLookup using the given NodeGetter and a separate NodeStatusGetter.
 // Very useful in Subworkflows where the Subworkflow is the reservoir of the nodes, but the status for these nodes
 // maybe stored int he Top-level workflow node itself.
-func NewNodeLookup(n v1alpha1.NodeGetter, s v1alpha1.NodeStatusGetter) NodeLookup {
+func NewNodeLookup(n v1alpha1.NodeGetter, s v1alpha1.NodeStatusGetter, d DAGStructure) NodeLookup {
 	return contextualNodeLookup{
 		NodeGetter:       n,
 		NodeStatusGetter: s,
+		DAGStructure:     d,
 	}
 }
 
@@ -45,6 +51,16 @@ func (s staticNodeLookup) GetNodeExecutionStatus(_ context.Context, id v1alpha1.
 	return s.status[id]
 }
 
+// TODO @hamersaw implement
+func (s staticNodeLookup) ToNode(id v1alpha1.NodeID) ([]v1alpha1.NodeID, error) {
+	return nil, nil
+}
+
+// TODO @hamersaw implement
+func (s staticNodeLookup) FromNode(id v1alpha1.NodeID) ([]v1alpha1.NodeID, error) {
+	return nil, nil
+}
+
 // Returns a new NodeLookup useful in Testing. Not recommended to be used in production
 func NewTestNodeLookup(nodes map[v1alpha1.NodeID]v1alpha1.ExecutableNode, status map[v1alpha1.NodeID]v1alpha1.ExecutableNodeStatus) NodeLookup {
 	return staticNodeLookup{
diff --git a/pkg/controller/executors/node_lookup_test.go b/pkg/controller/executors/node_lookup_test.go
index a86b00b08..4bce76138 100644
--- a/pkg/controller/executors/node_lookup_test.go
+++ b/pkg/controller/executors/node_lookup_test.go
@@ -18,14 +18,20 @@ type nsg struct {
 	v1alpha1.NodeStatusGetter
 }
 
+type dag struct {
+	DAGStructure
+}
+
 func TestNewNodeLookup(t *testing.T) {
 	n := ng{}
 	ns := nsg{}
-	nl := NewNodeLookup(n, ns)
+	d := dag{}
+	nl := NewNodeLookup(n, ns, d)
 	assert.NotNil(t, nl)
 	typed := nl.(contextualNodeLookup)
 	assert.Equal(t, n, typed.NodeGetter)
 	assert.Equal(t, ns, typed.NodeStatusGetter)
+	assert.Equal(t, d, typed.DAGStructure)
 }
 
 func TestNewTestNodeLookup(t *testing.T) {
diff --git a/pkg/controller/nodes/branch/handler.go b/pkg/controller/nodes/branch/handler.go
index 9b0cd7f59..109290b90 100644
--- a/pkg/controller/nodes/branch/handler.go
+++ b/pkg/controller/nodes/branch/handler.go
@@ -136,7 +136,11 @@ func (b *branchHandler) recurseDownstream(ctx context.Context, nCtx handler.Node
 	childNodeStatus := nl.GetNodeExecutionStatus(ctx, branchTakenNode.GetID())
 	childNodeStatus.SetDataDir(nodeStatus.GetDataDir())
 	childNodeStatus.SetOutputDir(nodeStatus.GetOutputDir())
-	dag := executors.NewLeafNodeDAGStructure(branchTakenNode.GetID(), nCtx.NodeID())
+	upstreamNodeIds, err := nCtx.ContextualNodeLookup().ToNode(branchTakenNode.GetID())
+	if err != nil {
+		return handler.UnknownTransition, err
+	}
+	dag := executors.NewLeafNodeDAGStructure(branchTakenNode.GetID(), append(upstreamNodeIds, nCtx.NodeID())...)
 	execContext, err := b.getExecutionContextForDownstream(nCtx)
 	if err != nil {
 		return handler.UnknownTransition, err
@@ -196,7 +200,11 @@ func (b *branchHandler) Abort(ctx context.Context, nCtx handler.NodeExecutionCon
 	// TODO we should replace the call to RecursiveNodeHandler with a call to SingleNode Handler. The inputs are also already known ahead of time
 	// There is no DAGStructure for the branch nodes, the branch taken node is the leaf node. The node itself may be arbitrarily complex, but in that case the node should reference a subworkflow etc
 	// The parent of the BranchTaken Node is the actual Branch Node and all the data is just forwarded from the Branch to the executed node.
-	dag := executors.NewLeafNodeDAGStructure(branchTakenNode.GetID(), nCtx.NodeID())
+	upstreamNodeIds, err := nCtx.ContextualNodeLookup().ToNode(branchTakenNode.GetID())
+	if err != nil {
+		return err
+	}
+	dag := executors.NewLeafNodeDAGStructure(branchTakenNode.GetID(), append(upstreamNodeIds, nCtx.NodeID())...)
 	execContext, err := b.getExecutionContextForDownstream(nCtx)
 	if err != nil {
 		return err
@@ -236,7 +244,11 @@ func (b *branchHandler) Finalize(ctx context.Context, nCtx handler.NodeExecution
 	// TODO we should replace the call to RecursiveNodeHandler with a call to SingleNode Handler. The inputs are also already known ahead of time
 	// There is no DAGStructure for the branch nodes, the branch taken node is the leaf node. The node itself may be arbitrarily complex, but in that case the node should reference a subworkflow etc
 	// The parent of the BranchTaken Node is the actual Branch Node and all the data is just forwarded from the Branch to the executed node.
-	dag := executors.NewLeafNodeDAGStructure(branchTakenNode.GetID(), nCtx.NodeID())
+	upstreamNodeIds, err := nCtx.ContextualNodeLookup().ToNode(branchTakenNode.GetID())
+	if err != nil {
+		return err
+	}
+	dag := executors.NewLeafNodeDAGStructure(branchTakenNode.GetID(), append(upstreamNodeIds, nCtx.NodeID())...)
 	execContext, err := b.getExecutionContextForDownstream(nCtx)
 	if err != nil {
 		return err
diff --git a/pkg/controller/nodes/branch/handler_test.go b/pkg/controller/nodes/branch/handler_test.go
index bc39c1b24..3cebd5edd 100644
--- a/pkg/controller/nodes/branch/handler_test.go
+++ b/pkg/controller/nodes/branch/handler_test.go
@@ -158,24 +158,27 @@ func TestBranchHandler_RecurseDownstream(t *testing.T) {
 		isErr           bool
 		expectedPhase   handler.EPhase
 		childPhase      v1alpha1.NodePhase
-		nl              *execMocks.NodeLookup
 	}{
 		{"childNodeError", executors.NodeStatusUndefined, fmt.Errorf("err"),
-			&mocks2.ExecutableNodeStatus{}, bn, true, handler.EPhaseUndefined, v1alpha1.NodePhaseFailed, &execMocks.NodeLookup{}},
+			&mocks2.ExecutableNodeStatus{}, bn, true, handler.EPhaseUndefined, v1alpha1.NodePhaseFailed},
 		{"childPending", executors.NodeStatusPending, nil,
-			&mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseRunning, v1alpha1.NodePhaseQueued, &execMocks.NodeLookup{}},
+			&mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseRunning, v1alpha1.NodePhaseQueued},
 		{"childStillRunning", executors.NodeStatusRunning, nil,
-			&mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseRunning, v1alpha1.NodePhaseRunning, &execMocks.NodeLookup{}},
+			&mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseRunning, v1alpha1.NodePhaseRunning},
 		{"childFailure", executors.NodeStatusFailed(expectedError), nil,
-			&mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseFailed, v1alpha1.NodePhaseFailed, &execMocks.NodeLookup{}},
+			&mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseFailed, v1alpha1.NodePhaseFailed},
 		{"childComplete", executors.NodeStatusComplete, nil,
-			&mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseSuccess, v1alpha1.NodePhaseSucceeded, &execMocks.NodeLookup{}},
+			&mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseSuccess, v1alpha1.NodePhaseSucceeded},
 	}
 	for _, test := range tests {
 		t.Run(test.name, func(t *testing.T) {
 			eCtx := &execMocks.ExecutionContext{}
 			eCtx.OnGetParentInfo().Return(parentInfo{})
-			nCtx, _ := createNodeContext(v1alpha1.BranchNodeNotYetEvaluated, &childNodeID, n, nil, test.nl, eCtx)
+
+			mockNodeLookup := &execMocks.NodeLookup{}
+			mockNodeLookup.OnToNodeMatch(mock.Anything).Return(nil, nil)
+
+			nCtx, _ := createNodeContext(v1alpha1.BranchNodeNotYetEvaluated, &childNodeID, n, nil, mockNodeLookup, eCtx)
 			newParentInfo, _ := common.CreateParentInfo(parentInfo{}, nCtx.NodeID(), nCtx.CurrentAttempt())
 			expectedExecContext := executors.NewExecutionContextWithParentInfo(nCtx.ExecutionContext(), newParentInfo)
 			mockNodeExecutor := &execMocks.Node{}
@@ -194,16 +197,16 @@ func TestBranchHandler_RecurseDownstream(t *testing.T) {
 					}
 					return false
 				}),
-				mock.MatchedBy(func(lookup executors.NodeLookup) bool { return assert.Equal(t, lookup, test.nl) }),
+				mock.MatchedBy(func(lookup executors.NodeLookup) bool { return assert.Equal(t, lookup, mockNodeLookup) }),
 				mock.MatchedBy(func(n v1alpha1.ExecutableNode) bool { return assert.Equal(t, n.GetID(), childNodeID) }),
 			).Return(test.ns, test.err)
 
 			childNodeStatus := &mocks2.ExecutableNodeStatus{}
-			if test.nl != nil {
+			if mockNodeLookup != nil {
 				childNodeStatus.OnGetOutputDir().Return("parent-output-dir")
 				test.nodeStatus.OnGetDataDir().Return("parent-data-dir")
 				test.nodeStatus.OnGetOutputDir().Return("parent-output-dir")
-				test.nl.OnGetNodeExecutionStatus(ctx, childNodeID).Return(childNodeStatus)
+				mockNodeLookup.OnGetNodeExecutionStatus(ctx, childNodeID).Return(childNodeStatus)
 				childNodeStatus.On("SetDataDir", storage.DataReference("parent-data-dir")).Once()
 				childNodeStatus.On("SetOutputDir", storage.DataReference("parent-output-dir")).Once()
 			}
@@ -295,17 +298,18 @@ func TestBranchHandler_AbortNode(t *testing.T) {
 
 	t.Run("BranchNodeSuccess", func(t *testing.T) {
 		mockNodeExecutor := &execMocks.Node{}
-		nl := &execMocks.NodeLookup{}
+		mockNodeLookup := &execMocks.NodeLookup{}
+		mockNodeLookup.OnToNodeMatch(mock.Anything).Return(nil, nil)
 		eCtx := &execMocks.ExecutionContext{}
 		eCtx.OnGetParentInfo().Return(parentInfo{})
-		nCtx, s := createNodeContext(v1alpha1.BranchNodeSuccess, &n1, n, nil, nl, eCtx)
+		nCtx, s := createNodeContext(v1alpha1.BranchNodeSuccess, &n1, n, nil, mockNodeLookup, eCtx)
 		newParentInfo, _ := common.CreateParentInfo(parentInfo{}, nCtx.NodeID(), nCtx.CurrentAttempt())
 		expectedExecContext := executors.NewExecutionContextWithParentInfo(nCtx.ExecutionContext(), newParentInfo)
 		mockNodeExecutor.OnAbortHandlerMatch(mock.Anything,
 			mock.MatchedBy(func(e executors.ExecutionContext) bool { return assert.Equal(t, e, expectedExecContext) }),
 			mock.Anything,
 			mock.Anything, mock.Anything, mock.Anything).Return(nil)
-		nl.OnGetNode(*s.s.FinalizedNodeID).Return(n, true)
+		mockNodeLookup.OnGetNode(*s.s.FinalizedNodeID).Return(n, true)
 		branch := New(mockNodeExecutor, eventConfig, promutils.NewTestScope())
 		err := branch.Abort(ctx, nCtx, "")
 		assert.NoError(t, err)
diff --git a/pkg/controller/nodes/dynamic/dynamic_workflow.go b/pkg/controller/nodes/dynamic/dynamic_workflow.go
index ef1fc51c3..eb891aa27 100644
--- a/pkg/controller/nodes/dynamic/dynamic_workflow.go
+++ b/pkg/controller/nodes/dynamic/dynamic_workflow.go
@@ -183,7 +183,7 @@ func (d dynamicNodeTaskNodeHandler) buildContextualDynamicWorkflow(ctx context.C
 				subWorkflow:        compiledWf,
 				subWorkflowClosure: workflowCacheContents.CompiledWorkflow,
 				execContext:        executors.NewExecutionContext(nCtx.ExecutionContext(), compiledWf, compiledWf, newParentInfo, nCtx.ExecutionContext()),
-				nodeLookup:         executors.NewNodeLookup(compiledWf, dynamicNodeStatus),
+				nodeLookup:         executors.NewNodeLookup(compiledWf, dynamicNodeStatus, compiledWf),
 				dynamicJobSpecURI:  string(f.GetLoc()),
 			}, nil
 		}
@@ -216,7 +216,7 @@ func (d dynamicNodeTaskNodeHandler) buildContextualDynamicWorkflow(ctx context.C
 		subWorkflow:        dynamicWf,
 		subWorkflowClosure: closure,
 		execContext:        executors.NewExecutionContext(nCtx.ExecutionContext(), dynamicWf, dynamicWf, newParentInfo, nCtx.ExecutionContext()),
-		nodeLookup:         executors.NewNodeLookup(dynamicWf, dynamicNodeStatus),
+		nodeLookup:         executors.NewNodeLookup(dynamicWf, dynamicNodeStatus, dynamicWf),
 		dynamicJobSpecURI:  string(f.GetLoc()),
 	}, nil
 }
diff --git a/pkg/controller/nodes/subworkflow/subworkflow.go b/pkg/controller/nodes/subworkflow/subworkflow.go
index 24d74473f..74beeaf79 100644
--- a/pkg/controller/nodes/subworkflow/subworkflow.go
+++ b/pkg/controller/nodes/subworkflow/subworkflow.go
@@ -209,7 +209,7 @@ func (s *subworkflowHandler) HandleFailingSubWorkflow(ctx context.Context, nCtx
 	}
 
 	status := nCtx.NodeStatus()
-	nodeLookup := executors.NewNodeLookup(subWorkflow, status)
+	nodeLookup := executors.NewNodeLookup(subWorkflow, status, subWorkflow)
 	return s.HandleFailureNodeOfSubWorkflow(ctx, nCtx, subWorkflow, nodeLookup)
 }
 
@@ -220,7 +220,7 @@ func (s *subworkflowHandler) StartSubWorkflow(ctx context.Context, nCtx handler.
 	}
 
 	status := nCtx.NodeStatus()
-	nodeLookup := executors.NewNodeLookup(subWorkflow, status)
+	nodeLookup := executors.NewNodeLookup(subWorkflow, status, subWorkflow)
 
 	// assert startStatus.IsComplete() == true
 	return s.startAndHandleSubWorkflow(ctx, nCtx, subWorkflow, nodeLookup)
@@ -233,7 +233,7 @@ func (s *subworkflowHandler) CheckSubWorkflowStatus(ctx context.Context, nCtx ha
 	}
 
 	status := nCtx.NodeStatus()
-	nodeLookup := executors.NewNodeLookup(subWorkflow, status)
+	nodeLookup := executors.NewNodeLookup(subWorkflow, status, subWorkflow)
 	return s.handleSubWorkflow(ctx, nCtx, subWorkflow, nodeLookup)
 }
 
@@ -243,7 +243,7 @@ func (s *subworkflowHandler) HandleAbort(ctx context.Context, nCtx handler.NodeE
 		return err
 	}
 	status := nCtx.NodeStatus()
-	nodeLookup := executors.NewNodeLookup(subWorkflow, status)
+	nodeLookup := executors.NewNodeLookup(subWorkflow, status, subWorkflow)
 	execContext, err := s.getExecutionContextForDownstream(nCtx)
 	if err != nil {
 		return err
diff --git a/pkg/controller/workflow/executor.go b/pkg/controller/workflow/executor.go
index 92a870c1c..e3eac1e37 100644
--- a/pkg/controller/workflow/executor.go
+++ b/pkg/controller/workflow/executor.go
@@ -124,7 +124,7 @@ func (c *workflowExecutor) handleReadyWorkflow(ctx context.Context, w *v1alpha1.
 	nodeStatus.SetDataDir(dataDir)
 	nodeStatus.SetOutputDir(outputDir)
 	execcontext := executors.NewExecutionContext(w, w, w, nil, executors.InitializeControlFlow())
-	s, err := c.nodeExecutor.SetInputsForStartNode(ctx, execcontext, w, executors.NewNodeLookup(w, w.GetExecutionStatus()), inputs)
+	s, err := c.nodeExecutor.SetInputsForStartNode(ctx, execcontext, w, executors.NewNodeLookup(w, w.GetExecutionStatus(), w), inputs)
 	if err != nil {
 		return StatusReady, err
 	}

From edf6e72056f180604bf64114a91f5218e6291fc8 Mon Sep 17 00:00:00 2001
From: Daniel Rammer <daniel@union.ai>
Date: Wed, 22 Mar 2023 18:14:34 -0500
Subject: [PATCH 2/4] removed dead comments

Signed-off-by: Daniel Rammer <daniel@union.ai>
---
 pkg/controller/executors/node_lookup.go | 2 --
 1 file changed, 2 deletions(-)

diff --git a/pkg/controller/executors/node_lookup.go b/pkg/controller/executors/node_lookup.go
index 0348b8b18..381b832c0 100644
--- a/pkg/controller/executors/node_lookup.go
+++ b/pkg/controller/executors/node_lookup.go
@@ -51,12 +51,10 @@ func (s staticNodeLookup) GetNodeExecutionStatus(_ context.Context, id v1alpha1.
 	return s.status[id]
 }
 
-// TODO @hamersaw implement
 func (s staticNodeLookup) ToNode(id v1alpha1.NodeID) ([]v1alpha1.NodeID, error) {
 	return nil, nil
 }
 
-// TODO @hamersaw implement
 func (s staticNodeLookup) FromNode(id v1alpha1.NodeID) ([]v1alpha1.NodeID, error) {
 	return nil, nil
 }

From 155401a1c45f0403b00fb4874bdcc0978d167940 Mon Sep 17 00:00:00 2001
From: Daniel Rammer <daniel@union.ai>
Date: Thu, 23 Mar 2023 09:44:42 -0500
Subject: [PATCH 3/4] added unit test

Signed-off-by: Daniel Rammer <daniel@union.ai>
---
 pkg/controller/nodes/branch/handler_test.go | 43 +++++++++++++--------
 1 file changed, 27 insertions(+), 16 deletions(-)

diff --git a/pkg/controller/nodes/branch/handler_test.go b/pkg/controller/nodes/branch/handler_test.go
index 3cebd5edd..daae4bef4 100644
--- a/pkg/controller/nodes/branch/handler_test.go
+++ b/pkg/controller/nodes/branch/handler_test.go
@@ -150,25 +150,28 @@ func TestBranchHandler_RecurseDownstream(t *testing.T) {
 	bn.OnGetID().Return(childNodeID)
 
 	tests := []struct {
-		name            string
-		ns              executors.NodeStatus
-		err             error
-		nodeStatus      *mocks2.ExecutableNodeStatus
-		branchTakenNode v1alpha1.ExecutableNode
-		isErr           bool
-		expectedPhase   handler.EPhase
-		childPhase      v1alpha1.NodePhase
+		name              string
+		ns                executors.NodeStatus
+		err               error
+		nodeStatus        *mocks2.ExecutableNodeStatus
+		branchTakenNode   v1alpha1.ExecutableNode
+		isErr             bool
+		expectedPhase     handler.EPhase
+		childPhase        v1alpha1.NodePhase
+		upstreamNodeID    string
 	}{
+		{"upstreamNodeExists", executors.NodeStatusPending, nil,
+			&mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseRunning, v1alpha1.NodePhaseQueued, "n2"},
 		{"childNodeError", executors.NodeStatusUndefined, fmt.Errorf("err"),
-			&mocks2.ExecutableNodeStatus{}, bn, true, handler.EPhaseUndefined, v1alpha1.NodePhaseFailed},
+			&mocks2.ExecutableNodeStatus{}, bn, true, handler.EPhaseUndefined, v1alpha1.NodePhaseFailed, ""},
 		{"childPending", executors.NodeStatusPending, nil,
-			&mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseRunning, v1alpha1.NodePhaseQueued},
+			&mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseRunning, v1alpha1.NodePhaseQueued, ""},
 		{"childStillRunning", executors.NodeStatusRunning, nil,
-			&mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseRunning, v1alpha1.NodePhaseRunning},
+			&mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseRunning, v1alpha1.NodePhaseRunning, ""},
 		{"childFailure", executors.NodeStatusFailed(expectedError), nil,
-			&mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseFailed, v1alpha1.NodePhaseFailed},
+			&mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseFailed, v1alpha1.NodePhaseFailed, ""},
 		{"childComplete", executors.NodeStatusComplete, nil,
-			&mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseSuccess, v1alpha1.NodePhaseSucceeded},
+			&mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseSuccess, v1alpha1.NodePhaseSucceeded, ""},
 	}
 	for _, test := range tests {
 		t.Run(test.name, func(t *testing.T) {
@@ -176,7 +179,11 @@ func TestBranchHandler_RecurseDownstream(t *testing.T) {
 			eCtx.OnGetParentInfo().Return(parentInfo{})
 
 			mockNodeLookup := &execMocks.NodeLookup{}
-			mockNodeLookup.OnToNodeMatch(mock.Anything).Return(nil, nil)
+			if len(test.upstreamNodeID) > 0 {
+				mockNodeLookup.OnToNodeMatch(childNodeID).Return([]string{test.upstreamNodeID}, nil)
+			} else {
+				mockNodeLookup.OnToNodeMatch(childNodeID).Return(nil, nil)
+			}
 
 			nCtx, _ := createNodeContext(v1alpha1.BranchNodeNotYetEvaluated, &childNodeID, n, nil, mockNodeLookup, eCtx)
 			newParentInfo, _ := common.CreateParentInfo(parentInfo{}, nCtx.NodeID(), nCtx.CurrentAttempt())
@@ -190,9 +197,13 @@ func TestBranchHandler_RecurseDownstream(t *testing.T) {
 						fList, err1 := d.FromNode("x")
 						dList, err2 := d.ToNode(childNodeID)
 						b := assert.NoError(t, err1)
-						b = b && assert.Equal(t, fList, []v1alpha1.NodeID{})
+						b = b && assert.Equal(t, []v1alpha1.NodeID{}, fList)
 						b = b && assert.NoError(t, err2)
-						b = b && assert.Equal(t, dList, []v1alpha1.NodeID{nodeID})
+						dListExpected := []v1alpha1.NodeID{nodeID}
+						if len(test.upstreamNodeID) > 0 {
+							dListExpected = append([]string{test.upstreamNodeID}, dListExpected...)
+						}
+						b = b && assert.Equal(t, dListExpected, dList)
 						return b
 					}
 					return false

From 0955fefff0c1316f1cc349f6b39788d81a4d06e2 Mon Sep 17 00:00:00 2001
From: Daniel Rammer <daniel@union.ai>
Date: Thu, 23 Mar 2023 09:45:29 -0500
Subject: [PATCH 4/4] fixed lint issues

Signed-off-by: Daniel Rammer <daniel@union.ai>
---
 pkg/controller/nodes/branch/handler_test.go | 18 +++++++++---------
 1 file changed, 9 insertions(+), 9 deletions(-)

diff --git a/pkg/controller/nodes/branch/handler_test.go b/pkg/controller/nodes/branch/handler_test.go
index daae4bef4..5711de5d4 100644
--- a/pkg/controller/nodes/branch/handler_test.go
+++ b/pkg/controller/nodes/branch/handler_test.go
@@ -150,15 +150,15 @@ func TestBranchHandler_RecurseDownstream(t *testing.T) {
 	bn.OnGetID().Return(childNodeID)
 
 	tests := []struct {
-		name              string
-		ns                executors.NodeStatus
-		err               error
-		nodeStatus        *mocks2.ExecutableNodeStatus
-		branchTakenNode   v1alpha1.ExecutableNode
-		isErr             bool
-		expectedPhase     handler.EPhase
-		childPhase        v1alpha1.NodePhase
-		upstreamNodeID    string
+		name            string
+		ns              executors.NodeStatus
+		err             error
+		nodeStatus      *mocks2.ExecutableNodeStatus
+		branchTakenNode v1alpha1.ExecutableNode
+		isErr           bool
+		expectedPhase   handler.EPhase
+		childPhase      v1alpha1.NodePhase
+		upstreamNodeID  string
 	}{
 		{"upstreamNodeExists", executors.NodeStatusPending, nil,
 			&mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseRunning, v1alpha1.NodePhaseQueued, "n2"},