From 0c0b505928840d3ca069e559b515599b2a29bb3f Mon Sep 17 00:00:00 2001 From: Catalin Toda Date: Thu, 1 Oct 2020 14:03:51 -0700 Subject: [PATCH] Set labels in workflow metadata --- pkg/manager/impl/execution_manager.go | 35 +++++++++++++++++++ pkg/manager/impl/execution_manager_test.go | 13 +++++++ .../impl/propeller_executor_test.go | 3 ++ pkg/workflowengine/interfaces/executor.go | 2 ++ 4 files changed, 53 insertions(+) diff --git a/pkg/manager/impl/execution_manager.go b/pkg/manager/impl/execution_manager.go index 202050579..a50bcafa4 100644 --- a/pkg/manager/impl/execution_manager.go +++ b/pkg/manager/impl/execution_manager.go @@ -455,6 +455,7 @@ func (m *ExecutionManager) launchSingleTaskExecution( logger.Errorf(ctx, "Failed to get quality of service for [%+v] with error: %v", workflowExecutionID, err) return nil, nil, err } + executeTaskInputs := workflowengineInterfaces.ExecuteTaskInput{ ExecutionID: &workflowExecutionID, WfClosure: *workflow.Closure.CompiledWorkflow, @@ -464,9 +465,15 @@ func (m *ExecutionManager) launchSingleTaskExecution( Auth: request.Spec.AuthRole, QueueingBudget: qualityOfService.QueuingBudget, } + if request.Spec.Labels != nil { executeTaskInputs.Labels = request.Spec.Labels.Values } + executeTaskInputs.Labels, err = m.addProjectLabels(ctx, request.Project, executeTaskInputs.Labels) + if err != nil { + return nil, nil, err + } + if request.Spec.Annotations != nil { executeTaskInputs.Annotations = request.Spec.Annotations.Values } @@ -632,6 +639,12 @@ func (m *ExecutionManager) launchExecutionAndPrepareModel( if err != nil { return nil, nil, err } + + executeWorkflowInputs.Labels, err = m.addProjectLabels(ctx, request.Project, executeWorkflowInputs.Labels) + if err != nil { + return nil, nil, err + } + err = m.addPluginOverrides(ctx, &workflowExecutionID, launchPlan.GetSpec().WorkflowId.Name, launchPlan.Id.Name, &executeWorkflowInputs) if err != nil { @@ -1346,3 +1359,25 @@ func NewExecutionManager( qualityOfServiceAllocator: executions.NewQualityOfServiceAllocator(config, resourceManager), } } + +func (m *ExecutionManager) addProjectLabels(ctx context.Context, projectName string, initialLabels map[string]string) (map[string]string, error) { + project, err := m.db.ProjectRepo().Get(ctx, projectName) + if err != nil { + logger.Errorf(ctx, "Failed to get project for [%+v] with error: %v", project, err) + return nil, err + } + // passing nil domain as not needed to retrieve labels + projectLabels := transformers.FromProjectModel(project, nil).Labels.GetValues() + + if initialLabels == nil { + initialLabels = make(map[string]string) + } + + // Add the project labels only if not set before + for k, v := range projectLabels { + if _, ok := initialLabels[k]; !ok { + initialLabels[k] = v + } + } + return initialLabels, nil +} diff --git a/pkg/manager/impl/execution_manager_test.go b/pkg/manager/impl/execution_manager_test.go index c3d2a90a7..56fd12d0e 100644 --- a/pkg/manager/impl/execution_manager_test.go +++ b/pkg/manager/impl/execution_manager_test.go @@ -35,6 +35,7 @@ import ( "github.com/lyft/flyteadmin/pkg/repositories/interfaces" repositoryMocks "github.com/lyft/flyteadmin/pkg/repositories/mocks" "github.com/lyft/flyteadmin/pkg/repositories/models" + "github.com/lyft/flyteadmin/pkg/repositories/transformers" runtimeInterfaces "github.com/lyft/flyteadmin/pkg/runtime/interfaces" runtimeIFaceMocks "github.com/lyft/flyteadmin/pkg/runtime/interfaces/mocks" runtimeMocks "github.com/lyft/flyteadmin/pkg/runtime/mocks" @@ -198,6 +199,17 @@ func getMockRepositoryForExecTest() repositories.RepositoryInterface { func TestCreateExecution(t *testing.T) { repository := getMockRepositoryForExecTest() + labels := admin.Labels{ + Values: map[string]string{ + "label3": "3", + "label2": "1", // common label, will be dropped + }} + repository.ProjectRepo().(*repositoryMocks.MockProjectRepo).GetFunction = func( + ctx context.Context, projectID string) (models.Project, error) { + return transformers.CreateProjectModel(&admin.Project{ + Labels: &labels}), nil + } + principal := "principal" repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetCreateCallback( func(ctx context.Context, input models.Execution) error { @@ -214,6 +226,7 @@ func TestCreateExecution(t *testing.T) { assert.EqualValues(t, map[string]string{ "label1": "1", "label2": "2", + "label3": "3", }, inputs.Labels) assert.EqualValues(t, map[string]string{ "annotation3": "3", diff --git a/pkg/workflowengine/impl/propeller_executor_test.go b/pkg/workflowengine/impl/propeller_executor_test.go index 3a9217b3f..fa7479943 100644 --- a/pkg/workflowengine/impl/propeller_executor_test.go +++ b/pkg/workflowengine/impl/propeller_executor_test.go @@ -198,6 +198,9 @@ func TestExecuteWorkflowHappyCase(t *testing.T) { MissingPluginBehavior: admin.PluginOverride_USE_DEFAULT, }, }, + ProjectLabels: map[string]string{ + "customlabel": "labelval", + }, }) assert.Nil(t, err) assert.NotNil(t, execInfo) diff --git a/pkg/workflowengine/interfaces/executor.go b/pkg/workflowengine/interfaces/executor.go index bd468ba28..ec200b1a7 100644 --- a/pkg/workflowengine/interfaces/executor.go +++ b/pkg/workflowengine/interfaces/executor.go @@ -19,6 +19,7 @@ type ExecuteWorkflowInput struct { Annotations map[string]string QueueingBudget time.Duration TaskPluginOverrides []*admin.PluginOverride + ProjectLabels map[string]string } type ExecuteTaskInput struct { @@ -31,6 +32,7 @@ type ExecuteTaskInput struct { Labels map[string]string Annotations map[string]string QueueingBudget time.Duration + ProjectLabels map[string]string } type TerminateWorkflowInput struct {