diff --git a/go.mod b/go.mod index 185f13fb70..139bc2e71b 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/DiSiqueira/GoTree v1.0.1-0.20180907134536-53a8e837f295 github.com/benlaurie/objecthash v0.0.0-20180202135721-d1e3d6079fc1 github.com/fatih/color v1.10.0 - github.com/flyteorg/flyteidl v0.19.5 + github.com/flyteorg/flyteidl v0.19.14 github.com/flyteorg/flyteplugins v0.5.59 github.com/flyteorg/flytestdlib v0.3.27 github.com/ghodss/yaml v1.0.0 diff --git a/go.sum b/go.sum index 9d23b11f52..3808fc414e 100644 --- a/go.sum +++ b/go.sum @@ -230,8 +230,8 @@ github.com/fatih/color v1.10.0 h1:s36xzo75JdqLaaWoiEHk767eHiwo0598uUxyfiPkDsg= github.com/fatih/color v1.10.0/go.mod h1:ELkj/draVOlAH/xkhN6mQ50Qd0MPOk5AAr3maGEBuJM= github.com/fatih/structtag v1.2.0/go.mod h1:mBJUNpUnHmRKrKlQQlmCrh5PuhftFbNv8Ys4/aAZl94= github.com/flyteorg/flyteidl v0.19.2/go.mod h1:576W2ViEyjTpT+kEVHAGbrTP3HARNUZ/eCwrNPmdx9U= -github.com/flyteorg/flyteidl v0.19.5 h1:qNhNK6mhCTuOms7zJmBtog6bLQJhBj+iScf1IlHdqeg= -github.com/flyteorg/flyteidl v0.19.5/go.mod h1:576W2ViEyjTpT+kEVHAGbrTP3HARNUZ/eCwrNPmdx9U= +github.com/flyteorg/flyteidl v0.19.14 h1:OLg2eT9uYllcfMMjEZJoXQ+2WXcrNbUxD+yaCrz2AlI= +github.com/flyteorg/flyteidl v0.19.14/go.mod h1:576W2ViEyjTpT+kEVHAGbrTP3HARNUZ/eCwrNPmdx9U= github.com/flyteorg/flyteplugins v0.5.59 h1:Uw1xlrlx5rSTpdTMwJTo7mbqHI7X7p7CFVm3473iRjo= github.com/flyteorg/flyteplugins v0.5.59/go.mod h1:nesnW7pJhXEysFQg9TnSp36ao33ie0oA/TI4sYPaeyw= github.com/flyteorg/flytestdlib v0.3.13/go.mod h1:Tz8JCECAbX6VWGwFT6cmEQ+RJpZ/6L9pswu3fzWs220= diff --git a/pkg/apis/flyteworkflow/v1alpha1/execution_config.go b/pkg/apis/flyteworkflow/v1alpha1/execution_config.go index 6de14607a5..bb6547ba2f 100644 --- a/pkg/apis/flyteworkflow/v1alpha1/execution_config.go +++ b/pkg/apis/flyteworkflow/v1alpha1/execution_config.go @@ -1,6 +1,8 @@ package v1alpha1 -import "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" +import ( + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" +) // This contains an OutputLocationPrefix. When running against AWS, this should be something of the form // s3://my-bucket, or s3://my-bucket/ A sharding string will automatically be appended to this prefix before @@ -20,6 +22,8 @@ type ExecutionConfig struct { TaskPluginImpls map[string]TaskPluginOverride // Can be used to control the number of parallel nodes to run within the workflow. This is useful to achieve fairness. MaxParallelism uint32 + // Defines execution behavior for processing nodes. + RecoveryExecution WorkflowExecutionIdentifier } type TaskPluginOverride struct { diff --git a/pkg/apis/flyteworkflow/v1alpha1/iface.go b/pkg/apis/flyteworkflow/v1alpha1/iface.go index f54c502641..bd2094982a 100644 --- a/pkg/apis/flyteworkflow/v1alpha1/iface.go +++ b/pkg/apis/flyteworkflow/v1alpha1/iface.go @@ -64,6 +64,7 @@ const ( NodePhaseTimingOut NodePhaseTimedOut NodePhaseDynamicRunning + NodePhaseRecovered ) func (p NodePhase) String() string { @@ -92,6 +93,8 @@ func (p NodePhase) String() string { return "RetryableFailure" case NodePhaseDynamicRunning: return "DynamicRunning" + case NodePhaseRecovered: + return "NodePhaseRecovered" } return "Unknown" diff --git a/pkg/apis/flyteworkflow/v1alpha1/node_status.go b/pkg/apis/flyteworkflow/v1alpha1/node_status.go index aaed1357a7..c720f94c14 100644 --- a/pkg/apis/flyteworkflow/v1alpha1/node_status.go +++ b/pkg/apis/flyteworkflow/v1alpha1/node_status.go @@ -397,7 +397,7 @@ func (in *NodeStatus) GetMessage() string { } func IsPhaseTerminal(phase NodePhase) bool { - return phase == NodePhaseSucceeded || phase == NodePhaseFailed || phase == NodePhaseSkipped || phase == NodePhaseTimedOut + return phase == NodePhaseSucceeded || phase == NodePhaseFailed || phase == NodePhaseSkipped || phase == NodePhaseTimedOut || phase == NodePhaseRecovered } func (in *NodeStatus) GetOrCreateTaskStatus() MutableTaskNodeStatus { @@ -576,7 +576,7 @@ func (in *NodeStatus) GetNodeExecutionStatus(ctx context.Context, id NodeID) Exe } func (in *NodeStatus) IsTerminated() bool { - return in.GetPhase() == NodePhaseFailed || in.GetPhase() == NodePhaseSkipped || in.GetPhase() == NodePhaseSucceeded + return in.GetPhase() == NodePhaseFailed || in.GetPhase() == NodePhaseSkipped || in.GetPhase() == NodePhaseSucceeded || in.GetPhase() == NodePhaseRecovered } func (in *NodeStatus) GetDataDir() DataReference { diff --git a/pkg/compiler/common/mocks/workflow_builder.go b/pkg/compiler/common/mocks/workflow_builder.go index 806bf7789d..7764002775 100644 --- a/pkg/compiler/common/mocks/workflow_builder.go +++ b/pkg/compiler/common/mocks/workflow_builder.go @@ -363,6 +363,40 @@ func (_m *WorkflowBuilder) GetNodes() common.NodeIndex { return r0 } +type WorkflowBuilder_GetOrCreateNodeBuilder struct { + *mock.Call +} + +func (_m WorkflowBuilder_GetOrCreateNodeBuilder) Return(_a0 common.NodeBuilder) *WorkflowBuilder_GetOrCreateNodeBuilder { + return &WorkflowBuilder_GetOrCreateNodeBuilder{Call: _m.Call.Return(_a0)} +} + +func (_m *WorkflowBuilder) OnGetOrCreateNodeBuilder(n *core.Node) *WorkflowBuilder_GetOrCreateNodeBuilder { + c := _m.On("GetOrCreateNodeBuilder", n) + return &WorkflowBuilder_GetOrCreateNodeBuilder{Call: c} +} + +func (_m *WorkflowBuilder) OnGetOrCreateNodeBuilderMatch(matchers ...interface{}) *WorkflowBuilder_GetOrCreateNodeBuilder { + c := _m.On("GetOrCreateNodeBuilder", matchers...) + return &WorkflowBuilder_GetOrCreateNodeBuilder{Call: c} +} + +// GetOrCreateNodeBuilder provides a mock function with given fields: n +func (_m *WorkflowBuilder) GetOrCreateNodeBuilder(n *core.Node) common.NodeBuilder { + ret := _m.Called(n) + + var r0 common.NodeBuilder + if rf, ok := ret.Get(0).(func(*core.Node) common.NodeBuilder); ok { + r0 = rf(n) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(common.NodeBuilder) + } + } + + return r0 +} + type WorkflowBuilder_GetSubWorkflow struct { *mock.Call } @@ -513,40 +547,6 @@ func (_m *WorkflowBuilder) GetUpstreamNodes() common.StringAdjacencyList { return r0 } -type WorkflowBuilder_GetOrCreateNodeBuilder struct { - *mock.Call -} - -func (_m WorkflowBuilder_GetOrCreateNodeBuilder) Return(_a0 common.NodeBuilder) *WorkflowBuilder_GetOrCreateNodeBuilder { - return &WorkflowBuilder_GetOrCreateNodeBuilder{Call: _m.Call.Return(_a0)} -} - -func (_m *WorkflowBuilder) OnGetOrCreateNodeBuilder(n *core.Node) *WorkflowBuilder_GetOrCreateNodeBuilder { - c := _m.On("GetOrCreateNodeBuilder", n) - return &WorkflowBuilder_GetOrCreateNodeBuilder{Call: c} -} - -func (_m *WorkflowBuilder) OnGetOrCreateNodeBuilderMatch(matchers ...interface{}) *WorkflowBuilder_GetOrCreateNodeBuilder { - c := _m.On("GetOrCreateNodeBuilder", matchers...) - return &WorkflowBuilder_GetOrCreateNodeBuilder{Call: c} -} - -// GetOrCreateNodeBuilder provides a mock function with given fields: n -func (_m *WorkflowBuilder) GetOrCreateNodeBuilder(n *core.Node) common.NodeBuilder { - ret := _m.Called(n) - - var r0 common.NodeBuilder - if rf, ok := ret.Get(0).(func(*core.Node) common.NodeBuilder); ok { - r0 = rf(n) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(common.NodeBuilder) - } - } - - return r0 -} - // StoreCompiledSubWorkflow provides a mock function with given fields: id, compiledWorkflow func (_m *WorkflowBuilder) StoreCompiledSubWorkflow(id core.Identifier, compiledWorkflow *core.CompiledWorkflow) { _m.Called(id, compiledWorkflow) diff --git a/pkg/compiler/test/testdata/app-workflows-work-one-python-task-w-f.pb b/pkg/compiler/test/testdata/app-workflows-work-one-python-task-w-f.pb index 039e2fadd2..c36fdb9d4f 100755 Binary files a/pkg/compiler/test/testdata/app-workflows-work-one-python-task-w-f.pb and b/pkg/compiler/test/testdata/app-workflows-work-one-python-task-w-f.pb differ diff --git a/pkg/compiler/test/testdata/app-workflows-work-one-python-task-w-f.yaml b/pkg/compiler/test/testdata/app-workflows-work-one-python-task-w-f.yaml index 3aee111b74..7a66bd6461 100644 --- a/pkg/compiler/test/testdata/app-workflows-work-one-python-task-w-f.yaml +++ b/pkg/compiler/test/testdata/app-workflows-work-one-python-task-w-f.yaml @@ -24,6 +24,12 @@ tasks: value: testValue2 - key: testKey3 value: testValue3 + - key: testKey1 + value: testValue1 + - key: testKey2 + value: testValue2 + - key: testKey3 + value: testValue3 image: myflytecontainer:abc123 resources: {} id: diff --git a/pkg/compiler/test/testdata/branch/k8s/5_myapp.workflows.cereal.mycereal_2.json b/pkg/compiler/test/testdata/branch/k8s/5_myapp.workflows.cereal.mycereal_2.json index a51a60cb55..0928fa163e 100755 --- a/pkg/compiler/test/testdata/branch/k8s/5_myapp.workflows.cereal.mycereal_2.json +++ b/pkg/compiler/test/testdata/branch/k8s/5_myapp.workflows.cereal.mycereal_2.json @@ -646,6 +646,7 @@ "rawOutputDataConfig": {}, "executionConfig": { "TaskPluginImpls": null, - "MaxParallelism": 0 + "MaxParallelism": 0, + "RecoveryExecution": {} } } \ No newline at end of file diff --git a/pkg/compiler/test/testdata/branch/k8s/mycereal_condition_has_no_deps.json b/pkg/compiler/test/testdata/branch/k8s/mycereal_condition_has_no_deps.json index 08c19b90fe..eb8d7830a5 100755 --- a/pkg/compiler/test/testdata/branch/k8s/mycereal_condition_has_no_deps.json +++ b/pkg/compiler/test/testdata/branch/k8s/mycereal_condition_has_no_deps.json @@ -650,6 +650,7 @@ "rawOutputDataConfig": {}, "executionConfig": { "TaskPluginImpls": null, - "MaxParallelism": 0 + "MaxParallelism": 0, + "RecoveryExecution": {} } } \ No newline at end of file diff --git a/pkg/compiler/test/testdata/branch/k8s/success_1.json b/pkg/compiler/test/testdata/branch/k8s/success_1.json index 44452b2dcb..5969c44f76 100755 --- a/pkg/compiler/test/testdata/branch/k8s/success_1.json +++ b/pkg/compiler/test/testdata/branch/k8s/success_1.json @@ -387,6 +387,7 @@ "rawOutputDataConfig": {}, "executionConfig": { "TaskPluginImpls": null, - "MaxParallelism": 0 + "MaxParallelism": 0, + "RecoveryExecution": {} } } \ No newline at end of file diff --git a/pkg/compiler/test/testdata/branch/k8s/success_10_simple.json b/pkg/compiler/test/testdata/branch/k8s/success_10_simple.json index 5515aad3e0..787cd4ae43 100755 --- a/pkg/compiler/test/testdata/branch/k8s/success_10_simple.json +++ b/pkg/compiler/test/testdata/branch/k8s/success_10_simple.json @@ -593,6 +593,7 @@ "rawOutputDataConfig": {}, "executionConfig": { "TaskPluginImpls": null, - "MaxParallelism": 0 + "MaxParallelism": 0, + "RecoveryExecution": {} } } \ No newline at end of file diff --git a/pkg/compiler/test/testdata/branch/k8s/success_2.json b/pkg/compiler/test/testdata/branch/k8s/success_2.json index c27c891e12..17affe4a1d 100755 --- a/pkg/compiler/test/testdata/branch/k8s/success_2.json +++ b/pkg/compiler/test/testdata/branch/k8s/success_2.json @@ -428,6 +428,7 @@ "rawOutputDataConfig": {}, "executionConfig": { "TaskPluginImpls": null, - "MaxParallelism": 0 + "MaxParallelism": 0, + "RecoveryExecution": {} } } \ No newline at end of file diff --git a/pkg/compiler/test/testdata/branch/k8s/success_3.json b/pkg/compiler/test/testdata/branch/k8s/success_3.json index 05c6d31ad4..841e2ef2ee 100755 --- a/pkg/compiler/test/testdata/branch/k8s/success_3.json +++ b/pkg/compiler/test/testdata/branch/k8s/success_3.json @@ -411,6 +411,7 @@ "rawOutputDataConfig": {}, "executionConfig": { "TaskPluginImpls": null, - "MaxParallelism": 0 + "MaxParallelism": 0, + "RecoveryExecution": {} } } \ No newline at end of file diff --git a/pkg/compiler/test/testdata/branch/k8s/success_4.json b/pkg/compiler/test/testdata/branch/k8s/success_4.json index 67153f43e6..3c96d00aa6 100755 --- a/pkg/compiler/test/testdata/branch/k8s/success_4.json +++ b/pkg/compiler/test/testdata/branch/k8s/success_4.json @@ -497,6 +497,7 @@ "rawOutputDataConfig": {}, "executionConfig": { "TaskPluginImpls": null, - "MaxParallelism": 0 + "MaxParallelism": 0, + "RecoveryExecution": {} } } \ No newline at end of file diff --git a/pkg/compiler/test/testdata/branch/k8s/success_5.json b/pkg/compiler/test/testdata/branch/k8s/success_5.json index e894306b0d..a7b46f188f 100755 --- a/pkg/compiler/test/testdata/branch/k8s/success_5.json +++ b/pkg/compiler/test/testdata/branch/k8s/success_5.json @@ -524,6 +524,7 @@ "rawOutputDataConfig": {}, "executionConfig": { "TaskPluginImpls": null, - "MaxParallelism": 0 + "MaxParallelism": 0, + "RecoveryExecution": {} } } \ No newline at end of file diff --git a/pkg/compiler/test/testdata/branch/k8s/success_6.json b/pkg/compiler/test/testdata/branch/k8s/success_6.json index 0553ae0d4e..6d5588cc88 100755 --- a/pkg/compiler/test/testdata/branch/k8s/success_6.json +++ b/pkg/compiler/test/testdata/branch/k8s/success_6.json @@ -356,6 +356,7 @@ "rawOutputDataConfig": {}, "executionConfig": { "TaskPluginImpls": null, - "MaxParallelism": 0 + "MaxParallelism": 0, + "RecoveryExecution": {} } } \ No newline at end of file diff --git a/pkg/compiler/test/testdata/branch/k8s/success_7_nested.json b/pkg/compiler/test/testdata/branch/k8s/success_7_nested.json index 57fefb0cce..d931758fbd 100755 --- a/pkg/compiler/test/testdata/branch/k8s/success_7_nested.json +++ b/pkg/compiler/test/testdata/branch/k8s/success_7_nested.json @@ -440,6 +440,7 @@ "rawOutputDataConfig": {}, "executionConfig": { "TaskPluginImpls": null, - "MaxParallelism": 0 + "MaxParallelism": 0, + "RecoveryExecution": {} } } \ No newline at end of file diff --git a/pkg/compiler/test/testdata/branch/k8s/success_8_nested.json b/pkg/compiler/test/testdata/branch/k8s/success_8_nested.json index 9290bf2c34..a3ee4dcbe9 100755 --- a/pkg/compiler/test/testdata/branch/k8s/success_8_nested.json +++ b/pkg/compiler/test/testdata/branch/k8s/success_8_nested.json @@ -522,6 +522,7 @@ "rawOutputDataConfig": {}, "executionConfig": { "TaskPluginImpls": null, - "MaxParallelism": 0 + "MaxParallelism": 0, + "RecoveryExecution": {} } } \ No newline at end of file diff --git a/pkg/compiler/test/testdata/branch/k8s/success_9_nested.json b/pkg/compiler/test/testdata/branch/k8s/success_9_nested.json index d88b954812..0a9bc398f2 100755 --- a/pkg/compiler/test/testdata/branch/k8s/success_9_nested.json +++ b/pkg/compiler/test/testdata/branch/k8s/success_9_nested.json @@ -545,6 +545,7 @@ "rawOutputDataConfig": {}, "executionConfig": { "TaskPluginImpls": null, - "MaxParallelism": 0 + "MaxParallelism": 0, + "RecoveryExecution": {} } } \ No newline at end of file diff --git a/pkg/controller/config/config_flags.go b/pkg/controller/config/config_flags.go index 4fbe0a2dc8..9fae047b63 100755 --- a/pkg/controller/config/config_flags.go +++ b/pkg/controller/config/config_flags.go @@ -28,6 +28,15 @@ func (Config) elemValueOrNil(v interface{}) interface{} { return v } +func (Config) mustJsonMarshal(v interface{}) string { + raw, err := json.Marshal(v) + if err != nil { + panic(err) + } + + return string(raw) +} + func (Config) mustMarshalJSON(v json.Marshaler) string { raw, err := v.MarshalJSON() if err != nil { diff --git a/pkg/controller/config/config_flags_test.go b/pkg/controller/config/config_flags_test.go index a178be1f34..5ff467b335 100755 --- a/pkg/controller/config/config_flags_test.go +++ b/pkg/controller/config/config_flags_test.go @@ -84,7 +84,7 @@ func testDecodeJson_Config(t *testing.T, val, result interface{}) { assert.NoError(t, decode_Config(val, result)) } -func testDecodeSlice_Config(t *testing.T, vStringSlice, result interface{}) { +func testDecodeRaw_Config(t *testing.T, vStringSlice, result interface{}) { assert.NoError(t, decode_Config(vStringSlice, result)) } @@ -100,14 +100,6 @@ func TestConfig_SetFlags(t *testing.T) { assert.True(t, cmdFlags.HasFlags()) t.Run("Test_kube-config", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vString, err := cmdFlags.GetString("kube-config"); err == nil { - assert.Equal(t, string(defaultConfig.KubeConfigPath), vString) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := "1" @@ -122,14 +114,6 @@ func TestConfig_SetFlags(t *testing.T) { }) }) t.Run("Test_master", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vString, err := cmdFlags.GetString("master"); err == nil { - assert.Equal(t, string(defaultConfig.MasterURL), vString) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := "1" @@ -144,14 +128,6 @@ func TestConfig_SetFlags(t *testing.T) { }) }) t.Run("Test_workers", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vInt, err := cmdFlags.GetInt("workers"); err == nil { - assert.Equal(t, int(defaultConfig.Workers), vInt) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := "1" @@ -166,14 +142,6 @@ func TestConfig_SetFlags(t *testing.T) { }) }) t.Run("Test_workflow-reeval-duration", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vString, err := cmdFlags.GetString("workflow-reeval-duration"); err == nil { - assert.Equal(t, string(defaultConfig.WorkflowReEval.String()), vString) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := defaultConfig.WorkflowReEval.String() @@ -188,14 +156,6 @@ func TestConfig_SetFlags(t *testing.T) { }) }) t.Run("Test_downstream-eval-duration", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vString, err := cmdFlags.GetString("downstream-eval-duration"); err == nil { - assert.Equal(t, string(defaultConfig.DownstreamEval.String()), vString) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := defaultConfig.DownstreamEval.String() @@ -210,14 +170,6 @@ func TestConfig_SetFlags(t *testing.T) { }) }) t.Run("Test_limit-namespace", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vString, err := cmdFlags.GetString("limit-namespace"); err == nil { - assert.Equal(t, string(defaultConfig.LimitNamespace), vString) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := "1" @@ -232,14 +184,6 @@ func TestConfig_SetFlags(t *testing.T) { }) }) t.Run("Test_prof-port", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vString, err := cmdFlags.GetString("prof-port"); err == nil { - assert.Equal(t, string(defaultConfig.ProfilerPort.String()), vString) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := defaultConfig.ProfilerPort.String() @@ -254,14 +198,6 @@ func TestConfig_SetFlags(t *testing.T) { }) }) t.Run("Test_metadata-prefix", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vString, err := cmdFlags.GetString("metadata-prefix"); err == nil { - assert.Equal(t, string(defaultConfig.MetadataPrefix), vString) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := "1" @@ -276,14 +212,6 @@ func TestConfig_SetFlags(t *testing.T) { }) }) t.Run("Test_rawoutput-prefix", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vString, err := cmdFlags.GetString("rawoutput-prefix"); err == nil { - assert.Equal(t, string(defaultConfig.DefaultRawOutputPrefix), vString) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := "1" @@ -298,14 +226,6 @@ func TestConfig_SetFlags(t *testing.T) { }) }) t.Run("Test_queue.type", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vString, err := cmdFlags.GetString("queue.type"); err == nil { - assert.Equal(t, string(defaultConfig.Queue.Type), vString) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := "1" @@ -320,14 +240,6 @@ func TestConfig_SetFlags(t *testing.T) { }) }) t.Run("Test_queue.queue.type", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vString, err := cmdFlags.GetString("queue.queue.type"); err == nil { - assert.Equal(t, string(defaultConfig.Queue.Queue.Type), vString) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := "1" @@ -342,14 +254,6 @@ func TestConfig_SetFlags(t *testing.T) { }) }) t.Run("Test_queue.queue.base-delay", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vString, err := cmdFlags.GetString("queue.queue.base-delay"); err == nil { - assert.Equal(t, string(defaultConfig.Queue.Queue.BaseDelay.String()), vString) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := defaultConfig.Queue.Queue.BaseDelay.String() @@ -364,14 +268,6 @@ func TestConfig_SetFlags(t *testing.T) { }) }) t.Run("Test_queue.queue.max-delay", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vString, err := cmdFlags.GetString("queue.queue.max-delay"); err == nil { - assert.Equal(t, string(defaultConfig.Queue.Queue.MaxDelay.String()), vString) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := defaultConfig.Queue.Queue.MaxDelay.String() @@ -386,14 +282,6 @@ func TestConfig_SetFlags(t *testing.T) { }) }) t.Run("Test_queue.queue.rate", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vInt64, err := cmdFlags.GetInt64("queue.queue.rate"); err == nil { - assert.Equal(t, int64(defaultConfig.Queue.Queue.Rate), vInt64) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := "1" @@ -408,14 +296,6 @@ func TestConfig_SetFlags(t *testing.T) { }) }) t.Run("Test_queue.queue.capacity", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vInt, err := cmdFlags.GetInt("queue.queue.capacity"); err == nil { - assert.Equal(t, int(defaultConfig.Queue.Queue.Capacity), vInt) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := "1" @@ -430,14 +310,6 @@ func TestConfig_SetFlags(t *testing.T) { }) }) t.Run("Test_queue.sub-queue.type", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vString, err := cmdFlags.GetString("queue.sub-queue.type"); err == nil { - assert.Equal(t, string(defaultConfig.Queue.Sub.Type), vString) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := "1" @@ -452,14 +324,6 @@ func TestConfig_SetFlags(t *testing.T) { }) }) t.Run("Test_queue.sub-queue.base-delay", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vString, err := cmdFlags.GetString("queue.sub-queue.base-delay"); err == nil { - assert.Equal(t, string(defaultConfig.Queue.Sub.BaseDelay.String()), vString) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := defaultConfig.Queue.Sub.BaseDelay.String() @@ -474,14 +338,6 @@ func TestConfig_SetFlags(t *testing.T) { }) }) t.Run("Test_queue.sub-queue.max-delay", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vString, err := cmdFlags.GetString("queue.sub-queue.max-delay"); err == nil { - assert.Equal(t, string(defaultConfig.Queue.Sub.MaxDelay.String()), vString) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := defaultConfig.Queue.Sub.MaxDelay.String() @@ -496,14 +352,6 @@ func TestConfig_SetFlags(t *testing.T) { }) }) t.Run("Test_queue.sub-queue.rate", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vInt64, err := cmdFlags.GetInt64("queue.sub-queue.rate"); err == nil { - assert.Equal(t, int64(defaultConfig.Queue.Sub.Rate), vInt64) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := "1" @@ -518,14 +366,6 @@ func TestConfig_SetFlags(t *testing.T) { }) }) t.Run("Test_queue.sub-queue.capacity", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vInt, err := cmdFlags.GetInt("queue.sub-queue.capacity"); err == nil { - assert.Equal(t, int(defaultConfig.Queue.Sub.Capacity), vInt) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := "1" @@ -540,14 +380,6 @@ func TestConfig_SetFlags(t *testing.T) { }) }) t.Run("Test_queue.batching-interval", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vString, err := cmdFlags.GetString("queue.batching-interval"); err == nil { - assert.Equal(t, string(defaultConfig.Queue.BatchingInterval.String()), vString) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := defaultConfig.Queue.BatchingInterval.String() @@ -562,14 +394,6 @@ func TestConfig_SetFlags(t *testing.T) { }) }) t.Run("Test_queue.batch-size", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vInt, err := cmdFlags.GetInt("queue.batch-size"); err == nil { - assert.Equal(t, int(defaultConfig.Queue.BatchSize), vInt) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := "1" @@ -584,14 +408,6 @@ func TestConfig_SetFlags(t *testing.T) { }) }) t.Run("Test_metrics-prefix", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vString, err := cmdFlags.GetString("metrics-prefix"); err == nil { - assert.Equal(t, string(defaultConfig.MetricsPrefix), vString) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := "1" @@ -606,14 +422,6 @@ func TestConfig_SetFlags(t *testing.T) { }) }) t.Run("Test_enable-admin-launcher", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vBool, err := cmdFlags.GetBool("enable-admin-launcher"); err == nil { - assert.Equal(t, bool(defaultConfig.EnableAdminLauncher), vBool) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := "1" @@ -628,14 +436,6 @@ func TestConfig_SetFlags(t *testing.T) { }) }) t.Run("Test_max-workflow-retries", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vInt, err := cmdFlags.GetInt("max-workflow-retries"); err == nil { - assert.Equal(t, int(defaultConfig.MaxWorkflowRetries), vInt) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := "1" @@ -650,14 +450,6 @@ func TestConfig_SetFlags(t *testing.T) { }) }) t.Run("Test_max-ttl-hours", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vInt, err := cmdFlags.GetInt("max-ttl-hours"); err == nil { - assert.Equal(t, int(defaultConfig.MaxTTLInHours), vInt) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := "1" @@ -672,14 +464,6 @@ func TestConfig_SetFlags(t *testing.T) { }) }) t.Run("Test_gc-interval", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vString, err := cmdFlags.GetString("gc-interval"); err == nil { - assert.Equal(t, string(defaultConfig.GCInterval.String()), vString) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := defaultConfig.GCInterval.String() @@ -694,14 +478,6 @@ func TestConfig_SetFlags(t *testing.T) { }) }) t.Run("Test_leader-election.enabled", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vBool, err := cmdFlags.GetBool("leader-election.enabled"); err == nil { - assert.Equal(t, bool(defaultConfig.LeaderElection.Enabled), vBool) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := "1" @@ -716,14 +492,6 @@ func TestConfig_SetFlags(t *testing.T) { }) }) t.Run("Test_leader-election.lock-config-map.Namespace", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vString, err := cmdFlags.GetString("leader-election.lock-config-map.Namespace"); err == nil { - assert.Equal(t, string(defaultConfig.LeaderElection.LockConfigMap.Namespace), vString) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := "1" @@ -738,14 +506,6 @@ func TestConfig_SetFlags(t *testing.T) { }) }) t.Run("Test_leader-election.lock-config-map.Name", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vString, err := cmdFlags.GetString("leader-election.lock-config-map.Name"); err == nil { - assert.Equal(t, string(defaultConfig.LeaderElection.LockConfigMap.Name), vString) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := "1" @@ -760,14 +520,6 @@ func TestConfig_SetFlags(t *testing.T) { }) }) t.Run("Test_leader-election.lease-duration", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vString, err := cmdFlags.GetString("leader-election.lease-duration"); err == nil { - assert.Equal(t, string(defaultConfig.LeaderElection.LeaseDuration.String()), vString) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := defaultConfig.LeaderElection.LeaseDuration.String() @@ -782,14 +534,6 @@ func TestConfig_SetFlags(t *testing.T) { }) }) t.Run("Test_leader-election.renew-deadline", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vString, err := cmdFlags.GetString("leader-election.renew-deadline"); err == nil { - assert.Equal(t, string(defaultConfig.LeaderElection.RenewDeadline.String()), vString) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := defaultConfig.LeaderElection.RenewDeadline.String() @@ -804,14 +548,6 @@ func TestConfig_SetFlags(t *testing.T) { }) }) t.Run("Test_leader-election.retry-period", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vString, err := cmdFlags.GetString("leader-election.retry-period"); err == nil { - assert.Equal(t, string(defaultConfig.LeaderElection.RetryPeriod.String()), vString) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := defaultConfig.LeaderElection.RetryPeriod.String() @@ -826,14 +562,6 @@ func TestConfig_SetFlags(t *testing.T) { }) }) t.Run("Test_publish-k8s-events", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vBool, err := cmdFlags.GetBool("publish-k8s-events"); err == nil { - assert.Equal(t, bool(defaultConfig.PublishK8sEvents), vBool) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := "1" @@ -848,14 +576,6 @@ func TestConfig_SetFlags(t *testing.T) { }) }) t.Run("Test_max-output-size-bytes", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vInt64, err := cmdFlags.GetInt64("max-output-size-bytes"); err == nil { - assert.Equal(t, int64(defaultConfig.MaxDatasetSizeBytes), vInt64) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := "1" @@ -870,14 +590,6 @@ func TestConfig_SetFlags(t *testing.T) { }) }) t.Run("Test_kube-client-config.burst", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vInt, err := cmdFlags.GetInt("kube-client-config.burst"); err == nil { - assert.Equal(t, int(defaultConfig.KubeConfig.Burst), vInt) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := "1" @@ -892,14 +604,6 @@ func TestConfig_SetFlags(t *testing.T) { }) }) t.Run("Test_kube-client-config.timeout", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vString, err := cmdFlags.GetString("kube-client-config.timeout"); err == nil { - assert.Equal(t, string(defaultConfig.KubeConfig.Timeout.String()), vString) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := defaultConfig.KubeConfig.Timeout.String() @@ -914,14 +618,6 @@ func TestConfig_SetFlags(t *testing.T) { }) }) t.Run("Test_node-config.default-deadlines.node-execution-deadline", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vString, err := cmdFlags.GetString("node-config.default-deadlines.node-execution-deadline"); err == nil { - assert.Equal(t, string(defaultConfig.NodeConfig.DefaultDeadlines.DefaultNodeExecutionDeadline.String()), vString) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := defaultConfig.NodeConfig.DefaultDeadlines.DefaultNodeExecutionDeadline.String() @@ -936,14 +632,6 @@ func TestConfig_SetFlags(t *testing.T) { }) }) t.Run("Test_node-config.default-deadlines.node-active-deadline", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vString, err := cmdFlags.GetString("node-config.default-deadlines.node-active-deadline"); err == nil { - assert.Equal(t, string(defaultConfig.NodeConfig.DefaultDeadlines.DefaultNodeActiveDeadline.String()), vString) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := defaultConfig.NodeConfig.DefaultDeadlines.DefaultNodeActiveDeadline.String() @@ -958,14 +646,6 @@ func TestConfig_SetFlags(t *testing.T) { }) }) t.Run("Test_node-config.default-deadlines.workflow-active-deadline", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vString, err := cmdFlags.GetString("node-config.default-deadlines.workflow-active-deadline"); err == nil { - assert.Equal(t, string(defaultConfig.NodeConfig.DefaultDeadlines.DefaultWorkflowActiveDeadline.String()), vString) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := defaultConfig.NodeConfig.DefaultDeadlines.DefaultWorkflowActiveDeadline.String() @@ -980,14 +660,6 @@ func TestConfig_SetFlags(t *testing.T) { }) }) t.Run("Test_node-config.max-node-retries-system-failures", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vInt64, err := cmdFlags.GetInt64("node-config.max-node-retries-system-failures"); err == nil { - assert.Equal(t, int64(defaultConfig.NodeConfig.MaxNodeRetriesOnSystemFailures), vInt64) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := "1" @@ -1002,14 +674,6 @@ func TestConfig_SetFlags(t *testing.T) { }) }) t.Run("Test_node-config.interruptible-failure-threshold", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vInt64, err := cmdFlags.GetInt64("node-config.interruptible-failure-threshold"); err == nil { - assert.Equal(t, int64(defaultConfig.NodeConfig.InterruptibleFailureThreshold), vInt64) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := "1" @@ -1024,14 +688,6 @@ func TestConfig_SetFlags(t *testing.T) { }) }) t.Run("Test_max-streak-length", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vInt, err := cmdFlags.GetInt("max-streak-length"); err == nil { - assert.Equal(t, int(defaultConfig.MaxStreakLength), vInt) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := "1" diff --git a/pkg/controller/controller.go b/pkg/controller/controller.go index f3d96ed772..e56898bcea 100644 --- a/pkg/controller/controller.go +++ b/pkg/controller/controller.go @@ -46,6 +46,7 @@ import ( informers "github.com/flyteorg/flytepropeller/pkg/client/informers/externalversions" lister "github.com/flyteorg/flytepropeller/pkg/client/listers/flyteworkflow/v1alpha1" "github.com/flyteorg/flytepropeller/pkg/controller/nodes" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/recovery" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/subworkflow/launchplan" "github.com/flyteorg/flytepropeller/pkg/controller/workflow" ) @@ -317,14 +318,13 @@ func getAdminClient(ctx context.Context) (client service.AdminServiceClient, err func New(ctx context.Context, cfg *config.Config, kubeclientset kubernetes.Interface, flytepropellerClientset clientset.Interface, flyteworkflowInformerFactory informers.SharedInformerFactory, kubeClient executors.Client, scope promutils.Scope) (*Controller, error) { + adminClient, err := getAdminClient(ctx) + if err != nil { + logger.Errorf(ctx, "failed to initialize Admin client, err :%s", err.Error()) + return nil, err + } var launchPlanActor launchplan.FlyteAdmin if cfg.EnableAdminLauncher { - adminClient, err := getAdminClient(ctx) - if err != nil { - logger.Errorf(ctx, "failed to initialize Admin client, err :%s", err.Error()) - return nil, err - } - launchPlanActor, err = launchplan.NewAdminLaunchPlanExecutor(ctx, adminClient, cfg.DownstreamEval.Duration, launchplan.GetAdminConfig(), scope.NewSubScope("admin_launcher")) if err != nil { @@ -421,7 +421,7 @@ func New(ctx context.Context, cfg *config.Config, kubeclientset kubernetes.Inter nodeExecutor, err := nodes.NewExecutor(ctx, cfg.NodeConfig, store, controller.enqueueWorkflowForNodeUpdates, eventSink, launchPlanActor, launchPlanActor, cfg.MaxDatasetSizeBytes, - storage.DataReference(cfg.DefaultRawOutputPrefix), kubeClient, catalogClient, scope) + storage.DataReference(cfg.DefaultRawOutputPrefix), kubeClient, catalogClient, recovery.NewClient(adminClient), scope) if err != nil { return nil, errors.Wrapf(err, "Failed to create Controller.") } diff --git a/pkg/controller/executors/node.go b/pkg/controller/executors/node.go index 1d32066dc2..a8f738c3eb 100644 --- a/pkg/controller/executors/node.go +++ b/pkg/controller/executors/node.go @@ -35,6 +35,8 @@ const ( NodePhaseTimingOut // Node failed because execution timed out NodePhaseTimedOut + // Node recovered from a prior execution. + NodePhaseRecovered ) func (p NodePhase) String() string { @@ -55,6 +57,8 @@ func (p NodePhase) String() string { return "Undefined" case NodePhaseTimedOut: return "NodePhaseTimedOut" + case NodePhaseRecovered: + return "NodePhaseRecovered" } return fmt.Sprintf("Unknown - %d", p) } @@ -110,6 +114,7 @@ var NodeStatusSuccess = NodeStatus{NodePhase: NodePhaseSuccess} var NodeStatusComplete = NodeStatus{NodePhase: NodePhaseComplete} var NodeStatusUndefined = NodeStatus{NodePhase: NodePhaseUndefined} var NodeStatusTimedOut = NodeStatus{NodePhase: NodePhaseTimedOut} +var NodeStatusRecovered = NodeStatus{NodePhase: NodePhaseRecovered} func NodeStatusFailed(err *core.ExecutionError) NodeStatus { return NodeStatus{NodePhase: NodePhaseFailed, Err: err} diff --git a/pkg/controller/nodes/dynamic/dynamic_workflow_test.go b/pkg/controller/nodes/dynamic/dynamic_workflow_test.go index 08184f07af..8fd2954631 100644 --- a/pkg/controller/nodes/dynamic/dynamic_workflow_test.go +++ b/pkg/controller/nodes/dynamic/dynamic_workflow_test.go @@ -497,7 +497,7 @@ func Test_dynamicNodeHandler_buildContextualDynamicWorkflow_withLaunchPlans(t *t composedPBStore.OnWriteRawMatch( mock.MatchedBy(func(ctx context.Context) bool { return true }), storage.DataReference("s3://my-s3-bucket/foo/bar/futures_compiled.pb"), - int64(1169), + int64(1192), storage.Options{}, mock.MatchedBy(func(rdr *bytes.Reader) bool { return true })).Return(errors.New("foo")) @@ -563,6 +563,9 @@ func Test_dynamicNodeHandler_buildContextualDynamicWorkflow_withLaunchPlans(t *t immutableParentInfo.OnCurrentAttempt().Return(uint32(2)) execContext.OnGetParentInfo().Return(&immutableParentInfo) execContext.OnGetEventVersion().Return(v1alpha1.EventVersion1) + execContext.OnGetExecutionConfig().Return(v1alpha1.ExecutionConfig{ + RecoveryExecution: v1alpha1.WorkflowExecutionIdentifier{}, + }) nCtx.OnExecutionContext().Return(execContext) dCtx, err := d.buildContextualDynamicWorkflow(ctx, nCtx) diff --git a/pkg/controller/nodes/errors/codes.go b/pkg/controller/nodes/errors/codes.go index 47fe612fb0..df2be215c3 100644 --- a/pkg/controller/nodes/errors/codes.go +++ b/pkg/controller/nodes/errors/codes.go @@ -21,6 +21,7 @@ const ( RemoteChildWorkflowExecutionFailed ErrorCode = "RemoteChildWorkflowExecutionFailed" NoBranchTakenError ErrorCode = "NoBranchTakenError" OutputsNotFoundError ErrorCode = "OutputsNotFoundError" + InputsNotFoundError ErrorCode = "InputsNotFoundError" StorageError ErrorCode = "StorageError" EventRecordingFailed ErrorCode = "EventRecordingFailed" CatalogCallFailed ErrorCode = "CatalogCallFailed" diff --git a/pkg/controller/nodes/executor.go b/pkg/controller/nodes/executor.go index f6508ee195..b8c9e97a8f 100644 --- a/pkg/controller/nodes/executor.go +++ b/pkg/controller/nodes/executor.go @@ -21,6 +21,10 @@ import ( "fmt" "time" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/recovery" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/common" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/ioutils" @@ -55,6 +59,7 @@ type nodeMetrics struct { Scope promutils.Scope FailureDuration labeled.StopWatch SuccessDuration labeled.StopWatch + RecoveryDuration labeled.StopWatch UserErrorDuration labeled.StopWatch SystemErrorDuration labeled.StopWatch UnknownErrorDuration labeled.StopWatch @@ -94,6 +99,7 @@ type nodeExecutor struct { interruptibleFailureThreshold uint32 defaultDataSandbox storage.DataReference shardSelector ioutils.ShardSelector + recoveryClient recovery.Client } func (c *nodeExecutor) RecordTransitionLatency(ctx context.Context, dag executors.DAGStructure, nl executors.NodeLookup, node v1alpha1.ExecutableNode, nodeStatus v1alpha1.ExecutableNodeStatus) { @@ -146,6 +152,119 @@ func (c *nodeExecutor) IdempotentRecordEvent(ctx context.Context, nodeEvent *eve return err } +func (c *nodeExecutor) attemptRecovery(ctx context.Context, nCtx handler.NodeExecutionContext) (handler.PhaseInfo, error) { + recovered, err := c.recoveryClient.RecoverNodeExecution(ctx, nCtx.ExecutionContext().GetExecutionConfig().RecoveryExecution.WorkflowExecutionIdentifier, nCtx.NodeExecutionMetadata().GetNodeExecutionID()) + if err != nil { + st, ok := status.FromError(err) + if !ok || st.Code() != codes.NotFound { + logger.Warnf(ctx, "Failed to recover node [%+v] with err [%+v]", nCtx.NodeExecutionMetadata().GetNodeExecutionID(), err) + } + // The node is not recoverable when it's not found in the parent execution + return handler.PhaseInfoUndefined, nil + } + if recovered == nil { + logger.Warnf(ctx, "call to recover node [%+v] returned no error but also no node", nCtx.NodeExecutionMetadata().GetNodeExecutionID()) + return handler.PhaseInfoUndefined, nil + } + if recovered.Closure == nil { + logger.Warnf(ctx, "Fetched node execution [%+v] data but was missing closure. Will not attempt to recover", + nCtx.NodeExecutionMetadata().GetNodeExecutionID()) + return handler.PhaseInfoUndefined, nil + } + // A recoverable node execution should always be in a terminal phase + switch recovered.Closure.Phase { + case core.NodeExecution_SKIPPED: + return handler.PhaseInfoSkip(nil, "node execution recovery indicated original node was skipped"), nil + case core.NodeExecution_SUCCEEDED: + logger.Debugf(ctx, "Node [%+v] can be recovered. Proceeding to copy inputs and outputs", nCtx.NodeExecutionMetadata().GetNodeExecutionID()) + default: + logger.Debugf(ctx, "Node [%+v] phase [%v] is not recoverable", nCtx.NodeExecutionMetadata().GetNodeExecutionID(), recovered.Closure.Phase) + return handler.PhaseInfoUndefined, nil + } + + recoveredData, err := c.recoveryClient.RecoverNodeExecutionData(ctx, nCtx.ExecutionContext().GetExecutionConfig().RecoveryExecution.WorkflowExecutionIdentifier, nCtx.NodeExecutionMetadata().GetNodeExecutionID()) + if err != nil { + st, ok := status.FromError(err) + if !ok || st.Code() != codes.NotFound { + logger.Warnf(ctx, "Failed to attemptRecovery node execution data for [%+v] although back-end indicated node was recoverable with err [%+v]", + nCtx.NodeExecutionMetadata().GetNodeExecutionID(), err) + } + return handler.PhaseInfoUndefined, nil + } + if recoveredData == nil { + logger.Warnf(ctx, "call to attemptRecovery node [%+v] data returned no error but also no data", nCtx.NodeExecutionMetadata().GetNodeExecutionID()) + return handler.PhaseInfoUndefined, nil + } + // Copy inputs to this node's expected location + if recoveredData.FullInputs != nil { + if err := c.store.WriteProtobuf(ctx, nCtx.InputReader().GetInputPath(), storage.Options{}, recoveredData.FullInputs); err != nil { + c.metrics.InputsWriteFailure.Inc(ctx) + logger.Errorf(ctx, "Failed to move recovered inputs for Node. Error [%v]. InputsFile [%s]", err, nCtx.InputReader().GetInputPath()) + return handler.PhaseInfoUndefined, errors.Wrapf( + errors.StorageError, nCtx.NodeID(), err, "Failed to store inputs for Node. InputsFile [%s]", nCtx.InputReader().GetInputPath()) + } + } else if len(recovered.InputUri) > 0 { + // If the inputs are too large they won't be returned inline in the RecoverData call. We must fetch them before copying them. + nodeInputs := &core.LiteralMap{} + if recoveredData.FullInputs == nil { + if err := c.store.ReadProtobuf(ctx, storage.DataReference(recovered.InputUri), nodeInputs); err != nil { + return handler.PhaseInfoUndefined, errors.Wrapf(errors.InputsNotFoundError, nCtx.NodeID(), err, "failed to read data from dataDir [%v].", recovered.InputUri) + } + } + + if err := c.store.WriteProtobuf(ctx, nCtx.InputReader().GetInputPath(), storage.Options{}, recoveredData.FullInputs); err != nil { + c.metrics.InputsWriteFailure.Inc(ctx) + logger.Errorf(ctx, "Failed to move recovered inputs for Node. Error [%v]. InputsFile [%s]", err, nCtx.InputReader().GetInputPath()) + return handler.PhaseInfoUndefined, errors.Wrapf( + errors.StorageError, nCtx.NodeID(), err, "Failed to store inputs for Node. InputsFile [%s]", nCtx.InputReader().GetInputPath()) + } + } + // Similarly, copy outputs' reference + so := storage.Options{} + var outputs = &core.LiteralMap{} + if recoveredData.FullOutputs != nil { + outputs = recoveredData.FullOutputs + } else if len(recovered.Closure.GetOutputUri()) > 0 { + if err := c.store.ReadProtobuf(ctx, storage.DataReference(recovered.Closure.GetOutputUri()), outputs); err != nil { + return handler.PhaseInfoUndefined, errors.Wrapf(errors.InputsNotFoundError, nCtx.NodeID(), err, "failed to read output data [%v].", recovered.Closure.GetOutputUri()) + } + } else { + logger.Debugf(ctx, "No outputs found for recovered node [%+v]", nCtx.NodeExecutionMetadata().GetNodeExecutionID()) + } + outputFile := v1alpha1.GetOutputsFile(nCtx.NodeStatus().GetOutputDir()) + if err := c.store.WriteProtobuf(ctx, outputFile, so, outputs); err != nil { + logger.Errorf(ctx, "Failed to write protobuf (metadata). Error [%v]", err) + return handler.PhaseInfoUndefined, errors.Wrapf(errors.CausedByError, nCtx.NodeID(), err, "Failed to store recovered node execution outputs") + } + + info := &handler.ExecutionInfo{ + OutputInfo: &handler.OutputInfo{ + OutputURI: outputFile, + }, + } + if recovered.Closure.GetTaskNodeMetadata() != nil { + taskNodeInfo := &handler.TaskNodeInfo{ + TaskNodeMetadata: &event.TaskNodeMetadata{ + CatalogKey: recovered.Closure.GetTaskNodeMetadata().CatalogKey, + CacheStatus: recovered.Closure.GetTaskNodeMetadata().CacheStatus, + }, + } + if recoveredData.DynamicWorkflow != nil { + taskNodeInfo.TaskNodeMetadata.DynamicWorkflow = &event.DynamicWorkflowNodeMetadata{ + Id: recoveredData.DynamicWorkflow.Id, + CompiledWorkflow: recoveredData.DynamicWorkflow.CompiledWorkflow, + } + } + info.TaskNodeInfo = taskNodeInfo + } else if recovered.Closure.GetWorkflowNodeMetadata() != nil { + logger.Warnf(ctx, "Attempted to recover node") + info.WorkflowNodeInfo = &handler.WorkflowNodeInfo{ + LaunchedWorkflowID: recovered.Closure.GetWorkflowNodeMetadata().ExecutionId, + } + } + return handler.PhaseInfoRecovered(info), nil +} + // In this method we check if the queue is ready to be processed and if so, we prime it in Admin as queued // Before we start the node execution, we need to transition this Node status to Queued. // This is because a node execution has to exist before task/wf executions can start. @@ -162,10 +281,16 @@ func (c *nodeExecutor) preExecute(ctx context.Context, dag executors.DAGStructur // TODO: Performance problem, we maybe in a retry loop and do not need to resolve the inputs again. // For now we will do this node := nCtx.Node() - nodeStatus := nCtx.NodeStatus() - dataDir := nodeStatus.GetDataDir() var nodeInputs *core.LiteralMap if !node.IsStartNode() { + if nCtx.ExecutionContext().GetExecutionConfig().RecoveryExecution.WorkflowExecutionIdentifier != nil { + phaseInfo, err := c.attemptRecovery(ctx, nCtx) + if err != nil || phaseInfo.GetPhase() == handler.EPhaseRecovered { + return phaseInfo, err + } + } + nodeStatus := nCtx.NodeStatus() + dataDir := nodeStatus.GetDataDir() t := c.metrics.NodeInputGatherLatency.Start(ctx) defer t.Stop() // Can execute @@ -326,7 +451,7 @@ func (c *nodeExecutor) handleNotYetStartedNode(ctx context.Context, dag executor // assert np == Queued! logger.Infof(ctx, "Change in node state detected from [%s] -> [%s]", nodeStatus.GetPhase().String(), np.String()) nev, err := ToNodeExecutionEvent(nCtx.NodeExecutionMetadata().GetNodeExecutionID(), - p, nCtx.InputReader(), nodeStatus, nCtx.ExecutionContext().GetEventVersion(), + p, nCtx.InputReader().GetInputPath().String(), nodeStatus, nCtx.ExecutionContext().GetEventVersion(), nCtx.ExecutionContext().GetParentInfo(), nCtx.node) if err != nil { return executors.NodeStatusUndefined, errors.Wrapf(errors.IllegalStateError, nCtx.NodeID(), err, "could not convert phase info to event") @@ -429,14 +554,18 @@ func (c *nodeExecutor) handleQueuedOrRunningNode(ctx context.Context, nCtx *node np = v1alpha1.NodePhaseSucceeded finalStatus = executors.NodeStatusSuccess } + if np == v1alpha1.NodePhaseRecovered { + logger.Infof(ctx, "Finalize not required, moving node to Recovered") + finalStatus = executors.NodeStatusRecovered + } // If it is retryable failure, we do no want to send any events, as the node is essentially still running // Similarly if the phase has not changed from the last time, events do not need to be sent if np != nodeStatus.GetPhase() && np != v1alpha1.NodePhaseRetryableFailure { - // assert np == skipped, succeeding or failing + // assert np == skipped, succeeding, failing or recovered logger.Infof(ctx, "Change in node state detected from [%s] -> [%s], (handler phase [%s])", nodeStatus.GetPhase().String(), np.String(), p.GetPhase().String()) nev, err := ToNodeExecutionEvent(nCtx.NodeExecutionMetadata().GetNodeExecutionID(), - p, nCtx.InputReader(), nCtx.NodeStatus(), nCtx.ExecutionContext().GetEventVersion(), + p, nCtx.InputReader().GetInputPath().String(), nCtx.NodeStatus(), nCtx.ExecutionContext().GetEventVersion(), nCtx.ExecutionContext().GetParentInfo(), nCtx.node) if err != nil { return executors.NodeStatusUndefined, errors.Wrapf(errors.IllegalStateError, nCtx.NodeID(), err, "could not convert phase info to event") @@ -732,7 +861,7 @@ func (c *nodeExecutor) RecursiveNodeHandler(ctx context.Context, execContext exe // TODO we can optimize skip state handling by iterating down the graph and marking all as skipped // Currently we treat either Skip or Success the same way. In this approach only one node will be skipped // at a time. As we iterate down, further nodes will be skipped - } else if nodePhase == v1alpha1.NodePhaseSucceeded || nodePhase == v1alpha1.NodePhaseSkipped { + } else if nodePhase == v1alpha1.NodePhaseSucceeded || nodePhase == v1alpha1.NodePhaseSkipped || nodePhase == v1alpha1.NodePhaseRecovered { logger.Debugf(currentNodeCtx, "Node has [%v], traversing downstream.", nodePhase) return c.handleDownstream(ctx, execContext, dag, nl, currentNode) } else if nodePhase == v1alpha1.NodePhaseFailed { @@ -875,7 +1004,7 @@ func (c *nodeExecutor) AbortHandler(ctx context.Context, execContext executors.E return errors.Wrapf(errors.EventRecordingFailed, nCtx.NodeID(), err, "failed to record node event") } } - } else if nodePhase == v1alpha1.NodePhaseSucceeded || nodePhase == v1alpha1.NodePhaseSkipped { + } else if nodePhase == v1alpha1.NodePhaseSucceeded || nodePhase == v1alpha1.NodePhaseSkipped || nodePhase == v1alpha1.NodePhaseRecovered { // Abort downstream nodes downstreamNodes, err := dag.FromNode(currentNode.GetID()) if err != nil { @@ -918,7 +1047,7 @@ func (c *nodeExecutor) Initialize(ctx context.Context) error { func NewExecutor(ctx context.Context, nodeConfig config.NodeConfig, store *storage.DataStore, enQWorkflow v1alpha1.EnqueueWorkflow, eventSink events.EventSink, workflowLauncher launchplan.Executor, launchPlanReader launchplan.Reader, maxDatasetSize int64, defaultRawOutputPrefix storage.DataReference, kubeClient executors.Client, - catalogClient catalog.Client, scope promutils.Scope) (executors.Node, error) { + catalogClient catalog.Client, recoveryClient recovery.Client, scope promutils.Scope) (executors.Node, error) { // TODO we may want to make this configurable. shardSelector, err := ioutils.NewBase36PrefixShardSelector(ctx) @@ -937,6 +1066,7 @@ func NewExecutor(ctx context.Context, nodeConfig config.NodeConfig, store *stora Scope: nodeScope, FailureDuration: labeled.NewStopWatch("failure_duration", "Indicates the total execution time of a failed workflow.", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), SuccessDuration: labeled.NewStopWatch("success_duration", "Indicates the total execution time of a successful workflow.", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), + RecoveryDuration: labeled.NewStopWatch("recovery_duration", "Indicates the total execution time of a recovered workflow.", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), UserErrorDuration: labeled.NewStopWatch("user_error_duration", "Indicates the total execution time before user error", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), SystemErrorDuration: labeled.NewStopWatch("system_error_duration", "Indicates the total execution time before system error", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), UnknownErrorDuration: labeled.NewStopWatch("unknown_error_duration", "Indicates the total execution time before unknown error", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), @@ -961,8 +1091,9 @@ func NewExecutor(ctx context.Context, nodeConfig config.NodeConfig, store *stora interruptibleFailureThreshold: uint32(nodeConfig.InterruptibleFailureThreshold), defaultDataSandbox: defaultRawOutputPrefix, shardSelector: shardSelector, + recoveryClient: recoveryClient, } - nodeHandlerFactory, err := NewHandlerFactory(ctx, exec, workflowLauncher, launchPlanReader, kubeClient, catalogClient, nodeScope) + nodeHandlerFactory, err := NewHandlerFactory(ctx, exec, workflowLauncher, launchPlanReader, kubeClient, catalogClient, recoveryClient, nodeScope) exec.nodeHandlerFactory = nodeHandlerFactory return exec, err } diff --git a/pkg/controller/nodes/executor_test.go b/pkg/controller/nodes/executor_test.go index fb3cc74895..17bfa96eac 100644 --- a/pkg/controller/nodes/executor_test.go +++ b/pkg/controller/nodes/executor_test.go @@ -8,11 +8,15 @@ import ( "testing" "time" + "github.com/golang/protobuf/proto" + + mocks3 "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io/mocks" + storageMocks "github.com/flyteorg/flytestdlib/storage/mocks" + eventsErr "github.com/flyteorg/flyteidl/clients/go/events/errors" "github.com/flyteorg/flyteidl/clients/go/coreutils" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" - mocks3 "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io/mocks" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" "github.com/flyteorg/flytestdlib/promutils/labeled" @@ -36,14 +40,18 @@ import ( "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" "github.com/flyteorg/flytepropeller/pkg/controller/config" "github.com/flyteorg/flytepropeller/pkg/controller/executors" + recoveryMocks "github.com/flyteorg/flytepropeller/pkg/controller/nodes/recovery/mocks" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/subworkflow/launchplan" flyteassert "github.com/flyteorg/flytepropeller/pkg/utils/assert" ) var fakeKubeClient = mocks4.NewFakeKubeClient() var catalogClient = catalog.NOOPCatalog{} +var recoveryClient = &recoveryMocks.RecoveryClient{} const taskID = "tID" +const inputsPath = "inputs.pb" +const outputsPath = "out/outputs.pb" func TestSetInputsForStartNode(t *testing.T) { ctx := context.Background() @@ -52,7 +60,7 @@ func TestSetInputsForStartNode(t *testing.T) { adminClient := launchplan.NewFailFastLaunchPlanExecutor() exec, err := NewExecutor(ctx, config.GetConfig().NodeConfig, mockStorage, enQWf, events.NewMockEventSink(), adminClient, - adminClient, 10, "s3://bucket/", fakeKubeClient, catalogClient, promutils.NewTestScope()) + adminClient, 10, "s3://bucket/", fakeKubeClient, catalogClient, recoveryClient, promutils.NewTestScope()) assert.NoError(t, err) inputs := &core.LiteralMap{ Literals: map[string]*core.Literal{ @@ -99,7 +107,7 @@ func TestSetInputsForStartNode(t *testing.T) { failStorage := createFailingDatastore(t, testScope.NewSubScope("failing")) execFail, err := NewExecutor(ctx, config.GetConfig().NodeConfig, failStorage, enQWf, events.NewMockEventSink(), adminClient, - adminClient, 10, "s3://bucket", fakeKubeClient, catalogClient, promutils.NewTestScope()) + adminClient, 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, promutils.NewTestScope()) assert.NoError(t, err) t.Run("StorageFailure", func(t *testing.T) { w := createDummyBaseWorkflow(mockStorage) @@ -125,7 +133,7 @@ func TestNodeExecutor_Initialize(t *testing.T) { t.Run("happy", func(t *testing.T) { execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, memStore, enQWf, mockEventSink, adminClient, - adminClient, 10, "s3://bucket", fakeKubeClient, catalogClient, promutils.NewTestScope()) + adminClient, 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*nodeExecutor) @@ -139,7 +147,7 @@ func TestNodeExecutor_Initialize(t *testing.T) { t.Run("error", func(t *testing.T) { execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, memStore, enQWf, mockEventSink, adminClient, - adminClient, 10, "s3://bucket", fakeKubeClient, catalogClient, promutils.NewTestScope()) + adminClient, 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*nodeExecutor) @@ -162,7 +170,7 @@ func TestNodeExecutor_RecursiveNodeHandler_RecurseStartNodes(t *testing.T) { adminClient := launchplan.NewFailFastLaunchPlanExecutor() execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, - 10, "s3://bucket", fakeKubeClient, catalogClient, promutils.NewTestScope()) + 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*nodeExecutor) @@ -266,7 +274,7 @@ func TestNodeExecutor_RecursiveNodeHandler_RecurseEndNode(t *testing.T) { adminClient := launchplan.NewFailFastLaunchPlanExecutor() execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, - 10, "s3://bucket", fakeKubeClient, catalogClient, promutils.NewTestScope()) + 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*nodeExecutor) @@ -632,6 +640,7 @@ func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { mockWf.OnGetRawOutputDataConfig().Return(v1alpha1.RawOutputDataConfig{ RawOutputDataConfig: &admin.RawOutputDataConfig{OutputLocationPrefix: ""}, }) + mockWf.OnGetExecutionConfig().Return(v1alpha1.ExecutionConfig{}) mockWfStatus.OnGetDataDir().Return(storage.DataReference("x")) mockWfStatus.OnConstructNodeDataDirMatch(mock.Anything, mock.Anything, mock.Anything).Return("x", nil) return mockWf, mockN2Status @@ -668,7 +677,7 @@ func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { adminClient := launchplan.NewFailFastLaunchPlanExecutor() execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, - adminClient, adminClient, 10, "s3://bucket", fakeKubeClient, catalogClient, promutils.NewTestScope()) + adminClient, adminClient, 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*nodeExecutor) exec.nodeHandlerFactory = hf @@ -743,7 +752,7 @@ func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { store := createInmemoryDataStore(t, promutils.NewTestScope()) adminClient := launchplan.NewFailFastLaunchPlanExecutor() execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, - adminClient, 10, "s3://bucket", fakeKubeClient, catalogClient, promutils.NewTestScope()) + adminClient, 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*nodeExecutor) exec.nodeHandlerFactory = hf @@ -853,7 +862,7 @@ func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { store := createInmemoryDataStore(t, promutils.NewTestScope()) adminClient := launchplan.NewFailFastLaunchPlanExecutor() execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, - adminClient, 10, "s3://bucket", fakeKubeClient, catalogClient, promutils.NewTestScope()) + adminClient, 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*nodeExecutor) exec.nodeHandlerFactory = hf @@ -917,7 +926,7 @@ func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { store := createInmemoryDataStore(t, promutils.NewTestScope()) adminClient := launchplan.NewFailFastLaunchPlanExecutor() execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, - adminClient, 10, "s3://bucket", fakeKubeClient, catalogClient, promutils.NewTestScope()) + adminClient, 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*nodeExecutor) exec.nodeHandlerFactory = hf @@ -948,7 +957,7 @@ func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { store := createInmemoryDataStore(t, promutils.NewTestScope()) adminClient := launchplan.NewFailFastLaunchPlanExecutor() execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, - adminClient, 10, "s3://bucket", fakeKubeClient, catalogClient, promutils.NewTestScope()) + adminClient, 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*nodeExecutor) exec.nodeHandlerFactory = hf @@ -982,7 +991,7 @@ func TestNodeExecutor_RecursiveNodeHandler_NoDownstream(t *testing.T) { store := createInmemoryDataStore(t, promutils.NewTestScope()) adminClient := launchplan.NewFailFastLaunchPlanExecutor() execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, - adminClient, 10, "s3://bucket", fakeKubeClient, catalogClient, promutils.NewTestScope()) + adminClient, 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*nodeExecutor) @@ -1093,7 +1102,7 @@ func TestNodeExecutor_RecursiveNodeHandler_UpstreamNotReady(t *testing.T) { adminClient := launchplan.NewFailFastLaunchPlanExecutor() execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, - 10, "s3://bucket", fakeKubeClient, catalogClient, promutils.NewTestScope()) + 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*nodeExecutor) @@ -1209,7 +1218,7 @@ func TestNodeExecutor_RecursiveNodeHandler_BranchNode(t *testing.T) { adminClient := launchplan.NewFailFastLaunchPlanExecutor() execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, - 10, "s3://bucket", fakeKubeClient, catalogClient, promutils.NewTestScope()) + 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*nodeExecutor) // Node not yet started @@ -1270,6 +1279,7 @@ func TestNodeExecutor_RecursiveNodeHandler_BranchNode(t *testing.T) { RawOutputDataConfig: &admin.RawOutputDataConfig{OutputLocationPrefix: ""}, }) eCtx.OnCurrentParallelism().Return(0) + eCtx.OnGetExecutionConfig().Return(v1alpha1.ExecutionConfig{}) branchTakenNodeID := "branchTakenNode" branchTakenNode := &mocks.ExecutableNode{} @@ -1638,8 +1648,6 @@ func TestNodeExecutionEventV0(t *testing.T) { NodeExecutionId: nID, } p := handler.PhaseInfoQueued("r") - inputReader := &mocks3.InputReader{} - inputReader.OnGetInputPath().Return("reference") parentInfo := &mocks4.ImmutableParentInfo{} parentInfo.OnGetUniqueID().Return("np1") parentInfo.OnCurrentAttempt().Return(uint32(2)) @@ -1653,7 +1661,7 @@ func TestNodeExecutionEventV0(t *testing.T) { ns.OnGetPhase().Return(v1alpha1.NodePhaseNotYetStarted) nl.OnGetNodeExecutionStatusMatch(mock.Anything, id).Return(ns) ns.OnGetParentTaskID().Return(tID) - ev, err := ToNodeExecutionEvent(nID, p, inputReader, ns, v1alpha1.EventVersion0, parentInfo, n) + ev, err := ToNodeExecutionEvent(nID, p, "reference", ns, v1alpha1.EventVersion0, parentInfo, n) assert.NoError(t, err) assert.Equal(t, "n1", ev.Id.NodeId) assert.Equal(t, execID, ev.Id.ExecutionId) @@ -1678,8 +1686,8 @@ func TestNodeExecutionEventV1(t *testing.T) { NodeExecutionId: nID, } p := handler.PhaseInfoQueued("r") - inputReader := &mocks3.InputReader{} - inputReader.OnGetInputPath().Return("reference") + //inputReader := &mocks3.InputReader{} + //inputReader.OnGetInputPath().Return("reference") parentInfo := &mocks4.ImmutableParentInfo{} parentInfo.OnGetUniqueID().Return("np1") parentInfo.OnCurrentAttempt().Return(uint32(2)) @@ -1693,7 +1701,7 @@ func TestNodeExecutionEventV1(t *testing.T) { ns.OnGetPhase().Return(v1alpha1.NodePhaseNotYetStarted) nl.OnGetNodeExecutionStatusMatch(mock.Anything, id).Return(ns) ns.OnGetParentTaskID().Return(tID) - eventOpt, err := ToNodeExecutionEvent(nID, p, inputReader, ns, v1alpha1.EventVersion1, parentInfo, n) + eventOpt, err := ToNodeExecutionEvent(nID, p, "reference", ns, v1alpha1.EventVersion1, parentInfo, n) assert.NoError(t, err) assert.Equal(t, "np1-2-n1", eventOpt.Id.NodeId) assert.Equal(t, execID, eventOpt.Id.ExecutionId) @@ -1717,7 +1725,7 @@ func TestNodeExecutor_RecursiveNodeHandler_ParallelismLimit(t *testing.T) { adminClient := launchplan.NewFailFastLaunchPlanExecutor() execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, - 10, "s3://bucket", fakeKubeClient, catalogClient, promutils.NewTestScope()) + 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*nodeExecutor) @@ -1888,3 +1896,346 @@ func Test_nodeExecutor_IdempotentRecordEvent(t *testing.T) { }) } } + +func TestRecover(t *testing.T) { + recoveryID := &core.WorkflowExecutionIdentifier{ + Project: "p", + Domain: "d", + Name: "orig", + } + wfExecID := &core.WorkflowExecutionIdentifier{ + Project: "project", + Domain: "domain", + Name: "name", + } + nodeID := "recovering" + nodeExecID := &core.NodeExecutionIdentifier{ + ExecutionId: wfExecID, + NodeId: nodeID, + } + + fullInputs := &core.LiteralMap{ + Literals: map[string]*core.Literal{ + "innie": { + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Primitive{ + Primitive: &core.Primitive{ + Value: &core.Primitive_Integer{ + Integer: 2, + }, + }, + }, + }, + }, + }, + }, + } + fullOutputs := &core.LiteralMap{ + Literals: map[string]*core.Literal{ + "outie": { + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Primitive{ + Primitive: &core.Primitive{ + Value: &core.Primitive_StringValue{ + StringValue: "foo", + }, + }, + }, + }, + }, + }, + }, + } + + execContext := &mocks4.ExecutionContext{} + execContext.OnGetExecutionConfig().Return(v1alpha1.ExecutionConfig{ + RecoveryExecution: v1alpha1.WorkflowExecutionIdentifier{ + WorkflowExecutionIdentifier: recoveryID, + }, + }) + + nm := &nodeHandlerMocks.NodeExecutionMetadata{} + nm.OnGetNodeExecutionID().Return(&core.NodeExecutionIdentifier{ + ExecutionId: wfExecID, + NodeId: nodeID, + }) + + ir := &mocks3.InputReader{} + ir.OnGetInputPath().Return(inputsPath) + + ns := &mocks.ExecutableNodeStatus{} + ns.OnGetOutputDir().Return(storage.DataReference("out")) + + nCtx := &nodeHandlerMocks.NodeExecutionContext{} + nCtx.OnExecutionContext().Return(execContext) + nCtx.OnNodeExecutionMetadata().Return(nm) + nCtx.OnInputReader().Return(ir) + nCtx.OnNodeStatus().Return(ns) + + t.Run("recover task node successfully", func(t *testing.T) { + recoveryClient := &recoveryMocks.RecoveryClient{} + recoveryClient.On("RecoverNodeExecution", mock.Anything, recoveryID, nodeExecID).Return( + &admin.NodeExecution{ + Closure: &admin.NodeExecutionClosure{ + Phase: core.NodeExecution_SUCCEEDED, + OutputResult: &admin.NodeExecutionClosure_OutputUri{ + OutputUri: "outputuri.pb", + }, + }, + }, nil) + + recoveryClient.On("RecoverNodeExecutionData", mock.Anything, recoveryID, nodeExecID).Return( + &admin.NodeExecutionGetDataResponse{ + FullInputs: fullInputs, + FullOutputs: fullOutputs, + }, nil) + + mockPBStore := &storageMocks.ComposedProtobufStore{} + mockPBStore.On("WriteProtobuf", mock.Anything, mock.MatchedBy(func(reference storage.DataReference) bool { + return reference.String() == inputsPath || reference.String() == outputsPath + }), mock.Anything, + mock.Anything).Return(nil) + storageClient := &storage.DataStore{ + ComposedProtobufStore: mockPBStore, + ReferenceConstructor: &storageMocks.ReferenceConstructor{}, + } + + executor := nodeExecutor{ + recoveryClient: recoveryClient, + store: storageClient, + } + + phaseInfo, err := executor.attemptRecovery(context.TODO(), nCtx) + assert.NoError(t, err) + assert.Equal(t, phaseInfo.GetPhase(), handler.EPhaseRecovered) + }) + t.Run("recover cached, dynamic task node successfully", func(t *testing.T) { + recoveryClient := &recoveryMocks.RecoveryClient{} + recoveryClient.On("RecoverNodeExecution", mock.Anything, recoveryID, nodeExecID).Return( + &admin.NodeExecution{ + Closure: &admin.NodeExecutionClosure{ + Phase: core.NodeExecution_SUCCEEDED, + OutputResult: &admin.NodeExecutionClosure_OutputUri{ + OutputUri: "outputuri.pb", + }, + TargetMetadata: &admin.NodeExecutionClosure_TaskNodeMetadata{ + TaskNodeMetadata: &admin.TaskNodeMetadata{ + CatalogKey: &core.CatalogMetadata{ + ArtifactTag: &core.CatalogArtifactTag{ + ArtifactId: "arty", + }, + }, + CacheStatus: core.CatalogCacheStatus_CACHE_HIT, + }, + }, + }, + }, nil) + + dynamicWorkflow := &admin.DynamicWorkflowNodeMetadata{ + Id: &core.Identifier{ + ResourceType: core.ResourceType_WORKFLOW, + Project: "p", + Domain: "d", + Name: "n", + Version: "abc123", + }, + CompiledWorkflow: &core.CompiledWorkflowClosure{ + Primary: &core.CompiledWorkflow{ + Template: &core.WorkflowTemplate{ + Metadata: &core.WorkflowMetadata{ + OnFailure: core.WorkflowMetadata_FAIL_AFTER_EXECUTABLE_NODES_COMPLETE, + }, + }, + }, + }, + } + recoveryClient.On("RecoverNodeExecutionData", mock.Anything, recoveryID, nodeExecID).Return( + &admin.NodeExecutionGetDataResponse{ + FullInputs: fullInputs, + FullOutputs: fullOutputs, + DynamicWorkflow: dynamicWorkflow, + }, nil) + + mockPBStore := &storageMocks.ComposedProtobufStore{} + mockPBStore.On("WriteProtobuf", mock.Anything, mock.MatchedBy(func(reference storage.DataReference) bool { + return reference.String() == inputsPath || reference.String() == outputsPath + }), mock.Anything, + mock.Anything).Return(nil) + storageClient := &storage.DataStore{ + ComposedProtobufStore: mockPBStore, + ReferenceConstructor: &storageMocks.ReferenceConstructor{}, + } + + executor := nodeExecutor{ + recoveryClient: recoveryClient, + store: storageClient, + } + + phaseInfo, err := executor.attemptRecovery(context.TODO(), nCtx) + assert.NoError(t, err) + assert.Equal(t, phaseInfo.GetPhase(), handler.EPhaseRecovered) + assert.True(t, proto.Equal(&event.TaskNodeMetadata{ + CatalogKey: &core.CatalogMetadata{ + ArtifactTag: &core.CatalogArtifactTag{ + ArtifactId: "arty", + }, + }, + CacheStatus: core.CatalogCacheStatus_CACHE_HIT, + DynamicWorkflow: &event.DynamicWorkflowNodeMetadata{ + Id: dynamicWorkflow.Id, + CompiledWorkflow: dynamicWorkflow.CompiledWorkflow, + }, + }, phaseInfo.GetInfo().TaskNodeInfo.TaskNodeMetadata)) + }) + t.Run("recover workflow node successfully", func(t *testing.T) { + recoveryClient := &recoveryMocks.RecoveryClient{} + recoveryClient.On("RecoverNodeExecution", mock.Anything, recoveryID, nodeExecID).Return( + &admin.NodeExecution{ + Closure: &admin.NodeExecutionClosure{ + Phase: core.NodeExecution_SUCCEEDED, + OutputResult: &admin.NodeExecutionClosure_OutputUri{ + OutputUri: "outputuri.pb", + }, + TargetMetadata: &admin.NodeExecutionClosure_WorkflowNodeMetadata{ + WorkflowNodeMetadata: &admin.WorkflowNodeMetadata{ + ExecutionId: &core.WorkflowExecutionIdentifier{ + Project: "p", + Domain: "d", + Name: "original_child_wf", + }, + }, + }, + }, + }, nil) + + recoveryClient.On("RecoverNodeExecutionData", mock.Anything, recoveryID, nodeExecID).Return( + &admin.NodeExecutionGetDataResponse{ + FullInputs: fullInputs, + FullOutputs: fullOutputs, + }, nil) + + mockPBStore := &storageMocks.ComposedProtobufStore{} + mockPBStore.On("WriteProtobuf", mock.Anything, mock.MatchedBy(func(reference storage.DataReference) bool { + return reference.String() == inputsPath || reference.String() == outputsPath + }), mock.Anything, + mock.Anything).Return(nil) + storageClient := &storage.DataStore{ + ComposedProtobufStore: mockPBStore, + ReferenceConstructor: &storageMocks.ReferenceConstructor{}, + } + + executor := nodeExecutor{ + recoveryClient: recoveryClient, + store: storageClient, + } + + phaseInfo, err := executor.attemptRecovery(context.TODO(), nCtx) + assert.NoError(t, err) + assert.Equal(t, phaseInfo.GetPhase(), handler.EPhaseRecovered) + assert.True(t, proto.Equal(&core.WorkflowExecutionIdentifier{ + Project: "p", + Domain: "d", + Name: "original_child_wf", + }, phaseInfo.GetInfo().WorkflowNodeInfo.LaunchedWorkflowID)) + }) + + t.Run("nothing to recover", func(t *testing.T) { + recoveryClient := &recoveryMocks.RecoveryClient{} + recoveryClient.On("RecoverNodeExecution", mock.Anything, recoveryID, nodeExecID).Return( + &admin.NodeExecution{ + Closure: &admin.NodeExecutionClosure{ + Phase: core.NodeExecution_FAILED, + }, + }, nil) + + executor := nodeExecutor{ + recoveryClient: recoveryClient, + } + + phaseInfo, err := executor.attemptRecovery(context.TODO(), nCtx) + assert.NoError(t, err) + assert.Equal(t, phaseInfo.GetPhase(), handler.EPhaseUndefined) + }) + + t.Run("Fetch inputs", func(t *testing.T) { + recoveryClient := &recoveryMocks.RecoveryClient{} + recoveryClient.On("RecoverNodeExecution", mock.Anything, recoveryID, nodeExecID).Return( + &admin.NodeExecution{ + InputUri: "inputuri", + Closure: &admin.NodeExecutionClosure{ + Phase: core.NodeExecution_SUCCEEDED, + OutputResult: &admin.NodeExecutionClosure_OutputUri{ + OutputUri: "outputuri.pb", + }, + }, + }, nil) + + recoveryClient.On("RecoverNodeExecutionData", mock.Anything, recoveryID, nodeExecID).Return( + &admin.NodeExecutionGetDataResponse{ + FullOutputs: fullOutputs, + }, nil) + + mockPBStore := &storageMocks.ComposedProtobufStore{} + mockPBStore.On("WriteProtobuf", mock.Anything, mock.MatchedBy(func(reference storage.DataReference) bool { + return reference.String() == inputsPath || reference.String() == outputsPath + }), mock.Anything, + mock.Anything).Return(nil) + mockPBStore.On("ReadProtobuf", mock.Anything, storage.DataReference("inputuri"), &core.LiteralMap{}).Return(nil) + + storageClient := &storage.DataStore{ + ComposedProtobufStore: mockPBStore, + ReferenceConstructor: &storageMocks.ReferenceConstructor{}, + } + + executor := nodeExecutor{ + recoveryClient: recoveryClient, + store: storageClient, + } + + phaseInfo, err := executor.attemptRecovery(context.TODO(), nCtx) + assert.NoError(t, err) + assert.Equal(t, phaseInfo.GetPhase(), handler.EPhaseRecovered) + mockPBStore.AssertNumberOfCalls(t, "ReadProtobuf", 1) + }) + t.Run("Fetch outputs", func(t *testing.T) { + recoveryClient := &recoveryMocks.RecoveryClient{} + recoveryClient.On("RecoverNodeExecution", mock.Anything, recoveryID, nodeExecID).Return( + &admin.NodeExecution{ + Closure: &admin.NodeExecutionClosure{ + Phase: core.NodeExecution_SUCCEEDED, + OutputResult: &admin.NodeExecutionClosure_OutputUri{ + OutputUri: "outputuri.pb", + }, + }, + }, nil) + + recoveryClient.On("RecoverNodeExecutionData", mock.Anything, recoveryID, nodeExecID).Return( + &admin.NodeExecutionGetDataResponse{ + FullInputs: fullInputs, + }, nil) + + mockPBStore := &storageMocks.ComposedProtobufStore{} + mockPBStore.On("WriteProtobuf", mock.Anything, mock.MatchedBy(func(reference storage.DataReference) bool { + return reference.String() == inputsPath || reference.String() == outputsPath + }), mock.Anything, + mock.Anything).Return(nil) + mockPBStore.On("ReadProtobuf", mock.Anything, storage.DataReference("outputuri.pb"), &core.LiteralMap{}).Return(nil) + + storageClient := &storage.DataStore{ + ComposedProtobufStore: mockPBStore, + ReferenceConstructor: &storageMocks.ReferenceConstructor{}, + } + + executor := nodeExecutor{ + recoveryClient: recoveryClient, + store: storageClient, + } + + phaseInfo, err := executor.attemptRecovery(context.TODO(), nCtx) + assert.NoError(t, err) + assert.Equal(t, phaseInfo.GetPhase(), handler.EPhaseRecovered) + mockPBStore.AssertNumberOfCalls(t, "ReadProtobuf", 1) + }) +} diff --git a/pkg/controller/nodes/handler/ephase_enumer.go b/pkg/controller/nodes/handler/ephase_enumer.go index a08c9e445f..d574f93e04 100644 --- a/pkg/controller/nodes/handler/ephase_enumer.go +++ b/pkg/controller/nodes/handler/ephase_enumer.go @@ -7,9 +7,9 @@ import ( "fmt" ) -const _EPhaseName = "UndefinedNotReadyQueuedRunningSkipFailedRetryableFailureSuccessTimedoutFailingDynamicRunning" +const _EPhaseName = "UndefinedNotReadyQueuedRunningSkipFailedRetryableFailureSuccessTimedoutFailingDynamicRunningRecovered" -var _EPhaseIndex = [...]uint8{0, 9, 17, 23, 30, 34, 40, 56, 63, 71, 78, 92} +var _EPhaseIndex = [...]uint8{0, 9, 17, 23, 30, 34, 40, 56, 63, 71, 78, 92, 101} func (i EPhase) String() string { if i >= EPhase(len(_EPhaseIndex)-1) { @@ -18,20 +18,21 @@ func (i EPhase) String() string { return _EPhaseName[_EPhaseIndex[i]:_EPhaseIndex[i+1]] } -var _EPhaseValues = []EPhase{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10} +var _EPhaseValues = []EPhase{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11} var _EPhaseNameToValueMap = map[string]EPhase{ - _EPhaseName[0:9]: 0, - _EPhaseName[9:17]: 1, - _EPhaseName[17:23]: 2, - _EPhaseName[23:30]: 3, - _EPhaseName[30:34]: 4, - _EPhaseName[34:40]: 5, - _EPhaseName[40:56]: 6, - _EPhaseName[56:63]: 7, - _EPhaseName[63:71]: 8, - _EPhaseName[71:78]: 9, - _EPhaseName[78:92]: 10, + _EPhaseName[0:9]: 0, + _EPhaseName[9:17]: 1, + _EPhaseName[17:23]: 2, + _EPhaseName[23:30]: 3, + _EPhaseName[30:34]: 4, + _EPhaseName[34:40]: 5, + _EPhaseName[40:56]: 6, + _EPhaseName[56:63]: 7, + _EPhaseName[63:71]: 8, + _EPhaseName[71:78]: 9, + _EPhaseName[78:92]: 10, + _EPhaseName[92:101]: 11, } // EPhaseString retrieves an enum value from the enum constants string name. diff --git a/pkg/controller/nodes/handler/transition_info.go b/pkg/controller/nodes/handler/transition_info.go index edff42201a..58264650d6 100644 --- a/pkg/controller/nodes/handler/transition_info.go +++ b/pkg/controller/nodes/handler/transition_info.go @@ -24,10 +24,11 @@ const ( EPhaseTimedout EPhaseFailing EPhaseDynamicRunning + EPhaseRecovered ) func (p EPhase) IsTerminal() bool { - if p == EPhaseFailed || p == EPhaseSuccess || p == EPhaseSkip || p == EPhaseTimedout { + if p == EPhaseFailed || p == EPhaseSuccess || p == EPhaseSkip || p == EPhaseTimedout || p == EPhaseRecovered { return true } return false @@ -138,6 +139,10 @@ func PhaseInfoTimedOut(info *ExecutionInfo, reason string) PhaseInfo { return phaseInfo(EPhaseTimedout, nil, info, reason) } +func PhaseInfoRecovered(info *ExecutionInfo) PhaseInfo { + return phaseInfo(EPhaseRecovered, nil, info, "successfully recovered") +} + func phaseInfoFailed(p EPhase, err *core.ExecutionError, info *ExecutionInfo) PhaseInfo { if err == nil { err = &core.ExecutionError{ diff --git a/pkg/controller/nodes/handler_factory.go b/pkg/controller/nodes/handler_factory.go index 1cbd1cc463..cef3c3cf1b 100644 --- a/pkg/controller/nodes/handler_factory.go +++ b/pkg/controller/nodes/handler_factory.go @@ -3,6 +3,8 @@ package nodes import ( "context" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/recovery" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/catalog" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/dynamic" @@ -51,9 +53,9 @@ func (f handlerFactory) Setup(ctx context.Context, setup handler.SetupContext) e } func NewHandlerFactory(ctx context.Context, executor executors.Node, workflowLauncher launchplan.Executor, - launchPlanReader launchplan.Reader, kubeClient executors.Client, client catalog.Client, scope promutils.Scope) (HandlerFactory, error) { + launchPlanReader launchplan.Reader, kubeClient executors.Client, client catalog.Client, recoveryClient recovery.Client, scope promutils.Scope) (HandlerFactory, error) { - t, err := task.New(ctx, kubeClient, client, scope) + t, err := task.New(ctx, kubeClient, client, recoveryClient, scope) if err != nil { return nil, err } @@ -62,7 +64,7 @@ func NewHandlerFactory(ctx context.Context, executor executors.Node, workflowLau handlers: map[v1alpha1.NodeKind]handler.Node{ v1alpha1.NodeKindBranch: branch.New(executor, scope), v1alpha1.NodeKindTask: dynamic.New(t, executor, launchPlanReader, scope), - v1alpha1.NodeKindWorkflow: subworkflow.New(executor, workflowLauncher, scope), + v1alpha1.NodeKindWorkflow: subworkflow.New(executor, workflowLauncher, recoveryClient, scope), v1alpha1.NodeKindStart: start.New(), v1alpha1.NodeKindEnd: end.New(), }, diff --git a/pkg/controller/nodes/predicate.go b/pkg/controller/nodes/predicate.go index 744a1088ef..d0a766bfaa 100644 --- a/pkg/controller/nodes/predicate.go +++ b/pkg/controller/nodes/predicate.go @@ -96,7 +96,8 @@ func CanExecute(ctx context.Context, dag executors.DAGStructure, nl executors.No upstreamNodeStatus.GetPhase() == v1alpha1.NodePhaseFailed || upstreamNodeStatus.GetPhase() == v1alpha1.NodePhaseTimedOut { skipped = true - } else if upstreamNodeStatus.GetPhase() != v1alpha1.NodePhaseSucceeded { + } else if !(upstreamNodeStatus.GetPhase() == v1alpha1.NodePhaseSucceeded || + upstreamNodeStatus.GetPhase() == v1alpha1.NodePhaseRecovered) { return PredicatePhaseNotReady, nil } } diff --git a/pkg/controller/nodes/recovery/client.go b/pkg/controller/nodes/recovery/client.go new file mode 100644 index 0000000000..d4a1954395 --- /dev/null +++ b/pkg/controller/nodes/recovery/client.go @@ -0,0 +1,46 @@ +package recovery + +import ( + "context" + + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service" +) + +//go:generate mockery -name Client -output=mocks -case=underscore + +type Client interface { + RecoverNodeExecution(ctx context.Context, execID *core.WorkflowExecutionIdentifier, id *core.NodeExecutionIdentifier) (*admin.NodeExecution, error) + RecoverNodeExecutionData(ctx context.Context, execID *core.WorkflowExecutionIdentifier, id *core.NodeExecutionIdentifier) (*admin.NodeExecutionGetDataResponse, error) +} + +type recoveryClient struct { + adminClient service.AdminServiceClient +} + +func (c *recoveryClient) RecoverNodeExecution(ctx context.Context, execID *core.WorkflowExecutionIdentifier, nodeID *core.NodeExecutionIdentifier) (*admin.NodeExecution, error) { + origNodeID := &core.NodeExecutionIdentifier{ + ExecutionId: execID, + NodeId: nodeID.NodeId, + } + return c.adminClient.GetNodeExecution(ctx, &admin.NodeExecutionGetRequest{ + Id: origNodeID, + }) +} + +func (c *recoveryClient) RecoverNodeExecutionData(ctx context.Context, execID *core.WorkflowExecutionIdentifier, nodeID *core.NodeExecutionIdentifier) (*admin.NodeExecutionGetDataResponse, error) { + origNodeID := &core.NodeExecutionIdentifier{ + ExecutionId: execID, + NodeId: nodeID.NodeId, + } + return c.adminClient.GetNodeExecutionData(ctx, &admin.NodeExecutionGetDataRequest{ + Id: origNodeID, + }) +} + +func NewClient(adminClient service.AdminServiceClient) Client { + return &recoveryClient{ + adminClient: adminClient, + } +} diff --git a/pkg/controller/nodes/recovery/mocks/client.go b/pkg/controller/nodes/recovery/mocks/client.go new file mode 100644 index 0000000000..fe551d56be --- /dev/null +++ b/pkg/controller/nodes/recovery/mocks/client.go @@ -0,0 +1,100 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import ( + context "context" + + admin "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" + + core "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + + mock "github.com/stretchr/testify/mock" +) + +// Client is an autogenerated mock type for the Client type +type Client struct { + mock.Mock +} + +type Client_RecoverNodeExecution struct { + *mock.Call +} + +func (_m Client_RecoverNodeExecution) Return(_a0 *admin.NodeExecution, _a1 error) *Client_RecoverNodeExecution { + return &Client_RecoverNodeExecution{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *Client) OnRecoverNodeExecution(ctx context.Context, execID *core.WorkflowExecutionIdentifier, id *core.NodeExecutionIdentifier) *Client_RecoverNodeExecution { + c := _m.On("RecoverNodeExecution", ctx, execID, id) + return &Client_RecoverNodeExecution{Call: c} +} + +func (_m *Client) OnRecoverNodeExecutionMatch(matchers ...interface{}) *Client_RecoverNodeExecution { + c := _m.On("RecoverNodeExecution", matchers...) + return &Client_RecoverNodeExecution{Call: c} +} + +// RecoverNodeExecution provides a mock function with given fields: ctx, execID, id +func (_m *Client) RecoverNodeExecution(ctx context.Context, execID *core.WorkflowExecutionIdentifier, id *core.NodeExecutionIdentifier) (*admin.NodeExecution, error) { + ret := _m.Called(ctx, execID, id) + + var r0 *admin.NodeExecution + if rf, ok := ret.Get(0).(func(context.Context, *core.WorkflowExecutionIdentifier, *core.NodeExecutionIdentifier) *admin.NodeExecution); ok { + r0 = rf(ctx, execID, id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*admin.NodeExecution) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *core.WorkflowExecutionIdentifier, *core.NodeExecutionIdentifier) error); ok { + r1 = rf(ctx, execID, id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type Client_RecoverNodeExecutionData struct { + *mock.Call +} + +func (_m Client_RecoverNodeExecutionData) Return(_a0 *admin.NodeExecutionGetDataResponse, _a1 error) *Client_RecoverNodeExecutionData { + return &Client_RecoverNodeExecutionData{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *Client) OnRecoverNodeExecutionData(ctx context.Context, execID *core.WorkflowExecutionIdentifier, id *core.NodeExecutionIdentifier) *Client_RecoverNodeExecutionData { + c := _m.On("RecoverNodeExecutionData", ctx, execID, id) + return &Client_RecoverNodeExecutionData{Call: c} +} + +func (_m *Client) OnRecoverNodeExecutionDataMatch(matchers ...interface{}) *Client_RecoverNodeExecutionData { + c := _m.On("RecoverNodeExecutionData", matchers...) + return &Client_RecoverNodeExecutionData{Call: c} +} + +// RecoverNodeExecutionData provides a mock function with given fields: ctx, execID, id +func (_m *Client) RecoverNodeExecutionData(ctx context.Context, execID *core.WorkflowExecutionIdentifier, id *core.NodeExecutionIdentifier) (*admin.NodeExecutionGetDataResponse, error) { + ret := _m.Called(ctx, execID, id) + + var r0 *admin.NodeExecutionGetDataResponse + if rf, ok := ret.Get(0).(func(context.Context, *core.WorkflowExecutionIdentifier, *core.NodeExecutionIdentifier) *admin.NodeExecutionGetDataResponse); ok { + r0 = rf(ctx, execID, id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*admin.NodeExecutionGetDataResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *core.WorkflowExecutionIdentifier, *core.NodeExecutionIdentifier) error); ok { + r1 = rf(ctx, execID, id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} diff --git a/pkg/controller/nodes/recovery/mocks/recovery_client.go b/pkg/controller/nodes/recovery/mocks/recovery_client.go new file mode 100644 index 0000000000..f52b65474e --- /dev/null +++ b/pkg/controller/nodes/recovery/mocks/recovery_client.go @@ -0,0 +1,100 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import ( + context "context" + + admin "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" + + core "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + + mock "github.com/stretchr/testify/mock" +) + +// RecoveryClient is an autogenerated mock type for the RecoveryClient type +type RecoveryClient struct { + mock.Mock +} + +type RecoveryClient_RecoverNodeExecution struct { + *mock.Call +} + +func (_m RecoveryClient_RecoverNodeExecution) Return(_a0 *admin.NodeExecution, _a1 error) *RecoveryClient_RecoverNodeExecution { + return &RecoveryClient_RecoverNodeExecution{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *RecoveryClient) OnRecoverNodeExecution(ctx context.Context, execID *core.WorkflowExecutionIdentifier, id *core.NodeExecutionIdentifier) *RecoveryClient_RecoverNodeExecution { + c := _m.On("RecoverNodeExecution", ctx, execID, id) + return &RecoveryClient_RecoverNodeExecution{Call: c} +} + +func (_m *RecoveryClient) OnRecoverNodeExecutionMatch(matchers ...interface{}) *RecoveryClient_RecoverNodeExecution { + c := _m.On("RecoverNodeExecution", matchers...) + return &RecoveryClient_RecoverNodeExecution{Call: c} +} + +// RecoverNodeExecution provides a mock function with given fields: ctx, execID, id +func (_m *RecoveryClient) RecoverNodeExecution(ctx context.Context, execID *core.WorkflowExecutionIdentifier, id *core.NodeExecutionIdentifier) (*admin.NodeExecution, error) { + ret := _m.Called(ctx, execID, id) + + var r0 *admin.NodeExecution + if rf, ok := ret.Get(0).(func(context.Context, *core.WorkflowExecutionIdentifier, *core.NodeExecutionIdentifier) *admin.NodeExecution); ok { + r0 = rf(ctx, execID, id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*admin.NodeExecution) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *core.WorkflowExecutionIdentifier, *core.NodeExecutionIdentifier) error); ok { + r1 = rf(ctx, execID, id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type RecoveryClient_RecoverNodeExecutionData struct { + *mock.Call +} + +func (_m RecoveryClient_RecoverNodeExecutionData) Return(_a0 *admin.NodeExecutionGetDataResponse, _a1 error) *RecoveryClient_RecoverNodeExecutionData { + return &RecoveryClient_RecoverNodeExecutionData{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *RecoveryClient) OnRecoverNodeExecutionData(ctx context.Context, execID *core.WorkflowExecutionIdentifier, id *core.NodeExecutionIdentifier) *RecoveryClient_RecoverNodeExecutionData { + c := _m.On("RecoverNodeExecutionData", ctx, execID, id) + return &RecoveryClient_RecoverNodeExecutionData{Call: c} +} + +func (_m *RecoveryClient) OnRecoverNodeExecutionDataMatch(matchers ...interface{}) *RecoveryClient_RecoverNodeExecutionData { + c := _m.On("RecoverNodeExecutionData", matchers...) + return &RecoveryClient_RecoverNodeExecutionData{Call: c} +} + +// RecoverNodeExecutionData provides a mock function with given fields: ctx, execID, id +func (_m *RecoveryClient) RecoverNodeExecutionData(ctx context.Context, execID *core.WorkflowExecutionIdentifier, id *core.NodeExecutionIdentifier) (*admin.NodeExecutionGetDataResponse, error) { + ret := _m.Called(ctx, execID, id) + + var r0 *admin.NodeExecutionGetDataResponse + if rf, ok := ret.Get(0).(func(context.Context, *core.WorkflowExecutionIdentifier, *core.NodeExecutionIdentifier) *admin.NodeExecutionGetDataResponse); ok { + r0 = rf(ctx, execID, id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*admin.NodeExecutionGetDataResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *core.WorkflowExecutionIdentifier, *core.NodeExecutionIdentifier) error); ok { + r1 = rf(ctx, execID, id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} diff --git a/pkg/controller/nodes/subworkflow/handler.go b/pkg/controller/nodes/subworkflow/handler.go index eb245338f9..b11892e9f3 100644 --- a/pkg/controller/nodes/subworkflow/handler.go +++ b/pkg/controller/nodes/subworkflow/handler.go @@ -3,6 +3,8 @@ package subworkflow import ( "context" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/recovery" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flytestdlib/promutils" @@ -125,12 +127,13 @@ func (w *workflowNodeHandler) Finalize(ctx context.Context, _ handler.NodeExecut return nil } -func New(executor executors.Node, workflowLauncher launchplan.Executor, scope promutils.Scope) handler.Node { +func New(executor executors.Node, workflowLauncher launchplan.Executor, recoveryClient recovery.Client, scope promutils.Scope) handler.Node { workflowScope := scope.NewSubScope("workflow") return &workflowNodeHandler{ subWfHandler: newSubworkflowHandler(executor), lpHandler: launchPlanHandler{ - launchPlan: workflowLauncher, + launchPlan: workflowLauncher, + recoveryClient: recoveryClient, }, metrics: newMetrics(workflowScope), } diff --git a/pkg/controller/nodes/subworkflow/handler_test.go b/pkg/controller/nodes/subworkflow/handler_test.go index 605696110d..784df59343 100644 --- a/pkg/controller/nodes/subworkflow/handler_test.go +++ b/pkg/controller/nodes/subworkflow/handler_test.go @@ -6,6 +6,8 @@ import ( "reflect" "testing" + mocks5 "github.com/flyteorg/flytepropeller/pkg/controller/nodes/recovery/mocks" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" mocks4 "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io/mocks" @@ -99,6 +101,7 @@ func createNodeContextWithVersion(phase v1alpha1.WorkflowNodePhase, n v1alpha1.E ex.OnGetEventVersion().Return(version) ex.OnGetParentInfo().Return(nil) ex.OnGetName().Return("name") + ex.OnGetExecutionConfig().Return(v1alpha1.ExecutionConfig{}) nCtx.OnExecutionContext().Return(ex) @@ -139,11 +142,12 @@ func TestWorkflowNodeHandler_StartNode_Launchplan(t *testing.T) { mockNodeStatus.OnGetAttempts().Return(attempts) wfStatus := &mocks2.MutableWorkflowNodeStatus{} mockNodeStatus.OnGetOrCreateWorkflowStatus().Return(wfStatus) + recoveryClient := &mocks5.RecoveryClient{} t.Run("happy v0", func(t *testing.T) { mockLPExec := &mocks.Executor{} - h := New(nil, mockLPExec, promutils.NewTestScope()) + h := New(nil, mockLPExec, recoveryClient, promutils.NewTestScope()) mockLPExec.OnLaunchMatch( ctx, mock.MatchedBy(func(o launchplan.LaunchContext) bool { @@ -166,7 +170,7 @@ func TestWorkflowNodeHandler_StartNode_Launchplan(t *testing.T) { t.Run("happy v1", func(t *testing.T) { mockLPExec := &mocks.Executor{} - h := New(nil, mockLPExec, promutils.NewTestScope()) + h := New(nil, mockLPExec, recoveryClient, promutils.NewTestScope()) mockLPExec.OnLaunchMatch( ctx, mock.MatchedBy(func(o launchplan.LaunchContext) bool { @@ -213,12 +217,13 @@ func TestWorkflowNodeHandler_CheckNodeStatus(t *testing.T) { mockNodeStatus := &mocks2.ExecutableNodeStatus{} mockNodeStatus.OnGetAttempts().Return(attempts) mockNodeStatus.OnGetDataDir().Return(dataDir) + recoveryClient := &mocks5.RecoveryClient{} t.Run("stillRunning V0", func(t *testing.T) { mockLPExec := &mocks.Executor{} - h := New(nil, mockLPExec, promutils.NewTestScope()) + h := New(nil, mockLPExec, recoveryClient, promutils.NewTestScope()) mockLPExec.OnGetStatusMatch( ctx, mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { @@ -237,7 +242,7 @@ func TestWorkflowNodeHandler_CheckNodeStatus(t *testing.T) { mockLPExec := &mocks.Executor{} - h := New(nil, mockLPExec, promutils.NewTestScope()) + h := New(nil, mockLPExec, recoveryClient, promutils.NewTestScope()) mockLPExec.OnGetStatusMatch( ctx, mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { @@ -280,13 +285,14 @@ func TestWorkflowNodeHandler_AbortNode(t *testing.T) { mockNodeStatus := &mocks2.ExecutableNodeStatus{} mockNodeStatus.OnGetAttempts().Return(attempts) mockNodeStatus.OnGetDataDir().Return(dataDir) + recoveryClient := &mocks5.RecoveryClient{} t.Run("abort v0", func(t *testing.T) { mockLPExec := &mocks.Executor{} nCtx := createNodeContext(v1alpha1.WorkflowNodePhaseExecuting, mockNode, mockNodeStatus) - h := New(nil, mockLPExec, promutils.NewTestScope()) + h := New(nil, mockLPExec, recoveryClient, promutils.NewTestScope()) mockLPExec.OnKillMatch( ctx, mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { @@ -307,7 +313,7 @@ func TestWorkflowNodeHandler_AbortNode(t *testing.T) { mockLPExec := &mocks.Executor{} nCtx := createNodeContextV1(v1alpha1.WorkflowNodePhaseExecuting, mockNode, mockNodeStatus) - h := New(nil, mockLPExec, promutils.NewTestScope()) + h := New(nil, mockLPExec, recoveryClient, promutils.NewTestScope()) mockLPExec.OnKillMatch( ctx, mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { @@ -326,7 +332,7 @@ func TestWorkflowNodeHandler_AbortNode(t *testing.T) { mockLPExec := &mocks.Executor{} expectedErr := fmt.Errorf("fail") - h := New(nil, mockLPExec, promutils.NewTestScope()) + h := New(nil, mockLPExec, recoveryClient, promutils.NewTestScope()) mockLPExec.OnKillMatch( ctx, mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { diff --git a/pkg/controller/nodes/subworkflow/launchplan.go b/pkg/controller/nodes/subworkflow/launchplan.go index 89b9886c3b..dbff0fb6b2 100644 --- a/pkg/controller/nodes/subworkflow/launchplan.go +++ b/pkg/controller/nodes/subworkflow/launchplan.go @@ -4,6 +4,10 @@ import ( "context" "fmt" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/recovery" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/common" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" @@ -17,7 +21,8 @@ import ( ) type launchPlanHandler struct { - launchPlan launchplan.Executor + launchPlan launchplan.Executor + recoveryClient recovery.Client } func getParentNodeExecutionID(nCtx handler.NodeExecutionContext) (*core.NodeExecutionIdentifier, error) { @@ -59,6 +64,22 @@ func (l *launchPlanHandler) StartLaunchPlan(ctx context.Context, nCtx handler.No launchCtx := launchplan.LaunchContext{ ParentNodeExecution: parentNodeExecutionID, } + if nCtx.ExecutionContext().GetExecutionConfig().RecoveryExecution.WorkflowExecutionIdentifier != nil { + recovered, err := l.recoveryClient.RecoverNodeExecution(ctx, nCtx.ExecutionContext().GetExecutionConfig().RecoveryExecution.WorkflowExecutionIdentifier, nCtx.NodeExecutionMetadata().GetNodeExecutionID()) + if err != nil { + st, ok := status.FromError(err) + if !ok || st.Code() != codes.NotFound { + logger.Warnf(ctx, "Failed to recover workflow node [%+v] with err [%+v]", nCtx.NodeExecutionMetadata().GetNodeExecutionID(), err) + } + } + if recovered != nil && recovered.Closure != nil && recovered.Closure.Phase == core.NodeExecution_SUCCEEDED { + if recovered.Closure.GetWorkflowNodeMetadata() != nil { + launchCtx.RecoveryExecution = recovered.Closure.GetWorkflowNodeMetadata().ExecutionId + } else { + logger.Debugf(ctx, "Attempted to recovered workflow node execution [%+v] but was missing workflow node metadata", recovered.Id) + } + } + } err = l.launchPlan.Launch(ctx, launchCtx, childID, nCtx.Node().GetWorkflowNode().GetLaunchPlanRefID().Identifier, nodeInputs) if err != nil { if launchplan.IsAlreadyExists(err) { diff --git a/pkg/controller/nodes/subworkflow/launchplan/admin.go b/pkg/controller/nodes/subworkflow/launchplan/admin.go index 9ac4420e66..1cfb990d8e 100644 --- a/pkg/controller/nodes/subworkflow/launchplan/admin.go +++ b/pkg/controller/nodes/subworkflow/launchplan/admin.go @@ -22,6 +22,8 @@ import ( "google.golang.org/grpc/status" ) +var isRecovery = true + // IsWorkflowTerminated returns a true if the Workflow Phase is in a Terminal Phase, else returns a false func IsWorkflowTerminated(p core.WorkflowExecution_Phase) bool { return p == core.WorkflowExecution_ABORTED || p == core.WorkflowExecution_FAILED || @@ -44,9 +46,49 @@ func (e executionCacheItem) ID() string { return e.String() } +func (a *adminLaunchPlanExecutor) handleLaunchError(ctx context.Context, isRecovery bool, + executionID *core.WorkflowExecutionIdentifier, launchPlanRef *core.Identifier, err error) error { + + statusCode := status.Code(err) + if isRecovery && statusCode == codes.NotFound { + logger.Warnf(ctx, "failed to recover workflow [%s] with err %+v. will attempt to launch instead", launchPlanRef.Name, err) + return nil + } + switch statusCode { + case codes.AlreadyExists: + _, err := a.cache.GetOrCreate(executionID.String(), executionCacheItem{WorkflowExecutionIdentifier: *executionID}) + if err != nil { + logger.Errorf(ctx, "Failed to add ExecID [%v] to auto refresh cache", executionID) + } + + return errors.Wrapf(RemoteErrorAlreadyExists, err, "ExecID %s already exists", executionID.Name) + case codes.DataLoss, codes.DeadlineExceeded, codes.Internal, codes.Unknown, codes.Canceled: + return errors.Wrapf(RemoteErrorSystem, err, "failed to launch workflow [%s], system error", launchPlanRef.Name) + default: + return errors.Wrapf(RemoteErrorUser, err, "failed to launch workflow") + } +} + func (a *adminLaunchPlanExecutor) Launch(ctx context.Context, launchCtx LaunchContext, executionID *core.WorkflowExecutionIdentifier, launchPlanRef *core.Identifier, inputs *core.LiteralMap) error { - + var err error + if launchCtx.RecoveryExecution != nil { + _, err = a.adminClient.RecoverExecution(ctx, &admin.ExecutionRecoverRequest{ + Id: launchCtx.RecoveryExecution, + Name: executionID.Name, + Metadata: &admin.ExecutionMetadata{ + ParentNodeExecution: launchCtx.ParentNodeExecution, + }, + }) + if err != nil { + launchErr := a.handleLaunchError(ctx, isRecovery, executionID, launchPlanRef, err) + if launchErr != nil { + return launchErr + } + } else { + return nil + } + } req := &admin.ExecutionCreateRequest{ Project: executionID.Project, Domain: executionID.Domain, @@ -62,21 +104,11 @@ func (a *adminLaunchPlanExecutor) Launch(ctx context.Context, launchCtx LaunchCo Inputs: inputs, }, } - _, err := a.adminClient.CreateExecution(ctx, req) + _, err = a.adminClient.CreateExecution(ctx, req) if err != nil { - statusCode := status.Code(err) - switch statusCode { - case codes.AlreadyExists: - _, err := a.cache.GetOrCreate(executionID.String(), executionCacheItem{WorkflowExecutionIdentifier: *executionID}) - if err != nil { - logger.Errorf(ctx, "Failed to add ExecID [%v] to auto refresh cache", executionID) - } - - return errors.Wrapf(RemoteErrorAlreadyExists, err, "ExecID %s already exists", executionID.Name) - case codes.DataLoss, codes.DeadlineExceeded, codes.Internal, codes.Unknown, codes.Canceled: - return errors.Wrapf(RemoteErrorSystem, err, "failed to launch workflow [%s], system error", launchPlanRef.Name) - default: - return errors.Wrapf(RemoteErrorUser, err, "failed to launch workflow") + launchErr := a.handleLaunchError(ctx, !isRecovery, executionID, launchPlanRef, err) + if launchErr != nil { + return launchErr } } diff --git a/pkg/controller/nodes/subworkflow/launchplan/admin_test.go b/pkg/controller/nodes/subworkflow/launchplan/admin_test.go index f52a5bf829..e8c79efd9d 100644 --- a/pkg/controller/nodes/subworkflow/launchplan/admin_test.go +++ b/pkg/controller/nodes/subworkflow/launchplan/admin_test.go @@ -5,6 +5,8 @@ import ( "testing" "time" + "github.com/golang/protobuf/proto" + "github.com/flyteorg/flytestdlib/cache" mocks2 "github.com/flyteorg/flytestdlib/cache/mocks" @@ -193,6 +195,92 @@ func TestAdminLaunchPlanExecutor_Launch(t *testing.T) { assert.NoError(t, err) }) + t.Run("happy recover", func(t *testing.T) { + + mockClient := &mocks.AdminServiceClient{} + parentNodeExecution := &core.NodeExecutionIdentifier{ + NodeId: "node-id", + ExecutionId: &core.WorkflowExecutionIdentifier{ + Project: "p", + Domain: "d", + Name: "orig", + }, + } + exec, err := NewAdminLaunchPlanExecutor(ctx, mockClient, time.Second, defaultAdminConfig, promutils.NewTestScope()) + mockClient.On("RecoverExecution", + ctx, + mock.MatchedBy(func(o *admin.ExecutionRecoverRequest) bool { + return o.Id.Project == "p" && o.Id.Domain == "d" && o.Id.Name == "w" && o.Name == "n" && + proto.Equal(o.Metadata.ParentNodeExecution, parentNodeExecution) + }), + ).Return(nil, nil) + assert.NoError(t, err) + err = exec.Launch(ctx, + LaunchContext{ + RecoveryExecution: &core.WorkflowExecutionIdentifier{ + Project: "p", + Domain: "d", + Name: "w", + }, + ParentNodeExecution: parentNodeExecution, + }, + id, + &core.Identifier{}, + nil, + ) + assert.NoError(t, err) + }) + + t.Run("recovery fails", func(t *testing.T) { + + mockClient := &mocks.AdminServiceClient{} + parentNodeExecution := &core.NodeExecutionIdentifier{ + NodeId: "node-id", + ExecutionId: &core.WorkflowExecutionIdentifier{ + Project: "p", + Domain: "d", + Name: "orig", + }, + } + exec, err := NewAdminLaunchPlanExecutor(ctx, mockClient, time.Second, defaultAdminConfig, promutils.NewTestScope()) + assert.NoError(t, err) + + recoveryErr := status.Error(codes.NotFound, "foo") + mockClient.On("RecoverExecution", + ctx, + mock.MatchedBy(func(o *admin.ExecutionRecoverRequest) bool { + return o.Id.Project == "p" && o.Id.Domain == "d" && o.Id.Name == "w" && o.Name == "n" && + proto.Equal(o.Metadata.ParentNodeExecution, parentNodeExecution) + }), + ).Return(nil, recoveryErr) + + var createCalled = false + mockClient.On("CreateExecution", + ctx, + mock.MatchedBy(func(o *admin.ExecutionCreateRequest) bool { + createCalled = true + return o.Project == "p" && o.Domain == "d" && o.Name == "n" && o.Spec.Inputs == nil && + o.Spec.Metadata.Mode == admin.ExecutionMetadata_CHILD_WORKFLOW + }), + ).Return(nil, nil) + + err = exec.Launch(ctx, + LaunchContext{ + RecoveryExecution: &core.WorkflowExecutionIdentifier{ + Project: "p", + Domain: "d", + Name: "w", + }, + ParentNodeExecution: parentNodeExecution, + }, + id, + &core.Identifier{}, + nil, + ) + assert.NoError(t, err) + assert.True(t, createCalled) + }) + t.Run("notFound", func(t *testing.T) { mockClient := &mocks.AdminServiceClient{} diff --git a/pkg/controller/nodes/subworkflow/launchplan/adminconfig_flags.go b/pkg/controller/nodes/subworkflow/launchplan/adminconfig_flags.go index fc3f2da228..3bb535e179 100755 --- a/pkg/controller/nodes/subworkflow/launchplan/adminconfig_flags.go +++ b/pkg/controller/nodes/subworkflow/launchplan/adminconfig_flags.go @@ -28,6 +28,15 @@ func (AdminConfig) elemValueOrNil(v interface{}) interface{} { return v } +func (AdminConfig) mustJsonMarshal(v interface{}) string { + raw, err := json.Marshal(v) + if err != nil { + panic(err) + } + + return string(raw) +} + func (AdminConfig) mustMarshalJSON(v json.Marshaler) string { raw, err := v.MarshalJSON() if err != nil { diff --git a/pkg/controller/nodes/subworkflow/launchplan/adminconfig_flags_test.go b/pkg/controller/nodes/subworkflow/launchplan/adminconfig_flags_test.go index 447b347448..bbff474eb1 100755 --- a/pkg/controller/nodes/subworkflow/launchplan/adminconfig_flags_test.go +++ b/pkg/controller/nodes/subworkflow/launchplan/adminconfig_flags_test.go @@ -84,7 +84,7 @@ func testDecodeJson_AdminConfig(t *testing.T, val, result interface{}) { assert.NoError(t, decode_AdminConfig(val, result)) } -func testDecodeSlice_AdminConfig(t *testing.T, vStringSlice, result interface{}) { +func testDecodeRaw_AdminConfig(t *testing.T, vStringSlice, result interface{}) { assert.NoError(t, decode_AdminConfig(vStringSlice, result)) } @@ -100,14 +100,6 @@ func TestAdminConfig_SetFlags(t *testing.T) { assert.True(t, cmdFlags.HasFlags()) t.Run("Test_tps", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vInt64, err := cmdFlags.GetInt64("tps"); err == nil { - assert.Equal(t, int64(defaultAdminConfig.TPS), vInt64) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := "1" @@ -122,14 +114,6 @@ func TestAdminConfig_SetFlags(t *testing.T) { }) }) t.Run("Test_burst", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vInt, err := cmdFlags.GetInt("burst"); err == nil { - assert.Equal(t, int(defaultAdminConfig.Burst), vInt) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := "1" @@ -144,14 +128,6 @@ func TestAdminConfig_SetFlags(t *testing.T) { }) }) t.Run("Test_cacheSize", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vInt, err := cmdFlags.GetInt("cacheSize"); err == nil { - assert.Equal(t, int(defaultAdminConfig.MaxCacheSize), vInt) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := "1" @@ -166,14 +142,6 @@ func TestAdminConfig_SetFlags(t *testing.T) { }) }) t.Run("Test_workers", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vInt, err := cmdFlags.GetInt("workers"); err == nil { - assert.Equal(t, int(defaultAdminConfig.Workers), vInt) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := "1" diff --git a/pkg/controller/nodes/subworkflow/launchplan/launchplan.go b/pkg/controller/nodes/subworkflow/launchplan/launchplan.go index 931390f7d3..91c8c85884 100644 --- a/pkg/controller/nodes/subworkflow/launchplan/launchplan.go +++ b/pkg/controller/nodes/subworkflow/launchplan/launchplan.go @@ -18,6 +18,8 @@ type LaunchContext struct { Principal string // If a node launched the execution, this specifies which node execution ParentNodeExecution *core.NodeExecutionIdentifier + // If a node in recovery mode launched this execution, propagate recovery mode to the child execution. + RecoveryExecution *core.WorkflowExecutionIdentifier } // Interface to be implemented by the remote system that can allow workflow launching capabilities diff --git a/pkg/controller/nodes/subworkflow/launchplan_test.go b/pkg/controller/nodes/subworkflow/launchplan_test.go index 431723bb20..7de30efdf1 100644 --- a/pkg/controller/nodes/subworkflow/launchplan_test.go +++ b/pkg/controller/nodes/subworkflow/launchplan_test.go @@ -6,6 +6,11 @@ import ( "reflect" "testing" + mocks4 "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io/mocks" + mocks3 "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler/mocks" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + "github.com/flyteorg/flyteidl/clients/go/coreutils" "github.com/flyteorg/flytestdlib/errors" @@ -20,6 +25,7 @@ import ( mocks2 "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks" execMocks "github.com/flyteorg/flytepropeller/pkg/controller/executors/mocks" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + recoveryMocks "github.com/flyteorg/flytepropeller/pkg/controller/nodes/recovery/mocks" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/subworkflow/launchplan" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/mocks" "github.com/flyteorg/flytepropeller/pkg/utils" @@ -166,6 +172,98 @@ func TestSubWorkflowHandler_StartLaunchPlan(t *testing.T) { assert.NoError(t, err) assert.Equal(t, handler.EPhaseFailed, s.Info().GetPhase()) }) + t.Run("recover successfully", func(t *testing.T) { + recoveredExecID := &core.WorkflowExecutionIdentifier{ + Project: "p", + Domain: "d", + Name: "n", + } + + mockLPExec := &mocks.Executor{} + mockLPExec.On("Launch", mock.Anything, launchplan.LaunchContext{ + ParentNodeExecution: &core.NodeExecutionIdentifier{ + NodeId: "n", + ExecutionId: &core.WorkflowExecutionIdentifier{ + Project: "project", + Domain: "domain", + Name: "name", + }, + }, + RecoveryExecution: recoveredExecID, + }, mock.Anything, mock.Anything, mock.Anything).Return(nil) + + recoveryClient := recoveryMocks.RecoveryClient{} + recoveryClient.On("RecoverNodeExecution", mock.Anything, recoveredExecID, mock.Anything).Return(&admin.NodeExecution{ + Closure: &admin.NodeExecutionClosure{ + Phase: core.NodeExecution_SUCCEEDED, + TargetMetadata: &admin.NodeExecutionClosure_WorkflowNodeMetadata{ + WorkflowNodeMetadata: &admin.WorkflowNodeMetadata{ + ExecutionId: recoveredExecID, + }, + }, + }, + }, nil) + + h := launchPlanHandler{ + launchPlan: mockLPExec, + recoveryClient: &recoveryClient, + } + mockLPExec.On("Launch", + ctx, + mock.MatchedBy(func(o launchplan.LaunchContext) bool { + return o.ParentNodeExecution.NodeId == mockNode.GetID() && + o.ParentNodeExecution.ExecutionId == wfExecID + }), + mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { + return assert.Equal(t, wfExecID.Project, o.Project) && assert.Equal(t, wfExecID.Domain, o.Domain) + }), + mock.MatchedBy(func(o *core.Identifier) bool { return lpID == o }), + mock.MatchedBy(func(o *core.LiteralMap) bool { return o.Literals == nil }), + ).Return(nil) + + wfStatus := &mocks2.MutableWorkflowNodeStatus{} + mockNodeStatus.On("GetOrCreateWorkflowStatus").Return(wfStatus) + + nCtx := &mocks3.NodeExecutionContext{} + + ir := &mocks4.InputReader{} + inputs := &core.LiteralMap{} + ir.OnGetMatch(mock.Anything).Return(inputs, nil) + nCtx.OnInputReader().Return(ir) + + nm := &mocks3.NodeExecutionMetadata{} + nm.OnGetAnnotations().Return(map[string]string{}) + nm.OnGetNodeExecutionID().Return(&core.NodeExecutionIdentifier{ + ExecutionId: wfExecID, + NodeId: "n", + }) + nm.OnGetK8sServiceAccount().Return("service-account") + nm.OnGetLabels().Return(map[string]string{}) + nm.OnGetNamespace().Return("namespace") + nm.OnGetOwnerID().Return(types.NamespacedName{Namespace: "namespace", Name: "name"}) + nm.OnGetOwnerReference().Return(v1.OwnerReference{ + Kind: "sample", + Name: "name", + }) + + nCtx.OnNodeExecutionMetadata().Return(nm) + ectx := &execMocks.ExecutionContext{} + ectx.OnGetEventVersion().Return(1) + ectx.OnGetParentInfo().Return(nil) + ectx.OnGetExecutionConfig().Return(v1alpha1.ExecutionConfig{ + RecoveryExecution: v1alpha1.WorkflowExecutionIdentifier{ + WorkflowExecutionIdentifier: recoveredExecID, + }, + }) + nCtx.OnExecutionContext().Return(ectx) + nCtx.OnCurrentAttempt().Return(uint32(1)) + nCtx.OnNode().Return(mockNode) + + s, err := h.StartLaunchPlan(ctx, nCtx) + assert.NoError(t, err) + assert.Equal(t, s.Info().GetPhase(), handler.EPhaseRunning) + assert.Equal(t, len(recoveryClient.Calls), 1) + }) } func TestSubWorkflowHandler_CheckLaunchPlanStatus(t *testing.T) { diff --git a/pkg/controller/nodes/task/catalog/config_flags.go b/pkg/controller/nodes/task/catalog/config_flags.go index ee5d3bdef2..97e6b9a5d7 100755 --- a/pkg/controller/nodes/task/catalog/config_flags.go +++ b/pkg/controller/nodes/task/catalog/config_flags.go @@ -28,6 +28,15 @@ func (Config) elemValueOrNil(v interface{}) interface{} { return v } +func (Config) mustJsonMarshal(v interface{}) string { + raw, err := json.Marshal(v) + if err != nil { + panic(err) + } + + return string(raw) +} + func (Config) mustMarshalJSON(v json.Marshaler) string { raw, err := v.MarshalJSON() if err != nil { diff --git a/pkg/controller/nodes/task/catalog/config_flags_test.go b/pkg/controller/nodes/task/catalog/config_flags_test.go index 9b28195648..e6dc041166 100755 --- a/pkg/controller/nodes/task/catalog/config_flags_test.go +++ b/pkg/controller/nodes/task/catalog/config_flags_test.go @@ -84,7 +84,7 @@ func testDecodeJson_Config(t *testing.T, val, result interface{}) { assert.NoError(t, decode_Config(val, result)) } -func testDecodeSlice_Config(t *testing.T, vStringSlice, result interface{}) { +func testDecodeRaw_Config(t *testing.T, vStringSlice, result interface{}) { assert.NoError(t, decode_Config(vStringSlice, result)) } @@ -100,14 +100,6 @@ func TestConfig_SetFlags(t *testing.T) { assert.True(t, cmdFlags.HasFlags()) t.Run("Test_type", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vString, err := cmdFlags.GetString("type"); err == nil { - assert.Equal(t, string(defaultConfig.Type), vString) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := "1" @@ -122,14 +114,6 @@ func TestConfig_SetFlags(t *testing.T) { }) }) t.Run("Test_endpoint", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vString, err := cmdFlags.GetString("endpoint"); err == nil { - assert.Equal(t, string(defaultConfig.Endpoint), vString) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := "1" @@ -144,14 +128,6 @@ func TestConfig_SetFlags(t *testing.T) { }) }) t.Run("Test_insecure", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vBool, err := cmdFlags.GetBool("insecure"); err == nil { - assert.Equal(t, bool(defaultConfig.Insecure), vBool) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := "1" @@ -166,14 +142,6 @@ func TestConfig_SetFlags(t *testing.T) { }) }) t.Run("Test_max-cache-age", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vString, err := cmdFlags.GetString("max-cache-age"); err == nil { - assert.Equal(t, string(defaultConfig.MaxCacheAge.String()), vString) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := defaultConfig.MaxCacheAge.String() diff --git a/pkg/controller/nodes/task/config/config_flags.go b/pkg/controller/nodes/task/config/config_flags.go index b7e8ad8489..9d96d8a058 100755 --- a/pkg/controller/nodes/task/config/config_flags.go +++ b/pkg/controller/nodes/task/config/config_flags.go @@ -28,6 +28,15 @@ func (Config) elemValueOrNil(v interface{}) interface{} { return v } +func (Config) mustJsonMarshal(v interface{}) string { + raw, err := json.Marshal(v) + if err != nil { + panic(err) + } + + return string(raw) +} + func (Config) mustMarshalJSON(v json.Marshaler) string { raw, err := v.MarshalJSON() if err != nil { diff --git a/pkg/controller/nodes/task/config/config_flags_test.go b/pkg/controller/nodes/task/config/config_flags_test.go index b5ebaf283d..eacff539a0 100755 --- a/pkg/controller/nodes/task/config/config_flags_test.go +++ b/pkg/controller/nodes/task/config/config_flags_test.go @@ -84,7 +84,7 @@ func testDecodeJson_Config(t *testing.T, val, result interface{}) { assert.NoError(t, decode_Config(val, result)) } -func testDecodeSlice_Config(t *testing.T, vStringSlice, result interface{}) { +func testDecodeRaw_Config(t *testing.T, vStringSlice, result interface{}) { assert.NoError(t, decode_Config(vStringSlice, result)) } @@ -100,21 +100,13 @@ func TestConfig_SetFlags(t *testing.T) { assert.True(t, cmdFlags.HasFlags()) t.Run("Test_task-plugins.enabled-plugins", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vStringSlice, err := cmdFlags.GetStringSlice("task-plugins.enabled-plugins"); err == nil { - assert.Equal(t, []string([]string{}), vStringSlice) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := join_Config("1,1", ",") cmdFlags.Set("task-plugins.enabled-plugins", testValue) if vStringSlice, err := cmdFlags.GetStringSlice("task-plugins.enabled-plugins"); err == nil { - testDecodeSlice_Config(t, join_Config(vStringSlice, ","), &actual.TaskPlugins.EnabledPlugins) + testDecodeRaw_Config(t, join_Config(vStringSlice, ","), &actual.TaskPlugins.EnabledPlugins) } else { assert.FailNow(t, err.Error()) @@ -122,14 +114,6 @@ func TestConfig_SetFlags(t *testing.T) { }) }) t.Run("Test_max-plugin-phase-versions", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vInt32, err := cmdFlags.GetInt32("max-plugin-phase-versions"); err == nil { - assert.Equal(t, int32(defaultConfig.MaxPluginPhaseVersions), vInt32) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := "1" @@ -144,14 +128,6 @@ func TestConfig_SetFlags(t *testing.T) { }) }) t.Run("Test_barrier.enabled", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vBool, err := cmdFlags.GetBool("barrier.enabled"); err == nil { - assert.Equal(t, bool(defaultConfig.BarrierConfig.Enabled), vBool) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := "1" @@ -166,14 +142,6 @@ func TestConfig_SetFlags(t *testing.T) { }) }) t.Run("Test_barrier.cache-size", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vInt, err := cmdFlags.GetInt("barrier.cache-size"); err == nil { - assert.Equal(t, int(defaultConfig.BarrierConfig.CacheSize), vInt) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := "1" @@ -188,14 +156,6 @@ func TestConfig_SetFlags(t *testing.T) { }) }) t.Run("Test_barrier.cache-ttl", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vString, err := cmdFlags.GetString("barrier.cache-ttl"); err == nil { - assert.Equal(t, string(defaultConfig.BarrierConfig.CacheTTL.String()), vString) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := defaultConfig.BarrierConfig.CacheTTL.String() @@ -210,14 +170,6 @@ func TestConfig_SetFlags(t *testing.T) { }) }) t.Run("Test_backoff.base-second", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vInt, err := cmdFlags.GetInt("backoff.base-second"); err == nil { - assert.Equal(t, int(defaultConfig.BackOffConfig.BaseSecond), vInt) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := "1" @@ -232,14 +184,6 @@ func TestConfig_SetFlags(t *testing.T) { }) }) t.Run("Test_backoff.max-duration", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vString, err := cmdFlags.GetString("backoff.max-duration"); err == nil { - assert.Equal(t, string(defaultConfig.BackOffConfig.MaxDuration.String()), vString) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := defaultConfig.BackOffConfig.MaxDuration.String() @@ -254,14 +198,6 @@ func TestConfig_SetFlags(t *testing.T) { }) }) t.Run("Test_maxLogMessageLength", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vInt, err := cmdFlags.GetInt("maxLogMessageLength"); err == nil { - assert.Equal(t, int(defaultConfig.MaxErrorMessageLength), vInt) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := "1" diff --git a/pkg/controller/nodes/task/handler.go b/pkg/controller/nodes/task/handler.go index 47af24e26f..2af6269ea5 100644 --- a/pkg/controller/nodes/task/handler.go +++ b/pkg/controller/nodes/task/handler.go @@ -6,6 +6,8 @@ import ( "runtime/debug" "time" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/recovery" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" @@ -738,7 +740,7 @@ func (t Handler) Finalize(ctx context.Context, nCtx handler.NodeExecutionContext }() } -func New(ctx context.Context, kubeClient executors.Client, client catalog.Client, scope promutils.Scope) (*Handler, error) { +func New(ctx context.Context, kubeClient executors.Client, client catalog.Client, recoveryClient recovery.Client, scope promutils.Scope) (*Handler, error) { // TODO New should take a pointer async, err := catalog.NewAsyncClient(client, *catalog.GetConfig(), scope.NewSubScope("async_catalog")) if err != nil { diff --git a/pkg/controller/nodes/task/handler_test.go b/pkg/controller/nodes/task/handler_test.go index 0538c322ec..b8c1958736 100644 --- a/pkg/controller/nodes/task/handler_test.go +++ b/pkg/controller/nodes/task/handler_test.go @@ -7,6 +7,8 @@ import ( "testing" "time" + mocks2 "github.com/flyteorg/flytepropeller/pkg/controller/nodes/recovery/mocks" + pluginK8sMocks "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/k8s/mocks" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" @@ -237,7 +239,7 @@ func Test_task_Setup(t *testing.T) { sCtx.On("EnqueueOwner").Return(pluginCore.EnqueueOwner(func(name types.NamespacedName) error { return nil })) sCtx.On("MetricsScope").Return(promutils.NewTestScope()) - tk, err := New(context.TODO(), mocks.NewFakeKubeClient(), &pluginCatalogMocks.Client{}, promutils.NewTestScope()) + tk, err := New(context.TODO(), mocks.NewFakeKubeClient(), &pluginCatalogMocks.Client{}, &mocks2.RecoveryClient{}, promutils.NewTestScope()) tk.cfg.TaskPlugins.EnabledPlugins = tt.enabledPlugins tk.cfg.TaskPlugins.DefaultForTaskTypes = tt.defaultForTaskTypes assert.NoError(t, err) @@ -896,7 +898,7 @@ func Test_task_Handle_Catalog(t *testing.T) { } else { c.OnPutMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(catalog.NewStatus(core.CatalogCacheStatus_CACHE_POPULATED, nil), nil) } - tk, err := New(context.TODO(), mocks.NewFakeKubeClient(), c, promutils.NewTestScope()) + tk, err := New(context.TODO(), mocks.NewFakeKubeClient(), c, &mocks2.RecoveryClient{}, promutils.NewTestScope()) assert.NoError(t, err) tk.defaultPlugins = map[pluginCore.TaskType]pluginCore.Plugin{ "test": fakeplugins.NewPhaseBasedPlugin(), @@ -1187,7 +1189,7 @@ func Test_task_Handle_Barrier(t *testing.T) { nCtx := createNodeContext(ev, "test", state, tt.args.prevTick) c := &pluginCatalogMocks.Client{} - tk, err := New(context.TODO(), mocks.NewFakeKubeClient(), c, promutils.NewTestScope()) + tk, err := New(context.TODO(), mocks.NewFakeKubeClient(), c, &mocks2.RecoveryClient{}, promutils.NewTestScope()) assert.NoError(t, err) tk.resourceManager = noopRm @@ -1662,7 +1664,7 @@ func Test_task_Finalize(t *testing.T) { } func TestNew(t *testing.T) { - got, err := New(context.TODO(), mocks.NewFakeKubeClient(), &pluginCatalogMocks.Client{}, promutils.NewTestScope()) + got, err := New(context.TODO(), mocks.NewFakeKubeClient(), &pluginCatalogMocks.Client{}, &mocks2.RecoveryClient{}, promutils.NewTestScope()) assert.NoError(t, err) assert.NotNil(t, got) assert.NotNil(t, got.defaultPlugins) diff --git a/pkg/controller/nodes/task/resourcemanager/config/config_flags.go b/pkg/controller/nodes/task/resourcemanager/config/config_flags.go index 61871d8f8b..b83ff310a4 100755 --- a/pkg/controller/nodes/task/resourcemanager/config/config_flags.go +++ b/pkg/controller/nodes/task/resourcemanager/config/config_flags.go @@ -28,6 +28,15 @@ func (Config) elemValueOrNil(v interface{}) interface{} { return v } +func (Config) mustJsonMarshal(v interface{}) string { + raw, err := json.Marshal(v) + if err != nil { + panic(err) + } + + return string(raw) +} + func (Config) mustMarshalJSON(v json.Marshaler) string { raw, err := v.MarshalJSON() if err != nil { diff --git a/pkg/controller/nodes/task/resourcemanager/config/config_flags_test.go b/pkg/controller/nodes/task/resourcemanager/config/config_flags_test.go index 28a15aff02..5ed6a99730 100755 --- a/pkg/controller/nodes/task/resourcemanager/config/config_flags_test.go +++ b/pkg/controller/nodes/task/resourcemanager/config/config_flags_test.go @@ -84,7 +84,7 @@ func testDecodeJson_Config(t *testing.T, val, result interface{}) { assert.NoError(t, decode_Config(val, result)) } -func testDecodeSlice_Config(t *testing.T, vStringSlice, result interface{}) { +func testDecodeRaw_Config(t *testing.T, vStringSlice, result interface{}) { assert.NoError(t, decode_Config(vStringSlice, result)) } @@ -100,14 +100,6 @@ func TestConfig_SetFlags(t *testing.T) { assert.True(t, cmdFlags.HasFlags()) t.Run("Test_type", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vString, err := cmdFlags.GetString("type"); err == nil { - assert.Equal(t, string(defaultConfig.Type), vString) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := "1" @@ -122,14 +114,6 @@ func TestConfig_SetFlags(t *testing.T) { }) }) t.Run("Test_resourceMaxQuota", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vInt, err := cmdFlags.GetInt("resourceMaxQuota"); err == nil { - assert.Equal(t, int(defaultConfig.ResourceMaxQuota), vInt) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := "1" @@ -144,21 +128,13 @@ func TestConfig_SetFlags(t *testing.T) { }) }) t.Run("Test_redis.hostPaths", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vStringSlice, err := cmdFlags.GetStringSlice("redis.hostPaths"); err == nil { - assert.Equal(t, []string([]string{}), vStringSlice) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := join_Config("1,1", ",") cmdFlags.Set("redis.hostPaths", testValue) if vStringSlice, err := cmdFlags.GetStringSlice("redis.hostPaths"); err == nil { - testDecodeSlice_Config(t, join_Config(vStringSlice, ","), &actual.RedisConfig.HostPaths) + testDecodeRaw_Config(t, join_Config(vStringSlice, ","), &actual.RedisConfig.HostPaths) } else { assert.FailNow(t, err.Error()) @@ -166,14 +142,6 @@ func TestConfig_SetFlags(t *testing.T) { }) }) t.Run("Test_redis.primaryName", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vString, err := cmdFlags.GetString("redis.primaryName"); err == nil { - assert.Equal(t, string(defaultConfig.RedisConfig.PrimaryName), vString) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := "1" @@ -188,14 +156,6 @@ func TestConfig_SetFlags(t *testing.T) { }) }) t.Run("Test_redis.hostPath", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vString, err := cmdFlags.GetString("redis.hostPath"); err == nil { - assert.Equal(t, string(defaultConfig.RedisConfig.HostPath), vString) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := "1" @@ -210,14 +170,6 @@ func TestConfig_SetFlags(t *testing.T) { }) }) t.Run("Test_redis.hostKey", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vString, err := cmdFlags.GetString("redis.hostKey"); err == nil { - assert.Equal(t, string(defaultConfig.RedisConfig.HostKey), vString) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := "1" @@ -232,14 +184,6 @@ func TestConfig_SetFlags(t *testing.T) { }) }) t.Run("Test_redis.maxRetries", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vInt, err := cmdFlags.GetInt("redis.maxRetries"); err == nil { - assert.Equal(t, int(defaultConfig.RedisConfig.MaxRetries), vInt) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := "1" diff --git a/pkg/controller/nodes/task/secretmanager/config_flags.go b/pkg/controller/nodes/task/secretmanager/config_flags.go index 6cf99f8a97..cbe1087761 100755 --- a/pkg/controller/nodes/task/secretmanager/config_flags.go +++ b/pkg/controller/nodes/task/secretmanager/config_flags.go @@ -28,6 +28,15 @@ func (Config) elemValueOrNil(v interface{}) interface{} { return v } +func (Config) mustJsonMarshal(v interface{}) string { + raw, err := json.Marshal(v) + if err != nil { + panic(err) + } + + return string(raw) +} + func (Config) mustMarshalJSON(v json.Marshaler) string { raw, err := v.MarshalJSON() if err != nil { diff --git a/pkg/controller/nodes/task/secretmanager/config_flags_test.go b/pkg/controller/nodes/task/secretmanager/config_flags_test.go index 292d4eb02a..66a9dffc74 100755 --- a/pkg/controller/nodes/task/secretmanager/config_flags_test.go +++ b/pkg/controller/nodes/task/secretmanager/config_flags_test.go @@ -84,7 +84,7 @@ func testDecodeJson_Config(t *testing.T, val, result interface{}) { assert.NoError(t, decode_Config(val, result)) } -func testDecodeSlice_Config(t *testing.T, vStringSlice, result interface{}) { +func testDecodeRaw_Config(t *testing.T, vStringSlice, result interface{}) { assert.NoError(t, decode_Config(vStringSlice, result)) } @@ -100,14 +100,6 @@ func TestConfig_SetFlags(t *testing.T) { assert.True(t, cmdFlags.HasFlags()) t.Run("Test_secrets-prefix", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vString, err := cmdFlags.GetString("secrets-prefix"); err == nil { - assert.Equal(t, string(defaultConfig.SecretFilePrefix), vString) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := "1" @@ -122,14 +114,6 @@ func TestConfig_SetFlags(t *testing.T) { }) }) t.Run("Test_env-prefix", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vString, err := cmdFlags.GetString("env-prefix"); err == nil { - assert.Equal(t, string(defaultConfig.EnvironmentPrefix), vString) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := "1" diff --git a/pkg/controller/nodes/transformers.go b/pkg/controller/nodes/transformers.go index 129c3f4c48..21f8966de2 100644 --- a/pkg/controller/nodes/transformers.go +++ b/pkg/controller/nodes/transformers.go @@ -11,7 +11,6 @@ import ( "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io" "github.com/flyteorg/flytestdlib/logger" "github.com/golang/protobuf/ptypes" v1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -64,6 +63,8 @@ func ToNodeExecEventPhase(p handler.EPhase) core.NodeExecution_Phase { return core.NodeExecution_SUCCEEDED case handler.EPhaseFailed: return core.NodeExecution_FAILED + case handler.EPhaseRecovered: + return core.NodeExecution_RECOVERED default: return core.NodeExecution_UNDEFINED } @@ -71,7 +72,7 @@ func ToNodeExecEventPhase(p handler.EPhase) core.NodeExecution_Phase { func ToNodeExecutionEvent(nodeExecID *core.NodeExecutionIdentifier, info handler.PhaseInfo, - reader io.InputReader, + inputPath string, status v1alpha1.ExecutableNodeStatus, eventVersion v1alpha1.EventVersion, parentInfo executors.ImmutableParentInfo, @@ -96,7 +97,7 @@ func ToNodeExecutionEvent(nodeExecID *core.NodeExecutionIdentifier, nev := &event.NodeExecutionEvent{ Id: nodeExecID, Phase: phase, - InputUri: reader.GetInputPath().String(), + InputUri: inputPath, OccurredAt: occurredTime, } @@ -166,6 +167,8 @@ func ToNodePhase(p handler.EPhase) (v1alpha1.NodePhase, error) { return v1alpha1.NodePhaseFailing, nil case handler.EPhaseTimedout: return v1alpha1.NodePhaseTimingOut, nil + case handler.EPhaseRecovered: + return v1alpha1.NodePhaseRecovered, nil } return v1alpha1.NodePhaseNotYetStarted, fmt.Errorf("no known conversion from handlerPhase[%d] to NodePhase", p) } diff --git a/pkg/controller/workflow/executor_test.go b/pkg/controller/workflow/executor_test.go index 6bf7c5e188..b6f5fec2ab 100644 --- a/pkg/controller/workflow/executor_test.go +++ b/pkg/controller/workflow/executor_test.go @@ -46,6 +46,7 @@ import ( "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" "github.com/flyteorg/flytepropeller/pkg/controller/config" "github.com/flyteorg/flytepropeller/pkg/controller/nodes" + recoveryMocks "github.com/flyteorg/flytepropeller/pkg/controller/nodes/recovery/mocks" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/subworkflow/launchplan" ) @@ -230,10 +231,11 @@ func TestWorkflowExecutor_HandleFlyteWorkflow_Error(t *testing.T) { eventSink := events.NewMockEventSink() catalogClient, err := catalog.NewCatalogClient(ctx) assert.NoError(t, err) + recoveryClient := &recoveryMocks.RecoveryClient{} adminClient := launchplan.NewFailFastLaunchPlanExecutor() nodeExec, err := nodes.NewExecutor(ctx, config.GetConfig().NodeConfig, store, enqueueWorkflow, eventSink, adminClient, - adminClient, maxOutputSize, "s3://bucket", fakeKubeClient, catalogClient, promutils.NewTestScope()) + adminClient, maxOutputSize, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, promutils.NewTestScope()) assert.NoError(t, err) executor, err := NewExecutor(ctx, store, enqueueWorkflow, eventSink, recorder, "", nodeExec, promutils.NewTestScope()) assert.NoError(t, err) @@ -308,10 +310,11 @@ func TestWorkflowExecutor_HandleFlyteWorkflow(t *testing.T) { eventSink := events.NewMockEventSink() catalogClient, err := catalog.NewCatalogClient(ctx) assert.NoError(t, err) + recoveryClient := &recoveryMocks.RecoveryClient{} adminClient := launchplan.NewFailFastLaunchPlanExecutor() nodeExec, err := nodes.NewExecutor(ctx, config.GetConfig().NodeConfig, store, enqueueWorkflow, eventSink, adminClient, - adminClient, maxOutputSize, "s3://bucket", fakeKubeClient, catalogClient, promutils.NewTestScope()) + adminClient, maxOutputSize, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, promutils.NewTestScope()) assert.NoError(t, err) executor, err := NewExecutor(ctx, store, enqueueWorkflow, eventSink, recorder, "", nodeExec, promutils.NewTestScope()) @@ -371,9 +374,10 @@ func BenchmarkWorkflowExecutor(b *testing.B) { eventSink := events.NewMockEventSink() catalogClient, err := catalog.NewCatalogClient(ctx) assert.NoError(b, err) + recoveryClient := &recoveryMocks.RecoveryClient{} adminClient := launchplan.NewFailFastLaunchPlanExecutor() nodeExec, err := nodes.NewExecutor(ctx, config.GetConfig().NodeConfig, store, enqueueWorkflow, eventSink, adminClient, - adminClient, maxOutputSize, "s3://bucket", fakeKubeClient, catalogClient, scope) + adminClient, maxOutputSize, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, scope) assert.NoError(b, err) executor, err := NewExecutor(ctx, store, enqueueWorkflow, eventSink, recorder, "", nodeExec, promutils.NewTestScope()) @@ -458,9 +462,10 @@ func TestWorkflowExecutor_HandleFlyteWorkflow_Failing(t *testing.T) { } catalogClient, err := catalog.NewCatalogClient(ctx) assert.NoError(t, err) + recoveryClient := &recoveryMocks.RecoveryClient{} adminClient := launchplan.NewFailFastLaunchPlanExecutor() nodeExec, err := nodes.NewExecutor(ctx, config.GetConfig().NodeConfig, store, enqueueWorkflow, eventSink, adminClient, - adminClient, maxOutputSize, "s3://bucket", fakeKubeClient, catalogClient, promutils.NewTestScope()) + adminClient, maxOutputSize, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, promutils.NewTestScope()) assert.NoError(t, err) executor, err := NewExecutor(ctx, store, enqueueWorkflow, eventSink, recorder, "", nodeExec, promutils.NewTestScope()) assert.NoError(t, err) @@ -551,8 +556,9 @@ func TestWorkflowExecutor_HandleFlyteWorkflow_Events(t *testing.T) { catalogClient, err := catalog.NewCatalogClient(ctx) assert.NoError(t, err) adminClient := launchplan.NewFailFastLaunchPlanExecutor() + recoveryClient := &recoveryMocks.RecoveryClient{} nodeExec, err := nodes.NewExecutor(ctx, config.GetConfig().NodeConfig, store, enqueueWorkflow, eventSink, adminClient, - adminClient, maxOutputSize, "s3://bucket", fakeKubeClient, catalogClient, promutils.NewTestScope()) + adminClient, maxOutputSize, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, promutils.NewTestScope()) assert.NoError(t, err) executor, err := NewExecutor(ctx, store, enqueueWorkflow, eventSink, recorder, "metadata", nodeExec, promutils.NewTestScope()) assert.NoError(t, err) @@ -605,10 +611,11 @@ func TestWorkflowExecutor_HandleFlyteWorkflow_EventFailure(t *testing.T) { nodeEventSink := events.NewMockEventSink() catalogClient, err := catalog.NewCatalogClient(ctx) assert.NoError(t, err) + recoveryClient := &recoveryMocks.RecoveryClient{} adminClient := launchplan.NewFailFastLaunchPlanExecutor() nodeExec, err := nodes.NewExecutor(ctx, config.GetConfig().NodeConfig, store, enqueueWorkflow, nodeEventSink, adminClient, - adminClient, maxOutputSize, "s3://bucket", fakeKubeClient, catalogClient, promutils.NewTestScope()) + adminClient, maxOutputSize, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, promutils.NewTestScope()) assert.NoError(t, err) t.Run("EventAlreadyInTerminalStateError", func(t *testing.T) { diff --git a/pkg/controller/workflowstore/config_flags.go b/pkg/controller/workflowstore/config_flags.go index c62d91a0d3..38a0d18263 100755 --- a/pkg/controller/workflowstore/config_flags.go +++ b/pkg/controller/workflowstore/config_flags.go @@ -28,6 +28,15 @@ func (Config) elemValueOrNil(v interface{}) interface{} { return v } +func (Config) mustJsonMarshal(v interface{}) string { + raw, err := json.Marshal(v) + if err != nil { + panic(err) + } + + return string(raw) +} + func (Config) mustMarshalJSON(v json.Marshaler) string { raw, err := v.MarshalJSON() if err != nil { diff --git a/pkg/controller/workflowstore/config_flags_test.go b/pkg/controller/workflowstore/config_flags_test.go index e70cd4659f..ec6a643ecf 100755 --- a/pkg/controller/workflowstore/config_flags_test.go +++ b/pkg/controller/workflowstore/config_flags_test.go @@ -84,7 +84,7 @@ func testDecodeJson_Config(t *testing.T, val, result interface{}) { assert.NoError(t, decode_Config(val, result)) } -func testDecodeSlice_Config(t *testing.T, vStringSlice, result interface{}) { +func testDecodeRaw_Config(t *testing.T, vStringSlice, result interface{}) { assert.NoError(t, decode_Config(vStringSlice, result)) } @@ -100,14 +100,6 @@ func TestConfig_SetFlags(t *testing.T) { assert.True(t, cmdFlags.HasFlags()) t.Run("Test_policy", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vString, err := cmdFlags.GetString("policy"); err == nil { - assert.Equal(t, string(defaultConfig.Policy), vString) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := "1" diff --git a/pkg/webhook/config/config_flags.go b/pkg/webhook/config/config_flags.go index 7ea3dd278a..c93c9c2cc7 100755 --- a/pkg/webhook/config/config_flags.go +++ b/pkg/webhook/config/config_flags.go @@ -28,6 +28,15 @@ func (Config) elemValueOrNil(v interface{}) interface{} { return v } +func (Config) mustJsonMarshal(v interface{}) string { + raw, err := json.Marshal(v) + if err != nil { + panic(err) + } + + return string(raw) +} + func (Config) mustMarshalJSON(v json.Marshaler) string { raw, err := v.MarshalJSON() if err != nil { diff --git a/pkg/webhook/config/config_flags_test.go b/pkg/webhook/config/config_flags_test.go index 441ca3b2d3..a539e5edd5 100755 --- a/pkg/webhook/config/config_flags_test.go +++ b/pkg/webhook/config/config_flags_test.go @@ -84,7 +84,7 @@ func testDecodeJson_Config(t *testing.T, val, result interface{}) { assert.NoError(t, decode_Config(val, result)) } -func testDecodeSlice_Config(t *testing.T, vStringSlice, result interface{}) { +func testDecodeRaw_Config(t *testing.T, vStringSlice, result interface{}) { assert.NoError(t, decode_Config(vStringSlice, result)) } @@ -100,14 +100,6 @@ func TestConfig_SetFlags(t *testing.T) { assert.True(t, cmdFlags.HasFlags()) t.Run("Test_metrics-prefix", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vString, err := cmdFlags.GetString("metrics-prefix"); err == nil { - assert.Equal(t, string(DefaultConfig.MetricsPrefix), vString) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := "1" @@ -122,14 +114,6 @@ func TestConfig_SetFlags(t *testing.T) { }) }) t.Run("Test_certDir", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vString, err := cmdFlags.GetString("certDir"); err == nil { - assert.Equal(t, string(DefaultConfig.CertDir), vString) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := "1" @@ -144,14 +128,6 @@ func TestConfig_SetFlags(t *testing.T) { }) }) t.Run("Test_listenPort", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vInt, err := cmdFlags.GetInt("listenPort"); err == nil { - assert.Equal(t, int(DefaultConfig.ListenPort), vInt) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := "1" @@ -166,14 +142,6 @@ func TestConfig_SetFlags(t *testing.T) { }) }) t.Run("Test_serviceName", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vString, err := cmdFlags.GetString("serviceName"); err == nil { - assert.Equal(t, string(DefaultConfig.ServiceName), vString) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := "1" @@ -188,14 +156,6 @@ func TestConfig_SetFlags(t *testing.T) { }) }) t.Run("Test_secretName", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vString, err := cmdFlags.GetString("secretName"); err == nil { - assert.Equal(t, string(DefaultConfig.SecretName), vString) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := "1" @@ -210,14 +170,6 @@ func TestConfig_SetFlags(t *testing.T) { }) }) t.Run("Test_awsSecretManager.sidecarImage", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vString, err := cmdFlags.GetString("awsSecretManager.sidecarImage"); err == nil { - assert.Equal(t, string(DefaultConfig.AWSSecretManagerConfig.SidecarImage), vString) - } else { - assert.FailNow(t, err.Error()) - } - }) t.Run("Override", func(t *testing.T) { testValue := "1" diff --git a/pkg/webhook/mocks/secrets_injector.go b/pkg/webhook/mocks/secrets_injector.go index 647e17c774..22b8ba6ad5 100644 --- a/pkg/webhook/mocks/secrets_injector.go +++ b/pkg/webhook/mocks/secrets_injector.go @@ -4,9 +4,11 @@ package mocks import ( context "context" - "github.com/flyteorg/flytepropeller/pkg/webhook/config" + + config "github.com/flyteorg/flytepropeller/pkg/webhook/config" core "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + mock "github.com/stretchr/testify/mock" v1 "k8s.io/api/core/v1"