Skip to content

Commit

Permalink
SVR-396: Implement GetExecutionCount endpoint (#155)
Browse files Browse the repository at this point in the history
* get execution count proto

Signed-off-by: troychiu <[email protected]>

* update prot

Signed-off-by: troychiu <[email protected]>

* end to end work

Signed-off-by: troychiu <[email protected]>

* unit test

Signed-off-by: troychiu <[email protected]>

* execution manager unit test

Signed-off-by: troychiu <[email protected]>

* integration test

Signed-off-by: troychiu <[email protected]>

* remove redundant struct

Signed-off-by: troychiu <[email protected]>

* lint

Signed-off-by: troychiu <[email protected]>

* fix typo

Signed-off-by: troychiu <[email protected]>

* modify interface

Signed-off-by: troychiu <[email protected]>

* fix suggestions

Signed-off-by: troychiu <[email protected]>

* fix suggestions

Signed-off-by: troychiu <[email protected]>

* fix suggestion

Signed-off-by: troychiu <[email protected]>

* fix suggestion

Signed-off-by: troychiu <[email protected]>

* fix suggestions

Signed-off-by: troychiu <[email protected]>

* log level

Signed-off-by: troychiu <[email protected]>

* running executions count

Signed-off-by: troychiu <[email protected]>

* Add test for getting running executions count

Signed-off-by: troychiu <[email protected]>

* Add validation for updated_at filter in GetExecutionCounts

Signed-off-by: troychiu <[email protected]>

* Update time filter & add index on execution table for execution counts

Signed-off-by: troychiu <[email protected]>

* revert change

Signed-off-by: troychiu <[email protected]>

* Update execution index in migrations and models

Signed-off-by: troychiu <[email protected]>

* update comment

Signed-off-by: troychiu <[email protected]>

* Update time filter validation to include correct execution timestamps

Signed-off-by: troychiu <[email protected]>

* fix tests

Signed-off-by: troychiu <[email protected]>

* update migration name

Signed-off-by: troychiu <[email protected]>

* resolve suggestions

Signed-off-by: troychiu <[email protected]>

* Update variable name in addPhaseFilter function.- Rename 'runningFilter' to 'phaseFilter'

Signed-off-by: troychiu <[email protected]>

---------

Signed-off-by: troychiu <[email protected]>
  • Loading branch information
troychiu authored Mar 25, 2024
1 parent 91ff694 commit e06144a
Show file tree
Hide file tree
Showing 41 changed files with 4,381 additions and 314 deletions.
84 changes: 84 additions & 0 deletions flyteadmin/pkg/manager/impl/execution_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -1885,6 +1885,80 @@ func (m *ExecutionManager) ListExecutions(
}, nil
}

func (m *ExecutionManager) GetExecutionCounts(
ctx context.Context, request admin.ExecutionCountsGetRequest) (*admin.ExecutionCountsGetResponse, error) {
// Check required fields
if err := validation.ValidateExecutionCountsGetRequest(request); err != nil {
logger.Debugf(ctx, "ExecutionCounts request [%+v] failed validation with err: %v", request, err)
return nil, err
}
ctx = contextutils.WithProjectDomain(ctx, request.Project, request.Domain)
filters, err := util.GetDbFilters(util.FilterSpec{
Org: request.Org,
Project: request.Project,
Domain: request.Domain,
RequestFilters: request.Filters,
}, common.Execution)
if err != nil {
return nil, err
}

countExecutionByPhaseInput := repositoryInterfaces.CountResourceInput{
InlineFilters: filters,
}
countExecutionByPhaseOutput, err := m.db.ExecutionRepo().CountByPhase(ctx, countExecutionByPhaseInput)
if err != nil {
logger.Debugf(ctx, "Failed to get execution counts using input [%+v] with err %v", countExecutionByPhaseInput, err)
return nil, err
}

executionCounts, err := transformers.FromExecutionCountsByPhase(ctx, countExecutionByPhaseOutput)
if err != nil {
logger.Errorf(ctx, "Failed to transform execution by phase output [%+v] with err %v", countExecutionByPhaseOutput, err)
return nil, err
}

return &admin.ExecutionCountsGetResponse{
ExecutionCounts: executionCounts,
}, nil
}

