diff --git a/flytepropeller/cmd/kubectl-flyte/cmd/testdata/array-node-cache-serialize.yaml.golden b/flytepropeller/cmd/kubectl-flyte/cmd/testdata/array-node-cache-serialize.yaml.golden new file mode 100644 index 0000000000..9eff401240 --- /dev/null +++ b/flytepropeller/cmd/kubectl-flyte/cmd/testdata/array-node-cache-serialize.yaml.golden @@ -0,0 +1,91 @@ +tasks: +- container: + args: + - "pyflyte-fast-execute" + - "--additional-distribution" + - "s3://my-s3-bucket/flytesnacks/development/SMJBJX7BQJ6MCOABLKQT5VZXVY======/script_mode.tar.gz" + - "--dest-dir" + - "/root" + - "--" + - "pyflyte-map-execute" + - "--inputs" + - "{{.input}}" + - "--output-prefix" + - "{{.outputPrefix}}" + - "--raw-output-data-prefix" + - "{{.rawOutputDataPrefix}}" + - "--checkpoint-path" + - "{{.checkpointOutputPrefix}}" + - "--prev-checkpoint" + - "{{.prevCheckpointPrefix}}" + - "--resolver" + - "MapTaskResolver" + - "--" + - "vars" + - "" + - "resolver" + - "flytekit.core.python_auto_container.default_task_resolver" + - "task-module" + - "map-task" + - "task-name" + - "a_mappable_task" + image: "array-node:ee1ba227aa95447d04bb1761691b4d97749642dc" + resources: + limits: + - name: 1 + value: "1" + - name: 3 + value: "500Mi" + requests: + - name: 1 + value: "1" + - name: 3 + value: "300Mi" + id: + name: task-1 + project: flytesnacks + domain: development + metadata: + discoverable: true + discovery_version: "1.0" + cache_serializable: true + interface: + inputs: + variables: + a: + type: + simple: INTEGER + outputs: + variables: + o0: + type: + simple: STRING +workflow: + id: + name: workflow-with-array-node + interface: + inputs: + variables: + x: + type: + collectionType: + simple: INTEGER + nodes: + - id: node-1 + inputs: + - binding: + promise: + node_id: start-node + var: x + var: a + arrayNode: + parallelism: 0 + node: + metadata: + retries: + retries: 3 + taskNode: + referenceId: + name: task-1 + project: flytesnacks + domain: development diff --git a/flytepropeller/cmd/kubectl-flyte/cmd/testdata/array-node-cache.yaml.golden b/flytepropeller/cmd/kubectl-flyte/cmd/testdata/array-node-cache.yaml.golden new file mode 100644 index 0000000000..bb07a9dd57 --- /dev/null +++ b/flytepropeller/cmd/kubectl-flyte/cmd/testdata/array-node-cache.yaml.golden @@ -0,0 +1,91 @@ +tasks: +- container: + args: + - "pyflyte-fast-execute" + - "--additional-distribution" + - "s3://my-s3-bucket/flytesnacks/development/SMJBJX7BQJ6MCOABLKQT5VZXVY======/script_mode.tar.gz" + - "--dest-dir" + - "/root" + - "--" + - "pyflyte-map-execute" + - "--inputs" + - "{{.input}}" + - "--output-prefix" + - "{{.outputPrefix}}" + - "--raw-output-data-prefix" + - "{{.rawOutputDataPrefix}}" + - "--checkpoint-path" + - "{{.checkpointOutputPrefix}}" + - "--prev-checkpoint" + - "{{.prevCheckpointPrefix}}" + - "--resolver" + - "MapTaskResolver" + - "--" + - "vars" + - "" + - "resolver" + - "flytekit.core.python_auto_container.default_task_resolver" + - "task-module" + - "map-task" + - "task-name" + - "a_mappable_task" + image: "array-node:ee1ba227aa95447d04bb1761691b4d97749642dc" + resources: + limits: + - name: 1 + value: "1" + - name: 3 + value: "500Mi" + requests: + - name: 1 + value: "1" + - name: 3 + value: "300Mi" + id: + name: task-1 + project: flytesnacks + domain: development + metadata: + discoverable: true + discovery_version: "1.0" + cache_serializable: false + interface: + inputs: + variables: + a: + type: + simple: INTEGER + outputs: + variables: + o0: + type: + simple: STRING +workflow: + id: + name: workflow-with-array-node + interface: + inputs: + variables: + x: + type: + collectionType: + simple: INTEGER + nodes: + - id: node-1 + inputs: + - binding: + promise: + node_id: start-node + var: x + var: a + arrayNode: + parallelism: 1 + node: + metadata: + retries: + retries: 3 + taskNode: + referenceId: + name: task-1 + project: flytesnacks + domain: development diff --git a/flytepropeller/cmd/kubectl-flyte/cmd/testdata/array-node-inputs.yaml.golden b/flytepropeller/cmd/kubectl-flyte/cmd/testdata/array-node-inputs.yaml.golden new file mode 100755 index 0000000000..42e1766866 --- /dev/null +++ b/flytepropeller/cmd/kubectl-flyte/cmd/testdata/array-node-inputs.yaml.golden @@ -0,0 +1,13 @@ +literals: + "x": + collection: + literals: + - scalar: + primitive: + integer: "1" + - scalar: + primitive: + integer: "2" + - scalar: + primitive: + integer: "3" diff --git a/flytepropeller/cmd/kubectl-flyte/cmd/testdata/array-node.yaml.golden b/flytepropeller/cmd/kubectl-flyte/cmd/testdata/array-node.yaml.golden new file mode 100644 index 0000000000..0b370417aa --- /dev/null +++ b/flytepropeller/cmd/kubectl-flyte/cmd/testdata/array-node.yaml.golden @@ -0,0 +1,86 @@ +tasks: +- container: + args: + - "pyflyte-fast-execute" + - "--additional-distribution" + - "s3://my-s3-bucket/flytesnacks/development/SMJBJX7BQJ6MCOABLKQT5VZXVY======/script_mode.tar.gz" + - "--dest-dir" + - "/root" + - "--" + - "pyflyte-map-execute" + - "--inputs" + - "{{.input}}" + - "--output-prefix" + - "{{.outputPrefix}}" + - "--raw-output-data-prefix" + - "{{.rawOutputDataPrefix}}" + - "--checkpoint-path" + - "{{.checkpointOutputPrefix}}" + - "--prev-checkpoint" + - "{{.prevCheckpointPrefix}}" + - "--resolver" + - "MapTaskResolver" + - "--" + - "vars" + - "" + - "resolver" + - "flytekit.core.python_auto_container.default_task_resolver" + - "task-module" + - "map-task" + - "task-name" + - "a_mappable_task" + image: "array-node:ee1ba227aa95447d04bb1761691b4d97749642dc" + resources: + limits: + - name: 1 + value: "1" + - name: 3 + value: "500Mi" + requests: + - name: 1 + value: "1" + - name: 3 + value: "300Mi" + id: + name: task-1 + metadata: + discoverable: false + cache_serializable: false + interface: + inputs: + variables: + a: + type: + simple: INTEGER + outputs: + variables: + x: + type: + simple: STRING +workflow: + id: + name: workflow-with-array-node + interface: + inputs: + variables: + x: + type: + collectionType: + simple: INTEGER + nodes: + - id: node-1 + inputs: + - binding: + promise: + node_id: start-node + var: x + var: a + arrayNode: + parallelism: 1 + node: + metadata: + retries: + retries: 3 + taskNode: + referenceId: + name: task-1 diff --git a/flytepropeller/events/event_recorder.go b/flytepropeller/events/event_recorder.go index b07cc412b6..7366cd1cf4 100644 --- a/flytepropeller/events/event_recorder.go +++ b/flytepropeller/events/event_recorder.go @@ -13,7 +13,7 @@ import ( "github.com/golang/protobuf/proto" ) -const maxErrorMessageLength = 104857600 //100KB +const MaxErrorMessageLength = 104857600 //100KB const truncationIndicator = "... ..." type recordingMetrics struct { @@ -60,7 +60,7 @@ func (r *eventRecorder) sinkEvent(ctx context.Context, event proto.Message) erro func (r *eventRecorder) RecordNodeEvent(ctx context.Context, e *event.NodeExecutionEvent) error { if err, ok := e.GetOutputResult().(*event.NodeExecutionEvent_Error); ok { - truncateErrorMessage(err.Error, maxErrorMessageLength) + truncateErrorMessage(err.Error, MaxErrorMessageLength) } return r.sinkEvent(ctx, e) @@ -68,7 +68,7 @@ func (r *eventRecorder) RecordNodeEvent(ctx context.Context, e *event.NodeExecut func (r *eventRecorder) RecordTaskEvent(ctx context.Context, e *event.TaskExecutionEvent) error { if err, ok := e.GetOutputResult().(*event.TaskExecutionEvent_Error); ok { - truncateErrorMessage(err.Error, maxErrorMessageLength) + truncateErrorMessage(err.Error, MaxErrorMessageLength) } return r.sinkEvent(ctx, e) @@ -76,7 +76,7 @@ func (r *eventRecorder) RecordTaskEvent(ctx context.Context, e *event.TaskExecut func (r *eventRecorder) RecordWorkflowEvent(ctx context.Context, e *event.WorkflowExecutionEvent) error { if err, ok := e.GetOutputResult().(*event.WorkflowExecutionEvent_Error); ok { - truncateErrorMessage(err.Error, maxErrorMessageLength) + truncateErrorMessage(err.Error, MaxErrorMessageLength) } return r.sinkEvent(ctx, e) diff --git a/flytepropeller/go.mod b/flytepropeller/go.mod index 3da1b2d77d..c35782dc3e 100644 --- a/flytepropeller/go.mod +++ b/flytepropeller/go.mod @@ -146,3 +146,5 @@ require ( ) replace github.com/aws/amazon-sagemaker-operator-for-k8s => github.com/aws/amazon-sagemaker-operator-for-k8s v1.0.1-0.20210303003444-0fb33b1fd49d + +replace github.com/flyteorg/flyteidl => github.com/flyteorg/flyteidl v1.5.11-0.20230614183933-d56d4d37bf34 diff --git a/flytepropeller/go.sum b/flytepropeller/go.sum index ebafdba07a..3fa0a9b27b 100644 --- a/flytepropeller/go.sum +++ b/flytepropeller/go.sum @@ -242,8 +242,8 @@ github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5Kwzbycv github.com/fatih/color v1.13.0 h1:8LOYc1KYPPmyKMuN8QV2DNRWNbLo6LZ0iLs8+mlH53w= github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk= github.com/felixge/httpsnoop v1.0.1/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= -github.com/flyteorg/flyteidl v1.5.10 h1:SHeiaWRt8EAVuFsat+BJswtc07HTZ4DqhfTEYSm621k= -github.com/flyteorg/flyteidl v1.5.10/go.mod h1:EtE/muM2lHHgBabjYcxqe9TWeJSL0kXwbI0RgVwI4Og= +github.com/flyteorg/flyteidl v1.5.11-0.20230614183933-d56d4d37bf34 h1:Gj5UKqJU+ozeTeYAvDWHiF4HSVufHW1W1ecymFfbbis= +github.com/flyteorg/flyteidl v1.5.11-0.20230614183933-d56d4d37bf34/go.mod h1:EtE/muM2lHHgBabjYcxqe9TWeJSL0kXwbI0RgVwI4Og= github.com/flyteorg/flyteplugins v1.1.8 h1:UVYdqDdcIqz2JIso+m3MsaPSsTZJZyZQ6Eg7nhX9r/Y= github.com/flyteorg/flyteplugins v1.1.8/go.mod h1:sRxeatEOHq1b9bTxTRNcwoIkVTAVN9dTz8toXkfcz2E= github.com/flyteorg/flytestdlib v1.0.20 h1:BrCQMlpdrFAPlADFJvCyn7gm+37df9WGYqLEB1mOlCQ= diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/array.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/array.go new file mode 100644 index 0000000000..6680e74106 --- /dev/null +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/array.go @@ -0,0 +1,24 @@ +package v1alpha1 + +type ArrayNodeSpec struct { + SubNodeSpec *NodeSpec + Parallelism uint32 + MinSuccesses *uint32 + MinSuccessRatio *float32 +} + +func (a *ArrayNodeSpec) GetSubNodeSpec() *NodeSpec { + return a.SubNodeSpec +} + +func (a *ArrayNodeSpec) GetParallelism() uint32 { + return a.Parallelism +} + +func (a *ArrayNodeSpec) GetMinSuccesses() *uint32 { + return a.MinSuccesses +} + +func (a *ArrayNodeSpec) GetMinSuccessRatio() *float32 { + return a.MinSuccessRatio +} diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/iface.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/iface.go index 33aa857d5b..f5200669bd 100644 --- a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/iface.go +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/iface.go @@ -13,6 +13,7 @@ import ( "k8s.io/apimachinery/pkg/types" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flytestdlib/bitarray" "github.com/flyteorg/flytestdlib/storage" ) @@ -45,6 +46,7 @@ const ( NodeKindBranch NodeKind = "branch" // A Branch node with conditions NodeKindWorkflow NodeKind = "workflow" // Either an inline workflow or a remote workflow definition NodeKindGate NodeKind = "gate" // A Gate node with a condition + NodeKindArray NodeKind = "array" // An array node with a subtask Node NodeKindStart NodeKind = "start" // Start node is a special node NodeKindEnd NodeKind = "end" ) @@ -255,6 +257,13 @@ type ExecutableGateNode interface { GetSleep() *core.SleepCondition } +type ExecutableArrayNode interface { + GetSubNodeSpec() *NodeSpec + GetParallelism() uint32 + GetMinSuccesses() *uint32 + GetMinSuccessRatio() *float32 +} + type ExecutableWorkflowNodeStatus interface { GetWorkflowNodePhase() WorkflowNodePhase GetExecutionError() *core.ExecutionError @@ -277,6 +286,28 @@ type MutableGateNodeStatus interface { SetGateNodePhase(phase GateNodePhase) } +type ExecutableArrayNodeStatus interface { + GetArrayNodePhase() ArrayNodePhase + GetExecutionError() *core.ExecutionError + GetSubNodePhases() bitarray.CompactArray + GetSubNodeTaskPhases() bitarray.CompactArray + GetSubNodeRetryAttempts() bitarray.CompactArray + GetSubNodeSystemFailures() bitarray.CompactArray + GetTaskPhaseVersion() uint32 +} + +type MutableArrayNodeStatus interface { + Mutable + ExecutableArrayNodeStatus + SetArrayNodePhase(phase ArrayNodePhase) + SetExecutionError(executionError *core.ExecutionError) + SetSubNodePhases(subNodePhases bitarray.CompactArray) + SetSubNodeTaskPhases(subNodeTaskPhases bitarray.CompactArray) + SetSubNodeRetryAttempts(subNodeRetryAttempts bitarray.CompactArray) + SetSubNodeSystemFailures(subNodeSystemFailures bitarray.CompactArray) + SetTaskPhaseVersion(taskPhaseVersion uint32) +} + type Mutable interface { IsDirty() bool } @@ -312,6 +343,10 @@ type MutableNodeStatus interface { GetGateNodeStatus() MutableGateNodeStatus GetOrCreateGateNodeStatus() MutableGateNodeStatus ClearGateNodeStatus() + + GetArrayNodeStatus() MutableArrayNodeStatus + GetOrCreateArrayNodeStatus() MutableArrayNodeStatus + ClearArrayNodeStatus() } type ExecutionTimeInfo interface { @@ -397,6 +432,7 @@ type ExecutableNode interface { GetBranchNode() ExecutableBranchNode GetWorkflowNode() ExecutableWorkflowNode GetGateNode() ExecutableGateNode + GetArrayNode() ExecutableArrayNode GetOutputAlias() []Alias GetInputBindings() []*Binding GetResources() *v1.ResourceRequirements diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableArrayNode.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableArrayNode.go new file mode 100644 index 0000000000..fb200ff066 --- /dev/null +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableArrayNode.go @@ -0,0 +1,147 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import ( + v1alpha1 "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + mock "github.com/stretchr/testify/mock" +) + +// ExecutableArrayNode is an autogenerated mock type for the ExecutableArrayNode type +type ExecutableArrayNode struct { + mock.Mock +} + +type ExecutableArrayNode_GetMinSuccessRatio struct { + *mock.Call +} + +func (_m ExecutableArrayNode_GetMinSuccessRatio) Return(_a0 *float32) *ExecutableArrayNode_GetMinSuccessRatio { + return &ExecutableArrayNode_GetMinSuccessRatio{Call: _m.Call.Return(_a0)} +} + +func (_m *ExecutableArrayNode) OnGetMinSuccessRatio() *ExecutableArrayNode_GetMinSuccessRatio { + c_call := _m.On("GetMinSuccessRatio") + return &ExecutableArrayNode_GetMinSuccessRatio{Call: c_call} +} + +func (_m *ExecutableArrayNode) OnGetMinSuccessRatioMatch(matchers ...interface{}) *ExecutableArrayNode_GetMinSuccessRatio { + c_call := _m.On("GetMinSuccessRatio", matchers...) + return &ExecutableArrayNode_GetMinSuccessRatio{Call: c_call} +} + +// GetMinSuccessRatio provides a mock function with given fields: +func (_m *ExecutableArrayNode) GetMinSuccessRatio() *float32 { + ret := _m.Called() + + var r0 *float32 + if rf, ok := ret.Get(0).(func() *float32); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*float32) + } + } + + return r0 +} + +type ExecutableArrayNode_GetMinSuccesses struct { + *mock.Call +} + +func (_m ExecutableArrayNode_GetMinSuccesses) Return(_a0 *uint32) *ExecutableArrayNode_GetMinSuccesses { + return &ExecutableArrayNode_GetMinSuccesses{Call: _m.Call.Return(_a0)} +} + +func (_m *ExecutableArrayNode) OnGetMinSuccesses() *ExecutableArrayNode_GetMinSuccesses { + c_call := _m.On("GetMinSuccesses") + return &ExecutableArrayNode_GetMinSuccesses{Call: c_call} +} + +func (_m *ExecutableArrayNode) OnGetMinSuccessesMatch(matchers ...interface{}) *ExecutableArrayNode_GetMinSuccesses { + c_call := _m.On("GetMinSuccesses", matchers...) + return &ExecutableArrayNode_GetMinSuccesses{Call: c_call} +} + +// GetMinSuccesses provides a mock function with given fields: +func (_m *ExecutableArrayNode) GetMinSuccesses() *uint32 { + ret := _m.Called() + + var r0 *uint32 + if rf, ok := ret.Get(0).(func() *uint32); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*uint32) + } + } + + return r0 +} + +type ExecutableArrayNode_GetParallelism struct { + *mock.Call +} + +func (_m ExecutableArrayNode_GetParallelism) Return(_a0 uint32) *ExecutableArrayNode_GetParallelism { + return &ExecutableArrayNode_GetParallelism{Call: _m.Call.Return(_a0)} +} + +func (_m *ExecutableArrayNode) OnGetParallelism() *ExecutableArrayNode_GetParallelism { + c_call := _m.On("GetParallelism") + return &ExecutableArrayNode_GetParallelism{Call: c_call} +} + +func (_m *ExecutableArrayNode) OnGetParallelismMatch(matchers ...interface{}) *ExecutableArrayNode_GetParallelism { + c_call := _m.On("GetParallelism", matchers...) + return &ExecutableArrayNode_GetParallelism{Call: c_call} +} + +// GetParallelism provides a mock function with given fields: +func (_m *ExecutableArrayNode) GetParallelism() uint32 { + ret := _m.Called() + + var r0 uint32 + if rf, ok := ret.Get(0).(func() uint32); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(uint32) + } + + return r0 +} + +type ExecutableArrayNode_GetSubNodeSpec struct { + *mock.Call +} + +func (_m ExecutableArrayNode_GetSubNodeSpec) Return(_a0 *v1alpha1.NodeSpec) *ExecutableArrayNode_GetSubNodeSpec { + return &ExecutableArrayNode_GetSubNodeSpec{Call: _m.Call.Return(_a0)} +} + +func (_m *ExecutableArrayNode) OnGetSubNodeSpec() *ExecutableArrayNode_GetSubNodeSpec { + c_call := _m.On("GetSubNodeSpec") + return &ExecutableArrayNode_GetSubNodeSpec{Call: c_call} +} + +func (_m *ExecutableArrayNode) OnGetSubNodeSpecMatch(matchers ...interface{}) *ExecutableArrayNode_GetSubNodeSpec { + c_call := _m.On("GetSubNodeSpec", matchers...) + return &ExecutableArrayNode_GetSubNodeSpec{Call: c_call} +} + +// GetSubNodeSpec provides a mock function with given fields: +func (_m *ExecutableArrayNode) GetSubNodeSpec() *v1alpha1.NodeSpec { + ret := _m.Called() + + var r0 *v1alpha1.NodeSpec + if rf, ok := ret.Get(0).(func() *v1alpha1.NodeSpec); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*v1alpha1.NodeSpec) + } + } + + return r0 +} diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableArrayNodeStatus.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableArrayNodeStatus.go new file mode 100644 index 0000000000..08de9e29c0 --- /dev/null +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableArrayNodeStatus.go @@ -0,0 +1,243 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import ( + core "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + bitarray "github.com/flyteorg/flytestdlib/bitarray" + + mock "github.com/stretchr/testify/mock" + + v1alpha1 "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" +) + +// ExecutableArrayNodeStatus is an autogenerated mock type for the ExecutableArrayNodeStatus type +type ExecutableArrayNodeStatus struct { + mock.Mock +} + +type ExecutableArrayNodeStatus_GetArrayNodePhase struct { + *mock.Call +} + +func (_m ExecutableArrayNodeStatus_GetArrayNodePhase) Return(_a0 v1alpha1.ArrayNodePhase) *ExecutableArrayNodeStatus_GetArrayNodePhase { + return &ExecutableArrayNodeStatus_GetArrayNodePhase{Call: _m.Call.Return(_a0)} +} + +func (_m *ExecutableArrayNodeStatus) OnGetArrayNodePhase() *ExecutableArrayNodeStatus_GetArrayNodePhase { + c_call := _m.On("GetArrayNodePhase") + return &ExecutableArrayNodeStatus_GetArrayNodePhase{Call: c_call} +} + +func (_m *ExecutableArrayNodeStatus) OnGetArrayNodePhaseMatch(matchers ...interface{}) *ExecutableArrayNodeStatus_GetArrayNodePhase { + c_call := _m.On("GetArrayNodePhase", matchers...) + return &ExecutableArrayNodeStatus_GetArrayNodePhase{Call: c_call} +} + +// GetArrayNodePhase provides a mock function with given fields: +func (_m *ExecutableArrayNodeStatus) GetArrayNodePhase() v1alpha1.ArrayNodePhase { + ret := _m.Called() + + var r0 v1alpha1.ArrayNodePhase + if rf, ok := ret.Get(0).(func() v1alpha1.ArrayNodePhase); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(v1alpha1.ArrayNodePhase) + } + + return r0 +} + +type ExecutableArrayNodeStatus_GetExecutionError struct { + *mock.Call +} + +func (_m ExecutableArrayNodeStatus_GetExecutionError) Return(_a0 *core.ExecutionError) *ExecutableArrayNodeStatus_GetExecutionError { + return &ExecutableArrayNodeStatus_GetExecutionError{Call: _m.Call.Return(_a0)} +} + +func (_m *ExecutableArrayNodeStatus) OnGetExecutionError() *ExecutableArrayNodeStatus_GetExecutionError { + c_call := _m.On("GetExecutionError") + return &ExecutableArrayNodeStatus_GetExecutionError{Call: c_call} +} + +func (_m *ExecutableArrayNodeStatus) OnGetExecutionErrorMatch(matchers ...interface{}) *ExecutableArrayNodeStatus_GetExecutionError { + c_call := _m.On("GetExecutionError", matchers...) + return &ExecutableArrayNodeStatus_GetExecutionError{Call: c_call} +} + +// GetExecutionError provides a mock function with given fields: +func (_m *ExecutableArrayNodeStatus) GetExecutionError() *core.ExecutionError { + ret := _m.Called() + + var r0 *core.ExecutionError + if rf, ok := ret.Get(0).(func() *core.ExecutionError); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*core.ExecutionError) + } + } + + return r0 +} + +type ExecutableArrayNodeStatus_GetSubNodePhases struct { + *mock.Call +} + +func (_m ExecutableArrayNodeStatus_GetSubNodePhases) Return(_a0 bitarray.CompactArray) *ExecutableArrayNodeStatus_GetSubNodePhases { + return &ExecutableArrayNodeStatus_GetSubNodePhases{Call: _m.Call.Return(_a0)} +} + +func (_m *ExecutableArrayNodeStatus) OnGetSubNodePhases() *ExecutableArrayNodeStatus_GetSubNodePhases { + c_call := _m.On("GetSubNodePhases") + return &ExecutableArrayNodeStatus_GetSubNodePhases{Call: c_call} +} + +func (_m *ExecutableArrayNodeStatus) OnGetSubNodePhasesMatch(matchers ...interface{}) *ExecutableArrayNodeStatus_GetSubNodePhases { + c_call := _m.On("GetSubNodePhases", matchers...) + return &ExecutableArrayNodeStatus_GetSubNodePhases{Call: c_call} +} + +// GetSubNodePhases provides a mock function with given fields: +func (_m *ExecutableArrayNodeStatus) GetSubNodePhases() bitarray.CompactArray { + ret := _m.Called() + + var r0 bitarray.CompactArray + if rf, ok := ret.Get(0).(func() bitarray.CompactArray); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bitarray.CompactArray) + } + + return r0 +} + +type ExecutableArrayNodeStatus_GetSubNodeRetryAttempts struct { + *mock.Call +} + +func (_m ExecutableArrayNodeStatus_GetSubNodeRetryAttempts) Return(_a0 bitarray.CompactArray) *ExecutableArrayNodeStatus_GetSubNodeRetryAttempts { + return &ExecutableArrayNodeStatus_GetSubNodeRetryAttempts{Call: _m.Call.Return(_a0)} +} + +func (_m *ExecutableArrayNodeStatus) OnGetSubNodeRetryAttempts() *ExecutableArrayNodeStatus_GetSubNodeRetryAttempts { + c_call := _m.On("GetSubNodeRetryAttempts") + return &ExecutableArrayNodeStatus_GetSubNodeRetryAttempts{Call: c_call} +} + +func (_m *ExecutableArrayNodeStatus) OnGetSubNodeRetryAttemptsMatch(matchers ...interface{}) *ExecutableArrayNodeStatus_GetSubNodeRetryAttempts { + c_call := _m.On("GetSubNodeRetryAttempts", matchers...) + return &ExecutableArrayNodeStatus_GetSubNodeRetryAttempts{Call: c_call} +} + +// GetSubNodeRetryAttempts provides a mock function with given fields: +func (_m *ExecutableArrayNodeStatus) GetSubNodeRetryAttempts() bitarray.CompactArray { + ret := _m.Called() + + var r0 bitarray.CompactArray + if rf, ok := ret.Get(0).(func() bitarray.CompactArray); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bitarray.CompactArray) + } + + return r0 +} + +type ExecutableArrayNodeStatus_GetSubNodeSystemFailures struct { + *mock.Call +} + +func (_m ExecutableArrayNodeStatus_GetSubNodeSystemFailures) Return(_a0 bitarray.CompactArray) *ExecutableArrayNodeStatus_GetSubNodeSystemFailures { + return &ExecutableArrayNodeStatus_GetSubNodeSystemFailures{Call: _m.Call.Return(_a0)} +} + +func (_m *ExecutableArrayNodeStatus) OnGetSubNodeSystemFailures() *ExecutableArrayNodeStatus_GetSubNodeSystemFailures { + c_call := _m.On("GetSubNodeSystemFailures") + return &ExecutableArrayNodeStatus_GetSubNodeSystemFailures{Call: c_call} +} + +func (_m *ExecutableArrayNodeStatus) OnGetSubNodeSystemFailuresMatch(matchers ...interface{}) *ExecutableArrayNodeStatus_GetSubNodeSystemFailures { + c_call := _m.On("GetSubNodeSystemFailures", matchers...) + return &ExecutableArrayNodeStatus_GetSubNodeSystemFailures{Call: c_call} +} + +// GetSubNodeSystemFailures provides a mock function with given fields: +func (_m *ExecutableArrayNodeStatus) GetSubNodeSystemFailures() bitarray.CompactArray { + ret := _m.Called() + + var r0 bitarray.CompactArray + if rf, ok := ret.Get(0).(func() bitarray.CompactArray); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bitarray.CompactArray) + } + + return r0 +} + +type ExecutableArrayNodeStatus_GetSubNodeTaskPhases struct { + *mock.Call +} + +func (_m ExecutableArrayNodeStatus_GetSubNodeTaskPhases) Return(_a0 bitarray.CompactArray) *ExecutableArrayNodeStatus_GetSubNodeTaskPhases { + return &ExecutableArrayNodeStatus_GetSubNodeTaskPhases{Call: _m.Call.Return(_a0)} +} + +func (_m *ExecutableArrayNodeStatus) OnGetSubNodeTaskPhases() *ExecutableArrayNodeStatus_GetSubNodeTaskPhases { + c_call := _m.On("GetSubNodeTaskPhases") + return &ExecutableArrayNodeStatus_GetSubNodeTaskPhases{Call: c_call} +} + +func (_m *ExecutableArrayNodeStatus) OnGetSubNodeTaskPhasesMatch(matchers ...interface{}) *ExecutableArrayNodeStatus_GetSubNodeTaskPhases { + c_call := _m.On("GetSubNodeTaskPhases", matchers...) + return &ExecutableArrayNodeStatus_GetSubNodeTaskPhases{Call: c_call} +} + +// GetSubNodeTaskPhases provides a mock function with given fields: +func (_m *ExecutableArrayNodeStatus) GetSubNodeTaskPhases() bitarray.CompactArray { + ret := _m.Called() + + var r0 bitarray.CompactArray + if rf, ok := ret.Get(0).(func() bitarray.CompactArray); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bitarray.CompactArray) + } + + return r0 +} + +type ExecutableArrayNodeStatus_GetTaskPhaseVersion struct { + *mock.Call +} + +func (_m ExecutableArrayNodeStatus_GetTaskPhaseVersion) Return(_a0 uint32) *ExecutableArrayNodeStatus_GetTaskPhaseVersion { + return &ExecutableArrayNodeStatus_GetTaskPhaseVersion{Call: _m.Call.Return(_a0)} +} + +func (_m *ExecutableArrayNodeStatus) OnGetTaskPhaseVersion() *ExecutableArrayNodeStatus_GetTaskPhaseVersion { + c_call := _m.On("GetTaskPhaseVersion") + return &ExecutableArrayNodeStatus_GetTaskPhaseVersion{Call: c_call} +} + +func (_m *ExecutableArrayNodeStatus) OnGetTaskPhaseVersionMatch(matchers ...interface{}) *ExecutableArrayNodeStatus_GetTaskPhaseVersion { + c_call := _m.On("GetTaskPhaseVersion", matchers...) + return &ExecutableArrayNodeStatus_GetTaskPhaseVersion{Call: c_call} +} + +// GetTaskPhaseVersion provides a mock function with given fields: +func (_m *ExecutableArrayNodeStatus) GetTaskPhaseVersion() uint32 { + ret := _m.Called() + + var r0 uint32 + if rf, ok := ret.Get(0).(func() uint32); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(uint32) + } + + return r0 +} diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableNode.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableNode.go index 5fbd946fae..a6f1432077 100644 --- a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableNode.go +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableNode.go @@ -51,6 +51,40 @@ func (_m *ExecutableNode) GetActiveDeadline() *time.Duration { return r0 } +type ExecutableNode_GetArrayNode struct { + *mock.Call +} + +func (_m ExecutableNode_GetArrayNode) Return(_a0 v1alpha1.ExecutableArrayNode) *ExecutableNode_GetArrayNode { + return &ExecutableNode_GetArrayNode{Call: _m.Call.Return(_a0)} +} + +func (_m *ExecutableNode) OnGetArrayNode() *ExecutableNode_GetArrayNode { + c_call := _m.On("GetArrayNode") + return &ExecutableNode_GetArrayNode{Call: c_call} +} + +func (_m *ExecutableNode) OnGetArrayNodeMatch(matchers ...interface{}) *ExecutableNode_GetArrayNode { + c_call := _m.On("GetArrayNode", matchers...) + return &ExecutableNode_GetArrayNode{Call: c_call} +} + +// GetArrayNode provides a mock function with given fields: +func (_m *ExecutableNode) GetArrayNode() v1alpha1.ExecutableArrayNode { + ret := _m.Called() + + var r0 v1alpha1.ExecutableArrayNode + if rf, ok := ret.Get(0).(func() v1alpha1.ExecutableArrayNode); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.ExecutableArrayNode) + } + } + + return r0 +} + type ExecutableNode_GetBranchNode struct { *mock.Call } diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableNodeStatus.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableNodeStatus.go index 346680cfa9..886aa217cc 100644 --- a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableNodeStatus.go +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableNodeStatus.go @@ -20,6 +20,11 @@ type ExecutableNodeStatus struct { mock.Mock } +// ClearArrayNodeStatus provides a mock function with given fields: +func (_m *ExecutableNodeStatus) ClearArrayNodeStatus() { + _m.Called() +} + // ClearDynamicNodeStatus provides a mock function with given fields: func (_m *ExecutableNodeStatus) ClearDynamicNodeStatus() { _m.Called() @@ -50,6 +55,40 @@ func (_m *ExecutableNodeStatus) ClearWorkflowStatus() { _m.Called() } +type ExecutableNodeStatus_GetArrayNodeStatus struct { + *mock.Call +} + +func (_m ExecutableNodeStatus_GetArrayNodeStatus) Return(_a0 v1alpha1.MutableArrayNodeStatus) *ExecutableNodeStatus_GetArrayNodeStatus { + return &ExecutableNodeStatus_GetArrayNodeStatus{Call: _m.Call.Return(_a0)} +} + +func (_m *ExecutableNodeStatus) OnGetArrayNodeStatus() *ExecutableNodeStatus_GetArrayNodeStatus { + c_call := _m.On("GetArrayNodeStatus") + return &ExecutableNodeStatus_GetArrayNodeStatus{Call: c_call} +} + +func (_m *ExecutableNodeStatus) OnGetArrayNodeStatusMatch(matchers ...interface{}) *ExecutableNodeStatus_GetArrayNodeStatus { + c_call := _m.On("GetArrayNodeStatus", matchers...) + return &ExecutableNodeStatus_GetArrayNodeStatus{Call: c_call} +} + +// GetArrayNodeStatus provides a mock function with given fields: +func (_m *ExecutableNodeStatus) GetArrayNodeStatus() v1alpha1.MutableArrayNodeStatus { + ret := _m.Called() + + var r0 v1alpha1.MutableArrayNodeStatus + if rf, ok := ret.Get(0).(func() v1alpha1.MutableArrayNodeStatus); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.MutableArrayNodeStatus) + } + } + + return r0 +} + type ExecutableNodeStatus_GetAttempts struct { *mock.Call } @@ -384,6 +423,40 @@ func (_m *ExecutableNodeStatus) GetNodeExecutionStatus(ctx context.Context, id s return r0 } +type ExecutableNodeStatus_GetOrCreateArrayNodeStatus struct { + *mock.Call +} + +func (_m ExecutableNodeStatus_GetOrCreateArrayNodeStatus) Return(_a0 v1alpha1.MutableArrayNodeStatus) *ExecutableNodeStatus_GetOrCreateArrayNodeStatus { + return &ExecutableNodeStatus_GetOrCreateArrayNodeStatus{Call: _m.Call.Return(_a0)} +} + +func (_m *ExecutableNodeStatus) OnGetOrCreateArrayNodeStatus() *ExecutableNodeStatus_GetOrCreateArrayNodeStatus { + c_call := _m.On("GetOrCreateArrayNodeStatus") + return &ExecutableNodeStatus_GetOrCreateArrayNodeStatus{Call: c_call} +} + +func (_m *ExecutableNodeStatus) OnGetOrCreateArrayNodeStatusMatch(matchers ...interface{}) *ExecutableNodeStatus_GetOrCreateArrayNodeStatus { + c_call := _m.On("GetOrCreateArrayNodeStatus", matchers...) + return &ExecutableNodeStatus_GetOrCreateArrayNodeStatus{Call: c_call} +} + +// GetOrCreateArrayNodeStatus provides a mock function with given fields: +func (_m *ExecutableNodeStatus) GetOrCreateArrayNodeStatus() v1alpha1.MutableArrayNodeStatus { + ret := _m.Called() + + var r0 v1alpha1.MutableArrayNodeStatus + if rf, ok := ret.Get(0).(func() v1alpha1.MutableArrayNodeStatus); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.MutableArrayNodeStatus) + } + } + + return r0 +} + type ExecutableNodeStatus_GetOrCreateBranchStatus struct { *mock.Call } diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableArrayNodeStatus.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableArrayNodeStatus.go new file mode 100644 index 0000000000..e052187cce --- /dev/null +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableArrayNodeStatus.go @@ -0,0 +1,310 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import ( + core "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + bitarray "github.com/flyteorg/flytestdlib/bitarray" + + mock "github.com/stretchr/testify/mock" + + v1alpha1 "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" +) + +// MutableArrayNodeStatus is an autogenerated mock type for the MutableArrayNodeStatus type +type MutableArrayNodeStatus struct { + mock.Mock +} + +type MutableArrayNodeStatus_GetArrayNodePhase struct { + *mock.Call +} + +func (_m MutableArrayNodeStatus_GetArrayNodePhase) Return(_a0 v1alpha1.ArrayNodePhase) *MutableArrayNodeStatus_GetArrayNodePhase { + return &MutableArrayNodeStatus_GetArrayNodePhase{Call: _m.Call.Return(_a0)} +} + +func (_m *MutableArrayNodeStatus) OnGetArrayNodePhase() *MutableArrayNodeStatus_GetArrayNodePhase { + c_call := _m.On("GetArrayNodePhase") + return &MutableArrayNodeStatus_GetArrayNodePhase{Call: c_call} +} + +func (_m *MutableArrayNodeStatus) OnGetArrayNodePhaseMatch(matchers ...interface{}) *MutableArrayNodeStatus_GetArrayNodePhase { + c_call := _m.On("GetArrayNodePhase", matchers...) + return &MutableArrayNodeStatus_GetArrayNodePhase{Call: c_call} +} + +// GetArrayNodePhase provides a mock function with given fields: +func (_m *MutableArrayNodeStatus) GetArrayNodePhase() v1alpha1.ArrayNodePhase { + ret := _m.Called() + + var r0 v1alpha1.ArrayNodePhase + if rf, ok := ret.Get(0).(func() v1alpha1.ArrayNodePhase); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(v1alpha1.ArrayNodePhase) + } + + return r0 +} + +type MutableArrayNodeStatus_GetExecutionError struct { + *mock.Call +} + +func (_m MutableArrayNodeStatus_GetExecutionError) Return(_a0 *core.ExecutionError) *MutableArrayNodeStatus_GetExecutionError { + return &MutableArrayNodeStatus_GetExecutionError{Call: _m.Call.Return(_a0)} +} + +func (_m *MutableArrayNodeStatus) OnGetExecutionError() *MutableArrayNodeStatus_GetExecutionError { + c_call := _m.On("GetExecutionError") + return &MutableArrayNodeStatus_GetExecutionError{Call: c_call} +} + +func (_m *MutableArrayNodeStatus) OnGetExecutionErrorMatch(matchers ...interface{}) *MutableArrayNodeStatus_GetExecutionError { + c_call := _m.On("GetExecutionError", matchers...) + return &MutableArrayNodeStatus_GetExecutionError{Call: c_call} +} + +// GetExecutionError provides a mock function with given fields: +func (_m *MutableArrayNodeStatus) GetExecutionError() *core.ExecutionError { + ret := _m.Called() + + var r0 *core.ExecutionError + if rf, ok := ret.Get(0).(func() *core.ExecutionError); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*core.ExecutionError) + } + } + + return r0 +} + +type MutableArrayNodeStatus_GetSubNodePhases struct { + *mock.Call +} + +func (_m MutableArrayNodeStatus_GetSubNodePhases) Return(_a0 bitarray.CompactArray) *MutableArrayNodeStatus_GetSubNodePhases { + return &MutableArrayNodeStatus_GetSubNodePhases{Call: _m.Call.Return(_a0)} +} + +func (_m *MutableArrayNodeStatus) OnGetSubNodePhases() *MutableArrayNodeStatus_GetSubNodePhases { + c_call := _m.On("GetSubNodePhases") + return &MutableArrayNodeStatus_GetSubNodePhases{Call: c_call} +} + +func (_m *MutableArrayNodeStatus) OnGetSubNodePhasesMatch(matchers ...interface{}) *MutableArrayNodeStatus_GetSubNodePhases { + c_call := _m.On("GetSubNodePhases", matchers...) + return &MutableArrayNodeStatus_GetSubNodePhases{Call: c_call} +} + +// GetSubNodePhases provides a mock function with given fields: +func (_m *MutableArrayNodeStatus) GetSubNodePhases() bitarray.CompactArray { + ret := _m.Called() + + var r0 bitarray.CompactArray + if rf, ok := ret.Get(0).(func() bitarray.CompactArray); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bitarray.CompactArray) + } + + return r0 +} + +type MutableArrayNodeStatus_GetSubNodeRetryAttempts struct { + *mock.Call +} + +func (_m MutableArrayNodeStatus_GetSubNodeRetryAttempts) Return(_a0 bitarray.CompactArray) *MutableArrayNodeStatus_GetSubNodeRetryAttempts { + return &MutableArrayNodeStatus_GetSubNodeRetryAttempts{Call: _m.Call.Return(_a0)} +} + +func (_m *MutableArrayNodeStatus) OnGetSubNodeRetryAttempts() *MutableArrayNodeStatus_GetSubNodeRetryAttempts { + c_call := _m.On("GetSubNodeRetryAttempts") + return &MutableArrayNodeStatus_GetSubNodeRetryAttempts{Call: c_call} +} + +func (_m *MutableArrayNodeStatus) OnGetSubNodeRetryAttemptsMatch(matchers ...interface{}) *MutableArrayNodeStatus_GetSubNodeRetryAttempts { + c_call := _m.On("GetSubNodeRetryAttempts", matchers...) + return &MutableArrayNodeStatus_GetSubNodeRetryAttempts{Call: c_call} +} + +// GetSubNodeRetryAttempts provides a mock function with given fields: +func (_m *MutableArrayNodeStatus) GetSubNodeRetryAttempts() bitarray.CompactArray { + ret := _m.Called() + + var r0 bitarray.CompactArray + if rf, ok := ret.Get(0).(func() bitarray.CompactArray); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bitarray.CompactArray) + } + + return r0 +} + +type MutableArrayNodeStatus_GetSubNodeSystemFailures struct { + *mock.Call +} + +func (_m MutableArrayNodeStatus_GetSubNodeSystemFailures) Return(_a0 bitarray.CompactArray) *MutableArrayNodeStatus_GetSubNodeSystemFailures { + return &MutableArrayNodeStatus_GetSubNodeSystemFailures{Call: _m.Call.Return(_a0)} +} + +func (_m *MutableArrayNodeStatus) OnGetSubNodeSystemFailures() *MutableArrayNodeStatus_GetSubNodeSystemFailures { + c_call := _m.On("GetSubNodeSystemFailures") + return &MutableArrayNodeStatus_GetSubNodeSystemFailures{Call: c_call} +} + +func (_m *MutableArrayNodeStatus) OnGetSubNodeSystemFailuresMatch(matchers ...interface{}) *MutableArrayNodeStatus_GetSubNodeSystemFailures { + c_call := _m.On("GetSubNodeSystemFailures", matchers...) + return &MutableArrayNodeStatus_GetSubNodeSystemFailures{Call: c_call} +} + +// GetSubNodeSystemFailures provides a mock function with given fields: +func (_m *MutableArrayNodeStatus) GetSubNodeSystemFailures() bitarray.CompactArray { + ret := _m.Called() + + var r0 bitarray.CompactArray + if rf, ok := ret.Get(0).(func() bitarray.CompactArray); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bitarray.CompactArray) + } + + return r0 +} + +type MutableArrayNodeStatus_GetSubNodeTaskPhases struct { + *mock.Call +} + +func (_m MutableArrayNodeStatus_GetSubNodeTaskPhases) Return(_a0 bitarray.CompactArray) *MutableArrayNodeStatus_GetSubNodeTaskPhases { + return &MutableArrayNodeStatus_GetSubNodeTaskPhases{Call: _m.Call.Return(_a0)} +} + +func (_m *MutableArrayNodeStatus) OnGetSubNodeTaskPhases() *MutableArrayNodeStatus_GetSubNodeTaskPhases { + c_call := _m.On("GetSubNodeTaskPhases") + return &MutableArrayNodeStatus_GetSubNodeTaskPhases{Call: c_call} +} + +func (_m *MutableArrayNodeStatus) OnGetSubNodeTaskPhasesMatch(matchers ...interface{}) *MutableArrayNodeStatus_GetSubNodeTaskPhases { + c_call := _m.On("GetSubNodeTaskPhases", matchers...) + return &MutableArrayNodeStatus_GetSubNodeTaskPhases{Call: c_call} +} + +// GetSubNodeTaskPhases provides a mock function with given fields: +func (_m *MutableArrayNodeStatus) GetSubNodeTaskPhases() bitarray.CompactArray { + ret := _m.Called() + + var r0 bitarray.CompactArray + if rf, ok := ret.Get(0).(func() bitarray.CompactArray); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bitarray.CompactArray) + } + + return r0 +} + +type MutableArrayNodeStatus_GetTaskPhaseVersion struct { + *mock.Call +} + +func (_m MutableArrayNodeStatus_GetTaskPhaseVersion) Return(_a0 uint32) *MutableArrayNodeStatus_GetTaskPhaseVersion { + return &MutableArrayNodeStatus_GetTaskPhaseVersion{Call: _m.Call.Return(_a0)} +} + +func (_m *MutableArrayNodeStatus) OnGetTaskPhaseVersion() *MutableArrayNodeStatus_GetTaskPhaseVersion { + c_call := _m.On("GetTaskPhaseVersion") + return &MutableArrayNodeStatus_GetTaskPhaseVersion{Call: c_call} +} + +func (_m *MutableArrayNodeStatus) OnGetTaskPhaseVersionMatch(matchers ...interface{}) *MutableArrayNodeStatus_GetTaskPhaseVersion { + c_call := _m.On("GetTaskPhaseVersion", matchers...) + return &MutableArrayNodeStatus_GetTaskPhaseVersion{Call: c_call} +} + +// GetTaskPhaseVersion provides a mock function with given fields: +func (_m *MutableArrayNodeStatus) GetTaskPhaseVersion() uint32 { + ret := _m.Called() + + var r0 uint32 + if rf, ok := ret.Get(0).(func() uint32); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(uint32) + } + + return r0 +} + +type MutableArrayNodeStatus_IsDirty struct { + *mock.Call +} + +func (_m MutableArrayNodeStatus_IsDirty) Return(_a0 bool) *MutableArrayNodeStatus_IsDirty { + return &MutableArrayNodeStatus_IsDirty{Call: _m.Call.Return(_a0)} +} + +func (_m *MutableArrayNodeStatus) OnIsDirty() *MutableArrayNodeStatus_IsDirty { + c_call := _m.On("IsDirty") + return &MutableArrayNodeStatus_IsDirty{Call: c_call} +} + +func (_m *MutableArrayNodeStatus) OnIsDirtyMatch(matchers ...interface{}) *MutableArrayNodeStatus_IsDirty { + c_call := _m.On("IsDirty", matchers...) + return &MutableArrayNodeStatus_IsDirty{Call: c_call} +} + +// IsDirty provides a mock function with given fields: +func (_m *MutableArrayNodeStatus) IsDirty() bool { + ret := _m.Called() + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// SetArrayNodePhase provides a mock function with given fields: phase +func (_m *MutableArrayNodeStatus) SetArrayNodePhase(phase v1alpha1.ArrayNodePhase) { + _m.Called(phase) +} + +// SetExecutionError provides a mock function with given fields: executionError +func (_m *MutableArrayNodeStatus) SetExecutionError(executionError *core.ExecutionError) { + _m.Called(executionError) +} + +// SetSubNodePhases provides a mock function with given fields: subNodePhases +func (_m *MutableArrayNodeStatus) SetSubNodePhases(subNodePhases bitarray.CompactArray) { + _m.Called(subNodePhases) +} + +// SetSubNodeRetryAttempts provides a mock function with given fields: subNodeRetryAttempts +func (_m *MutableArrayNodeStatus) SetSubNodeRetryAttempts(subNodeRetryAttempts bitarray.CompactArray) { + _m.Called(subNodeRetryAttempts) +} + +// SetSubNodeSystemFailures provides a mock function with given fields: subNodeSystemFailures +func (_m *MutableArrayNodeStatus) SetSubNodeSystemFailures(subNodeSystemFailures bitarray.CompactArray) { + _m.Called(subNodeSystemFailures) +} + +// SetSubNodeTaskPhases provides a mock function with given fields: subNodeTaskPhases +func (_m *MutableArrayNodeStatus) SetSubNodeTaskPhases(subNodeTaskPhases bitarray.CompactArray) { + _m.Called(subNodeTaskPhases) +} + +// SetTaskPhaseVersion provides a mock function with given fields: taskPhaseVersion +func (_m *MutableArrayNodeStatus) SetTaskPhaseVersion(taskPhaseVersion uint32) { + _m.Called(taskPhaseVersion) +} diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableNodeStatus.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableNodeStatus.go index 9bb0f59b2e..56feb9c1bb 100644 --- a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableNodeStatus.go +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableNodeStatus.go @@ -18,6 +18,11 @@ type MutableNodeStatus struct { mock.Mock } +// ClearArrayNodeStatus provides a mock function with given fields: +func (_m *MutableNodeStatus) ClearArrayNodeStatus() { + _m.Called() +} + // ClearDynamicNodeStatus provides a mock function with given fields: func (_m *MutableNodeStatus) ClearDynamicNodeStatus() { _m.Called() @@ -48,6 +53,40 @@ func (_m *MutableNodeStatus) ClearWorkflowStatus() { _m.Called() } +type MutableNodeStatus_GetArrayNodeStatus struct { + *mock.Call +} + +func (_m MutableNodeStatus_GetArrayNodeStatus) Return(_a0 v1alpha1.MutableArrayNodeStatus) *MutableNodeStatus_GetArrayNodeStatus { + return &MutableNodeStatus_GetArrayNodeStatus{Call: _m.Call.Return(_a0)} +} + +func (_m *MutableNodeStatus) OnGetArrayNodeStatus() *MutableNodeStatus_GetArrayNodeStatus { + c_call := _m.On("GetArrayNodeStatus") + return &MutableNodeStatus_GetArrayNodeStatus{Call: c_call} +} + +func (_m *MutableNodeStatus) OnGetArrayNodeStatusMatch(matchers ...interface{}) *MutableNodeStatus_GetArrayNodeStatus { + c_call := _m.On("GetArrayNodeStatus", matchers...) + return &MutableNodeStatus_GetArrayNodeStatus{Call: c_call} +} + +// GetArrayNodeStatus provides a mock function with given fields: +func (_m *MutableNodeStatus) GetArrayNodeStatus() v1alpha1.MutableArrayNodeStatus { + ret := _m.Called() + + var r0 v1alpha1.MutableArrayNodeStatus + if rf, ok := ret.Get(0).(func() v1alpha1.MutableArrayNodeStatus); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.MutableArrayNodeStatus) + } + } + + return r0 +} + type MutableNodeStatus_GetBranchStatus struct { *mock.Call } @@ -150,6 +189,40 @@ func (_m *MutableNodeStatus) GetGateNodeStatus() v1alpha1.MutableGateNodeStatus return r0 } +type MutableNodeStatus_GetOrCreateArrayNodeStatus struct { + *mock.Call +} + +func (_m MutableNodeStatus_GetOrCreateArrayNodeStatus) Return(_a0 v1alpha1.MutableArrayNodeStatus) *MutableNodeStatus_GetOrCreateArrayNodeStatus { + return &MutableNodeStatus_GetOrCreateArrayNodeStatus{Call: _m.Call.Return(_a0)} +} + +func (_m *MutableNodeStatus) OnGetOrCreateArrayNodeStatus() *MutableNodeStatus_GetOrCreateArrayNodeStatus { + c_call := _m.On("GetOrCreateArrayNodeStatus") + return &MutableNodeStatus_GetOrCreateArrayNodeStatus{Call: c_call} +} + +func (_m *MutableNodeStatus) OnGetOrCreateArrayNodeStatusMatch(matchers ...interface{}) *MutableNodeStatus_GetOrCreateArrayNodeStatus { + c_call := _m.On("GetOrCreateArrayNodeStatus", matchers...) + return &MutableNodeStatus_GetOrCreateArrayNodeStatus{Call: c_call} +} + +// GetOrCreateArrayNodeStatus provides a mock function with given fields: +func (_m *MutableNodeStatus) GetOrCreateArrayNodeStatus() v1alpha1.MutableArrayNodeStatus { + ret := _m.Called() + + var r0 v1alpha1.MutableArrayNodeStatus + if rf, ok := ret.Get(0).(func() v1alpha1.MutableArrayNodeStatus); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.MutableArrayNodeStatus) + } + } + + return r0 +} + type MutableNodeStatus_GetOrCreateBranchStatus struct { *mock.Call } diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/node_status.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/node_status.go index 14efd4f401..07787d6e77 100644 --- a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/node_status.go +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/node_status.go @@ -9,11 +9,11 @@ import ( "strconv" "time" - "github.com/flyteorg/flytestdlib/storage" - + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flytestdlib/bitarray" "github.com/flyteorg/flytestdlib/logger" + "github.com/flyteorg/flytestdlib/storage" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" ) @@ -219,6 +219,103 @@ func (in *GateNodeStatus) SetGateNodePhase(phase GateNodePhase) { } } +type ArrayNodePhase int + +const ( + ArrayNodePhaseNone ArrayNodePhase = iota + ArrayNodePhaseExecuting + ArrayNodePhaseFailing + ArrayNodePhaseSucceeding +) + +type ArrayNodeStatus struct { + MutableStruct + Phase ArrayNodePhase `json:"phase,omitempty"` + ExecutionError *core.ExecutionError `json:"executionError,omitempty"` + SubNodePhases bitarray.CompactArray `json:"subphase,omitempty"` + SubNodeTaskPhases bitarray.CompactArray `json:"subtphase,omitempty"` + SubNodeRetryAttempts bitarray.CompactArray `json:"subattempts,omitempty"` + SubNodeSystemFailures bitarray.CompactArray `json:"subsysfailures,omitempty"` + TaskPhaseVersion uint32 `json:"taskPhaseVersion,omitempty"` +} + +func (in *ArrayNodeStatus) GetArrayNodePhase() ArrayNodePhase { + return in.Phase +} + +func (in *ArrayNodeStatus) SetArrayNodePhase(phase ArrayNodePhase) { + if in.Phase != phase { + in.SetDirty() + in.Phase = phase + } +} + +func (in *ArrayNodeStatus) GetExecutionError() *core.ExecutionError { + return in.ExecutionError +} + +func (in *ArrayNodeStatus) SetExecutionError(executionError *core.ExecutionError) { + if in.ExecutionError != executionError { + in.SetDirty() + in.ExecutionError = executionError + } +} + +func (in *ArrayNodeStatus) GetSubNodePhases() bitarray.CompactArray { + return in.SubNodePhases +} + +func (in *ArrayNodeStatus) SetSubNodePhases(subNodePhases bitarray.CompactArray) { + if in.SubNodePhases != subNodePhases { + in.SetDirty() + in.SubNodePhases = subNodePhases + } +} + +func (in *ArrayNodeStatus) GetSubNodeTaskPhases() bitarray.CompactArray { + return in.SubNodeTaskPhases +} + +func (in *ArrayNodeStatus) SetSubNodeTaskPhases(subNodeTaskPhases bitarray.CompactArray) { + if in.SubNodeTaskPhases != subNodeTaskPhases { + in.SetDirty() + in.SubNodeTaskPhases = subNodeTaskPhases + } +} + +func (in *ArrayNodeStatus) GetSubNodeRetryAttempts() bitarray.CompactArray { + return in.SubNodeRetryAttempts +} + +func (in *ArrayNodeStatus) SetSubNodeRetryAttempts(subNodeRetryAttempts bitarray.CompactArray) { + if in.SubNodeRetryAttempts != subNodeRetryAttempts { + in.SetDirty() + in.SubNodeRetryAttempts = subNodeRetryAttempts + } +} + +func (in *ArrayNodeStatus) GetSubNodeSystemFailures() bitarray.CompactArray { + return in.SubNodeSystemFailures +} + +func (in *ArrayNodeStatus) SetSubNodeSystemFailures(subNodeSystemFailures bitarray.CompactArray) { + if in.SubNodeSystemFailures != subNodeSystemFailures { + in.SetDirty() + in.SubNodeSystemFailures = subNodeSystemFailures + } +} + +func (in *ArrayNodeStatus) GetTaskPhaseVersion() uint32 { + return in.TaskPhaseVersion +} + +func (in *ArrayNodeStatus) SetTaskPhaseVersion(taskPhaseVersion uint32) { + if in.TaskPhaseVersion != taskPhaseVersion { + in.SetDirty() + in.TaskPhaseVersion = taskPhaseVersion + } +} + type NodeStatus struct { MutableStruct Phase NodePhase `json:"phase,omitempty"` @@ -247,6 +344,7 @@ type NodeStatus struct { TaskNodeStatus *TaskNodeStatus `json:",omitempty"` DynamicNodeStatus *DynamicNodeStatus `json:"dynamicNodeStatus,omitempty"` GateNodeStatus *GateNodeStatus `json:"gateNodeStatus,omitempty"` + ArrayNodeStatus *ArrayNodeStatus `json:"arrayNodeStatus,omitempty"` // In case of Failing/Failed Phase, an execution error can be optionally associated with the Node Error *ExecutionError `json:"error,omitempty"` @@ -259,7 +357,9 @@ func (in *NodeStatus) IsDirty() bool { (in.TaskNodeStatus != nil && in.TaskNodeStatus.IsDirty()) || (in.DynamicNodeStatus != nil && in.DynamicNodeStatus.IsDirty()) || (in.WorkflowNodeStatus != nil && in.WorkflowNodeStatus.IsDirty()) || - (in.BranchStatus != nil && in.BranchStatus.IsDirty()) + (in.BranchStatus != nil && in.BranchStatus.IsDirty()) || + (in.GateNodeStatus != nil && in.GateNodeStatus.IsDirty()) || + (in.ArrayNodeStatus != nil && in.ArrayNodeStatus.IsDirty()) if isDirty { return true } @@ -327,6 +427,13 @@ func (in *NodeStatus) GetGateNodeStatus() MutableGateNodeStatus { return in.GateNodeStatus } +func (in *NodeStatus) GetArrayNodeStatus() MutableArrayNodeStatus { + if in.ArrayNodeStatus == nil { + return nil + } + return in.ArrayNodeStatus +} + func (in NodeStatus) VisitNodeStatuses(visitor NodeStatusVisitFn) { for n, s := range in.SubNodeStatus { visitor(n, s) @@ -365,6 +472,11 @@ func (in *NodeStatus) ClearGateNodeStatus() { in.SetDirty() } +func (in *NodeStatus) ClearArrayNodeStatus() { + in.ArrayNodeStatus = nil + in.SetDirty() +} + func (in *NodeStatus) GetLastUpdatedAt() *metav1.Time { return in.LastUpdatedAt } @@ -471,6 +583,17 @@ func (in *NodeStatus) GetOrCreateGateNodeStatus() MutableGateNodeStatus { return in.GateNodeStatus } +func (in *NodeStatus) GetOrCreateArrayNodeStatus() MutableArrayNodeStatus { + if in.ArrayNodeStatus == nil { + in.SetDirty() + in.ArrayNodeStatus = &ArrayNodeStatus{ + MutableStruct: MutableStruct{}, + } + } + + return in.ArrayNodeStatus +} + func (in *NodeStatus) UpdatePhase(p NodePhase, occurredAt metav1.Time, reason string, err *core.ExecutionError) { if in.Phase == p { // We will not update the phase multiple times. This prevents the comparison from returning false positive diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/nodes.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/nodes.go index 682af365d8..21c8b02610 100644 --- a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/nodes.go +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/nodes.go @@ -101,6 +101,7 @@ type NodeSpec struct { TaskRef *TaskID `json:"task,omitempty"` WorkflowNode *WorkflowNodeSpec `json:"workflow,omitempty"` GateNode *GateNodeSpec `json:"gate,omitempty"` + ArrayNode *ArrayNodeSpec `json:"array,omitempty"` InputBindings []*Binding `json:"inputBindings,omitempty"` Config *typesv1.ConfigMap `json:"config,omitempty"` RetryStrategy *RetryStrategy `json:"retry,omitempty"` @@ -206,6 +207,13 @@ func (in *NodeSpec) GetGateNode() ExecutableGateNode { return in.GateNode } +func (in *NodeSpec) GetArrayNode() ExecutableArrayNode { + if in.ArrayNode == nil { + return nil + } + return in.ArrayNode +} + func (in *NodeSpec) GetTaskID() *TaskID { return in.TaskRef } diff --git a/flytepropeller/pkg/compiler/common/mocks/node.go b/flytepropeller/pkg/compiler/common/mocks/node.go index 364a1921dc..ea9a24df16 100644 --- a/flytepropeller/pkg/compiler/common/mocks/node.go +++ b/flytepropeller/pkg/compiler/common/mocks/node.go @@ -14,6 +14,40 @@ type Node struct { mock.Mock } +type Node_GetArrayNode struct { + *mock.Call +} + +func (_m Node_GetArrayNode) Return(_a0 *core.ArrayNode) *Node_GetArrayNode { + return &Node_GetArrayNode{Call: _m.Call.Return(_a0)} +} + +func (_m *Node) OnGetArrayNode() *Node_GetArrayNode { + c_call := _m.On("GetArrayNode") + return &Node_GetArrayNode{Call: c_call} +} + +func (_m *Node) OnGetArrayNodeMatch(matchers ...interface{}) *Node_GetArrayNode { + c_call := _m.On("GetArrayNode", matchers...) + return &Node_GetArrayNode{Call: c_call} +} + +// GetArrayNode provides a mock function with given fields: +func (_m *Node) GetArrayNode() *core.ArrayNode { + ret := _m.Called() + + var r0 *core.ArrayNode + if rf, ok := ret.Get(0).(func() *core.ArrayNode); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*core.ArrayNode) + } + } + + return r0 +} + type Node_GetBranchNode struct { *mock.Call } diff --git a/flytepropeller/pkg/compiler/common/mocks/node_builder.go b/flytepropeller/pkg/compiler/common/mocks/node_builder.go index 44b320dc9e..9ab7501306 100644 --- a/flytepropeller/pkg/compiler/common/mocks/node_builder.go +++ b/flytepropeller/pkg/compiler/common/mocks/node_builder.go @@ -14,6 +14,40 @@ type NodeBuilder struct { mock.Mock } +type NodeBuilder_GetArrayNode struct { + *mock.Call +} + +func (_m NodeBuilder_GetArrayNode) Return(_a0 *core.ArrayNode) *NodeBuilder_GetArrayNode { + return &NodeBuilder_GetArrayNode{Call: _m.Call.Return(_a0)} +} + +func (_m *NodeBuilder) OnGetArrayNode() *NodeBuilder_GetArrayNode { + c_call := _m.On("GetArrayNode") + return &NodeBuilder_GetArrayNode{Call: c_call} +} + +func (_m *NodeBuilder) OnGetArrayNodeMatch(matchers ...interface{}) *NodeBuilder_GetArrayNode { + c_call := _m.On("GetArrayNode", matchers...) + return &NodeBuilder_GetArrayNode{Call: c_call} +} + +// GetArrayNode provides a mock function with given fields: +func (_m *NodeBuilder) GetArrayNode() *core.ArrayNode { + ret := _m.Called() + + var r0 *core.ArrayNode + if rf, ok := ret.Get(0).(func() *core.ArrayNode); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*core.ArrayNode) + } + } + + return r0 +} + type NodeBuilder_GetBranchNode struct { *mock.Call } diff --git a/flytepropeller/pkg/compiler/common/reader.go b/flytepropeller/pkg/compiler/common/reader.go index d0ea361724..74ee17d40b 100644 --- a/flytepropeller/pkg/compiler/common/reader.go +++ b/flytepropeller/pkg/compiler/common/reader.go @@ -41,6 +41,7 @@ type Node interface { GetTask() Task GetSubWorkflow() Workflow GetGateNode() *core.GateNode + GetArrayNode() *core.ArrayNode } // An immutable task that represents the final output of the compiler. diff --git a/flytepropeller/pkg/compiler/requirements.go b/flytepropeller/pkg/compiler/requirements.go index ab1b11a05c..b9f589ad79 100755 --- a/flytepropeller/pkg/compiler/requirements.go +++ b/flytepropeller/pkg/compiler/requirements.go @@ -86,5 +86,7 @@ func updateNodeRequirements(node *flyteNode, subWfs common.WorkflowIndex, taskId if elseNode := branchN.IfElse.GetElseNode(); elseNode != nil { updateNodeRequirements(elseNode, subWfs, taskIds, workflowIds, followSubworkflows, errs) } + } else if arrayNode := node.GetArrayNode(); arrayNode != nil { + updateNodeRequirements(arrayNode.Node, subWfs, taskIds, workflowIds, followSubworkflows, errs) } } diff --git a/flytepropeller/pkg/compiler/transformers/k8s/node.go b/flytepropeller/pkg/compiler/transformers/k8s/node.go index d9ac41dd49..9bc0f608ed 100644 --- a/flytepropeller/pkg/compiler/transformers/k8s/node.go +++ b/flytepropeller/pkg/compiler/transformers/k8s/node.go @@ -152,6 +152,28 @@ func buildNodeSpec(n *core.Node, tasks []*core.CompiledTask, errs errors.Compile }, } } + case *core.Node_ArrayNode: + arrayNode := n.GetArrayNode() + + // build subNodeSpecs + subNodeSpecs, ok := buildNodeSpec(arrayNode.Node, tasks, errs) + if !ok { + return nil, ok + } + + // build ArrayNode + nodeSpec.Kind = v1alpha1.NodeKindArray + nodeSpec.ArrayNode = &v1alpha1.ArrayNodeSpec{ + SubNodeSpec: subNodeSpecs[0], + Parallelism: arrayNode.Parallelism, + } + + switch successCriteria := arrayNode.SuccessCriteria.(type) { + case *core.ArrayNode_MinSuccesses: + nodeSpec.ArrayNode.MinSuccesses = &successCriteria.MinSuccesses + case *core.ArrayNode_MinSuccessRatio: + nodeSpec.ArrayNode.MinSuccessRatio = &successCriteria.MinSuccessRatio + } default: if n.GetId() == v1alpha1.StartNodeID { nodeSpec.Kind = v1alpha1.NodeKindStart diff --git a/flytepropeller/pkg/compiler/transformers/k8s/node_test.go b/flytepropeller/pkg/compiler/transformers/k8s/node_test.go index a9732d9d70..e879f2cb31 100644 --- a/flytepropeller/pkg/compiler/transformers/k8s/node_test.go +++ b/flytepropeller/pkg/compiler/transformers/k8s/node_test.go @@ -243,6 +243,29 @@ func TestBuildNodeSpec(t *testing.T) { mustBuild(t, n, 1, errs.NewScope()) }) + + t.Run("ArrayNode", func(t *testing.T) { + n.Node.Target = &core.Node_ArrayNode{ + ArrayNode: &core.ArrayNode{ + Node: &core.Node{ + Id: "foo", + Target: &core.Node_TaskNode{ + TaskNode: &core.TaskNode{ + Reference: &core.TaskNode_ReferenceId{ + ReferenceId: &core.Identifier{Name: "ref_1"}, + }, + }, + }, + }, + Parallelism: 10, + SuccessCriteria: &core.ArrayNode_MinSuccessRatio{ + MinSuccessRatio: 0.5, + }, + }, + } + + mustBuild(t, n, 1, errs.NewScope()) + }) } func TestBuildTasks(t *testing.T) { diff --git a/flytepropeller/pkg/compiler/validators/interface.go b/flytepropeller/pkg/compiler/validators/interface.go index 1ae7ecd5b0..cdee66a45f 100644 --- a/flytepropeller/pkg/compiler/validators/interface.go +++ b/flytepropeller/pkg/compiler/validators/interface.go @@ -153,6 +153,14 @@ func ValidateUnderlyingInterface(w c.WorkflowBuilder, node c.NodeBuilder, errs e } else { errs.Collect(errors.NewNoConditionFound(node.GetId())) } + case *core.Node_ArrayNode: + arrayNode := node.GetArrayNode() + underlyingNodeBuilder := w.GetOrCreateNodeBuilder(arrayNode.Node) + if underlyingIface, ok := ValidateUnderlyingInterface(w, underlyingNodeBuilder, errs.NewScope()); ok { + // ArrayNode interface should be inferred from the underlying node interface. flytekit + // will correct wrap variables in collections as needed, leaving partials as is. + iface = underlyingIface + } default: errs.Collect(errors.NewValueRequiredErr(node.GetId(), "Target")) } diff --git a/flytepropeller/pkg/compiler/validators/interface_test.go b/flytepropeller/pkg/compiler/validators/interface_test.go index 5580bd5c62..9a2183ebf4 100644 --- a/flytepropeller/pkg/compiler/validators/interface_test.go +++ b/flytepropeller/pkg/compiler/validators/interface_test.go @@ -1,6 +1,7 @@ package validators import ( + "reflect" "testing" "time" @@ -374,6 +375,90 @@ func TestValidateUnderlyingInterface(t *testing.T) { assertNonEmptyInterface(t, iface, ifaceOk, errs) }) }) + + t.Run("ArrayNode", func(t *testing.T) { + // mock underlying task node + iface := &core.TypedInterface{ + Inputs: &core.VariableMap{ + Variables: map[string]*core.Variable{ + "foo": { + Type: &core.LiteralType{ + Type: &core.LiteralType_Simple{ + Simple: core.SimpleType_INTEGER, + }, + }, + }, + }, + }, + Outputs: &core.VariableMap{ + Variables: map[string]*core.Variable{ + "bar": { + Type: &core.LiteralType{ + Type: &core.LiteralType_Simple{ + Simple: core.SimpleType_FLOAT, + }, + }, + }, + }, + }, + } + + taskNode := &core.Node{ + Id: "node_1", + Target: &core.Node_TaskNode{ + TaskNode: &core.TaskNode{ + Reference: &core.TaskNode_ReferenceId{ + ReferenceId: &core.Identifier{ + Name: "Task_1", + }, + }, + }, + }, + } + + task := mocks.Task{} + task.On("GetInterface").Return(iface) + + taskNodeBuilder := &mocks.NodeBuilder{} + taskNodeBuilder.On("GetCoreNode").Return(taskNode) + taskNodeBuilder.On("GetId").Return(taskNode.Id) + taskNodeBuilder.On("GetTaskNode").Return(taskNode.Target.(*core.Node_TaskNode).TaskNode) + taskNodeBuilder.On("GetInterface").Return(nil) + taskNodeBuilder.On("SetInterface", mock.AnythingOfType("*core.TypedInterface")).Return(nil) + + wfBuilder := mocks.WorkflowBuilder{} + wfBuilder.On("GetTask", mock.MatchedBy(func(id core.Identifier) bool { + return id.String() == (&core.Identifier{ + Name: "Task_1", + }).String() + })).Return(&task, true) + wfBuilder.On("GetOrCreateNodeBuilder", mock.MatchedBy(func(node *core.Node) bool { + return node.Id == "node_1" + })).Return(taskNodeBuilder) + + // mock array node + arrayNode := &core.Node{ + Id: "node_2", + Target: &core.Node_ArrayNode{ + ArrayNode: &core.ArrayNode{ + Node: taskNode, + }, + }, + } + + nodeBuilder := mocks.NodeBuilder{} + nodeBuilder.On("GetArrayNode").Return(arrayNode.Target.(*core.Node_ArrayNode).ArrayNode) + nodeBuilder.On("GetCoreNode").Return(arrayNode) + nodeBuilder.On("GetId").Return(arrayNode.Id) + nodeBuilder.On("GetInterface").Return(nil) + nodeBuilder.On("SetInterface", mock.Anything).Return() + + // compute arrayNode interface + errs := errors.NewCompileErrors() + arrayNodeIface, ifaceOk := ValidateUnderlyingInterface(&wfBuilder, &nodeBuilder, errs.NewScope()) + assertNonEmptyInterface(t, arrayNodeIface, ifaceOk, errs) + assert.True(t, reflect.DeepEqual(arrayNodeIface, iface)) + }) } func matchIdentifier(id core.Identifier) interface{} { diff --git a/flytepropeller/pkg/controller/controller.go b/flytepropeller/pkg/controller/controller.go index d7b8d9c28b..57dcb9ba6f 100644 --- a/flytepropeller/pkg/controller/controller.go +++ b/flytepropeller/pkg/controller/controller.go @@ -26,6 +26,7 @@ import ( "github.com/flyteorg/flytepropeller/pkg/controller/executors" "github.com/flyteorg/flytepropeller/pkg/controller/nodes" errors3 "github.com/flyteorg/flytepropeller/pkg/controller/nodes/errors" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/factory" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/recovery" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/subworkflow/launchplan" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/catalog" @@ -437,9 +438,16 @@ func New(ctx context.Context, cfg *config.Config, kubeclientset kubernetes.Inter controller.levelMonitor = NewResourceLevelMonitor(scope.NewSubScope("collector"), flyteworkflowInformer.Lister()) + recoveryClient := recovery.NewClient(adminClient) + nodeHandlerFactory, err := factory.NewHandlerFactory(ctx, launchPlanActor, launchPlanActor, + kubeClient, catalogClient, recoveryClient, &cfg.EventConfig, cfg.ClusterID, signalClient, scope) + if err != nil { + return nil, errors.Wrapf(err, "failed to create node handler factory") + } + nodeExecutor, err := nodes.NewExecutor(ctx, cfg.NodeConfig, store, controller.enqueueWorkflowForNodeUpdates, eventSink, - launchPlanActor, launchPlanActor, cfg.MaxDatasetSizeBytes, - storage.DataReference(cfg.DefaultRawOutputPrefix), kubeClient, catalogClient, recovery.NewClient(adminClient), &cfg.EventConfig, cfg.ClusterID, signalClient, scope) + launchPlanActor, launchPlanActor, cfg.MaxDatasetSizeBytes, storage.DataReference(cfg.DefaultRawOutputPrefix), kubeClient, + catalogClient, recoveryClient, &cfg.EventConfig, cfg.ClusterID, signalClient, nodeHandlerFactory, scope) if err != nil { return nil, errors.Wrapf(err, "Failed to create Controller.") } diff --git a/flytepropeller/pkg/controller/executors/node_lookup.go b/flytepropeller/pkg/controller/executors/node_lookup.go index 381b832c0e..66fc9bddf7 100644 --- a/flytepropeller/pkg/controller/executors/node_lookup.go +++ b/flytepropeller/pkg/controller/executors/node_lookup.go @@ -12,6 +12,7 @@ import ( type NodeLookup interface { GetNode(nodeID v1alpha1.NodeID) (v1alpha1.ExecutableNode, bool) GetNodeExecutionStatus(ctx context.Context, id v1alpha1.NodeID) v1alpha1.ExecutableNodeStatus + // Lookup for upstream edges, find all node ids from which this node can be reached. ToNode(id v1alpha1.NodeID) ([]v1alpha1.NodeID, error) // Lookup for downstream edges, find all node ids that can be reached from the given node id. diff --git a/flytepropeller/pkg/controller/nodes/array/execution_context.go b/flytepropeller/pkg/controller/nodes/array/execution_context.go new file mode 100644 index 0000000000..6731ec55b1 --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/array/execution_context.go @@ -0,0 +1,48 @@ +package array + +import ( + "strconv" + + "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/flyteorg/flytepropeller/pkg/controller/executors" +) + +const ( + FlyteK8sArrayIndexVarName string = "FLYTE_K8S_ARRAY_INDEX" + JobIndexVarName string = "BATCH_JOB_ARRAY_INDEX_VAR_NAME" +) + +type arrayExecutionContext struct { + executors.ExecutionContext + executionConfig v1alpha1.ExecutionConfig + currentParallelism *uint32 +} + +func (a *arrayExecutionContext) GetExecutionConfig() v1alpha1.ExecutionConfig { + return a.executionConfig +} + +func (a *arrayExecutionContext) CurrentParallelism() uint32 { + return *a.currentParallelism +} + +func (a *arrayExecutionContext) IncrementParallelism() uint32 { + *a.currentParallelism = *a.currentParallelism + 1 + return *a.currentParallelism +} + +func newArrayExecutionContext(executionContext executors.ExecutionContext, subNodeIndex int, currentParallelism *uint32, maxParallelism uint32) *arrayExecutionContext { + executionConfig := executionContext.GetExecutionConfig() + if executionConfig.EnvironmentVariables == nil { + executionConfig.EnvironmentVariables = make(map[string]string) + } + executionConfig.EnvironmentVariables[JobIndexVarName] = FlyteK8sArrayIndexVarName + executionConfig.EnvironmentVariables[FlyteK8sArrayIndexVarName] = strconv.Itoa(subNodeIndex) + executionConfig.MaxParallelism = maxParallelism + + return &arrayExecutionContext{ + ExecutionContext: executionContext, + executionConfig: executionConfig, + currentParallelism: currentParallelism, + } +} diff --git a/flytepropeller/pkg/controller/nodes/array/handler.go b/flytepropeller/pkg/controller/nodes/array/handler.go new file mode 100644 index 0000000000..19641cb93b --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/array/handler.go @@ -0,0 +1,598 @@ +package array + +import ( + "context" + "fmt" + "math" + "strconv" + + idlcore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" + + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/ioutils" + "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/errorcollector" + + "github.com/flyteorg/flytepropeller/events" + "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/flyteorg/flytepropeller/pkg/compiler/validators" + "github.com/flyteorg/flytepropeller/pkg/controller/config" + "github.com/flyteorg/flytepropeller/pkg/controller/executors" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/errors" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/k8s" + + "github.com/flyteorg/flytestdlib/bitarray" + "github.com/flyteorg/flytestdlib/logger" + "github.com/flyteorg/flytestdlib/promutils" + "github.com/flyteorg/flytestdlib/storage" +) + +var ( + nilLiteral = &idlcore.Literal{ + Value: &idlcore.Literal_Scalar{ + Scalar: &idlcore.Scalar{ + Value: &idlcore.Scalar_NoneType{ + NoneType: &idlcore.Void{}, + }, + }, + }, + } +) + +//go:generate mockery -all -case=underscore + +// arrayNodeHandler is a handle implementation for processing array nodes +type arrayNodeHandler struct { + eventConfig *config.EventConfig + metrics metrics + nodeExecutor interfaces.Node + pluginStateBytesNotStarted []byte + pluginStateBytesStarted []byte +} + +// metrics encapsulates the prometheus metrics for this handler +type metrics struct { + scope promutils.Scope +} + +// newMetrics initializes a new metrics struct +func newMetrics(scope promutils.Scope) metrics { + return metrics{ + scope: scope, + } +} + +// Abort stops the array node defined in the NodeExecutionContext +func (a *arrayNodeHandler) Abort(ctx context.Context, nCtx interfaces.NodeExecutionContext, reason string) error { + arrayNode := nCtx.Node().GetArrayNode() + arrayNodeState := nCtx.NodeStateReader().GetArrayNodeState() + + externalResources := make([]*event.ExternalResourceInfo, 0, len(arrayNodeState.SubNodePhases.GetItems())) + messageCollector := errorcollector.NewErrorMessageCollector() + switch arrayNodeState.Phase { + case v1alpha1.ArrayNodePhaseExecuting, v1alpha1.ArrayNodePhaseFailing: + currentParallelism := uint32(0) + for i, nodePhaseUint64 := range arrayNodeState.SubNodePhases.GetItems() { + nodePhase := v1alpha1.NodePhase(nodePhaseUint64) + + // do not process nodes that have not started or are in a terminal state + if nodePhase == v1alpha1.NodePhaseNotYetStarted || isTerminalNodePhase(nodePhase) { + continue + } + + // create array contexts + arrayNodeExecutor, arrayExecutionContext, arrayDAGStructure, arrayNodeLookup, subNodeSpec, _, _, err := + a.buildArrayNodeContext(ctx, nCtx, &arrayNodeState, arrayNode, i, ¤tParallelism) + if err != nil { + return err + } + + // abort subNode + err = arrayNodeExecutor.AbortHandler(ctx, arrayExecutionContext, arrayDAGStructure, arrayNodeLookup, subNodeSpec, reason) + if err != nil { + messageCollector.Collect(i, err.Error()) + } else { + externalResources = append(externalResources, &event.ExternalResourceInfo{ + ExternalId: buildSubNodeID(nCtx, i, 0), + Index: uint32(i), + Logs: nil, + RetryAttempt: 0, + Phase: idlcore.TaskExecution_ABORTED, + }) + } + } + } + + if messageCollector.Length() > 0 { + return fmt.Errorf(messageCollector.Summary(events.MaxErrorMessageLength)) + } + + // update aborted state for subNodes + taskExecutionEvent, err := buildTaskExecutionEvent(ctx, nCtx, idlcore.TaskExecution_ABORTED, 0, externalResources) + if err != nil { + return err + } + + if err := nCtx.EventsRecorder().RecordTaskEvent(ctx, taskExecutionEvent, a.eventConfig); err != nil { + logger.Errorf(ctx, "ArrayNode event recording failed: [%s]", err.Error()) + return err + } + + return nil +} + +// Finalize completes the array node defined in the NodeExecutionContext +func (a *arrayNodeHandler) Finalize(ctx context.Context, nCtx interfaces.NodeExecutionContext) error { + arrayNode := nCtx.Node().GetArrayNode() + arrayNodeState := nCtx.NodeStateReader().GetArrayNodeState() + + messageCollector := errorcollector.NewErrorMessageCollector() + switch arrayNodeState.Phase { + case v1alpha1.ArrayNodePhaseExecuting, v1alpha1.ArrayNodePhaseFailing, v1alpha1.ArrayNodePhaseSucceeding: + currentParallelism := uint32(0) + for i, nodePhaseUint64 := range arrayNodeState.SubNodePhases.GetItems() { + nodePhase := v1alpha1.NodePhase(nodePhaseUint64) + + // do not process nodes that have not started or are in a terminal state + if nodePhase == v1alpha1.NodePhaseNotYetStarted || isTerminalNodePhase(nodePhase) { + continue + } + + // create array contexts + arrayNodeExecutor, arrayExecutionContext, arrayDAGStructure, arrayNodeLookup, subNodeSpec, _, _, err := + a.buildArrayNodeContext(ctx, nCtx, &arrayNodeState, arrayNode, i, ¤tParallelism) + if err != nil { + return err + } + + // finalize subNode + err = arrayNodeExecutor.FinalizeHandler(ctx, arrayExecutionContext, arrayDAGStructure, arrayNodeLookup, subNodeSpec) + if err != nil { + messageCollector.Collect(i, err.Error()) + } + } + } + + if messageCollector.Length() > 0 { + return fmt.Errorf(messageCollector.Summary(events.MaxErrorMessageLength)) + } + + return nil +} + +// FinalizeRequired defines whether or not this handler requires finalize to be called on node +// completion +func (a *arrayNodeHandler) FinalizeRequired() bool { + // must return true because we can't determine if finalize is required for the subNode + return true +} + +// Handle is responsible for transitioning and reporting node state to complete the node defined +// by the NodeExecutionContext +func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecutionContext) (handler.Transition, error) { + arrayNode := nCtx.Node().GetArrayNode() + arrayNodeState := nCtx.NodeStateReader().GetArrayNodeState() + currentArrayNodePhase := arrayNodeState.Phase + + var externalResources []*event.ExternalResourceInfo + taskPhaseVersion := arrayNodeState.TaskPhaseVersion + + switch currentArrayNodePhase { + case v1alpha1.ArrayNodePhaseNone: + // identify and validate array node input value lengths + literalMap, err := nCtx.InputReader().Get(ctx) + if err != nil { + return handler.UnknownTransition, err + } + + size := -1 + for _, variable := range literalMap.Literals { + literalType := validators.LiteralTypeForLiteral(variable) + switch literalType.Type.(type) { + case *idlcore.LiteralType_CollectionType: + collectionLength := len(variable.GetCollection().Literals) + + if size == -1 { + size = collectionLength + } else if size != collectionLength { + return handler.DoTransition(handler.TransitionTypeEphemeral, + handler.PhaseInfoFailure(idlcore.ExecutionError_USER, errors.InvalidArrayLength, + fmt.Sprintf("input arrays have different lengths: expecting '%d' found '%d'", size, collectionLength), nil), + ), nil + } + } + } + + if size == -1 { + return handler.DoTransition(handler.TransitionTypeEphemeral, + handler.PhaseInfoFailure(idlcore.ExecutionError_USER, errors.InvalidArrayLength, "no input array provided", nil), + ), nil + } + + // initialize ArrayNode state + maxAttempts := task.DefaultMaxAttempts + subNodeSpec := *arrayNode.GetSubNodeSpec() + if subNodeSpec.GetRetryStrategy() != nil && subNodeSpec.GetRetryStrategy().MinAttempts != nil { + maxAttempts = *subNodeSpec.GetRetryStrategy().MinAttempts + } + + for _, item := range []struct { + arrayReference *bitarray.CompactArray + maxValue int + }{ + // we use NodePhaseRecovered for the `maxValue` of `SubNodePhases` because `Phase` is + // defined as an `iota` so it is impossible to programmatically get largest value + {arrayReference: &arrayNodeState.SubNodePhases, maxValue: int(v1alpha1.NodePhaseRecovered)}, + {arrayReference: &arrayNodeState.SubNodeTaskPhases, maxValue: len(core.Phases) - 1}, + {arrayReference: &arrayNodeState.SubNodeRetryAttempts, maxValue: maxAttempts}, + {arrayReference: &arrayNodeState.SubNodeSystemFailures, maxValue: maxAttempts}, + } { + + *item.arrayReference, err = bitarray.NewCompactArray(uint(size), bitarray.Item(item.maxValue)) + if err != nil { + return handler.UnknownTransition, err + } + } + + // initialize externalResources + externalResources = make([]*event.ExternalResourceInfo, 0, size) + for i := 0; i < size; i++ { + externalResources = append(externalResources, &event.ExternalResourceInfo{ + ExternalId: buildSubNodeID(nCtx, i, 0), + Index: uint32(i), + Logs: nil, + RetryAttempt: 0, + Phase: idlcore.TaskExecution_QUEUED, + }) + } + + // transition ArrayNode to `ArrayNodePhaseExecuting` + arrayNodeState.Phase = v1alpha1.ArrayNodePhaseExecuting + case v1alpha1.ArrayNodePhaseExecuting: + // process array node subNodes + currentParallelism := uint32(0) + messageCollector := errorcollector.NewErrorMessageCollector() + externalResources = make([]*event.ExternalResourceInfo, 0) + for i, nodePhaseUint64 := range arrayNodeState.SubNodePhases.GetItems() { + nodePhase := v1alpha1.NodePhase(nodePhaseUint64) + + // do not process nodes in terminal state + if isTerminalNodePhase(nodePhase) { + continue + } + + // create array contexts + arrayNodeExecutor, arrayExecutionContext, arrayDAGStructure, arrayNodeLookup, subNodeSpec, subNodeStatus, arrayEventRecorder, err := + a.buildArrayNodeContext(ctx, nCtx, &arrayNodeState, arrayNode, i, ¤tParallelism) + if err != nil { + return handler.UnknownTransition, err + } + + // execute subNode + _, err = arrayNodeExecutor.RecursiveNodeHandler(ctx, arrayExecutionContext, arrayDAGStructure, arrayNodeLookup, subNodeSpec) + if err != nil { + return handler.UnknownTransition, err + } + + // capture subNode error if exists + if subNodeStatus.Error != nil { + messageCollector.Collect(i, subNodeStatus.Error.Message) + } + + // process events + cacheStatus := idlcore.CatalogCacheStatus_CACHE_DISABLED + for _, nodeExecutionEvent := range arrayEventRecorder.NodeEvents() { + switch target := nodeExecutionEvent.TargetMetadata.(type) { + case *event.NodeExecutionEvent_TaskNodeMetadata: + if target.TaskNodeMetadata != nil { + cacheStatus = target.TaskNodeMetadata.CacheStatus + } + } + } + + retryAttempt := subNodeStatus.GetAttempts() + + for _, taskExecutionEvent := range arrayEventRecorder.TaskEvents() { + for _, log := range taskExecutionEvent.Logs { + log.Name = fmt.Sprintf("%s-%d", log.Name, i) + } + + externalResources = append(externalResources, &event.ExternalResourceInfo{ + ExternalId: buildSubNodeID(nCtx, i, retryAttempt), + Index: uint32(i), + Logs: taskExecutionEvent.Logs, + RetryAttempt: retryAttempt, + Phase: taskExecutionEvent.Phase, + CacheStatus: cacheStatus, + }) + } + + // update subNode state + arrayNodeState.SubNodePhases.SetItem(i, uint64(subNodeStatus.GetPhase())) + if subNodeStatus.GetTaskNodeStatus() == nil { + // resetting task phase because during retries we clear the GetTaskNodeStatus + arrayNodeState.SubNodeTaskPhases.SetItem(i, uint64(0)) + } else { + arrayNodeState.SubNodeTaskPhases.SetItem(i, uint64(subNodeStatus.GetTaskNodeStatus().GetPhase())) + } + arrayNodeState.SubNodeRetryAttempts.SetItem(i, uint64(subNodeStatus.GetAttempts())) + arrayNodeState.SubNodeSystemFailures.SetItem(i, uint64(subNodeStatus.GetSystemFailures())) + } + + // process phases of subNodes to determine overall `ArrayNode` phase + successCount := 0 + failedCount := 0 + failingCount := 0 + runningCount := 0 + for _, nodePhaseUint64 := range arrayNodeState.SubNodePhases.GetItems() { + nodePhase := v1alpha1.NodePhase(nodePhaseUint64) + switch nodePhase { + case v1alpha1.NodePhaseSucceeded, v1alpha1.NodePhaseRecovered, v1alpha1.NodePhaseSkipped: + successCount++ + case v1alpha1.NodePhaseFailing: + failingCount++ + case v1alpha1.NodePhaseFailed, v1alpha1.NodePhaseTimedOut: + failedCount++ + default: + runningCount++ + } + } + + // calculate minimum number of successes to succeed the ArrayNode + minSuccesses := len(arrayNodeState.SubNodePhases.GetItems()) + if arrayNode.GetMinSuccesses() != nil { + minSuccesses = int(*arrayNode.GetMinSuccesses()) + } else if minSuccessRatio := arrayNode.GetMinSuccessRatio(); minSuccessRatio != nil { + minSuccesses = int(math.Ceil(float64(*minSuccessRatio) * float64(minSuccesses))) + } + + // if there is a failing node set the error message if it has not been previous set + if failingCount > 0 && arrayNodeState.Error == nil { + arrayNodeState.Error = &idlcore.ExecutionError{ + Message: messageCollector.Summary(events.MaxErrorMessageLength), + } + } + + if len(arrayNodeState.SubNodePhases.GetItems())-failedCount < minSuccesses { + // no chance to reach the mininum number of successes + arrayNodeState.Phase = v1alpha1.ArrayNodePhaseFailing + } else if successCount >= minSuccesses && runningCount == 0 { + // wait until all tasks have completed before declaring success + arrayNodeState.Phase = v1alpha1.ArrayNodePhaseSucceeding + } + case v1alpha1.ArrayNodePhaseFailing: + if err := a.Abort(ctx, nCtx, "ArrayNodeFailing"); err != nil { + return handler.UnknownTransition, err + } + + // fail with reported error if one exists + if arrayNodeState.Error != nil { + return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailureErr(arrayNodeState.Error, nil)), nil + } + + return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure( + idlcore.ExecutionError_UNKNOWN, + "ArrayNodeFailing", + "Unknown reason", + nil, + )), nil + case v1alpha1.ArrayNodePhaseSucceeding: + outputLiterals := make(map[string]*idlcore.Literal) + for i, nodePhaseUint64 := range arrayNodeState.SubNodePhases.GetItems() { + nodePhase := v1alpha1.NodePhase(nodePhaseUint64) + + if nodePhase != v1alpha1.NodePhaseSucceeded { + // retrieve output variables from task template + var outputVariables map[string]*idlcore.Variable + task, err := nCtx.ExecutionContext().GetTask(*arrayNode.GetSubNodeSpec().TaskRef) + if err != nil { + // Should never happen + return handler.UnknownTransition, err + } + + if task.CoreTask() != nil && task.CoreTask().Interface != nil && task.CoreTask().Interface.Outputs != nil { + outputVariables = task.CoreTask().Interface.Outputs.Variables + } + + // append nil literal for all output variables + for name := range outputVariables { + appendLiteral(name, nilLiteral, outputLiterals, len(arrayNodeState.SubNodePhases.GetItems())) + } + } else { + // initialize subNode reader + currentAttempt := uint32(arrayNodeState.SubNodeRetryAttempts.GetItem(i)) + subDataDir, subOutputDir, err := constructOutputReferences(ctx, nCtx, strconv.Itoa(i), strconv.Itoa(int(currentAttempt))) + if err != nil { + return handler.UnknownTransition, err + } + + // checkpoint paths are not computed here because this function is only called when writing + // existing cached outputs. if this functionality changes this will need to be revisited. + outputPaths := ioutils.NewCheckpointRemoteFilePaths(ctx, nCtx.DataStore(), subOutputDir, ioutils.NewRawOutputPaths(ctx, subDataDir), "") + reader := ioutils.NewRemoteFileOutputReader(ctx, nCtx.DataStore(), outputPaths, nCtx.MaxDatasetSizeBytes()) + + // read outputs + outputs, executionErr, err := reader.Read(ctx) + if err != nil { + return handler.UnknownTransition, err + } else if executionErr != nil { + return handler.UnknownTransition, errors.Errorf(errors.IllegalStateError, nCtx.NodeID(), + "execution error ArrayNode output, bad state: %s", executionErr.String()) + } + + // copy individual subNode output literals into a collection of output literals + for name, literal := range outputs.GetLiterals() { + appendLiteral(name, literal, outputLiterals, len(arrayNodeState.SubNodePhases.GetItems())) + } + } + } + + outputLiteralMap := &idlcore.LiteralMap{ + Literals: outputLiterals, + } + + outputFile := v1alpha1.GetOutputsFile(nCtx.NodeStatus().GetOutputDir()) + if err := nCtx.DataStore().WriteProtobuf(ctx, outputFile, storage.Options{}, outputLiteralMap); err != nil { + return handler.UnknownTransition, err + } + + return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoSuccess( + &handler.ExecutionInfo{ + OutputInfo: &handler.OutputInfo{ + OutputURI: outputFile, + }, + }, + )), nil + default: + return handler.UnknownTransition, errors.Errorf(errors.IllegalStateError, nCtx.NodeID(), "invalid ArrayNode phase %+v", arrayNodeState.Phase) + } + + // if there were changes to subNode status externalResources will be populated and must be + // reported to admin through a TaskExecutionEvent. + if len(externalResources) > 0 { + // determine task phase from ArrayNodePhase + taskPhase := idlcore.TaskExecution_UNDEFINED + switch currentArrayNodePhase { + case v1alpha1.ArrayNodePhaseNone: + taskPhase = idlcore.TaskExecution_QUEUED + case v1alpha1.ArrayNodePhaseExecuting: + taskPhase = idlcore.TaskExecution_RUNNING + case v1alpha1.ArrayNodePhaseSucceeding: + taskPhase = idlcore.TaskExecution_SUCCEEDED + case v1alpha1.ArrayNodePhaseFailing: + taskPhase = idlcore.TaskExecution_FAILED + } + + // need to increment taskPhaseVersion if arrayNodeState.Phase does not change, otherwise + // reset to 0. by incrementing this always we report an event and ensure processing + // everytime the ArrayNode is evaluated. if this overhead becomes too large, we will need + // to revisit and only increment when any subNode state changes. + if currentArrayNodePhase != arrayNodeState.Phase { + arrayNodeState.TaskPhaseVersion = 0 + } else { + arrayNodeState.TaskPhaseVersion = taskPhaseVersion + 1 + } + + taskExecutionEvent, err := buildTaskExecutionEvent(ctx, nCtx, taskPhase, taskPhaseVersion, externalResources) + if err != nil { + return handler.UnknownTransition, err + } + + if err := nCtx.EventsRecorder().RecordTaskEvent(ctx, taskExecutionEvent, a.eventConfig); err != nil { + logger.Errorf(ctx, "ArrayNode event recording failed: [%s]", err.Error()) + return handler.UnknownTransition, err + } + } + + // update array node status + if err := nCtx.NodeStateWriter().PutArrayNodeState(arrayNodeState); err != nil { + logger.Errorf(ctx, "failed to store ArrayNode state with err [%s]", err.Error()) + return handler.UnknownTransition, err + } + + return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoRunning(&handler.ExecutionInfo{})), nil +} + +// Setup handles any initialization requirements for this handler +func (a *arrayNodeHandler) Setup(_ context.Context, _ interfaces.SetupContext) error { + return nil +} + +// New initializes a new arrayNodeHandler +func New(nodeExecutor interfaces.Node, eventConfig *config.EventConfig, scope promutils.Scope) (interfaces.NodeHandler, error) { + // create k8s PluginState byte mocks to reuse instead of creating for each subNode evaluation + pluginStateBytesNotStarted, err := bytesFromK8sPluginState(k8s.PluginState{Phase: k8s.PluginPhaseNotStarted}) + if err != nil { + return nil, err + } + + pluginStateBytesStarted, err := bytesFromK8sPluginState(k8s.PluginState{Phase: k8s.PluginPhaseStarted}) + if err != nil { + return nil, err + } + + arrayScope := scope.NewSubScope("array") + return &arrayNodeHandler{ + eventConfig: eventConfig, + metrics: newMetrics(arrayScope), + nodeExecutor: nodeExecutor, + pluginStateBytesNotStarted: pluginStateBytesNotStarted, + pluginStateBytesStarted: pluginStateBytesStarted, + }, nil +} + +// buildArrayNodeContext creates a custom environment to execute the ArrayNode subnode. This is uniquely required for +// the arrayNodeHandler because we require the same node execution entrypoint (ie. recursiveNodeExecutor.RecursiveNodeHandler) +// but need many different execution details, for example setting input values as a singular item rather than a collection, +// injecting environment variables for flytekit maptask execution, aggregating eventing so that rather than tracking state for +// each subnode individually it sends a single event for the whole ArrayNode, and many more. +func (a *arrayNodeHandler) buildArrayNodeContext(ctx context.Context, nCtx interfaces.NodeExecutionContext, arrayNodeState *handler.ArrayNodeState, arrayNode v1alpha1.ExecutableArrayNode, subNodeIndex int, currentParallelism *uint32) ( + interfaces.Node, executors.ExecutionContext, executors.DAGStructure, executors.NodeLookup, *v1alpha1.NodeSpec, *v1alpha1.NodeStatus, *arrayEventRecorder, error) { + + nodePhase := v1alpha1.NodePhase(arrayNodeState.SubNodePhases.GetItem(subNodeIndex)) + taskPhase := int(arrayNodeState.SubNodeTaskPhases.GetItem(subNodeIndex)) + + // need to initialize the inputReader everytime to ensure TaskHandler can access for cache lookups / population + inputLiteralMap, err := constructLiteralMap(ctx, nCtx.InputReader(), subNodeIndex) + if err != nil { + return nil, nil, nil, nil, nil, nil, nil, err + } + + inputReader := newStaticInputReader(nCtx.InputReader(), inputLiteralMap) + + // if node has not yet started we automatically set to NodePhaseQueued to skip input resolution + if nodePhase == v1alpha1.NodePhaseNotYetStarted { + // TODO - to supprt fastcache we'll need to override the bindings to BindingScalars for the input resolution on the nCtx + // that way resolution is just reading a literal ... but does this still write a file then?!? + nodePhase = v1alpha1.NodePhaseQueued + } + + // wrap node lookup + subNodeSpec := *arrayNode.GetSubNodeSpec() + + subNodeID := fmt.Sprintf("%s-n%d", nCtx.NodeID(), subNodeIndex) + subNodeSpec.ID = subNodeID + subNodeSpec.Name = subNodeID + + // TODO - if we want to support more plugin types we need to figure out the best way to store plugin state + // currently just mocking based on node phase -> which works for all k8s plugins + // we can not pre-allocated a bit array because max size is 256B and with 5k fanout node state = 1.28MB + pluginStateBytes := a.pluginStateBytesStarted + if taskPhase == int(core.PhaseUndefined) || taskPhase == int(core.PhaseRetryableFailure) { + pluginStateBytes = a.pluginStateBytesNotStarted + } + + // construct output references + currentAttempt := uint32(arrayNodeState.SubNodeRetryAttempts.GetItem(subNodeIndex)) + subDataDir, subOutputDir, err := constructOutputReferences(ctx, nCtx, strconv.Itoa(subNodeIndex), strconv.Itoa(int(currentAttempt))) + if err != nil { + return nil, nil, nil, nil, nil, nil, nil, err + } + + subNodeStatus := &v1alpha1.NodeStatus{ + Phase: nodePhase, + DataDir: subDataDir, + OutputDir: subOutputDir, + Attempts: currentAttempt, + SystemFailures: uint32(arrayNodeState.SubNodeSystemFailures.GetItem(subNodeIndex)), + TaskNodeStatus: &v1alpha1.TaskNodeStatus{ + Phase: taskPhase, + PluginState: pluginStateBytes, + }, + } + + // initialize mocks + arrayNodeLookup := newArrayNodeLookup(nCtx.ContextualNodeLookup(), subNodeID, &subNodeSpec, subNodeStatus) + + arrayExecutionContext := newArrayExecutionContext(nCtx.ExecutionContext(), subNodeIndex, currentParallelism, arrayNode.GetParallelism()) + + arrayEventRecorder := newArrayEventRecorder() + arrayNodeExecutionContextBuilder := newArrayNodeExecutionContextBuilder(a.nodeExecutor.GetNodeExecutionContextBuilder(), + subNodeID, subNodeIndex, subNodeStatus, inputReader, arrayEventRecorder, currentParallelism, arrayNode.GetParallelism()) + arrayNodeExecutor := a.nodeExecutor.WithNodeExecutionContextBuilder(arrayNodeExecutionContextBuilder) + + return arrayNodeExecutor, arrayExecutionContext, &arrayNodeLookup, &arrayNodeLookup, &subNodeSpec, subNodeStatus, arrayEventRecorder, nil +} diff --git a/flytepropeller/pkg/controller/nodes/array/handler_test.go b/flytepropeller/pkg/controller/nodes/array/handler_test.go new file mode 100644 index 0000000000..f3e6f8bd18 --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/array/handler_test.go @@ -0,0 +1,888 @@ +package array + +import ( + "context" + "fmt" + "testing" + + idlcore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" + + eventmocks "github.com/flyteorg/flytepropeller/events/mocks" + "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/flyteorg/flytepropeller/pkg/controller/config" + execmocks "github.com/flyteorg/flytepropeller/pkg/controller/executors/mocks" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes" + gatemocks "github.com/flyteorg/flytepropeller/pkg/controller/nodes/gate/mocks" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces/mocks" + recoverymocks "github.com/flyteorg/flytepropeller/pkg/controller/nodes/recovery/mocks" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/subworkflow/launchplan" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/catalog" + + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" + pluginmocks "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io/mocks" + + "github.com/flyteorg/flytestdlib/bitarray" + "github.com/flyteorg/flytestdlib/contextutils" + "github.com/flyteorg/flytestdlib/promutils" + "github.com/flyteorg/flytestdlib/promutils/labeled" + "github.com/flyteorg/flytestdlib/storage" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +var ( + taskRef = "taskRef" + arrayNodeSpec = v1alpha1.NodeSpec{ + ID: "foo", + ArrayNode: &v1alpha1.ArrayNodeSpec{ + SubNodeSpec: &v1alpha1.NodeSpec{ + TaskRef: &taskRef, + }, + }, + } +) + +func createArrayNodeHandler(ctx context.Context, t *testing.T, nodeHandler interfaces.NodeHandler, dataStore *storage.DataStore, scope promutils.Scope) (interfaces.NodeHandler, error) { + // mock components + adminClient := launchplan.NewFailFastLaunchPlanExecutor() + enqueueWorkflowFunc := func(workflowID v1alpha1.WorkflowID) {} + eventConfig := &config.EventConfig{} + mockEventSink := eventmocks.NewMockEventSink() + mockHandlerFactory := &mocks.HandlerFactory{} + mockHandlerFactory.OnGetHandlerMatch(mock.Anything).Return(nodeHandler, nil) + mockKubeClient := execmocks.NewFakeKubeClient() + mockRecoveryClient := &recoverymocks.Client{} + mockSignalClient := &gatemocks.SignalServiceClient{} + noopCatalogClient := catalog.NOOPCatalog{} + + // create node executor + nodeExecutor, err := nodes.NewExecutor(ctx, config.GetConfig().NodeConfig, dataStore, enqueueWorkflowFunc, mockEventSink, adminClient, + adminClient, 10, "s3://bucket/", mockKubeClient, noopCatalogClient, mockRecoveryClient, eventConfig, "clusterID", mockSignalClient, mockHandlerFactory, scope) + assert.NoError(t, err) + + // return ArrayNodeHandler + return New(nodeExecutor, eventConfig, scope) +} + +func createNodeExecutionContext(dataStore *storage.DataStore, eventRecorder interfaces.EventRecorder, outputVariables []string, + inputLiteralMap *idlcore.LiteralMap, arrayNodeSpec *v1alpha1.NodeSpec, arrayNodeState *handler.ArrayNodeState) interfaces.NodeExecutionContext { + + nCtx := &mocks.NodeExecutionContext{} + nCtx.OnMaxDatasetSizeBytes().Return(9999999) + + // ContextualNodeLookup + nodeLookup := &execmocks.NodeLookup{} + nodeLookup.OnFromNodeMatch(mock.Anything).Return(nil, nil) + nCtx.OnContextualNodeLookup().Return(nodeLookup) + + // DataStore + nCtx.OnDataStore().Return(dataStore) + + // ExecutionContext + executionContext := &execmocks.ExecutionContext{} + executionContext.OnGetEventVersion().Return(1) + executionContext.OnGetExecutionConfig().Return(v1alpha1.ExecutionConfig{}) + executionContext.OnGetExecutionID().Return( + v1alpha1.ExecutionID{ + WorkflowExecutionIdentifier: &idlcore.WorkflowExecutionIdentifier{ + Project: "project", + Domain: "domain", + Name: "name", + }, + }) + executionContext.OnGetLabels().Return(nil) + executionContext.OnGetRawOutputDataConfig().Return(v1alpha1.RawOutputDataConfig{}) + executionContext.OnIsInterruptible().Return(false) + executionContext.OnGetParentInfo().Return(nil) + outputVariableMap := make(map[string]*idlcore.Variable) + for _, outputVariable := range outputVariables { + outputVariableMap[outputVariable] = &idlcore.Variable{} + } + executionContext.OnGetTaskMatch(taskRef).Return( + &v1alpha1.TaskSpec{ + TaskTemplate: &idlcore.TaskTemplate{ + Interface: &idlcore.TypedInterface{ + Outputs: &idlcore.VariableMap{ + Variables: outputVariableMap, + }, + }, + }, + }, + nil, + ) + nCtx.OnExecutionContext().Return(executionContext) + + // EventsRecorder + nCtx.OnEventsRecorder().Return(eventRecorder) + + // InputReader + inputFilePaths := &pluginmocks.InputFilePaths{} + inputFilePaths.OnGetInputPath().Return(storage.DataReference("s3://bucket/input")) + nCtx.OnInputReader().Return( + newStaticInputReader( + inputFilePaths, + inputLiteralMap, + )) + + // Node + nCtx.OnNode().Return(arrayNodeSpec) + + // NodeExecutionMetadata + nodeExecutionMetadata := &mocks.NodeExecutionMetadata{} + nodeExecutionMetadata.OnGetNodeExecutionID().Return(&idlcore.NodeExecutionIdentifier{ + NodeId: "foo", + ExecutionId: &idlcore.WorkflowExecutionIdentifier{ + Project: "project", + Domain: "domain", + Name: "name", + }, + }) + nCtx.OnNodeExecutionMetadata().Return(nodeExecutionMetadata) + + // NodeID + nCtx.OnNodeID().Return("foo") + + // NodeStateReader + nodeStateReader := &mocks.NodeStateReader{} + nodeStateReader.OnGetArrayNodeState().Return(*arrayNodeState) + nCtx.OnNodeStateReader().Return(nodeStateReader) + + // NodeStateWriter + nodeStateWriter := &mocks.NodeStateWriter{} + nodeStateWriter.OnPutArrayNodeStateMatch(mock.Anything, mock.Anything).Run( + func(args mock.Arguments) { + *arrayNodeState = args.Get(0).(handler.ArrayNodeState) + }, + ).Return(nil) + nCtx.OnNodeStateWriter().Return(nodeStateWriter) + + // NodeStatus + nCtx.OnNodeStatus().Return(&v1alpha1.NodeStatus{ + DataDir: storage.DataReference("s3://bucket/data"), + OutputDir: storage.DataReference("s3://bucket/output"), + }) + + return nCtx +} + +func TestAbort(t *testing.T) { + ctx := context.Background() + scope := promutils.NewTestScope() + dataStore, err := storage.NewDataStore(&storage.Config{ + Type: storage.TypeMemory, + }, scope) + assert.NoError(t, err) + + nodeHandler := &mocks.NodeHandler{} + nodeHandler.OnAbortMatch(mock.Anything, mock.Anything, mock.Anything).Return(nil) + nodeHandler.OnFinalizeMatch(mock.Anything, mock.Anything, mock.Anything).Return(nil) + + // initialize ArrayNodeHandler + arrayNodeHandler, err := createArrayNodeHandler(ctx, t, nodeHandler, dataStore, scope) + assert.NoError(t, err) + + tests := []struct { + name string + inputMap map[string][]int64 + subNodePhases []v1alpha1.NodePhase + subNodeTaskPhases []core.Phase + expectedExternalResourcePhases []idlcore.TaskExecution_Phase + }{ + { + name: "Success", + inputMap: map[string][]int64{ + "foo": []int64{0, 1, 2}, + }, + subNodePhases: []v1alpha1.NodePhase{v1alpha1.NodePhaseSucceeded, v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseNotYetStarted}, + subNodeTaskPhases: []core.Phase{core.PhaseSuccess, core.PhaseRunning, core.PhaseUndefined}, + expectedExternalResourcePhases: []idlcore.TaskExecution_Phase{idlcore.TaskExecution_ABORTED}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + // initailize universal variables + literalMap := convertMapToArrayLiterals(test.inputMap) + + size := -1 + for _, v := range test.inputMap { + if size == -1 { + size = len(v) + } else if len(v) > size { // calculating size as largest input list + size = len(v) + } + } + + // initialize ArrayNodeState + arrayNodeState := &handler.ArrayNodeState{ + Phase: v1alpha1.ArrayNodePhaseFailing, + } + for _, item := range []struct { + arrayReference *bitarray.CompactArray + maxValue int + }{ + {arrayReference: &arrayNodeState.SubNodePhases, maxValue: int(v1alpha1.NodePhaseRecovered)}, + {arrayReference: &arrayNodeState.SubNodeTaskPhases, maxValue: len(core.Phases) - 1}, + {arrayReference: &arrayNodeState.SubNodeRetryAttempts, maxValue: 1}, + {arrayReference: &arrayNodeState.SubNodeSystemFailures, maxValue: 1}, + } { + + *item.arrayReference, err = bitarray.NewCompactArray(uint(size), bitarray.Item(item.maxValue)) + assert.NoError(t, err) + } + + for i, nodePhase := range test.subNodePhases { + arrayNodeState.SubNodePhases.SetItem(i, bitarray.Item(nodePhase)) + } + for i, taskPhase := range test.subNodeTaskPhases { + arrayNodeState.SubNodeTaskPhases.SetItem(i, bitarray.Item(taskPhase)) + } + + // create NodeExecutionContext + eventRecorder := newArrayEventRecorder() + nCtx := createNodeExecutionContext(dataStore, eventRecorder, nil, literalMap, &arrayNodeSpec, arrayNodeState) + + // evaluate node + err := arrayNodeHandler.Abort(ctx, nCtx, "foo") + assert.NoError(t, err) + + nodeHandler.AssertNumberOfCalls(t, "Abort", len(test.expectedExternalResourcePhases)) + if len(test.expectedExternalResourcePhases) > 0 { + assert.Equal(t, 1, len(eventRecorder.taskEvents)) + + externalResources := eventRecorder.taskEvents[0].Metadata.GetExternalResources() + assert.Equal(t, len(test.expectedExternalResourcePhases), len(externalResources)) + for i, expectedPhase := range test.expectedExternalResourcePhases { + assert.Equal(t, expectedPhase, externalResources[i].Phase) + } + } else { + assert.Equal(t, 0, len(eventRecorder.taskEvents)) + } + }) + } +} + +func TestFinalize(t *testing.T) { + ctx := context.Background() + scope := promutils.NewTestScope() + dataStore, err := storage.NewDataStore(&storage.Config{ + Type: storage.TypeMemory, + }, scope) + assert.NoError(t, err) + + nodeHandler := &mocks.NodeHandler{} + nodeHandler.OnFinalizeMatch(mock.Anything, mock.Anything, mock.Anything).Return(nil) + + // initialize ArrayNodeHandler + arrayNodeHandler, err := createArrayNodeHandler(ctx, t, nodeHandler, dataStore, scope) + assert.NoError(t, err) + + tests := []struct { + name string + inputMap map[string][]int64 + subNodePhases []v1alpha1.NodePhase + subNodeTaskPhases []core.Phase + expectedFinalizeCalls int + }{ + { + name: "Success", + inputMap: map[string][]int64{ + "foo": []int64{0, 1, 2}, + }, + subNodePhases: []v1alpha1.NodePhase{v1alpha1.NodePhaseSucceeded, v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseNotYetStarted}, + subNodeTaskPhases: []core.Phase{core.PhaseSuccess, core.PhaseRunning, core.PhaseUndefined}, + expectedFinalizeCalls: 1, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + // initailize universal variables + literalMap := convertMapToArrayLiterals(test.inputMap) + + size := -1 + for _, v := range test.inputMap { + if size == -1 { + size = len(v) + } else if len(v) > size { // calculating size as largest input list + size = len(v) + } + } + + // initialize ArrayNodeState + arrayNodeState := &handler.ArrayNodeState{ + Phase: v1alpha1.ArrayNodePhaseFailing, + } + for _, item := range []struct { + arrayReference *bitarray.CompactArray + maxValue int + }{ + {arrayReference: &arrayNodeState.SubNodePhases, maxValue: int(v1alpha1.NodePhaseRecovered)}, + {arrayReference: &arrayNodeState.SubNodeTaskPhases, maxValue: len(core.Phases) - 1}, + {arrayReference: &arrayNodeState.SubNodeRetryAttempts, maxValue: 1}, + {arrayReference: &arrayNodeState.SubNodeSystemFailures, maxValue: 1}, + } { + + *item.arrayReference, err = bitarray.NewCompactArray(uint(size), bitarray.Item(item.maxValue)) + assert.NoError(t, err) + } + + for i, nodePhase := range test.subNodePhases { + arrayNodeState.SubNodePhases.SetItem(i, bitarray.Item(nodePhase)) + } + for i, taskPhase := range test.subNodeTaskPhases { + arrayNodeState.SubNodeTaskPhases.SetItem(i, bitarray.Item(taskPhase)) + } + + // create NodeExecutionContext + eventRecorder := newArrayEventRecorder() + nCtx := createNodeExecutionContext(dataStore, eventRecorder, nil, literalMap, &arrayNodeSpec, arrayNodeState) + + // evaluate node + err := arrayNodeHandler.Finalize(ctx, nCtx) + assert.NoError(t, err) + + // validate + nodeHandler.AssertNumberOfCalls(t, "Finalize", test.expectedFinalizeCalls) + }) + } +} + +func TestHandleArrayNodePhaseNone(t *testing.T) { + ctx := context.Background() + scope := promutils.NewTestScope() + dataStore, err := storage.NewDataStore(&storage.Config{ + Type: storage.TypeMemory, + }, scope) + assert.NoError(t, err) + nodeHandler := &mocks.NodeHandler{} + + // initialize ArrayNodeHandler + arrayNodeHandler, err := createArrayNodeHandler(ctx, t, nodeHandler, dataStore, scope) + assert.NoError(t, err) + + tests := []struct { + name string + inputValues map[string][]int64 + expectedArrayNodePhase v1alpha1.ArrayNodePhase + expectedTransitionPhase handler.EPhase + expectedExternalResourcePhases []idlcore.TaskExecution_Phase + }{ + { + name: "Success", + inputValues: map[string][]int64{ + "foo": []int64{1, 2}, + }, + expectedArrayNodePhase: v1alpha1.ArrayNodePhaseExecuting, + expectedTransitionPhase: handler.EPhaseRunning, + expectedExternalResourcePhases: []idlcore.TaskExecution_Phase{idlcore.TaskExecution_QUEUED, idlcore.TaskExecution_QUEUED}, + }, + { + name: "SuccessMultipleInputs", + inputValues: map[string][]int64{ + "foo": []int64{1, 2, 3}, + "bar": []int64{4, 5, 6}, + }, + expectedArrayNodePhase: v1alpha1.ArrayNodePhaseExecuting, + expectedTransitionPhase: handler.EPhaseRunning, + expectedExternalResourcePhases: []idlcore.TaskExecution_Phase{idlcore.TaskExecution_QUEUED, idlcore.TaskExecution_QUEUED, idlcore.TaskExecution_QUEUED}, + }, + { + name: "FailureDifferentInputListLengths", + inputValues: map[string][]int64{ + "foo": []int64{1, 2}, + "bar": []int64{3}, + }, + expectedArrayNodePhase: v1alpha1.ArrayNodePhaseNone, + expectedTransitionPhase: handler.EPhaseFailed, + expectedExternalResourcePhases: nil, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + // create NodeExecutionContext + eventRecorder := newArrayEventRecorder() + literalMap := convertMapToArrayLiterals(test.inputValues) + arrayNodeState := &handler.ArrayNodeState{ + Phase: v1alpha1.ArrayNodePhaseNone, + } + nCtx := createNodeExecutionContext(dataStore, eventRecorder, nil, literalMap, &arrayNodeSpec, arrayNodeState) + + // evaluate node + transition, err := arrayNodeHandler.Handle(ctx, nCtx) + assert.NoError(t, err) + + // validate results + assert.Equal(t, test.expectedArrayNodePhase, arrayNodeState.Phase) + assert.Equal(t, test.expectedTransitionPhase, transition.Info().GetPhase()) + + if len(test.expectedExternalResourcePhases) > 0 { + assert.Equal(t, 1, len(eventRecorder.taskEvents)) + + externalResources := eventRecorder.taskEvents[0].Metadata.GetExternalResources() + assert.Equal(t, len(test.expectedExternalResourcePhases), len(externalResources)) + for i, expectedPhase := range test.expectedExternalResourcePhases { + assert.Equal(t, expectedPhase, externalResources[i].Phase) + } + } else { + assert.Equal(t, 0, len(eventRecorder.taskEvents)) + } + }) + } +} + +func TestHandleArrayNodePhaseExecuting(t *testing.T) { + ctx := context.Background() + minSuccessRatio := float32(0.5) + + // initailize universal variables + inputMap := map[string][]int64{ + "foo": []int64{0, 1}, + "bar": []int64{2, 3}, + } + literalMap := convertMapToArrayLiterals(inputMap) + + size := -1 + for _, v := range inputMap { + if size == -1 { + size = len(v) + } else if len(v) > size { // calculating size as largest input list + size = len(v) + } + } + + tests := []struct { + name string + parallelism int + minSuccessRatio *float32 + subNodePhases []v1alpha1.NodePhase + subNodeTaskPhases []core.Phase + subNodeTransitions []handler.Transition + expectedArrayNodePhase v1alpha1.ArrayNodePhase + expectedTransitionPhase handler.EPhase + expectedExternalResourcePhases []idlcore.TaskExecution_Phase + }{ + { + name: "StartAllSubNodes", + subNodePhases: []v1alpha1.NodePhase{ + v1alpha1.NodePhaseQueued, + v1alpha1.NodePhaseQueued, + }, + subNodeTaskPhases: []core.Phase{ + core.PhaseUndefined, + core.PhaseUndefined, + }, + subNodeTransitions: []handler.Transition{ + handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoRunning(&handler.ExecutionInfo{})), + handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoRunning(&handler.ExecutionInfo{})), + }, + expectedArrayNodePhase: v1alpha1.ArrayNodePhaseExecuting, + expectedTransitionPhase: handler.EPhaseRunning, + expectedExternalResourcePhases: []idlcore.TaskExecution_Phase{idlcore.TaskExecution_RUNNING, idlcore.TaskExecution_RUNNING}, + }, + { + name: "StartOneSubNodeParallelism", + parallelism: 1, + subNodePhases: []v1alpha1.NodePhase{ + v1alpha1.NodePhaseQueued, + v1alpha1.NodePhaseQueued, + }, + subNodeTaskPhases: []core.Phase{ + core.PhaseUndefined, + core.PhaseUndefined, + }, + subNodeTransitions: []handler.Transition{ + handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoRunning(&handler.ExecutionInfo{})), + handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoRunning(&handler.ExecutionInfo{})), + }, + expectedArrayNodePhase: v1alpha1.ArrayNodePhaseExecuting, + expectedTransitionPhase: handler.EPhaseRunning, + expectedExternalResourcePhases: []idlcore.TaskExecution_Phase{idlcore.TaskExecution_RUNNING, idlcore.TaskExecution_QUEUED}, + }, + { + name: "AllSubNodesSuccedeed", + subNodePhases: []v1alpha1.NodePhase{ + v1alpha1.NodePhaseRunning, + v1alpha1.NodePhaseRunning, + }, + subNodeTaskPhases: []core.Phase{ + core.PhaseRunning, + core.PhaseRunning, + }, + subNodeTransitions: []handler.Transition{ + handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoSuccess(&handler.ExecutionInfo{})), + handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoSuccess(&handler.ExecutionInfo{})), + }, + expectedArrayNodePhase: v1alpha1.ArrayNodePhaseSucceeding, + expectedTransitionPhase: handler.EPhaseRunning, + expectedExternalResourcePhases: []idlcore.TaskExecution_Phase{idlcore.TaskExecution_SUCCEEDED, idlcore.TaskExecution_SUCCEEDED}, + }, + { + name: "OneSubNodeSuccedeedMinSuccessRatio", + minSuccessRatio: &minSuccessRatio, + subNodePhases: []v1alpha1.NodePhase{ + v1alpha1.NodePhaseRunning, + v1alpha1.NodePhaseRunning, + }, + subNodeTaskPhases: []core.Phase{ + core.PhaseRunning, + core.PhaseRunning, + }, + subNodeTransitions: []handler.Transition{ + handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoSuccess(&handler.ExecutionInfo{})), + handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(0, "", "", &handler.ExecutionInfo{})), + }, + expectedArrayNodePhase: v1alpha1.ArrayNodePhaseSucceeding, + expectedTransitionPhase: handler.EPhaseRunning, + expectedExternalResourcePhases: []idlcore.TaskExecution_Phase{idlcore.TaskExecution_SUCCEEDED, idlcore.TaskExecution_FAILED}, + }, + { + name: "OneSubNodeFailed", + subNodePhases: []v1alpha1.NodePhase{ + v1alpha1.NodePhaseRunning, + v1alpha1.NodePhaseRunning, + }, + subNodeTaskPhases: []core.Phase{ + core.PhaseRunning, + core.PhaseRunning, + }, + subNodeTransitions: []handler.Transition{ + handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(0, "", "", &handler.ExecutionInfo{})), + handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoSuccess(&handler.ExecutionInfo{})), + }, + expectedArrayNodePhase: v1alpha1.ArrayNodePhaseFailing, + expectedTransitionPhase: handler.EPhaseRunning, + expectedExternalResourcePhases: []idlcore.TaskExecution_Phase{idlcore.TaskExecution_FAILED, idlcore.TaskExecution_SUCCEEDED}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + scope := promutils.NewTestScope() + dataStore, err := storage.NewDataStore(&storage.Config{ + Type: storage.TypeMemory, + }, scope) + assert.NoError(t, err) + + // initialize ArrayNodeState + arrayNodeState := &handler.ArrayNodeState{ + Phase: v1alpha1.ArrayNodePhaseExecuting, + } + for _, item := range []struct { + arrayReference *bitarray.CompactArray + maxValue int + }{ + {arrayReference: &arrayNodeState.SubNodePhases, maxValue: int(v1alpha1.NodePhaseRecovered)}, + {arrayReference: &arrayNodeState.SubNodeTaskPhases, maxValue: len(core.Phases) - 1}, + {arrayReference: &arrayNodeState.SubNodeRetryAttempts, maxValue: 1}, + {arrayReference: &arrayNodeState.SubNodeSystemFailures, maxValue: 1}, + } { + + *item.arrayReference, err = bitarray.NewCompactArray(uint(size), bitarray.Item(item.maxValue)) + assert.NoError(t, err) + } + + for i, nodePhase := range test.subNodePhases { + arrayNodeState.SubNodePhases.SetItem(i, bitarray.Item(nodePhase)) + } + for i, taskPhase := range test.subNodeTaskPhases { + arrayNodeState.SubNodeTaskPhases.SetItem(i, bitarray.Item(taskPhase)) + } + + // create NodeExecutionContext + eventRecorder := newArrayEventRecorder() + + nodeSpec := arrayNodeSpec + nodeSpec.ArrayNode.Parallelism = uint32(test.parallelism) + nodeSpec.ArrayNode.MinSuccessRatio = test.minSuccessRatio + + nCtx := createNodeExecutionContext(dataStore, eventRecorder, nil, literalMap, &arrayNodeSpec, arrayNodeState) + + // initialize ArrayNodeHandler + nodeHandler := &mocks.NodeHandler{} + nodeHandler.OnFinalizeRequired().Return(false) + for i, transition := range test.subNodeTransitions { + nodeID := fmt.Sprintf("%s-n%d", nCtx.NodeID(), i) + transitionPhase := test.expectedExternalResourcePhases[i] + + nodeHandler.OnHandleMatch(mock.Anything, mock.MatchedBy(func(arrayNCtx interfaces.NodeExecutionContext) bool { + return arrayNCtx.NodeID() == nodeID // match on NodeID using index to ensure each subNode is handled independently + })).Run( + func(args mock.Arguments) { + // mock sending TaskExecutionEvent from handler to show task state transition + taskExecutionEvent := &event.TaskExecutionEvent{ + Phase: transitionPhase, + } + + err := args.Get(1).(interfaces.NodeExecutionContext).EventsRecorder().RecordTaskEvent(ctx, taskExecutionEvent, &config.EventConfig{}) + assert.NoError(t, err) + }, + ).Return(transition, nil) + } + + arrayNodeHandler, err := createArrayNodeHandler(ctx, t, nodeHandler, dataStore, scope) + assert.NoError(t, err) + + // evaluate node + transition, err := arrayNodeHandler.Handle(ctx, nCtx) + assert.NoError(t, err) + + // validate results + assert.Equal(t, test.expectedArrayNodePhase, arrayNodeState.Phase) + assert.Equal(t, test.expectedTransitionPhase, transition.Info().GetPhase()) + + if len(test.expectedExternalResourcePhases) > 0 { + assert.Equal(t, 1, len(eventRecorder.taskEvents)) + + externalResources := eventRecorder.taskEvents[0].Metadata.GetExternalResources() + assert.Equal(t, len(test.expectedExternalResourcePhases), len(externalResources)) + for i, expectedPhase := range test.expectedExternalResourcePhases { + assert.Equal(t, expectedPhase, externalResources[i].Phase) + } + } else { + assert.Equal(t, 0, len(eventRecorder.taskEvents)) + } + }) + } +} + +func TestHandleArrayNodePhaseSucceeding(t *testing.T) { + ctx := context.Background() + scope := promutils.NewTestScope() + dataStore, err := storage.NewDataStore(&storage.Config{ + Type: storage.TypeMemory, + }, scope) + assert.NoError(t, err) + nodeHandler := &mocks.NodeHandler{} + valueOne := 1 + + // initialize ArrayNodeHandler + arrayNodeHandler, err := createArrayNodeHandler(ctx, t, nodeHandler, dataStore, scope) + assert.NoError(t, err) + + tests := []struct { + name string + outputVariable string + outputValues []*int + subNodePhases []v1alpha1.NodePhase + expectedArrayNodePhase v1alpha1.ArrayNodePhase + expectedTransitionPhase handler.EPhase + }{ + { + name: "Success", + outputValues: []*int{&valueOne, nil}, + outputVariable: "foo", + subNodePhases: []v1alpha1.NodePhase{ + v1alpha1.NodePhaseSucceeded, + v1alpha1.NodePhaseFailed, + }, + expectedArrayNodePhase: v1alpha1.ArrayNodePhaseSucceeding, + expectedTransitionPhase: handler.EPhaseSuccess, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + // initialize ArrayNodeState + subNodePhases, err := bitarray.NewCompactArray(uint(len(test.subNodePhases)), bitarray.Item(v1alpha1.NodePhaseRecovered)) + assert.NoError(t, err) + for i, nodePhase := range test.subNodePhases { + subNodePhases.SetItem(i, bitarray.Item(nodePhase)) + } + + retryAttempts, err := bitarray.NewCompactArray(uint(len(test.subNodePhases)), bitarray.Item(1)) + assert.NoError(t, err) + + arrayNodeState := &handler.ArrayNodeState{ + Phase: v1alpha1.ArrayNodePhaseSucceeding, + SubNodePhases: subNodePhases, + SubNodeRetryAttempts: retryAttempts, + } + + // create NodeExecutionContext + eventRecorder := newArrayEventRecorder() + literalMap := &idlcore.LiteralMap{} + nCtx := createNodeExecutionContext(dataStore, eventRecorder, []string{test.outputVariable}, literalMap, &arrayNodeSpec, arrayNodeState) + + // write mocked output files + for i, outputValue := range test.outputValues { + if outputValue == nil { + continue + } + + outputFile := storage.DataReference(fmt.Sprintf("s3://bucket/output/%d/0/outputs.pb", i)) + outputLiteralMap := &idlcore.LiteralMap{ + Literals: map[string]*idlcore.Literal{ + test.outputVariable: &idlcore.Literal{ + Value: &idlcore.Literal_Scalar{ + Scalar: &idlcore.Scalar{ + Value: &idlcore.Scalar_Primitive{ + Primitive: &idlcore.Primitive{ + Value: &idlcore.Primitive_Integer{ + Integer: int64(*outputValue), + }, + }, + }, + }, + }, + }, + }, + } + + err := nCtx.DataStore().WriteProtobuf(ctx, outputFile, storage.Options{}, outputLiteralMap) + assert.NoError(t, err) + } + + // evaluate node + transition, err := arrayNodeHandler.Handle(ctx, nCtx) + assert.NoError(t, err) + + // validate results + assert.Equal(t, test.expectedArrayNodePhase, arrayNodeState.Phase) + assert.Equal(t, test.expectedTransitionPhase, transition.Info().GetPhase()) + + // validate output file + var outputs idlcore.LiteralMap + outputFile := v1alpha1.GetOutputsFile(nCtx.NodeStatus().GetOutputDir()) + err = nCtx.DataStore().ReadProtobuf(ctx, outputFile, &outputs) + assert.NoError(t, err) + + assert.Len(t, outputs.GetLiterals(), 1) + + collection := outputs.GetLiterals()[test.outputVariable].GetCollection() + assert.NotNil(t, collection) + + assert.Len(t, collection.GetLiterals(), len(test.outputValues)) + for i, outputValue := range test.outputValues { + if outputValue == nil { + assert.NotNil(t, collection.GetLiterals()[i].GetScalar()) + } else { + assert.Equal(t, int64(*outputValue), collection.GetLiterals()[i].GetScalar().GetPrimitive().GetInteger()) + } + } + }) + } +} + +func TestHandleArrayNodePhaseFailing(t *testing.T) { + ctx := context.Background() + scope := promutils.NewTestScope() + dataStore, err := storage.NewDataStore(&storage.Config{ + Type: storage.TypeMemory, + }, scope) + assert.NoError(t, err) + + nodeHandler := &mocks.NodeHandler{} + nodeHandler.OnAbortMatch(mock.Anything, mock.Anything, mock.Anything).Return(nil) + nodeHandler.OnFinalizeMatch(mock.Anything, mock.Anything, mock.Anything).Return(nil) + + // initialize ArrayNodeHandler + arrayNodeHandler, err := createArrayNodeHandler(ctx, t, nodeHandler, dataStore, scope) + assert.NoError(t, err) + + tests := []struct { + name string + subNodePhases []v1alpha1.NodePhase + expectedArrayNodePhase v1alpha1.ArrayNodePhase + expectedTransitionPhase handler.EPhase + expectedAbortCalls int + }{ + { + name: "Success", + subNodePhases: []v1alpha1.NodePhase{ + v1alpha1.NodePhaseRunning, + v1alpha1.NodePhaseSucceeded, + v1alpha1.NodePhaseFailed, + }, + expectedArrayNodePhase: v1alpha1.ArrayNodePhaseFailing, + expectedTransitionPhase: handler.EPhaseFailed, + expectedAbortCalls: 1, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + // initialize ArrayNodeState + arrayNodeState := &handler.ArrayNodeState{ + Phase: v1alpha1.ArrayNodePhaseFailing, + } + + for _, item := range []struct { + arrayReference *bitarray.CompactArray + maxValue int + }{ + {arrayReference: &arrayNodeState.SubNodePhases, maxValue: int(v1alpha1.NodePhaseRecovered)}, + {arrayReference: &arrayNodeState.SubNodeTaskPhases, maxValue: len(core.Phases) - 1}, + {arrayReference: &arrayNodeState.SubNodeRetryAttempts, maxValue: 1}, + {arrayReference: &arrayNodeState.SubNodeSystemFailures, maxValue: 1}, + } { + + *item.arrayReference, err = bitarray.NewCompactArray(uint(len(test.subNodePhases)), bitarray.Item(item.maxValue)) + assert.NoError(t, err) + } + + for i, nodePhase := range test.subNodePhases { + arrayNodeState.SubNodePhases.SetItem(i, bitarray.Item(nodePhase)) + } + + // create NodeExecutionContext + eventRecorder := newArrayEventRecorder() + literalMap := &idlcore.LiteralMap{} + nCtx := createNodeExecutionContext(dataStore, eventRecorder, nil, literalMap, &arrayNodeSpec, arrayNodeState) + + // evaluate node + transition, err := arrayNodeHandler.Handle(ctx, nCtx) + assert.NoError(t, err) + + // validate results + assert.Equal(t, test.expectedArrayNodePhase, arrayNodeState.Phase) + assert.Equal(t, test.expectedTransitionPhase, transition.Info().GetPhase()) + nodeHandler.AssertNumberOfCalls(t, "Abort", test.expectedAbortCalls) + }) + } +} + +func init() { + labeled.SetMetricKeys(contextutils.ProjectKey, contextutils.DomainKey, contextutils.WorkflowIDKey, contextutils.TaskIDKey) +} + +func convertMapToArrayLiterals(values map[string][]int64) *idlcore.LiteralMap { + literalMap := make(map[string]*idlcore.Literal) + for k, v := range values { + // create LiteralCollection + literalList := make([]*idlcore.Literal, 0, len(v)) + for _, x := range v { + literalList = append(literalList, &idlcore.Literal{ + Value: &idlcore.Literal_Scalar{ + Scalar: &idlcore.Scalar{ + Value: &idlcore.Scalar_Primitive{ + Primitive: &idlcore.Primitive{ + Value: &idlcore.Primitive_Integer{ + Integer: x, + }, + }, + }, + }, + }, + }) + } + + // add LiteralCollection to map + literalMap[k] = &idlcore.Literal{ + Value: &idlcore.Literal_Collection{ + Collection: &idlcore.LiteralCollection{ + Literals: literalList, + }, + }, + } + } + + return &idlcore.LiteralMap{ + Literals: literalMap, + } +} diff --git a/flytepropeller/pkg/controller/nodes/array/node_execution_context.go b/flytepropeller/pkg/controller/nodes/array/node_execution_context.go new file mode 100644 index 0000000000..af3ea42f71 --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/array/node_execution_context.go @@ -0,0 +1,150 @@ +package array + +import ( + "context" + + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" + + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io" + + "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/flyteorg/flytepropeller/pkg/controller/config" + "github.com/flyteorg/flytepropeller/pkg/controller/executors" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" +) + +type arrayEventRecorder struct { + nodeEvents []*event.NodeExecutionEvent + taskEvents []*event.TaskExecutionEvent +} + +func (a *arrayEventRecorder) RecordNodeEvent(ctx context.Context, event *event.NodeExecutionEvent, eventConfig *config.EventConfig) error { + a.nodeEvents = append(a.nodeEvents, event) + return nil +} + +func (a *arrayEventRecorder) RecordTaskEvent(ctx context.Context, event *event.TaskExecutionEvent, eventConfig *config.EventConfig) error { + a.taskEvents = append(a.taskEvents, event) + return nil +} + +func (a *arrayEventRecorder) NodeEvents() []*event.NodeExecutionEvent { + return a.nodeEvents +} + +func (a *arrayEventRecorder) TaskEvents() []*event.TaskExecutionEvent { + return a.taskEvents +} + +func newArrayEventRecorder() *arrayEventRecorder { + return &arrayEventRecorder{ + nodeEvents: make([]*event.NodeExecutionEvent, 0), + taskEvents: make([]*event.TaskExecutionEvent, 0), + } +} + +type staticInputReader struct { + io.InputFilePaths + input *core.LiteralMap +} + +func (i staticInputReader) Get(_ context.Context) (*core.LiteralMap, error) { + return i.input, nil +} + +func newStaticInputReader(inputPaths io.InputFilePaths, input *core.LiteralMap) staticInputReader { + return staticInputReader{ + InputFilePaths: inputPaths, + input: input, + } +} + +func constructLiteralMap(ctx context.Context, inputReader io.InputReader, index int) (*core.LiteralMap, error) { + inputs, err := inputReader.Get(ctx) + if err != nil { + return nil, err + } + + literals := make(map[string]*core.Literal) + for name, literal := range inputs.Literals { + if literalCollection := literal.GetCollection(); literalCollection != nil { + literals[name] = literalCollection.Literals[index] + } + } + + return &core.LiteralMap{ + Literals: literals, + }, nil +} + +type arrayTaskReader struct { + interfaces.TaskReader +} + +func (a *arrayTaskReader) Read(ctx context.Context) (*core.TaskTemplate, error) { + taskTemplate, err := a.TaskReader.Read(ctx) + if err != nil { + return nil, err + } + + // convert output list variable to singular + outputVariables := make(map[string]*core.Variable) + for key, value := range taskTemplate.Interface.Outputs.Variables { + switch v := value.Type.Type.(type) { + case *core.LiteralType_CollectionType: + outputVariables[key] = &core.Variable{ + Type: v.CollectionType, + Description: value.Description, + } + default: + outputVariables[key] = value + } + } + + taskTemplate.Interface.Outputs = &core.VariableMap{ + Variables: outputVariables, + } + return taskTemplate, nil +} + +type arrayNodeExecutionContext struct { + interfaces.NodeExecutionContext + eventRecorder interfaces.EventRecorder + executionContext executors.ExecutionContext + inputReader io.InputReader + nodeStatus *v1alpha1.NodeStatus + taskReader interfaces.TaskReader +} + +func (a *arrayNodeExecutionContext) EventsRecorder() interfaces.EventRecorder { + return a.eventRecorder +} + +func (a *arrayNodeExecutionContext) ExecutionContext() executors.ExecutionContext { + return a.executionContext +} + +func (a *arrayNodeExecutionContext) InputReader() io.InputReader { + return a.inputReader +} + +func (a *arrayNodeExecutionContext) NodeStatus() v1alpha1.ExecutableNodeStatus { + return a.nodeStatus +} + +func (a *arrayNodeExecutionContext) TaskReader() interfaces.TaskReader { + return a.taskReader +} + +func newArrayNodeExecutionContext(nodeExecutionContext interfaces.NodeExecutionContext, inputReader io.InputReader, eventRecorder interfaces.EventRecorder, subNodeIndex int, nodeStatus *v1alpha1.NodeStatus, currentParallelism *uint32, maxParallelism uint32) *arrayNodeExecutionContext { + arrayExecutionContext := newArrayExecutionContext(nodeExecutionContext.ExecutionContext(), subNodeIndex, currentParallelism, maxParallelism) + return &arrayNodeExecutionContext{ + NodeExecutionContext: nodeExecutionContext, + eventRecorder: eventRecorder, + executionContext: arrayExecutionContext, + inputReader: inputReader, + nodeStatus: nodeStatus, + taskReader: &arrayTaskReader{nodeExecutionContext.TaskReader()}, + } +} diff --git a/flytepropeller/pkg/controller/nodes/array/node_execution_context_builder.go b/flytepropeller/pkg/controller/nodes/array/node_execution_context_builder.go new file mode 100644 index 0000000000..de145b95ae --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/array/node_execution_context_builder.go @@ -0,0 +1,55 @@ +package array + +import ( + "context" + + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io" + + "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/flyteorg/flytepropeller/pkg/controller/executors" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" +) + +type arrayNodeExecutionContextBuilder struct { + nCtxBuilder interfaces.NodeExecutionContextBuilder + subNodeID v1alpha1.NodeID + subNodeIndex int + subNodeStatus *v1alpha1.NodeStatus + inputReader io.InputReader + currentParallelism *uint32 + maxParallelism uint32 + eventRecorder interfaces.EventRecorder +} + +func (a *arrayNodeExecutionContextBuilder) BuildNodeExecutionContext(ctx context.Context, executionContext executors.ExecutionContext, + nl executors.NodeLookup, currentNodeID v1alpha1.NodeID) (interfaces.NodeExecutionContext, error) { + + // create base NodeExecutionContext + nCtx, err := a.nCtxBuilder.BuildNodeExecutionContext(ctx, executionContext, nl, currentNodeID) + if err != nil { + return nil, err + } + + if currentNodeID == a.subNodeID { + // overwrite NodeExecutionContext for ArrayNode execution + nCtx = newArrayNodeExecutionContext(nCtx, a.inputReader, a.eventRecorder, a.subNodeIndex, a.subNodeStatus, a.currentParallelism, a.maxParallelism) + } + + return nCtx, nil +} + +func newArrayNodeExecutionContextBuilder(nCtxBuilder interfaces.NodeExecutionContextBuilder, subNodeID v1alpha1.NodeID, + subNodeIndex int, subNodeStatus *v1alpha1.NodeStatus, inputReader io.InputReader, eventRecorder interfaces.EventRecorder, + currentParallelism *uint32, maxParallelism uint32) interfaces.NodeExecutionContextBuilder { + + return &arrayNodeExecutionContextBuilder{ + nCtxBuilder: nCtxBuilder, + subNodeID: subNodeID, + subNodeIndex: subNodeIndex, + subNodeStatus: subNodeStatus, + inputReader: inputReader, + currentParallelism: currentParallelism, + maxParallelism: maxParallelism, + eventRecorder: eventRecorder, + } +} diff --git a/flytepropeller/pkg/controller/nodes/array/node_lookup.go b/flytepropeller/pkg/controller/nodes/array/node_lookup.go new file mode 100644 index 0000000000..061b323af4 --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/array/node_lookup.go @@ -0,0 +1,40 @@ +package array + +import ( + "context" + + "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/flyteorg/flytepropeller/pkg/controller/executors" +) + +type arrayNodeLookup struct { + executors.NodeLookup + subNodeID v1alpha1.NodeID + subNodeSpec *v1alpha1.NodeSpec + subNodeStatus *v1alpha1.NodeStatus +} + +func (a *arrayNodeLookup) GetNode(nodeID v1alpha1.NodeID) (v1alpha1.ExecutableNode, bool) { + if nodeID == a.subNodeID { + return a.subNodeSpec, true + } + + return a.NodeLookup.GetNode(nodeID) +} + +func (a *arrayNodeLookup) GetNodeExecutionStatus(ctx context.Context, id v1alpha1.NodeID) v1alpha1.ExecutableNodeStatus { + if id == a.subNodeID { + return a.subNodeStatus + } + + return a.NodeLookup.GetNodeExecutionStatus(ctx, id) +} + +func newArrayNodeLookup(nodeLookup executors.NodeLookup, subNodeID v1alpha1.NodeID, subNodeSpec *v1alpha1.NodeSpec, subNodeStatus *v1alpha1.NodeStatus) arrayNodeLookup { + return arrayNodeLookup{ + NodeLookup: nodeLookup, + subNodeID: subNodeID, + subNodeSpec: subNodeSpec, + subNodeStatus: subNodeStatus, + } +} diff --git a/flytepropeller/pkg/controller/nodes/array/utils.go b/flytepropeller/pkg/controller/nodes/array/utils.go new file mode 100644 index 0000000000..a0700e5739 --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/array/utils.go @@ -0,0 +1,103 @@ +package array + +import ( + "bytes" + "context" + "fmt" + "time" + + idlcore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" + + "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/codex" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/k8s" + + "github.com/flyteorg/flytestdlib/storage" + + "github.com/golang/protobuf/ptypes" +) + +func appendLiteral(name string, literal *idlcore.Literal, outputLiterals map[string]*idlcore.Literal, length int) { + outputLiteral, exists := outputLiterals[name] + if !exists { + outputLiteral = &idlcore.Literal{ + Value: &idlcore.Literal_Collection{ + Collection: &idlcore.LiteralCollection{ + Literals: make([]*idlcore.Literal, 0, length), + }, + }, + } + + outputLiterals[name] = outputLiteral + } + + collection := outputLiteral.GetCollection() + collection.Literals = append(collection.Literals, literal) +} + +func buildTaskExecutionEvent(_ context.Context, nCtx interfaces.NodeExecutionContext, taskPhase idlcore.TaskExecution_Phase, taskPhaseVersion uint32, externalResources []*event.ExternalResourceInfo) (*event.TaskExecutionEvent, error) { + occurredAt, err := ptypes.TimestampProto(time.Now()) + if err != nil { + return nil, err + } + + nodeExecutionID := nCtx.NodeExecutionMetadata().GetNodeExecutionID() + workflowExecutionID := nodeExecutionID.ExecutionId + return &event.TaskExecutionEvent{ + TaskId: &idlcore.Identifier{ + ResourceType: idlcore.ResourceType_TASK, + Project: workflowExecutionID.Project, + Domain: workflowExecutionID.Domain, + Name: nCtx.NodeID(), + Version: "v1", // this value is irrelevant but necessary for the identifier to be valid + }, + ParentNodeExecutionId: nCtx.NodeExecutionMetadata().GetNodeExecutionID(), + RetryAttempt: 0, // ArrayNode will never retry + Phase: taskPhase, + PhaseVersion: taskPhaseVersion, + OccurredAt: occurredAt, + Metadata: &event.TaskExecutionMetadata{ + ExternalResources: externalResources, + }, + TaskType: "k8s-array", + EventVersion: 1, + }, nil +} + +func buildSubNodeID(nCtx interfaces.NodeExecutionContext, index int, retryAttempt uint32) string { + return fmt.Sprintf("%s-n%d-%d", nCtx.NodeID(), index, retryAttempt) +} + +func bytesFromK8sPluginState(pluginState k8s.PluginState) ([]byte, error) { + buffer := make([]byte, 0, task.MaxPluginStateSizeBytes) + bufferWriter := bytes.NewBuffer(buffer) + + codec := codex.GobStateCodec{} + if err := codec.Encode(pluginState, bufferWriter); err != nil { + return nil, err + } + + return bufferWriter.Bytes(), nil +} + +func constructOutputReferences(ctx context.Context, nCtx interfaces.NodeExecutionContext, postfix ...string) (storage.DataReference, storage.DataReference, error) { + subDataDir, err := nCtx.DataStore().ConstructReference(ctx, nCtx.NodeStatus().GetDataDir(), postfix...) + if err != nil { + return "", "", err + } + + subOutputDir, err := nCtx.DataStore().ConstructReference(ctx, nCtx.NodeStatus().GetOutputDir(), postfix...) + if err != nil { + return "", "", err + } + + return subDataDir, subOutputDir, nil +} + +func isTerminalNodePhase(nodePhase v1alpha1.NodePhase) bool { + return nodePhase == v1alpha1.NodePhaseSucceeded || nodePhase == v1alpha1.NodePhaseFailed || nodePhase == v1alpha1.NodePhaseTimedOut || + nodePhase == v1alpha1.NodePhaseSkipped || nodePhase == v1alpha1.NodePhaseRecovered +} diff --git a/flytepropeller/pkg/controller/nodes/array/utils_test.go b/flytepropeller/pkg/controller/nodes/array/utils_test.go new file mode 100644 index 0000000000..2e3eaf6e66 --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/array/utils_test.go @@ -0,0 +1,36 @@ +package array + +import ( + "testing" + + idlcore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + + "github.com/stretchr/testify/assert" +) + +func TestAppendLiteral(t *testing.T) { + outputLiterals := make(map[string]*idlcore.Literal) + literalMaps := []map[string]*idlcore.Literal{ + map[string]*idlcore.Literal{ + "foo": nilLiteral, + "bar": nilLiteral, + }, + map[string]*idlcore.Literal{ + "foo": nilLiteral, + "bar": nilLiteral, + }, + } + + for _, m := range literalMaps { + for k, v := range m { + appendLiteral(k, v, outputLiterals, len(literalMaps)) + } + } + + for _, v := range outputLiterals { + collection, ok := v.Value.(*idlcore.Literal_Collection) + assert.True(t, ok) + + assert.Equal(t, 2, len(collection.Collection.Literals)) + } +} diff --git a/flytepropeller/pkg/controller/nodes/branch/handler.go b/flytepropeller/pkg/controller/nodes/branch/handler.go index 109290b908..ed73245521 100644 --- a/flytepropeller/pkg/controller/nodes/branch/handler.go +++ b/flytepropeller/pkg/controller/nodes/branch/handler.go @@ -5,16 +5,18 @@ import ( "fmt" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" - "github.com/flyteorg/flytepropeller/pkg/controller/config" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/common" - stdErrors "github.com/flyteorg/flytestdlib/errors" - "github.com/flyteorg/flytestdlib/logger" - "github.com/flyteorg/flytestdlib/promutils" "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/flyteorg/flytepropeller/pkg/controller/config" "github.com/flyteorg/flytepropeller/pkg/controller/executors" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/common" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/errors" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" + + stdErrors "github.com/flyteorg/flytestdlib/errors" + "github.com/flyteorg/flytestdlib/logger" + "github.com/flyteorg/flytestdlib/promutils" ) type metrics struct { @@ -22,7 +24,7 @@ type metrics struct { } type branchHandler struct { - nodeExecutor executors.Node + nodeExecutor interfaces.Node m metrics eventConfig *config.EventConfig } @@ -31,13 +33,13 @@ func (b *branchHandler) FinalizeRequired() bool { return false } -func (b *branchHandler) Setup(ctx context.Context, _ handler.SetupContext) error { +func (b *branchHandler) Setup(ctx context.Context, _ interfaces.SetupContext) error { logger.Debugf(ctx, "BranchNode::Setup: nothing to do") return nil } -func (b *branchHandler) HandleBranchNode(ctx context.Context, branchNode v1alpha1.ExecutableBranchNode, nCtx handler.NodeExecutionContext, nl executors.NodeLookup) (handler.Transition, error) { - if nCtx.NodeStateReader().GetBranchNode().FinalizedNodeID == nil { +func (b *branchHandler) HandleBranchNode(ctx context.Context, branchNode v1alpha1.ExecutableBranchNode, nCtx interfaces.NodeExecutionContext, nl executors.NodeLookup) (handler.Transition, error) { + if nCtx.NodeStateReader().GetBranchNodeState().FinalizedNodeID == nil { nodeInputs, err := nCtx.InputReader().Get(ctx) if err != nil { errMsg := fmt.Sprintf("Failed to read input. Error [%s]", err) @@ -79,7 +81,7 @@ func (b *branchHandler) HandleBranchNode(ctx context.Context, branchNode v1alpha } // If the branchNodestatus was already evaluated i.e, Node is in Running status - branchStatus := nCtx.NodeStateReader().GetBranchNode() + branchStatus := nCtx.NodeStateReader().GetBranchNodeState() userError := branchNode.GetElseFail() finalNodeID := branchStatus.FinalizedNodeID if finalNodeID == nil { @@ -103,7 +105,7 @@ func (b *branchHandler) HandleBranchNode(ctx context.Context, branchNode v1alpha return b.recurseDownstream(ctx, nCtx, nodeStatus, branchTakenNode) } -func (b *branchHandler) Handle(ctx context.Context, nCtx handler.NodeExecutionContext) (handler.Transition, error) { +func (b *branchHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecutionContext) (handler.Transition, error) { logger.Debug(ctx, "Starting Branch Node") branchNode := nCtx.Node().GetBranchNode() if branchNode == nil { @@ -115,7 +117,7 @@ func (b *branchHandler) Handle(ctx context.Context, nCtx handler.NodeExecutionCo return b.HandleBranchNode(ctx, branchNode, nCtx, nl) } -func (b *branchHandler) getExecutionContextForDownstream(nCtx handler.NodeExecutionContext) (executors.ExecutionContext, error) { +func (b *branchHandler) getExecutionContextForDownstream(nCtx interfaces.NodeExecutionContext) (executors.ExecutionContext, error) { newParentInfo, err := common.CreateParentInfo(nCtx.ExecutionContext().GetParentInfo(), nCtx.NodeID(), nCtx.CurrentAttempt()) if err != nil { return nil, err @@ -123,7 +125,7 @@ func (b *branchHandler) getExecutionContextForDownstream(nCtx handler.NodeExecut return executors.NewExecutionContextWithParentInfo(nCtx.ExecutionContext(), newParentInfo), nil } -func (b *branchHandler) recurseDownstream(ctx context.Context, nCtx handler.NodeExecutionContext, nodeStatus v1alpha1.ExecutableNodeStatus, branchTakenNode v1alpha1.ExecutableNode) (handler.Transition, error) { +func (b *branchHandler) recurseDownstream(ctx context.Context, nCtx interfaces.NodeExecutionContext, nodeStatus v1alpha1.ExecutableNodeStatus, branchTakenNode v1alpha1.ExecutableNode) (handler.Transition, error) { // TODO we should replace the call to RecursiveNodeHandler with a call to SingleNode Handler. The inputs are also already known ahead of time // There is no DAGStructure for the branch nodes, the branch taken node is the leaf node. The node itself may be arbitrarily complex, but in that case the node should reference a subworkflow etc // The parent of the BranchTaken Node is the actual Branch Node and all the data is just forwarded from the Branch to the executed node. @@ -167,7 +169,7 @@ func (b *branchHandler) recurseDownstream(ctx context.Context, nCtx handler.Node return handler.DoTransition(handler.TransitionTypeEphemeral, phase), nil } -func (b *branchHandler) Abort(ctx context.Context, nCtx handler.NodeExecutionContext, reason string) error { +func (b *branchHandler) Abort(ctx context.Context, nCtx interfaces.NodeExecutionContext, reason string) error { branch := nCtx.Node().GetBranchNode() if branch == nil { @@ -175,7 +177,7 @@ func (b *branchHandler) Abort(ctx context.Context, nCtx handler.NodeExecutionCon } // If the branch was already evaluated i.e, Node is in Running status - branchNodeState := nCtx.NodeStateReader().GetBranchNode() + branchNodeState := nCtx.NodeStateReader().GetBranchNodeState() if branchNodeState.Phase == v1alpha1.BranchNodeNotYetEvaluated { logger.Errorf(ctx, "No node finalized through previous branch evaluation.") return nil @@ -212,14 +214,14 @@ func (b *branchHandler) Abort(ctx context.Context, nCtx handler.NodeExecutionCon return b.nodeExecutor.AbortHandler(ctx, execContext, dag, nCtx.ContextualNodeLookup(), branchTakenNode, reason) } -func (b *branchHandler) Finalize(ctx context.Context, nCtx handler.NodeExecutionContext) error { +func (b *branchHandler) Finalize(ctx context.Context, nCtx interfaces.NodeExecutionContext) error { branch := nCtx.Node().GetBranchNode() if branch == nil { return errors.Errorf(errors.IllegalStateError, nCtx.NodeID(), "Invoked branch handler, for a non branch node.") } // If the branch was already evaluated i.e, Node is in Running status - branchNodeState := nCtx.NodeStateReader().GetBranchNode() + branchNodeState := nCtx.NodeStateReader().GetBranchNodeState() if branchNodeState.Phase == v1alpha1.BranchNodeNotYetEvaluated { logger.Errorf(ctx, "No node finalized through previous branch evaluation.") return nil @@ -256,7 +258,7 @@ func (b *branchHandler) Finalize(ctx context.Context, nCtx handler.NodeExecution return b.nodeExecutor.FinalizeHandler(ctx, execContext, dag, nCtx.ContextualNodeLookup(), branchTakenNode) } -func New(executor executors.Node, eventConfig *config.EventConfig, scope promutils.Scope) handler.Node { +func New(executor interfaces.Node, eventConfig *config.EventConfig, scope promutils.Scope) interfaces.NodeHandler { return &branchHandler{ nodeExecutor: executor, m: metrics{scope: scope}, diff --git a/flytepropeller/pkg/controller/nodes/branch/handler_test.go b/flytepropeller/pkg/controller/nodes/branch/handler_test.go index 5711de5d42..774c4a4596 100644 --- a/flytepropeller/pkg/controller/nodes/branch/handler_test.go +++ b/flytepropeller/pkg/controller/nodes/branch/handler_test.go @@ -26,7 +26,8 @@ import ( "github.com/flyteorg/flytepropeller/pkg/controller/executors" execMocks "github.com/flyteorg/flytepropeller/pkg/controller/executors/mocks" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler/mocks" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces/mocks" ) var eventConfig = &config.EventConfig{ @@ -37,6 +38,9 @@ type branchNodeStateHolder struct { s handler.BranchNodeState } +func (t *branchNodeStateHolder) ClearNodeStatus() { +} + func (t *branchNodeStateHolder) PutTaskNodeState(s handler.TaskNodeState) error { panic("not implemented") } @@ -58,6 +62,10 @@ func (t branchNodeStateHolder) PutGateNodeState(s handler.GateNodeState) error { panic("not implemented") } +func (t branchNodeStateHolder) PutArrayNodeState(s handler.ArrayNodeState) error { + panic("not implemented") +} + type parentInfo struct { } @@ -119,7 +127,7 @@ func createNodeContext(phase v1alpha1.BranchNodePhase, childNodeID *v1alpha1.Nod nCtx.OnEnqueueOwnerFunc().Return(nil) nr := &mocks.NodeStateReader{} - nr.OnGetBranchNode().Return(handler.BranchNodeState{ + nr.OnGetBranchNodeState().Return(handler.BranchNodeState{ FinalizedNodeID: childNodeID, Phase: phase, }) @@ -151,7 +159,7 @@ func TestBranchHandler_RecurseDownstream(t *testing.T) { tests := []struct { name string - ns executors.NodeStatus + ns interfaces.NodeStatus err error nodeStatus *mocks2.ExecutableNodeStatus branchTakenNode v1alpha1.ExecutableNode @@ -160,17 +168,17 @@ func TestBranchHandler_RecurseDownstream(t *testing.T) { childPhase v1alpha1.NodePhase upstreamNodeID string }{ - {"upstreamNodeExists", executors.NodeStatusPending, nil, + {"upstreamNodeExists", interfaces.NodeStatusPending, nil, &mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseRunning, v1alpha1.NodePhaseQueued, "n2"}, - {"childNodeError", executors.NodeStatusUndefined, fmt.Errorf("err"), + {"childNodeError", interfaces.NodeStatusUndefined, fmt.Errorf("err"), &mocks2.ExecutableNodeStatus{}, bn, true, handler.EPhaseUndefined, v1alpha1.NodePhaseFailed, ""}, - {"childPending", executors.NodeStatusPending, nil, + {"childPending", interfaces.NodeStatusPending, nil, &mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseRunning, v1alpha1.NodePhaseQueued, ""}, - {"childStillRunning", executors.NodeStatusRunning, nil, + {"childStillRunning", interfaces.NodeStatusRunning, nil, &mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseRunning, v1alpha1.NodePhaseRunning, ""}, - {"childFailure", executors.NodeStatusFailed(expectedError), nil, + {"childFailure", interfaces.NodeStatusFailed(expectedError), nil, &mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseFailed, v1alpha1.NodePhaseFailed, ""}, - {"childComplete", executors.NodeStatusComplete, nil, + {"childComplete", interfaces.NodeStatusComplete, nil, &mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseSuccess, v1alpha1.NodePhaseSucceeded, ""}, } for _, test := range tests { @@ -188,7 +196,7 @@ func TestBranchHandler_RecurseDownstream(t *testing.T) { nCtx, _ := createNodeContext(v1alpha1.BranchNodeNotYetEvaluated, &childNodeID, n, nil, mockNodeLookup, eCtx) newParentInfo, _ := common.CreateParentInfo(parentInfo{}, nCtx.NodeID(), nCtx.CurrentAttempt()) expectedExecContext := executors.NewExecutionContextWithParentInfo(nCtx.ExecutionContext(), newParentInfo) - mockNodeExecutor := &execMocks.Node{} + mockNodeExecutor := &mocks.Node{} mockNodeExecutor.OnRecursiveNodeHandlerMatch( mock.Anything, // ctx mock.MatchedBy(func(e executors.ExecutionContext) bool { return assert.Equal(t, e, expectedExecContext) }), @@ -295,7 +303,7 @@ func TestBranchHandler_AbortNode(t *testing.T) { assert.NotNil(t, w) t.Run("NoBranchNode", func(t *testing.T) { - mockNodeExecutor := &execMocks.Node{} + mockNodeExecutor := &mocks.Node{} mockNodeExecutor.OnAbortHandlerMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(fmt.Errorf("err")) @@ -308,7 +316,7 @@ func TestBranchHandler_AbortNode(t *testing.T) { }) t.Run("BranchNodeSuccess", func(t *testing.T) { - mockNodeExecutor := &execMocks.Node{} + mockNodeExecutor := &mocks.Node{} mockNodeLookup := &execMocks.NodeLookup{} mockNodeLookup.OnToNodeMatch(mock.Anything).Return(nil, nil) eCtx := &execMocks.ExecutionContext{} @@ -329,7 +337,7 @@ func TestBranchHandler_AbortNode(t *testing.T) { func TestBranchHandler_Initialize(t *testing.T) { ctx := context.TODO() - mockNodeExecutor := &execMocks.Node{} + mockNodeExecutor := &mocks.Node{} branch := New(mockNodeExecutor, eventConfig, promutils.NewTestScope()) assert.NoError(t, branch.Setup(ctx, nil)) } @@ -337,7 +345,7 @@ func TestBranchHandler_Initialize(t *testing.T) { // TODO incomplete test suite, add more func TestBranchHandler_HandleNode(t *testing.T) { ctx := context.TODO() - mockNodeExecutor := &execMocks.Node{} + mockNodeExecutor := &mocks.Node{} branch := New(mockNodeExecutor, eventConfig, promutils.NewTestScope()) childNodeID := "child" childDatadir := v1alpha1.DataReference("test") diff --git a/flytepropeller/pkg/controller/nodes/dynamic/dynamic_workflow.go b/flytepropeller/pkg/controller/nodes/dynamic/dynamic_workflow.go index 6166d8722b..4612ab803f 100644 --- a/flytepropeller/pkg/controller/nodes/dynamic/dynamic_workflow.go +++ b/flytepropeller/pkg/controller/nodes/dynamic/dynamic_workflow.go @@ -5,23 +5,23 @@ import ( "fmt" "strconv" - "k8s.io/apimachinery/pkg/util/rand" - - node_common "github.com/flyteorg/flytepropeller/pkg/controller/nodes/common" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" "github.com/flyteorg/flytepropeller/pkg/compiler" "github.com/flyteorg/flytepropeller/pkg/compiler/common" "github.com/flyteorg/flytepropeller/pkg/compiler/transformers/k8s" "github.com/flyteorg/flytepropeller/pkg/controller/executors" + node_common "github.com/flyteorg/flytepropeller/pkg/controller/nodes/common" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/subworkflow/launchplan" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task" "github.com/flyteorg/flytepropeller/pkg/utils" "github.com/flyteorg/flytestdlib/errors" "github.com/flyteorg/flytestdlib/logger" "github.com/flyteorg/flytestdlib/storage" + + "k8s.io/apimachinery/pkg/util/rand" ) type dynamicWorkflowContext struct { @@ -36,7 +36,7 @@ type dynamicWorkflowContext struct { const dynamicWfNameTemplate = "dynamic_%s" func setEphemeralNodeExecutionStatusAttributes(ctx context.Context, djSpec *core.DynamicJobSpec, - nCtx handler.NodeExecutionContext, parentNodeStatus v1alpha1.ExecutableNodeStatus) error { + nCtx interfaces.NodeExecutionContext, parentNodeStatus v1alpha1.ExecutableNodeStatus) error { if nCtx.ExecutionContext().GetEventVersion() != v1alpha1.EventVersion0 { return nil } @@ -77,7 +77,7 @@ func setEphemeralNodeExecutionStatusAttributes(ctx context.Context, djSpec *core } func (d dynamicNodeTaskNodeHandler) buildDynamicWorkflowTemplate(ctx context.Context, djSpec *core.DynamicJobSpec, - nCtx handler.NodeExecutionContext, parentNodeStatus v1alpha1.ExecutableNodeStatus) (*core.WorkflowTemplate, error) { + nCtx interfaces.NodeExecutionContext, parentNodeStatus v1alpha1.ExecutableNodeStatus) (*core.WorkflowTemplate, error) { iface, err := underlyingInterface(ctx, nCtx.TaskReader()) if err != nil { @@ -127,7 +127,7 @@ func (d dynamicNodeTaskNodeHandler) buildDynamicWorkflowTemplate(ctx context.Con }, nil } -func (d dynamicNodeTaskNodeHandler) buildContextualDynamicWorkflow(ctx context.Context, nCtx handler.NodeExecutionContext) (dynamicWorkflowContext, error) { +func (d dynamicNodeTaskNodeHandler) buildContextualDynamicWorkflow(ctx context.Context, nCtx interfaces.NodeExecutionContext) (dynamicWorkflowContext, error) { t := d.metrics.buildDynamicWorkflow.Start(ctx) defer t.Stop() @@ -221,7 +221,7 @@ func (d dynamicNodeTaskNodeHandler) buildContextualDynamicWorkflow(ctx context.C }, nil } -func (d dynamicNodeTaskNodeHandler) buildDynamicWorkflow(ctx context.Context, nCtx handler.NodeExecutionContext, +func (d dynamicNodeTaskNodeHandler) buildDynamicWorkflow(ctx context.Context, nCtx interfaces.NodeExecutionContext, djSpec *core.DynamicJobSpec, dynamicNodeStatus v1alpha1.ExecutableNodeStatus) (*core.CompiledWorkflowClosure, *v1alpha1.FlyteWorkflow, error) { wf, err := d.buildDynamicWorkflowTemplate(ctx, djSpec, nCtx, dynamicNodeStatus) if err != nil { @@ -265,7 +265,7 @@ func (d dynamicNodeTaskNodeHandler) buildDynamicWorkflow(ctx context.Context, nC } func (d dynamicNodeTaskNodeHandler) progressDynamicWorkflow(ctx context.Context, execContext executors.ExecutionContext, dynamicWorkflow v1alpha1.ExecutableWorkflow, nl executors.NodeLookup, - nCtx handler.NodeExecutionContext, prevState handler.DynamicNodeState) (handler.Transition, handler.DynamicNodeState, error) { + nCtx interfaces.NodeExecutionContext, prevState handler.DynamicNodeState) (handler.Transition, handler.DynamicNodeState, error) { state, err := d.nodeExecutor.RecursiveNodeHandler(ctx, execContext, dynamicWorkflow, nl, dynamicWorkflow.StartNode()) if err != nil { diff --git a/flytepropeller/pkg/controller/nodes/dynamic/dynamic_workflow_test.go b/flytepropeller/pkg/controller/nodes/dynamic/dynamic_workflow_test.go index 28abfeea32..643535fd58 100644 --- a/flytepropeller/pkg/controller/nodes/dynamic/dynamic_workflow_test.go +++ b/flytepropeller/pkg/controller/nodes/dynamic/dynamic_workflow_test.go @@ -25,7 +25,7 @@ import ( mocks4 "github.com/flyteorg/flytepropeller/pkg/controller/executors/mocks" mocks6 "github.com/flyteorg/flytepropeller/pkg/controller/nodes/dynamic/mocks" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler/mocks" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces/mocks" mocks5 "github.com/flyteorg/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/mocks" ) @@ -183,7 +183,7 @@ func Test_dynamicNodeHandler_buildContextualDynamicWorkflow_withLaunchPlans(t *t }, }, nil) h := &mocks6.TaskNodeHandler{} - n := &mocks4.Node{} + n := &mocks.Node{} d := dynamicNodeTaskNodeHandler{ TaskNodeHandler: h, nodeExecutor: n, @@ -255,7 +255,7 @@ func Test_dynamicNodeHandler_buildContextualDynamicWorkflow_withLaunchPlans(t *t }, }, nil) h := &mocks6.TaskNodeHandler{} - n := &mocks4.Node{} + n := &mocks.Node{} d := dynamicNodeTaskNodeHandler{ TaskNodeHandler: h, nodeExecutor: n, @@ -324,7 +324,7 @@ func Test_dynamicNodeHandler_buildContextualDynamicWorkflow_withLaunchPlans(t *t }, }, nil) h := &mocks6.TaskNodeHandler{} - n := &mocks4.Node{} + n := &mocks.Node{} d := dynamicNodeTaskNodeHandler{ TaskNodeHandler: h, nodeExecutor: n, @@ -407,7 +407,7 @@ func Test_dynamicNodeHandler_buildContextualDynamicWorkflow_withLaunchPlans(t *t }, }, nil) h := &mocks6.TaskNodeHandler{} - n := &mocks4.Node{} + n := &mocks.Node{} d := dynamicNodeTaskNodeHandler{ TaskNodeHandler: h, nodeExecutor: n, @@ -461,7 +461,7 @@ func Test_dynamicNodeHandler_buildContextualDynamicWorkflow_withLaunchPlans(t *t mockLPLauncher := &mocks5.Reader{} h := &mocks6.TaskNodeHandler{} - n := &mocks4.Node{} + n := &mocks.Node{} d := dynamicNodeTaskNodeHandler{ TaskNodeHandler: h, nodeExecutor: n, @@ -550,7 +550,7 @@ func Test_dynamicNodeHandler_buildContextualDynamicWorkflow_withLaunchPlans(t *t }, nil) h := &mocks6.TaskNodeHandler{} - n := &mocks4.Node{} + n := &mocks.Node{} d := dynamicNodeTaskNodeHandler{ TaskNodeHandler: h, nodeExecutor: n, diff --git a/flytepropeller/pkg/controller/nodes/dynamic/handler.go b/flytepropeller/pkg/controller/nodes/dynamic/handler.go index f027cca0e4..369b26ee65 100644 --- a/flytepropeller/pkg/controller/nodes/dynamic/handler.go +++ b/flytepropeller/pkg/controller/nodes/dynamic/handler.go @@ -4,28 +4,26 @@ import ( "context" "time" - "github.com/flyteorg/flytepropeller/pkg/controller/config" - + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/catalog" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/ioutils" - "github.com/flyteorg/flytestdlib/logger" - "github.com/flyteorg/flytestdlib/promutils/labeled" + "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/flyteorg/flytepropeller/pkg/controller/config" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/errors" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/subworkflow/launchplan" - "github.com/flyteorg/flytepropeller/pkg/utils" - - "github.com/flyteorg/flytepropeller/pkg/controller/executors" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task" + "github.com/flyteorg/flytepropeller/pkg/utils" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" stdErrors "github.com/flyteorg/flytestdlib/errors" + "github.com/flyteorg/flytestdlib/logger" "github.com/flyteorg/flytestdlib/promutils" - - "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/errors" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + "github.com/flyteorg/flytestdlib/promutils/labeled" ) //go:generate mockery -all -case=underscore @@ -33,7 +31,7 @@ import ( const dynamicNodeID = "dynamic-node" type TaskNodeHandler interface { - handler.Node + interfaces.NodeHandler ValidateOutputAndCacheAdd(ctx context.Context, nodeID v1alpha1.NodeID, i io.InputReader, r io.OutputReader, outputCommitter io.OutputWriter, executionConfig v1alpha1.ExecutionConfig, tr ioutils.SimpleTaskReader, m catalog.Metadata) (catalog.Status, *io.ExecutionError, error) @@ -60,12 +58,12 @@ func newMetrics(scope promutils.Scope) metrics { type dynamicNodeTaskNodeHandler struct { TaskNodeHandler metrics metrics - nodeExecutor executors.Node + nodeExecutor interfaces.Node lpReader launchplan.Reader eventConfig *config.EventConfig } -func (d dynamicNodeTaskNodeHandler) handleParentNode(ctx context.Context, prevState handler.DynamicNodeState, nCtx handler.NodeExecutionContext) (handler.Transition, handler.DynamicNodeState, error) { +func (d dynamicNodeTaskNodeHandler) handleParentNode(ctx context.Context, prevState handler.DynamicNodeState, nCtx interfaces.NodeExecutionContext) (handler.Transition, handler.DynamicNodeState, error) { // It seems parent node is still running, lets call handle for parent node trns, err := d.TaskNodeHandler.Handle(ctx, nCtx) if err != nil { @@ -95,7 +93,7 @@ func (d dynamicNodeTaskNodeHandler) handleParentNode(ctx context.Context, prevSt return trns, prevState, nil } -func (d dynamicNodeTaskNodeHandler) produceDynamicWorkflow(ctx context.Context, nCtx handler.NodeExecutionContext) ( +func (d dynamicNodeTaskNodeHandler) produceDynamicWorkflow(ctx context.Context, nCtx interfaces.NodeExecutionContext) ( handler.Transition, handler.DynamicNodeState, error) { // The first time this is called we go ahead and evaluate the dynamic node to build the workflow. We then cache // this workflow definition and send it to be persisted by flyteadmin so that users can observe the structure. @@ -125,7 +123,7 @@ func (d dynamicNodeTaskNodeHandler) produceDynamicWorkflow(ctx context.Context, })), nextState, nil } -func (d dynamicNodeTaskNodeHandler) handleDynamicSubNodes(ctx context.Context, nCtx handler.NodeExecutionContext, prevState handler.DynamicNodeState) (handler.Transition, handler.DynamicNodeState, error) { +func (d dynamicNodeTaskNodeHandler) handleDynamicSubNodes(ctx context.Context, nCtx interfaces.NodeExecutionContext, prevState handler.DynamicNodeState) (handler.Transition, handler.DynamicNodeState, error) { dCtx, err := d.buildContextualDynamicWorkflow(ctx, nCtx) if err != nil { if stdErrors.IsCausedBy(err, utils.ErrorCodeUser) { @@ -182,7 +180,7 @@ func (d dynamicNodeTaskNodeHandler) handleDynamicSubNodes(ctx context.Context, n // DynamicNodePhaseParentFinalized: The parent has node completed successfully and the generated dynamic sub workflow has been serialized and sent as an event. // DynamicNodePhaseExecuting: The parent node has completed and finalized successfully, the sub-nodes are being handled // DynamicNodePhaseFailing: one or more of sub-nodes have failed and the failure is being handled -func (d dynamicNodeTaskNodeHandler) Handle(ctx context.Context, nCtx handler.NodeExecutionContext) (handler.Transition, error) { +func (d dynamicNodeTaskNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecutionContext) (handler.Transition, error) { ds := nCtx.NodeStateReader().GetDynamicNodeState() var err error var trns handler.Transition @@ -242,7 +240,7 @@ func (d dynamicNodeTaskNodeHandler) Handle(ctx context.Context, nCtx handler.Nod return trns, nil } -func (d dynamicNodeTaskNodeHandler) Abort(ctx context.Context, nCtx handler.NodeExecutionContext, reason string) error { +func (d dynamicNodeTaskNodeHandler) Abort(ctx context.Context, nCtx interfaces.NodeExecutionContext, reason string) error { ds := nCtx.NodeStateReader().GetDynamicNodeState() switch ds.Phase { case v1alpha1.DynamicNodePhaseFailing: @@ -269,7 +267,7 @@ func (d dynamicNodeTaskNodeHandler) Abort(ctx context.Context, nCtx handler.Node } } -func (d dynamicNodeTaskNodeHandler) finalizeParentNode(ctx context.Context, nCtx handler.NodeExecutionContext) error { +func (d dynamicNodeTaskNodeHandler) finalizeParentNode(ctx context.Context, nCtx interfaces.NodeExecutionContext) error { logger.Infof(ctx, "Finalizing Parent node RetryAttempt [%d]", nCtx.CurrentAttempt()) if err := d.TaskNodeHandler.Finalize(ctx, nCtx); err != nil { logger.Errorf(ctx, "Failed to finalize Dynamic Nodes Parent.") @@ -279,7 +277,7 @@ func (d dynamicNodeTaskNodeHandler) finalizeParentNode(ctx context.Context, nCtx } // This is a weird method. We should always finalize before we set the dynamic parent node phase as complete? -func (d dynamicNodeTaskNodeHandler) Finalize(ctx context.Context, nCtx handler.NodeExecutionContext) error { +func (d dynamicNodeTaskNodeHandler) Finalize(ctx context.Context, nCtx interfaces.NodeExecutionContext) error { errs := make([]error, 0, 2) ds := nCtx.NodeStateReader().GetDynamicNodeState() @@ -312,7 +310,7 @@ func (d dynamicNodeTaskNodeHandler) Finalize(ctx context.Context, nCtx handler.N return nil } -func New(underlying TaskNodeHandler, nodeExecutor executors.Node, launchPlanReader launchplan.Reader, eventConfig *config.EventConfig, scope promutils.Scope) handler.Node { +func New(underlying TaskNodeHandler, nodeExecutor interfaces.Node, launchPlanReader launchplan.Reader, eventConfig *config.EventConfig, scope promutils.Scope) interfaces.NodeHandler { return &dynamicNodeTaskNodeHandler{ TaskNodeHandler: underlying, diff --git a/flytepropeller/pkg/controller/nodes/dynamic/handler_test.go b/flytepropeller/pkg/controller/nodes/dynamic/handler_test.go index 27bc2d935f..ae0bb6912c 100644 --- a/flytepropeller/pkg/controller/nodes/dynamic/handler_test.go +++ b/flytepropeller/pkg/controller/nodes/dynamic/handler_test.go @@ -28,17 +28,20 @@ import ( "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" flyteMocks "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks" - "github.com/flyteorg/flytepropeller/pkg/controller/executors" executorMocks "github.com/flyteorg/flytepropeller/pkg/controller/executors/mocks" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/dynamic/mocks" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" - nodeMocks "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler/mocks" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" + nodeMocks "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces/mocks" ) type dynamicNodeStateHolder struct { s handler.DynamicNodeState } +func (t *dynamicNodeStateHolder) ClearNodeStatus() { +} + func (t *dynamicNodeStateHolder) PutTaskNodeState(s handler.TaskNodeState) error { panic("not implemented") } @@ -60,6 +63,10 @@ func (t dynamicNodeStateHolder) PutGateNodeState(s handler.GateNodeState) error panic("not implemented") } +func (t dynamicNodeStateHolder) PutArrayNodeState(s handler.ArrayNodeState) error { + panic("not implemented") +} + var tID = "task-1" var eventConfig = &config.EventConfig{ @@ -186,7 +193,7 @@ func Test_dynamicNodeHandler_Handle_Parent(t *testing.T) { } h := &mocks.TaskNodeHandler{} mockLPLauncher := &lpMocks.Reader{} - n := &executorMocks.Node{} + n := &nodeMocks.Node{} if tt.args.isErr { h.OnHandleMatch(mock.Anything, mock.Anything).Return(handler.UnknownTransition, fmt.Errorf("error")) } else { @@ -300,7 +307,7 @@ func Test_dynamicNodeHandler_Handle_ParentFinalize(t *testing.T) { assert.NoError(t, err) dj := &core.DynamicJobSpec{} mockLPLauncher := &lpMocks.Reader{} - n := &executorMocks.Node{} + n := &nodeMocks.Node{} assert.NoError(t, nCtx.DataStore().WriteProtobuf(context.TODO(), f, storage.Options{}, dj)) h := &mocks.TaskNodeHandler{} h.OnFinalizeMatch(mock.Anything, mock.Anything).Return(nil) @@ -320,7 +327,7 @@ func Test_dynamicNodeHandler_Handle_ParentFinalize(t *testing.T) { assert.NoError(t, err) dj := &core.DynamicJobSpec{} mockLPLauncher := &lpMocks.Reader{} - n := &executorMocks.Node{} + n := &nodeMocks.Node{} assert.NoError(t, nCtx.DataStore().WriteProtobuf(context.TODO(), f, storage.Options{}, dj)) h := &mocks.TaskNodeHandler{} h.OnFinalizeMatch(mock.Anything, mock.Anything).Return(fmt.Errorf("err")) @@ -546,7 +553,7 @@ func Test_dynamicNodeHandler_Handle_SubTaskV1(t *testing.T) { } type args struct { - s executors.NodeStatus + s interfaces.NodeStatus isErr bool dj *core.DynamicJobSpec validErr *io.ExecutionError @@ -565,15 +572,15 @@ func Test_dynamicNodeHandler_Handle_SubTaskV1(t *testing.T) { want want }{ {"error", args{isErr: true, dj: createDynamicJobSpec()}, want{isErr: true}}, - {"success", args{s: executors.NodeStatusSuccess, dj: createDynamicJobSpec()}, want{p: handler.EPhaseDynamicRunning, phase: v1alpha1.DynamicNodePhaseExecuting}}, - {"complete", args{s: executors.NodeStatusComplete, dj: createDynamicJobSpec(), generateOutputs: true, validCacheStatus: &validCachePopulatedStatus}, want{p: handler.EPhaseSuccess, phase: v1alpha1.DynamicNodePhaseExecuting, info: execInfoWithTaskNodeMeta}}, - {"complete-no-outputs", args{s: executors.NodeStatusComplete, dj: createDynamicJobSpec(), generateOutputs: false}, want{p: handler.EPhaseRetryableFailure, phase: v1alpha1.DynamicNodePhaseFailing}}, - {"complete-valid-error-retryable", args{s: executors.NodeStatusComplete, dj: createDynamicJobSpec(), validErr: &io.ExecutionError{IsRecoverable: true}, generateOutputs: true}, want{p: handler.EPhaseRetryableFailure, phase: v1alpha1.DynamicNodePhaseFailing, info: execInfoOutputOnly}}, - {"complete-valid-error", args{s: executors.NodeStatusComplete, dj: createDynamicJobSpec(), validErr: &io.ExecutionError{}, generateOutputs: true}, want{p: handler.EPhaseFailed, phase: v1alpha1.DynamicNodePhaseFailing, info: execInfoOutputOnly}}, - {"failed", args{s: executors.NodeStatusFailed(&core.ExecutionError{}), dj: createDynamicJobSpec()}, want{p: handler.EPhaseDynamicRunning, phase: v1alpha1.DynamicNodePhaseFailing}}, - {"running", args{s: executors.NodeStatusRunning, dj: createDynamicJobSpec()}, want{p: handler.EPhaseDynamicRunning, phase: v1alpha1.DynamicNodePhaseExecuting}}, - {"running-valid-err", args{s: executors.NodeStatusRunning, dj: createDynamicJobSpec(), validErr: &io.ExecutionError{}}, want{p: handler.EPhaseDynamicRunning, phase: v1alpha1.DynamicNodePhaseExecuting}}, - {"queued", args{s: executors.NodeStatusQueued, dj: createDynamicJobSpec()}, want{p: handler.EPhaseDynamicRunning, phase: v1alpha1.DynamicNodePhaseExecuting}}, + {"success", args{s: interfaces.NodeStatusSuccess, dj: createDynamicJobSpec()}, want{p: handler.EPhaseDynamicRunning, phase: v1alpha1.DynamicNodePhaseExecuting}}, + {"complete", args{s: interfaces.NodeStatusComplete, dj: createDynamicJobSpec(), generateOutputs: true, validCacheStatus: &validCachePopulatedStatus}, want{p: handler.EPhaseSuccess, phase: v1alpha1.DynamicNodePhaseExecuting, info: execInfoWithTaskNodeMeta}}, + {"complete-no-outputs", args{s: interfaces.NodeStatusComplete, dj: createDynamicJobSpec(), generateOutputs: false}, want{p: handler.EPhaseRetryableFailure, phase: v1alpha1.DynamicNodePhaseFailing}}, + {"complete-valid-error-retryable", args{s: interfaces.NodeStatusComplete, dj: createDynamicJobSpec(), validErr: &io.ExecutionError{IsRecoverable: true}, generateOutputs: true}, want{p: handler.EPhaseRetryableFailure, phase: v1alpha1.DynamicNodePhaseFailing, info: execInfoOutputOnly}}, + {"complete-valid-error", args{s: interfaces.NodeStatusComplete, dj: createDynamicJobSpec(), validErr: &io.ExecutionError{}, generateOutputs: true}, want{p: handler.EPhaseFailed, phase: v1alpha1.DynamicNodePhaseFailing, info: execInfoOutputOnly}}, + {"failed", args{s: interfaces.NodeStatusFailed(&core.ExecutionError{}), dj: createDynamicJobSpec()}, want{p: handler.EPhaseDynamicRunning, phase: v1alpha1.DynamicNodePhaseFailing}}, + {"running", args{s: interfaces.NodeStatusRunning, dj: createDynamicJobSpec()}, want{p: handler.EPhaseDynamicRunning, phase: v1alpha1.DynamicNodePhaseExecuting}}, + {"running-valid-err", args{s: interfaces.NodeStatusRunning, dj: createDynamicJobSpec(), validErr: &io.ExecutionError{}}, want{p: handler.EPhaseDynamicRunning, phase: v1alpha1.DynamicNodePhaseExecuting}}, + {"queued", args{s: interfaces.NodeStatusQueued, dj: createDynamicJobSpec()}, want{p: handler.EPhaseDynamicRunning, phase: v1alpha1.DynamicNodePhaseExecuting}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -601,9 +608,9 @@ func Test_dynamicNodeHandler_Handle_SubTaskV1(t *testing.T) { } h.OnValidateOutputAndCacheAddMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(validCacheStatus, nil, nil) } - n := &executorMocks.Node{} + n := &nodeMocks.Node{} if tt.args.isErr { - n.OnRecursiveNodeHandlerMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(executors.NodeStatusUndefined, fmt.Errorf("error")) + n.OnRecursiveNodeHandlerMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(interfaces.NodeStatusUndefined, fmt.Errorf("error")) } else { n.OnRecursiveNodeHandlerMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tt.args.s, nil) } @@ -747,7 +754,7 @@ func Test_dynamicNodeHandler_Handle_SubTask(t *testing.T) { } type args struct { - s executors.NodeStatus + s interfaces.NodeStatus isErr bool dj *core.DynamicJobSpec validErr *io.ExecutionError @@ -764,15 +771,15 @@ func Test_dynamicNodeHandler_Handle_SubTask(t *testing.T) { want want }{ {"error", args{isErr: true, dj: createDynamicJobSpec()}, want{isErr: true}}, - {"success", args{s: executors.NodeStatusSuccess, dj: createDynamicJobSpec()}, want{p: handler.EPhaseDynamicRunning, phase: v1alpha1.DynamicNodePhaseExecuting}}, - {"complete", args{s: executors.NodeStatusComplete, dj: createDynamicJobSpec(), generateOutputs: true}, want{p: handler.EPhaseSuccess, phase: v1alpha1.DynamicNodePhaseExecuting}}, - {"complete-no-outputs", args{s: executors.NodeStatusComplete, dj: createDynamicJobSpec(), generateOutputs: false}, want{p: handler.EPhaseRetryableFailure, phase: v1alpha1.DynamicNodePhaseFailing}}, - {"complete-valid-error-retryable", args{s: executors.NodeStatusComplete, dj: createDynamicJobSpec(), validErr: &io.ExecutionError{IsRecoverable: true}, generateOutputs: true}, want{p: handler.EPhaseRetryableFailure, phase: v1alpha1.DynamicNodePhaseFailing}}, - {"complete-valid-error", args{s: executors.NodeStatusComplete, dj: createDynamicJobSpec(), validErr: &io.ExecutionError{}, generateOutputs: true}, want{p: handler.EPhaseFailed, phase: v1alpha1.DynamicNodePhaseFailing}}, - {"failed", args{s: executors.NodeStatusFailed(&core.ExecutionError{}), dj: createDynamicJobSpec()}, want{p: handler.EPhaseDynamicRunning, phase: v1alpha1.DynamicNodePhaseFailing}}, - {"running", args{s: executors.NodeStatusRunning, dj: createDynamicJobSpec()}, want{p: handler.EPhaseDynamicRunning, phase: v1alpha1.DynamicNodePhaseExecuting}}, - {"running-valid-err", args{s: executors.NodeStatusRunning, dj: createDynamicJobSpec(), validErr: &io.ExecutionError{}}, want{p: handler.EPhaseDynamicRunning, phase: v1alpha1.DynamicNodePhaseExecuting}}, - {"queued", args{s: executors.NodeStatusQueued, dj: createDynamicJobSpec()}, want{p: handler.EPhaseDynamicRunning, phase: v1alpha1.DynamicNodePhaseExecuting}}, + {"success", args{s: interfaces.NodeStatusSuccess, dj: createDynamicJobSpec()}, want{p: handler.EPhaseDynamicRunning, phase: v1alpha1.DynamicNodePhaseExecuting}}, + {"complete", args{s: interfaces.NodeStatusComplete, dj: createDynamicJobSpec(), generateOutputs: true}, want{p: handler.EPhaseSuccess, phase: v1alpha1.DynamicNodePhaseExecuting}}, + {"complete-no-outputs", args{s: interfaces.NodeStatusComplete, dj: createDynamicJobSpec(), generateOutputs: false}, want{p: handler.EPhaseRetryableFailure, phase: v1alpha1.DynamicNodePhaseFailing}}, + {"complete-valid-error-retryable", args{s: interfaces.NodeStatusComplete, dj: createDynamicJobSpec(), validErr: &io.ExecutionError{IsRecoverable: true}, generateOutputs: true}, want{p: handler.EPhaseRetryableFailure, phase: v1alpha1.DynamicNodePhaseFailing}}, + {"complete-valid-error", args{s: interfaces.NodeStatusComplete, dj: createDynamicJobSpec(), validErr: &io.ExecutionError{}, generateOutputs: true}, want{p: handler.EPhaseFailed, phase: v1alpha1.DynamicNodePhaseFailing}}, + {"failed", args{s: interfaces.NodeStatusFailed(&core.ExecutionError{}), dj: createDynamicJobSpec()}, want{p: handler.EPhaseDynamicRunning, phase: v1alpha1.DynamicNodePhaseFailing}}, + {"running", args{s: interfaces.NodeStatusRunning, dj: createDynamicJobSpec()}, want{p: handler.EPhaseDynamicRunning, phase: v1alpha1.DynamicNodePhaseExecuting}}, + {"running-valid-err", args{s: interfaces.NodeStatusRunning, dj: createDynamicJobSpec(), validErr: &io.ExecutionError{}}, want{p: handler.EPhaseDynamicRunning, phase: v1alpha1.DynamicNodePhaseExecuting}}, + {"queued", args{s: interfaces.NodeStatusQueued, dj: createDynamicJobSpec()}, want{p: handler.EPhaseDynamicRunning, phase: v1alpha1.DynamicNodePhaseExecuting}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -792,9 +799,9 @@ func Test_dynamicNodeHandler_Handle_SubTask(t *testing.T) { } else { h.OnValidateOutputAndCacheAddMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(catalog.NewStatus(core.CatalogCacheStatus_CACHE_HIT, &core.CatalogMetadata{ArtifactTag: &core.CatalogArtifactTag{Name: "name", ArtifactId: "id"}}), nil, nil) } - n := &executorMocks.Node{} + n := &nodeMocks.Node{} if tt.args.isErr { - n.OnRecursiveNodeHandlerMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(executors.NodeStatusUndefined, fmt.Errorf("error")) + n.OnRecursiveNodeHandlerMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(interfaces.NodeStatusUndefined, fmt.Errorf("error")) } else { n.OnRecursiveNodeHandlerMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tt.args.s, nil) } @@ -876,7 +883,7 @@ func TestDynamicNodeTaskNodeHandler_Finalize(t *testing.T) { mockLPLauncher := &lpMocks.Reader{} h := &mocks.TaskNodeHandler{} h.OnFinalize(ctx, nCtx).Return(nil) - n := &executorMocks.Node{} + n := &nodeMocks.Node{} d := New(h, n, mockLPLauncher, eventConfig, promutils.NewTestScope()) assert.NoError(t, d.Finalize(ctx, nCtx)) assert.NotZero(t, len(h.ExpectedCalls)) @@ -1007,7 +1014,7 @@ func TestDynamicNodeTaskNodeHandler_Finalize(t *testing.T) { mockLPLauncher := &lpMocks.Reader{} h := &mocks.TaskNodeHandler{} h.OnFinalize(ctx, nCtx).Return(nil) - n := &executorMocks.Node{} + n := &nodeMocks.Node{} n.OnFinalizeHandlerMatch(ctx, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) d := New(h, n, mockLPLauncher, eventConfig, promutils.NewTestScope()) assert.NoError(t, d.Finalize(ctx, nCtx)) @@ -1028,7 +1035,7 @@ func TestDynamicNodeTaskNodeHandler_Finalize(t *testing.T) { mockLPLauncher := &lpMocks.Reader{} h := &mocks.TaskNodeHandler{} h.OnFinalize(ctx, nCtx).Return(fmt.Errorf("err")) - n := &executorMocks.Node{} + n := &nodeMocks.Node{} n.OnFinalizeHandlerMatch(ctx, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) d := New(h, n, mockLPLauncher, eventConfig, promutils.NewTestScope()) assert.Error(t, d.Finalize(ctx, nCtx)) @@ -1049,7 +1056,7 @@ func TestDynamicNodeTaskNodeHandler_Finalize(t *testing.T) { mockLPLauncher := &lpMocks.Reader{} h := &mocks.TaskNodeHandler{} h.OnFinalize(ctx, nCtx).Return(nil) - n := &executorMocks.Node{} + n := &nodeMocks.Node{} n.OnFinalizeHandlerMatch(ctx, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(fmt.Errorf("err")) d := New(h, n, mockLPLauncher, eventConfig, promutils.NewTestScope()) assert.Error(t, d.Finalize(ctx, nCtx)) diff --git a/flytepropeller/pkg/controller/nodes/dynamic/mocks/task_node_handler.go b/flytepropeller/pkg/controller/nodes/dynamic/mocks/task_node_handler.go index 49936c11d3..e8d8cc6d7c 100644 --- a/flytepropeller/pkg/controller/nodes/dynamic/mocks/task_node_handler.go +++ b/flytepropeller/pkg/controller/nodes/dynamic/mocks/task_node_handler.go @@ -9,6 +9,8 @@ import ( handler "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + interfaces "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" + io "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io" ioutils "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/ioutils" @@ -31,7 +33,7 @@ func (_m TaskNodeHandler_Abort) Return(_a0 error) *TaskNodeHandler_Abort { return &TaskNodeHandler_Abort{Call: _m.Call.Return(_a0)} } -func (_m *TaskNodeHandler) OnAbort(ctx context.Context, executionContext handler.NodeExecutionContext, reason string) *TaskNodeHandler_Abort { +func (_m *TaskNodeHandler) OnAbort(ctx context.Context, executionContext interfaces.NodeExecutionContext, reason string) *TaskNodeHandler_Abort { c_call := _m.On("Abort", ctx, executionContext, reason) return &TaskNodeHandler_Abort{Call: c_call} } @@ -42,11 +44,11 @@ func (_m *TaskNodeHandler) OnAbortMatch(matchers ...interface{}) *TaskNodeHandle } // Abort provides a mock function with given fields: ctx, executionContext, reason -func (_m *TaskNodeHandler) Abort(ctx context.Context, executionContext handler.NodeExecutionContext, reason string) error { +func (_m *TaskNodeHandler) Abort(ctx context.Context, executionContext interfaces.NodeExecutionContext, reason string) error { ret := _m.Called(ctx, executionContext, reason) var r0 error - if rf, ok := ret.Get(0).(func(context.Context, handler.NodeExecutionContext, string) error); ok { + if rf, ok := ret.Get(0).(func(context.Context, interfaces.NodeExecutionContext, string) error); ok { r0 = rf(ctx, executionContext, reason) } else { r0 = ret.Error(0) @@ -63,7 +65,7 @@ func (_m TaskNodeHandler_Finalize) Return(_a0 error) *TaskNodeHandler_Finalize { return &TaskNodeHandler_Finalize{Call: _m.Call.Return(_a0)} } -func (_m *TaskNodeHandler) OnFinalize(ctx context.Context, executionContext handler.NodeExecutionContext) *TaskNodeHandler_Finalize { +func (_m *TaskNodeHandler) OnFinalize(ctx context.Context, executionContext interfaces.NodeExecutionContext) *TaskNodeHandler_Finalize { c_call := _m.On("Finalize", ctx, executionContext) return &TaskNodeHandler_Finalize{Call: c_call} } @@ -74,11 +76,11 @@ func (_m *TaskNodeHandler) OnFinalizeMatch(matchers ...interface{}) *TaskNodeHan } // Finalize provides a mock function with given fields: ctx, executionContext -func (_m *TaskNodeHandler) Finalize(ctx context.Context, executionContext handler.NodeExecutionContext) error { +func (_m *TaskNodeHandler) Finalize(ctx context.Context, executionContext interfaces.NodeExecutionContext) error { ret := _m.Called(ctx, executionContext) var r0 error - if rf, ok := ret.Get(0).(func(context.Context, handler.NodeExecutionContext) error); ok { + if rf, ok := ret.Get(0).(func(context.Context, interfaces.NodeExecutionContext) error); ok { r0 = rf(ctx, executionContext) } else { r0 = ret.Error(0) @@ -127,7 +129,7 @@ func (_m TaskNodeHandler_Handle) Return(_a0 handler.Transition, _a1 error) *Task return &TaskNodeHandler_Handle{Call: _m.Call.Return(_a0, _a1)} } -func (_m *TaskNodeHandler) OnHandle(ctx context.Context, executionContext handler.NodeExecutionContext) *TaskNodeHandler_Handle { +func (_m *TaskNodeHandler) OnHandle(ctx context.Context, executionContext interfaces.NodeExecutionContext) *TaskNodeHandler_Handle { c_call := _m.On("Handle", ctx, executionContext) return &TaskNodeHandler_Handle{Call: c_call} } @@ -138,18 +140,18 @@ func (_m *TaskNodeHandler) OnHandleMatch(matchers ...interface{}) *TaskNodeHandl } // Handle provides a mock function with given fields: ctx, executionContext -func (_m *TaskNodeHandler) Handle(ctx context.Context, executionContext handler.NodeExecutionContext) (handler.Transition, error) { +func (_m *TaskNodeHandler) Handle(ctx context.Context, executionContext interfaces.NodeExecutionContext) (handler.Transition, error) { ret := _m.Called(ctx, executionContext) var r0 handler.Transition - if rf, ok := ret.Get(0).(func(context.Context, handler.NodeExecutionContext) handler.Transition); ok { + if rf, ok := ret.Get(0).(func(context.Context, interfaces.NodeExecutionContext) handler.Transition); ok { r0 = rf(ctx, executionContext) } else { r0 = ret.Get(0).(handler.Transition) } var r1 error - if rf, ok := ret.Get(1).(func(context.Context, handler.NodeExecutionContext) error); ok { + if rf, ok := ret.Get(1).(func(context.Context, interfaces.NodeExecutionContext) error); ok { r1 = rf(ctx, executionContext) } else { r1 = ret.Error(1) @@ -166,7 +168,7 @@ func (_m TaskNodeHandler_Setup) Return(_a0 error) *TaskNodeHandler_Setup { return &TaskNodeHandler_Setup{Call: _m.Call.Return(_a0)} } -func (_m *TaskNodeHandler) OnSetup(ctx context.Context, setupContext handler.SetupContext) *TaskNodeHandler_Setup { +func (_m *TaskNodeHandler) OnSetup(ctx context.Context, setupContext interfaces.SetupContext) *TaskNodeHandler_Setup { c_call := _m.On("Setup", ctx, setupContext) return &TaskNodeHandler_Setup{Call: c_call} } @@ -177,11 +179,11 @@ func (_m *TaskNodeHandler) OnSetupMatch(matchers ...interface{}) *TaskNodeHandle } // Setup provides a mock function with given fields: ctx, setupContext -func (_m *TaskNodeHandler) Setup(ctx context.Context, setupContext handler.SetupContext) error { +func (_m *TaskNodeHandler) Setup(ctx context.Context, setupContext interfaces.SetupContext) error { ret := _m.Called(ctx, setupContext) var r0 error - if rf, ok := ret.Get(0).(func(context.Context, handler.SetupContext) error); ok { + if rf, ok := ret.Get(0).(func(context.Context, interfaces.SetupContext) error); ok { r0 = rf(ctx, setupContext) } else { r0 = ret.Error(0) diff --git a/flytepropeller/pkg/controller/nodes/dynamic/utils.go b/flytepropeller/pkg/controller/nodes/dynamic/utils.go index d08845856e..6a38bafe4f 100644 --- a/flytepropeller/pkg/controller/nodes/dynamic/utils.go +++ b/flytepropeller/pkg/controller/nodes/dynamic/utils.go @@ -9,11 +9,11 @@ import ( "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/encoding" "github.com/flyteorg/flytepropeller/pkg/compiler" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" ) // Constructs the expected interface of a given node. -func underlyingInterface(ctx context.Context, taskReader handler.TaskReader) (*core.TypedInterface, error) { +func underlyingInterface(ctx context.Context, taskReader interfaces.TaskReader) (*core.TypedInterface, error) { t, err := taskReader.Read(ctx) iface := &core.TypedInterface{} if err != nil { diff --git a/flytepropeller/pkg/controller/nodes/dynamic/utils_test.go b/flytepropeller/pkg/controller/nodes/dynamic/utils_test.go index 6afc3cb807..291d175ac3 100644 --- a/flytepropeller/pkg/controller/nodes/dynamic/utils_test.go +++ b/flytepropeller/pkg/controller/nodes/dynamic/utils_test.go @@ -4,7 +4,7 @@ import ( "context" "testing" - mocks2 "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler/mocks" + mocks2 "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces/mocks" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" diff --git a/flytepropeller/pkg/controller/nodes/end/handler.go b/flytepropeller/pkg/controller/nodes/end/handler.go index 4f56ee840b..d77a7ab508 100644 --- a/flytepropeller/pkg/controller/nodes/end/handler.go +++ b/flytepropeller/pkg/controller/nodes/end/handler.go @@ -9,6 +9,7 @@ import ( "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/errors" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" ) type endHandler struct { @@ -18,11 +19,11 @@ func (e endHandler) FinalizeRequired() bool { return false } -func (e endHandler) Setup(ctx context.Context, setupContext handler.SetupContext) error { +func (e endHandler) Setup(ctx context.Context, setupContext interfaces.SetupContext) error { return nil } -func (e endHandler) Handle(ctx context.Context, executionContext handler.NodeExecutionContext) (handler.Transition, error) { +func (e endHandler) Handle(ctx context.Context, executionContext interfaces.NodeExecutionContext) (handler.Transition, error) { inputs, err := executionContext.InputReader().Get(ctx) if err != nil { return handler.UnknownTransition, err @@ -41,14 +42,14 @@ func (e endHandler) Handle(ctx context.Context, executionContext handler.NodeExe return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoSuccess(nil)), nil } -func (e endHandler) Abort(_ context.Context, _ handler.NodeExecutionContext, _ string) error { +func (e endHandler) Abort(_ context.Context, _ interfaces.NodeExecutionContext, _ string) error { return nil } -func (e endHandler) Finalize(_ context.Context, _ handler.NodeExecutionContext) error { +func (e endHandler) Finalize(_ context.Context, _ interfaces.NodeExecutionContext) error { return nil } -func New() handler.Node { +func New() interfaces.NodeHandler { return &endHandler{} } diff --git a/flytepropeller/pkg/controller/nodes/end/handler_test.go b/flytepropeller/pkg/controller/nodes/end/handler_test.go index fd18841e32..d1d500d149 100644 --- a/flytepropeller/pkg/controller/nodes/end/handler_test.go +++ b/flytepropeller/pkg/controller/nodes/end/handler_test.go @@ -21,7 +21,7 @@ import ( mocks3 "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/errors" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler/mocks" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces/mocks" flyteassert "github.com/flyteorg/flytepropeller/pkg/utils/assert" ) diff --git a/flytepropeller/pkg/controller/nodes/errors/codes.go b/flytepropeller/pkg/controller/nodes/errors/codes.go index df2be215c3..30ded68b71 100644 --- a/flytepropeller/pkg/controller/nodes/errors/codes.go +++ b/flytepropeller/pkg/controller/nodes/errors/codes.go @@ -25,4 +25,5 @@ const ( StorageError ErrorCode = "StorageError" EventRecordingFailed ErrorCode = "EventRecordingFailed" CatalogCallFailed ErrorCode = "CatalogCallFailed" + InvalidArrayLength ErrorCode = "InvalidArrayLength" ) diff --git a/flytepropeller/pkg/controller/nodes/executor.go b/flytepropeller/pkg/controller/nodes/executor.go index c447b779c6..23fbafc466 100644 --- a/flytepropeller/pkg/controller/nodes/executor.go +++ b/flytepropeller/pkg/controller/nodes/executor.go @@ -34,6 +34,7 @@ import ( "github.com/flyteorg/flytepropeller/pkg/controller/nodes/common" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/errors" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/recovery" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/subworkflow/launchplan" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task" @@ -84,1131 +85,1165 @@ type nodeMetrics struct { NodeInputGatherLatency labeled.StopWatch } -// Implements the executors.Node interface -type nodeExecutor struct { - nodeHandlerFactory HandlerFactory - enqueueWorkflow v1alpha1.EnqueueWorkflow - store *storage.DataStore - nodeRecorder events.NodeEventRecorder - taskRecorder events.TaskEventRecorder - metrics *nodeMetrics - maxDatasetSizeBytes int64 - outputResolver OutputResolver - defaultExecutionDeadline time.Duration - defaultActiveDeadline time.Duration - maxNodeRetriesForSystemFailures uint32 - interruptibleFailureThreshold uint32 - defaultDataSandbox storage.DataReference - shardSelector ioutils.ShardSelector - recoveryClient recovery.Client - eventConfig *config.EventConfig - clusterID string +// recursiveNodeExector implements the executors.Node interfaces and is the starting point for +// executing any node in the workflow. +type recursiveNodeExecutor struct { + nodeExecutor interfaces.NodeExecutor + nCtxBuilder interfaces.NodeExecutionContextBuilder + enqueueWorkflow v1alpha1.EnqueueWorkflow + nodeHandlerFactory interfaces.HandlerFactory + store *storage.DataStore + metrics *nodeMetrics } -func (c *nodeExecutor) RecordTransitionLatency(ctx context.Context, dag executors.DAGStructure, nl executors.NodeLookup, node v1alpha1.ExecutableNode, nodeStatus v1alpha1.ExecutableNodeStatus) { - if nodeStatus.GetPhase() == v1alpha1.NodePhaseNotYetStarted || nodeStatus.GetPhase() == v1alpha1.NodePhaseQueued { - // Log transition latency (The most recently finished parent node endAt time to this node's queuedAt time -now-) - t, err := GetParentNodeMaxEndTime(ctx, dag, nl, node) - if err != nil { - logger.Warnf(ctx, "Failed to record transition latency for node. Error: %s", err.Error()) - return - } - if !t.IsZero() { - c.metrics.TransitionLatency.Observe(ctx, t.Time, time.Now()) - } - } else if nodeStatus.GetPhase() == v1alpha1.NodePhaseRetryableFailure && nodeStatus.GetLastUpdatedAt() != nil { - c.metrics.TransitionLatency.Observe(ctx, nodeStatus.GetLastUpdatedAt().Time, time.Now()) +func (c *recursiveNodeExecutor) SetInputsForStartNode(ctx context.Context, execContext executors.ExecutionContext, dag executors.DAGStructureWithStartNode, nl executors.NodeLookup, inputs *core.LiteralMap) (interfaces.NodeStatus, error) { + startNode := dag.StartNode() + ctx = contextutils.WithNodeID(ctx, startNode.GetID()) + if inputs == nil { + logger.Infof(ctx, "No inputs for the workflow. Skipping storing inputs") + return interfaces.NodeStatusComplete, nil } -} -func (c *nodeExecutor) IdempotentRecordEvent(ctx context.Context, nodeEvent *event.NodeExecutionEvent) error { - if nodeEvent == nil { - return fmt.Errorf("event recording attempt of Nil Node execution event") + // StartNode is special. It does not have any processing step. It just takes the workflow (or subworkflow) inputs and converts to its own outputs + nodeStatus := nl.GetNodeExecutionStatus(ctx, startNode.GetID()) + + if len(nodeStatus.GetDataDir()) == 0 { + return interfaces.NodeStatusUndefined, errors.Errorf(errors.IllegalStateError, startNode.GetID(), "no data-dir set, cannot store inputs") } + outputFile := v1alpha1.GetOutputsFile(nodeStatus.GetOutputDir()) - if nodeEvent.Id == nil { - return fmt.Errorf("event recording attempt of with nil node Event ID") + so := storage.Options{} + if err := c.store.WriteProtobuf(ctx, outputFile, so, inputs); err != nil { + logger.Errorf(ctx, "Failed to write protobuf (metadata). Error [%v]", err) + return interfaces.NodeStatusUndefined, errors.Wrapf(errors.CausedByError, startNode.GetID(), err, "Failed to store workflow inputs (as start node)") } - logger.Infof(ctx, "Recording NodeEvent [%s] phase[%s]", nodeEvent.GetId().String(), nodeEvent.Phase.String()) - err := c.nodeRecorder.RecordNodeEvent(ctx, nodeEvent, c.eventConfig) - if err != nil { - if nodeEvent.GetId().NodeId == v1alpha1.EndNodeID { - return nil - } + return interfaces.NodeStatusComplete, nil +} - if eventsErr.IsAlreadyExists(err) { - logger.Infof(ctx, "Node event phase: %s, nodeId %s already exist", - nodeEvent.Phase.String(), nodeEvent.GetId().NodeId) - return nil - } else if eventsErr.IsEventAlreadyInTerminalStateError(err) { - if IsTerminalNodePhase(nodeEvent.Phase) { - // Event was trying to record a different terminal phase for an already terminal event. ignoring. - logger.Infof(ctx, "Node event phase: %s, nodeId %s already in terminal phase. err: %s", - nodeEvent.Phase.String(), nodeEvent.GetId().NodeId, err.Error()) - return nil - } - logger.Warningf(ctx, "Failed to record nodeEvent, error [%s]", err.Error()) - return errors.Wrapf(errors.IllegalStateError, nodeEvent.Id.NodeId, err, "phase mis-match mismatch between propeller and control plane; Trying to record Node p: %s", nodeEvent.Phase) +func canHandleNode(phase v1alpha1.NodePhase) bool { + return phase == v1alpha1.NodePhaseNotYetStarted || + phase == v1alpha1.NodePhaseQueued || + phase == v1alpha1.NodePhaseRunning || + phase == v1alpha1.NodePhaseFailing || + phase == v1alpha1.NodePhaseTimingOut || + phase == v1alpha1.NodePhaseRetryableFailure || + phase == v1alpha1.NodePhaseSucceeding || + phase == v1alpha1.NodePhaseDynamicRunning +} + +// IsMaxParallelismAchieved checks if we have already achieved max parallelism. It returns true, if the desired max parallelism +// value is achieved, false otherwise +// MaxParallelism is defined as the maximum number of TaskNodes and LaunchPlans (together) that can be executed concurrently +// by one workflow execution. A setting of `0` indicates that it is disabled. +func IsMaxParallelismAchieved(ctx context.Context, currentNode v1alpha1.ExecutableNode, currentPhase v1alpha1.NodePhase, + execContext executors.ExecutionContext) bool { + maxParallelism := execContext.GetExecutionConfig().MaxParallelism + if maxParallelism == 0 { + logger.Debugf(ctx, "Parallelism control disabled") + return false + } + + if currentNode.GetKind() == v1alpha1.NodeKindTask || + (currentNode.GetKind() == v1alpha1.NodeKindWorkflow && currentNode.GetWorkflowNode() != nil && currentNode.GetWorkflowNode().GetLaunchPlanRefID() != nil) { + // If we are queued, let us see if we can proceed within the node parallelism bounds + if execContext.CurrentParallelism() >= maxParallelism { + logger.Infof(ctx, "Maximum Parallelism for task/launch-plan nodes achieved [%d] >= Max [%d], Round will be short-circuited.", execContext.CurrentParallelism(), maxParallelism) + return true } + // We know that Propeller goes through each workflow in a single thread, thus every node is really processed + // sequentially. So, we can continue - now that we know we are under the parallelism limits and increment the + // parallelism if the node, enters a running state + logger.Debugf(ctx, "Parallelism criteria not met, Current [%d], Max [%d]", execContext.CurrentParallelism(), maxParallelism) + } else { + logger.Debugf(ctx, "NodeKind: %s in status [%s]. Parallelism control is not applicable. Current Parallelism [%d]", + currentNode.GetKind().String(), currentPhase.String(), execContext.CurrentParallelism()) } - return err + return false } -func (c *nodeExecutor) recoverInputs(ctx context.Context, nCtx handler.NodeExecutionContext, - recovered *admin.NodeExecution, recoveredData *admin.NodeExecutionGetDataResponse) (*core.LiteralMap, error) { +// RecursiveNodeHandler This is the entrypoint of executing a node in a workflow. A workflow consists of nodes, that are +// nested within other nodes. The system follows an actor model, where the parent nodes control the execution of nested nodes +// The recursive node-handler uses a modified depth-first type of algorithm to execute non-blocked nodes. +func (c *recursiveNodeExecutor) RecursiveNodeHandler(ctx context.Context, execContext executors.ExecutionContext, + dag executors.DAGStructure, nl executors.NodeLookup, currentNode v1alpha1.ExecutableNode) ( + interfaces.NodeStatus, error) { - nodeInputs := recoveredData.FullInputs - if nodeInputs != nil { - if err := c.store.WriteProtobuf(ctx, nCtx.InputReader().GetInputPath(), storage.Options{}, nodeInputs); err != nil { - c.metrics.InputsWriteFailure.Inc(ctx) - logger.Errorf(ctx, "Failed to move recovered inputs for Node. Error [%v]. InputsFile [%s]", err, nCtx.InputReader().GetInputPath()) - return nil, errors.Wrapf(errors.StorageError, nCtx.NodeID(), err, "Failed to store inputs for Node. InputsFile [%s]", nCtx.InputReader().GetInputPath()) - } - } else if len(recovered.InputUri) > 0 { - // If the inputs are too large they won't be returned inline in the RecoverData call. We must fetch them before copying them. - nodeInputs = &core.LiteralMap{} - if recoveredData.FullInputs == nil { - if err := c.store.ReadProtobuf(ctx, storage.DataReference(recovered.InputUri), nodeInputs); err != nil { - return nil, errors.Wrapf(errors.InputsNotFoundError, nCtx.NodeID(), err, "failed to read data from dataDir [%v].", recovered.InputUri) - } - } + currentNodeCtx := contextutils.WithNodeID(ctx, currentNode.GetID()) + nodeStatus := nl.GetNodeExecutionStatus(ctx, currentNode.GetID()) + nodePhase := nodeStatus.GetPhase() - if err := c.store.WriteProtobuf(ctx, nCtx.InputReader().GetInputPath(), storage.Options{}, nodeInputs); err != nil { - c.metrics.InputsWriteFailure.Inc(ctx) - logger.Errorf(ctx, "Failed to move recovered inputs for Node. Error [%v]. InputsFile [%s]", err, nCtx.InputReader().GetInputPath()) - return nil, errors.Wrapf(errors.StorageError, nCtx.NodeID(), err, "Failed to store inputs for Node. InputsFile [%s]", nCtx.InputReader().GetInputPath()) - } - } + if canHandleNode(nodePhase) { + // TODO Follow up Pull Request, + // 1. Rename this method to DAGTraversalHandleNode (accepts a DAGStructure along-with) the remaining arguments + // 2. Create a new method called HandleNode (part of the interface) (remaining all args as the previous method, but no DAGStructure + // 3. Additional both methods will receive inputs reader + // 4. The Downstream nodes handler will Resolve the Inputs + // 5. the method will delegate all other node handling to HandleNode. + // 6. Thus we can get rid of SetInputs for StartNode as well + logger.Debugf(currentNodeCtx, "Handling node Status [%v]", nodeStatus.GetPhase().String()) - return nodeInputs, nil -} + t := c.metrics.NodeExecutionTime.Start(ctx) + defer t.Stop() -func (c *nodeExecutor) attemptRecovery(ctx context.Context, nCtx handler.NodeExecutionContext) (handler.PhaseInfo, error) { - fullyQualifiedNodeID := nCtx.NodeExecutionMetadata().GetNodeExecutionID().NodeId - if nCtx.ExecutionContext().GetEventVersion() != v1alpha1.EventVersion0 { - // compute fully qualified node id (prefixed with parent id and retry attempt) to ensure uniqueness - var err error - fullyQualifiedNodeID, err = common.GenerateUniqueID(nCtx.ExecutionContext().GetParentInfo(), nCtx.NodeExecutionMetadata().GetNodeExecutionID().NodeId) - if err != nil { - return handler.PhaseInfoUndefined, err + // This is an optimization to avoid creating the nodeContext object in case the node has already been looked at. + // If the overhead was zero, we would just do the isDirtyCheck after the nodeContext is created + if nodeStatus.IsDirty() { + return interfaces.NodeStatusRunning, nil } - } - recovered, err := c.recoveryClient.RecoverNodeExecution(ctx, nCtx.ExecutionContext().GetExecutionConfig().RecoveryExecution.WorkflowExecutionIdentifier, fullyQualifiedNodeID) - if err != nil { - st, ok := status.FromError(err) - if !ok || st.Code() != codes.NotFound { - logger.Warnf(ctx, "Failed to recover node [%+v] with err [%+v]", nCtx.NodeExecutionMetadata().GetNodeExecutionID(), err) - } - // The node is not recoverable when it's not found in the parent execution - return handler.PhaseInfoUndefined, nil - } - if recovered == nil { - logger.Warnf(ctx, "call to recover node [%+v] returned no error but also no node", nCtx.NodeExecutionMetadata().GetNodeExecutionID()) - return handler.PhaseInfoUndefined, nil - } - if recovered.Closure == nil { - logger.Warnf(ctx, "Fetched node execution [%+v] data but was missing closure. Will not attempt to recover", - nCtx.NodeExecutionMetadata().GetNodeExecutionID()) - return handler.PhaseInfoUndefined, nil - } - // A recoverable node execution should always be in a terminal phase - switch recovered.Closure.Phase { - case core.NodeExecution_SKIPPED: - return handler.PhaseInfoUndefined, nil - case core.NodeExecution_SUCCEEDED: - fallthrough - case core.NodeExecution_RECOVERED: - logger.Debugf(ctx, "Node [%+v] can be recovered. Proceeding to copy inputs and outputs", nCtx.NodeExecutionMetadata().GetNodeExecutionID()) - default: - // The node execution may be partially recoverable through intra task checkpointing. Save the checkpoint - // uri in the task node state to pass to the task handler later on. - if metadata, ok := recovered.Closure.TargetMetadata.(*admin.NodeExecutionClosure_TaskNodeMetadata); ok { - state := nCtx.NodeStateReader().GetTaskNodeState() - state.PreviousNodeExecutionCheckpointURI = storage.DataReference(metadata.TaskNodeMetadata.CheckpointUri) - err = nCtx.NodeStateWriter().PutTaskNodeState(state) - if err != nil { - logger.Warn(ctx, "failed to save recovered checkpoint uri for [%+v]: [%+v]", - nCtx.NodeExecutionMetadata().GetNodeExecutionID(), err) - } + if IsMaxParallelismAchieved(ctx, currentNode, nodePhase, execContext) { + return interfaces.NodeStatusRunning, nil } - // if this node is a dynamic task we attempt to recover the compiled workflow from instances where the parent - // task succeeded but the dynamic task did not complete. this is important to ensure correctness since node ids - // within the compiled closure may not be generated deterministically. - if recovered.Metadata != nil && recovered.Metadata.IsDynamic && len(recovered.Closure.DynamicJobSpecUri) > 0 { - // recover node inputs - recoveredData, err := c.recoveryClient.RecoverNodeExecutionData(ctx, - nCtx.ExecutionContext().GetExecutionConfig().RecoveryExecution.WorkflowExecutionIdentifier, fullyQualifiedNodeID) - if err != nil || recoveredData == nil { - return handler.PhaseInfoUndefined, nil - } - - if _, err := c.recoverInputs(ctx, nCtx, recovered, recoveredData); err != nil { - return handler.PhaseInfoUndefined, err - } + nCtx, err := c.nCtxBuilder.BuildNodeExecutionContext(ctx, execContext, nl, currentNode.GetID()) + if err != nil { + // NodeExecution creation failure is a permanent fail / system error. + // Should a system failure always return an err? + return interfaces.NodeStatusFailed(&core.ExecutionError{ + Code: "InternalError", + Message: err.Error(), + Kind: core.ExecutionError_SYSTEM, + }), nil + } - // copy previous DynamicJobSpec file - f, err := task.NewRemoteFutureFileReader(ctx, nCtx.NodeStatus().GetOutputDir(), nCtx.DataStore()) - if err != nil { - return handler.PhaseInfoUndefined, err - } + // Now depending on the node type decide + h, err := c.nodeHandlerFactory.GetHandler(nCtx.Node().GetKind()) + if err != nil { + return interfaces.NodeStatusUndefined, err + } - dynamicJobSpecReference := storage.DataReference(recovered.Closure.DynamicJobSpecUri) - if err := nCtx.DataStore().CopyRaw(ctx, dynamicJobSpecReference, f.GetLoc(), storage.Options{}); err != nil { - return handler.PhaseInfoUndefined, errors.Wrapf(errors.StorageError, nCtx.NodeID(), err, - "failed to store dynamic job spec for node. source file [%s] destination file [%s]", dynamicJobSpecReference, f.GetLoc()) - } + return c.nodeExecutor.HandleNode(currentNodeCtx, dag, nCtx, h) - // transition node phase to 'Running' and dynamic task phase to 'DynamicNodePhaseParentFinalized' - state := nCtx.NodeStateReader().GetDynamicNodeState() - state.Phase = v1alpha1.DynamicNodePhaseParentFinalized - if err := nCtx.NodeStateWriter().PutDynamicNodeState(state); err != nil { - return handler.PhaseInfoUndefined, errors.Wrapf(errors.UnknownError, nCtx.NodeID(), err, "failed to store dynamic node state") - } + // TODO we can optimize skip state handling by iterating down the graph and marking all as skipped + // Currently we treat either Skip or Success the same way. In this approach only one node will be skipped + // at a time. As we iterate down, further nodes will be skipped + } else if nodePhase == v1alpha1.NodePhaseSucceeded || nodePhase == v1alpha1.NodePhaseSkipped || nodePhase == v1alpha1.NodePhaseRecovered { + logger.Debugf(currentNodeCtx, "Node has [%v], traversing downstream.", nodePhase) + return c.handleDownstream(ctx, execContext, dag, nl, currentNode) + } else if nodePhase == v1alpha1.NodePhaseFailed { + logger.Debugf(currentNodeCtx, "Node has failed, traversing downstream.") + _, err := c.handleDownstream(ctx, execContext, dag, nl, currentNode) + if err != nil { + return interfaces.NodeStatusUndefined, err + } - return handler.PhaseInfoRunning(&handler.ExecutionInfo{}), nil + return interfaces.NodeStatusFailed(nodeStatus.GetExecutionError()), nil + } else if nodePhase == v1alpha1.NodePhaseTimedOut { + logger.Debugf(currentNodeCtx, "Node has timed out, traversing downstream.") + _, err := c.handleDownstream(ctx, execContext, dag, nl, currentNode) + if err != nil { + return interfaces.NodeStatusUndefined, err } - logger.Debugf(ctx, "Node [%+v] phase [%v] is not recoverable", nCtx.NodeExecutionMetadata().GetNodeExecutionID(), recovered.Closure.Phase) - return handler.PhaseInfoUndefined, nil + return interfaces.NodeStatusTimedOut, nil } - recoveredData, err := c.recoveryClient.RecoverNodeExecutionData(ctx, nCtx.ExecutionContext().GetExecutionConfig().RecoveryExecution.WorkflowExecutionIdentifier, fullyQualifiedNodeID) - if err != nil { - st, ok := status.FromError(err) - if !ok || st.Code() != codes.NotFound { - logger.Warnf(ctx, "Failed to attemptRecovery node execution data for [%+v] although back-end indicated node was recoverable with err [%+v]", - nCtx.NodeExecutionMetadata().GetNodeExecutionID(), err) - } - return handler.PhaseInfoUndefined, nil - } - if recoveredData == nil { - logger.Warnf(ctx, "call to attemptRecovery node [%+v] data returned no error but also no data", nCtx.NodeExecutionMetadata().GetNodeExecutionID()) - return handler.PhaseInfoUndefined, nil - } + return interfaces.NodeStatusUndefined, errors.Errorf(errors.IllegalStateError, currentNode.GetID(), + "Should never reach here. Current Phase: %v", nodePhase) +} - // Copy inputs to this node's expected location - nodeInputs, err := c.recoverInputs(ctx, nCtx, recovered, recoveredData) +// The space search for the next node to execute is implemented like a DFS algorithm. handleDownstream visits all the nodes downstream from +// the currentNode. Visit a node is the RecursiveNodeHandler. A visit may be partial, complete or may result in a failure. +func (c *recursiveNodeExecutor) handleDownstream(ctx context.Context, execContext executors.ExecutionContext, dag executors.DAGStructure, nl executors.NodeLookup, currentNode v1alpha1.ExecutableNode) (interfaces.NodeStatus, error) { + logger.Debugf(ctx, "Handling downstream Nodes") + // This node is success. Handle all downstream nodes + downstreamNodes, err := dag.FromNode(currentNode.GetID()) if err != nil { - return handler.PhaseInfoUndefined, err - } - - // Similarly, copy outputs' reference - so := storage.Options{} - var outputs = &core.LiteralMap{} - if recoveredData.FullOutputs != nil { - outputs = recoveredData.FullOutputs - } else if recovered.Closure.GetOutputData() != nil { - outputs = recovered.Closure.GetOutputData() - } else if len(recovered.Closure.GetOutputUri()) > 0 { - if err := c.store.ReadProtobuf(ctx, storage.DataReference(recovered.Closure.GetOutputUri()), outputs); err != nil { - return handler.PhaseInfoUndefined, errors.Wrapf(errors.InputsNotFoundError, nCtx.NodeID(), err, "failed to read output data [%v].", recovered.Closure.GetOutputUri()) - } - } else { - logger.Debugf(ctx, "No outputs found for recovered node [%+v]", nCtx.NodeExecutionMetadata().GetNodeExecutionID()) + logger.Debugf(ctx, "Error when retrieving downstream nodes, [%s]", err) + return interfaces.NodeStatusFailed(&core.ExecutionError{ + Code: errors.BadSpecificationError, + Message: fmt.Sprintf("failed to retrieve downstream nodes for [%s]", currentNode.GetID()), + Kind: core.ExecutionError_SYSTEM, + }), nil } - - outputFile := v1alpha1.GetOutputsFile(nCtx.NodeStatus().GetOutputDir()) - oi := &handler.OutputInfo{ - OutputURI: outputFile, + if len(downstreamNodes) == 0 { + logger.Debugf(ctx, "No downstream nodes found. Complete.") + return interfaces.NodeStatusComplete, nil } + // If any downstream node is failed, fail, all + // Else if all are success then success + // Else if any one is running then Downstream is still running + allCompleted := true + partialNodeCompletion := false + onFailurePolicy := execContext.GetOnFailurePolicy() + stateOnComplete := interfaces.NodeStatusComplete + for _, downstreamNodeName := range downstreamNodes { + downstreamNode, ok := nl.GetNode(downstreamNodeName) + if !ok { + return interfaces.NodeStatusFailed(&core.ExecutionError{ + Code: errors.BadSpecificationError, + Message: fmt.Sprintf("failed to retrieve downstream node [%s] for [%s]", downstreamNodeName, currentNode.GetID()), + Kind: core.ExecutionError_SYSTEM, + }), nil + } - deckFile := storage.DataReference(recovered.Closure.GetDeckUri()) - if len(deckFile) > 0 { - metadata, err := nCtx.DataStore().Head(ctx, deckFile) + state, err := c.RecursiveNodeHandler(ctx, execContext, dag, nl, downstreamNode) if err != nil { - logger.Errorf(ctx, "Failed to check the existence of deck file. Error: %v", err) - return handler.PhaseInfoUndefined, errors.Wrapf(errors.CausedByError, nCtx.NodeID(), err, "Failed to check the existence of deck file.") + return interfaces.NodeStatusUndefined, err + } + + if state.HasFailed() || state.HasTimedOut() { + logger.Debugf(ctx, "Some downstream node has failed. Failed: [%v]. TimedOut: [%v]. Error: [%s]", state.HasFailed(), state.HasTimedOut(), state.Err) + if onFailurePolicy == v1alpha1.WorkflowOnFailurePolicy(core.WorkflowMetadata_FAIL_AFTER_EXECUTABLE_NODES_COMPLETE) { + // If the failure policy allows other nodes to continue running, do not exit the loop, + // Keep track of the last failed state in the loop since it'll be the one to return. + // TODO: If multiple nodes fail (which this mode allows), consolidate/summarize failure states in one. + stateOnComplete = state + } else { + return state, nil + } + } else if !state.IsComplete() { + // A Failed/Timedout node is implicitly considered "complete" this means none of the downstream nodes from + // that node will ever be allowed to run. + // This else block, therefore, deals with all other states. IsComplete will return true if and only if this + // node as well as all of its downstream nodes have finished executing with success statuses. Otherwise we + // mark this node's state as not completed to ensure we will visit it again later. + allCompleted = false } - if metadata.Exists() { - oi.DeckURI = &deckFile + if state.PartiallyComplete() { + // This implies that one of the downstream nodes has just succeeded and workflow is ready for propagation + // We do not propagate in current cycle to make it possible to store the state between transitions + partialNodeCompletion = true } } - if err := c.store.WriteProtobuf(ctx, outputFile, so, outputs); err != nil { - logger.Errorf(ctx, "Failed to write protobuf (metadata). Error [%v]", err) - return handler.PhaseInfoUndefined, errors.Wrapf(errors.CausedByError, nCtx.NodeID(), err, "Failed to store recovered node execution outputs") + if allCompleted { + logger.Debugf(ctx, "All downstream nodes completed") + return stateOnComplete, nil } - info := &handler.ExecutionInfo{ - Inputs: nodeInputs, - OutputInfo: oi, + if partialNodeCompletion { + return interfaces.NodeStatusSuccess, nil } - if recovered.Closure.GetTaskNodeMetadata() != nil { - taskNodeInfo := &handler.TaskNodeInfo{ - TaskNodeMetadata: &event.TaskNodeMetadata{ - CatalogKey: recovered.Closure.GetTaskNodeMetadata().CatalogKey, - CacheStatus: recovered.Closure.GetTaskNodeMetadata().CacheStatus, - }, + return interfaces.NodeStatusPending, nil +} + +func (c *recursiveNodeExecutor) FinalizeHandler(ctx context.Context, execContext executors.ExecutionContext, dag executors.DAGStructure, nl executors.NodeLookup, currentNode v1alpha1.ExecutableNode) error { + nodeStatus := nl.GetNodeExecutionStatus(ctx, currentNode.GetID()) + nodePhase := nodeStatus.GetPhase() + + if nodePhase == v1alpha1.NodePhaseNotYetStarted { + logger.Infof(ctx, "Node not yet started, will not finalize") + // Nothing to be aborted + return nil + } + + if canHandleNode(nodePhase) { + ctx = contextutils.WithNodeID(ctx, currentNode.GetID()) + + // Now depending on the node type decide + h, err := c.nodeHandlerFactory.GetHandler(currentNode.GetKind()) + if err != nil { + return err } - if recoveredData.DynamicWorkflow != nil { - taskNodeInfo.TaskNodeMetadata.DynamicWorkflow = &event.DynamicWorkflowNodeMetadata{ - Id: recoveredData.DynamicWorkflow.Id, - CompiledWorkflow: recoveredData.DynamicWorkflow.CompiledWorkflow, + + nCtx, err := c.nCtxBuilder.BuildNodeExecutionContext(ctx, execContext, nl, currentNode.GetID()) + if err != nil { + return err + } + // Abort this node + err = c.nodeExecutor.Finalize(ctx, h, nCtx) + if err != nil { + return err + } + } else { + // Abort downstream nodes + downstreamNodes, err := dag.FromNode(currentNode.GetID()) + if err != nil { + logger.Debugf(ctx, "Error when retrieving downstream nodes. Error [%v]", err) + return nil + } + + errs := make([]error, 0, len(downstreamNodes)) + for _, d := range downstreamNodes { + downstreamNode, ok := nl.GetNode(d) + if !ok { + return errors.Errorf(errors.BadSpecificationError, currentNode.GetID(), "Unable to find Downstream Node [%v]", d) + } + + if err := c.FinalizeHandler(ctx, execContext, dag, nl, downstreamNode); err != nil { + logger.Infof(ctx, "Failed to abort node [%v]. Error: %v", d, err) + errs = append(errs, err) } } - info.TaskNodeInfo = taskNodeInfo - } else if recovered.Closure.GetWorkflowNodeMetadata() != nil { - logger.Warnf(ctx, "Attempted to recover node") - info.WorkflowNodeInfo = &handler.WorkflowNodeInfo{ - LaunchedWorkflowID: recovered.Closure.GetWorkflowNodeMetadata().ExecutionId, + + if len(errs) > 0 { + return errors.ErrorCollection{Errors: errs} } + + return nil } - return handler.PhaseInfoRecovered(info), nil + + return nil } -// In this method we check if the queue is ready to be processed and if so, we prime it in Admin as queued -// Before we start the node execution, we need to transition this Node status to Queued. -// This is because a node execution has to exist before task/wf executions can start. -func (c *nodeExecutor) preExecute(ctx context.Context, dag executors.DAGStructure, nCtx handler.NodeExecutionContext) ( - handler.PhaseInfo, error) { - logger.Debugf(ctx, "Node not yet started") - // Query the nodes information to figure out if it can be executed. - predicatePhase, err := CanExecute(ctx, dag, nCtx.ContextualNodeLookup(), nCtx.Node()) - if err != nil { - logger.Debugf(ctx, "Node failed in CanExecute. Error [%s]", err) - return handler.PhaseInfoUndefined, err +func (c *recursiveNodeExecutor) AbortHandler(ctx context.Context, execContext executors.ExecutionContext, dag executors.DAGStructure, nl executors.NodeLookup, currentNode v1alpha1.ExecutableNode, reason string) error { + nodeStatus := nl.GetNodeExecutionStatus(ctx, currentNode.GetID()) + nodePhase := nodeStatus.GetPhase() + + if nodePhase == v1alpha1.NodePhaseNotYetStarted { + logger.Infof(ctx, "Node not yet started, will not finalize") + // Nothing to be aborted + return nil } - if predicatePhase == PredicatePhaseReady { - // TODO: Performance problem, we maybe in a retry loop and do not need to resolve the inputs again. - // For now we will do this - node := nCtx.Node() - var nodeInputs *core.LiteralMap - if !node.IsStartNode() { - if nCtx.ExecutionContext().GetExecutionConfig().RecoveryExecution.WorkflowExecutionIdentifier != nil { - phaseInfo, err := c.attemptRecovery(ctx, nCtx) - if err != nil || phaseInfo.GetPhase() != handler.EPhaseUndefined { - return phaseInfo, err - } - } - nodeStatus := nCtx.NodeStatus() - dataDir := nodeStatus.GetDataDir() - t := c.metrics.NodeInputGatherLatency.Start(ctx) - defer t.Stop() - // Can execute - var err error - nodeInputs, err = Resolve(ctx, c.outputResolver, nCtx.ContextualNodeLookup(), node.GetID(), node.GetInputBindings()) - // TODO we need to handle retryable, network errors here!! - if err != nil { - c.metrics.ResolutionFailure.Inc(ctx) - logger.Warningf(ctx, "Failed to resolve inputs for Node. Error [%v]", err) - return handler.PhaseInfoFailure(core.ExecutionError_SYSTEM, "BindingResolutionFailure", err.Error(), nil), nil - } + if canHandleNode(nodePhase) { + ctx = contextutils.WithNodeID(ctx, currentNode.GetID()) - if nodeInputs != nil { - inputsFile := v1alpha1.GetInputsFile(dataDir) - if err := c.store.WriteProtobuf(ctx, inputsFile, storage.Options{}, nodeInputs); err != nil { - c.metrics.InputsWriteFailure.Inc(ctx) - logger.Errorf(ctx, "Failed to store inputs for Node. Error [%v]. InputsFile [%s]", err, inputsFile) - return handler.PhaseInfoUndefined, errors.Wrapf( - errors.StorageError, node.GetID(), err, "Failed to store inputs for Node. InputsFile [%s]", inputsFile) - } + // Now depending on the node type decide + h, err := c.nodeHandlerFactory.GetHandler(currentNode.GetKind()) + if err != nil { + return err + } + + nCtx, err := c.nCtxBuilder.BuildNodeExecutionContext(ctx, execContext, nl, currentNode.GetID()) + if err != nil { + return err + } + // Abort this node + return c.nodeExecutor.Abort(ctx, h, nCtx, reason, true) + } else if nodePhase == v1alpha1.NodePhaseSucceeded || nodePhase == v1alpha1.NodePhaseSkipped || nodePhase == v1alpha1.NodePhaseRecovered { + // Abort downstream nodes + downstreamNodes, err := dag.FromNode(currentNode.GetID()) + if err != nil { + logger.Debugf(ctx, "Error when retrieving downstream nodes. Error [%v]", err) + return nil + } + + errs := make([]error, 0, len(downstreamNodes)) + for _, d := range downstreamNodes { + downstreamNode, ok := nl.GetNode(d) + if !ok { + return errors.Errorf(errors.BadSpecificationError, currentNode.GetID(), "Unable to find Downstream Node [%v]", d) } - logger.Debugf(ctx, "Node Data Directory [%s].", nodeStatus.GetDataDir()) + if err := c.AbortHandler(ctx, execContext, dag, nl, downstreamNode, reason); err != nil { + logger.Infof(ctx, "Failed to abort node [%v]. Error: %v", d, err) + errs = append(errs, err) + } } - return handler.PhaseInfoQueued("node queued", nodeInputs), nil - } + if len(errs) > 0 { + return errors.ErrorCollection{Errors: errs} + } - // Now that we have resolved the inputs, we can record as a transition latency. This is because we have completed - // all the overhead that we have to compute. Any failures after this will incur this penalty, but it could be due - // to various external reasons - like queuing, overuse of quota, plugin overhead etc. - logger.Debugf(ctx, "preExecute completed in phase [%s]", predicatePhase.String()) - if predicatePhase == PredicatePhaseSkip { - return handler.PhaseInfoSkip(nil, "Node Skipped as parent node was skipped"), nil + return nil + } else { + ctx = contextutils.WithNodeID(ctx, currentNode.GetID()) + logger.Warnf(ctx, "Trying to abort a node in state [%s]", nodeStatus.GetPhase().String()) } - return handler.PhaseInfoNotReady("predecessor node not yet complete"), nil + return nil } -func isTimeoutExpired(queuedAt *metav1.Time, timeout time.Duration) bool { - if !queuedAt.IsZero() && timeout != 0 { - deadline := queuedAt.Add(timeout) - if deadline.Before(time.Now()) { - return true - } - } - return false +func (c *recursiveNodeExecutor) Initialize(ctx context.Context) error { + logger.Infof(ctx, "Initializing Core Node Executor") + s := c.newSetupContext(ctx) + return c.nodeHandlerFactory.Setup(ctx, c, s) } -func (c *nodeExecutor) isEligibleForRetry(nCtx *nodeExecContext, nodeStatus v1alpha1.ExecutableNodeStatus, err *core.ExecutionError) (currentAttempt, maxAttempts uint32, isEligible bool) { - if err.Kind == core.ExecutionError_SYSTEM { - currentAttempt = nodeStatus.GetSystemFailures() - maxAttempts = c.maxNodeRetriesForSystemFailures - isEligible = currentAttempt < c.maxNodeRetriesForSystemFailures - return - } - - currentAttempt = (nodeStatus.GetAttempts() + 1) - nodeStatus.GetSystemFailures() - if nCtx.Node().GetRetryStrategy() != nil && nCtx.Node().GetRetryStrategy().MinAttempts != nil { - maxAttempts = uint32(*nCtx.Node().GetRetryStrategy().MinAttempts) - } - isEligible = currentAttempt < maxAttempts - return +// GetNodeExecutionContextBuilder returns the current NodeExecutionContextBuilder +func (c *recursiveNodeExecutor) GetNodeExecutionContextBuilder() interfaces.NodeExecutionContextBuilder { + return c.nCtxBuilder } -func (c *nodeExecutor) execute(ctx context.Context, h handler.Node, nCtx *nodeExecContext, nodeStatus v1alpha1.ExecutableNodeStatus) (handler.PhaseInfo, error) { - logger.Debugf(ctx, "Executing node") - defer logger.Debugf(ctx, "Node execution round complete") - - t, err := h.Handle(ctx, nCtx) - if err != nil { - return handler.PhaseInfoUndefined, err +// WithNodeExecutionContextBuilder returns a new Node with the given NodeExecutionContextBuilder +func (c *recursiveNodeExecutor) WithNodeExecutionContextBuilder(nCtxBuilder interfaces.NodeExecutionContextBuilder) interfaces.Node { + return &recursiveNodeExecutor{ + nodeExecutor: c.nodeExecutor, + nCtxBuilder: nCtxBuilder, + enqueueWorkflow: c.enqueueWorkflow, + nodeHandlerFactory: c.nodeHandlerFactory, + store: c.store, + metrics: c.metrics, } +} - phase := t.Info() - // check for timeout for non-terminal phases - if !phase.GetPhase().IsTerminal() { - activeDeadline := c.defaultActiveDeadline - if nCtx.Node().GetActiveDeadline() != nil && *nCtx.Node().GetActiveDeadline() > 0 { - activeDeadline = *nCtx.Node().GetActiveDeadline() - } - if isTimeoutExpired(nodeStatus.GetQueuedAt(), activeDeadline) { - logger.Infof(ctx, "Node has timed out; timeout configured: %v", activeDeadline) - return handler.PhaseInfoTimedOut(nil, fmt.Sprintf("task active timeout [%s] expired", activeDeadline.String())), nil - } +// nodeExecutor implements the NodeExecutor interface and is responsible for executing a single node. +type nodeExecutor struct { + clusterID string + defaultActiveDeadline time.Duration + defaultDataSandbox storage.DataReference + defaultExecutionDeadline time.Duration + enqueueWorkflow v1alpha1.EnqueueWorkflow + eventConfig *config.EventConfig + interruptibleFailureThreshold uint32 + maxDatasetSizeBytes int64 + maxNodeRetriesForSystemFailures uint32 + metrics *nodeMetrics + nodeRecorder events.NodeEventRecorder + outputResolver OutputResolver + recoveryClient recovery.Client + shardSelector ioutils.ShardSelector + store *storage.DataStore + taskRecorder events.TaskEventRecorder +} - // Execution timeout is a retry-able error - executionDeadline := c.defaultExecutionDeadline - if nCtx.Node().GetExecutionDeadline() != nil && *nCtx.Node().GetExecutionDeadline() > 0 { - executionDeadline = *nCtx.Node().GetExecutionDeadline() +func (c *nodeExecutor) RecordTransitionLatency(ctx context.Context, dag executors.DAGStructure, nl executors.NodeLookup, node v1alpha1.ExecutableNode, nodeStatus v1alpha1.ExecutableNodeStatus) { + if nodeStatus.GetPhase() == v1alpha1.NodePhaseNotYetStarted || nodeStatus.GetPhase() == v1alpha1.NodePhaseQueued { + // Log transition latency (The most recently finished parent node endAt time to this node's queuedAt time -now-) + t, err := GetParentNodeMaxEndTime(ctx, dag, nl, node) + if err != nil { + logger.Warnf(ctx, "Failed to record transition latency for node. Error: %s", err.Error()) + return } - if isTimeoutExpired(nodeStatus.GetLastAttemptStartedAt(), executionDeadline) { - logger.Infof(ctx, "Current execution for the node timed out; timeout configured: %v", executionDeadline) - executionErr := &core.ExecutionError{Code: "TimeoutExpired", Message: fmt.Sprintf("task execution timeout [%s] expired", executionDeadline.String()), Kind: core.ExecutionError_USER} - phase = handler.PhaseInfoRetryableFailureErr(executionErr, nil) + if !t.IsZero() { + c.metrics.TransitionLatency.Observe(ctx, t.Time, time.Now()) } + } else if nodeStatus.GetPhase() == v1alpha1.NodePhaseRetryableFailure && nodeStatus.GetLastUpdatedAt() != nil { + c.metrics.TransitionLatency.Observe(ctx, nodeStatus.GetLastUpdatedAt().Time, time.Now()) } +} - if phase.GetPhase() == handler.EPhaseRetryableFailure { - currentAttempt, maxAttempts, isEligible := c.isEligibleForRetry(nCtx, nodeStatus, phase.GetErr()) - if !isEligible { - return handler.PhaseInfoFailure( - core.ExecutionError_USER, - fmt.Sprintf("RetriesExhausted|%s", phase.GetErr().Code), - fmt.Sprintf("[%d/%d] currentAttempt done. Last Error: %s::%s", currentAttempt, maxAttempts, phase.GetErr().Kind.String(), phase.GetErr().Message), - phase.GetInfo(), - ), nil - } +/*func (c *nodeExecutor) IdempotentRecordEvent(ctx context.Context, nodeEvent *event.NodeExecutionEvent) error { + if nodeEvent == nil { + return fmt.Errorf("event recording attempt of Nil Node execution event") + } - // Retrying to clearing all status - nCtx.nsm.clearNodeStatus() + if nodeEvent.Id == nil { + return fmt.Errorf("event recording attempt of with nil node Event ID") } - return phase, nil -} + logger.Infof(ctx, "Recording NodeEvent [%s] phase[%s]", nodeEvent.GetId().String(), nodeEvent.Phase.String()) + err := c.nodeRecorder.RecordNodeEvent(ctx, nodeEvent, c.eventConfig) + if err != nil { + if nodeEvent.GetId().NodeId == v1alpha1.EndNodeID { + return nil + } -func (c *nodeExecutor) abort(ctx context.Context, h handler.Node, nCtx handler.NodeExecutionContext, reason string) error { - logger.Debugf(ctx, "Calling aborting & finalize") - if err := h.Abort(ctx, nCtx, reason); err != nil { - finalizeErr := h.Finalize(ctx, nCtx) - if finalizeErr != nil { - return errors.ErrorCollection{Errors: []error{err, finalizeErr}} + if eventsErr.IsAlreadyExists(err) { + logger.Infof(ctx, "Node event phase: %s, nodeId %s already exist", + nodeEvent.Phase.String(), nodeEvent.GetId().NodeId) + return nil + } else if eventsErr.IsEventAlreadyInTerminalStateError(err) { + if IsTerminalNodePhase(nodeEvent.Phase) { + // Event was trying to record a different terminal phase for an already terminal event. ignoring. + logger.Infof(ctx, "Node event phase: %s, nodeId %s already in terminal phase. err: %s", + nodeEvent.Phase.String(), nodeEvent.GetId().NodeId, err.Error()) + return nil + } + logger.Warningf(ctx, "Failed to record nodeEvent, error [%s]", err.Error()) + return errors.Wrapf(errors.IllegalStateError, nodeEvent.Id.NodeId, err, "phase mis-match mismatch between propeller and control plane; Trying to record Node p: %s", nodeEvent.Phase) } - return err } + return err +}*/ - return h.Finalize(ctx, nCtx) -} +func (c *nodeExecutor) recoverInputs(ctx context.Context, nCtx interfaces.NodeExecutionContext, + recovered *admin.NodeExecution, recoveredData *admin.NodeExecutionGetDataResponse) (*core.LiteralMap, error) { -func (c *nodeExecutor) finalize(ctx context.Context, h handler.Node, nCtx handler.NodeExecutionContext) error { - return h.Finalize(ctx, nCtx) -} + nodeInputs := recoveredData.FullInputs + if nodeInputs != nil { + if err := c.store.WriteProtobuf(ctx, nCtx.InputReader().GetInputPath(), storage.Options{}, nodeInputs); err != nil { + c.metrics.InputsWriteFailure.Inc(ctx) + logger.Errorf(ctx, "Failed to move recovered inputs for Node. Error [%v]. InputsFile [%s]", err, nCtx.InputReader().GetInputPath()) + return nil, errors.Wrapf(errors.StorageError, nCtx.NodeID(), err, "Failed to store inputs for Node. InputsFile [%s]", nCtx.InputReader().GetInputPath()) + } + } else if len(recovered.InputUri) > 0 { + // If the inputs are too large they won't be returned inline in the RecoverData call. We must fetch them before copying them. + nodeInputs = &core.LiteralMap{} + if recoveredData.FullInputs == nil { + if err := c.store.ReadProtobuf(ctx, storage.DataReference(recovered.InputUri), nodeInputs); err != nil { + return nil, errors.Wrapf(errors.InputsNotFoundError, nCtx.NodeID(), err, "failed to read data from dataDir [%v].", recovered.InputUri) + } + } -func (c *nodeExecutor) handleNotYetStartedNode(ctx context.Context, dag executors.DAGStructure, nCtx *nodeExecContext, _ handler.Node) (executors.NodeStatus, error) { - logger.Debugf(ctx, "Node not yet started, running pre-execute") - defer logger.Debugf(ctx, "Node pre-execute completed") - occurredAt := time.Now() - p, err := c.preExecute(ctx, dag, nCtx) - if err != nil { - logger.Errorf(ctx, "failed preExecute for node. Error: %s", err.Error()) - return executors.NodeStatusUndefined, err + if err := c.store.WriteProtobuf(ctx, nCtx.InputReader().GetInputPath(), storage.Options{}, nodeInputs); err != nil { + c.metrics.InputsWriteFailure.Inc(ctx) + logger.Errorf(ctx, "Failed to move recovered inputs for Node. Error [%v]. InputsFile [%s]", err, nCtx.InputReader().GetInputPath()) + return nil, errors.Wrapf(errors.StorageError, nCtx.NodeID(), err, "Failed to store inputs for Node. InputsFile [%s]", nCtx.InputReader().GetInputPath()) + } } - if p.GetPhase() == handler.EPhaseUndefined { - return executors.NodeStatusUndefined, errors.Errorf(errors.IllegalStateError, nCtx.NodeID(), "received undefined phase.") - } + return nodeInputs, nil +} - if p.GetPhase() == handler.EPhaseNotReady { - return executors.NodeStatusPending, nil +func (c *nodeExecutor) attemptRecovery(ctx context.Context, nCtx interfaces.NodeExecutionContext) (handler.PhaseInfo, error) { + fullyQualifiedNodeID := nCtx.NodeExecutionMetadata().GetNodeExecutionID().NodeId + if nCtx.ExecutionContext().GetEventVersion() != v1alpha1.EventVersion0 { + // compute fully qualified node id (prefixed with parent id and retry attempt) to ensure uniqueness + var err error + fullyQualifiedNodeID, err = common.GenerateUniqueID(nCtx.ExecutionContext().GetParentInfo(), nCtx.NodeExecutionMetadata().GetNodeExecutionID().NodeId) + if err != nil { + return handler.PhaseInfoUndefined, err + } } - np, err := ToNodePhase(p.GetPhase()) + recovered, err := c.recoveryClient.RecoverNodeExecution(ctx, nCtx.ExecutionContext().GetExecutionConfig().RecoveryExecution.WorkflowExecutionIdentifier, fullyQualifiedNodeID) if err != nil { - return executors.NodeStatusUndefined, errors.Wrapf(errors.IllegalStateError, nCtx.NodeID(), err, "failed to move from queued") + st, ok := status.FromError(err) + if !ok || st.Code() != codes.NotFound { + logger.Warnf(ctx, "Failed to recover node [%+v] with err [%+v]", nCtx.NodeExecutionMetadata().GetNodeExecutionID(), err) + } + // The node is not recoverable when it's not found in the parent execution + return handler.PhaseInfoUndefined, nil + } + if recovered == nil { + logger.Warnf(ctx, "call to recover node [%+v] returned no error but also no node", nCtx.NodeExecutionMetadata().GetNodeExecutionID()) + return handler.PhaseInfoUndefined, nil + } + if recovered.Closure == nil { + logger.Warnf(ctx, "Fetched node execution [%+v] data but was missing closure. Will not attempt to recover", + nCtx.NodeExecutionMetadata().GetNodeExecutionID()) + return handler.PhaseInfoUndefined, nil } + // A recoverable node execution should always be in a terminal phase + switch recovered.Closure.Phase { + case core.NodeExecution_SKIPPED: + return handler.PhaseInfoUndefined, nil + case core.NodeExecution_SUCCEEDED: + fallthrough + case core.NodeExecution_RECOVERED: + logger.Debugf(ctx, "Node [%+v] can be recovered. Proceeding to copy inputs and outputs", nCtx.NodeExecutionMetadata().GetNodeExecutionID()) + default: + // The node execution may be partially recoverable through intra task checkpointing. Save the checkpoint + // uri in the task node state to pass to the task handler later on. + if metadata, ok := recovered.Closure.TargetMetadata.(*admin.NodeExecutionClosure_TaskNodeMetadata); ok { + state := nCtx.NodeStateReader().GetTaskNodeState() + state.PreviousNodeExecutionCheckpointURI = storage.DataReference(metadata.TaskNodeMetadata.CheckpointUri) + err = nCtx.NodeStateWriter().PutTaskNodeState(state) + if err != nil { + logger.Warn(ctx, "failed to save recovered checkpoint uri for [%+v]: [%+v]", + nCtx.NodeExecutionMetadata().GetNodeExecutionID(), err) + } + } - nodeStatus := nCtx.NodeStatus() - if np != nodeStatus.GetPhase() { - // assert np == Queued! - logger.Infof(ctx, "Change in node state detected from [%s] -> [%s]", nodeStatus.GetPhase().String(), np.String()) - p = p.WithOccuredAt(occurredAt) + // if this node is a dynamic task we attempt to recover the compiled workflow from instances where the parent + // task succeeded but the dynamic task did not complete. this is important to ensure correctness since node ids + // within the compiled closure may not be generated deterministically. + if recovered.Metadata != nil && recovered.Metadata.IsDynamic && len(recovered.Closure.DynamicJobSpecUri) > 0 { + // recover node inputs + recoveredData, err := c.recoveryClient.RecoverNodeExecutionData(ctx, + nCtx.ExecutionContext().GetExecutionConfig().RecoveryExecution.WorkflowExecutionIdentifier, fullyQualifiedNodeID) + if err != nil || recoveredData == nil { + return handler.PhaseInfoUndefined, nil + } - nev, err := ToNodeExecutionEvent(nCtx.NodeExecutionMetadata().GetNodeExecutionID(), - p, nCtx.InputReader().GetInputPath().String(), nodeStatus, nCtx.ExecutionContext().GetEventVersion(), - nCtx.ExecutionContext().GetParentInfo(), nCtx.node, c.clusterID, nCtx.NodeStateReader().GetDynamicNodeState().Phase, - c.eventConfig) - if err != nil { - return executors.NodeStatusUndefined, errors.Wrapf(errors.IllegalStateError, nCtx.NodeID(), err, "could not convert phase info to event") - } - err = c.IdempotentRecordEvent(ctx, nev) - if err != nil { - logger.Warningf(ctx, "Failed to record nodeEvent, error [%s]", err.Error()) - return executors.NodeStatusUndefined, errors.Wrapf(errors.EventRecordingFailed, nCtx.NodeID(), err, "failed to record node event") - } - UpdateNodeStatus(np, p, nCtx.nsm, nodeStatus) - c.RecordTransitionLatency(ctx, dag, nCtx.ContextualNodeLookup(), nCtx.Node(), nodeStatus) - } + if _, err := c.recoverInputs(ctx, nCtx, recovered, recoveredData); err != nil { + return handler.PhaseInfoUndefined, err + } - if np == v1alpha1.NodePhaseQueued { - if nCtx.md.IsInterruptible() { - c.metrics.InterruptibleNodesRunning.Inc(ctx) - } - return executors.NodeStatusQueued, nil - } else if np == v1alpha1.NodePhaseSkipped { - return executors.NodeStatusSuccess, nil - } + // copy previous DynamicJobSpec file + f, err := task.NewRemoteFutureFileReader(ctx, nCtx.NodeStatus().GetOutputDir(), nCtx.DataStore()) + if err != nil { + return handler.PhaseInfoUndefined, err + } - return executors.NodeStatusPending, nil -} + dynamicJobSpecReference := storage.DataReference(recovered.Closure.DynamicJobSpecUri) + if err := nCtx.DataStore().CopyRaw(ctx, dynamicJobSpecReference, f.GetLoc(), storage.Options{}); err != nil { + return handler.PhaseInfoUndefined, errors.Wrapf(errors.StorageError, nCtx.NodeID(), err, + "failed to store dynamic job spec for node. source file [%s] destination file [%s]", dynamicJobSpecReference, f.GetLoc()) + } -func (c *nodeExecutor) handleQueuedOrRunningNode(ctx context.Context, nCtx *nodeExecContext, h handler.Node) (executors.NodeStatus, error) { - nodeStatus := nCtx.NodeStatus() - currentPhase := nodeStatus.GetPhase() + // transition node phase to 'Running' and dynamic task phase to 'DynamicNodePhaseParentFinalized' + state := nCtx.NodeStateReader().GetDynamicNodeState() + state.Phase = v1alpha1.DynamicNodePhaseParentFinalized + if err := nCtx.NodeStateWriter().PutDynamicNodeState(state); err != nil { + return handler.PhaseInfoUndefined, errors.Wrapf(errors.UnknownError, nCtx.NodeID(), err, "failed to store dynamic node state") + } - // case v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseRunning: - logger.Debugf(ctx, "node executing, current phase [%s]", currentPhase) - defer logger.Debugf(ctx, "node execution completed") + return handler.PhaseInfoRunning(&handler.ExecutionInfo{}), nil + } - // Since we reset node status inside execute for retryable failure, we use lastAttemptStartTime to carry that information - // across execute which is used to emit metrics - lastAttemptStartTime := nodeStatus.GetLastAttemptStartedAt() + logger.Debugf(ctx, "Node [%+v] phase [%v] is not recoverable", nCtx.NodeExecutionMetadata().GetNodeExecutionID(), recovered.Closure.Phase) + return handler.PhaseInfoUndefined, nil + } - p, err := c.execute(ctx, h, nCtx, nodeStatus) + recoveredData, err := c.recoveryClient.RecoverNodeExecutionData(ctx, nCtx.ExecutionContext().GetExecutionConfig().RecoveryExecution.WorkflowExecutionIdentifier, fullyQualifiedNodeID) if err != nil { - logger.Errorf(ctx, "failed Execute for node. Error: %s", err.Error()) - return executors.NodeStatusUndefined, err + st, ok := status.FromError(err) + if !ok || st.Code() != codes.NotFound { + logger.Warnf(ctx, "Failed to attemptRecovery node execution data for [%+v] although back-end indicated node was recoverable with err [%+v]", + nCtx.NodeExecutionMetadata().GetNodeExecutionID(), err) + } + return handler.PhaseInfoUndefined, nil } - - if p.GetPhase() == handler.EPhaseUndefined { - return executors.NodeStatusUndefined, errors.Errorf(errors.IllegalStateError, nCtx.NodeID(), "received undefined phase.") + if recoveredData == nil { + logger.Warnf(ctx, "call to attemptRecovery node [%+v] data returned no error but also no data", nCtx.NodeExecutionMetadata().GetNodeExecutionID()) + return handler.PhaseInfoUndefined, nil } - np, err := ToNodePhase(p.GetPhase()) + // Copy inputs to this node's expected location + nodeInputs, err := c.recoverInputs(ctx, nCtx, recovered, recoveredData) if err != nil { - return executors.NodeStatusUndefined, errors.Wrapf(errors.IllegalStateError, nCtx.NodeID(), err, "failed to move from queued") + return handler.PhaseInfoUndefined, err } - // execErr in phase-info 'p' is only available if node has failed to execute, and the current phase at that time - // will be v1alpha1.NodePhaseRunning - execErr := p.GetErr() - if execErr != nil && (currentPhase == v1alpha1.NodePhaseRunning || currentPhase == v1alpha1.NodePhaseQueued || - currentPhase == v1alpha1.NodePhaseDynamicRunning) { - endTime := time.Now() - startTime := endTime - if lastAttemptStartTime != nil { - startTime = lastAttemptStartTime.Time - } - - if execErr.GetKind() == core.ExecutionError_SYSTEM { - nodeStatus.IncrementSystemFailures() - c.metrics.SystemErrorDuration.Observe(ctx, startTime, endTime) - } else if execErr.GetKind() == core.ExecutionError_USER { - c.metrics.UserErrorDuration.Observe(ctx, startTime, endTime) - } else { - c.metrics.UnknownErrorDuration.Observe(ctx, startTime, endTime) - } - // When a node fails, we fail the workflow. Independent of number of nodes succeeding/failing, whenever a first node fails, - // the entire workflow is failed. - if np == v1alpha1.NodePhaseFailing { - if execErr.GetKind() == core.ExecutionError_SYSTEM { - nodeStatus.IncrementSystemFailures() - c.metrics.PermanentSystemErrorDuration.Observe(ctx, startTime, endTime) - } else if execErr.GetKind() == core.ExecutionError_USER { - c.metrics.PermanentUserErrorDuration.Observe(ctx, startTime, endTime) - } else { - c.metrics.PermanentUnknownErrorDuration.Observe(ctx, startTime, endTime) - } - } - } - finalStatus := executors.NodeStatusRunning - if np == v1alpha1.NodePhaseFailing && !h.FinalizeRequired() { - logger.Infof(ctx, "Finalize not required, moving node to Failed") - np = v1alpha1.NodePhaseFailed - finalStatus = executors.NodeStatusFailed(p.GetErr()) + // Similarly, copy outputs' reference + so := storage.Options{} + var outputs = &core.LiteralMap{} + if recoveredData.FullOutputs != nil { + outputs = recoveredData.FullOutputs + } else if recovered.Closure.GetOutputData() != nil { + outputs = recovered.Closure.GetOutputData() + } else if len(recovered.Closure.GetOutputUri()) > 0 { + if err := c.store.ReadProtobuf(ctx, storage.DataReference(recovered.Closure.GetOutputUri()), outputs); err != nil { + return handler.PhaseInfoUndefined, errors.Wrapf(errors.InputsNotFoundError, nCtx.NodeID(), err, "failed to read output data [%v].", recovered.Closure.GetOutputUri()) + } + } else { + logger.Debugf(ctx, "No outputs found for recovered node [%+v]", nCtx.NodeExecutionMetadata().GetNodeExecutionID()) } - if np == v1alpha1.NodePhaseTimingOut && !h.FinalizeRequired() { - logger.Infof(ctx, "Finalize not required, moving node to TimedOut") - np = v1alpha1.NodePhaseTimedOut - finalStatus = executors.NodeStatusTimedOut + outputFile := v1alpha1.GetOutputsFile(nCtx.NodeStatus().GetOutputDir()) + oi := &handler.OutputInfo{ + OutputURI: outputFile, } - if np == v1alpha1.NodePhaseSucceeding && !h.FinalizeRequired() { - logger.Infof(ctx, "Finalize not required, moving node to Succeeded") - np = v1alpha1.NodePhaseSucceeded - finalStatus = executors.NodeStatusSuccess + deckFile := storage.DataReference(recovered.Closure.GetDeckUri()) + if len(deckFile) > 0 { + metadata, err := nCtx.DataStore().Head(ctx, deckFile) + if err != nil { + logger.Errorf(ctx, "Failed to check the existence of deck file. Error: %v", err) + return handler.PhaseInfoUndefined, errors.Wrapf(errors.CausedByError, nCtx.NodeID(), err, "Failed to check the existence of deck file.") + } + + if metadata.Exists() { + oi.DeckURI = &deckFile + } } - if np == v1alpha1.NodePhaseRecovered { - logger.Infof(ctx, "Finalize not required, moving node to Recovered") - finalStatus = executors.NodeStatusRecovered + + if err := c.store.WriteProtobuf(ctx, outputFile, so, outputs); err != nil { + logger.Errorf(ctx, "Failed to write protobuf (metadata). Error [%v]", err) + return handler.PhaseInfoUndefined, errors.Wrapf(errors.CausedByError, nCtx.NodeID(), err, "Failed to store recovered node execution outputs") } - // If it is retryable failure, we do no want to send any events, as the node is essentially still running - // Similarly if the phase has not changed from the last time, events do not need to be sent - if np != nodeStatus.GetPhase() && np != v1alpha1.NodePhaseRetryableFailure { - // assert np == skipped, succeeding, failing or recovered - logger.Infof(ctx, "Change in node state detected from [%s] -> [%s], (handler phase [%s])", nodeStatus.GetPhase().String(), np.String(), p.GetPhase().String()) + info := &handler.ExecutionInfo{ + Inputs: nodeInputs, + OutputInfo: oi, + } - nev, err := ToNodeExecutionEvent(nCtx.NodeExecutionMetadata().GetNodeExecutionID(), - p, nCtx.InputReader().GetInputPath().String(), nCtx.NodeStatus(), nCtx.ExecutionContext().GetEventVersion(), - nCtx.ExecutionContext().GetParentInfo(), nCtx.node, c.clusterID, nCtx.NodeStateReader().GetDynamicNodeState().Phase, - c.eventConfig) - if err != nil { - return executors.NodeStatusUndefined, errors.Wrapf(errors.IllegalStateError, nCtx.NodeID(), err, "could not convert phase info to event") + if recovered.Closure.GetTaskNodeMetadata() != nil { + taskNodeInfo := &handler.TaskNodeInfo{ + TaskNodeMetadata: &event.TaskNodeMetadata{ + CatalogKey: recovered.Closure.GetTaskNodeMetadata().CatalogKey, + CacheStatus: recovered.Closure.GetTaskNodeMetadata().CacheStatus, + }, } + if recoveredData.DynamicWorkflow != nil { + taskNodeInfo.TaskNodeMetadata.DynamicWorkflow = &event.DynamicWorkflowNodeMetadata{ + Id: recoveredData.DynamicWorkflow.Id, + CompiledWorkflow: recoveredData.DynamicWorkflow.CompiledWorkflow, + } + } + info.TaskNodeInfo = taskNodeInfo + } else if recovered.Closure.GetWorkflowNodeMetadata() != nil { + logger.Warnf(ctx, "Attempted to recover node") + info.WorkflowNodeInfo = &handler.WorkflowNodeInfo{ + LaunchedWorkflowID: recovered.Closure.GetWorkflowNodeMetadata().ExecutionId, + } + } + return handler.PhaseInfoRecovered(info), nil +} - err = c.IdempotentRecordEvent(ctx, nev) - if err != nil { - if eventsErr.IsTooLarge(err) { - // With large enough dynamic task fanouts the reported node event, which contains the compiled - // workflow closure, can exceed the gRPC message size limit. In this case we immediately - // transition the node to failing to abort the workflow. - np = v1alpha1.NodePhaseFailing - p = handler.PhaseInfoFailure(core.ExecutionError_USER, "NodeFailed", err.Error(), p.GetInfo()) - - err = c.IdempotentRecordEvent(ctx, &event.NodeExecutionEvent{ - Id: nCtx.NodeExecutionMetadata().GetNodeExecutionID(), - Phase: core.NodeExecution_FAILED, - OccurredAt: ptypes.TimestampNow(), - OutputResult: &event.NodeExecutionEvent_Error{ - Error: &core.ExecutionError{ - Code: "NodeFailed", - Message: err.Error(), - }, - }, - ReportedAt: ptypes.TimestampNow(), - }) +// In this method we check if the queue is ready to be processed and if so, we prime it in Admin as queued +// Before we start the node execution, we need to transition this Node status to Queued. +// This is because a node execution has to exist before task/wf executions can start. +func (c *nodeExecutor) preExecute(ctx context.Context, dag executors.DAGStructure, nCtx interfaces.NodeExecutionContext) ( + handler.PhaseInfo, error) { + logger.Debugf(ctx, "Node not yet started") + // Query the nodes information to figure out if it can be executed. + predicatePhase, err := CanExecute(ctx, dag, nCtx.ContextualNodeLookup(), nCtx.Node()) + if err != nil { + logger.Debugf(ctx, "Node failed in CanExecute. Error [%s]", err) + return handler.PhaseInfoUndefined, err + } - if err != nil { - return executors.NodeStatusUndefined, errors.Wrapf(errors.EventRecordingFailed, nCtx.NodeID(), err, "failed to record node event") + if predicatePhase == PredicatePhaseReady { + // TODO: Performance problem, we maybe in a retry loop and do not need to resolve the inputs again. + // For now we will do this + node := nCtx.Node() + var nodeInputs *core.LiteralMap + if !node.IsStartNode() { + if nCtx.ExecutionContext().GetExecutionConfig().RecoveryExecution.WorkflowExecutionIdentifier != nil { + phaseInfo, err := c.attemptRecovery(ctx, nCtx) + if err != nil || phaseInfo.GetPhase() != handler.EPhaseUndefined { + return phaseInfo, err } - } else { - logger.Warningf(ctx, "Failed to record nodeEvent, error [%s]", err.Error()) - return executors.NodeStatusUndefined, errors.Wrapf(errors.EventRecordingFailed, nCtx.NodeID(), err, "failed to record node event") } - } + nodeStatus := nCtx.NodeStatus() + dataDir := nodeStatus.GetDataDir() + t := c.metrics.NodeInputGatherLatency.Start(ctx) + defer t.Stop() + // Can execute + var err error + nodeInputs, err = Resolve(ctx, c.outputResolver, nCtx.ContextualNodeLookup(), node.GetID(), node.GetInputBindings()) + // TODO we need to handle retryable, network errors here!! + if err != nil { + c.metrics.ResolutionFailure.Inc(ctx) + logger.Warningf(ctx, "Failed to resolve inputs for Node. Error [%v]", err) + return handler.PhaseInfoFailure(core.ExecutionError_SYSTEM, "BindingResolutionFailure", err.Error(), nil), nil + } - // We reach here only when transitioning from Queued to Running. In this case, the startedAt is not set. - if np == v1alpha1.NodePhaseRunning { - if nodeStatus.GetQueuedAt() != nil { - c.metrics.QueuingLatency.Observe(ctx, nodeStatus.GetQueuedAt().Time, time.Now()) + if nodeInputs != nil { + inputsFile := v1alpha1.GetInputsFile(dataDir) + if err := c.store.WriteProtobuf(ctx, inputsFile, storage.Options{}, nodeInputs); err != nil { + c.metrics.InputsWriteFailure.Inc(ctx) + logger.Errorf(ctx, "Failed to store inputs for Node. Error [%v]. InputsFile [%s]", err, inputsFile) + return handler.PhaseInfoUndefined, errors.Wrapf( + errors.StorageError, node.GetID(), err, "Failed to store inputs for Node. InputsFile [%s]", inputsFile) + } } + + logger.Debugf(ctx, "Node Data Directory [%s].", nodeStatus.GetDataDir()) } - } - UpdateNodeStatus(np, p, nCtx.nsm, nodeStatus) - return finalStatus, nil -} + return handler.PhaseInfoQueued("node queued", nodeInputs), nil + } -func (c *nodeExecutor) handleRetryableFailure(ctx context.Context, nCtx *nodeExecContext, h handler.Node) (executors.NodeStatus, error) { - nodeStatus := nCtx.NodeStatus() - logger.Debugf(ctx, "node failed with retryable failure, aborting and finalizing, message: %s", nodeStatus.GetMessage()) - if err := c.abort(ctx, h, nCtx, nodeStatus.GetMessage()); err != nil { - return executors.NodeStatusUndefined, err + // Now that we have resolved the inputs, we can record as a transition latency. This is because we have completed + // all the overhead that we have to compute. Any failures after this will incur this penalty, but it could be due + // to various external reasons - like queuing, overuse of quota, plugin overhead etc. + logger.Debugf(ctx, "preExecute completed in phase [%s]", predicatePhase.String()) + if predicatePhase == PredicatePhaseSkip { + return handler.PhaseInfoSkip(nil, "Node Skipped as parent node was skipped"), nil } - // NOTE: It is important to increment attempts only after abort has been called. Increment attempt mutates the state - // Attempt is used throughout the system to determine the idempotent resource version. - nodeStatus.IncrementAttempts() - nodeStatus.UpdatePhase(v1alpha1.NodePhaseRunning, metav1.Now(), "retrying", nil) - // We are going to retry in the next round, so we should clear all current state - nodeStatus.ClearSubNodeStatus() - nodeStatus.ClearTaskStatus() - nodeStatus.ClearWorkflowStatus() - nodeStatus.ClearDynamicNodeStatus() - return executors.NodeStatusPending, nil + return handler.PhaseInfoNotReady("predecessor node not yet complete"), nil } -func (c *nodeExecutor) handleNode(ctx context.Context, dag executors.DAGStructure, nCtx *nodeExecContext, h handler.Node) (executors.NodeStatus, error) { - logger.Debugf(ctx, "Handling Node [%s]", nCtx.NodeID()) - defer logger.Debugf(ctx, "Completed node [%s]", nCtx.NodeID()) - - nodeStatus := nCtx.NodeStatus() - currentPhase := nodeStatus.GetPhase() - - // Optimization! - // If it is start node we directly move it to Queued without needing to run preExecute - if currentPhase == v1alpha1.NodePhaseNotYetStarted && !nCtx.Node().IsStartNode() { - p, err := c.handleNotYetStartedNode(ctx, dag, nCtx, h) - if err != nil { - return p, err - } - if p.NodePhase == executors.NodePhaseQueued { - logger.Infof(ctx, "Node was queued, parallelism is now [%d]", nCtx.ExecutionContext().IncrementParallelism()) +func isTimeoutExpired(queuedAt *metav1.Time, timeout time.Duration) bool { + if !queuedAt.IsZero() && timeout != 0 { + deadline := queuedAt.Add(timeout) + if deadline.Before(time.Now()) { + return true } - return p, err } + return false +} - if currentPhase == v1alpha1.NodePhaseFailing { - logger.Debugf(ctx, "node failing") - if err := c.abort(ctx, h, nCtx, "node failing"); err != nil { - return executors.NodeStatusUndefined, err - } - nodeStatus.UpdatePhase(v1alpha1.NodePhaseFailed, metav1.Now(), nodeStatus.GetMessage(), nodeStatus.GetExecutionError()) - c.metrics.FailureDuration.Observe(ctx, nodeStatus.GetStartedAt().Time, nodeStatus.GetStoppedAt().Time) - if nCtx.md.IsInterruptible() { - c.metrics.InterruptibleNodesTerminated.Inc(ctx) - } - return executors.NodeStatusFailed(nodeStatus.GetExecutionError()), nil +func (c *nodeExecutor) isEligibleForRetry(nCtx interfaces.NodeExecutionContext, nodeStatus v1alpha1.ExecutableNodeStatus, err *core.ExecutionError) (currentAttempt, maxAttempts uint32, isEligible bool) { + if err.Kind == core.ExecutionError_SYSTEM { + currentAttempt = nodeStatus.GetSystemFailures() + maxAttempts = c.maxNodeRetriesForSystemFailures + isEligible = currentAttempt < c.maxNodeRetriesForSystemFailures + return } - if currentPhase == v1alpha1.NodePhaseTimingOut { - logger.Debugf(ctx, "node timing out") - if err := c.abort(ctx, h, nCtx, "node timed out"); err != nil { - return executors.NodeStatusUndefined, err - } + currentAttempt = (nodeStatus.GetAttempts() + 1) - nodeStatus.GetSystemFailures() + if nCtx.Node().GetRetryStrategy() != nil && nCtx.Node().GetRetryStrategy().MinAttempts != nil { + maxAttempts = uint32(*nCtx.Node().GetRetryStrategy().MinAttempts) + } + isEligible = currentAttempt < maxAttempts + return +} - nodeStatus.ClearSubNodeStatus() - nodeStatus.UpdatePhase(v1alpha1.NodePhaseTimedOut, metav1.Now(), nodeStatus.GetMessage(), nodeStatus.GetExecutionError()) - c.metrics.TimedOutFailure.Inc(ctx) - if nCtx.md.IsInterruptible() { - c.metrics.InterruptibleNodesTerminated.Inc(ctx) - } - return executors.NodeStatusTimedOut, nil +func (c *nodeExecutor) execute(ctx context.Context, h interfaces.NodeHandler, nCtx interfaces.NodeExecutionContext, nodeStatus v1alpha1.ExecutableNodeStatus) (handler.PhaseInfo, error) { + logger.Debugf(ctx, "Executing node") + defer logger.Debugf(ctx, "Node execution round complete") + + t, err := h.Handle(ctx, nCtx) + if err != nil { + return handler.PhaseInfoUndefined, err } - if currentPhase == v1alpha1.NodePhaseSucceeding { - logger.Debugf(ctx, "node succeeding") - if err := c.finalize(ctx, h, nCtx); err != nil { - return executors.NodeStatusUndefined, err + phase := t.Info() + // check for timeout for non-terminal phases + if !phase.GetPhase().IsTerminal() { + activeDeadline := c.defaultActiveDeadline + if nCtx.Node().GetActiveDeadline() != nil && *nCtx.Node().GetActiveDeadline() > 0 { + activeDeadline = *nCtx.Node().GetActiveDeadline() } - t := metav1.Now() - - started := nodeStatus.GetStartedAt() - if started == nil { - started = &t + if isTimeoutExpired(nodeStatus.GetQueuedAt(), activeDeadline) { + logger.Infof(ctx, "Node has timed out; timeout configured: %v", activeDeadline) + return handler.PhaseInfoTimedOut(nil, fmt.Sprintf("task active timeout [%s] expired", activeDeadline.String())), nil } - stopped := nodeStatus.GetStoppedAt() - if stopped == nil { - stopped = &t + + // Execution timeout is a retry-able error + executionDeadline := c.defaultExecutionDeadline + if nCtx.Node().GetExecutionDeadline() != nil && *nCtx.Node().GetExecutionDeadline() > 0 { + executionDeadline = *nCtx.Node().GetExecutionDeadline() } - c.metrics.SuccessDuration.Observe(ctx, started.Time, stopped.Time) - nodeStatus.ClearSubNodeStatus() - nodeStatus.UpdatePhase(v1alpha1.NodePhaseSucceeded, t, "completed successfully", nil) - if nCtx.md.IsInterruptible() { - c.metrics.InterruptibleNodesTerminated.Inc(ctx) + if isTimeoutExpired(nodeStatus.GetLastAttemptStartedAt(), executionDeadline) { + logger.Infof(ctx, "Current execution for the node timed out; timeout configured: %v", executionDeadline) + executionErr := &core.ExecutionError{Code: "TimeoutExpired", Message: fmt.Sprintf("task execution timeout [%s] expired", executionDeadline.String()), Kind: core.ExecutionError_USER} + phase = handler.PhaseInfoRetryableFailureErr(executionErr, nil) } - return executors.NodeStatusSuccess, nil } - if currentPhase == v1alpha1.NodePhaseRetryableFailure { - return c.handleRetryableFailure(ctx, nCtx, h) - } + if phase.GetPhase() == handler.EPhaseRetryableFailure { + currentAttempt, maxAttempts, isEligible := c.isEligibleForRetry(nCtx, nodeStatus, phase.GetErr()) + if !isEligible { + return handler.PhaseInfoFailure( + core.ExecutionError_USER, + fmt.Sprintf("RetriesExhausted|%s", phase.GetErr().Code), + fmt.Sprintf("[%d/%d] currentAttempt done. Last Error: %s::%s", currentAttempt, maxAttempts, phase.GetErr().Kind.String(), phase.GetErr().Message), + phase.GetInfo(), + ), nil + } - if currentPhase == v1alpha1.NodePhaseFailed { - // This should never happen - return executors.NodeStatusFailed(nodeStatus.GetExecutionError()), nil + // Retrying to clearing all status + nCtx.NodeStateWriter().ClearNodeStatus() } - return c.handleQueuedOrRunningNode(ctx, nCtx, h) + return phase, nil } -// The space search for the next node to execute is implemented like a DFS algorithm. handleDownstream visits all the nodes downstream from -// the currentNode. Visit a node is the RecursiveNodeHandler. A visit may be partial, complete or may result in a failure. -func (c *nodeExecutor) handleDownstream(ctx context.Context, execContext executors.ExecutionContext, dag executors.DAGStructure, nl executors.NodeLookup, currentNode v1alpha1.ExecutableNode) (executors.NodeStatus, error) { - logger.Debugf(ctx, "Handling downstream Nodes") - // This node is success. Handle all downstream nodes - downstreamNodes, err := dag.FromNode(currentNode.GetID()) - if err != nil { - logger.Debugf(ctx, "Error when retrieving downstream nodes, [%s]", err) - return executors.NodeStatusFailed(&core.ExecutionError{ - Code: errors.BadSpecificationError, - Message: fmt.Sprintf("failed to retrieve downstream nodes for [%s]", currentNode.GetID()), - Kind: core.ExecutionError_SYSTEM, - }), nil +func (c *nodeExecutor) Abort(ctx context.Context, h interfaces.NodeHandler, nCtx interfaces.NodeExecutionContext, reason string, finalTransition bool) error { + logger.Debugf(ctx, "Calling aborting & finalize") + if err := h.Abort(ctx, nCtx, reason); err != nil { + finalizeErr := h.Finalize(ctx, nCtx) + if finalizeErr != nil { + return errors.ErrorCollection{Errors: []error{err, finalizeErr}} + } + return err } - if len(downstreamNodes) == 0 { - logger.Debugf(ctx, "No downstream nodes found. Complete.") - return executors.NodeStatusComplete, nil + + if err := h.Finalize(ctx, nCtx); err != nil { + return err } - // If any downstream node is failed, fail, all - // Else if all are success then success - // Else if any one is running then Downstream is still running - allCompleted := true - partialNodeCompletion := false - onFailurePolicy := execContext.GetOnFailurePolicy() - stateOnComplete := executors.NodeStatusComplete - for _, downstreamNodeName := range downstreamNodes { - downstreamNode, ok := nl.GetNode(downstreamNodeName) - if !ok { - return executors.NodeStatusFailed(&core.ExecutionError{ - Code: errors.BadSpecificationError, - Message: fmt.Sprintf("failed to retrieve downstream node [%s] for [%s]", downstreamNodeName, currentNode.GetID()), - Kind: core.ExecutionError_SYSTEM, - }), nil - } - state, err := c.RecursiveNodeHandler(ctx, execContext, dag, nl, downstreamNode) - if err != nil { - return executors.NodeStatusUndefined, err + // only send event if this is the final transition for this node + if finalTransition { + nodeExecutionID := &core.NodeExecutionIdentifier{ + ExecutionId: nCtx.NodeExecutionMetadata().GetNodeExecutionID().ExecutionId, + NodeId: nCtx.NodeExecutionMetadata().GetNodeExecutionID().NodeId, } - - if state.HasFailed() || state.HasTimedOut() { - logger.Debugf(ctx, "Some downstream node has failed. Failed: [%v]. TimedOut: [%v]. Error: [%s]", state.HasFailed(), state.HasTimedOut(), state.Err) - if onFailurePolicy == v1alpha1.WorkflowOnFailurePolicy(core.WorkflowMetadata_FAIL_AFTER_EXECUTABLE_NODES_COMPLETE) { - // If the failure policy allows other nodes to continue running, do not exit the loop, - // Keep track of the last failed state in the loop since it'll be the one to return. - // TODO: If multiple nodes fail (which this mode allows), consolidate/summarize failure states in one. - stateOnComplete = state - } else { - return state, nil + if nCtx.ExecutionContext().GetEventVersion() != v1alpha1.EventVersion0 { + currentNodeUniqueID, err := common.GenerateUniqueID(nCtx.ExecutionContext().GetParentInfo(), nodeExecutionID.NodeId) + if err != nil { + return err } - } else if !state.IsComplete() { - // A Failed/Timedout node is implicitly considered "complete" this means none of the downstream nodes from - // that node will ever be allowed to run. - // This else block, therefore, deals with all other states. IsComplete will return true if and only if this - // node as well as all of its downstream nodes have finished executing with success statuses. Otherwise we - // mark this node's state as not completed to ensure we will visit it again later. - allCompleted = false + nodeExecutionID.NodeId = currentNodeUniqueID } - if state.PartiallyComplete() { - // This implies that one of the downstream nodes has just succeeded and workflow is ready for propagation - // We do not propagate in current cycle to make it possible to store the state between transitions - partialNodeCompletion = true + err := nCtx.EventsRecorder().RecordNodeEvent(ctx, &event.NodeExecutionEvent{ + Id: nodeExecutionID, + Phase: core.NodeExecution_ABORTED, + OccurredAt: ptypes.TimestampNow(), + OutputResult: &event.NodeExecutionEvent_Error{ + Error: &core.ExecutionError{ + Code: "NodeAborted", + Message: reason, + }, + }, + ProducerId: c.clusterID, + ReportedAt: ptypes.TimestampNow(), + }, c.eventConfig) + if err != nil && !eventsErr.IsNotFound(err) && !eventsErr.IsEventIncompatibleClusterError(err) { + if errors2.IsCausedBy(err, errors.IllegalStateError) { + logger.Debugf(ctx, "Failed to record abort event due to illegal state transition. Ignoring the error. Error: %v", err) + } else { + logger.Warningf(ctx, "Failed to record nodeEvent, error [%s]", err.Error()) + return errors.Wrapf(errors.EventRecordingFailed, nCtx.NodeID(), err, "failed to record node event") + } } } - if allCompleted { - logger.Debugf(ctx, "All downstream nodes completed") - return stateOnComplete, nil - } - - if partialNodeCompletion { - return executors.NodeStatusSuccess, nil - } + return nil +} - return executors.NodeStatusPending, nil +func (c *nodeExecutor) Finalize(ctx context.Context, h interfaces.NodeHandler, nCtx interfaces.NodeExecutionContext) error { + return h.Finalize(ctx, nCtx) } -func (c *nodeExecutor) SetInputsForStartNode(ctx context.Context, execContext executors.ExecutionContext, dag executors.DAGStructureWithStartNode, nl executors.NodeLookup, inputs *core.LiteralMap) (executors.NodeStatus, error) { - startNode := dag.StartNode() - ctx = contextutils.WithNodeID(ctx, startNode.GetID()) - if inputs == nil { - logger.Infof(ctx, "No inputs for the workflow. Skipping storing inputs") - return executors.NodeStatusComplete, nil +func (c *nodeExecutor) handleNotYetStartedNode(ctx context.Context, dag executors.DAGStructure, nCtx interfaces.NodeExecutionContext, _ interfaces.NodeHandler) (interfaces.NodeStatus, error) { + logger.Debugf(ctx, "Node not yet started, running pre-execute") + defer logger.Debugf(ctx, "Node pre-execute completed") + occurredAt := time.Now() + p, err := c.preExecute(ctx, dag, nCtx) + if err != nil { + logger.Errorf(ctx, "failed preExecute for node. Error: %s", err.Error()) + return interfaces.NodeStatusUndefined, err } - // StartNode is special. It does not have any processing step. It just takes the workflow (or subworkflow) inputs and converts to its own outputs - nodeStatus := nl.GetNodeExecutionStatus(ctx, startNode.GetID()) - - if len(nodeStatus.GetDataDir()) == 0 { - return executors.NodeStatusUndefined, errors.Errorf(errors.IllegalStateError, startNode.GetID(), "no data-dir set, cannot store inputs") + if p.GetPhase() == handler.EPhaseUndefined { + return interfaces.NodeStatusUndefined, errors.Errorf(errors.IllegalStateError, nCtx.NodeID(), "received undefined phase.") } - outputFile := v1alpha1.GetOutputsFile(nodeStatus.GetOutputDir()) - so := storage.Options{} - if err := c.store.WriteProtobuf(ctx, outputFile, so, inputs); err != nil { - logger.Errorf(ctx, "Failed to write protobuf (metadata). Error [%v]", err) - return executors.NodeStatusUndefined, errors.Wrapf(errors.CausedByError, startNode.GetID(), err, "Failed to store workflow inputs (as start node)") + if p.GetPhase() == handler.EPhaseNotReady { + return interfaces.NodeStatusPending, nil } - return executors.NodeStatusComplete, nil -} + np, err := ToNodePhase(p.GetPhase()) + if err != nil { + return interfaces.NodeStatusUndefined, errors.Wrapf(errors.IllegalStateError, nCtx.NodeID(), err, "failed to move from queued") + } -func canHandleNode(phase v1alpha1.NodePhase) bool { - return phase == v1alpha1.NodePhaseNotYetStarted || - phase == v1alpha1.NodePhaseQueued || - phase == v1alpha1.NodePhaseRunning || - phase == v1alpha1.NodePhaseFailing || - phase == v1alpha1.NodePhaseTimingOut || - phase == v1alpha1.NodePhaseRetryableFailure || - phase == v1alpha1.NodePhaseSucceeding || - phase == v1alpha1.NodePhaseDynamicRunning -} + nodeStatus := nCtx.NodeStatus() + if np != nodeStatus.GetPhase() { + // assert np == Queued! + logger.Infof(ctx, "Change in node state detected from [%s] -> [%s]", nodeStatus.GetPhase().String(), np.String()) + p = p.WithOccuredAt(occurredAt) -// IsMaxParallelismAchieved checks if we have already achieved max parallelism. It returns true, if the desired max parallelism -// value is achieved, false otherwise -// MaxParallelism is defined as the maximum number of TaskNodes and LaunchPlans (together) that can be executed concurrently -// by one workflow execution. A setting of `0` indicates that it is disabled. -func IsMaxParallelismAchieved(ctx context.Context, currentNode v1alpha1.ExecutableNode, currentPhase v1alpha1.NodePhase, - execContext executors.ExecutionContext) bool { - maxParallelism := execContext.GetExecutionConfig().MaxParallelism - if maxParallelism == 0 { - logger.Debugf(ctx, "Parallelism control disabled") - return false + nev, err := ToNodeExecutionEvent(nCtx.NodeExecutionMetadata().GetNodeExecutionID(), + p, nCtx.InputReader().GetInputPath().String(), nodeStatus, nCtx.ExecutionContext().GetEventVersion(), + nCtx.ExecutionContext().GetParentInfo(), nCtx.Node(), c.clusterID, nCtx.NodeStateReader().GetDynamicNodeState().Phase, + c.eventConfig) + if err != nil { + return interfaces.NodeStatusUndefined, errors.Wrapf(errors.IllegalStateError, nCtx.NodeID(), err, "could not convert phase info to event") + } + err = nCtx.EventsRecorder().RecordNodeEvent(ctx, nev, c.eventConfig) + if err != nil { + logger.Warningf(ctx, "Failed to record nodeEvent, error [%s]", err.Error()) + return interfaces.NodeStatusUndefined, errors.Wrapf(errors.EventRecordingFailed, nCtx.NodeID(), err, "failed to record node event") + } + UpdateNodeStatus(np, p, nCtx.NodeStateReader(), nodeStatus) + c.RecordTransitionLatency(ctx, dag, nCtx.ContextualNodeLookup(), nCtx.Node(), nodeStatus) } - if currentNode.GetKind() == v1alpha1.NodeKindTask || - (currentNode.GetKind() == v1alpha1.NodeKindWorkflow && currentNode.GetWorkflowNode() != nil && currentNode.GetWorkflowNode().GetLaunchPlanRefID() != nil) { - // If we are queued, let us see if we can proceed within the node parallelism bounds - if execContext.CurrentParallelism() >= maxParallelism { - logger.Infof(ctx, "Maximum Parallelism for task/launch-plan nodes achieved [%d] >= Max [%d], Round will be short-circuited.", execContext.CurrentParallelism(), maxParallelism) - return true + if np == v1alpha1.NodePhaseQueued { + if nCtx.NodeExecutionMetadata().IsInterruptible() { + c.metrics.InterruptibleNodesRunning.Inc(ctx) } - // We know that Propeller goes through each workflow in a single thread, thus every node is really processed - // sequentially. So, we can continue - now that we know we are under the parallelism limits and increment the - // parallelism if the node, enters a running state - logger.Debugf(ctx, "Parallelism criteria not met, Current [%d], Max [%d]", execContext.CurrentParallelism(), maxParallelism) - } else { - logger.Debugf(ctx, "NodeKind: %s in status [%s]. Parallelism control is not applicable. Current Parallelism [%d]", - currentNode.GetKind().String(), currentPhase.String(), execContext.CurrentParallelism()) + return interfaces.NodeStatusQueued, nil + } else if np == v1alpha1.NodePhaseSkipped { + return interfaces.NodeStatusSuccess, nil } - return false -} -// RecursiveNodeHandler This is the entrypoint of executing a node in a workflow. A workflow consists of nodes, that are -// nested within other nodes. The system follows an actor model, where the parent nodes control the execution of nested nodes -// The recursive node-handler uses a modified depth-first type of algorithm to execute non-blocked nodes. -func (c *nodeExecutor) RecursiveNodeHandler(ctx context.Context, execContext executors.ExecutionContext, - dag executors.DAGStructure, nl executors.NodeLookup, currentNode v1alpha1.ExecutableNode) ( - executors.NodeStatus, error) { + return interfaces.NodeStatusPending, nil +} - currentNodeCtx := contextutils.WithNodeID(ctx, currentNode.GetID()) - nodeStatus := nl.GetNodeExecutionStatus(ctx, currentNode.GetID()) - nodePhase := nodeStatus.GetPhase() +func (c *nodeExecutor) handleQueuedOrRunningNode(ctx context.Context, nCtx interfaces.NodeExecutionContext, h interfaces.NodeHandler) (interfaces.NodeStatus, error) { + nodeStatus := nCtx.NodeStatus() + currentPhase := nodeStatus.GetPhase() - if canHandleNode(nodePhase) { - // TODO Follow up Pull Request, - // 1. Rename this method to DAGTraversalHandleNode (accepts a DAGStructure along-with) the remaining arguments - // 2. Create a new method called HandleNode (part of the interface) (remaining all args as the previous method, but no DAGStructure - // 3. Additional both methods will receive inputs reader - // 4. The Downstream nodes handler will Resolve the Inputs - // 5. the method will delegate all other node handling to HandleNode. - // 6. Thus we can get rid of SetInputs for StartNode as well - logger.Debugf(currentNodeCtx, "Handling node Status [%v]", nodeStatus.GetPhase().String()) + // case v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseRunning: + logger.Debugf(ctx, "node executing, current phase [%s]", currentPhase) + defer logger.Debugf(ctx, "node execution completed") - t := c.metrics.NodeExecutionTime.Start(ctx) - defer t.Stop() + // Since we reset node status inside execute for retryable failure, we use lastAttemptStartTime to carry that information + // across execute which is used to emit metrics + lastAttemptStartTime := nodeStatus.GetLastAttemptStartedAt() - // This is an optimization to avoid creating the nodeContext object in case the node has already been looked at. - // If the overhead was zero, we would just do the isDirtyCheck after the nodeContext is created - if nodeStatus.IsDirty() { - return executors.NodeStatusRunning, nil - } + p, err := c.execute(ctx, h, nCtx, nodeStatus) + if err != nil { + logger.Errorf(ctx, "failed Execute for node. Error: %s", err.Error()) + return interfaces.NodeStatusUndefined, err + } - if IsMaxParallelismAchieved(ctx, currentNode, nodePhase, execContext) { - return executors.NodeStatusRunning, nil - } + if p.GetPhase() == handler.EPhaseUndefined { + return interfaces.NodeStatusUndefined, errors.Errorf(errors.IllegalStateError, nCtx.NodeID(), "received undefined phase.") + } - nCtx, err := c.newNodeExecContextDefault(ctx, currentNode.GetID(), execContext, nl) - if err != nil { - // NodeExecution creation failure is a permanent fail / system error. - // Should a system failure always return an err? - return executors.NodeStatusFailed(&core.ExecutionError{ - Code: "InternalError", - Message: err.Error(), - Kind: core.ExecutionError_SYSTEM, - }), nil - } + np, err := ToNodePhase(p.GetPhase()) + if err != nil { + return interfaces.NodeStatusUndefined, errors.Wrapf(errors.IllegalStateError, nCtx.NodeID(), err, "failed to move from queued") + } - // Now depending on the node type decide - h, err := c.nodeHandlerFactory.GetHandler(nCtx.Node().GetKind()) - if err != nil { - return executors.NodeStatusUndefined, err + // execErr in phase-info 'p' is only available if node has failed to execute, and the current phase at that time + // will be v1alpha1.NodePhaseRunning + execErr := p.GetErr() + if execErr != nil && (currentPhase == v1alpha1.NodePhaseRunning || currentPhase == v1alpha1.NodePhaseQueued || + currentPhase == v1alpha1.NodePhaseDynamicRunning) { + endTime := time.Now() + startTime := endTime + if lastAttemptStartTime != nil { + startTime = lastAttemptStartTime.Time } - return c.handleNode(currentNodeCtx, dag, nCtx, h) - - // TODO we can optimize skip state handling by iterating down the graph and marking all as skipped - // Currently we treat either Skip or Success the same way. In this approach only one node will be skipped - // at a time. As we iterate down, further nodes will be skipped - } else if nodePhase == v1alpha1.NodePhaseSucceeded || nodePhase == v1alpha1.NodePhaseSkipped || nodePhase == v1alpha1.NodePhaseRecovered { - logger.Debugf(currentNodeCtx, "Node has [%v], traversing downstream.", nodePhase) - return c.handleDownstream(ctx, execContext, dag, nl, currentNode) - } else if nodePhase == v1alpha1.NodePhaseFailed { - logger.Debugf(currentNodeCtx, "Node has failed, traversing downstream.") - _, err := c.handleDownstream(ctx, execContext, dag, nl, currentNode) - if err != nil { - return executors.NodeStatusUndefined, err + if execErr.GetKind() == core.ExecutionError_SYSTEM { + nodeStatus.IncrementSystemFailures() + c.metrics.SystemErrorDuration.Observe(ctx, startTime, endTime) + } else if execErr.GetKind() == core.ExecutionError_USER { + c.metrics.UserErrorDuration.Observe(ctx, startTime, endTime) + } else { + c.metrics.UnknownErrorDuration.Observe(ctx, startTime, endTime) } - - return executors.NodeStatusFailed(nodeStatus.GetExecutionError()), nil - } else if nodePhase == v1alpha1.NodePhaseTimedOut { - logger.Debugf(currentNodeCtx, "Node has timed out, traversing downstream.") - _, err := c.handleDownstream(ctx, execContext, dag, nl, currentNode) - if err != nil { - return executors.NodeStatusUndefined, err + // When a node fails, we fail the workflow. Independent of number of nodes succeeding/failing, whenever a first node fails, + // the entire workflow is failed. + if np == v1alpha1.NodePhaseFailing { + if execErr.GetKind() == core.ExecutionError_SYSTEM { + nodeStatus.IncrementSystemFailures() + c.metrics.PermanentSystemErrorDuration.Observe(ctx, startTime, endTime) + } else if execErr.GetKind() == core.ExecutionError_USER { + c.metrics.PermanentUserErrorDuration.Observe(ctx, startTime, endTime) + } else { + c.metrics.PermanentUnknownErrorDuration.Observe(ctx, startTime, endTime) + } } - - return executors.NodeStatusTimedOut, nil + } + finalStatus := interfaces.NodeStatusRunning + if np == v1alpha1.NodePhaseFailing && !h.FinalizeRequired() { + logger.Infof(ctx, "Finalize not required, moving node to Failed") + np = v1alpha1.NodePhaseFailed + finalStatus = interfaces.NodeStatusFailed(p.GetErr()) } - return executors.NodeStatusUndefined, errors.Errorf(errors.IllegalStateError, currentNode.GetID(), - "Should never reach here. Current Phase: %v", nodePhase) -} - -func (c *nodeExecutor) FinalizeHandler(ctx context.Context, execContext executors.ExecutionContext, dag executors.DAGStructure, nl executors.NodeLookup, currentNode v1alpha1.ExecutableNode) error { - nodeStatus := nl.GetNodeExecutionStatus(ctx, currentNode.GetID()) - nodePhase := nodeStatus.GetPhase() + if np == v1alpha1.NodePhaseTimingOut && !h.FinalizeRequired() { + logger.Infof(ctx, "Finalize not required, moving node to TimedOut") + np = v1alpha1.NodePhaseTimedOut + finalStatus = interfaces.NodeStatusTimedOut + } - if nodePhase == v1alpha1.NodePhaseNotYetStarted { - logger.Infof(ctx, "Node not yet started, will not finalize") - // Nothing to be aborted - return nil + if np == v1alpha1.NodePhaseSucceeding && !h.FinalizeRequired() { + logger.Infof(ctx, "Finalize not required, moving node to Succeeded") + np = v1alpha1.NodePhaseSucceeded + finalStatus = interfaces.NodeStatusSuccess + } + if np == v1alpha1.NodePhaseRecovered { + logger.Infof(ctx, "Finalize not required, moving node to Recovered") + finalStatus = interfaces.NodeStatusRecovered } - if canHandleNode(nodePhase) { - ctx = contextutils.WithNodeID(ctx, currentNode.GetID()) + // If it is retryable failure, we do no want to send any events, as the node is essentially still running + // Similarly if the phase has not changed from the last time, events do not need to be sent + if np != nodeStatus.GetPhase() && np != v1alpha1.NodePhaseRetryableFailure { + // assert np == skipped, succeeding, failing or recovered + logger.Infof(ctx, "Change in node state detected from [%s] -> [%s], (handler phase [%s])", nodeStatus.GetPhase().String(), np.String(), p.GetPhase().String()) - // Now depending on the node type decide - h, err := c.nodeHandlerFactory.GetHandler(currentNode.GetKind()) + nev, err := ToNodeExecutionEvent(nCtx.NodeExecutionMetadata().GetNodeExecutionID(), + p, nCtx.InputReader().GetInputPath().String(), nCtx.NodeStatus(), nCtx.ExecutionContext().GetEventVersion(), + nCtx.ExecutionContext().GetParentInfo(), nCtx.Node(), c.clusterID, nCtx.NodeStateReader().GetDynamicNodeState().Phase, + c.eventConfig) if err != nil { - return err + return interfaces.NodeStatusUndefined, errors.Wrapf(errors.IllegalStateError, nCtx.NodeID(), err, "could not convert phase info to event") } - nCtx, err := c.newNodeExecContextDefault(ctx, currentNode.GetID(), execContext, nl) - if err != nil { - return err - } - // Abort this node - err = c.finalize(ctx, h, nCtx) - if err != nil { - return err - } - } else { - // Abort downstream nodes - downstreamNodes, err := dag.FromNode(currentNode.GetID()) + err = nCtx.EventsRecorder().RecordNodeEvent(ctx, nev, c.eventConfig) if err != nil { - logger.Debugf(ctx, "Error when retrieving downstream nodes. Error [%v]", err) - return nil - } + if eventsErr.IsTooLarge(err) { + // With large enough dynamic task fanouts the reported node event, which contains the compiled + // workflow closure, can exceed the gRPC message size limit. In this case we immediately + // transition the node to failing to abort the workflow. + np = v1alpha1.NodePhaseFailing + p = handler.PhaseInfoFailure(core.ExecutionError_USER, "NodeFailed", err.Error(), p.GetInfo()) - errs := make([]error, 0, len(downstreamNodes)) - for _, d := range downstreamNodes { - downstreamNode, ok := nl.GetNode(d) - if !ok { - return errors.Errorf(errors.BadSpecificationError, currentNode.GetID(), "Unable to find Downstream Node [%v]", d) - } + err = nCtx.EventsRecorder().RecordNodeEvent(ctx, &event.NodeExecutionEvent{ + Id: nCtx.NodeExecutionMetadata().GetNodeExecutionID(), + Phase: core.NodeExecution_FAILED, + OccurredAt: ptypes.TimestampNow(), + OutputResult: &event.NodeExecutionEvent_Error{ + Error: &core.ExecutionError{ + Code: "NodeFailed", + Message: err.Error(), + }, + }, + ReportedAt: ptypes.TimestampNow(), + }, c.eventConfig) - if err := c.FinalizeHandler(ctx, execContext, dag, nl, downstreamNode); err != nil { - logger.Infof(ctx, "Failed to abort node [%v]. Error: %v", d, err) - errs = append(errs, err) + if err != nil { + return interfaces.NodeStatusUndefined, errors.Wrapf(errors.EventRecordingFailed, nCtx.NodeID(), err, "failed to record node event") + } + } else { + logger.Warningf(ctx, "Failed to record nodeEvent, error [%s]", err.Error()) + return interfaces.NodeStatusUndefined, errors.Wrapf(errors.EventRecordingFailed, nCtx.NodeID(), err, "failed to record node event") } } - if len(errs) > 0 { - return errors.ErrorCollection{Errors: errs} + // We reach here only when transitioning from Queued to Running. In this case, the startedAt is not set. + if np == v1alpha1.NodePhaseRunning { + if nodeStatus.GetQueuedAt() != nil { + c.metrics.QueuingLatency.Observe(ctx, nodeStatus.GetQueuedAt().Time, time.Now()) + } } - - return nil } - return nil + UpdateNodeStatus(np, p, nCtx.NodeStateReader(), nodeStatus) + return finalStatus, nil } -func (c *nodeExecutor) AbortHandler(ctx context.Context, execContext executors.ExecutionContext, dag executors.DAGStructure, nl executors.NodeLookup, currentNode v1alpha1.ExecutableNode, reason string) error { - nodeStatus := nl.GetNodeExecutionStatus(ctx, currentNode.GetID()) - nodePhase := nodeStatus.GetPhase() - - if nodePhase == v1alpha1.NodePhaseNotYetStarted { - logger.Infof(ctx, "Node not yet started, will not finalize") - // Nothing to be aborted - return nil +func (c *nodeExecutor) handleRetryableFailure(ctx context.Context, nCtx interfaces.NodeExecutionContext, h interfaces.NodeHandler) (interfaces.NodeStatus, error) { + nodeStatus := nCtx.NodeStatus() + logger.Debugf(ctx, "node failed with retryable failure, aborting and finalizing, message: %s", nodeStatus.GetMessage()) + if err := c.Abort(ctx, h, nCtx, nodeStatus.GetMessage(), false); err != nil { + return interfaces.NodeStatusUndefined, err } - if canHandleNode(nodePhase) { - ctx = contextutils.WithNodeID(ctx, currentNode.GetID()) + // NOTE: It is important to increment attempts only after abort has been called. Increment attempt mutates the state + // Attempt is used throughout the system to determine the idempotent resource version. + nodeStatus.IncrementAttempts() + nodeStatus.UpdatePhase(v1alpha1.NodePhaseRunning, metav1.Now(), "retrying", nil) + // We are going to retry in the next round, so we should clear all current state + nodeStatus.ClearSubNodeStatus() + nodeStatus.ClearTaskStatus() + nodeStatus.ClearWorkflowStatus() + nodeStatus.ClearDynamicNodeStatus() + nodeStatus.ClearGateNodeStatus() + nodeStatus.ClearArrayNodeStatus() + return interfaces.NodeStatusPending, nil +} - // Now depending on the node type decide - h, err := c.nodeHandlerFactory.GetHandler(currentNode.GetKind()) - if err != nil { - return err - } +func (c *nodeExecutor) HandleNode(ctx context.Context, dag executors.DAGStructure, nCtx interfaces.NodeExecutionContext, h interfaces.NodeHandler) (interfaces.NodeStatus, error) { + logger.Debugf(ctx, "Handling Node [%s]", nCtx.NodeID()) + defer logger.Debugf(ctx, "Completed node [%s]", nCtx.NodeID()) + + nodeStatus := nCtx.NodeStatus() + currentPhase := nodeStatus.GetPhase() - nCtx, err := c.newNodeExecContextDefault(ctx, currentNode.GetID(), execContext, nl) + // Optimization! + // If it is start node we directly move it to Queued without needing to run preExecute + if currentPhase == v1alpha1.NodePhaseNotYetStarted && !nCtx.Node().IsStartNode() { + p, err := c.handleNotYetStartedNode(ctx, dag, nCtx, h) if err != nil { - return err + return p, err } - // Abort this node - err = c.abort(ctx, h, nCtx, reason) - if err != nil { - return err + if p.NodePhase == interfaces.NodePhaseQueued { + logger.Infof(ctx, "Node was queued, parallelism is now [%d]", nCtx.ExecutionContext().IncrementParallelism()) } - nodeExecutionID := &core.NodeExecutionIdentifier{ - ExecutionId: nCtx.NodeExecutionMetadata().GetNodeExecutionID().ExecutionId, - NodeId: nCtx.NodeExecutionMetadata().GetNodeExecutionID().NodeId, + return p, err + } + + if currentPhase == v1alpha1.NodePhaseFailing { + logger.Debugf(ctx, "node failing") + if err := c.Abort(ctx, h, nCtx, "node failing", false); err != nil { + return interfaces.NodeStatusUndefined, err } - if nCtx.ExecutionContext().GetEventVersion() != v1alpha1.EventVersion0 { - currentNodeUniqueID, err := common.GenerateUniqueID(nCtx.ExecutionContext().GetParentInfo(), nodeExecutionID.NodeId) - if err != nil { - return err - } - nodeExecutionID.NodeId = currentNodeUniqueID + nodeStatus.UpdatePhase(v1alpha1.NodePhaseFailed, metav1.Now(), nodeStatus.GetMessage(), nodeStatus.GetExecutionError()) + c.metrics.FailureDuration.Observe(ctx, nodeStatus.GetStartedAt().Time, nodeStatus.GetStoppedAt().Time) + if nCtx.NodeExecutionMetadata().IsInterruptible() { + c.metrics.InterruptibleNodesTerminated.Inc(ctx) } + return interfaces.NodeStatusFailed(nodeStatus.GetExecutionError()), nil + } - err = c.IdempotentRecordEvent(ctx, &event.NodeExecutionEvent{ - Id: nodeExecutionID, - Phase: core.NodeExecution_ABORTED, - OccurredAt: ptypes.TimestampNow(), - OutputResult: &event.NodeExecutionEvent_Error{ - Error: &core.ExecutionError{ - Code: "NodeAborted", - Message: reason, - }, - }, - ProducerId: c.clusterID, - ReportedAt: ptypes.TimestampNow(), - }) - if err != nil && !eventsErr.IsNotFound(err) && !eventsErr.IsEventIncompatibleClusterError(err) { - if errors2.IsCausedBy(err, errors.IllegalStateError) { - logger.Debugf(ctx, "Failed to record abort event due to illegal state transition. Ignoring the error. Error: %v", err) - } else { - logger.Warningf(ctx, "Failed to record nodeEvent, error [%s]", err.Error()) - return errors.Wrapf(errors.EventRecordingFailed, nCtx.NodeID(), err, "failed to record node event") - } - } - } else if nodePhase == v1alpha1.NodePhaseSucceeded || nodePhase == v1alpha1.NodePhaseSkipped || nodePhase == v1alpha1.NodePhaseRecovered { - // Abort downstream nodes - downstreamNodes, err := dag.FromNode(currentNode.GetID()) - if err != nil { - logger.Debugf(ctx, "Error when retrieving downstream nodes. Error [%v]", err) - return nil + if currentPhase == v1alpha1.NodePhaseTimingOut { + logger.Debugf(ctx, "node timing out") + if err := c.Abort(ctx, h, nCtx, "node timed out", false); err != nil { + return interfaces.NodeStatusUndefined, err } - errs := make([]error, 0, len(downstreamNodes)) - for _, d := range downstreamNodes { - downstreamNode, ok := nl.GetNode(d) - if !ok { - return errors.Errorf(errors.BadSpecificationError, currentNode.GetID(), "Unable to find Downstream Node [%v]", d) - } + nodeStatus.ClearSubNodeStatus() + nodeStatus.UpdatePhase(v1alpha1.NodePhaseTimedOut, metav1.Now(), nodeStatus.GetMessage(), nodeStatus.GetExecutionError()) + c.metrics.TimedOutFailure.Inc(ctx) + if nCtx.NodeExecutionMetadata().IsInterruptible() { + c.metrics.InterruptibleNodesTerminated.Inc(ctx) + } + return interfaces.NodeStatusTimedOut, nil + } - if err := c.AbortHandler(ctx, execContext, dag, nl, downstreamNode, reason); err != nil { - logger.Infof(ctx, "Failed to abort node [%v]. Error: %v", d, err) - errs = append(errs, err) - } + if currentPhase == v1alpha1.NodePhaseSucceeding { + logger.Debugf(ctx, "node succeeding") + if err := c.Finalize(ctx, h, nCtx); err != nil { + return interfaces.NodeStatusUndefined, err } + t := metav1.Now() - if len(errs) > 0 { - return errors.ErrorCollection{Errors: errs} + started := nodeStatus.GetStartedAt() + if started == nil { + started = &t + } + stopped := nodeStatus.GetStoppedAt() + if stopped == nil { + stopped = &t + } + c.metrics.SuccessDuration.Observe(ctx, started.Time, stopped.Time) + nodeStatus.ClearSubNodeStatus() + nodeStatus.UpdatePhase(v1alpha1.NodePhaseSucceeded, t, "completed successfully", nil) + if nCtx.NodeExecutionMetadata().IsInterruptible() { + c.metrics.InterruptibleNodesTerminated.Inc(ctx) } + return interfaces.NodeStatusSuccess, nil + } - return nil - } else { - ctx = contextutils.WithNodeID(ctx, currentNode.GetID()) - logger.Warnf(ctx, "Trying to abort a node in state [%s]", nodeStatus.GetPhase().String()) + if currentPhase == v1alpha1.NodePhaseRetryableFailure { + return c.handleRetryableFailure(ctx, nCtx, h) } - return nil -} + if currentPhase == v1alpha1.NodePhaseFailed { + // This should never happen + return interfaces.NodeStatusFailed(nodeStatus.GetExecutionError()), nil + } -func (c *nodeExecutor) Initialize(ctx context.Context) error { - logger.Infof(ctx, "Initializing Core Node Executor") - s := c.newSetupContext(ctx) - return c.nodeHandlerFactory.Setup(ctx, s) + return c.handleQueuedOrRunningNode(ctx, nCtx, h) } func NewExecutor(ctx context.Context, nodeConfig config.NodeConfig, store *storage.DataStore, enQWorkflow v1alpha1.EnqueueWorkflow, eventSink events.EventSink, - workflowLauncher launchplan.Executor, launchPlanReader launchplan.Reader, maxDatasetSize int64, - defaultRawOutputPrefix storage.DataReference, kubeClient executors.Client, - catalogClient catalog.Client, recoveryClient recovery.Client, eventConfig *config.EventConfig, clusterID string, signalClient service.SignalServiceClient, scope promutils.Scope) (executors.Node, error) { + workflowLauncher launchplan.Executor, launchPlanReader launchplan.Reader, maxDatasetSize int64, defaultRawOutputPrefix storage.DataReference, kubeClient executors.Client, + catalogClient catalog.Client, recoveryClient recovery.Client, eventConfig *config.EventConfig, clusterID string, signalClient service.SignalServiceClient, + nodeHandlerFactory interfaces.HandlerFactory, scope promutils.Scope) (interfaces.Node, error) { // TODO we may want to make this configurable. shardSelector, err := ioutils.NewBase36PrefixShardSelector(ctx) @@ -1217,46 +1252,57 @@ func NewExecutor(ctx context.Context, nodeConfig config.NodeConfig, store *stora } nodeScope := scope.NewSubScope("node") - exec := &nodeExecutor{ - store: store, - enqueueWorkflow: enQWorkflow, - nodeRecorder: events.NewNodeEventRecorder(eventSink, nodeScope, store), - taskRecorder: events.NewTaskEventRecorder(eventSink, scope.NewSubScope("task"), store), - maxDatasetSizeBytes: maxDatasetSize, - metrics: &nodeMetrics{ - Scope: nodeScope, - FailureDuration: labeled.NewStopWatch("failure_duration", "Indicates the total execution time of a failed workflow.", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), - SuccessDuration: labeled.NewStopWatch("success_duration", "Indicates the total execution time of a successful workflow.", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), - RecoveryDuration: labeled.NewStopWatch("recovery_duration", "Indicates the total execution time of a recovered workflow.", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), - UserErrorDuration: labeled.NewStopWatch("user_error_duration", "Indicates the total execution time before user error", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), - SystemErrorDuration: labeled.NewStopWatch("system_error_duration", "Indicates the total execution time before system error", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), - UnknownErrorDuration: labeled.NewStopWatch("unknown_error_duration", "Indicates the total execution time before unknown error", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), - PermanentUserErrorDuration: labeled.NewStopWatch("perma_user_error_duration", "Indicates the total execution time before non recoverable user error", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), - PermanentSystemErrorDuration: labeled.NewStopWatch("perma_system_error_duration", "Indicates the total execution time before non recoverable system error", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), - PermanentUnknownErrorDuration: labeled.NewStopWatch("perma_unknown_error_duration", "Indicates the total execution time before non recoverable unknown error", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), - InputsWriteFailure: labeled.NewCounter("inputs_write_fail", "Indicates failure in writing node inputs to metastore", nodeScope), - TimedOutFailure: labeled.NewCounter("timeout_fail", "Indicates failure due to timeout", nodeScope), - InterruptedThresholdHit: labeled.NewCounter("interrupted_threshold", "Indicates the node interruptible disabled because it hit max failure count", nodeScope), - InterruptibleNodesRunning: labeled.NewCounter("interruptible_nodes_running", "number of interruptible nodes running", nodeScope), - InterruptibleNodesTerminated: labeled.NewCounter("interruptible_nodes_terminated", "number of interruptible nodes finished running", nodeScope), - ResolutionFailure: labeled.NewCounter("input_resolve_fail", "Indicates failure in resolving node inputs", nodeScope), - TransitionLatency: labeled.NewStopWatch("transition_latency", "Measures the latency between the last parent node stoppedAt time and current node's queued time.", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), - QueuingLatency: labeled.NewStopWatch("queueing_latency", "Measures the latency between the time a node's been queued to the time the handler reported the executable moved to running state", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), - NodeExecutionTime: labeled.NewStopWatch("node_exec_latency", "Measures the time taken to execute one node, a node can be complex so it may encompass sub-node latency.", time.Microsecond, nodeScope, labeled.EmitUnlabeledMetric), - NodeInputGatherLatency: labeled.NewStopWatch("node_input_latency", "Measures the latency to aggregate inputs and check readiness of a node", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), - }, - outputResolver: NewRemoteFileOutputResolver(store), - defaultExecutionDeadline: nodeConfig.DefaultDeadlines.DefaultNodeExecutionDeadline.Duration, + metrics := &nodeMetrics{ + Scope: nodeScope, + FailureDuration: labeled.NewStopWatch("failure_duration", "Indicates the total execution time of a failed workflow.", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), + SuccessDuration: labeled.NewStopWatch("success_duration", "Indicates the total execution time of a successful workflow.", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), + RecoveryDuration: labeled.NewStopWatch("recovery_duration", "Indicates the total execution time of a recovered workflow.", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), + UserErrorDuration: labeled.NewStopWatch("user_error_duration", "Indicates the total execution time before user error", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), + SystemErrorDuration: labeled.NewStopWatch("system_error_duration", "Indicates the total execution time before system error", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), + UnknownErrorDuration: labeled.NewStopWatch("unknown_error_duration", "Indicates the total execution time before unknown error", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), + PermanentUserErrorDuration: labeled.NewStopWatch("perma_user_error_duration", "Indicates the total execution time before non recoverable user error", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), + PermanentSystemErrorDuration: labeled.NewStopWatch("perma_system_error_duration", "Indicates the total execution time before non recoverable system error", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), + PermanentUnknownErrorDuration: labeled.NewStopWatch("perma_unknown_error_duration", "Indicates the total execution time before non recoverable unknown error", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), + InputsWriteFailure: labeled.NewCounter("inputs_write_fail", "Indicates failure in writing node inputs to metastore", nodeScope), + TimedOutFailure: labeled.NewCounter("timeout_fail", "Indicates failure due to timeout", nodeScope), + InterruptedThresholdHit: labeled.NewCounter("interrupted_threshold", "Indicates the node interruptible disabled because it hit max failure count", nodeScope), + InterruptibleNodesRunning: labeled.NewCounter("interruptible_nodes_running", "number of interruptible nodes running", nodeScope), + InterruptibleNodesTerminated: labeled.NewCounter("interruptible_nodes_terminated", "number of interruptible nodes finished running", nodeScope), + ResolutionFailure: labeled.NewCounter("input_resolve_fail", "Indicates failure in resolving node inputs", nodeScope), + TransitionLatency: labeled.NewStopWatch("transition_latency", "Measures the latency between the last parent node stoppedAt time and current node's queued time.", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), + QueuingLatency: labeled.NewStopWatch("queueing_latency", "Measures the latency between the time a node's been queued to the time the handler reported the executable moved to running state", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), + NodeExecutionTime: labeled.NewStopWatch("node_exec_latency", "Measures the time taken to execute one node, a node can be complex so it may encompass sub-node latency.", time.Microsecond, nodeScope, labeled.EmitUnlabeledMetric), + NodeInputGatherLatency: labeled.NewStopWatch("node_input_latency", "Measures the latency to aggregate inputs and check readiness of a node", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), + } + + nodeExecutor := &nodeExecutor{ + clusterID: clusterID, defaultActiveDeadline: nodeConfig.DefaultDeadlines.DefaultNodeActiveDeadline.Duration, - maxNodeRetriesForSystemFailures: uint32(nodeConfig.MaxNodeRetriesOnSystemFailures), - interruptibleFailureThreshold: uint32(nodeConfig.InterruptibleFailureThreshold), defaultDataSandbox: defaultRawOutputPrefix, - shardSelector: shardSelector, - recoveryClient: recoveryClient, + defaultExecutionDeadline: nodeConfig.DefaultDeadlines.DefaultNodeExecutionDeadline.Duration, + enqueueWorkflow: enQWorkflow, eventConfig: eventConfig, - clusterID: clusterID, + interruptibleFailureThreshold: uint32(nodeConfig.InterruptibleFailureThreshold), + maxDatasetSizeBytes: maxDatasetSize, + maxNodeRetriesForSystemFailures: uint32(nodeConfig.MaxNodeRetriesOnSystemFailures), + metrics: metrics, + nodeRecorder: events.NewNodeEventRecorder(eventSink, nodeScope, store), + outputResolver: NewRemoteFileOutputResolver(store), + recoveryClient: recoveryClient, + shardSelector: shardSelector, + store: store, + taskRecorder: events.NewTaskEventRecorder(eventSink, scope.NewSubScope("task"), store), + } + + exec := &recursiveNodeExecutor{ + nodeExecutor: nodeExecutor, + nCtxBuilder: nodeExecutor, + nodeHandlerFactory: nodeHandlerFactory, + enqueueWorkflow: enQWorkflow, + store: store, + metrics: metrics, } - nodeHandlerFactory, err := NewHandlerFactory(ctx, exec, workflowLauncher, launchPlanReader, kubeClient, catalogClient, recoveryClient, eventConfig, clusterID, signalClient, nodeScope) - exec.nodeHandlerFactory = nodeHandlerFactory + /*nodeHandlerFactory, err := NewHandlerFactory(ctx, exec, workflowLauncher, launchPlanReader, kubeClient, catalogClient, recoveryClient, eventConfig, clusterID, signalClient, nodeScope) + exec.nodeHandlerFactory = nodeHandlerFactory*/ return exec, err } diff --git a/flytepropeller/pkg/controller/nodes/executor_test.go b/flytepropeller/pkg/controller/nodes/executor_test.go index ac5a49f401..71b7c3df70 100644 --- a/flytepropeller/pkg/controller/nodes/executor_test.go +++ b/flytepropeller/pkg/controller/nodes/executor_test.go @@ -13,8 +13,6 @@ import ( "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" - "github.com/flyteorg/flytestdlib/contextutils" - mocks3 "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io/mocks" "github.com/flyteorg/flytepropeller/events" @@ -27,13 +25,14 @@ import ( mocks4 "github.com/flyteorg/flytepropeller/pkg/controller/executors/mocks" gatemocks "github.com/flyteorg/flytepropeller/pkg/controller/nodes/gate/mocks" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" - nodeHandlerMocks "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler/mocks" - mocks2 "github.com/flyteorg/flytepropeller/pkg/controller/nodes/mocks" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" + nodemocks "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces/mocks" recoveryMocks "github.com/flyteorg/flytepropeller/pkg/controller/nodes/recovery/mocks" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/subworkflow/launchplan" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/catalog" flyteassert "github.com/flyteorg/flytepropeller/pkg/utils/assert" + "github.com/flyteorg/flytestdlib/contextutils" "github.com/flyteorg/flytestdlib/promutils" "github.com/flyteorg/flytestdlib/promutils/labeled" "github.com/flyteorg/flytestdlib/storage" @@ -70,8 +69,10 @@ func TestSetInputsForStartNode(t *testing.T) { enQWf := func(workflowID v1alpha1.WorkflowID) {} adminClient := launchplan.NewFailFastLaunchPlanExecutor() + hf := &nodemocks.HandlerFactory{} + hf.On("Setup", mock.Anything, mock.Anything, mock.Anything).Return(nil) exec, err := NewExecutor(ctx, config.GetConfig().NodeConfig, mockStorage, enQWf, eventMocks.NewMockEventSink(), adminClient, - adminClient, 10, "s3://bucket/", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) + adminClient, 10, "s3://bucket/", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) inputs := &core.LiteralMap{ Literals: map[string]*core.Literal{ @@ -87,7 +88,7 @@ func TestSetInputsForStartNode(t *testing.T) { } s, err := exec.SetInputsForStartNode(ctx, w, w, w, nil) assert.NoError(t, err) - assert.Equal(t, executors.NodeStatusComplete, s) + assert.Equal(t, interfaces.NodeStatusComplete, s) }) t.Run("WithInputs", func(t *testing.T) { @@ -99,7 +100,7 @@ func TestSetInputsForStartNode(t *testing.T) { } s, err := exec.SetInputsForStartNode(ctx, w, w, w, inputs) assert.NoError(t, err) - assert.Equal(t, executors.NodeStatusComplete, s) + assert.Equal(t, interfaces.NodeStatusComplete, s) actual := &core.LiteralMap{} if assert.NoError(t, mockStorage.ReadProtobuf(ctx, "s3://test-bucket/exec/start-node/data/0/outputs.pb", actual)) { flyteassert.EqualLiteralMap(t, inputs, actual) @@ -113,12 +114,12 @@ func TestSetInputsForStartNode(t *testing.T) { } s, err := exec.SetInputsForStartNode(ctx, w, w, w, inputs) assert.Error(t, err) - assert.Equal(t, executors.NodeStatusUndefined, s) + assert.Equal(t, interfaces.NodeStatusUndefined, s) }) failStorage := createFailingDatastore(t, testScope.NewSubScope("failing")) execFail, err := NewExecutor(ctx, config.GetConfig().NodeConfig, failStorage, enQWf, eventMocks.NewMockEventSink(), adminClient, - adminClient, 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) + adminClient, 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) t.Run("StorageFailure", func(t *testing.T) { w := createDummyBaseWorkflow(mockStorage) @@ -128,7 +129,7 @@ func TestSetInputsForStartNode(t *testing.T) { } s, err := execFail.SetInputsForStartNode(ctx, w, w, w, inputs) assert.Error(t, err) - assert.Equal(t, executors.NodeStatusUndefined, s) + assert.Equal(t, interfaces.NodeStatusUndefined, s) }) } @@ -143,29 +144,25 @@ func TestNodeExecutor_Initialize(t *testing.T) { adminClient := launchplan.NewFailFastLaunchPlanExecutor() t.Run("happy", func(t *testing.T) { + hf := &nodemocks.HandlerFactory{} + hf.On("Setup", mock.Anything, mock.Anything, mock.Anything).Return(nil) + execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, memStore, enQWf, mockEventSink, adminClient, adminClient, - 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) + 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) - exec := execIface.(*nodeExecutor) - - hf := &mocks2.HandlerFactory{} - exec.nodeHandlerFactory = hf - - hf.On("Setup", mock.Anything, mock.Anything).Return(nil) + exec := execIface.(*recursiveNodeExecutor) assert.NoError(t, exec.Initialize(ctx)) }) t.Run("error", func(t *testing.T) { + hf := &nodemocks.HandlerFactory{} + hf.On("Setup", mock.Anything, mock.Anything, mock.Anything).Return(fmt.Errorf("error")) + execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, memStore, enQWf, mockEventSink, adminClient, adminClient, - 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) + 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) - exec := execIface.(*nodeExecutor) - - hf := &mocks2.HandlerFactory{} - exec.nodeHandlerFactory = hf - - hf.On("Setup", mock.Anything, mock.Anything).Return(fmt.Errorf("error")) + exec := execIface.(*recursiveNodeExecutor) assert.Error(t, exec.Initialize(ctx)) }) @@ -180,10 +177,12 @@ func TestNodeExecutor_RecursiveNodeHandler_RecurseStartNodes(t *testing.T) { store := createInmemoryDataStore(t, promutils.NewTestScope()) adminClient := launchplan.NewFailFastLaunchPlanExecutor() + hf := &nodemocks.HandlerFactory{} + hf.On("Setup", mock.Anything, mock.Anything, mock.Anything).Return(nil) execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, - 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) + 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) - exec := execIface.(*nodeExecutor) + exec := execIface.(*recursiveNodeExecutor) defaultNodeID := "n1" @@ -230,30 +229,30 @@ func TestNodeExecutor_RecursiveNodeHandler_RecurseStartNodes(t *testing.T) { name string currentNodePhase v1alpha1.NodePhase expectedNodePhase v1alpha1.NodePhase - expectedPhase executors.NodePhase + expectedPhase interfaces.NodePhase handlerReturn func() (handler.Transition, error) expectedError bool }{ // Starting at Queued - {"nys->success", v1alpha1.NodePhaseNotYetStarted, v1alpha1.NodePhaseSucceeded, executors.NodePhaseSuccess, func() (handler.Transition, error) { + {"nys->success", v1alpha1.NodePhaseNotYetStarted, v1alpha1.NodePhaseSucceeded, interfaces.NodePhaseSuccess, func() (handler.Transition, error) { return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoSuccess(nil)), nil }, false}, - {"queued->success", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseSucceeded, executors.NodePhaseSuccess, func() (handler.Transition, error) { + {"queued->success", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseSucceeded, interfaces.NodePhaseSuccess, func() (handler.Transition, error) { return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoSuccess(nil)), nil }, false}, - {"nys->error", v1alpha1.NodePhaseNotYetStarted, v1alpha1.NodePhaseNotYetStarted, executors.NodePhaseUndefined, func() (handler.Transition, error) { + {"nys->error", v1alpha1.NodePhaseNotYetStarted, v1alpha1.NodePhaseNotYetStarted, interfaces.NodePhaseUndefined, func() (handler.Transition, error) { return handler.UnknownTransition, fmt.Errorf("err") }, true}, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - hf := &mocks2.HandlerFactory{} + hf := &nodemocks.HandlerFactory{} exec.nodeHandlerFactory = hf - h := &nodeHandlerMocks.Node{} + h := &nodemocks.NodeHandler{} h.On("Handle", mock.MatchedBy(func(ctx context.Context) bool { return true }), - mock.MatchedBy(func(o handler.NodeExecutionContext) bool { return true }), + mock.MatchedBy(func(o interfaces.NodeExecutionContext) bool { return true }), ).Return(test.handlerReturn()) h.On("FinalizeRequired").Return(false) @@ -284,10 +283,12 @@ func TestNodeExecutor_RecursiveNodeHandler_RecurseEndNode(t *testing.T) { store := createInmemoryDataStore(t, promutils.NewTestScope()) adminClient := launchplan.NewFailFastLaunchPlanExecutor() + hf := &nodemocks.HandlerFactory{} + hf.On("Setup", mock.Anything, mock.Anything, mock.Anything).Return(nil) execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, - 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) + 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) - exec := execIface.(*nodeExecutor) + exec := execIface.(*recursiveNodeExecutor) // Node not yet started { @@ -335,22 +336,22 @@ func TestNodeExecutor_RecursiveNodeHandler_RecurseEndNode(t *testing.T) { name string parentNodePhase v1alpha1.NodePhase expectedNodePhase v1alpha1.NodePhase - expectedPhase executors.NodePhase + expectedPhase interfaces.NodePhase expectedError bool }{ - {"notYetStarted", v1alpha1.NodePhaseNotYetStarted, v1alpha1.NodePhaseNotYetStarted, executors.NodePhasePending, false}, - {"running", v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseNotYetStarted, executors.NodePhasePending, false}, - {"queued", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseNotYetStarted, executors.NodePhasePending, false}, - {"retryable", v1alpha1.NodePhaseRetryableFailure, v1alpha1.NodePhaseNotYetStarted, executors.NodePhasePending, false}, - {"failing", v1alpha1.NodePhaseFailing, v1alpha1.NodePhaseNotYetStarted, executors.NodePhasePending, false}, - {"skipped", v1alpha1.NodePhaseSkipped, v1alpha1.NodePhaseSkipped, executors.NodePhaseSuccess, false}, - {"success", v1alpha1.NodePhaseSucceeded, v1alpha1.NodePhaseQueued, executors.NodePhaseQueued, false}, + {"notYetStarted", v1alpha1.NodePhaseNotYetStarted, v1alpha1.NodePhaseNotYetStarted, interfaces.NodePhasePending, false}, + {"running", v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseNotYetStarted, interfaces.NodePhasePending, false}, + {"queued", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseNotYetStarted, interfaces.NodePhasePending, false}, + {"retryable", v1alpha1.NodePhaseRetryableFailure, v1alpha1.NodePhaseNotYetStarted, interfaces.NodePhasePending, false}, + {"failing", v1alpha1.NodePhaseFailing, v1alpha1.NodePhaseNotYetStarted, interfaces.NodePhasePending, false}, + {"skipped", v1alpha1.NodePhaseSkipped, v1alpha1.NodePhaseSkipped, interfaces.NodePhaseSuccess, false}, + {"success", v1alpha1.NodePhaseSucceeded, v1alpha1.NodePhaseQueued, interfaces.NodePhaseQueued, false}, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - hf := &mocks2.HandlerFactory{} + hf := &nodemocks.HandlerFactory{} exec.nodeHandlerFactory = hf - h := &nodeHandlerMocks.Node{} + h := &nodemocks.NodeHandler{} hf.OnGetHandler(v1alpha1.NodeKindEnd).Return(h, nil) mockWf, mockNode, mockNodeStatus := createSingleNodeWf(test.parentNodePhase, 0) @@ -426,36 +427,36 @@ func TestNodeExecutor_RecursiveNodeHandler_RecurseEndNode(t *testing.T) { name string currentNodePhase v1alpha1.NodePhase expectedNodePhase v1alpha1.NodePhase - expectedPhase executors.NodePhase + expectedPhase interfaces.NodePhase handlerReturn func() (handler.Transition, error) expectedError bool }{ // Starting at Queued - {"queued->success", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseSucceeded, executors.NodePhaseSuccess, func() (handler.Transition, error) { + {"queued->success", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseSucceeded, interfaces.NodePhaseSuccess, func() (handler.Transition, error) { return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoSuccess(nil)), nil }, false}, - {"queued->failed", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseFailed, executors.NodePhaseFailed, func() (handler.Transition, error) { + {"queued->failed", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseFailed, interfaces.NodePhaseFailed, func() (handler.Transition, error) { return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(core.ExecutionError_USER, "code", "mesage", nil)), nil }, false}, - {"queued->running", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseRunning, executors.NodePhasePending, func() (handler.Transition, error) { + {"queued->running", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseRunning, interfaces.NodePhasePending, func() (handler.Transition, error) { return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoRunning(nil)), nil }, false}, - {"queued->error", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseQueued, executors.NodePhaseUndefined, func() (handler.Transition, error) { + {"queued->error", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseQueued, interfaces.NodePhaseUndefined, func() (handler.Transition, error) { return handler.UnknownTransition, fmt.Errorf("err") }, true}, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - hf := &mocks2.HandlerFactory{} + hf := &nodemocks.HandlerFactory{} exec.nodeHandlerFactory = hf - h := &nodeHandlerMocks.Node{} + h := &nodemocks.NodeHandler{} h.On("Handle", mock.MatchedBy(func(ctx context.Context) bool { return true }), - mock.MatchedBy(func(o handler.NodeExecutionContext) bool { return true }), + mock.MatchedBy(func(o interfaces.NodeExecutionContext) bool { return true }), ).Return(test.handlerReturn()) h.OnFinalizeRequired().Return(false) @@ -471,7 +472,7 @@ func TestNodeExecutor_RecursiveNodeHandler_RecurseEndNode(t *testing.T) { } else { assert.NoError(t, err) } - if test.expectedPhase == executors.NodePhaseFailed { + if test.expectedPhase == interfaces.NodePhaseFailed { assert.NotNil(t, s.Err) } else { assert.Nil(t, s.Err) @@ -665,22 +666,23 @@ func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { currentNodePhase v1alpha1.NodePhase parentNodePhase v1alpha1.NodePhase expectedNodePhase v1alpha1.NodePhase - expectedPhase executors.NodePhase + expectedPhase interfaces.NodePhase expectedError bool updateCalled bool }{ - {"notYetStarted->skipped", v1alpha1.NodePhaseNotYetStarted, v1alpha1.NodePhaseFailed, v1alpha1.NodePhaseSkipped, executors.NodePhaseFailed, false, false}, - {"notYetStarted->skipped", v1alpha1.NodePhaseNotYetStarted, v1alpha1.NodePhaseSkipped, v1alpha1.NodePhaseSkipped, executors.NodePhaseSuccess, false, true}, - {"notYetStarted->queued", v1alpha1.NodePhaseNotYetStarted, v1alpha1.NodePhaseSucceeded, v1alpha1.NodePhaseQueued, executors.NodePhasePending, false, true}, + {"notYetStarted->skipped", v1alpha1.NodePhaseNotYetStarted, v1alpha1.NodePhaseFailed, v1alpha1.NodePhaseSkipped, interfaces.NodePhaseFailed, false, false}, + {"notYetStarted->skipped", v1alpha1.NodePhaseNotYetStarted, v1alpha1.NodePhaseSkipped, v1alpha1.NodePhaseSkipped, interfaces.NodePhaseSuccess, false, true}, + {"notYetStarted->queued", v1alpha1.NodePhaseNotYetStarted, v1alpha1.NodePhaseSucceeded, v1alpha1.NodePhaseQueued, interfaces.NodePhasePending, false, true}, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - hf := &mocks2.HandlerFactory{} + hf := &nodemocks.HandlerFactory{} + hf.On("Setup", mock.Anything, mock.Anything, mock.Anything).Return(nil) - h := &nodeHandlerMocks.Node{} + h := &nodemocks.NodeHandler{} h.OnHandleMatch( mock.MatchedBy(func(ctx context.Context) bool { return true }), - mock.MatchedBy(func(o handler.NodeExecutionContext) bool { return true }), + mock.MatchedBy(func(o interfaces.NodeExecutionContext) bool { return true }), ).Return(handler.UnknownTransition, fmt.Errorf("should not be called")) h.OnFinalizeRequired().Return(false) hf.OnGetHandler(v1alpha1.NodeKindTask).Return(h, nil) @@ -691,10 +693,9 @@ func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { adminClient := launchplan.NewFailFastLaunchPlanExecutor() execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, - 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) + 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) - exec := execIface.(*nodeExecutor) - exec.nodeHandlerFactory = hf + exec := execIface.(*recursiveNodeExecutor) execContext := executors.NewExecutionContext(mockWf, mockWf, mockWf, nil, executors.InitializeControlFlow()) s, err := exec.RecursiveNodeHandler(ctx, execContext, mockWf, mockWf, startNode) @@ -714,7 +715,7 @@ func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { name string currentNodePhase v1alpha1.NodePhase expectedNodePhase v1alpha1.NodePhase - expectedPhase executors.NodePhase + expectedPhase interfaces.NodePhase handlerReturn func() (handler.Transition, error) finalizeReturnErr bool expectedError bool @@ -722,54 +723,54 @@ func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { eventPhase core.NodeExecution_Phase }{ // Starting at Queued - {"queued->running", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseRunning, executors.NodePhasePending, func() (handler.Transition, error) { + {"queued->running", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseRunning, interfaces.NodePhasePending, func() (handler.Transition, error) { return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoRunning(nil)), nil }, true, false, true, core.NodeExecution_RUNNING}, - {"queued->queued", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseQueued, executors.NodePhasePending, func() (handler.Transition, error) { + {"queued->queued", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseQueued, interfaces.NodePhasePending, func() (handler.Transition, error) { return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoQueued("reason", &core.LiteralMap{})), nil }, true, false, false, core.NodeExecution_QUEUED}, - {"queued->failing", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseFailing, executors.NodePhasePending, func() (handler.Transition, error) { + {"queued->failing", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseFailing, interfaces.NodePhasePending, func() (handler.Transition, error) { return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(core.ExecutionError_USER, "code", "reason", nil)), nil }, true, false, true, core.NodeExecution_FAILED}, - {"failing->failed", v1alpha1.NodePhaseFailing, v1alpha1.NodePhaseFailed, executors.NodePhaseFailed, func() (handler.Transition, error) { + {"failing->failed", v1alpha1.NodePhaseFailing, v1alpha1.NodePhaseFailed, interfaces.NodePhaseFailed, func() (handler.Transition, error) { return handler.UnknownTransition, fmt.Errorf("error") }, false, false, false, core.NodeExecution_FAILED}, - {"failing->failed(error)", v1alpha1.NodePhaseFailing, v1alpha1.NodePhaseFailing, executors.NodePhaseUndefined, func() (handler.Transition, error) { + {"failing->failed(error)", v1alpha1.NodePhaseFailing, v1alpha1.NodePhaseFailing, interfaces.NodePhaseUndefined, func() (handler.Transition, error) { return handler.UnknownTransition, fmt.Errorf("error") }, true, true, false, core.NodeExecution_FAILING}, - {"queued->succeeding", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseSucceeding, executors.NodePhasePending, func() (handler.Transition, error) { + {"queued->succeeding", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseSucceeding, interfaces.NodePhasePending, func() (handler.Transition, error) { return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoSuccess(nil)), nil }, true, false, true, core.NodeExecution_SUCCEEDED}, - {"succeeding->success", v1alpha1.NodePhaseSucceeding, v1alpha1.NodePhaseSucceeded, executors.NodePhaseSuccess, func() (handler.Transition, error) { + {"succeeding->success", v1alpha1.NodePhaseSucceeding, v1alpha1.NodePhaseSucceeded, interfaces.NodePhaseSuccess, func() (handler.Transition, error) { return handler.UnknownTransition, fmt.Errorf("error") }, false, false, false, core.NodeExecution_SUCCEEDED}, - {"succeeding->success(error)", v1alpha1.NodePhaseSucceeding, v1alpha1.NodePhaseSucceeding, executors.NodePhaseUndefined, func() (handler.Transition, error) { + {"succeeding->success(error)", v1alpha1.NodePhaseSucceeding, v1alpha1.NodePhaseSucceeding, interfaces.NodePhaseUndefined, func() (handler.Transition, error) { return handler.UnknownTransition, fmt.Errorf("error") }, true, true, false, core.NodeExecution_SUCCEEDED}, - {"queued->error", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseQueued, executors.NodePhaseUndefined, func() (handler.Transition, error) { + {"queued->error", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseQueued, interfaces.NodePhaseUndefined, func() (handler.Transition, error) { return handler.UnknownTransition, fmt.Errorf("error") }, true, true, false, core.NodeExecution_RUNNING}, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - hf := &mocks2.HandlerFactory{} + hf := &nodemocks.HandlerFactory{} + hf.On("Setup", mock.Anything, mock.Anything, mock.Anything).Return(nil) store := createInmemoryDataStore(t, promutils.NewTestScope()) adminClient := launchplan.NewFailFastLaunchPlanExecutor() execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, - 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) + 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) - exec := execIface.(*nodeExecutor) - exec.nodeHandlerFactory = hf + exec := execIface.(*recursiveNodeExecutor) called := false evRecorder := &eventMocks.NodeEventRecorder{} @@ -780,12 +781,14 @@ func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { return true }), mock.Anything).Return(nil) - exec.nodeRecorder = evRecorder + nodeExec, ok := exec.nodeExecutor.(*nodeExecutor) + assert.True(t, ok) + nodeExec.nodeRecorder = evRecorder - h := &nodeHandlerMocks.Node{} + h := &nodemocks.NodeHandler{} h.On("Handle", mock.MatchedBy(func(ctx context.Context) bool { return true }), - mock.MatchedBy(func(o handler.NodeExecutionContext) bool { return true }), + mock.MatchedBy(func(o interfaces.NodeExecutionContext) bool { return true }), ).Return(test.handlerReturn()) h.On("FinalizeRequired").Return(true) @@ -831,57 +834,57 @@ func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { name string currentNodePhase v1alpha1.NodePhase expectedNodePhase v1alpha1.NodePhase - expectedPhase executors.NodePhase + expectedPhase interfaces.NodePhase handlerReturn func() (handler.Transition, error) expectedError bool eventRecorded bool eventPhase core.NodeExecution_Phase attempts int }{ - {"running->running", v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseRunning, executors.NodePhasePending, func() (handler.Transition, error) { + {"running->running", v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseRunning, interfaces.NodePhasePending, func() (handler.Transition, error) { return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoRunning(nil)), nil }, false, false, core.NodeExecution_RUNNING, 0}, - {"running->retryableFailure", v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseFailing, executors.NodePhasePending, + {"running->retryableFailure", v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseFailing, interfaces.NodePhasePending, func() (handler.Transition, error) { return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoRetryableFailure(core.ExecutionError_USER, "x", "y", nil)), nil }, false, true, core.NodeExecution_FAILED, 0}, - {"retryablefailure->running", v1alpha1.NodePhaseRetryableFailure, v1alpha1.NodePhaseRunning, executors.NodePhasePending, func() (handler.Transition, error) { + {"retryablefailure->running", v1alpha1.NodePhaseRetryableFailure, v1alpha1.NodePhaseRunning, interfaces.NodePhasePending, func() (handler.Transition, error) { return handler.UnknownTransition, fmt.Errorf("should not be invoked") }, false, false, core.NodeExecution_RUNNING, 1}, - {"running->failing", v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseFailing, executors.NodePhasePending, func() (handler.Transition, error) { + {"running->failing", v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseFailing, interfaces.NodePhasePending, func() (handler.Transition, error) { return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(core.ExecutionError_USER, "code", "reason", nil)), nil }, false, true, core.NodeExecution_FAILED, 0}, - {"running->succeeding", v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseSucceeding, executors.NodePhasePending, func() (handler.Transition, error) { + {"running->succeeding", v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseSucceeding, interfaces.NodePhasePending, func() (handler.Transition, error) { return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoSuccess(nil)), nil }, false, true, core.NodeExecution_SUCCEEDED, 0}, - {"running->error", v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseRunning, executors.NodePhaseUndefined, func() (handler.Transition, error) { + {"running->error", v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseRunning, interfaces.NodePhaseUndefined, func() (handler.Transition, error) { return handler.UnknownTransition, fmt.Errorf("error") }, true, false, core.NodeExecution_RUNNING, 0}, - {"previously-failed", v1alpha1.NodePhaseFailed, v1alpha1.NodePhaseFailed, executors.NodePhaseFailed, func() (handler.Transition, error) { + {"previously-failed", v1alpha1.NodePhaseFailed, v1alpha1.NodePhaseFailed, interfaces.NodePhaseFailed, func() (handler.Transition, error) { return handler.UnknownTransition, fmt.Errorf("error") }, false, false, core.NodeExecution_RUNNING, 0}, - {"previously-success", v1alpha1.NodePhaseSucceeded, v1alpha1.NodePhaseSucceeded, executors.NodePhaseComplete, func() (handler.Transition, error) { + {"previously-success", v1alpha1.NodePhaseSucceeded, v1alpha1.NodePhaseSucceeded, interfaces.NodePhaseComplete, func() (handler.Transition, error) { return handler.UnknownTransition, fmt.Errorf("error") }, false, false, core.NodeExecution_RUNNING, 0}, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - hf := &mocks2.HandlerFactory{} + hf := &nodemocks.HandlerFactory{} + hf.On("Setup", mock.Anything, mock.Anything, mock.Anything).Return(nil) store := createInmemoryDataStore(t, promutils.NewTestScope()) adminClient := launchplan.NewFailFastLaunchPlanExecutor() execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, - 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) + 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) - exec := execIface.(*nodeExecutor) - exec.nodeHandlerFactory = hf + exec := execIface.(*recursiveNodeExecutor) called := false evRecorder := &eventMocks.NodeEventRecorder{} @@ -891,12 +894,15 @@ func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { called = true return true }), mock.Anything).Return(nil) - exec.nodeRecorder = evRecorder - h := &nodeHandlerMocks.Node{} + nodeExec, ok := exec.nodeExecutor.(*nodeExecutor) + assert.True(t, ok) + nodeExec.nodeRecorder = evRecorder + + h := &nodemocks.NodeHandler{} h.On("Handle", mock.MatchedBy(func(ctx context.Context) bool { return true }), - mock.MatchedBy(func(o handler.NodeExecutionContext) bool { return true }), + mock.MatchedBy(func(o interfaces.NodeExecutionContext) bool { return true }), ).Return(test.handlerReturn()) h.On("FinalizeRequired").Return(true) if test.currentNodePhase == v1alpha1.NodePhaseRetryableFailure { @@ -938,19 +944,19 @@ func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { // Extinguished retries t.Run("retries-exhausted", func(t *testing.T) { - hf := &mocks2.HandlerFactory{} + hf := &nodemocks.HandlerFactory{} + hf.On("Setup", mock.Anything, mock.Anything, mock.Anything).Return(nil) store := createInmemoryDataStore(t, promutils.NewTestScope()) adminClient := launchplan.NewFailFastLaunchPlanExecutor() execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, - 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) + 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) - exec := execIface.(*nodeExecutor) - exec.nodeHandlerFactory = hf + exec := execIface.(*recursiveNodeExecutor) - h := &nodeHandlerMocks.Node{} + h := &nodemocks.NodeHandler{} h.On("Handle", mock.MatchedBy(func(ctx context.Context) bool { return true }), - mock.MatchedBy(func(o handler.NodeExecutionContext) bool { return true }), + mock.MatchedBy(func(o interfaces.NodeExecutionContext) bool { return true }), ).Return(handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoRetryableFailure(core.ExecutionError_USER, "x", "y", nil)), nil) h.On("FinalizeRequired").Return(true) h.On("Finalize", mock.Anything, mock.Anything).Return(fmt.Errorf("error")) @@ -962,26 +968,26 @@ func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { s, err := exec.RecursiveNodeHandler(ctx, execContext, mockWf, mockWf, startNode) assert.NoError(t, err) - assert.Equal(t, executors.NodePhasePending.String(), s.NodePhase.String()) + assert.Equal(t, interfaces.NodePhasePending.String(), s.NodePhase.String()) assert.Equal(t, uint32(0), mockNodeStatus.GetAttempts()) assert.Equal(t, v1alpha1.NodePhaseFailing.String(), mockNodeStatus.GetPhase().String()) }) // Remaining retries t.Run("retries-remaining", func(t *testing.T) { - hf := &mocks2.HandlerFactory{} + hf := &nodemocks.HandlerFactory{} + hf.On("Setup", mock.Anything, mock.Anything, mock.Anything).Return(nil) store := createInmemoryDataStore(t, promutils.NewTestScope()) adminClient := launchplan.NewFailFastLaunchPlanExecutor() execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, - 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) + 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) - exec := execIface.(*nodeExecutor) - exec.nodeHandlerFactory = hf + exec := execIface.(*recursiveNodeExecutor) - h := &nodeHandlerMocks.Node{} + h := &nodemocks.NodeHandler{} h.On("Handle", mock.MatchedBy(func(ctx context.Context) bool { return true }), - mock.MatchedBy(func(o handler.NodeExecutionContext) bool { return true }), + mock.MatchedBy(func(o interfaces.NodeExecutionContext) bool { return true }), ).Return(handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoRetryableFailure(core.ExecutionError_USER, "x", "y", nil)), nil) h.On("FinalizeRequired").Return(true) h.On("Finalize", mock.Anything, mock.Anything).Return(fmt.Errorf("error")) @@ -992,7 +998,7 @@ func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { execContext := executors.NewExecutionContext(mockWf, mockWf, nil, nil, executors.InitializeControlFlow()) s, err := exec.RecursiveNodeHandler(ctx, execContext, mockWf, mockWf, startNode) assert.NoError(t, err) - assert.Equal(t, executors.NodePhasePending.String(), s.NodePhase.String()) + assert.Equal(t, interfaces.NodePhasePending.String(), s.NodePhase.String()) assert.Equal(t, uint32(0), mockNodeStatus.GetAttempts()) assert.Equal(t, v1alpha1.NodePhaseFailing.String(), mockNodeStatus.GetPhase().String()) }) @@ -1006,10 +1012,12 @@ func TestNodeExecutor_RecursiveNodeHandler_NoDownstream(t *testing.T) { store := createInmemoryDataStore(t, promutils.NewTestScope()) adminClient := launchplan.NewFailFastLaunchPlanExecutor() + hf := &nodemocks.HandlerFactory{} + hf.On("Setup", mock.Anything, mock.Anything, mock.Anything).Return(nil) execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, - 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) + 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) - exec := execIface.(*nodeExecutor) + exec := execIface.(*recursiveNodeExecutor) defaultNodeID := "n1" taskID := "tID" @@ -1072,21 +1080,21 @@ func TestNodeExecutor_RecursiveNodeHandler_NoDownstream(t *testing.T) { name string currentNodePhase v1alpha1.NodePhase expectedNodePhase v1alpha1.NodePhase - expectedPhase executors.NodePhase + expectedPhase interfaces.NodePhase expectedError bool }{ - {"succeeded", v1alpha1.NodePhaseSucceeded, v1alpha1.NodePhaseSucceeded, executors.NodePhaseComplete, false}, - {"failed", v1alpha1.NodePhaseFailed, v1alpha1.NodePhaseFailed, executors.NodePhaseFailed, false}, + {"succeeded", v1alpha1.NodePhaseSucceeded, v1alpha1.NodePhaseSucceeded, interfaces.NodePhaseComplete, false}, + {"failed", v1alpha1.NodePhaseFailed, v1alpha1.NodePhaseFailed, interfaces.NodePhaseFailed, false}, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - hf := &mocks2.HandlerFactory{} + hf := &nodemocks.HandlerFactory{} exec.nodeHandlerFactory = hf - h := &nodeHandlerMocks.Node{} + h := &nodemocks.NodeHandler{} h.On("Handle", mock.MatchedBy(func(ctx context.Context) bool { return true }), - mock.MatchedBy(func(o handler.NodeExecutionContext) bool { return true }), + mock.MatchedBy(func(o interfaces.NodeExecutionContext) bool { return true }), ).Return(handler.UnknownTransition, fmt.Errorf("should not be called")) h.On("FinalizeRequired").Return(true) h.On("Finalize", mock.Anything, mock.Anything).Return(fmt.Errorf("error")) @@ -1117,10 +1125,12 @@ func TestNodeExecutor_RecursiveNodeHandler_UpstreamNotReady(t *testing.T) { store := createInmemoryDataStore(t, promutils.NewTestScope()) adminClient := launchplan.NewFailFastLaunchPlanExecutor() + hf := &nodemocks.HandlerFactory{} + hf.On("Setup", mock.Anything, mock.Anything, mock.Anything).Return(nil) execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, - 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) + 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) - exec := execIface.(*nodeExecutor) + exec := execIface.(*recursiveNodeExecutor) defaultNodeID := "n1" taskID := taskID @@ -1183,25 +1193,25 @@ func TestNodeExecutor_RecursiveNodeHandler_UpstreamNotReady(t *testing.T) { name string parentNodePhase v1alpha1.NodePhase expectedNodePhase v1alpha1.NodePhase - expectedPhase executors.NodePhase + expectedPhase interfaces.NodePhase expectedError bool }{ - {"notYetStarted", v1alpha1.NodePhaseNotYetStarted, v1alpha1.NodePhaseNotYetStarted, executors.NodePhasePending, false}, - {"running", v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseNotYetStarted, executors.NodePhasePending, false}, - {"queued", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseNotYetStarted, executors.NodePhasePending, false}, - {"retryable", v1alpha1.NodePhaseRetryableFailure, v1alpha1.NodePhaseNotYetStarted, executors.NodePhasePending, false}, - {"failing", v1alpha1.NodePhaseFailing, v1alpha1.NodePhaseNotYetStarted, executors.NodePhasePending, false}, - {"failing", v1alpha1.NodePhaseSucceeding, v1alpha1.NodePhaseNotYetStarted, executors.NodePhasePending, false}, - {"skipped", v1alpha1.NodePhaseSkipped, v1alpha1.NodePhaseSkipped, executors.NodePhaseSuccess, false}, + {"notYetStarted", v1alpha1.NodePhaseNotYetStarted, v1alpha1.NodePhaseNotYetStarted, interfaces.NodePhasePending, false}, + {"running", v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseNotYetStarted, interfaces.NodePhasePending, false}, + {"queued", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseNotYetStarted, interfaces.NodePhasePending, false}, + {"retryable", v1alpha1.NodePhaseRetryableFailure, v1alpha1.NodePhaseNotYetStarted, interfaces.NodePhasePending, false}, + {"failing", v1alpha1.NodePhaseFailing, v1alpha1.NodePhaseNotYetStarted, interfaces.NodePhasePending, false}, + {"failing", v1alpha1.NodePhaseSucceeding, v1alpha1.NodePhaseNotYetStarted, interfaces.NodePhasePending, false}, + {"skipped", v1alpha1.NodePhaseSkipped, v1alpha1.NodePhaseSkipped, interfaces.NodePhaseSuccess, false}, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - hf := &mocks2.HandlerFactory{} + hf := &nodemocks.HandlerFactory{} exec.nodeHandlerFactory = hf - h := &nodeHandlerMocks.Node{} + h := &nodemocks.NodeHandler{} h.On("Handle", mock.MatchedBy(func(ctx context.Context) bool { return true }), - mock.MatchedBy(func(o handler.NodeExecutionContext) bool { return true }), + mock.MatchedBy(func(o interfaces.NodeExecutionContext) bool { return true }), ).Return(handler.UnknownTransition, fmt.Errorf("should not be called")) h.On("FinalizeRequired").Return(true) h.On("Finalize", mock.Anything, mock.Anything).Return(fmt.Errorf("error")) @@ -1233,10 +1243,12 @@ func TestNodeExecutor_RecursiveNodeHandler_BranchNode(t *testing.T) { store := createInmemoryDataStore(t, promutils.NewTestScope()) adminClient := launchplan.NewFailFastLaunchPlanExecutor() + hf := &nodemocks.HandlerFactory{} + hf.On("Setup", mock.Anything, mock.Anything, mock.Anything).Return(nil) execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, - 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) + 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) - exec := execIface.(*nodeExecutor) + exec := execIface.(*recursiveNodeExecutor) // Node not yet started { tests := []struct { @@ -1244,22 +1256,22 @@ func TestNodeExecutor_RecursiveNodeHandler_BranchNode(t *testing.T) { parentNodePhase v1alpha1.BranchNodePhase currentNodePhase v1alpha1.NodePhase phaseUpdateExpected bool - expectedPhase executors.NodePhase + expectedPhase interfaces.NodePhase expectedError bool }{ - {"branchSuccess", v1alpha1.BranchNodeSuccess, v1alpha1.NodePhaseNotYetStarted, true, executors.NodePhaseQueued, false}, - {"branchNotYetDone", v1alpha1.BranchNodeNotYetEvaluated, v1alpha1.NodePhaseNotYetStarted, false, executors.NodePhasePending, false}, - {"branchError", v1alpha1.BranchNodeError, v1alpha1.NodePhaseNotYetStarted, false, executors.NodePhasePending, false}, + {"branchSuccess", v1alpha1.BranchNodeSuccess, v1alpha1.NodePhaseNotYetStarted, true, interfaces.NodePhaseQueued, false}, + {"branchNotYetDone", v1alpha1.BranchNodeNotYetEvaluated, v1alpha1.NodePhaseNotYetStarted, false, interfaces.NodePhasePending, false}, + {"branchError", v1alpha1.BranchNodeError, v1alpha1.NodePhaseNotYetStarted, false, interfaces.NodePhasePending, false}, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - hf := &mocks2.HandlerFactory{} + hf := &nodemocks.HandlerFactory{} exec.nodeHandlerFactory = hf - h := &nodeHandlerMocks.Node{} + h := &nodemocks.NodeHandler{} h.OnHandleMatch( mock.MatchedBy(func(ctx context.Context) bool { return true }), - mock.MatchedBy(func(o handler.NodeExecutionContext) bool { return true }), + mock.MatchedBy(func(o interfaces.NodeExecutionContext) bool { return true }), ).Return(handler.UnknownTransition, fmt.Errorf("should not be called")) h.OnFinalizeRequired().Return(true) h.OnFinalizeMatch(mock.Anything, mock.Anything).Return(fmt.Errorf("error")) @@ -1345,11 +1357,10 @@ func TestNodeExecutor_RecursiveNodeHandler_BranchNode(t *testing.T) { func Test_nodeExecutor_RecordTransitionLatency(t *testing.T) { testScope := promutils.NewTestScope() type fields struct { - nodeHandlerFactory HandlerFactory - enqueueWorkflow v1alpha1.EnqueueWorkflow - store *storage.DataStore - nodeRecorder events.NodeEventRecorder - metrics *nodeMetrics + enqueueWorkflow v1alpha1.EnqueueWorkflow + store *storage.DataStore + nodeRecorder events.NodeEventRecorder + metrics *nodeMetrics } type args struct { w v1alpha1.ExecutableWorkflow @@ -1390,11 +1401,10 @@ func Test_nodeExecutor_RecordTransitionLatency(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := &nodeExecutor{ - nodeHandlerFactory: tt.fields.nodeHandlerFactory, - enqueueWorkflow: tt.fields.enqueueWorkflow, - store: tt.fields.store, - nodeRecorder: tt.fields.nodeRecorder, - metrics: tt.fields.metrics, + enqueueWorkflow: tt.fields.enqueueWorkflow, + store: tt.fields.store, + nodeRecorder: tt.fields.nodeRecorder, + metrics: tt.fields.metrics, } c.RecordTransitionLatency(context.TODO(), tt.args.w, tt.args.w, tt.args.node, tt.args.nodeStatus) @@ -1493,18 +1503,14 @@ func Test_nodeExecutor_timeout(t *testing.T) { handlerReturn := func() (handler.Transition, error) { return handler.DoTransition(handler.TransitionTypeEphemeral, tt.phaseInfo), tt.err } - h := &nodeHandlerMocks.Node{} + h := &nodemocks.NodeHandler{} h.On("Handle", mock.MatchedBy(func(ctx context.Context) bool { return true }), - mock.MatchedBy(func(o handler.NodeExecutionContext) bool { return true }), + mock.MatchedBy(func(o interfaces.NodeExecutionContext) bool { return true }), ).Return(handlerReturn()) h.On("FinalizeRequired").Return(true) h.On("Finalize", mock.Anything, mock.Anything).Return(nil) - hf := &mocks2.HandlerFactory{} - hf.On("GetHandler", v1alpha1.NodeKindStart).Return(h, nil) - c.nodeHandlerFactory = hf - mockNode := &mocks.ExecutableNode{} mockNode.On("GetID").Return("node") mockNode.On("GetBranchNode").Return(nil) @@ -1545,18 +1551,15 @@ func Test_nodeExecutor_system_error(t *testing.T) { ns.On("ClearLastAttemptStartedAt").Return() c := &nodeExecutor{} - h := &nodeHandlerMocks.Node{} + h := &nodemocks.NodeHandler{} h.On("Handle", mock.MatchedBy(func(ctx context.Context) bool { return true }), - mock.MatchedBy(func(o handler.NodeExecutionContext) bool { return true }), + mock.MatchedBy(func(o interfaces.NodeExecutionContext) bool { return true }), ).Return(handler.DoTransition(handler.TransitionTypeEphemeral, phaseInfo), nil) h.On("FinalizeRequired").Return(true) h.On("Finalize", mock.Anything, mock.Anything).Return(nil) - hf := &mocks2.HandlerFactory{} - hf.On("GetHandler", v1alpha1.NodeKindStart).Return(h, nil) - c.nodeHandlerFactory = hf c.maxNodeRetriesForSystemFailures = 2 mockNode := &mocks.ExecutableNode{} @@ -1579,7 +1582,7 @@ func Test_nodeExecutor_abort(t *testing.T) { nCtx := &nodeExecContext{} t.Run("abort error calls finalize", func(t *testing.T) { - h := &nodeHandlerMocks.Node{} + h := &nodemocks.NodeHandler{} h.OnAbortMatch(mock.Anything, mock.Anything, mock.Anything).Return(errors.New("test error")) h.OnFinalizeRequired().Return(true) var called bool @@ -1587,13 +1590,13 @@ func Test_nodeExecutor_abort(t *testing.T) { called = true }).Return(nil) - err := exec.abort(ctx, h, nCtx, "testing") + err := exec.Abort(ctx, h, nCtx, "testing", false) assert.Equal(t, "test error", err.Error()) assert.True(t, called) }) t.Run("abort error calls finalize with error", func(t *testing.T) { - h := &nodeHandlerMocks.Node{} + h := &nodemocks.NodeHandler{} h.OnAbortMatch(mock.Anything, mock.Anything, mock.Anything).Return(errors.New("test error")) h.OnFinalizeRequired().Return(true) var called bool @@ -1601,13 +1604,13 @@ func Test_nodeExecutor_abort(t *testing.T) { called = true }).Return(errors.New("finalize error")) - err := exec.abort(ctx, h, nCtx, "testing") + err := exec.Abort(ctx, h, nCtx, "testing", false) assert.Equal(t, "0: test error\r\n1: finalize error\r\n", err.Error()) assert.True(t, called) }) t.Run("abort calls finalize when no errors", func(t *testing.T) { - h := &nodeHandlerMocks.Node{} + h := &nodemocks.NodeHandler{} h.OnAbortMatch(mock.Anything, mock.Anything, mock.Anything).Return(nil) h.OnFinalizeRequired().Return(true) var called bool @@ -1615,7 +1618,7 @@ func Test_nodeExecutor_abort(t *testing.T) { called = true }).Return(nil) - err := exec.abort(ctx, h, nCtx, "testing") + err := exec.Abort(ctx, h, nCtx, "testing", false) assert.NoError(t, err) assert.True(t, called) }) @@ -1623,7 +1626,7 @@ func Test_nodeExecutor_abort(t *testing.T) { func TestNodeExecutor_AbortHandler(t *testing.T) { ctx := context.Background() - exec := nodeExecutor{} + exec := recursiveNodeExecutor{} t.Run("not-yet-started", func(t *testing.T) { id := "id" @@ -1649,17 +1652,20 @@ func TestNodeExecutor_AbortHandler(t *testing.T) { ns.OnGetDataDir().Return(storage.DataReference("s3:/foo")) nl.OnGetNodeExecutionStatusMatch(mock.Anything, id).Return(ns) nl.OnGetNode(id).Return(n, true) - incompatibleClusterErr := fakeNodeEventRecorder{&eventsErr.EventError{Code: eventsErr.AlreadyExists, Cause: fmt.Errorf("err")}} + incompatibleClusterErr := fakeEventRecorder{nodeErr: &eventsErr.EventError{Code: eventsErr.AlreadyExists, Cause: fmt.Errorf("err")}} - hf := &mocks2.HandlerFactory{} - exec.nodeHandlerFactory = hf - h := &nodeHandlerMocks.Node{} + hf := &nodemocks.HandlerFactory{} + h := &nodemocks.NodeHandler{} h.OnAbortMatch(mock.Anything, mock.Anything, "aborting").Return(nil) h.OnFinalizeMatch(mock.Anything, mock.Anything).Return(nil) hf.OnGetHandlerMatch(v1alpha1.NodeKindStart).Return(h, nil) - nExec := nodeExecutor{ - nodeRecorder: incompatibleClusterErr, + nodeExecutor := &nodeExecutor{ + nodeRecorder: incompatibleClusterErr, + } + nExec := recursiveNodeExecutor{ + nodeExecutor: nodeExecutor, + nCtxBuilder: nodeExecutor, nodeHandlerFactory: hf, } @@ -1680,7 +1686,7 @@ func TestNodeExecutor_AbortHandler(t *testing.T) { func TestNodeExecutor_FinalizeHandler(t *testing.T) { ctx := context.Background() - exec := nodeExecutor{} + exec := recursiveNodeExecutor{} t.Run("not-yet-started", func(t *testing.T) { id := "id" @@ -1843,10 +1849,12 @@ func TestNodeExecutor_RecursiveNodeHandler_ParallelismLimit(t *testing.T) { store := createInmemoryDataStore(t, promutils.NewTestScope()) adminClient := launchplan.NewFailFastLaunchPlanExecutor() + hf := &nodemocks.HandlerFactory{} + hf.On("Setup", mock.Anything, mock.Anything, mock.Anything).Return(nil) execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, - 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) + 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) - exec := execIface.(*nodeExecutor) + exec := execIface.(*recursiveNodeExecutor) defaultNodeID := "n1" taskID := taskID @@ -1921,12 +1929,12 @@ func TestNodeExecutor_RecursiveNodeHandler_ParallelismLimit(t *testing.T) { cf := executors.InitializeControlFlow() eCtx := executors.NewExecutionContext(mockWf, mockWf, nil, nil, cf) - hf := &mocks2.HandlerFactory{} + hf := &nodemocks.HandlerFactory{} exec.nodeHandlerFactory = hf - h := &nodeHandlerMocks.Node{} + h := &nodemocks.NodeHandler{} h.OnHandleMatch( mock.MatchedBy(func(ctx context.Context) bool { return true }), - mock.MatchedBy(func(o handler.NodeExecutionContext) bool { return true }), + mock.MatchedBy(func(o interfaces.NodeExecutionContext) bool { return true }), ).Return(handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoSuccess(nil)), nil) h.OnFinalizeRequired().Return(false) @@ -1934,7 +1942,7 @@ func TestNodeExecutor_RecursiveNodeHandler_ParallelismLimit(t *testing.T) { s, err := exec.RecursiveNodeHandler(ctx, eCtx, mockWf, mockWf, mockNode) assert.NoError(t, err) - assert.Equal(t, s.NodePhase.String(), executors.NodePhaseSuccess.String()) + assert.Equal(t, s.NodePhase.String(), interfaces.NodePhaseSuccess.String()) }) t.Run("parallelism-met", func(t *testing.T) { @@ -1945,7 +1953,7 @@ func TestNodeExecutor_RecursiveNodeHandler_ParallelismLimit(t *testing.T) { s, err := exec.RecursiveNodeHandler(ctx, eCtx, mockWf, mockWf, mockNode) assert.NoError(t, err) - assert.Equal(t, s.NodePhase.String(), executors.NodePhaseRunning.String()) + assert.Equal(t, s.NodePhase.String(), interfaces.NodePhaseRunning.String()) }) t.Run("parallelism-met-not-yet-started", func(t *testing.T) { @@ -1956,7 +1964,7 @@ func TestNodeExecutor_RecursiveNodeHandler_ParallelismLimit(t *testing.T) { s, err := exec.RecursiveNodeHandler(ctx, eCtx, mockWf, mockWf, mockNode) assert.NoError(t, err) - assert.Equal(t, s.NodePhase.String(), executors.NodePhaseRunning.String()) + assert.Equal(t, s.NodePhase.String(), interfaces.NodePhaseRunning.String()) }) t.Run("parallelism-disabled", func(t *testing.T) { @@ -1965,12 +1973,12 @@ func TestNodeExecutor_RecursiveNodeHandler_ParallelismLimit(t *testing.T) { cf.IncrementParallelism() eCtx := executors.NewExecutionContext(mockWf, mockWf, nil, nil, cf) - hf := &mocks2.HandlerFactory{} + hf := &nodemocks.HandlerFactory{} exec.nodeHandlerFactory = hf - h := &nodeHandlerMocks.Node{} + h := &nodemocks.NodeHandler{} h.OnHandleMatch( mock.MatchedBy(func(ctx context.Context) bool { return true }), - mock.MatchedBy(func(o handler.NodeExecutionContext) bool { return true }), + mock.MatchedBy(func(o interfaces.NodeExecutionContext) bool { return true }), ).Return(handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoSuccess(nil)), nil) h.OnFinalizeRequired().Return(false) @@ -1978,59 +1986,10 @@ func TestNodeExecutor_RecursiveNodeHandler_ParallelismLimit(t *testing.T) { s, err := exec.RecursiveNodeHandler(ctx, eCtx, mockWf, mockWf, mockNode) assert.NoError(t, err) - assert.Equal(t, s.NodePhase.String(), executors.NodePhaseSuccess.String()) + assert.Equal(t, s.NodePhase.String(), interfaces.NodePhaseSuccess.String()) }) } -type fakeNodeEventRecorder struct { - err error -} - -func (f fakeNodeEventRecorder) RecordNodeEvent(ctx context.Context, event *event.NodeExecutionEvent, eventConfig *config.EventConfig) error { - if f.err != nil { - return f.err - } - return nil -} - -func Test_nodeExecutor_IdempotentRecordEvent(t *testing.T) { - noErrRecorder := fakeNodeEventRecorder{} - alreadyExistsError := fakeNodeEventRecorder{&eventsErr.EventError{Code: eventsErr.AlreadyExists, Cause: fmt.Errorf("err")}} - inTerminalError := fakeNodeEventRecorder{&eventsErr.EventError{Code: eventsErr.EventAlreadyInTerminalStateError, Cause: fmt.Errorf("err")}} - otherError := fakeNodeEventRecorder{&eventsErr.EventError{Code: eventsErr.ResourceExhausted, Cause: fmt.Errorf("err")}} - - tests := []struct { - name string - rec events.NodeEventRecorder - p core.NodeExecution_Phase - wantErr bool - }{ - {"aborted-success", noErrRecorder, core.NodeExecution_ABORTED, false}, - {"aborted-failure", otherError, core.NodeExecution_ABORTED, true}, - {"aborted-already", alreadyExistsError, core.NodeExecution_ABORTED, false}, - {"aborted-terminal", inTerminalError, core.NodeExecution_ABORTED, false}, - {"running-terminal", inTerminalError, core.NodeExecution_RUNNING, true}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - c := &nodeExecutor{ - nodeRecorder: tt.rec, - eventConfig: &config.EventConfig{ - RawOutputPolicy: config.RawOutputPolicyReference, - }, - } - ev := &event.NodeExecutionEvent{ - Id: &core.NodeExecutionIdentifier{}, - Phase: tt.p, - ProducerId: "propeller", - } - if err := c.IdempotentRecordEvent(context.TODO(), ev); (err != nil) != tt.wantErr { - t.Errorf("IdempotentRecordEvent() error = %v, wantErr %v", err, tt.wantErr) - } - }) - } -} - func TestRecover(t *testing.T) { recoveryID := &core.WorkflowExecutionIdentifier{ Project: "p", @@ -2087,7 +2046,7 @@ func TestRecover(t *testing.T) { }) execContext.OnGetEventVersion().Return(v1alpha1.EventVersion0) - nm := &nodeHandlerMocks.NodeExecutionMetadata{} + nm := &nodemocks.NodeExecutionMetadata{} nm.OnGetNodeExecutionID().Return(&core.NodeExecutionIdentifier{ ExecutionId: wfExecID, NodeId: nodeID, @@ -2099,7 +2058,7 @@ func TestRecover(t *testing.T) { ns := &mocks.ExecutableNodeStatus{} ns.OnGetOutputDir().Return(storage.DataReference("out")) - nCtx := &nodeHandlerMocks.NodeExecutionContext{} + nCtx := &nodemocks.NodeExecutionContext{} nCtx.OnExecutionContext().Return(execContext) nCtx.OnNodeExecutionMetadata().Return(nm) nCtx.OnInputReader().Return(ir) @@ -2156,7 +2115,7 @@ func TestRecover(t *testing.T) { dstDynamicJobSpecURI := "dst/foo/bar" // initialize node execution context - nCtx := &nodeHandlerMocks.NodeExecutionContext{} + nCtx := &nodemocks.NodeExecutionContext{} nCtx.OnExecutionContext().Return(execContext) nCtx.OnNodeExecutionMetadata().Return(nm) nCtx.OnInputReader().Return(ir) @@ -2182,11 +2141,11 @@ func TestRecover(t *testing.T) { nCtx.OnDataStore().Return(storageClient) - reader := &nodeHandlerMocks.NodeStateReader{} + reader := &nodemocks.NodeStateReader{} reader.OnGetDynamicNodeState().Return(handler.DynamicNodeState{}) nCtx.OnNodeStateReader().Return(reader) - writer := &nodeHandlerMocks.NodeStateWriter{} + writer := &nodemocks.NodeStateWriter{} writer.OnPutDynamicNodeStateMatch(mock.Anything).Run(func(args mock.Arguments) { state := args.Get(0).(handler.DynamicNodeState) assert.Equal(t, v1alpha1.DynamicNodePhaseParentFinalized, state.Phase) @@ -2403,10 +2362,10 @@ func TestRecover(t *testing.T) { recoveryClient: recoveryClient, } - reader := &nodeHandlerMocks.NodeStateReader{} + reader := &nodemocks.NodeStateReader{} reader.OnGetTaskNodeState().Return(handler.TaskNodeState{}) nCtx.OnNodeStateReader().Return(reader) - writer := &nodeHandlerMocks.NodeStateWriter{} + writer := &nodemocks.NodeStateWriter{} writer.OnPutTaskNodeStateMatch(mock.Anything).Run(func(args mock.Arguments) { state := args.Get(0).(handler.TaskNodeState) assert.Equal(t, state.PreviousNodeExecutionCheckpointURI.String(), "prev path") diff --git a/flytepropeller/pkg/controller/nodes/factory/handler_factory.go b/flytepropeller/pkg/controller/nodes/factory/handler_factory.go new file mode 100644 index 0000000000..9ec00da7a0 --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/factory/handler_factory.go @@ -0,0 +1,96 @@ +package factory + +import ( + "context" + + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service" + + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/catalog" + + "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/flyteorg/flytepropeller/pkg/controller/config" + "github.com/flyteorg/flytepropeller/pkg/controller/executors" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/array" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/branch" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/dynamic" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/end" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/gate" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/recovery" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/start" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/subworkflow" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/subworkflow/launchplan" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task" + + "github.com/flyteorg/flytestdlib/promutils" + + "github.com/pkg/errors" +) + +type handlerFactory struct { + handlers map[v1alpha1.NodeKind]interfaces.NodeHandler + + workflowLauncher launchplan.Executor + launchPlanReader launchplan.Reader + kubeClient executors.Client + catalogClient catalog.Client + recoveryClient recovery.Client + eventConfig *config.EventConfig + clusterID string + signalClient service.SignalServiceClient + scope promutils.Scope +} + +func (f *handlerFactory) GetHandler(kind v1alpha1.NodeKind) (interfaces.NodeHandler, error) { + h, ok := f.handlers[kind] + if !ok { + return nil, errors.Errorf("Handler not registered for NodeKind [%v]", kind) + } + return h, nil +} + +func (f *handlerFactory) Setup(ctx context.Context, executor interfaces.Node, setup interfaces.SetupContext) error { + t, err := task.New(ctx, f.kubeClient, f.catalogClient, f.eventConfig, f.clusterID, f.scope) + if err != nil { + return err + } + + arrayHandler, err := array.New(executor, f.eventConfig, f.scope) + if err != nil { + return err + } + + f.handlers = map[v1alpha1.NodeKind]interfaces.NodeHandler{ + v1alpha1.NodeKindBranch: branch.New(executor, f.eventConfig, f.scope), + v1alpha1.NodeKindTask: dynamic.New(t, executor, f.launchPlanReader, f.eventConfig, f.scope), + v1alpha1.NodeKindWorkflow: subworkflow.New(executor, f.workflowLauncher, f.recoveryClient, f.eventConfig, f.scope), + v1alpha1.NodeKindGate: gate.New(f.eventConfig, f.signalClient, f.scope), + v1alpha1.NodeKindArray: arrayHandler, + v1alpha1.NodeKindStart: start.New(), + v1alpha1.NodeKindEnd: end.New(), + } + + for _, v := range f.handlers { + if err := v.Setup(ctx, setup); err != nil { + return err + } + } + return nil +} + +func NewHandlerFactory(ctx context.Context, workflowLauncher launchplan.Executor, launchPlanReader launchplan.Reader, + kubeClient executors.Client, catalogClient catalog.Client, recoveryClient recovery.Client, eventConfig *config.EventConfig, + clusterID string, signalClient service.SignalServiceClient, scope promutils.Scope) (interfaces.HandlerFactory, error) { + + return &handlerFactory{ + workflowLauncher: workflowLauncher, + launchPlanReader: launchPlanReader, + kubeClient: kubeClient, + catalogClient: catalogClient, + recoveryClient: recoveryClient, + eventConfig: eventConfig, + clusterID: clusterID, + signalClient: signalClient, + scope: scope, + }, nil +} diff --git a/flytepropeller/pkg/controller/nodes/gate/handler.go b/flytepropeller/pkg/controller/nodes/gate/handler.go index 31e21c4dc5..4b8d627209 100644 --- a/flytepropeller/pkg/controller/nodes/gate/handler.go +++ b/flytepropeller/pkg/controller/nodes/gate/handler.go @@ -13,6 +13,7 @@ import ( "github.com/flyteorg/flytepropeller/pkg/controller/config" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/errors" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" "github.com/flyteorg/flytestdlib/logger" "github.com/flyteorg/flytestdlib/promutils" @@ -46,12 +47,12 @@ func newMetrics(scope promutils.Scope) metrics { } // Abort stops the gate node defined in the NodeExecutionContext -func (g *gateNodeHandler) Abort(ctx context.Context, nCtx handler.NodeExecutionContext, reason string) error { +func (g *gateNodeHandler) Abort(ctx context.Context, nCtx interfaces.NodeExecutionContext, reason string) error { return nil } // Finalize completes the gate node defined in the NodeExecutionContext -func (g *gateNodeHandler) Finalize(ctx context.Context, _ handler.NodeExecutionContext) error { +func (g *gateNodeHandler) Finalize(ctx context.Context, _ interfaces.NodeExecutionContext) error { return nil } @@ -63,7 +64,7 @@ func (g *gateNodeHandler) FinalizeRequired() bool { // Handle is responsible for transitioning and reporting node state to complete the node defined // by the NodeExecutionContext -func (g *gateNodeHandler) Handle(ctx context.Context, nCtx handler.NodeExecutionContext) (handler.Transition, error) { +func (g *gateNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecutionContext) (handler.Transition, error) { gateNode := nCtx.Node().GetGateNode() gateNodeState := nCtx.NodeStateReader().GetGateNodeState() @@ -197,7 +198,7 @@ func (g *gateNodeHandler) Handle(ctx context.Context, nCtx handler.NodeExecution // update gate node status if err := nCtx.NodeStateWriter().PutGateNodeState(gateNodeState); err != nil { - logger.Errorf(ctx, "failed to store TaskNode state with err [%s]", err.Error()) + logger.Errorf(ctx, "failed to store GateNode state with err [%s]", err.Error()) return handler.UnknownTransition, err } @@ -205,12 +206,12 @@ func (g *gateNodeHandler) Handle(ctx context.Context, nCtx handler.NodeExecution } // Setup handles any initialization requirements for this handler -func (g *gateNodeHandler) Setup(_ context.Context, _ handler.SetupContext) error { +func (g *gateNodeHandler) Setup(_ context.Context, _ interfaces.SetupContext) error { return nil } // New initializes a new gateNodeHandler -func New(eventConfig *config.EventConfig, signalClient service.SignalServiceClient, scope promutils.Scope) handler.Node { +func New(eventConfig *config.EventConfig, signalClient service.SignalServiceClient, scope promutils.Scope) interfaces.NodeHandler { gateScope := scope.NewSubScope("gate") return &gateNodeHandler{ signalClient: signalClient, diff --git a/flytepropeller/pkg/controller/nodes/gate/handler_test.go b/flytepropeller/pkg/controller/nodes/gate/handler_test.go index ca7e8e9dd5..b60b9db24b 100644 --- a/flytepropeller/pkg/controller/nodes/gate/handler_test.go +++ b/flytepropeller/pkg/controller/nodes/gate/handler_test.go @@ -17,7 +17,7 @@ import ( executormocks "github.com/flyteorg/flytepropeller/pkg/controller/executors/mocks" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/gate/mocks" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" - nodeMocks "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler/mocks" + nodeMocks "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces/mocks" "github.com/flyteorg/flytestdlib/contextutils" "github.com/flyteorg/flytestdlib/promutils" diff --git a/flytepropeller/pkg/controller/nodes/handler/mocks/node.go b/flytepropeller/pkg/controller/nodes/handler/mocks/node.go deleted file mode 100644 index e7e376606a..0000000000 --- a/flytepropeller/pkg/controller/nodes/handler/mocks/node.go +++ /dev/null @@ -1,182 +0,0 @@ -// Code generated by mockery v1.0.1. DO NOT EDIT. - -package mocks - -import ( - context "context" - - handler "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" - mock "github.com/stretchr/testify/mock" -) - -// Node is an autogenerated mock type for the Node type -type Node struct { - mock.Mock -} - -type Node_Abort struct { - *mock.Call -} - -func (_m Node_Abort) Return(_a0 error) *Node_Abort { - return &Node_Abort{Call: _m.Call.Return(_a0)} -} - -func (_m *Node) OnAbort(ctx context.Context, executionContext handler.NodeExecutionContext, reason string) *Node_Abort { - c_call := _m.On("Abort", ctx, executionContext, reason) - return &Node_Abort{Call: c_call} -} - -func (_m *Node) OnAbortMatch(matchers ...interface{}) *Node_Abort { - c_call := _m.On("Abort", matchers...) - return &Node_Abort{Call: c_call} -} - -// Abort provides a mock function with given fields: ctx, executionContext, reason -func (_m *Node) Abort(ctx context.Context, executionContext handler.NodeExecutionContext, reason string) error { - ret := _m.Called(ctx, executionContext, reason) - - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, handler.NodeExecutionContext, string) error); ok { - r0 = rf(ctx, executionContext, reason) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -type Node_Finalize struct { - *mock.Call -} - -func (_m Node_Finalize) Return(_a0 error) *Node_Finalize { - return &Node_Finalize{Call: _m.Call.Return(_a0)} -} - -func (_m *Node) OnFinalize(ctx context.Context, executionContext handler.NodeExecutionContext) *Node_Finalize { - c_call := _m.On("Finalize", ctx, executionContext) - return &Node_Finalize{Call: c_call} -} - -func (_m *Node) OnFinalizeMatch(matchers ...interface{}) *Node_Finalize { - c_call := _m.On("Finalize", matchers...) - return &Node_Finalize{Call: c_call} -} - -// Finalize provides a mock function with given fields: ctx, executionContext -func (_m *Node) Finalize(ctx context.Context, executionContext handler.NodeExecutionContext) error { - ret := _m.Called(ctx, executionContext) - - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, handler.NodeExecutionContext) error); ok { - r0 = rf(ctx, executionContext) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -type Node_FinalizeRequired struct { - *mock.Call -} - -func (_m Node_FinalizeRequired) Return(_a0 bool) *Node_FinalizeRequired { - return &Node_FinalizeRequired{Call: _m.Call.Return(_a0)} -} - -func (_m *Node) OnFinalizeRequired() *Node_FinalizeRequired { - c_call := _m.On("FinalizeRequired") - return &Node_FinalizeRequired{Call: c_call} -} - -func (_m *Node) OnFinalizeRequiredMatch(matchers ...interface{}) *Node_FinalizeRequired { - c_call := _m.On("FinalizeRequired", matchers...) - return &Node_FinalizeRequired{Call: c_call} -} - -// FinalizeRequired provides a mock function with given fields: -func (_m *Node) FinalizeRequired() bool { - ret := _m.Called() - - var r0 bool - if rf, ok := ret.Get(0).(func() bool); ok { - r0 = rf() - } else { - r0 = ret.Get(0).(bool) - } - - return r0 -} - -type Node_Handle struct { - *mock.Call -} - -func (_m Node_Handle) Return(_a0 handler.Transition, _a1 error) *Node_Handle { - return &Node_Handle{Call: _m.Call.Return(_a0, _a1)} -} - -func (_m *Node) OnHandle(ctx context.Context, executionContext handler.NodeExecutionContext) *Node_Handle { - c_call := _m.On("Handle", ctx, executionContext) - return &Node_Handle{Call: c_call} -} - -func (_m *Node) OnHandleMatch(matchers ...interface{}) *Node_Handle { - c_call := _m.On("Handle", matchers...) - return &Node_Handle{Call: c_call} -} - -// Handle provides a mock function with given fields: ctx, executionContext -func (_m *Node) Handle(ctx context.Context, executionContext handler.NodeExecutionContext) (handler.Transition, error) { - ret := _m.Called(ctx, executionContext) - - var r0 handler.Transition - if rf, ok := ret.Get(0).(func(context.Context, handler.NodeExecutionContext) handler.Transition); ok { - r0 = rf(ctx, executionContext) - } else { - r0 = ret.Get(0).(handler.Transition) - } - - var r1 error - if rf, ok := ret.Get(1).(func(context.Context, handler.NodeExecutionContext) error); ok { - r1 = rf(ctx, executionContext) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -type Node_Setup struct { - *mock.Call -} - -func (_m Node_Setup) Return(_a0 error) *Node_Setup { - return &Node_Setup{Call: _m.Call.Return(_a0)} -} - -func (_m *Node) OnSetup(ctx context.Context, setupContext handler.SetupContext) *Node_Setup { - c_call := _m.On("Setup", ctx, setupContext) - return &Node_Setup{Call: c_call} -} - -func (_m *Node) OnSetupMatch(matchers ...interface{}) *Node_Setup { - c_call := _m.On("Setup", matchers...) - return &Node_Setup{Call: c_call} -} - -// Setup provides a mock function with given fields: ctx, setupContext -func (_m *Node) Setup(ctx context.Context, setupContext handler.SetupContext) error { - ret := _m.Called(ctx, setupContext) - - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, handler.SetupContext) error); ok { - r0 = rf(ctx, setupContext) - } else { - r0 = ret.Error(0) - } - - return r0 -} diff --git a/flytepropeller/pkg/controller/nodes/handler/mocks/node_state_reader.go b/flytepropeller/pkg/controller/nodes/handler/mocks/node_state_reader.go deleted file mode 100644 index ef86e64cc3..0000000000 --- a/flytepropeller/pkg/controller/nodes/handler/mocks/node_state_reader.go +++ /dev/null @@ -1,173 +0,0 @@ -// Code generated by mockery v1.0.1. DO NOT EDIT. - -package mocks - -import ( - handler "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" - mock "github.com/stretchr/testify/mock" -) - -// NodeStateReader is an autogenerated mock type for the NodeStateReader type -type NodeStateReader struct { - mock.Mock -} - -type NodeStateReader_GetBranchNode struct { - *mock.Call -} - -func (_m NodeStateReader_GetBranchNode) Return(_a0 handler.BranchNodeState) *NodeStateReader_GetBranchNode { - return &NodeStateReader_GetBranchNode{Call: _m.Call.Return(_a0)} -} - -func (_m *NodeStateReader) OnGetBranchNode() *NodeStateReader_GetBranchNode { - c_call := _m.On("GetBranchNode") - return &NodeStateReader_GetBranchNode{Call: c_call} -} - -func (_m *NodeStateReader) OnGetBranchNodeMatch(matchers ...interface{}) *NodeStateReader_GetBranchNode { - c_call := _m.On("GetBranchNode", matchers...) - return &NodeStateReader_GetBranchNode{Call: c_call} -} - -// GetBranchNode provides a mock function with given fields: -func (_m *NodeStateReader) GetBranchNode() handler.BranchNodeState { - ret := _m.Called() - - var r0 handler.BranchNodeState - if rf, ok := ret.Get(0).(func() handler.BranchNodeState); ok { - r0 = rf() - } else { - r0 = ret.Get(0).(handler.BranchNodeState) - } - - return r0 -} - -type NodeStateReader_GetDynamicNodeState struct { - *mock.Call -} - -func (_m NodeStateReader_GetDynamicNodeState) Return(_a0 handler.DynamicNodeState) *NodeStateReader_GetDynamicNodeState { - return &NodeStateReader_GetDynamicNodeState{Call: _m.Call.Return(_a0)} -} - -func (_m *NodeStateReader) OnGetDynamicNodeState() *NodeStateReader_GetDynamicNodeState { - c_call := _m.On("GetDynamicNodeState") - return &NodeStateReader_GetDynamicNodeState{Call: c_call} -} - -func (_m *NodeStateReader) OnGetDynamicNodeStateMatch(matchers ...interface{}) *NodeStateReader_GetDynamicNodeState { - c_call := _m.On("GetDynamicNodeState", matchers...) - return &NodeStateReader_GetDynamicNodeState{Call: c_call} -} - -// GetDynamicNodeState provides a mock function with given fields: -func (_m *NodeStateReader) GetDynamicNodeState() handler.DynamicNodeState { - ret := _m.Called() - - var r0 handler.DynamicNodeState - if rf, ok := ret.Get(0).(func() handler.DynamicNodeState); ok { - r0 = rf() - } else { - r0 = ret.Get(0).(handler.DynamicNodeState) - } - - return r0 -} - -type NodeStateReader_GetGateNodeState struct { - *mock.Call -} - -func (_m NodeStateReader_GetGateNodeState) Return(_a0 handler.GateNodeState) *NodeStateReader_GetGateNodeState { - return &NodeStateReader_GetGateNodeState{Call: _m.Call.Return(_a0)} -} - -func (_m *NodeStateReader) OnGetGateNodeState() *NodeStateReader_GetGateNodeState { - c_call := _m.On("GetGateNodeState") - return &NodeStateReader_GetGateNodeState{Call: c_call} -} - -func (_m *NodeStateReader) OnGetGateNodeStateMatch(matchers ...interface{}) *NodeStateReader_GetGateNodeState { - c_call := _m.On("GetGateNodeState", matchers...) - return &NodeStateReader_GetGateNodeState{Call: c_call} -} - -// GetGateNodeState provides a mock function with given fields: -func (_m *NodeStateReader) GetGateNodeState() handler.GateNodeState { - ret := _m.Called() - - var r0 handler.GateNodeState - if rf, ok := ret.Get(0).(func() handler.GateNodeState); ok { - r0 = rf() - } else { - r0 = ret.Get(0).(handler.GateNodeState) - } - - return r0 -} - -type NodeStateReader_GetTaskNodeState struct { - *mock.Call -} - -func (_m NodeStateReader_GetTaskNodeState) Return(_a0 handler.TaskNodeState) *NodeStateReader_GetTaskNodeState { - return &NodeStateReader_GetTaskNodeState{Call: _m.Call.Return(_a0)} -} - -func (_m *NodeStateReader) OnGetTaskNodeState() *NodeStateReader_GetTaskNodeState { - c_call := _m.On("GetTaskNodeState") - return &NodeStateReader_GetTaskNodeState{Call: c_call} -} - -func (_m *NodeStateReader) OnGetTaskNodeStateMatch(matchers ...interface{}) *NodeStateReader_GetTaskNodeState { - c_call := _m.On("GetTaskNodeState", matchers...) - return &NodeStateReader_GetTaskNodeState{Call: c_call} -} - -// GetTaskNodeState provides a mock function with given fields: -func (_m *NodeStateReader) GetTaskNodeState() handler.TaskNodeState { - ret := _m.Called() - - var r0 handler.TaskNodeState - if rf, ok := ret.Get(0).(func() handler.TaskNodeState); ok { - r0 = rf() - } else { - r0 = ret.Get(0).(handler.TaskNodeState) - } - - return r0 -} - -type NodeStateReader_GetWorkflowNodeState struct { - *mock.Call -} - -func (_m NodeStateReader_GetWorkflowNodeState) Return(_a0 handler.WorkflowNodeState) *NodeStateReader_GetWorkflowNodeState { - return &NodeStateReader_GetWorkflowNodeState{Call: _m.Call.Return(_a0)} -} - -func (_m *NodeStateReader) OnGetWorkflowNodeState() *NodeStateReader_GetWorkflowNodeState { - c_call := _m.On("GetWorkflowNodeState") - return &NodeStateReader_GetWorkflowNodeState{Call: c_call} -} - -func (_m *NodeStateReader) OnGetWorkflowNodeStateMatch(matchers ...interface{}) *NodeStateReader_GetWorkflowNodeState { - c_call := _m.On("GetWorkflowNodeState", matchers...) - return &NodeStateReader_GetWorkflowNodeState{Call: c_call} -} - -// GetWorkflowNodeState provides a mock function with given fields: -func (_m *NodeStateReader) GetWorkflowNodeState() handler.WorkflowNodeState { - ret := _m.Called() - - var r0 handler.WorkflowNodeState - if rf, ok := ret.Get(0).(func() handler.WorkflowNodeState); ok { - r0 = rf() - } else { - r0 = ret.Get(0).(handler.WorkflowNodeState) - } - - return r0 -} diff --git a/flytepropeller/pkg/controller/nodes/handler/state.go b/flytepropeller/pkg/controller/nodes/handler/state.go index 9688f4e33a..89adfc8f8d 100644 --- a/flytepropeller/pkg/controller/nodes/handler/state.go +++ b/flytepropeller/pkg/controller/nodes/handler/state.go @@ -4,8 +4,12 @@ import ( "time" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + pluginCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + + "github.com/flyteorg/flytestdlib/bitarray" "github.com/flyteorg/flytestdlib/storage" ) @@ -46,18 +50,12 @@ type GateNodeState struct { StartedAt time.Time } -type NodeStateWriter interface { - PutTaskNodeState(s TaskNodeState) error - PutBranchNode(s BranchNodeState) error - PutDynamicNodeState(s DynamicNodeState) error - PutWorkflowNodeState(s WorkflowNodeState) error - PutGateNodeState(s GateNodeState) error -} - -type NodeStateReader interface { - GetTaskNodeState() TaskNodeState - GetBranchNode() BranchNodeState - GetDynamicNodeState() DynamicNodeState - GetWorkflowNodeState() WorkflowNodeState - GetGateNodeState() GateNodeState +type ArrayNodeState struct { + Phase v1alpha1.ArrayNodePhase + TaskPhaseVersion uint32 + Error *core.ExecutionError + SubNodePhases bitarray.CompactArray + SubNodeTaskPhases bitarray.CompactArray + SubNodeRetryAttempts bitarray.CompactArray + SubNodeSystemFailures bitarray.CompactArray } diff --git a/flytepropeller/pkg/controller/nodes/handler/transition_info.go b/flytepropeller/pkg/controller/nodes/handler/transition_info.go index 5d302f4fa5..0cce41ef43 100644 --- a/flytepropeller/pkg/controller/nodes/handler/transition_info.go +++ b/flytepropeller/pkg/controller/nodes/handler/transition_info.go @@ -52,6 +52,9 @@ type TaskNodeInfo struct { type GateNodeInfo struct { } +type ArrayNodeInfo struct { +} + type OutputInfo struct { OutputURI storage.DataReference DeckURI *storage.DataReference @@ -65,6 +68,7 @@ type ExecutionInfo struct { OutputInfo *OutputInfo TaskNodeInfo *TaskNodeInfo GateNodeInfo *GateNodeInfo + ArrayNodeInfo *ArrayNodeInfo } type PhaseInfo struct { diff --git a/flytepropeller/pkg/controller/nodes/handler/transition_info_test.go b/flytepropeller/pkg/controller/nodes/handler/transition_info_test.go index 579d16cb53..9f85f0628b 100644 --- a/flytepropeller/pkg/controller/nodes/handler/transition_info_test.go +++ b/flytepropeller/pkg/controller/nodes/handler/transition_info_test.go @@ -4,9 +4,10 @@ import ( "testing" "github.com/flyteorg/flyteidl/clients/go/coreutils" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + "github.com/golang/protobuf/proto" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/stretchr/testify/assert" ) diff --git a/flytepropeller/pkg/controller/nodes/handler/transition_test.go b/flytepropeller/pkg/controller/nodes/handler/transition_test.go index 32f79d2dce..61236531fc 100644 --- a/flytepropeller/pkg/controller/nodes/handler/transition_test.go +++ b/flytepropeller/pkg/controller/nodes/handler/transition_test.go @@ -6,6 +6,7 @@ import ( "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flytestdlib/storage" + "github.com/stretchr/testify/assert" ) diff --git a/flytepropeller/pkg/controller/nodes/handler_factory.go b/flytepropeller/pkg/controller/nodes/handler_factory.go deleted file mode 100644 index e13143e6b2..0000000000 --- a/flytepropeller/pkg/controller/nodes/handler_factory.go +++ /dev/null @@ -1,81 +0,0 @@ -package nodes - -import ( - "context" - - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service" - - "github.com/flyteorg/flytepropeller/pkg/controller/config" - - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/recovery" - - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/catalog" - - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/dynamic" - - "github.com/flyteorg/flytestdlib/promutils" - - "github.com/pkg/errors" - - "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" - "github.com/flyteorg/flytepropeller/pkg/controller/executors" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/branch" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/end" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/gate" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/start" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/subworkflow" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/subworkflow/launchplan" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task" -) - -//go:generate mockery -name HandlerFactory -case=underscore - -type HandlerFactory interface { - GetHandler(kind v1alpha1.NodeKind) (handler.Node, error) - Setup(ctx context.Context, setup handler.SetupContext) error -} - -type handlerFactory struct { - handlers map[v1alpha1.NodeKind]handler.Node -} - -func (f handlerFactory) GetHandler(kind v1alpha1.NodeKind) (handler.Node, error) { - h, ok := f.handlers[kind] - if !ok { - return nil, errors.Errorf("Handler not registered for NodeKind [%v]", kind) - } - return h, nil -} - -func (f handlerFactory) Setup(ctx context.Context, setup handler.SetupContext) error { - for _, v := range f.handlers { - if err := v.Setup(ctx, setup); err != nil { - return err - } - } - return nil -} - -func NewHandlerFactory(ctx context.Context, executor executors.Node, workflowLauncher launchplan.Executor, - launchPlanReader launchplan.Reader, kubeClient executors.Client, client catalog.Client, recoveryClient recovery.Client, - eventConfig *config.EventConfig, clusterID string, signalClient service.SignalServiceClient, scope promutils.Scope) (HandlerFactory, error) { - - t, err := task.New(ctx, kubeClient, client, eventConfig, clusterID, scope) - if err != nil { - return nil, err - } - - f := &handlerFactory{ - handlers: map[v1alpha1.NodeKind]handler.Node{ - v1alpha1.NodeKindBranch: branch.New(executor, eventConfig, scope), - v1alpha1.NodeKindTask: dynamic.New(t, executor, launchPlanReader, eventConfig, scope), - v1alpha1.NodeKindWorkflow: subworkflow.New(executor, workflowLauncher, recoveryClient, eventConfig, scope), - v1alpha1.NodeKindGate: gate.New(eventConfig, signalClient, scope), - v1alpha1.NodeKindStart: start.New(), - v1alpha1.NodeKindEnd: end.New(), - }, - } - - return f, nil -} diff --git a/flytepropeller/pkg/controller/nodes/handler/iface.go b/flytepropeller/pkg/controller/nodes/interfaces/handler.go similarity index 54% rename from flytepropeller/pkg/controller/nodes/handler/iface.go rename to flytepropeller/pkg/controller/nodes/interfaces/handler.go index d0b3591712..16ef732742 100644 --- a/flytepropeller/pkg/controller/nodes/handler/iface.go +++ b/flytepropeller/pkg/controller/nodes/interfaces/handler.go @@ -1,13 +1,24 @@ -package handler +package interfaces import ( "context" + + "github.com/flyteorg/flytepropeller/pkg/controller/executors" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + "github.com/flyteorg/flytestdlib/promutils" ) //go:generate mockery -all -case=underscore +// NodeExecutor defines the interface for handling a single Flyte Node of any Node type. +type NodeExecutor interface { + HandleNode(ctx context.Context, dag executors.DAGStructure, nCtx NodeExecutionContext, h NodeHandler) (NodeStatus, error) + Abort(ctx context.Context, h NodeHandler, nCtx NodeExecutionContext, reason string, finalTransition bool) error + Finalize(ctx context.Context, h NodeHandler, nCtx NodeExecutionContext) error +} + // Interface that should be implemented for a node type. -type Node interface { +type NodeHandler interface { // Method to indicate that finalize is required for this handler FinalizeRequired() bool @@ -15,7 +26,7 @@ type Node interface { Setup(ctx context.Context, setupContext SetupContext) error // Core method that should handle this node - Handle(ctx context.Context, executionContext NodeExecutionContext) (Transition, error) + Handle(ctx context.Context, executionContext NodeExecutionContext) (handler.Transition, error) // This method should be invoked to indicate the node needs to be aborted. Abort(ctx context.Context, executionContext NodeExecutionContext, reason string) error @@ -24,3 +35,9 @@ type Node interface { // It is guaranteed that Handle -> (happens before) -> Finalize. Abort -> finalize may be repeated multiple times Finalize(ctx context.Context, executionContext NodeExecutionContext) error } + +type SetupContext interface { + EnqueueOwner() func(string) + OwnerKind() string + MetricsScope() promutils.Scope +} diff --git a/flytepropeller/pkg/controller/nodes/interfaces/handler_factory.go b/flytepropeller/pkg/controller/nodes/interfaces/handler_factory.go new file mode 100644 index 0000000000..a323d8bf85 --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/interfaces/handler_factory.go @@ -0,0 +1,14 @@ +package interfaces + +import ( + "context" + + "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" +) + +//go:generate mockery -name HandlerFactory -case=underscore + +type HandlerFactory interface { + GetHandler(kind v1alpha1.NodeKind) (NodeHandler, error) + Setup(ctx context.Context, executor Node, setup SetupContext) error +} diff --git a/flytepropeller/pkg/controller/nodes/interfaces/mocks/event_recorder.go b/flytepropeller/pkg/controller/nodes/interfaces/mocks/event_recorder.go new file mode 100644 index 0000000000..684419825b --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/interfaces/mocks/event_recorder.go @@ -0,0 +1,82 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import ( + context "context" + + config "github.com/flyteorg/flytepropeller/pkg/controller/config" + + event "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" + + mock "github.com/stretchr/testify/mock" +) + +// EventRecorder is an autogenerated mock type for the EventRecorder type +type EventRecorder struct { + mock.Mock +} + +type EventRecorder_RecordNodeEvent struct { + *mock.Call +} + +func (_m EventRecorder_RecordNodeEvent) Return(_a0 error) *EventRecorder_RecordNodeEvent { + return &EventRecorder_RecordNodeEvent{Call: _m.Call.Return(_a0)} +} + +func (_m *EventRecorder) OnRecordNodeEvent(ctx context.Context, _a1 *event.NodeExecutionEvent, eventConfig *config.EventConfig) *EventRecorder_RecordNodeEvent { + c_call := _m.On("RecordNodeEvent", ctx, _a1, eventConfig) + return &EventRecorder_RecordNodeEvent{Call: c_call} +} + +func (_m *EventRecorder) OnRecordNodeEventMatch(matchers ...interface{}) *EventRecorder_RecordNodeEvent { + c_call := _m.On("RecordNodeEvent", matchers...) + return &EventRecorder_RecordNodeEvent{Call: c_call} +} + +// RecordNodeEvent provides a mock function with given fields: ctx, _a1, eventConfig +func (_m *EventRecorder) RecordNodeEvent(ctx context.Context, _a1 *event.NodeExecutionEvent, eventConfig *config.EventConfig) error { + ret := _m.Called(ctx, _a1, eventConfig) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *event.NodeExecutionEvent, *config.EventConfig) error); ok { + r0 = rf(ctx, _a1, eventConfig) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +type EventRecorder_RecordTaskEvent struct { + *mock.Call +} + +func (_m EventRecorder_RecordTaskEvent) Return(_a0 error) *EventRecorder_RecordTaskEvent { + return &EventRecorder_RecordTaskEvent{Call: _m.Call.Return(_a0)} +} + +func (_m *EventRecorder) OnRecordTaskEvent(ctx context.Context, _a1 *event.TaskExecutionEvent, eventConfig *config.EventConfig) *EventRecorder_RecordTaskEvent { + c_call := _m.On("RecordTaskEvent", ctx, _a1, eventConfig) + return &EventRecorder_RecordTaskEvent{Call: c_call} +} + +func (_m *EventRecorder) OnRecordTaskEventMatch(matchers ...interface{}) *EventRecorder_RecordTaskEvent { + c_call := _m.On("RecordTaskEvent", matchers...) + return &EventRecorder_RecordTaskEvent{Call: c_call} +} + +// RecordTaskEvent provides a mock function with given fields: ctx, _a1, eventConfig +func (_m *EventRecorder) RecordTaskEvent(ctx context.Context, _a1 *event.TaskExecutionEvent, eventConfig *config.EventConfig) error { + ret := _m.Called(ctx, _a1, eventConfig) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *event.TaskExecutionEvent, *config.EventConfig) error); ok { + r0 = rf(ctx, _a1, eventConfig) + } else { + r0 = ret.Error(0) + } + + return r0 +} diff --git a/flytepropeller/pkg/controller/nodes/mocks/handler_factory.go b/flytepropeller/pkg/controller/nodes/interfaces/mocks/handler_factory.go similarity index 61% rename from flytepropeller/pkg/controller/nodes/mocks/handler_factory.go rename to flytepropeller/pkg/controller/nodes/interfaces/mocks/handler_factory.go index fffa3c8188..ca851495f9 100644 --- a/flytepropeller/pkg/controller/nodes/mocks/handler_factory.go +++ b/flytepropeller/pkg/controller/nodes/interfaces/mocks/handler_factory.go @@ -5,7 +5,7 @@ package mocks import ( context "context" - handler "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + interfaces "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" mock "github.com/stretchr/testify/mock" v1alpha1 "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" @@ -20,7 +20,7 @@ type HandlerFactory_GetHandler struct { *mock.Call } -func (_m HandlerFactory_GetHandler) Return(_a0 handler.Node, _a1 error) *HandlerFactory_GetHandler { +func (_m HandlerFactory_GetHandler) Return(_a0 interfaces.NodeHandler, _a1 error) *HandlerFactory_GetHandler { return &HandlerFactory_GetHandler{Call: _m.Call.Return(_a0, _a1)} } @@ -35,15 +35,15 @@ func (_m *HandlerFactory) OnGetHandlerMatch(matchers ...interface{}) *HandlerFac } // GetHandler provides a mock function with given fields: kind -func (_m *HandlerFactory) GetHandler(kind v1alpha1.NodeKind) (handler.Node, error) { +func (_m *HandlerFactory) GetHandler(kind v1alpha1.NodeKind) (interfaces.NodeHandler, error) { ret := _m.Called(kind) - var r0 handler.Node - if rf, ok := ret.Get(0).(func(v1alpha1.NodeKind) handler.Node); ok { + var r0 interfaces.NodeHandler + if rf, ok := ret.Get(0).(func(v1alpha1.NodeKind) interfaces.NodeHandler); ok { r0 = rf(kind) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(handler.Node) + r0 = ret.Get(0).(interfaces.NodeHandler) } } @@ -65,8 +65,8 @@ func (_m HandlerFactory_Setup) Return(_a0 error) *HandlerFactory_Setup { return &HandlerFactory_Setup{Call: _m.Call.Return(_a0)} } -func (_m *HandlerFactory) OnSetup(ctx context.Context, setup handler.SetupContext) *HandlerFactory_Setup { - c_call := _m.On("Setup", ctx, setup) +func (_m *HandlerFactory) OnSetup(ctx context.Context, executor interfaces.Node, setup interfaces.SetupContext) *HandlerFactory_Setup { + c_call := _m.On("Setup", ctx, executor, setup) return &HandlerFactory_Setup{Call: c_call} } @@ -75,13 +75,13 @@ func (_m *HandlerFactory) OnSetupMatch(matchers ...interface{}) *HandlerFactory_ return &HandlerFactory_Setup{Call: c_call} } -// Setup provides a mock function with given fields: ctx, setup -func (_m *HandlerFactory) Setup(ctx context.Context, setup handler.SetupContext) error { - ret := _m.Called(ctx, setup) +// Setup provides a mock function with given fields: ctx, executor, setup +func (_m *HandlerFactory) Setup(ctx context.Context, executor interfaces.Node, setup interfaces.SetupContext) error { + ret := _m.Called(ctx, executor, setup) var r0 error - if rf, ok := ret.Get(0).(func(context.Context, handler.SetupContext) error); ok { - r0 = rf(ctx, setup) + if rf, ok := ret.Get(0).(func(context.Context, interfaces.Node, interfaces.SetupContext) error); ok { + r0 = rf(ctx, executor, setup) } else { r0 = ret.Error(0) } diff --git a/flytepropeller/pkg/controller/executors/mocks/node.go b/flytepropeller/pkg/controller/nodes/interfaces/mocks/node.go similarity index 68% rename from flytepropeller/pkg/controller/executors/mocks/node.go rename to flytepropeller/pkg/controller/nodes/interfaces/mocks/node.go index 8f2d5cf0c5..0413f7ebcb 100644 --- a/flytepropeller/pkg/controller/executors/mocks/node.go +++ b/flytepropeller/pkg/controller/nodes/interfaces/mocks/node.go @@ -8,6 +8,8 @@ import ( core "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" executors "github.com/flyteorg/flytepropeller/pkg/controller/executors" + interfaces "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" + mock "github.com/stretchr/testify/mock" v1alpha1 "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" @@ -82,6 +84,40 @@ func (_m *Node) FinalizeHandler(ctx context.Context, execContext executors.Execu return r0 } +type Node_GetNodeExecutionContextBuilder struct { + *mock.Call +} + +func (_m Node_GetNodeExecutionContextBuilder) Return(_a0 interfaces.NodeExecutionContextBuilder) *Node_GetNodeExecutionContextBuilder { + return &Node_GetNodeExecutionContextBuilder{Call: _m.Call.Return(_a0)} +} + +func (_m *Node) OnGetNodeExecutionContextBuilder() *Node_GetNodeExecutionContextBuilder { + c_call := _m.On("GetNodeExecutionContextBuilder") + return &Node_GetNodeExecutionContextBuilder{Call: c_call} +} + +func (_m *Node) OnGetNodeExecutionContextBuilderMatch(matchers ...interface{}) *Node_GetNodeExecutionContextBuilder { + c_call := _m.On("GetNodeExecutionContextBuilder", matchers...) + return &Node_GetNodeExecutionContextBuilder{Call: c_call} +} + +// GetNodeExecutionContextBuilder provides a mock function with given fields: +func (_m *Node) GetNodeExecutionContextBuilder() interfaces.NodeExecutionContextBuilder { + ret := _m.Called() + + var r0 interfaces.NodeExecutionContextBuilder + if rf, ok := ret.Get(0).(func() interfaces.NodeExecutionContextBuilder); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(interfaces.NodeExecutionContextBuilder) + } + } + + return r0 +} + type Node_Initialize struct { *mock.Call } @@ -118,7 +154,7 @@ type Node_RecursiveNodeHandler struct { *mock.Call } -func (_m Node_RecursiveNodeHandler) Return(_a0 executors.NodeStatus, _a1 error) *Node_RecursiveNodeHandler { +func (_m Node_RecursiveNodeHandler) Return(_a0 interfaces.NodeStatus, _a1 error) *Node_RecursiveNodeHandler { return &Node_RecursiveNodeHandler{Call: _m.Call.Return(_a0, _a1)} } @@ -133,14 +169,14 @@ func (_m *Node) OnRecursiveNodeHandlerMatch(matchers ...interface{}) *Node_Recur } // RecursiveNodeHandler provides a mock function with given fields: ctx, execContext, dag, nl, currentNode -func (_m *Node) RecursiveNodeHandler(ctx context.Context, execContext executors.ExecutionContext, dag executors.DAGStructure, nl executors.NodeLookup, currentNode v1alpha1.ExecutableNode) (executors.NodeStatus, error) { +func (_m *Node) RecursiveNodeHandler(ctx context.Context, execContext executors.ExecutionContext, dag executors.DAGStructure, nl executors.NodeLookup, currentNode v1alpha1.ExecutableNode) (interfaces.NodeStatus, error) { ret := _m.Called(ctx, execContext, dag, nl, currentNode) - var r0 executors.NodeStatus - if rf, ok := ret.Get(0).(func(context.Context, executors.ExecutionContext, executors.DAGStructure, executors.NodeLookup, v1alpha1.ExecutableNode) executors.NodeStatus); ok { + var r0 interfaces.NodeStatus + if rf, ok := ret.Get(0).(func(context.Context, executors.ExecutionContext, executors.DAGStructure, executors.NodeLookup, v1alpha1.ExecutableNode) interfaces.NodeStatus); ok { r0 = rf(ctx, execContext, dag, nl, currentNode) } else { - r0 = ret.Get(0).(executors.NodeStatus) + r0 = ret.Get(0).(interfaces.NodeStatus) } var r1 error @@ -157,7 +193,7 @@ type Node_SetInputsForStartNode struct { *mock.Call } -func (_m Node_SetInputsForStartNode) Return(_a0 executors.NodeStatus, _a1 error) *Node_SetInputsForStartNode { +func (_m Node_SetInputsForStartNode) Return(_a0 interfaces.NodeStatus, _a1 error) *Node_SetInputsForStartNode { return &Node_SetInputsForStartNode{Call: _m.Call.Return(_a0, _a1)} } @@ -172,14 +208,14 @@ func (_m *Node) OnSetInputsForStartNodeMatch(matchers ...interface{}) *Node_SetI } // SetInputsForStartNode provides a mock function with given fields: ctx, execContext, dag, nl, inputs -func (_m *Node) SetInputsForStartNode(ctx context.Context, execContext executors.ExecutionContext, dag executors.DAGStructureWithStartNode, nl executors.NodeLookup, inputs *core.LiteralMap) (executors.NodeStatus, error) { +func (_m *Node) SetInputsForStartNode(ctx context.Context, execContext executors.ExecutionContext, dag executors.DAGStructureWithStartNode, nl executors.NodeLookup, inputs *core.LiteralMap) (interfaces.NodeStatus, error) { ret := _m.Called(ctx, execContext, dag, nl, inputs) - var r0 executors.NodeStatus - if rf, ok := ret.Get(0).(func(context.Context, executors.ExecutionContext, executors.DAGStructureWithStartNode, executors.NodeLookup, *core.LiteralMap) executors.NodeStatus); ok { + var r0 interfaces.NodeStatus + if rf, ok := ret.Get(0).(func(context.Context, executors.ExecutionContext, executors.DAGStructureWithStartNode, executors.NodeLookup, *core.LiteralMap) interfaces.NodeStatus); ok { r0 = rf(ctx, execContext, dag, nl, inputs) } else { - r0 = ret.Get(0).(executors.NodeStatus) + r0 = ret.Get(0).(interfaces.NodeStatus) } var r1 error @@ -191,3 +227,37 @@ func (_m *Node) SetInputsForStartNode(ctx context.Context, execContext executors return r0, r1 } + +type Node_WithNodeExecutionContextBuilder struct { + *mock.Call +} + +func (_m Node_WithNodeExecutionContextBuilder) Return(_a0 interfaces.Node) *Node_WithNodeExecutionContextBuilder { + return &Node_WithNodeExecutionContextBuilder{Call: _m.Call.Return(_a0)} +} + +func (_m *Node) OnWithNodeExecutionContextBuilder(_a0 interfaces.NodeExecutionContextBuilder) *Node_WithNodeExecutionContextBuilder { + c_call := _m.On("WithNodeExecutionContextBuilder", _a0) + return &Node_WithNodeExecutionContextBuilder{Call: c_call} +} + +func (_m *Node) OnWithNodeExecutionContextBuilderMatch(matchers ...interface{}) *Node_WithNodeExecutionContextBuilder { + c_call := _m.On("WithNodeExecutionContextBuilder", matchers...) + return &Node_WithNodeExecutionContextBuilder{Call: c_call} +} + +// WithNodeExecutionContextBuilder provides a mock function with given fields: _a0 +func (_m *Node) WithNodeExecutionContextBuilder(_a0 interfaces.NodeExecutionContextBuilder) interfaces.Node { + ret := _m.Called(_a0) + + var r0 interfaces.Node + if rf, ok := ret.Get(0).(func(interfaces.NodeExecutionContextBuilder) interfaces.Node); ok { + r0 = rf(_a0) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(interfaces.Node) + } + } + + return r0 +} diff --git a/flytepropeller/pkg/controller/nodes/handler/mocks/node_execution_context.go b/flytepropeller/pkg/controller/nodes/interfaces/mocks/node_execution_context.go similarity index 90% rename from flytepropeller/pkg/controller/nodes/handler/mocks/node_execution_context.go rename to flytepropeller/pkg/controller/nodes/interfaces/mocks/node_execution_context.go index 434f78caa4..fcf130e304 100644 --- a/flytepropeller/pkg/controller/nodes/handler/mocks/node_execution_context.go +++ b/flytepropeller/pkg/controller/nodes/interfaces/mocks/node_execution_context.go @@ -3,9 +3,8 @@ package mocks import ( - events "github.com/flyteorg/flytepropeller/events" executors "github.com/flyteorg/flytepropeller/pkg/controller/executors" - handler "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + interfaces "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" io "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io" @@ -161,7 +160,7 @@ type NodeExecutionContext_EventsRecorder struct { *mock.Call } -func (_m NodeExecutionContext_EventsRecorder) Return(_a0 events.TaskEventRecorder) *NodeExecutionContext_EventsRecorder { +func (_m NodeExecutionContext_EventsRecorder) Return(_a0 interfaces.EventRecorder) *NodeExecutionContext_EventsRecorder { return &NodeExecutionContext_EventsRecorder{Call: _m.Call.Return(_a0)} } @@ -176,15 +175,15 @@ func (_m *NodeExecutionContext) OnEventsRecorderMatch(matchers ...interface{}) * } // EventsRecorder provides a mock function with given fields: -func (_m *NodeExecutionContext) EventsRecorder() events.TaskEventRecorder { +func (_m *NodeExecutionContext) EventsRecorder() interfaces.EventRecorder { ret := _m.Called() - var r0 events.TaskEventRecorder - if rf, ok := ret.Get(0).(func() events.TaskEventRecorder); ok { + var r0 interfaces.EventRecorder + if rf, ok := ret.Get(0).(func() interfaces.EventRecorder); ok { r0 = rf() } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(events.TaskEventRecorder) + r0 = ret.Get(0).(interfaces.EventRecorder) } } @@ -329,7 +328,7 @@ type NodeExecutionContext_NodeExecutionMetadata struct { *mock.Call } -func (_m NodeExecutionContext_NodeExecutionMetadata) Return(_a0 handler.NodeExecutionMetadata) *NodeExecutionContext_NodeExecutionMetadata { +func (_m NodeExecutionContext_NodeExecutionMetadata) Return(_a0 interfaces.NodeExecutionMetadata) *NodeExecutionContext_NodeExecutionMetadata { return &NodeExecutionContext_NodeExecutionMetadata{Call: _m.Call.Return(_a0)} } @@ -344,15 +343,15 @@ func (_m *NodeExecutionContext) OnNodeExecutionMetadataMatch(matchers ...interfa } // NodeExecutionMetadata provides a mock function with given fields: -func (_m *NodeExecutionContext) NodeExecutionMetadata() handler.NodeExecutionMetadata { +func (_m *NodeExecutionContext) NodeExecutionMetadata() interfaces.NodeExecutionMetadata { ret := _m.Called() - var r0 handler.NodeExecutionMetadata - if rf, ok := ret.Get(0).(func() handler.NodeExecutionMetadata); ok { + var r0 interfaces.NodeExecutionMetadata + if rf, ok := ret.Get(0).(func() interfaces.NodeExecutionMetadata); ok { r0 = rf() } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(handler.NodeExecutionMetadata) + r0 = ret.Get(0).(interfaces.NodeExecutionMetadata) } } @@ -395,7 +394,7 @@ type NodeExecutionContext_NodeStateReader struct { *mock.Call } -func (_m NodeExecutionContext_NodeStateReader) Return(_a0 handler.NodeStateReader) *NodeExecutionContext_NodeStateReader { +func (_m NodeExecutionContext_NodeStateReader) Return(_a0 interfaces.NodeStateReader) *NodeExecutionContext_NodeStateReader { return &NodeExecutionContext_NodeStateReader{Call: _m.Call.Return(_a0)} } @@ -410,15 +409,15 @@ func (_m *NodeExecutionContext) OnNodeStateReaderMatch(matchers ...interface{}) } // NodeStateReader provides a mock function with given fields: -func (_m *NodeExecutionContext) NodeStateReader() handler.NodeStateReader { +func (_m *NodeExecutionContext) NodeStateReader() interfaces.NodeStateReader { ret := _m.Called() - var r0 handler.NodeStateReader - if rf, ok := ret.Get(0).(func() handler.NodeStateReader); ok { + var r0 interfaces.NodeStateReader + if rf, ok := ret.Get(0).(func() interfaces.NodeStateReader); ok { r0 = rf() } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(handler.NodeStateReader) + r0 = ret.Get(0).(interfaces.NodeStateReader) } } @@ -429,7 +428,7 @@ type NodeExecutionContext_NodeStateWriter struct { *mock.Call } -func (_m NodeExecutionContext_NodeStateWriter) Return(_a0 handler.NodeStateWriter) *NodeExecutionContext_NodeStateWriter { +func (_m NodeExecutionContext_NodeStateWriter) Return(_a0 interfaces.NodeStateWriter) *NodeExecutionContext_NodeStateWriter { return &NodeExecutionContext_NodeStateWriter{Call: _m.Call.Return(_a0)} } @@ -444,15 +443,15 @@ func (_m *NodeExecutionContext) OnNodeStateWriterMatch(matchers ...interface{}) } // NodeStateWriter provides a mock function with given fields: -func (_m *NodeExecutionContext) NodeStateWriter() handler.NodeStateWriter { +func (_m *NodeExecutionContext) NodeStateWriter() interfaces.NodeStateWriter { ret := _m.Called() - var r0 handler.NodeStateWriter - if rf, ok := ret.Get(0).(func() handler.NodeStateWriter); ok { + var r0 interfaces.NodeStateWriter + if rf, ok := ret.Get(0).(func() interfaces.NodeStateWriter); ok { r0 = rf() } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(handler.NodeStateWriter) + r0 = ret.Get(0).(interfaces.NodeStateWriter) } } @@ -563,7 +562,7 @@ type NodeExecutionContext_TaskReader struct { *mock.Call } -func (_m NodeExecutionContext_TaskReader) Return(_a0 handler.TaskReader) *NodeExecutionContext_TaskReader { +func (_m NodeExecutionContext_TaskReader) Return(_a0 interfaces.TaskReader) *NodeExecutionContext_TaskReader { return &NodeExecutionContext_TaskReader{Call: _m.Call.Return(_a0)} } @@ -578,15 +577,15 @@ func (_m *NodeExecutionContext) OnTaskReaderMatch(matchers ...interface{}) *Node } // TaskReader provides a mock function with given fields: -func (_m *NodeExecutionContext) TaskReader() handler.TaskReader { +func (_m *NodeExecutionContext) TaskReader() interfaces.TaskReader { ret := _m.Called() - var r0 handler.TaskReader - if rf, ok := ret.Get(0).(func() handler.TaskReader); ok { + var r0 interfaces.TaskReader + if rf, ok := ret.Get(0).(func() interfaces.TaskReader); ok { r0 = rf() } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(handler.TaskReader) + r0 = ret.Get(0).(interfaces.TaskReader) } } diff --git a/flytepropeller/pkg/controller/nodes/interfaces/mocks/node_execution_context_builder.go b/flytepropeller/pkg/controller/nodes/interfaces/mocks/node_execution_context_builder.go new file mode 100644 index 0000000000..d068f19025 --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/interfaces/mocks/node_execution_context_builder.go @@ -0,0 +1,58 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import ( + context "context" + + executors "github.com/flyteorg/flytepropeller/pkg/controller/executors" + interfaces "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" + + mock "github.com/stretchr/testify/mock" +) + +// NodeExecutionContextBuilder is an autogenerated mock type for the NodeExecutionContextBuilder type +type NodeExecutionContextBuilder struct { + mock.Mock +} + +type NodeExecutionContextBuilder_BuildNodeExecutionContext struct { + *mock.Call +} + +func (_m NodeExecutionContextBuilder_BuildNodeExecutionContext) Return(_a0 interfaces.NodeExecutionContext, _a1 error) *NodeExecutionContextBuilder_BuildNodeExecutionContext { + return &NodeExecutionContextBuilder_BuildNodeExecutionContext{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *NodeExecutionContextBuilder) OnBuildNodeExecutionContext(ctx context.Context, executionContext executors.ExecutionContext, nl executors.NodeLookup, currentNodeID string) *NodeExecutionContextBuilder_BuildNodeExecutionContext { + c_call := _m.On("BuildNodeExecutionContext", ctx, executionContext, nl, currentNodeID) + return &NodeExecutionContextBuilder_BuildNodeExecutionContext{Call: c_call} +} + +func (_m *NodeExecutionContextBuilder) OnBuildNodeExecutionContextMatch(matchers ...interface{}) *NodeExecutionContextBuilder_BuildNodeExecutionContext { + c_call := _m.On("BuildNodeExecutionContext", matchers...) + return &NodeExecutionContextBuilder_BuildNodeExecutionContext{Call: c_call} +} + +// BuildNodeExecutionContext provides a mock function with given fields: ctx, executionContext, nl, currentNodeID +func (_m *NodeExecutionContextBuilder) BuildNodeExecutionContext(ctx context.Context, executionContext executors.ExecutionContext, nl executors.NodeLookup, currentNodeID string) (interfaces.NodeExecutionContext, error) { + ret := _m.Called(ctx, executionContext, nl, currentNodeID) + + var r0 interfaces.NodeExecutionContext + if rf, ok := ret.Get(0).(func(context.Context, executors.ExecutionContext, executors.NodeLookup, string) interfaces.NodeExecutionContext); ok { + r0 = rf(ctx, executionContext, nl, currentNodeID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(interfaces.NodeExecutionContext) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, executors.ExecutionContext, executors.NodeLookup, string) error); ok { + r1 = rf(ctx, executionContext, nl, currentNodeID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} diff --git a/flytepropeller/pkg/controller/nodes/handler/mocks/node_execution_metadata.go b/flytepropeller/pkg/controller/nodes/interfaces/mocks/node_execution_metadata.go similarity index 100% rename from flytepropeller/pkg/controller/nodes/handler/mocks/node_execution_metadata.go rename to flytepropeller/pkg/controller/nodes/interfaces/mocks/node_execution_metadata.go diff --git a/flytepropeller/pkg/controller/nodes/interfaces/mocks/node_executor.go b/flytepropeller/pkg/controller/nodes/interfaces/mocks/node_executor.go new file mode 100644 index 0000000000..e619c8fd70 --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/interfaces/mocks/node_executor.go @@ -0,0 +1,120 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import ( + context "context" + + executors "github.com/flyteorg/flytepropeller/pkg/controller/executors" + interfaces "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" + + mock "github.com/stretchr/testify/mock" +) + +// NodeExecutor is an autogenerated mock type for the NodeExecutor type +type NodeExecutor struct { + mock.Mock +} + +type NodeExecutor_Abort struct { + *mock.Call +} + +func (_m NodeExecutor_Abort) Return(_a0 error) *NodeExecutor_Abort { + return &NodeExecutor_Abort{Call: _m.Call.Return(_a0)} +} + +func (_m *NodeExecutor) OnAbort(ctx context.Context, h interfaces.NodeHandler, nCtx interfaces.NodeExecutionContext, reason string, finalTransition bool) *NodeExecutor_Abort { + c_call := _m.On("Abort", ctx, h, nCtx, reason, finalTransition) + return &NodeExecutor_Abort{Call: c_call} +} + +func (_m *NodeExecutor) OnAbortMatch(matchers ...interface{}) *NodeExecutor_Abort { + c_call := _m.On("Abort", matchers...) + return &NodeExecutor_Abort{Call: c_call} +} + +// Abort provides a mock function with given fields: ctx, h, nCtx, reason, finalTransition +func (_m *NodeExecutor) Abort(ctx context.Context, h interfaces.NodeHandler, nCtx interfaces.NodeExecutionContext, reason string, finalTransition bool) error { + ret := _m.Called(ctx, h, nCtx, reason, finalTransition) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, interfaces.NodeHandler, interfaces.NodeExecutionContext, string, bool) error); ok { + r0 = rf(ctx, h, nCtx, reason, finalTransition) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +type NodeExecutor_Finalize struct { + *mock.Call +} + +func (_m NodeExecutor_Finalize) Return(_a0 error) *NodeExecutor_Finalize { + return &NodeExecutor_Finalize{Call: _m.Call.Return(_a0)} +} + +func (_m *NodeExecutor) OnFinalize(ctx context.Context, h interfaces.NodeHandler, nCtx interfaces.NodeExecutionContext) *NodeExecutor_Finalize { + c_call := _m.On("Finalize", ctx, h, nCtx) + return &NodeExecutor_Finalize{Call: c_call} +} + +func (_m *NodeExecutor) OnFinalizeMatch(matchers ...interface{}) *NodeExecutor_Finalize { + c_call := _m.On("Finalize", matchers...) + return &NodeExecutor_Finalize{Call: c_call} +} + +// Finalize provides a mock function with given fields: ctx, h, nCtx +func (_m *NodeExecutor) Finalize(ctx context.Context, h interfaces.NodeHandler, nCtx interfaces.NodeExecutionContext) error { + ret := _m.Called(ctx, h, nCtx) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, interfaces.NodeHandler, interfaces.NodeExecutionContext) error); ok { + r0 = rf(ctx, h, nCtx) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +type NodeExecutor_HandleNode struct { + *mock.Call +} + +func (_m NodeExecutor_HandleNode) Return(_a0 interfaces.NodeStatus, _a1 error) *NodeExecutor_HandleNode { + return &NodeExecutor_HandleNode{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *NodeExecutor) OnHandleNode(ctx context.Context, dag executors.DAGStructure, nCtx interfaces.NodeExecutionContext, h interfaces.NodeHandler) *NodeExecutor_HandleNode { + c_call := _m.On("HandleNode", ctx, dag, nCtx, h) + return &NodeExecutor_HandleNode{Call: c_call} +} + +func (_m *NodeExecutor) OnHandleNodeMatch(matchers ...interface{}) *NodeExecutor_HandleNode { + c_call := _m.On("HandleNode", matchers...) + return &NodeExecutor_HandleNode{Call: c_call} +} + +// HandleNode provides a mock function with given fields: ctx, dag, nCtx, h +func (_m *NodeExecutor) HandleNode(ctx context.Context, dag executors.DAGStructure, nCtx interfaces.NodeExecutionContext, h interfaces.NodeHandler) (interfaces.NodeStatus, error) { + ret := _m.Called(ctx, dag, nCtx, h) + + var r0 interfaces.NodeStatus + if rf, ok := ret.Get(0).(func(context.Context, executors.DAGStructure, interfaces.NodeExecutionContext, interfaces.NodeHandler) interfaces.NodeStatus); ok { + r0 = rf(ctx, dag, nCtx, h) + } else { + r0 = ret.Get(0).(interfaces.NodeStatus) + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, executors.DAGStructure, interfaces.NodeExecutionContext, interfaces.NodeHandler) error); ok { + r1 = rf(ctx, dag, nCtx, h) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} diff --git a/flytepropeller/pkg/controller/nodes/interfaces/mocks/node_handler.go b/flytepropeller/pkg/controller/nodes/interfaces/mocks/node_handler.go new file mode 100644 index 0000000000..66bd61d27d --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/interfaces/mocks/node_handler.go @@ -0,0 +1,184 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import ( + context "context" + + handler "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + interfaces "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" + + mock "github.com/stretchr/testify/mock" +) + +// NodeHandler is an autogenerated mock type for the NodeHandler type +type NodeHandler struct { + mock.Mock +} + +type NodeHandler_Abort struct { + *mock.Call +} + +func (_m NodeHandler_Abort) Return(_a0 error) *NodeHandler_Abort { + return &NodeHandler_Abort{Call: _m.Call.Return(_a0)} +} + +func (_m *NodeHandler) OnAbort(ctx context.Context, executionContext interfaces.NodeExecutionContext, reason string) *NodeHandler_Abort { + c_call := _m.On("Abort", ctx, executionContext, reason) + return &NodeHandler_Abort{Call: c_call} +} + +func (_m *NodeHandler) OnAbortMatch(matchers ...interface{}) *NodeHandler_Abort { + c_call := _m.On("Abort", matchers...) + return &NodeHandler_Abort{Call: c_call} +} + +// Abort provides a mock function with given fields: ctx, executionContext, reason +func (_m *NodeHandler) Abort(ctx context.Context, executionContext interfaces.NodeExecutionContext, reason string) error { + ret := _m.Called(ctx, executionContext, reason) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, interfaces.NodeExecutionContext, string) error); ok { + r0 = rf(ctx, executionContext, reason) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +type NodeHandler_Finalize struct { + *mock.Call +} + +func (_m NodeHandler_Finalize) Return(_a0 error) *NodeHandler_Finalize { + return &NodeHandler_Finalize{Call: _m.Call.Return(_a0)} +} + +func (_m *NodeHandler) OnFinalize(ctx context.Context, executionContext interfaces.NodeExecutionContext) *NodeHandler_Finalize { + c_call := _m.On("Finalize", ctx, executionContext) + return &NodeHandler_Finalize{Call: c_call} +} + +func (_m *NodeHandler) OnFinalizeMatch(matchers ...interface{}) *NodeHandler_Finalize { + c_call := _m.On("Finalize", matchers...) + return &NodeHandler_Finalize{Call: c_call} +} + +// Finalize provides a mock function with given fields: ctx, executionContext +func (_m *NodeHandler) Finalize(ctx context.Context, executionContext interfaces.NodeExecutionContext) error { + ret := _m.Called(ctx, executionContext) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, interfaces.NodeExecutionContext) error); ok { + r0 = rf(ctx, executionContext) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +type NodeHandler_FinalizeRequired struct { + *mock.Call +} + +func (_m NodeHandler_FinalizeRequired) Return(_a0 bool) *NodeHandler_FinalizeRequired { + return &NodeHandler_FinalizeRequired{Call: _m.Call.Return(_a0)} +} + +func (_m *NodeHandler) OnFinalizeRequired() *NodeHandler_FinalizeRequired { + c_call := _m.On("FinalizeRequired") + return &NodeHandler_FinalizeRequired{Call: c_call} +} + +func (_m *NodeHandler) OnFinalizeRequiredMatch(matchers ...interface{}) *NodeHandler_FinalizeRequired { + c_call := _m.On("FinalizeRequired", matchers...) + return &NodeHandler_FinalizeRequired{Call: c_call} +} + +// FinalizeRequired provides a mock function with given fields: +func (_m *NodeHandler) FinalizeRequired() bool { + ret := _m.Called() + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +type NodeHandler_Handle struct { + *mock.Call +} + +func (_m NodeHandler_Handle) Return(_a0 handler.Transition, _a1 error) *NodeHandler_Handle { + return &NodeHandler_Handle{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *NodeHandler) OnHandle(ctx context.Context, executionContext interfaces.NodeExecutionContext) *NodeHandler_Handle { + c_call := _m.On("Handle", ctx, executionContext) + return &NodeHandler_Handle{Call: c_call} +} + +func (_m *NodeHandler) OnHandleMatch(matchers ...interface{}) *NodeHandler_Handle { + c_call := _m.On("Handle", matchers...) + return &NodeHandler_Handle{Call: c_call} +} + +// Handle provides a mock function with given fields: ctx, executionContext +func (_m *NodeHandler) Handle(ctx context.Context, executionContext interfaces.NodeExecutionContext) (handler.Transition, error) { + ret := _m.Called(ctx, executionContext) + + var r0 handler.Transition + if rf, ok := ret.Get(0).(func(context.Context, interfaces.NodeExecutionContext) handler.Transition); ok { + r0 = rf(ctx, executionContext) + } else { + r0 = ret.Get(0).(handler.Transition) + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, interfaces.NodeExecutionContext) error); ok { + r1 = rf(ctx, executionContext) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type NodeHandler_Setup struct { + *mock.Call +} + +func (_m NodeHandler_Setup) Return(_a0 error) *NodeHandler_Setup { + return &NodeHandler_Setup{Call: _m.Call.Return(_a0)} +} + +func (_m *NodeHandler) OnSetup(ctx context.Context, setupContext interfaces.SetupContext) *NodeHandler_Setup { + c_call := _m.On("Setup", ctx, setupContext) + return &NodeHandler_Setup{Call: c_call} +} + +func (_m *NodeHandler) OnSetupMatch(matchers ...interface{}) *NodeHandler_Setup { + c_call := _m.On("Setup", matchers...) + return &NodeHandler_Setup{Call: c_call} +} + +// Setup provides a mock function with given fields: ctx, setupContext +func (_m *NodeHandler) Setup(ctx context.Context, setupContext interfaces.SetupContext) error { + ret := _m.Called(ctx, setupContext) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, interfaces.SetupContext) error); ok { + r0 = rf(ctx, setupContext) + } else { + r0 = ret.Error(0) + } + + return r0 +} diff --git a/flytepropeller/pkg/controller/nodes/interfaces/mocks/node_state_reader.go b/flytepropeller/pkg/controller/nodes/interfaces/mocks/node_state_reader.go new file mode 100644 index 0000000000..853eb5b674 --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/interfaces/mocks/node_state_reader.go @@ -0,0 +1,398 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import ( + handler "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + + mock "github.com/stretchr/testify/mock" +) + +// NodeStateReader is an autogenerated mock type for the NodeStateReader type +type NodeStateReader struct { + mock.Mock +} + +type NodeStateReader_GetArrayNodeState struct { + *mock.Call +} + +func (_m NodeStateReader_GetArrayNodeState) Return(_a0 handler.ArrayNodeState) *NodeStateReader_GetArrayNodeState { + return &NodeStateReader_GetArrayNodeState{Call: _m.Call.Return(_a0)} +} + +func (_m *NodeStateReader) OnGetArrayNodeState() *NodeStateReader_GetArrayNodeState { + c_call := _m.On("GetArrayNodeState") + return &NodeStateReader_GetArrayNodeState{Call: c_call} +} + +func (_m *NodeStateReader) OnGetArrayNodeStateMatch(matchers ...interface{}) *NodeStateReader_GetArrayNodeState { + c_call := _m.On("GetArrayNodeState", matchers...) + return &NodeStateReader_GetArrayNodeState{Call: c_call} +} + +// GetArrayNodeState provides a mock function with given fields: +func (_m *NodeStateReader) GetArrayNodeState() handler.ArrayNodeState { + ret := _m.Called() + + var r0 handler.ArrayNodeState + if rf, ok := ret.Get(0).(func() handler.ArrayNodeState); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(handler.ArrayNodeState) + } + + return r0 +} + +type NodeStateReader_GetBranchNodeState struct { + *mock.Call +} + +func (_m NodeStateReader_GetBranchNodeState) Return(_a0 handler.BranchNodeState) *NodeStateReader_GetBranchNodeState { + return &NodeStateReader_GetBranchNodeState{Call: _m.Call.Return(_a0)} +} + +func (_m *NodeStateReader) OnGetBranchNodeState() *NodeStateReader_GetBranchNodeState { + c_call := _m.On("GetBranchNodeState") + return &NodeStateReader_GetBranchNodeState{Call: c_call} +} + +func (_m *NodeStateReader) OnGetBranchNodeStateMatch(matchers ...interface{}) *NodeStateReader_GetBranchNodeState { + c_call := _m.On("GetBranchNodeState", matchers...) + return &NodeStateReader_GetBranchNodeState{Call: c_call} +} + +// GetBranchNodeState provides a mock function with given fields: +func (_m *NodeStateReader) GetBranchNodeState() handler.BranchNodeState { + ret := _m.Called() + + var r0 handler.BranchNodeState + if rf, ok := ret.Get(0).(func() handler.BranchNodeState); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(handler.BranchNodeState) + } + + return r0 +} + +type NodeStateReader_GetDynamicNodeState struct { + *mock.Call +} + +func (_m NodeStateReader_GetDynamicNodeState) Return(_a0 handler.DynamicNodeState) *NodeStateReader_GetDynamicNodeState { + return &NodeStateReader_GetDynamicNodeState{Call: _m.Call.Return(_a0)} +} + +func (_m *NodeStateReader) OnGetDynamicNodeState() *NodeStateReader_GetDynamicNodeState { + c_call := _m.On("GetDynamicNodeState") + return &NodeStateReader_GetDynamicNodeState{Call: c_call} +} + +func (_m *NodeStateReader) OnGetDynamicNodeStateMatch(matchers ...interface{}) *NodeStateReader_GetDynamicNodeState { + c_call := _m.On("GetDynamicNodeState", matchers...) + return &NodeStateReader_GetDynamicNodeState{Call: c_call} +} + +// GetDynamicNodeState provides a mock function with given fields: +func (_m *NodeStateReader) GetDynamicNodeState() handler.DynamicNodeState { + ret := _m.Called() + + var r0 handler.DynamicNodeState + if rf, ok := ret.Get(0).(func() handler.DynamicNodeState); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(handler.DynamicNodeState) + } + + return r0 +} + +type NodeStateReader_GetGateNodeState struct { + *mock.Call +} + +func (_m NodeStateReader_GetGateNodeState) Return(_a0 handler.GateNodeState) *NodeStateReader_GetGateNodeState { + return &NodeStateReader_GetGateNodeState{Call: _m.Call.Return(_a0)} +} + +func (_m *NodeStateReader) OnGetGateNodeState() *NodeStateReader_GetGateNodeState { + c_call := _m.On("GetGateNodeState") + return &NodeStateReader_GetGateNodeState{Call: c_call} +} + +func (_m *NodeStateReader) OnGetGateNodeStateMatch(matchers ...interface{}) *NodeStateReader_GetGateNodeState { + c_call := _m.On("GetGateNodeState", matchers...) + return &NodeStateReader_GetGateNodeState{Call: c_call} +} + +// GetGateNodeState provides a mock function with given fields: +func (_m *NodeStateReader) GetGateNodeState() handler.GateNodeState { + ret := _m.Called() + + var r0 handler.GateNodeState + if rf, ok := ret.Get(0).(func() handler.GateNodeState); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(handler.GateNodeState) + } + + return r0 +} + +type NodeStateReader_GetTaskNodeState struct { + *mock.Call +} + +func (_m NodeStateReader_GetTaskNodeState) Return(_a0 handler.TaskNodeState) *NodeStateReader_GetTaskNodeState { + return &NodeStateReader_GetTaskNodeState{Call: _m.Call.Return(_a0)} +} + +func (_m *NodeStateReader) OnGetTaskNodeState() *NodeStateReader_GetTaskNodeState { + c_call := _m.On("GetTaskNodeState") + return &NodeStateReader_GetTaskNodeState{Call: c_call} +} + +func (_m *NodeStateReader) OnGetTaskNodeStateMatch(matchers ...interface{}) *NodeStateReader_GetTaskNodeState { + c_call := _m.On("GetTaskNodeState", matchers...) + return &NodeStateReader_GetTaskNodeState{Call: c_call} +} + +// GetTaskNodeState provides a mock function with given fields: +func (_m *NodeStateReader) GetTaskNodeState() handler.TaskNodeState { + ret := _m.Called() + + var r0 handler.TaskNodeState + if rf, ok := ret.Get(0).(func() handler.TaskNodeState); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(handler.TaskNodeState) + } + + return r0 +} + +type NodeStateReader_GetWorkflowNodeState struct { + *mock.Call +} + +func (_m NodeStateReader_GetWorkflowNodeState) Return(_a0 handler.WorkflowNodeState) *NodeStateReader_GetWorkflowNodeState { + return &NodeStateReader_GetWorkflowNodeState{Call: _m.Call.Return(_a0)} +} + +func (_m *NodeStateReader) OnGetWorkflowNodeState() *NodeStateReader_GetWorkflowNodeState { + c_call := _m.On("GetWorkflowNodeState") + return &NodeStateReader_GetWorkflowNodeState{Call: c_call} +} + +func (_m *NodeStateReader) OnGetWorkflowNodeStateMatch(matchers ...interface{}) *NodeStateReader_GetWorkflowNodeState { + c_call := _m.On("GetWorkflowNodeState", matchers...) + return &NodeStateReader_GetWorkflowNodeState{Call: c_call} +} + +// GetWorkflowNodeState provides a mock function with given fields: +func (_m *NodeStateReader) GetWorkflowNodeState() handler.WorkflowNodeState { + ret := _m.Called() + + var r0 handler.WorkflowNodeState + if rf, ok := ret.Get(0).(func() handler.WorkflowNodeState); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(handler.WorkflowNodeState) + } + + return r0 +} + +type NodeStateReader_HasArrayNodeState struct { + *mock.Call +} + +func (_m NodeStateReader_HasArrayNodeState) Return(_a0 bool) *NodeStateReader_HasArrayNodeState { + return &NodeStateReader_HasArrayNodeState{Call: _m.Call.Return(_a0)} +} + +func (_m *NodeStateReader) OnHasArrayNodeState() *NodeStateReader_HasArrayNodeState { + c_call := _m.On("HasArrayNodeState") + return &NodeStateReader_HasArrayNodeState{Call: c_call} +} + +func (_m *NodeStateReader) OnHasArrayNodeStateMatch(matchers ...interface{}) *NodeStateReader_HasArrayNodeState { + c_call := _m.On("HasArrayNodeState", matchers...) + return &NodeStateReader_HasArrayNodeState{Call: c_call} +} + +// HasArrayNodeState provides a mock function with given fields: +func (_m *NodeStateReader) HasArrayNodeState() bool { + ret := _m.Called() + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +type NodeStateReader_HasBranchNodeState struct { + *mock.Call +} + +func (_m NodeStateReader_HasBranchNodeState) Return(_a0 bool) *NodeStateReader_HasBranchNodeState { + return &NodeStateReader_HasBranchNodeState{Call: _m.Call.Return(_a0)} +} + +func (_m *NodeStateReader) OnHasBranchNodeState() *NodeStateReader_HasBranchNodeState { + c_call := _m.On("HasBranchNodeState") + return &NodeStateReader_HasBranchNodeState{Call: c_call} +} + +func (_m *NodeStateReader) OnHasBranchNodeStateMatch(matchers ...interface{}) *NodeStateReader_HasBranchNodeState { + c_call := _m.On("HasBranchNodeState", matchers...) + return &NodeStateReader_HasBranchNodeState{Call: c_call} +} + +// HasBranchNodeState provides a mock function with given fields: +func (_m *NodeStateReader) HasBranchNodeState() bool { + ret := _m.Called() + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +type NodeStateReader_HasDynamicNodeState struct { + *mock.Call +} + +func (_m NodeStateReader_HasDynamicNodeState) Return(_a0 bool) *NodeStateReader_HasDynamicNodeState { + return &NodeStateReader_HasDynamicNodeState{Call: _m.Call.Return(_a0)} +} + +func (_m *NodeStateReader) OnHasDynamicNodeState() *NodeStateReader_HasDynamicNodeState { + c_call := _m.On("HasDynamicNodeState") + return &NodeStateReader_HasDynamicNodeState{Call: c_call} +} + +func (_m *NodeStateReader) OnHasDynamicNodeStateMatch(matchers ...interface{}) *NodeStateReader_HasDynamicNodeState { + c_call := _m.On("HasDynamicNodeState", matchers...) + return &NodeStateReader_HasDynamicNodeState{Call: c_call} +} + +// HasDynamicNodeState provides a mock function with given fields: +func (_m *NodeStateReader) HasDynamicNodeState() bool { + ret := _m.Called() + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +type NodeStateReader_HasGateNodeState struct { + *mock.Call +} + +func (_m NodeStateReader_HasGateNodeState) Return(_a0 bool) *NodeStateReader_HasGateNodeState { + return &NodeStateReader_HasGateNodeState{Call: _m.Call.Return(_a0)} +} + +func (_m *NodeStateReader) OnHasGateNodeState() *NodeStateReader_HasGateNodeState { + c_call := _m.On("HasGateNodeState") + return &NodeStateReader_HasGateNodeState{Call: c_call} +} + +func (_m *NodeStateReader) OnHasGateNodeStateMatch(matchers ...interface{}) *NodeStateReader_HasGateNodeState { + c_call := _m.On("HasGateNodeState", matchers...) + return &NodeStateReader_HasGateNodeState{Call: c_call} +} + +// HasGateNodeState provides a mock function with given fields: +func (_m *NodeStateReader) HasGateNodeState() bool { + ret := _m.Called() + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +type NodeStateReader_HasTaskNodeState struct { + *mock.Call +} + +func (_m NodeStateReader_HasTaskNodeState) Return(_a0 bool) *NodeStateReader_HasTaskNodeState { + return &NodeStateReader_HasTaskNodeState{Call: _m.Call.Return(_a0)} +} + +func (_m *NodeStateReader) OnHasTaskNodeState() *NodeStateReader_HasTaskNodeState { + c_call := _m.On("HasTaskNodeState") + return &NodeStateReader_HasTaskNodeState{Call: c_call} +} + +func (_m *NodeStateReader) OnHasTaskNodeStateMatch(matchers ...interface{}) *NodeStateReader_HasTaskNodeState { + c_call := _m.On("HasTaskNodeState", matchers...) + return &NodeStateReader_HasTaskNodeState{Call: c_call} +} + +// HasTaskNodeState provides a mock function with given fields: +func (_m *NodeStateReader) HasTaskNodeState() bool { + ret := _m.Called() + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +type NodeStateReader_HasWorkflowNodeState struct { + *mock.Call +} + +func (_m NodeStateReader_HasWorkflowNodeState) Return(_a0 bool) *NodeStateReader_HasWorkflowNodeState { + return &NodeStateReader_HasWorkflowNodeState{Call: _m.Call.Return(_a0)} +} + +func (_m *NodeStateReader) OnHasWorkflowNodeState() *NodeStateReader_HasWorkflowNodeState { + c_call := _m.On("HasWorkflowNodeState") + return &NodeStateReader_HasWorkflowNodeState{Call: c_call} +} + +func (_m *NodeStateReader) OnHasWorkflowNodeStateMatch(matchers ...interface{}) *NodeStateReader_HasWorkflowNodeState { + c_call := _m.On("HasWorkflowNodeState", matchers...) + return &NodeStateReader_HasWorkflowNodeState{Call: c_call} +} + +// HasWorkflowNodeState provides a mock function with given fields: +func (_m *NodeStateReader) HasWorkflowNodeState() bool { + ret := _m.Called() + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} diff --git a/flytepropeller/pkg/controller/nodes/handler/mocks/node_state_writer.go b/flytepropeller/pkg/controller/nodes/interfaces/mocks/node_state_writer.go similarity index 82% rename from flytepropeller/pkg/controller/nodes/handler/mocks/node_state_writer.go rename to flytepropeller/pkg/controller/nodes/interfaces/mocks/node_state_writer.go index ec5359550a..46c0e2a383 100644 --- a/flytepropeller/pkg/controller/nodes/handler/mocks/node_state_writer.go +++ b/flytepropeller/pkg/controller/nodes/interfaces/mocks/node_state_writer.go @@ -4,6 +4,7 @@ package mocks import ( handler "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + mock "github.com/stretchr/testify/mock" ) @@ -12,6 +13,43 @@ type NodeStateWriter struct { mock.Mock } +// ClearNodeStatus provides a mock function with given fields: +func (_m *NodeStateWriter) ClearNodeStatus() { + _m.Called() +} + +type NodeStateWriter_PutArrayNodeState struct { + *mock.Call +} + +func (_m NodeStateWriter_PutArrayNodeState) Return(_a0 error) *NodeStateWriter_PutArrayNodeState { + return &NodeStateWriter_PutArrayNodeState{Call: _m.Call.Return(_a0)} +} + +func (_m *NodeStateWriter) OnPutArrayNodeState(s handler.ArrayNodeState) *NodeStateWriter_PutArrayNodeState { + c_call := _m.On("PutArrayNodeState", s) + return &NodeStateWriter_PutArrayNodeState{Call: c_call} +} + +func (_m *NodeStateWriter) OnPutArrayNodeStateMatch(matchers ...interface{}) *NodeStateWriter_PutArrayNodeState { + c_call := _m.On("PutArrayNodeState", matchers...) + return &NodeStateWriter_PutArrayNodeState{Call: c_call} +} + +// PutArrayNodeState provides a mock function with given fields: s +func (_m *NodeStateWriter) PutArrayNodeState(s handler.ArrayNodeState) error { + ret := _m.Called(s) + + var r0 error + if rf, ok := ret.Get(0).(func(handler.ArrayNodeState) error); ok { + r0 = rf(s) + } else { + r0 = ret.Error(0) + } + + return r0 +} + type NodeStateWriter_PutBranchNode struct { *mock.Call } diff --git a/flytepropeller/pkg/controller/nodes/handler/mocks/setup_context.go b/flytepropeller/pkg/controller/nodes/interfaces/mocks/setup_context.go similarity index 100% rename from flytepropeller/pkg/controller/nodes/handler/mocks/setup_context.go rename to flytepropeller/pkg/controller/nodes/interfaces/mocks/setup_context.go diff --git a/flytepropeller/pkg/controller/nodes/handler/mocks/task_reader.go b/flytepropeller/pkg/controller/nodes/interfaces/mocks/task_reader.go similarity index 100% rename from flytepropeller/pkg/controller/nodes/handler/mocks/task_reader.go rename to flytepropeller/pkg/controller/nodes/interfaces/mocks/task_reader.go diff --git a/flytepropeller/pkg/controller/executors/node.go b/flytepropeller/pkg/controller/nodes/interfaces/node.go similarity index 73% rename from flytepropeller/pkg/controller/executors/node.go rename to flytepropeller/pkg/controller/nodes/interfaces/node.go index a8f738c3eb..7196897411 100644 --- a/flytepropeller/pkg/controller/executors/node.go +++ b/flytepropeller/pkg/controller/nodes/interfaces/node.go @@ -1,4 +1,4 @@ -package executors +package interfaces import ( "context" @@ -7,6 +7,7 @@ import ( "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/flyteorg/flytepropeller/pkg/controller/executors" ) //go:generate mockery -all -case=underscore @@ -67,22 +68,38 @@ func (p NodePhase) String() string { type Node interface { // This method is used specifically to set inputs for start node. This is because start node does not retrieve inputs // from predecessors, but the inputs are inputs to the workflow or inputs to the parent container (workflow) node. - SetInputsForStartNode(ctx context.Context, execContext ExecutionContext, dag DAGStructureWithStartNode, nl NodeLookup, inputs *core.LiteralMap) (NodeStatus, error) + SetInputsForStartNode(ctx context.Context, execContext executors.ExecutionContext, dag executors.DAGStructureWithStartNode, + nl executors.NodeLookup, inputs *core.LiteralMap) (NodeStatus, error) // This is the main entrypoint to execute a node. It recursively depth-first goes through all ready nodes and starts their execution // This returns either // - 1. It finds a blocking node (not ready, or running) // - 2. A node fails and hence the workflow will fail // - 3. The final/end node has completed and the workflow should be stopped - RecursiveNodeHandler(ctx context.Context, execContext ExecutionContext, dag DAGStructure, nl NodeLookup, currentNode v1alpha1.ExecutableNode) (NodeStatus, error) + RecursiveNodeHandler(ctx context.Context, execContext executors.ExecutionContext, dag executors.DAGStructure, + nl executors.NodeLookup, currentNode v1alpha1.ExecutableNode) (NodeStatus, error) // This aborts the given node. If the given node is complete then it recursively finds the running nodes and aborts them - AbortHandler(ctx context.Context, execContext ExecutionContext, dag DAGStructure, nl NodeLookup, currentNode v1alpha1.ExecutableNode, reason string) error + AbortHandler(ctx context.Context, execContext executors.ExecutionContext, dag executors.DAGStructure, + nl executors.NodeLookup, currentNode v1alpha1.ExecutableNode, reason string) error - FinalizeHandler(ctx context.Context, execContext ExecutionContext, dag DAGStructure, nl NodeLookup, currentNode v1alpha1.ExecutableNode) error + FinalizeHandler(ctx context.Context, execContext executors.ExecutionContext, dag executors.DAGStructure, + nl executors.NodeLookup, currentNode v1alpha1.ExecutableNode) error // This method should be used to initialize Node executor Initialize(ctx context.Context) error + + // GetNodeExecutionContextBuilder returns the current NodeExecutionContextBuilder + GetNodeExecutionContextBuilder() NodeExecutionContextBuilder + + // WithNodeExecutionContextBuilder returns a new Node with the given NodeExecutionContextBuilder + WithNodeExecutionContextBuilder(NodeExecutionContextBuilder) Node +} + +// NodeExecutionContextBuilder defines how a NodeExecutionContext is built +type NodeExecutionContextBuilder interface { + BuildNodeExecutionContext(ctx context.Context, executionContext executors.ExecutionContext, + nl executors.NodeLookup, currentNodeID v1alpha1.NodeID) (NodeExecutionContext, error) } // Helper struct to allow passing of status between functions diff --git a/flytepropeller/pkg/controller/nodes/handler/node_exec_context.go b/flytepropeller/pkg/controller/nodes/interfaces/node_exec_context.go similarity index 91% rename from flytepropeller/pkg/controller/nodes/handler/node_exec_context.go rename to flytepropeller/pkg/controller/nodes/interfaces/node_exec_context.go index 117358dabc..3b8afc384e 100644 --- a/flytepropeller/pkg/controller/nodes/handler/node_exec_context.go +++ b/flytepropeller/pkg/controller/nodes/interfaces/node_exec_context.go @@ -1,20 +1,21 @@ -package handler +package interfaces import ( "context" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/ioutils" - "github.com/flyteorg/flytestdlib/promutils" - "github.com/flyteorg/flytestdlib/storage" - v1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/apimachinery/pkg/types" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/ioutils" "github.com/flyteorg/flytepropeller/events" "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" "github.com/flyteorg/flytepropeller/pkg/controller/executors" + + "github.com/flyteorg/flytestdlib/storage" + + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" ) type TaskReader interface { @@ -23,10 +24,9 @@ type TaskReader interface { GetTaskID() *core.Identifier } -type SetupContext interface { - EnqueueOwner() func(string) - OwnerKind() string - MetricsScope() promutils.Scope +type EventRecorder interface { + events.TaskEventRecorder + events.NodeEventRecorder } type NodeExecutionMetadata interface { @@ -54,7 +54,8 @@ type NodeExecutionContext interface { DataStore() *storage.DataStore InputReader() io.InputReader - EventsRecorder() events.TaskEventRecorder + //EventsRecorder() events.TaskEventRecorder + EventsRecorder() EventRecorder NodeID() v1alpha1.NodeID Node() v1alpha1.ExecutableNode CurrentAttempt() uint32 diff --git a/flytepropeller/pkg/controller/nodes/interfaces/state.go b/flytepropeller/pkg/controller/nodes/interfaces/state.go new file mode 100644 index 0000000000..bdbcad2e1e --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/interfaces/state.go @@ -0,0 +1,30 @@ +package interfaces + +import ( + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" +) + +type NodeStateWriter interface { + PutTaskNodeState(s handler.TaskNodeState) error + PutBranchNode(s handler.BranchNodeState) error + PutDynamicNodeState(s handler.DynamicNodeState) error + PutWorkflowNodeState(s handler.WorkflowNodeState) error + PutGateNodeState(s handler.GateNodeState) error + PutArrayNodeState(s handler.ArrayNodeState) error + ClearNodeStatus() +} + +type NodeStateReader interface { + HasTaskNodeState() bool + GetTaskNodeState() handler.TaskNodeState + HasBranchNodeState() bool + GetBranchNodeState() handler.BranchNodeState + HasDynamicNodeState() bool + GetDynamicNodeState() handler.DynamicNodeState + HasWorkflowNodeState() bool + GetWorkflowNodeState() handler.WorkflowNodeState + HasGateNodeState() bool + GetGateNodeState() handler.GateNodeState + HasArrayNodeState() bool + GetArrayNodeState() handler.ArrayNodeState +} diff --git a/flytepropeller/pkg/controller/nodes/node_exec_context.go b/flytepropeller/pkg/controller/nodes/node_exec_context.go index 94a83040a7..ba43d1ba77 100644 --- a/flytepropeller/pkg/controller/nodes/node_exec_context.go +++ b/flytepropeller/pkg/controller/nodes/node_exec_context.go @@ -6,23 +6,90 @@ import ( "strconv" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" - "github.com/flyteorg/flytestdlib/storage" - "k8s.io/apimachinery/pkg/types" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/ioutils" "github.com/flyteorg/flytepropeller/events" + eventsErr "github.com/flyteorg/flytepropeller/events/errors" "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/flyteorg/flytepropeller/pkg/controller/config" "github.com/flyteorg/flytepropeller/pkg/controller/executors" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + nodeerrors "github.com/flyteorg/flytepropeller/pkg/controller/nodes/errors" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" "github.com/flyteorg/flytepropeller/pkg/utils" + + "github.com/flyteorg/flytestdlib/logger" + "github.com/flyteorg/flytestdlib/storage" + + "github.com/pkg/errors" + + "k8s.io/apimachinery/pkg/types" ) const NodeIDLabel = "node-id" const TaskNameLabel = "task-name" const NodeInterruptibleLabel = "interruptible" +type eventRecorder struct { + taskEventRecorder events.TaskEventRecorder + nodeEventRecorder events.NodeEventRecorder +} + +func (e eventRecorder) RecordTaskEvent(ctx context.Context, ev *event.TaskExecutionEvent, eventConfig *config.EventConfig) error { + if err := e.taskEventRecorder.RecordTaskEvent(ctx, ev, eventConfig); err != nil { + if eventsErr.IsAlreadyExists(err) { + logger.Warningf(ctx, "Failed to record taskEvent, error [%s]. Trying to record state: %s. Ignoring this error!", err.Error(), ev.Phase) + return nil + } else if eventsErr.IsEventAlreadyInTerminalStateError(err) { + if IsTerminalTaskPhase(ev.Phase) { + // Event is terminal and the stored value in flyteadmin is already terminal. This implies aborted case. So ignoring + logger.Warningf(ctx, "Failed to record taskEvent, error [%s]. Trying to record state: %s. Ignoring this error!", err.Error(), ev.Phase) + return nil + } + logger.Warningf(ctx, "Failed to record taskEvent in state: %s, error: %s", ev.Phase, err) + return errors.Wrapf(err, "failed to record task event, as it already exists in terminal state. Event state: %s", ev.Phase) + } + return err + } + return nil +} + +func (e eventRecorder) RecordNodeEvent(ctx context.Context, nodeEvent *event.NodeExecutionEvent, eventConfig *config.EventConfig) error { + if nodeEvent == nil { + return fmt.Errorf("event recording attempt of Nil Node execution event") + } + + if nodeEvent.Id == nil { + return fmt.Errorf("event recording attempt of with nil node Event ID") + } + + logger.Infof(ctx, "Recording NodeEvent [%s] phase[%s]", nodeEvent.GetId().String(), nodeEvent.Phase.String()) + err := e.nodeEventRecorder.RecordNodeEvent(ctx, nodeEvent, eventConfig) + if err != nil { + if nodeEvent.GetId().NodeId == v1alpha1.EndNodeID { + return nil + } + + if eventsErr.IsAlreadyExists(err) { + logger.Infof(ctx, "Node event phase: %s, nodeId %s already exist", + nodeEvent.Phase.String(), nodeEvent.GetId().NodeId) + return nil + } else if eventsErr.IsEventAlreadyInTerminalStateError(err) { + if IsTerminalNodePhase(nodeEvent.Phase) { + // Event was trying to record a different terminal phase for an already terminal event. ignoring. + logger.Infof(ctx, "Node event phase: %s, nodeId %s already in terminal phase. err: %s", + nodeEvent.Phase.String(), nodeEvent.GetId().NodeId, err.Error()) + return nil + } + logger.Warningf(ctx, "Failed to record nodeEvent, error [%s]", err.Error()) + return nodeerrors.Wrapf(nodeerrors.IllegalStateError, nodeEvent.Id.NodeId, err, "phase mis-match mismatch between propeller and control plane; Trying to record Node p: %s", nodeEvent.Phase) + } + } + return err +} + type nodeExecMetadata struct { v1alpha1.Meta nodeExecID *core.NodeExecutionIdentifier @@ -57,9 +124,9 @@ func (e nodeExecMetadata) GetLabels() map[string]string { type nodeExecContext struct { store *storage.DataStore - tr handler.TaskReader - md handler.NodeExecutionMetadata - er events.TaskEventRecorder + tr interfaces.TaskReader + md interfaces.NodeExecutionMetadata + eventRecorder interfaces.EventRecorder inputs io.InputReader node v1alpha1.ExecutableNode nodeStatus v1alpha1.ExecutableNodeStatus @@ -92,15 +159,15 @@ func (e nodeExecContext) EnqueueOwnerFunc() func() error { return e.enqueueOwner } -func (e nodeExecContext) TaskReader() handler.TaskReader { +func (e nodeExecContext) TaskReader() interfaces.TaskReader { return e.tr } -func (e nodeExecContext) NodeStateReader() handler.NodeStateReader { +func (e nodeExecContext) NodeStateReader() interfaces.NodeStateReader { return e.nsm } -func (e nodeExecContext) NodeStateWriter() handler.NodeStateWriter { +func (e nodeExecContext) NodeStateWriter() interfaces.NodeStateWriter { return e.nsm } @@ -112,8 +179,8 @@ func (e nodeExecContext) InputReader() io.InputReader { return e.inputs } -func (e nodeExecContext) EventsRecorder() events.TaskEventRecorder { - return e.er +func (e nodeExecContext) EventsRecorder() interfaces.EventRecorder { + return e.eventRecorder } func (e nodeExecContext) NodeID() v1alpha1.NodeID { @@ -132,7 +199,7 @@ func (e nodeExecContext) NodeStatus() v1alpha1.ExecutableNodeStatus { return e.nodeStatus } -func (e nodeExecContext) NodeExecutionMetadata() handler.NodeExecutionMetadata { +func (e nodeExecContext) NodeExecutionMetadata() interfaces.NodeExecutionMetadata { return e.md } @@ -142,7 +209,7 @@ func (e nodeExecContext) MaxDatasetSizeBytes() int64 { func newNodeExecContext(_ context.Context, store *storage.DataStore, execContext executors.ExecutionContext, nl executors.NodeLookup, node v1alpha1.ExecutableNode, nodeStatus v1alpha1.ExecutableNodeStatus, inputs io.InputReader, interruptible bool, interruptibleFailureThreshold uint32, - maxDatasetSize int64, er events.TaskEventRecorder, tr handler.TaskReader, nsm *nodeStateManager, + maxDatasetSize int64, taskEventRecorder events.TaskEventRecorder, nodeEventRecorder events.NodeEventRecorder, tr interfaces.TaskReader, nsm *nodeStateManager, enqueueOwner func() error, rawOutputPrefix storage.DataReference, outputShardSelector ioutils.ShardSelector) *nodeExecContext { md := nodeExecMetadata{ @@ -168,12 +235,15 @@ func newNodeExecContext(_ context.Context, store *storage.DataStore, execContext md.nodeLabels = nodeLabels return &nodeExecContext{ - md: md, - store: store, - node: node, - nodeStatus: nodeStatus, - inputs: inputs, - er: er, + md: md, + store: store, + node: node, + nodeStatus: nodeStatus, + inputs: inputs, + eventRecorder: &eventRecorder{ + taskEventRecorder: taskEventRecorder, + nodeEventRecorder: nodeEventRecorder, + }, maxDatasetSizeBytes: maxDatasetSize, tr: tr, nsm: nsm, @@ -185,14 +255,14 @@ func newNodeExecContext(_ context.Context, store *storage.DataStore, execContext } } -func (c *nodeExecutor) newNodeExecContextDefault(ctx context.Context, currentNodeID v1alpha1.NodeID, - executionContext executors.ExecutionContext, nl executors.NodeLookup) (*nodeExecContext, error) { +func (c *nodeExecutor) BuildNodeExecutionContext(ctx context.Context, executionContext executors.ExecutionContext, + nl executors.NodeLookup, currentNodeID v1alpha1.NodeID) (interfaces.NodeExecutionContext, error) { n, ok := nl.GetNode(currentNodeID) if !ok { return nil, fmt.Errorf("failed to find node with ID [%s] in execution [%s]", currentNodeID, executionContext.GetID()) } - var tr handler.TaskReader + var tr interfaces.TaskReader if n.GetKind() == v1alpha1.NodeKindTask { if n.GetTaskID() == nil { return nil, fmt.Errorf("bad state, no task-id defined for node [%s]", n.GetID()) @@ -243,7 +313,8 @@ func (c *nodeExecutor) newNodeExecContextDefault(ctx context.Context, currentNod interruptible, c.interruptibleFailureThreshold, c.maxDatasetSizeBytes, - &taskEventRecorder{TaskEventRecorder: c.taskRecorder}, + c.taskRecorder, + c.nodeRecorder, tr, newNodeStateManager(ctx, s), workflowEnqueuer, diff --git a/flytepropeller/pkg/controller/nodes/node_exec_context_test.go b/flytepropeller/pkg/controller/nodes/node_exec_context_test.go index 48fa664a37..707ce33f4c 100644 --- a/flytepropeller/pkg/controller/nodes/node_exec_context_test.go +++ b/flytepropeller/pkg/controller/nodes/node_exec_context_test.go @@ -2,22 +2,29 @@ package nodes import ( "context" + "fmt" "strconv" "testing" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" + + "github.com/flyteorg/flytepropeller/events" + eventsErr "github.com/flyteorg/flytepropeller/events/errors" + "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks" + "github.com/flyteorg/flytepropeller/pkg/controller/config" "github.com/flyteorg/flytepropeller/pkg/controller/executors" mocks2 "github.com/flyteorg/flytepropeller/pkg/controller/executors/mocks" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/ioutils" + "github.com/flyteorg/flytestdlib/promutils" "github.com/flyteorg/flytestdlib/promutils/labeled" "github.com/flyteorg/flytestdlib/storage" - "github.com/stretchr/testify/assert" - "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" - "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks" + "github.com/stretchr/testify/assert" ) type TaskReader struct{} @@ -28,6 +35,19 @@ func (t TaskReader) GetTaskID() *core.Identifier { return &core.Identifier{Project: "p", Domain: "d", Name: "task-name"} } +type fakeEventRecorder struct { + nodeErr error + taskErr error +} + +func (f fakeEventRecorder) RecordNodeEvent(ctx context.Context, event *event.NodeExecutionEvent, eventConfig *config.EventConfig) error { + return f.nodeErr +} + +func (f fakeEventRecorder) RecordTaskEvent(ctx context.Context, event *event.TaskExecutionEvent, eventConfig *config.EventConfig) error { + return f.taskErr +} + type parentInfo struct { executors.ImmutableParentInfo } @@ -90,7 +110,7 @@ func Test_NodeContext(t *testing.T) { s, _ := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) p := parentInfo{} execContext := executors.NewExecutionContext(w1, nil, nil, p, nil) - nCtx := newNodeExecContext(context.TODO(), s, execContext, w1, getTestNodeSpec(nil), nil, nil, false, 0, 2, nil, TaskReader{}, nil, nil, "s3://bucket", ioutils.NewConstantShardSelector([]string{"x"})) + nCtx := newNodeExecContext(context.TODO(), s, execContext, w1, getTestNodeSpec(nil), nil, nil, false, 0, 2, nil, nil, TaskReader{}, nil, nil, "s3://bucket", ioutils.NewConstantShardSelector([]string{"x"})) assert.Equal(t, "id", nCtx.NodeExecutionMetadata().GetLabels()["node-id"]) assert.Equal(t, "false", nCtx.NodeExecutionMetadata().GetLabels()["interruptible"]) assert.Equal(t, "task-name", nCtx.NodeExecutionMetadata().GetLabels()["task-name"]) @@ -118,14 +138,14 @@ func Test_NodeContextDefault(t *testing.T) { } p := parentInfo{} execContext := executors.NewExecutionContext(w1, w1, w1, p, nil) - nodeExecContext, err := nodeExecutor.newNodeExecContextDefault(context.Background(), "node-a", execContext, nodeLookup) + nodeExecContext, err := nodeExecutor.BuildNodeExecutionContext(context.Background(), execContext, nodeLookup, "node-a") assert.NoError(t, err) - assert.Equal(t, "s3://bucket-a", nodeExecContext.rawOutputPrefix.String()) + assert.Equal(t, "s3://bucket-a", nodeExecContext.RawOutputPrefix().String()) w1.RawOutputDataConfig.OutputLocationPrefix = "s3://bucket-b" - nodeExecContext, err = nodeExecutor.newNodeExecContextDefault(context.Background(), "node-a", execContext, nodeLookup) + nodeExecContext, err = nodeExecutor.BuildNodeExecutionContext(context.Background(), execContext, nodeLookup, "node-a") assert.NoError(t, err) - assert.Equal(t, "s3://bucket-b", nodeExecContext.rawOutputPrefix.String()) + assert.Equal(t, "s3://bucket-b", nodeExecContext.RawOutputPrefix().String()) } func Test_NodeContextDefaultInterruptible(t *testing.T) { @@ -148,10 +168,10 @@ func Test_NodeContextDefaultInterruptible(t *testing.T) { } verifyNodeExecContext := func(t *testing.T, executionContext executors.ExecutionContext, nl executors.NodeLookup, shouldBeInterruptible bool) { - nodeExecContext, err := nodeExecutor.newNodeExecContextDefault(context.Background(), "node-a", executionContext, nl) + nodeExecContext, err := nodeExecutor.BuildNodeExecutionContext(context.Background(), executionContext, nl, "node-a") assert.NoError(t, err) - assert.Equal(t, shouldBeInterruptible, nodeExecContext.md.IsInterruptible()) - labels := nodeExecContext.md.GetLabels() + assert.Equal(t, shouldBeInterruptible, nodeExecContext.NodeExecutionMetadata().IsInterruptible()) + labels := nodeExecContext.NodeExecutionMetadata().GetLabels() assert.Contains(t, labels, NodeInterruptibleLabel) assert.Equal(t, strconv.FormatBool(shouldBeInterruptible), labels[NodeInterruptibleLabel]) } @@ -284,3 +304,74 @@ func Test_NodeContextDefaultInterruptible(t *testing.T) { verifyNodeExecContext(t, execContext, nodeLookup, false) }) } + +func Test_NodeContext_RecordNodeEvent(t *testing.T) { + noErrRecorder := fakeEventRecorder{} + alreadyExistsError := fakeEventRecorder{nodeErr: &eventsErr.EventError{Code: eventsErr.AlreadyExists, Cause: fmt.Errorf("err")}} + inTerminalError := fakeEventRecorder{nodeErr: &eventsErr.EventError{Code: eventsErr.EventAlreadyInTerminalStateError, Cause: fmt.Errorf("err")}} + otherError := fakeEventRecorder{nodeErr: &eventsErr.EventError{Code: eventsErr.ResourceExhausted, Cause: fmt.Errorf("err")}} + + tests := []struct { + name string + rec events.NodeEventRecorder + p core.NodeExecution_Phase + wantErr bool + }{ + {"aborted-success", noErrRecorder, core.NodeExecution_ABORTED, false}, + {"aborted-failure", otherError, core.NodeExecution_ABORTED, true}, + {"aborted-already", alreadyExistsError, core.NodeExecution_ABORTED, false}, + {"aborted-terminal", inTerminalError, core.NodeExecution_ABORTED, false}, + {"running-terminal", inTerminalError, core.NodeExecution_RUNNING, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + eventRecorder := &eventRecorder{ + nodeEventRecorder: tt.rec, + } + + ev := &event.NodeExecutionEvent{ + Id: &core.NodeExecutionIdentifier{}, + Phase: tt.p, + ProducerId: "propeller", + } + if err := eventRecorder.RecordNodeEvent(context.TODO(), ev, &config.EventConfig{}); (err != nil) != tt.wantErr { + t.Errorf("RecordNodeEvent() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func Test_NodeContext_RecordTaskEvent(t1 *testing.T) { + noErrRecorder := fakeEventRecorder{} + alreadyExistsError := fakeEventRecorder{taskErr: &eventsErr.EventError{Code: eventsErr.AlreadyExists, Cause: fmt.Errorf("err")}} + inTerminalError := fakeEventRecorder{taskErr: &eventsErr.EventError{Code: eventsErr.EventAlreadyInTerminalStateError, Cause: fmt.Errorf("err")}} + otherError := fakeEventRecorder{taskErr: &eventsErr.EventError{Code: eventsErr.ResourceExhausted, Cause: fmt.Errorf("err")}} + + tests := []struct { + name string + rec events.TaskEventRecorder + p core.TaskExecution_Phase + wantErr bool + }{ + {"aborted-success", noErrRecorder, core.TaskExecution_ABORTED, false}, + {"aborted-failure", otherError, core.TaskExecution_ABORTED, true}, + {"aborted-already", alreadyExistsError, core.TaskExecution_ABORTED, false}, + {"aborted-terminal", inTerminalError, core.TaskExecution_ABORTED, false}, + {"running-terminal", inTerminalError, core.TaskExecution_RUNNING, true}, + } + for _, tt := range tests { + t1.Run(tt.name, func(t1 *testing.T) { + t := &eventRecorder{ + taskEventRecorder: tt.rec, + } + ev := &event.TaskExecutionEvent{ + Phase: tt.p, + } + if err := t.RecordTaskEvent(context.TODO(), ev, &config.EventConfig{ + RawOutputPolicy: config.RawOutputPolicyReference, + }); (err != nil) != tt.wantErr { + t1.Errorf("RecordTaskEvent() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/flytepropeller/pkg/controller/nodes/node_state_manager.go b/flytepropeller/pkg/controller/nodes/node_state_manager.go index 843c3b1f9b..7a961fed54 100644 --- a/flytepropeller/pkg/controller/nodes/node_state_manager.go +++ b/flytepropeller/pkg/controller/nodes/node_state_manager.go @@ -16,6 +16,7 @@ type nodeStateManager struct { d *handler.DynamicNodeState w *handler.WorkflowNodeState g *handler.GateNodeState + a *handler.ArrayNodeState } func (n *nodeStateManager) PutTaskNodeState(s handler.TaskNodeState) error { @@ -43,7 +44,40 @@ func (n *nodeStateManager) PutGateNodeState(s handler.GateNodeState) error { return nil } +func (n *nodeStateManager) PutArrayNodeState(s handler.ArrayNodeState) error { + n.a = &s + return nil +} + +func (n *nodeStateManager) HasTaskNodeState() bool { + return n.t != nil +} + +func (n *nodeStateManager) HasBranchNodeState() bool { + return n.b != nil +} + +func (n *nodeStateManager) HasDynamicNodeState() bool { + return n.d != nil +} + +func (n *nodeStateManager) HasWorkflowNodeState() bool { + return n.w != nil +} + +func (n *nodeStateManager) HasGateNodeState() bool { + return n.g != nil +} + +func (n *nodeStateManager) HasArrayNodeState() bool { + return n.a != nil +} + func (n nodeStateManager) GetTaskNodeState() handler.TaskNodeState { + if n.t != nil { + return *n.t + } + tn := n.nodeStatus.GetTaskNodeStatus() if tn != nil { return handler.TaskNodeState{ @@ -59,7 +93,11 @@ func (n nodeStateManager) GetTaskNodeState() handler.TaskNodeState { return handler.TaskNodeState{} } -func (n nodeStateManager) GetBranchNode() handler.BranchNodeState { +func (n nodeStateManager) GetBranchNodeState() handler.BranchNodeState { + if n.b != nil { + return *n.b + } + bn := n.nodeStatus.GetBranchStatus() bs := handler.BranchNodeState{} if bn != nil { @@ -70,6 +108,10 @@ func (n nodeStateManager) GetBranchNode() handler.BranchNodeState { } func (n nodeStateManager) GetDynamicNodeState() handler.DynamicNodeState { + if n.d != nil { + return *n.d + } + dn := n.nodeStatus.GetDynamicNodeStatus() ds := handler.DynamicNodeState{} if dn != nil { @@ -83,6 +125,10 @@ func (n nodeStateManager) GetDynamicNodeState() handler.DynamicNodeState { } func (n nodeStateManager) GetWorkflowNodeState() handler.WorkflowNodeState { + if n.w != nil { + return *n.w + } + wn := n.nodeStatus.GetWorkflowNodeStatus() ws := handler.WorkflowNodeState{} if wn != nil { @@ -93,6 +139,10 @@ func (n nodeStateManager) GetWorkflowNodeState() handler.WorkflowNodeState { } func (n nodeStateManager) GetGateNodeState() handler.GateNodeState { + if n.g != nil { + return *n.g + } + gn := n.nodeStatus.GetGateNodeStatus() gs := handler.GateNodeState{} if gn != nil { @@ -101,12 +151,32 @@ func (n nodeStateManager) GetGateNodeState() handler.GateNodeState { return gs } -func (n *nodeStateManager) clearNodeStatus() { +func (n nodeStateManager) GetArrayNodeState() handler.ArrayNodeState { + if n.a != nil { + return *n.a + } + + an := n.nodeStatus.GetArrayNodeStatus() + as := handler.ArrayNodeState{} + if an != nil { + as.Phase = an.GetArrayNodePhase() + as.Error = an.GetExecutionError() + as.SubNodePhases = an.GetSubNodePhases() + as.SubNodeTaskPhases = an.GetSubNodeTaskPhases() + as.SubNodeRetryAttempts = an.GetSubNodeRetryAttempts() + as.SubNodeSystemFailures = an.GetSubNodeSystemFailures() + as.TaskPhaseVersion = an.GetTaskPhaseVersion() + } + return as +} + +func (n *nodeStateManager) ClearNodeStatus() { n.t = nil n.b = nil n.d = nil n.w = nil n.g = nil + n.a = nil n.nodeStatus.ClearLastAttemptStartedAt() } diff --git a/flytepropeller/pkg/controller/nodes/setup_context.go b/flytepropeller/pkg/controller/nodes/setup_context.go index ef192f453c..d398637633 100644 --- a/flytepropeller/pkg/controller/nodes/setup_context.go +++ b/flytepropeller/pkg/controller/nodes/setup_context.go @@ -6,7 +6,7 @@ import ( "github.com/flyteorg/flytestdlib/promutils" "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" ) type setupContext struct { @@ -26,7 +26,7 @@ func (s *setupContext) MetricsScope() promutils.Scope { return s.scope } -func (c *nodeExecutor) newSetupContext(_ context.Context) handler.SetupContext { +func (c *recursiveNodeExecutor) newSetupContext(_ context.Context) interfaces.SetupContext { return &setupContext{ enq: c.enqueueWorkflow, scope: c.metrics.Scope, diff --git a/flytepropeller/pkg/controller/nodes/start/handler.go b/flytepropeller/pkg/controller/nodes/start/handler.go index f1dda96f49..a8535b8fd9 100644 --- a/flytepropeller/pkg/controller/nodes/start/handler.go +++ b/flytepropeller/pkg/controller/nodes/start/handler.go @@ -4,6 +4,7 @@ import ( "context" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" ) type startHandler struct { @@ -13,22 +14,22 @@ func (s startHandler) FinalizeRequired() bool { return false } -func (s startHandler) Setup(ctx context.Context, setupContext handler.SetupContext) error { +func (s startHandler) Setup(ctx context.Context, setupContext interfaces.SetupContext) error { return nil } -func (s startHandler) Handle(ctx context.Context, executionContext handler.NodeExecutionContext) (handler.Transition, error) { +func (s startHandler) Handle(ctx context.Context, executionContext interfaces.NodeExecutionContext) (handler.Transition, error) { return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoSuccess(&handler.ExecutionInfo{})), nil } -func (s startHandler) Abort(ctx context.Context, executionContext handler.NodeExecutionContext, reason string) error { +func (s startHandler) Abort(ctx context.Context, executionContext interfaces.NodeExecutionContext, reason string) error { return nil } -func (s startHandler) Finalize(ctx context.Context, executionContext handler.NodeExecutionContext) error { +func (s startHandler) Finalize(ctx context.Context, executionContext interfaces.NodeExecutionContext) error { return nil } -func New() handler.Node { +func New() interfaces.NodeHandler { return &startHandler{} } diff --git a/flytepropeller/pkg/controller/nodes/subworkflow/handler.go b/flytepropeller/pkg/controller/nodes/subworkflow/handler.go index bf7b5c393b..5f478e3986 100644 --- a/flytepropeller/pkg/controller/nodes/subworkflow/handler.go +++ b/flytepropeller/pkg/controller/nodes/subworkflow/handler.go @@ -14,9 +14,9 @@ import ( "github.com/flyteorg/flytestdlib/promutils/labeled" "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" - "github.com/flyteorg/flytepropeller/pkg/controller/executors" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/errors" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/subworkflow/launchplan" ) @@ -40,11 +40,11 @@ func (w *workflowNodeHandler) FinalizeRequired() bool { return false } -func (w *workflowNodeHandler) Setup(_ context.Context, _ handler.SetupContext) error { +func (w *workflowNodeHandler) Setup(_ context.Context, _ interfaces.SetupContext) error { return nil } -func (w *workflowNodeHandler) Handle(ctx context.Context, nCtx handler.NodeExecutionContext) (handler.Transition, error) { +func (w *workflowNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecutionContext) (handler.Transition, error) { logger.Debug(ctx, "Starting workflow Node") invalidWFNodeError := func() (handler.Transition, error) { @@ -112,7 +112,7 @@ func (w *workflowNodeHandler) Handle(ctx context.Context, nCtx handler.NodeExecu return invalidWFNodeError() } -func (w *workflowNodeHandler) Abort(ctx context.Context, nCtx handler.NodeExecutionContext, reason string) error { +func (w *workflowNodeHandler) Abort(ctx context.Context, nCtx interfaces.NodeExecutionContext, reason string) error { wfNode := nCtx.Node().GetWorkflowNode() if wfNode.GetSubWorkflowRef() != nil { return w.subWfHandler.HandleAbort(ctx, nCtx, reason) @@ -124,12 +124,12 @@ func (w *workflowNodeHandler) Abort(ctx context.Context, nCtx handler.NodeExecut return nil } -func (w *workflowNodeHandler) Finalize(ctx context.Context, _ handler.NodeExecutionContext) error { +func (w *workflowNodeHandler) Finalize(ctx context.Context, _ interfaces.NodeExecutionContext) error { logger.Warnf(ctx, "Subworkflow finalize invoked. Nothing to be done") return nil } -func New(executor executors.Node, workflowLauncher launchplan.Executor, recoveryClient recovery.Client, eventConfig *config.EventConfig, scope promutils.Scope) handler.Node { +func New(executor interfaces.Node, workflowLauncher launchplan.Executor, recoveryClient recovery.Client, eventConfig *config.EventConfig, scope promutils.Scope) interfaces.NodeHandler { workflowScope := scope.NewSubScope("workflow") return &workflowNodeHandler{ subWfHandler: newSubworkflowHandler(executor, eventConfig), diff --git a/flytepropeller/pkg/controller/nodes/subworkflow/handler_test.go b/flytepropeller/pkg/controller/nodes/subworkflow/handler_test.go index c776de01ca..20e40fdc03 100644 --- a/flytepropeller/pkg/controller/nodes/subworkflow/handler_test.go +++ b/flytepropeller/pkg/controller/nodes/subworkflow/handler_test.go @@ -26,7 +26,7 @@ import ( mocks2 "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks" execMocks "github.com/flyteorg/flytepropeller/pkg/controller/executors/mocks" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" - mocks3 "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler/mocks" + mocks3 "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces/mocks" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/subworkflow/launchplan" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/mocks" ) @@ -39,6 +39,9 @@ var eventConfig = &config.EventConfig{ RawOutputPolicy: config.RawOutputPolicyReference, } +func (t *workflowNodeStateHolder) ClearNodeStatus() { +} + func (t *workflowNodeStateHolder) PutTaskNodeState(s handler.TaskNodeState) error { panic("not implemented") } @@ -60,6 +63,10 @@ func (t workflowNodeStateHolder) PutGateNodeState(s handler.GateNodeState) error panic("not implemented") } +func (t workflowNodeStateHolder) PutArrayNodeState(s handler.ArrayNodeState) error { + panic("not implemented") +} + var wfExecID = &core.WorkflowExecutionIdentifier{ Project: "project", Domain: "domain", diff --git a/flytepropeller/pkg/controller/nodes/subworkflow/launchplan.go b/flytepropeller/pkg/controller/nodes/subworkflow/launchplan.go index 2550dbc065..9f44dc9cbf 100644 --- a/flytepropeller/pkg/controller/nodes/subworkflow/launchplan.go +++ b/flytepropeller/pkg/controller/nodes/subworkflow/launchplan.go @@ -4,22 +4,22 @@ import ( "context" "fmt" - "github.com/flyteorg/flytepropeller/pkg/controller/config" - - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/recovery" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" - - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/common" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" - "github.com/flyteorg/flytestdlib/logger" - "github.com/flyteorg/flytestdlib/storage" "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/flyteorg/flytepropeller/pkg/controller/config" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/common" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/errors" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/recovery" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/subworkflow/launchplan" + + "github.com/flyteorg/flytestdlib/logger" + "github.com/flyteorg/flytestdlib/storage" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) type launchPlanHandler struct { @@ -28,7 +28,7 @@ type launchPlanHandler struct { eventConfig *config.EventConfig } -func getParentNodeExecutionID(nCtx handler.NodeExecutionContext) (*core.NodeExecutionIdentifier, error) { +func getParentNodeExecutionID(nCtx interfaces.NodeExecutionContext) (*core.NodeExecutionIdentifier, error) { nodeExecID := &core.NodeExecutionIdentifier{ ExecutionId: nCtx.NodeExecutionMetadata().GetNodeExecutionID().ExecutionId, } @@ -45,7 +45,7 @@ func getParentNodeExecutionID(nCtx handler.NodeExecutionContext) (*core.NodeExec return nodeExecID, nil } -func (l *launchPlanHandler) StartLaunchPlan(ctx context.Context, nCtx handler.NodeExecutionContext) (handler.Transition, error) { +func (l *launchPlanHandler) StartLaunchPlan(ctx context.Context, nCtx interfaces.NodeExecutionContext) (handler.Transition, error) { nodeInputs, err := nCtx.InputReader().Get(ctx) if err != nil { errMsg := fmt.Sprintf("Failed to read input. Error [%s]", err) @@ -125,7 +125,7 @@ func (l *launchPlanHandler) StartLaunchPlan(ctx context.Context, nCtx handler.No })), nil } -func GetChildWorkflowExecutionIDForExecution(parentNodeExecID *core.NodeExecutionIdentifier, nCtx handler.NodeExecutionContext) (*core.WorkflowExecutionIdentifier, error) { +func GetChildWorkflowExecutionIDForExecution(parentNodeExecID *core.NodeExecutionIdentifier, nCtx interfaces.NodeExecutionContext) (*core.WorkflowExecutionIdentifier, error) { // Handle launch plan if nCtx.ExecutionContext().GetDefinitionVersion() == v1alpha1.WorkflowDefinitionVersion0 { return GetChildWorkflowExecutionID( @@ -140,7 +140,7 @@ func GetChildWorkflowExecutionIDForExecution(parentNodeExecID *core.NodeExecutio ) } -func (l *launchPlanHandler) CheckLaunchPlanStatus(ctx context.Context, nCtx handler.NodeExecutionContext) (handler.Transition, error) { +func (l *launchPlanHandler) CheckLaunchPlanStatus(ctx context.Context, nCtx interfaces.NodeExecutionContext) (handler.Transition, error) { parentNodeExecutionID, err := getParentNodeExecutionID(nCtx) if err != nil { return handler.UnknownTransition, err @@ -218,7 +218,7 @@ func (l *launchPlanHandler) CheckLaunchPlanStatus(ctx context.Context, nCtx hand return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoRunning(nil)), nil } -func (l *launchPlanHandler) HandleAbort(ctx context.Context, nCtx handler.NodeExecutionContext, reason string) error { +func (l *launchPlanHandler) HandleAbort(ctx context.Context, nCtx interfaces.NodeExecutionContext, reason string) error { parentNodeExecutionID, err := getParentNodeExecutionID(nCtx) if err != nil { return err diff --git a/flytepropeller/pkg/controller/nodes/subworkflow/launchplan_test.go b/flytepropeller/pkg/controller/nodes/subworkflow/launchplan_test.go index 8416ca6997..96e50280ec 100644 --- a/flytepropeller/pkg/controller/nodes/subworkflow/launchplan_test.go +++ b/flytepropeller/pkg/controller/nodes/subworkflow/launchplan_test.go @@ -7,7 +7,7 @@ import ( "testing" mocks4 "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io/mocks" - mocks3 "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler/mocks" + mocks3 "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces/mocks" v1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" diff --git a/flytepropeller/pkg/controller/nodes/subworkflow/subworkflow.go b/flytepropeller/pkg/controller/nodes/subworkflow/subworkflow.go index 74beeaf792..d0dad95b9d 100644 --- a/flytepropeller/pkg/controller/nodes/subworkflow/subworkflow.go +++ b/flytepropeller/pkg/controller/nodes/subworkflow/subworkflow.go @@ -4,28 +4,28 @@ import ( "context" "fmt" - "github.com/flyteorg/flytepropeller/pkg/controller/config" - - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/common" - "github.com/flyteorg/flytestdlib/logger" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" - "github.com/flyteorg/flytestdlib/storage" "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/flyteorg/flytepropeller/pkg/controller/config" "github.com/flyteorg/flytepropeller/pkg/controller/executors" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/common" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/errors" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" + + "github.com/flyteorg/flytestdlib/logger" + "github.com/flyteorg/flytestdlib/storage" ) // Subworkflow handler handles inline subWorkflows type subworkflowHandler struct { - nodeExecutor executors.Node + nodeExecutor interfaces.Node eventConfig *config.EventConfig } // Helper method that extracts the SubWorkflow from the ExecutionContext -func GetSubWorkflow(ctx context.Context, nCtx handler.NodeExecutionContext) (v1alpha1.ExecutableSubWorkflow, error) { +func GetSubWorkflow(ctx context.Context, nCtx interfaces.NodeExecutionContext) (v1alpha1.ExecutableSubWorkflow, error) { node := nCtx.Node() subID := *node.GetWorkflowNode().GetSubWorkflowRef() subWorkflow := nCtx.ExecutionContext().FindSubWorkflow(subID) @@ -36,7 +36,7 @@ func GetSubWorkflow(ctx context.Context, nCtx handler.NodeExecutionContext) (v1a } // Performs an additional step of passing in and setting the inputs, before handling the execution of a SubWorkflow. -func (s *subworkflowHandler) startAndHandleSubWorkflow(ctx context.Context, nCtx handler.NodeExecutionContext, subWorkflow v1alpha1.ExecutableSubWorkflow, nl executors.NodeLookup) (handler.Transition, error) { +func (s *subworkflowHandler) startAndHandleSubWorkflow(ctx context.Context, nCtx interfaces.NodeExecutionContext, subWorkflow v1alpha1.ExecutableSubWorkflow, nl executors.NodeLookup) (handler.Transition, error) { // Before starting the subworkflow, lets set the inputs for the Workflow. The inputs for a SubWorkflow are essentially // Copy of the inputs to the Node nodeInputs, err := nCtx.InputReader().Get(ctx) @@ -63,7 +63,7 @@ func (s *subworkflowHandler) startAndHandleSubWorkflow(ctx context.Context, nCtx } // Calls the recursive node executor to handle the SubWorkflow and translates the results after the success -func (s *subworkflowHandler) handleSubWorkflow(ctx context.Context, nCtx handler.NodeExecutionContext, subworkflow v1alpha1.ExecutableSubWorkflow, nl executors.NodeLookup) (handler.Transition, error) { +func (s *subworkflowHandler) handleSubWorkflow(ctx context.Context, nCtx interfaces.NodeExecutionContext, subworkflow v1alpha1.ExecutableSubWorkflow, nl executors.NodeLookup) (handler.Transition, error) { // The current node would end up becoming the parent for the sub workflow nodes. // This is done to track the lineage. For level zero, the CreateParentInfo will return nil newParentInfo, err := common.CreateParentInfo(nCtx.ExecutionContext().GetParentInfo(), nCtx.NodeID(), nCtx.CurrentAttempt()) @@ -135,7 +135,7 @@ func (s *subworkflowHandler) handleSubWorkflow(ctx context.Context, nCtx handler return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoRunning(nil)), nil } -func (s *subworkflowHandler) getExecutionContextForDownstream(nCtx handler.NodeExecutionContext) (executors.ExecutionContext, error) { +func (s *subworkflowHandler) getExecutionContextForDownstream(nCtx interfaces.NodeExecutionContext) (executors.ExecutionContext, error) { newParentInfo, err := common.CreateParentInfo(nCtx.ExecutionContext().GetParentInfo(), nCtx.NodeID(), nCtx.CurrentAttempt()) if err != nil { return nil, err @@ -143,7 +143,7 @@ func (s *subworkflowHandler) getExecutionContextForDownstream(nCtx handler.NodeE return executors.NewExecutionContextWithParentInfo(nCtx.ExecutionContext(), newParentInfo), nil } -func (s *subworkflowHandler) HandleFailureNodeOfSubWorkflow(ctx context.Context, nCtx handler.NodeExecutionContext, subworkflow v1alpha1.ExecutableSubWorkflow, nl executors.NodeLookup) (handler.Transition, error) { +func (s *subworkflowHandler) HandleFailureNodeOfSubWorkflow(ctx context.Context, nCtx interfaces.NodeExecutionContext, subworkflow v1alpha1.ExecutableSubWorkflow, nl executors.NodeLookup) (handler.Transition, error) { originalError := nCtx.NodeStateReader().GetWorkflowNodeState().Error if subworkflow.GetOnFailureNode() != nil { execContext, err := s.getExecutionContextForDownstream(nCtx) @@ -155,7 +155,7 @@ func (s *subworkflowHandler) HandleFailureNodeOfSubWorkflow(ctx context.Context, return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoUndefined), err } - if state.NodePhase == executors.NodePhaseRunning { + if state.NodePhase == interfaces.NodePhaseRunning { return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoRunning(nil)), nil } @@ -185,7 +185,7 @@ func (s *subworkflowHandler) HandleFailureNodeOfSubWorkflow(ctx context.Context, originalError, nil)), nil } -func (s *subworkflowHandler) HandleFailingSubWorkflow(ctx context.Context, nCtx handler.NodeExecutionContext) (handler.Transition, error) { +func (s *subworkflowHandler) HandleFailingSubWorkflow(ctx context.Context, nCtx interfaces.NodeExecutionContext) (handler.Transition, error) { subWorkflow, err := GetSubWorkflow(ctx, nCtx) if err != nil { return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(core.ExecutionError_SYSTEM, errors.SubWorkflowExecutionFailed, err.Error(), nil)), nil @@ -213,7 +213,7 @@ func (s *subworkflowHandler) HandleFailingSubWorkflow(ctx context.Context, nCtx return s.HandleFailureNodeOfSubWorkflow(ctx, nCtx, subWorkflow, nodeLookup) } -func (s *subworkflowHandler) StartSubWorkflow(ctx context.Context, nCtx handler.NodeExecutionContext) (handler.Transition, error) { +func (s *subworkflowHandler) StartSubWorkflow(ctx context.Context, nCtx interfaces.NodeExecutionContext) (handler.Transition, error) { subWorkflow, err := GetSubWorkflow(ctx, nCtx) if err != nil { return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(core.ExecutionError_SYSTEM, errors.SubWorkflowExecutionFailed, err.Error(), nil)), nil @@ -226,7 +226,7 @@ func (s *subworkflowHandler) StartSubWorkflow(ctx context.Context, nCtx handler. return s.startAndHandleSubWorkflow(ctx, nCtx, subWorkflow, nodeLookup) } -func (s *subworkflowHandler) CheckSubWorkflowStatus(ctx context.Context, nCtx handler.NodeExecutionContext) (handler.Transition, error) { +func (s *subworkflowHandler) CheckSubWorkflowStatus(ctx context.Context, nCtx interfaces.NodeExecutionContext) (handler.Transition, error) { subWorkflow, err := GetSubWorkflow(ctx, nCtx) if err != nil { return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(core.ExecutionError_SYSTEM, errors.SubWorkflowExecutionFailed, err.Error(), nil)), nil @@ -237,7 +237,7 @@ func (s *subworkflowHandler) CheckSubWorkflowStatus(ctx context.Context, nCtx ha return s.handleSubWorkflow(ctx, nCtx, subWorkflow, nodeLookup) } -func (s *subworkflowHandler) HandleAbort(ctx context.Context, nCtx handler.NodeExecutionContext, reason string) error { +func (s *subworkflowHandler) HandleAbort(ctx context.Context, nCtx interfaces.NodeExecutionContext, reason string) error { subWorkflow, err := GetSubWorkflow(ctx, nCtx) if err != nil { return err @@ -251,7 +251,7 @@ func (s *subworkflowHandler) HandleAbort(ctx context.Context, nCtx handler.NodeE return s.nodeExecutor.AbortHandler(ctx, execContext, subWorkflow, nodeLookup, subWorkflow.StartNode(), reason) } -func newSubworkflowHandler(nodeExecutor executors.Node, eventConfig *config.EventConfig) subworkflowHandler { +func newSubworkflowHandler(nodeExecutor interfaces.Node, eventConfig *config.EventConfig) subworkflowHandler { return subworkflowHandler{ nodeExecutor: nodeExecutor, eventConfig: eventConfig, diff --git a/flytepropeller/pkg/controller/nodes/subworkflow/subworkflow_test.go b/flytepropeller/pkg/controller/nodes/subworkflow/subworkflow_test.go index 50840776f7..def32c198d 100644 --- a/flytepropeller/pkg/controller/nodes/subworkflow/subworkflow_test.go +++ b/flytepropeller/pkg/controller/nodes/subworkflow/subworkflow_test.go @@ -13,7 +13,7 @@ import ( coreMocks "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks" execMocks "github.com/flyteorg/flytepropeller/pkg/controller/executors/mocks" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler/mocks" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces/mocks" ) func TestGetSubWorkflow(t *testing.T) { @@ -87,7 +87,7 @@ func Test_subworkflowHandler_HandleAbort(t *testing.T) { nCtx.OnNodeStatus().Return(ns) nCtx.OnNodeID().Return("n1") - nodeExec := &execMocks.Node{} + nodeExec := &mocks.Node{} s := newSubworkflowHandler(nodeExec, eventConfig) n := &coreMocks.ExecutableNode{} swf.OnGetID().Return("swf") @@ -120,7 +120,7 @@ func Test_subworkflowHandler_HandleAbort(t *testing.T) { nCtx.OnNodeID().Return("n1") nCtx.OnCurrentAttempt().Return(uint32(1)) - nodeExec := &execMocks.Node{} + nodeExec := &mocks.Node{} s := newSubworkflowHandler(nodeExec, eventConfig) n := &coreMocks.ExecutableNode{} swf.OnGetID().Return("swf") @@ -154,7 +154,7 @@ func Test_subworkflowHandler_HandleAbort(t *testing.T) { nCtx.OnNodeID().Return("n1") nCtx.OnCurrentAttempt().Return(uint32(1)) - nodeExec := &execMocks.Node{} + nodeExec := &mocks.Node{} s := newSubworkflowHandler(nodeExec, eventConfig) n := &coreMocks.ExecutableNode{} swf.OnGetID().Return("swf") diff --git a/flytepropeller/pkg/controller/nodes/task/handler.go b/flytepropeller/pkg/controller/nodes/task/handler.go index e0faf8f85d..c3b8833fa7 100644 --- a/flytepropeller/pkg/controller/nodes/task/handler.go +++ b/flytepropeller/pkg/controller/nodes/task/handler.go @@ -6,39 +6,39 @@ import ( "runtime/debug" "time" - eventsErr "github.com/flyteorg/flytepropeller/events/errors" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" - "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" - - "github.com/flyteorg/flytepropeller/pkg/utils" - - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/ioutils" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" + pluginMachinery "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/catalog" pluginCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/ioutils" pluginK8s "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/k8s" + + eventsErr "github.com/flyteorg/flytepropeller/events/errors" + "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" controllerConfig "github.com/flyteorg/flytepropeller/pkg/controller/config" + "github.com/flyteorg/flytepropeller/pkg/controller/executors" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/errors" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/config" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/resourcemanager" + rmConfig "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/resourcemanager/config" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/secretmanager" + "github.com/flyteorg/flytepropeller/pkg/utils" + "github.com/flyteorg/flytestdlib/contextutils" "github.com/flyteorg/flytestdlib/logger" "github.com/flyteorg/flytestdlib/promutils" "github.com/flyteorg/flytestdlib/promutils/labeled" "github.com/flyteorg/flytestdlib/storage" - "github.com/golang/protobuf/ptypes" - regErrors "github.com/pkg/errors" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/resourcemanager" - rmConfig "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/resourcemanager/config" + "github.com/golang/protobuf/ptypes" - "github.com/flyteorg/flytepropeller/pkg/controller/executors" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/errors" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/config" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/secretmanager" + regErrors "github.com/pkg/errors" ) const pluginContextKey = contextutils.Key("plugin") @@ -227,7 +227,7 @@ func (t *Handler) setDefault(ctx context.Context, p pluginCore.Plugin) error { return nil } -func (t *Handler) Setup(ctx context.Context, sCtx handler.SetupContext) error { +func (t *Handler) Setup(ctx context.Context, sCtx interfaces.SetupContext) error { tSCtx := t.newSetupContext(sCtx) // Create a new base resource negotiator @@ -532,7 +532,7 @@ func (t Handler) invokePlugin(ctx context.Context, p pluginCore.Plugin, tCtx *ta return pluginTrns, nil } -func (t Handler) Handle(ctx context.Context, nCtx handler.NodeExecutionContext) (handler.Transition, error) { +func (t Handler) Handle(ctx context.Context, nCtx interfaces.NodeExecutionContext) (handler.Transition, error) { ttype := nCtx.TaskReader().GetTaskType() ctx = contextutils.WithTaskType(ctx, ttype) p, err := t.ResolvePlugin(ctx, ttype, nCtx.ExecutionContext().GetExecutionConfig()) @@ -553,6 +553,14 @@ func (t Handler) Handle(ctx context.Context, nCtx handler.NodeExecutionContext) ts := nCtx.NodeStateReader().GetTaskNodeState() pluginTrns := &pluginRequestedTransition{} + defer func() { + // increment parallelism if the final pluginTrns is not in a terminal state + if pluginTrns != nil && !pluginTrns.pInfo.Phase().IsTerminal() { + eCtx := nCtx.ExecutionContext() + logger.Infof(ctx, "Parallelism now set to [%d].", eCtx.IncrementParallelism()) + } + }() + // We will start with the assumption that catalog is disabled pluginTrns.PopulateCacheInfo(catalog.NewFailedCatalogEntry(catalog.NewStatus(core.CatalogCacheStatus_CACHE_DISABLED, nil))) @@ -763,7 +771,7 @@ func (t Handler) Handle(ctx context.Context, nCtx handler.NodeExecutionContext) return pluginTrns.FinalTransition(ctx) } -func (t Handler) Abort(ctx context.Context, nCtx handler.NodeExecutionContext, reason string) error { +func (t Handler) Abort(ctx context.Context, nCtx interfaces.NodeExecutionContext, reason string) error { taskNodeState := nCtx.NodeStateReader().GetTaskNodeState() currentPhase := taskNodeState.PluginPhase logger.Debugf(ctx, "Abort invoked with phase [%v]", currentPhase) @@ -829,7 +837,7 @@ func (t Handler) Abort(ctx context.Context, nCtx handler.NodeExecutionContext, r return nil } -func (t Handler) Finalize(ctx context.Context, nCtx handler.NodeExecutionContext) error { +func (t Handler) Finalize(ctx context.Context, nCtx interfaces.NodeExecutionContext) error { logger.Debugf(ctx, "Finalize invoked.") ttype := nCtx.TaskReader().GetTaskType() p, err := t.ResolvePlugin(ctx, ttype, nCtx.ExecutionContext().GetExecutionConfig()) diff --git a/flytepropeller/pkg/controller/nodes/task/handler_test.go b/flytepropeller/pkg/controller/nodes/task/handler_test.go index 6da4762cbf..0a009cb658 100644 --- a/flytepropeller/pkg/controller/nodes/task/handler_test.go +++ b/flytepropeller/pkg/controller/nodes/task/handler_test.go @@ -10,9 +10,6 @@ import ( "github.com/golang/protobuf/proto" eventsErr "github.com/flyteorg/flytepropeller/events/errors" - mocks2 "github.com/flyteorg/flytepropeller/events/mocks" - - "github.com/flyteorg/flytepropeller/events" pluginK8sMocks "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/k8s/mocks" @@ -49,7 +46,8 @@ import ( flyteMocks "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks" "github.com/flyteorg/flytepropeller/pkg/controller/executors/mocks" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" - nodeMocks "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler/mocks" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" + nodeMocks "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces/mocks" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/codex" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/config" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/fakeplugins" @@ -352,19 +350,26 @@ func Test_task_ResolvePlugin(t *testing.T) { } } -type fakeBufferedTaskEventRecorder struct { +type fakeBufferedEventRecorder struct { evs []*event.TaskExecutionEvent } -func (f *fakeBufferedTaskEventRecorder) RecordTaskEvent(ctx context.Context, ev *event.TaskExecutionEvent, eventConfig *controllerConfig.EventConfig) error { +func (f *fakeBufferedEventRecorder) RecordTaskEvent(ctx context.Context, ev *event.TaskExecutionEvent, eventConfig *controllerConfig.EventConfig) error { f.evs = append(f.evs, ev) return nil } +func (f *fakeBufferedEventRecorder) RecordNodeEvent(ctx context.Context, ev *event.NodeExecutionEvent, eventConfig *controllerConfig.EventConfig) error { + return nil +} + type taskNodeStateHolder struct { s handler.TaskNodeState } +func (t *taskNodeStateHolder) ClearNodeStatus() { +} + func (t *taskNodeStateHolder) PutTaskNodeState(s handler.TaskNodeState) error { t.s = s return nil @@ -386,6 +391,10 @@ func (t taskNodeStateHolder) PutGateNodeState(s handler.GateNodeState) error { panic("not implemented") } +func (t taskNodeStateHolder) PutArrayNodeState(s handler.ArrayNodeState) error { + panic("not implemented") +} + func CreateNoopResourceManager(ctx context.Context, scope promutils.Scope) resourcemanager.BaseResourceManager { rmBuilder, _ := resourcemanager.GetResourceManagerBuilderByType(ctx, rmConfig.TypeNoop, scope) rm, _ := rmBuilder.BuildResourceManager(ctx) @@ -399,7 +408,7 @@ func Test_task_Handle_NoCatalog(t *testing.T) { "foo": coreutils.MustMakeLiteral("bar"), }, } - createNodeContext := func(pluginPhase pluginCore.Phase, pluginVer uint32, pluginResp fakeplugins.NextPhaseState, recorder events.TaskEventRecorder, ttype string, s *taskNodeStateHolder, allowIncrementParallelism bool) *nodeMocks.NodeExecutionContext { + createNodeContext := func(pluginPhase pluginCore.Phase, pluginVer uint32, pluginResp fakeplugins.NextPhaseState, recorder interfaces.EventRecorder, ttype string, s *taskNodeStateHolder, allowIncrementParallelism bool) *nodeMocks.NodeExecutionContext { wfExecID := &core.WorkflowExecutionIdentifier{ Project: "project", Domain: "domain", @@ -690,7 +699,7 @@ func Test_task_Handle_NoCatalog(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { state := &taskNodeStateHolder{} - ev := &fakeBufferedTaskEventRecorder{} + ev := &fakeBufferedEventRecorder{} nCtx := createNodeContext(tt.args.startingPluginPhase, uint32(tt.args.startingPluginPhaseVersion), tt.args.expectedState, ev, "test", state, tt.want.incrParallel) c := &pluginCatalogMocks.Client{} tk := Handler{ @@ -758,7 +767,7 @@ func Test_task_Handle_NoCatalog(t *testing.T) { func Test_task_Handle_Catalog(t *testing.T) { - createNodeContext := func(recorder events.TaskEventRecorder, ttype string, s *taskNodeStateHolder, overwriteCache bool) *nodeMocks.NodeExecutionContext { + createNodeContext := func(recorder interfaces.EventRecorder, ttype string, s *taskNodeStateHolder, overwriteCache bool) *nodeMocks.NodeExecutionContext { wfExecID := &core.WorkflowExecutionIdentifier{ Project: "project", Domain: "domain", @@ -950,7 +959,7 @@ func Test_task_Handle_Catalog(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { state := &taskNodeStateHolder{} - ev := &fakeBufferedTaskEventRecorder{} + ev := &fakeBufferedEventRecorder{} nCtx := createNodeContext(ev, "test", state, tt.args.catalogSkip) c := &pluginCatalogMocks.Client{} if tt.args.catalogFetch { @@ -1018,7 +1027,7 @@ func Test_task_Handle_Catalog(t *testing.T) { func Test_task_Handle_Reservation(t *testing.T) { - createNodeContext := func(recorder events.TaskEventRecorder, ttype string, s *taskNodeStateHolder, overwriteCache bool) *nodeMocks.NodeExecutionContext { + createNodeContext := func(recorder interfaces.EventRecorder, ttype string, s *taskNodeStateHolder, overwriteCache bool) *nodeMocks.NodeExecutionContext { wfExecID := &core.WorkflowExecutionIdentifier{ Project: "project", Domain: "domain", @@ -1206,7 +1215,7 @@ func Test_task_Handle_Reservation(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { state := &taskNodeStateHolder{} - ev := &fakeBufferedTaskEventRecorder{} + ev := &fakeBufferedEventRecorder{} nCtx := createNodeContext(ev, "test", state, tt.args.catalogSkip) c := &pluginCatalogMocks.Client{} nr := &nodeMocks.NodeStateReader{} @@ -1269,7 +1278,7 @@ func Test_task_Handle_Reservation(t *testing.T) { } func Test_task_Abort(t *testing.T) { - createNodeCtx := func(ev events.TaskEventRecorder) *nodeMocks.NodeExecutionContext { + createNodeCtx := func(ev interfaces.EventRecorder) *nodeMocks.NodeExecutionContext { wfExecID := &core.WorkflowExecutionIdentifier{ Project: "project", Domain: "domain", @@ -1355,7 +1364,7 @@ func Test_task_Abort(t *testing.T) { noopRm := CreateNoopResourceManager(context.TODO(), promutils.NewTestScope()) - incompatibleClusterEventsRecorder := mocks2.TaskEventRecorder{} + incompatibleClusterEventsRecorder := nodeMocks.EventRecorder{} incompatibleClusterEventsRecorder.OnRecordTaskEventMatch(mock.Anything, mock.Anything, mock.Anything).Return( &eventsErr.EventError{ Code: eventsErr.EventIncompatibleCusterError, @@ -1365,7 +1374,7 @@ func Test_task_Abort(t *testing.T) { defaultPluginCallback func() pluginCore.Plugin } type args struct { - ev events.TaskEventRecorder + ev interfaces.EventRecorder } tests := []struct { name string @@ -1391,7 +1400,7 @@ func Test_task_Abort(t *testing.T) { p.OnGetProperties().Return(pluginCore.PluginProperties{}) p.On("Abort", mock.Anything, mock.Anything).Return(nil) return p - }}, args{ev: &fakeBufferedTaskEventRecorder{}}, false, true}, + }}, args{ev: &fakeBufferedEventRecorder{}}, false, true}, {"abort-swallows-incompatible-cluster-err", fields{defaultPluginCallback: func() pluginCore.Plugin { p := &pluginCoreMocks.Plugin{} p.On("GetID").Return("id") @@ -1416,10 +1425,10 @@ func Test_task_Abort(t *testing.T) { c = 1 if !tt.wantErr { switch tt.args.ev.(type) { - case *fakeBufferedTaskEventRecorder: - assert.Len(t, tt.args.ev.(*fakeBufferedTaskEventRecorder).evs, 1) - case *mocks2.TaskEventRecorder: - assert.Len(t, tt.args.ev.(*mocks2.TaskEventRecorder).Calls, 1) + case *fakeBufferedEventRecorder: + assert.Len(t, tt.args.ev.(*fakeBufferedEventRecorder).evs, 1) + case *nodeMocks.EventRecorder: + assert.Len(t, tt.args.ev.(*nodeMocks.EventRecorder).Calls, 1) } } } @@ -1431,7 +1440,7 @@ func Test_task_Abort(t *testing.T) { } func Test_task_Abort_v1(t *testing.T) { - createNodeCtx := func(ev events.TaskEventRecorder) *nodeMocks.NodeExecutionContext { + createNodeCtx := func(ev interfaces.EventRecorder) *nodeMocks.NodeExecutionContext { wfExecID := &core.WorkflowExecutionIdentifier{ Project: "project", Domain: "domain", @@ -1517,7 +1526,7 @@ func Test_task_Abort_v1(t *testing.T) { noopRm := CreateNoopResourceManager(context.TODO(), promutils.NewTestScope()) - incompatibleClusterEventsRecorder := mocks2.TaskEventRecorder{} + incompatibleClusterEventsRecorder := nodeMocks.EventRecorder{} incompatibleClusterEventsRecorder.OnRecordTaskEventMatch(mock.Anything, mock.Anything, mock.Anything).Return( &eventsErr.EventError{ Code: eventsErr.EventIncompatibleCusterError, @@ -1527,7 +1536,7 @@ func Test_task_Abort_v1(t *testing.T) { defaultPluginCallback func() pluginCore.Plugin } type args struct { - ev events.TaskEventRecorder + ev interfaces.EventRecorder } tests := []struct { name string @@ -1553,7 +1562,7 @@ func Test_task_Abort_v1(t *testing.T) { p.OnGetProperties().Return(pluginCore.PluginProperties{}) p.On("Abort", mock.Anything, mock.Anything).Return(nil) return p - }}, args{ev: &fakeBufferedTaskEventRecorder{}}, false, true}, + }}, args{ev: &fakeBufferedEventRecorder{}}, false, true}, {"abort-swallows-incompatible-cluster-err", fields{defaultPluginCallback: func() pluginCore.Plugin { p := &pluginCoreMocks.Plugin{} p.On("GetID").Return("id") @@ -1578,10 +1587,10 @@ func Test_task_Abort_v1(t *testing.T) { c = 1 if !tt.wantErr { switch tt.args.ev.(type) { - case *fakeBufferedTaskEventRecorder: - assert.Len(t, tt.args.ev.(*fakeBufferedTaskEventRecorder).evs, 1) - case *mocks2.TaskEventRecorder: - assert.Len(t, tt.args.ev.(*mocks2.TaskEventRecorder).Calls, 1) + case *fakeBufferedEventRecorder: + assert.Len(t, tt.args.ev.(*fakeBufferedEventRecorder).evs, 1) + case *nodeMocks.EventRecorder: + assert.Len(t, tt.args.ev.(*nodeMocks.EventRecorder).Calls, 1) } } } diff --git a/flytepropeller/pkg/controller/nodes/task/k8s/plugin_manager.go b/flytepropeller/pkg/controller/nodes/task/k8s/plugin_manager.go index 37534dc32f..a3397354e7 100644 --- a/flytepropeller/pkg/controller/nodes/task/k8s/plugin_manager.go +++ b/flytepropeller/pkg/controller/nodes/task/k8s/plugin_manager.go @@ -185,7 +185,6 @@ func (e *PluginManager) getPodEffectiveResourceLimits(ctx context.Context, pod * } func (e *PluginManager) LaunchResource(ctx context.Context, tCtx pluginsCore.TaskExecutionContext) (pluginsCore.Transition, error) { - tmpl, err := tCtx.TaskReader().Read(ctx) if err != nil { return pluginsCore.Transition{}, err @@ -253,7 +252,6 @@ func (e *PluginManager) LaunchResource(ctx context.Context, tCtx pluginsCore.Tas } func (e *PluginManager) CheckResourcePhase(ctx context.Context, tCtx pluginsCore.TaskExecutionContext, k8sPluginState *k8s.PluginState) (pluginsCore.Transition, error) { - o, err := e.plugin.BuildIdentityResource(ctx, tCtx.TaskExecutionMetadata()) if err != nil { logger.Errorf(ctx, "Failed to build the Resource with name: %v. Error: %v", tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), err) diff --git a/flytepropeller/pkg/controller/nodes/task/plugin_state_manager.go b/flytepropeller/pkg/controller/nodes/task/plugin_state_manager.go index 496f5387f3..f68d7b58aa 100644 --- a/flytepropeller/pkg/controller/nodes/task/plugin_state_manager.go +++ b/flytepropeller/pkg/controller/nodes/task/plugin_state_manager.go @@ -19,7 +19,7 @@ const ( const currentCodec = GobCodecVersion // TODO Configurable? -const maxPluginStateSizeBytes = 256 +const MaxPluginStateSizeBytes = 256 type stateCodec interface { Encode(interface{}, io.Writer) error @@ -38,7 +38,7 @@ type pluginStateManager struct { func (p *pluginStateManager) Put(stateVersion uint8, v interface{}) error { p.newStateVersion = stateVersion if v != nil { - buf := make([]byte, 0, maxPluginStateSizeBytes) + buf := make([]byte, 0, MaxPluginStateSizeBytes) p.newState = bytes.NewBuffer(buf) return p.codec.Encode(v, p.newState) } diff --git a/flytepropeller/pkg/controller/nodes/task/setup_ctx.go b/flytepropeller/pkg/controller/nodes/task/setup_ctx.go index c788ffd4e1..4277275a26 100644 --- a/flytepropeller/pkg/controller/nodes/task/setup_ctx.go +++ b/flytepropeller/pkg/controller/nodes/task/setup_ctx.go @@ -2,14 +2,16 @@ package task import ( pluginCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" + + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" + "github.com/flyteorg/flytestdlib/promutils" - "k8s.io/apimachinery/pkg/types" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + "k8s.io/apimachinery/pkg/types" ) type setupContext struct { - handler.SetupContext + interfaces.SetupContext kubeClient pluginCore.KubeClient secretManager pluginCore.SecretManager } @@ -29,7 +31,7 @@ func (s setupContext) EnqueueOwner() pluginCore.EnqueueOwner { } } -func (t *Handler) newSetupContext(sCtx handler.SetupContext) *setupContext { +func (t *Handler) newSetupContext(sCtx interfaces.SetupContext) *setupContext { return &setupContext{ SetupContext: sCtx, diff --git a/flytepropeller/pkg/controller/nodes/task/setup_ctx_test.go b/flytepropeller/pkg/controller/nodes/task/setup_ctx_test.go index 6b8d0a438c..e987bbb244 100644 --- a/flytepropeller/pkg/controller/nodes/task/setup_ctx_test.go +++ b/flytepropeller/pkg/controller/nodes/task/setup_ctx_test.go @@ -4,13 +4,13 @@ import ( "testing" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" "github.com/flyteorg/flytestdlib/promutils" "github.com/stretchr/testify/assert" ) type dummySetupCtx struct { - handler.SetupContext + interfaces.SetupContext testScopeName string } diff --git a/flytepropeller/pkg/controller/nodes/task/taskexec_context.go b/flytepropeller/pkg/controller/nodes/task/taskexec_context.go index 7abfd25991..a308647ac0 100644 --- a/flytepropeller/pkg/controller/nodes/task/taskexec_context.go +++ b/flytepropeller/pkg/controller/nodes/task/taskexec_context.go @@ -15,7 +15,7 @@ import ( "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/common" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/errors" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/resourcemanager" "github.com/flyteorg/flytepropeller/pkg/utils" "github.com/flyteorg/flytestdlib/logger" @@ -58,7 +58,7 @@ func (te taskExecutionID) GetGeneratedNameWith(minLength, maxLength int) (string } type taskExecutionMetadata struct { - handler.NodeExecutionMetadata + interfaces.NodeExecutionMetadata taskExecID taskExecutionID o pluginCore.TaskOverrides maxAttempts uint32 @@ -87,7 +87,7 @@ func (t taskExecutionMetadata) GetEnvironmentVariables() map[string]string { } type taskExecutionContext struct { - handler.NodeExecutionContext + interfaces.NodeExecutionContext tm taskExecutionMetadata rm resourcemanager.TaskResourceManager psm *pluginStateManager @@ -208,7 +208,7 @@ func convertTaskResourcesToRequirements(taskResources v1alpha1.TaskResources) *v // ComputeRawOutputPrefix constructs the output directory, where raw outputs of a task can be stored by the task. FlytePropeller may not have // access to this location and can be passed in per execution. // the function also returns the uniqueID generated -func ComputeRawOutputPrefix(ctx context.Context, length int, nCtx handler.NodeExecutionContext, currentNodeUniqueID v1alpha1.NodeID, currentAttempt uint32) (io.RawOutputPaths, string, error) { +func ComputeRawOutputPrefix(ctx context.Context, length int, nCtx interfaces.NodeExecutionContext, currentNodeUniqueID v1alpha1.NodeID, currentAttempt uint32) (io.RawOutputPaths, string, error) { uniqueID, err := encoding.FixedLengthUniqueIDForParts(length, []string{nCtx.NodeExecutionMetadata().GetOwnerID().Name, currentNodeUniqueID, strconv.Itoa(int(currentAttempt))}) if err != nil { // SHOULD never really happen @@ -223,7 +223,7 @@ func ComputeRawOutputPrefix(ctx context.Context, length int, nCtx handler.NodeEx } // ComputePreviousCheckpointPath returns the checkpoint path for the previous attempt, if this is the first attempt then returns an empty path -func ComputePreviousCheckpointPath(ctx context.Context, length int, nCtx handler.NodeExecutionContext, currentNodeUniqueID v1alpha1.NodeID, currentAttempt uint32) (storage.DataReference, error) { +func ComputePreviousCheckpointPath(ctx context.Context, length int, nCtx interfaces.NodeExecutionContext, currentNodeUniqueID v1alpha1.NodeID, currentAttempt uint32) (storage.DataReference, error) { // If first attempt for this node execution, look for a checkpoint path in a prior execution if currentAttempt == 0 { return nCtx.NodeStateReader().GetTaskNodeState().PreviousNodeExecutionCheckpointURI, nil @@ -237,7 +237,7 @@ func ComputePreviousCheckpointPath(ctx context.Context, length int, nCtx handler return ioutils.ConstructCheckpointPath(nCtx.DataStore(), prevRawOutputPrefix.GetRawOutputPrefix()), nil } -func (t *Handler) newTaskExecutionContext(ctx context.Context, nCtx handler.NodeExecutionContext, plugin pluginCore.Plugin) (*taskExecutionContext, error) { +func (t *Handler) newTaskExecutionContext(ctx context.Context, nCtx interfaces.NodeExecutionContext, plugin pluginCore.Plugin) (*taskExecutionContext, error) { id := GetTaskExecutionIdentifier(nCtx) currentNodeUniqueID := nCtx.NodeID() diff --git a/flytepropeller/pkg/controller/nodes/task/taskexec_context_test.go b/flytepropeller/pkg/controller/nodes/task/taskexec_context_test.go index bea082871a..cd30e86b5a 100644 --- a/flytepropeller/pkg/controller/nodes/task/taskexec_context_test.go +++ b/flytepropeller/pkg/controller/nodes/task/taskexec_context_test.go @@ -31,7 +31,7 @@ import ( "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" flyteMocks "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" - nodeMocks "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler/mocks" + nodeMocks "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces/mocks" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/codex" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/secretmanager" ) diff --git a/flytepropeller/pkg/controller/nodes/task/transformer.go b/flytepropeller/pkg/controller/nodes/task/transformer.go index 6faa93f70a..21edca7231 100644 --- a/flytepropeller/pkg/controller/nodes/task/transformer.go +++ b/flytepropeller/pkg/controller/nodes/task/transformer.go @@ -12,6 +12,7 @@ import ( "github.com/flyteorg/flytepropeller/pkg/controller/executors" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/common" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" "github.com/golang/protobuf/ptypes" timestamppb "github.com/golang/protobuf/ptypes/timestamp" @@ -75,7 +76,7 @@ type ToTaskExecutionEventInputs struct { EventConfig *config.EventConfig OutputWriter io.OutputFilePaths Info pluginCore.PhaseInfo - NodeExecutionMetadata handler.NodeExecutionMetadata + NodeExecutionMetadata interfaces.NodeExecutionMetadata ExecContext executors.ExecutionContext TaskType string PluginID string @@ -185,7 +186,7 @@ func ToTaskExecutionEvent(input ToTaskExecutionEventInputs) (*event.TaskExecutio return tev, nil } -func GetTaskExecutionIdentifier(nCtx handler.NodeExecutionContext) *core.TaskExecutionIdentifier { +func GetTaskExecutionIdentifier(nCtx interfaces.NodeExecutionContext) *core.TaskExecutionIdentifier { return &core.TaskExecutionIdentifier{ TaskId: nCtx.TaskReader().GetTaskID(), RetryAttempt: nCtx.CurrentAttempt(), diff --git a/flytepropeller/pkg/controller/nodes/task/transformer_test.go b/flytepropeller/pkg/controller/nodes/task/transformer_test.go index f8a6f12897..d9400336ec 100644 --- a/flytepropeller/pkg/controller/nodes/task/transformer_test.go +++ b/flytepropeller/pkg/controller/nodes/task/transformer_test.go @@ -24,7 +24,7 @@ import ( pluginMocks "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" - handlerMocks "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler/mocks" + nodemocks "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces/mocks" ) const containerTaskType = "container" @@ -60,7 +60,7 @@ func TestToTaskExecutionEvent(t *testing.T) { const outputPath = "out" out.On("GetOutputPath").Return(storage.DataReference(outputPath)) - nodeExecutionMetadata := handlerMocks.NodeExecutionMetadata{} + nodeExecutionMetadata := nodemocks.NodeExecutionMetadata{} nodeExecutionMetadata.OnIsInterruptible().Return(true) mockExecContext := &mocks2.ExecutionContext{} @@ -158,7 +158,7 @@ func TestToTaskExecutionEvent(t *testing.T) { assert.EqualValues(t, resourcePoolInfo, tev.Metadata.ResourcePoolInfo) assert.Equal(t, testClusterID, tev.ProducerId) - defaultNodeExecutionMetadata := handlerMocks.NodeExecutionMetadata{} + defaultNodeExecutionMetadata := nodemocks.NodeExecutionMetadata{} defaultNodeExecutionMetadata.OnIsInterruptible().Return(false) tev, err = ToTaskExecutionEvent(ToTaskExecutionEventInputs{ TaskExecContext: tCtx, @@ -251,7 +251,7 @@ func TestToTaskExecutionEventWithParent(t *testing.T) { const outputPath = "out" out.On("GetOutputPath").Return(storage.DataReference(outputPath)) - nodeExecutionMetadata := handlerMocks.NodeExecutionMetadata{} + nodeExecutionMetadata := nodemocks.NodeExecutionMetadata{} nodeExecutionMetadata.OnIsInterruptible().Return(true) mockExecContext := &mocks2.ExecutionContext{} diff --git a/flytepropeller/pkg/controller/nodes/task_event_recorder.go b/flytepropeller/pkg/controller/nodes/task_event_recorder.go deleted file mode 100644 index ef3ec1e93d..0000000000 --- a/flytepropeller/pkg/controller/nodes/task_event_recorder.go +++ /dev/null @@ -1,36 +0,0 @@ -package nodes - -import ( - "context" - - "github.com/flyteorg/flytepropeller/events" - eventsErr "github.com/flyteorg/flytepropeller/events/errors" - "github.com/flyteorg/flytepropeller/pkg/controller/config" - - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" - "github.com/flyteorg/flytestdlib/logger" - "github.com/pkg/errors" -) - -type taskEventRecorder struct { - events.TaskEventRecorder -} - -func (t taskEventRecorder) RecordTaskEvent(ctx context.Context, ev *event.TaskExecutionEvent, eventConfig *config.EventConfig) error { - if err := t.TaskEventRecorder.RecordTaskEvent(ctx, ev, eventConfig); err != nil { - if eventsErr.IsAlreadyExists(err) { - logger.Warningf(ctx, "Failed to record taskEvent, error [%s]. Trying to record state: %s. Ignoring this error!", err.Error(), ev.Phase) - return nil - } else if eventsErr.IsEventAlreadyInTerminalStateError(err) { - if IsTerminalTaskPhase(ev.Phase) { - // Event is terminal and the stored value in flyteadmin is already terminal. This implies aborted case. So ignoring - logger.Warningf(ctx, "Failed to record taskEvent, error [%s]. Trying to record state: %s. Ignoring this error!", err.Error(), ev.Phase) - return nil - } - logger.Warningf(ctx, "Failed to record taskEvent in state: %s, error: %s", ev.Phase, err) - return errors.Wrapf(err, "failed to record task event, as it already exists in terminal state. Event state: %s", ev.Phase) - } - return err - } - return nil -} diff --git a/flytepropeller/pkg/controller/nodes/task_event_recorder_test.go b/flytepropeller/pkg/controller/nodes/task_event_recorder_test.go deleted file mode 100644 index 0f4da20376..0000000000 --- a/flytepropeller/pkg/controller/nodes/task_event_recorder_test.go +++ /dev/null @@ -1,60 +0,0 @@ -package nodes - -import ( - "context" - "fmt" - "testing" - - "github.com/flyteorg/flytepropeller/events" - eventsErr "github.com/flyteorg/flytepropeller/events/errors" - "github.com/flyteorg/flytepropeller/pkg/controller/config" - - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" -) - -type fakeTaskEventsRecorder struct { - err error -} - -func (f fakeTaskEventsRecorder) RecordTaskEvent(ctx context.Context, event *event.TaskExecutionEvent, eventConfig *config.EventConfig) error { - if f.err != nil { - return f.err - } - return nil -} - -func Test_taskEventRecorder_RecordTaskEvent(t1 *testing.T) { - noErrRecorder := fakeTaskEventsRecorder{} - alreadyExistsError := fakeTaskEventsRecorder{&eventsErr.EventError{Code: eventsErr.AlreadyExists, Cause: fmt.Errorf("err")}} - inTerminalError := fakeTaskEventsRecorder{&eventsErr.EventError{Code: eventsErr.EventAlreadyInTerminalStateError, Cause: fmt.Errorf("err")}} - otherError := fakeTaskEventsRecorder{&eventsErr.EventError{Code: eventsErr.ResourceExhausted, Cause: fmt.Errorf("err")}} - - tests := []struct { - name string - rec events.TaskEventRecorder - p core.TaskExecution_Phase - wantErr bool - }{ - {"aborted-success", noErrRecorder, core.TaskExecution_ABORTED, false}, - {"aborted-failure", otherError, core.TaskExecution_ABORTED, true}, - {"aborted-already", alreadyExistsError, core.TaskExecution_ABORTED, false}, - {"aborted-terminal", inTerminalError, core.TaskExecution_ABORTED, false}, - {"running-terminal", inTerminalError, core.TaskExecution_RUNNING, true}, - } - for _, tt := range tests { - t1.Run(tt.name, func(t1 *testing.T) { - t := taskEventRecorder{ - TaskEventRecorder: tt.rec, - } - ev := &event.TaskExecutionEvent{ - Phase: tt.p, - } - if err := t.RecordTaskEvent(context.TODO(), ev, &config.EventConfig{ - RawOutputPolicy: config.RawOutputPolicyReference, - }); (err != nil) != tt.wantErr { - t1.Errorf("RecordTaskEvent() error = %v, wantErr %v", err, tt.wantErr) - } - }) - } -} diff --git a/flytepropeller/pkg/controller/nodes/transformers.go b/flytepropeller/pkg/controller/nodes/transformers.go index 306ab182a4..b8c06ce3d8 100644 --- a/flytepropeller/pkg/controller/nodes/transformers.go +++ b/flytepropeller/pkg/controller/nodes/transformers.go @@ -6,19 +6,21 @@ import ( "strconv" "time" - "github.com/flyteorg/flytepropeller/pkg/controller/config" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" + "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/flyteorg/flytepropeller/pkg/controller/config" "github.com/flyteorg/flytepropeller/pkg/controller/executors" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/common" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" "github.com/flyteorg/flytestdlib/logger" + "github.com/golang/protobuf/ptypes" - v1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" ) // This is used by flyteadmin to indicate that the events will now contain populated IsParent and IsDynamic bits. @@ -226,54 +228,72 @@ func ToK8sTime(t time.Time) v1.Time { return v1.Time{Time: t} } -func UpdateNodeStatus(np v1alpha1.NodePhase, p handler.PhaseInfo, n *nodeStateManager, s v1alpha1.ExecutableNodeStatus) { +func UpdateNodeStatus(np v1alpha1.NodePhase, p handler.PhaseInfo, n interfaces.NodeStateReader, s v1alpha1.ExecutableNodeStatus) { // We update the phase only if it is not already updated if np != s.GetPhase() { s.UpdatePhase(np, ToK8sTime(p.GetOccurredAt()), p.GetReason(), p.GetErr()) } // Update TaskStatus - if n.t != nil { + if n.HasTaskNodeState() { + nt := n.GetTaskNodeState() t := s.GetOrCreateTaskStatus() - t.SetPhaseVersion(n.t.PluginPhaseVersion) - t.SetPhase(int(n.t.PluginPhase)) - t.SetLastPhaseUpdatedAt(n.t.LastPhaseUpdatedAt) - t.SetPluginState(n.t.PluginState) - t.SetPluginStateVersion(n.t.PluginStateVersion) - t.SetPreviousNodeExecutionCheckpointPath(n.t.PreviousNodeExecutionCheckpointURI) - t.SetCleanupOnFailure(n.t.CleanupOnFailure) + t.SetPhaseVersion(nt.PluginPhaseVersion) + t.SetPhase(int(nt.PluginPhase)) + t.SetLastPhaseUpdatedAt(nt.LastPhaseUpdatedAt) + t.SetPluginState(nt.PluginState) + t.SetPluginStateVersion(nt.PluginStateVersion) + t.SetPreviousNodeExecutionCheckpointPath(nt.PreviousNodeExecutionCheckpointURI) + t.SetCleanupOnFailure(nt.CleanupOnFailure) } // Update dynamic node status - if n.d != nil { + if n.HasDynamicNodeState() { + nd := n.GetDynamicNodeState() t := s.GetOrCreateDynamicNodeStatus() - t.SetDynamicNodePhase(n.d.Phase) - t.SetDynamicNodeReason(n.d.Reason) - t.SetExecutionError(n.d.Error) - t.SetIsFailurePermanent(n.d.IsFailurePermanent) + t.SetDynamicNodePhase(nd.Phase) + t.SetDynamicNodeReason(nd.Reason) + t.SetExecutionError(nd.Error) + t.SetIsFailurePermanent(nd.IsFailurePermanent) } // Update branch node status - if n.b != nil { + if n.HasBranchNodeState() { + nb := n.GetBranchNodeState() t := s.GetOrCreateBranchStatus() - if n.b.Phase == v1alpha1.BranchNodeError { + if nb.Phase == v1alpha1.BranchNodeError { t.SetBranchNodeError() - } else if n.b.FinalizedNodeID != nil { - t.SetBranchNodeSuccess(*n.b.FinalizedNodeID) + } else if nb.FinalizedNodeID != nil { + t.SetBranchNodeSuccess(*nb.FinalizedNodeID) } else { logger.Warnf(context.TODO(), "branch node status neither success nor error set") } } // Update workflow node status - if n.w != nil { + if n.HasWorkflowNodeState() { + nw := n.GetWorkflowNodeState() t := s.GetOrCreateWorkflowStatus() - t.SetWorkflowNodePhase(n.w.Phase) - t.SetExecutionError(n.w.Error) + t.SetWorkflowNodePhase(nw.Phase) + t.SetExecutionError(nw.Error) } // Update gate node status - if n.g != nil { + if n.HasGateNodeState() { + ng := n.GetGateNodeState() t := s.GetOrCreateGateNodeStatus() - t.SetGateNodePhase(n.g.Phase) + t.SetGateNodePhase(ng.Phase) + } + + // Update array node status + if n.HasArrayNodeState() { + na := n.GetArrayNodeState() + t := s.GetOrCreateArrayNodeStatus() + t.SetArrayNodePhase(na.Phase) + t.SetExecutionError(na.Error) + t.SetSubNodePhases(na.SubNodePhases) + t.SetSubNodeTaskPhases(na.SubNodeTaskPhases) + t.SetSubNodeRetryAttempts(na.SubNodeRetryAttempts) + t.SetSubNodeSystemFailures(na.SubNodeSystemFailures) + t.SetTaskPhaseVersion(na.TaskPhaseVersion) } } diff --git a/flytepropeller/pkg/controller/workflow/executor.go b/flytepropeller/pkg/controller/workflow/executor.go index e3eac1e37c..11a7662135 100644 --- a/flytepropeller/pkg/controller/workflow/executor.go +++ b/flytepropeller/pkg/controller/workflow/executor.go @@ -5,23 +5,25 @@ import ( "fmt" "time" - "github.com/flyteorg/flytepropeller/pkg/controller/config" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" - "github.com/flyteorg/flytestdlib/logger" - "github.com/flyteorg/flytestdlib/promutils" - "github.com/flyteorg/flytestdlib/promutils/labeled" - "github.com/flyteorg/flytestdlib/storage" - corev1 "k8s.io/api/core/v1" - "k8s.io/client-go/tools/record" "github.com/flyteorg/flytepropeller/events" eventsErr "github.com/flyteorg/flytepropeller/events/errors" "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/flyteorg/flytepropeller/pkg/controller/config" "github.com/flyteorg/flytepropeller/pkg/controller/executors" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" "github.com/flyteorg/flytepropeller/pkg/controller/workflow/errors" "github.com/flyteorg/flytepropeller/pkg/utils" + + "github.com/flyteorg/flytestdlib/logger" + "github.com/flyteorg/flytestdlib/promutils" + "github.com/flyteorg/flytestdlib/promutils/labeled" + "github.com/flyteorg/flytestdlib/storage" + + corev1 "k8s.io/api/core/v1" + "k8s.io/client-go/tools/record" ) type workflowMetrics struct { @@ -64,7 +66,7 @@ type workflowExecutor struct { wfRecorder events.WorkflowEventRecorder k8sRecorder record.EventRecorder metadataPrefix storage.DataReference - nodeExecutor executors.Node + nodeExecutor interfaces.Node metrics *workflowMetrics eventConfig *config.EventConfig clusterID string @@ -495,7 +497,7 @@ func (c *workflowExecutor) cleanupRunningNodes(ctx context.Context, w v1alpha1.E } func NewExecutor(ctx context.Context, store *storage.DataStore, enQWorkflow v1alpha1.EnqueueWorkflow, eventSink events.EventSink, - k8sEventRecorder record.EventRecorder, metadataPrefix string, nodeExecutor executors.Node, eventConfig *config.EventConfig, + k8sEventRecorder record.EventRecorder, metadataPrefix string, nodeExecutor interfaces.Node, eventConfig *config.EventConfig, clusterID string, scope promutils.Scope) (executors.Workflow, error) { basePrefix := store.GetBaseContainerFQN(ctx) if metadataPrefix != "" { diff --git a/flytepropeller/pkg/controller/workflow/executor_test.go b/flytepropeller/pkg/controller/workflow/executor_test.go index a07590ec77..ddd553dd83 100644 --- a/flytepropeller/pkg/controller/workflow/executor_test.go +++ b/flytepropeller/pkg/controller/workflow/executor_test.go @@ -32,6 +32,9 @@ import ( eventsErr "github.com/flyteorg/flytepropeller/events/errors" eventMocks "github.com/flyteorg/flytepropeller/events/mocks" mocks2 "github.com/flyteorg/flytepropeller/pkg/controller/executors/mocks" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/factory" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + nodemocks "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces/mocks" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/catalog" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/fakeplugins" @@ -243,10 +246,13 @@ func TestWorkflowExecutor_HandleFlyteWorkflow_Error(t *testing.T) { catalogClient, err := catalog.NewCatalogClient(ctx, nil) assert.NoError(t, err) recoveryClient := &recoveryMocks.Client{} - adminClient := launchplan.NewFailFastLaunchPlanExecutor() + + handlerFactory, err := factory.NewHandlerFactory(ctx, adminClient, adminClient, fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) + assert.NoError(t, err) + nodeExec, err := nodes.NewExecutor(ctx, config.GetConfig().NodeConfig, store, enqueueWorkflow, eventSink, adminClient, adminClient, - maxOutputSize, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) + maxOutputSize, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, handlerFactory, promutils.NewTestScope()) assert.NoError(t, err) executor, err := NewExecutor(ctx, store, enqueueWorkflow, eventSink, recorder, "", nodeExec, eventConfig, testClusterID, promutils.NewTestScope()) assert.NoError(t, err) @@ -323,10 +329,13 @@ func TestWorkflowExecutor_HandleFlyteWorkflow(t *testing.T) { catalogClient, err := catalog.NewCatalogClient(ctx, nil) assert.NoError(t, err) recoveryClient := &recoveryMocks.Client{} - adminClient := launchplan.NewFailFastLaunchPlanExecutor() + + handlerFactory, err := factory.NewHandlerFactory(ctx, adminClient, adminClient, fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) + assert.NoError(t, err) + nodeExec, err := nodes.NewExecutor(ctx, config.GetConfig().NodeConfig, store, enqueueWorkflow, eventSink, adminClient, adminClient, - maxOutputSize, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) + maxOutputSize, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, handlerFactory, promutils.NewTestScope()) assert.NoError(t, err) executor, err := NewExecutor(ctx, store, enqueueWorkflow, eventSink, recorder, "", nodeExec, eventConfig, testClusterID, promutils.NewTestScope()) @@ -388,8 +397,9 @@ func BenchmarkWorkflowExecutor(b *testing.B) { assert.NoError(b, err) recoveryClient := &recoveryMocks.Client{} adminClient := launchplan.NewFailFastLaunchPlanExecutor() + handlerFactory := &nodemocks.HandlerFactory{} nodeExec, err := nodes.NewExecutor(ctx, config.GetConfig().NodeConfig, store, enqueueWorkflow, eventSink, adminClient, adminClient, - maxOutputSize, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, scope) + maxOutputSize, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, handlerFactory, scope) assert.NoError(b, err) executor, err := NewExecutor(ctx, store, enqueueWorkflow, eventSink, recorder, "", nodeExec, eventConfig, testClusterID, promutils.NewTestScope()) @@ -489,8 +499,19 @@ func TestWorkflowExecutor_HandleFlyteWorkflow_Failing(t *testing.T) { assert.NoError(t, err) recoveryClient := &recoveryMocks.Client{} adminClient := launchplan.NewFailFastLaunchPlanExecutor() + + h := &nodemocks.NodeHandler{} + h.OnAbortMatch(mock.Anything, mock.Anything, mock.Anything).Return(nil) + h.OnHandleMatch(mock.Anything, mock.Anything).Return(handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoSuccess(nil)), nil) + h.OnFinalizeMatch(mock.Anything, mock.Anything).Return(nil) + h.OnFinalizeRequired().Return(false) + + handlerFactory := &nodemocks.HandlerFactory{} + handlerFactory.OnSetupMatch(mock.Anything, mock.Anything, mock.Anything).Return(nil) + handlerFactory.OnGetHandlerMatch(mock.Anything).Return(h, nil) + nodeExec, err := nodes.NewExecutor(ctx, config.GetConfig().NodeConfig, store, enqueueWorkflow, eventSink, adminClient, adminClient, - maxOutputSize, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) + maxOutputSize, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, handlerFactory, promutils.NewTestScope()) assert.NoError(t, err) executor, err := NewExecutor(ctx, store, enqueueWorkflow, eventSink, recorder, "", nodeExec, eventConfig, testClusterID, promutils.NewTestScope()) assert.NoError(t, err) @@ -585,8 +606,12 @@ func TestWorkflowExecutor_HandleFlyteWorkflow_Events(t *testing.T) { assert.NoError(t, err) adminClient := launchplan.NewFailFastLaunchPlanExecutor() recoveryClient := &recoveryMocks.Client{} + + handlerFactory, err := factory.NewHandlerFactory(ctx, adminClient, adminClient, fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) + assert.NoError(t, err) + nodeExec, err := nodes.NewExecutor(ctx, config.GetConfig().NodeConfig, store, enqueueWorkflow, eventSink, adminClient, adminClient, - maxOutputSize, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) + maxOutputSize, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, handlerFactory, promutils.NewTestScope()) assert.NoError(t, err) executor, err := NewExecutor(ctx, store, enqueueWorkflow, eventSink, recorder, "metadata", nodeExec, eventConfig, testClusterID, promutils.NewTestScope()) assert.NoError(t, err) @@ -643,8 +668,16 @@ func TestWorkflowExecutor_HandleFlyteWorkflow_EventFailure(t *testing.T) { recoveryClient := &recoveryMocks.Client{} adminClient := launchplan.NewFailFastLaunchPlanExecutor() + h := &nodemocks.NodeHandler{} + h.OnAbortMatch(mock.Anything, mock.Anything, mock.Anything).Return(nil) + h.OnHandleMatch(mock.Anything, mock.Anything).Return(handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoSuccess(nil)), nil) + h.OnFinalizeMatch(mock.Anything, mock.Anything).Return(nil) + h.OnFinalizeRequired().Return(false) + handlerFactory := &nodemocks.HandlerFactory{} + handlerFactory.OnSetupMatch(mock.Anything, mock.Anything, mock.Anything).Return(nil) + handlerFactory.OnGetHandlerMatch(mock.Anything).Return(h, nil) nodeExec, err := nodes.NewExecutor(ctx, config.GetConfig().NodeConfig, store, enqueueWorkflow, nodeEventSink, adminClient, adminClient, - maxOutputSize, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) + maxOutputSize, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, handlerFactory, promutils.NewTestScope()) assert.NoError(t, err) t.Run("EventAlreadyInTerminalStateError", func(t *testing.T) { @@ -726,7 +759,7 @@ func TestWorkflowExecutor_HandleAbortedWorkflow(t *testing.T) { t.Run("user-initiated-fail", func(t *testing.T) { - nodeExec := &mocks2.Node{} + nodeExec := &nodemocks.Node{} wExec := &workflowExecutor{ nodeExecutor: nodeExec, metrics: newMetrics(promutils.NewTestScope()), @@ -756,7 +789,7 @@ func TestWorkflowExecutor_HandleAbortedWorkflow(t *testing.T) { t.Run("user-initiated-success", func(t *testing.T) { var evs []*event.WorkflowExecutionEvent - nodeExec := &mocks2.Node{} + nodeExec := &nodemocks.Node{} wfRecorder := &eventMocks.WorkflowEventRecorder{} wfRecorder.On("RecordWorkflowEvent", mock.Anything, mock.MatchedBy(func(ev *event.WorkflowExecutionEvent) bool { assert.Equal(t, testClusterID, ev.ProducerId) @@ -798,7 +831,7 @@ func TestWorkflowExecutor_HandleAbortedWorkflow(t *testing.T) { t.Run("user-initiated-attempts-exhausted", func(t *testing.T) { var evs []*event.WorkflowExecutionEvent - nodeExec := &mocks2.Node{} + nodeExec := &nodemocks.Node{} wfRecorder := &eventMocks.WorkflowEventRecorder{} wfRecorder.OnRecordWorkflowEventMatch(mock.Anything, mock.MatchedBy(func(ev *event.WorkflowExecutionEvent) bool { assert.Equal(t, testClusterID, ev.ProducerId) @@ -839,7 +872,7 @@ func TestWorkflowExecutor_HandleAbortedWorkflow(t *testing.T) { t.Run("failure-abort-success", func(t *testing.T) { var evs []*event.WorkflowExecutionEvent - nodeExec := &mocks2.Node{} + nodeExec := &nodemocks.Node{} wfRecorder := &eventMocks.WorkflowEventRecorder{} wfRecorder.OnRecordWorkflowEventMatch(mock.Anything, mock.MatchedBy(func(ev *event.WorkflowExecutionEvent) bool { assert.Equal(t, testClusterID, ev.ProducerId) @@ -877,7 +910,7 @@ func TestWorkflowExecutor_HandleAbortedWorkflow(t *testing.T) { t.Run("failure-abort-failed", func(t *testing.T) { - nodeExec := &mocks2.Node{} + nodeExec := &nodemocks.Node{} wExec := &workflowExecutor{ nodeExecutor: nodeExec, metrics: newMetrics(promutils.NewTestScope()),