Skip to content

Commit

Permalink
Add exists check for workflow & node executions (flyteorg#172)
Browse files Browse the repository at this point in the history
  • Loading branch information
Katrina Rogan authored Mar 31, 2021
1 parent 7e13a8d commit 3ee1fd2
Show file tree
Hide file tree
Showing 37 changed files with 326 additions and 153 deletions.
4 changes: 2 additions & 2 deletions flyteadmin/pkg/manager/impl/execution_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ func (m *ExecutionManager) launchSingleTaskExecution(
ctx context.Context, request admin.ExecutionCreateRequest, requestedAt time.Time) (
context.Context, *models.Execution, error) {

taskModel, err := m.db.TaskRepo().Get(ctx, repositoryInterfaces.GetResourceInput{
taskModel, err := m.db.TaskRepo().Get(ctx, repositoryInterfaces.Identifier{
Project: request.Spec.LaunchPlan.Project,
Domain: request.Spec.LaunchPlan.Domain,
Name: request.Spec.LaunchPlan.Name,
Expand Down Expand Up @@ -1264,7 +1264,7 @@ func (m *ExecutionManager) TerminateExecution(
}
ctx = getExecutionContext(ctx, request.Id)
// Save the abort reason (best effort)
executionModel, err := m.db.ExecutionRepo().Get(ctx, repositoryInterfaces.GetResourceInput{
executionModel, err := m.db.ExecutionRepo().Get(ctx, repositoryInterfaces.Identifier{
Project: request.Id.Project,
Domain: request.Id.Domain,
Name: request.Id.Name,
Expand Down
42 changes: 21 additions & 21 deletions flyteadmin/pkg/manager/impl/execution_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ func setDefaultLpCallbackForExecTest(repository repositories.RepositoryInterface
}
lpClosureBytes, _ := proto.Marshal(&lpClosure)

lpGetFunc := func(input interfaces.GetResourceInput) (models.LaunchPlan, error) {
lpGetFunc := func(input interfaces.Identifier) (models.LaunchPlan, error) {
lpModel := models.LaunchPlan{
LaunchPlanKey: models.LaunchPlanKey{
Project: input.Project,
Expand Down Expand Up @@ -179,7 +179,7 @@ func getMockStorageForExecTest(ctx context.Context) *storage.DataStore {
func getMockRepositoryForExecTest() repositories.RepositoryInterface {
repository := repositoryMocks.NewMockRepository()
repository.WorkflowRepo().(*repositoryMocks.MockWorkflowRepo).SetGetCallback(
func(input interfaces.GetResourceInput) (models.Workflow, error) {
func(input interfaces.Identifier) (models.Workflow, error) {
return models.Workflow{
BaseModel: models.BaseModel{
CreatedAt: testutils.MockCreatedAtValue,
Expand Down Expand Up @@ -289,7 +289,7 @@ func TestCreateExecutionFromWorkflowNode(t *testing.T) {

getNodeExecutionCalled := false
repository.NodeExecutionRepo().(*repositoryMocks.MockNodeExecutionRepo).SetGetCallback(
func(ctx context.Context, input interfaces.GetNodeExecutionInput) (models.NodeExecution, error) {
func(ctx context.Context, input interfaces.NodeExecutionResource) (models.NodeExecution, error) {
assert.EqualValues(t, input.NodeExecutionIdentifier, parentNodeExecutionID)
getNodeExecutionCalled = true
return models.NodeExecution{
Expand All @@ -301,7 +301,7 @@ func TestCreateExecutionFromWorkflowNode(t *testing.T) {
)
getExecutionCalled := false
repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetGetCallback(
func(ctx context.Context, input interfaces.GetResourceInput) (models.Execution, error) {
func(ctx context.Context, input interfaces.Identifier) (models.Execution, error) {
assert.EqualValues(t, input.Project, parentNodeExecutionID.ExecutionId.Project)
assert.EqualValues(t, input.Domain, parentNodeExecutionID.ExecutionId.Domain)
assert.EqualValues(t, input.Name, parentNodeExecutionID.ExecutionId.Name)
Expand Down Expand Up @@ -666,7 +666,7 @@ func TestCreateExecutionNoNotifications(t *testing.T) {

// The LaunchPlan is retrieved within the CreateExecution call to ExecutionManager.
// Create a callback method used by the mock to retrieve a LaunchPlan.
lpGetFunc := func(input interfaces.GetResourceInput) (models.LaunchPlan, error) {
lpGetFunc := func(input interfaces.Identifier) (models.LaunchPlan, error) {
lpModel := models.LaunchPlan{
LaunchPlanKey: models.LaunchPlanKey{
Project: input.Project,
Expand Down Expand Up @@ -754,7 +754,7 @@ func TestCreateExecutionDynamicLabelsAndAnnotations(t *testing.T) {

func makeExecutionGetFunc(
t *testing.T, closureBytes []byte, startTime *time.Time) repositoryMocks.GetExecutionFunc {
return func(ctx context.Context, input interfaces.GetResourceInput) (models.Execution, error) {
return func(ctx context.Context, input interfaces.Identifier) (models.Execution, error) {
assert.Equal(t, "project", input.Project)
assert.Equal(t, "domain", input.Domain)
assert.Equal(t, "name", input.Name)
Expand All @@ -780,7 +780,7 @@ func makeExecutionGetFunc(

func makeLegacyExecutionGetFunc(
t *testing.T, closureBytes []byte, startTime *time.Time) repositoryMocks.GetExecutionFunc {
return func(ctx context.Context, input interfaces.GetResourceInput) (models.Execution, error) {
return func(ctx context.Context, input interfaces.Identifier) (models.Execution, error) {
assert.Equal(t, "project", input.Project)
assert.Equal(t, "domain", input.Domain)
assert.Equal(t, "name", input.Name)
Expand Down Expand Up @@ -869,7 +869,7 @@ func TestRelaunchExecution_GetExistingFailure(t *testing.T) {

expectedErr := errors.New("expected error")
repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetGetCallback(
func(ctx context.Context, input interfaces.GetResourceInput) (models.Execution, error) {
func(ctx context.Context, input interfaces.Identifier) (models.Execution, error) {
return models.Execution{}, expectedErr
})

Expand Down Expand Up @@ -1002,7 +1002,7 @@ func TestCreateWorkflowEvent(t *testing.T) {

func TestCreateWorkflowEvent_TerminalState(t *testing.T) {
repository := repositoryMocks.NewMockRepository()
executionGetFunc := func(ctx context.Context, input interfaces.GetResourceInput) (models.Execution, error) {
executionGetFunc := func(ctx context.Context, input interfaces.Identifier) (models.Execution, error) {
return models.Execution{
ExecutionKey: models.ExecutionKey{
Project: "project",
Expand Down Expand Up @@ -1096,7 +1096,7 @@ func TestCreateWorkflowEvent_DuplicateRunning(t *testing.T) {
occurredAt := time.Now().UTC()

repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetGetCallback(
func(ctx context.Context, input interfaces.GetResourceInput) (models.Execution, error) {
func(ctx context.Context, input interfaces.Identifier) (models.Execution, error) {
return models.Execution{
ExecutionKey: models.ExecutionKey{
Project: "project",
Expand Down Expand Up @@ -1137,7 +1137,7 @@ func TestCreateWorkflowEvent_InvalidPhaseChange(t *testing.T) {
occurredAt := time.Now().UTC()

repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetGetCallback(
func(ctx context.Context, input interfaces.GetResourceInput) (models.Execution, error) {
func(ctx context.Context, input interfaces.Identifier) (models.Execution, error) {
return models.Execution{
ExecutionKey: models.ExecutionKey{
Project: "project",
Expand Down Expand Up @@ -1240,7 +1240,7 @@ func TestCreateWorkflowEvent_DatabaseGetError(t *testing.T) {
startTime := time.Now()

expectedErr := errors.New("expected error")
executionGetFunc := func(ctx context.Context, input interfaces.GetResourceInput) (models.Execution, error) {
executionGetFunc := func(ctx context.Context, input interfaces.Identifier) (models.Execution, error) {
return models.Execution{}, expectedErr
}
repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetGetCallback(executionGetFunc)
Expand Down Expand Up @@ -1304,7 +1304,7 @@ func TestCreateWorkflowEvent_DatabaseUpdateError(t *testing.T) {
func TestGetExecution(t *testing.T) {
repository := repositoryMocks.NewMockRepository()
startedAt := time.Date(2018, 8, 30, 0, 0, 0, 0, time.UTC)
executionGetFunc := func(ctx context.Context, input interfaces.GetResourceInput) (models.Execution, error) {
executionGetFunc := func(ctx context.Context, input interfaces.Identifier) (models.Execution, error) {
assert.Equal(t, "project", input.Project)
assert.Equal(t, "domain", input.Domain)
assert.Equal(t, "name", input.Name)
Expand Down Expand Up @@ -1338,7 +1338,7 @@ func TestGetExecution_DatabaseError(t *testing.T) {
repository := repositoryMocks.NewMockRepository()
expectedErr := errors.New("expected error")

executionGetFunc := func(ctx context.Context, input interfaces.GetResourceInput) (models.Execution, error) {
executionGetFunc := func(ctx context.Context, input interfaces.Identifier) (models.Execution, error) {
assert.Equal(t, "project", input.Project)
assert.Equal(t, "domain", input.Domain)
assert.Equal(t, "name", input.Name)
Expand All @@ -1356,7 +1356,7 @@ func TestGetExecution_DatabaseError(t *testing.T) {
func TestGetExecution_TransformerError(t *testing.T) {
repository := repositoryMocks.NewMockRepository()
startedAt := time.Date(2018, 8, 30, 0, 0, 0, 0, time.UTC)
executionGetFunc := func(ctx context.Context, input interfaces.GetResourceInput) (models.Execution, error) {
executionGetFunc := func(ctx context.Context, input interfaces.Identifier) (models.Execution, error) {
assert.Equal(t, "project", input.Project)
assert.Equal(t, "domain", input.Domain)
assert.Equal(t, "name", input.Name)
Expand Down Expand Up @@ -1927,7 +1927,7 @@ func TestGetExecutionData(t *testing.T) {
}
var closureBytes, _ = proto.Marshal(&closure)

executionGetFunc := func(ctx context.Context, input interfaces.GetResourceInput) (models.Execution, error) {
executionGetFunc := func(ctx context.Context, input interfaces.Identifier) (models.Execution, error) {
return models.Execution{
ExecutionKey: models.ExecutionKey{
Project: "project",
Expand Down Expand Up @@ -2120,7 +2120,7 @@ func TestPluginOverrides_ResourceGetFailure(t *testing.T) {
func TestGetExecution_Legacy(t *testing.T) {
repository := repositoryMocks.NewMockRepository()
startedAt := time.Date(2018, 8, 30, 0, 0, 0, 0, time.UTC)
executionGetFunc := func(ctx context.Context, input interfaces.GetResourceInput) (models.Execution, error) {
executionGetFunc := func(ctx context.Context, input interfaces.Identifier) (models.Execution, error) {
assert.Equal(t, "project", input.Project)
assert.Equal(t, "domain", input.Domain)
assert.Equal(t, "name", input.Name)
Expand Down Expand Up @@ -2152,7 +2152,7 @@ func TestGetExecution_Legacy(t *testing.T) {
func TestGetExecution_LegacyClient_OffloadedData(t *testing.T) {
repository := repositoryMocks.NewMockRepository()
startedAt := time.Date(2018, 8, 30, 0, 0, 0, 0, time.UTC)
executionGetFunc := func(ctx context.Context, input interfaces.GetResourceInput) (models.Execution, error) {
executionGetFunc := func(ctx context.Context, input interfaces.Identifier) (models.Execution, error) {
assert.Equal(t, "project", input.Project)
assert.Equal(t, "domain", input.Domain)
assert.Equal(t, "name", input.Name)
Expand Down Expand Up @@ -2199,7 +2199,7 @@ func TestGetExecutionData_LegacyModel(t *testing.T) {
}
var closureBytes, _ = proto.Marshal(closure)

executionGetFunc := func(ctx context.Context, input interfaces.GetResourceInput) (models.Execution, error) {
executionGetFunc := func(ctx context.Context, input interfaces.Identifier) (models.Execution, error) {
return models.Execution{
ExecutionKey: models.ExecutionKey{
Project: "project",
Expand Down Expand Up @@ -2742,7 +2742,7 @@ func TestCreateSingleTaskExecution(t *testing.T) {
}
repository.WorkflowRepo().(*repositoryMocks.MockWorkflowRepo).SetCreateCallback(workflowcreateFunc)

workflowGetFunc := func(input interfaces.GetResourceInput) (models.Workflow, error) {
workflowGetFunc := func(input interfaces.Identifier) (models.Workflow, error) {
if getCalledCount <= 1 {
getCalledCount++
return models.Workflow{}, flyteAdminErrors.NewFlyteAdminErrorf(codes.NotFound, "not found")
Expand All @@ -2752,7 +2752,7 @@ func TestCreateSingleTaskExecution(t *testing.T) {
}
repository.WorkflowRepo().(*repositoryMocks.MockWorkflowRepo).SetGetCallback(workflowGetFunc)
repository.TaskRepo().(*repositoryMocks.MockTaskRepo).SetGetCallback(
func(input interfaces.GetResourceInput) (models.Task, error) {
func(input interfaces.Identifier) (models.Task, error) {
createdAt := time.Now()
createdAtProto, _ := ptypes.TimestampProto(createdAt)
taskClosure := &admin.TaskClosure{
Expand Down
2 changes: 1 addition & 1 deletion flyteadmin/pkg/manager/impl/launch_plan_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ func (m *LaunchPlanManager) disableLaunchPlan(ctx context.Context, request admin

func (m *LaunchPlanManager) enableLaunchPlan(ctx context.Context, request admin.LaunchPlanUpdateRequest) (
*admin.LaunchPlanUpdateResponse, error) {
newlyActiveLaunchPlanModel, err := m.db.LaunchPlanRepo().Get(ctx, repoInterfaces.GetResourceInput{
newlyActiveLaunchPlanModel, err := m.db.LaunchPlanRepo().Get(ctx, repoInterfaces.Identifier{
Project: request.Id.Project,
Domain: request.Id.Domain,
Name: request.Id.Name,
Expand Down
24 changes: 12 additions & 12 deletions flyteadmin/pkg/manager/impl/launch_plan_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func getMockConfigForLpTest() runtimeInterfaces.Configuration {
func setDefaultWorkflowCallbackForLpTest(repository repositories.RepositoryInterface) {
workflowSpec := testutils.GetSampleWorkflowSpecForTest()
typedInterface, _ := proto.Marshal(workflowSpec.Template.Interface)
workflowGetFunc := func(input interfaces.GetResourceInput) (models.Workflow, error) {
workflowGetFunc := func(input interfaces.Identifier) (models.Workflow, error) {
return models.Workflow{
WorkflowKey: models.WorkflowKey{
Project: input.Project,
Expand All @@ -80,7 +80,7 @@ func setDefaultWorkflowCallbackForLpTest(repository repositories.RepositoryInter
func TestCreateLaunchPlan(t *testing.T) {
repository := getMockRepositoryForLpTest()
repository.LaunchPlanRepo().(*repositoryMocks.MockLaunchPlanRepo).SetGetCallback(
func(input interfaces.GetResourceInput) (models.LaunchPlan, error) {
func(input interfaces.Identifier) (models.LaunchPlan, error) {
return models.LaunchPlan{}, errors.New("foo")
})
var createCalled bool
Expand Down Expand Up @@ -116,7 +116,7 @@ func TestLaunchPlanManager_GetLaunchPlan(t *testing.T) {
specBytes, _ := proto.Marshal(lpRequest.Spec)
closureBytes, _ := proto.Marshal(&closure)

launchPlanGetFunc := func(input interfaces.GetResourceInput) (models.LaunchPlan, error) {
launchPlanGetFunc := func(input interfaces.Identifier) (models.LaunchPlan, error) {
return models.LaunchPlan{
LaunchPlanKey: models.LaunchPlanKey{
Project: input.Project,
Expand Down Expand Up @@ -242,7 +242,7 @@ func TestLaunchPlan_ValidationError(t *testing.T) {
func TestLaunchPlan_DatabaseError(t *testing.T) {
repository := getMockRepositoryForLpTest()
repository.LaunchPlanRepo().(*repositoryMocks.MockLaunchPlanRepo).SetGetCallback(
func(input interfaces.GetResourceInput) (models.LaunchPlan, error) {
func(input interfaces.Identifier) (models.LaunchPlan, error) {
return models.LaunchPlan{}, errors.New("foo")
})
setDefaultWorkflowCallbackForLpTest(repository)
Expand Down Expand Up @@ -281,7 +281,7 @@ func TestCreateLaunchPlanInCompatibleInputs(t *testing.T) {
func TestCreateLaunchPlanValidateCreate(t *testing.T) {
repository := getMockRepositoryForLpTest()
repository.LaunchPlanRepo().(*repositoryMocks.MockLaunchPlanRepo).SetGetCallback(
func(input interfaces.GetResourceInput) (models.LaunchPlan, error) {
func(input interfaces.Identifier) (models.LaunchPlan, error) {
return models.LaunchPlan{}, errors.New("foo")
})
setDefaultWorkflowCallbackForLpTest(repository)
Expand Down Expand Up @@ -324,10 +324,10 @@ func TestCreateLaunchPlanValidateCreate(t *testing.T) {
func TestCreateLaunchPlanNoWorkflowInterface(t *testing.T) {
repository := getMockRepositoryForLpTest()
repository.LaunchPlanRepo().(*repositoryMocks.MockLaunchPlanRepo).SetGetCallback(
func(input interfaces.GetResourceInput) (models.LaunchPlan, error) {
func(input interfaces.Identifier) (models.LaunchPlan, error) {
return models.LaunchPlan{}, errors.New("foo")
})
workflowGetFunc := func(input interfaces.GetResourceInput) (models.Workflow, error) {
workflowGetFunc := func(input interfaces.Identifier) (models.Workflow, error) {
return models.Workflow{
WorkflowKey: models.WorkflowKey{
Project: input.Project,
Expand Down Expand Up @@ -365,7 +365,7 @@ func TestCreateLaunchPlanNoWorkflowInterface(t *testing.T) {
}

func makeLaunchPlanRepoGetCallback(t *testing.T) repositoryMocks.GetLaunchPlanFunc {
return func(input interfaces.GetResourceInput) (models.LaunchPlan, error) {
return func(input interfaces.Identifier) (models.LaunchPlan, error) {
assert.Equal(t, project, input.Project)
assert.Equal(t, domain, input.Domain)
assert.Equal(t, name, input.Name)
Expand Down Expand Up @@ -756,7 +756,7 @@ func TestUpdateSchedules_EnableNoSchedule(t *testing.T) {
func TestDisableLaunchPlan(t *testing.T) {
repository := getMockRepositoryForLpTest()

lpGetFunc := func(input interfaces.GetResourceInput) (models.LaunchPlan, error) {
lpGetFunc := func(input interfaces.Identifier) (models.LaunchPlan, error) {
assert.Equal(t, project, input.Project)
assert.Equal(t, domain, input.Domain)
assert.Equal(t, name, input.Name)
Expand Down Expand Up @@ -816,7 +816,7 @@ func TestDisableLaunchPlan_DatabaseError(t *testing.T) {
repository := getMockRepositoryForLpTest()
expectedError := errors.New("expected error")

lpGetFunc := func(input interfaces.GetResourceInput) (models.LaunchPlan, error) {
lpGetFunc := func(input interfaces.Identifier) (models.LaunchPlan, error) {
assert.Equal(t, project, input.Project)
assert.Equal(t, domain, input.Domain)
assert.Equal(t, name, input.Name)
Expand All @@ -832,7 +832,7 @@ func TestDisableLaunchPlan_DatabaseError(t *testing.T) {
assert.EqualError(t, err, expectedError.Error(),
"Failures on getting the existing launch plan should propagate")

lpGetFunc = func(input interfaces.GetResourceInput) (models.LaunchPlan, error) {
lpGetFunc = func(input interfaces.Identifier) (models.LaunchPlan, error) {
assert.Equal(t, project, input.Project)
assert.Equal(t, domain, input.Domain)
assert.Equal(t, name, input.Name)
Expand Down Expand Up @@ -955,7 +955,7 @@ func TestEnableLaunchPlan_DatabaseError(t *testing.T) {
repository := getMockRepositoryForLpTest()
expectedError := errors.New("expected error")

lpGetFunc := func(input interfaces.GetResourceInput) (models.LaunchPlan, error) {
lpGetFunc := func(input interfaces.Identifier) (models.LaunchPlan, error) {
assert.Equal(t, project, input.Project)
assert.Equal(t, domain, input.Domain)
assert.Equal(t, name, input.Name)
Expand Down
32 changes: 19 additions & 13 deletions flyteadmin/pkg/manager/impl/node_execution_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,24 @@ func getNodeExecutionContext(ctx context.Context, identifier *core.NodeExecution
func (m *NodeExecutionManager) createNodeExecutionWithEvent(
ctx context.Context, request *admin.NodeExecutionEventRequest) error {

executionID := request.Event.Id.ExecutionId
workflowExecutionExists, err := m.db.ExecutionRepo().Exists(ctx, repoInterfaces.Identifier{
Project: executionID.Project,
Domain: executionID.Domain,
Name: executionID.Name,
})
if err != nil || !workflowExecutionExists {
m.metrics.MissingWorkflowExecution.Inc()
logger.Debugf(ctx, "Failed to find existing execution with id [%+v] with err: %v", executionID, err)
if err != nil {
if ferr, ok := err.(errors.FlyteAdminError); ok {
return errors.NewFlyteAdminErrorf(ferr.Code(),
"Failed to get existing execution id: [%+v] with err: %v", executionID, err)
}
}
return fmt.Errorf("failed to get existing execution id: [%+v]", executionID)
}

var parentTaskExecutionID uint
if request.Event.ParentTaskMetadata != nil {
taskExecutionModel, err := util.GetTaskExecutionModel(ctx, m.db, request.Event.ParentTaskMetadata.Id)
Expand Down Expand Up @@ -186,22 +204,10 @@ func (m *NodeExecutionManager) CreateNodeEvent(ctx context.Context, request admi
logger.Debugf(ctx, "CreateNodeEvent called with invalid identifier [%+v]: %v", request.Event.Id, err)
}
ctx = getNodeExecutionContext(ctx, request.Event.Id)
executionID := request.Event.Id.ExecutionId
logger.Debugf(ctx, "Received node execution event for Node Exec Id [%+v] transitioning to phase [%v], w/ Metadata [%v]",
request.Event.Id, request.Event.Phase, request.Event.ParentTaskMetadata)

_, err := util.GetExecutionModel(ctx, m.db, *executionID)
if err != nil {
m.metrics.MissingWorkflowExecution.Inc()
logger.Debugf(ctx, "Failed to find existing execution with id [%+v] with err: %v", executionID, err)
if ferr, ok := err.(errors.FlyteAdminError); ok {
return nil, errors.NewFlyteAdminErrorf(ferr.Code(),
"Failed to get existing execution id:[%+v] with err: %v", executionID, err)
}
return nil, fmt.Errorf("failed to get existing execution id: [%+v] with err: %v", executionID, err)
}

nodeExecutionModel, err := m.db.NodeExecutionRepo().Get(ctx, repoInterfaces.GetNodeExecutionInput{
nodeExecutionModel, err := m.db.NodeExecutionRepo().Get(ctx, repoInterfaces.NodeExecutionResource{
NodeExecutionIdentifier: *request.Event.Id,
})
if err != nil {
Expand Down
Loading

0 comments on commit 3ee1fd2

Please sign in to comment.