From a79c03f354b3ff6c676205cece9d8393ea9382dc Mon Sep 17 00:00:00 2001 From: Haytham AbuelFutuh Date: Thu, 2 Jan 2020 15:35:19 -0800 Subject: [PATCH] Statemachine & Status fixes (Discretization, status size... etc.) (#49) This PR fixes a few issues uncovered during the investigation of the statemachine inconsistency issues last week. Specifically: - [X] Ensure each node can a progress at most once per round (IsDirty flag) - [X] Remove ParentTaskID and DataDir from NodeStatus field (Causing workflow etcd. obj size to bloat) - [X] Add Parent RetryAttempt in the generated hierarchal name of dynamic sub-nodes to ensure retries do not reuse an existing sub-node status. Details: https://docs.google.com/document/d/1ISaxIZeYLcBaeapEmeTqb-g0x04pJbf5t3i30qMfk6U/edit?usp=sharing --- cmd/kubectl-flyte/cmd/get.go | 9 +- cmd/kubectl-flyte/cmd/printers/node.go | 13 +- cmd/kubectl-flyte/cmd/printers/workflow.go | 5 +- pkg/apis/flyteworkflow/v1alpha1/iface.go | 16 +- .../v1alpha1/mocks/BaseWorkflowWithStatus.go | 16 +- .../v1alpha1/mocks/ExecutableNodeStatus.go | 21 +- .../v1alpha1/mocks/ExecutableWorkflow.go | 16 +- .../mocks/ExecutableWorkflowStatus.go | 32 +-- .../flyteworkflow/v1alpha1/mocks/Mutable.go | 42 +++ .../v1alpha1/mocks/MutableBranchNodeStatus.go | 32 +++ .../mocks/MutableDynamicNodeStatus.go | 32 +++ .../v1alpha1/mocks/MutableNodeStatus.go | 37 +++ .../mocks/MutableSubWorkflowNodeStatus.go | 32 +++ .../v1alpha1/mocks/MutableTaskNodeStatus.go | 32 +++ .../mocks/MutableWorkflowNodeStatus.go | 32 +++ .../v1alpha1/mocks/NodeStatusGetter.go | 16 +- .../flyteworkflow/v1alpha1/node_status.go | 240 ++++++++++++++---- pkg/apis/flyteworkflow/v1alpha1/workflow.go | 15 +- .../flyteworkflow/v1alpha1/workflow_status.go | 38 ++- .../v1alpha1/zz_generated.deepcopy.go | 20 ++ .../common/mocks/interface_provider.go | 6 +- pkg/compiler/common/mocks/node.go | 9 +- pkg/compiler/common/mocks/node_builder.go | 9 +- pkg/compiler/common/mocks/task.go | 6 +- pkg/compiler/common/mocks/workflow.go | 9 +- pkg/compiler/common/mocks/workflow_builder.go | 12 +- pkg/controller/executors/mocks/client.go | 38 ++- pkg/controller/executors/mocks/node.go | 62 ++++- pkg/controller/executors/mocks/workflow.go | 56 +++- pkg/controller/executors/node.go | 2 +- pkg/controller/nodes/branch/evaluator.go | 2 +- pkg/controller/nodes/branch/evaluator_test.go | 18 ++ pkg/controller/nodes/branch/handler.go | 8 +- pkg/controller/nodes/branch/handler_test.go | 2 +- pkg/controller/nodes/dynamic/handler.go | 27 +- pkg/controller/nodes/dynamic/handler_test.go | 26 +- .../nodes/dynamic/mocks/task_node_handler.go | 18 +- pkg/controller/nodes/dynamic/subworkflow.go | 16 +- .../nodes/dynamic/subworkflow_test.go | 6 +- pkg/controller/nodes/dynamic/utils.go | 12 +- pkg/controller/nodes/dynamic/utils_test.go | 14 +- pkg/controller/nodes/executor.go | 56 ++-- pkg/controller/nodes/executor_test.go | 27 +- pkg/controller/nodes/handler/mocks/node.go | 9 +- .../handler/mocks/node_execution_context.go | 17 +- .../handler/mocks/node_execution_metadata.go | 12 +- .../nodes/handler/mocks/node_state_reader.go | 6 +- .../nodes/handler/mocks/node_state_writer.go | 6 +- .../nodes/handler/mocks/setup_context.go | 6 +- .../nodes/handler/mocks/task_reader.go | 9 +- pkg/controller/nodes/mocks/handler_factory.go | 11 +- pkg/controller/nodes/mocks/output_resolver.go | 19 +- pkg/controller/nodes/node_state_manager.go | 2 + pkg/controller/nodes/output_resolver.go | 6 +- pkg/controller/nodes/predicate.go | 14 +- pkg/controller/nodes/predicate_test.go | 92 +++---- pkg/controller/nodes/resolve.go | 4 +- pkg/controller/nodes/resolve_test.go | 12 +- pkg/controller/nodes/subworkflow/handler.go | 2 +- .../nodes/subworkflow/handler_test.go | 6 +- .../nodes/subworkflow/launchplan.go | 6 +- .../subworkflow/launchplan/mocks/Executor.go | 16 +- .../nodes/subworkflow/launchplan_test.go | 16 +- .../nodes/subworkflow/subworkflow.go | 15 +- .../nodes/task/k8s/plugin_manager.go | 1 + .../resourcemanager/config/config_flags.go | 10 +- .../config/config_flags_test.go | 40 +-- pkg/controller/workflow/executor.go | 10 +- pkg/controller/workflow/executor_test.go | 9 +- pkg/utils/encoder.go | 2 +- 70 files changed, 1110 insertions(+), 385 deletions(-) create mode 100644 pkg/apis/flyteworkflow/v1alpha1/mocks/Mutable.go diff --git a/cmd/kubectl-flyte/cmd/get.go b/cmd/kubectl-flyte/cmd/get.go index 023677631a..4973fe89b7 100644 --- a/cmd/kubectl-flyte/cmd/get.go +++ b/cmd/kubectl-flyte/cmd/get.go @@ -1,6 +1,7 @@ package cmd import ( + "context" "fmt" "sort" "strings" @@ -33,9 +34,11 @@ func NewGetCommand(opts *RootOptions) *cobra.Command { Short: "Gets a single workflow or lists all workflows currently in execution", Long: `use labels to filter`, RunE: func(cmd *cobra.Command, args []string) error { + ctx := context.Background() + if len(args) > 0 { name := args[0] - return getOpts.getWorkflow(name) + return getOpts.getWorkflow(ctx, name) } return getOpts.listWorkflows() }, @@ -49,7 +52,7 @@ func NewGetCommand(opts *RootOptions) *cobra.Command { return getCmd } -func (g *GetOpts) getWorkflow(name string) error { +func (g *GetOpts) getWorkflow(ctx context.Context, name string) error { parts := strings.Split(name, "/") if len(parts) > 1 { g.ConfigOverrides.Context.Namespace = parts[0] @@ -61,7 +64,7 @@ func (g *GetOpts) getWorkflow(name string) error { } wp := printers.WorkflowPrinter{} tree := gotree.New("Workflow") - if err := wp.Print(tree, w); err != nil { + if err := wp.Print(ctx, tree, w); err != nil { return err } fmt.Print(tree.Print()) diff --git a/cmd/kubectl-flyte/cmd/printers/node.go b/cmd/kubectl-flyte/cmd/printers/node.go index 6dd3a0bd9d..d7f6b220f7 100644 --- a/cmd/kubectl-flyte/cmd/printers/node.go +++ b/cmd/kubectl-flyte/cmd/printers/node.go @@ -1,6 +1,7 @@ package printers import ( + "context" "fmt" "strconv" "strings" @@ -79,7 +80,7 @@ func (p NodePrinter) BranchNodeInfo(node v1alpha1.ExecutableNode, nodeStatus v1a } -func (p NodePrinter) traverseNode(tree gotree.Tree, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode, nodeStatus v1alpha1.ExecutableNodeStatus) error { +func (p NodePrinter) traverseNode(ctx context.Context, tree gotree.Tree, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode, nodeStatus v1alpha1.ExecutableNodeStatus) error { switch node.GetKind() { case v1alpha1.NodeKindBranch: subTree := tree.Add(strings.Join(p.BranchNodeInfo(node, nodeStatus), " | ")) @@ -89,7 +90,7 @@ func (p NodePrinter) traverseNode(tree gotree.Tree, w v1alpha1.ExecutableWorkflo if !ok { return fmt.Errorf("failed to find branch node %s", *nodeID) } - if err := p.traverseNode(subTree, w, ifNode, nodeStatus.GetNodeExecutionStatus(*nodeID)); err != nil { + if err := p.traverseNode(ctx, subTree, w, ifNode, nodeStatus.GetNodeExecutionStatus(ctx, *nodeID)); err != nil { return err } } @@ -113,7 +114,7 @@ func (p NodePrinter) traverseNode(tree gotree.Tree, w v1alpha1.ExecutableWorkflo s := w.FindSubWorkflow(*node.GetWorkflowNode().GetSubWorkflowRef()) wp := WorkflowPrinter{} cw := executors.NewSubContextualWorkflow(w, s, nodeStatus) - return wp.Print(tree, cw) + return wp.Print(ctx, tree, cw) } case v1alpha1.NodeKindTask: sub := tree.Add(strings.Join(p.NodeInfo(w.GetName(), node, nodeStatus), " | ")) @@ -126,10 +127,10 @@ func (p NodePrinter) traverseNode(tree gotree.Tree, w v1alpha1.ExecutableWorkflo return nil } -func (p NodePrinter) PrintList(tree gotree.Tree, w v1alpha1.ExecutableWorkflow, nodes []v1alpha1.ExecutableNode) error { +func (p NodePrinter) PrintList(ctx context.Context, tree gotree.Tree, w v1alpha1.ExecutableWorkflow, nodes []v1alpha1.ExecutableNode) error { for _, n := range nodes { - s := w.GetNodeExecutionStatus(n.GetID()) - if err := p.traverseNode(tree, w, n, s); err != nil { + s := w.GetNodeExecutionStatus(ctx, n.GetID()) + if err := p.traverseNode(ctx, tree, w, n, s); err != nil { return err } } diff --git a/cmd/kubectl-flyte/cmd/printers/workflow.go b/cmd/kubectl-flyte/cmd/printers/workflow.go index f15da36bd1..8dcc8efb78 100644 --- a/cmd/kubectl-flyte/cmd/printers/workflow.go +++ b/cmd/kubectl-flyte/cmd/printers/workflow.go @@ -1,6 +1,7 @@ package printers import ( + "context" "fmt" "time" @@ -37,7 +38,7 @@ func CalculateWorkflowRuntime(s v1alpha1.ExecutableWorkflowStatus) string { type WorkflowPrinter struct { } -func (p WorkflowPrinter) Print(tree gotree.Tree, w v1alpha1.ExecutableWorkflow) error { +func (p WorkflowPrinter) Print(ctx context.Context, tree gotree.Tree, w v1alpha1.ExecutableWorkflow) error { sortedNodes, err := visualize.TopologicalSort(w) if err != nil { return err @@ -49,7 +50,7 @@ func (p WorkflowPrinter) Print(tree gotree.Tree, w v1alpha1.ExecutableWorkflow) tree.AddTree(newTree) } np := NodePrinter{} - return np.PrintList(newTree, w, sortedNodes) + return np.PrintList(ctx, newTree, w, sortedNodes) } func (p WorkflowPrinter) PrintShort(tree gotree.Tree, w v1alpha1.ExecutableWorkflow) error { diff --git a/pkg/apis/flyteworkflow/v1alpha1/iface.go b/pkg/apis/flyteworkflow/v1alpha1/iface.go index 4edf182fa7..238ff8b124 100644 --- a/pkg/apis/flyteworkflow/v1alpha1/iface.go +++ b/pkg/apis/flyteworkflow/v1alpha1/iface.go @@ -165,6 +165,7 @@ type ExecutableBranchNodeStatus interface { } type MutableBranchNodeStatus interface { + Mutable ExecutableBranchNodeStatus SetBranchNodeError() @@ -178,6 +179,7 @@ type ExecutableDynamicNodeStatus interface { } type MutableDynamicNodeStatus interface { + Mutable ExecutableDynamicNodeStatus SetDynamicNodePhase(phase DynamicNodePhase) @@ -198,11 +200,17 @@ type ExecutableWorkflowNodeStatus interface { } type MutableWorkflowNodeStatus interface { + Mutable ExecutableWorkflowNodeStatus SetWorkflowNodePhase(phase WorkflowNodePhase) } +type Mutable interface { + IsDirty() bool +} + type MutableNodeStatus interface { + Mutable // Mutation API's SetDataDir(DataReference) SetParentNodeID(n *NodeID) @@ -225,6 +233,7 @@ type MutableNodeStatus interface { GetDynamicNodeStatus() MutableDynamicNodeStatus ClearDynamicNodeStatus() ClearLastAttemptStartedAt() + ClearSubNodeStatus() } // Interface for a Node p. This provides a mutable API. @@ -247,7 +256,6 @@ type ExecutableNodeStatus interface { GetTaskNodeStatus() ExecutableTaskNodeStatus IsCached() bool - IsDirty() bool } type ExecutableSubWorkflowNodeStatus interface { @@ -255,6 +263,7 @@ type ExecutableSubWorkflowNodeStatus interface { } type MutableSubWorkflowNodeStatus interface { + Mutable ExecutableSubWorkflowNodeStatus SetPhase(phase WorkflowPhase) } @@ -268,6 +277,7 @@ type ExecutableTaskNodeStatus interface { } type MutableTaskNodeStatus interface { + Mutable ExecutableTaskNodeStatus SetPhase(phase int) SetPhaseVersion(version uint32) @@ -320,7 +330,7 @@ type ExecutableWorkflowStatus interface { SetOutputReference(reference DataReference) IncFailedAttempts() SetMessage(msg string) - ConstructNodeDataDir(ctx context.Context, constructor storage.ReferenceConstructor, name NodeID) (storage.DataReference, error) + ConstructNodeDataDir(ctx context.Context, name NodeID) (storage.DataReference, error) } type BaseWorkflow interface { @@ -381,7 +391,7 @@ type ExecutableWorkflow interface { } type NodeStatusGetter interface { - GetNodeExecutionStatus(id NodeID) ExecutableNodeStatus + GetNodeExecutionStatus(ctx context.Context, id NodeID) ExecutableNodeStatus } type NodeStatusMap = map[NodeID]ExecutableNodeStatus diff --git a/pkg/apis/flyteworkflow/v1alpha1/mocks/BaseWorkflowWithStatus.go b/pkg/apis/flyteworkflow/v1alpha1/mocks/BaseWorkflowWithStatus.go index 006f759c25..dc80fea299 100644 --- a/pkg/apis/flyteworkflow/v1alpha1/mocks/BaseWorkflowWithStatus.go +++ b/pkg/apis/flyteworkflow/v1alpha1/mocks/BaseWorkflowWithStatus.go @@ -3,6 +3,8 @@ package mocks import ( + context "context" + v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" mock "github.com/stretchr/testify/mock" ) @@ -134,8 +136,8 @@ func (_m BaseWorkflowWithStatus_GetNodeExecutionStatus) Return(_a0 v1alpha1.Exec return &BaseWorkflowWithStatus_GetNodeExecutionStatus{Call: _m.Call.Return(_a0)} } -func (_m *BaseWorkflowWithStatus) OnGetNodeExecutionStatus(id string) *BaseWorkflowWithStatus_GetNodeExecutionStatus { - c := _m.On("GetNodeExecutionStatus", id) +func (_m *BaseWorkflowWithStatus) OnGetNodeExecutionStatus(ctx context.Context, id string) *BaseWorkflowWithStatus_GetNodeExecutionStatus { + c := _m.On("GetNodeExecutionStatus", ctx, id) return &BaseWorkflowWithStatus_GetNodeExecutionStatus{Call: c} } @@ -144,13 +146,13 @@ func (_m *BaseWorkflowWithStatus) OnGetNodeExecutionStatusMatch(matchers ...inte return &BaseWorkflowWithStatus_GetNodeExecutionStatus{Call: c} } -// GetNodeExecutionStatus provides a mock function with given fields: id -func (_m *BaseWorkflowWithStatus) GetNodeExecutionStatus(id string) v1alpha1.ExecutableNodeStatus { - ret := _m.Called(id) +// GetNodeExecutionStatus provides a mock function with given fields: ctx, id +func (_m *BaseWorkflowWithStatus) GetNodeExecutionStatus(ctx context.Context, id string) v1alpha1.ExecutableNodeStatus { + ret := _m.Called(ctx, id) var r0 v1alpha1.ExecutableNodeStatus - if rf, ok := ret.Get(0).(func(string) v1alpha1.ExecutableNodeStatus); ok { - r0 = rf(id) + if rf, ok := ret.Get(0).(func(context.Context, string) v1alpha1.ExecutableNodeStatus); ok { + r0 = rf(ctx, id) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(v1alpha1.ExecutableNodeStatus) diff --git a/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableNodeStatus.go b/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableNodeStatus.go index 6662759f58..c6ac1236dd 100644 --- a/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableNodeStatus.go +++ b/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableNodeStatus.go @@ -3,6 +3,8 @@ package mocks import ( + context "context" + core "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" mock "github.com/stretchr/testify/mock" @@ -28,6 +30,11 @@ func (_m *ExecutableNodeStatus) ClearLastAttemptStartedAt() { _m.Called() } +// ClearSubNodeStatus provides a mock function with given fields: +func (_m *ExecutableNodeStatus) ClearSubNodeStatus() { + _m.Called() +} + // ClearTaskStatus provides a mock function with given fields: func (_m *ExecutableNodeStatus) ClearTaskStatus() { _m.Called() @@ -278,8 +285,8 @@ func (_m ExecutableNodeStatus_GetNodeExecutionStatus) Return(_a0 v1alpha1.Execut return &ExecutableNodeStatus_GetNodeExecutionStatus{Call: _m.Call.Return(_a0)} } -func (_m *ExecutableNodeStatus) OnGetNodeExecutionStatus(id string) *ExecutableNodeStatus_GetNodeExecutionStatus { - c := _m.On("GetNodeExecutionStatus", id) +func (_m *ExecutableNodeStatus) OnGetNodeExecutionStatus(ctx context.Context, id string) *ExecutableNodeStatus_GetNodeExecutionStatus { + c := _m.On("GetNodeExecutionStatus", ctx, id) return &ExecutableNodeStatus_GetNodeExecutionStatus{Call: c} } @@ -288,13 +295,13 @@ func (_m *ExecutableNodeStatus) OnGetNodeExecutionStatusMatch(matchers ...interf return &ExecutableNodeStatus_GetNodeExecutionStatus{Call: c} } -// GetNodeExecutionStatus provides a mock function with given fields: id -func (_m *ExecutableNodeStatus) GetNodeExecutionStatus(id string) v1alpha1.ExecutableNodeStatus { - ret := _m.Called(id) +// GetNodeExecutionStatus provides a mock function with given fields: ctx, id +func (_m *ExecutableNodeStatus) GetNodeExecutionStatus(ctx context.Context, id string) v1alpha1.ExecutableNodeStatus { + ret := _m.Called(ctx, id) var r0 v1alpha1.ExecutableNodeStatus - if rf, ok := ret.Get(0).(func(string) v1alpha1.ExecutableNodeStatus); ok { - r0 = rf(id) + if rf, ok := ret.Get(0).(func(context.Context, string) v1alpha1.ExecutableNodeStatus); ok { + r0 = rf(ctx, id) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(v1alpha1.ExecutableNodeStatus) diff --git a/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableWorkflow.go b/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableWorkflow.go index de51f07bad..ff0397a59c 100644 --- a/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableWorkflow.go +++ b/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableWorkflow.go @@ -3,6 +3,8 @@ package mocks import ( + context "context" + mock "github.com/stretchr/testify/mock" types "k8s.io/apimachinery/pkg/types" @@ -468,8 +470,8 @@ func (_m ExecutableWorkflow_GetNodeExecutionStatus) Return(_a0 v1alpha1.Executab return &ExecutableWorkflow_GetNodeExecutionStatus{Call: _m.Call.Return(_a0)} } -func (_m *ExecutableWorkflow) OnGetNodeExecutionStatus(id string) *ExecutableWorkflow_GetNodeExecutionStatus { - c := _m.On("GetNodeExecutionStatus", id) +func (_m *ExecutableWorkflow) OnGetNodeExecutionStatus(ctx context.Context, id string) *ExecutableWorkflow_GetNodeExecutionStatus { + c := _m.On("GetNodeExecutionStatus", ctx, id) return &ExecutableWorkflow_GetNodeExecutionStatus{Call: c} } @@ -478,13 +480,13 @@ func (_m *ExecutableWorkflow) OnGetNodeExecutionStatusMatch(matchers ...interfac return &ExecutableWorkflow_GetNodeExecutionStatus{Call: c} } -// GetNodeExecutionStatus provides a mock function with given fields: id -func (_m *ExecutableWorkflow) GetNodeExecutionStatus(id string) v1alpha1.ExecutableNodeStatus { - ret := _m.Called(id) +// GetNodeExecutionStatus provides a mock function with given fields: ctx, id +func (_m *ExecutableWorkflow) GetNodeExecutionStatus(ctx context.Context, id string) v1alpha1.ExecutableNodeStatus { + ret := _m.Called(ctx, id) var r0 v1alpha1.ExecutableNodeStatus - if rf, ok := ret.Get(0).(func(string) v1alpha1.ExecutableNodeStatus); ok { - r0 = rf(id) + if rf, ok := ret.Get(0).(func(context.Context, string) v1alpha1.ExecutableNodeStatus); ok { + r0 = rf(ctx, id) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(v1alpha1.ExecutableNodeStatus) diff --git a/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableWorkflowStatus.go b/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableWorkflowStatus.go index 6cf3283f63..5593b96c8c 100644 --- a/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableWorkflowStatus.go +++ b/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableWorkflowStatus.go @@ -26,8 +26,8 @@ func (_m ExecutableWorkflowStatus_ConstructNodeDataDir) Return(_a0 storage.DataR return &ExecutableWorkflowStatus_ConstructNodeDataDir{Call: _m.Call.Return(_a0, _a1)} } -func (_m *ExecutableWorkflowStatus) OnConstructNodeDataDir(ctx context.Context, constructor storage.ReferenceConstructor, name string) *ExecutableWorkflowStatus_ConstructNodeDataDir { - c := _m.On("ConstructNodeDataDir", ctx, constructor, name) +func (_m *ExecutableWorkflowStatus) OnConstructNodeDataDir(ctx context.Context, name string) *ExecutableWorkflowStatus_ConstructNodeDataDir { + c := _m.On("ConstructNodeDataDir", ctx, name) return &ExecutableWorkflowStatus_ConstructNodeDataDir{Call: c} } @@ -36,20 +36,20 @@ func (_m *ExecutableWorkflowStatus) OnConstructNodeDataDirMatch(matchers ...inte return &ExecutableWorkflowStatus_ConstructNodeDataDir{Call: c} } -// ConstructNodeDataDir provides a mock function with given fields: ctx, constructor, name -func (_m *ExecutableWorkflowStatus) ConstructNodeDataDir(ctx context.Context, constructor storage.ReferenceConstructor, name string) (storage.DataReference, error) { - ret := _m.Called(ctx, constructor, name) +// ConstructNodeDataDir provides a mock function with given fields: ctx, name +func (_m *ExecutableWorkflowStatus) ConstructNodeDataDir(ctx context.Context, name string) (storage.DataReference, error) { + ret := _m.Called(ctx, name) var r0 storage.DataReference - if rf, ok := ret.Get(0).(func(context.Context, storage.ReferenceConstructor, string) storage.DataReference); ok { - r0 = rf(ctx, constructor, name) + if rf, ok := ret.Get(0).(func(context.Context, string) storage.DataReference); ok { + r0 = rf(ctx, name) } else { r0 = ret.Get(0).(storage.DataReference) } var r1 error - if rf, ok := ret.Get(1).(func(context.Context, storage.ReferenceConstructor, string) error); ok { - r1 = rf(ctx, constructor, name) + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, name) } else { r1 = ret.Error(1) } @@ -163,8 +163,8 @@ func (_m ExecutableWorkflowStatus_GetNodeExecutionStatus) Return(_a0 v1alpha1.Ex return &ExecutableWorkflowStatus_GetNodeExecutionStatus{Call: _m.Call.Return(_a0)} } -func (_m *ExecutableWorkflowStatus) OnGetNodeExecutionStatus(id string) *ExecutableWorkflowStatus_GetNodeExecutionStatus { - c := _m.On("GetNodeExecutionStatus", id) +func (_m *ExecutableWorkflowStatus) OnGetNodeExecutionStatus(ctx context.Context, id string) *ExecutableWorkflowStatus_GetNodeExecutionStatus { + c := _m.On("GetNodeExecutionStatus", ctx, id) return &ExecutableWorkflowStatus_GetNodeExecutionStatus{Call: c} } @@ -173,13 +173,13 @@ func (_m *ExecutableWorkflowStatus) OnGetNodeExecutionStatusMatch(matchers ...in return &ExecutableWorkflowStatus_GetNodeExecutionStatus{Call: c} } -// GetNodeExecutionStatus provides a mock function with given fields: id -func (_m *ExecutableWorkflowStatus) GetNodeExecutionStatus(id string) v1alpha1.ExecutableNodeStatus { - ret := _m.Called(id) +// GetNodeExecutionStatus provides a mock function with given fields: ctx, id +func (_m *ExecutableWorkflowStatus) GetNodeExecutionStatus(ctx context.Context, id string) v1alpha1.ExecutableNodeStatus { + ret := _m.Called(ctx, id) var r0 v1alpha1.ExecutableNodeStatus - if rf, ok := ret.Get(0).(func(string) v1alpha1.ExecutableNodeStatus); ok { - r0 = rf(id) + if rf, ok := ret.Get(0).(func(context.Context, string) v1alpha1.ExecutableNodeStatus); ok { + r0 = rf(ctx, id) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(v1alpha1.ExecutableNodeStatus) diff --git a/pkg/apis/flyteworkflow/v1alpha1/mocks/Mutable.go b/pkg/apis/flyteworkflow/v1alpha1/mocks/Mutable.go new file mode 100644 index 0000000000..9a2666806f --- /dev/null +++ b/pkg/apis/flyteworkflow/v1alpha1/mocks/Mutable.go @@ -0,0 +1,42 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import mock "github.com/stretchr/testify/mock" + +// Mutable is an autogenerated mock type for the Mutable type +type Mutable struct { + mock.Mock +} + +type Mutable_IsDirty struct { + *mock.Call +} + +func (_m Mutable_IsDirty) Return(_a0 bool) *Mutable_IsDirty { + return &Mutable_IsDirty{Call: _m.Call.Return(_a0)} +} + +func (_m *Mutable) OnIsDirty() *Mutable_IsDirty { + c := _m.On("IsDirty") + return &Mutable_IsDirty{Call: c} +} + +func (_m *Mutable) OnIsDirtyMatch(matchers ...interface{}) *Mutable_IsDirty { + c := _m.On("IsDirty", matchers...) + return &Mutable_IsDirty{Call: c} +} + +// IsDirty provides a mock function with given fields: +func (_m *Mutable) IsDirty() bool { + ret := _m.Called() + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} diff --git a/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableBranchNodeStatus.go b/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableBranchNodeStatus.go index 3022b16c6e..b6ca6d4404 100644 --- a/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableBranchNodeStatus.go +++ b/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableBranchNodeStatus.go @@ -78,6 +78,38 @@ func (_m *MutableBranchNodeStatus) GetPhase() v1alpha1.BranchNodePhase { return r0 } +type MutableBranchNodeStatus_IsDirty struct { + *mock.Call +} + +func (_m MutableBranchNodeStatus_IsDirty) Return(_a0 bool) *MutableBranchNodeStatus_IsDirty { + return &MutableBranchNodeStatus_IsDirty{Call: _m.Call.Return(_a0)} +} + +func (_m *MutableBranchNodeStatus) OnIsDirty() *MutableBranchNodeStatus_IsDirty { + c := _m.On("IsDirty") + return &MutableBranchNodeStatus_IsDirty{Call: c} +} + +func (_m *MutableBranchNodeStatus) OnIsDirtyMatch(matchers ...interface{}) *MutableBranchNodeStatus_IsDirty { + c := _m.On("IsDirty", matchers...) + return &MutableBranchNodeStatus_IsDirty{Call: c} +} + +// IsDirty provides a mock function with given fields: +func (_m *MutableBranchNodeStatus) IsDirty() bool { + ret := _m.Called() + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + // SetBranchNodeError provides a mock function with given fields: func (_m *MutableBranchNodeStatus) SetBranchNodeError() { _m.Called() diff --git a/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableDynamicNodeStatus.go b/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableDynamicNodeStatus.go index d9d20bc7a6..70ccd23a1f 100644 --- a/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableDynamicNodeStatus.go +++ b/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableDynamicNodeStatus.go @@ -76,6 +76,38 @@ func (_m *MutableDynamicNodeStatus) GetDynamicNodeReason() string { return r0 } +type MutableDynamicNodeStatus_IsDirty struct { + *mock.Call +} + +func (_m MutableDynamicNodeStatus_IsDirty) Return(_a0 bool) *MutableDynamicNodeStatus_IsDirty { + return &MutableDynamicNodeStatus_IsDirty{Call: _m.Call.Return(_a0)} +} + +func (_m *MutableDynamicNodeStatus) OnIsDirty() *MutableDynamicNodeStatus_IsDirty { + c := _m.On("IsDirty") + return &MutableDynamicNodeStatus_IsDirty{Call: c} +} + +func (_m *MutableDynamicNodeStatus) OnIsDirtyMatch(matchers ...interface{}) *MutableDynamicNodeStatus_IsDirty { + c := _m.On("IsDirty", matchers...) + return &MutableDynamicNodeStatus_IsDirty{Call: c} +} + +// IsDirty provides a mock function with given fields: +func (_m *MutableDynamicNodeStatus) IsDirty() bool { + ret := _m.Called() + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + // SetDynamicNodePhase provides a mock function with given fields: phase func (_m *MutableDynamicNodeStatus) SetDynamicNodePhase(phase v1alpha1.DynamicNodePhase) { _m.Called(phase) diff --git a/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableNodeStatus.go b/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableNodeStatus.go index 27521d7bff..bab9e4feb5 100644 --- a/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableNodeStatus.go +++ b/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableNodeStatus.go @@ -28,6 +28,11 @@ func (_m *MutableNodeStatus) ClearLastAttemptStartedAt() { _m.Called() } +// ClearSubNodeStatus provides a mock function with given fields: +func (_m *MutableNodeStatus) ClearSubNodeStatus() { + _m.Called() +} + // ClearTaskStatus provides a mock function with given fields: func (_m *MutableNodeStatus) ClearTaskStatus() { _m.Called() @@ -342,6 +347,38 @@ func (_m *MutableNodeStatus) IncrementAttempts() uint32 { return r0 } +type MutableNodeStatus_IsDirty struct { + *mock.Call +} + +func (_m MutableNodeStatus_IsDirty) Return(_a0 bool) *MutableNodeStatus_IsDirty { + return &MutableNodeStatus_IsDirty{Call: _m.Call.Return(_a0)} +} + +func (_m *MutableNodeStatus) OnIsDirty() *MutableNodeStatus_IsDirty { + c := _m.On("IsDirty") + return &MutableNodeStatus_IsDirty{Call: c} +} + +func (_m *MutableNodeStatus) OnIsDirtyMatch(matchers ...interface{}) *MutableNodeStatus_IsDirty { + c := _m.On("IsDirty", matchers...) + return &MutableNodeStatus_IsDirty{Call: c} +} + +// IsDirty provides a mock function with given fields: +func (_m *MutableNodeStatus) IsDirty() bool { + ret := _m.Called() + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + // ResetDirty provides a mock function with given fields: func (_m *MutableNodeStatus) ResetDirty() { _m.Called() diff --git a/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableSubWorkflowNodeStatus.go b/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableSubWorkflowNodeStatus.go index 2da9d0ed4b..2ac0c8006c 100644 --- a/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableSubWorkflowNodeStatus.go +++ b/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableSubWorkflowNodeStatus.go @@ -44,6 +44,38 @@ func (_m *MutableSubWorkflowNodeStatus) GetPhase() v1alpha1.WorkflowPhase { return r0 } +type MutableSubWorkflowNodeStatus_IsDirty struct { + *mock.Call +} + +func (_m MutableSubWorkflowNodeStatus_IsDirty) Return(_a0 bool) *MutableSubWorkflowNodeStatus_IsDirty { + return &MutableSubWorkflowNodeStatus_IsDirty{Call: _m.Call.Return(_a0)} +} + +func (_m *MutableSubWorkflowNodeStatus) OnIsDirty() *MutableSubWorkflowNodeStatus_IsDirty { + c := _m.On("IsDirty") + return &MutableSubWorkflowNodeStatus_IsDirty{Call: c} +} + +func (_m *MutableSubWorkflowNodeStatus) OnIsDirtyMatch(matchers ...interface{}) *MutableSubWorkflowNodeStatus_IsDirty { + c := _m.On("IsDirty", matchers...) + return &MutableSubWorkflowNodeStatus_IsDirty{Call: c} +} + +// IsDirty provides a mock function with given fields: +func (_m *MutableSubWorkflowNodeStatus) IsDirty() bool { + ret := _m.Called() + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + // SetPhase provides a mock function with given fields: phase func (_m *MutableSubWorkflowNodeStatus) SetPhase(phase v1alpha1.WorkflowPhase) { _m.Called(phase) diff --git a/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableTaskNodeStatus.go b/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableTaskNodeStatus.go index fad9a8f2ac..8889034718 100644 --- a/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableTaskNodeStatus.go +++ b/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableTaskNodeStatus.go @@ -171,6 +171,38 @@ func (_m *MutableTaskNodeStatus) GetPluginStateVersion() uint32 { return r0 } +type MutableTaskNodeStatus_IsDirty struct { + *mock.Call +} + +func (_m MutableTaskNodeStatus_IsDirty) Return(_a0 bool) *MutableTaskNodeStatus_IsDirty { + return &MutableTaskNodeStatus_IsDirty{Call: _m.Call.Return(_a0)} +} + +func (_m *MutableTaskNodeStatus) OnIsDirty() *MutableTaskNodeStatus_IsDirty { + c := _m.On("IsDirty") + return &MutableTaskNodeStatus_IsDirty{Call: c} +} + +func (_m *MutableTaskNodeStatus) OnIsDirtyMatch(matchers ...interface{}) *MutableTaskNodeStatus_IsDirty { + c := _m.On("IsDirty", matchers...) + return &MutableTaskNodeStatus_IsDirty{Call: c} +} + +// IsDirty provides a mock function with given fields: +func (_m *MutableTaskNodeStatus) IsDirty() bool { + ret := _m.Called() + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + // SetBarrierClockTick provides a mock function with given fields: tick func (_m *MutableTaskNodeStatus) SetBarrierClockTick(tick uint32) { _m.Called(tick) diff --git a/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableWorkflowNodeStatus.go b/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableWorkflowNodeStatus.go index 34a15c91a5..33b19a0ccd 100644 --- a/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableWorkflowNodeStatus.go +++ b/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableWorkflowNodeStatus.go @@ -44,6 +44,38 @@ func (_m *MutableWorkflowNodeStatus) GetWorkflowNodePhase() v1alpha1.WorkflowNod return r0 } +type MutableWorkflowNodeStatus_IsDirty struct { + *mock.Call +} + +func (_m MutableWorkflowNodeStatus_IsDirty) Return(_a0 bool) *MutableWorkflowNodeStatus_IsDirty { + return &MutableWorkflowNodeStatus_IsDirty{Call: _m.Call.Return(_a0)} +} + +func (_m *MutableWorkflowNodeStatus) OnIsDirty() *MutableWorkflowNodeStatus_IsDirty { + c := _m.On("IsDirty") + return &MutableWorkflowNodeStatus_IsDirty{Call: c} +} + +func (_m *MutableWorkflowNodeStatus) OnIsDirtyMatch(matchers ...interface{}) *MutableWorkflowNodeStatus_IsDirty { + c := _m.On("IsDirty", matchers...) + return &MutableWorkflowNodeStatus_IsDirty{Call: c} +} + +// IsDirty provides a mock function with given fields: +func (_m *MutableWorkflowNodeStatus) IsDirty() bool { + ret := _m.Called() + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + // SetWorkflowNodePhase provides a mock function with given fields: phase func (_m *MutableWorkflowNodeStatus) SetWorkflowNodePhase(phase v1alpha1.WorkflowNodePhase) { _m.Called(phase) diff --git a/pkg/apis/flyteworkflow/v1alpha1/mocks/NodeStatusGetter.go b/pkg/apis/flyteworkflow/v1alpha1/mocks/NodeStatusGetter.go index 1445e3ea8c..308dea1b66 100644 --- a/pkg/apis/flyteworkflow/v1alpha1/mocks/NodeStatusGetter.go +++ b/pkg/apis/flyteworkflow/v1alpha1/mocks/NodeStatusGetter.go @@ -3,6 +3,8 @@ package mocks import ( + context "context" + v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" mock "github.com/stretchr/testify/mock" ) @@ -20,8 +22,8 @@ func (_m NodeStatusGetter_GetNodeExecutionStatus) Return(_a0 v1alpha1.Executable return &NodeStatusGetter_GetNodeExecutionStatus{Call: _m.Call.Return(_a0)} } -func (_m *NodeStatusGetter) OnGetNodeExecutionStatus(id string) *NodeStatusGetter_GetNodeExecutionStatus { - c := _m.On("GetNodeExecutionStatus", id) +func (_m *NodeStatusGetter) OnGetNodeExecutionStatus(ctx context.Context, id string) *NodeStatusGetter_GetNodeExecutionStatus { + c := _m.On("GetNodeExecutionStatus", ctx, id) return &NodeStatusGetter_GetNodeExecutionStatus{Call: c} } @@ -30,13 +32,13 @@ func (_m *NodeStatusGetter) OnGetNodeExecutionStatusMatch(matchers ...interface{ return &NodeStatusGetter_GetNodeExecutionStatus{Call: c} } -// GetNodeExecutionStatus provides a mock function with given fields: id -func (_m *NodeStatusGetter) GetNodeExecutionStatus(id string) v1alpha1.ExecutableNodeStatus { - ret := _m.Called(id) +// GetNodeExecutionStatus provides a mock function with given fields: ctx, id +func (_m *NodeStatusGetter) GetNodeExecutionStatus(ctx context.Context, id string) v1alpha1.ExecutableNodeStatus { + ret := _m.Called(ctx, id) var r0 v1alpha1.ExecutableNodeStatus - if rf, ok := ret.Get(0).(func(string) v1alpha1.ExecutableNodeStatus); ok { - r0 = rf(id) + if rf, ok := ret.Get(0).(func(context.Context, string) v1alpha1.ExecutableNodeStatus); ok { + r0 = rf(ctx, id) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(v1alpha1.ExecutableNodeStatus) diff --git a/pkg/apis/flyteworkflow/v1alpha1/node_status.go b/pkg/apis/flyteworkflow/v1alpha1/node_status.go index 4f4dd6d8a0..1087d9a45a 100644 --- a/pkg/apis/flyteworkflow/v1alpha1/node_status.go +++ b/pkg/apis/flyteworkflow/v1alpha1/node_status.go @@ -2,14 +2,37 @@ package v1alpha1 import ( "bytes" + "context" "encoding/json" "reflect" + "github.com/lyft/flytestdlib/storage" + + "github.com/lyft/flytestdlib/logger" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" ) +type MutableStruct struct { + isDirty bool +} + +func (in *MutableStruct) SetDirty() { + in.isDirty = true +} + +// For testing only +func (in *MutableStruct) ResetDirty() { + in.isDirty = false +} + +func (in MutableStruct) IsDirty() bool { + return in.isDirty +} + type BranchNodeStatus struct { + MutableStruct Phase BranchNodePhase `json:"phase"` FinalizedNodeID *NodeID `json:"finalNodeId"` } @@ -19,10 +42,12 @@ func (in *BranchNodeStatus) GetPhase() BranchNodePhase { } func (in *BranchNodeStatus) SetBranchNodeError() { + in.SetDirty() in.Phase = BranchNodeError } func (in *BranchNodeStatus) SetBranchNodeSuccess(id NodeID) { + in.SetDirty() in.Phase = BranchNodeSuccess in.FinalizedNodeID = &id } @@ -60,34 +85,41 @@ const ( ) type DynamicNodeStatus struct { + MutableStruct Phase DynamicNodePhase `json:"phase"` Reason string `json:"reason"` } -func (s *DynamicNodeStatus) GetDynamicNodePhase() DynamicNodePhase { - return s.Phase +func (in *DynamicNodeStatus) GetDynamicNodePhase() DynamicNodePhase { + return in.Phase } -func (s *DynamicNodeStatus) GetDynamicNodeReason() string { - return s.Reason +func (in *DynamicNodeStatus) GetDynamicNodeReason() string { + return in.Reason } -func (s *DynamicNodeStatus) SetDynamicNodeReason(reason string) { - s.Reason = reason +func (in *DynamicNodeStatus) SetDynamicNodeReason(reason string) { + if in.Reason != reason { + in.SetDirty() + in.Reason = reason + } } -func (s *DynamicNodeStatus) SetDynamicNodePhase(phase DynamicNodePhase) { - s.Phase = phase +func (in *DynamicNodeStatus) SetDynamicNodePhase(phase DynamicNodePhase) { + if in.Phase != phase { + in.SetDirty() + in.Phase = phase + } } -func (s *DynamicNodeStatus) Equals(o *DynamicNodeStatus) bool { - if s == nil && o == nil { +func (in *DynamicNodeStatus) Equals(o *DynamicNodeStatus) bool { + if in == nil && o == nil { return true } - if s == nil || o == nil { + if in == nil || o == nil { return false } - return s.Phase == o.Phase && s.Reason == o.Reason + return in.Phase == o.Phase && in.Reason == o.Reason } type WorkflowNodePhase int @@ -98,6 +130,7 @@ const ( ) type WorkflowNodeStatus struct { + MutableStruct Phase WorkflowNodePhase `json:"phase"` } @@ -106,10 +139,14 @@ func (in *WorkflowNodeStatus) GetWorkflowNodePhase() WorkflowNodePhase { } func (in *WorkflowNodeStatus) SetWorkflowNodePhase(phase WorkflowNodePhase) { - in.Phase = phase + if in.Phase != phase { + in.SetDirty() + in.Phase = phase + } } type NodeStatus struct { + MutableStruct Phase NodePhase `json:"phase"` QueuedAt *metav1.Time `json:"queuedAt,omitempty"` StartedAt *metav1.Time `json:"startedAt,omitempty"` @@ -117,23 +154,70 @@ type NodeStatus struct { LastUpdatedAt *metav1.Time `json:"lastUpdatedAt,omitempty"` LastAttemptStartedAt *metav1.Time `json:"laStartedAt,omitempty"` Message string `json:"message,omitempty"` - DataDir DataReference `json:"dataDir,omitempty"` + DataDir DataReference `json:"-"` Attempts uint32 `json:"attempts"` Cached bool `json:"cached"` - dirty bool // This is useful only for branch nodes. If this is set, then it can be used to determine if execution can proceed ParentNode *NodeID `json:"parentNode,omitempty"` - ParentTask *TaskExecutionIdentifier `json:"parentTask,omitempty"` + ParentTask *TaskExecutionIdentifier `json:"-"` BranchStatus *BranchNodeStatus `json:"branchStatus,omitempty"` SubNodeStatus map[NodeID]*NodeStatus `json:"subNodeStatus,omitempty"` // We can store the outputs at this layer // TODO not used delete WorkflowNodeStatus *WorkflowNodeStatus `json:"workflowNodeStatus,omitempty"` - TaskNodeStatus *TaskNodeStatus `json:",omitempty"` - // TODO not used delete + + TaskNodeStatus *TaskNodeStatus `json:",omitempty"` DynamicNodeStatus *DynamicNodeStatus `json:"dynamicNodeStatus,omitempty"` + + // Not Persisted + DataReferenceConstructor storage.ReferenceConstructor `json:"-"` +} + +func (in *NodeStatus) IsDirty() bool { + isDirty := in.MutableStruct.IsDirty() || + (in.TaskNodeStatus != nil && in.TaskNodeStatus.IsDirty()) || + (in.DynamicNodeStatus != nil && in.DynamicNodeStatus.IsDirty()) || + (in.WorkflowNodeStatus != nil && in.WorkflowNodeStatus.IsDirty()) || + (in.BranchStatus != nil && in.BranchStatus.IsDirty()) + if isDirty { + return true + } + + for _, sub := range in.SubNodeStatus { + if sub.IsDirty() { + return true + } + } + + return false +} + +// ResetDirty is for unit tests, shouldn't be used in actual logic. +func (in *NodeStatus) ResetDirty() { + in.MutableStruct.ResetDirty() + + if in.TaskNodeStatus != nil { + in.TaskNodeStatus.ResetDirty() + } + + if in.DynamicNodeStatus != nil { + in.DynamicNodeStatus.ResetDirty() + } + + if in.WorkflowNodeStatus != nil { + in.WorkflowNodeStatus.ResetDirty() + } + + if in.BranchStatus != nil { + in.BranchStatus.ResetDirty() + } + + // Reset SubNodeStatus Dirty + for _, subStatus := range in.SubNodeStatus { + subStatus.ResetDirty() + } } func (in *NodeStatus) GetBranchStatus() MutableBranchNodeStatus { @@ -172,14 +256,22 @@ func (in NodeStatus) GetDynamicNodeStatus() MutableDynamicNodeStatus { func (in *NodeStatus) ClearWorkflowStatus() { in.WorkflowNodeStatus = nil + in.SetDirty() } func (in *NodeStatus) ClearTaskStatus() { in.TaskNodeStatus = nil + in.SetDirty() } func (in *NodeStatus) ClearLastAttemptStartedAt() { in.LastAttemptStartedAt = nil + in.SetDirty() +} + +func (in *NodeStatus) ClearSubNodeStatus() { + in.SubNodeStatus = nil + in.SetDirty() } func (in *NodeStatus) GetLastUpdatedAt() *metav1.Time { @@ -196,35 +288,25 @@ func (in *NodeStatus) GetAttempts() uint32 { func (in *NodeStatus) SetCached() { in.Cached = true - in.setDirty() + in.SetDirty() } -func (in *NodeStatus) setDirty() { - in.dirty = true -} func (in *NodeStatus) IsCached() bool { return in.Cached } -func (in *NodeStatus) IsDirty() bool { - return in.dirty -} - -// ResetDirty is for unit tests, shouldn't be used in actual logic. -func (in *NodeStatus) ResetDirty() { - in.dirty = false -} - func (in *NodeStatus) IncrementAttempts() uint32 { in.Attempts++ - in.setDirty() + in.SetDirty() return in.Attempts } func (in *NodeStatus) GetOrCreateDynamicNodeStatus() MutableDynamicNodeStatus { if in.DynamicNodeStatus == nil { - in.setDirty() - in.DynamicNodeStatus = &DynamicNodeStatus{} + in.SetDirty() + in.DynamicNodeStatus = &DynamicNodeStatus{ + MutableStruct: MutableStruct{}, + } } return in.DynamicNodeStatus @@ -232,14 +314,17 @@ func (in *NodeStatus) GetOrCreateDynamicNodeStatus() MutableDynamicNodeStatus { func (in *NodeStatus) ClearDynamicNodeStatus() { in.DynamicNodeStatus = nil + in.SetDirty() } func (in *NodeStatus) GetOrCreateBranchStatus() MutableBranchNodeStatus { if in.BranchStatus == nil { - in.BranchStatus = &BranchNodeStatus{} + in.SetDirty() + in.BranchStatus = &BranchNodeStatus{ + MutableStruct: MutableStruct{}, + } } - in.setDirty() return in.BranchStatus } @@ -248,7 +333,6 @@ func (in *NodeStatus) GetWorkflowNodeStatus() ExecutableWorkflowNodeStatus { return nil } - in.setDirty() return in.WorkflowNodeStatus } @@ -266,10 +350,12 @@ func IsPhaseTerminal(phase NodePhase) bool { func (in *NodeStatus) GetOrCreateTaskStatus() MutableTaskNodeStatus { if in.TaskNodeStatus == nil { - in.TaskNodeStatus = &TaskNodeStatus{} + in.SetDirty() + in.TaskNodeStatus = &TaskNodeStatus{ + MutableStruct: MutableStruct{}, + } } - in.setDirty() return in.TaskNodeStatus } @@ -307,11 +393,8 @@ func (in *NodeStatus) UpdatePhase(p NodePhase, occurredAt metav1.Time, reason st in.StoppedAt = &n } - if in.Phase != p { - in.LastUpdatedAt = &n - } - - in.setDirty() + in.LastUpdatedAt = &n + in.SetDirty() } func (in *NodeStatus) GetStartedAt() *metav1.Time { @@ -338,23 +421,31 @@ func (in *NodeStatus) GetParentTaskID() *core.TaskExecutionIdentifier { } func (in *NodeStatus) SetParentNodeID(n *NodeID) { - in.ParentNode = n - in.setDirty() + if in.ParentNode == nil || in.ParentNode != n { + in.ParentNode = n + in.SetDirty() + } } func (in *NodeStatus) SetParentTaskID(t *core.TaskExecutionIdentifier) { - in.ParentTask = &TaskExecutionIdentifier{ - TaskExecutionIdentifier: t, + if in.ParentTask == nil || in.ParentTask.TaskExecutionIdentifier != t { + in.ParentTask = &TaskExecutionIdentifier{ + TaskExecutionIdentifier: t, + } + + // We do not need to set Dirty here because this field is not persisted. + //in.SetDirty() } - in.setDirty() } func (in *NodeStatus) GetOrCreateWorkflowStatus() MutableWorkflowNodeStatus { if in.WorkflowNodeStatus == nil { - in.WorkflowNodeStatus = &WorkflowNodeStatus{} + in.SetDirty() + in.WorkflowNodeStatus = &WorkflowNodeStatus{ + MutableStruct: MutableStruct{}, + } } - in.setDirty() return in.WorkflowNodeStatus } @@ -367,19 +458,44 @@ func (in NodeStatus) GetTaskNodeStatus() ExecutableTaskNodeStatus { return in.TaskNodeStatus } -func (in *NodeStatus) GetNodeExecutionStatus(id NodeID) ExecutableNodeStatus { +func (in *NodeStatus) GetNodeExecutionStatus(ctx context.Context, id NodeID) ExecutableNodeStatus { n, ok := in.SubNodeStatus[id] if ok { + n.SetParentTaskID(in.GetParentTaskID()) + n.DataReferenceConstructor = in.DataReferenceConstructor + if len(n.GetDataDir()) == 0 { + dataDir, err := in.DataReferenceConstructor.ConstructReference(ctx, in.GetDataDir(), id) + if err != nil { + logger.Errorf(ctx, "Failed to construct data dir for node [%v]", id) + return n + } + + n.SetDataDir(dataDir) + } + return n } + if in.SubNodeStatus == nil { in.SubNodeStatus = make(map[NodeID]*NodeStatus) } - newNodeStatus := &NodeStatus{} + + newNodeStatus := &NodeStatus{ + MutableStruct: MutableStruct{}, + } newNodeStatus.SetParentTaskID(in.GetParentTaskID()) newNodeStatus.SetParentNodeID(in.GetParentNodeID()) + dataDir, err := in.DataReferenceConstructor.ConstructReference(ctx, in.GetDataDir(), id) + if err != nil { + logger.Errorf(ctx, "Failed to construct data dir for node [%v]", id) + return n + } + + newNodeStatus.SetDataDir(dataDir) + newNodeStatus.DataReferenceConstructor = in.DataReferenceConstructor in.SubNodeStatus[id] = newNodeStatus + in.SetDirty() return newNodeStatus } @@ -393,7 +509,9 @@ func (in *NodeStatus) GetDataDir() DataReference { func (in *NodeStatus) SetDataDir(d DataReference) { in.DataDir = d - in.setDirty() + + // We do not need to set Dirty here because this field is not persisted. + //in.SetDirty() } func (in *NodeStatus) Equals(other *NodeStatus) bool { @@ -402,6 +520,10 @@ func (in *NodeStatus) Equals(other *NodeStatus) bool { return false } + if in.IsDirty() != other.IsDirty() { + return false + } + if in.Phase == other.Phase { if in.Phase == NodePhaseSucceeded || in.Phase == NodePhaseFailed { return true @@ -482,6 +604,7 @@ func (in *CustomState) DeepCopy() *CustomState { } type TaskNodeStatus struct { + MutableStruct Phase int `json:"phase,omitempty"` PhaseVersion uint32 `json:"phaseVersion,omitempty"` PluginState []byte `json:"pState,omitempty"` @@ -495,14 +618,17 @@ func (in *TaskNodeStatus) GetBarrierClockTick() uint32 { func (in *TaskNodeStatus) SetBarrierClockTick(tick uint32) { in.BarrierClockTick = tick + in.SetDirty() } func (in *TaskNodeStatus) SetPluginState(s []byte) { in.PluginState = s + in.SetDirty() } func (in *TaskNodeStatus) SetPluginStateVersion(v uint32) { in.PluginStateVersion = v + in.SetDirty() } func (in *TaskNodeStatus) GetPluginState() []byte { @@ -515,10 +641,12 @@ func (in *TaskNodeStatus) GetPluginStateVersion() uint32 { func (in *TaskNodeStatus) SetPhase(phase int) { in.Phase = phase + in.SetDirty() } func (in *TaskNodeStatus) SetPhaseVersion(version uint32) { in.PhaseVersion = version + in.SetDirty() } func (in TaskNodeStatus) GetPhase() int { @@ -530,6 +658,10 @@ func (in TaskNodeStatus) GetPhaseVersion() uint32 { } func (in *TaskNodeStatus) UpdatePhase(phase int, phaseVersion uint32) { + if in.Phase != phase || in.PhaseVersion != phaseVersion { + in.SetDirty() + } + in.Phase = phase in.PhaseVersion = phaseVersion } diff --git a/pkg/apis/flyteworkflow/v1alpha1/workflow.go b/pkg/apis/flyteworkflow/v1alpha1/workflow.go index ba497adf7c..2da5bd62a6 100644 --- a/pkg/apis/flyteworkflow/v1alpha1/workflow.go +++ b/pkg/apis/flyteworkflow/v1alpha1/workflow.go @@ -2,8 +2,11 @@ package v1alpha1 import ( "bytes" + "context" "encoding/json" + "github.com/lyft/flytestdlib/storage" + "k8s.io/apimachinery/pkg/types" "github.com/golang/protobuf/jsonpb" @@ -39,6 +42,9 @@ type FlyteWorkflow struct { ServiceAccountName string `json:"serviceAccountName,omitempty" protobuf:"bytes,8,opt,name=serviceAccountName"` // Status is the only mutable section in the workflow. It holds all the execution information Status WorkflowStatus `json:"status,omitempty"` + + // non-Serialized fields + DataReferenceConstructor storage.ReferenceConstructor `json:"-"` } var FlyteWorkflowGVK = SchemeGroupVersion.WithKind(FlyteWorkflowKind) @@ -61,7 +67,9 @@ func (in *FlyteWorkflow) GetTask(id TaskID) (ExecutableTask, error) { } func (in *FlyteWorkflow) GetExecutionStatus() ExecutableWorkflowStatus { - return &in.Status + s := &in.Status + s.DataReferenceConstructor = in.DataReferenceConstructor + return s } func (in *FlyteWorkflow) GetK8sWorkflowID() types.NamespacedName { @@ -83,8 +91,8 @@ func (in *FlyteWorkflow) FindSubWorkflow(subID WorkflowID) ExecutableSubWorkflow return s } -func (in *FlyteWorkflow) GetNodeExecutionStatus(id NodeID) ExecutableNodeStatus { - return in.Status.GetNodeExecutionStatus(id) +func (in *FlyteWorkflow) GetNodeExecutionStatus(ctx context.Context, id NodeID) ExecutableNodeStatus { + return in.GetExecutionStatus().GetNodeExecutionStatus(ctx, id) } func (in *FlyteWorkflow) GetServiceAccountName() string { @@ -190,6 +198,7 @@ func (in *WorkflowSpec) FromNode(name NodeID) ([]NodeID, error) { if _, ok := in.Nodes[name]; !ok { return nil, errors.Errorf("Bad Node [%v], is not defined in the Workflow [%v]", name, in.ID) } + downstreamNodes := in.Connections.DownstreamEdges[name] return downstreamNodes, nil } diff --git a/pkg/apis/flyteworkflow/v1alpha1/workflow_status.go b/pkg/apis/flyteworkflow/v1alpha1/workflow_status.go index 4027d4d958..5b3a342643 100644 --- a/pkg/apis/flyteworkflow/v1alpha1/workflow_status.go +++ b/pkg/apis/flyteworkflow/v1alpha1/workflow_status.go @@ -3,6 +3,8 @@ package v1alpha1 import ( "context" + "github.com/lyft/flytestdlib/logger" + "github.com/lyft/flytestdlib/storage" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" ) @@ -27,6 +29,9 @@ type WorkflowStatus struct { // that spin in an error loop. The value should be set at the global level and will be enforced. At the end of // the retries the workflow will fail FailedAttempts uint32 `json:"failedAttempts,omitempty"` + + // non-Serialized fields + DataReferenceConstructor storage.ReferenceConstructor `json:"-"` } func IsWorkflowPhaseTerminal(p WorkflowPhase) bool { @@ -84,21 +89,46 @@ func (in *WorkflowStatus) GetMessage() string { return in.Message } -func (in *WorkflowStatus) GetNodeExecutionStatus(id NodeID) ExecutableNodeStatus { +func (in *WorkflowStatus) GetNodeExecutionStatus(ctx context.Context, id NodeID) ExecutableNodeStatus { n, ok := in.NodeStatus[id] if ok { + n.DataReferenceConstructor = in.DataReferenceConstructor + if len(n.GetDataDir()) == 0 { + dataDir, err := in.ConstructNodeDataDir(ctx, id) + if err != nil { + logger.Errorf(ctx, "Failed to construct data dir for node [%v]", id) + return n + } + + n.SetDataDir(dataDir) + } + return n } + if in.NodeStatus == nil { in.NodeStatus = make(map[NodeID]*NodeStatus) } - newNodeStatus := &NodeStatus{} + + newNodeStatus := &NodeStatus{ + MutableStruct: MutableStruct{}, + } + + dataDir, err := in.ConstructNodeDataDir(ctx, id) + if err != nil { + logger.Errorf(ctx, "Failed to construct data dir for node [%v], exec id [%v]", id) + return n + } + + newNodeStatus.SetDataDir(dataDir) + newNodeStatus.DataReferenceConstructor = in.DataReferenceConstructor + in.NodeStatus[id] = newNodeStatus return newNodeStatus } -func (in *WorkflowStatus) ConstructNodeDataDir(ctx context.Context, constructor storage.ReferenceConstructor, name NodeID) (storage.DataReference, error) { - return constructor.ConstructReference(ctx, in.GetDataDir(), name, "data") +func (in *WorkflowStatus) ConstructNodeDataDir(ctx context.Context, name NodeID) (storage.DataReference, error) { + return in.DataReferenceConstructor.ConstructReference(ctx, in.GetDataDir(), name, "data") } func (in *WorkflowStatus) GetDataDir() DataReference { diff --git a/pkg/apis/flyteworkflow/v1alpha1/zz_generated.deepcopy.go b/pkg/apis/flyteworkflow/v1alpha1/zz_generated.deepcopy.go index 261b6a717e..5fb6f32daa 100644 --- a/pkg/apis/flyteworkflow/v1alpha1/zz_generated.deepcopy.go +++ b/pkg/apis/flyteworkflow/v1alpha1/zz_generated.deepcopy.go @@ -80,6 +80,7 @@ func (in *BranchNodeSpec) DeepCopy() *BranchNodeSpec { // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *BranchNodeStatus) DeepCopyInto(out *BranchNodeStatus) { *out = *in + out.MutableStruct = in.MutableStruct if in.FinalizedNodeID != nil { in, out := &in.FinalizedNodeID, &out.FinalizedNodeID *out = new(string) @@ -111,6 +112,7 @@ func (in *Connections) DeepCopy() *Connections { // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *DynamicNodeStatus) DeepCopyInto(out *DynamicNodeStatus) { *out = *in + out.MutableStruct = in.MutableStruct return } @@ -284,6 +286,22 @@ func (in *Inputs) DeepCopy() *Inputs { return out } +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *MutableStruct) DeepCopyInto(out *MutableStruct) { + *out = *in + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new MutableStruct. +func (in *MutableStruct) DeepCopy() *MutableStruct { + if in == nil { + return nil + } + out := new(MutableStruct) + in.DeepCopyInto(out) + return out +} + // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new NodeMetadata. func (in *NodeMetadata) DeepCopy() *NodeMetadata { if in == nil { @@ -397,6 +415,7 @@ func (in *NodeSpec) DeepCopy() *NodeSpec { // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *NodeStatus) DeepCopyInto(out *NodeStatus) { *out = *in + out.MutableStruct = in.MutableStruct if in.QueuedAt != nil { in, out := &in.QueuedAt, &out.QueuedAt *out = (*in).DeepCopy() @@ -562,6 +581,7 @@ func (in *WorkflowNodeSpec) DeepCopy() *WorkflowNodeSpec { // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *WorkflowNodeStatus) DeepCopyInto(out *WorkflowNodeStatus) { *out = *in + out.MutableStruct = in.MutableStruct return } diff --git a/pkg/compiler/common/mocks/interface_provider.go b/pkg/compiler/common/mocks/interface_provider.go index 4191fbc7e9..82eddf0722 100644 --- a/pkg/compiler/common/mocks/interface_provider.go +++ b/pkg/compiler/common/mocks/interface_provider.go @@ -2,8 +2,10 @@ package mocks -import core "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" -import mock "github.com/stretchr/testify/mock" +import ( + core "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + mock "github.com/stretchr/testify/mock" +) // InterfaceProvider is an autogenerated mock type for the InterfaceProvider type type InterfaceProvider struct { diff --git a/pkg/compiler/common/mocks/node.go b/pkg/compiler/common/mocks/node.go index 4920bb5954..cf1406d1dd 100644 --- a/pkg/compiler/common/mocks/node.go +++ b/pkg/compiler/common/mocks/node.go @@ -2,9 +2,12 @@ package mocks -import common "github.com/lyft/flytepropeller/pkg/compiler/common" -import core "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" -import mock "github.com/stretchr/testify/mock" +import ( + core "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + common "github.com/lyft/flytepropeller/pkg/compiler/common" + + mock "github.com/stretchr/testify/mock" +) // Node is an autogenerated mock type for the Node type type Node struct { diff --git a/pkg/compiler/common/mocks/node_builder.go b/pkg/compiler/common/mocks/node_builder.go index 8fe6599211..f8bb224b0d 100644 --- a/pkg/compiler/common/mocks/node_builder.go +++ b/pkg/compiler/common/mocks/node_builder.go @@ -2,9 +2,12 @@ package mocks -import common "github.com/lyft/flytepropeller/pkg/compiler/common" -import core "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" -import mock "github.com/stretchr/testify/mock" +import ( + core "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + common "github.com/lyft/flytepropeller/pkg/compiler/common" + + mock "github.com/stretchr/testify/mock" +) // NodeBuilder is an autogenerated mock type for the NodeBuilder type type NodeBuilder struct { diff --git a/pkg/compiler/common/mocks/task.go b/pkg/compiler/common/mocks/task.go index 8c0f0dabe3..a6a461f2ce 100644 --- a/pkg/compiler/common/mocks/task.go +++ b/pkg/compiler/common/mocks/task.go @@ -2,8 +2,10 @@ package mocks -import core "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" -import mock "github.com/stretchr/testify/mock" +import ( + core "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + mock "github.com/stretchr/testify/mock" +) // Task is an autogenerated mock type for the Task type type Task struct { diff --git a/pkg/compiler/common/mocks/workflow.go b/pkg/compiler/common/mocks/workflow.go index 5268df3999..3bec3a59b7 100644 --- a/pkg/compiler/common/mocks/workflow.go +++ b/pkg/compiler/common/mocks/workflow.go @@ -2,9 +2,12 @@ package mocks -import common "github.com/lyft/flytepropeller/pkg/compiler/common" -import core "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" -import mock "github.com/stretchr/testify/mock" +import ( + core "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + common "github.com/lyft/flytepropeller/pkg/compiler/common" + + mock "github.com/stretchr/testify/mock" +) // Workflow is an autogenerated mock type for the Workflow type type Workflow struct { diff --git a/pkg/compiler/common/mocks/workflow_builder.go b/pkg/compiler/common/mocks/workflow_builder.go index d2cd9fd30d..adb196fdaa 100644 --- a/pkg/compiler/common/mocks/workflow_builder.go +++ b/pkg/compiler/common/mocks/workflow_builder.go @@ -2,10 +2,14 @@ package mocks -import common "github.com/lyft/flytepropeller/pkg/compiler/common" -import core "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" -import errors "github.com/lyft/flytepropeller/pkg/compiler/errors" -import mock "github.com/stretchr/testify/mock" +import ( + core "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + common "github.com/lyft/flytepropeller/pkg/compiler/common" + + errors "github.com/lyft/flytepropeller/pkg/compiler/errors" + + mock "github.com/stretchr/testify/mock" +) // WorkflowBuilder is an autogenerated mock type for the WorkflowBuilder type type WorkflowBuilder struct { diff --git a/pkg/controller/executors/mocks/client.go b/pkg/controller/executors/mocks/client.go index 6e96c33c5c..8b7c95c156 100644 --- a/pkg/controller/executors/mocks/client.go +++ b/pkg/controller/executors/mocks/client.go @@ -1,4 +1,4 @@ -// Code generated by mockery v1.0.0. DO NOT EDIT. +// Code generated by mockery v1.0.1. DO NOT EDIT. package mocks @@ -14,6 +14,24 @@ type Client struct { mock.Mock } +type Client_GetCache struct { + *mock.Call +} + +func (_m Client_GetCache) Return(_a0 cache.Cache) *Client_GetCache { + return &Client_GetCache{Call: _m.Call.Return(_a0)} +} + +func (_m *Client) OnGetCache() *Client_GetCache { + c := _m.On("GetCache") + return &Client_GetCache{Call: c} +} + +func (_m *Client) OnGetCacheMatch(matchers ...interface{}) *Client_GetCache { + c := _m.On("GetCache", matchers...) + return &Client_GetCache{Call: c} +} + // GetCache provides a mock function with given fields: func (_m *Client) GetCache() cache.Cache { ret := _m.Called() @@ -30,6 +48,24 @@ func (_m *Client) GetCache() cache.Cache { return r0 } +type Client_GetClient struct { + *mock.Call +} + +func (_m Client_GetClient) Return(_a0 client.Client) *Client_GetClient { + return &Client_GetClient{Call: _m.Call.Return(_a0)} +} + +func (_m *Client) OnGetClient() *Client_GetClient { + c := _m.On("GetClient") + return &Client_GetClient{Call: c} +} + +func (_m *Client) OnGetClientMatch(matchers ...interface{}) *Client_GetClient { + c := _m.On("GetClient", matchers...) + return &Client_GetClient{Call: c} +} + // GetClient provides a mock function with given fields: func (_m *Client) GetClient() client.Client { ret := _m.Called() diff --git a/pkg/controller/executors/mocks/node.go b/pkg/controller/executors/mocks/node.go index 173ed3bf0e..73d154e1d5 100644 --- a/pkg/controller/executors/mocks/node.go +++ b/pkg/controller/executors/mocks/node.go @@ -1,4 +1,4 @@ -// Code generated by mockery v1.0.0. DO NOT EDIT. +// Code generated by mockery v1.0.1. DO NOT EDIT. package mocks @@ -18,6 +18,24 @@ type Node struct { mock.Mock } +type Node_AbortHandler struct { + *mock.Call +} + +func (_m Node_AbortHandler) Return(_a0 error) *Node_AbortHandler { + return &Node_AbortHandler{Call: _m.Call.Return(_a0)} +} + +func (_m *Node) OnAbortHandler(ctx context.Context, w v1alpha1.ExecutableWorkflow, currentNode v1alpha1.ExecutableNode, reason string) *Node_AbortHandler { + c := _m.On("AbortHandler", ctx, w, currentNode, reason) + return &Node_AbortHandler{Call: c} +} + +func (_m *Node) OnAbortHandlerMatch(matchers ...interface{}) *Node_AbortHandler { + c := _m.On("AbortHandler", matchers...) + return &Node_AbortHandler{Call: c} +} + // AbortHandler provides a mock function with given fields: ctx, w, currentNode, reason func (_m *Node) AbortHandler(ctx context.Context, w v1alpha1.ExecutableWorkflow, currentNode v1alpha1.ExecutableNode, reason string) error { ret := _m.Called(ctx, w, currentNode, reason) @@ -96,6 +114,24 @@ func (_m *Node) Initialize(ctx context.Context) error { return r0 } +type Node_RecursiveNodeHandler struct { + *mock.Call +} + +func (_m Node_RecursiveNodeHandler) Return(_a0 executors.NodeStatus, _a1 error) *Node_RecursiveNodeHandler { + return &Node_RecursiveNodeHandler{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *Node) OnRecursiveNodeHandler(ctx context.Context, w v1alpha1.ExecutableWorkflow, currentNode v1alpha1.ExecutableNode) *Node_RecursiveNodeHandler { + c := _m.On("RecursiveNodeHandler", ctx, w, currentNode) + return &Node_RecursiveNodeHandler{Call: c} +} + +func (_m *Node) OnRecursiveNodeHandlerMatch(matchers ...interface{}) *Node_RecursiveNodeHandler { + c := _m.On("RecursiveNodeHandler", matchers...) + return &Node_RecursiveNodeHandler{Call: c} +} + // RecursiveNodeHandler provides a mock function with given fields: ctx, w, currentNode func (_m *Node) RecursiveNodeHandler(ctx context.Context, w v1alpha1.ExecutableWorkflow, currentNode v1alpha1.ExecutableNode) (executors.NodeStatus, error) { ret := _m.Called(ctx, w, currentNode) @@ -117,19 +153,37 @@ func (_m *Node) RecursiveNodeHandler(ctx context.Context, w v1alpha1.ExecutableW return r0, r1 } +type Node_SetInputsForStartNode struct { + *mock.Call +} + +func (_m Node_SetInputsForStartNode) Return(_a0 executors.NodeStatus, _a1 error) *Node_SetInputsForStartNode { + return &Node_SetInputsForStartNode{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *Node) OnSetInputsForStartNode(ctx context.Context, w v1alpha1.ExecutableWorkflow, inputs *core.LiteralMap) *Node_SetInputsForStartNode { + c := _m.On("SetInputsForStartNode", ctx, w, inputs) + return &Node_SetInputsForStartNode{Call: c} +} + +func (_m *Node) OnSetInputsForStartNodeMatch(matchers ...interface{}) *Node_SetInputsForStartNode { + c := _m.On("SetInputsForStartNode", matchers...) + return &Node_SetInputsForStartNode{Call: c} +} + // SetInputsForStartNode provides a mock function with given fields: ctx, w, inputs -func (_m *Node) SetInputsForStartNode(ctx context.Context, w v1alpha1.BaseWorkflowWithStatus, inputs *core.LiteralMap) (executors.NodeStatus, error) { +func (_m *Node) SetInputsForStartNode(ctx context.Context, w v1alpha1.ExecutableWorkflow, inputs *core.LiteralMap) (executors.NodeStatus, error) { ret := _m.Called(ctx, w, inputs) var r0 executors.NodeStatus - if rf, ok := ret.Get(0).(func(context.Context, v1alpha1.BaseWorkflowWithStatus, *core.LiteralMap) executors.NodeStatus); ok { + if rf, ok := ret.Get(0).(func(context.Context, v1alpha1.ExecutableWorkflow, *core.LiteralMap) executors.NodeStatus); ok { r0 = rf(ctx, w, inputs) } else { r0 = ret.Get(0).(executors.NodeStatus) } var r1 error - if rf, ok := ret.Get(1).(func(context.Context, v1alpha1.BaseWorkflowWithStatus, *core.LiteralMap) error); ok { + if rf, ok := ret.Get(1).(func(context.Context, v1alpha1.ExecutableWorkflow, *core.LiteralMap) error); ok { r1 = rf(ctx, w, inputs) } else { r1 = ret.Error(1) diff --git a/pkg/controller/executors/mocks/workflow.go b/pkg/controller/executors/mocks/workflow.go index 7cfc5f7849..559677e6a9 100644 --- a/pkg/controller/executors/mocks/workflow.go +++ b/pkg/controller/executors/mocks/workflow.go @@ -1,4 +1,4 @@ -// Code generated by mockery v1.0.0. DO NOT EDIT. +// Code generated by mockery v1.0.1. DO NOT EDIT. package mocks @@ -15,6 +15,24 @@ type Workflow struct { mock.Mock } +type Workflow_HandleAbortedWorkflow struct { + *mock.Call +} + +func (_m Workflow_HandleAbortedWorkflow) Return(_a0 error) *Workflow_HandleAbortedWorkflow { + return &Workflow_HandleAbortedWorkflow{Call: _m.Call.Return(_a0)} +} + +func (_m *Workflow) OnHandleAbortedWorkflow(ctx context.Context, w *v1alpha1.FlyteWorkflow, maxRetries uint32) *Workflow_HandleAbortedWorkflow { + c := _m.On("HandleAbortedWorkflow", ctx, w, maxRetries) + return &Workflow_HandleAbortedWorkflow{Call: c} +} + +func (_m *Workflow) OnHandleAbortedWorkflowMatch(matchers ...interface{}) *Workflow_HandleAbortedWorkflow { + c := _m.On("HandleAbortedWorkflow", matchers...) + return &Workflow_HandleAbortedWorkflow{Call: c} +} + // HandleAbortedWorkflow provides a mock function with given fields: ctx, w, maxRetries func (_m *Workflow) HandleAbortedWorkflow(ctx context.Context, w *v1alpha1.FlyteWorkflow, maxRetries uint32) error { ret := _m.Called(ctx, w, maxRetries) @@ -29,6 +47,24 @@ func (_m *Workflow) HandleAbortedWorkflow(ctx context.Context, w *v1alpha1.Flyte return r0 } +type Workflow_HandleFlyteWorkflow struct { + *mock.Call +} + +func (_m Workflow_HandleFlyteWorkflow) Return(_a0 error) *Workflow_HandleFlyteWorkflow { + return &Workflow_HandleFlyteWorkflow{Call: _m.Call.Return(_a0)} +} + +func (_m *Workflow) OnHandleFlyteWorkflow(ctx context.Context, w *v1alpha1.FlyteWorkflow) *Workflow_HandleFlyteWorkflow { + c := _m.On("HandleFlyteWorkflow", ctx, w) + return &Workflow_HandleFlyteWorkflow{Call: c} +} + +func (_m *Workflow) OnHandleFlyteWorkflowMatch(matchers ...interface{}) *Workflow_HandleFlyteWorkflow { + c := _m.On("HandleFlyteWorkflow", matchers...) + return &Workflow_HandleFlyteWorkflow{Call: c} +} + // HandleFlyteWorkflow provides a mock function with given fields: ctx, w func (_m *Workflow) HandleFlyteWorkflow(ctx context.Context, w *v1alpha1.FlyteWorkflow) error { ret := _m.Called(ctx, w) @@ -43,6 +79,24 @@ func (_m *Workflow) HandleFlyteWorkflow(ctx context.Context, w *v1alpha1.FlyteWo return r0 } +type Workflow_Initialize struct { + *mock.Call +} + +func (_m Workflow_Initialize) Return(_a0 error) *Workflow_Initialize { + return &Workflow_Initialize{Call: _m.Call.Return(_a0)} +} + +func (_m *Workflow) OnInitialize(ctx context.Context) *Workflow_Initialize { + c := _m.On("Initialize", ctx) + return &Workflow_Initialize{Call: c} +} + +func (_m *Workflow) OnInitializeMatch(matchers ...interface{}) *Workflow_Initialize { + c := _m.On("Initialize", matchers...) + return &Workflow_Initialize{Call: c} +} + // Initialize provides a mock function with given fields: ctx func (_m *Workflow) Initialize(ctx context.Context) error { ret := _m.Called(ctx) diff --git a/pkg/controller/executors/node.go b/pkg/controller/executors/node.go index 41b4bbed49..9cd38eeb64 100644 --- a/pkg/controller/executors/node.go +++ b/pkg/controller/executors/node.go @@ -62,7 +62,7 @@ func (p NodePhase) String() string { type Node interface { // This method is used specifically to set inputs for start node. This is because start node does not retrieve inputs // from predecessors, but the inputs are inputs to the workflow or inputs to the parent container (workflow) node. - SetInputsForStartNode(ctx context.Context, w v1alpha1.BaseWorkflowWithStatus, inputs *core.LiteralMap) (NodeStatus, error) + SetInputsForStartNode(ctx context.Context, w v1alpha1.ExecutableWorkflow, inputs *core.LiteralMap) (NodeStatus, error) // This is the main entrypoint to execute a node. It recursively depth-first goes through all ready nodes and starts their execution // This returns either diff --git a/pkg/controller/nodes/branch/evaluator.go b/pkg/controller/nodes/branch/evaluator.go index 56ba1a9d52..89183ab394 100644 --- a/pkg/controller/nodes/branch/evaluator.go +++ b/pkg/controller/nodes/branch/evaluator.go @@ -123,7 +123,7 @@ func DecideBranch(ctx context.Context, w v1alpha1.BaseWorkflowWithStatus, nodeID if !ok { return nil, errors.Errorf(errors.DownstreamNodeNotFoundError, nodeID, "Downstream node [%v] not found", skippedNodeID) } - nStatus := w.GetNodeExecutionStatus(n.GetID()) + nStatus := w.GetNodeExecutionStatus(ctx, n.GetID()) logger.Infof(ctx, "Branch Setting Node[%v] status to Skipped!", skippedNodeID) nStatus.UpdatePhase(v1alpha1.NodePhaseSkipped, v1.Now(), "Branch evaluated to false") } diff --git a/pkg/controller/nodes/branch/evaluator_test.go b/pkg/controller/nodes/branch/evaluator_test.go index f2959a4390..5f044c29c7 100644 --- a/pkg/controller/nodes/branch/evaluator_test.go +++ b/pkg/controller/nodes/branch/evaluator_test.go @@ -4,6 +4,9 @@ import ( "context" "testing" + "github.com/lyft/flytestdlib/promutils" + "github.com/lyft/flytestdlib/storage" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" "github.com/stretchr/testify/assert" @@ -340,12 +343,16 @@ func TestEvaluateIfBlock(t *testing.T) { func TestDecideBranch(t *testing.T) { ctx := context.Background() + dataStore, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) + assert.NoError(t, err) + t.Run("EmptyIfBlock", func(t *testing.T) { w := &v1alpha1.FlyteWorkflow{ WorkflowSpec: &v1alpha1.WorkflowSpec{ ID: "w1", Nodes: map[v1alpha1.NodeID]*v1alpha1.NodeSpec{}, }, + DataReferenceConstructor: dataStore, } branchNode := &v1alpha1.BranchNodeSpec{} b, err := DecideBranch(ctx, w, "n1", branchNode, nil) @@ -359,6 +366,7 @@ func TestDecideBranch(t *testing.T) { ID: "w1", Nodes: map[v1alpha1.NodeID]*v1alpha1.NodeSpec{}, }, + DataReferenceConstructor: dataStore, } exp, inputs := getComparisonExpression(1.0, core.ComparisonExpression_EQ, 1.0) branchNode := &v1alpha1.BranchNodeSpec{ @@ -390,6 +398,7 @@ func TestDecideBranch(t *testing.T) { }, }, }, + DataReferenceConstructor: dataStore, } exp, inputs := getComparisonExpression(1.0, core.ComparisonExpression_EQ, 1.0) branchNode := &v1alpha1.BranchNodeSpec{ @@ -413,6 +422,7 @@ func TestDecideBranch(t *testing.T) { t.Run("RepeatedCondition", func(t *testing.T) { n1 := "n1" n2 := "n2" + w := &v1alpha1.FlyteWorkflow{ WorkflowSpec: &v1alpha1.WorkflowSpec{ ID: "w1", @@ -425,7 +435,9 @@ func TestDecideBranch(t *testing.T) { }, }, }, + DataReferenceConstructor: dataStore, } + exp, inputs := getComparisonExpression(1.0, core.ComparisonExpression_EQ, 1.0) branchNode := &v1alpha1.BranchNodeSpec{ If: v1alpha1.IfBlock{ @@ -474,6 +486,7 @@ func TestDecideBranch(t *testing.T) { }, }, }, + DataReferenceConstructor: dataStore, } exp1, inputs := getComparisonExpression(1, core.ComparisonExpression_NEQ, 1) exp2, _ := getComparisonExpression(1, core.ComparisonExpression_EQ, 1) @@ -525,6 +538,7 @@ func TestDecideBranch(t *testing.T) { }, }, }, + DataReferenceConstructor: dataStore, } exp1, inputs := getComparisonExpression(1, core.ComparisonExpression_NEQ, 1) exp2, _ := getComparisonExpression(1, core.ComparisonExpression_NEQ, 1) @@ -574,6 +588,7 @@ func TestDecideBranch(t *testing.T) { }, }, }, + DataReferenceConstructor: dataStore, } exp1, inputs := getComparisonExpression(1, core.ComparisonExpression_NEQ, 1) exp2, _ := getComparisonExpression(1, core.ComparisonExpression_NEQ, 1) @@ -624,7 +639,9 @@ func TestDecideBranch(t *testing.T) { }, }, }, + DataReferenceConstructor: dataStore, } + exp1, inputs := getComparisonExpression(1, core.ComparisonExpression_NEQ, 1) exp2, _ := getComparisonExpression(1, core.ComparisonExpression_NEQ, 1) branchNode := &v1alpha1.BranchNodeSpec{ @@ -656,6 +673,7 @@ func TestDecideBranch(t *testing.T) { }, }, } + b, err := DecideBranch(ctx, w, "n", branchNode, inputs) assert.Error(t, err) assert.Nil(t, b) diff --git a/pkg/controller/nodes/branch/handler.go b/pkg/controller/nodes/branch/handler.go index 108dfa7d48..58afc969fa 100644 --- a/pkg/controller/nodes/branch/handler.go +++ b/pkg/controller/nodes/branch/handler.go @@ -66,11 +66,11 @@ func (b *branchHandler) Handle(ctx context.Context, nCtx handler.NodeExecutionCo return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(errors.DownstreamNodeNotFoundError, errMsg, nil)), nil } i := nCtx.NodeID() - childNodeStatus := w.GetNodeExecutionStatus(finalNode.GetID()) + childNodeStatus := w.GetNodeExecutionStatus(ctx, finalNode.GetID()) childNodeStatus.SetParentNodeID(&i) logger.Debugf(ctx, "Recursing down branchNodestatus node") - nodeStatus := w.GetNodeExecutionStatus(nCtx.NodeID()) + nodeStatus := w.GetNodeExecutionStatus(ctx, nCtx.NodeID()) return b.recurseDownstream(ctx, nCtx, nodeStatus, finalNode) } @@ -96,7 +96,7 @@ func (b *branchHandler) Handle(ctx context.Context, nCtx handler.NodeExecutionCo } // Recurse downstream - nodeStatus := w.GetNodeExecutionStatus(nCtx.NodeID()) + nodeStatus := w.GetNodeExecutionStatus(ctx, nCtx.NodeID()) return b.recurseDownstream(ctx, nCtx, nodeStatus, branchTakenNode) } @@ -109,7 +109,7 @@ func (b *branchHandler) recurseDownstream(ctx context.Context, nCtx handler.Node if downstreamStatus.IsComplete() { // For branch node we set the output node to be the same as the child nodes output - childNodeStatus := w.GetNodeExecutionStatus(branchTakenNode.GetID()) + childNodeStatus := w.GetNodeExecutionStatus(ctx, branchTakenNode.GetID()) nodeStatus.SetDataDir(childNodeStatus.GetDataDir()) phase := handler.PhaseInfoSuccess(&handler.ExecutionInfo{ OutputInfo: &handler.OutputInfo{OutputURI: v1alpha1.GetOutputsFile(childNodeStatus.GetDataDir())}, diff --git a/pkg/controller/nodes/branch/handler_test.go b/pkg/controller/nodes/branch/handler_test.go index de24fe7344..84f09ba4aa 100644 --- a/pkg/controller/nodes/branch/handler_test.go +++ b/pkg/controller/nodes/branch/handler_test.go @@ -182,7 +182,7 @@ func TestBranchHandler_RecurseDownstream(t *testing.T) { } assert.Equal(t, test.expectedPhase, h.Info().GetPhase()) if test.nodeStatus != nil { - assert.Equal(t, w.GetNodeExecutionStatus(test.branchTakenNode.GetID()).GetDataDir(), test.nodeStatus.GetDataDir()) + assert.Equal(t, w.GetNodeExecutionStatus(ctx, test.branchTakenNode.GetID()).GetDataDir(), test.nodeStatus.GetDataDir()) } }) } diff --git a/pkg/controller/nodes/dynamic/handler.go b/pkg/controller/nodes/dynamic/handler.go index c381116e30..b47b23f8d2 100644 --- a/pkg/controller/nodes/dynamic/handler.go +++ b/pkg/controller/nodes/dynamic/handler.go @@ -3,6 +3,7 @@ package dynamic import ( "context" "fmt" + "strconv" "time" "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/catalog" @@ -134,13 +135,14 @@ func (d dynamicNodeTaskNodeHandler) Handle(ctx context.Context, nCtx handler.Nod var trns handler.Transition newState := ds logger.Infof(ctx, "Dynamic handler.Handle's called with phase %v.", ds.Phase) - if ds.Phase == v1alpha1.DynamicNodePhaseExecuting { + switch ds.Phase { + case v1alpha1.DynamicNodePhaseExecuting: trns, newState, err = d.handleDynamicSubNodes(ctx, nCtx, ds) if err != nil { logger.Errorf(ctx, "handling dynamic subnodes failed with error: %s", err.Error()) return trns, err } - } else if ds.Phase == v1alpha1.DynamicNodePhaseFailing { + case v1alpha1.DynamicNodePhaseFailing: err = d.Abort(ctx, nCtx, ds.Reason) if err != nil { logger.Errorf(ctx, "Failing to abort dynamic workflow") @@ -148,7 +150,7 @@ func (d dynamicNodeTaskNodeHandler) Handle(ctx context.Context, nCtx handler.Nod } trns = handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoRetryableFailure("DynamicNodeFailed", ds.Reason, nil)) - } else { + default: trns, newState, err = d.handleParentNode(ctx, ds, nCtx) if err != nil { logger.Errorf(ctx, "handling parent node failed with error: %s", err.Error()) @@ -211,33 +213,32 @@ func (d dynamicNodeTaskNodeHandler) Finalize(ctx context.Context, nCtx handler.N } } -func (d dynamicNodeTaskNodeHandler) buildDynamicWorkflowTemplate(ctx context.Context, djSpec *core.DynamicJobSpec, nCtx handler.NodeExecutionContext, parentNodeStatus v1alpha1.ExecutableNodeStatus) ( - *core.WorkflowTemplate, error) { +func (d dynamicNodeTaskNodeHandler) buildDynamicWorkflowTemplate(ctx context.Context, djSpec *core.DynamicJobSpec, + nCtx handler.NodeExecutionContext, parentNodeStatus v1alpha1.ExecutableNodeStatus) (*core.WorkflowTemplate, error) { iface, err := underlyingInterface(ctx, nCtx.TaskReader()) if err != nil { return nil, err } + currentAttemptStr := strconv.Itoa(int(nCtx.CurrentAttempt())) // Modify node IDs to include lineage, the entire system assumes node IDs are unique per parent WF. // We keep track of the original node ids because that's where inputs are written to. parentNodeID := nCtx.NodeID() for _, n := range djSpec.Nodes { - newID, err := hierarchicalNodeID(parentNodeID, n.Id) + newID, err := hierarchicalNodeID(parentNodeID, currentAttemptStr, n.Id) if err != nil { return nil, err } // Instantiate a nodeStatus using the modified name but set its data directory using the original name. - subNodeStatus := parentNodeStatus.GetNodeExecutionStatus(newID) + subNodeStatus := parentNodeStatus.GetNodeExecutionStatus(ctx, newID) originalNodePath, err := nCtx.DataStore().ConstructReference(ctx, nCtx.NodeStatus().GetDataDir(), n.Id) if err != nil { return nil, err } subNodeStatus.SetDataDir(originalNodePath) - subNodeStatus.ResetDirty() - n.Id = newID } @@ -263,7 +264,7 @@ func (d dynamicNodeTaskNodeHandler) buildDynamicWorkflowTemplate(ctx context.Con } for _, o := range djSpec.Outputs { - err = updateBindingNodeIDsWithLineage(parentNodeID, o.Binding) + err = updateBindingNodeIDsWithLineage(parentNodeID, currentAttemptStr, o.Binding) if err != nil { return nil, err } @@ -295,7 +296,7 @@ func (d dynamicNodeTaskNodeHandler) buildContextualDynamicWorkflow(ctx context.C // TODO: This is a hack to set parent task execution id, we should move to node-node relationship. execID := task.GetTaskExecutionIdentifier(nCtx) - nStatus := nCtx.NodeStatus().GetNodeExecutionStatus(dynamicNodeID) + nStatus := nCtx.NodeStatus().GetNodeExecutionStatus(ctx, dynamicNodeID) nStatus.SetDataDir(nCtx.NodeStatus().GetDataDir()) nStatus.SetParentTaskID(execID) @@ -349,7 +350,7 @@ func (d dynamicNodeTaskNodeHandler) buildContextualDynamicWorkflow(ctx context.C logger.Errorf(ctx, "Failed to cache Dynamic workflow [%s]", err.Error()) } - return newContextualWorkflow(nCtx.Workflow(), subwf, nStatus, subwf.Tasks, subwf.SubWorkflows), true, nil + return newContextualWorkflow(nCtx.Workflow(), subwf, nStatus, subwf.Tasks, subwf.SubWorkflows, nCtx.DataStore()), true, nil } func (d dynamicNodeTaskNodeHandler) progressDynamicWorkflow(ctx context.Context, dynamicWorkflow v1alpha1.ExecutableWorkflow, @@ -384,7 +385,7 @@ func (d dynamicNodeTaskNodeHandler) progressDynamicWorkflow(ctx context.Context, var o *handler.OutputInfo // If the WF interface has outputs, validate that the outputs file was written. if outputBindings := dynamicWorkflow.GetOutputBindings(); len(outputBindings) > 0 { - endNodeStatus := dynamicWorkflow.GetNodeExecutionStatus(v1alpha1.EndNodeID) + endNodeStatus := dynamicWorkflow.GetNodeExecutionStatus(ctx, v1alpha1.EndNodeID) if endNodeStatus == nil { return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure("MalformedDynamicWorkflow", "no end-node found in dynamic workflow", nil)), handler.DynamicNodeState{Phase: v1alpha1.DynamicNodePhaseFailing, Reason: "no end-node found in dynamic workflow"}, diff --git a/pkg/controller/nodes/dynamic/handler_test.go b/pkg/controller/nodes/dynamic/handler_test.go index 6a951a848b..e04ae781f9 100644 --- a/pkg/controller/nodes/dynamic/handler_test.go +++ b/pkg/controller/nodes/dynamic/handler_test.go @@ -105,18 +105,22 @@ func Test_dynamicNodeHandler_Handle_Parent(t *testing.T) { n := &flyteMocks.ExecutableNode{} n.On("GetResources").Return(res) + dataStore, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) + assert.NoError(t, err) + ir := &ioMocks.InputReader{} nCtx := &nodeMocks.NodeExecutionContext{} nCtx.On("NodeExecutionMetadata").Return(nm) nCtx.On("Node").Return(n) nCtx.On("InputReader").Return(ir) - nCtx.On("DataStore").Return(storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope())) + nCtx.On("DataReferenceConstructor").Return(storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope())) nCtx.On("CurrentAttempt").Return(uint32(1)) nCtx.On("TaskReader").Return(tr) nCtx.On("MaxDatasetSizeBytes").Return(int64(1)) nCtx.On("NodeStatus").Return(ns) nCtx.On("NodeID").Return("n1") nCtx.On("EnqueueOwner").Return(nil) + nCtx.OnDataStore().Return(dataStore) r := &nodeMocks.NodeStateReader{} r.On("GetDynamicNodeState").Return(handler.DynamicNodeState{}) @@ -258,6 +262,8 @@ func createDynamicJobSpec() *core.DynamicJobSpec { func Test_dynamicNodeHandler_Handle_SubTask(t *testing.T) { createNodeContext := func(ttype string, finalOutput storage.DataReference) *nodeMocks.NodeExecutionContext { + ctx := context.TODO() + wfExecID := &core.WorkflowExecutionIdentifier{ Project: "project", Domain: "domain", @@ -308,17 +314,21 @@ func Test_dynamicNodeHandler_Handle_SubTask(t *testing.T) { tID := "task-1" n.On("GetTaskID").Return(&tID) + dataStore, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) + assert.NoError(t, err) + ir := &ioMocks.InputReader{} nCtx := &nodeMocks.NodeExecutionContext{} nCtx.On("NodeExecutionMetadata").Return(nm) nCtx.On("Node").Return(n) nCtx.On("InputReader").Return(ir) - nCtx.On("DataStore").Return(storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope())) + nCtx.On("DataReferenceConstructor").Return(storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope())) nCtx.On("CurrentAttempt").Return(uint32(1)) nCtx.On("TaskReader").Return(tr) nCtx.On("MaxDatasetSizeBytes").Return(int64(1)) nCtx.On("NodeID").Return("n1") nCtx.On("EnqueueOwnerFunc").Return(func() error { return nil }) + nCtx.OnDataStore().Return(dataStore) endNodeStatus := &flyteMocks.ExecutableNodeStatus{} endNodeStatus.On("GetDataDir").Return(storage.DataReference("end-node")) @@ -332,19 +342,19 @@ func Test_dynamicNodeHandler_Handle_SubTask(t *testing.T) { dynamicNS := &flyteMocks.ExecutableNodeStatus{} dynamicNS.On("SetDataDir", mock.Anything).Return() dynamicNS.On("SetParentTaskID", mock.Anything).Return() - dynamicNS.On("GetNodeExecutionStatus", "n1-Node_1").Return(subNs) - dynamicNS.On("GetNodeExecutionStatus", "n1-Node_2").Return(subNs) - dynamicNS.On("GetNodeExecutionStatus", "n1-Node_3").Return(subNs) - dynamicNS.On("GetNodeExecutionStatus", v1alpha1.EndNodeID).Return(endNodeStatus) + dynamicNS.OnGetNodeExecutionStatus(ctx, "n1-1-Node_1").Return(subNs) + dynamicNS.OnGetNodeExecutionStatus(ctx, "n1-1-Node_2").Return(subNs) + dynamicNS.OnGetNodeExecutionStatus(ctx, "n1-1-Node_3").Return(subNs) + dynamicNS.OnGetNodeExecutionStatus(ctx, v1alpha1.EndNodeID).Return(endNodeStatus) ns := &flyteMocks.ExecutableNodeStatus{} ns.On("GetDataDir").Return(storage.DataReference("data-dir")) - ns.On("GetNodeExecutionStatus", dynamicNodeID).Return(dynamicNS) + ns.OnGetNodeExecutionStatus(ctx, dynamicNodeID).Return(dynamicNS) nCtx.On("NodeStatus").Return(ns) w := &flyteMocks.ExecutableWorkflow{} ws := &flyteMocks.ExecutableWorkflowStatus{} - ws.On("GetNodeExecutionStatus", "n1").Return(ns) + ws.OnGetNodeExecutionStatus(ctx, "n1").Return(ns) w.On("GetExecutionStatus").Return(ws) nCtx.On("Workflow").Return(w) diff --git a/pkg/controller/nodes/dynamic/mocks/task_node_handler.go b/pkg/controller/nodes/dynamic/mocks/task_node_handler.go index adf7063acf..e1b7f2294f 100644 --- a/pkg/controller/nodes/dynamic/mocks/task_node_handler.go +++ b/pkg/controller/nodes/dynamic/mocks/task_node_handler.go @@ -2,13 +2,19 @@ package mocks -import catalog "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/catalog" -import context "context" -import core "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core" +import ( + context "context" -import handler "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" -import io "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/io" -import mock "github.com/stretchr/testify/mock" + catalog "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/catalog" + + core "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core" + + handler "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" + + io "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/io" + + mock "github.com/stretchr/testify/mock" +) // TaskNodeHandler is an autogenerated mock type for the TaskNodeHandler type type TaskNodeHandler struct { diff --git a/pkg/controller/nodes/dynamic/subworkflow.go b/pkg/controller/nodes/dynamic/subworkflow.go index 3869a5c219..9d2350755d 100644 --- a/pkg/controller/nodes/dynamic/subworkflow.go +++ b/pkg/controller/nodes/dynamic/subworkflow.go @@ -22,13 +22,14 @@ func newContextualWorkflow(baseWorkflow v1alpha1.ExecutableWorkflow, subwf v1alpha1.ExecutableSubWorkflow, status v1alpha1.ExecutableNodeStatus, tasks map[v1alpha1.TaskID]*v1alpha1.TaskSpec, - workflows map[v1alpha1.WorkflowID]*v1alpha1.WorkflowSpec) v1alpha1.ExecutableWorkflow { + workflows map[v1alpha1.WorkflowID]*v1alpha1.WorkflowSpec, + refConstructor storage.ReferenceConstructor) v1alpha1.ExecutableWorkflow { return &contextualWorkflow{ ExecutableWorkflow: executors.NewSubContextualWorkflow(baseWorkflow, subwf, status), extraTasks: tasks, extraWorkflows: workflows, - status: newContextualWorkflowStatus(baseWorkflow.GetExecutionStatus(), status), + status: newContextualWorkflowStatus(baseWorkflow.GetExecutionStatus(), status, refConstructor), } } @@ -55,7 +56,8 @@ func (w contextualWorkflow) FindSubWorkflow(id v1alpha1.WorkflowID) v1alpha1.Exe // A contextual workflow status to override some of the implementations. type ContextualWorkflowStatus struct { v1alpha1.ExecutableWorkflowStatus - baseStatus v1alpha1.ExecutableNodeStatus + baseStatus v1alpha1.ExecutableNodeStatus + referenceConstructor storage.ReferenceConstructor } func (w ContextualWorkflowStatus) GetDataDir() v1alpha1.DataReference { @@ -74,16 +76,16 @@ func (w ContextualWorkflowStatus) GetDataDir() v1alpha1.DataReference { // |_ sub-node2/inputs.pb // TODO: This is just a stop-gap until we transition the DynamicJobSpec to be a full-fledged workflow spec. // TODO: this will allow us to have proper data bindings between nodes then we can stop making assumptions about data refs. -func (w ContextualWorkflowStatus) ConstructNodeDataDir(ctx context.Context, constructor storage.ReferenceConstructor, - name v1alpha1.NodeID) (storage.DataReference, error) { - return constructor.ConstructReference(ctx, w.GetDataDir(), name) +func (w ContextualWorkflowStatus) ConstructNodeDataDir(ctx context.Context, name v1alpha1.NodeID) (storage.DataReference, error) { + return w.referenceConstructor.ConstructReference(ctx, w.GetDataDir(), name) } func newContextualWorkflowStatus(baseWfStatus v1alpha1.ExecutableWorkflowStatus, - baseStatus v1alpha1.ExecutableNodeStatus) *ContextualWorkflowStatus { + baseStatus v1alpha1.ExecutableNodeStatus, constructor storage.ReferenceConstructor) *ContextualWorkflowStatus { return &ContextualWorkflowStatus{ ExecutableWorkflowStatus: baseWfStatus, baseStatus: baseStatus, + referenceConstructor: constructor, } } diff --git a/pkg/controller/nodes/dynamic/subworkflow_test.go b/pkg/controller/nodes/dynamic/subworkflow_test.go index 9433b6ee58..f55016ee75 100644 --- a/pkg/controller/nodes/dynamic/subworkflow_test.go +++ b/pkg/controller/nodes/dynamic/subworkflow_test.go @@ -23,7 +23,7 @@ func TestNewContextualWorkflow(t *testing.T) { wf.On("GetExecutionStatus").Return(&mocks.ExecutableWorkflowStatus{}) subwf := &mocks.ExecutableSubWorkflow{} - cWF := newContextualWorkflow(wf, subwf, nil, nil, nil) + cWF := newContextualWorkflow(wf, subwf, nil, nil, nil, nil) cWF.GetAnnotations() assert.True(t, calledBase) @@ -43,9 +43,9 @@ func TestConstructNodeDataDir(t *testing.T) { ds, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) assert.NoError(t, err) - cWF := newContextualWorkflowStatus(wfStatus, nodeStatus) + cWF := newContextualWorkflowStatus(wfStatus, nodeStatus, ds) - dataDir, err := cWF.ConstructNodeDataDir(context.TODO(), ds, "my_node") + dataDir, err := cWF.ConstructNodeDataDir(context.TODO(), "my_node") assert.NoError(t, err) assert.NotNil(t, dataDir) assert.Equal(t, "fk://right/my_node", dataDir.String()) diff --git a/pkg/controller/nodes/dynamic/utils.go b/pkg/controller/nodes/dynamic/utils.go index b17f701d58..317faf6b6f 100644 --- a/pkg/controller/nodes/dynamic/utils.go +++ b/pkg/controller/nodes/dynamic/utils.go @@ -27,27 +27,27 @@ func underlyingInterface(ctx context.Context, taskReader handler.TaskReader) (*c return iface, nil } -func hierarchicalNodeID(parentNodeID, nodeID string) (string, error) { - return utils.FixedLengthUniqueIDForParts(20, parentNodeID, nodeID) +func hierarchicalNodeID(parentNodeID, retryAttempt, nodeID string) (string, error) { + return utils.FixedLengthUniqueIDForParts(20, parentNodeID, retryAttempt, nodeID) } -func updateBindingNodeIDsWithLineage(parentNodeID string, binding *core.BindingData) (err error) { +func updateBindingNodeIDsWithLineage(parentNodeID, retryAttempt string, binding *core.BindingData) (err error) { switch b := binding.Value.(type) { case *core.BindingData_Promise: - b.Promise.NodeId, err = hierarchicalNodeID(parentNodeID, b.Promise.NodeId) + b.Promise.NodeId, err = hierarchicalNodeID(parentNodeID, retryAttempt, b.Promise.NodeId) if err != nil { return err } case *core.BindingData_Collection: for _, item := range b.Collection.Bindings { - err = updateBindingNodeIDsWithLineage(parentNodeID, item) + err = updateBindingNodeIDsWithLineage(parentNodeID, retryAttempt, item) if err != nil { return err } } case *core.BindingData_Map: for _, item := range b.Map.Bindings { - err = updateBindingNodeIDsWithLineage(parentNodeID, item) + err = updateBindingNodeIDsWithLineage(parentNodeID, retryAttempt, item) if err != nil { return err } diff --git a/pkg/controller/nodes/dynamic/utils_test.go b/pkg/controller/nodes/dynamic/utils_test.go index f5f43bf216..d36fc93c39 100644 --- a/pkg/controller/nodes/dynamic/utils_test.go +++ b/pkg/controller/nodes/dynamic/utils_test.go @@ -15,15 +15,21 @@ import ( func TestHierarchicalNodeID(t *testing.T) { t.Run("empty parent", func(t *testing.T) { - actual, err := hierarchicalNodeID("", "abc") + actual, err := hierarchicalNodeID("", "0", "abc") assert.NoError(t, err) - assert.Equal(t, "-abc", actual) + assert.Equal(t, "0-abc", actual) }) t.Run("long result", func(t *testing.T) { - actual, err := hierarchicalNodeID("abcdefghijklmnopqrstuvwxyz", "abc") + actual, err := hierarchicalNodeID("abcdefghijklmnopqrstuvwxyz", "0", "abc") assert.NoError(t, err) - assert.Equal(t, "fpa3kc3y", actual) + assert.Equal(t, "fkm1vhcq", actual) + }) + + t.Run("Real case", func(t *testing.T) { + actual, err := hierarchicalNodeID("ensure-tables-task", "0", "2499f2af-7c23-42fd-8e62-01bf93cea82d") + assert.NoError(t, err) + assert.Equal(t, "fyvhfkda", actual) }) } diff --git a/pkg/controller/nodes/executor.go b/pkg/controller/nodes/executor.go index 2226bbb418..b6da592f56 100644 --- a/pkg/controller/nodes/executor.go +++ b/pkg/controller/nodes/executor.go @@ -116,7 +116,6 @@ func (c *nodeExecutor) preExecute(ctx context.Context, w v1alpha1.ExecutableWork } if predicatePhase == PredicatePhaseReady { - // TODO: Performance problem, we maybe in a retry loop and do not need to resolve the inputs again. // For now we will do this. dataDir := nodeStatus.GetDataDir() @@ -146,8 +145,10 @@ func (c *nodeExecutor) preExecute(ctx context.Context, w v1alpha1.ExecutableWork logger.Debugf(ctx, "Node Data Directory [%s].", nodeStatus.GetDataDir()) } + return handler.PhaseInfoQueued("node queued"), nil } + // Now that we have resolved the inputs, we can record as a transition latency. This is because we have completed // all the overhead that we have to compute. Any failures after this will incur this penalty, but it could be due // to various external reasons - like queuing, overuse of quota, plugin overhead etc. @@ -155,6 +156,7 @@ func (c *nodeExecutor) preExecute(ctx context.Context, w v1alpha1.ExecutableWork if predicatePhase == PredicatePhaseSkip { return handler.PhaseInfoSkip(nil, "Node Skipped as parent node was skipped"), nil } + return handler.PhaseInfoNotReady("predecessor node not yet complete"), nil } @@ -248,7 +250,12 @@ func (c *nodeExecutor) handleNode(ctx context.Context, w v1alpha1.ExecutableWork NodeId: node.GetID(), ExecutionId: w.GetExecutionID().WorkflowExecutionIdentifier, } - nodeStatus := w.GetNodeExecutionStatus(node.GetID()) + + nodeStatus := w.GetNodeExecutionStatus(ctx, node.GetID()) + + if nodeStatus.IsDirty() { + return executors.NodeStatusRunning, nil + } // Now depending on the node type decide h, err := c.nodeHandlerFactory.GetHandler(node.GetKind()) @@ -256,16 +263,6 @@ func (c *nodeExecutor) handleNode(ctx context.Context, w v1alpha1.ExecutableWork return executors.NodeStatusUndefined, err } - if len(nodeStatus.GetDataDir()) == 0 { - // Predicate ready, lets Resolve the data - dataDir, err := w.GetExecutionStatus().ConstructNodeDataDir(ctx, c.store, node.GetID()) - if err != nil { - return executors.NodeStatusUndefined, err - } - - nodeStatus.SetDataDir(dataDir) - } - nCtx, err := c.newNodeExecContextDefault(ctx, w, node, nodeStatus) if err != nil { return executors.NodeStatusUndefined, err @@ -283,9 +280,11 @@ func (c *nodeExecutor) handleNode(ctx context.Context, w v1alpha1.ExecutableWork logger.Errorf(ctx, "failed preExecute for node. Error: %s", err.Error()) return executors.NodeStatusUndefined, err } + if p.GetPhase() == handler.EPhaseUndefined { return executors.NodeStatusUndefined, errors.Errorf(errors.IllegalStateError, node.GetID(), "received undefined phase.") } + if p.GetPhase() == handler.EPhaseNotReady { return executors.NodeStatusPending, nil } @@ -294,6 +293,7 @@ func (c *nodeExecutor) handleNode(ctx context.Context, w v1alpha1.ExecutableWork if err != nil { return executors.NodeStatusUndefined, errors.Wrapf(errors.IllegalStateError, node.GetID(), err, "failed to move from queued") } + if np != nodeStatus.GetPhase() { // assert np == Queued! logger.Infof(ctx, "Change in node state detected from [%s] -> [%s]", nodeStatus.GetPhase().String(), np.String()) @@ -309,11 +309,13 @@ func (c *nodeExecutor) handleNode(ctx context.Context, w v1alpha1.ExecutableWork UpdateNodeStatus(np, p, nCtx.nsm, nodeStatus) c.RecordTransitionLatency(ctx, w, node, nodeStatus) } + if np == v1alpha1.NodePhaseQueued { return executors.NodeStatusQueued, nil } else if np == v1alpha1.NodePhaseSkipped { return executors.NodeStatusSuccess, nil } + return executors.NodeStatusPending, nil } @@ -335,6 +337,7 @@ func (c *nodeExecutor) handleNode(ctx context.Context, w v1alpha1.ExecutableWork return executors.NodeStatusUndefined, err } + nodeStatus.ClearSubNodeStatus() nodeStatus.UpdatePhase(v1alpha1.NodePhaseTimedOut, v1.Now(), nodeStatus.GetMessage()) c.metrics.TimedOutFailure.Inc(ctx) return executors.NodeStatusTimedOut, nil @@ -346,6 +349,7 @@ func (c *nodeExecutor) handleNode(ctx context.Context, w v1alpha1.ExecutableWork return executors.NodeStatusUndefined, err } + nodeStatus.ClearSubNodeStatus() nodeStatus.UpdatePhase(v1alpha1.NodePhaseSucceeded, v1.Now(), "completed successfully") c.metrics.SuccessDuration.Observe(ctx, nodeStatus.GetStartedAt().Time, nodeStatus.GetStoppedAt().Time) return executors.NodeStatusSuccess, nil @@ -359,9 +363,10 @@ func (c *nodeExecutor) handleNode(ctx context.Context, w v1alpha1.ExecutableWork nodeStatus.UpdatePhase(v1alpha1.NodePhaseRunning, v1.Now(), "retrying") // We are going to retry in the next round, so we should clear all current state - nodeStatus.ClearDynamicNodeStatus() + nodeStatus.ClearSubNodeStatus() nodeStatus.ClearTaskStatus() nodeStatus.ClearWorkflowStatus() + nodeStatus.ClearDynamicNodeStatus() nodeStatus.ClearLastAttemptStartedAt() return executors.NodeStatusPending, nil } @@ -371,11 +376,6 @@ func (c *nodeExecutor) handleNode(ctx context.Context, w v1alpha1.ExecutableWork return executors.NodeStatusFailed(fmt.Errorf(nodeStatus.GetMessage())), nil } - if currentPhase == v1alpha1.NodePhaseFailed { - // This should never happen - return executors.NodeStatusSuccess, nil - } - // case v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseRetryableFailure: logger.Debugf(ctx, "node executing, current phase [%s]", currentPhase) defer logger.Debugf(ctx, "node execution completed") @@ -495,33 +495,39 @@ func (c *nodeExecutor) handleDownstream(ctx context.Context, w v1alpha1.Executab return executors.NodeStatusPending, nil } -func (c *nodeExecutor) SetInputsForStartNode(ctx context.Context, w v1alpha1.BaseWorkflowWithStatus, inputs *core.LiteralMap) (executors.NodeStatus, error) { +func (c *nodeExecutor) SetInputsForStartNode(ctx context.Context, w v1alpha1.ExecutableWorkflow, inputs *core.LiteralMap) (executors.NodeStatus, error) { startNode := w.StartNode() if startNode == nil { return executors.NodeStatusFailed(errors.Errorf(errors.BadSpecificationError, v1alpha1.StartNodeID, "Start node not found")), nil } + ctx = contextutils.WithNodeID(ctx, startNode.GetID()) if inputs == nil { logger.Infof(ctx, "No inputs for the workflow. Skipping storing inputs") return executors.NodeStatusComplete, nil } + // StartNode is special. It does not have any processing step. It just takes the workflow (or subworkflow) inputs and converts to its own outputs - nodeStatus := w.GetNodeExecutionStatus(startNode.GetID()) + nodeStatus := w.GetNodeExecutionStatus(ctx, startNode.GetID()) + if nodeStatus.GetDataDir() == "" { return executors.NodeStatusUndefined, errors.Errorf(errors.IllegalStateError, startNode.GetID(), "no data-dir set, cannot store inputs") } + outputFile := v1alpha1.GetOutputsFile(nodeStatus.GetDataDir()) so := storage.Options{} if err := c.store.WriteProtobuf(ctx, outputFile, so, inputs); err != nil { logger.Errorf(ctx, "Failed to write protobuf (metadata). Error [%v]", err) return executors.NodeStatusUndefined, errors.Wrapf(errors.CausedByError, startNode.GetID(), err, "Failed to store workflow inputs (as start node)") } + return executors.NodeStatusComplete, nil } func (c *nodeExecutor) RecursiveNodeHandler(ctx context.Context, w v1alpha1.ExecutableWorkflow, currentNode v1alpha1.ExecutableNode) (executors.NodeStatus, error) { currentNodeCtx := contextutils.WithNodeID(ctx, currentNode.GetID()) - nodeStatus := w.GetNodeExecutionStatus(currentNode.GetID()) + nodeStatus := w.GetNodeExecutionStatus(ctx, currentNode.GetID()) + switch nodeStatus.GetPhase() { case v1alpha1.NodePhaseNotYetStarted, v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseFailing, v1alpha1.NodePhaseTimingOut, v1alpha1.NodePhaseRetryableFailure, v1alpha1.NodePhaseSucceeding: logger.Debugf(currentNodeCtx, "Handling node Status [%v]", nodeStatus.GetPhase().String()) @@ -546,11 +552,11 @@ func (c *nodeExecutor) RecursiveNodeHandler(ctx context.Context, w v1alpha1.Exec } func (c *nodeExecutor) FinalizeHandler(ctx context.Context, w v1alpha1.ExecutableWorkflow, currentNode v1alpha1.ExecutableNode) error { - nodeStatus := w.GetNodeExecutionStatus(currentNode.GetID()) + nodeStatus := w.GetExecutionStatus().GetNodeExecutionStatus(ctx, currentNode.GetID()) + switch nodeStatus.GetPhase() { case v1alpha1.NodePhaseFailing, v1alpha1.NodePhaseSucceeding, v1alpha1.NodePhaseRetryableFailure: ctx = contextutils.WithNodeID(ctx, currentNode.GetID()) - nodeStatus := w.GetNodeExecutionStatus(currentNode.GetID()) // Now depending on the node type decide h, err := c.nodeHandlerFactory.GetHandler(currentNode.GetKind()) @@ -599,11 +605,11 @@ func (c *nodeExecutor) FinalizeHandler(ctx context.Context, w v1alpha1.Executabl } func (c *nodeExecutor) AbortHandler(ctx context.Context, w v1alpha1.ExecutableWorkflow, currentNode v1alpha1.ExecutableNode, reason string) error { - nodeStatus := w.GetNodeExecutionStatus(currentNode.GetID()) + nodeStatus := w.GetExecutionStatus().GetNodeExecutionStatus(ctx, currentNode.GetID()) + switch nodeStatus.GetPhase() { case v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseFailing, v1alpha1.NodePhaseSucceeding, v1alpha1.NodePhaseRetryableFailure, v1alpha1.NodePhaseQueued: ctx = contextutils.WithNodeID(ctx, currentNode.GetID()) - nodeStatus := w.GetNodeExecutionStatus(currentNode.GetID()) // Now depending on the node type decide h, err := c.nodeHandlerFactory.GetHandler(currentNode.GetKind()) diff --git a/pkg/controller/nodes/executor_test.go b/pkg/controller/nodes/executor_test.go index e4a42b01c1..00c9d310df 100644 --- a/pkg/controller/nodes/executor_test.go +++ b/pkg/controller/nodes/executor_test.go @@ -56,7 +56,7 @@ func TestSetInputsForStartNode(t *testing.T) { } t.Run("NoInputs", func(t *testing.T) { - w := createDummyBaseWorkflow() + w := createDummyBaseWorkflow(mockStorage) w.DummyStartNode = &v1alpha1.NodeSpec{ ID: v1alpha1.StartNodeID, } @@ -66,8 +66,8 @@ func TestSetInputsForStartNode(t *testing.T) { }) t.Run("WithInputs", func(t *testing.T) { - w := createDummyBaseWorkflow() - w.GetNodeExecutionStatus(v1alpha1.StartNodeID).SetDataDir("s3://test-bucket/exec/start-node/data") + w := createDummyBaseWorkflow(mockStorage) + w.GetNodeExecutionStatus(ctx, v1alpha1.StartNodeID).SetDataDir("s3://test-bucket/exec/start-node/data") w.DummyStartNode = &v1alpha1.NodeSpec{ ID: v1alpha1.StartNodeID, } @@ -81,7 +81,7 @@ func TestSetInputsForStartNode(t *testing.T) { }) t.Run("DataDirNotSet", func(t *testing.T) { - w := createDummyBaseWorkflow() + w := createDummyBaseWorkflow(mockStorage) w.DummyStartNode = &v1alpha1.NodeSpec{ ID: v1alpha1.StartNodeID, } @@ -94,8 +94,8 @@ func TestSetInputsForStartNode(t *testing.T) { execFail, err := NewExecutor(ctx, config.GetConfig().DefaultDeadlines, failStorage, enQWf, events.NewMockEventSink(), launchplan.NewFailFastLaunchPlanExecutor(), 10, fakeKubeClient, catalogClient, promutils.NewTestScope()) assert.NoError(t, err) t.Run("StorageFailure", func(t *testing.T) { - w := createDummyBaseWorkflow() - w.GetNodeExecutionStatus(v1alpha1.StartNodeID).SetDataDir("s3://test-bucket/exec/start-node/data") + w := createDummyBaseWorkflow(mockStorage) + w.GetNodeExecutionStatus(ctx, v1alpha1.StartNodeID).SetDataDir("s3://test-bucket/exec/start-node/data") w.DummyStartNode = &v1alpha1.NodeSpec{ ID: v1alpha1.StartNodeID, } @@ -184,6 +184,7 @@ func TestNodeExecutor_RecursiveNodeHandler_RecurseStartNodes(t *testing.T) { }, }, }, + DataReferenceConstructor: store, }, startNode, startNodeStatus } @@ -280,6 +281,7 @@ func TestNodeExecutor_RecursiveNodeHandler_RecurseEndNode(t *testing.T) { }, }, }, + DataReferenceConstructor: store, }, n, ns } @@ -362,6 +364,7 @@ func TestNodeExecutor_RecursiveNodeHandler_RecurseEndNode(t *testing.T) { }, }, }, + DataReferenceConstructor: store, }, n, ns } @@ -406,7 +409,7 @@ func TestNodeExecutor_RecursiveNodeHandler_RecurseEndNode(t *testing.T) { mockWf, _, mockNodeStatus := createSingleNodeWf(test.currentNodePhase, 0) startNode := mockWf.StartNode() - startStatus := mockWf.GetNodeExecutionStatus(startNode.GetID()) + startStatus := mockWf.GetNodeExecutionStatus(ctx, startNode.GetID()) assert.Equal(t, v1alpha1.NodePhaseSucceeded, startStatus.GetPhase()) s, err := exec.RecursiveNodeHandler(ctx, mockWf, startNode) if test.expectedError { @@ -431,6 +434,7 @@ func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { defaultNodeID := "n1" taskID := taskID + store := createInmemoryDataStore(t, promutils.NewTestScope()) createSingleNodeWf := func(p v1alpha1.NodePhase, maxAttempts int) (v1alpha1.ExecutableWorkflow, v1alpha1.ExecutableNode, v1alpha1.ExecutableNodeStatus) { n := &v1alpha1.NodeSpec{ ID: defaultNodeID, @@ -478,6 +482,7 @@ func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { }, }, }, + DataReferenceConstructor: store, }, n, ns } @@ -540,8 +545,8 @@ func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { mockWf := &mocks.ExecutableWorkflow{} mockWf.On("StartNode").Return(mockNodeN0) mockWf.On("GetNode", nodeN2).Return(mockNode, true) - mockWf.On("GetNodeExecutionStatus", nodeN0).Return(mockN0Status) - mockWf.On("GetNodeExecutionStatus", nodeN2).Return(mockN2Status) + mockWf.OnGetNodeExecutionStatusMatch(mock.Anything, nodeN0).Return(mockN0Status) + mockWf.OnGetNodeExecutionStatusMatch(mock.Anything, nodeN2).Return(mockN2Status) mockWf.On("GetConnections").Return(connections) mockWf.On("GetID").Return("w1") mockWf.On("FromNode", nodeN0).Return([]string{nodeN2}, nil) @@ -687,7 +692,7 @@ func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { mockWf, _, mockNodeStatus := createSingleNodeWf(test.currentNodePhase, 0) startNode := mockWf.StartNode() - startStatus := mockWf.GetNodeExecutionStatus(startNode.GetID()) + startStatus := mockWf.GetNodeExecutionStatus(ctx, startNode.GetID()) assert.Equal(t, v1alpha1.NodePhaseSucceeded, startStatus.GetPhase()) s, err := exec.RecursiveNodeHandler(ctx, mockWf, startNode) if test.expectedError { @@ -915,6 +920,7 @@ func TestNodeExecutor_RecursiveNodeHandler_NoDownstream(t *testing.T) { }, }, }, + DataReferenceConstructor: store, }, n, ns } @@ -1012,6 +1018,7 @@ func TestNodeExecutor_RecursiveNodeHandler_UpstreamNotReady(t *testing.T) { }, }, }, + DataReferenceConstructor: store, }, n, ns } diff --git a/pkg/controller/nodes/handler/mocks/node.go b/pkg/controller/nodes/handler/mocks/node.go index 22026d8fb6..dcd1ae51ea 100644 --- a/pkg/controller/nodes/handler/mocks/node.go +++ b/pkg/controller/nodes/handler/mocks/node.go @@ -2,9 +2,12 @@ package mocks -import context "context" -import handler "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" -import mock "github.com/stretchr/testify/mock" +import ( + context "context" + + handler "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" + mock "github.com/stretchr/testify/mock" +) // Node is an autogenerated mock type for the Node type type Node struct { diff --git a/pkg/controller/nodes/handler/mocks/node_execution_context.go b/pkg/controller/nodes/handler/mocks/node_execution_context.go index 92ec92bc8b..db5219d400 100644 --- a/pkg/controller/nodes/handler/mocks/node_execution_context.go +++ b/pkg/controller/nodes/handler/mocks/node_execution_context.go @@ -2,12 +2,17 @@ package mocks -import events "github.com/lyft/flyteidl/clients/go/events" -import handler "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" -import io "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/io" -import mock "github.com/stretchr/testify/mock" -import storage "github.com/lyft/flytestdlib/storage" -import v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" +import ( + events "github.com/lyft/flyteidl/clients/go/events" + io "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/io" + handler "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" + + mock "github.com/stretchr/testify/mock" + + storage "github.com/lyft/flytestdlib/storage" + + v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" +) // NodeExecutionContext is an autogenerated mock type for the NodeExecutionContext type type NodeExecutionContext struct { diff --git a/pkg/controller/nodes/handler/mocks/node_execution_metadata.go b/pkg/controller/nodes/handler/mocks/node_execution_metadata.go index 6418701a7c..44ac549054 100644 --- a/pkg/controller/nodes/handler/mocks/node_execution_metadata.go +++ b/pkg/controller/nodes/handler/mocks/node_execution_metadata.go @@ -2,10 +2,14 @@ package mocks -import mock "github.com/stretchr/testify/mock" -import types "k8s.io/apimachinery/pkg/types" -import v1 "k8s.io/apimachinery/pkg/apis/meta/v1" -import v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" +import ( + mock "github.com/stretchr/testify/mock" + types "k8s.io/apimachinery/pkg/types" + + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" +) // NodeExecutionMetadata is an autogenerated mock type for the NodeExecutionMetadata type type NodeExecutionMetadata struct { diff --git a/pkg/controller/nodes/handler/mocks/node_state_reader.go b/pkg/controller/nodes/handler/mocks/node_state_reader.go index 64316cd2e1..85987c4820 100644 --- a/pkg/controller/nodes/handler/mocks/node_state_reader.go +++ b/pkg/controller/nodes/handler/mocks/node_state_reader.go @@ -2,8 +2,10 @@ package mocks -import handler "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" -import mock "github.com/stretchr/testify/mock" +import ( + handler "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" + mock "github.com/stretchr/testify/mock" +) // NodeStateReader is an autogenerated mock type for the NodeStateReader type type NodeStateReader struct { diff --git a/pkg/controller/nodes/handler/mocks/node_state_writer.go b/pkg/controller/nodes/handler/mocks/node_state_writer.go index bc8163c13b..42d051ba45 100644 --- a/pkg/controller/nodes/handler/mocks/node_state_writer.go +++ b/pkg/controller/nodes/handler/mocks/node_state_writer.go @@ -2,8 +2,10 @@ package mocks -import handler "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" -import mock "github.com/stretchr/testify/mock" +import ( + handler "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" + mock "github.com/stretchr/testify/mock" +) // NodeStateWriter is an autogenerated mock type for the NodeStateWriter type type NodeStateWriter struct { diff --git a/pkg/controller/nodes/handler/mocks/setup_context.go b/pkg/controller/nodes/handler/mocks/setup_context.go index 89f3959365..bbc2031757 100644 --- a/pkg/controller/nodes/handler/mocks/setup_context.go +++ b/pkg/controller/nodes/handler/mocks/setup_context.go @@ -2,8 +2,10 @@ package mocks -import mock "github.com/stretchr/testify/mock" -import promutils "github.com/lyft/flytestdlib/promutils" +import ( + promutils "github.com/lyft/flytestdlib/promutils" + mock "github.com/stretchr/testify/mock" +) // SetupContext is an autogenerated mock type for the SetupContext type type SetupContext struct { diff --git a/pkg/controller/nodes/handler/mocks/task_reader.go b/pkg/controller/nodes/handler/mocks/task_reader.go index 94bdb0a915..71d721ca31 100644 --- a/pkg/controller/nodes/handler/mocks/task_reader.go +++ b/pkg/controller/nodes/handler/mocks/task_reader.go @@ -2,10 +2,13 @@ package mocks -import context "context" -import core "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" +import ( + context "context" -import mock "github.com/stretchr/testify/mock" + core "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + + mock "github.com/stretchr/testify/mock" +) // TaskReader is an autogenerated mock type for the TaskReader type type TaskReader struct { diff --git a/pkg/controller/nodes/mocks/handler_factory.go b/pkg/controller/nodes/mocks/handler_factory.go index d8a0c9a6d2..3b7bc6e0e8 100644 --- a/pkg/controller/nodes/mocks/handler_factory.go +++ b/pkg/controller/nodes/mocks/handler_factory.go @@ -2,11 +2,14 @@ package mocks -import context "context" -import handler "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" -import mock "github.com/stretchr/testify/mock" +import ( + context "context" -import v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + handler "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" + mock "github.com/stretchr/testify/mock" + + v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" +) // HandlerFactory is an autogenerated mock type for the HandlerFactory type type HandlerFactory struct { diff --git a/pkg/controller/nodes/mocks/output_resolver.go b/pkg/controller/nodes/mocks/output_resolver.go index 46333b52d5..061e876877 100644 --- a/pkg/controller/nodes/mocks/output_resolver.go +++ b/pkg/controller/nodes/mocks/output_resolver.go @@ -2,11 +2,14 @@ package mocks -import context "context" -import core "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" -import mock "github.com/stretchr/testify/mock" +import ( + context "context" -import v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + core "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + mock "github.com/stretchr/testify/mock" + + v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" +) // OutputResolver is an autogenerated mock type for the OutputResolver type type OutputResolver struct { @@ -21,7 +24,7 @@ func (_m OutputResolver_ExtractOutput) Return(values *core.Literal, err error) * return &OutputResolver_ExtractOutput{Call: _m.Call.Return(values, err)} } -func (_m *OutputResolver) OnExtractOutput(ctx context.Context, w v1alpha1.ExecutableWorkflow, n v1alpha1.ExecutableNode, bindToVar string) *OutputResolver_ExtractOutput { +func (_m *OutputResolver) OnExtractOutput(ctx context.Context, w v1alpha1.BaseWorkflowWithStatus, n v1alpha1.ExecutableNode, bindToVar string) *OutputResolver_ExtractOutput { c := _m.On("ExtractOutput", ctx, w, n, bindToVar) return &OutputResolver_ExtractOutput{Call: c} } @@ -32,11 +35,11 @@ func (_m *OutputResolver) OnExtractOutputMatch(matchers ...interface{}) *OutputR } // ExtractOutput provides a mock function with given fields: ctx, w, n, bindToVar -func (_m *OutputResolver) ExtractOutput(ctx context.Context, w v1alpha1.ExecutableWorkflow, n v1alpha1.ExecutableNode, bindToVar string) (*core.Literal, error) { +func (_m *OutputResolver) ExtractOutput(ctx context.Context, w v1alpha1.BaseWorkflowWithStatus, n v1alpha1.ExecutableNode, bindToVar string) (*core.Literal, error) { ret := _m.Called(ctx, w, n, bindToVar) var r0 *core.Literal - if rf, ok := ret.Get(0).(func(context.Context, v1alpha1.ExecutableWorkflow, v1alpha1.ExecutableNode, string) *core.Literal); ok { + if rf, ok := ret.Get(0).(func(context.Context, v1alpha1.BaseWorkflowWithStatus, v1alpha1.ExecutableNode, string) *core.Literal); ok { r0 = rf(ctx, w, n, bindToVar) } else { if ret.Get(0) != nil { @@ -45,7 +48,7 @@ func (_m *OutputResolver) ExtractOutput(ctx context.Context, w v1alpha1.Executab } var r1 error - if rf, ok := ret.Get(1).(func(context.Context, v1alpha1.ExecutableWorkflow, v1alpha1.ExecutableNode, string) error); ok { + if rf, ok := ret.Get(1).(func(context.Context, v1alpha1.BaseWorkflowWithStatus, v1alpha1.ExecutableNode, string) error); ok { r1 = rf(ctx, w, n, bindToVar) } else { r1 = ret.Error(1) diff --git a/pkg/controller/nodes/node_state_manager.go b/pkg/controller/nodes/node_state_manager.go index bb3b79feab..b8ec926a7e 100644 --- a/pkg/controller/nodes/node_state_manager.go +++ b/pkg/controller/nodes/node_state_manager.go @@ -85,6 +85,8 @@ func (n nodeStateManager) clearNodeStatus() { n.t = nil n.b = nil n.d = nil + n.w = nil + n.nodeStatus.ClearLastAttemptStartedAt() } func newNodeStateManager(_ context.Context, status v1alpha1.ExecutableNodeStatus) *nodeStateManager { diff --git a/pkg/controller/nodes/output_resolver.go b/pkg/controller/nodes/output_resolver.go index bb33e341c0..d14815816b 100644 --- a/pkg/controller/nodes/output_resolver.go +++ b/pkg/controller/nodes/output_resolver.go @@ -20,7 +20,7 @@ type VarName = string type OutputResolver interface { // Extracts a subset of node outputs to literals. - ExtractOutput(ctx context.Context, w v1alpha1.ExecutableWorkflow, n v1alpha1.ExecutableNode, + ExtractOutput(ctx context.Context, w v1alpha1.BaseWorkflowWithStatus, n v1alpha1.ExecutableNode, bindToVar VarName) (values *core.Literal, err error) } @@ -37,9 +37,9 @@ type remoteFileOutputResolver struct { store *storage.DataStore } -func (r remoteFileOutputResolver) ExtractOutput(ctx context.Context, w v1alpha1.ExecutableWorkflow, n v1alpha1.ExecutableNode, +func (r remoteFileOutputResolver) ExtractOutput(ctx context.Context, w v1alpha1.BaseWorkflowWithStatus, n v1alpha1.ExecutableNode, bindToVar VarName) (values *core.Literal, err error) { - nodeStatus := w.GetNodeExecutionStatus(n.GetID()) + nodeStatus := w.GetNodeExecutionStatus(ctx, n.GetID()) outputsFileRef := v1alpha1.GetOutputsFile(nodeStatus.GetDataDir()) index, actualVar, err := ParseVarName(bindToVar) diff --git a/pkg/controller/nodes/predicate.go b/pkg/controller/nodes/predicate.go index 216a254285..499df73963 100644 --- a/pkg/controller/nodes/predicate.go +++ b/pkg/controller/nodes/predicate.go @@ -43,15 +43,17 @@ func CanExecute(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha logger.Debugf(ctx, "Start Node id is assumed to be ready.") return PredicatePhaseReady, nil } - nodeStatus := w.GetNodeExecutionStatus(nodeID) + + nodeStatus := w.GetNodeExecutionStatus(ctx, nodeID) parentNodeID := nodeStatus.GetParentNodeID() upstreamNodes, ok := w.GetConnections().UpstreamEdges[nodeID] if !ok { return PredicatePhaseUndefined, errors.Errorf(errors.BadSpecificationError, nodeID, "Unable to find upstream nodes for Node") } + skipped := false for _, upstreamNodeID := range upstreamNodes { - upstreamNodeStatus := w.GetNodeExecutionStatus(upstreamNodeID) + upstreamNodeStatus := w.GetNodeExecutionStatus(ctx, upstreamNodeID) if upstreamNodeStatus.IsDirty() { return PredicatePhaseNotReady, nil @@ -62,11 +64,13 @@ func CanExecute(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha if !ok { return PredicatePhaseUndefined, errors.Errorf(errors.BadSpecificationError, nodeID, "Upstream node [%v] of node [%v] not defined", upstreamNodeID, nodeID) } + // This only happens if current node is the child node of a branch node if upstreamNode.GetBranchNode() == nil || upstreamNodeStatus.GetOrCreateBranchStatus().GetPhase() != v1alpha1.BranchNodeSuccess { logger.Debugf(ctx, "Branch sub node is expected to have parent branch node in succeeded state") return PredicatePhaseUndefined, errors.Errorf(errors.IllegalStateError, nodeID, "Upstream node [%v] is set as parent, but is not a branch node of [%v] or in illegal state.", upstreamNodeID, nodeID) } + continue } @@ -76,9 +80,11 @@ func CanExecute(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha return PredicatePhaseNotReady, nil } } + if skipped { return PredicatePhaseSkip, nil } + return PredicatePhaseReady, nil } @@ -90,7 +96,7 @@ func GetParentNodeMaxEndTime(ctx context.Context, w v1alpha1.ExecutableWorkflow, return zeroTime, nil } - nodeStatus := w.GetNodeExecutionStatus(node.GetID()) + nodeStatus := w.GetNodeExecutionStatus(ctx, node.GetID()) parentNodeID := nodeStatus.GetParentNodeID() upstreamNodes, ok := w.GetConnections().UpstreamEdges[nodeID] if !ok { @@ -99,7 +105,7 @@ func GetParentNodeMaxEndTime(ctx context.Context, w v1alpha1.ExecutableWorkflow, var latest v1.Time for _, upstreamNodeID := range upstreamNodes { - upstreamNodeStatus := w.GetNodeExecutionStatus(upstreamNodeID) + upstreamNodeStatus := w.GetNodeExecutionStatus(ctx, upstreamNodeID) if parentNodeID != nil && *parentNodeID == upstreamNodeID { upstreamNode, ok := w.GetNode(upstreamNodeID) if !ok { diff --git a/pkg/controller/nodes/predicate_test.go b/pkg/controller/nodes/predicate_test.go index ae151010b1..3519936062 100644 --- a/pkg/controller/nodes/predicate_test.go +++ b/pkg/controller/nodes/predicate_test.go @@ -38,7 +38,7 @@ func TestCanExecute(t *testing.T) { mockNode := &mocks.BaseNode{} mockNode.On("GetID").Return(nodeN2) mockWf := &mocks.ExecutableWorkflow{} - mockWf.On("GetNodeExecutionStatus", nodeN2).Return(mockNodeStatus) + mockWf.OnGetNodeExecutionStatus(ctx, nodeN2).Return(mockNodeStatus) mockWf.On("GetConnections").Return(&v1alpha1.Connections{}) mockWf.On("GetID").Return("w1") @@ -65,9 +65,9 @@ func TestCanExecute(t *testing.T) { mockN1Status.On("IsDirty").Return(false) mockWf := &mocks.ExecutableWorkflow{} - mockWf.On("GetNodeExecutionStatus", nodeN0).Return(mockN0Status) - mockWf.On("GetNodeExecutionStatus", nodeN1).Return(mockN1Status) - mockWf.On("GetNodeExecutionStatus", nodeN2).Return(mockN2Status) + mockWf.OnGetNodeExecutionStatus(ctx, nodeN0).Return(mockN0Status) + mockWf.OnGetNodeExecutionStatus(ctx, nodeN1).Return(mockN1Status) + mockWf.OnGetNodeExecutionStatus(ctx, nodeN2).Return(mockN2Status) mockWf.On("GetConnections").Return(connections) mockWf.On("GetID").Return("w1") @@ -95,9 +95,9 @@ func TestCanExecute(t *testing.T) { mockN1Status.On("IsDirty").Return(false) mockWf := &mocks.ExecutableWorkflow{} - mockWf.On("GetNodeExecutionStatus", nodeN0).Return(mockN0Status) - mockWf.On("GetNodeExecutionStatus", nodeN1).Return(mockN1Status) - mockWf.On("GetNodeExecutionStatus", nodeN2).Return(mockN2Status) + mockWf.OnGetNodeExecutionStatus(ctx, nodeN0).Return(mockN0Status) + mockWf.OnGetNodeExecutionStatus(ctx, nodeN1).Return(mockN1Status) + mockWf.OnGetNodeExecutionStatus(ctx, nodeN2).Return(mockN2Status) mockWf.On("GetConnections").Return(connections) mockWf.On("GetID").Return("w1") @@ -125,9 +125,9 @@ func TestCanExecute(t *testing.T) { mockN1Status.On("IsDirty").Return(false) mockWf := &mocks.ExecutableWorkflow{} - mockWf.On("GetNodeExecutionStatus", nodeN0).Return(mockN0Status) - mockWf.On("GetNodeExecutionStatus", nodeN1).Return(mockN1Status) - mockWf.On("GetNodeExecutionStatus", nodeN2).Return(mockN2Status) + mockWf.OnGetNodeExecutionStatus(ctx, nodeN0).Return(mockN0Status) + mockWf.OnGetNodeExecutionStatus(ctx, nodeN1).Return(mockN1Status) + mockWf.OnGetNodeExecutionStatus(ctx, nodeN2).Return(mockN2Status) mockWf.On("GetConnections").Return(connections) mockWf.On("GetID").Return("w1") @@ -155,9 +155,9 @@ func TestCanExecute(t *testing.T) { mockN1Status.On("IsDirty").Return(true) mockWf := &mocks.ExecutableWorkflow{} - mockWf.On("GetNodeExecutionStatus", nodeN0).Return(mockN0Status) - mockWf.On("GetNodeExecutionStatus", nodeN1).Return(mockN1Status) - mockWf.On("GetNodeExecutionStatus", nodeN2).Return(mockN2Status) + mockWf.OnGetNodeExecutionStatus(ctx, nodeN0).Return(mockN0Status) + mockWf.OnGetNodeExecutionStatus(ctx, nodeN1).Return(mockN1Status) + mockWf.OnGetNodeExecutionStatus(ctx, nodeN2).Return(mockN2Status) mockWf.On("GetConnections").Return(connections) mockWf.On("GetID").Return("w1") @@ -185,9 +185,9 @@ func TestCanExecute(t *testing.T) { mockN1Status.On("IsDirty").Return(false) mockWf := &mocks.ExecutableWorkflow{} - mockWf.On("GetNodeExecutionStatus", nodeN0).Return(mockN0Status) - mockWf.On("GetNodeExecutionStatus", nodeN1).Return(mockN1Status) - mockWf.On("GetNodeExecutionStatus", nodeN2).Return(mockN2Status) + mockWf.OnGetNodeExecutionStatus(ctx, nodeN0).Return(mockN0Status) + mockWf.OnGetNodeExecutionStatus(ctx, nodeN1).Return(mockN1Status) + mockWf.OnGetNodeExecutionStatus(ctx, nodeN2).Return(mockN2Status) mockWf.On("GetConnections").Return(connections) mockWf.On("GetID").Return("w1") @@ -215,9 +215,9 @@ func TestCanExecute(t *testing.T) { mockN1Status.On("IsDirty").Return(false) mockWf := &mocks.ExecutableWorkflow{} - mockWf.On("GetNodeExecutionStatus", nodeN0).Return(mockN0Status) - mockWf.On("GetNodeExecutionStatus", nodeN1).Return(mockN1Status) - mockWf.On("GetNodeExecutionStatus", nodeN2).Return(mockN2Status) + mockWf.OnGetNodeExecutionStatus(ctx, nodeN0).Return(mockN0Status) + mockWf.OnGetNodeExecutionStatus(ctx, nodeN1).Return(mockN1Status) + mockWf.OnGetNodeExecutionStatus(ctx, nodeN2).Return(mockN2Status) mockWf.On("GetConnections").Return(connections) mockWf.On("GetID").Return("w1") @@ -245,9 +245,9 @@ func TestCanExecute(t *testing.T) { mockN1Status.On("IsDirty").Return(false) mockWf := &mocks.ExecutableWorkflow{} - mockWf.On("GetNodeExecutionStatus", nodeN0).Return(mockN0Status) - mockWf.On("GetNodeExecutionStatus", nodeN1).Return(mockN1Status) - mockWf.On("GetNodeExecutionStatus", nodeN2).Return(mockN2Status) + mockWf.OnGetNodeExecutionStatus(ctx, nodeN0).Return(mockN0Status) + mockWf.OnGetNodeExecutionStatus(ctx, nodeN1).Return(mockN1Status) + mockWf.OnGetNodeExecutionStatus(ctx, nodeN2).Return(mockN2Status) mockWf.On("GetConnections").Return(connections) mockWf.On("GetID").Return("w1") @@ -276,9 +276,9 @@ func TestCanExecute(t *testing.T) { mockN1Status.On("IsDirty").Return(false) mockWf := &mocks.ExecutableWorkflow{} - mockWf.On("GetNodeExecutionStatus", nodeN0).Return(mockN0Status) - mockWf.On("GetNodeExecutionStatus", nodeN1).Return(mockN1Status) - mockWf.On("GetNodeExecutionStatus", nodeN2).Return(mockN2Status) + mockWf.OnGetNodeExecutionStatus(ctx, nodeN0).Return(mockN0Status) + mockWf.OnGetNodeExecutionStatus(ctx, nodeN1).Return(mockN1Status) + mockWf.OnGetNodeExecutionStatus(ctx, nodeN2).Return(mockN2Status) mockWf.On("GetConnections").Return(connections) mockWf.On("GetID").Return("w1") @@ -309,9 +309,9 @@ func TestCanExecute(t *testing.T) { mockN1Status.On("IsDirty").Return(false) mockWf := &mocks.ExecutableWorkflow{} - mockWf.On("GetNodeExecutionStatus", nodeN0).Return(mockN0Status) - mockWf.On("GetNodeExecutionStatus", nodeN1).Return(mockN1Status) - mockWf.On("GetNodeExecutionStatus", nodeN2).Return(mockN2Status) + mockWf.OnGetNodeExecutionStatus(ctx, nodeN0).Return(mockN0Status) + mockWf.OnGetNodeExecutionStatus(ctx, nodeN1).Return(mockN1Status) + mockWf.OnGetNodeExecutionStatus(ctx, nodeN2).Return(mockN2Status) mockWf.On("GetConnections").Return(connections) mockWf.On("GetNode", nodeN0).Return(nil, false) mockWf.On("GetID").Return("w1") @@ -343,9 +343,9 @@ func TestCanExecute(t *testing.T) { mockN1Status.On("IsDirty").Return(false) mockWf := &mocks.ExecutableWorkflow{} - mockWf.On("GetNodeExecutionStatus", nodeN0).Return(mockN0Status) - mockWf.On("GetNodeExecutionStatus", nodeN1).Return(mockN1Status) - mockWf.On("GetNodeExecutionStatus", nodeN2).Return(mockN2Status) + mockWf.OnGetNodeExecutionStatus(ctx, nodeN0).Return(mockN0Status) + mockWf.OnGetNodeExecutionStatus(ctx, nodeN1).Return(mockN1Status) + mockWf.OnGetNodeExecutionStatus(ctx, nodeN2).Return(mockN2Status) mockWf.On("GetConnections").Return(connections) mockWf.On("GetNode", nodeN0).Return(mockN0Node, true) mockWf.On("GetID").Return("w1") @@ -381,9 +381,9 @@ func TestCanExecute(t *testing.T) { mockN1Status.On("IsDirty").Return(false) mockWf := &mocks.ExecutableWorkflow{} - mockWf.On("GetNodeExecutionStatus", nodeN0).Return(mockN0Status) - mockWf.On("GetNodeExecutionStatus", nodeN1).Return(mockN1Status) - mockWf.On("GetNodeExecutionStatus", nodeN2).Return(mockN2Status) + mockWf.OnGetNodeExecutionStatus(ctx, nodeN0).Return(mockN0Status) + mockWf.OnGetNodeExecutionStatus(ctx, nodeN1).Return(mockN1Status) + mockWf.OnGetNodeExecutionStatus(ctx, nodeN2).Return(mockN2Status) mockWf.On("GetConnections").Return(connections) mockWf.On("GetNode", nodeN0).Return(mockN0Node, true) mockWf.On("GetID").Return("w1") @@ -419,9 +419,9 @@ func TestCanExecute(t *testing.T) { mockN1Status.On("IsDirty").Return(false) mockWf := &mocks.ExecutableWorkflow{} - mockWf.On("GetNodeExecutionStatus", nodeN0).Return(mockN0Status) - mockWf.On("GetNodeExecutionStatus", nodeN1).Return(mockN1Status) - mockWf.On("GetNodeExecutionStatus", nodeN2).Return(mockN2Status) + mockWf.OnGetNodeExecutionStatus(ctx, nodeN0).Return(mockN0Status) + mockWf.OnGetNodeExecutionStatus(ctx, nodeN1).Return(mockN1Status) + mockWf.OnGetNodeExecutionStatus(ctx, nodeN2).Return(mockN2Status) mockWf.On("GetConnections").Return(connections) mockWf.On("GetNode", nodeN0).Return(mockN0Node, true) mockWf.On("GetID").Return("w1") @@ -458,9 +458,9 @@ func TestCanExecute(t *testing.T) { mockN1Status.On("IsDirty").Return(false) mockWf := &mocks.ExecutableWorkflow{} - mockWf.On("GetNodeExecutionStatus", nodeN0).Return(mockN0Status) - mockWf.On("GetNodeExecutionStatus", nodeN1).Return(mockN1Status) - mockWf.On("GetNodeExecutionStatus", nodeN2).Return(mockN2Status) + mockWf.OnGetNodeExecutionStatus(ctx, nodeN0).Return(mockN0Status) + mockWf.OnGetNodeExecutionStatus(ctx, nodeN1).Return(mockN1Status) + mockWf.OnGetNodeExecutionStatus(ctx, nodeN2).Return(mockN2Status) mockWf.On("GetConnections").Return(connections) mockWf.On("GetNode", nodeN0).Return(mockN0Node, true) mockWf.On("GetID").Return("w1") @@ -497,9 +497,9 @@ func TestCanExecute(t *testing.T) { mockN1Status.On("IsDirty").Return(false) mockWf := &mocks.ExecutableWorkflow{} - mockWf.On("GetNodeExecutionStatus", nodeN0).Return(mockN0Status) - mockWf.On("GetNodeExecutionStatus", nodeN1).Return(mockN1Status) - mockWf.On("GetNodeExecutionStatus", nodeN2).Return(mockN2Status) + mockWf.OnGetNodeExecutionStatus(ctx, nodeN0).Return(mockN0Status) + mockWf.OnGetNodeExecutionStatus(ctx, nodeN1).Return(mockN1Status) + mockWf.OnGetNodeExecutionStatus(ctx, nodeN2).Return(mockN2Status) mockWf.On("GetConnections").Return(connections) mockWf.On("GetNode", nodeN0).Return(mockN0Node, true) mockWf.On("GetID").Return("w1") @@ -536,9 +536,9 @@ func TestCanExecute(t *testing.T) { mockN1Status.On("IsDirty").Return(false) mockWf := &mocks.ExecutableWorkflow{} - mockWf.On("GetNodeExecutionStatus", nodeN0).Return(mockN0Status) - mockWf.On("GetNodeExecutionStatus", nodeN1).Return(mockN1Status) - mockWf.On("GetNodeExecutionStatus", nodeN2).Return(mockN2Status) + mockWf.OnGetNodeExecutionStatus(ctx, nodeN0).Return(mockN0Status) + mockWf.OnGetNodeExecutionStatus(ctx, nodeN1).Return(mockN1Status) + mockWf.OnGetNodeExecutionStatus(ctx, nodeN2).Return(mockN2Status) mockWf.On("GetConnections").Return(connections) mockWf.On("GetNode", nodeN0).Return(mockN0Node, true) mockWf.On("GetID").Return("w1") diff --git a/pkg/controller/nodes/resolve.go b/pkg/controller/nodes/resolve.go index 196f6827f5..38b152f4f1 100644 --- a/pkg/controller/nodes/resolve.go +++ b/pkg/controller/nodes/resolve.go @@ -9,7 +9,7 @@ import ( "github.com/lyft/flytepropeller/pkg/controller/nodes/errors" ) -func ResolveBindingData(ctx context.Context, outputResolver OutputResolver, w v1alpha1.ExecutableWorkflow, bindingData *core.BindingData) (*core.Literal, error) { +func ResolveBindingData(ctx context.Context, outputResolver OutputResolver, w v1alpha1.BaseWorkflowWithStatus, bindingData *core.BindingData) (*core.Literal, error) { literal := &core.Literal{} if bindingData == nil { return nil, nil @@ -75,7 +75,7 @@ func ResolveBindingData(ctx context.Context, outputResolver OutputResolver, w v1 return literal, nil } -func Resolve(ctx context.Context, outputResolver OutputResolver, w v1alpha1.ExecutableWorkflow, nodeID v1alpha1.NodeID, bindings []*v1alpha1.Binding) (*core.LiteralMap, error) { +func Resolve(ctx context.Context, outputResolver OutputResolver, w v1alpha1.BaseWorkflowWithStatus, nodeID v1alpha1.NodeID, bindings []*v1alpha1.Binding) (*core.LiteralMap, error) { literalMap := make(map[string]*core.Literal, len(bindings)) for _, binding := range bindings { varName := binding.GetVar() diff --git a/pkg/controller/nodes/resolve_test.go b/pkg/controller/nodes/resolve_test.go index 87a3541e1b..529c252493 100644 --- a/pkg/controller/nodes/resolve_test.go +++ b/pkg/controller/nodes/resolve_test.go @@ -25,6 +25,7 @@ type dummyBaseWorkflow struct { FromNodeCb func(name v1alpha1.NodeID) ([]v1alpha1.NodeID, error) GetNodeCb func(nodeId v1alpha1.NodeID) (v1alpha1.ExecutableNode, bool) Status map[v1alpha1.NodeID]*v1alpha1.NodeStatus + DataStore *storage.DataStore } func (d *dummyBaseWorkflow) GetOutputBindings() []*v1alpha1.Binding { @@ -101,13 +102,17 @@ func (d *dummyBaseWorkflow) GetExecutionStatus() v1alpha1.ExecutableWorkflowStat return nil } -func (d *dummyBaseWorkflow) GetNodeExecutionStatus(id v1alpha1.NodeID) v1alpha1.ExecutableNodeStatus { +func (d *dummyBaseWorkflow) GetNodeExecutionStatus(_ context.Context, id v1alpha1.NodeID) v1alpha1.ExecutableNodeStatus { n, ok := d.Status[id] if ok { + n.DataReferenceConstructor = d.DataStore return n } - n = &v1alpha1.NodeStatus{} + n = &v1alpha1.NodeStatus{ + MutableStruct: v1alpha1.MutableStruct{}, + } d.Status[id] = n + n.DataReferenceConstructor = d.DataStore return n } @@ -127,12 +132,13 @@ func (d *dummyBaseWorkflow) GetNode(nodeID v1alpha1.NodeID) (v1alpha1.Executable return d.GetNodeCb(nodeID) } -func createDummyBaseWorkflow() *dummyBaseWorkflow { +func createDummyBaseWorkflow(dataStore *storage.DataStore) *dummyBaseWorkflow { return &dummyBaseWorkflow{ ID: "w1", Status: map[v1alpha1.NodeID]*v1alpha1.NodeStatus{ v1alpha1.StartNodeID: {}, }, + DataStore: dataStore, } } diff --git a/pkg/controller/nodes/subworkflow/handler.go b/pkg/controller/nodes/subworkflow/handler.go index 235cff5ad1..9a485a466a 100644 --- a/pkg/controller/nodes/subworkflow/handler.go +++ b/pkg/controller/nodes/subworkflow/handler.go @@ -81,7 +81,7 @@ func (w *workflowNodeHandler) Handle(ctx context.Context, nCtx handler.NodeExecu if wfNode.GetSubWorkflowRef() != nil { wf := nCtx.Workflow() - status := wf.GetNodeExecutionStatus(nCtx.NodeID()) + status := wf.GetNodeExecutionStatus(ctx, nCtx.NodeID()) return w.subWfHandler.CheckSubWorkflowStatus(ctx, nCtx, wf, status) } else if wfNode.GetLaunchPlanRefID() != nil { return w.lpHandler.CheckLaunchPlanStatus(ctx, nCtx) diff --git a/pkg/controller/nodes/subworkflow/handler_test.go b/pkg/controller/nodes/subworkflow/handler_test.go index 29b412a506..029fa3f7b0 100644 --- a/pkg/controller/nodes/subworkflow/handler_test.go +++ b/pkg/controller/nodes/subworkflow/handler_test.go @@ -132,7 +132,7 @@ func TestWorkflowNodeHandler_StartNode_Launchplan(t *testing.T) { Project: "z", } mockWf := &mocks2.ExecutableWorkflow{} - mockWf.On("GetNodeExecutionStatus", nodeID).Return(mockNodeStatus) + mockWf.OnGetNodeExecutionStatus(ctx, nodeID).Return(mockNodeStatus) mockWf.On("GetExecutionID").Return(v1alpha1.WorkflowExecutionIdentifier{ WorkflowExecutionIdentifier: parentID, }) @@ -195,7 +195,7 @@ func TestWorkflowNodeHandler_CheckNodeStatus(t *testing.T) { Project: "z", } mockWf := &mocks2.ExecutableWorkflow{} - mockWf.On("GetNodeExecutionStatus", nodeID).Return(mockNodeStatus) + mockWf.OnGetNodeExecutionStatus(ctx, nodeID).Return(mockNodeStatus) mockWf.On("GetExecutionID").Return(v1alpha1.WorkflowExecutionIdentifier{ WorkflowExecutionIdentifier: parentID, }) @@ -255,7 +255,7 @@ func TestWorkflowNodeHandler_AbortNode(t *testing.T) { Project: "z", } mockWf := &mocks2.ExecutableWorkflow{} - mockWf.On("GetNodeExecutionStatus", nodeID).Return(mockNodeStatus) + mockWf.OnGetNodeExecutionStatus(ctx, nodeID).Return(mockNodeStatus) mockWf.On("GetName").Return("test") mockWf.On("GetExecutionID").Return(v1alpha1.WorkflowExecutionIdentifier{ WorkflowExecutionIdentifier: parentID, diff --git a/pkg/controller/nodes/subworkflow/launchplan.go b/pkg/controller/nodes/subworkflow/launchplan.go index c50f03dece..f2fe430736 100644 --- a/pkg/controller/nodes/subworkflow/launchplan.go +++ b/pkg/controller/nodes/subworkflow/launchplan.go @@ -26,7 +26,7 @@ func (l *launchPlanHandler) StartLaunchPlan(ctx context.Context, nCtx handler.No } w := nCtx.Workflow() - nodeStatus := w.GetNodeExecutionStatus(nCtx.NodeID()) + nodeStatus := w.GetNodeExecutionStatus(ctx, nCtx.NodeID()) childID, err := GetChildWorkflowExecutionID( w.GetExecutionID().WorkflowExecutionIdentifier, nCtx.NodeID(), @@ -69,7 +69,7 @@ func (l *launchPlanHandler) CheckLaunchPlanStatus(ctx context.Context, nCtx hand // Handle launch plan w := nCtx.Workflow() - nodeStatus := w.GetNodeExecutionStatus(nCtx.NodeID()) + nodeStatus := w.GetNodeExecutionStatus(ctx, nCtx.NodeID()) childID, err := GetChildWorkflowExecutionID( w.GetExecutionID().WorkflowExecutionIdentifier, nCtx.NodeID(), @@ -150,7 +150,7 @@ func (l *launchPlanHandler) CheckLaunchPlanStatus(ctx context.Context, nCtx hand } func (l *launchPlanHandler) HandleAbort(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode) error { - nodeStatus := w.GetNodeExecutionStatus(node.GetID()) + nodeStatus := w.GetNodeExecutionStatus(ctx, node.GetID()) childID, err := GetChildWorkflowExecutionID( w.GetExecutionID().WorkflowExecutionIdentifier, node.GetID(), diff --git a/pkg/controller/nodes/subworkflow/launchplan/mocks/Executor.go b/pkg/controller/nodes/subworkflow/launchplan/mocks/Executor.go index 346085e169..c85036471e 100644 --- a/pkg/controller/nodes/subworkflow/launchplan/mocks/Executor.go +++ b/pkg/controller/nodes/subworkflow/launchplan/mocks/Executor.go @@ -2,11 +2,17 @@ package mocks -import admin "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" -import context "context" -import core "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" -import launchplan "github.com/lyft/flytepropeller/pkg/controller/nodes/subworkflow/launchplan" -import mock "github.com/stretchr/testify/mock" +import ( + context "context" + + admin "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" + + core "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + + launchplan "github.com/lyft/flytepropeller/pkg/controller/nodes/subworkflow/launchplan" + + mock "github.com/stretchr/testify/mock" +) // Executor is an autogenerated mock type for the Executor type type Executor struct { diff --git a/pkg/controller/nodes/subworkflow/launchplan_test.go b/pkg/controller/nodes/subworkflow/launchplan_test.go index c9772104d3..51f7e5bd72 100644 --- a/pkg/controller/nodes/subworkflow/launchplan_test.go +++ b/pkg/controller/nodes/subworkflow/launchplan_test.go @@ -64,7 +64,7 @@ func TestSubWorkflowHandler_StartLaunchPlan(t *testing.T) { Project: "z", } mockWf := &mocks2.ExecutableWorkflow{} - mockWf.On("GetNodeExecutionStatus", nodeID).Return(mockNodeStatus) + mockWf.OnGetNodeExecutionStatus(ctx, nodeID).Return(mockNodeStatus) mockWf.On("GetExecutionID").Return(v1alpha1.WorkflowExecutionIdentifier{ WorkflowExecutionIdentifier: parentID, }) @@ -211,7 +211,7 @@ func TestSubWorkflowHandler_CheckLaunchPlanStatus(t *testing.T) { Project: "z", } mockWf := &mocks2.ExecutableWorkflow{} - mockWf.On("GetNodeExecutionStatus", nodeID).Return(mockNodeStatus) + mockWf.OnGetNodeExecutionStatus(ctx, nodeID).Return(mockNodeStatus) mockWf.On("GetExecutionID").Return(v1alpha1.WorkflowExecutionIdentifier{ WorkflowExecutionIdentifier: parentID, }) @@ -297,7 +297,7 @@ func TestSubWorkflowHandler_CheckLaunchPlanStatus(t *testing.T) { }, nil) nCtx := createNodeContext(v1alpha1.WorkflowNodePhaseExecuting, mockWf, mockNode) - nCtx.On("DataStore").Return(mockStore) + nCtx.OnDataStore().Return(mockStore) s, err := h.CheckLaunchPlanStatus(ctx, nCtx) assert.NoError(t, err) assert.Equal(t, handler.EPhaseSuccess, s.Info().GetPhase()) @@ -338,7 +338,7 @@ func TestSubWorkflowHandler_CheckLaunchPlanStatus(t *testing.T) { }, nil) nCtx := createNodeContext(v1alpha1.WorkflowNodePhaseExecuting, mockWf, mockNode) - nCtx.On("DataStore").Return(mockStore) + nCtx.OnDataStore().Return(mockStore) s, err := h.CheckLaunchPlanStatus(ctx, nCtx) assert.NoError(t, err) assert.Equal(t, s.Info().GetPhase(), handler.EPhaseSuccess) @@ -497,7 +497,7 @@ func TestSubWorkflowHandler_CheckLaunchPlanStatus(t *testing.T) { }, nil) nCtx := createNodeContext(v1alpha1.WorkflowNodePhaseExecuting, mockWf, mockNode) - nCtx.On("DataStore").Return(mockStore) + nCtx.OnDataStore().Return(mockStore) s, err := h.CheckLaunchPlanStatus(ctx, nCtx) assert.Error(t, err) assert.Equal(t, s.Info().GetPhase(), handler.EPhaseUndefined) @@ -530,7 +530,7 @@ func TestSubWorkflowHandler_CheckLaunchPlanStatus(t *testing.T) { }, nil) nCtx := createNodeContext(v1alpha1.WorkflowNodePhaseExecuting, mockWf, mockNode) - nCtx.On("DataStore").Return(mockStore) + nCtx.OnDataStore().Return(mockStore) s, err := h.CheckLaunchPlanStatus(ctx, nCtx) assert.NotNil(t, err) assert.Equal(t, handler.EPhaseUndefined, s.Info().GetPhase()) @@ -563,7 +563,7 @@ func TestSubWorkflowHandler_CheckLaunchPlanStatus(t *testing.T) { }, nil) nCtx := createNodeContext(v1alpha1.WorkflowNodePhaseExecuting, mockWf, mockNode) - nCtx.On("DataStore").Return(mockStore) + nCtx.OnDataStore().Return(mockStore) s, err := h.CheckLaunchPlanStatus(ctx, nCtx) assert.Error(t, err) assert.Equal(t, s.Info().GetPhase().String(), handler.EPhaseUndefined.String()) @@ -605,7 +605,7 @@ func TestLaunchPlanHandler_HandleAbort(t *testing.T) { } mockWf := &mocks2.ExecutableWorkflow{} mockWf.On("GetName").Return("test") - mockWf.On("GetNodeExecutionStatus", nodeID).Return(mockNodeStatus) + mockWf.OnGetNodeExecutionStatus(ctx, nodeID).Return(mockNodeStatus) mockWf.On("GetExecutionID").Return(v1alpha1.WorkflowExecutionIdentifier{ WorkflowExecutionIdentifier: parentID, }) diff --git a/pkg/controller/nodes/subworkflow/subworkflow.go b/pkg/controller/nodes/subworkflow/subworkflow.go index 1f8cfb594f..7b672ba875 100644 --- a/pkg/controller/nodes/subworkflow/subworkflow.go +++ b/pkg/controller/nodes/subworkflow/subworkflow.go @@ -42,7 +42,7 @@ func (s *subworkflowHandler) DoInlineSubWorkflow(ctx context.Context, nCtx handl // If the WF interface has outputs, validate that the outputs file was written. var oInfo *handler.OutputInfo if outputBindings := w.GetOutputBindings(); len(outputBindings) > 0 { - endNodeStatus := w.GetNodeExecutionStatus(v1alpha1.EndNodeID) + endNodeStatus := w.GetNodeExecutionStatus(ctx, v1alpha1.EndNodeID) store := nCtx.DataStore() if endNodeStatus == nil { return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(errors.SubWorkflowExecutionFailed, "No end node found in subworkflow.", nil)), err @@ -113,7 +113,7 @@ func (s *subworkflowHandler) StartSubWorkflow(ctx context.Context, nCtx handler. } w := nCtx.Workflow() - status := w.GetNodeExecutionStatus(node.GetID()) + status := w.GetNodeExecutionStatus(ctx, node.GetID()) contextualSubWorkflow := executors.NewSubContextualWorkflow(w, subWorkflow, status) startNode := contextualSubWorkflow.StartNode() if startNode == nil { @@ -123,10 +123,9 @@ func (s *subworkflowHandler) StartSubWorkflow(ctx context.Context, nCtx handler. // Before starting the subworkflow, lets set the inputs for the Workflow. The inputs for a SubWorkflow are essentially // Copy of the inputs to the Node - nodeStatus := contextualSubWorkflow.GetNodeExecutionStatus(startNode.GetID()) + nodeStatus := contextualSubWorkflow.GetNodeExecutionStatus(ctx, startNode.GetID()) if len(nodeStatus.GetDataDir()) == 0 { - store := nCtx.DataStore() - dataDir, err := contextualSubWorkflow.GetExecutionStatus().ConstructNodeDataDir(ctx, store, startNode.GetID()) + dataDir, err := contextualSubWorkflow.GetExecutionStatus().ConstructNodeDataDir(ctx, startNode.GetID()) if err != nil { err = errors2.Wrapf(err, "Failed to create metadata store key. Error [%v]", err) return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoUndefined), err @@ -171,12 +170,12 @@ func (s *subworkflowHandler) CheckSubWorkflowStatus(ctx context.Context, nCtx ha return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(errors.SubWorkflowExecutionFailed, errMsg, nil)), nil } - parentNodeStatus := w.GetNodeExecutionStatus(nCtx.NodeID()) + parentNodeStatus := w.GetNodeExecutionStatus(ctx, nCtx.NodeID()) return s.DoInlineSubWorkflow(ctx, nCtx, contextualSubWorkflow, parentNodeStatus, startNode) } func (s *subworkflowHandler) HandleSubWorkflowFailingNode(ctx context.Context, nCtx handler.NodeExecutionContext, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode) (handler.Transition, error) { - status := w.GetNodeExecutionStatus(node.GetID()) + status := w.GetNodeExecutionStatus(ctx, node.GetID()) subID := *node.GetWorkflowNode().GetSubWorkflowRef() subWorkflow := w.FindSubWorkflow(subID) if subWorkflow == nil { @@ -193,7 +192,7 @@ func (s *subworkflowHandler) HandleAbort(ctx context.Context, nCtx handler.NodeE return fmt.Errorf("no sub workflow [%s] found in node [%s]", workflowID, nCtx.NodeID()) } - nodeStatus := w.GetNodeExecutionStatus(nCtx.NodeID()) + nodeStatus := w.GetNodeExecutionStatus(ctx, nCtx.NodeID()) contextualSubWorkflow := executors.NewSubContextualWorkflow(w, subWorkflow, nodeStatus) startNode := w.StartNode() diff --git a/pkg/controller/nodes/task/k8s/plugin_manager.go b/pkg/controller/nodes/task/k8s/plugin_manager.go index b2f45e630b..8d364008a6 100644 --- a/pkg/controller/nodes/task/k8s/plugin_manager.go +++ b/pkg/controller/nodes/task/k8s/plugin_manager.go @@ -321,6 +321,7 @@ func (e *PluginManager) Finalize(ctx context.Context, tCtx pluginsCore.TaskExecu logger.Errorf(ctx, "Failed to build the Resource with name: %v. Error: %v, when finalizing.", tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), err) return nil } + AddObjectMetadata(tCtx.TaskExecutionMetadata(), o, config.GetK8sPluginConfig()) nsName := k8stypes.NamespacedName{Namespace: o.GetNamespace(), Name: o.GetName()} // Attempt to get resource from informer cache, if not found, retrieve it from API server. diff --git a/pkg/controller/nodes/task/resourcemanager/config/config_flags.go b/pkg/controller/nodes/task/resourcemanager/config/config_flags.go index 832d417bb0..624361a37d 100755 --- a/pkg/controller/nodes/task/resourcemanager/config/config_flags.go +++ b/pkg/controller/nodes/task/resourcemanager/config/config_flags.go @@ -41,10 +41,10 @@ func (Config) mustMarshalJSON(v json.Marshaler) string { // flags is json-name.json-sub-name... etc. func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags := pflag.NewFlagSet("Config", pflag.ExitOnError) - cmdFlags.String(fmt.Sprintf("%v%v", prefix, "resourceManagerType"), defaultConfig.Type, "Which resource manager to use") - cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "resourceQuota"), defaultConfig.ResourceMaxQuota, "Global limit for concurrent Qubole queries") - cmdFlags.String(fmt.Sprintf("%v%v", prefix, "redisConfig.hostPath"), defaultConfig.RedisConfig.HostPath, "Redis host location") - cmdFlags.String(fmt.Sprintf("%v%v", prefix, "redisConfig.hostKey"), defaultConfig.RedisConfig.HostKey, "Key for local Redis access") - cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "redisConfig.maxRetries"), defaultConfig.RedisConfig.MaxRetries, "See Redis client options for more info") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "type"), defaultConfig.Type, "Which resource manager to use") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "resourceMaxQuota"), defaultConfig.ResourceMaxQuota, "Global limit for concurrent Qubole queries") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "redis.hostPath"), defaultConfig.RedisConfig.HostPath, "Redis host location") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "redis.hostKey"), defaultConfig.RedisConfig.HostKey, "Key for local Redis access") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "redis.maxRetries"), defaultConfig.RedisConfig.MaxRetries, "See Redis client options for more info") return cmdFlags } diff --git a/pkg/controller/nodes/task/resourcemanager/config/config_flags_test.go b/pkg/controller/nodes/task/resourcemanager/config/config_flags_test.go index fb661c3538..a65ba6e2f0 100755 --- a/pkg/controller/nodes/task/resourcemanager/config/config_flags_test.go +++ b/pkg/controller/nodes/task/resourcemanager/config/config_flags_test.go @@ -99,10 +99,10 @@ func TestConfig_SetFlags(t *testing.T) { cmdFlags := actual.GetPFlagSet("") assert.True(t, cmdFlags.HasFlags()) - t.Run("Test_resourceManagerType", func(t *testing.T) { + t.Run("Test_type", func(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly - if vString, err := cmdFlags.GetString("resourceManagerType"); err == nil { + if vString, err := cmdFlags.GetString("type"); err == nil { assert.Equal(t, string(defaultConfig.Type), vString) } else { assert.FailNow(t, err.Error()) @@ -112,8 +112,8 @@ func TestConfig_SetFlags(t *testing.T) { t.Run("Override", func(t *testing.T) { testValue := "1" - cmdFlags.Set("resourceManagerType", testValue) - if vString, err := cmdFlags.GetString("resourceManagerType"); err == nil { + cmdFlags.Set("type", testValue) + if vString, err := cmdFlags.GetString("type"); err == nil { testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.Type) } else { @@ -121,10 +121,10 @@ func TestConfig_SetFlags(t *testing.T) { } }) }) - t.Run("Test_resourceQuota", func(t *testing.T) { + t.Run("Test_resourceMaxQuota", func(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly - if vInt, err := cmdFlags.GetInt("resourceQuota"); err == nil { + if vInt, err := cmdFlags.GetInt("resourceMaxQuota"); err == nil { assert.Equal(t, int(defaultConfig.ResourceMaxQuota), vInt) } else { assert.FailNow(t, err.Error()) @@ -134,8 +134,8 @@ func TestConfig_SetFlags(t *testing.T) { t.Run("Override", func(t *testing.T) { testValue := "1" - cmdFlags.Set("resourceQuota", testValue) - if vInt, err := cmdFlags.GetInt("resourceQuota"); err == nil { + cmdFlags.Set("resourceMaxQuota", testValue) + if vInt, err := cmdFlags.GetInt("resourceMaxQuota"); err == nil { testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.ResourceMaxQuota) } else { @@ -143,10 +143,10 @@ func TestConfig_SetFlags(t *testing.T) { } }) }) - t.Run("Test_redisConfig.hostPath", func(t *testing.T) { + t.Run("Test_redis.hostPath", func(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly - if vString, err := cmdFlags.GetString("redisConfig.hostPath"); err == nil { + if vString, err := cmdFlags.GetString("redis.hostPath"); err == nil { assert.Equal(t, string(defaultConfig.RedisConfig.HostPath), vString) } else { assert.FailNow(t, err.Error()) @@ -156,8 +156,8 @@ func TestConfig_SetFlags(t *testing.T) { t.Run("Override", func(t *testing.T) { testValue := "1" - cmdFlags.Set("redisConfig.hostPath", testValue) - if vString, err := cmdFlags.GetString("redisConfig.hostPath"); err == nil { + cmdFlags.Set("redis.hostPath", testValue) + if vString, err := cmdFlags.GetString("redis.hostPath"); err == nil { testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.RedisConfig.HostPath) } else { @@ -165,10 +165,10 @@ func TestConfig_SetFlags(t *testing.T) { } }) }) - t.Run("Test_redisConfig.hostKey", func(t *testing.T) { + t.Run("Test_redis.hostKey", func(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly - if vString, err := cmdFlags.GetString("redisConfig.hostKey"); err == nil { + if vString, err := cmdFlags.GetString("redis.hostKey"); err == nil { assert.Equal(t, string(defaultConfig.RedisConfig.HostKey), vString) } else { assert.FailNow(t, err.Error()) @@ -178,8 +178,8 @@ func TestConfig_SetFlags(t *testing.T) { t.Run("Override", func(t *testing.T) { testValue := "1" - cmdFlags.Set("redisConfig.hostKey", testValue) - if vString, err := cmdFlags.GetString("redisConfig.hostKey"); err == nil { + cmdFlags.Set("redis.hostKey", testValue) + if vString, err := cmdFlags.GetString("redis.hostKey"); err == nil { testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.RedisConfig.HostKey) } else { @@ -187,10 +187,10 @@ func TestConfig_SetFlags(t *testing.T) { } }) }) - t.Run("Test_redisConfig.maxRetries", func(t *testing.T) { + t.Run("Test_redis.maxRetries", func(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly - if vInt, err := cmdFlags.GetInt("redisConfig.maxRetries"); err == nil { + if vInt, err := cmdFlags.GetInt("redis.maxRetries"); err == nil { assert.Equal(t, int(defaultConfig.RedisConfig.MaxRetries), vInt) } else { assert.FailNow(t, err.Error()) @@ -200,8 +200,8 @@ func TestConfig_SetFlags(t *testing.T) { t.Run("Override", func(t *testing.T) { testValue := "1" - cmdFlags.Set("redisConfig.maxRetries", testValue) - if vInt, err := cmdFlags.GetInt("redisConfig.maxRetries"); err == nil { + cmdFlags.Set("redis.maxRetries", testValue) + if vInt, err := cmdFlags.GetInt("redis.maxRetries"); err == nil { testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.RedisConfig.MaxRetries) } else { diff --git a/pkg/controller/workflow/executor.go b/pkg/controller/workflow/executor.go index 8b4dbeb1b4..7bf67a57ff 100644 --- a/pkg/controller/workflow/executor.go +++ b/pkg/controller/workflow/executor.go @@ -90,7 +90,7 @@ func (c *workflowExecutor) handleReadyWorkflow(ctx context.Context, w *v1alpha1. } // Before starting the subworkflow, lets set the inputs for the Workflow. The inputs for a SubWorkflow are essentially // Copy of the inputs to the Node - nodeStatus := w.GetNodeExecutionStatus(startNode.GetID()) + nodeStatus := w.GetNodeExecutionStatus(ctx, startNode.GetID()) dataDir, err := c.store.ConstructReference(ctx, ref, startNode.GetID(), "data") if err != nil { return StatusFailing(errors.Wrapf(errors.CausedByError, w.GetID(), err, "failed to create metadata prefix for start node.")), nil @@ -168,7 +168,7 @@ func (c *workflowExecutor) handleFailingWorkflow(ctx context.Context, w *v1alpha func (c *workflowExecutor) handleSucceedingWorkflow(ctx context.Context, w *v1alpha1.FlyteWorkflow) Status { logger.Infof(ctx, "Workflow completed successfully") - endNodeStatus := w.GetNodeExecutionStatus(v1alpha1.EndNodeID) + endNodeStatus := w.GetNodeExecutionStatus(ctx, v1alpha1.EndNodeID) if endNodeStatus.GetPhase() == v1alpha1.NodePhaseSucceeded { if endNodeStatus.GetDataDir() != "" { w.Status.SetOutputReference(v1alpha1.GetOutputsFile(endNodeStatus.GetDataDir())) @@ -240,7 +240,7 @@ func (c *workflowExecutor) TransitionToPhase(ctx context.Context, execID *core.W c.metrics.FailureDuration.Observe(ctx, wStatus.GetStartedAt().Time, wStatus.GetStoppedAt().Time) case v1alpha1.WorkflowPhaseSucceeding: wfEvent.Phase = core.WorkflowExecution_SUCCEEDING - endNodeStatus := wStatus.GetNodeExecutionStatus(v1alpha1.EndNodeID) + endNodeStatus := wStatus.GetNodeExecutionStatus(ctx, v1alpha1.EndNodeID) // Workflow completion latency is recorded as the time it takes for the workflow to transition from end // node started time to workflow success being sent to the control plane. if endNodeStatus != nil && endNodeStatus.GetStartedAt() != nil { @@ -295,6 +295,8 @@ func (c *workflowExecutor) HandleFlyteWorkflow(ctx context.Context, w *v1alpha1. logger.Infof(ctx, "Handling Workflow [%s], id: [%s], p [%s]", w.GetName(), w.GetExecutionID(), w.GetExecutionStatus().GetPhase().String()) defer logger.Infof(ctx, "Handling Workflow [%s] Done", w.GetName()) + w.DataReferenceConstructor = c.store + wStatus := w.GetExecutionStatus() // Initialize the Status if not already initialized switch wStatus.GetPhase() { @@ -352,6 +354,8 @@ func (c *workflowExecutor) HandleFlyteWorkflow(ctx context.Context, w *v1alpha1. } func (c *workflowExecutor) HandleAbortedWorkflow(ctx context.Context, w *v1alpha1.FlyteWorkflow, maxRetries uint32) error { + w.DataReferenceConstructor = c.store + if !w.Status.IsTerminated() { reason := "User initiated workflow abort." c.metrics.IncompleteWorkflowAborted.Inc(ctx) diff --git a/pkg/controller/workflow/executor_test.go b/pkg/controller/workflow/executor_test.go index b205ef3f90..319ce2089c 100644 --- a/pkg/controller/workflow/executor_test.go +++ b/pkg/controller/workflow/executor_test.go @@ -314,7 +314,7 @@ func TestWorkflowExecutor_HandleFlyteWorkflow(t *testing.T) { w := &v1alpha1.FlyteWorkflow{} if assert.NoError(t, json.Unmarshal(wJSON, w)) { // For benchmark workflow, we know how many rounds it needs - // Number of rounds = 14 + // Number of rounds = 28 // + WF (x1) // | start-node: Succeeded, successfully completed | (x1) // | add-one-and-print-0: Succeeded, completed successfully || add-one-and-print-3: Succeeded, completed successfully || print-every-time-0: Succeeded, completed successfully | (x3) @@ -323,7 +323,7 @@ func TestWorkflowExecutor_HandleFlyteWorkflow(t *testing.T) { // | add-one-and-print-2: Succeeded, completed successfully | (x3) // + WF (x2) // Also there is some overlap - for i := 0; i < 14; i++ { + for i := 0; i < 28; i++ { err := executor.HandleFlyteWorkflow(ctx, w) if err != nil { t.Log(err) @@ -544,8 +544,8 @@ func TestWorkflowExecutor_HandleFlyteWorkflow_Events(t *testing.T) { w := &v1alpha1.FlyteWorkflow{} if assert.NoError(t, json.Unmarshal(wJSON, w)) { // For benchmark workflow, we know how many rounds it needs - // Number of rounds = 14 ? - for i := 0; i < 14; i++ { + // Number of rounds = 28 ? + for i := 0; i < 28; i++ { err := executor.HandleFlyteWorkflow(ctx, w) assert.NoError(t, err) fmt.Printf("Round[%d] Workflow[%v]\n", i, w.Status.Phase.String()) @@ -642,5 +642,4 @@ func TestWorkflowExecutor_HandleFlyteWorkflow_EventFailure(t *testing.T) { assert.Error(t, err) assert.True(t, wfErrors.Matches(err, wfErrors.EventRecordingError)) }) - } diff --git a/pkg/utils/encoder.go b/pkg/utils/encoder.go index 8d92d92f76..5f49368d89 100644 --- a/pkg/utils/encoder.go +++ b/pkg/utils/encoder.go @@ -38,7 +38,7 @@ func FixedLengthUniqueID(inputID string, maxLength int) (string, error) { func FixedLengthUniqueIDForParts(maxLength int, parts ...string) (string, error) { b := strings.Builder{} for i, p := range parts { - if i > 0 { + if i > 0 && b.Len() > 0 { _, err := b.WriteRune('-') if err != nil { return "", err