From 81ac07f63ff38da6724afdc7d6a027088135767e Mon Sep 17 00:00:00 2001 From: Katrina Rogan Date: Tue, 4 Jan 2022 16:13:13 -0800 Subject: [PATCH] Execution event cluster validation (#307) --- pkg/common/constants.go | 4 + pkg/errors/errors.go | 15 ++ pkg/errors/errors_test.go | 16 +++ pkg/manager/impl/execution_manager.go | 8 +- pkg/manager/impl/execution_manager_test.go | 37 +++++ pkg/manager/impl/node_execution_manager.go | 41 +++--- .../impl/node_execution_manager_test.go | 28 ++-- pkg/manager/impl/task_execution_manager.go | 4 + .../impl/task_execution_manager_test.go | 5 + .../impl/validation/shared_execution.go | 42 ++++++ .../impl/validation/shared_execution_test.go | 55 ++++++++ pkg/repositories/gormimpl/execution_repo.go | 18 --- .../gormimpl/execution_repo_test.go | 37 ----- pkg/repositories/interfaces/execution_repo.go | 2 - pkg/repositories/mocks/execution_repo.go | 8 -- pkg/repositories/transformers/execution.go | 49 ++++++- .../transformers/execution_test.go | 130 ++++++++++++++++++ tests/bootstrap.go | 1 + tests/execution_test.go | 4 +- 19 files changed, 397 insertions(+), 107 deletions(-) create mode 100644 pkg/manager/impl/validation/shared_execution.go create mode 100644 pkg/manager/impl/validation/shared_execution_test.go diff --git a/pkg/common/constants.go b/pkg/common/constants.go index 8e0d9717cd..6f95e5cdfd 100644 --- a/pkg/common/constants.go +++ b/pkg/common/constants.go @@ -11,3 +11,7 @@ const ( ) const MaxResponseStatusBytes = 32000 + +// DefaultProducerID is used in older versions of propeller which hard code this producer id. +// See https://github.com/flyteorg/flytepropeller/blob/eaf084934de5d630cd4c11aae15ecae780cc787e/pkg/controller/nodes/task/transformer.go#L114 +const DefaultProducerID = "propeller" diff --git a/pkg/errors/errors.go b/pkg/errors/errors.go index 46b0193a04..9cf9d5a2d3 100644 --- a/pkg/errors/errors.go +++ b/pkg/errors/errors.go @@ -91,3 +91,18 @@ func NewAlreadyInTerminalStateError(ctx context.Context, errorMsg string, curPha } return statusErr } + +func NewIncompatibleClusterError(ctx context.Context, errorMsg, curCluster string) FlyteAdminError { + statusErr, transformationErr := NewFlyteAdminError(codes.FailedPrecondition, errorMsg).WithDetails(&admin.EventFailureReason{ + Reason: &admin.EventFailureReason_IncompatibleCluster{ + IncompatibleCluster: &admin.EventErrorIncompatibleCluster{ + Cluster: curCluster, + }, + }, + }) + if transformationErr != nil { + logger.Panicf(ctx, "Failed to wrap grpc status in type 'Error': %v", transformationErr) + return NewFlyteAdminErrorf(codes.FailedPrecondition, errorMsg) + } + return statusErr +} diff --git a/pkg/errors/errors_test.go b/pkg/errors/errors_test.go index 239025d768..16d6d99fc9 100644 --- a/pkg/errors/errors_test.go +++ b/pkg/errors/errors_test.go @@ -27,3 +27,19 @@ func TestGrpcStatusError(t *testing.T) { _, ok = details.GetReason().(*admin.EventFailureReason_AlreadyInTerminalState) assert.True(t, ok) } + +func TestNewIncompatibleClusterError(t *testing.T) { + errorMsg := "foo" + cluster := "C1" + statusErr := NewIncompatibleClusterError(context.Background(), errorMsg, cluster) + assert.NotNil(t, statusErr) + s, ok := status.FromError(statusErr) + assert.True(t, ok) + assert.Equal(t, codes.FailedPrecondition, s.Code()) + assert.Equal(t, errorMsg, s.Message()) + + details, ok := s.Details()[0].(*admin.EventFailureReason) + assert.True(t, ok) + _, ok = details.GetReason().(*admin.EventFailureReason_IncompatibleCluster) + assert.True(t, ok) +} diff --git a/pkg/manager/impl/execution_manager.go b/pkg/manager/impl/execution_manager.go index 9d99795b09..e8bfbf33dd 100644 --- a/pkg/manager/impl/execution_manager.go +++ b/pkg/manager/impl/execution_manager.go @@ -1177,7 +1177,8 @@ func (m *ExecutionManager) CreateWorkflowEvent(ctx context.Context, request admi } wfExecPhase := core.WorkflowExecution_Phase(core.WorkflowExecution_Phase_value[executionModel.Phase]) - if wfExecPhase == request.Event.Phase { + // Subsequent queued events announcing a cluster reassignment are permitted. + if wfExecPhase == request.Event.Phase && request.Event.Phase != core.WorkflowExecution_QUEUED { logger.Debugf(ctx, "This phase %s was already recorded for workflow execution %v", wfExecPhase.String(), request.Event.ExecutionId) return nil, errors.NewFlyteAdminErrorf(codes.AlreadyExists, @@ -1188,6 +1189,11 @@ func (m *ExecutionManager) CreateWorkflowEvent(ctx context.Context, request admi curPhase := wfExecPhase.String() errorMsg := fmt.Sprintf("Invalid phase change from %s to %s for workflow execution %v", curPhase, request.Event.Phase.String(), request.Event.ExecutionId) return nil, errors.NewAlreadyInTerminalStateError(ctx, errorMsg, curPhase) + } else if wfExecPhase == core.WorkflowExecution_RUNNING && request.Event.Phase == core.WorkflowExecution_QUEUED { + // Cannot go back in time from RUNNING -> QUEUED + return nil, errors.NewFlyteAdminErrorf(codes.FailedPrecondition, + "Cannot go from %s to %s for workflow execution %v", + wfExecPhase.String(), request.Event.Phase.String(), request.Event.ExecutionId) } err = transformers.UpdateExecutionModelState(ctx, executionModel, request, m.config.ApplicationConfiguration().GetRemoteDataConfig().InlineEventDataPolicy, m.storageClient) diff --git a/pkg/manager/impl/execution_manager_test.go b/pkg/manager/impl/execution_manager_test.go index ea2b672df6..01e13705ec 100644 --- a/pkg/manager/impl/execution_manager_test.go +++ b/pkg/manager/impl/execution_manager_test.go @@ -1340,6 +1340,7 @@ func TestCreateWorkflowEvent(t *testing.T) { OutputResult: &event.WorkflowExecutionEvent_Error{ Error: &executionError, }, + ProducerId: testCluster, }, } mockDbEventWriter := &eventWriterMocks.WorkflowExecutionEventWriter{} @@ -1391,6 +1392,40 @@ func TestCreateWorkflowEvent_TerminalState(t *testing.T) { assert.True(t, ok) } +func TestCreateWorkflowEvent_NoRunningToQueued(t *testing.T) { + repository := repositoryMocks.NewMockRepository() + executionGetFunc := func(ctx context.Context, input interfaces.Identifier) (models.Execution, error) { + return models.Execution{ + ExecutionKey: models.ExecutionKey{ + Project: "project", + Domain: "domain", + Name: "name", + }, + Spec: specBytes, + Phase: core.WorkflowExecution_RUNNING.String(), + }, nil + } + + repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetGetCallback(executionGetFunc) + updateExecutionFunc := func(context context.Context, execution models.Execution) error { + return nil + } + repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetUpdateCallback(updateExecutionFunc) + execManager := NewExecutionManager(repository, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockScope.NewTestScope(), &mockPublisher, mockExecutionRemoteURL, nil, nil, nil, &eventWriterMocks.WorkflowExecutionEventWriter{}) + + resp, err := execManager.CreateWorkflowEvent(context.Background(), admin.WorkflowExecutionEventRequest{ + RequestId: "1", + Event: &event.WorkflowExecutionEvent{ + ExecutionId: &executionIdentifier, + Phase: core.WorkflowExecution_QUEUED, + }, + }) + assert.Nil(t, resp) + assert.NotNil(t, err) + adminError := err.(flyteAdminErrors.FlyteAdminError) + assert.Equal(t, adminError.Code(), codes.FailedPrecondition) +} + func TestCreateWorkflowEvent_StartedRunning(t *testing.T) { repository := repositoryMocks.NewMockRepository() occurredAt := time.Now().UTC() @@ -1427,6 +1462,7 @@ func TestCreateWorkflowEvent_StartedRunning(t *testing.T) { ExecutionId: &executionIdentifier, OccurredAt: occurredAtTimestamp, Phase: core.WorkflowExecution_RUNNING, + ProducerId: testCluster, }, } mockDbEventWriter := &eventWriterMocks.WorkflowExecutionEventWriter{} @@ -1641,6 +1677,7 @@ func TestCreateWorkflowEvent_DatabaseUpdateError(t *testing.T) { OutputResult: &event.WorkflowExecutionEvent_Error{ Error: &executionError, }, + ProducerId: testCluster, }, }) assert.Nil(t, resp) diff --git a/pkg/manager/impl/node_execution_manager.go b/pkg/manager/impl/node_execution_manager.go index f94d846da8..e1c75d3b3a 100644 --- a/pkg/manager/impl/node_execution_manager.go +++ b/pkg/manager/impl/node_execution_manager.go @@ -83,25 +83,6 @@ func getNodeExecutionContext(ctx context.Context, identifier *core.NodeExecution func (m *NodeExecutionManager) createNodeExecutionWithEvent( ctx context.Context, request *admin.NodeExecutionEventRequest, dynamicWorkflowRemoteClosureReference string) error { - - executionID := request.Event.Id.ExecutionId - workflowExecutionExists, err := m.db.ExecutionRepo().Exists(ctx, repoInterfaces.Identifier{ - Project: executionID.Project, - Domain: executionID.Domain, - Name: executionID.Name, - }) - if err != nil || !workflowExecutionExists { - m.metrics.MissingWorkflowExecution.Inc() - logger.Debugf(ctx, "Failed to find existing execution with id [%+v] with err: %v", executionID, err) - if err != nil { - if ferr, ok := err.(errors.FlyteAdminError); ok { - return errors.NewFlyteAdminErrorf(ferr.Code(), - "Failed to get existing execution id: [%+v] with err: %v", executionID, err) - } - } - return fmt.Errorf("failed to get existing execution id: [%+v]", executionID) - } - var parentTaskExecutionID *uint if request.Event.ParentTaskMetadata != nil { taskExecutionModel, err := util.GetTaskExecutionModel(ctx, m.db, request.Event.ParentTaskMetadata.Id) @@ -233,6 +214,28 @@ func (m *NodeExecutionManager) CreateNodeEvent(ctx context.Context, request admi logger.Debugf(ctx, "Received node execution event for Node Exec Id [%+v] transitioning to phase [%v], w/ Metadata [%v]", request.Event.Id, request.Event.Phase, request.Event.ParentTaskMetadata) + executionID := request.Event.Id.ExecutionId + workflowExecution, err := m.db.ExecutionRepo().Get(ctx, repoInterfaces.Identifier{ + Project: executionID.Project, + Domain: executionID.Domain, + Name: executionID.Name, + }) + if err != nil { + m.metrics.MissingWorkflowExecution.Inc() + logger.Debugf(ctx, "Failed to find existing execution with id [%+v] with err: %v", executionID, err) + if err != nil { + if ferr, ok := err.(errors.FlyteAdminError); ok { + return nil, errors.NewFlyteAdminErrorf(ferr.Code(), + "Failed to get existing execution id: [%+v] with err: %v", executionID, err) + } + } + return nil, fmt.Errorf("failed to get existing execution id: [%+v]", executionID) + } + + if err := validation.ValidateCluster(ctx, workflowExecution.Cluster, request.Event.ProducerId); err != nil { + return nil, err + } + var dynamicWorkflowRemoteClosureReference string if request.Event.GetTaskNodeMetadata() != nil && request.Event.GetTaskNodeMetadata().DynamicWorkflow != nil { dynamicWorkflowRemoteClosureDataReference, err := m.uploadDynamicWorkflowClosure( diff --git a/pkg/manager/impl/node_execution_manager_test.go b/pkg/manager/impl/node_execution_manager_test.go index a35081ea4e..a32484a387 100644 --- a/pkg/manager/impl/node_execution_manager_test.go +++ b/pkg/manager/impl/node_execution_manager_test.go @@ -103,6 +103,7 @@ func addGetExecutionCallback(t *testing.T, repository repositories.RepositoryInt Domain: "domain", Name: "name", }, + Cluster: "propeller", }, nil }) } @@ -225,26 +226,19 @@ func TestCreateNodeEvent_MissingExecution(t *testing.T) { expectedErr := flyteAdminErrors.NewFlyteAdminErrorf(codes.Internal, "expected error") repository.NodeExecutionRepo().(*repositoryMocks.MockNodeExecutionRepo).SetGetCallback( func(ctx context.Context, input interfaces.NodeExecutionResource) (models.NodeExecution, error) { - return models.NodeExecution{}, flyteAdminErrors.NewFlyteAdminError(codes.NotFound, "foo") + return models.NodeExecution{}, expectedErr }) - repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).ExistsFunction = - func(ctx context.Context, input interfaces.Identifier) (bool, error) { - return false, expectedErr - } + repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetGetCallback( + func(ctx context.Context, input interfaces.Identifier) (models.Execution, error) { + return models.Execution{}, expectedErr + }) + + repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetGetCallback(func(ctx context.Context, input interfaces.Identifier) (models.Execution, error) { + return models.Execution{}, expectedErr + }) nodeExecManager := NewNodeExecutionManager(repository, getMockExecutionsConfigProvider(), make([]string, 0), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockNodeExecutionRemoteURL, &mockPublisher, &eventWriterMocks.NodeExecutionEventWriter{}) resp, err := nodeExecManager.CreateNodeEvent(context.Background(), request) - assert.EqualError(t, err, "Failed to get existing execution id: [project:\"project\""+ - " domain:\"domain\" name:\"name\" ] with err: expected error") - assert.Nil(t, resp) - - repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).ExistsFunction = - func(ctx context.Context, input interfaces.Identifier) (bool, error) { - return false, nil - } - nodeExecManager = NewNodeExecutionManager(repository, getMockExecutionsConfigProvider(), make([]string, 0), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockNodeExecutionRemoteURL, &mockPublisher, &eventWriterMocks.NodeExecutionEventWriter{}) - resp, err = nodeExecManager.CreateNodeEvent(context.Background(), request) - assert.EqualError(t, err, "failed to get existing execution id: [project:\"project\""+ - " domain:\"domain\" name:\"name\" ]") + assert.EqualError(t, err, "Failed to get existing execution id: [project:\"project\" domain:\"domain\" name:\"name\" ] with err: expected error") assert.Nil(t, resp) } diff --git a/pkg/manager/impl/task_execution_manager.go b/pkg/manager/impl/task_execution_manager.go index 90c6670bd8..4d54f66b1b 100644 --- a/pkg/manager/impl/task_execution_manager.go +++ b/pkg/manager/impl/task_execution_manager.go @@ -130,6 +130,10 @@ func (m *TaskExecutionManager) CreateTaskExecutionEvent(ctx context.Context, req return nil, err } + if err := validation.ValidateClusterForExecutionID(ctx, m.db, request.Event.ParentNodeExecutionId.ExecutionId, request.Event.ProducerId); err != nil { + return nil, err + } + // Get the parent node execution, if none found a MissingEntityError will be returned nodeExecutionID := request.Event.ParentNodeExecutionId taskExecutionID := core.TaskExecutionIdentifier{ diff --git a/pkg/manager/impl/task_execution_manager_test.go b/pkg/manager/impl/task_execution_manager_test.go index 9b417f0260..4adfacd3c7 100644 --- a/pkg/manager/impl/task_execution_manager_test.go +++ b/pkg/manager/impl/task_execution_manager_test.go @@ -75,6 +75,7 @@ func addGetWorkflowExecutionCallback(repository repositories.RepositoryInterface Domain: sampleNodeExecID.ExecutionId.Domain, Name: sampleNodeExecID.ExecutionId.Name, }, + Cluster: "propeller", }, nil }, ) @@ -299,6 +300,7 @@ func TestCreateTaskEvent_Update(t *testing.T) { func TestCreateTaskEvent_MissingExecution(t *testing.T) { repository := repositoryMocks.NewMockRepository() expectedErr := flyteAdminErrors.NewFlyteAdminErrorf(codes.Internal, "expected error") + addGetWorkflowExecutionCallback(repository) repository.TaskExecutionRepo().(*repositoryMocks.MockTaskExecutionRepo).SetGetCallback( func(ctx context.Context, input interfaces.GetTaskExecutionInput) (models.TaskExecution, error) { return models.TaskExecution{}, flyteAdminErrors.NewFlyteAdminError(codes.NotFound, "foo") @@ -327,6 +329,7 @@ func TestCreateTaskEvent_MissingExecution(t *testing.T) { func TestCreateTaskEvent_CreateDatabaseError(t *testing.T) { repository := repositoryMocks.NewMockRepository() + addGetWorkflowExecutionCallback(repository) repository.TaskExecutionRepo().(*repositoryMocks.MockTaskExecutionRepo).SetGetCallback( func(ctx context.Context, input interfaces.GetTaskExecutionInput) (models.TaskExecution, error) { return models.TaskExecution{}, flyteAdminErrors.NewFlyteAdminError(codes.NotFound, "foo") @@ -345,6 +348,7 @@ func TestCreateTaskEvent_CreateDatabaseError(t *testing.T) { func TestCreateTaskEvent_UpdateDatabaseError(t *testing.T) { repository := repositoryMocks.NewMockRepository() + addGetWorkflowExecutionCallback(repository) addGetNodeExecutionCallback(repository) repository.TaskExecutionRepo().(*repositoryMocks.MockTaskExecutionRepo).SetGetCallback( @@ -385,6 +389,7 @@ func TestCreateTaskEvent_UpdateDatabaseError(t *testing.T) { func TestCreateTaskEvent_UpdateTerminalEventError(t *testing.T) { repository := repositoryMocks.NewMockRepository() + addGetWorkflowExecutionCallback(repository) repository.TaskExecutionRepo().(*repositoryMocks.MockTaskExecutionRepo).SetGetCallback( func(ctx context.Context, input interfaces.GetTaskExecutionInput) (models.TaskExecution, error) { return models.TaskExecution{ diff --git a/pkg/manager/impl/validation/shared_execution.go b/pkg/manager/impl/validation/shared_execution.go new file mode 100644 index 0000000000..a5ef2a11bc --- /dev/null +++ b/pkg/manager/impl/validation/shared_execution.go @@ -0,0 +1,42 @@ +package validation + +import ( + "context" + "fmt" + + "github.com/flyteorg/flyteadmin/pkg/common" + "github.com/flyteorg/flyteadmin/pkg/errors" + "github.com/flyteorg/flyteadmin/pkg/repositories" + repoInterfaces "github.com/flyteorg/flyteadmin/pkg/repositories/interfaces" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flytestdlib/logger" +) + +// ValidateClusterForExecutionID validates that the execution denoted by executionId is recorded as executing on `cluster`. +func ValidateClusterForExecutionID(ctx context.Context, db repositories.RepositoryInterface, executionID *core.WorkflowExecutionIdentifier, cluster string) error { + workflowExecution, err := db.ExecutionRepo().Get(ctx, repoInterfaces.Identifier{ + Project: executionID.Project, + Domain: executionID.Domain, + Name: executionID.Name, + }) + if err != nil { + logger.Debugf(ctx, "Failed to find existing execution with id [%+v] with err: %v", executionID, err) + return err + } + return ValidateCluster(ctx, workflowExecution.Cluster, cluster) +} + +// ValidateClusterForExecution validates that the execution is recorded as executing on `cluster`. +func ValidateCluster(ctx context.Context, recordedCluster, cluster string) error { + // DefaultProducerID is used in older versions of propeller which hard code this producer id. + // See https://github.com/flyteorg/flytepropeller/blob/eaf084934de5d630cd4c11aae15ecae780cc787e/pkg/controller/nodes/task/transformer.go#L114 + if len(cluster) == 0 || cluster == common.DefaultProducerID { + return nil + } + if recordedCluster != cluster { + errorMsg := fmt.Sprintf("Cluster/producer from event [%s] does not match existing workflow execution cluster: [%s]", + recordedCluster, cluster) + return errors.NewIncompatibleClusterError(ctx, errorMsg, recordedCluster) + } + return nil +} diff --git a/pkg/manager/impl/validation/shared_execution_test.go b/pkg/manager/impl/validation/shared_execution_test.go new file mode 100644 index 0000000000..187f6c9f99 --- /dev/null +++ b/pkg/manager/impl/validation/shared_execution_test.go @@ -0,0 +1,55 @@ +package validation + +import ( + "context" + "testing" + + "github.com/flyteorg/flyteadmin/pkg/common" + "github.com/flyteorg/flyteadmin/pkg/errors" + "github.com/flyteorg/flyteadmin/pkg/repositories/interfaces" + repositoryMocks "github.com/flyteorg/flyteadmin/pkg/repositories/mocks" + "github.com/flyteorg/flyteadmin/pkg/repositories/models" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + "github.com/stretchr/testify/assert" + "google.golang.org/grpc/codes" +) + +var testCluster = "C1" + +var testExecID = &core.WorkflowExecutionIdentifier{ + Project: "p", + Domain: "d", + Name: "n", +} + +func TestValidateClusterForExecutionID(t *testing.T) { + repository := repositoryMocks.NewMockRepository() + repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetGetCallback(func(ctx context.Context, input interfaces.Identifier) (models.Execution, error) { + return models.Execution{ + Cluster: testCluster, + }, nil + }) + assert.NoError(t, ValidateClusterForExecutionID(context.TODO(), repository, testExecID, testCluster)) + assert.NoError(t, ValidateClusterForExecutionID(context.TODO(), repository, testExecID, common.DefaultProducerID)) +} + +func TestValidateCluster_Nonmatching(t *testing.T) { + repository := repositoryMocks.NewMockRepository() + repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetGetCallback(func(ctx context.Context, input interfaces.Identifier) (models.Execution, error) { + return models.Execution{ + Cluster: "C2", + }, nil + }) + err := ValidateClusterForExecutionID(context.TODO(), repository, testExecID, testCluster) + assert.Equal(t, codes.FailedPrecondition, err.(errors.FlyteAdminError).Code()) +} + +func TestValidateCluster_NoExecution(t *testing.T) { + repository := repositoryMocks.NewMockRepository() + expectedErr := errors.NewFlyteAdminError(codes.Internal, "foo") + repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetGetCallback(func(ctx context.Context, input interfaces.Identifier) (models.Execution, error) { + return models.Execution{}, expectedErr + }) + err := ValidateClusterForExecutionID(context.TODO(), repository, testExecID, testCluster) + assert.Equal(t, expectedErr, err) +} diff --git a/pkg/repositories/gormimpl/execution_repo.go b/pkg/repositories/gormimpl/execution_repo.go index 66ba20e6e7..cd62dd64c8 100644 --- a/pkg/repositories/gormimpl/execution_repo.go +++ b/pkg/repositories/gormimpl/execution_repo.go @@ -110,24 +110,6 @@ func (r *ExecutionRepo) List(ctx context.Context, input interfaces.ListResourceI }, nil } -func (r *ExecutionRepo) Exists(ctx context.Context, input interfaces.Identifier) (bool, error) { - var execution models.Execution - timer := r.metrics.ExistsDuration.Start() - // Only select the id field (uint) to check for existence. - tx := r.db.Select(ID).Where(&models.Execution{ - ExecutionKey: models.ExecutionKey{ - Project: input.Project, - Domain: input.Domain, - Name: input.Name, - }, - }).Take(&execution) - timer.Stop() - if tx.Error != nil { - return false, r.errorTransformer.ToFlyteAdminError(tx.Error) - } - return true, nil -} - // Returns an instance of ExecutionRepoInterface func NewExecutionRepo( db *gorm.DB, errorTransformer adminErrors.ErrorTransformer, scope promutils.Scope) interfaces.ExecutionRepoInterface { diff --git a/pkg/repositories/gormimpl/execution_repo_test.go b/pkg/repositories/gormimpl/execution_repo_test.go index 5550e9b32a..619e8b1b70 100644 --- a/pkg/repositories/gormimpl/execution_repo_test.go +++ b/pkg/repositories/gormimpl/execution_repo_test.go @@ -343,40 +343,3 @@ func TestListExecutionsForWorkflow(t *testing.T) { assert.Equal(t, time.Hour, execution.Duration) } } - -func TestExecutionExists(t *testing.T) { - executionRepo := NewExecutionRepo(GetDbForTest(t), errors.NewTestErrorTransformer(), mockScope.NewTestScope()) - expectedExecution := models.Execution{ - BaseModel: models.BaseModel{ - ID: uint(20), - }, - ExecutionKey: models.ExecutionKey{ - Project: "project", - Domain: "domain", - Name: "1", - }, - LaunchPlanID: uint(2), - Phase: core.WorkflowExecution_SUCCEEDED.String(), - Closure: []byte{1, 2}, - WorkflowID: uint(3), - Spec: []byte{3, 4}, - } - - executions := make([]map[string]interface{}, 0) - execution := getMockExecutionResponseFromDb(expectedExecution) - executions = append(executions, execution) - - GlobalMock := mocket.Catcher.Reset() - GlobalMock.Logging = true - - // Only match on queries that append expected filters - GlobalMock.NewMock().WithQuery(`SELECT "id" FROM "executions" WHERE "executions"."execution_project" = $1 AND "executions"."execution_domain" = $2 AND "executions"."execution_name" = $3 LIMIT 1`).WithReply(executions) - - exists, err := executionRepo.Exists(context.Background(), interfaces.Identifier{ - Project: "project", - Domain: "domain", - Name: "1", - }) - assert.NoError(t, err) - assert.True(t, exists) -} diff --git a/pkg/repositories/interfaces/execution_repo.go b/pkg/repositories/interfaces/execution_repo.go index 475b0a2c52..770bcd3523 100644 --- a/pkg/repositories/interfaces/execution_repo.go +++ b/pkg/repositories/interfaces/execution_repo.go @@ -16,8 +16,6 @@ type ExecutionRepoInterface interface { Get(ctx context.Context, input Identifier) (models.Execution, error) // Returns executions matching query parameters. A limit must be provided for the results page size. List(ctx context.Context, input ListResourceInput) (ExecutionCollectionOutput, error) - // Returns a matching execution if it exists. - Exists(ctx context.Context, input Identifier) (bool, error) } // Response format for a query on workflows. diff --git a/pkg/repositories/mocks/execution_repo.go b/pkg/repositories/mocks/execution_repo.go index 4cfbdb18eb..27cb010156 100644 --- a/pkg/repositories/mocks/execution_repo.go +++ b/pkg/repositories/mocks/execution_repo.go @@ -18,7 +18,6 @@ type MockExecutionRepo struct { updateFunction UpdateExecutionFunc getFunction GetExecutionFunc listFunction ListExecutionFunc - ExistsFunction func(ctx context.Context, input interfaces.Identifier) (bool, error) } func (r *MockExecutionRepo) Create(ctx context.Context, input models.Execution) error { @@ -70,13 +69,6 @@ func (r *MockExecutionRepo) SetListCallback(listFunction ListExecutionFunc) { r.listFunction = listFunction } -func (r *MockExecutionRepo) Exists(ctx context.Context, input interfaces.Identifier) (bool, error) { - if r.ExistsFunction != nil { - return r.ExistsFunction(ctx, input) - } - return true, nil -} - func NewMockExecutionRepo() interfaces.ExecutionRepoInterface { return &MockExecutionRepo{} } diff --git a/pkg/repositories/transformers/execution.go b/pkg/repositories/transformers/execution.go index 91421b1756..d1b12126d0 100644 --- a/pkg/repositories/transformers/execution.go +++ b/pkg/repositories/transformers/execution.go @@ -2,8 +2,11 @@ package transformers import ( "context" + "fmt" "time" + "k8s.io/apimachinery/pkg/util/sets" + "github.com/flyteorg/flyteadmin/pkg/runtime/interfaces" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" @@ -19,7 +22,9 @@ import ( "github.com/flyteorg/flyteadmin/pkg/repositories/models" ) -// Request parameters for calls to CreateExecutionModel. +var clusterReassignablePhases = sets.NewString(core.WorkflowExecution_UNDEFINED.String(), core.WorkflowExecution_QUEUED.String()) + +// CreateExecutionModelInput encapsulates request parameters for calls to CreateExecutionModel. type CreateExecutionModelInput struct { WorkflowExecutionID core.WorkflowExecutionIdentifier RequestSpec *admin.ExecutionSpec @@ -37,7 +42,7 @@ type CreateExecutionModelInput struct { UserInputsURI storage.DataReference } -// Transforms a ExecutionCreateRequest to a Execution model +// CreateExecutionModel transforms a ExecutionCreateRequest to a Execution model func CreateExecutionModel(input CreateExecutionModelInput) (*models.Execution, error) { requestSpec := input.RequestSpec if requestSpec.Metadata == nil { @@ -104,6 +109,30 @@ func CreateExecutionModel(input CreateExecutionModelInput) (*models.Execution, e return executionModel, nil } +func reassignCluster(ctx context.Context, cluster string, executionID *core.WorkflowExecutionIdentifier, execution *models.Execution) error { + logger.Debugf(ctx, "Updating cluster for execution [%v] with existing recorded cluster [%s] and setting to cluster [%s]", + executionID, execution.Cluster, cluster) + execution.Cluster = cluster + var executionSpec admin.ExecutionSpec + err := proto.Unmarshal(execution.Spec, &executionSpec) + if err != nil { + return errors.NewFlyteAdminErrorf(codes.Internal, "Failed to unmarshal execution spec: %v", err) + } + if executionSpec.Metadata == nil { + executionSpec.Metadata = &admin.ExecutionMetadata{} + } + if executionSpec.Metadata.SystemMetadata == nil { + executionSpec.Metadata.SystemMetadata = &admin.SystemMetadata{} + } + executionSpec.Metadata.SystemMetadata.ExecutionCluster = cluster + marshaledSpec, err := proto.Marshal(&executionSpec) + if err != nil { + return errors.NewFlyteAdminErrorf(codes.Internal, "Failed to marshal execution spec: %v", err) + } + execution.Spec = marshaledSpec + return nil +} + // Updates an existing model given a WorkflowExecution event. func UpdateExecutionModelState( ctx context.Context, @@ -138,6 +167,22 @@ func UpdateExecutionModelState( } } + // Default or empty cluster values do not require updating the execution model. + ignoreClusterFromEvent := len(request.Event.ProducerId) == 0 || request.Event.ProducerId == common.DefaultProducerID + logger.Debugf(ctx, "Producer Id [%v]. IgnoreClusterFromEvent [%v]", request.Event.ProducerId, ignoreClusterFromEvent) + if !ignoreClusterFromEvent { + if clusterReassignablePhases.Has(execution.Phase) { + if err := reassignCluster(ctx, request.Event.ProducerId, request.Event.ExecutionId, execution); err != nil { + return err + } + } else if execution.Cluster != request.Event.ProducerId { + errorMsg := fmt.Sprintf("Cannot accept events for running/terminated execution [%v] from cluster [%s],"+ + "expected events to originate from [%s]", + request.Event.ExecutionId, request.Event.ProducerId, execution.Cluster) + return errors.NewIncompatibleClusterError(ctx, errorMsg, execution.Cluster) + } + } + if request.Event.GetOutputUri() != "" { executionClosure.OutputResult = &admin.ExecutionClosure_Outputs{ Outputs: &admin.LiteralMapBlob{ diff --git a/pkg/repositories/transformers/execution_test.go b/pkg/repositories/transformers/execution_test.go index af22bc2ffb..bbca4b7f85 100644 --- a/pkg/repositories/transformers/execution_test.go +++ b/pkg/repositories/transformers/execution_test.go @@ -6,6 +6,11 @@ import ( "testing" "time" + "github.com/flyteorg/flyteadmin/pkg/common" + + "github.com/flyteorg/flyteadmin/pkg/errors" + "google.golang.org/grpc/codes" + commonMocks "github.com/flyteorg/flyteadmin/pkg/common/mocks" "github.com/flyteorg/flyteadmin/pkg/runtime/interfaces" "github.com/flyteorg/flytestdlib/storage" @@ -581,3 +586,128 @@ func TestFromExecutionModels(t *testing.T) { Closure: &closure, }, executions[0])) } + +func TestUpdateModelState_WithClusterInformation(t *testing.T) { + createdAt := time.Date(2018, 10, 29, 16, 0, 0, 0, time.UTC) + createdAtProto, _ := ptypes.TimestampProto(createdAt) + existingClosure := admin.ExecutionClosure{ + ComputedInputs: &core.LiteralMap{ + Literals: map[string]*core.Literal{ + "foo": {}, + }, + }, + Phase: core.WorkflowExecution_UNDEFINED, + CreatedAt: createdAtProto, + } + spec := testutils.GetExecutionRequest().Spec + specBytes, _ := proto.Marshal(spec) + existingClosureBytes, _ := proto.Marshal(&existingClosure) + startedAt := time.Now() + executionModel := getRunningExecutionModel(specBytes, existingClosureBytes, startedAt) + testCluster := "C1" + altCluster := "C2" + executionModel.Cluster = testCluster + occurredAt := time.Date(2018, 10, 29, 16, 10, 0, 0, time.UTC) + occurredAtProto, _ := ptypes.TimestampProto(occurredAt) + t.Run("update", func(t *testing.T) { + executionModel.Cluster = altCluster + err := UpdateExecutionModelState(context.TODO(), &executionModel, admin.WorkflowExecutionEventRequest{ + Event: &event.WorkflowExecutionEvent{ + Phase: core.WorkflowExecution_QUEUED, + OccurredAt: occurredAtProto, + ProducerId: testCluster, + }, + }, interfaces.InlineEventDataPolicyStoreInline, commonMocks.GetMockStorageClient()) + assert.NoError(t, err) + assert.Equal(t, testCluster, executionModel.Cluster) + executionModel.Cluster = testCluster + }) + t.Run("do not update", func(t *testing.T) { + err := UpdateExecutionModelState(context.TODO(), &executionModel, admin.WorkflowExecutionEventRequest{ + Event: &event.WorkflowExecutionEvent{ + Phase: core.WorkflowExecution_RUNNING, + OccurredAt: occurredAtProto, + ProducerId: altCluster, + }, + }, interfaces.InlineEventDataPolicyStoreInline, commonMocks.GetMockStorageClient()) + assert.Equal(t, err.(errors.FlyteAdminError).Code(), codes.FailedPrecondition) + }) + t.Run("matches recorded", func(t *testing.T) { + executionModel.Cluster = testCluster + err := UpdateExecutionModelState(context.TODO(), &executionModel, admin.WorkflowExecutionEventRequest{ + Event: &event.WorkflowExecutionEvent{ + Phase: core.WorkflowExecution_RUNNING, + OccurredAt: occurredAtProto, + ProducerId: testCluster, + }, + }, interfaces.InlineEventDataPolicyStoreInline, commonMocks.GetMockStorageClient()) + assert.NoError(t, err) + }) + t.Run("default cluster value", func(t *testing.T) { + executionModel.Cluster = testCluster + err := UpdateExecutionModelState(context.TODO(), &executionModel, admin.WorkflowExecutionEventRequest{ + Event: &event.WorkflowExecutionEvent{ + Phase: core.WorkflowExecution_RUNNING, + OccurredAt: occurredAtProto, + ProducerId: common.DefaultProducerID, + }, + }, interfaces.InlineEventDataPolicyStoreInline, commonMocks.GetMockStorageClient()) + assert.NoError(t, err) + }) +} + +func TestReassignCluster(t *testing.T) { + oldCluster := "old_cluster" + newCluster := "new_cluster" + + workflowExecutionID := core.WorkflowExecutionIdentifier{ + Project: "project", + Domain: "domain", + Name: "name", + } + + t.Run("happy case", func(t *testing.T) { + spec := testutils.GetExecutionRequest().Spec + spec.Metadata = &admin.ExecutionMetadata{ + SystemMetadata: &admin.SystemMetadata{ + ExecutionCluster: oldCluster, + }, + } + specBytes, _ := proto.Marshal(spec) + executionModel := models.Execution{ + Spec: specBytes, + Cluster: oldCluster, + } + err := reassignCluster(context.TODO(), newCluster, &workflowExecutionID, &executionModel) + assert.NoError(t, err) + assert.Equal(t, newCluster, executionModel.Cluster) + + var updatedSpec admin.ExecutionSpec + err = proto.Unmarshal(executionModel.Spec, &updatedSpec) + assert.NoError(t, err) + assert.Equal(t, newCluster, updatedSpec.Metadata.SystemMetadata.ExecutionCluster) + }) + t.Run("happy case - initialize cluster", func(t *testing.T) { + spec := testutils.GetExecutionRequest().Spec + specBytes, _ := proto.Marshal(spec) + executionModel := models.Execution{ + Spec: specBytes, + } + err := reassignCluster(context.TODO(), newCluster, &workflowExecutionID, &executionModel) + assert.NoError(t, err) + assert.Equal(t, newCluster, executionModel.Cluster) + + var updatedSpec admin.ExecutionSpec + err = proto.Unmarshal(executionModel.Spec, &updatedSpec) + assert.NoError(t, err) + assert.Equal(t, newCluster, updatedSpec.Metadata.SystemMetadata.ExecutionCluster) + }) + t.Run("invalid existing spec", func(t *testing.T) { + executionModel := models.Execution{ + Spec: []byte("I'm invalid"), + Cluster: oldCluster, + } + err := reassignCluster(context.TODO(), newCluster, &workflowExecutionID, &executionModel) + assert.Equal(t, err.(errors.FlyteAdminError).Code(), codes.Internal) + }) +} diff --git a/tests/bootstrap.go b/tests/bootstrap.go index c626626566..2cfb29ee59 100644 --- a/tests/bootstrap.go +++ b/tests/bootstrap.go @@ -5,6 +5,7 @@ package tests import ( "context" "fmt" + "gorm.io/gorm" database_config "github.com/flyteorg/flyteadmin/pkg/repositories/config" diff --git a/tests/execution_test.go b/tests/execution_test.go index 2d6d835bf2..079738984b 100644 --- a/tests/execution_test.go +++ b/tests/execution_test.go @@ -8,17 +8,15 @@ import ( "testing" "time" + database_config "github.com/flyteorg/flyteadmin/pkg/repositories/config" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" "github.com/flyteorg/flytestdlib/logger" - database_config "github.com/flyteorg/flyteadmin/pkg/repositories/config" "github.com/golang/protobuf/proto" "github.com/golang/protobuf/ptypes" "github.com/stretchr/testify/assert" - - ) var workflowExecutionID = &core.WorkflowExecutionIdentifier{