From 2e2e8f74f2a39d84bf564e632946804717e34436 Mon Sep 17 00:00:00 2001 From: Andrew Dye Date: Wed, 7 Sep 2022 12:52:26 -0700 Subject: [PATCH] Fix query for NodeExecutionsRepo.Count (#472) Signed-off-by: Andrew Dye Signed-off-by: Andrew Dye --- .../pkg/repositories/gormimpl/node_execution_repo.go | 8 +++++--- .../pkg/repositories/gormimpl/node_execution_repo_test.go | 4 ++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/flyteadmin/pkg/repositories/gormimpl/node_execution_repo.go b/flyteadmin/pkg/repositories/gormimpl/node_execution_repo.go index 15188ac14d..65cd8a774c 100644 --- a/flyteadmin/pkg/repositories/gormimpl/node_execution_repo.go +++ b/flyteadmin/pkg/repositories/gormimpl/node_execution_repo.go @@ -164,12 +164,14 @@ func (r *NodeExecutionRepo) Exists(ctx context.Context, input interfaces.NodeExe func (r *NodeExecutionRepo) Count(ctx context.Context, input interfaces.CountResourceInput) (int64, error) { var err error - tx := r.db.Model(&models.NodeExecution{}) + tx := r.db.Model(&models.NodeExecution{}).Preload("ChildNodeExecutions") // Add join condition (joining multiple tables is fine even we only filter on a subset of table attributes). // (this query isn't called for deletes). - tx = tx.Joins(innerJoinNodeExecToNodeEvents) - tx = tx.Joins(innerJoinExecToNodeExec) + tx = tx.Joins(fmt.Sprintf("INNER JOIN %s ON %s.execution_project = %s.execution_project AND "+ + "%s.execution_domain = %s.execution_domain AND %s.execution_name = %s.execution_name", + executionTableName, nodeExecutionTableName, executionTableName, + nodeExecutionTableName, executionTableName, nodeExecutionTableName, executionTableName)) // Apply filters tx, err = applyScopedFilters(tx, input.InlineFilters, input.MapFilters) diff --git a/flyteadmin/pkg/repositories/gormimpl/node_execution_repo_test.go b/flyteadmin/pkg/repositories/gormimpl/node_execution_repo_test.go index 9b58960437..d3f778f10f 100644 --- a/flyteadmin/pkg/repositories/gormimpl/node_execution_repo_test.go +++ b/flyteadmin/pkg/repositories/gormimpl/node_execution_repo_test.go @@ -390,7 +390,7 @@ func TestCountNodeExecutions_Filters(t *testing.T) { GlobalMock := mocket.Catcher.Reset() GlobalMock.NewMock().WithQuery( - `SELECT count(*) FROM "node_executions" INNER JOIN node_executions ON node_event_executions.node_execution_id = node_executions.id INNER JOIN executions ON node_executions.execution_project = executions.execution_project AND node_executions.execution_domain = executions.execution_domain AND node_executions.execution_name = executions.execution_name WHERE node_executions.phase = $1 AND "error_code" IS NULL`).WithReply([]map[string]interface{}{{"rows": 3}}) + `SELECT count(*) FROM "node_executions" INNER JOIN executions ON node_executions.execution_project = executions.execution_project AND node_executions.execution_domain = executions.execution_domain AND node_executions.execution_name = executions.execution_name WHERE node_executions.phase = $1 AND "node_executions"."error_code" IS NULL`).WithReply([]map[string]interface{}{{"rows": 3}}) count, err := nodeExecutionRepo.Count(context.Background(), interfaces.CountResourceInput{ InlineFilters: []common.InlineFilter{ @@ -398,7 +398,7 @@ func TestCountNodeExecutions_Filters(t *testing.T) { }, MapFilters: []common.MapFilter{ common.NewMapFilter(map[string]interface{}{ - "error_code": nil, + "\"node_executions\".\"error_code\"": nil, }), }, })