Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove unnecessary joins for list node and task execution entities in flyteadmin db queries #5935

Merged
merged 3 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@
func (c *CloudEventWrappedPublisher) getLatestTaskExecutions(ctx context.Context, nodeExecutionID *core.NodeExecutionIdentifier) (*admin.TaskExecution, error) {
ctx = getNodeExecutionContext(ctx, nodeExecutionID)

identifierFilters, err := util.GetNodeExecutionIdentifierFilters(ctx, nodeExecutionID)
identifierFilters, err := util.GetNodeExecutionIdentifierFilters(ctx, nodeExecutionID, common.TaskExecution)

Check warning on line 210 in flyteadmin/pkg/async/cloudevent/implementations/cloudevent_publisher.go

View check run for this annotation

Codecov / codecov/patch

flyteadmin/pkg/async/cloudevent/implementations/cloudevent_publisher.go#L210

Added line #L210 was not covered by tests
if err != nil {
return nil, err
}
Expand Down
9 changes: 8 additions & 1 deletion flyteadmin/pkg/common/filters.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,13 @@ var executionIdentifierFields = map[string]bool{
"name": true,
}

// Entities that have special case handling for execution identifier fields.
var executionIdentifierEntities = map[Entity]bool{
Execution: true,
NodeExecution: true,
TaskExecution: true,
}

var entityMetadataFields = map[string]bool{
"description": true,
"state": true,
Expand Down Expand Up @@ -253,7 +260,7 @@ func (f *inlineFilterImpl) GetGormJoinTableQueryExpr(tableName string) (GormQuer

func customizeField(field string, entity Entity) string {
// Execution identifier fields have to be customized because we differ from convention in those column names.
if entity == Execution && executionIdentifierFields[field] {
if executionIdentifierEntities[entity] && executionIdentifierFields[field] {
return fmt.Sprintf("execution_%s", field)
}
// admin_tag table has been migrated to an execution_tag table, so we need to customize the field name.
Expand Down
17 changes: 11 additions & 6 deletions flyteadmin/pkg/manager/impl/node_execution_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -407,11 +407,16 @@ func (m *NodeExecutionManager) listNodeExecutions(
return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument,
"invalid pagination token %s for ListNodeExecutions", requestToken)
}
joinTableEntities := make(map[common.Entity]bool)
for _, filter := range filters {
joinTableEntities[filter.GetEntity()] = true
}
listInput := repoInterfaces.ListResourceInput{
Limit: int(limit),
Offset: offset,
InlineFilters: filters,
SortParameter: sortParameter,
Limit: int(limit),
Offset: offset,
InlineFilters: filters,
SortParameter: sortParameter,
JoinTableEntities: joinTableEntities,
}

listInput.MapFilters = mapFilters
Expand Down Expand Up @@ -445,7 +450,7 @@ func (m *NodeExecutionManager) ListNodeExecutions(
}
ctx = getExecutionContext(ctx, request.WorkflowExecutionId)

identifierFilters, err := util.GetWorkflowExecutionIdentifierFilters(ctx, request.WorkflowExecutionId)
identifierFilters, err := util.GetWorkflowExecutionIdentifierFilters(ctx, request.WorkflowExecutionId, common.NodeExecution)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -483,7 +488,7 @@ func (m *NodeExecutionManager) ListNodeExecutionsForTask(
}
ctx = getTaskExecutionContext(ctx, request.TaskExecutionId)
identifierFilters, err := util.GetWorkflowExecutionIdentifierFilters(
ctx, request.TaskExecutionId.NodeExecutionId.ExecutionId)
ctx, request.TaskExecutionId.NodeExecutionId.ExecutionId, common.NodeExecution)
if err != nil {
return nil, err
}
Expand Down
145 changes: 136 additions & 9 deletions flyteadmin/pkg/manager/impl/node_execution_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -784,17 +784,17 @@ func TestListNodeExecutionsLevelZero(t *testing.T) {
assert.Equal(t, 1, input.Limit)
assert.Equal(t, 2, input.Offset)
assert.Len(t, input.InlineFilters, 3)
assert.Equal(t, common.Execution, input.InlineFilters[0].GetEntity())
assert.Equal(t, common.NodeExecution, input.InlineFilters[0].GetEntity())
queryExpr, _ := input.InlineFilters[0].GetGormQueryExpr()
assert.Equal(t, "project", queryExpr.Args)
assert.Equal(t, "execution_project = ?", queryExpr.Query)

assert.Equal(t, common.Execution, input.InlineFilters[1].GetEntity())
assert.Equal(t, common.NodeExecution, input.InlineFilters[1].GetEntity())
queryExpr, _ = input.InlineFilters[1].GetGormQueryExpr()
assert.Equal(t, "domain", queryExpr.Args)
assert.Equal(t, "execution_domain = ?", queryExpr.Query)

assert.Equal(t, common.Execution, input.InlineFilters[2].GetEntity())
assert.Equal(t, common.NodeExecution, input.InlineFilters[2].GetEntity())
queryExpr, _ = input.InlineFilters[2].GetGormQueryExpr()
assert.Equal(t, "name", queryExpr.Args)
assert.Equal(t, "execution_name = ?", queryExpr.Query)
Expand All @@ -806,6 +806,10 @@ func TestListNodeExecutionsLevelZero(t *testing.T) {
"parent_task_execution_id": nil,
}, filter)

assert.EqualValues(t, input.JoinTableEntities, map[common.Entity]bool{
common.NodeExecution: true,
})

assert.Equal(t, "execution_domain asc", input.SortParameter.GetGormOrderExpr())
return interfaces.NodeExecutionCollectionOutput{
NodeExecutions: []models.NodeExecution{
Expand Down Expand Up @@ -904,17 +908,17 @@ func TestListNodeExecutionsWithParent(t *testing.T) {
assert.Equal(t, 1, input.Limit)
assert.Equal(t, 2, input.Offset)
assert.Len(t, input.InlineFilters, 4)
assert.Equal(t, common.Execution, input.InlineFilters[0].GetEntity())
assert.Equal(t, common.NodeExecution, input.InlineFilters[0].GetEntity())
queryExpr, _ := input.InlineFilters[0].GetGormQueryExpr()
assert.Equal(t, "project", queryExpr.Args)
assert.Equal(t, "execution_project = ?", queryExpr.Query)

assert.Equal(t, common.Execution, input.InlineFilters[1].GetEntity())
assert.Equal(t, common.NodeExecution, input.InlineFilters[1].GetEntity())
queryExpr, _ = input.InlineFilters[1].GetGormQueryExpr()
assert.Equal(t, "domain", queryExpr.Args)
assert.Equal(t, "execution_domain = ?", queryExpr.Query)

assert.Equal(t, common.Execution, input.InlineFilters[2].GetEntity())
assert.Equal(t, common.NodeExecution, input.InlineFilters[2].GetEntity())
queryExpr, _ = input.InlineFilters[2].GetGormQueryExpr()
assert.Equal(t, "name", queryExpr.Args)
assert.Equal(t, "execution_name = ?", queryExpr.Query)
Expand Down Expand Up @@ -979,6 +983,129 @@ func TestListNodeExecutionsWithParent(t *testing.T) {
assert.Equal(t, "3", nodeExecutions.Token)
}

func TestListNodeExecutions_WithJoinTableFilter(t *testing.T) {
repository := repositoryMocks.NewMockRepository()
expectedClosure := admin.NodeExecutionClosure{
Phase: core.NodeExecution_SUCCEEDED,
}
expectedMetadata := admin.NodeExecutionMetaData{
SpecNodeId: "spec_node_id",
RetryGroup: "retry_group",
}
metadataBytes, _ := proto.Marshal(&expectedMetadata)
closureBytes, _ := proto.Marshal(&expectedClosure)

repository.NodeExecutionRepo().(*repositoryMocks.MockNodeExecutionRepo).SetListCallback(
func(ctx context.Context, input interfaces.ListResourceInput) (
interfaces.NodeExecutionCollectionOutput, error) {
assert.Equal(t, 1, input.Limit)
assert.Equal(t, 2, input.Offset)
assert.Len(t, input.InlineFilters, 4)
assert.Equal(t, common.NodeExecution, input.InlineFilters[0].GetEntity())
queryExpr, _ := input.InlineFilters[0].GetGormQueryExpr()
assert.Equal(t, "project", queryExpr.Args)
assert.Equal(t, "execution_project = ?", queryExpr.Query)

assert.Equal(t, common.NodeExecution, input.InlineFilters[1].GetEntity())
queryExpr, _ = input.InlineFilters[1].GetGormQueryExpr()
assert.Equal(t, "domain", queryExpr.Args)
assert.Equal(t, "execution_domain = ?", queryExpr.Query)

assert.Equal(t, common.NodeExecution, input.InlineFilters[2].GetEntity())
queryExpr, _ = input.InlineFilters[2].GetGormQueryExpr()
assert.Equal(t, "name", queryExpr.Args)
assert.Equal(t, "execution_name = ?", queryExpr.Query)

assert.Equal(t, common.Execution, input.InlineFilters[3].GetEntity())
queryExpr, _ = input.InlineFilters[3].GetGormQueryExpr()
assert.Equal(t, "SUCCEEDED", queryExpr.Args)
assert.Equal(t, "phase = ?", queryExpr.Query)

assert.Len(t, input.MapFilters, 1)
filter := input.MapFilters[0].GetFilter()
assert.Equal(t, map[string]interface{}{
"parent_id": nil,
"parent_task_execution_id": nil,
}, filter)

assert.EqualValues(t, input.JoinTableEntities, map[common.Entity]bool{
common.NodeExecution: true,
common.Execution: true,
})

assert.Equal(t, "execution_domain asc", input.SortParameter.GetGormOrderExpr())
return interfaces.NodeExecutionCollectionOutput{
NodeExecutions: []models.NodeExecution{
{
NodeExecutionKey: models.NodeExecutionKey{
NodeID: "node id",
ExecutionKey: models.ExecutionKey{
Project: "project",
Domain: "domain",
Name: "name",
},
},
Phase: core.NodeExecution_SUCCEEDED.String(),
InputURI: "input uri",
StartedAt: &occurredAt,
Closure: closureBytes,
NodeExecutionMetadata: metadataBytes,
},
},
}, nil
})
repository.NodeExecutionRepo().(*repositoryMocks.MockNodeExecutionRepo).SetGetWithChildrenCallback(
func(
ctx context.Context, input interfaces.NodeExecutionResource) (models.NodeExecution, error) {
return models.NodeExecution{
NodeExecutionKey: models.NodeExecutionKey{
NodeID: "node id",
ExecutionKey: models.ExecutionKey{
Project: "project",
Domain: "domain",
Name: "name",
},
},
Phase: core.NodeExecution_SUCCEEDED.String(),
InputURI: "input uri",
StartedAt: &occurredAt,
Closure: closureBytes,
NodeExecutionMetadata: metadataBytes,
}, nil
})
nodeExecManager := NewNodeExecutionManager(repository, getMockExecutionsConfigProvider(), make([]string, 0), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockNodeExecutionRemoteURL, nil, nil, &eventWriterMocks.NodeExecutionEventWriter{})
nodeExecutions, err := nodeExecManager.ListNodeExecutions(context.Background(), &admin.NodeExecutionListRequest{
WorkflowExecutionId: &core.WorkflowExecutionIdentifier{
Project: "project",
Domain: "domain",
Name: "name",
},
Limit: 1,
Token: "2",
SortBy: &admin.Sort{
Direction: admin.Sort_ASCENDING,
Key: "execution_domain",
},
Filters: "eq(execution.phase, SUCCEEDED)",
})
assert.NoError(t, err)
assert.Len(t, nodeExecutions.NodeExecutions, 1)
assert.True(t, proto.Equal(&admin.NodeExecution{
Id: &core.NodeExecutionIdentifier{
NodeId: "node id",
ExecutionId: &core.WorkflowExecutionIdentifier{
Project: "project",
Domain: "domain",
Name: "name",
},
},
InputUri: "input uri",
Closure: &expectedClosure,
Metadata: &expectedMetadata,
}, nodeExecutions.NodeExecutions[0]))
assert.Equal(t, "3", nodeExecutions.Token)
}

func TestListNodeExecutions_InvalidParams(t *testing.T) {
nodeExecManager := NewNodeExecutionManager(nil, getMockExecutionsConfigProvider(), make([]string, 0), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockNodeExecutionRemoteURL, nil, nil, &eventWriterMocks.NodeExecutionEventWriter{})
_, err := nodeExecManager.ListNodeExecutions(context.Background(), &admin.NodeExecutionListRequest{
Expand Down Expand Up @@ -1120,17 +1247,17 @@ func TestListNodeExecutionsForTask(t *testing.T) {
assert.Equal(t, 1, input.Limit)
assert.Equal(t, 2, input.Offset)
assert.Len(t, input.InlineFilters, 4)
assert.Equal(t, common.Execution, input.InlineFilters[0].GetEntity())
assert.Equal(t, common.NodeExecution, input.InlineFilters[0].GetEntity())
queryExpr, _ := input.InlineFilters[0].GetGormQueryExpr()
assert.Equal(t, "project", queryExpr.Args)
assert.Equal(t, "execution_project = ?", queryExpr.Query)

assert.Equal(t, common.Execution, input.InlineFilters[1].GetEntity())
assert.Equal(t, common.NodeExecution, input.InlineFilters[1].GetEntity())
queryExpr, _ = input.InlineFilters[1].GetGormQueryExpr()
assert.Equal(t, "domain", queryExpr.Args)
assert.Equal(t, "execution_domain = ?", queryExpr.Query)

assert.Equal(t, common.Execution, input.InlineFilters[2].GetEntity())
assert.Equal(t, common.NodeExecution, input.InlineFilters[2].GetEntity())
queryExpr, _ = input.InlineFilters[2].GetGormQueryExpr()
assert.Equal(t, "name", queryExpr.Args)
assert.Equal(t, "execution_name = ?", queryExpr.Query)
Expand Down
2 changes: 1 addition & 1 deletion flyteadmin/pkg/manager/impl/signal_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func (s *SignalManager) ListSignals(ctx context.Context, request *admin.SignalLi
}
ctx = getExecutionContext(ctx, request.WorkflowExecutionId)

identifierFilters, err := util.GetWorkflowExecutionIdentifierFilters(ctx, request.WorkflowExecutionId)
identifierFilters, err := util.GetWorkflowExecutionIdentifierFilters(ctx, request.WorkflowExecutionId, common.Signal)
if err != nil {
return nil, err
}
Expand Down
15 changes: 10 additions & 5 deletions flyteadmin/pkg/manager/impl/task_execution_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ func (m *TaskExecutionManager) ListTaskExecutions(
}
ctx = getNodeExecutionContext(ctx, request.NodeExecutionId)

identifierFilters, err := util.GetNodeExecutionIdentifierFilters(ctx, request.NodeExecutionId)
identifierFilters, err := util.GetNodeExecutionIdentifierFilters(ctx, request.NodeExecutionId, common.TaskExecution)
if err != nil {
return nil, err
}
Expand All @@ -267,12 +267,17 @@ func (m *TaskExecutionManager) ListTaskExecutions(
return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument,
"invalid pagination token %s for ListTaskExecutions", request.Token)
}
joinTableEntities := make(map[common.Entity]bool)
for _, filter := range filters {
joinTableEntities[filter.GetEntity()] = true
}

output, err := m.db.TaskExecutionRepo().List(ctx, repoInterfaces.ListResourceInput{
InlineFilters: filters,
Offset: offset,
Limit: int(request.Limit),
SortParameter: sortParameter,
InlineFilters: filters,
Offset: offset,
Limit: int(request.Limit),
SortParameter: sortParameter,
JoinTableEntities: joinTableEntities,
})
if err != nil {
logger.Debugf(ctx, "Failed to list task executions with request [%+v] with err %v",
Expand Down
Loading
Loading