From 0a4ac6a4595dce85c1fe89f86a36677ff82a744d Mon Sep 17 00:00:00 2001 From: Katrina Rogan Date: Mon, 26 Jul 2021 20:58:57 -0700 Subject: [PATCH] Implement recovery endpoint (#220) --- go.mod | 8 +- go.sum | 16 +- pkg/manager/impl/execution_manager.go | 55 +++++ pkg/manager/impl/execution_manager_test.go | 223 ++++++++++++++++++ pkg/manager/interfaces/execution.go | 5 + pkg/manager/mocks/execution.go | 11 + pkg/rpc/adminservice/execution.go | 25 ++ pkg/rpc/adminservice/metrics.go | 2 + pkg/rpc/adminservice/tests/execution_test.go | 63 +++++ pkg/workflowengine/impl/propeller_executor.go | 9 +- .../impl/propeller_executor_test.go | 23 +- pkg/workflowengine/interfaces/executor.go | 1 + 12 files changed, 424 insertions(+), 17 deletions(-) diff --git a/go.mod b/go.mod index 52d41d598..f8fbfb4d9 100644 --- a/go.mod +++ b/go.mod @@ -16,10 +16,10 @@ require ( github.com/coreos/go-oidc v2.2.1+incompatible github.com/dgrijalva/jwt-go v3.2.0+incompatible github.com/evanphx/json-patch v4.9.0+incompatible - github.com/flyteorg/flyteidl v0.19.5 - github.com/flyteorg/flyteplugins v0.5.56 - github.com/flyteorg/flytepropeller v0.12.9 - github.com/flyteorg/flytestdlib v0.3.22 + github.com/flyteorg/flyteidl v0.19.14 + github.com/flyteorg/flyteplugins v0.5.59 + github.com/flyteorg/flytepropeller v0.13.3 + github.com/flyteorg/flytestdlib v0.3.27 github.com/ghodss/yaml v1.0.0 github.com/gofrs/uuid v4.0.0+incompatible // indirect github.com/gogo/protobuf v1.3.2 diff --git a/go.sum b/go.sum index 4ccd75e86..ff6d387d0 100644 --- a/go.sum +++ b/go.sum @@ -304,16 +304,16 @@ github.com/fatih/structtag v1.2.0/go.mod h1:mBJUNpUnHmRKrKlQQlmCrh5PuhftFbNv8Ys4 github.com/felixge/httpsnoop v1.0.1 h1:lvB5Jl89CsZtGIWuTcDM1E/vkVs49/Ml7JJe07l8SPQ= github.com/felixge/httpsnoop v1.0.1/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/flyteorg/flyteidl v0.19.2/go.mod h1:576W2ViEyjTpT+kEVHAGbrTP3HARNUZ/eCwrNPmdx9U= -github.com/flyteorg/flyteidl v0.19.5 h1:qNhNK6mhCTuOms7zJmBtog6bLQJhBj+iScf1IlHdqeg= -github.com/flyteorg/flyteidl v0.19.5/go.mod h1:576W2ViEyjTpT+kEVHAGbrTP3HARNUZ/eCwrNPmdx9U= -github.com/flyteorg/flyteplugins v0.5.56 h1:LF/dwMFJDSMEmOp8hd9rU4Et4oyn0K+LgMzcHOu/xrw= -github.com/flyteorg/flyteplugins v0.5.56/go.mod h1:Jp5WheQMI08luZmgcmcgyjtzakKH0tPws/t35DzpKUA= -github.com/flyteorg/flytepropeller v0.12.9 h1:ocxVxJlB8t7nP1fesJ20+4VCDM7oLF1ahqXC+E3sw2c= -github.com/flyteorg/flytepropeller v0.12.9/go.mod h1:DxQI+r+Yg6EAajDBmfKJqOjDBwiM4cJgfPSyWjiz2l0= +github.com/flyteorg/flyteidl v0.19.14 h1:OLg2eT9uYllcfMMjEZJoXQ+2WXcrNbUxD+yaCrz2AlI= +github.com/flyteorg/flyteidl v0.19.14/go.mod h1:576W2ViEyjTpT+kEVHAGbrTP3HARNUZ/eCwrNPmdx9U= +github.com/flyteorg/flyteplugins v0.5.59 h1:Uw1xlrlx5rSTpdTMwJTo7mbqHI7X7p7CFVm3473iRjo= +github.com/flyteorg/flyteplugins v0.5.59/go.mod h1:nesnW7pJhXEysFQg9TnSp36ao33ie0oA/TI4sYPaeyw= +github.com/flyteorg/flytepropeller v0.13.3 h1:nnO4d9w6UbgLCF9kn0M6LTkYpS/F5jEoEF22YcRmLYI= +github.com/flyteorg/flytepropeller v0.13.3/go.mod h1:c+OOw8L7h1/IaxoiRZ1Hmhenlc1dxIT23yzhFETRgXI= github.com/flyteorg/flytestdlib v0.3.13/go.mod h1:Tz8JCECAbX6VWGwFT6cmEQ+RJpZ/6L9pswu3fzWs220= -github.com/flyteorg/flytestdlib v0.3.17/go.mod h1:VlbQuHTE+z2N5qusfwi+6WEkeJoqr8Q0E4NtBAsdwkU= -github.com/flyteorg/flytestdlib v0.3.22 h1:nJEPaCdxzXBaeg2p4fdo3I3Ua09NedFRaUwuLafLEdw= github.com/flyteorg/flytestdlib v0.3.22/go.mod h1:1XG0DwYTUm34Yrffm1Qy9Tdr/pWQypEqTq5dUxw3/cM= +github.com/flyteorg/flytestdlib v0.3.27 h1:d3OI5qb5u8CkSs2HMTuM62K5GuTrf6FJKq8CHW6Ymbs= +github.com/flyteorg/flytestdlib v0.3.27/go.mod h1:7cDWkY3v7xsoesFcDdu6DSW5Q2U2W5KlHUbUHSwBG1Q= github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k= github.com/fogleman/gg v1.3.0/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k= github.com/form3tech-oss/jwt-go v3.2.2+incompatible h1:TcekIExNqud5crz4xD2pavyTgWiPvpYe4Xau31I0PRk= diff --git a/pkg/manager/impl/execution_manager.go b/pkg/manager/impl/execution_manager.go index cfe757b05..ac850904b 100644 --- a/pkg/manager/impl/execution_manager.go +++ b/pkg/manager/impl/execution_manager.go @@ -753,6 +753,10 @@ func (m *ExecutionManager) launchExecutionAndPrepareModel( if overrides != nil { executeWorkflowInputs.TaskPluginOverrides = overrides } + if request.Spec.Metadata != nil && request.Spec.Metadata.ReferenceExecution != nil && + request.Spec.Metadata.Mode == admin.ExecutionMetadata_RECOVERED { + executeWorkflowInputs.RecoveryExecution = request.Spec.Metadata.ReferenceExecution + } execInfo, err := m.workflowExecutor.ExecuteWorkflow(ctx, executeWorkflowInputs) if err != nil { @@ -901,6 +905,57 @@ func (m *ExecutionManager) RelaunchExecution( }, nil } +func (m *ExecutionManager) RecoverExecution( + ctx context.Context, request admin.ExecutionRecoverRequest, requestedAt time.Time) ( + *admin.ExecutionCreateResponse, error) { + existingExecutionModel, err := util.GetExecutionModel(ctx, m.db, *request.Id) + if err != nil { + logger.Debugf(ctx, "Failed to get execution model for request [%+v] with err %v", request, err) + return nil, err + } + existingExecution, err := transformers.FromExecutionModel(*existingExecutionModel) + if err != nil { + return nil, err + } + + executionSpec := existingExecution.Spec + if executionSpec.Metadata == nil { + executionSpec.Metadata = &admin.ExecutionMetadata{} + } + var inputs *core.LiteralMap + if len(existingExecutionModel.UserInputsURI) > 0 { + inputs = &core.LiteralMap{} + if err := m.storageClient.ReadProtobuf(ctx, existingExecutionModel.UserInputsURI, inputs); err != nil { + return nil, err + } + } + if request.Metadata != nil { + executionSpec.Metadata.ParentNodeExecution = request.Metadata.ParentNodeExecution + } + executionSpec.Metadata.Mode = admin.ExecutionMetadata_RECOVERED + executionSpec.Metadata.ReferenceExecution = existingExecution.Id + var executionModel *models.Execution + ctx, executionModel, err = m.launchExecutionAndPrepareModel(ctx, admin.ExecutionCreateRequest{ + Project: request.Id.Project, + Domain: request.Id.Domain, + Name: request.Name, + Spec: executionSpec, + Inputs: inputs, + }, requestedAt) + if err != nil { + return nil, err + } + executionModel.SourceExecutionID = existingExecutionModel.ID + workflowExecutionIdentifier, err := m.createExecutionModel(ctx, executionModel) + if err != nil { + return nil, err + } + logger.Infof(ctx, "Successfully recovered [%+v] as [%+v]", request.Id, workflowExecutionIdentifier) + return &admin.ExecutionCreateResponse{ + Id: workflowExecutionIdentifier, + }, nil +} + func (m *ExecutionManager) emitScheduledWorkflowMetrics( ctx context.Context, executionModel *models.Execution, runningEventTimeProto *timestamp.Timestamp) { if executionModel == nil || runningEventTimeProto == nil { diff --git a/pkg/manager/impl/execution_manager_test.go b/pkg/manager/impl/execution_manager_test.go index 41cd189cb..08d0ed660 100644 --- a/pkg/manager/impl/execution_manager_test.go +++ b/pkg/manager/impl/execution_manager_test.go @@ -950,6 +950,229 @@ func TestRelaunchExecution_CreateFailure(t *testing.T) { assert.EqualError(t, err, expectedErr.Error()) } +func TestRecoverExecution(t *testing.T) { + // Set up mocks. + repository := getMockRepositoryForExecTest() + setDefaultLpCallbackForExecTest(repository) + execManager := NewExecutionManager(repository, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), workflowengineMocks.NewMockExecutor(), mockScope.NewTestScope(), mockScope.NewTestScope(), &mockPublisher, mockExecutionRemoteURL, nil, nil, nil, &eventWriterMocks.WorkflowExecutionEventWriter{}) + startTime := time.Now() + startTimeProto, _ := ptypes.TimestampProto(startTime) + existingClosure := admin.ExecutionClosure{ + Phase: core.WorkflowExecution_SUCCEEDED, + StartedAt: startTimeProto, + } + existingClosureBytes, _ := proto.Marshal(&existingClosure) + executionGetFunc := makeExecutionGetFunc(t, existingClosureBytes, &startTime) + repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetGetCallback(executionGetFunc) + + var createCalled bool + exCreateFunc := func(ctx context.Context, input models.Execution) error { + createCalled = true + assert.Equal(t, "recovered", input.Name) + assert.Equal(t, "domain", input.Domain) + assert.Equal(t, "project", input.Project) + assert.Equal(t, uint(8), input.SourceExecutionID) + var spec admin.ExecutionSpec + err := proto.Unmarshal(input.Spec, &spec) + assert.Nil(t, err) + assert.Equal(t, admin.ExecutionMetadata_RECOVERED, spec.Metadata.Mode) + assert.Equal(t, int32(admin.ExecutionMetadata_RECOVERED), input.Mode) + return nil + } + repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetCreateCallback(exCreateFunc) + + // Issue request. + response, err := execManager.RecoverExecution(context.Background(), admin.ExecutionRecoverRequest{ + Id: &core.WorkflowExecutionIdentifier{ + Project: "project", + Domain: "domain", + Name: "name", + }, + Name: "recovered", + }, requestedAt) + + // And verify response. + assert.Nil(t, err) + + expectedResponse := &admin.ExecutionCreateResponse{ + Id: &core.WorkflowExecutionIdentifier{ + Project: "project", + Domain: "domain", + Name: "recovered", + }, + } + assert.True(t, createCalled) + assert.True(t, proto.Equal(expectedResponse, response)) +} + +func TestRecoverExecution_RecoveredChildNode(t *testing.T) { + repository := getMockRepositoryForExecTest() + setDefaultLpCallbackForExecTest(repository) + execManager := NewExecutionManager(repository, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), workflowengineMocks.NewMockExecutor(), mockScope.NewTestScope(), mockScope.NewTestScope(), &mockPublisher, mockExecutionRemoteURL, nil, nil, nil, &eventWriterMocks.WorkflowExecutionEventWriter{}) + startTime := time.Now() + startTimeProto, _ := ptypes.TimestampProto(startTime) + existingClosure := admin.ExecutionClosure{ + Phase: core.WorkflowExecution_SUCCEEDED, + StartedAt: startTimeProto, + } + existingClosureBytes, _ := proto.Marshal(&existingClosure) + referencedExecutionID := uint(123) + ignoredExecutionID := uint(456) + repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetGetCallback(func(ctx context.Context, input interfaces.Identifier) (models.Execution, error) { + switch input.Name { + case "name": + return models.Execution{ + Spec: specBytes, + Closure: existingClosureBytes, + BaseModel: models.BaseModel{ + ID: referencedExecutionID, + }, + }, nil + case "orig": + return models.Execution{ + BaseModel: models.BaseModel{ + ID: ignoredExecutionID, + }, + }, nil + default: + return models.Execution{}, flyteAdminErrors.NewFlyteAdminErrorf(codes.InvalidArgument, "unexpected get for execution %s", input.Name) + } + }) + + parentNodeDatabaseID := uint(12345) + var createCalled bool + exCreateFunc := func(ctx context.Context, input models.Execution) error { + createCalled = true + assert.Equal(t, "recovered", input.Name) + assert.Equal(t, "domain", input.Domain) + assert.Equal(t, "project", input.Project) + var spec admin.ExecutionSpec + err := proto.Unmarshal(input.Spec, &spec) + assert.Nil(t, err) + assert.Equal(t, admin.ExecutionMetadata_RECOVERED, spec.Metadata.Mode) + assert.Equal(t, int32(admin.ExecutionMetadata_RECOVERED), input.Mode) + assert.Equal(t, parentNodeDatabaseID, input.ParentNodeExecutionID) + assert.Equal(t, referencedExecutionID, input.SourceExecutionID) + + return nil + } + repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetCreateCallback(exCreateFunc) + + parentNodeExecution := core.NodeExecutionIdentifier{ + ExecutionId: &core.WorkflowExecutionIdentifier{ + Project: "p", + Domain: "d", + Name: "orig", + }, + NodeId: "parent", + } + repository.NodeExecutionRepo().(*repositoryMocks.MockNodeExecutionRepo).SetGetCallback(func(ctx context.Context, input interfaces.NodeExecutionResource) (models.NodeExecution, error) { + assert.True(t, proto.Equal(&parentNodeExecution, &input.NodeExecutionIdentifier)) + + return models.NodeExecution{ + BaseModel: models.BaseModel{ + ID: parentNodeDatabaseID, + }, + }, nil + }) + + // Issue request. + response, err := execManager.RecoverExecution(context.Background(), admin.ExecutionRecoverRequest{ + Id: &core.WorkflowExecutionIdentifier{ + Project: "project", + Domain: "domain", + Name: "name", + }, + Name: "recovered", + Metadata: &admin.ExecutionMetadata{ + ParentNodeExecution: &parentNodeExecution, + }, + }, requestedAt) + + // And verify response. + assert.Nil(t, err) + + expectedResponse := &admin.ExecutionCreateResponse{ + Id: &core.WorkflowExecutionIdentifier{ + Project: "project", + Domain: "domain", + Name: "recovered", + }, + } + assert.True(t, createCalled) + assert.True(t, proto.Equal(expectedResponse, response)) +} + +func TestRecoverExecution_GetExistingFailure(t *testing.T) { + // Set up mocks. + repository := getMockRepositoryForExecTest() + setDefaultLpCallbackForExecTest(repository) + execManager := NewExecutionManager(repository, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), workflowengineMocks.NewMockExecutor(), mockScope.NewTestScope(), mockScope.NewTestScope(), &mockPublisher, mockExecutionRemoteURL, nil, nil, nil, &eventWriterMocks.WorkflowExecutionEventWriter{}) + + expectedErr := errors.New("expected error") + repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetGetCallback( + func(ctx context.Context, input interfaces.Identifier) (models.Execution, error) { + return models.Execution{}, expectedErr + }) + + var createCalled bool + repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetCreateCallback( + func(ctx context.Context, input models.Execution) error { + createCalled = true + return nil + }) + + // Issue request. + _, err := execManager.RecoverExecution(context.Background(), admin.ExecutionRecoverRequest{ + Id: &core.WorkflowExecutionIdentifier{ + Project: "project", + Domain: "domain", + Name: "name", + }, + Name: "recovered", + }, requestedAt) + + // And verify response. + assert.EqualError(t, err, expectedErr.Error()) + assert.False(t, createCalled) +} + +func TestRecoverExecution_GetExistingInputsFailure(t *testing.T) { + // Set up mocks. + repository := getMockRepositoryForExecTest() + setDefaultLpCallbackForExecTest(repository) + + expectedErr := errors.New("foo") + mockStorage := commonMocks.GetMockStorageClient() + mockStorage.ComposedProtobufStore.(*commonMocks.TestDataStore).ReadProtobufCb = func( + ctx context.Context, reference storage.DataReference, msg proto.Message) error { + return expectedErr + } + execManager := NewExecutionManager(repository, getMockExecutionsConfigProvider(), mockStorage, workflowengineMocks.NewMockExecutor(), mockScope.NewTestScope(), mockScope.NewTestScope(), &mockPublisher, mockExecutionRemoteURL, nil, nil, nil, &eventWriterMocks.WorkflowExecutionEventWriter{}) + startTime := time.Now() + startTimeProto, _ := ptypes.TimestampProto(startTime) + existingClosure := admin.ExecutionClosure{ + Phase: core.WorkflowExecution_SUCCEEDED, + StartedAt: startTimeProto, + } + existingClosureBytes, _ := proto.Marshal(&existingClosure) + executionGetFunc := makeExecutionGetFunc(t, existingClosureBytes, &startTime) + repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetGetCallback(executionGetFunc) + + // Issue request. + _, err := execManager.RecoverExecution(context.Background(), admin.ExecutionRecoverRequest{ + Id: &core.WorkflowExecutionIdentifier{ + Project: "project", + Domain: "domain", + Name: "name", + }, + Name: "recovered", + }, requestedAt) + + // And verify response. + assert.EqualError(t, err, "Unable to read WorkflowClosure from location s3://flyte/metadata/admin/remote closure id : foo") +} + func TestCreateWorkflowEvent(t *testing.T) { repository := repositoryMocks.NewMockRepository() startTime := time.Now() diff --git a/pkg/manager/interfaces/execution.go b/pkg/manager/interfaces/execution.go index 10d556ff1..8850bf148 100644 --- a/pkg/manager/interfaces/execution.go +++ b/pkg/manager/interfaces/execution.go @@ -13,6 +13,11 @@ type ExecutionInterface interface { *admin.ExecutionCreateResponse, error) RelaunchExecution(ctx context.Context, request admin.ExecutionRelaunchRequest, requestedAt time.Time) ( *admin.ExecutionCreateResponse, error) + // Recreates a previously-run workflow execution that will point to the original execution so that propeller will + // only start executing from the last known failure point. Propeller can recover individual workflow execution nodes + // which previously succeeded based on the recovery (original) workflow execution id. + RecoverExecution(ctx context.Context, request admin.ExecutionRecoverRequest, requestedAt time.Time) ( + *admin.ExecutionCreateResponse, error) CreateWorkflowEvent(ctx context.Context, request admin.WorkflowExecutionEventRequest) ( *admin.WorkflowExecutionEventResponse, error) GetExecution(ctx context.Context, request admin.WorkflowExecutionGetRequest) (*admin.Execution, error) diff --git a/pkg/manager/mocks/execution.go b/pkg/manager/mocks/execution.go index 9b90775e0..cfcc19d2c 100644 --- a/pkg/manager/mocks/execution.go +++ b/pkg/manager/mocks/execution.go @@ -13,6 +13,8 @@ type CreateExecutionFunc func( type RelaunchExecutionFunc func( ctx context.Context, request admin.ExecutionRelaunchRequest, requestedAt time.Time) ( *admin.ExecutionCreateResponse, error) +type RecoverExecutionFunc func(ctx context.Context, request admin.ExecutionRecoverRequest, requestedAt time.Time) ( + *admin.ExecutionCreateResponse, error) type CreateExecutionEventFunc func(ctx context.Context, request admin.WorkflowExecutionEventRequest) ( *admin.WorkflowExecutionEventResponse, error) type GetExecutionFunc func(ctx context.Context, request admin.WorkflowExecutionGetRequest) (*admin.Execution, error) @@ -25,6 +27,7 @@ type TerminateExecutionFunc func( type MockExecutionManager struct { createExecutionFunc CreateExecutionFunc relaunchExecutionFunc RelaunchExecutionFunc + RecoverExecutionFunc RecoverExecutionFunc createExecutionEventFunc CreateExecutionEventFunc getExecutionFunc GetExecutionFunc getExecutionDataFunc GetExecutionDataFunc @@ -62,6 +65,14 @@ func (m *MockExecutionManager) SetCreateEventCallback(createEventFunc CreateExec m.createExecutionEventFunc = createEventFunc } +func (m *MockExecutionManager) RecoverExecution(ctx context.Context, request admin.ExecutionRecoverRequest, requestedAt time.Time) ( + *admin.ExecutionCreateResponse, error) { + if m.RecoverExecutionFunc != nil { + return m.RecoverExecutionFunc(ctx, request, requestedAt) + } + return &admin.ExecutionCreateResponse{}, nil +} + func (m *MockExecutionManager) CreateWorkflowEvent( ctx context.Context, request admin.WorkflowExecutionEventRequest) (*admin.WorkflowExecutionEventResponse, error) { diff --git a/pkg/rpc/adminservice/execution.go b/pkg/rpc/adminservice/execution.go index 58e1276aa..adf7913ce 100644 --- a/pkg/rpc/adminservice/execution.go +++ b/pkg/rpc/adminservice/execution.go @@ -66,6 +66,31 @@ func (m *AdminService) RelaunchExecution( return response, nil } +func (m *AdminService) RecoverExecution( + ctx context.Context, request *admin.ExecutionRecoverRequest) (*admin.ExecutionCreateResponse, error) { + defer m.interceptPanic(ctx, request) + requestedAt := time.Now() + if request == nil { + return nil, status.Errorf(codes.InvalidArgument, "Incorrect request, nil requests not allowed") + } + var response *admin.ExecutionCreateResponse + var err error + m.Metrics.executionEndpointMetrics.recover.Time(func() { + response, err = m.ExecutionManager.RecoverExecution(ctx, *request, requestedAt) + }) + audit.NewLogBuilder().WithAuthenticatedCtx(ctx).WithRequest( + "ExecutionCreateRequest", + audit.ParametersFromExecutionIdentifier(request.Id), + audit.ReadWrite, + requestedAt, + ).WithResponse(time.Now(), err).Log(ctx) + if err != nil { + return nil, util.TransformAndRecordError(err, &m.Metrics.executionEndpointMetrics.relaunch) + } + m.Metrics.executionEndpointMetrics.relaunch.Success() + return response, nil +} + func (m *AdminService) CreateWorkflowEvent( ctx context.Context, request *admin.WorkflowExecutionEventRequest) (*admin.WorkflowExecutionEventResponse, error) { defer m.interceptPanic(ctx, request) diff --git a/pkg/rpc/adminservice/metrics.go b/pkg/rpc/adminservice/metrics.go index 160d27813..3de1401c5 100644 --- a/pkg/rpc/adminservice/metrics.go +++ b/pkg/rpc/adminservice/metrics.go @@ -12,6 +12,7 @@ type executionEndpointMetrics struct { create util.RequestMetrics relaunch util.RequestMetrics + recover util.RequestMetrics createEvent util.RequestMetrics get util.RequestMetrics getData util.RequestMetrics @@ -121,6 +122,7 @@ func InitMetrics(adminScope promutils.Scope) AdminMetrics { scope: adminScope, create: util.NewRequestMetrics(adminScope, "create_execution"), relaunch: util.NewRequestMetrics(adminScope, "relaunch_execution"), + recover: util.NewRequestMetrics(adminScope, "recover_execution"), createEvent: util.NewRequestMetrics(adminScope, "create_execution_event"), get: util.NewRequestMetrics(adminScope, "get_execution"), getData: util.NewRequestMetrics(adminScope, "get_execution_data"), diff --git a/pkg/rpc/adminservice/tests/execution_test.go b/pkg/rpc/adminservice/tests/execution_test.go index 40688918a..1f6f0d6cf 100644 --- a/pkg/rpc/adminservice/tests/execution_test.go +++ b/pkg/rpc/adminservice/tests/execution_test.go @@ -139,6 +139,69 @@ func TestRelaunchExecutionError(t *testing.T) { "missing entity of type execution with identifier ") } +func TestRecoverExecutionHappyCase(t *testing.T) { + ctx := context.Background() + + mockExecutionManager := mocks.MockExecutionManager{} + mockExecutionManager.RecoverExecutionFunc = + func(ctx context.Context, + request admin.ExecutionRecoverRequest, requestedAt time.Time) (*admin.ExecutionCreateResponse, error) { + return &admin.ExecutionCreateResponse{ + Id: &core.WorkflowExecutionIdentifier{ + Project: request.Id.Project, + Domain: request.Id.Domain, + Name: request.Name, + }, + }, nil + } + + mockServer := NewMockAdminServer(NewMockAdminServerInput{ + executionManager: &mockExecutionManager, + }) + + resp, err := mockServer.RecoverExecution(ctx, &admin.ExecutionRecoverRequest{ + Id: &core.WorkflowExecutionIdentifier{ + Project: "project", + Domain: "domain", + }, + Name: "name", + }) + assert.Equal(t, "project", resp.Id.Project) + assert.Equal(t, "domain", resp.Id.Domain) + assert.Equal(t, "name", resp.Id.Name) + assert.NoError(t, err) +} + +func TestRecoverExecutionError(t *testing.T) { + ctx := context.Background() + + mockExecutionManager := mocks.MockExecutionManager{} + mockExecutionManager.RecoverExecutionFunc = + func(ctx context.Context, + request admin.ExecutionRecoverRequest, requestedAt time.Time) (*admin.ExecutionCreateResponse, error) { + return nil, repoErrors.GetMissingEntityError("execution", request.Id) + } + mockServer := NewMockAdminServer(NewMockAdminServerInput{ + executionManager: &mockExecutionManager, + }) + + resp, err := mockServer.RecoverExecution(ctx, &admin.ExecutionRecoverRequest{ + Name: "Name", + }) + assert.Nil(t, resp) + assert.EqualError(t, err, + "missing entity of type execution with identifier ") +} + +func TestRecoverExecution_InvalidRequest(t *testing.T) { + ctx := context.Background() + mockServer := NewMockAdminServer(NewMockAdminServerInput{}) + resp, err := mockServer.RecoverExecution(ctx, nil) + assert.Nil(t, resp) + assert.EqualError(t, err, + "rpc error: code = InvalidArgument desc = Incorrect request, nil requests not allowed") +} + func TestCreateWorkflowEvent(t *testing.T) { phase := core.WorkflowExecution_RUNNING mockExecutionManager := mocks.MockExecutionManager{} diff --git a/pkg/workflowengine/impl/propeller_executor.go b/pkg/workflowengine/impl/propeller_executor.go index 1dbf639c3..029684dcb 100644 --- a/pkg/workflowengine/impl/propeller_executor.go +++ b/pkg/workflowengine/impl/propeller_executor.go @@ -86,9 +86,12 @@ func (c *FlytePropeller) addPermissions(auth *admin.AuthRole, flyteWf *v1alpha1. } func addExecutionOverrides(taskPluginOverrides []*admin.PluginOverride, - workflowExecutionConfig *admin.WorkflowExecutionConfig, flyteWf *v1alpha1.FlyteWorkflow) { + workflowExecutionConfig *admin.WorkflowExecutionConfig, recoveryExecution *core.WorkflowExecutionIdentifier, flyteWf *v1alpha1.FlyteWorkflow) { executionConfig := v1alpha1.ExecutionConfig{ TaskPluginImpls: make(map[string]v1alpha1.TaskPluginOverride), + RecoveryExecution: v1alpha1.WorkflowExecutionIdentifier{ + WorkflowExecutionIdentifier: recoveryExecution, + }, } for _, override := range taskPluginOverrides { executionConfig.TaskPluginImpls[override.TaskType] = v1alpha1.TaskPluginOverride{ @@ -137,7 +140,7 @@ func (c *FlytePropeller) ExecuteWorkflow(ctx context.Context, input interfaces.E flyteWf.WorkflowMeta = &v1alpha1.WorkflowMeta{} } flyteWf.WorkflowMeta.EventVersion = c.eventVersion - addExecutionOverrides(input.TaskPluginOverrides, input.ExecutionConfig, flyteWf) + addExecutionOverrides(input.TaskPluginOverrides, input.ExecutionConfig, input.RecoveryExecution, flyteWf) if input.Reference.Spec.RawOutputDataConfig != nil { flyteWf.RawOutputDataConfig = v1alpha1.RawOutputDataConfig{ @@ -220,7 +223,7 @@ func (c *FlytePropeller) ExecuteTask(ctx context.Context, input interfaces.Execu flyteWf.Labels = labels annotations := addMapValues(input.Annotations, flyteWf.Annotations) flyteWf.Annotations = annotations - addExecutionOverrides(input.TaskPluginOverrides, input.ExecutionConfig, flyteWf) + addExecutionOverrides(input.TaskPluginOverrides, input.ExecutionConfig, nil, flyteWf) /* TODO(katrogan): uncomment once propeller has updated the flyte workflow CRD. diff --git a/pkg/workflowengine/impl/propeller_executor_test.go b/pkg/workflowengine/impl/propeller_executor_test.go index 697fcd30b..3377cb5a4 100644 --- a/pkg/workflowengine/impl/propeller_executor_test.go +++ b/pkg/workflowengine/impl/propeller_executor_test.go @@ -6,6 +6,8 @@ import ( "testing" "time" + "github.com/golang/protobuf/proto" + interfaces2 "github.com/flyteorg/flyteadmin/pkg/executioncluster/interfaces" "github.com/flyteorg/flyteadmin/pkg/executioncluster" @@ -140,6 +142,11 @@ func TestExecuteWorkflowHappyCase(t *testing.T) { FlyteClient: &FakeK8FlyteClient{}, }, nil }) + recoveryNodeExecutionID := &core.WorkflowExecutionIdentifier{ + Project: "p", + Domain: "d", + Name: "original", + } fakeFlyteWorkflow := FakeFlyteWorkflow{ createCallback: func(workflow *v1alpha1.FlyteWorkflow, opts v1.CreateOptions) (*v1alpha1.FlyteWorkflow, error) { assert.EqualValues(t, map[string]string{ @@ -159,6 +166,7 @@ func TestExecuteWorkflowHappyCase(t *testing.T) { }, workflow.ExecutionConfig.TaskPluginImpls) assert.Empty(t, opts) assert.Equal(t, workflow.ServiceAccountName, testK8sServiceAccount) + assert.True(t, proto.Equal(recoveryNodeExecutionID, workflow.ExecutionConfig.RecoveryExecution.WorkflowExecutionIdentifier)) return nil, nil }, } @@ -206,6 +214,7 @@ func TestExecuteWorkflowHappyCase(t *testing.T) { AssumableIamRole: testRole, KubernetesServiceAccount: testK8sServiceAccount, }, + RecoveryExecution: recoveryNodeExecutionID, }) assert.Nil(t, err) assert.NotNil(t, execInfo) @@ -463,7 +472,7 @@ func TestAddExecutionOverrides(t *testing.T) { }, } workflow := &v1alpha1.FlyteWorkflow{} - addExecutionOverrides(overrides, nil, workflow) + addExecutionOverrides(overrides, nil, nil, workflow) assert.EqualValues(t, workflow.ExecutionConfig.TaskPluginImpls, map[string]v1alpha1.TaskPluginOverride{ "taskType1": { PluginIDs: []string{"Plugin1", "Plugin2"}, @@ -476,7 +485,17 @@ func TestAddExecutionOverrides(t *testing.T) { MaxParallelism: 100, } workflow := &v1alpha1.FlyteWorkflow{} - addExecutionOverrides(nil, workflowExecutionConfig, workflow) + addExecutionOverrides(nil, workflowExecutionConfig, nil, workflow) assert.EqualValues(t, workflow.ExecutionConfig.MaxParallelism, uint32(100)) }) + t.Run("recovery execution", func(t *testing.T) { + recoveryExecutionID := &core.WorkflowExecutionIdentifier{ + Project: "p", + Domain: "d", + Name: "n", + } + workflow := &v1alpha1.FlyteWorkflow{} + addExecutionOverrides(nil, nil, recoveryExecutionID, workflow) + assert.True(t, proto.Equal(recoveryExecutionID, workflow.ExecutionConfig.RecoveryExecution.WorkflowExecutionIdentifier)) + }) } diff --git a/pkg/workflowengine/interfaces/executor.go b/pkg/workflowengine/interfaces/executor.go index 30adb9690..7aa44d34a 100644 --- a/pkg/workflowengine/interfaces/executor.go +++ b/pkg/workflowengine/interfaces/executor.go @@ -21,6 +21,7 @@ type ExecuteWorkflowInput struct { TaskPluginOverrides []*admin.PluginOverride ExecutionConfig *admin.WorkflowExecutionConfig Auth *admin.AuthRole + RecoveryExecution *core.WorkflowExecutionIdentifier } type ExecuteTaskInput struct {