func (m *ExecutionManager) GetRunningExecutionsCount(
ctx context.Context, request admin.RunningExecutionsCountGetRequest) (*admin.RunningExecutionsCountGetResponse, error) {
// Check required fields
if err := validation.ValidateRunningExecutionsGetRequest(request); err != nil {
logger.Debugf(ctx, "RunningExecutionsCount request [%+v] failed validation with err: %v", request, err)
return nil, err
}
ctx = contextutils.WithProjectDomain(ctx, request.Project, request.Domain)
filters, err := util.GetDbFilters(util.FilterSpec{
Org: request.Org,
Project: request.Project,
Domain: request.Domain,
}, common.Execution)
if err != nil {
return nil, err
}

// Add filter to fetch only RUNNING executions
if filters, err = addPhaseFilter(filters, core.WorkflowExecution_RUNNING); err != nil {
return nil, err
}

countRunningExecutionsInput := repositoryInterfaces.CountResourceInput{
InlineFilters: filters,
}
countRunningExecutionsOutput, err := m.db.ExecutionRepo().Count(ctx, countRunningExecutionsInput)
if err != nil {
logger.Debugf(ctx, "Failed to get running executions count using input [%+v] with err %v", countRunningExecutionsOutput, err)
return nil, err
}

return &admin.RunningExecutionsCountGetResponse{
Count: countRunningExecutionsOutput,
}, nil
}

// publishNotifications will only forward major errors because the assumption made is all of the objects
// that are being manipulated have already been validated/manipulated by Flyte itself.
// Note: This method should be refactored somewhere else once the interaction with pushing to SNS.
Expand Down Expand Up @@ -2109,3 +2183,13 @@ func addStateFilter(filters []common.InlineFilter) ([]common.InlineFilter, error
}
return filters, nil
}

func addPhaseFilter(filters []common.InlineFilter, phase core.WorkflowExecution_Phase) ([]common.InlineFilter, error) {
phaseFilter, err := common.NewSingleValueFilter(common.Execution, common.Equal, shared.Phase,
phase.String())
if err != nil {
return filters, err
}
filters = append(filters, phaseFilter)
return filters, nil
}
235 changes: 227 additions & 8 deletions flyteadmin/pkg/manager/impl/execution_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,15 @@ import (
)

const (
principal = "principal"
rawOutput = "raw_output"
principal = "principal"
rawOutput = "raw_output"
executionOrgQueryExpr = "execution_org = ?"
executionProjectQueryExpr = "execution_project = ?"
executionDomainQueryExpr = "execution_domain = ?"
executionNameQueryExpr = "execution_name = ?"
executionCreatedAtFilter = "gte(execution_created_at,2021-01-01T00:00:00Z)"
executionCreatedAtValue = "2021-01-01T00:00:00Z"
executionCreatedAtQueryExpr = "execution_created_at >= ?"
)

