diff --git a/flyteadmin/go.mod b/flyteadmin/go.mod index 792330face..957e352b4e 100644 --- a/flyteadmin/go.mod +++ b/flyteadmin/go.mod @@ -26,7 +26,7 @@ require ( github.com/kelseyhightower/envconfig v1.4.0 // indirect github.com/lib/pq v1.3.0 github.com/lyft/flyteidl v0.18.6 - github.com/lyft/flytepropeller v0.3.16 + github.com/lyft/flytepropeller v0.3.17 github.com/lyft/flytestdlib v0.3.9 github.com/magiconair/properties v1.8.1 github.com/mitchellh/mapstructure v1.1.2 diff --git a/flyteadmin/go.sum b/flyteadmin/go.sum index 5151ad89e7..068f8cf07e 100644 --- a/flyteadmin/go.sum +++ b/flyteadmin/go.sum @@ -462,17 +462,11 @@ github.com/lyft/apimachinery v0.0.0-20191031200210-047e3ea32d7f/go.mod h1:llRdnz github.com/lyft/datacatalog v0.2.1/go.mod h1:ktrPvzTDUwHO5Lv0hLH38zLHnOJ++rGoAO0iQ/sIPJ4= github.com/lyft/flyteidl v0.17.0/go.mod h1:/zQXxuHO11u/saxTTZc8oYExIGEShXB+xCB1/F1Cu20= github.com/lyft/flyteidl v0.18.0/go.mod h1:/zQXxuHO11u/saxTTZc8oYExIGEShXB+xCB1/F1Cu20= -github.com/lyft/flyteidl v0.18.1/go.mod h1:/zQXxuHO11u/saxTTZc8oYExIGEShXB+xCB1/F1Cu20= -github.com/lyft/flyteidl v0.18.3 h1:+O0rDCXoiui5X56DtoqquW0rqjN75jDWqAEyvcqmarI= -github.com/lyft/flyteidl v0.18.3/go.mod h1:/zQXxuHO11u/saxTTZc8oYExIGEShXB+xCB1/F1Cu20= github.com/lyft/flyteidl v0.18.6 h1:HGbxHI8avEDvoPqcO2+/BoJVcP9sjOj4qwJ/wNRWuoA= github.com/lyft/flyteidl v0.18.6/go.mod h1:/zQXxuHO11u/saxTTZc8oYExIGEShXB+xCB1/F1Cu20= -github.com/lyft/flyteplugins v0.4.4/go.mod h1:8zhqFG9BzbHNQGEXzGYltTJLD+KTmQZkanxXgeFI25c= github.com/lyft/flyteplugins v0.5.1/go.mod h1:8zhqFG9BzbHNQGEXzGYltTJLD+KTmQZkanxXgeFI25c= -github.com/lyft/flytepropeller v0.3.7 h1:l2AguhyhiUDCvqjHYF8XJw46gPW9j4XNZwJEAJdiEtI= -github.com/lyft/flytepropeller v0.3.7/go.mod h1:8sNP7ZnEngNRYBMewmH4PtiRR0pus8RkjNoPqelyKX8= -github.com/lyft/flytepropeller v0.3.16 h1:a6KbvtDRMMVEUlVTqQ9h9IOehUerk3dT+pvsN5Ql/4o= -github.com/lyft/flytepropeller v0.3.16/go.mod h1:GArCzcLAZ48OacGUsHUA3f028ixoU8CVZOMikyjEdNY= +github.com/lyft/flytepropeller v0.3.17 h1:a2PVqWjnn8oNEeayAqNizMAtEixl/F3S4vd8z4kbiqI= +github.com/lyft/flytepropeller v0.3.17/go.mod h1:T8Utxqv7B5USAX9c/Qh0lBbKXHFSgOwwaISOd9h36P4= github.com/lyft/flytestdlib v0.3.0 h1:nIkX4MlyYdcLLzaF35RI2P5BhARt+qMgHoFto8eVNzU= github.com/lyft/flytestdlib v0.3.0/go.mod h1:LJPPJlkFj+wwVWMrQT3K5JZgNhZi2mULsCG4ZYhinhU= github.com/lyft/flytestdlib v0.3.9 h1:NaKp9xkeWWwhVvqTOcR/FqlASy1N2gu/kN7PVe4S7YI= diff --git a/flyteadmin/pkg/common/testutils/common.go b/flyteadmin/pkg/common/testutils/common.go new file mode 100644 index 0000000000..e41abb4a0d --- /dev/null +++ b/flyteadmin/pkg/common/testutils/common.go @@ -0,0 +1,21 @@ +package testutils + +import "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" + +// Convenience method to wrap verbose boilerplate for initializing a PluginOverrides MatchingAttributes. +func GetPluginOverridesAttributes(vals map[string][]string) *admin.MatchingAttributes { + overrides := make([]*admin.PluginOverride, 0, len(vals)) + for taskType, pluginIDs := range vals { + overrides = append(overrides, &admin.PluginOverride{ + TaskType: taskType, + PluginId: pluginIDs, + }) + } + return &admin.MatchingAttributes{ + Target: &admin.MatchingAttributes_PluginOverrides{ + PluginOverrides: &admin.PluginOverrides{ + Overrides: overrides, + }, + }, + } +} diff --git a/flyteadmin/pkg/manager/impl/execution_manager.go b/flyteadmin/pkg/manager/impl/execution_manager.go index d43f52129e..2020505799 100644 --- a/flyteadmin/pkg/manager/impl/execution_manager.go +++ b/flyteadmin/pkg/manager/impl/execution_manager.go @@ -173,6 +173,27 @@ func (m *ExecutionManager) addLabelsAndAnnotations(requestSpec *admin.ExecutionS return nil } +func (m *ExecutionManager) addPluginOverrides(ctx context.Context, executionID *core.WorkflowExecutionIdentifier, + workflowName, launchPlanName string, partiallyPopulatedInputs *workflowengineInterfaces.ExecuteWorkflowInput) error { + override, err := m.resourceManager.GetResource(ctx, interfaces.ResourceRequest{ + Project: executionID.Project, + Domain: executionID.Domain, + Workflow: workflowName, + LaunchPlan: launchPlanName, + ResourceType: admin.MatchableResource_PLUGIN_OVERRIDE, + }) + if err != nil { + ec, ok := err.(errors.FlyteAdminError) + if !ok || ec.Code() != codes.NotFound { + return err + } + } + if override != nil && override.Attributes != nil && override.Attributes.GetPluginOverrides() != nil { + partiallyPopulatedInputs.TaskPluginOverrides = override.Attributes.GetPluginOverrides().Overrides + } + return nil +} + func (m *ExecutionManager) offloadInputs(ctx context.Context, literalMap *core.LiteralMap, identifier *core.WorkflowExecutionIdentifier, key string) (storage.DataReference, error) { if literalMap == nil { literalMap = &core.LiteralMap{} @@ -611,6 +632,11 @@ func (m *ExecutionManager) launchExecutionAndPrepareModel( if err != nil { return nil, nil, err } + err = m.addPluginOverrides(ctx, &workflowExecutionID, launchPlan.GetSpec().WorkflowId.Name, launchPlan.Id.Name, + &executeWorkflowInputs) + if err != nil { + return nil, nil, err + } execInfo, err := m.workflowExecutor.ExecuteWorkflow(ctx, executeWorkflowInputs) if err != nil { diff --git a/flyteadmin/pkg/manager/impl/execution_manager_test.go b/flyteadmin/pkg/manager/impl/execution_manager_test.go index 66b0b0410e..c3d2a90a73 100644 --- a/flyteadmin/pkg/manager/impl/execution_manager_test.go +++ b/flyteadmin/pkg/manager/impl/execution_manager_test.go @@ -25,6 +25,7 @@ import ( "github.com/golang/protobuf/proto" notificationMocks "github.com/lyft/flyteadmin/pkg/async/notifications/mocks" + commonTestUtils "github.com/lyft/flyteadmin/pkg/common/testutils" dataMocks "github.com/lyft/flyteadmin/pkg/data/mocks" flyteAdminErrors "github.com/lyft/flyteadmin/pkg/errors" "github.com/lyft/flyteadmin/pkg/manager/impl/executions" @@ -2104,6 +2105,80 @@ func TestAddLabelsAndAnnotationsRuntimeLimitsObserved(t *testing.T) { assert.EqualError(t, err, "Labels has too many entries [2 > 1]") } +func TestAddPluginOverrides(t *testing.T) { + executionID := &core.WorkflowExecutionIdentifier{ + Project: project, + Domain: domain, + Name: "unused", + } + workflowName := "workflow_name" + launchPlanName := "launch_plan_name" + + db := repositoryMocks.NewMockRepository() + db.ResourceRepo().(*repositoryMocks.MockResourceRepo).GetFunction = func(ctx context.Context, ID interfaces.ResourceID) ( + models.Resource, error) { + assert.Equal(t, project, ID.Project) + assert.Equal(t, domain, ID.Domain) + assert.Equal(t, workflowName, ID.Workflow) + assert.Equal(t, launchPlanName, ID.LaunchPlan) + existingAttributes := commonTestUtils.GetPluginOverridesAttributes(map[string][]string{ + "python": {"plugin a"}, + "hive": {"plugin b"}, + }) + bytes, err := proto.Marshal(existingAttributes) + if err != nil { + t.Fatal(err) + } + return models.Resource{ + Project: project, + Domain: domain, + Attributes: bytes, + }, nil + } + partiallyPopulatedInputs := workflowengineInterfaces.ExecuteWorkflowInput{} + + execManager := NewExecutionManager( + db, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), workflowengineMocks.NewMockExecutor(), + mockScope.NewTestScope(), mockScope.NewTestScope(), &mockPublisher, mockExecutionRemoteURL, nil, nil) + + err := execManager.(*ExecutionManager).addPluginOverrides( + context.Background(), executionID, workflowName, launchPlanName, &partiallyPopulatedInputs) + assert.NoError(t, err) + assert.Len(t, partiallyPopulatedInputs.TaskPluginOverrides, 2) + for _, override := range partiallyPopulatedInputs.TaskPluginOverrides { + if override.TaskType == "python" { + assert.EqualValues(t, []string{"plugin a"}, override.PluginId) + } else if override.TaskType == "hive" { + assert.EqualValues(t, []string{"plugin b"}, override.PluginId) + } else { + t.Errorf("Unexpected task type [%s] plugin override committed to db", override.TaskType) + } + } +} + +func TestPluginOverrides_ResourceGetFailure(t *testing.T) { + executionID := &core.WorkflowExecutionIdentifier{ + Project: project, + Domain: domain, + Name: "unused", + } + workflowName := "workflow_name" + launchPlanName := "launch_plan_name" + + db := repositoryMocks.NewMockRepository() + db.ResourceRepo().(*repositoryMocks.MockResourceRepo).GetFunction = func(ctx context.Context, ID interfaces.ResourceID) ( + models.Resource, error) { + return models.Resource{}, flyteAdminErrors.NewFlyteAdminErrorf(codes.Aborted, "uh oh") + } + execManager := NewExecutionManager( + db, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), workflowengineMocks.NewMockExecutor(), + mockScope.NewTestScope(), mockScope.NewTestScope(), &mockPublisher, mockExecutionRemoteURL, nil, nil) + + err := execManager.(*ExecutionManager).addPluginOverrides( + context.Background(), executionID, workflowName, launchPlanName, &workflowengineInterfaces.ExecuteWorkflowInput{}) + assert.Error(t, err, "uh oh") +} + func TestGetExecution_Legacy(t *testing.T) { repository := repositoryMocks.NewMockRepository() startedAt := time.Date(2018, 8, 30, 0, 0, 0, 0, time.UTC) diff --git a/flyteadmin/pkg/manager/impl/resources/resource_manager.go b/flyteadmin/pkg/manager/impl/resources/resource_manager.go index 6427f9ec29..686212bff9 100644 --- a/flyteadmin/pkg/manager/impl/resources/resource_manager.go +++ b/flyteadmin/pkg/manager/impl/resources/resource_manager.go @@ -3,6 +3,8 @@ package resources import ( "context" + "github.com/lyft/flyteadmin/pkg/repositories/models" + "github.com/gogo/protobuf/proto" "github.com/lyft/flyteadmin/pkg/errors" "github.com/lyft/flytestdlib/contextutils" @@ -54,6 +56,41 @@ func (m *ResourceManager) GetResource(ctx context.Context, request interfaces.Re }, nil } +func (m *ResourceManager) createOrMergeUpdateWorkflowAttributes( + ctx context.Context, request admin.WorkflowAttributesUpdateRequest, model models.Resource, + resourceType admin.MatchableResource) (*admin.WorkflowAttributesUpdateResponse, error) { + resourceID := repo_interface.ResourceID{ + Project: model.Project, + Domain: model.Domain, + Workflow: model.Workflow, + LaunchPlan: model.LaunchPlan, + ResourceType: model.ResourceType, + } + existing, err := m.db.ResourceRepo().GetRaw(ctx, resourceID) + if err != nil { + ec, ok := err.(errors.FlyteAdminError) + if ok && ec.Code() == codes.NotFound { + // Proceed with the default CreateOrUpdate call since there's no existing model to update. + err = m.db.ResourceRepo().CreateOrUpdate(ctx, model) + if err != nil { + return nil, err + } + return &admin.WorkflowAttributesUpdateResponse{}, nil + } + return nil, err + } + updatedModel, err := transformers.MergeUpdateWorkflowAttributes( + ctx, existing, resourceType, &resourceID, request.Attributes) + if err != nil { + return nil, err + } + err = m.db.ResourceRepo().CreateOrUpdate(ctx, updatedModel) + if err != nil { + return nil, err + } + return &admin.WorkflowAttributesUpdateResponse{}, nil +} + func (m *ResourceManager) UpdateWorkflowAttributes( ctx context.Context, request admin.WorkflowAttributesUpdateRequest) ( *admin.WorkflowAttributesUpdateResponse, error) { @@ -67,6 +104,9 @@ func (m *ResourceManager) UpdateWorkflowAttributes( if err != nil { return nil, err } + if request.Attributes.GetMatchingAttributes().GetPluginOverrides() != nil { + return m.createOrMergeUpdateWorkflowAttributes(ctx, request, model, admin.MatchableResource_PLUGIN_OVERRIDE) + } err = m.db.ResourceRepo().CreateOrUpdate(ctx, model) if err != nil { return nil, err @@ -109,6 +149,41 @@ func (m *ResourceManager) DeleteWorkflowAttributes(ctx context.Context, return &admin.WorkflowAttributesDeleteResponse{}, nil } +func (m *ResourceManager) createOrMergeUpdateProjectDomainAttributes( + ctx context.Context, request admin.ProjectDomainAttributesUpdateRequest, model models.Resource, + resourceType admin.MatchableResource) (*admin.ProjectDomainAttributesUpdateResponse, error) { + resourceID := repo_interface.ResourceID{ + Project: model.Project, + Domain: model.Domain, + Workflow: model.Workflow, + LaunchPlan: model.LaunchPlan, + ResourceType: model.ResourceType, + } + existing, err := m.db.ResourceRepo().GetRaw(ctx, resourceID) + if err != nil { + ec, ok := err.(errors.FlyteAdminError) + if ok && ec.Code() == codes.NotFound { + // Proceed with the default CreateOrUpdate call since there's no existing model to update. + err = m.db.ResourceRepo().CreateOrUpdate(ctx, model) + if err != nil { + return nil, err + } + return &admin.ProjectDomainAttributesUpdateResponse{}, nil + } + return nil, err + } + updatedModel, err := transformers.MergeUpdateProjectDomainAttributes( + ctx, existing, resourceType, &resourceID, request.Attributes) + if err != nil { + return nil, err + } + err = m.db.ResourceRepo().CreateOrUpdate(ctx, updatedModel) + if err != nil { + return nil, err + } + return &admin.ProjectDomainAttributesUpdateResponse{}, nil +} + func (m *ResourceManager) UpdateProjectDomainAttributes( ctx context.Context, request admin.ProjectDomainAttributesUpdateRequest) ( *admin.ProjectDomainAttributesUpdateResponse, error) { @@ -123,6 +198,9 @@ func (m *ResourceManager) UpdateProjectDomainAttributes( if err != nil { return nil, err } + if request.Attributes.GetMatchingAttributes().GetPluginOverrides() != nil { + return m.createOrMergeUpdateProjectDomainAttributes(ctx, request, model, admin.MatchableResource_PLUGIN_OVERRIDE) + } err = m.db.ResourceRepo().CreateOrUpdate(ctx, model) if err != nil { return nil, err diff --git a/flyteadmin/pkg/manager/impl/resources/resource_manager_test.go b/flyteadmin/pkg/manager/impl/resources/resource_manager_test.go index 7f0a8493ad..95ce25807c 100644 --- a/flyteadmin/pkg/manager/impl/resources/resource_manager_test.go +++ b/flyteadmin/pkg/manager/impl/resources/resource_manager_test.go @@ -2,12 +2,17 @@ package resources import ( "context" + "fmt" "testing" + "github.com/lyft/flyteadmin/pkg/errors" + "google.golang.org/grpc/codes" + "github.com/lyft/flyteadmin/pkg/manager/interfaces" repoInterfaces "github.com/lyft/flyteadmin/pkg/repositories/interfaces" "github.com/golang/protobuf/proto" + commonTestUtils "github.com/lyft/flyteadmin/pkg/common/testutils" "github.com/lyft/flyteadmin/pkg/manager/impl/testutils" "github.com/lyft/flyteadmin/pkg/repositories/mocks" "github.com/lyft/flyteadmin/pkg/repositories/models" @@ -47,6 +52,97 @@ func TestUpdateWorkflowAttributes(t *testing.T) { assert.True(t, createOrUpdateCalled) } +func TestUpdateWorkflowAttributes_CreateOrMerge(t *testing.T) { + request := admin.WorkflowAttributesUpdateRequest{ + Attributes: &admin.WorkflowAttributes{ + Project: project, + Domain: domain, + Workflow: workflow, + MatchingAttributes: commonTestUtils.GetPluginOverridesAttributes(map[string][]string{"python": {"plugin a"}}), + }, + } + + t.Run("create only", func(t *testing.T) { + db := mocks.NewMockRepository() + db.ResourceRepo().(*mocks.MockResourceRepo).GetFunction = func(ctx context.Context, ID repoInterfaces.ResourceID) ( + models.Resource, error) { + return models.Resource{}, errors.NewFlyteAdminError(codes.NotFound, "foo") + } + var createOrUpdateCalled bool + db.ResourceRepo().(*mocks.MockResourceRepo).CreateOrUpdateFunction = func(ctx context.Context, input models.Resource) error { + assert.Equal(t, project, input.Project) + assert.Equal(t, domain, input.Domain) + assert.Equal(t, workflow, input.Workflow) + + var attributesToBeSaved admin.MatchingAttributes + err := proto.Unmarshal(input.Attributes, &attributesToBeSaved) + if err != nil { + t.Fatal(err) + } + assert.Len(t, attributesToBeSaved.GetPluginOverrides().Overrides, 1) + assert.True(t, proto.Equal(attributesToBeSaved.GetPluginOverrides().Overrides[0], &admin.PluginOverride{ + TaskType: "python", + PluginId: []string{"plugin a"}})) + + createOrUpdateCalled = true + return nil + } + manager := NewResourceManager(db, testutils.GetApplicationConfigWithDefaultDomains()) + _, err := manager.UpdateWorkflowAttributes(context.Background(), request) + assert.NoError(t, err) + assert.True(t, createOrUpdateCalled) + }) + t.Run("merge update", func(t *testing.T) { + db := mocks.NewMockRepository() + db.ResourceRepo().(*mocks.MockResourceRepo).GetFunction = func(ctx context.Context, ID repoInterfaces.ResourceID) ( + models.Resource, error) { + existingAttributes := commonTestUtils.GetPluginOverridesAttributes(map[string][]string{ + "hive": {"plugin b"}, + "python": {"plugin c"}, + }) + bytes, err := proto.Marshal(existingAttributes) + if err != nil { + t.Fatal(err) + } + return models.Resource{ + Project: project, + Domain: domain, + Workflow: workflow, + Attributes: bytes, + }, nil + } + var createOrUpdateCalled bool + db.ResourceRepo().(*mocks.MockResourceRepo).CreateOrUpdateFunction = func(ctx context.Context, input models.Resource) error { + assert.Equal(t, project, input.Project) + assert.Equal(t, domain, input.Domain) + assert.Equal(t, workflow, input.Workflow) + + var attributesToBeSaved admin.MatchingAttributes + err := proto.Unmarshal(input.Attributes, &attributesToBeSaved) + if err != nil { + t.Fatal(err) + } + + assert.Len(t, attributesToBeSaved.GetPluginOverrides().Overrides, 2) + for _, override := range attributesToBeSaved.GetPluginOverrides().Overrides { + if override.TaskType == "python" { + assert.EqualValues(t, []string{"plugin a"}, override.PluginId) + } else if override.TaskType == "hive" { + assert.EqualValues(t, []string{"plugin b"}, override.PluginId) + } else { + t.Error(fmt.Sprintf("Unexpected task type [%s] plugin override committed to db", override.TaskType)) + } + } + createOrUpdateCalled = true + return nil + } + manager := NewResourceManager(db, testutils.GetApplicationConfigWithDefaultDomains()) + _, err := manager.UpdateWorkflowAttributes(context.Background(), request) + assert.NoError(t, err) + assert.True(t, createOrUpdateCalled) + }) +} + func TestGetWorkflowAttributes(t *testing.T) { request := admin.WorkflowAttributesGetRequest{ Project: project, @@ -131,6 +227,93 @@ func TestUpdateProjectDomainAttributes(t *testing.T) { assert.True(t, createOrUpdateCalled) } +func TestUpdateProjectDomainAttributes_CreateOrMerge(t *testing.T) { + request := admin.ProjectDomainAttributesUpdateRequest{ + Attributes: &admin.ProjectDomainAttributes{ + Project: project, + Domain: domain, + MatchingAttributes: commonTestUtils.GetPluginOverridesAttributes(map[string][]string{"python": {"plugin a"}}), + }, + } + + t.Run("create only", func(t *testing.T) { + db := mocks.NewMockRepository() + db.ResourceRepo().(*mocks.MockResourceRepo).GetFunction = func(ctx context.Context, ID repoInterfaces.ResourceID) ( + models.Resource, error) { + return models.Resource{}, errors.NewFlyteAdminError(codes.NotFound, "foo") + } + var createOrUpdateCalled bool + db.ResourceRepo().(*mocks.MockResourceRepo).CreateOrUpdateFunction = func(ctx context.Context, input models.Resource) error { + assert.Equal(t, project, input.Project) + assert.Equal(t, domain, input.Domain) + + var attributesToBeSaved admin.MatchingAttributes + err := proto.Unmarshal(input.Attributes, &attributesToBeSaved) + if err != nil { + t.Fatal(err) + } + assert.Len(t, attributesToBeSaved.GetPluginOverrides().Overrides, 1) + assert.True(t, proto.Equal(attributesToBeSaved.GetPluginOverrides().Overrides[0], &admin.PluginOverride{ + TaskType: "python", + PluginId: []string{"plugin a"}})) + + createOrUpdateCalled = true + return nil + } + manager := NewResourceManager(db, testutils.GetApplicationConfigWithDefaultDomains()) + _, err := manager.UpdateProjectDomainAttributes(context.Background(), request) + assert.NoError(t, err) + assert.True(t, createOrUpdateCalled) + }) + t.Run("merge update", func(t *testing.T) { + db := mocks.NewMockRepository() + db.ResourceRepo().(*mocks.MockResourceRepo).GetFunction = func(ctx context.Context, ID repoInterfaces.ResourceID) ( + models.Resource, error) { + existingAttributes := commonTestUtils.GetPluginOverridesAttributes(map[string][]string{ + "hive": {"plugin b"}, + "python": {"plugin c"}, + }) + bytes, err := proto.Marshal(existingAttributes) + if err != nil { + t.Fatal(err) + } + return models.Resource{ + Project: project, + Domain: domain, + Attributes: bytes, + }, nil + } + var createOrUpdateCalled bool + db.ResourceRepo().(*mocks.MockResourceRepo).CreateOrUpdateFunction = func(ctx context.Context, input models.Resource) error { + assert.Equal(t, project, input.Project) + assert.Equal(t, domain, input.Domain) + + var attributesToBeSaved admin.MatchingAttributes + err := proto.Unmarshal(input.Attributes, &attributesToBeSaved) + if err != nil { + t.Fatal(err) + } + + assert.Len(t, attributesToBeSaved.GetPluginOverrides().Overrides, 2) + for _, override := range attributesToBeSaved.GetPluginOverrides().Overrides { + if override.TaskType == "python" { + assert.EqualValues(t, []string{"plugin a"}, override.PluginId) + } else if override.TaskType == "hive" { + assert.EqualValues(t, []string{"plugin b"}, override.PluginId) + } else { + t.Error(fmt.Sprintf("Unexpected task type [%s] plugin override committed to db", override.TaskType)) + } + } + createOrUpdateCalled = true + return nil + } + manager := NewResourceManager(db, testutils.GetApplicationConfigWithDefaultDomains()) + _, err := manager.UpdateProjectDomainAttributes(context.Background(), request) + assert.NoError(t, err) + assert.True(t, createOrUpdateCalled) + }) +} + func TestGetProjectDomainAttributes(t *testing.T) { request := admin.ProjectDomainAttributesGetRequest{ Project: project, diff --git a/flyteadmin/pkg/manager/impl/validation/attributes_validator.go b/flyteadmin/pkg/manager/impl/validation/attributes_validator.go index 40bbfbe3de..32a4bbfc20 100644 --- a/flyteadmin/pkg/manager/impl/validation/attributes_validator.go +++ b/flyteadmin/pkg/manager/impl/validation/attributes_validator.go @@ -26,6 +26,8 @@ func validateMatchingAttributes(attributes *admin.MatchingAttributes, identifier return admin.MatchableResource_EXECUTION_QUEUE, nil } else if attributes.GetExecutionClusterLabel() != nil { return admin.MatchableResource_EXECUTION_CLUSTER_LABEL, nil + } else if attributes.GetPluginOverrides() != nil { + return admin.MatchableResource_PLUGIN_OVERRIDE, nil } return defaultMatchableResource, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "Unrecognized matching attributes type for request %s", identifier) diff --git a/flyteadmin/pkg/manager/impl/validation/attributes_validator_test.go b/flyteadmin/pkg/manager/impl/validation/attributes_validator_test.go index 12cd1cdf24..3e0f937823 100644 --- a/flyteadmin/pkg/manager/impl/validation/attributes_validator_test.go +++ b/flyteadmin/pkg/manager/impl/validation/attributes_validator_test.go @@ -68,6 +68,23 @@ func TestValidateMatchingAttributes(t *testing.T) { admin.MatchableResource_EXECUTION_QUEUE, nil, }, + { + &admin.MatchingAttributes{ + Target: &admin.MatchingAttributes_PluginOverrides{ + PluginOverrides: &admin.PluginOverrides{ + Overrides: []*admin.PluginOverride{ + { + TaskType: "python", + PluginId: []string{"foo"}, + }, + }, + }, + }, + }, + "foo", + admin.MatchableResource_PLUGIN_OVERRIDE, + nil, + }, } for _, tc := range testCases { matchableResource, err := validateMatchingAttributes(tc.attributes, tc.identifier) diff --git a/flyteadmin/pkg/repositories/transformers/resource.go b/flyteadmin/pkg/repositories/transformers/resource.go index fb40ca05f1..6c5736a5ab 100644 --- a/flyteadmin/pkg/repositories/transformers/resource.go +++ b/flyteadmin/pkg/repositories/transformers/resource.go @@ -1,7 +1,11 @@ package transformers import ( + "context" + "github.com/golang/protobuf/proto" + repoInterfaces "github.com/lyft/flyteadmin/pkg/repositories/interfaces" + "github.com/lyft/flytestdlib/logger" "github.com/lyft/flyteadmin/pkg/errors" "github.com/lyft/flyteadmin/pkg/repositories/models" @@ -24,6 +28,61 @@ func WorkflowAttributesToResourceModel(attributes admin.WorkflowAttributes, reso }, nil } +func mergeUpdatePluginOverrides(existingAttributes admin.MatchingAttributes, + newMatchingAttributes *admin.MatchingAttributes) *admin.MatchingAttributes { + taskPluginOverrides := make(map[string]*admin.PluginOverride) + if existingAttributes.GetPluginOverrides() != nil && len(existingAttributes.GetPluginOverrides().Overrides) > 0 { + for _, pluginOverride := range existingAttributes.GetPluginOverrides().Overrides { + taskPluginOverrides[pluginOverride.TaskType] = pluginOverride + } + } + if newMatchingAttributes.GetPluginOverrides() != nil && + len(newMatchingAttributes.GetPluginOverrides().Overrides) > 0 { + for _, pluginOverride := range newMatchingAttributes.GetPluginOverrides().Overrides { + taskPluginOverrides[pluginOverride.TaskType] = pluginOverride + } + } + + updatedPluginOverrides := make([]*admin.PluginOverride, 0, len(taskPluginOverrides)) + for _, pluginOverride := range taskPluginOverrides { + updatedPluginOverrides = append(updatedPluginOverrides, pluginOverride) + } + return &admin.MatchingAttributes{ + Target: &admin.MatchingAttributes_PluginOverrides{ + PluginOverrides: &admin.PluginOverrides{ + Overrides: updatedPluginOverrides, + }, + }, + } +} + +func MergeUpdateWorkflowAttributes(ctx context.Context, model models.Resource, resource admin.MatchableResource, + resourceID *repoInterfaces.ResourceID, workflowAttributes *admin.WorkflowAttributes) (models.Resource, error) { + switch resource { + case admin.MatchableResource_PLUGIN_OVERRIDE: + var existingAttributes admin.MatchingAttributes + err := proto.Unmarshal(model.Attributes, &existingAttributes) + if err != nil { + return models.Resource{}, errors.NewFlyteAdminErrorf(codes.Internal, + "Unable to unmarshal existing resource attributes for [%+v] with err: %v", resourceID, err) + } + updatedAttributes := mergeUpdatePluginOverrides(existingAttributes, workflowAttributes.GetMatchingAttributes()) + marshaledAttributes, err := proto.Marshal(updatedAttributes) + if err != nil { + return models.Resource{}, errors.NewFlyteAdminErrorf(codes.Internal, + "Failed to marshal merge-updated attributes for [%+v] with err: %v", resourceID, err) + } + model.Attributes = marshaledAttributes + return model, nil + default: + logger.Warningf(ctx, "Tried to merge-update an unsupported resource type [%s] for [%+v]", + resource.String(), resourceID) + return models.Resource{}, errors.NewFlyteAdminErrorf(codes.Internal, + "Tried to merge-update an unsupported resource type [%s] for [%+v]", + resource.String(), resourceID) + } +} + func FromResourceModelToWorkflowAttributes(model models.Resource) (admin.WorkflowAttributes, error) { var attributes admin.MatchingAttributes err := proto.Unmarshal(model.Attributes, &attributes) @@ -53,6 +112,33 @@ func ProjectDomainAttributesToResourceModel(attributes admin.ProjectDomainAttrib }, nil } +func MergeUpdateProjectDomainAttributes(ctx context.Context, model models.Resource, resource admin.MatchableResource, + resourceID *repoInterfaces.ResourceID, attributes *admin.ProjectDomainAttributes) (models.Resource, error) { + switch resource { + case admin.MatchableResource_PLUGIN_OVERRIDE: + var existingAttributes admin.MatchingAttributes + err := proto.Unmarshal(model.Attributes, &existingAttributes) + if err != nil { + return models.Resource{}, errors.NewFlyteAdminErrorf(codes.Internal, + "Unable to unmarshal existing resource attributes for [%+v] with err: %v", resourceID, err) + } + updatedAttributes := mergeUpdatePluginOverrides(existingAttributes, attributes.GetMatchingAttributes()) + marshaledAttributes, err := proto.Marshal(updatedAttributes) + if err != nil { + return models.Resource{}, errors.NewFlyteAdminErrorf(codes.Internal, + "Failed to marshal merge-updated attributes for [%+v] with err: %v", resourceID, err) + } + model.Attributes = marshaledAttributes + return model, nil + default: + logger.Warningf(ctx, "Tried to merge-update an unsupported resource type [%s] for [%+v]", + resource.String(), resourceID) + return models.Resource{}, errors.NewFlyteAdminErrorf(codes.Internal, + "Tried to merge-update an unsupported resource type [%s] for [%+v]", + resource.String(), resourceID) + } +} + func FromResourceModelToProjectDomainAttributes(model models.Resource) (admin.ProjectDomainAttributes, error) { var attributes admin.MatchingAttributes err := proto.Unmarshal(model.Attributes, &attributes) diff --git a/flyteadmin/pkg/repositories/transformers/resource_test.go b/flyteadmin/pkg/repositories/transformers/resource_test.go index 7d0fbd30eb..7d281ce616 100644 --- a/flyteadmin/pkg/repositories/transformers/resource_test.go +++ b/flyteadmin/pkg/repositories/transformers/resource_test.go @@ -1,11 +1,15 @@ package transformers import ( + "context" "testing" + "github.com/lyft/flyteadmin/pkg/common/testutils" + "github.com/golang/protobuf/proto" "github.com/lyft/flyteadmin/pkg/errors" + repoInterfaces "github.com/lyft/flyteadmin/pkg/repositories/interfaces" "github.com/lyft/flyteadmin/pkg/repositories/models" "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" "google.golang.org/grpc/codes" @@ -13,6 +17,10 @@ import ( "github.com/stretchr/testify/assert" ) +const resourceProject = "project" +const resourceDomain = "domain" +const resourceWorkflow = "workflow" + var matchingClusterResourceAttributes = &admin.MatchingAttributes{ Target: &admin.MatchingAttributes_ClusterResourceAttributes{ ClusterResourceAttributes: &admin.ClusterResourceAttributes{ @@ -24,9 +32,9 @@ var matchingClusterResourceAttributes = &admin.MatchingAttributes{ } var workflowAttributes = admin.WorkflowAttributes{ - Project: "project", - Domain: "domain", - Workflow: "workflow", + Project: resourceProject, + Domain: resourceDomain, + Workflow: resourceWorkflow, MatchingAttributes: matchingClusterResourceAttributes, } @@ -43,8 +51,8 @@ var matchingExecutionQueueAttributes = &admin.MatchingAttributes{ } var projectDomainAttributes = admin.ProjectDomainAttributes{ - Project: "project", - Domain: "domain", + Project: resourceProject, + Domain: resourceDomain, MatchingAttributes: matchingExecutionQueueAttributes, } @@ -55,18 +63,77 @@ func TestToProjectDomainAttributesModel(t *testing.T) { model, err := ProjectDomainAttributesToResourceModel(projectDomainAttributes, admin.MatchableResource_EXECUTION_QUEUE) assert.Nil(t, err) assert.EqualValues(t, models.Resource{ - Project: "project", - Domain: "domain", + Project: resourceProject, + Domain: resourceDomain, ResourceType: admin.MatchableResource_EXECUTION_QUEUE.String(), Priority: models.ResourcePriorityProjectDomainLevel, Attributes: marshalledExecutionQueueAttributes, }, model) } +func TestMergeUpdateProjectDomainAttributes(t *testing.T) { + t.Run("plugin override", func(t *testing.T) { + existingWorkflowAttributes, _ := proto.Marshal(testutils.GetPluginOverridesAttributes(map[string][]string{ + "python": {"plugin_a"}, + "hive": {"plugin_b"}, + })) + + existingModel := models.Resource{ + ID: 1, + Project: resourceProject, + Domain: resourceDomain, + Workflow: resourceWorkflow, + ResourceType: "PLUGIN_OVERRIDE", + Attributes: existingWorkflowAttributes, + } + mergeUpdatedModel, err := MergeUpdateProjectDomainAttributes(context.Background(), existingModel, + admin.MatchableResource_PLUGIN_OVERRIDE, &repoInterfaces.ResourceID{}, &admin.ProjectDomainAttributes{ + Project: resourceProject, + Domain: resourceDomain, + MatchingAttributes: testutils.GetPluginOverridesAttributes(map[string][]string{ + "sidecar": {"plugin_c"}, + "hive": {"plugin_d"}, + }), + }) + assert.NoError(t, err) + var updatedAttributes admin.MatchingAttributes + err = proto.Unmarshal(mergeUpdatedModel.Attributes, &updatedAttributes) + assert.NoError(t, err) + var sawPythonTask, sawSidecarTask, sawHiveTask bool + for _, override := range updatedAttributes.GetPluginOverrides().GetOverrides() { + if override.TaskType == "python" { + sawPythonTask = true + assert.EqualValues(t, []string{"plugin_a"}, override.PluginId) + } else if override.TaskType == "sidecar" { + sawSidecarTask = true + assert.EqualValues(t, []string{"plugin_c"}, override.PluginId) + } else if override.TaskType == "hive" { + sawHiveTask = true + assert.EqualValues(t, []string{"plugin_d"}, override.PluginId) + } + } + assert.True(t, sawPythonTask, "Missing python task from finalized attributes") + assert.True(t, sawSidecarTask, "Missing sidecar task from finalized attributes") + assert.True(t, sawHiveTask, "Missing hive task from finalized attributes") + }) + t.Run("unsupported resource type", func(t *testing.T) { + existingModel := models.Resource{ + ID: 1, + Project: resourceProject, + Domain: resourceDomain, + Workflow: resourceWorkflow, + ResourceType: "PLUGIN_OVERRIDE", + } + _, err := MergeUpdateProjectDomainAttributes(context.Background(), existingModel, + admin.MatchableResource_TASK_RESOURCE, &repoInterfaces.ResourceID{}, &admin.ProjectDomainAttributes{}) + assert.Error(t, err, "unsupported resource type") + }) +} + func TestFromProjectDomainAttributesModel(t *testing.T) { model := models.Resource{ - Project: "project", - Domain: "domain", + Project: resourceProject, + Domain: resourceDomain, ResourceType: admin.MatchableResource_EXECUTION_QUEUE.String(), Attributes: marshalledExecutionQueueAttributes, } @@ -77,8 +144,8 @@ func TestFromProjectDomainAttributesModel(t *testing.T) { func TestFromProjectDomainAttributesModel_InvalidResourceAttributes(t *testing.T) { model := models.Resource{ - Project: "project", - Domain: "domain", + Project: resourceProject, + Domain: resourceDomain, ResourceType: admin.MatchableResource_EXECUTION_QUEUE.String(), Attributes: []byte("i'm invalid!"), } @@ -91,19 +158,79 @@ func TestToWorkflowAttributesModel(t *testing.T) { model, err := WorkflowAttributesToResourceModel(workflowAttributes, admin.MatchableResource_EXECUTION_QUEUE) assert.Nil(t, err) assert.EqualValues(t, models.Resource{ - Project: "project", - Domain: "domain", - Workflow: "workflow", + Project: resourceProject, + Domain: resourceDomain, + Workflow: resourceWorkflow, ResourceType: admin.MatchableResource_EXECUTION_QUEUE.String(), Priority: models.ResourcePriorityWorkflowLevel, Attributes: marshalledClusterResourceAttributes, }, model) } +func TestMergeUpdateWorkflowAttributes(t *testing.T) { + t.Run("plugin override", func(t *testing.T) { + existingWorkflowAttributes, _ := proto.Marshal(testutils.GetPluginOverridesAttributes(map[string][]string{ + "python": {"plugin_a"}, + "hive": {"plugin_b"}, + })) + + existingModel := models.Resource{ + ID: 1, + Project: resourceProject, + Domain: resourceDomain, + Workflow: resourceWorkflow, + ResourceType: "PLUGIN_OVERRIDE", + Attributes: existingWorkflowAttributes, + } + mergeUpdatedModel, err := MergeUpdateWorkflowAttributes(context.Background(), existingModel, + admin.MatchableResource_PLUGIN_OVERRIDE, &repoInterfaces.ResourceID{}, &admin.WorkflowAttributes{ + Project: resourceProject, + Domain: resourceDomain, + Workflow: resourceWorkflow, + MatchingAttributes: testutils.GetPluginOverridesAttributes(map[string][]string{ + "sidecar": {"plugin_c"}, + "hive": {"plugin_d"}, + }), + }) + assert.NoError(t, err) + var updatedAttributes admin.MatchingAttributes + err = proto.Unmarshal(mergeUpdatedModel.Attributes, &updatedAttributes) + assert.NoError(t, err) + var sawPythonTask, sawSidecarTask, sawHiveTask bool + for _, override := range updatedAttributes.GetPluginOverrides().GetOverrides() { + if override.TaskType == "python" { + sawPythonTask = true + assert.EqualValues(t, []string{"plugin_a"}, override.PluginId) + } else if override.TaskType == "sidecar" { + sawSidecarTask = true + assert.EqualValues(t, []string{"plugin_c"}, override.PluginId) + } else if override.TaskType == "hive" { + sawHiveTask = true + assert.EqualValues(t, []string{"plugin_d"}, override.PluginId) + } + } + assert.True(t, sawPythonTask, "Missing python task from finalized attributes") + assert.True(t, sawSidecarTask, "Missing sidecar task from finalized attributes") + assert.True(t, sawHiveTask, "Missing hive task from finalized attributes") + }) + t.Run("unsupported resource type", func(t *testing.T) { + existingModel := models.Resource{ + ID: 1, + Project: resourceProject, + Domain: resourceDomain, + Workflow: resourceWorkflow, + ResourceType: "TASK_RESOURCE", + } + _, err := MergeUpdateWorkflowAttributes(context.Background(), existingModel, + admin.MatchableResource_TASK_RESOURCE, &repoInterfaces.ResourceID{}, &admin.WorkflowAttributes{}) + assert.Error(t, err, "unsupported resource type") + }) +} + func TestFromWorkflowAttributesModel(t *testing.T) { model := models.Resource{ - Project: "project", - Domain: "domain", + Project: resourceProject, + Domain: resourceDomain, Workflow: "workflow", ResourceType: admin.MatchableResource_EXECUTION_QUEUE.String(), Attributes: marshalledClusterResourceAttributes, @@ -115,8 +242,8 @@ func TestFromWorkflowAttributesModel(t *testing.T) { func TestFromWorkflowAttributesModel_InvalidResourceAttributes(t *testing.T) { model := models.Resource{ - Project: "project", - Domain: "domain", + Project: resourceProject, + Domain: resourceDomain, Workflow: "workflow", ResourceType: admin.MatchableResource_EXECUTION_QUEUE.String(), Attributes: []byte("i'm invalid!"), diff --git a/flyteadmin/pkg/workflowengine/impl/propeller_executor.go b/flyteadmin/pkg/workflowengine/impl/propeller_executor.go index 1bc4ba06cd..f952293114 100644 --- a/flyteadmin/pkg/workflowengine/impl/propeller_executor.go +++ b/flyteadmin/pkg/workflowengine/impl/propeller_executor.go @@ -91,6 +91,23 @@ func (c *FlytePropeller) addPermissions(launchPlan admin.LaunchPlan, flyteWf *v1 } } +func addExecutionOverrides(taskPluginOverrides []*admin.PluginOverride, flyteWf *v1alpha1.FlyteWorkflow) { + executionConfig := v1alpha1.ExecutionConfig{ + TaskPluginImpls: make(map[string]v1alpha1.TaskPluginOverride), + } + if len(taskPluginOverrides) == 0 { + return + } + for _, override := range taskPluginOverrides { + executionConfig.TaskPluginImpls[override.TaskType] = v1alpha1.TaskPluginOverride{ + PluginIDs: override.PluginId, + MissingPluginBehavior: override.MissingPluginBehavior, + } + + } + flyteWf.ExecutionConfig = executionConfig +} + func (c *FlytePropeller) ExecuteWorkflow(ctx context.Context, input interfaces.ExecuteWorkflowInput) (*interfaces.ExecutionInfo, error) { if input.ExecutionID == nil { c.metrics.InvalidExecutionID.Inc() @@ -121,6 +138,7 @@ func (c *FlytePropeller) ExecuteWorkflow(ctx context.Context, input interfaces.E flyteWf.Labels = labels annotations := addMapValues(input.Annotations, flyteWf.Annotations) flyteWf.Annotations = annotations + addExecutionOverrides(input.TaskPluginOverrides, flyteWf) if input.Reference.Spec.RawOutputDataConfig != nil { flyteWf.RawOutputDataConfig = v1alpha1.RawOutputDataConfig{ diff --git a/flyteadmin/pkg/workflowengine/impl/propeller_executor_test.go b/flyteadmin/pkg/workflowengine/impl/propeller_executor_test.go index beea62cbc8..3a9217b3f1 100644 --- a/flyteadmin/pkg/workflowengine/impl/propeller_executor_test.go +++ b/flyteadmin/pkg/workflowengine/impl/propeller_executor_test.go @@ -147,6 +147,13 @@ func TestExecuteWorkflowHappyCase(t *testing.T) { "customannotation": "annotationval", } assert.EqualValues(t, expectedAnnotations, workflow.Annotations) + + assert.EqualValues(t, map[string]v1alpha1.TaskPluginOverride{ + "python": { + PluginIDs: []string{"plugin a"}, + MissingPluginBehavior: admin.PluginOverride_USE_DEFAULT, + }, + }, workflow.ExecutionConfig.TaskPluginImpls) return nil, nil }, } @@ -184,6 +191,13 @@ func TestExecuteWorkflowHappyCase(t *testing.T) { Annotations: map[string]string{ "customannotation": "annotationval", }, + TaskPluginOverrides: []*admin.PluginOverride{ + { + TaskType: "python", + PluginId: []string{"plugin a"}, + MissingPluginBehavior: admin.PluginOverride_USE_DEFAULT, + }, + }, }) assert.Nil(t, err) assert.NotNil(t, execInfo) diff --git a/flyteadmin/pkg/workflowengine/interfaces/executor.go b/flyteadmin/pkg/workflowengine/interfaces/executor.go index 471ec98b53..bd468ba28b 100644 --- a/flyteadmin/pkg/workflowengine/interfaces/executor.go +++ b/flyteadmin/pkg/workflowengine/interfaces/executor.go @@ -10,14 +10,15 @@ import ( ) type ExecuteWorkflowInput struct { - ExecutionID *core.WorkflowExecutionIdentifier - WfClosure core.CompiledWorkflowClosure - Inputs *core.LiteralMap - Reference admin.LaunchPlan - AcceptedAt time.Time - Labels map[string]string - Annotations map[string]string - QueueingBudget time.Duration + ExecutionID *core.WorkflowExecutionIdentifier + WfClosure core.CompiledWorkflowClosure + Inputs *core.LiteralMap + Reference admin.LaunchPlan + AcceptedAt time.Time + Labels map[string]string + Annotations map[string]string + QueueingBudget time.Duration + TaskPluginOverrides []*admin.PluginOverride } type ExecuteTaskInput struct {