Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Refactor to ensure branch node rename is consistent in all nodes and …
Browse files Browse the repository at this point in the history
…connections (#268)

* Bump to pick up latest plugins (#267)

Signed-off-by: Katrina Rogan <[email protected]>
Signed-off-by: Haytham Abuelfutuh <[email protected]>

* Reflect node id change in upstream and downstream connections

Signed-off-by: Haytham Abuelfutuh <[email protected]>

* wip

Signed-off-by: Haytham Abuelfutuh <[email protected]>

* fix lint issue

Signed-off-by: Ketan Umare <[email protected]>
Signed-off-by: Haytham Abuelfutuh <[email protected]>

* Fix connections for branch nodes

Signed-off-by: Haytham Abuelfutuh <[email protected]>

* Support mismatching interfaces for branches

Signed-off-by: Haytham Abuelfutuh <[email protected]>

* only fail adding nodes in compiler transformer when nodes are different

Signed-off-by: Haytham Abuelfutuh <[email protected]>

* remove replaceNodeID

Signed-off-by: Haytham Abuelfutuh <[email protected]>

* Fix unit tests

Signed-off-by: Haytham Abuelfutuh <[email protected]>

* Clean up commented code

Signed-off-by: Haytham Abuelfutuh <[email protected]>

* bump DCO

Signed-off-by: Haytham Abuelfutuh <[email protected]>

* Cleanup

Signed-off-by: Haytham Abuelfutuh <[email protected]>

* Revert config.yaml

Signed-off-by: Haytham Abuelfutuh <[email protected]>

* PR Comments

Signed-off-by: Haytham Abuelfutuh <[email protected]>

Co-authored-by: Katrina Rogan <[email protected]>
Co-authored-by: Ketan Umare <[email protected]>
  • Loading branch information
3 people authored Jun 2, 2021
1 parent 035b097 commit 0437b71
Show file tree
Hide file tree
Showing 38 changed files with 11,606 additions and 137 deletions.
28 changes: 19 additions & 9 deletions pkg/compiler/builders.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package compiler

import (
"fmt"

"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
c "github.com/flyteorg/flytepropeller/pkg/compiler/common"
)
Expand All @@ -12,12 +14,13 @@ type flyteNode = core.Node
// A builder object for the Graph struct. This contains information the compiler uses while building the final Graph
// struct.
type workflowBuilder struct {
CoreWorkflow *flyteWorkflow
LaunchPlans map[c.WorkflowIDKey]c.InterfaceProvider
Tasks c.TaskIndex
downstreamNodes c.StringAdjacencyList
upstreamNodes c.StringAdjacencyList
Nodes c.NodeIndex
CoreWorkflow *flyteWorkflow
LaunchPlans map[c.WorkflowIDKey]c.InterfaceProvider
Tasks c.TaskIndex
downstreamNodes c.StringAdjacencyList
upstreamNodes c.StringAdjacencyList
Nodes c.NodeIndex
NodeBuilderIndex c.NodeIndex

// These are references to all subgraphs and tasks passed to CompileWorkflow. They will be passed around but will
// not show in their entirety in the final Graph. The required subset of these will be added to each subgraph as
Expand All @@ -30,7 +33,7 @@ type workflowBuilder struct {

func (w workflowBuilder) GetFailureNode() c.Node {
if w.GetCoreWorkflow() != nil && w.GetCoreWorkflow().GetTemplate() != nil && w.GetCoreWorkflow().GetTemplate().FailureNode != nil {
return w.NewNodeBuilder(w.GetCoreWorkflow().GetTemplate().FailureNode)
return w.GetOrCreateNodeBuilder(w.GetCoreWorkflow().GetTemplate().FailureNode)
}

return nil
Expand All @@ -52,8 +55,15 @@ func (w workflowBuilder) GetUpstreamNodes() c.StringAdjacencyList {
return w.upstreamNodes
}

func (w workflowBuilder) NewNodeBuilder(n *flyteNode) c.NodeBuilder {
return &nodeBuilder{flyteNode: n}
func (w workflowBuilder) GetOrCreateNodeBuilder(n *flyteNode) c.NodeBuilder {
address := fmt.Sprintf("%p", n)
if existingBuilder, found := w.NodeBuilderIndex[address]; found {
return existingBuilder
}

newObj := &nodeBuilder{flyteNode: n}
w.NodeBuilderIndex[address] = newObj
return newObj
}

func (w workflowBuilder) GetNode(id c.NodeID) (node c.NodeBuilder, found bool) {
Expand Down
11 changes: 10 additions & 1 deletion pkg/compiler/common/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,14 @@ const (
EndNodeID = "end-node"
)

type EdgeDirection uint8

const (
EdgeDirectionBidirectional EdgeDirection = iota
EdgeDirectionDownstream
EdgeDirectionUpstream
)

//go:generate mockery -all -output=mocks -case=underscore

// A mutable workflow used during the build of the intermediate layer.
Expand All @@ -21,8 +29,9 @@ type WorkflowBuilder interface {
AddUpstreamEdge(nodeProvider, nodeDependent NodeID)
AddDownstreamEdge(nodeProvider, nodeDependent NodeID)
AddNode(n NodeBuilder, errs errors.CompileErrors) (node NodeBuilder, ok bool)
AddEdges(n NodeBuilder, edgeDirection EdgeDirection, errs errors.CompileErrors) (ok bool)
ValidateWorkflow(fg *core.CompiledWorkflow, errs errors.CompileErrors) (Workflow, bool)
NewNodeBuilder(n *core.Node) NodeBuilder
GetOrCreateNodeBuilder(n *core.Node) NodeBuilder
}

// A mutable node used during the build of the intermediate layer.
Expand Down
54 changes: 43 additions & 11 deletions pkg/compiler/common/mocks/workflow_builder.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

151 changes: 144 additions & 7 deletions pkg/compiler/test/compiler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import (
"strings"
"testing"

"k8s.io/apimachinery/pkg/util/sets"

"github.com/go-test/deep"

"github.com/ghodss/yaml"
Expand Down Expand Up @@ -182,6 +184,113 @@ func TestDynamic(t *testing.T) {
}))
}

func getAllSubNodeIDs(n *core.Node) sets.String {
res := sets.NewString()
if branchNode := n.GetBranchNode(); branchNode != nil {
thenNode := branchNode.IfElse.Case.ThenNode
if hasPromiseInputs(thenNode.GetInputs()) {
res.Insert(thenNode.GetId())
}

res = res.Union(getAllSubNodeIDs(thenNode))

for _, other := range branchNode.IfElse.Other {
if hasPromiseInputs(other.ThenNode.GetInputs()) {
res.Insert(other.ThenNode.GetId())
}

res = res.Union(getAllSubNodeIDs(other.ThenNode))
}

if elseNode := branchNode.IfElse.GetElseNode(); elseNode != nil {
if hasPromiseInputs(elseNode.GetInputs()) {
res.Insert(elseNode.GetId())
}

res = res.Union(getAllSubNodeIDs(elseNode))
}
}

// TODO: Support Sub workflow

return res
}

type nodePredicate func(n *core.Node) bool

var hasPromiseNodePredicate = func(n *core.Node) bool {
return hasPromiseInputs(n.GetInputs())
}

var allNodesPredicate = func(n *core.Node) bool {
return true
}

func getAllMatchingNodes(wf *core.CompiledWorkflow, predicate nodePredicate) sets.String {
s := sets.NewString()
for _, n := range wf.Template.Nodes {
if predicate(n) {
s.Insert(n.GetId())
}

s = s.Union(getAllSubNodeIDs(n))
}

return s
}

func bindingHasPromiseInputs(binding *core.BindingData) bool {
switch v := binding.GetValue().(type) {
case *core.BindingData_Collection:
for _, d := range v.Collection.Bindings {
if bindingHasPromiseInputs(d) {
return true
}
}
case *core.BindingData_Map:
for _, d := range v.Map.Bindings {
if bindingHasPromiseInputs(d) {
return true
}
}
case *core.BindingData_Promise:
return true
}

return false
}

func hasPromiseInputs(bindings []*core.Binding) bool {
for _, b := range bindings {
if bindingHasPromiseInputs(b.Binding) {
return true
}
}

return false
}

func assertNodeIDsInConnections(t testing.TB, nodeIDsWithDeps, allNodeIDs sets.String, connections *core.ConnectionSet) bool {
actualNodeIDs := sets.NewString()
for id, lst := range connections.Downstream {
actualNodeIDs.Insert(id)
actualNodeIDs.Insert(lst.Ids...)
}

for id, lst := range connections.Upstream {
actualNodeIDs.Insert(id)
actualNodeIDs.Insert(lst.Ids...)
}

notFoundInConnections := nodeIDsWithDeps.Difference(actualNodeIDs)
correct := assert.Empty(t, notFoundInConnections, "All nodes must appear in connections")

notFoundInNodes := actualNodeIDs.Difference(allNodeIDs)
correct = correct && assert.Empty(t, notFoundInNodes, "All connections must correspond to existing nodes")

return correct
}

func TestBranches(t *testing.T) {
errors.SetConfig(errors.Config{IncludeSource: true})
assert.NoError(t, filepath.Walk("testdata/branch", func(path string, info os.FileInfo, err error) error {
Expand All @@ -195,16 +304,36 @@ func TestBranches(t *testing.T) {

t.Run(path, func(t *testing.T) {
// If you want to debug a single use-case. Uncomment this line.
//if !strings.HasSuffix(path, "success_1.json") {
//if !strings.HasSuffix(path, "success_8_nested.json") {
// t.SkipNow()
//}

raw, err := ioutil.ReadFile(path)
assert.NoError(t, err)
wf := &core.WorkflowClosure{}
err = jsonpb.UnmarshalString(string(raw), wf)
if !assert.NoError(t, err) {
t.FailNow()
if filepath.Ext(path) == ".json" {
err = jsonpb.UnmarshalString(string(raw), wf)
if !assert.NoError(t, err) {
t.FailNow()
}
} else if filepath.Ext(path) == ".pb" {
err = proto.Unmarshal(raw, wf)
if !assert.NoError(t, err) {
t.FailNow()
}

m := &jsonpb.Marshaler{
Indent: " ",
}
raw, err := m.MarshalToString(wf)
if !assert.NoError(t, err) {
t.FailNow()
}

err = ioutil.WriteFile(strings.TrimSuffix(path, filepath.Ext(path))+".json", []byte(raw), os.ModePerm)
if !assert.NoError(t, err) {
t.FailNow()
}
}

t.Log("Compiling Workflow")
Expand All @@ -215,8 +344,10 @@ func TestBranches(t *testing.T) {
t.FailNow()
}

marshaler := jsonpb.Marshaler{}
rawStr, err := marshaler.MarshalToString(compiledWfc)
m := &jsonpb.Marshaler{
Indent: " ",
}
rawStr, err := m.MarshalToString(compiledWfc)
if !assert.NoError(t, err) {
t.Fail()
}
Expand All @@ -238,6 +369,12 @@ func TestBranches(t *testing.T) {
}
}

allNodeIDs := getAllMatchingNodes(compiledWfc.Primary, allNodesPredicate)
nodeIDsWithDeps := getAllMatchingNodes(compiledWfc.Primary, hasPromiseNodePredicate)
if !assertNodeIDsInConnections(t, nodeIDsWithDeps, allNodeIDs, compiledWfc.Primary.Connections) {
t.FailNow()
}

inputs := map[string]interface{}{}
for varName, v := range compiledWfc.Primary.Template.Interface.Inputs.Variables {
inputs[varName] = coreutils.MustMakeDefaultLiteralForType(v.Type)
Expand All @@ -252,7 +389,7 @@ func TestBranches(t *testing.T) {
},
"namespace")
if assert.NoError(t, err) {
raw, err := json.Marshal(flyteWf)
raw, err := json.MarshalIndent(flyteWf, "", " ")
if assert.NoError(t, err) {
assert.NotEmpty(t, raw)
}
Expand Down
Loading

0 comments on commit 0437b71

Please sign in to comment.