var spec = testutils.GetExecutionRequest().Spec
Expand Down Expand Up @@ -2995,13 +3002,13 @@ func TestListExecutions(t *testing.T) {
for _, filter := range input.InlineFilters {
assert.Equal(t, common.Execution, filter.GetEntity())
queryExpr, _ := filter.GetGormQueryExpr()
if queryExpr.Args == projectValue && queryExpr.Query == "execution_project = ?" {
if queryExpr.Args == projectValue && queryExpr.Query == executionProjectQueryExpr {
projectFilter = true
}
if queryExpr.Args == domainValue && queryExpr.Query == "execution_domain = ?" {
if queryExpr.Args == domainValue && queryExpr.Query == executionDomainQueryExpr {
domainFilter = true
}
if queryExpr.Args == nameValue && queryExpr.Query == "execution_name = ?" {
if queryExpr.Args == nameValue && queryExpr.Query == executionNameQueryExpr {
nameFilter = true
}
}
Expand Down Expand Up @@ -3165,6 +3172,218 @@ func TestListExecutions_TransformerError(t *testing.T) {
assert.Nil(t, executionList)
}

func TestGetExecutionCounts(t *testing.T) {
repository := repositoryMocks.NewMockRepository()
getExecutionCountsFunc := func(
ctx context.Context, input interfaces.CountResourceInput) (interfaces.ExecutionCountsByPhaseOutput, error) {
var orgFilter, projectFilter, domainFilter, updatedAtFilter, nameFilter bool
for _, filter := range input.InlineFilters {
assert.Equal(t, common.Execution, filter.GetEntity())
queryExpr, _ := filter.GetGormQueryExpr()
if queryExpr.Args == orgValue && queryExpr.Query == executionOrgQueryExpr {
orgFilter = true
}
if queryExpr.Args == projectValue && queryExpr.Query == executionProjectQueryExpr {
projectFilter = true
}
if queryExpr.Args == domainValue && queryExpr.Query == executionDomainQueryExpr {
domainFilter = true
}
if queryExpr.Args == executionCreatedAtValue && queryExpr.Query == executionCreatedAtQueryExpr {
updatedAtFilter = true
}
if queryExpr.Args == nameValue && queryExpr.Query == executionNameQueryExpr {
nameFilter = true
}
}
assert.True(t, orgFilter, "Missing org equality filter")
assert.True(t, projectFilter, "Missing project equality filter")
assert.True(t, domainFilter, "Missing domain equality filter")
assert.True(t, updatedAtFilter, "Missing updated at filter")
assert.False(t, nameFilter, "Included name equality filter")
return interfaces.ExecutionCountsByPhaseOutput{
{
Phase: "FAILED",
Count: int64(3),
},
{
Phase: "SUCCEEDED",
Count: int64(4),
},
}, nil
}
repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetCountByPhaseCallback(getExecutionCountsFunc)
r := plugins.NewRegistry()
r.RegisterDefault(plugins.PluginIDWorkflowExecutor, &defaultTestExecutor)
execManager := NewExecutionManager(repository, r, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockScope.NewTestScope(), &mockPublisher, mockExecutionRemoteURL, nil, nil, nil, nil, &eventWriterMocks.WorkflowExecutionEventWriter{}, artifacts.NewArtifactRegistry(context.Background(), nil))

executionCountsGetResponse, err := execManager.GetExecutionCounts(context.Background(), admin.ExecutionCountsGetRequest{
Org: orgValue,
Project: projectValue,
Domain: domainValue,
Filters: executionCreatedAtFilter,
})
executionCounts := executionCountsGetResponse.ExecutionCounts
assert.NoError(t, err)
assert.NotNil(t, executionCounts)
assert.Len(t, executionCounts, 2)

assert.Equal(t, core.WorkflowExecution_FAILED, executionCounts[0].Phase)
assert.Equal(t, int64(3), executionCounts[0].Count)
assert.Equal(t, core.WorkflowExecution_SUCCEEDED, executionCounts[1].Phase)
assert.Equal(t, int64(4), executionCounts[1].Count)
}

func TestGetExecutionCounts_MissingParameters(t *testing.T) {
r := plugins.NewRegistry()
r.RegisterDefault(plugins.PluginIDWorkflowExecutor, &defaultTestExecutor)
execManager := NewExecutionManager(repositoryMocks.NewMockRepository(), r, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockScope.NewTestScope(), &mockPublisher, mockExecutionRemoteURL, nil, nil, nil, nil, &eventWriterMocks.WorkflowExecutionEventWriter{}, artifacts.NewArtifactRegistry(context.Background(), nil))

// Test missing domain
_, err := execManager.GetExecutionCounts(context.Background(), admin.ExecutionCountsGetRequest{
Project: projectValue,
Filters: executionCreatedAtFilter,
})
assert.Error(t, err)
assert.Equal(t, codes.InvalidArgument, err.(flyteAdminErrors.FlyteAdminError).Code())

// Test missing project
_, err = execManager.GetExecutionCounts(context.Background(), admin.ExecutionCountsGetRequest{
Domain: domainValue,
Filters: executionCreatedAtFilter,
})
assert.Error(t, err)
assert.Equal(t, codes.InvalidArgument, err.(flyteAdminErrors.FlyteAdminError).Code())

// Filter is optional
_, err = execManager.GetExecutionCounts(context.Background(), admin.ExecutionCountsGetRequest{
Project: projectValue,
Domain: domainValue,
})
assert.NoError(t, err)
}

func TestGetExecutionCounts_DatabaseError(t *testing.T) {
repository := repositoryMocks.NewMockRepository()
expectedErr := errors.New("expected error")
getExecutionCountsFunc := func(
ctx context.Context, input interfaces.CountResourceInput) (interfaces.ExecutionCountsByPhaseOutput, error) {
return interfaces.ExecutionCountsByPhaseOutput{}, expectedErr
}
repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetCountByPhaseCallback(getExecutionCountsFunc)
r := plugins.NewRegistry()
r.RegisterDefault(plugins.PluginIDWorkflowExecutor, &defaultTestExecutor)
execManager := NewExecutionManager(repository, r, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockScope.NewTestScope(), &mockPublisher, mockExecutionRemoteURL, nil, nil, nil, nil, &eventWriterMocks.WorkflowExecutionEventWriter{}, artifacts.NewArtifactRegistry(context.Background(), nil))
_, err := execManager.GetExecutionCounts(context.Background(), admin.ExecutionCountsGetRequest{
Project: projectValue,
Domain: domainValue,
Filters: executionCreatedAtFilter,
})
assert.EqualError(t, err, expectedErr.Error())
}

func TestGetExecutionCounts_TransformerError(t *testing.T) {
repository := repositoryMocks.NewMockRepository()
getExecutionCountsFunc := func(
ctx context.Context, input interfaces.CountResourceInput) (interfaces.ExecutionCountsByPhaseOutput, error) {
return interfaces.ExecutionCountsByPhaseOutput{
{
Phase: "INVALID_PHASE",
Count: int64(3),
},
}, nil
}
repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetCountByPhaseCallback(getExecutionCountsFunc)
r := plugins.NewRegistry()
r.RegisterDefault(plugins.PluginIDWorkflowExecutor, &defaultTestExecutor)
execManager := NewExecutionManager(repository, r, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockScope.NewTestScope(), &mockPublisher, mockExecutionRemoteURL, nil, nil, nil, nil, &eventWriterMocks.WorkflowExecutionEventWriter{}, artifacts.NewArtifactRegistry(context.Background(), nil))

executionCountsGetResponse, err := execManager.GetExecutionCounts(context.Background(), admin.ExecutionCountsGetRequest{
Project: projectValue,
Domain: domainValue,
Filters: executionCreatedAtFilter,
})
assert.EqualError(t, err, "Failed to transform INVALID_PHASE into an execution phase.")
assert.Nil(t, executionCountsGetResponse)
}

func TestGetRunningExecutionsCount(t *testing.T) {
repository := repositoryMocks.NewMockRepository()
getRunningExecutionsCountFunc := func(
ctx context.Context, input interfaces.CountResourceInput) (int64, error) {
var orgFilter, projectFilter, domainFilter, nameFilter bool
for _, filter := range input.InlineFilters {
assert.Equal(t, common.Execution, filter.GetEntity())
queryExpr, _ := filter.GetGormQueryExpr()
if queryExpr.Args == orgValue && queryExpr.Query == executionOrgQueryExpr {
orgFilter = true
}
if queryExpr.Args == projectValue && queryExpr.Query == executionProjectQueryExpr {
projectFilter = true
}
if queryExpr.Args == domainValue && queryExpr.Query == executionDomainQueryExpr {
domainFilter = true
}
if queryExpr.Args == nameValue && queryExpr.Query == executionNameQueryExpr {
nameFilter = true
}
}
assert.True(t, orgFilter, "Missing org equality filter")
assert.True(t, projectFilter, "Missing project equality filter")
assert.True(t, domainFilter, "Missing domain equality filter")
assert.False(t, nameFilter, "Included name equality filter")
return 3, nil
}
repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetCountCallback(getRunningExecutionsCountFunc)
r := plugins.NewRegistry()
r.RegisterDefault(plugins.PluginIDWorkflowExecutor, &defaultTestExecutor)
execManager := NewExecutionManager(repository, r, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockScope.NewTestScope(), &mockPublisher, mockExecutionRemoteURL, nil, nil, nil, nil, &eventWriterMocks.WorkflowExecutionEventWriter{}, artifacts.NewArtifactRegistry(context.Background(), nil))

runningExecutionsCountGetResponse, err := execManager.GetRunningExecutionsCount(context.Background(), admin.RunningExecutionsCountGetRequest{
Org: orgValue,
Project: projectValue,
Domain: domainValue,
})
assert.NoError(t, err)
assert.NotNil(t, runningExecutionsCountGetResponse)
assert.Equal(t, int64(3), runningExecutionsCountGetResponse.Count)
}

func TestGetRunningExecutionsCount_MissingParameters(t *testing.T) {
r := plugins.NewRegistry()
r.RegisterDefault(plugins.PluginIDWorkflowExecutor, &defaultTestExecutor)
execManager := NewExecutionManager(repositoryMocks.NewMockRepository(), r, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockScope.NewTestScope(), &mockPublisher, mockExecutionRemoteURL, nil, nil, nil, nil, &eventWriterMocks.WorkflowExecutionEventWriter{}, artifacts.NewArtifactRegistry(context.Background(), nil))
_, err := execManager.GetRunningExecutionsCount(context.Background(), admin.RunningExecutionsCountGetRequest{
Project: projectValue,
})
assert.Error(t, err)
assert.Equal(t, codes.InvalidArgument, err.(flyteAdminErrors.FlyteAdminError).Code())

_, err = execManager.GetRunningExecutionsCount(context.Background(), admin.RunningExecutionsCountGetRequest{
Domain: domainValue,
})
assert.Error(t, err)
assert.Equal(t, codes.InvalidArgument, err.(flyteAdminErrors.FlyteAdminError).Code())
}

func TestGetRunningExecutionsCount_DatabaseError(t *testing.T) {
repository := repositoryMocks.NewMockRepository()
expectedErr := errors.New("expected error")
getRunningExecutionsCountFunc := func(
ctx context.Context, input interfaces.CountResourceInput) (int64, error) {
return 0, expectedErr
}
repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetCountCallback(getRunningExecutionsCountFunc)
r := plugins.NewRegistry()
r.RegisterDefault(plugins.PluginIDWorkflowExecutor, &defaultTestExecutor)
execManager := NewExecutionManager(repository, r, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockScope.NewTestScope(), &mockPublisher, mockExecutionRemoteURL, nil, nil, nil, nil, &eventWriterMocks.WorkflowExecutionEventWriter{}, artifacts.NewArtifactRegistry(context.Background(), nil))
_, err := execManager.GetRunningExecutionsCount(context.Background(), admin.RunningExecutionsCountGetRequest{
Project: projectValue,
Domain: domainValue,
})
assert.EqualError(t, err, expectedErr.Error())
}

func TestExecutionManager_PublishNotifications(t *testing.T) {
repository := repositoryMocks.NewMockRepository()
queue := executions.NewQueueAllocator(getMockExecutionsConfigProvider(), repository)
Expand Down Expand Up @@ -3979,13 +4198,13 @@ func TestListExecutions_LegacyModel(t *testing.T) {
for _, filter := range input.InlineFilters {
assert.Equal(t, common.Execution, filter.GetEntity())
queryExpr, _ := filter.GetGormQueryExpr()
if queryExpr.Args == projectValue && queryExpr.Query == "execution_project = ?" {
if queryExpr.Args == projectValue && queryExpr.Query == executionProjectQueryExpr {
projectFilter = true
}
if queryExpr.Args == domainValue && queryExpr.Query == "execution_domain = ?" {
if queryExpr.Args == domainValue && queryExpr.Query == executionDomainQueryExpr {
domainFilter = true
}
if queryExpr.Args == nameValue && queryExpr.Query == "execution_name = ?" {
if queryExpr.Args == nameValue && queryExpr.Query == executionNameQueryExpr {
nameFilter = true
}
}
Expand Down
8 changes: 6 additions & 2 deletions flyteadmin/pkg/manager/impl/shared/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ const (
Attributes = "attributes"
MatchingAttributes = "matching_attributes"
// Parent of a node execution in the node executions table
ParentID = "parent_id"
WorkflowClosure = "workflow_closure"
ParentID = "parent_id"
WorkflowClosure = "workflow_closure"
Phase = "phase"
StartedAt = "started_at"
ExecutionCreatedAt = "execution_created_at"
ExecutionUpdatedAt = "execution_updated_at"
)
2 changes: 1 addition & 1 deletion flyteadmin/pkg/manager/impl/task_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import (
)

// Static values for test
const orgValue = ""
const orgValue = "foobar"
const projectValue = "foo"
const domainValue = "bar"
const nameValue = "baz"
Expand Down
Loading

0 comments on commit e06144a

Please sign in to comment.