diff --git a/flyteadmin/pkg/manager/impl/execution_manager.go b/flyteadmin/pkg/manager/impl/execution_manager.go index bd85d8bee7..1dd3f8119f 100644 --- a/flyteadmin/pkg/manager/impl/execution_manager.go +++ b/flyteadmin/pkg/manager/impl/execution_manager.go @@ -371,7 +371,7 @@ func (m *ExecutionManager) launchSingleTaskExecution( ctx context.Context, request admin.ExecutionCreateRequest, requestedAt time.Time) ( context.Context, *models.Execution, error) { - taskModel, err := m.db.TaskRepo().Get(ctx, repositoryInterfaces.GetResourceInput{ + taskModel, err := m.db.TaskRepo().Get(ctx, repositoryInterfaces.Identifier{ Project: request.Spec.LaunchPlan.Project, Domain: request.Spec.LaunchPlan.Domain, Name: request.Spec.LaunchPlan.Name, @@ -1264,7 +1264,7 @@ func (m *ExecutionManager) TerminateExecution( } ctx = getExecutionContext(ctx, request.Id) // Save the abort reason (best effort) - executionModel, err := m.db.ExecutionRepo().Get(ctx, repositoryInterfaces.GetResourceInput{ + executionModel, err := m.db.ExecutionRepo().Get(ctx, repositoryInterfaces.Identifier{ Project: request.Id.Project, Domain: request.Id.Domain, Name: request.Id.Name, diff --git a/flyteadmin/pkg/manager/impl/execution_manager_test.go b/flyteadmin/pkg/manager/impl/execution_manager_test.go index bb41d83158..9ba7ae75c5 100644 --- a/flyteadmin/pkg/manager/impl/execution_manager_test.go +++ b/flyteadmin/pkg/manager/impl/execution_manager_test.go @@ -131,7 +131,7 @@ func setDefaultLpCallbackForExecTest(repository repositories.RepositoryInterface } lpClosureBytes, _ := proto.Marshal(&lpClosure) - lpGetFunc := func(input interfaces.GetResourceInput) (models.LaunchPlan, error) { + lpGetFunc := func(input interfaces.Identifier) (models.LaunchPlan, error) { lpModel := models.LaunchPlan{ LaunchPlanKey: models.LaunchPlanKey{ Project: input.Project, @@ -179,7 +179,7 @@ func getMockStorageForExecTest(ctx context.Context) *storage.DataStore { func getMockRepositoryForExecTest() repositories.RepositoryInterface { repository := repositoryMocks.NewMockRepository() repository.WorkflowRepo().(*repositoryMocks.MockWorkflowRepo).SetGetCallback( - func(input interfaces.GetResourceInput) (models.Workflow, error) { + func(input interfaces.Identifier) (models.Workflow, error) { return models.Workflow{ BaseModel: models.BaseModel{ CreatedAt: testutils.MockCreatedAtValue, @@ -289,7 +289,7 @@ func TestCreateExecutionFromWorkflowNode(t *testing.T) { getNodeExecutionCalled := false repository.NodeExecutionRepo().(*repositoryMocks.MockNodeExecutionRepo).SetGetCallback( - func(ctx context.Context, input interfaces.GetNodeExecutionInput) (models.NodeExecution, error) { + func(ctx context.Context, input interfaces.NodeExecutionResource) (models.NodeExecution, error) { assert.EqualValues(t, input.NodeExecutionIdentifier, parentNodeExecutionID) getNodeExecutionCalled = true return models.NodeExecution{ @@ -301,7 +301,7 @@ func TestCreateExecutionFromWorkflowNode(t *testing.T) { ) getExecutionCalled := false repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetGetCallback( - func(ctx context.Context, input interfaces.GetResourceInput) (models.Execution, error) { + func(ctx context.Context, input interfaces.Identifier) (models.Execution, error) { assert.EqualValues(t, input.Project, parentNodeExecutionID.ExecutionId.Project) assert.EqualValues(t, input.Domain, parentNodeExecutionID.ExecutionId.Domain) assert.EqualValues(t, input.Name, parentNodeExecutionID.ExecutionId.Name) @@ -666,7 +666,7 @@ func TestCreateExecutionNoNotifications(t *testing.T) { // The LaunchPlan is retrieved within the CreateExecution call to ExecutionManager. // Create a callback method used by the mock to retrieve a LaunchPlan. - lpGetFunc := func(input interfaces.GetResourceInput) (models.LaunchPlan, error) { + lpGetFunc := func(input interfaces.Identifier) (models.LaunchPlan, error) { lpModel := models.LaunchPlan{ LaunchPlanKey: models.LaunchPlanKey{ Project: input.Project, @@ -754,7 +754,7 @@ func TestCreateExecutionDynamicLabelsAndAnnotations(t *testing.T) { func makeExecutionGetFunc( t *testing.T, closureBytes []byte, startTime *time.Time) repositoryMocks.GetExecutionFunc { - return func(ctx context.Context, input interfaces.GetResourceInput) (models.Execution, error) { + return func(ctx context.Context, input interfaces.Identifier) (models.Execution, error) { assert.Equal(t, "project", input.Project) assert.Equal(t, "domain", input.Domain) assert.Equal(t, "name", input.Name) @@ -780,7 +780,7 @@ func makeExecutionGetFunc( func makeLegacyExecutionGetFunc( t *testing.T, closureBytes []byte, startTime *time.Time) repositoryMocks.GetExecutionFunc { - return func(ctx context.Context, input interfaces.GetResourceInput) (models.Execution, error) { + return func(ctx context.Context, input interfaces.Identifier) (models.Execution, error) { assert.Equal(t, "project", input.Project) assert.Equal(t, "domain", input.Domain) assert.Equal(t, "name", input.Name) @@ -869,7 +869,7 @@ func TestRelaunchExecution_GetExistingFailure(t *testing.T) { expectedErr := errors.New("expected error") repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetGetCallback( - func(ctx context.Context, input interfaces.GetResourceInput) (models.Execution, error) { + func(ctx context.Context, input interfaces.Identifier) (models.Execution, error) { return models.Execution{}, expectedErr }) @@ -1002,7 +1002,7 @@ func TestCreateWorkflowEvent(t *testing.T) { func TestCreateWorkflowEvent_TerminalState(t *testing.T) { repository := repositoryMocks.NewMockRepository() - executionGetFunc := func(ctx context.Context, input interfaces.GetResourceInput) (models.Execution, error) { + executionGetFunc := func(ctx context.Context, input interfaces.Identifier) (models.Execution, error) { return models.Execution{ ExecutionKey: models.ExecutionKey{ Project: "project", @@ -1096,7 +1096,7 @@ func TestCreateWorkflowEvent_DuplicateRunning(t *testing.T) { occurredAt := time.Now().UTC() repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetGetCallback( - func(ctx context.Context, input interfaces.GetResourceInput) (models.Execution, error) { + func(ctx context.Context, input interfaces.Identifier) (models.Execution, error) { return models.Execution{ ExecutionKey: models.ExecutionKey{ Project: "project", @@ -1137,7 +1137,7 @@ func TestCreateWorkflowEvent_InvalidPhaseChange(t *testing.T) { occurredAt := time.Now().UTC() repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetGetCallback( - func(ctx context.Context, input interfaces.GetResourceInput) (models.Execution, error) { + func(ctx context.Context, input interfaces.Identifier) (models.Execution, error) { return models.Execution{ ExecutionKey: models.ExecutionKey{ Project: "project", @@ -1240,7 +1240,7 @@ func TestCreateWorkflowEvent_DatabaseGetError(t *testing.T) { startTime := time.Now() expectedErr := errors.New("expected error") - executionGetFunc := func(ctx context.Context, input interfaces.GetResourceInput) (models.Execution, error) { + executionGetFunc := func(ctx context.Context, input interfaces.Identifier) (models.Execution, error) { return models.Execution{}, expectedErr } repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetGetCallback(executionGetFunc) @@ -1304,7 +1304,7 @@ func TestCreateWorkflowEvent_DatabaseUpdateError(t *testing.T) { func TestGetExecution(t *testing.T) { repository := repositoryMocks.NewMockRepository() startedAt := time.Date(2018, 8, 30, 0, 0, 0, 0, time.UTC) - executionGetFunc := func(ctx context.Context, input interfaces.GetResourceInput) (models.Execution, error) { + executionGetFunc := func(ctx context.Context, input interfaces.Identifier) (models.Execution, error) { assert.Equal(t, "project", input.Project) assert.Equal(t, "domain", input.Domain) assert.Equal(t, "name", input.Name) @@ -1338,7 +1338,7 @@ func TestGetExecution_DatabaseError(t *testing.T) { repository := repositoryMocks.NewMockRepository() expectedErr := errors.New("expected error") - executionGetFunc := func(ctx context.Context, input interfaces.GetResourceInput) (models.Execution, error) { + executionGetFunc := func(ctx context.Context, input interfaces.Identifier) (models.Execution, error) { assert.Equal(t, "project", input.Project) assert.Equal(t, "domain", input.Domain) assert.Equal(t, "name", input.Name) @@ -1356,7 +1356,7 @@ func TestGetExecution_DatabaseError(t *testing.T) { func TestGetExecution_TransformerError(t *testing.T) { repository := repositoryMocks.NewMockRepository() startedAt := time.Date(2018, 8, 30, 0, 0, 0, 0, time.UTC) - executionGetFunc := func(ctx context.Context, input interfaces.GetResourceInput) (models.Execution, error) { + executionGetFunc := func(ctx context.Context, input interfaces.Identifier) (models.Execution, error) { assert.Equal(t, "project", input.Project) assert.Equal(t, "domain", input.Domain) assert.Equal(t, "name", input.Name) @@ -1927,7 +1927,7 @@ func TestGetExecutionData(t *testing.T) { } var closureBytes, _ = proto.Marshal(&closure) - executionGetFunc := func(ctx context.Context, input interfaces.GetResourceInput) (models.Execution, error) { + executionGetFunc := func(ctx context.Context, input interfaces.Identifier) (models.Execution, error) { return models.Execution{ ExecutionKey: models.ExecutionKey{ Project: "project", @@ -2120,7 +2120,7 @@ func TestPluginOverrides_ResourceGetFailure(t *testing.T) { func TestGetExecution_Legacy(t *testing.T) { repository := repositoryMocks.NewMockRepository() startedAt := time.Date(2018, 8, 30, 0, 0, 0, 0, time.UTC) - executionGetFunc := func(ctx context.Context, input interfaces.GetResourceInput) (models.Execution, error) { + executionGetFunc := func(ctx context.Context, input interfaces.Identifier) (models.Execution, error) { assert.Equal(t, "project", input.Project) assert.Equal(t, "domain", input.Domain) assert.Equal(t, "name", input.Name) @@ -2152,7 +2152,7 @@ func TestGetExecution_Legacy(t *testing.T) { func TestGetExecution_LegacyClient_OffloadedData(t *testing.T) { repository := repositoryMocks.NewMockRepository() startedAt := time.Date(2018, 8, 30, 0, 0, 0, 0, time.UTC) - executionGetFunc := func(ctx context.Context, input interfaces.GetResourceInput) (models.Execution, error) { + executionGetFunc := func(ctx context.Context, input interfaces.Identifier) (models.Execution, error) { assert.Equal(t, "project", input.Project) assert.Equal(t, "domain", input.Domain) assert.Equal(t, "name", input.Name) @@ -2199,7 +2199,7 @@ func TestGetExecutionData_LegacyModel(t *testing.T) { } var closureBytes, _ = proto.Marshal(closure) - executionGetFunc := func(ctx context.Context, input interfaces.GetResourceInput) (models.Execution, error) { + executionGetFunc := func(ctx context.Context, input interfaces.Identifier) (models.Execution, error) { return models.Execution{ ExecutionKey: models.ExecutionKey{ Project: "project", @@ -2742,7 +2742,7 @@ func TestCreateSingleTaskExecution(t *testing.T) { } repository.WorkflowRepo().(*repositoryMocks.MockWorkflowRepo).SetCreateCallback(workflowcreateFunc) - workflowGetFunc := func(input interfaces.GetResourceInput) (models.Workflow, error) { + workflowGetFunc := func(input interfaces.Identifier) (models.Workflow, error) { if getCalledCount <= 1 { getCalledCount++ return models.Workflow{}, flyteAdminErrors.NewFlyteAdminErrorf(codes.NotFound, "not found") @@ -2752,7 +2752,7 @@ func TestCreateSingleTaskExecution(t *testing.T) { } repository.WorkflowRepo().(*repositoryMocks.MockWorkflowRepo).SetGetCallback(workflowGetFunc) repository.TaskRepo().(*repositoryMocks.MockTaskRepo).SetGetCallback( - func(input interfaces.GetResourceInput) (models.Task, error) { + func(input interfaces.Identifier) (models.Task, error) { createdAt := time.Now() createdAtProto, _ := ptypes.TimestampProto(createdAt) taskClosure := &admin.TaskClosure{ diff --git a/flyteadmin/pkg/manager/impl/launch_plan_manager.go b/flyteadmin/pkg/manager/impl/launch_plan_manager.go index ca46065136..a3614268d0 100644 --- a/flyteadmin/pkg/manager/impl/launch_plan_manager.go +++ b/flyteadmin/pkg/manager/impl/launch_plan_manager.go @@ -279,7 +279,7 @@ func (m *LaunchPlanManager) disableLaunchPlan(ctx context.Context, request admin func (m *LaunchPlanManager) enableLaunchPlan(ctx context.Context, request admin.LaunchPlanUpdateRequest) ( *admin.LaunchPlanUpdateResponse, error) { - newlyActiveLaunchPlanModel, err := m.db.LaunchPlanRepo().Get(ctx, repoInterfaces.GetResourceInput{ + newlyActiveLaunchPlanModel, err := m.db.LaunchPlanRepo().Get(ctx, repoInterfaces.Identifier{ Project: request.Id.Project, Domain: request.Id.Domain, Name: request.Id.Name, diff --git a/flyteadmin/pkg/manager/impl/launch_plan_manager_test.go b/flyteadmin/pkg/manager/impl/launch_plan_manager_test.go index 5914b8a9b0..ac4797cbe6 100644 --- a/flyteadmin/pkg/manager/impl/launch_plan_manager_test.go +++ b/flyteadmin/pkg/manager/impl/launch_plan_manager_test.go @@ -63,7 +63,7 @@ func getMockConfigForLpTest() runtimeInterfaces.Configuration { func setDefaultWorkflowCallbackForLpTest(repository repositories.RepositoryInterface) { workflowSpec := testutils.GetSampleWorkflowSpecForTest() typedInterface, _ := proto.Marshal(workflowSpec.Template.Interface) - workflowGetFunc := func(input interfaces.GetResourceInput) (models.Workflow, error) { + workflowGetFunc := func(input interfaces.Identifier) (models.Workflow, error) { return models.Workflow{ WorkflowKey: models.WorkflowKey{ Project: input.Project, @@ -80,7 +80,7 @@ func setDefaultWorkflowCallbackForLpTest(repository repositories.RepositoryInter func TestCreateLaunchPlan(t *testing.T) { repository := getMockRepositoryForLpTest() repository.LaunchPlanRepo().(*repositoryMocks.MockLaunchPlanRepo).SetGetCallback( - func(input interfaces.GetResourceInput) (models.LaunchPlan, error) { + func(input interfaces.Identifier) (models.LaunchPlan, error) { return models.LaunchPlan{}, errors.New("foo") }) var createCalled bool @@ -116,7 +116,7 @@ func TestLaunchPlanManager_GetLaunchPlan(t *testing.T) { specBytes, _ := proto.Marshal(lpRequest.Spec) closureBytes, _ := proto.Marshal(&closure) - launchPlanGetFunc := func(input interfaces.GetResourceInput) (models.LaunchPlan, error) { + launchPlanGetFunc := func(input interfaces.Identifier) (models.LaunchPlan, error) { return models.LaunchPlan{ LaunchPlanKey: models.LaunchPlanKey{ Project: input.Project, @@ -242,7 +242,7 @@ func TestLaunchPlan_ValidationError(t *testing.T) { func TestLaunchPlan_DatabaseError(t *testing.T) { repository := getMockRepositoryForLpTest() repository.LaunchPlanRepo().(*repositoryMocks.MockLaunchPlanRepo).SetGetCallback( - func(input interfaces.GetResourceInput) (models.LaunchPlan, error) { + func(input interfaces.Identifier) (models.LaunchPlan, error) { return models.LaunchPlan{}, errors.New("foo") }) setDefaultWorkflowCallbackForLpTest(repository) @@ -281,7 +281,7 @@ func TestCreateLaunchPlanInCompatibleInputs(t *testing.T) { func TestCreateLaunchPlanValidateCreate(t *testing.T) { repository := getMockRepositoryForLpTest() repository.LaunchPlanRepo().(*repositoryMocks.MockLaunchPlanRepo).SetGetCallback( - func(input interfaces.GetResourceInput) (models.LaunchPlan, error) { + func(input interfaces.Identifier) (models.LaunchPlan, error) { return models.LaunchPlan{}, errors.New("foo") }) setDefaultWorkflowCallbackForLpTest(repository) @@ -324,10 +324,10 @@ func TestCreateLaunchPlanValidateCreate(t *testing.T) { func TestCreateLaunchPlanNoWorkflowInterface(t *testing.T) { repository := getMockRepositoryForLpTest() repository.LaunchPlanRepo().(*repositoryMocks.MockLaunchPlanRepo).SetGetCallback( - func(input interfaces.GetResourceInput) (models.LaunchPlan, error) { + func(input interfaces.Identifier) (models.LaunchPlan, error) { return models.LaunchPlan{}, errors.New("foo") }) - workflowGetFunc := func(input interfaces.GetResourceInput) (models.Workflow, error) { + workflowGetFunc := func(input interfaces.Identifier) (models.Workflow, error) { return models.Workflow{ WorkflowKey: models.WorkflowKey{ Project: input.Project, @@ -365,7 +365,7 @@ func TestCreateLaunchPlanNoWorkflowInterface(t *testing.T) { } func makeLaunchPlanRepoGetCallback(t *testing.T) repositoryMocks.GetLaunchPlanFunc { - return func(input interfaces.GetResourceInput) (models.LaunchPlan, error) { + return func(input interfaces.Identifier) (models.LaunchPlan, error) { assert.Equal(t, project, input.Project) assert.Equal(t, domain, input.Domain) assert.Equal(t, name, input.Name) @@ -756,7 +756,7 @@ func TestUpdateSchedules_EnableNoSchedule(t *testing.T) { func TestDisableLaunchPlan(t *testing.T) { repository := getMockRepositoryForLpTest() - lpGetFunc := func(input interfaces.GetResourceInput) (models.LaunchPlan, error) { + lpGetFunc := func(input interfaces.Identifier) (models.LaunchPlan, error) { assert.Equal(t, project, input.Project) assert.Equal(t, domain, input.Domain) assert.Equal(t, name, input.Name) @@ -816,7 +816,7 @@ func TestDisableLaunchPlan_DatabaseError(t *testing.T) { repository := getMockRepositoryForLpTest() expectedError := errors.New("expected error") - lpGetFunc := func(input interfaces.GetResourceInput) (models.LaunchPlan, error) { + lpGetFunc := func(input interfaces.Identifier) (models.LaunchPlan, error) { assert.Equal(t, project, input.Project) assert.Equal(t, domain, input.Domain) assert.Equal(t, name, input.Name) @@ -832,7 +832,7 @@ func TestDisableLaunchPlan_DatabaseError(t *testing.T) { assert.EqualError(t, err, expectedError.Error(), "Failures on getting the existing launch plan should propagate") - lpGetFunc = func(input interfaces.GetResourceInput) (models.LaunchPlan, error) { + lpGetFunc = func(input interfaces.Identifier) (models.LaunchPlan, error) { assert.Equal(t, project, input.Project) assert.Equal(t, domain, input.Domain) assert.Equal(t, name, input.Name) @@ -955,7 +955,7 @@ func TestEnableLaunchPlan_DatabaseError(t *testing.T) { repository := getMockRepositoryForLpTest() expectedError := errors.New("expected error") - lpGetFunc := func(input interfaces.GetResourceInput) (models.LaunchPlan, error) { + lpGetFunc := func(input interfaces.Identifier) (models.LaunchPlan, error) { assert.Equal(t, project, input.Project) assert.Equal(t, domain, input.Domain) assert.Equal(t, name, input.Name) diff --git a/flyteadmin/pkg/manager/impl/node_execution_manager.go b/flyteadmin/pkg/manager/impl/node_execution_manager.go index 9cbff6830f..3a6e9610ba 100644 --- a/flyteadmin/pkg/manager/impl/node_execution_manager.go +++ b/flyteadmin/pkg/manager/impl/node_execution_manager.go @@ -80,6 +80,24 @@ func getNodeExecutionContext(ctx context.Context, identifier *core.NodeExecution func (m *NodeExecutionManager) createNodeExecutionWithEvent( ctx context.Context, request *admin.NodeExecutionEventRequest) error { + executionID := request.Event.Id.ExecutionId + workflowExecutionExists, err := m.db.ExecutionRepo().Exists(ctx, repoInterfaces.Identifier{ + Project: executionID.Project, + Domain: executionID.Domain, + Name: executionID.Name, + }) + if err != nil || !workflowExecutionExists { + m.metrics.MissingWorkflowExecution.Inc() + logger.Debugf(ctx, "Failed to find existing execution with id [%+v] with err: %v", executionID, err) + if err != nil { + if ferr, ok := err.(errors.FlyteAdminError); ok { + return errors.NewFlyteAdminErrorf(ferr.Code(), + "Failed to get existing execution id: [%+v] with err: %v", executionID, err) + } + } + return fmt.Errorf("failed to get existing execution id: [%+v]", executionID) + } + var parentTaskExecutionID uint if request.Event.ParentTaskMetadata != nil { taskExecutionModel, err := util.GetTaskExecutionModel(ctx, m.db, request.Event.ParentTaskMetadata.Id) @@ -186,22 +204,10 @@ func (m *NodeExecutionManager) CreateNodeEvent(ctx context.Context, request admi logger.Debugf(ctx, "CreateNodeEvent called with invalid identifier [%+v]: %v", request.Event.Id, err) } ctx = getNodeExecutionContext(ctx, request.Event.Id) - executionID := request.Event.Id.ExecutionId logger.Debugf(ctx, "Received node execution event for Node Exec Id [%+v] transitioning to phase [%v], w/ Metadata [%v]", request.Event.Id, request.Event.Phase, request.Event.ParentTaskMetadata) - _, err := util.GetExecutionModel(ctx, m.db, *executionID) - if err != nil { - m.metrics.MissingWorkflowExecution.Inc() - logger.Debugf(ctx, "Failed to find existing execution with id [%+v] with err: %v", executionID, err) - if ferr, ok := err.(errors.FlyteAdminError); ok { - return nil, errors.NewFlyteAdminErrorf(ferr.Code(), - "Failed to get existing execution id:[%+v] with err: %v", executionID, err) - } - return nil, fmt.Errorf("failed to get existing execution id: [%+v] with err: %v", executionID, err) - } - - nodeExecutionModel, err := m.db.NodeExecutionRepo().Get(ctx, repoInterfaces.GetNodeExecutionInput{ + nodeExecutionModel, err := m.db.NodeExecutionRepo().Get(ctx, repoInterfaces.NodeExecutionResource{ NodeExecutionIdentifier: *request.Event.Id, }) if err != nil { diff --git a/flyteadmin/pkg/manager/impl/node_execution_manager_test.go b/flyteadmin/pkg/manager/impl/node_execution_manager_test.go index 6770ee2ed0..57d5b9b109 100644 --- a/flyteadmin/pkg/manager/impl/node_execution_manager_test.go +++ b/flyteadmin/pkg/manager/impl/node_execution_manager_test.go @@ -65,7 +65,7 @@ var mockNodeExecutionRemoteURL = dataMocks.NewMockRemoteURL() func addGetExecutionCallback(t *testing.T, repository repositories.RepositoryInterface) { repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetGetCallback( - func(ctx context.Context, input interfaces.GetResourceInput) (models.Execution, error) { + func(ctx context.Context, input interfaces.Identifier) (models.Execution, error) { assert.Equal(t, "project", input.Project) assert.Equal(t, "domain", input.Domain) assert.Equal(t, "name", input.Name) @@ -86,7 +86,7 @@ func TestCreateNodeEvent(t *testing.T) { repository := repositoryMocks.NewMockRepository() addGetExecutionCallback(t, repository) repository.NodeExecutionRepo().(*repositoryMocks.MockNodeExecutionRepo).SetGetCallback( - func(ctx context.Context, input interfaces.GetNodeExecutionInput) (models.NodeExecution, error) { + func(ctx context.Context, input interfaces.NodeExecutionResource) (models.NodeExecution, error) { assert.True(t, proto.Equal(&core.NodeExecutionIdentifier{ NodeId: "node id", ExecutionId: &workflowExecutionIdentifier, @@ -146,7 +146,7 @@ func TestCreateNodeEvent_Update(t *testing.T) { repository := repositoryMocks.NewMockRepository() addGetExecutionCallback(t, repository) repository.NodeExecutionRepo().(*repositoryMocks.MockNodeExecutionRepo).SetGetCallback( - func(ctx context.Context, input interfaces.GetNodeExecutionInput) (models.NodeExecution, error) { + func(ctx context.Context, input interfaces.NodeExecutionResource) (models.NodeExecution, error) { assert.True(t, proto.Equal(&core.NodeExecutionIdentifier{ NodeId: "node id", ExecutionId: &workflowExecutionIdentifier, @@ -200,23 +200,37 @@ func TestCreateNodeEvent_Update(t *testing.T) { func TestCreateNodeEvent_MissingExecution(t *testing.T) { repository := repositoryMocks.NewMockRepository() - expectedErr := errors.New("expected error") - repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetGetCallback( - func(ctx context.Context, input interfaces.GetResourceInput) (models.Execution, error) { - return models.Execution{}, expectedErr + expectedErr := flyteAdminErrors.NewFlyteAdminErrorf(codes.Internal, "expected error") + repository.NodeExecutionRepo().(*repositoryMocks.MockNodeExecutionRepo).SetGetCallback( + func(ctx context.Context, input interfaces.NodeExecutionResource) (models.NodeExecution, error) { + return models.NodeExecution{}, flyteAdminErrors.NewFlyteAdminError(codes.NotFound, "foo") }) + repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).ExistsFunction = + func(ctx context.Context, input interfaces.Identifier) (bool, error) { + return false, expectedErr + } nodeExecManager := NewNodeExecutionManager(repository, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockNodeExecutionRemoteURL, &mockPublisher) resp, err := nodeExecManager.CreateNodeEvent(context.Background(), request) - assert.EqualError(t, err, "failed to get existing execution id: [project:\"project\""+ + assert.EqualError(t, err, "Failed to get existing execution id: [project:\"project\""+ " domain:\"domain\" name:\"name\" ] with err: expected error") assert.Nil(t, resp) + + repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).ExistsFunction = + func(ctx context.Context, input interfaces.Identifier) (bool, error) { + return false, nil + } + nodeExecManager = NewNodeExecutionManager(repository, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockNodeExecutionRemoteURL, &mockPublisher) + resp, err = nodeExecManager.CreateNodeEvent(context.Background(), request) + assert.EqualError(t, err, "failed to get existing execution id: [project:\"project\""+ + " domain:\"domain\" name:\"name\" ]") + assert.Nil(t, resp) } func TestCreateNodeEvent_CreateDatabaseError(t *testing.T) { repository := repositoryMocks.NewMockRepository() addGetExecutionCallback(t, repository) repository.NodeExecutionRepo().(*repositoryMocks.MockNodeExecutionRepo).SetGetCallback( - func(ctx context.Context, input interfaces.GetNodeExecutionInput) (models.NodeExecution, error) { + func(ctx context.Context, input interfaces.NodeExecutionResource) (models.NodeExecution, error) { return models.NodeExecution{}, flyteAdminErrors.NewFlyteAdminError(codes.NotFound, "foo") }) @@ -235,7 +249,7 @@ func TestCreateNodeEvent_UpdateDatabaseError(t *testing.T) { repository := repositoryMocks.NewMockRepository() addGetExecutionCallback(t, repository) repository.NodeExecutionRepo().(*repositoryMocks.MockNodeExecutionRepo).SetGetCallback( - func(ctx context.Context, input interfaces.GetNodeExecutionInput) (models.NodeExecution, error) { + func(ctx context.Context, input interfaces.NodeExecutionResource) (models.NodeExecution, error) { assert.True(t, proto.Equal(&core.NodeExecutionIdentifier{ NodeId: "node id", ExecutionId: &workflowExecutionIdentifier, @@ -270,7 +284,7 @@ func TestCreateNodeEvent_UpdateTerminalEventError(t *testing.T) { repository := repositoryMocks.NewMockRepository() addGetExecutionCallback(t, repository) repository.NodeExecutionRepo().(*repositoryMocks.MockNodeExecutionRepo).SetGetCallback( - func(ctx context.Context, input interfaces.GetNodeExecutionInput) (models.NodeExecution, error) { + func(ctx context.Context, input interfaces.NodeExecutionResource) (models.NodeExecution, error) { assert.True(t, proto.Equal(&core.NodeExecutionIdentifier{ NodeId: "node id", ExecutionId: &workflowExecutionIdentifier, @@ -305,7 +319,7 @@ func TestCreateNodeEvent_UpdateDuplicateEventError(t *testing.T) { repository := repositoryMocks.NewMockRepository() addGetExecutionCallback(t, repository) repository.NodeExecutionRepo().(*repositoryMocks.MockNodeExecutionRepo).SetGetCallback( - func(ctx context.Context, input interfaces.GetNodeExecutionInput) (models.NodeExecution, error) { + func(ctx context.Context, input interfaces.NodeExecutionResource) (models.NodeExecution, error) { assert.True(t, proto.Equal(&core.NodeExecutionIdentifier{ NodeId: "node id", ExecutionId: &workflowExecutionIdentifier, @@ -334,7 +348,7 @@ func TestCreateNodeEvent_FirstEventIsTerminal(t *testing.T) { repository := repositoryMocks.NewMockRepository() addGetExecutionCallback(t, repository) repository.NodeExecutionRepo().(*repositoryMocks.MockNodeExecutionRepo).SetGetCallback( - func(ctx context.Context, input interfaces.GetNodeExecutionInput) (models.NodeExecution, error) { + func(ctx context.Context, input interfaces.NodeExecutionResource) (models.NodeExecution, error) { return models.NodeExecution{}, flyteAdminErrors.NewFlyteAdminError(codes.NotFound, "foo") }) nodeExecManager := NewNodeExecutionManager(repository, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockNodeExecutionRemoteURL, &mockPublisher) @@ -372,7 +386,7 @@ func TestGetNodeExecution(t *testing.T) { metadataBytes, _ := proto.Marshal(&expectedMetadata) closureBytes, _ := proto.Marshal(&expectedClosure) repository.NodeExecutionRepo().(*repositoryMocks.MockNodeExecutionRepo).SetGetCallback( - func(ctx context.Context, input interfaces.GetNodeExecutionInput) (models.NodeExecution, error) { + func(ctx context.Context, input interfaces.NodeExecutionResource) (models.NodeExecution, error) { workflowExecutionIdentifier := core.WorkflowExecutionIdentifier{ Project: "project", Domain: "domain", @@ -423,7 +437,7 @@ func TestGetNodeExecutionParentNode(t *testing.T) { metadataBytes, _ := proto.Marshal(&expectedMetadata) closureBytes, _ := proto.Marshal(&expectedClosure) repository.NodeExecutionRepo().(*repositoryMocks.MockNodeExecutionRepo).SetGetCallback( - func(ctx context.Context, input interfaces.GetNodeExecutionInput) (models.NodeExecution, error) { + func(ctx context.Context, input interfaces.NodeExecutionResource) (models.NodeExecution, error) { workflowExecutionIdentifier := core.WorkflowExecutionIdentifier{ Project: "project", Domain: "domain", @@ -479,7 +493,7 @@ func TestGetNodeExecution_DatabaseError(t *testing.T) { repository := repositoryMocks.NewMockRepository() expectedErr := errors.New("expected error") repository.NodeExecutionRepo().(*repositoryMocks.MockNodeExecutionRepo).SetGetCallback( - func(ctx context.Context, input interfaces.GetNodeExecutionInput) (models.NodeExecution, error) { + func(ctx context.Context, input interfaces.NodeExecutionResource) (models.NodeExecution, error) { return models.NodeExecution{}, expectedErr }) nodeExecManager := NewNodeExecutionManager(repository, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockNodeExecutionRemoteURL, nil) @@ -493,7 +507,7 @@ func TestGetNodeExecution_DatabaseError(t *testing.T) { func TestGetNodeExecution_TransformerError(t *testing.T) { repository := repositoryMocks.NewMockRepository() repository.NodeExecutionRepo().(*repositoryMocks.MockNodeExecutionRepo).SetGetCallback( - func(ctx context.Context, input interfaces.GetNodeExecutionInput) (models.NodeExecution, error) { + func(ctx context.Context, input interfaces.NodeExecutionResource) (models.NodeExecution, error) { return models.NodeExecution{ NodeExecutionKey: models.NodeExecutionKey{ NodeID: "node id", @@ -622,7 +636,7 @@ func TestListNodeExecutionsWithParent(t *testing.T) { metadataBytes, _ := proto.Marshal(&expectedMetadata) closureBytes, _ := proto.Marshal(&expectedClosure) parentID := uint(12) - repository.NodeExecutionRepo().(*repositoryMocks.MockNodeExecutionRepo).SetGetCallback(func(ctx context.Context, input interfaces.GetNodeExecutionInput) (execution models.NodeExecution, e error) { + repository.NodeExecutionRepo().(*repositoryMocks.MockNodeExecutionRepo).SetGetCallback(func(ctx context.Context, input interfaces.NodeExecutionResource) (execution models.NodeExecution, e error) { assert.Equal(t, "parent_1", input.NodeExecutionIdentifier.NodeId) return models.NodeExecution{ BaseModel: models.BaseModel{ @@ -959,7 +973,7 @@ func TestGetNodeExecutionData(t *testing.T) { closureBytes, _ := proto.Marshal(&expectedClosure) repository.NodeExecutionRepo().(*repositoryMocks.MockNodeExecutionRepo).SetGetCallback( - func(ctx context.Context, input interfaces.GetNodeExecutionInput) (models.NodeExecution, error) { + func(ctx context.Context, input interfaces.NodeExecutionResource) (models.NodeExecution, error) { workflowExecutionIdentifier := core.WorkflowExecutionIdentifier{ Project: "project", Domain: "domain", diff --git a/flyteadmin/pkg/manager/impl/task_execution_manager.go b/flyteadmin/pkg/manager/impl/task_execution_manager.go index f96c0c78a8..1c6bdb454a 100644 --- a/flyteadmin/pkg/manager/impl/task_execution_manager.go +++ b/flyteadmin/pkg/manager/impl/task_execution_manager.go @@ -61,9 +61,25 @@ func getTaskExecutionContext(ctx context.Context, identifier *core.TaskExecution } func (m *TaskExecutionManager) createTaskExecution( - ctx context.Context, nodeExecutionModel *models.NodeExecution, request *admin.TaskExecutionEventRequest) ( + ctx context.Context, request *admin.TaskExecutionEventRequest) ( models.TaskExecution, error) { + nodeExecutionID := request.Event.ParentNodeExecutionId + nodeExecutionExists, err := m.db.NodeExecutionRepo().Exists(ctx, repoInterfaces.NodeExecutionResource{ + NodeExecutionIdentifier: *nodeExecutionID, + }) + if err != nil || !nodeExecutionExists { + m.metrics.MissingTaskExecution.Inc() + logger.Debugf(ctx, "Failed to get existing node execution [%+v] with err %v", nodeExecutionID, err) + if err != nil { + if ferr, ok := err.(errors.FlyteAdminError); ok { + return models.TaskExecution{}, errors.NewFlyteAdminErrorf(ferr.Code(), + "Failed to get existing node execution id: [%+v] with err: %v", nodeExecutionID, err) + } + } + return models.TaskExecution{}, fmt.Errorf("failed to get existing node execution id: [%+v]", nodeExecutionID) + } + taskExecutionModel, err := transformers.CreateTaskExecutionModel( transformers.CreateTaskExecutionModelInput{ Request: request, @@ -73,13 +89,13 @@ func (m *TaskExecutionManager) createTaskExecution( return models.TaskExecution{}, err } if err := m.db.TaskExecutionRepo().Create(ctx, *taskExecutionModel); err != nil { - logger.Debugf(ctx, "Failed to create task execution with task id [%+v] and node execution model [%+v] with err %v", - request.Event.TaskId, nodeExecutionModel, err) + logger.Debugf(ctx, "Failed to create task execution with task id [%+v] with err %v", + request.Event.TaskId, err) return models.TaskExecution{}, err } m.metrics.TaskExecutionsCreated.Inc() - m.metrics.ClosureSizeBytes.Observe(float64(len(nodeExecutionModel.Closure))) + m.metrics.ClosureSizeBytes.Observe(float64(len(taskExecutionModel.Closure))) logger.Debugf(ctx, "created task execution: %+v", request.Event.TaskId) return *taskExecutionModel, nil } @@ -116,16 +132,6 @@ func (m *TaskExecutionManager) CreateTaskExecutionEvent(ctx context.Context, req ctx = getTaskExecutionContext(ctx, &taskExecutionID) logger.Debugf(ctx, "Received task execution event for [%+v] transitioning to phase [%v]", taskExecutionID, request.Event.Phase) - nodeExecutionModel, err := util.GetNodeExecutionModel(ctx, m.db, nodeExecutionID) - if err != nil { - m.metrics.MissingTaskExecution.Inc() - logger.Debugf(ctx, "Failed to get existing node execution [%+v] with err %v", nodeExecutionID, err) - if ferr, ok := err.(errors.FlyteAdminError); ok { - return nil, errors.NewFlyteAdminErrorf(ferr.Code(), - "Failed to get existing execution node id:[%+v] with err: %v", nodeExecutionID, err) - } - return nil, fmt.Errorf("failed to get existing node execution id: [%+v] with err: %v", nodeExecutionID, err) - } // See if the task execution exists // - if it does check if the new phase is applicable and then update @@ -139,7 +145,7 @@ func (m *TaskExecutionManager) CreateTaskExecutionEvent(ctx context.Context, req logger.Debugf(ctx, "Failed to find existing task execution [%+v] with err %v", taskExecutionID, err) return nil, err } - _, err := m.createTaskExecution(ctx, nodeExecutionModel, &request) + _, err := m.createTaskExecution(ctx, &request) if err != nil { return nil, err } diff --git a/flyteadmin/pkg/manager/impl/task_execution_manager_test.go b/flyteadmin/pkg/manager/impl/task_execution_manager_test.go index f546255f77..9b417f0260 100644 --- a/flyteadmin/pkg/manager/impl/task_execution_manager_test.go +++ b/flyteadmin/pkg/manager/impl/task_execution_manager_test.go @@ -68,7 +68,7 @@ var retryAttemptValue = uint32(1) func addGetWorkflowExecutionCallback(repository repositories.RepositoryInterface) { repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetGetCallback( - func(ctx context.Context, input interfaces.GetResourceInput) (models.Execution, error) { + func(ctx context.Context, input interfaces.Identifier) (models.Execution, error) { return models.Execution{ ExecutionKey: models.ExecutionKey{ Project: sampleNodeExecID.ExecutionId.Project, @@ -83,7 +83,7 @@ func addGetWorkflowExecutionCallback(repository repositories.RepositoryInterface func addGetNodeExecutionCallback(repository repositories.RepositoryInterface) { repository.NodeExecutionRepo().(*repositoryMocks.MockNodeExecutionRepo).SetGetCallback( - func(ctx context.Context, input interfaces.GetNodeExecutionInput) (models.NodeExecution, error) { + func(ctx context.Context, input interfaces.NodeExecutionResource) (models.NodeExecution, error) { return models.NodeExecution{ NodeExecutionKey: models.NodeExecutionKey{ NodeID: sampleNodeExecID.NodeId, @@ -100,7 +100,7 @@ func addGetNodeExecutionCallback(repository repositories.RepositoryInterface) { func addGetTaskCallback(repository repositories.RepositoryInterface) { repository.TaskRepo().(*repositoryMocks.MockTaskRepo).SetGetCallback( - func(input interfaces.GetResourceInput) (models.Task, error) { + func(input interfaces.Identifier) (models.Task, error) { return models.Task{ TaskKey: models.TaskKey{ Project: sampleTaskID.Project, @@ -298,17 +298,31 @@ func TestCreateTaskEvent_Update(t *testing.T) { func TestCreateTaskEvent_MissingExecution(t *testing.T) { repository := repositoryMocks.NewMockRepository() - expectedErr := errors.New("expected error") - repository.NodeExecutionRepo().(*repositoryMocks.MockNodeExecutionRepo).SetGetCallback( - func(ctx context.Context, input interfaces.GetNodeExecutionInput) (models.NodeExecution, error) { - return models.NodeExecution{}, expectedErr + expectedErr := flyteAdminErrors.NewFlyteAdminErrorf(codes.Internal, "expected error") + repository.TaskExecutionRepo().(*repositoryMocks.MockTaskExecutionRepo).SetGetCallback( + func(ctx context.Context, input interfaces.GetTaskExecutionInput) (models.TaskExecution, error) { + return models.TaskExecution{}, flyteAdminErrors.NewFlyteAdminError(codes.NotFound, "foo") }) + repository.NodeExecutionRepo().(*repositoryMocks.MockNodeExecutionRepo).ExistsFunction = func( + ctx context.Context, input interfaces.NodeExecutionResource) (bool, error) { + return false, expectedErr + } taskExecManager := NewTaskExecutionManager(repository, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockTaskExecutionRemoteURL, nil) resp, err := taskExecManager.CreateTaskExecutionEvent(context.Background(), taskEventRequest) - assert.EqualError(t, err, "failed to get existing node execution id: [node_id:\"node-id\""+ + assert.EqualError(t, err, "Failed to get existing node execution id: [node_id:\"node-id\""+ " execution_id:<project:\"project\" domain:\"domain\" name:\"name\" > ] "+ "with err: expected error") assert.Nil(t, resp) + + repository.NodeExecutionRepo().(*repositoryMocks.MockNodeExecutionRepo).ExistsFunction = func( + ctx context.Context, input interfaces.NodeExecutionResource) (bool, error) { + return false, nil + } + taskExecManager = NewTaskExecutionManager(repository, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockTaskExecutionRemoteURL, nil) + resp, err = taskExecManager.CreateTaskExecutionEvent(context.Background(), taskEventRequest) + assert.EqualError(t, err, "failed to get existing node execution id: [node_id:\"node-id\""+ + " execution_id:<project:\"project\" domain:\"domain\" name:\"name\" > ]") + assert.Nil(t, resp) } func TestCreateTaskEvent_CreateDatabaseError(t *testing.T) { diff --git a/flyteadmin/pkg/manager/impl/task_manager_test.go b/flyteadmin/pkg/manager/impl/task_manager_test.go index a11b399ec0..5e01aec1a6 100644 --- a/flyteadmin/pkg/manager/impl/task_manager_test.go +++ b/flyteadmin/pkg/manager/impl/task_manager_test.go @@ -71,7 +71,7 @@ func getMockTaskRepository() repositories.RepositoryInterface { func TestCreateTask(t *testing.T) { mockRepository := getMockTaskRepository() mockRepository.TaskRepo().(*repositoryMocks.MockTaskRepo).SetGetCallback( - func(input interfaces.GetResourceInput) (models.Task, error) { + func(input interfaces.Identifier) (models.Task, error) { return models.Task{}, errors.New("foo") }) var createCalled bool @@ -123,7 +123,7 @@ func TestCreateTask_CompilerError(t *testing.T) { func TestCreateTask_DatabaseError(t *testing.T) { repository := getMockTaskRepository() repository.TaskRepo().(*repositoryMocks.MockTaskRepo).SetGetCallback( - func(input interfaces.GetResourceInput) (models.Task, error) { + func(input interfaces.Identifier) (models.Task, error) { return models.Task{}, errors.New("foo") }) expectedErr := errors.New("expected error") @@ -141,7 +141,7 @@ func TestCreateTask_DatabaseError(t *testing.T) { func TestGetTask(t *testing.T) { repository := getMockTaskRepository() - taskGetFunc := func(input interfaces.GetResourceInput) (models.Task, error) { + taskGetFunc := func(input interfaces.Identifier) (models.Task, error) { assert.Equal(t, "project", input.Project) assert.Equal(t, "domain", input.Domain) assert.Equal(t, "name", input.Name) @@ -176,7 +176,7 @@ func TestGetTask(t *testing.T) { func TestGetTask_DatabaseError(t *testing.T) { repository := getMockTaskRepository() expectedErr := errors.New("expected error") - taskGetFunc := func(input interfaces.GetResourceInput) (models.Task, error) { + taskGetFunc := func(input interfaces.Identifier) (models.Task, error) { return models.Task{}, expectedErr } repository.TaskRepo().(*repositoryMocks.MockTaskRepo).SetGetCallback(taskGetFunc) @@ -190,7 +190,7 @@ func TestGetTask_DatabaseError(t *testing.T) { func TestGetTask_TransformerError(t *testing.T) { repository := getMockTaskRepository() - taskGetFunc := func(input interfaces.GetResourceInput) (models.Task, error) { + taskGetFunc := func(input interfaces.Identifier) (models.Task, error) { assert.Equal(t, "project", input.Project) assert.Equal(t, "domain", input.Domain) assert.Equal(t, "name", input.Name) diff --git a/flyteadmin/pkg/manager/impl/util/shared.go b/flyteadmin/pkg/manager/impl/util/shared.go index ef6b3ebf61..da67d9f925 100644 --- a/flyteadmin/pkg/manager/impl/util/shared.go +++ b/flyteadmin/pkg/manager/impl/util/shared.go @@ -44,7 +44,7 @@ func GetTask(ctx context.Context, repo repositories.RepositoryInterface, identif func GetWorkflowModel( ctx context.Context, repo repositories.RepositoryInterface, identifier core.Identifier) (models.Workflow, error) { - workflowModel, err := (repo).WorkflowRepo().Get(ctx, repoInterfaces.GetResourceInput{ + workflowModel, err := (repo).WorkflowRepo().Get(ctx, repoInterfaces.Identifier{ Project: identifier.Project, Domain: identifier.Domain, Name: identifier.Name, @@ -93,7 +93,7 @@ func GetWorkflow( func GetLaunchPlanModel( ctx context.Context, repo repositories.RepositoryInterface, identifier core.Identifier) (models.LaunchPlan, error) { - launchPlanModel, err := (repo).LaunchPlanRepo().Get(ctx, repoInterfaces.GetResourceInput{ + launchPlanModel, err := (repo).LaunchPlanRepo().Get(ctx, repoInterfaces.Identifier{ Project: identifier.Project, Domain: identifier.Domain, Name: identifier.Name, @@ -179,7 +179,7 @@ func ListActiveLaunchPlanVersionsFilters(project, domain string) ([]common.Inlin func GetExecutionModel( ctx context.Context, repo repositories.RepositoryInterface, identifier core.WorkflowExecutionIdentifier) ( *models.Execution, error) { - executionModel, err := repo.ExecutionRepo().Get(ctx, repoInterfaces.GetResourceInput{ + executionModel, err := repo.ExecutionRepo().Get(ctx, repoInterfaces.Identifier{ Project: identifier.Project, Domain: identifier.Domain, Name: identifier.Name, @@ -192,7 +192,7 @@ func GetExecutionModel( func GetNodeExecutionModel(ctx context.Context, repo repositories.RepositoryInterface, nodeExecutionIdentifier *core.NodeExecutionIdentifier) ( *models.NodeExecution, error) { - nodeExecutionModel, err := repo.NodeExecutionRepo().Get(ctx, repoInterfaces.GetNodeExecutionInput{ + nodeExecutionModel, err := repo.NodeExecutionRepo().Get(ctx, repoInterfaces.NodeExecutionResource{ NodeExecutionIdentifier: *nodeExecutionIdentifier, }) @@ -205,7 +205,7 @@ func GetNodeExecutionModel(ctx context.Context, repo repositories.RepositoryInte func GetTaskModel(ctx context.Context, repo repositories.RepositoryInterface, taskIdentifier *core.Identifier) ( *models.Task, error) { - taskModel, err := repo.TaskRepo().Get(ctx, repoInterfaces.GetResourceInput{ + taskModel, err := repo.TaskRepo().Get(ctx, repoInterfaces.Identifier{ Project: taskIdentifier.Project, Domain: taskIdentifier.Domain, Name: taskIdentifier.Name, diff --git a/flyteadmin/pkg/manager/impl/util/shared_test.go b/flyteadmin/pkg/manager/impl/util/shared_test.go index f9337d8019..25ba0044aa 100644 --- a/flyteadmin/pkg/manager/impl/util/shared_test.go +++ b/flyteadmin/pkg/manager/impl/util/shared_test.go @@ -52,7 +52,7 @@ func TestPopulateExecutionID_ExistingName(t *testing.T) { func TestGetTask(t *testing.T) { repository := repositoryMocks.NewMockRepository() - taskGetFunc := func(input interfaces.GetResourceInput) (models.Task, error) { + taskGetFunc := func(input interfaces.Identifier) (models.Task, error) { assert.Equal(t, project, input.Project) assert.Equal(t, domain, input.Domain) assert.Equal(t, name, input.Name) @@ -85,7 +85,7 @@ func TestGetTask(t *testing.T) { func TestGetTask_DatabaseError(t *testing.T) { repository := repositoryMocks.NewMockRepository() - taskGetFunc := func(input interfaces.GetResourceInput) (models.Task, error) { + taskGetFunc := func(input interfaces.Identifier) (models.Task, error) { return models.Task{}, errExpected } repository.TaskRepo().(*repositoryMocks.MockTaskRepo).SetGetCallback(taskGetFunc) @@ -102,7 +102,7 @@ func TestGetTask_DatabaseError(t *testing.T) { func TestGetTask_TransformerError(t *testing.T) { repository := repositoryMocks.NewMockRepository() - taskGetFunc := func(input interfaces.GetResourceInput) (models.Task, error) { + taskGetFunc := func(input interfaces.Identifier) (models.Task, error) { assert.Equal(t, project, input.Project) assert.Equal(t, domain, input.Domain) assert.Equal(t, name, input.Name) @@ -131,7 +131,7 @@ func TestGetTask_TransformerError(t *testing.T) { func TestGetWorkflowModel(t *testing.T) { repository := repositoryMocks.NewMockRepository() - workflowGetFunc := func(input interfaces.GetResourceInput) (models.Workflow, error) { + workflowGetFunc := func(input interfaces.Identifier) (models.Workflow, error) { assert.Equal(t, project, input.Project) assert.Equal(t, domain, input.Domain) assert.Equal(t, name, input.Name) @@ -165,7 +165,7 @@ func TestGetWorkflowModel(t *testing.T) { func TestGetWorkflowModel_DatabaseError(t *testing.T) { repository := repositoryMocks.NewMockRepository() - workflowGetFunc := func(input interfaces.GetResourceInput) (models.Workflow, error) { + workflowGetFunc := func(input interfaces.Identifier) (models.Workflow, error) { return models.Workflow{}, errExpected } repository.WorkflowRepo().(*repositoryMocks.MockWorkflowRepo).SetGetCallback(workflowGetFunc) @@ -208,7 +208,7 @@ func TestFetchAndGetWorkflowClosure_RemoteReadError(t *testing.T) { func TestGetWorkflow(t *testing.T) { repository := repositoryMocks.NewMockRepository() - workflowGetFunc := func(input interfaces.GetResourceInput) (models.Workflow, error) { + workflowGetFunc := func(input interfaces.Identifier) (models.Workflow, error) { assert.Equal(t, project, input.Project) assert.Equal(t, domain, input.Domain) assert.Equal(t, name, input.Name) @@ -249,7 +249,7 @@ func TestGetWorkflow(t *testing.T) { func TestGetLaunchPlanModel(t *testing.T) { repository := repositoryMocks.NewMockRepository() - getLaunchPlanFunc := func(input interfaces.GetResourceInput) (models.LaunchPlan, error) { + getLaunchPlanFunc := func(input interfaces.Identifier) (models.LaunchPlan, error) { assert.Equal(t, project, input.Project) assert.Equal(t, domain, input.Domain) assert.Equal(t, name, input.Name) @@ -281,7 +281,7 @@ func TestGetLaunchPlanModel(t *testing.T) { func TestGetLaunchPlanModel_DatabaseError(t *testing.T) { repository := repositoryMocks.NewMockRepository() - getLaunchPlanFunc := func(input interfaces.GetResourceInput) (models.LaunchPlan, error) { + getLaunchPlanFunc := func(input interfaces.Identifier) (models.LaunchPlan, error) { return models.LaunchPlan{}, errExpected } repository.LaunchPlanRepo().(*repositoryMocks.MockLaunchPlanRepo).SetGetCallback(getLaunchPlanFunc) @@ -298,7 +298,7 @@ func TestGetLaunchPlanModel_DatabaseError(t *testing.T) { func TestGetLaunchPlan(t *testing.T) { repository := repositoryMocks.NewMockRepository() - getLaunchPlanFunc := func(input interfaces.GetResourceInput) (models.LaunchPlan, error) { + getLaunchPlanFunc := func(input interfaces.Identifier) (models.LaunchPlan, error) { assert.Equal(t, project, input.Project) assert.Equal(t, domain, input.Domain) assert.Equal(t, name, input.Name) @@ -330,7 +330,7 @@ func TestGetLaunchPlan(t *testing.T) { func TestGetLaunchPlan_TransformerError(t *testing.T) { repository := repositoryMocks.NewMockRepository() - getLaunchPlanFunc := func(input interfaces.GetResourceInput) (models.LaunchPlan, error) { + getLaunchPlanFunc := func(input interfaces.Identifier) (models.LaunchPlan, error) { assert.Equal(t, project, input.Project) assert.Equal(t, domain, input.Domain) assert.Equal(t, name, input.Name) diff --git a/flyteadmin/pkg/manager/impl/util/single_task_execution.go b/flyteadmin/pkg/manager/impl/util/single_task_execution.go index d245563843..88d0af12cd 100644 --- a/flyteadmin/pkg/manager/impl/util/single_task_execution.go +++ b/flyteadmin/pkg/manager/impl/util/single_task_execution.go @@ -81,7 +81,7 @@ func CreateOrGetWorkflowModel( Name: generateWorkflowNameFromTask(taskIdentifier.Name), Version: taskIdentifier.Version, } - workflowModel, err := db.WorkflowRepo().Get(ctx, repositoryInterfaces.GetResourceInput{ + workflowModel, err := db.WorkflowRepo().Get(ctx, repositoryInterfaces.Identifier{ Project: workflowIdentifier.Project, Domain: workflowIdentifier.Domain, Name: workflowIdentifier.Name, @@ -145,7 +145,7 @@ func CreateOrGetWorkflowModel( logger.Warningf(ctx, "Failed to set skeleton workflow state to system-generated: %v", err) return nil, err } - workflowModel, err = db.WorkflowRepo().Get(ctx, repositoryInterfaces.GetResourceInput{ + workflowModel, err = db.WorkflowRepo().Get(ctx, repositoryInterfaces.Identifier{ Project: workflowIdentifier.Project, Domain: workflowIdentifier.Domain, Name: workflowIdentifier.Name, diff --git a/flyteadmin/pkg/manager/impl/util/single_task_execution_test.go b/flyteadmin/pkg/manager/impl/util/single_task_execution_test.go index be083a43b8..cfaaeb4a29 100644 --- a/flyteadmin/pkg/manager/impl/util/single_task_execution_test.go +++ b/flyteadmin/pkg/manager/impl/util/single_task_execution_test.go @@ -78,7 +78,7 @@ func TestCreateOrGetWorkflowModel(t *testing.T) { } repository.WorkflowRepo().(*repositoryMocks.MockWorkflowRepo).SetCreateCallback(workflowcreateFunc) - workflowGetFunc := func(input interfaces.GetResourceInput) (models.Workflow, error) { + workflowGetFunc := func(input interfaces.Identifier) (models.Workflow, error) { if getCalledCount == 0 { getCalledCount++ return models.Workflow{}, flyteAdminErrors.NewFlyteAdminErrorf(codes.NotFound, "not found") @@ -176,7 +176,7 @@ func TestCreateOrGetLaunchPlan(t *testing.T) { } repository.LaunchPlanRepo().(*repositoryMocks.MockLaunchPlanRepo).SetCreateCallback(launchPlanCreateFunc) - launchPlanGetFunc := func(input interfaces.GetResourceInput) (models.LaunchPlan, error) { + launchPlanGetFunc := func(input interfaces.Identifier) (models.LaunchPlan, error) { if getCalledCount == 0 { getCalledCount++ return models.LaunchPlan{}, flyteAdminErrors.NewFlyteAdminErrorf(codes.NotFound, "not found") diff --git a/flyteadmin/pkg/manager/impl/workflow_manager_test.go b/flyteadmin/pkg/manager/impl/workflow_manager_test.go index 606564d0b7..03ff87c02b 100644 --- a/flyteadmin/pkg/manager/impl/workflow_manager_test.go +++ b/flyteadmin/pkg/manager/impl/workflow_manager_test.go @@ -79,7 +79,7 @@ func getMockRepository(workflowOnGet bool) repositories.RepositoryInterface { mockRepo := repositoryMocks.NewMockRepository() if !workflowOnGet { mockRepo.(*repositoryMocks.MockRepository).WorkflowRepo().(*repositoryMocks.MockWorkflowRepo).SetGetCallback( - func(input interfaces.GetResourceInput) (models.Workflow, error) { + func(input interfaces.Identifier) (models.Workflow, error) { return models.Workflow{}, adminErrors.NewFlyteAdminError(codes.NotFound, "not found") }) } @@ -261,7 +261,7 @@ func TestCreateWorkflow_DatabaseError(t *testing.T) { func TestGetWorkflow(t *testing.T) { repository := repositoryMocks.NewMockRepository() - workflowGetFunc := func(input interfaces.GetResourceInput) (models.Workflow, error) { + workflowGetFunc := func(input interfaces.Identifier) (models.Workflow, error) { assert.Equal(t, "project", input.Project) assert.Equal(t, "domain", input.Domain) assert.Equal(t, "name", input.Name) @@ -308,7 +308,7 @@ func TestGetWorkflow(t *testing.T) { func TestGetWorkflow_DatabaseError(t *testing.T) { repository := repositoryMocks.NewMockRepository() expectedErr := errors.New("expected error") - workflowGetFunc := func(input interfaces.GetResourceInput) (models.Workflow, error) { + workflowGetFunc := func(input interfaces.Identifier) (models.Workflow, error) { return models.Workflow{}, expectedErr } repository.WorkflowRepo().(*repositoryMocks.MockWorkflowRepo).SetGetCallback(workflowGetFunc) @@ -324,7 +324,7 @@ func TestGetWorkflow_DatabaseError(t *testing.T) { func TestGetWorkflow_TransformerError(t *testing.T) { repository := repositoryMocks.NewMockRepository() - workflowGetFunc := func(input interfaces.GetResourceInput) (models.Workflow, error) { + workflowGetFunc := func(input interfaces.Identifier) (models.Workflow, error) { assert.Equal(t, "project", input.Project) assert.Equal(t, "domain", input.Domain) assert.Equal(t, "name", input.Name) diff --git a/flyteadmin/pkg/repositories/gormimpl/common.go b/flyteadmin/pkg/repositories/gormimpl/common.go index 1ecd5c65da..65ce05fbc7 100644 --- a/flyteadmin/pkg/repositories/gormimpl/common.go +++ b/flyteadmin/pkg/repositories/gormimpl/common.go @@ -14,16 +14,10 @@ import ( const Project = "project" const Domain = "domain" const Name = "name" -const Version = "version" -const Closure = "closure" const Description = "description" const ResourceType = "resource_type" const State = "state" - -const ProjectID = "project_id" -const ProjectName = "project_name" -const DomainID = "domain_id" -const DomainName = "domain_name" +const ID = "id" const executionTableName = "executions" const namedEntityMetadataTableName = "named_entity_metadata" diff --git a/flyteadmin/pkg/repositories/gormimpl/execution_repo.go b/flyteadmin/pkg/repositories/gormimpl/execution_repo.go index 8c5abb3f1d..ac78d8002b 100644 --- a/flyteadmin/pkg/repositories/gormimpl/execution_repo.go +++ b/flyteadmin/pkg/repositories/gormimpl/execution_repo.go @@ -32,7 +32,7 @@ func (r *ExecutionRepo) Create(ctx context.Context, input models.Execution) erro return nil } -func (r *ExecutionRepo) Get(ctx context.Context, input interfaces.GetResourceInput) (models.Execution, error) { +func (r *ExecutionRepo) Get(ctx context.Context, input interfaces.Identifier) (models.Execution, error) { var execution models.Execution timer := r.metrics.GetDuration.Start() tx := r.db.Where(&models.Execution{ @@ -129,6 +129,24 @@ func (r *ExecutionRepo) List(ctx context.Context, input interfaces.ListResourceI }, nil } +func (r *ExecutionRepo) Exists(ctx context.Context, input interfaces.Identifier) (bool, error) { + var execution models.Execution + timer := r.metrics.ExistsDuration.Start() + // Only select the id field (uint) to check for existence. + tx := r.db.Select(ID).Where(&models.Execution{ + ExecutionKey: models.ExecutionKey{ + Project: input.Project, + Domain: input.Domain, + Name: input.Name, + }, + }).Take(&execution) + timer.Stop() + if tx.Error != nil { + return false, r.errorTransformer.ToFlyteAdminError(tx.Error) + } + return !tx.RecordNotFound(), nil +} + // Returns an instance of ExecutionRepoInterface func NewExecutionRepo( db *gorm.DB, errorTransformer errors.ErrorTransformer, scope promutils.Scope) interfaces.ExecutionRepoInterface { diff --git a/flyteadmin/pkg/repositories/gormimpl/execution_repo_test.go b/flyteadmin/pkg/repositories/gormimpl/execution_repo_test.go index 69a6a818f4..304987a964 100644 --- a/flyteadmin/pkg/repositories/gormimpl/execution_repo_test.go +++ b/flyteadmin/pkg/repositories/gormimpl/execution_repo_test.go @@ -164,7 +164,7 @@ func TestGetExecution(t *testing.T) { GlobalMock.NewMock().WithQuery(`SELECT * FROM "executions" WHERE "executions"."deleted_at" IS NULL AND ` + `(("executions"."execution_project" = project) AND ("executions"."execution_domain" = domain) AND ` + `("executions"."execution_name" = 1)) LIMIT 1`).WithReply(executions) - output, err := executionRepo.Get(context.Background(), interfaces.GetResourceInput{ + output, err := executionRepo.Get(context.Background(), interfaces.Identifier{ Project: "project", Domain: "domain", Name: "1", @@ -379,3 +379,39 @@ func TestListExecutionsForWorkflow(t *testing.T) { assert.Equal(t, time.Hour, execution.Duration) } } + +func TestExecutionExists(t *testing.T) { + executionRepo := NewExecutionRepo(GetDbForTest(t), errors.NewTestErrorTransformer(), mockScope.NewTestScope()) + expectedExecution := models.Execution{ + BaseModel: models.BaseModel{ + ID: uint(20), + }, + ExecutionKey: models.ExecutionKey{ + Project: "project", + Domain: "domain", + Name: "1", + }, + LaunchPlanID: uint(2), + Phase: core.WorkflowExecution_SUCCEEDED.String(), + Closure: []byte{1, 2}, + WorkflowID: uint(3), + Spec: []byte{3, 4}, + } + + executions := make([]map[string]interface{}, 0) + execution := getMockExecutionResponseFromDb(expectedExecution) + executions = append(executions, execution) + + GlobalMock := mocket.Catcher.Reset() + // Only match on queries that append expected filters + GlobalMock.NewMock().WithQuery(`SELECT id FROM "executions" WHERE "executions"."deleted_at" IS NULL AND ` + + `(("executions"."execution_project" = project) AND ("executions"."execution_domain" = domain) AND ` + + `("executions"."execution_name" = 1)) LIMIT 1`).WithReply(executions) + exists, err := executionRepo.Exists(context.Background(), interfaces.Identifier{ + Project: "project", + Domain: "domain", + Name: "1", + }) + assert.NoError(t, err) + assert.True(t, exists) +} diff --git a/flyteadmin/pkg/repositories/gormimpl/launch_plan_repo.go b/flyteadmin/pkg/repositories/gormimpl/launch_plan_repo.go index 957274df34..277033f0ba 100644 --- a/flyteadmin/pkg/repositories/gormimpl/launch_plan_repo.go +++ b/flyteadmin/pkg/repositories/gormimpl/launch_plan_repo.go @@ -48,7 +48,7 @@ func (r *LaunchPlanRepo) Update(ctx context.Context, input models.LaunchPlan) er return nil } -func (r *LaunchPlanRepo) Get(ctx context.Context, input interfaces.GetResourceInput) (models.LaunchPlan, error) { +func (r *LaunchPlanRepo) Get(ctx context.Context, input interfaces.Identifier) (models.LaunchPlan, error) { var launchPlan models.LaunchPlan timer := r.metrics.GetDuration.Start() tx := r.db.Where(&models.LaunchPlan{ diff --git a/flyteadmin/pkg/repositories/gormimpl/launch_plan_repo_test.go b/flyteadmin/pkg/repositories/gormimpl/launch_plan_repo_test.go index 7371ebea85..9ecdaf7e42 100644 --- a/flyteadmin/pkg/repositories/gormimpl/launch_plan_repo_test.go +++ b/flyteadmin/pkg/repositories/gormimpl/launch_plan_repo_test.go @@ -78,7 +78,7 @@ func TestGetLaunchPlan(t *testing.T) { `SELECT * FROM "launch_plans" WHERE "launch_plans"."deleted_at" IS NULL AND ` + `(("launch_plans"."project" = project) AND ("launch_plans"."domain" = domain) AND ` + `("launch_plans"."name" = name) AND ("launch_plans"."version" = XYZ)) LIMIT 1`).WithReply(launchPlans) - output, err := launchPlanRepo.Get(context.Background(), interfaces.GetResourceInput{ + output, err := launchPlanRepo.Get(context.Background(), interfaces.Identifier{ Project: project, Domain: domain, Name: name, diff --git a/flyteadmin/pkg/repositories/gormimpl/metrics.go b/flyteadmin/pkg/repositories/gormimpl/metrics.go index 49626efecc..f00225b4b8 100644 --- a/flyteadmin/pkg/repositories/gormimpl/metrics.go +++ b/flyteadmin/pkg/repositories/gormimpl/metrics.go @@ -15,6 +15,7 @@ type gormMetrics struct { ListDuration promutils.StopWatch ListIdentifiersDuration promutils.StopWatch DeleteDuration promutils.StopWatch + ExistsDuration promutils.StopWatch } func newMetrics(scope promutils.Scope) gormMetrics { @@ -31,5 +32,6 @@ func newMetrics(scope promutils.Scope) gormMetrics { ListIdentifiersDuration: scope.MustNewStopWatch( "list_identifiers", "time taken to list identifier entries", time.Millisecond), DeleteDuration: scope.MustNewStopWatch("delete", "time taken to delete an individual entry", time.Millisecond), + ExistsDuration: scope.MustNewStopWatch("exists", "time taken to determine whether an individual entry exists", time.Millisecond), } } diff --git a/flyteadmin/pkg/repositories/gormimpl/node_execution_repo.go b/flyteadmin/pkg/repositories/gormimpl/node_execution_repo.go index 65a33f1894..d3d1615efc 100644 --- a/flyteadmin/pkg/repositories/gormimpl/node_execution_repo.go +++ b/flyteadmin/pkg/repositories/gormimpl/node_execution_repo.go @@ -44,7 +44,7 @@ func (r *NodeExecutionRepo) Create(ctx context.Context, event *models.NodeExecut return nil } -func (r *NodeExecutionRepo) Get(ctx context.Context, input interfaces.GetNodeExecutionInput) (models.NodeExecution, error) { +func (r *NodeExecutionRepo) Get(ctx context.Context, input interfaces.NodeExecutionResource) (models.NodeExecution, error) { var nodeExecution models.NodeExecution timer := r.metrics.GetDuration.Start() tx := r.db.Where(&models.NodeExecution{ @@ -166,6 +166,26 @@ func (r *NodeExecutionRepo) ListEvents( }, nil } +func (r *NodeExecutionRepo) Exists(ctx context.Context, input interfaces.NodeExecutionResource) (bool, error) { + var nodeExecution models.NodeExecution + timer := r.metrics.ExistsDuration.Start() + tx := r.db.Select(ID).Where(&models.NodeExecution{ + NodeExecutionKey: models.NodeExecutionKey{ + NodeID: input.NodeExecutionIdentifier.NodeId, + ExecutionKey: models.ExecutionKey{ + Project: input.NodeExecutionIdentifier.ExecutionId.Project, + Domain: input.NodeExecutionIdentifier.ExecutionId.Domain, + Name: input.NodeExecutionIdentifier.ExecutionId.Name, + }, + }, + }).Take(&nodeExecution) + timer.Stop() + if tx.Error != nil { + return false, r.errorTransformer.ToFlyteAdminError(tx.Error) + } + return !tx.RecordNotFound(), nil +} + // Returns an instance of NodeExecutionRepoInterface func NewNodeExecutionRepo( db *gorm.DB, errorTransformer errors.ErrorTransformer, diff --git a/flyteadmin/pkg/repositories/gormimpl/node_execution_repo_test.go b/flyteadmin/pkg/repositories/gormimpl/node_execution_repo_test.go index 4ce7da445a..5dfcad1343 100644 --- a/flyteadmin/pkg/repositories/gormimpl/node_execution_repo_test.go +++ b/flyteadmin/pkg/repositories/gormimpl/node_execution_repo_test.go @@ -187,7 +187,7 @@ func TestGetNodeExecution(t *testing.T) { `(("node_executions"."execution_project" = execution_project) AND ("node_executions"."execution_domain" ` + `= execution_domain) AND ("node_executions"."execution_name" = execution_name) AND ("node_executions".` + `"node_id" = 1)) LIMIT 1`).WithReply(nodeExecutions) - output, err := nodeExecutionRepo.Get(context.Background(), interfaces.GetNodeExecutionInput{ + output, err := nodeExecutionRepo.Get(context.Background(), interfaces.NodeExecutionResource{ NodeExecutionIdentifier: core.NodeExecutionIdentifier{ NodeId: "1", ExecutionId: &core.WorkflowExecutionIdentifier{ @@ -432,3 +432,46 @@ func TestListNodeExecutionEvents_MissingParameters(t *testing.T) { }) assert.EqualError(t, err, "missing and/or invalid parameters: filters") } + +func TestNodeExecutionExists(t *testing.T) { + nodeExecutionRepo := NewNodeExecutionRepo(GetDbForTest(t), errors.NewTestErrorTransformer(), mockScope.NewTestScope()) + id := uint(10) + expectedNodeExecution := models.NodeExecution{ + NodeExecutionKey: models.NodeExecutionKey{ + NodeID: "1", + ExecutionKey: models.ExecutionKey{ + Project: "project", + Domain: "domain", + Name: "1", + }, + }, + BaseModel: models.BaseModel{ + ID: id, + }, + Phase: nodePhase, + Closure: []byte("closure"), + } + + nodeExecutions := make([]map[string]interface{}, 0) + nodeExecution := getMockNodeExecutionResponseFromDb(expectedNodeExecution) + nodeExecutions = append(nodeExecutions, nodeExecution) + + GlobalMock := mocket.Catcher.Reset() + GlobalMock.NewMock().WithQuery( + `SELECT id FROM "node_executions" WHERE "node_executions"."deleted_at" IS NULL AND ` + + `(("node_executions"."execution_project" = execution_project) AND ("node_executions"."execution_domain" = ` + + `execution_domain) AND ("node_executions"."execution_name" = execution_name) AND ` + + `("node_executions"."node_id" = 1)) LIMIT 1`).WithReply(nodeExecutions) + exists, err := nodeExecutionRepo.Exists(context.Background(), interfaces.NodeExecutionResource{ + NodeExecutionIdentifier: core.NodeExecutionIdentifier{ + NodeId: "1", + ExecutionId: &core.WorkflowExecutionIdentifier{ + Project: "execution_project", + Domain: "execution_domain", + Name: "execution_name", + }, + }, + }) + assert.NoError(t, err) + assert.True(t, exists) +} diff --git a/flyteadmin/pkg/repositories/gormimpl/task_repo.go b/flyteadmin/pkg/repositories/gormimpl/task_repo.go index bc6bd5a5ae..36c3e46376 100644 --- a/flyteadmin/pkg/repositories/gormimpl/task_repo.go +++ b/flyteadmin/pkg/repositories/gormimpl/task_repo.go @@ -30,7 +30,7 @@ func (r *TaskRepo) Create(ctx context.Context, input models.Task) error { return nil } -func (r *TaskRepo) Get(ctx context.Context, input interfaces.GetResourceInput) (models.Task, error) { +func (r *TaskRepo) Get(ctx context.Context, input interfaces.Identifier) (models.Task, error) { var task models.Task timer := r.metrics.GetDuration.Start() tx := r.db.Where(&models.Task{ diff --git a/flyteadmin/pkg/repositories/gormimpl/task_repo_test.go b/flyteadmin/pkg/repositories/gormimpl/task_repo_test.go index 0dd7a69186..b3cd88ae0a 100644 --- a/flyteadmin/pkg/repositories/gormimpl/task_repo_test.go +++ b/flyteadmin/pkg/repositories/gormimpl/task_repo_test.go @@ -58,7 +58,7 @@ func TestGetTask(t *testing.T) { `SELECT * FROM "tasks" WHERE "tasks"."deleted_at" IS NULL AND (("tasks"."project" = project) ` + `AND ("tasks"."domain" = domain) AND ("tasks"."name" = name) AND ("tasks"."version" = XYZ)) LIMIT 1`). WithReply(tasks) - output, err := taskRepo.Get(context.Background(), interfaces.GetResourceInput{ + output, err := taskRepo.Get(context.Background(), interfaces.Identifier{ Project: project, Domain: domain, Name: name, diff --git a/flyteadmin/pkg/repositories/gormimpl/workflow_repo.go b/flyteadmin/pkg/repositories/gormimpl/workflow_repo.go index 80ebce3db3..aecbd310b6 100644 --- a/flyteadmin/pkg/repositories/gormimpl/workflow_repo.go +++ b/flyteadmin/pkg/repositories/gormimpl/workflow_repo.go @@ -30,7 +30,7 @@ func (r *WorkflowRepo) Create(ctx context.Context, input models.Workflow) error return nil } -func (r *WorkflowRepo) Get(ctx context.Context, input interfaces.GetResourceInput) (models.Workflow, error) { +func (r *WorkflowRepo) Get(ctx context.Context, input interfaces.Identifier) (models.Workflow, error) { var workflow models.Workflow timer := r.metrics.GetDuration.Start() tx := r.db.Where(&models.Workflow{ diff --git a/flyteadmin/pkg/repositories/gormimpl/workflow_repo_test.go b/flyteadmin/pkg/repositories/gormimpl/workflow_repo_test.go index b53b44d750..68bb8b54e4 100644 --- a/flyteadmin/pkg/repositories/gormimpl/workflow_repo_test.go +++ b/flyteadmin/pkg/repositories/gormimpl/workflow_repo_test.go @@ -56,7 +56,7 @@ func TestGetWorkflow(t *testing.T) { GlobalMock.NewMock().WithQuery( `(("workflows"."project" = project) AND ("workflows"."domain" = domain) AND ` + `("workflows"."name" = name) AND ("workflows"."version" = XYZ))`).WithReply(workflows) - output, err := workflowRepo.Get(context.Background(), interfaces.GetResourceInput{ + output, err := workflowRepo.Get(context.Background(), interfaces.Identifier{ Project: project, Domain: domain, Name: name, diff --git a/flyteadmin/pkg/repositories/interfaces/common.go b/flyteadmin/pkg/repositories/interfaces/common.go index f66024615e..b2bb1222c2 100644 --- a/flyteadmin/pkg/repositories/interfaces/common.go +++ b/flyteadmin/pkg/repositories/interfaces/common.go @@ -5,7 +5,7 @@ import ( ) // Parameters for getting an individual resource. -type GetResourceInput struct { +type Identifier struct { Project string Domain string Name string diff --git a/flyteadmin/pkg/repositories/interfaces/execution_repo.go b/flyteadmin/pkg/repositories/interfaces/execution_repo.go index b414265de1..349a917c42 100644 --- a/flyteadmin/pkg/repositories/interfaces/execution_repo.go +++ b/flyteadmin/pkg/repositories/interfaces/execution_repo.go @@ -16,9 +16,11 @@ type ExecutionRepoInterface interface { // This updates only an existing execution model with all non-empty fields in the input. UpdateExecution(ctx context.Context, execution models.Execution) error // Returns a matching execution if it exists. - Get(ctx context.Context, input GetResourceInput) (models.Execution, error) + Get(ctx context.Context, input Identifier) (models.Execution, error) // Returns executions matching query parameters. A limit must be provided for the results page size. List(ctx context.Context, input ListResourceInput) (ExecutionCollectionOutput, error) + // Returns a matching execution if it exists. + Exists(ctx context.Context, input Identifier) (bool, error) } // Response format for a query on workflows. diff --git a/flyteadmin/pkg/repositories/interfaces/launch_plan_repo.go b/flyteadmin/pkg/repositories/interfaces/launch_plan_repo.go index ef1dd02296..61afa93993 100644 --- a/flyteadmin/pkg/repositories/interfaces/launch_plan_repo.go +++ b/flyteadmin/pkg/repositories/interfaces/launch_plan_repo.go @@ -18,7 +18,7 @@ type LaunchPlanRepoInterface interface { // (and deactivates the formerly active version if the toDisable model exists). SetActive(ctx context.Context, toEnable models.LaunchPlan, toDisable *models.LaunchPlan) error // Returns a matching launch plan if it exists. - Get(ctx context.Context, input GetResourceInput) (models.LaunchPlan, error) + Get(ctx context.Context, input Identifier) (models.LaunchPlan, error) // Returns launch plan revisions matching query parameters. A limit must be provided for the results page size. List(ctx context.Context, input ListResourceInput) (LaunchPlanCollectionOutput, error) // Returns a list of identifiers for launch plans. A limit must be provided for the results page size. diff --git a/flyteadmin/pkg/repositories/interfaces/node_execution_repo.go b/flyteadmin/pkg/repositories/interfaces/node_execution_repo.go index 415cdb358c..6c8aecef42 100644 --- a/flyteadmin/pkg/repositories/interfaces/node_execution_repo.go +++ b/flyteadmin/pkg/repositories/interfaces/node_execution_repo.go @@ -15,14 +15,16 @@ type NodeExecutionRepoInterface interface { // This execution and event correspond to entire graph (workflow) executions. Update(ctx context.Context, event *models.NodeExecutionEvent, execution *models.NodeExecution) error // Returns a matching execution if it exists. - Get(ctx context.Context, input GetNodeExecutionInput) (models.NodeExecution, error) + Get(ctx context.Context, input NodeExecutionResource) (models.NodeExecution, error) // Returns node executions matching query parameters. A limit must be provided for the results page size. List(ctx context.Context, input ListResourceInput) (NodeExecutionCollectionOutput, error) // Return node execution events matching query parameters. A limit must be provided for the results page size. ListEvents(ctx context.Context, input ListResourceInput) (NodeExecutionEventCollectionOutput, error) + // Returns whether a matching execution exists. + Exists(ctx context.Context, input NodeExecutionResource) (bool, error) } -type GetNodeExecutionInput struct { +type NodeExecutionResource struct { NodeExecutionIdentifier core.NodeExecutionIdentifier } diff --git a/flyteadmin/pkg/repositories/interfaces/task_repo.go b/flyteadmin/pkg/repositories/interfaces/task_repo.go index 0c44f2130d..f4d35377ef 100644 --- a/flyteadmin/pkg/repositories/interfaces/task_repo.go +++ b/flyteadmin/pkg/repositories/interfaces/task_repo.go @@ -11,7 +11,7 @@ type TaskRepoInterface interface { // Inserts a task model into the database store. Create(ctx context.Context, input models.Task) error // Returns a matching task if it exists. - Get(ctx context.Context, input GetResourceInput) (models.Task, error) + Get(ctx context.Context, input Identifier) (models.Task, error) // Returns task revisions matching query parameters. A limit must be provided for the results page size. List(ctx context.Context, input ListResourceInput) (TaskCollectionOutput, error) // Returns tasks with only the project, name, and domain filled in. diff --git a/flyteadmin/pkg/repositories/interfaces/workflow_repo.go b/flyteadmin/pkg/repositories/interfaces/workflow_repo.go index 65912fc98b..55206124ad 100644 --- a/flyteadmin/pkg/repositories/interfaces/workflow_repo.go +++ b/flyteadmin/pkg/repositories/interfaces/workflow_repo.go @@ -11,7 +11,7 @@ type WorkflowRepoInterface interface { // Inserts a workflow model into the database store. Create(ctx context.Context, input models.Workflow) error // Returns a matching workflow if it exists. - Get(ctx context.Context, input GetResourceInput) (models.Workflow, error) + Get(ctx context.Context, input Identifier) (models.Workflow, error) // Returns workflow revisions matching query parameters. A limit must be provided for the results page size. List(ctx context.Context, input ListResourceInput) (WorkflowCollectionOutput, error) ListIdentifiers(ctx context.Context, input ListResourceInput) (WorkflowCollectionOutput, error) diff --git a/flyteadmin/pkg/repositories/mocks/execution_repo.go b/flyteadmin/pkg/repositories/mocks/execution_repo.go index 0101b9e781..f06bfafa64 100644 --- a/flyteadmin/pkg/repositories/mocks/execution_repo.go +++ b/flyteadmin/pkg/repositories/mocks/execution_repo.go @@ -10,7 +10,7 @@ import ( type CreateExecutionFunc func(ctx context.Context, input models.Execution) error type UpdateFunc func(ctx context.Context, event models.ExecutionEvent, execution models.Execution) error type UpdateExecutionFunc func(ctx context.Context, execution models.Execution) error -type GetExecutionFunc func(ctx context.Context, input interfaces.GetResourceInput) (models.Execution, error) +type GetExecutionFunc func(ctx context.Context, input interfaces.Identifier) (models.Execution, error) type ListExecutionFunc func(ctx context.Context, input interfaces.ListResourceInput) ( interfaces.ExecutionCollectionOutput, error) @@ -20,6 +20,7 @@ type MockExecutionRepo struct { updateExecutionFunc UpdateExecutionFunc getFunction GetExecutionFunc listFunction ListExecutionFunc + ExistsFunction func(ctx context.Context, input interfaces.Identifier) (bool, error) } func (r *MockExecutionRepo) Create(ctx context.Context, input models.Execution) error { @@ -55,7 +56,7 @@ func (r *MockExecutionRepo) SetUpdateExecutionCallback(updateExecutionFunc Updat r.updateExecutionFunc = updateExecutionFunc } -func (r *MockExecutionRepo) Get(ctx context.Context, input interfaces.GetResourceInput) (models.Execution, error) { +func (r *MockExecutionRepo) Get(ctx context.Context, input interfaces.Identifier) (models.Execution, error) { if r.getFunction != nil { return r.getFunction(ctx, input) } @@ -78,6 +79,13 @@ func (r *MockExecutionRepo) SetListCallback(listFunction ListExecutionFunc) { r.listFunction = listFunction } +func (r *MockExecutionRepo) Exists(ctx context.Context, input interfaces.Identifier) (bool, error) { + if r.ExistsFunction != nil { + return r.ExistsFunction(ctx, input) + } + return true, nil +} + func NewMockExecutionRepo() interfaces.ExecutionRepoInterface { return &MockExecutionRepo{} } diff --git a/flyteadmin/pkg/repositories/mocks/launch_plan_repo.go b/flyteadmin/pkg/repositories/mocks/launch_plan_repo.go index eacbf902aa..1d53fdc224 100644 --- a/flyteadmin/pkg/repositories/mocks/launch_plan_repo.go +++ b/flyteadmin/pkg/repositories/mocks/launch_plan_repo.go @@ -11,7 +11,7 @@ import ( type CreateLaunchPlanFunc func(input models.LaunchPlan) error type UpdateLaunchPlanFunc func(input models.LaunchPlan) error type SetActiveLaunchPlanFunc func(toEnable models.LaunchPlan, toDisable *models.LaunchPlan) error -type GetLaunchPlanFunc func(input interfaces.GetResourceInput) (models.LaunchPlan, error) +type GetLaunchPlanFunc func(input interfaces.Identifier) (models.LaunchPlan, error) type ListLaunchPlanFunc func(input interfaces.ListResourceInput) (interfaces.LaunchPlanCollectionOutput, error) type ListLaunchPlanIdentifiersFunc func(input interfaces.ListResourceInput) ( interfaces.LaunchPlanCollectionOutput, error) @@ -60,7 +60,7 @@ func (r *MockLaunchPlanRepo) SetSetActiveCallback(setActiveFunction SetActiveLau } func (r *MockLaunchPlanRepo) Get( - ctx context.Context, input interfaces.GetResourceInput) (models.LaunchPlan, error) { + ctx context.Context, input interfaces.Identifier) (models.LaunchPlan, error) { if r.getFunction != nil { return r.getFunction(input) } diff --git a/flyteadmin/pkg/repositories/mocks/node_execution_repo.go b/flyteadmin/pkg/repositories/mocks/node_execution_repo.go index 0e0179fc79..fd821d557f 100644 --- a/flyteadmin/pkg/repositories/mocks/node_execution_repo.go +++ b/flyteadmin/pkg/repositories/mocks/node_execution_repo.go @@ -9,7 +9,7 @@ import ( type CreateNodeExecutionFunc func(ctx context.Context, event *models.NodeExecutionEvent, input *models.NodeExecution) error type UpdateNodeExecutionFunc func(ctx context.Context, event *models.NodeExecutionEvent, nodeExecution *models.NodeExecution) error -type GetNodeExecutionFunc func(ctx context.Context, input interfaces.GetNodeExecutionInput) (models.NodeExecution, error) +type GetNodeExecutionFunc func(ctx context.Context, input interfaces.NodeExecutionResource) (models.NodeExecution, error) type ListNodeExecutionFunc func(ctx context.Context, input interfaces.ListResourceInput) ( interfaces.NodeExecutionCollectionOutput, error) type ListNodeExecutionEventFunc func(ctx context.Context, input interfaces.ListResourceInput) ( @@ -21,6 +21,7 @@ type MockNodeExecutionRepo struct { getFunction GetNodeExecutionFunc listFunction ListNodeExecutionFunc listEventFunction ListNodeExecutionEventFunc + ExistsFunction func(ctx context.Context, input interfaces.NodeExecutionResource) (bool, error) } func (r *MockNodeExecutionRepo) Create(ctx context.Context, event *models.NodeExecutionEvent, input *models.NodeExecution) error { @@ -45,7 +46,7 @@ func (r *MockNodeExecutionRepo) SetUpdateCallback(updateFunction UpdateNodeExecu r.updateFunction = updateFunction } -func (r *MockNodeExecutionRepo) Get(ctx context.Context, input interfaces.GetNodeExecutionInput) (models.NodeExecution, error) { +func (r *MockNodeExecutionRepo) Get(ctx context.Context, input interfaces.NodeExecutionResource) (models.NodeExecution, error) { if r.getFunction != nil { return r.getFunction(ctx, input) } @@ -80,6 +81,13 @@ func (r *MockNodeExecutionRepo) SetListEventCallback(listEventFunction ListNodeE r.listEventFunction = listEventFunction } +func (r *MockNodeExecutionRepo) Exists(ctx context.Context, input interfaces.NodeExecutionResource) (bool, error) { + if r.ExistsFunction != nil { + return r.ExistsFunction(ctx, input) + } + return true, nil +} + func NewMockNodeExecutionRepo() interfaces.NodeExecutionRepoInterface { return &MockNodeExecutionRepo{} } diff --git a/flyteadmin/pkg/repositories/mocks/task_repo.go b/flyteadmin/pkg/repositories/mocks/task_repo.go index 6ca6f7406c..5bbefedb9a 100644 --- a/flyteadmin/pkg/repositories/mocks/task_repo.go +++ b/flyteadmin/pkg/repositories/mocks/task_repo.go @@ -9,7 +9,7 @@ import ( ) type CreateTaskFunc func(input models.Task) error -type GetTaskFunc func(input interfaces.GetResourceInput) (models.Task, error) +type GetTaskFunc func(input interfaces.Identifier) (models.Task, error) type ListTaskFunc func(input interfaces.ListResourceInput) (interfaces.TaskCollectionOutput, error) type ListTaskIdentifiersFunc func(input interfaces.ListResourceInput) (interfaces.TaskCollectionOutput, error) @@ -31,7 +31,7 @@ func (r *MockTaskRepo) SetCreateCallback(createFunction CreateTaskFunc) { r.createFunction = createFunction } -func (r *MockTaskRepo) Get(ctx context.Context, input interfaces.GetResourceInput) (models.Task, error) { +func (r *MockTaskRepo) Get(ctx context.Context, input interfaces.Identifier) (models.Task, error) { if r.getFunction != nil { return r.getFunction(input) } diff --git a/flyteadmin/pkg/repositories/mocks/workflow_repo.go b/flyteadmin/pkg/repositories/mocks/workflow_repo.go index b0775b996a..d029f52a33 100644 --- a/flyteadmin/pkg/repositories/mocks/workflow_repo.go +++ b/flyteadmin/pkg/repositories/mocks/workflow_repo.go @@ -9,7 +9,7 @@ import ( ) type CreateWorkflowFunc func(input models.Workflow) error -type GetWorkflowFunc func(input interfaces.GetResourceInput) (models.Workflow, error) +type GetWorkflowFunc func(input interfaces.Identifier) (models.Workflow, error) type ListWorkflowFunc func(input interfaces.ListResourceInput) (interfaces.WorkflowCollectionOutput, error) type ListIdentifiersFunc func(input interfaces.ListResourceInput) (interfaces.WorkflowCollectionOutput, error) @@ -31,7 +31,7 @@ func (r *MockWorkflowRepo) SetCreateCallback(createFunction CreateWorkflowFunc) r.createFunction = createFunction } -func (r *MockWorkflowRepo) Get(ctx context.Context, input interfaces.GetResourceInput) (models.Workflow, error) { +func (r *MockWorkflowRepo) Get(ctx context.Context, input interfaces.Identifier) (models.Workflow, error) { if r.getFunction != nil { return r.getFunction(input) }