From 3b9f57c8a875dd75b26ab34bb4ce6ef5aa8eef95 Mon Sep 17 00:00:00 2001 From: Ketan Umare <16888709+kumare3@users.noreply.github.com> Date: Thu, 18 Mar 2021 14:04:20 -0700 Subject: [PATCH] Limit parallelism of task nodes (#232) * wip: Limit parallelism of task nodes Signed-off-by: Ketan Umare * Unit test fixes Signed-off-by: Ketan Umare * can be deleted * Setting parallelism to 10 to test * Updated logic to use task handler to mark parallelism Signed-off-by: Ketan Umare * Test for incr parallelism Signed-off-by: Ketan Umare * linter fix Signed-off-by: Ketan Umare * fixed lint error Signed-off-by: Ketan Umare * updated go.sum Signed-off-by: Ketan Umare * Update pkg/controller/nodes/executor.go Co-authored-by: Haytham Abuelfutuh Co-authored-by: Haytham Abuelfutuh Signed-off-by: Ketan Umare --- README.md | 12 +- go.sum | 3 - .../{admin.go => execution_config.go} | 13 ++ ...admin_test.go => execution_config_test.go} | 0 pkg/apis/flyteworkflow/v1alpha1/workflow.go | 12 -- .../v1alpha1/zz_generated.deepcopy.go | 1 + pkg/controller/executors/execution_context.go | 37 +++- .../executors/execution_context_test.go | 2 +- .../executors/mocks/control_flow.go | 74 ++++++++ .../executors/mocks/execution_context.go | 64 +++++++ .../nodes/dynamic/dynamic_workflow.go | 2 +- pkg/controller/nodes/executor.go | 48 ++++- pkg/controller/nodes/executor_test.go | 176 ++++++++++++++++-- .../nodes/node_exec_context_test.go | 4 +- pkg/controller/nodes/resolve_test.go | 1 + pkg/controller/nodes/task/handler.go | 4 + pkg/controller/nodes/task/handler_test.go | 12 +- pkg/controller/workflow/executor.go | 8 +- 18 files changed, 414 insertions(+), 59 deletions(-) rename pkg/apis/flyteworkflow/v1alpha1/{admin.go => execution_config.go} (53%) rename pkg/apis/flyteworkflow/v1alpha1/{admin_test.go => execution_config_test.go} (100%) create mode 100644 pkg/controller/executors/mocks/control_flow.go diff --git a/README.md b/README.md index 434068de3..562f252b1 100644 --- a/README.md +++ b/README.md @@ -2,12 +2,12 @@ Flyte Propeller =============== [![Current Release](https://img.shields.io/github/release/flyteorg/flytepropeller.svg)](https://github.com/flyteorg/flytepropeller/releases/latest) ![Master](https://github.com/flyteorg/flytepropeller/workflows/Master/badge.svg) -[![GoDoc](https://godoc.org/github.com/lyft/flytepropeller?status.svg)](https://pkg.go.dev/mod/github.com/lyft/flytepropeller) +[![GoDoc](https://godoc.org/github.com/flyteorg/flytepropeller?status.svg)](https://pkg.go.dev/mod/github.com/flyteorg/flytepropeller) [![License](https://img.shields.io/badge/LICENSE-Apache2.0-ff69b4.svg)](http://www.apache.org/licenses/LICENSE-2.0.html) [![CodeCoverage](https://img.shields.io/codecov/c/github/flyteorg/flytepropeller.svg)](https://codecov.io/gh/flyteorg/flytepropeller) -[![Go Report Card](https://goreportcard.com/badge/github.com/lyft/flytepropeller)](https://goreportcard.com/report/github.com/lyft/flytepropeller) -![Commit activity](https://img.shields.io/github/commit-activity/w/lyft/flytepropeller.svg?style=plastic) -![Commit since last release](https://img.shields.io/github/commits-since/lyft/flytepropeller/latest.svg?style=plastic) +[![Go Report Card](https://goreportcard.com/badge/github.com/flyteorg/flytepropeller)](https://goreportcard.com/report/github.com/flyteorg/flytepropeller) +![Commit activity](https://img.shields.io/github/commit-activity/w/flyteorg/flytepropeller.svg?style=plastic) +![Commit since last release](https://img.shields.io/github/commits-since/flyteorg/flytepropeller/latest.svg?style=plastic) Kubernetes operator to executes Flyte graphs natively on kubernetes @@ -89,7 +89,7 @@ To delete a specific workflow $ kubectl-flyte delete --namespace flytekit-development flytekit-development-ff806e973581f4508bf1 ``` -To delete all completed workflows - they have to be either success/failed with a special isCompleted label set on them. The Label is set `here ` +To delete all completed workflows - they have to be either success/failed with a special isCompleted label set on them. The Label is set `here ` ``` $ kubectl-flyte delete --namespace flytekit-development --all-completed @@ -97,7 +97,7 @@ To delete all completed workflows - they have to be either success/failed with a Running propeller locally ------------------------- -use the config.yaml in root found `here `. Cd into this folder and then run +use the config.yaml in root found `here `. Cd into this folder and then run ``` $ flytepropeller --logtostderr diff --git a/go.sum b/go.sum index d2abe147f..543224836 100644 --- a/go.sum +++ b/go.sum @@ -72,7 +72,6 @@ github.com/Azure/go-autorest/logger v0.2.0/go.mod h1:T9E3cAhj2VqvPOtCYAvby9aBXkZ github.com/Azure/go-autorest/tracing v0.5.0/go.mod h1:r/s2XiOKccPW3HrqB+W0TQzfbtp2fGCgRFtBroKn4Dk= github.com/Azure/go-autorest/tracing v0.6.0 h1:TYi4+3m5t6K48TGI9AUdb+IzbnSxvnvUMfuitfgcfuo= github.com/Azure/go-autorest/tracing v0.6.0/go.mod h1:+vhtPC754Xsa23ID7GlGsrdKBpUA79WCAKPPZVC2DeU= -github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/DiSiqueira/GoTree v1.0.1-0.20180907134536-53a8e837f295 h1:xJ0dAkuxJXfwdH7IaSzBEbSQxEDz36YUmt7+CB4zoNA= @@ -231,7 +230,6 @@ github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5Kwzbycv github.com/fatih/color v1.10.0 h1:s36xzo75JdqLaaWoiEHk767eHiwo0598uUxyfiPkDsg= github.com/fatih/color v1.10.0/go.mod h1:ELkj/draVOlAH/xkhN6mQ50Qd0MPOk5AAr3maGEBuJM= github.com/fatih/structtag v1.2.0/go.mod h1:mBJUNpUnHmRKrKlQQlmCrh5PuhftFbNv8Ys4/aAZl94= -github.com/flyteorg/flyteidl v0.18.15 h1:sXrlwTRaRjQsXYMNrY/S930SKdKtu4XnpNFEu8I4tn4= github.com/flyteorg/flyteidl v0.18.15/go.mod h1:b5Fq4Z8a5b0mF6pEwTd48ufvikUGVkWSjZiMT0ZtqKI= github.com/flyteorg/flyteidl v0.18.20 h1:OGOb2FOHWL363Qp8uzbJeFbQBKYPT30+afv+8BnBlGs= github.com/flyteorg/flyteidl v0.18.20/go.mod h1:b5Fq4Z8a5b0mF6pEwTd48ufvikUGVkWSjZiMT0ZtqKI= @@ -1231,7 +1229,6 @@ honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWh honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= honnef.co/go/tools v0.0.1-2020.1.3/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= -honnef.co/go/tools v0.0.1-2020.1.4 h1:UoveltGrhghAA7ePc+e+QYDHXrBps2PqFZiHkGR/xK8= honnef.co/go/tools v0.0.1-2020.1.4/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= k8s.io/api v0.0.0-20210217171935-8e2decd92398/go.mod h1:60tmSUpHxGPFerNHbo/ayI2lKxvtrhbxFyXuEIWJd78= k8s.io/api v0.18.2/go.mod h1:SJCWI7OLzhZSvbY7U8zwNl9UA4o1fizoug34OV/2r78= diff --git a/pkg/apis/flyteworkflow/v1alpha1/admin.go b/pkg/apis/flyteworkflow/v1alpha1/execution_config.go similarity index 53% rename from pkg/apis/flyteworkflow/v1alpha1/admin.go rename to pkg/apis/flyteworkflow/v1alpha1/execution_config.go index 88077f666..6de14607a 100644 --- a/pkg/apis/flyteworkflow/v1alpha1/admin.go +++ b/pkg/apis/flyteworkflow/v1alpha1/execution_config.go @@ -13,3 +13,16 @@ type RawOutputDataConfig struct { func (in *RawOutputDataConfig) DeepCopyInto(out *RawOutputDataConfig) { *out = *in } + +// This contains workflow-execution specifications and overrides. +type ExecutionConfig struct { + // Maps individual task types to their alternate (non-default) plugin handlers by name. + TaskPluginImpls map[string]TaskPluginOverride + // Can be used to control the number of parallel nodes to run within the workflow. This is useful to achieve fairness. + MaxParallelism uint32 +} + +type TaskPluginOverride struct { + PluginIDs []string + MissingPluginBehavior admin.PluginOverride_MissingPluginBehavior +} diff --git a/pkg/apis/flyteworkflow/v1alpha1/admin_test.go b/pkg/apis/flyteworkflow/v1alpha1/execution_config_test.go similarity index 100% rename from pkg/apis/flyteworkflow/v1alpha1/admin_test.go rename to pkg/apis/flyteworkflow/v1alpha1/execution_config_test.go diff --git a/pkg/apis/flyteworkflow/v1alpha1/workflow.go b/pkg/apis/flyteworkflow/v1alpha1/workflow.go index f5a23f462..79251b709 100644 --- a/pkg/apis/flyteworkflow/v1alpha1/workflow.go +++ b/pkg/apis/flyteworkflow/v1alpha1/workflow.go @@ -9,7 +9,6 @@ import ( "k8s.io/apimachinery/pkg/types" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/golang/protobuf/jsonpb" "github.com/pkg/errors" @@ -306,14 +305,3 @@ type FlyteWorkflowList struct { metav1.ListMeta `json:"metadata"` Items []FlyteWorkflow `json:"items"` } - -// This contains workflow-execution specifications and overrides. -type ExecutionConfig struct { - // Maps individual task types to their alternate (non-default) plugin handlers by name. - TaskPluginImpls map[string]TaskPluginOverride -} - -type TaskPluginOverride struct { - PluginIDs []string - MissingPluginBehavior admin.PluginOverride_MissingPluginBehavior -} diff --git a/pkg/apis/flyteworkflow/v1alpha1/zz_generated.deepcopy.go b/pkg/apis/flyteworkflow/v1alpha1/zz_generated.deepcopy.go index 4b6dceb29..53da1d12d 100644 --- a/pkg/apis/flyteworkflow/v1alpha1/zz_generated.deepcopy.go +++ b/pkg/apis/flyteworkflow/v1alpha1/zz_generated.deepcopy.go @@ -151,6 +151,7 @@ func (in *ExecutionConfig) DeepCopyInto(out *ExecutionConfig) { (*out)[key] = *val.DeepCopy() } } + out.MaxParallelism = in.MaxParallelism return } diff --git a/pkg/controller/executors/execution_context.go b/pkg/controller/executors/execution_context.go index adc910ff2..53c0bbbd1 100644 --- a/pkg/controller/executors/execution_context.go +++ b/pkg/controller/executors/execution_context.go @@ -28,14 +28,21 @@ type ImmutableParentInfo interface { CurrentAttempt() uint32 } +type ControlFlow interface { + CurrentParallelism() uint32 + IncrementParallelism() uint32 +} + type ExecutionContext interface { ImmutableExecutionContext TaskDetailsGetter SubWorkflowGetter ParentInfoGetter + ControlFlow } type execContext struct { + ControlFlow ImmutableExecutionContext TaskDetailsGetter SubWorkflowGetter @@ -59,24 +66,40 @@ func (p *parentExecutionInfo) CurrentAttempt() uint32 { return p.currentAttempts } +type controlFlow struct { + // We could use atomic.Uint32, but this is not required for current Propeller. As every round is run in a single + // thread and using atomic will introduce memory barriers + v uint32 +} + +func (c *controlFlow) CurrentParallelism() uint32 { + return c.v +} + +func (c *controlFlow) IncrementParallelism() uint32 { + c.v = c.v + 1 + return c.v +} + func NewExecutionContextWithTasksGetter(prevExecContext ExecutionContext, taskGetter TaskDetailsGetter) ExecutionContext { - return NewExecutionContext(prevExecContext, taskGetter, prevExecContext, prevExecContext.GetParentInfo()) + return NewExecutionContext(prevExecContext, taskGetter, prevExecContext, prevExecContext.GetParentInfo(), prevExecContext) } func NewExecutionContextWithWorkflowGetter(prevExecContext ExecutionContext, getter SubWorkflowGetter) ExecutionContext { - return NewExecutionContext(prevExecContext, prevExecContext, getter, prevExecContext.GetParentInfo()) + return NewExecutionContext(prevExecContext, prevExecContext, getter, prevExecContext.GetParentInfo(), prevExecContext) } func NewExecutionContextWithParentInfo(prevExecContext ExecutionContext, parentInfo ImmutableParentInfo) ExecutionContext { - return NewExecutionContext(prevExecContext, prevExecContext, prevExecContext, parentInfo) + return NewExecutionContext(prevExecContext, prevExecContext, prevExecContext, parentInfo, prevExecContext) } -func NewExecutionContext(immExecContext ImmutableExecutionContext, tasksGetter TaskDetailsGetter, workflowGetter SubWorkflowGetter, parentInfo ImmutableParentInfo) ExecutionContext { +func NewExecutionContext(immExecContext ImmutableExecutionContext, tasksGetter TaskDetailsGetter, workflowGetter SubWorkflowGetter, parentInfo ImmutableParentInfo, flow ControlFlow) ExecutionContext { return execContext{ ImmutableExecutionContext: immExecContext, TaskDetailsGetter: tasksGetter, SubWorkflowGetter: workflowGetter, parentInfo: parentInfo, + ControlFlow: flow, } } @@ -86,3 +109,9 @@ func NewParentInfo(uniqueID string, currentAttempts uint32) ImmutableParentInfo uniqueID: uniqueID, } } + +func InitializeControlFlow() ControlFlow { + return &controlFlow{ + v: 0, + } +} diff --git a/pkg/controller/executors/execution_context_test.go b/pkg/controller/executors/execution_context_test.go index 22dcd1764..363905707 100644 --- a/pkg/controller/executors/execution_context_test.go +++ b/pkg/controller/executors/execution_context_test.go @@ -28,7 +28,7 @@ func TestExecutionContext(t *testing.T) { subWfGetter := subWfGetter{} immutableParentInfo := immutableParentInfo{} - ec := NewExecutionContext(eCtx, taskGetter, subWfGetter, immutableParentInfo) + ec := NewExecutionContext(eCtx, taskGetter, subWfGetter, immutableParentInfo, InitializeControlFlow()) assert.NotNil(t, ec) typed := ec.(execContext) assert.Equal(t, typed.ImmutableExecutionContext, eCtx) diff --git a/pkg/controller/executors/mocks/control_flow.go b/pkg/controller/executors/mocks/control_flow.go new file mode 100644 index 000000000..8d9bfff64 --- /dev/null +++ b/pkg/controller/executors/mocks/control_flow.go @@ -0,0 +1,74 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import mock "github.com/stretchr/testify/mock" + +// ControlFlow is an autogenerated mock type for the ControlFlow type +type ControlFlow struct { + mock.Mock +} + +type ControlFlow_CurrentParallelism struct { + *mock.Call +} + +func (_m ControlFlow_CurrentParallelism) Return(_a0 uint32) *ControlFlow_CurrentParallelism { + return &ControlFlow_CurrentParallelism{Call: _m.Call.Return(_a0)} +} + +func (_m *ControlFlow) OnCurrentParallelism() *ControlFlow_CurrentParallelism { + c := _m.On("CurrentParallelism") + return &ControlFlow_CurrentParallelism{Call: c} +} + +func (_m *ControlFlow) OnCurrentParallelismMatch(matchers ...interface{}) *ControlFlow_CurrentParallelism { + c := _m.On("CurrentParallelism", matchers...) + return &ControlFlow_CurrentParallelism{Call: c} +} + +// CurrentParallelism provides a mock function with given fields: +func (_m *ControlFlow) CurrentParallelism() uint32 { + ret := _m.Called() + + var r0 uint32 + if rf, ok := ret.Get(0).(func() uint32); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(uint32) + } + + return r0 +} + +type ControlFlow_IncrementParallelism struct { + *mock.Call +} + +func (_m ControlFlow_IncrementParallelism) Return(_a0 uint32) *ControlFlow_IncrementParallelism { + return &ControlFlow_IncrementParallelism{Call: _m.Call.Return(_a0)} +} + +func (_m *ControlFlow) OnIncrementParallelism() *ControlFlow_IncrementParallelism { + c := _m.On("IncrementParallelism") + return &ControlFlow_IncrementParallelism{Call: c} +} + +func (_m *ControlFlow) OnIncrementParallelismMatch(matchers ...interface{}) *ControlFlow_IncrementParallelism { + c := _m.On("IncrementParallelism", matchers...) + return &ControlFlow_IncrementParallelism{Call: c} +} + +// IncrementParallelism provides a mock function with given fields: +func (_m *ControlFlow) IncrementParallelism() uint32 { + ret := _m.Called() + + var r0 uint32 + if rf, ok := ret.Get(0).(func() uint32); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(uint32) + } + + return r0 +} diff --git a/pkg/controller/executors/mocks/execution_context.go b/pkg/controller/executors/mocks/execution_context.go index d56ee4696..8772a229b 100644 --- a/pkg/controller/executors/mocks/execution_context.go +++ b/pkg/controller/executors/mocks/execution_context.go @@ -18,6 +18,38 @@ type ExecutionContext struct { mock.Mock } +type ExecutionContext_CurrentParallelism struct { + *mock.Call +} + +func (_m ExecutionContext_CurrentParallelism) Return(_a0 uint32) *ExecutionContext_CurrentParallelism { + return &ExecutionContext_CurrentParallelism{Call: _m.Call.Return(_a0)} +} + +func (_m *ExecutionContext) OnCurrentParallelism() *ExecutionContext_CurrentParallelism { + c := _m.On("CurrentParallelism") + return &ExecutionContext_CurrentParallelism{Call: c} +} + +func (_m *ExecutionContext) OnCurrentParallelismMatch(matchers ...interface{}) *ExecutionContext_CurrentParallelism { + c := _m.On("CurrentParallelism", matchers...) + return &ExecutionContext_CurrentParallelism{Call: c} +} + +// CurrentParallelism provides a mock function with given fields: +func (_m *ExecutionContext) CurrentParallelism() uint32 { + ret := _m.Called() + + var r0 uint32 + if rf, ok := ret.Get(0).(func() uint32); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(uint32) + } + + return r0 +} + type ExecutionContext_FindSubWorkflow struct { *mock.Call } @@ -579,6 +611,38 @@ func (_m *ExecutionContext) GetTask(id string) (v1alpha1.ExecutableTask, error) return r0, r1 } +type ExecutionContext_IncrementParallelism struct { + *mock.Call +} + +func (_m ExecutionContext_IncrementParallelism) Return(_a0 uint32) *ExecutionContext_IncrementParallelism { + return &ExecutionContext_IncrementParallelism{Call: _m.Call.Return(_a0)} +} + +func (_m *ExecutionContext) OnIncrementParallelism() *ExecutionContext_IncrementParallelism { + c := _m.On("IncrementParallelism") + return &ExecutionContext_IncrementParallelism{Call: c} +} + +func (_m *ExecutionContext) OnIncrementParallelismMatch(matchers ...interface{}) *ExecutionContext_IncrementParallelism { + c := _m.On("IncrementParallelism", matchers...) + return &ExecutionContext_IncrementParallelism{Call: c} +} + +// IncrementParallelism provides a mock function with given fields: +func (_m *ExecutionContext) IncrementParallelism() uint32 { + ret := _m.Called() + + var r0 uint32 + if rf, ok := ret.Get(0).(func() uint32); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(uint32) + } + + return r0 +} + type ExecutionContext_IsInterruptible struct { *mock.Call } diff --git a/pkg/controller/nodes/dynamic/dynamic_workflow.go b/pkg/controller/nodes/dynamic/dynamic_workflow.go index a65d121d3..852db85ec 100644 --- a/pkg/controller/nodes/dynamic/dynamic_workflow.go +++ b/pkg/controller/nodes/dynamic/dynamic_workflow.go @@ -207,7 +207,7 @@ func (d dynamicNodeTaskNodeHandler) buildContextualDynamicWorkflow(ctx context.C return dynamicWorkflowContext{ isDynamic: true, subWorkflow: dynamicWf, - execContext: executors.NewExecutionContext(nCtx.ExecutionContext(), dynamicWf, dynamicWf, newParentInfo), + execContext: executors.NewExecutionContext(nCtx.ExecutionContext(), dynamicWf, dynamicWf, newParentInfo, nCtx.ExecutionContext()), nodeLookup: executors.NewNodeLookup(dynamicWf, dynamicNodeStatus), }, nil } diff --git a/pkg/controller/nodes/executor.go b/pkg/controller/nodes/executor.go index 33dcd2e7b..07d62fa92 100644 --- a/pkg/controller/nodes/executor.go +++ b/pkg/controller/nodes/executor.go @@ -1,3 +1,19 @@ +// Core Nodes Executor implementation +// This module implements the core Nodes executor. +// This executor is the starting point for executing any node in the workflow. Since Nodes in a workflow are composable, +// i.e., one node may contain other nodes, the Node Handler is recursive in nature. +// This executor handles the core logic for all nodes, but specific logic for handling different kinds of nodes is delegated +// to the respective node handlers +// +// Available node handlers are +// - Task: Arguably the most important handler as it handles all tasks. These include all plugins. The goal of the workflow is +// is to run tasks, thus every workflow will contain atleast one TaskNode (except for the case, where the workflow +// is purely a meta-workflow and can run other workflows +// - SubWorkflow: This is one of the most important handlers. It can executes Workflows that are nested inside a workflow +// - DynamicTask Handler: This is just a decorator on the Task Handler. It handles cases, in which the Task returns a futures +// file. Every Task is actually executed through the DynamicTaskHandler +// - Branch Handler: This handler is used to execute branches +// - Start & End Node handler: these are nominal handlers for the start and end node and do no really carry a lot of logic package nodes import ( @@ -62,6 +78,7 @@ type nodeMetrics struct { NodeInputGatherLatency labeled.StopWatch } +// Implements the executors.Node interface type nodeExecutor struct { nodeHandlerFactory HandlerFactory enqueueWorkflow v1alpha1.EnqueueWorkflow @@ -137,7 +154,7 @@ func (c *nodeExecutor) preExecute(ctx context.Context, dag executors.DAGStructur if predicatePhase == PredicatePhaseReady { // TODO: Performance problem, we maybe in a retry loop and do not need to resolve the inputs again. - // For now we will do this. + // For now we will do this node := nCtx.Node() nodeStatus := nCtx.NodeStatus() dataDir := nodeStatus.GetDataDir() @@ -182,7 +199,7 @@ func (c *nodeExecutor) preExecute(ctx context.Context, dag executors.DAGStructur return handler.PhaseInfoNotReady("predecessor node not yet complete"), nil } -func (c *nodeExecutor) isTimeoutExpired(queuedAt *metav1.Time, timeout time.Duration) bool { +func isTimeoutExpired(queuedAt *metav1.Time, timeout time.Duration) bool { if !queuedAt.IsZero() && timeout != 0 { deadline := queuedAt.Add(timeout) if deadline.Before(time.Now()) { @@ -224,7 +241,7 @@ func (c *nodeExecutor) execute(ctx context.Context, h handler.Node, nCtx *nodeEx if nCtx.Node().GetActiveDeadline() != nil && *nCtx.Node().GetActiveDeadline() > 0 { activeDeadline = *nCtx.Node().GetActiveDeadline() } - if c.isTimeoutExpired(nodeStatus.GetQueuedAt(), activeDeadline) { + if isTimeoutExpired(nodeStatus.GetQueuedAt(), activeDeadline) { logger.Errorf(ctx, "Node has timed out; timeout configured: %v", activeDeadline) return handler.PhaseInfoTimedOut(nil, fmt.Sprintf("task active timeout [%s] expired", activeDeadline.String())), nil } @@ -234,7 +251,7 @@ func (c *nodeExecutor) execute(ctx context.Context, h handler.Node, nCtx *nodeEx if nCtx.Node().GetExecutionDeadline() != nil && *nCtx.Node().GetExecutionDeadline() > 0 { executionDeadline = *nCtx.Node().GetExecutionDeadline() } - if c.isTimeoutExpired(nodeStatus.GetLastAttemptStartedAt(), executionDeadline) { + if isTimeoutExpired(nodeStatus.GetLastAttemptStartedAt(), executionDeadline) { logger.Errorf(ctx, "Current execution for the node timed out; timeout configured: %v", executionDeadline) executionErr := &core.ExecutionError{Code: "TimeoutExpired", Message: fmt.Sprintf("task execution timeout [%s] expired", executionDeadline.String()), Kind: core.ExecutionError_USER} phase = handler.PhaseInfoRetryableFailureErr(executionErr, nil) @@ -659,11 +676,32 @@ func (c *nodeExecutor) RecursiveNodeHandler(ctx context.Context, execContext exe // This is an optimization to avoid creating the nodeContext object in case the node has already been looked at. // If the overhead was zero, we would just do the isDirtyCheck after the nodeContext is created - nodeStatus := nl.GetNodeExecutionStatus(ctx, currentNode.GetID()) if nodeStatus.IsDirty() { return executors.NodeStatusRunning, nil } + // Now if the node is of type task, then let us check if we are within the parallelism limit, only if the node + // has been queued already + if currentNode.GetKind() == v1alpha1.NodeKindTask && nodeStatus.GetPhase() == v1alpha1.NodePhaseQueued { + maxParallelism := execContext.GetExecutionConfig().MaxParallelism + if maxParallelism > 0 { + // If we are queued, let us see if we can proceed within the node parallelism bounds + if execContext.CurrentParallelism() >= maxParallelism { + logger.Infof(ctx, "Maximum Parallelism for task nodes achieved [%d] >= Max [%d], Round will be short-circuited.", execContext.CurrentParallelism(), maxParallelism) + return executors.NodeStatusRunning, nil + } + // We know that Propeller goes through each workflow in a single thread, thus every node is really processed + // sequentially. So, we can continue - now that we know we are under the parallelism limits and increment the + // parallelism if the node, enters a running state + logger.Debugf(ctx, "Parallelism criteria not met, Current [%d], Max [%d]", execContext.CurrentParallelism(), maxParallelism) + } else { + logger.Debugf(ctx, "Parallelism control disabled") + } + } else { + logger.Debugf(ctx, "NodeKind: %s in status [%s]. Parallelism control is not applicable. Current Parallelism [%d]", + currentNode.GetKind().String(), nodeStatus.GetPhase().String(), execContext.CurrentParallelism()) + } + nCtx, err := c.newNodeExecContextDefault(ctx, currentNode.GetID(), execContext, nl) if err != nil { // NodeExecution creation failure is a permanent fail / system error. diff --git a/pkg/controller/nodes/executor_test.go b/pkg/controller/nodes/executor_test.go index c598f72bb..e5bad53ee 100644 --- a/pkg/controller/nodes/executor_test.go +++ b/pkg/controller/nodes/executor_test.go @@ -239,7 +239,7 @@ func TestNodeExecutor_RecursiveNodeHandler_RecurseStartNodes(t *testing.T) { hf.On("GetHandler", v1alpha1.NodeKindStart).Return(h, nil) mockWf, startNode, startNodeStatus := createStartNodeWf(test.currentNodePhase, 0) - executionContext := executors.NewExecutionContext(mockWf, nil, nil, nil) + executionContext := executors.NewExecutionContext(mockWf, nil, nil, nil, executors.InitializeControlFlow()) s, err := exec.RecursiveNodeHandler(ctx, executionContext, mockWf, mockWf, startNode) if test.expectedError { assert.Error(t, err) @@ -328,7 +328,7 @@ func TestNodeExecutor_RecursiveNodeHandler_RecurseEndNode(t *testing.T) { hf.On("GetHandler", v1alpha1.NodeKindEnd).Return(h, nil) mockWf, mockNode, mockNodeStatus := createSingleNodeWf(test.parentNodePhase, 0) - execContext := executors.NewExecutionContext(mockWf, nil, nil, nil) + execContext := executors.NewExecutionContext(mockWf, nil, nil, nil, executors.InitializeControlFlow()) s, err := exec.RecursiveNodeHandler(ctx, execContext, mockWf, mockWf, mockNode) if test.expectedError { assert.Error(t, err) @@ -392,7 +392,7 @@ func TestNodeExecutor_RecursiveNodeHandler_RecurseEndNode(t *testing.T) { RawOutputDataConfig: &admin.RawOutputDataConfig{OutputLocationPrefix: ""}, }, } - executionContext := executors.NewExecutionContext(w, nil, nil, nil) + executionContext := executors.NewExecutionContext(w, nil, nil, nil, executors.InitializeControlFlow()) return w, executionContext, n, ns } @@ -665,7 +665,7 @@ func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { exec := execIface.(*nodeExecutor) exec.nodeHandlerFactory = hf - execContext := executors.NewExecutionContext(mockWf, mockWf, mockWf, nil) + execContext := executors.NewExecutionContext(mockWf, mockWf, mockWf, nil, executors.InitializeControlFlow()) s, err := exec.RecursiveNodeHandler(ctx, execContext, mockWf, mockWf, startNode) if test.expectedError { assert.Error(t, err) @@ -769,7 +769,7 @@ func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { startNode := mockWf.StartNode() startStatus := mockWf.GetNodeExecutionStatus(ctx, startNode.GetID()) assert.Equal(t, v1alpha1.NodePhaseSucceeded, startStatus.GetPhase()) - execContext := executors.NewExecutionContext(mockWf, mockWf, nil, nil) + execContext := executors.NewExecutionContext(mockWf, mockWf, nil, nil, executors.InitializeControlFlow()) s, err := exec.RecursiveNodeHandler(ctx, execContext, mockWf, mockWf, startNode) if test.expectedError { assert.Error(t, err) @@ -878,7 +878,7 @@ func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { mockWf, _, mockNodeStatus := createSingleNodeWf(test.currentNodePhase, 1) execErr := mockNodeStatus.GetExecutionError() startNode := mockWf.StartNode() - execContext := executors.NewExecutionContext(mockWf, mockWf, nil, nil) + execContext := executors.NewExecutionContext(mockWf, mockWf, nil, nil, executors.InitializeControlFlow()) s, err := exec.RecursiveNodeHandler(ctx, execContext, mockWf, mockWf, startNode) if test.expectedError { assert.Error(t, err) @@ -925,7 +925,7 @@ func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { mockWf, _, mockNodeStatus := createSingleNodeWf(v1alpha1.NodePhaseRunning, 0) startNode := mockWf.StartNode() - execContext := executors.NewExecutionContext(mockWf, mockWf, nil, nil) + execContext := executors.NewExecutionContext(mockWf, mockWf, nil, nil, executors.InitializeControlFlow()) s, err := exec.RecursiveNodeHandler(ctx, execContext, mockWf, mockWf, startNode) assert.NoError(t, err) @@ -956,7 +956,7 @@ func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { mockWf, _, mockNodeStatus := createSingleNodeWf(v1alpha1.NodePhaseRunning, 1) startNode := mockWf.StartNode() - execContext := executors.NewExecutionContext(mockWf, mockWf, nil, nil) + execContext := executors.NewExecutionContext(mockWf, mockWf, nil, nil, executors.InitializeControlFlow()) s, err := exec.RecursiveNodeHandler(ctx, execContext, mockWf, mockWf, startNode) assert.NoError(t, err) assert.Equal(t, executors.NodePhasePending.String(), s.NodePhase.String()) @@ -1061,7 +1061,7 @@ func TestNodeExecutor_RecursiveNodeHandler_NoDownstream(t *testing.T) { hf.On("GetHandler", v1alpha1.NodeKindTask).Return(h, nil) mockWf, mockNode, mockNodeStatus := createSingleNodeWf(test.currentNodePhase, 1) - execContext := executors.NewExecutionContext(mockWf, nil, nil, nil) + execContext := executors.NewExecutionContext(mockWf, nil, nil, nil, executors.InitializeControlFlow()) s, err := exec.RecursiveNodeHandler(ctx, execContext, mockWf, mockWf, mockNode) if test.expectedError { assert.Error(t, err) @@ -1169,7 +1169,7 @@ func TestNodeExecutor_RecursiveNodeHandler_UpstreamNotReady(t *testing.T) { hf.On("GetHandler", v1alpha1.NodeKindTask).Return(h, nil) mockWf, mockNode, mockNodeStatus := createSingleNodeWf(test.parentNodePhase, 0) - execContext := executors.NewExecutionContext(mockWf, mockWf, nil, nil) + execContext := executors.NewExecutionContext(mockWf, mockWf, nil, nil, executors.InitializeControlFlow()) s, err := exec.RecursiveNodeHandler(ctx, execContext, mockWf, mockWf, mockNode) if test.expectedError { assert.Error(t, err) @@ -1251,6 +1251,7 @@ func TestNodeExecutor_RecursiveNodeHandler_BranchNode(t *testing.T) { eCtx.OnGetRawOutputDataConfig().Return(v1alpha1.RawOutputDataConfig{ RawOutputDataConfig: &admin.RawOutputDataConfig{OutputLocationPrefix: ""}, }) + eCtx.OnCurrentParallelism().Return(0) branchTakenNodeID := "branchTakenNode" branchTakenNode := &mocks.ExecutableNode{} @@ -1519,8 +1520,8 @@ func Test_nodeExecutor_system_error(t *testing.T) { nCtx := &nodeExecContext{node: mockNode, nsm: &nodeStateManager{nodeStatus: ns}} phaseInfo, err := c.execute(context.TODO(), h, nCtx, ns) - assert.Equal(t, handler.EPhaseRetryableFailure, phaseInfo.GetPhase()) assert.NoError(t, err) + assert.Equal(t, handler.EPhaseRetryableFailure, phaseInfo.GetPhase()) assert.Equal(t, core.ExecutionError_SYSTEM, phaseInfo.GetErr().GetKind()) } @@ -1633,15 +1634,15 @@ func TestNodeExecutionEventV0(t *testing.T) { ns.OnGetPhase().Return(v1alpha1.NodePhaseNotYetStarted) nl.OnGetNodeExecutionStatusMatch(mock.Anything, id).Return(ns) ns.OnGetParentTaskID().Return(tID) - event, err := ToNodeExecutionEvent(nID, p, inputReader, ns, v1alpha1.EventVersion0, parentInfo, n) + ev, err := ToNodeExecutionEvent(nID, p, inputReader, ns, v1alpha1.EventVersion0, parentInfo, n) assert.NoError(t, err) - assert.Equal(t, "n1", event.Id.NodeId) - assert.Equal(t, execID, event.Id.ExecutionId) - assert.Empty(t, event.SpecNodeId) - assert.Nil(t, event.ParentNodeMetadata) - assert.Equal(t, tID, event.ParentTaskMetadata.Id) - assert.Empty(t, event.NodeName) - assert.Empty(t, event.RetryGroup) + assert.Equal(t, "n1", ev.Id.NodeId) + assert.Equal(t, execID, ev.Id.ExecutionId) + assert.Empty(t, ev.SpecNodeId) + assert.Nil(t, ev.ParentNodeMetadata) + assert.Equal(t, tID, ev.ParentTaskMetadata.Id) + assert.Empty(t, ev.NodeName) + assert.Empty(t, ev.RetryGroup) } func TestNodeExecutionEventV1(t *testing.T) { @@ -1686,3 +1687,140 @@ func TestNodeExecutionEventV1(t *testing.T) { assert.Equal(t, "name", eventOpt.NodeName) assert.Equal(t, "2", eventOpt.RetryGroup) } + +func TestNodeExecutor_RecursiveNodeHandler_ParallelismLimit(t *testing.T) { + ctx := context.Background() + enQWf := func(workflowID v1alpha1.WorkflowID) { + } + mockEventSink := events.NewMockEventSink().(*events.MockEventSink) + + store := createInmemoryDataStore(t, promutils.NewTestScope()) + + adminClient := launchplan.NewFailFastLaunchPlanExecutor() + execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, + 10, "s3://bucket", fakeKubeClient, catalogClient, promutils.NewTestScope()) + assert.NoError(t, err) + exec := execIface.(*nodeExecutor) + + defaultNodeID := "n1" + taskID := taskID + createSingleNodeWf := func(p v1alpha1.NodePhase, maxParallelism uint32) (v1alpha1.ExecutableWorkflow, v1alpha1.ExecutableNode, v1alpha1.ExecutableNodeStatus) { + maxAttempts := 1 + n := &v1alpha1.NodeSpec{ + ID: defaultNodeID, + TaskRef: &taskID, + Kind: v1alpha1.NodeKindTask, + RetryStrategy: &v1alpha1.RetryStrategy{ + MinAttempts: &maxAttempts, + }, + } + + var err *v1alpha1.ExecutionError + if p == v1alpha1.NodePhaseFailing || p == v1alpha1.NodePhaseFailed { + err = &v1alpha1.ExecutionError{ExecutionError: &core.ExecutionError{Code: "test", Message: "test"}} + } + ns := &v1alpha1.NodeStatus{ + Phase: p, + LastAttemptStartedAt: &v1.Time{}, + Error: err, + } + + startNode := &v1alpha1.NodeSpec{ + Kind: v1alpha1.NodeKindStart, + ID: v1alpha1.StartNodeID, + } + return &v1alpha1.FlyteWorkflow{ + Tasks: map[v1alpha1.TaskID]*v1alpha1.TaskSpec{ + taskID: { + TaskTemplate: &core.TaskTemplate{}, + }, + }, + ExecutionConfig: v1alpha1.ExecutionConfig{ + MaxParallelism: maxParallelism, + }, + Status: v1alpha1.WorkflowStatus{ + NodeStatus: map[v1alpha1.NodeID]*v1alpha1.NodeStatus{ + defaultNodeID: ns, + v1alpha1.StartNodeID: { + Phase: v1alpha1.NodePhaseSucceeded, + }, + }, + DataDir: "data", + }, + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: "wf", + Nodes: map[v1alpha1.NodeID]*v1alpha1.NodeSpec{ + defaultNodeID: n, + v1alpha1.StartNodeID: startNode, + }, + Connections: v1alpha1.Connections{ + UpstreamEdges: map[v1alpha1.NodeID][]v1alpha1.NodeID{ + defaultNodeID: {v1alpha1.StartNodeID}, + }, + DownstreamEdges: map[v1alpha1.NodeID][]v1alpha1.NodeID{ + v1alpha1.StartNodeID: {defaultNodeID}, + }, + }, + }, + DataReferenceConstructor: store, + RawOutputDataConfig: v1alpha1.RawOutputDataConfig{ + RawOutputDataConfig: &admin.RawOutputDataConfig{OutputLocationPrefix: ""}, + }, + }, n, ns + + } + + t.Run("parallelism-not-met", func(t *testing.T) { + mockWf, mockNode, _ := createSingleNodeWf(v1alpha1.NodePhaseQueued, 1) + cf := executors.InitializeControlFlow() + eCtx := executors.NewExecutionContext(mockWf, mockWf, nil, nil, cf) + + hf := &mocks2.HandlerFactory{} + exec.nodeHandlerFactory = hf + h := &nodeHandlerMocks.Node{} + h.OnHandleMatch( + mock.MatchedBy(func(ctx context.Context) bool { return true }), + mock.MatchedBy(func(o handler.NodeExecutionContext) bool { return true }), + ).Return(handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoSuccess(nil)), nil) + h.OnFinalizeRequired().Return(false) + + hf.OnGetHandler(v1alpha1.NodeKindTask).Return(h, nil) + + s, err := exec.RecursiveNodeHandler(ctx, eCtx, mockWf, mockWf, mockNode) + assert.NoError(t, err) + assert.Equal(t, s.NodePhase.String(), executors.NodePhaseSuccess.String()) + }) + + t.Run("parallelism-met", func(t *testing.T) { + mockWf, mockNode, _ := createSingleNodeWf(v1alpha1.NodePhaseQueued, 1) + cf := executors.InitializeControlFlow() + cf.IncrementParallelism() + eCtx := executors.NewExecutionContext(mockWf, mockWf, nil, nil, cf) + + s, err := exec.RecursiveNodeHandler(ctx, eCtx, mockWf, mockWf, mockNode) + assert.NoError(t, err) + assert.Equal(t, s.NodePhase.String(), executors.NodePhaseRunning.String()) + }) + + t.Run("parallelism-disabled", func(t *testing.T) { + mockWf, mockNode, _ := createSingleNodeWf(v1alpha1.NodePhaseQueued, 0) + cf := executors.InitializeControlFlow() + cf.IncrementParallelism() + eCtx := executors.NewExecutionContext(mockWf, mockWf, nil, nil, cf) + + hf := &mocks2.HandlerFactory{} + exec.nodeHandlerFactory = hf + h := &nodeHandlerMocks.Node{} + h.OnHandleMatch( + mock.MatchedBy(func(ctx context.Context) bool { return true }), + mock.MatchedBy(func(o handler.NodeExecutionContext) bool { return true }), + ).Return(handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoSuccess(nil)), nil) + h.OnFinalizeRequired().Return(false) + + hf.OnGetHandler(v1alpha1.NodeKindTask).Return(h, nil) + + s, err := exec.RecursiveNodeHandler(ctx, eCtx, mockWf, mockWf, mockNode) + assert.NoError(t, err) + assert.Equal(t, s.NodePhase.String(), executors.NodePhaseSuccess.String()) + }) +} diff --git a/pkg/controller/nodes/node_exec_context_test.go b/pkg/controller/nodes/node_exec_context_test.go index 03d9852a9..31e6dfe79 100644 --- a/pkg/controller/nodes/node_exec_context_test.go +++ b/pkg/controller/nodes/node_exec_context_test.go @@ -56,7 +56,7 @@ func Test_NodeContext(t *testing.T) { } s, _ := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) p := parentInfo{} - execContext := executors.NewExecutionContext(w1, nil, nil, p) + execContext := executors.NewExecutionContext(w1, nil, nil, p, nil) nCtx := newNodeExecContext(context.TODO(), s, execContext, w1, n, nil, nil, false, 0, nil, TaskReader{}, nil, nil, "s3://bucket", ioutils.NewConstantShardSelector([]string{"x"})) assert.Equal(t, "id", nCtx.NodeExecutionMetadata().GetLabels()["node-id"]) assert.Equal(t, "false", nCtx.NodeExecutionMetadata().GetLabels()["interruptible"]) @@ -111,7 +111,7 @@ func Test_NodeContextDefault(t *testing.T) { enqueueWorkflow: func(workflowID v1alpha1.WorkflowID) {}, } p := parentInfo{} - execContext := executors.NewExecutionContext(w1, w1, w1, p) + execContext := executors.NewExecutionContext(w1, w1, w1, p, nil) nodeExecContext, err := nodeExecutor.newNodeExecContextDefault(context.Background(), "node-a", execContext, nodeLookup) assert.NoError(t, err) assert.Equal(t, "s3://bucket-a", nodeExecContext.rawOutputPrefix.String()) diff --git a/pkg/controller/nodes/resolve_test.go b/pkg/controller/nodes/resolve_test.go index 95fea0d42..35daed9d5 100644 --- a/pkg/controller/nodes/resolve_test.go +++ b/pkg/controller/nodes/resolve_test.go @@ -22,6 +22,7 @@ import ( var testScope = promutils.NewScope("test") type dummyBaseWorkflow struct { + executors.ControlFlow DummyStartNode v1alpha1.ExecutableNode ID v1alpha1.WorkflowID ToNodeCb func(name v1alpha1.NodeID) ([]v1alpha1.NodeID, error) diff --git a/pkg/controller/nodes/task/handler.go b/pkg/controller/nodes/task/handler.go index 78177871a..25c0fd88a 100644 --- a/pkg/controller/nodes/task/handler.go +++ b/pkg/controller/nodes/task/handler.go @@ -623,6 +623,10 @@ func (t Handler) Handle(ctx context.Context, nCtx handler.NodeExecutionContext) return handler.UnknownTransition, err } + if !pluginTrns.pInfo.Phase().IsTerminal() { + eCtx := nCtx.ExecutionContext() + logger.Infof(ctx, "Parallelism now set to [%d].", eCtx.IncrementParallelism()) + } return pluginTrns.FinalTransition(ctx) } diff --git a/pkg/controller/nodes/task/handler_test.go b/pkg/controller/nodes/task/handler_test.go index 18be27422..d5981e4f3 100644 --- a/pkg/controller/nodes/task/handler_test.go +++ b/pkg/controller/nodes/task/handler_test.go @@ -358,7 +358,7 @@ func CreateNoopResourceManager(ctx context.Context, scope promutils.Scope) resou func Test_task_Handle_NoCatalog(t *testing.T) { - createNodeContext := func(pluginPhase pluginCore.Phase, pluginVer uint32, pluginResp fakeplugins.NextPhaseState, recorder events.TaskEventRecorder, ttype string, s *taskNodeStateHolder) *nodeMocks.NodeExecutionContext { + createNodeContext := func(pluginPhase pluginCore.Phase, pluginVer uint32, pluginResp fakeplugins.NextPhaseState, recorder events.TaskEventRecorder, ttype string, s *taskNodeStateHolder, allowIncrementParallelism bool) *nodeMocks.NodeExecutionContext { wfExecID := &core.WorkflowExecutionIdentifier{ Project: "project", Domain: "domain", @@ -448,6 +448,9 @@ func Test_task_Handle_NoCatalog(t *testing.T) { executionContext.OnGetExecutionConfig().Return(v1alpha1.ExecutionConfig{}) executionContext.OnGetEventVersion().Return(v1alpha1.EventVersion0) executionContext.OnGetParentInfo().Return(nil) + if allowIncrementParallelism { + executionContext.OnIncrementParallelism().Return(1) + } nCtx.OnExecutionContext().Return(executionContext) st := bytes.NewBuffer([]byte{}) @@ -477,6 +480,7 @@ func Test_task_Handle_NoCatalog(t *testing.T) { event bool eventPhase core.TaskExecution_Phase skipStateUpdate bool + incrParallel bool } tests := []struct { name string @@ -583,6 +587,7 @@ func Test_task_Handle_NoCatalog(t *testing.T) { handlerPhase: handler.EPhaseRunning, event: true, eventPhase: core.TaskExecution_RUNNING, + incrParallel: true, }, }, { @@ -604,6 +609,7 @@ func Test_task_Handle_NoCatalog(t *testing.T) { handlerPhase: handler.EPhaseRunning, event: false, skipStateUpdate: true, + incrParallel: true, }, }, { @@ -619,6 +625,7 @@ func Test_task_Handle_NoCatalog(t *testing.T) { handlerPhase: handler.EPhaseUndefined, event: false, wantErr: true, + incrParallel: true, }, }, } @@ -626,7 +633,7 @@ func Test_task_Handle_NoCatalog(t *testing.T) { t.Run(tt.name, func(t *testing.T) { state := &taskNodeStateHolder{} ev := &fakeBufferedTaskEventRecorder{} - nCtx := createNodeContext(tt.args.startingPluginPhase, uint32(tt.args.startingPluginPhaseVersion), tt.args.expectedState, ev, "test", state) + nCtx := createNodeContext(tt.args.startingPluginPhase, uint32(tt.args.startingPluginPhaseVersion), tt.args.expectedState, ev, "test", state, tt.want.incrParallel) c := &pluginCatalogMocks.Client{} tk := Handler{ cfg: &config.Config{MaxErrorMessageLength: 100}, @@ -998,6 +1005,7 @@ func Test_task_Handle_Barrier(t *testing.T) { executionContext.OnGetExecutionConfig().Return(v1alpha1.ExecutionConfig{}) executionContext.OnGetEventVersion().Return(v1alpha1.EventVersion0) executionContext.OnGetParentInfo().Return(nil) + executionContext.OnIncrementParallelism().Return(1) nCtx.OnExecutionContext().Return(executionContext) nCtx.OnRawOutputPrefix().Return("s3://sandbox/") diff --git a/pkg/controller/workflow/executor.go b/pkg/controller/workflow/executor.go index 206221bbe..a9df79f05 100644 --- a/pkg/controller/workflow/executor.go +++ b/pkg/controller/workflow/executor.go @@ -119,7 +119,7 @@ func (c *workflowExecutor) handleReadyWorkflow(ctx context.Context, w *v1alpha1. logger.Infof(ctx, "Setting the MetadataDir for StartNode [%v]", dataDir) nodeStatus.SetDataDir(dataDir) nodeStatus.SetOutputDir(outputDir) - execcontext := executors.NewExecutionContext(w, w, w, nil) + execcontext := executors.NewExecutionContext(w, w, w, nil, executors.InitializeControlFlow()) s, err := c.nodeExecutor.SetInputsForStartNode(ctx, execcontext, w, executors.NewNodeLookup(w, w.GetExecutionStatus()), inputs) if err != nil { return StatusReady, err @@ -139,7 +139,7 @@ func (c *workflowExecutor) handleRunningWorkflow(ctx context.Context, w *v1alpha Code: errors.IllegalStateError.String(), Message: "Start node not found"}), nil } - execcontext := executors.NewExecutionContext(w, w, w, nil) + execcontext := executors.NewExecutionContext(w, w, w, nil, executors.InitializeControlFlow()) state, err := c.nodeExecutor.RecursiveNodeHandler(ctx, execcontext, w, w, startNode) if err != nil { return StatusRunning, err @@ -166,7 +166,7 @@ func (c *workflowExecutor) handleRunningWorkflow(ctx context.Context, w *v1alpha func (c *workflowExecutor) handleFailureNode(ctx context.Context, w *v1alpha1.FlyteWorkflow) (Status, error) { execErr := executionErrorOrDefault(w.GetExecutionStatus().GetExecutionError(), w.GetExecutionStatus().GetMessage()) errorNode := w.GetOnFailureNode() - execcontext := executors.NewExecutionContext(w, w, w, nil) + execcontext := executors.NewExecutionContext(w, w, w, nil, executors.InitializeControlFlow()) state, err := c.nodeExecutor.RecursiveNodeHandler(ctx, execcontext, w, w, errorNode) if err != nil { return StatusFailureNode(execErr), err @@ -468,7 +468,7 @@ func (c *workflowExecutor) cleanupRunningNodes(ctx context.Context, w v1alpha1.E return errors.Errorf(errors.IllegalStateError, w.GetID(), "StartNode not found in running workflow?") } - execcontext := executors.NewExecutionContext(w, w, w, nil) + execcontext := executors.NewExecutionContext(w, w, w, nil, executors.InitializeControlFlow()) if err := c.nodeExecutor.AbortHandler(ctx, execcontext, w, w, startNode, reason); err != nil { return errors.Errorf(errors.CausedByError, w.GetID(), "Failed to propagate Abort for workflow. Error: %v", err) }