diff --git a/flyteadmin/Gopkg.lock b/flyteadmin/Gopkg.lock index bc5bc3d56..15d388e9b 100644 --- a/flyteadmin/Gopkg.lock +++ b/flyteadmin/Gopkg.lock @@ -54,7 +54,7 @@ version = "v1.0.7" [[projects]] - digest = "1:f5eadf80885cb45a08e0baf4c86079016a27bf9df179fba726a806d8a4dcd1c8" + digest = "1:39cc2836b0733b809ab56eb991d3d58c96c08b6381fded8dc636c53656a0f5cc" name = "github.com/aws/aws-sdk-go" packages = [ "aws", @@ -110,8 +110,8 @@ "service/sts/stsiface", ] pruneopts = "UT" - revision = "a27888d3507a35a68e9240d5d7cf0d9552e6ef3b" - version = "v1.28.2" + revision = "5abff723756e51686a6a4377e7c5b1d74a9e18a3" + version = "v1.28.5" [[projects]] digest = "1:390dd851aedb37e9ef896e7e4e5c8553576798f8be769657e2f507090c96449d" @@ -456,7 +456,7 @@ version = "v1.3.0" [[projects]] - digest = "1:9f01843d5ce98395a8af3da88dc6c67d169e0e94d253337026c423e2b18cdf4f" + digest = "1:93e28b61815df4d76c520ffe48f5995503c581f3752a823be85141cc278b6745" name = "github.com/lyft/flyteidl" packages = [ "clients/go/admin", @@ -469,9 +469,9 @@ "gen/pb-go/flyteidl/service", ] pruneopts = "UT" - revision = "28c0dfb6608b70262aac9cb1ff83a750521ded8e" + revision = "ed373f903580e70fa1a08820b61d124163cb52da" source = "https://github.com/lyft/flyteidl" - version = "v0.16.5" + version = "v0.16.6" [[projects]] digest = "1:aef670729de939f423811651dcc81328bb82f15ab3a6026db7c32e18b4dc9d9a" @@ -726,12 +726,12 @@ version = "v1.0.5" [[projects]] - digest = "1:8d0b79a29be9946ea00b2f2b778ec5b36db07453d97abe9d754b25bc2cc49a0e" + digest = "1:83d0e0f3f46dc86daf27e4d7c834c8c492db787c53c09d94183de7478db92142" name = "github.com/spf13/viper" packages = ["."] pruneopts = "UT" - revision = "eabbc68a3ecd5cf8c11a2f84dbda5e7a38493b2f" - version = "v1.6.1" + revision = "4525543ce4fe90f7970f5e2cdc300b8ffc8c0582" + version = "v1.6.2" [[projects]] digest = "1:ac83cf90d08b63ad5f7e020ef480d319ae890c208f8524622a2f3136e2686b02" @@ -787,7 +787,7 @@ [[projects]] branch = "master" - digest = "1:08644aeff6284192f9b529821257b83245e3a3a48e216074929629f62c3c5799" + digest = "1:096ab5deba256597722219f6bebd2efeec569368ba281b5df172937478049a6b" name = "golang.org/x/crypto" packages = [ "ed25519", @@ -796,7 +796,7 @@ "ssh/terminal", ] pruneopts = "UT" - revision = "61a87790db17894570dfb32dbaa0a4af9ce60cb4" + revision = "530e935923ad688be97c15eeb8e5ee42ebf2b54a" [[projects]] branch = "master" @@ -832,14 +832,14 @@ [[projects]] branch = "master" - digest = "1:c8faf148023f13e62a93d3ed6c01fc586b3c6b52a40c5c9b0f2220bae71f390a" + digest = "1:f7b9c8ac6e14581494d8f20a044878b1d23aac99e49bc4f1ac831a973fb9ccfd" name = "golang.org/x/sys" packages = [ "unix", "windows", ] pruneopts = "UT" - revision = "86b910548bc16777f40503131aa424ae0a092199" + revision = "59e60aa80a0c64fa4b088976ee16ad7f04252c25" [[projects]] digest = "1:8d8faad6b12a3a4c819a3f9618cb6ee1fa1cfc33253abeeea8b55336721e3405" @@ -887,7 +887,7 @@ [[projects]] branch = "master" - digest = "1:92be4e1210ff0dfe76525dbfeadeab782f41622cf8510b6a132ccffe04203ab3" + digest = "1:8da46a0888e0645a407823ba529400ab5150831745941b93b74bb69c5f756da2" name = "google.golang.org/api" packages = [ "googleapi", @@ -901,7 +901,7 @@ "transport/http/internal/propagation", ] pruneopts = "UT" - revision = "abe50cf84d67c0039411f2fb036cfa09991a6e32" + revision = "b4cd77d6a56cb166fb09be9d2264a9f42264fffd" [[projects]] digest = "1:3c03b58f57452764a4499c55c582346c0ee78c8a5033affe5bdfd9efd3da5bd1" @@ -933,13 +933,14 @@ "protobuf/field_mask", ] pruneopts = "UT" - revision = "e1de0a7b01eb2fc11d735e4bfb79d2e53ec9edb3" + revision = "32f20d992d240fbca6ef7dec6c05d1f024314e02" [[projects]] - digest = "1:24c7bbb52f35264c506e5d404a257408937953cab7f39b898bd7edde0103b52a" + digest = "1:3abb0cc9fcc8c2a4bbe49b6d2c5894d7ba782b17a3463fdbfc5820b764633643" name = "google.golang.org/grpc" packages = [ ".", + "attributes", "backoff", "balancer", "balancer/base", @@ -978,8 +979,8 @@ "tap", ] pruneopts = "UT" - revision = "1a3960e4bd028ac0cec0a2afd27d7d8e67c11514" - version = "v1.25.1" + revision = "f5b0812e6fe574d90da76b205e9eb51f6ddb1919" + version = "v1.26.0" [[projects]] digest = "1:1048ae210f190cd7b6aea19a92a055bd6112b025dd49f560579dfdfd76c8c42e" diff --git a/flyteadmin/Gopkg.toml b/flyteadmin/Gopkg.toml index 922396a05..9152bdcb1 100644 --- a/flyteadmin/Gopkg.toml +++ b/flyteadmin/Gopkg.toml @@ -60,7 +60,7 @@ [[override]] name = "github.com/lyft/flyteidl" source = "https://github.com/lyft/flyteidl" - version = "=0.16.5" + version = "^0.16.6" [[constraint]] name = "github.com/lyft/flytepropeller" @@ -92,11 +92,6 @@ name = "gopkg.in/gormigrate.v1" version = "1.2.1" -[[override]] - name = "k8s.io/apimachinery" - source = "https://github.com/lyft/apimachinery" - revision = "047e3ea32d7fb5984f444d7dd9510cfd362d7d7c" - [[override]] name = "k8s.io/api" revision = "b49a72c274e072a6e385d55c671acb3717186ce5" @@ -122,6 +117,11 @@ go-tests = true unused-packages = true +[[override]] + name = "k8s.io/apimachinery" + source = "https://github.com/lyft/apimachinery" + revision = "047e3ea32d7fb5984f444d7dd9510cfd362d7d7c" + [[constraint]] name = "github.com/graymeta/stow" revision = "903027f87de7054953efcdb8ba70d5dc02df38c7" diff --git a/flyteadmin/pkg/clusterresource/controller.go b/flyteadmin/pkg/clusterresource/controller.go index ca5d79902..2a3b20674 100644 --- a/flyteadmin/pkg/clusterresource/controller.go +++ b/flyteadmin/pkg/clusterresource/controller.go @@ -11,7 +11,8 @@ import ( "strings" "time" - "github.com/lyft/flyteadmin/pkg/resourcematching" + "github.com/lyft/flyteadmin/pkg/manager/impl/resources" + managerinterfaces "github.com/lyft/flyteadmin/pkg/manager/interfaces" "github.com/lyft/flyteadmin/pkg/executioncluster/interfaces" @@ -70,6 +71,7 @@ type controller struct { db repositories.RepositoryInterface config runtimeInterfaces.Configuration executionCluster interfaces.ClusterInterface + resourceManager managerinterfaces.ResourceInterface poller chan struct{} metrics controllerMetrics lastAppliedTemplateDir string @@ -175,17 +177,16 @@ func (c *controller) getCustomTemplateValues( } collectedErrs := make([]error, 0) // All override values saved in the database take precedence over the domain-specific defaults. - attributes, err := resourcematching.GetOverrideValuesToApply(ctx, resourcematching.GetOverrideValuesInput{ - Db: c.db, - Project: project, - Domain: domain, - Resource: admin.MatchableResource_CLUSTER_RESOURCE, + resource, err := c.resourceManager.GetResource(ctx, managerinterfaces.ResourceRequest{ + Project: project, + Domain: domain, + ResourceType: admin.MatchableResource_CLUSTER_RESOURCE, }) if err != nil { collectedErrs = append(collectedErrs, err) } - if attributes != nil && attributes.GetClusterResourceAttributes() != nil { - for templateKey, templateValue := range attributes.GetClusterResourceAttributes().Attributes { + if resource != nil && resource.Attributes != nil && resource.Attributes.GetClusterResourceAttributes() != nil { + for templateKey, templateValue := range resource.Attributes.GetClusterResourceAttributes().Attributes { customTemplateValues[fmt.Sprintf(templateVariableFormat, templateKey)] = templateValue } } @@ -285,7 +286,7 @@ func (c *controller) syncNamespace(ctx context.Context, namespace NamespaceName, err = target.Client.Create(ctx, k8sObjCopy) if err != nil { if k8serrors.IsAlreadyExists(err) { - logger.Debugf(ctx, "Resource [%+v] in namespace [%s] already exists - attempting update instead", + logger.Debugf(ctx, "Type [%+v] in namespace [%s] already exists - attempting update instead", k8sObj.GetObjectKind().GroupVersionKind().Kind, namespace) c.metrics.AppliedTemplateExists.Inc() // Use a strategic-merge-patch to mimic `kubectl apply` behavior. @@ -423,6 +424,7 @@ func NewClusterResourceController(db repositories.RepositoryInterface, execution db: db, config: config, executionCluster: executionCluster, + resourceManager: resources.NewResourceManager(db), poller: make(chan struct{}), metrics: newMetrics(scope), appliedTemplates: make(map[string]map[string]time.Time), diff --git a/flyteadmin/pkg/clusterresource/controller_test.go b/flyteadmin/pkg/clusterresource/controller_test.go index c91fc5d4c..a69d83a90 100644 --- a/flyteadmin/pkg/clusterresource/controller_test.go +++ b/flyteadmin/pkg/clusterresource/controller_test.go @@ -7,6 +7,9 @@ import ( "testing" "time" + "github.com/lyft/flyteadmin/pkg/manager/impl/resources" + "github.com/lyft/flyteadmin/pkg/repositories/interfaces" + "github.com/lyft/flyteadmin/pkg/repositories/transformers" "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" @@ -161,16 +164,16 @@ func TestGetCustomTemplateValues(t *testing.T) { }, }, } - projectDomainModel, err := transformers.ToProjectDomainAttributesModel(projectDomainAttributes, admin.MatchableResource_CLUSTER_RESOURCE) + resourceModel, err := transformers.ProjectDomainAttributesToResourceModel(projectDomainAttributes, admin.MatchableResource_CLUSTER_RESOURCE) assert.Nil(t, err) - mockRepository.ProjectDomainAttributesRepo().(*repositoryMocks.MockProjectDomainAttributesRepo).GetFunction = func( - ctx context.Context, project, domain, resource string) (models.ProjectDomainAttributes, error) { - assert.Equal(t, "project-foo", project) - assert.Equal(t, "domain-bar", domain) - return projectDomainModel, nil + mockRepository.ResourceRepo().(*repositoryMocks.MockResourceRepo).GetFunction = func(ctx context.Context, ID interfaces.ResourceID) (resource models.Resource, e error) { + assert.Equal(t, "project-foo", ID.Project) + assert.Equal(t, "domain-bar", ID.Domain) + return resourceModel, nil } testController := controller{ - db: mockRepository, + db: mockRepository, + resourceManager: resources.NewResourceManager(mockRepository), } domainTemplateValues := templateValuesType{ "{{ var1 }}": "i'm getting overwritten", @@ -191,7 +194,8 @@ func TestGetCustomTemplateValues(t *testing.T) { func TestGetCustomTemplateValues_NothingToOverride(t *testing.T) { mockRepository := repositoryMocks.NewMockRepository() testController := controller{ - db: mockRepository, + db: mockRepository, + resourceManager: resources.NewResourceManager(mockRepository), } customTemplateValues, err := testController.getCustomTemplateValues(context.Background(), "project-foo", "domain-bar", templateValuesType{ "{{ var1 }}": "val1", @@ -207,14 +211,14 @@ func TestGetCustomTemplateValues_NothingToOverride(t *testing.T) { func TestGetCustomTemplateValues_InvalidDBModel(t *testing.T) { mockRepository := repositoryMocks.NewMockRepository() - mockRepository.ProjectDomainAttributesRepo().(*repositoryMocks.MockProjectDomainAttributesRepo).GetFunction = func( - ctx context.Context, project, domain, resource string) (models.ProjectDomainAttributes, error) { - return models.ProjectDomainAttributes{ + mockRepository.ResourceRepo().(*repositoryMocks.MockResourceRepo).GetFunction = func(ctx context.Context, ID interfaces.ResourceID) (resource models.Resource, e error) { + return models.Resource{ Attributes: []byte("i'm invalid"), }, nil } testController := controller{ - db: mockRepository, + db: mockRepository, + resourceManager: resources.NewResourceManager(mockRepository), } _, err := testController.getCustomTemplateValues(context.Background(), "project-foo", "domain-bar", templateValuesType{ "{{ var1 }}": "val1", diff --git a/flyteadmin/pkg/manager/impl/execution_manager.go b/flyteadmin/pkg/manager/impl/execution_manager.go index 5d8ea4fca..16477798c 100644 --- a/flyteadmin/pkg/manager/impl/execution_manager.go +++ b/flyteadmin/pkg/manager/impl/execution_manager.go @@ -6,6 +6,8 @@ import ( "strconv" "time" + "github.com/lyft/flyteadmin/pkg/manager/impl/resources" + "github.com/golang/protobuf/ptypes" "github.com/golang/protobuf/ptypes/timestamp" dataInterfaces "github.com/lyft/flyteadmin/pkg/data/interfaces" @@ -185,6 +187,130 @@ func (m *ExecutionManager) offloadInputs(ctx context.Context, literalMap *core.L return inputsURI, nil } +func createTaskDefaultLimits(ctx context.Context, task *core.CompiledTask) runtimeInterfaces.TaskResourceSet { + // The values below should never be used (deduce it from the request; request should be set by the time we get here). + // Setting them here just in case we end up with requests not set. We are not adding to config because it would add + // more confusion as its mostly not used. + cpuLimit := "500m" + memoryLimit := "500Mi" + resourceEntries := task.Template.GetContainer().Resources.Requests + var cpuIndex, memoryIndex = -1, -1 + for idx, entry := range resourceEntries { + switch entry.Name { + case core.Resources_CPU: + cpuIndex = idx + + case core.Resources_MEMORY: + memoryIndex = idx + } + } + + if cpuIndex < 0 || memoryIndex < 0 { + logger.Errorf(ctx, "Cpu request and Memory request missing for %s", task.Template.Id) + } + + if cpuIndex >= 0 { + cpuLimit = resourceEntries[cpuIndex].Value + } + if memoryIndex >= 0 { + memoryLimit = resourceEntries[memoryIndex].Value + } + + return runtimeInterfaces.TaskResourceSet{CPU: cpuLimit, Memory: memoryLimit} +} + +func assignResourcesIfUnset(ctx context.Context, identifier *core.Identifier, + platformValues runtimeInterfaces.TaskResourceSet, + resourceEntries []*core.Resources_ResourceEntry, taskResourceSpec *admin.TaskResourceSpec) []*core.Resources_ResourceEntry { + var cpuIndex, memoryIndex = -1, -1 + for idx, entry := range resourceEntries { + switch entry.Name { + case core.Resources_CPU: + cpuIndex = idx + case core.Resources_MEMORY: + memoryIndex = idx + } + } + if cpuIndex > 0 && memoryIndex > 0 { + // nothing to do + return resourceEntries + } + + if cpuIndex < 0 && platformValues.CPU != "" { + logger.Debugf(ctx, "Setting 'cpu' for [%+v] to %s", identifier, platformValues.CPU) + cpuValue := platformValues.CPU + if taskResourceSpec != nil && len(taskResourceSpec.Cpu) > 0 { + // Use the custom attributes from the database rather than the platform defaults from the application config + cpuValue = taskResourceSpec.Cpu + } + cpuResource := &core.Resources_ResourceEntry{ + Name: core.Resources_CPU, + Value: cpuValue, + } + resourceEntries = append(resourceEntries, cpuResource) + } + if memoryIndex < 0 && platformValues.Memory != "" { + memoryValue := platformValues.Memory + if taskResourceSpec != nil && len(taskResourceSpec.Memory) > 0 { + // Use the custom attributes from the database rather than the platform defaults from the application config + memoryValue = taskResourceSpec.Memory + } + memoryResource := &core.Resources_ResourceEntry{ + Name: core.Resources_MEMORY, + Value: memoryValue, + } + logger.Debugf(ctx, "Setting 'memory' for [%+v] to %s", identifier, platformValues.Memory) + resourceEntries = append(resourceEntries, memoryResource) + } + return resourceEntries +} + +// Assumes input contains a compiled task with a valid container resource execConfig. +// +// Note: The system will assign a system-default value for request but for limit it will deduce it from the request +// itself => Limit := Min([Some-Multiplier X Request], System-Max). For now we are using a multiplier of 1. In +// general we recommend the users to set limits close to requests for more predictability in the system. +func setCompiledTaskDefaults(ctx context.Context, taskConfig runtimeInterfaces.TaskResourceConfiguration, task *core.CompiledTask, + db repositories.RepositoryInterface, workflowName string) { + resourceManager := resources.NewResourceManager(db) + if task == nil { + logger.Warningf(ctx, "Can't set default resources for nil task.") + return + } + if task.Template == nil || task.Template.GetContainer() == nil || task.Template.GetContainer().Resources == nil { + // Nothing to do + logger.Debugf(ctx, "Not setting default resources for task [%+v], no container resources found to check", task) + return + } + resource, err := resourceManager.GetResource(ctx, interfaces.ResourceRequest{ + Project: task.Template.Id.Project, + Domain: task.Template.Id.Domain, + Workflow: workflowName, + ResourceType: admin.MatchableResource_TASK_RESOURCE, + }) + + if err != nil { + logger.Warningf(ctx, "Failed to fetch override values when assigning task resource default values for [%+v]: %v", + task.Template, err) + } + logger.Debugf(ctx, "Assigning task requested resources for [%+v]", task.Template.Id) + var taskResourceSpec *admin.TaskResourceSpec + if resource != nil && resource.Attributes != nil && resource.Attributes.GetTaskResourceAttributes() != nil { + taskResourceSpec = resource.Attributes.GetTaskResourceAttributes().Defaults + } + task.Template.GetContainer().Resources.Requests = assignResourcesIfUnset( + ctx, task.Template.Id, taskConfig.GetDefaults(), task.Template.GetContainer().Resources.Requests, + taskResourceSpec) + + logger.Debugf(ctx, "Assigning task resource limits for [%+v]", task.Template.Id) + if resource != nil && resource.Attributes != nil && resource.Attributes.GetTaskResourceAttributes() != nil { + taskResourceSpec = resource.Attributes.GetTaskResourceAttributes().Limits + } + task.Template.GetContainer().Resources.Limits = assignResourcesIfUnset( + ctx, task.Template.Id, createTaskDefaultLimits(ctx, task), task.Template.GetContainer().Resources.Limits, + taskResourceSpec) +} + func (m *ExecutionManager) launchExecutionAndPrepareModel( ctx context.Context, request admin.ExecutionCreateRequest, requestedAt time.Time) ( context.Context, *models.Execution, error) { @@ -243,7 +369,7 @@ func (m *ExecutionManager) launchExecutionAndPrepareModel( // Dynamically assign task resource defaults. for _, task := range workflow.Closure.CompiledWorkflow.Tasks { - validation.SetDefaults(ctx, m.config.TaskResourceConfiguration(), task, m.db, name) + setCompiledTaskDefaults(ctx, m.config.TaskResourceConfiguration(), task, m.db, name) } // Dynamically assign execution queues. diff --git a/flyteadmin/pkg/manager/impl/execution_manager_test.go b/flyteadmin/pkg/manager/impl/execution_manager_test.go index 314c49cca..af5274a47 100644 --- a/flyteadmin/pkg/manager/impl/execution_manager_test.go +++ b/flyteadmin/pkg/manager/impl/execution_manager_test.go @@ -2355,3 +2355,186 @@ func TestListExecutions_LegacyModel(t *testing.T) { } assert.Empty(t, executionList.Token) } + +func TestAssignResourcesIfUnset(t *testing.T) { + platformValues := runtimeInterfaces.TaskResourceSet{ + CPU: "200m", + GPU: "8", + Memory: "200Gi", + } + taskResourceSpec := &admin.TaskResourceSpec{ + Cpu: "400m", + Memory: "400Gi", + } + assignedResources := assignResourcesIfUnset(context.Background(), &core.Identifier{ + Project: "project", + Domain: "domain", + Name: "name", + Version: "version", + }, platformValues, []*core.Resources_ResourceEntry{}, taskResourceSpec) + + assert.EqualValues(t, []*core.Resources_ResourceEntry{ + { + Name: core.Resources_CPU, + Value: taskResourceSpec.Cpu, + }, + { + Name: core.Resources_MEMORY, + Value: taskResourceSpec.Memory, + }, + }, assignedResources) +} + +func TestSetDefaults(t *testing.T) { + task := &core.CompiledTask{ + Template: &core.TaskTemplate{ + Target: &core.TaskTemplate_Container{ + Container: &core.Container{ + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + { + Name: core.Resources_CPU, + Value: "200m", + }, + }, + }, + }, + }, + Id: &core.Identifier{ + Project: "project", + Domain: "domain", + Name: "task_name", + Version: "version", + }, + }, + } + + taskConfig := runtimeMocks.MockTaskResourceConfiguration{} + taskConfig.Defaults = runtimeInterfaces.TaskResourceSet{ + CPU: "200m", + GPU: "8", + Memory: "200Gi", + } + taskConfig.Limits = runtimeInterfaces.TaskResourceSet{ + CPU: "300m", + GPU: "8", + Memory: "500Gi", + } + setCompiledTaskDefaults(context.Background(), &taskConfig, task, repositoryMocks.NewMockRepository(), "workflow") + assert.True(t, proto.Equal( + &core.Container{ + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + { + Name: core.Resources_CPU, + Value: "200m", + }, + { + Name: core.Resources_MEMORY, + Value: "200Gi", + }, + }, + Limits: []*core.Resources_ResourceEntry{ + { + Name: core.Resources_CPU, + Value: "200m", + }, + { + Name: core.Resources_MEMORY, + Value: "200Gi", + }, + }, + }, + }, + task.Template.GetContainer()), fmt.Sprintf("%+v", task.Template.GetContainer())) +} + +func TestSetDefaults_MissingDefaults(t *testing.T) { + task := &core.CompiledTask{ + Template: &core.TaskTemplate{ + Target: &core.TaskTemplate_Container{ + Container: &core.Container{ + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + { + Name: core.Resources_CPU, + Value: "200m", + }, + }, + }, + }, + }, + Id: &core.Identifier{ + Project: "project", + Domain: "domain", + Name: "task_name", + Version: "version", + }, + }, + } + + taskConfig := runtimeMocks.MockTaskResourceConfiguration{} + taskConfig.Defaults = runtimeInterfaces.TaskResourceSet{ + CPU: "200m", + GPU: "8", + Memory: "200Gi", + } + taskConfig.Limits = runtimeInterfaces.TaskResourceSet{ + CPU: "300m", + GPU: "8", + } + setCompiledTaskDefaults(context.Background(), &taskConfig, task, repositoryMocks.NewMockRepository(), "workflow") + assert.True(t, proto.Equal( + &core.Container{ + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + { + Name: core.Resources_CPU, + Value: "200m", + }, + { + Name: core.Resources_MEMORY, + Value: "200Gi", + }, + }, + Limits: []*core.Resources_ResourceEntry{ + { + Name: core.Resources_CPU, + Value: "200m", + }, + { + Name: core.Resources_MEMORY, + Value: "200Gi", + }, + }, + }, + }, + task.Template.GetContainer()), fmt.Sprintf("%+v", task.Template.GetContainer())) +} + +func TestCreateTaskDefaultLimits(t *testing.T) { + task := &core.CompiledTask{ + Template: &core.TaskTemplate{ + Target: &core.TaskTemplate_Container{ + Container: &core.Container{ + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + { + Name: core.Resources_CPU, + Value: "200m", + }, + { + Name: core.Resources_MEMORY, + Value: "200Mi", + }, + }, + }, + }, + }, + }, + } + + defaultLimits := createTaskDefaultLimits(context.Background(), task) + assert.Equal(t, "200Mi", defaultLimits.Memory) + assert.Equal(t, "200m", defaultLimits.CPU) +} diff --git a/flyteadmin/pkg/manager/impl/executions/queues.go b/flyteadmin/pkg/manager/impl/executions/queues.go index aaa6c119b..de68f02db 100644 --- a/flyteadmin/pkg/manager/impl/executions/queues.go +++ b/flyteadmin/pkg/manager/impl/executions/queues.go @@ -4,8 +4,10 @@ import ( "context" "math/rand" + "github.com/lyft/flyteadmin/pkg/manager/impl/resources" + "github.com/lyft/flyteadmin/pkg/manager/interfaces" + "github.com/lyft/flyteadmin/pkg/repositories" - "github.com/lyft/flyteadmin/pkg/resourcematching" "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" "github.com/lyft/flytestdlib/logger" @@ -31,9 +33,10 @@ type QueueAllocator interface { } type queueAllocatorImpl struct { - queueConfigMap queueConfig - config runtimeInterfaces.Configuration - db repositories.RepositoryInterface + queueConfigMap queueConfig + config runtimeInterfaces.Configuration + db repositories.RepositoryInterface + resourceManager interfaces.ResourceInterface } func (q *queueAllocatorImpl) refreshExecutionQueues(executionQueues []runtimeInterfaces.ExecutionQueue) { @@ -60,20 +63,20 @@ func (q *queueAllocatorImpl) GetQueue(ctx context.Context, identifier core.Ident executionQueues := q.config.QueueConfiguration().GetExecutionQueues() q.refreshExecutionQueues(executionQueues) - attributes, err := resourcematching.GetOverrideValuesToApply(ctx, resourcematching.GetOverrideValuesInput{ - Db: q.db, - Project: identifier.Project, - Domain: identifier.Domain, - Workflow: identifier.Name, - Resource: admin.MatchableResource_EXECUTION_QUEUE, + resource, err := q.resourceManager.GetResource(ctx, interfaces.ResourceRequest{ + Project: identifier.Project, + Domain: identifier.Domain, + Workflow: identifier.Name, + ResourceType: admin.MatchableResource_EXECUTION_QUEUE, }) + if err != nil { logger.Warningf(ctx, "Failed to fetch override values when assigning execution queue for [%+v] with err: %v", identifier, err) } - if attributes != nil && attributes.GetExecutionQueueAttributes() != nil { - for _, tag := range attributes.GetExecutionQueueAttributes().Tags { + if resource != nil && resource.Attributes != nil && resource.Attributes.GetExecutionQueueAttributes() != nil { + for _, tag := range resource.Attributes.GetExecutionQueueAttributes().Tags { matches, ok := q.queueConfigMap[tag] if !ok { continue @@ -108,8 +111,9 @@ func (q *queueAllocatorImpl) GetQueue(ctx context.Context, identifier core.Ident func NewQueueAllocator(config runtimeInterfaces.Configuration, db repositories.RepositoryInterface) QueueAllocator { queueAllocator := queueAllocatorImpl{ - config: config, - db: db, + config: config, + db: db, + resourceManager: resources.NewResourceManager(db), } return &queueAllocator } diff --git a/flyteadmin/pkg/manager/impl/executions/queues_test.go b/flyteadmin/pkg/manager/impl/executions/queues_test.go index f107b5d32..0f9128a27 100644 --- a/flyteadmin/pkg/manager/impl/executions/queues_test.go +++ b/flyteadmin/pkg/manager/impl/executions/queues_test.go @@ -4,6 +4,8 @@ import ( "context" "testing" + "github.com/lyft/flyteadmin/pkg/repositories/interfaces" + "github.com/golang/protobuf/proto" "github.com/lyft/flyteadmin/pkg/errors" "github.com/lyft/flyteadmin/pkg/repositories/mocks" @@ -31,16 +33,14 @@ func TestGetQueue(t *testing.T) { }, } db := mocks.NewMockRepository() - db.WorkflowAttributesRepo().(*mocks.MockWorkflowAttributesRepo).GetFunction = func( - ctx context.Context, project, domain, workflow, resource string) ( - models.WorkflowAttributes, error) { - response := models.WorkflowAttributes{ - Project: project, - Domain: domain, - Workflow: workflow, - Resource: resource, + db.ResourceRepo().(*mocks.MockResourceRepo).GetFunction = func(ctx context.Context, ID interfaces.ResourceID) (resource models.Resource, e error) { + response := models.Resource{ + Project: ID.Project, + Domain: ID.Domain, + Workflow: ID.Workflow, + ResourceType: ID.ResourceType, } - if project == testProject && domain == testDomain && workflow == testWorkflow { + if ID.Project == testProject && ID.Domain == testDomain && ID.Workflow == testWorkflow { matchingAttributes := &admin.MatchingAttributes{ Target: &admin.MatchingAttributes_ExecutionQueueAttributes{ ExecutionQueueAttributes: &admin.ExecutionQueueAttributes{ @@ -122,11 +122,9 @@ func TestGetQueueDefaults(t *testing.T) { }, } db := mocks.NewMockRepository() - db.WorkflowAttributesRepo().(*mocks.MockWorkflowAttributesRepo).GetFunction = func( - ctx context.Context, project, domain, workflow, resource string) ( - models.WorkflowAttributes, error) { - if project == testProject && domain == testDomain && workflow == "workflow" && - resource == admin.MatchableResource_EXECUTION_QUEUE.String() { + db.ResourceRepo().(*mocks.MockResourceRepo).GetFunction = func(ctx context.Context, ID interfaces.ResourceID) (resource models.Resource, e error) { + if ID.Project == testProject && ID.Domain == testDomain && ID.Workflow == "workflow" && + ID.ResourceType == admin.MatchableResource_EXECUTION_QUEUE.String() { matchingAttributes := &admin.MatchingAttributes{ Target: &admin.MatchingAttributes_ExecutionQueueAttributes{ ExecutionQueueAttributes: &admin.ExecutionQueueAttributes{ @@ -135,19 +133,15 @@ func TestGetQueueDefaults(t *testing.T) { }, } marshalledMatchingAttributes, _ := proto.Marshal(matchingAttributes) - return models.WorkflowAttributes{ - Project: project, - Domain: domain, - Workflow: workflow, - Resource: resource, - Attributes: marshalledMatchingAttributes, + return models.Resource{ + Project: ID.Project, + Domain: ID.Domain, + Workflow: ID.Workflow, + ResourceType: ID.ResourceType, + Attributes: marshalledMatchingAttributes, }, nil } - return models.WorkflowAttributes{}, errors.NewFlyteAdminError(codes.NotFound, "foo") - } - db.ProjectDomainAttributesRepo().(*mocks.MockProjectDomainAttributesRepo).GetFunction = func( - ctx context.Context, project, domain, resource string) (models.ProjectDomainAttributes, error) { - if project == testProject && domain == testDomain && resource == admin.MatchableResource_EXECUTION_QUEUE.String() { + if ID.Project == testProject && ID.Domain == testDomain && ID.ResourceType == admin.MatchableResource_EXECUTION_QUEUE.String() { matchingAttributes := &admin.MatchingAttributes{ Target: &admin.MatchingAttributes_ExecutionQueueAttributes{ ExecutionQueueAttributes: &admin.ExecutionQueueAttributes{ @@ -156,18 +150,15 @@ func TestGetQueueDefaults(t *testing.T) { }, } marshalledMatchingAttributes, _ := proto.Marshal(matchingAttributes) - return models.ProjectDomainAttributes{ - Project: project, - Domain: domain, - Resource: resource, - Attributes: marshalledMatchingAttributes, + return models.Resource{ + Project: ID.Project, + Domain: ID.Domain, + ResourceType: ID.ResourceType, + Attributes: marshalledMatchingAttributes, }, nil } - return models.ProjectDomainAttributes{}, errors.NewFlyteAdminError(codes.NotFound, "foo") - } - db.ProjectAttributesRepo().(*mocks.MockProjectAttributesRepo).GetFunction = func( - ctx context.Context, project, resource string) (models.ProjectAttributes, error) { - if project == testProject && resource == admin.MatchableResource_EXECUTION_QUEUE.String() { + + if ID.Project == testProject && ID.ResourceType == admin.MatchableResource_EXECUTION_QUEUE.String() { matchingAttributes := &admin.MatchingAttributes{ Target: &admin.MatchingAttributes_ExecutionQueueAttributes{ ExecutionQueueAttributes: &admin.ExecutionQueueAttributes{ @@ -176,13 +167,13 @@ func TestGetQueueDefaults(t *testing.T) { }, } marshalledMatchingAttributes, _ := proto.Marshal(matchingAttributes) - return models.ProjectAttributes{ - Project: project, - Resource: resource, - Attributes: marshalledMatchingAttributes, + return models.Resource{ + Project: ID.Project, + ResourceType: ID.ResourceType, + Attributes: marshalledMatchingAttributes, }, nil } - return models.ProjectAttributes{}, errors.NewFlyteAdminError(codes.NotFound, "foo") + return models.Resource{}, errors.NewFlyteAdminError(codes.NotFound, "foo") } queueAllocator := NewQueueAllocator(runtimeMocks.NewMockConfigurationProvider( diff --git a/flyteadmin/pkg/manager/impl/project_attributes_manager.go b/flyteadmin/pkg/manager/impl/project_attributes_manager.go deleted file mode 100644 index e035af9e6..000000000 --- a/flyteadmin/pkg/manager/impl/project_attributes_manager.go +++ /dev/null @@ -1,75 +0,0 @@ -package impl - -import ( - "context" - - "github.com/lyft/flytestdlib/logger" - - "github.com/lyft/flyteadmin/pkg/manager/impl/validation" - "github.com/lyft/flyteadmin/pkg/repositories/transformers" - - "github.com/lyft/flyteadmin/pkg/manager/interfaces" - "github.com/lyft/flyteadmin/pkg/repositories" - "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" -) - -type ProjectAttributesManager struct { - db repositories.RepositoryInterface -} - -func (m *ProjectAttributesManager) UpdateProjectAttributes( - ctx context.Context, request admin.ProjectAttributesUpdateRequest) ( - *admin.ProjectAttributesUpdateResponse, error) { - var resource admin.MatchableResource - var err error - if resource, err = validation.ValidateProjectAttributesUpdateRequest(request); err != nil { - return nil, err - } - - model, err := transformers.ToProjectAttributesModel(*request.Attributes, resource) - if err != nil { - return nil, err - } - err = m.db.ProjectAttributesRepo().CreateOrUpdate(ctx, model) - if err != nil { - return nil, err - } - - return &admin.ProjectAttributesUpdateResponse{}, nil -} - -func (m *ProjectAttributesManager) GetProjectAttributes(ctx context.Context, request admin.ProjectAttributesGetRequest) ( - *admin.ProjectAttributesGetResponse, error) { - if err := validation.ValidateProjectAttributesGetRequest(request); err != nil { - return nil, err - } - projectAttributesModel, err := m.db.ProjectAttributesRepo().Get(ctx, request.Project, request.ResourceType.String()) - if err != nil { - return nil, err - } - projectAttributes, err := transformers.FromProjectAttributesModel(projectAttributesModel) - if err != nil { - return nil, err - } - return &admin.ProjectAttributesGetResponse{ - Attributes: &projectAttributes, - }, nil -} - -func (m *ProjectAttributesManager) DeleteProjectAttributes(ctx context.Context, - request admin.ProjectAttributesDeleteRequest) (*admin.ProjectAttributesDeleteResponse, error) { - if err := validation.ValidateProjectAttributesDeleteRequest(request); err != nil { - return nil, err - } - if err := m.db.ProjectAttributesRepo().Delete(ctx, request.Project, request.ResourceType.String()); err != nil { - return nil, err - } - logger.Infof(ctx, "Deleted project attributes for: %s (%s)", request.Project, request.ResourceType.String()) - return &admin.ProjectAttributesDeleteResponse{}, nil -} - -func NewProjectAttributesManager(db repositories.RepositoryInterface) interfaces.ProjectAttributesInterface { - return &ProjectAttributesManager{ - db: db, - } -} diff --git a/flyteadmin/pkg/manager/impl/project_attributes_manager_test.go b/flyteadmin/pkg/manager/impl/project_attributes_manager_test.go deleted file mode 100644 index 7d8c264b7..000000000 --- a/flyteadmin/pkg/manager/impl/project_attributes_manager_test.go +++ /dev/null @@ -1,82 +0,0 @@ -package impl - -import ( - "context" - "testing" - - "github.com/golang/protobuf/proto" - "github.com/lyft/flyteadmin/pkg/manager/impl/testutils" - "github.com/lyft/flyteadmin/pkg/repositories/mocks" - "github.com/lyft/flyteadmin/pkg/repositories/models" - "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" - "github.com/stretchr/testify/assert" -) - -func TestUpdateProjectAttributes(t *testing.T) { - request := admin.ProjectAttributesUpdateRequest{ - Attributes: &admin.ProjectAttributes{ - Project: "project", - MatchingAttributes: testutils.ExecutionQueueAttributes, - }, - } - db := mocks.NewMockRepository() - expectedSerializedAttrs, _ := proto.Marshal(testutils.ExecutionQueueAttributes) - var createOrUpdateCalled bool - db.ProjectAttributesRepo().(*mocks.MockProjectAttributesRepo).CreateOrUpdateFunction = func( - ctx context.Context, input models.ProjectAttributes) error { - assert.Equal(t, "project", input.Project) - assert.Equal(t, admin.MatchableResource_EXECUTION_QUEUE.String(), input.Resource) - assert.EqualValues(t, expectedSerializedAttrs, input.Attributes) - createOrUpdateCalled = true - return nil - } - manager := NewProjectAttributesManager(db) - _, err := manager.UpdateProjectAttributes(context.Background(), request) - assert.Nil(t, err) - assert.True(t, createOrUpdateCalled) -} - -func TestGetProjectAttributes(t *testing.T) { - request := admin.ProjectAttributesGetRequest{ - Project: "project", - ResourceType: admin.MatchableResource_EXECUTION_QUEUE, - } - db := mocks.NewMockRepository() - db.ProjectAttributesRepo().(*mocks.MockProjectAttributesRepo).GetFunction = func( - ctx context.Context, project, resource string) (models.ProjectAttributes, error) { - assert.Equal(t, "project", project) - assert.Equal(t, admin.MatchableResource_EXECUTION_QUEUE.String(), resource) - expectedSerializedAttrs, _ := proto.Marshal(testutils.ExecutionQueueAttributes) - return models.ProjectAttributes{ - Project: project, - Resource: resource, - Attributes: expectedSerializedAttrs, - }, nil - } - manager := NewProjectAttributesManager(db) - response, err := manager.GetProjectAttributes(context.Background(), request) - assert.Nil(t, err) - assert.True(t, proto.Equal(&admin.ProjectAttributesGetResponse{ - Attributes: &admin.ProjectAttributes{ - Project: "project", - MatchingAttributes: testutils.ExecutionQueueAttributes, - }, - }, response)) -} - -func TestDeleteProjectAttributes(t *testing.T) { - request := admin.ProjectAttributesDeleteRequest{ - Project: "project", - ResourceType: admin.MatchableResource_EXECUTION_QUEUE, - } - db := mocks.NewMockRepository() - db.ProjectAttributesRepo().(*mocks.MockProjectAttributesRepo).DeleteFunction = func( - ctx context.Context, project, resource string) error { - assert.Equal(t, "project", project) - assert.Equal(t, admin.MatchableResource_EXECUTION_QUEUE.String(), resource) - return nil - } - manager := NewProjectAttributesManager(db) - _, err := manager.DeleteProjectAttributes(context.Background(), request) - assert.Nil(t, err) -} diff --git a/flyteadmin/pkg/manager/impl/project_domain_attributes_manager.go b/flyteadmin/pkg/manager/impl/project_domain_attributes_manager.go deleted file mode 100644 index 45fe906b7..000000000 --- a/flyteadmin/pkg/manager/impl/project_domain_attributes_manager.go +++ /dev/null @@ -1,83 +0,0 @@ -package impl - -import ( - "context" - - "github.com/lyft/flytestdlib/logger" - - "github.com/lyft/flytestdlib/contextutils" - - "github.com/lyft/flyteadmin/pkg/manager/impl/validation" - "github.com/lyft/flyteadmin/pkg/repositories/transformers" - - "github.com/lyft/flyteadmin/pkg/manager/interfaces" - "github.com/lyft/flyteadmin/pkg/repositories" - "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" -) - -type ProjectDomainAttributesManager struct { - db repositories.RepositoryInterface -} - -func (m *ProjectDomainAttributesManager) UpdateProjectDomainAttributes( - ctx context.Context, request admin.ProjectDomainAttributesUpdateRequest) ( - *admin.ProjectDomainAttributesUpdateResponse, error) { - var resource admin.MatchableResource - var err error - if resource, err = validation.ValidateProjectDomainAttributesUpdateRequest(request); err != nil { - return nil, err - } - ctx = contextutils.WithProjectDomain(ctx, request.Attributes.Project, request.Attributes.Domain) - - model, err := transformers.ToProjectDomainAttributesModel(*request.Attributes, resource) - if err != nil { - return nil, err - } - err = m.db.ProjectDomainAttributesRepo().CreateOrUpdate(ctx, model) - if err != nil { - return nil, err - } - - return &admin.ProjectDomainAttributesUpdateResponse{}, nil -} - -func (m *ProjectDomainAttributesManager) GetProjectDomainAttributes( - ctx context.Context, request admin.ProjectDomainAttributesGetRequest) ( - *admin.ProjectDomainAttributesGetResponse, error) { - if err := validation.ValidateProjectDomainAttributesGetRequest(request); err != nil { - return nil, err - } - projectAttributesModel, err := m.db.ProjectDomainAttributesRepo().Get( - ctx, request.Project, request.Domain, request.ResourceType.String()) - if err != nil { - return nil, err - } - projectAttributes, err := transformers.FromProjectDomainAttributesModel(projectAttributesModel) - if err != nil { - return nil, err - } - return &admin.ProjectDomainAttributesGetResponse{ - Attributes: &projectAttributes, - }, nil -} - -func (m *ProjectDomainAttributesManager) DeleteProjectDomainAttributes(ctx context.Context, - request admin.ProjectDomainAttributesDeleteRequest) (*admin.ProjectDomainAttributesDeleteResponse, error) { - if err := validation.ValidateProjectDomainAttributesDeleteRequest(request); err != nil { - return nil, err - } - if err := m.db.ProjectDomainAttributesRepo().Delete( - ctx, request.Project, request.Domain, request.ResourceType.String()); err != nil { - return nil, err - } - logger.Infof(ctx, "Deleted project-domain attributes for: %s-%s (%s)", request.Project, - request.Domain, request.ResourceType.String()) - return &admin.ProjectDomainAttributesDeleteResponse{}, nil -} - -func NewProjectDomainAttributesManager( - db repositories.RepositoryInterface) interfaces.ProjectDomainAttributesInterface { - return &ProjectDomainAttributesManager{ - db: db, - } -} diff --git a/flyteadmin/pkg/manager/impl/project_domain_attributes_manager_test.go b/flyteadmin/pkg/manager/impl/project_domain_attributes_manager_test.go deleted file mode 100644 index e5031afe1..000000000 --- a/flyteadmin/pkg/manager/impl/project_domain_attributes_manager_test.go +++ /dev/null @@ -1,90 +0,0 @@ -package impl - -import ( - "context" - "testing" - - "github.com/golang/protobuf/proto" - "github.com/lyft/flyteadmin/pkg/manager/impl/testutils" - "github.com/lyft/flyteadmin/pkg/repositories/mocks" - "github.com/lyft/flyteadmin/pkg/repositories/models" - "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" - "github.com/stretchr/testify/assert" -) - -func TestUpdateProjectDomainAttributes(t *testing.T) { - request := admin.ProjectDomainAttributesUpdateRequest{ - Attributes: &admin.ProjectDomainAttributes{ - Project: "project", - Domain: "domain", - MatchingAttributes: testutils.ExecutionQueueAttributes, - }, - } - db := mocks.NewMockRepository() - expectedSerializedAttrs, _ := proto.Marshal(testutils.ExecutionQueueAttributes) - var createOrUpdateCalled bool - db.ProjectDomainAttributesRepo().(*mocks.MockProjectDomainAttributesRepo).CreateOrUpdateFunction = func( - ctx context.Context, input models.ProjectDomainAttributes) error { - assert.Equal(t, "project", input.Project) - assert.Equal(t, "domain", input.Domain) - assert.Equal(t, admin.MatchableResource_EXECUTION_QUEUE.String(), input.Resource) - assert.EqualValues(t, expectedSerializedAttrs, input.Attributes) - createOrUpdateCalled = true - return nil - } - manager := NewProjectDomainAttributesManager(db) - _, err := manager.UpdateProjectDomainAttributes(context.Background(), request) - assert.Nil(t, err) - assert.True(t, createOrUpdateCalled) -} - -func TestGetProjectDomainAttributes(t *testing.T) { - request := admin.ProjectDomainAttributesGetRequest{ - Project: "project", - Domain: "domain", - ResourceType: admin.MatchableResource_EXECUTION_QUEUE, - } - db := mocks.NewMockRepository() - db.ProjectDomainAttributesRepo().(*mocks.MockProjectDomainAttributesRepo).GetFunction = func( - ctx context.Context, project, domain, resource string) (models.ProjectDomainAttributes, error) { - assert.Equal(t, "project", project) - assert.Equal(t, "domain", domain) - assert.Equal(t, admin.MatchableResource_EXECUTION_QUEUE.String(), resource) - expectedSerializedAttrs, _ := proto.Marshal(testutils.ExecutionQueueAttributes) - return models.ProjectDomainAttributes{ - Project: project, - Domain: domain, - Resource: resource, - Attributes: expectedSerializedAttrs, - }, nil - } - manager := NewProjectDomainAttributesManager(db) - response, err := manager.GetProjectDomainAttributes(context.Background(), request) - assert.Nil(t, err) - assert.True(t, proto.Equal(&admin.ProjectDomainAttributesGetResponse{ - Attributes: &admin.ProjectDomainAttributes{ - Project: "project", - Domain: "domain", - MatchingAttributes: testutils.ExecutionQueueAttributes, - }, - }, response)) -} - -func TestDeleteProjectDomainAttributes(t *testing.T) { - request := admin.ProjectDomainAttributesDeleteRequest{ - Project: "project", - Domain: "domain", - ResourceType: admin.MatchableResource_EXECUTION_QUEUE, - } - db := mocks.NewMockRepository() - db.ProjectDomainAttributesRepo().(*mocks.MockProjectDomainAttributesRepo).DeleteFunction = func( - ctx context.Context, project, domain, resource string) error { - assert.Equal(t, "project", project) - assert.Equal(t, "domain", domain) - assert.Equal(t, admin.MatchableResource_EXECUTION_QUEUE.String(), resource) - return nil - } - manager := NewProjectDomainAttributesManager(db) - _, err := manager.DeleteProjectDomainAttributes(context.Background(), request) - assert.Nil(t, err) -} diff --git a/flyteadmin/pkg/manager/impl/resources/resource_manager.go b/flyteadmin/pkg/manager/impl/resources/resource_manager.go new file mode 100644 index 000000000..f6f504140 --- /dev/null +++ b/flyteadmin/pkg/manager/impl/resources/resource_manager.go @@ -0,0 +1,168 @@ +package resources + +import ( + "context" + + "github.com/gogo/protobuf/proto" + "github.com/lyft/flyteadmin/pkg/errors" + "github.com/lyft/flytestdlib/contextutils" + "google.golang.org/grpc/codes" + + "github.com/lyft/flytestdlib/logger" + + "github.com/lyft/flyteadmin/pkg/manager/impl/validation" + repo_interface "github.com/lyft/flyteadmin/pkg/repositories/interfaces" + "github.com/lyft/flyteadmin/pkg/repositories/transformers" + + "github.com/lyft/flyteadmin/pkg/manager/interfaces" + "github.com/lyft/flyteadmin/pkg/repositories" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" +) + +type ResourceManager struct { + db repositories.RepositoryInterface +} + +func (m *ResourceManager) GetResource(ctx context.Context, request interfaces.ResourceRequest) (*interfaces.ResourceResponse, error) { + resource, err := m.db.ResourceRepo().Get(ctx, repo_interface.ResourceID{ + ResourceType: request.ResourceType.String(), + Project: request.Project, + Domain: request.Domain, + Workflow: request.Workflow, + LaunchPlan: request.LaunchPlan, + }) + if err != nil { + return nil, err + } + + var attributes admin.MatchingAttributes + err = proto.Unmarshal(resource.Attributes, &attributes) + if err != nil { + return nil, errors.NewFlyteAdminErrorf( + codes.Internal, "Failed to decode resource attribute with err: %v", err) + } + return &interfaces.ResourceResponse{ + ResourceType: resource.ResourceType, + Project: resource.Project, + Domain: resource.Domain, + Workflow: resource.Workflow, + LaunchPlan: resource.LaunchPlan, + Attributes: &attributes, + }, nil +} + +func (m *ResourceManager) UpdateWorkflowAttributes( + ctx context.Context, request admin.WorkflowAttributesUpdateRequest) ( + *admin.WorkflowAttributesUpdateResponse, error) { + var resource admin.MatchableResource + var err error + if resource, err = validation.ValidateWorkflowAttributesUpdateRequest(request); err != nil { + return nil, err + } + + model, err := transformers.WorkflowAttributesToResourceModel(*request.Attributes, resource) + if err != nil { + return nil, err + } + err = m.db.ResourceRepo().CreateOrUpdate(ctx, model) + if err != nil { + return nil, err + } + + return &admin.WorkflowAttributesUpdateResponse{}, nil +} + +func (m *ResourceManager) GetWorkflowAttributes( + ctx context.Context, request admin.WorkflowAttributesGetRequest) ( + *admin.WorkflowAttributesGetResponse, error) { + if err := validation.ValidateWorkflowAttributesGetRequest(request); err != nil { + return nil, err + } + projectAttributesModel, err := m.db.ResourceRepo().Get( + ctx, repo_interface.ResourceID{Project: request.Project, Domain: request.Domain, Workflow: request.Workflow, ResourceType: request.ResourceType.String()}) + if err != nil { + return nil, err + } + workflowAttributes, err := transformers.FromResourceModelToWorkflowAttributes(projectAttributesModel) + if err != nil { + return nil, err + } + return &admin.WorkflowAttributesGetResponse{ + Attributes: &workflowAttributes, + }, nil +} + +func (m *ResourceManager) DeleteWorkflowAttributes(ctx context.Context, + request admin.WorkflowAttributesDeleteRequest) (*admin.WorkflowAttributesDeleteResponse, error) { + if err := validation.ValidateWorkflowAttributesDeleteRequest(request); err != nil { + return nil, err + } + if err := m.db.ResourceRepo().Delete( + ctx, repo_interface.ResourceID{Project: request.Project, Domain: request.Domain, Workflow: request.Workflow, ResourceType: request.ResourceType.String()}); err != nil { + return nil, err + } + logger.Infof(ctx, "Deleted workflow attributes for: %s-%s-%s (%s)", request.Project, + request.Domain, request.Workflow, request.ResourceType.String()) + return &admin.WorkflowAttributesDeleteResponse{}, nil +} + +func (m *ResourceManager) UpdateProjectDomainAttributes( + ctx context.Context, request admin.ProjectDomainAttributesUpdateRequest) ( + *admin.ProjectDomainAttributesUpdateResponse, error) { + var resource admin.MatchableResource + var err error + if resource, err = validation.ValidateProjectDomainAttributesUpdateRequest(request); err != nil { + return nil, err + } + ctx = contextutils.WithProjectDomain(ctx, request.Attributes.Project, request.Attributes.Domain) + + model, err := transformers.ProjectDomainAttributesToResourceModel(*request.Attributes, resource) + if err != nil { + return nil, err + } + err = m.db.ResourceRepo().CreateOrUpdate(ctx, model) + if err != nil { + return nil, err + } + return &admin.ProjectDomainAttributesUpdateResponse{}, nil +} + +func (m *ResourceManager) GetProjectDomainAttributes( + ctx context.Context, request admin.ProjectDomainAttributesGetRequest) ( + *admin.ProjectDomainAttributesGetResponse, error) { + if err := validation.ValidateProjectDomainAttributesGetRequest(request); err != nil { + return nil, err + } + projectAttributesModel, err := m.db.ResourceRepo().Get( + ctx, repo_interface.ResourceID{Project: request.Project, Domain: request.Domain, ResourceType: request.ResourceType.String()}) + if err != nil { + return nil, err + } + projectAttributes, err := transformers.FromResourceModelToProjectDomainAttributes(projectAttributesModel) + if err != nil { + return nil, err + } + return &admin.ProjectDomainAttributesGetResponse{ + Attributes: &projectAttributes, + }, nil +} + +func (m *ResourceManager) DeleteProjectDomainAttributes(ctx context.Context, + request admin.ProjectDomainAttributesDeleteRequest) (*admin.ProjectDomainAttributesDeleteResponse, error) { + if err := validation.ValidateProjectDomainAttributesDeleteRequest(request); err != nil { + return nil, err + } + if err := m.db.ResourceRepo().Delete( + ctx, repo_interface.ResourceID{Project: request.Project, Domain: request.Domain, ResourceType: request.ResourceType.String()}); err != nil { + return nil, err + } + logger.Infof(ctx, "Deleted project-domain attributes for: %s-%s (%s)", request.Project, + request.Domain, request.ResourceType.String()) + return &admin.ProjectDomainAttributesDeleteResponse{}, nil +} + +func NewResourceManager(db repositories.RepositoryInterface) interfaces.ResourceInterface { + return &ResourceManager{ + db: db, + } +} diff --git a/flyteadmin/pkg/manager/impl/resources/resource_manager_test.go b/flyteadmin/pkg/manager/impl/resources/resource_manager_test.go new file mode 100644 index 000000000..b3a77d215 --- /dev/null +++ b/flyteadmin/pkg/manager/impl/resources/resource_manager_test.go @@ -0,0 +1,221 @@ +package resources + +import ( + "context" + "testing" + + interfaces2 "github.com/lyft/flyteadmin/pkg/manager/interfaces" + "github.com/lyft/flyteadmin/pkg/repositories/interfaces" + + "github.com/golang/protobuf/proto" + "github.com/lyft/flyteadmin/pkg/manager/impl/testutils" + "github.com/lyft/flyteadmin/pkg/repositories/mocks" + "github.com/lyft/flyteadmin/pkg/repositories/models" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/stretchr/testify/assert" +) + +const project = "project" +const domain = "domain" +const workflow = "workflow" + +func TestUpdateWorkflowAttributes(t *testing.T) { + request := admin.WorkflowAttributesUpdateRequest{ + Attributes: &admin.WorkflowAttributes{ + Project: project, + Domain: domain, + Workflow: workflow, + MatchingAttributes: testutils.ExecutionQueueAttributes, + }, + } + db := mocks.NewMockRepository() + expectedSerializedAttrs, _ := proto.Marshal(testutils.ExecutionQueueAttributes) + var createOrUpdateCalled bool + db.ResourceRepo().(*mocks.MockResourceRepo).CreateOrUpdateFunction = func( + ctx context.Context, input models.Resource) error { + assert.Equal(t, project, input.Project) + assert.Equal(t, domain, input.Domain) + assert.Equal(t, workflow, input.Workflow) + assert.Equal(t, admin.MatchableResource_EXECUTION_QUEUE.String(), input.ResourceType) + assert.EqualValues(t, expectedSerializedAttrs, input.Attributes) + createOrUpdateCalled = true + return nil + } + manager := NewResourceManager(db) + _, err := manager.UpdateWorkflowAttributes(context.Background(), request) + assert.Nil(t, err) + assert.True(t, createOrUpdateCalled) +} + +func TestGetWorkflowAttributes(t *testing.T) { + request := admin.WorkflowAttributesGetRequest{ + Project: project, + Domain: domain, + Workflow: workflow, + ResourceType: admin.MatchableResource_EXECUTION_QUEUE, + } + db := mocks.NewMockRepository() + db.ResourceRepo().(*mocks.MockResourceRepo).GetFunction = func( + ctx context.Context, ID interfaces.ResourceID) (models.Resource, error) { + assert.Equal(t, project, ID.Project) + assert.Equal(t, domain, ID.Domain) + assert.Equal(t, workflow, ID.Workflow) + assert.Equal(t, admin.MatchableResource_EXECUTION_QUEUE.String(), ID.ResourceType) + expectedSerializedAttrs, _ := proto.Marshal(testutils.ExecutionQueueAttributes) + return models.Resource{ + Project: project, + Domain: domain, + Workflow: workflow, + ResourceType: "resource", + Attributes: expectedSerializedAttrs, + }, nil + } + manager := NewResourceManager(db) + response, err := manager.GetWorkflowAttributes(context.Background(), request) + assert.Nil(t, err) + assert.True(t, proto.Equal(&admin.WorkflowAttributesGetResponse{ + Attributes: &admin.WorkflowAttributes{ + Project: project, + Domain: domain, + Workflow: workflow, + MatchingAttributes: testutils.ExecutionQueueAttributes, + }, + }, response)) +} + +func TestDeleteWorkflowAttributes(t *testing.T) { + request := admin.WorkflowAttributesDeleteRequest{ + Project: project, + Domain: domain, + Workflow: workflow, + ResourceType: admin.MatchableResource_EXECUTION_QUEUE, + } + db := mocks.NewMockRepository() + db.ResourceRepo().(*mocks.MockResourceRepo).DeleteFunction = func( + ctx context.Context, ID interfaces.ResourceID) error { + assert.Equal(t, project, ID.Project) + assert.Equal(t, domain, ID.Domain) + assert.Equal(t, workflow, ID.Workflow) + assert.Equal(t, admin.MatchableResource_EXECUTION_QUEUE.String(), ID.ResourceType) + return nil + } + manager := NewResourceManager(db) + _, err := manager.DeleteWorkflowAttributes(context.Background(), request) + assert.Nil(t, err) +} + +func TestUpdateProjectDomainAttributes(t *testing.T) { + request := admin.ProjectDomainAttributesUpdateRequest{ + Attributes: &admin.ProjectDomainAttributes{ + Project: project, + Domain: domain, + MatchingAttributes: testutils.ExecutionQueueAttributes, + }, + } + db := mocks.NewMockRepository() + expectedSerializedAttrs, _ := proto.Marshal(testutils.ExecutionQueueAttributes) + var createOrUpdateCalled bool + db.ResourceRepo().(*mocks.MockResourceRepo).CreateOrUpdateFunction = func( + ctx context.Context, input models.Resource) error { + assert.Equal(t, project, input.Project) + assert.Equal(t, domain, input.Domain) + assert.Equal(t, "", input.Workflow) + assert.Equal(t, admin.MatchableResource_EXECUTION_QUEUE.String(), input.ResourceType) + assert.EqualValues(t, expectedSerializedAttrs, input.Attributes) + createOrUpdateCalled = true + return nil + } + manager := NewResourceManager(db) + _, err := manager.UpdateProjectDomainAttributes(context.Background(), request) + assert.Nil(t, err) + assert.True(t, createOrUpdateCalled) +} + +func TestGetProjectDomainAttributes(t *testing.T) { + request := admin.ProjectDomainAttributesGetRequest{ + Project: project, + Domain: domain, + ResourceType: admin.MatchableResource_EXECUTION_QUEUE, + } + db := mocks.NewMockRepository() + db.ResourceRepo().(*mocks.MockResourceRepo).GetFunction = func( + ctx context.Context, ID interfaces.ResourceID) (models.Resource, error) { + assert.Equal(t, project, ID.Project) + assert.Equal(t, domain, ID.Domain) + assert.Equal(t, "", ID.Workflow) + assert.Equal(t, admin.MatchableResource_EXECUTION_QUEUE.String(), ID.ResourceType) + expectedSerializedAttrs, _ := proto.Marshal(testutils.ExecutionQueueAttributes) + return models.Resource{ + Project: project, + Domain: domain, + ResourceType: "resource", + Attributes: expectedSerializedAttrs, + }, nil + } + manager := NewResourceManager(db) + response, err := manager.GetProjectDomainAttributes(context.Background(), request) + assert.Nil(t, err) + assert.True(t, proto.Equal(&admin.ProjectDomainAttributesGetResponse{ + Attributes: &admin.ProjectDomainAttributes{ + Project: project, + Domain: domain, + MatchingAttributes: testutils.ExecutionQueueAttributes, + }, + }, response)) +} + +func TestDeleteProjectDomainAttributes(t *testing.T) { + request := admin.ProjectDomainAttributesDeleteRequest{ + Project: project, + Domain: domain, + ResourceType: admin.MatchableResource_EXECUTION_QUEUE, + } + db := mocks.NewMockRepository() + db.ResourceRepo().(*mocks.MockResourceRepo).DeleteFunction = func( + ctx context.Context, ID interfaces.ResourceID) error { + assert.Equal(t, project, ID.Project) + assert.Equal(t, domain, ID.Domain) + assert.Equal(t, admin.MatchableResource_EXECUTION_QUEUE.String(), ID.ResourceType) + return nil + } + manager := NewResourceManager(db) + _, err := manager.DeleteProjectDomainAttributes(context.Background(), request) + assert.Nil(t, err) +} + +func TestGetResource(t *testing.T) { + request := interfaces2.ResourceRequest{ + Project: project, + Domain: domain, + Workflow: workflow, + LaunchPlan: "launch_plan", + ResourceType: admin.MatchableResource_EXECUTION_QUEUE, + } + db := mocks.NewMockRepository() + db.ResourceRepo().(*mocks.MockResourceRepo).GetFunction = func( + ctx context.Context, ID interfaces.ResourceID) (models.Resource, error) { + assert.Equal(t, project, ID.Project) + assert.Equal(t, domain, ID.Domain) + assert.Equal(t, workflow, ID.Workflow) + assert.Equal(t, "launch_plan", ID.LaunchPlan) + assert.Equal(t, admin.MatchableResource_EXECUTION_QUEUE.String(), ID.ResourceType) + expectedSerializedAttrs, _ := proto.Marshal(testutils.ExecutionQueueAttributes) + return models.Resource{ + Project: ID.Project, + Domain: ID.Domain, + Workflow: ID.Workflow, + LaunchPlan: ID.LaunchPlan, + ResourceType: ID.ResourceType, + Attributes: expectedSerializedAttrs, + }, nil + } + manager := NewResourceManager(db) + response, err := manager.GetResource(context.Background(), request) + assert.Nil(t, err) + assert.Equal(t, request.Project, response.Project) + assert.Equal(t, request.Domain, response.Domain) + assert.Equal(t, request.Workflow, response.Workflow) + assert.Equal(t, request.LaunchPlan, response.LaunchPlan) + assert.Equal(t, request.ResourceType.String(), response.ResourceType) + assert.True(t, proto.Equal(response.Attributes, testutils.ExecutionQueueAttributes)) +} diff --git a/flyteadmin/pkg/manager/impl/validation/attributes_validator.go b/flyteadmin/pkg/manager/impl/validation/attributes_validator.go index 4ae18ae8b..9ad8db767 100644 --- a/flyteadmin/pkg/manager/impl/validation/attributes_validator.go +++ b/flyteadmin/pkg/manager/impl/validation/attributes_validator.go @@ -26,34 +26,6 @@ func validateMatchingAttributes(attributes *admin.MatchingAttributes, identifier "Unrecognized matching attributes type for request %s", identifier) } -func ValidateProjectAttributesUpdateRequest(request admin.ProjectAttributesUpdateRequest) ( - admin.MatchableResource, error) { - if request.Attributes == nil { - return defaultMatchableResource, shared.GetMissingArgumentError(shared.Attributes) - } - if err := ValidateEmptyStringField(request.Attributes.Project, shared.Project); err != nil { - return defaultMatchableResource, err - } - - return validateMatchingAttributes(request.Attributes.MatchingAttributes, request.Attributes.Project) -} - -func ValidateProjectAttributesGetRequest(request admin.ProjectAttributesGetRequest) error { - if err := ValidateEmptyStringField(request.Project, shared.Project); err != nil { - return err - } - - return nil -} - -func ValidateProjectAttributesDeleteRequest(request admin.ProjectAttributesDeleteRequest) error { - if err := ValidateEmptyStringField(request.Project, shared.Project); err != nil { - return err - } - - return nil -} - func ValidateProjectDomainAttributesUpdateRequest(request admin.ProjectDomainAttributesUpdateRequest) ( admin.MatchableResource, error) { if request.Attributes == nil { diff --git a/flyteadmin/pkg/manager/impl/validation/attributes_validator_test.go b/flyteadmin/pkg/manager/impl/validation/attributes_validator_test.go index 56053b656..e0c767e20 100644 --- a/flyteadmin/pkg/manager/impl/validation/attributes_validator_test.go +++ b/flyteadmin/pkg/manager/impl/validation/attributes_validator_test.go @@ -70,49 +70,6 @@ func TestValidateMatchingAttributes(t *testing.T) { } } -func TestValidateProjectAttributesUpdateRequest(t *testing.T) { - _, err := ValidateProjectAttributesUpdateRequest(admin.ProjectAttributesUpdateRequest{}) - assert.Equal(t, "missing attributes", err.Error()) - - _, err = ValidateProjectAttributesUpdateRequest(admin.ProjectAttributesUpdateRequest{ - Attributes: &admin.ProjectAttributes{}}) - assert.Equal(t, "missing project", err.Error()) - - matchableResource, err := ValidateProjectAttributesUpdateRequest(admin.ProjectAttributesUpdateRequest{ - Attributes: &admin.ProjectAttributes{ - Project: "project", - MatchingAttributes: &admin.MatchingAttributes{ - Target: &admin.MatchingAttributes_TaskResourceAttributes{ - TaskResourceAttributes: &admin.TaskResourceAttributes{ - Defaults: &admin.TaskResourceSpec{ - Cpu: "1", - }, - }, - }, - }, - }}) - assert.Equal(t, admin.MatchableResource_TASK_RESOURCE, matchableResource) - assert.Nil(t, err) -} - -func TestValidateProjectAttributesGetRequest(t *testing.T) { - err := ValidateProjectAttributesGetRequest(admin.ProjectAttributesGetRequest{}) - assert.Equal(t, "missing project", err.Error()) - - assert.Nil(t, ValidateProjectAttributesGetRequest(admin.ProjectAttributesGetRequest{ - Project: "project", - })) -} - -func TestValidateProjectAttributesDeleteRequest(t *testing.T) { - err := ValidateProjectAttributesDeleteRequest(admin.ProjectAttributesDeleteRequest{}) - assert.Equal(t, "missing project", err.Error()) - - assert.Nil(t, ValidateProjectAttributesDeleteRequest(admin.ProjectAttributesDeleteRequest{ - Project: "project", - })) -} - func TestValidateProjectDomainAttributesUpdateRequest(t *testing.T) { _, err := ValidateProjectDomainAttributesUpdateRequest(admin.ProjectDomainAttributesUpdateRequest{}) assert.Equal(t, "missing attributes", err.Error()) diff --git a/flyteadmin/pkg/manager/impl/validation/task_validator.go b/flyteadmin/pkg/manager/impl/validation/task_validator.go index 562768529..6280770c8 100644 --- a/flyteadmin/pkg/manager/impl/validation/task_validator.go +++ b/flyteadmin/pkg/manager/impl/validation/task_validator.go @@ -4,8 +4,6 @@ package validation import ( "context" - "github.com/lyft/flyteadmin/pkg/resourcematching" - "github.com/lyft/flyteadmin/pkg/repositories" "github.com/lyft/flyteadmin/pkg/common" @@ -191,18 +189,18 @@ func validateTaskResources( if ok && limitQuantity.Value() < defaultQuantity.Value() { // Only assert the requested limit is greater than than the requested default when the limit is actually set return errors.NewFlyteAdminErrorf(codes.InvalidArgument, - "Resource %v for [%+v] cannot set default > limit", resourceName, identifier) + "Type %v for [%+v] cannot set default > limit", resourceName, identifier) } platformLimit, platformLimitOk := platformTaskResourceLimits[resourceName] if ok && platformLimitOk && limitQuantity.Value() > platformLimit.Value() { // Also check that the requested limit is less than the platform task limit. return errors.NewFlyteAdminErrorf(codes.InvalidArgument, - "Resource %v for [%+v] cannot set limit > platform limit", resourceName, identifier) + "Type %v for [%+v] cannot set limit > platform limit", resourceName, identifier) } if platformLimitOk && defaultQuantity.Value() > platformTaskResourceLimits[resourceName].Value() { // Also check that the requested limit is less than the platform task limit. return errors.NewFlyteAdminErrorf(codes.InvalidArgument, - "Resource %v for [%+v] cannot set default > platform limit", resourceName, identifier) + "Type %v for [%+v] cannot set default > platform limit", resourceName, identifier) } case core.Resources_GPU: limitQuantity, ok := requestedResourceLimits[resourceName] @@ -214,7 +212,7 @@ func validateTaskResources( platformLimit, platformLimitOk := platformTaskResourceLimits[resourceName] if platformLimitOk && defaultQuantity.Value() > platformLimit.Value() { return errors.NewFlyteAdminErrorf(codes.InvalidArgument, - "Resource %v for [%+v] cannot set default > platform limit", resourceName, identifier) + "Type %v for [%+v] cannot set default > platform limit", resourceName, identifier) } } } @@ -222,131 +220,6 @@ func validateTaskResources( return nil } -func assignResourcesIfUnset(ctx context.Context, identifier *core.Identifier, - platformValues runtimeInterfaces.TaskResourceSet, - resourceEntries []*core.Resources_ResourceEntry, taskResourceSpec *admin.TaskResourceSpec) []*core.Resources_ResourceEntry { - var cpuIndex, memoryIndex = -1, -1 - for idx, entry := range resourceEntries { - switch entry.Name { - case core.Resources_CPU: - cpuIndex = idx - case core.Resources_MEMORY: - memoryIndex = idx - } - } - if cpuIndex > 0 && memoryIndex > 0 { - // nothing to do - return resourceEntries - } - - if cpuIndex < 0 && platformValues.CPU != "" { - logger.Debugf(ctx, "Setting 'cpu' for [%+v] to %s", identifier, platformValues.CPU) - cpuValue := platformValues.CPU - if taskResourceSpec != nil && len(taskResourceSpec.Cpu) > 0 { - // Use the custom attributes from the database rather than the platform defaults from the application config - cpuValue = taskResourceSpec.Cpu - } - cpuResource := &core.Resources_ResourceEntry{ - Name: core.Resources_CPU, - Value: cpuValue, - } - resourceEntries = append(resourceEntries, cpuResource) - } - if memoryIndex < 0 && platformValues.Memory != "" { - memoryValue := platformValues.Memory - if taskResourceSpec != nil && len(taskResourceSpec.Memory) > 0 { - // Use the custom attributes from the database rather than the platform defaults from the application config - memoryValue = taskResourceSpec.Memory - } - memoryResource := &core.Resources_ResourceEntry{ - Name: core.Resources_MEMORY, - Value: memoryValue, - } - logger.Debugf(ctx, "Setting 'memory' for [%+v] to %s", identifier, platformValues.Memory) - resourceEntries = append(resourceEntries, memoryResource) - } - return resourceEntries -} - -// Assumes input contains a compiled task with a valid container resource execConfig. -// -// Note: The system will assign a system-default value for request but for limit it will deduce it from the request -// itself => Limit := Min([Some-Multiplier X Request], System-Max). For now we are using a multiplier of 1. In -// general we recommend the users to set limits close to requests for more predictability in the system. -func SetDefaults(ctx context.Context, taskConfig runtime.TaskResourceConfiguration, task *core.CompiledTask, - db repositories.RepositoryInterface, workflowName string) { - if task == nil { - logger.Warningf(ctx, "Can't set default resources for nil task.") - return - } - if task.Template == nil || task.Template.GetContainer() == nil || task.Template.GetContainer().Resources == nil { - // Nothing to do - logger.Debugf(ctx, "Not setting default resources for task [%+v], no container resources found to check", task) - return - } - - attributes, err := resourcematching.GetOverrideValuesToApply(ctx, resourcematching.GetOverrideValuesInput{ - Db: db, - Project: task.Template.Id.Project, - Domain: task.Template.Id.Domain, - Workflow: workflowName, - Resource: admin.MatchableResource_TASK_RESOURCE, - }) - if err != nil { - logger.Warningf(ctx, "Failed to fetch override values when assigning task resource default values for [%+v]: %v", - task.Template, err) - } - - logger.Debugf(ctx, "Assigning task requested resources for [%+v]", task.Template.Id) - var taskResourceSpec *admin.TaskResourceSpec - if attributes != nil && attributes.GetTaskResourceAttributes() != nil { - taskResourceSpec = attributes.GetTaskResourceAttributes().Defaults - } - task.Template.GetContainer().Resources.Requests = assignResourcesIfUnset( - ctx, task.Template.Id, taskConfig.GetDefaults(), task.Template.GetContainer().Resources.Requests, - taskResourceSpec) - - logger.Debugf(ctx, "Assigning task resource limits for [%+v]", task.Template.Id) - if attributes != nil && attributes.GetTaskResourceAttributes() != nil { - taskResourceSpec = attributes.GetTaskResourceAttributes().Limits - } - task.Template.GetContainer().Resources.Limits = assignResourcesIfUnset( - ctx, task.Template.Id, createTaskDefaultLimits(ctx, task), task.Template.GetContainer().Resources.Limits, - taskResourceSpec) -} - -func createTaskDefaultLimits(ctx context.Context, task *core.CompiledTask) runtimeInterfaces.TaskResourceSet { - // The values below should never be used (deduce it from the request; request should be set by the time we get here). - // Setting them here just in case we end up with requests not set. We are not adding to config because it would add - // more confusion as its mostly not used. - cpuLimit := "500m" - memoryLimit := "500Mi" - resourceEntries := task.Template.GetContainer().Resources.Requests - var cpuIndex, memoryIndex = -1, -1 - for idx, entry := range resourceEntries { - switch entry.Name { - case core.Resources_CPU: - cpuIndex = idx - - case core.Resources_MEMORY: - memoryIndex = idx - } - } - - if cpuIndex < 0 || memoryIndex < 0 { - logger.Errorf(ctx, "Cpu request and Memory request missing for %s", task.Template.Id) - } - - if cpuIndex >= 0 { - cpuLimit = resourceEntries[cpuIndex].Value - } - if memoryIndex >= 0 { - memoryLimit = resourceEntries[memoryIndex].Value - } - - return runtimeInterfaces.TaskResourceSet{CPU: cpuLimit, Memory: memoryLimit} -} - func validateTaskType(taskID core.Identifier, taskType string, whitelistConfig runtime.WhitelistConfiguration) error { taskTypeWhitelist := whitelistConfig.GetTaskTypeWhitelist() if taskTypeWhitelist == nil { diff --git a/flyteadmin/pkg/manager/impl/validation/task_validator_test.go b/flyteadmin/pkg/manager/impl/validation/task_validator_test.go index 15151dc37..94b5125ec 100644 --- a/flyteadmin/pkg/manager/impl/validation/task_validator_test.go +++ b/flyteadmin/pkg/manager/impl/validation/task_validator_test.go @@ -3,14 +3,8 @@ package validation import ( "context" "errors" - "fmt" "testing" - "github.com/lyft/flyteadmin/pkg/repositories/mocks" - "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" - - "github.com/golang/protobuf/proto" - "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" "k8s.io/apimachinery/pkg/api/resource" @@ -300,7 +294,7 @@ func TestValidateTaskResources_LimitLessThanRequested(t *testing.T) { Value: "1Gi", }, }) - assert.EqualError(t, err, "Resource CPU for [name:\"name\" ] cannot set default > limit") + assert.EqualError(t, err, "Type CPU for [name:\"name\" ] cannot set default > limit") } func TestValidateTaskResources_LimitGreaterThanConfig(t *testing.T) { @@ -320,7 +314,7 @@ func TestValidateTaskResources_LimitGreaterThanConfig(t *testing.T) { Value: "1.5Gi", }, }) - assert.EqualError(t, err, "Resource CPU for [name:\"name\" ] cannot set limit > platform limit") + assert.EqualError(t, err, "Type CPU for [name:\"name\" ] cannot set limit > platform limit") } func TestValidateTaskResources_DefaultGreaterThanConfig(t *testing.T) { @@ -335,7 +329,7 @@ func TestValidateTaskResources_DefaultGreaterThanConfig(t *testing.T) { Value: "1.5Gi", }, }, []*core.Resources_ResourceEntry{}) - assert.EqualError(t, err, "Resource CPU for [name:\"name\" ] cannot set default > platform limit") + assert.EqualError(t, err, "Type CPU for [name:\"name\" ] cannot set default > platform limit") } func TestValidateTaskResources_GPULimitNotEqualToRequested(t *testing.T) { @@ -374,7 +368,7 @@ func TestValidateTaskResources_GPULimitGreaterThanConfig(t *testing.T) { Value: "2", }, }) - assert.EqualError(t, err, "Resource GPU for [name:\"name\" ] cannot set default > platform limit") + assert.EqualError(t, err, "Type GPU for [name:\"name\" ] cannot set default > platform limit") } func TestValidateTaskResources_GPUDefaultGreaterThanConfig(t *testing.T) { @@ -389,7 +383,7 @@ func TestValidateTaskResources_GPUDefaultGreaterThanConfig(t *testing.T) { Value: "2", }, }, []*core.Resources_ResourceEntry{}) - assert.EqualError(t, err, "Resource GPU for [name:\"name\" ] cannot set default > platform limit") + assert.EqualError(t, err, "Type GPU for [name:\"name\" ] cannot set default > platform limit") } func TestIsWholeNumber(t *testing.T) { @@ -416,186 +410,3 @@ func TestIsWholeNumber(t *testing.T) { "%s should not be treated as a whole number", fraction) } } - -func TestAssignResourcesIfUnset(t *testing.T) { - platformValues := runtimeInterfaces.TaskResourceSet{ - CPU: "200m", - GPU: "8", - Memory: "200Gi", - } - taskResourceSpec := &admin.TaskResourceSpec{ - Cpu: "400m", - Memory: "400Gi", - } - assignedResources := assignResourcesIfUnset(context.Background(), &core.Identifier{ - Project: "project", - Domain: "domain", - Name: "name", - Version: "version", - }, platformValues, []*core.Resources_ResourceEntry{}, taskResourceSpec) - - assert.EqualValues(t, []*core.Resources_ResourceEntry{ - { - Name: core.Resources_CPU, - Value: taskResourceSpec.Cpu, - }, - { - Name: core.Resources_MEMORY, - Value: taskResourceSpec.Memory, - }, - }, assignedResources) -} - -func TestSetDefaults(t *testing.T) { - task := &core.CompiledTask{ - Template: &core.TaskTemplate{ - Target: &core.TaskTemplate_Container{ - Container: &core.Container{ - Resources: &core.Resources{ - Requests: []*core.Resources_ResourceEntry{ - { - Name: core.Resources_CPU, - Value: "200m", - }, - }, - }, - }, - }, - Id: &core.Identifier{ - Project: "project", - Domain: "domain", - Name: "task_name", - Version: "version", - }, - }, - } - - taskConfig := runtimeMocks.MockTaskResourceConfiguration{} - taskConfig.Defaults = runtimeInterfaces.TaskResourceSet{ - CPU: "200m", - GPU: "8", - Memory: "200Gi", - } - taskConfig.Limits = runtimeInterfaces.TaskResourceSet{ - CPU: "300m", - GPU: "8", - Memory: "500Gi", - } - SetDefaults(context.Background(), &taskConfig, task, mocks.NewMockRepository(), "workflow") - assert.True(t, proto.Equal( - &core.Container{ - Resources: &core.Resources{ - Requests: []*core.Resources_ResourceEntry{ - { - Name: core.Resources_CPU, - Value: "200m", - }, - { - Name: core.Resources_MEMORY, - Value: "200Gi", - }, - }, - Limits: []*core.Resources_ResourceEntry{ - { - Name: core.Resources_CPU, - Value: "200m", - }, - { - Name: core.Resources_MEMORY, - Value: "200Gi", - }, - }, - }, - }, - task.Template.GetContainer()), fmt.Sprintf("%+v", task.Template.GetContainer())) -} - -func TestSetDefaults_MissingDefaults(t *testing.T) { - task := &core.CompiledTask{ - Template: &core.TaskTemplate{ - Target: &core.TaskTemplate_Container{ - Container: &core.Container{ - Resources: &core.Resources{ - Requests: []*core.Resources_ResourceEntry{ - { - Name: core.Resources_CPU, - Value: "200m", - }, - }, - }, - }, - }, - Id: &core.Identifier{ - Project: "project", - Domain: "domain", - Name: "task_name", - Version: "version", - }, - }, - } - - taskConfig := runtimeMocks.MockTaskResourceConfiguration{} - taskConfig.Defaults = runtimeInterfaces.TaskResourceSet{ - CPU: "200m", - GPU: "8", - Memory: "200Gi", - } - taskConfig.Limits = runtimeInterfaces.TaskResourceSet{ - CPU: "300m", - GPU: "8", - } - SetDefaults(context.Background(), &taskConfig, task, mocks.NewMockRepository(), "workflow") - assert.True(t, proto.Equal( - &core.Container{ - Resources: &core.Resources{ - Requests: []*core.Resources_ResourceEntry{ - { - Name: core.Resources_CPU, - Value: "200m", - }, - { - Name: core.Resources_MEMORY, - Value: "200Gi", - }, - }, - Limits: []*core.Resources_ResourceEntry{ - { - Name: core.Resources_CPU, - Value: "200m", - }, - { - Name: core.Resources_MEMORY, - Value: "200Gi", - }, - }, - }, - }, - task.Template.GetContainer()), fmt.Sprintf("%+v", task.Template.GetContainer())) -} - -func TestCreateTaskDefaultLimits(t *testing.T) { - task := &core.CompiledTask{ - Template: &core.TaskTemplate{ - Target: &core.TaskTemplate_Container{ - Container: &core.Container{ - Resources: &core.Resources{ - Requests: []*core.Resources_ResourceEntry{ - { - Name: core.Resources_CPU, - Value: "200m", - }, - { - Name: core.Resources_MEMORY, - Value: "200Mi", - }, - }, - }, - }, - }, - }, - } - - defaultLimits := createTaskDefaultLimits(context.Background(), task) - assert.Equal(t, "200Mi", defaultLimits.Memory) - assert.Equal(t, "200m", defaultLimits.CPU) -} diff --git a/flyteadmin/pkg/manager/impl/workflow_attributes_manager.go b/flyteadmin/pkg/manager/impl/workflow_attributes_manager.go deleted file mode 100644 index f5a285a7c..000000000 --- a/flyteadmin/pkg/manager/impl/workflow_attributes_manager.go +++ /dev/null @@ -1,79 +0,0 @@ -package impl - -import ( - "context" - - "github.com/lyft/flytestdlib/logger" - - "github.com/lyft/flyteadmin/pkg/manager/impl/validation" - "github.com/lyft/flyteadmin/pkg/repositories/transformers" - - "github.com/lyft/flyteadmin/pkg/manager/interfaces" - "github.com/lyft/flyteadmin/pkg/repositories" - "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" -) - -type WorkflowAttributesManager struct { - db repositories.RepositoryInterface -} - -func (m *WorkflowAttributesManager) UpdateWorkflowAttributes( - ctx context.Context, request admin.WorkflowAttributesUpdateRequest) ( - *admin.WorkflowAttributesUpdateResponse, error) { - var resource admin.MatchableResource - var err error - if resource, err = validation.ValidateWorkflowAttributesUpdateRequest(request); err != nil { - return nil, err - } - - model, err := transformers.ToWorkflowAttributesModel(*request.Attributes, resource) - if err != nil { - return nil, err - } - err = m.db.WorkflowAttributesRepo().CreateOrUpdate(ctx, model) - if err != nil { - return nil, err - } - - return &admin.WorkflowAttributesUpdateResponse{}, nil -} - -func (m *WorkflowAttributesManager) GetWorkflowAttributes( - ctx context.Context, request admin.WorkflowAttributesGetRequest) ( - *admin.WorkflowAttributesGetResponse, error) { - if err := validation.ValidateWorkflowAttributesGetRequest(request); err != nil { - return nil, err - } - projectAttributesModel, err := m.db.WorkflowAttributesRepo().Get( - ctx, request.Project, request.Domain, request.Workflow, request.ResourceType.String()) - if err != nil { - return nil, err - } - projectAttributes, err := transformers.FromWorkflowAttributesModel(projectAttributesModel) - if err != nil { - return nil, err - } - return &admin.WorkflowAttributesGetResponse{ - Attributes: &projectAttributes, - }, nil -} - -func (m *WorkflowAttributesManager) DeleteWorkflowAttributes(ctx context.Context, - request admin.WorkflowAttributesDeleteRequest) (*admin.WorkflowAttributesDeleteResponse, error) { - if err := validation.ValidateWorkflowAttributesDeleteRequest(request); err != nil { - return nil, err - } - if err := m.db.WorkflowAttributesRepo().Delete( - ctx, request.Project, request.Domain, request.Workflow, request.ResourceType.String()); err != nil { - return nil, err - } - logger.Infof(ctx, "Deleted workflow attributes for: %s-%s-%s (%s)", request.Project, - request.Domain, request.Workflow, request.ResourceType.String()) - return &admin.WorkflowAttributesDeleteResponse{}, nil -} - -func NewWorkflowAttributesManager(db repositories.RepositoryInterface) interfaces.WorkflowAttributesInterface { - return &WorkflowAttributesManager{ - db: db, - } -} diff --git a/flyteadmin/pkg/manager/impl/workflow_attributes_manager_test.go b/flyteadmin/pkg/manager/impl/workflow_attributes_manager_test.go deleted file mode 100644 index 12f40ad5c..000000000 --- a/flyteadmin/pkg/manager/impl/workflow_attributes_manager_test.go +++ /dev/null @@ -1,98 +0,0 @@ -package impl - -import ( - "context" - "testing" - - "github.com/golang/protobuf/proto" - "github.com/lyft/flyteadmin/pkg/manager/impl/testutils" - "github.com/lyft/flyteadmin/pkg/repositories/mocks" - "github.com/lyft/flyteadmin/pkg/repositories/models" - "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" - "github.com/stretchr/testify/assert" -) - -func TestUpdateWorkflowAttributes(t *testing.T) { - request := admin.WorkflowAttributesUpdateRequest{ - Attributes: &admin.WorkflowAttributes{ - Project: "project", - Domain: "domain", - Workflow: "workflow", - MatchingAttributes: testutils.ExecutionQueueAttributes, - }, - } - db := mocks.NewMockRepository() - expectedSerializedAttrs, _ := proto.Marshal(testutils.ExecutionQueueAttributes) - var createOrUpdateCalled bool - db.WorkflowAttributesRepo().(*mocks.MockWorkflowAttributesRepo).CreateOrUpdateFunction = func( - ctx context.Context, input models.WorkflowAttributes) error { - assert.Equal(t, "project", input.Project) - assert.Equal(t, "domain", input.Domain) - assert.Equal(t, "workflow", input.Workflow) - assert.Equal(t, admin.MatchableResource_EXECUTION_QUEUE.String(), input.Resource) - assert.EqualValues(t, expectedSerializedAttrs, input.Attributes) - createOrUpdateCalled = true - return nil - } - manager := NewWorkflowAttributesManager(db) - _, err := manager.UpdateWorkflowAttributes(context.Background(), request) - assert.Nil(t, err) - assert.True(t, createOrUpdateCalled) -} - -func TestGetWorkflowAttributes(t *testing.T) { - request := admin.WorkflowAttributesGetRequest{ - Project: "project", - Domain: "domain", - Workflow: "workflow", - ResourceType: admin.MatchableResource_EXECUTION_QUEUE, - } - db := mocks.NewMockRepository() - db.WorkflowAttributesRepo().(*mocks.MockWorkflowAttributesRepo).GetFunction = func( - ctx context.Context, project, domain, workflow, resource string) (models.WorkflowAttributes, error) { - assert.Equal(t, "project", project) - assert.Equal(t, "domain", domain) - assert.Equal(t, "workflow", workflow) - assert.Equal(t, admin.MatchableResource_EXECUTION_QUEUE.String(), resource) - expectedSerializedAttrs, _ := proto.Marshal(testutils.ExecutionQueueAttributes) - return models.WorkflowAttributes{ - Project: project, - Domain: domain, - Workflow: workflow, - Resource: resource, - Attributes: expectedSerializedAttrs, - }, nil - } - manager := NewWorkflowAttributesManager(db) - response, err := manager.GetWorkflowAttributes(context.Background(), request) - assert.Nil(t, err) - assert.True(t, proto.Equal(&admin.WorkflowAttributesGetResponse{ - Attributes: &admin.WorkflowAttributes{ - Project: "project", - Domain: "domain", - Workflow: "workflow", - MatchingAttributes: testutils.ExecutionQueueAttributes, - }, - }, response)) -} - -func TestDeleteWorkflowAttributes(t *testing.T) { - request := admin.WorkflowAttributesDeleteRequest{ - Project: "project", - Domain: "domain", - Workflow: "workflow", - ResourceType: admin.MatchableResource_EXECUTION_QUEUE, - } - db := mocks.NewMockRepository() - db.WorkflowAttributesRepo().(*mocks.MockWorkflowAttributesRepo).DeleteFunction = func( - ctx context.Context, project, domain, workflow, resource string) error { - assert.Equal(t, "project", project) - assert.Equal(t, "domain", domain) - assert.Equal(t, "workflow", workflow) - assert.Equal(t, admin.MatchableResource_EXECUTION_QUEUE.String(), resource) - return nil - } - manager := NewWorkflowAttributesManager(db) - _, err := manager.DeleteWorkflowAttributes(context.Background(), request) - assert.Nil(t, err) -} diff --git a/flyteadmin/pkg/manager/interfaces/project_attributes.go b/flyteadmin/pkg/manager/interfaces/project_attributes.go deleted file mode 100644 index 85910d699..000000000 --- a/flyteadmin/pkg/manager/interfaces/project_attributes.go +++ /dev/null @@ -1,17 +0,0 @@ -package interfaces - -import ( - "context" - - "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" -) - -// Interface for managing project-specific attributes. -type ProjectAttributesInterface interface { - UpdateProjectAttributes(ctx context.Context, request admin.ProjectAttributesUpdateRequest) ( - *admin.ProjectAttributesUpdateResponse, error) - GetProjectAttributes(ctx context.Context, request admin.ProjectAttributesGetRequest) ( - *admin.ProjectAttributesGetResponse, error) - DeleteProjectAttributes(ctx context.Context, request admin.ProjectAttributesDeleteRequest) ( - *admin.ProjectAttributesDeleteResponse, error) -} diff --git a/flyteadmin/pkg/manager/interfaces/project_domain_attributes.go b/flyteadmin/pkg/manager/interfaces/project_domain_attributes.go deleted file mode 100644 index 4253f13eb..000000000 --- a/flyteadmin/pkg/manager/interfaces/project_domain_attributes.go +++ /dev/null @@ -1,17 +0,0 @@ -package interfaces - -import ( - "context" - - "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" -) - -// Interface for managing projects and domain -specific attributes. -type ProjectDomainAttributesInterface interface { - UpdateProjectDomainAttributes(ctx context.Context, request admin.ProjectDomainAttributesUpdateRequest) ( - *admin.ProjectDomainAttributesUpdateResponse, error) - GetProjectDomainAttributes(ctx context.Context, request admin.ProjectDomainAttributesGetRequest) ( - *admin.ProjectDomainAttributesGetResponse, error) - DeleteProjectDomainAttributes(ctx context.Context, request admin.ProjectDomainAttributesDeleteRequest) ( - *admin.ProjectDomainAttributesDeleteResponse, error) -} diff --git a/flyteadmin/pkg/manager/interfaces/resource.go b/flyteadmin/pkg/manager/interfaces/resource.go new file mode 100644 index 000000000..a61fe7ab6 --- /dev/null +++ b/flyteadmin/pkg/manager/interfaces/resource.go @@ -0,0 +1,44 @@ +package interfaces + +import ( + "context" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" +) + +// Interface for managing project, domain and workflow -specific attributes. +type ResourceInterface interface { + GetResource(ctx context.Context, request ResourceRequest) (*ResourceResponse, error) + + UpdateWorkflowAttributes(ctx context.Context, request admin.WorkflowAttributesUpdateRequest) ( + *admin.WorkflowAttributesUpdateResponse, error) + GetWorkflowAttributes(ctx context.Context, request admin.WorkflowAttributesGetRequest) ( + *admin.WorkflowAttributesGetResponse, error) + DeleteWorkflowAttributes(ctx context.Context, request admin.WorkflowAttributesDeleteRequest) ( + *admin.WorkflowAttributesDeleteResponse, error) + + UpdateProjectDomainAttributes(ctx context.Context, request admin.ProjectDomainAttributesUpdateRequest) ( + *admin.ProjectDomainAttributesUpdateResponse, error) + GetProjectDomainAttributes(ctx context.Context, request admin.ProjectDomainAttributesGetRequest) ( + *admin.ProjectDomainAttributesGetResponse, error) + DeleteProjectDomainAttributes(ctx context.Context, request admin.ProjectDomainAttributesDeleteRequest) ( + *admin.ProjectDomainAttributesDeleteResponse, error) +} + +// TODO we can move this to flyteidl, once we are exposing an endpoint +type ResourceRequest struct { + Project string + Domain string + Workflow string + LaunchPlan string + ResourceType admin.MatchableResource +} + +type ResourceResponse struct { + Project string + Domain string + Workflow string + LaunchPlan string + ResourceType string + Attributes *admin.MatchingAttributes +} diff --git a/flyteadmin/pkg/manager/interfaces/workflow_attributes.go b/flyteadmin/pkg/manager/interfaces/workflow_attributes.go deleted file mode 100644 index 885ac0a9d..000000000 --- a/flyteadmin/pkg/manager/interfaces/workflow_attributes.go +++ /dev/null @@ -1,17 +0,0 @@ -package interfaces - -import ( - "context" - - "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" -) - -// Interface for managing project, domain and workflow -specific attributes. -type WorkflowAttributesInterface interface { - UpdateWorkflowAttributes(ctx context.Context, request admin.WorkflowAttributesUpdateRequest) ( - *admin.WorkflowAttributesUpdateResponse, error) - GetWorkflowAttributes(ctx context.Context, request admin.WorkflowAttributesGetRequest) ( - *admin.WorkflowAttributesGetResponse, error) - DeleteWorkflowAttributes(ctx context.Context, request admin.WorkflowAttributesDeleteRequest) ( - *admin.WorkflowAttributesDeleteResponse, error) -} diff --git a/flyteadmin/pkg/manager/mocks/project_domain.go b/flyteadmin/pkg/manager/mocks/resource.go similarity index 55% rename from flyteadmin/pkg/manager/mocks/project_domain.go rename to flyteadmin/pkg/manager/mocks/resource.go index 058687eb4..6eb54c189 100644 --- a/flyteadmin/pkg/manager/mocks/project_domain.go +++ b/flyteadmin/pkg/manager/mocks/resource.go @@ -3,6 +3,8 @@ package mocks import ( "context" + "github.com/lyft/flyteadmin/pkg/manager/interfaces" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" ) @@ -13,17 +15,36 @@ type GetProjectDomainFunc func(ctx context.Context, request admin.ProjectDomainA type DeleteProjectDomainFunc func(ctx context.Context, request admin.ProjectDomainAttributesDeleteRequest) ( *admin.ProjectDomainAttributesDeleteResponse, error) -type MockProjectDomainAttributesManager struct { +type MockResourceManager struct { updateProjectDomainFunc UpdateProjectDomainFunc GetFunc GetProjectDomainFunc DeleteFunc DeleteProjectDomainFunc } -func (m *MockProjectDomainAttributesManager) SetUpdateProjectDomainAttributes(updateProjectDomainFunc UpdateProjectDomainFunc) { +func (m *MockResourceManager) GetResource(ctx context.Context, request interfaces.ResourceRequest) (*interfaces.ResourceResponse, error) { + panic("implement me") +} + +func (m *MockResourceManager) UpdateWorkflowAttributes(ctx context.Context, request admin.WorkflowAttributesUpdateRequest) ( + *admin.WorkflowAttributesUpdateResponse, error) { + panic("implement me") +} + +func (m *MockResourceManager) GetWorkflowAttributes(ctx context.Context, request admin.WorkflowAttributesGetRequest) ( + *admin.WorkflowAttributesGetResponse, error) { + panic("implement me") +} + +func (m *MockResourceManager) DeleteWorkflowAttributes(ctx context.Context, request admin.WorkflowAttributesDeleteRequest) ( + *admin.WorkflowAttributesDeleteResponse, error) { + panic("implement me") +} + +func (m *MockResourceManager) SetUpdateProjectDomainAttributes(updateProjectDomainFunc UpdateProjectDomainFunc) { m.updateProjectDomainFunc = updateProjectDomainFunc } -func (m *MockProjectDomainAttributesManager) UpdateProjectDomainAttributes( +func (m *MockResourceManager) UpdateProjectDomainAttributes( ctx context.Context, request admin.ProjectDomainAttributesUpdateRequest) ( *admin.ProjectDomainAttributesUpdateResponse, error) { if m.updateProjectDomainFunc != nil { @@ -32,7 +53,7 @@ func (m *MockProjectDomainAttributesManager) UpdateProjectDomainAttributes( return nil, nil } -func (m *MockProjectDomainAttributesManager) GetProjectDomainAttributes( +func (m *MockResourceManager) GetProjectDomainAttributes( ctx context.Context, request admin.ProjectDomainAttributesGetRequest) ( *admin.ProjectDomainAttributesGetResponse, error) { if m.GetFunc != nil { @@ -41,7 +62,7 @@ func (m *MockProjectDomainAttributesManager) GetProjectDomainAttributes( return nil, nil } -func (m *MockProjectDomainAttributesManager) DeleteProjectDomainAttributes( +func (m *MockResourceManager) DeleteProjectDomainAttributes( ctx context.Context, request admin.ProjectDomainAttributesDeleteRequest) ( *admin.ProjectDomainAttributesDeleteResponse, error) { if m.DeleteFunc != nil { diff --git a/flyteadmin/pkg/repositories/config/migrations.go b/flyteadmin/pkg/repositories/config/migrations.go index b761680a2..1ca585767 100644 --- a/flyteadmin/pkg/repositories/config/migrations.go +++ b/flyteadmin/pkg/repositories/config/migrations.go @@ -151,33 +151,12 @@ var Migrations = []*gormigrate.Migration{ }, // Add ProjectAttributes with custom resource attributes. { - ID: "2019-12-30-project-attributes", + ID: "2020-01-10-resource", Migrate: func(tx *gorm.DB) error { - return tx.AutoMigrate(&models.ProjectAttributes{}).Error + return tx.AutoMigrate(&models.Resource{}).Error }, Rollback: func(tx *gorm.DB) error { - return tx.DropTable("project_attributes").Error - }, - }, - - // Add ProjectDomainAttributes with custom resource attributes. - { - ID: "2019-12-30-project-domain-attributes", - Migrate: func(tx *gorm.DB) error { - return tx.AutoMigrate(&models.ProjectDomainAttributes{}).Error - }, - Rollback: func(tx *gorm.DB) error { - return tx.DropTable("project_domain_attributes").Error - }, - }, - // Add WorkflowAttributes with custom resource attributes. - { - ID: "2019-12-30-workflow-attributes", - Migrate: func(tx *gorm.DB) error { - return tx.AutoMigrate(&models.WorkflowAttributes{}).Error - }, - Rollback: func(tx *gorm.DB) error { - return tx.DropTable("workflow_attributes").Error + return tx.DropTable("resources").Error }, }, } diff --git a/flyteadmin/pkg/repositories/factory.go b/flyteadmin/pkg/repositories/factory.go index 720b77b43..3e14ffc39 100644 --- a/flyteadmin/pkg/repositories/factory.go +++ b/flyteadmin/pkg/repositories/factory.go @@ -28,9 +28,7 @@ type RepositoryInterface interface { LaunchPlanRepo() interfaces.LaunchPlanRepoInterface ExecutionRepo() interfaces.ExecutionRepoInterface ProjectRepo() interfaces.ProjectRepoInterface - ProjectAttributesRepo() interfaces.ProjectAttributesRepoInterface - ProjectDomainAttributesRepo() interfaces.ProjectDomainAttributesRepoInterface - WorkflowAttributesRepo() interfaces.WorkflowAttributesRepoInterface + ResourceRepo() interfaces.ResourceRepoInterface NodeExecutionRepo() interfaces.NodeExecutionRepoInterface TaskExecutionRepo() interfaces.TaskExecutionRepoInterface NamedEntityRepo() interfaces.NamedEntityRepoInterface diff --git a/flyteadmin/pkg/repositories/gormimpl/project_attributes_repo.go b/flyteadmin/pkg/repositories/gormimpl/project_attributes_repo.go deleted file mode 100644 index ce3d5c670..000000000 --- a/flyteadmin/pkg/repositories/gormimpl/project_attributes_repo.go +++ /dev/null @@ -1,88 +0,0 @@ -package gormimpl - -import ( - "context" - - "github.com/jinzhu/gorm" - "github.com/lyft/flyteadmin/pkg/repositories/errors" - "github.com/lyft/flyteadmin/pkg/repositories/interfaces" - "github.com/lyft/flyteadmin/pkg/repositories/models" - "github.com/lyft/flytestdlib/promutils" - "google.golang.org/grpc/codes" - - flyteAdminErrors "github.com/lyft/flyteadmin/pkg/errors" -) - -type ProjectAttributesRepo struct { - db *gorm.DB - errorTransformer errors.ErrorTransformer - metrics gormMetrics -} - -func (r *ProjectAttributesRepo) CreateOrUpdate(ctx context.Context, input models.ProjectAttributes) error { - timer := r.metrics.GetDuration.Start() - var record models.ProjectAttributes - tx := r.db.FirstOrCreate(&record, models.ProjectAttributes{ - Project: input.Project, - Resource: input.Resource, - }) - timer.Stop() - if tx.Error != nil { - return r.errorTransformer.ToFlyteAdminError(tx.Error) - } - - timer = r.metrics.UpdateDuration.Start() - record.Attributes = input.Attributes - tx = r.db.Save(&record) - timer.Stop() - if tx.Error != nil { - return r.errorTransformer.ToFlyteAdminError(tx.Error) - } - return nil -} - -func (r *ProjectAttributesRepo) Get(ctx context.Context, project, resource string) (models.ProjectAttributes, error) { - var model models.ProjectAttributes - timer := r.metrics.GetDuration.Start() - tx := r.db.Where(&models.ProjectAttributes{ - Project: project, - Resource: resource, - }).First(&model) - timer.Stop() - if tx.Error != nil { - return models.ProjectAttributes{}, r.errorTransformer.ToFlyteAdminError(tx.Error) - } - if tx.RecordNotFound() { - return models.ProjectAttributes{}, flyteAdminErrors.NewFlyteAdminErrorf(codes.NotFound, - "project [%s] not found", project) - } - return model, nil -} - -func (r *ProjectAttributesRepo) Delete(ctx context.Context, project, resource string) error { - var tx *gorm.DB - r.metrics.DeleteDuration.Time(func() { - tx = r.db.Where(&models.ProjectAttributes{ - Project: project, - Resource: resource, - }).Unscoped().Delete(models.ProjectAttributes{}) - }) - if tx.Error != nil { - return r.errorTransformer.ToFlyteAdminError(tx.Error) - } - if tx.RecordNotFound() { - return flyteAdminErrors.NewFlyteAdminErrorf(codes.NotFound, - "project [%s] not found", project) - } - return nil -} - -func NewProjectAttributesRepo(db *gorm.DB, errorTransformer errors.ErrorTransformer, - scope promutils.Scope) interfaces.ProjectAttributesRepoInterface { - metrics := newMetrics(scope) - return &ProjectAttributesRepo{ - db: db, - errorTransformer: errorTransformer, - metrics: metrics, - } -} diff --git a/flyteadmin/pkg/repositories/gormimpl/project_attributes_repo_test.go b/flyteadmin/pkg/repositories/gormimpl/project_attributes_repo_test.go deleted file mode 100644 index 85b0740dd..000000000 --- a/flyteadmin/pkg/repositories/gormimpl/project_attributes_repo_test.go +++ /dev/null @@ -1,71 +0,0 @@ -package gormimpl - -import ( - "context" - "testing" - - mocket "github.com/Selvatico/go-mocket" - "github.com/lyft/flyteadmin/pkg/repositories/errors" - "github.com/lyft/flyteadmin/pkg/repositories/models" - mockScope "github.com/lyft/flytestdlib/promutils" - "github.com/stretchr/testify/assert" -) - -const testProjectAttr = "project" -const testResourceAttr = "resource" - -func TestCreateProjectAttributes(t *testing.T) { - projectRepo := NewProjectAttributesRepo(GetDbForTest(t), errors.NewTestErrorTransformer(), mockScope.NewTestScope()) - GlobalMock := mocket.Catcher.Reset() - - query := GlobalMock.NewMock() - query.WithQuery( - `INSERT INTO "project_attributes" ` + - `("created_at","updated_at","deleted_at","project","resource","attributes") VALUES (?,?,?,?,?,?)`) - - err := projectRepo.CreateOrUpdate(context.Background(), models.ProjectAttributes{ - Project: testProjectAttr, - Resource: testResourceAttr, - Attributes: []byte("attrs"), - }) - assert.NoError(t, err) - assert.True(t, query.Triggered) -} - -func TestGetProjectAttributes(t *testing.T) { - projectRepo := NewProjectAttributesRepo(GetDbForTest(t), errors.NewTestErrorTransformer(), mockScope.NewTestScope()) - GlobalMock := mocket.Catcher.Reset() - - response := make(map[string]interface{}) - response["project"] = testProjectAttr - response["resource"] = testResourceAttr - response["attributes"] = []byte("attrs") - - query := GlobalMock.NewMock() - query.WithQuery(`SELECT * FROM "project_attributes" WHERE "project_attributes"."deleted_at" ` + - `IS NULL AND (("project_attributes"."project" = project) AND ("project_attributes"."resource" = resource)) ` + - `ORDER BY "project_attributes"."id" ASC LIMIT 1`).WithReply( - []map[string]interface{}{ - response, - }) - - output, err := projectRepo.Get(context.Background(), "project", "resource") - assert.Nil(t, err) - assert.Equal(t, testProjectAttr, output.Project) - assert.Equal(t, testResourceAttr, output.Resource) - assert.Equal(t, []byte("attrs"), output.Attributes) -} - -func TestDeleteProjectAttributes(t *testing.T) { - projectRepo := NewProjectAttributesRepo(GetDbForTest(t), errors.NewTestErrorTransformer(), mockScope.NewTestScope()) - GlobalMock := mocket.Catcher.Reset() - - query := GlobalMock.NewMock() - fakeResponse := query.WithQuery( - `DELETE FROM "project_attributes" WHERE ("project_attributes"."project" = ?) AND ` + - `("project_attributes"."resource" = ?)`) - - err := projectRepo.Delete(context.Background(), "project", "resource") - assert.Nil(t, err) - assert.True(t, fakeResponse.Triggered) -} diff --git a/flyteadmin/pkg/repositories/gormimpl/project_domain_attributes_repo.go b/flyteadmin/pkg/repositories/gormimpl/project_domain_attributes_repo.go deleted file mode 100644 index 7d8fe7eaa..000000000 --- a/flyteadmin/pkg/repositories/gormimpl/project_domain_attributes_repo.go +++ /dev/null @@ -1,92 +0,0 @@ -package gormimpl - -import ( - "context" - - "github.com/jinzhu/gorm" - "github.com/lyft/flyteadmin/pkg/repositories/errors" - "github.com/lyft/flyteadmin/pkg/repositories/interfaces" - "github.com/lyft/flyteadmin/pkg/repositories/models" - "github.com/lyft/flytestdlib/promutils" - "google.golang.org/grpc/codes" - - flyteAdminErrors "github.com/lyft/flyteadmin/pkg/errors" -) - -type ProjectDomainAttributesRepo struct { - db *gorm.DB - errorTransformer errors.ErrorTransformer - metrics gormMetrics -} - -func (r *ProjectDomainAttributesRepo) CreateOrUpdate(ctx context.Context, input models.ProjectDomainAttributes) error { - timer := r.metrics.GetDuration.Start() - var record models.ProjectDomainAttributes - tx := r.db.FirstOrCreate(&record, models.ProjectDomainAttributes{ - Project: input.Project, - Domain: input.Domain, - Resource: input.Resource, - }) - timer.Stop() - if tx.Error != nil { - return r.errorTransformer.ToFlyteAdminError(tx.Error) - } - - timer = r.metrics.UpdateDuration.Start() - record.Attributes = input.Attributes - tx = r.db.Save(&record) - timer.Stop() - if tx.Error != nil { - return r.errorTransformer.ToFlyteAdminError(tx.Error) - } - return nil -} - -func (r *ProjectDomainAttributesRepo) Get(ctx context.Context, project, domain, resource string) ( - models.ProjectDomainAttributes, error) { - var model models.ProjectDomainAttributes - timer := r.metrics.GetDuration.Start() - tx := r.db.Where(&models.ProjectDomainAttributes{ - Project: project, - Domain: domain, - Resource: resource, - }).First(&model) - timer.Stop() - if tx.Error != nil { - return models.ProjectDomainAttributes{}, r.errorTransformer.ToFlyteAdminError(tx.Error) - } - if tx.RecordNotFound() { - return models.ProjectDomainAttributes{}, flyteAdminErrors.NewFlyteAdminErrorf(codes.NotFound, - "project-domain [%s-%s] not found", project, domain) - } - return model, nil -} - -func (r *ProjectDomainAttributesRepo) Delete(ctx context.Context, project, domain, resource string) error { - var tx *gorm.DB - r.metrics.DeleteDuration.Time(func() { - tx = r.db.Where(&models.ProjectDomainAttributes{ - Project: project, - Domain: domain, - Resource: resource, - }).Unscoped().Delete(models.ProjectDomainAttributes{}) - }) - if tx.Error != nil { - return r.errorTransformer.ToFlyteAdminError(tx.Error) - } - if tx.RecordNotFound() { - return flyteAdminErrors.NewFlyteAdminErrorf(codes.NotFound, - "project-domain [%s-%s] not found", project, domain) - } - return nil -} - -func NewProjectDomainAttributesRepo(db *gorm.DB, errorTransformer errors.ErrorTransformer, - scope promutils.Scope) interfaces.ProjectDomainAttributesRepoInterface { - metrics := newMetrics(scope) - return &ProjectDomainAttributesRepo{ - db: db, - errorTransformer: errorTransformer, - metrics: metrics, - } -} diff --git a/flyteadmin/pkg/repositories/gormimpl/project_domain_attributes_repo_test.go b/flyteadmin/pkg/repositories/gormimpl/project_domain_attributes_repo_test.go deleted file mode 100644 index b51d9d8e4..000000000 --- a/flyteadmin/pkg/repositories/gormimpl/project_domain_attributes_repo_test.go +++ /dev/null @@ -1,72 +0,0 @@ -package gormimpl - -import ( - "context" - "testing" - - mocket "github.com/Selvatico/go-mocket" - "github.com/lyft/flyteadmin/pkg/repositories/errors" - "github.com/lyft/flyteadmin/pkg/repositories/models" - mockScope "github.com/lyft/flytestdlib/promutils" - "github.com/stretchr/testify/assert" -) - -func TestCreateProjectDomainAttributes(t *testing.T) { - projectDomainRepo := NewProjectDomainAttributesRepo(GetDbForTest(t), errors.NewTestErrorTransformer(), mockScope.NewTestScope()) - GlobalMock := mocket.Catcher.Reset() - - query := GlobalMock.NewMock() - query.WithQuery( - `INSERT INTO "project_domain_attributes" ` + - `("created_at","updated_at","deleted_at","project","domain","resource","attributes") VALUES (?,?,?,?,?,?,?)`) - - err := projectDomainRepo.CreateOrUpdate(context.Background(), models.ProjectDomainAttributes{ - Project: "project", - Domain: "domain", - Resource: "resource", - Attributes: []byte("attrs"), - }) - assert.NoError(t, err) - assert.True(t, query.Triggered) -} - -func TestGetProjectDomainAttributes(t *testing.T) { - projectDomainRepo := NewProjectDomainAttributesRepo(GetDbForTest(t), errors.NewTestErrorTransformer(), mockScope.NewTestScope()) - GlobalMock := mocket.Catcher.Reset() - - response := make(map[string]interface{}) - response["project"] = "project" - response["domain"] = "domain" - response["resource"] = "resource" - response["attributes"] = []byte("attrs") - - query := GlobalMock.NewMock() - query.WithQuery(`SELECT * FROM "project_domain_attributes" WHERE "project_domain_attributes"."deleted_at" ` + - `IS NULL AND (("project_domain_attributes"."project" = project) AND ("project_domain_attributes"."domain" = ` + - `domain) AND ("project_domain_attributes"."resource" = resource)) ORDER BY "project_domain_attributes"."id" ` + - `ASC LIMIT 1`).WithReply( - []map[string]interface{}{ - response, - }) - - output, err := projectDomainRepo.Get(context.Background(), "project", "domain", "resource") - assert.Nil(t, err) - assert.Equal(t, "project", output.Project) - assert.Equal(t, "domain", output.Domain) - assert.Equal(t, "resource", output.Resource) - assert.Equal(t, []byte("attrs"), output.Attributes) -} - -func TestDeleteProjectDomainAttributes(t *testing.T) { - projectDomainRepo := NewProjectDomainAttributesRepo(GetDbForTest(t), errors.NewTestErrorTransformer(), mockScope.NewTestScope()) - GlobalMock := mocket.Catcher.Reset() - - query := GlobalMock.NewMock() - fakeResponse := query.WithQuery( - `DELETE FROM "project_domain_attributes" WHERE ("project_domain_attributes"."project" = ?) AND ` + - `("project_domain_attributes"."domain" = ?) AND ("project_domain_attributes"."resource" = ?)`) - - err := projectDomainRepo.Delete(context.Background(), "project", "domain", "resource") - assert.Nil(t, err) - assert.True(t, fakeResponse.Triggered) -} diff --git a/flyteadmin/pkg/repositories/gormimpl/resource_repo.go b/flyteadmin/pkg/repositories/gormimpl/resource_repo.go new file mode 100644 index 000000000..57bfb7c52 --- /dev/null +++ b/flyteadmin/pkg/repositories/gormimpl/resource_repo.go @@ -0,0 +1,171 @@ +package gormimpl + +import ( + "context" + "fmt" + + "github.com/jinzhu/gorm" + "github.com/lyft/flyteadmin/pkg/repositories/errors" + "github.com/lyft/flyteadmin/pkg/repositories/interfaces" + "github.com/lyft/flyteadmin/pkg/repositories/models" + "github.com/lyft/flytestdlib/promutils" + "google.golang.org/grpc/codes" + + flyteAdminErrors "github.com/lyft/flyteadmin/pkg/errors" +) + +type ResourceRepo struct { + db *gorm.DB + errorTransformer errors.ErrorTransformer + metrics gormMetrics +} + +/* + The data in the Resource repo maps to the following rules: + * Domain and ResourceType can never be empty. + * Empty string can be interpreted as all. Example: "" for Project field can be interpreted as all Projects for a domain. + * One cannot provide specific value for Project, unless a specific value for Domain is provided. + ** Project is always scoped within a domain. + ** Example: Domain="" Project="Lyft" is invalid. + * One cannot provide specific value for Workflow, unless a specific value for Domain and Project is provided. + ** Workflow is always scoped within a domain and project. + ** Example: Domain="staging" Project="" Workflow="W1" is invalid. + * One cannot provide specific value for Launch plan, unless a specific value for Domain, Project and Workflow is provided. + ** Launch plan is always scoped within a domain, project and workflow. + ** Example: Domain="staging" Project="Lyft" Workflow="" LaunchPlan= "l1" is invalid. +*/ +func validateCreateOrUpdateResourceInput(project, domain, workflow, launchPlan, resourceType string) bool { + if domain == "" || resourceType == "" { + return false + } + if project == "" && (workflow != "" || launchPlan != "") { + return false + } + if workflow == "" && launchPlan != "" { + return false + } + return true +} + +func (r *ResourceRepo) CreateOrUpdate(ctx context.Context, input models.Resource) error { + if !validateCreateOrUpdateResourceInput(input.Project, input.Domain, input.Workflow, input.LaunchPlan, input.ResourceType) { + return errors.GetInvalidInputError(fmt.Sprintf("%v", input)) + } + if input.Priority == 0 { + return errors.GetInvalidInputError(fmt.Sprintf("invalid priority %v", input)) + } + timer := r.metrics.GetDuration.Start() + var record models.Resource + tx := r.db.FirstOrCreate(&record, models.Resource{ + Project: input.Project, + Domain: input.Domain, + Workflow: input.Workflow, + LaunchPlan: input.LaunchPlan, + ResourceType: input.ResourceType, + Priority: input.Priority, + }) + timer.Stop() + if tx.Error != nil { + return r.errorTransformer.ToFlyteAdminError(tx.Error) + } + + timer = r.metrics.UpdateDuration.Start() + record.Attributes = input.Attributes + tx = r.db.Save(&record) + timer.Stop() + if tx.Error != nil { + return r.errorTransformer.ToFlyteAdminError(tx.Error) + } + return nil +} + +func (r *ResourceRepo) Get(ctx context.Context, ID interfaces.ResourceID) (models.Resource, error) { + if !validateCreateOrUpdateResourceInput(ID.Project, ID.Domain, ID.Workflow, ID.LaunchPlan, ID.ResourceType) { + return models.Resource{}, r.errorTransformer.ToFlyteAdminError(errors.GetInvalidInputError(fmt.Sprintf("%v", ID))) + } + var resources []models.Resource + timer := r.metrics.GetDuration.Start() + + txWhereClause := "resource_type = ? AND domain = ? AND project IN (?) AND workflow IN (?) AND launch_plan IN (?)" + project := []string{""} + if ID.Project != "" { + project = append(project, ID.Project) + } + + workflow := []string{""} + if ID.Workflow != "" { + workflow = append(workflow, ID.Workflow) + } + + launchPlan := []string{""} + if ID.LaunchPlan != "" { + launchPlan = append(launchPlan, ID.LaunchPlan) + } + + tx := r.db.Where(txWhereClause, ID.ResourceType, ID.Domain, project, workflow, launchPlan) + tx.Order("priority desc").First(&resources) + timer.Stop() + + if tx.Error != nil { + return models.Resource{}, r.errorTransformer.ToFlyteAdminError(tx.Error) + } + if tx.RecordNotFound() || len(resources) == 0 { + return models.Resource{}, flyteAdminErrors.NewFlyteAdminErrorf(codes.NotFound, + "%+v", ID) + } + return resources[0], nil +} + +func (r *ResourceRepo) GetRaw(ctx context.Context, ID interfaces.ResourceID) (models.Resource, error) { + if ID.Domain == "" || ID.ResourceType == "" { + return models.Resource{}, r.errorTransformer.ToFlyteAdminError(errors.GetInvalidInputError(fmt.Sprintf("%v", ID))) + } + var model models.Resource + timer := r.metrics.GetDuration.Start() + tx := r.db.Where(&models.Resource{ + Project: ID.Project, + Domain: ID.Domain, + Workflow: ID.Workflow, + LaunchPlan: ID.LaunchPlan, + ResourceType: ID.ResourceType, + }).First(&model) + timer.Stop() + if tx.Error != nil { + return models.Resource{}, r.errorTransformer.ToFlyteAdminError(tx.Error) + } + if tx.RecordNotFound() { + return models.Resource{}, flyteAdminErrors.NewFlyteAdminErrorf(codes.NotFound, + "%v", ID) + } + return model, nil +} + +func (r *ResourceRepo) Delete(ctx context.Context, ID interfaces.ResourceID) error { + var tx *gorm.DB + r.metrics.DeleteDuration.Time(func() { + tx = r.db.Where(&models.Resource{ + Project: ID.Project, + Domain: ID.Domain, + Workflow: ID.Workflow, + LaunchPlan: ID.LaunchPlan, + ResourceType: ID.ResourceType, + }).Unscoped().Delete(models.Resource{}) + }) + if tx.Error != nil { + return r.errorTransformer.ToFlyteAdminError(tx.Error) + } + if tx.RecordNotFound() { + return flyteAdminErrors.NewFlyteAdminErrorf(codes.NotFound, "%v", ID) + } + return nil +} + +func NewResourceRepo(db *gorm.DB, errorTransformer errors.ErrorTransformer, + scope promutils.Scope) interfaces.ResourceRepoInterface { + metrics := newMetrics(scope) + return &ResourceRepo{ + db: db, + errorTransformer: errorTransformer, + metrics: metrics, + } +} diff --git a/flyteadmin/pkg/repositories/gormimpl/resource_repo_test.go b/flyteadmin/pkg/repositories/gormimpl/resource_repo_test.go new file mode 100644 index 000000000..d7fae4fa0 --- /dev/null +++ b/flyteadmin/pkg/repositories/gormimpl/resource_repo_test.go @@ -0,0 +1,137 @@ +package gormimpl + +import ( + "context" + "testing" + + "github.com/lyft/flyteadmin/pkg/repositories/interfaces" + + mocket "github.com/Selvatico/go-mocket" + "github.com/lyft/flyteadmin/pkg/repositories/errors" + "github.com/lyft/flyteadmin/pkg/repositories/models" + mockScope "github.com/lyft/flytestdlib/promutils" + "github.com/stretchr/testify/assert" +) + +func TestCreateWorkflowAttributes(t *testing.T) { + resourceRepo := NewResourceRepo(GetDbForTest(t), errors.NewTestErrorTransformer(), mockScope.NewTestScope()) + GlobalMock := mocket.Catcher.Reset() + + query := GlobalMock.NewMock() + query.WithQuery( + `INSERT INTO "resources" ("created_at","updated_at","deleted_at","project","domain",` + + `"workflow","launch_plan","resource_type","priority","attributes") VALUES (?,?,?,?,?,?,?,?,?,?)`) + + err := resourceRepo.CreateOrUpdate(context.Background(), models.Resource{ + Project: "project", + Domain: "domain", + Workflow: "workflow", + ResourceType: "resource", + Priority: models.ResourcePriorityLaunchPlanLevel, + Attributes: []byte("attrs"), + }) + assert.NoError(t, err) + assert.True(t, query.Triggered) +} + +func TestGetWorkflowAttributes(t *testing.T) { + resourceRepo := NewResourceRepo(GetDbForTest(t), errors.NewTestErrorTransformer(), mockScope.NewTestScope()) + GlobalMock := mocket.Catcher.Reset() + + response := make(map[string]interface{}) + response["project"] = "project" + response["domain"] = "domain" + response["workflow"] = "workflow" + response["resource_type"] = "resource-type" + response["attributes"] = []byte("attrs") + + query := GlobalMock.NewMock() + query.WithQuery(`SELECT * FROM "resources" WHERE "resources"."deleted_at" IS NULL AND` + + ` ((resource_type = resource AND domain = domain AND project IN (,project)` + + ` AND workflow IN (,workflow) AND launch_plan IN ())) ORDER BY` + + ` priority desc,"resources"."id" ASC LIMIT 1`).WithReply( + []map[string]interface{}{ + response, + }) + + output, err := resourceRepo.Get(context.Background(), interfaces.ResourceID{Project: "project", Domain: "domain", Workflow: "workflow", ResourceType: "resource"}) + assert.Nil(t, err) + assert.Equal(t, "project", output.Project) + assert.Equal(t, "domain", output.Domain) + assert.Equal(t, "workflow", output.Workflow) + assert.Equal(t, "resource-type", output.ResourceType) + assert.Equal(t, []byte("attrs"), output.Attributes) +} + +func TestProjectDomainAttributes(t *testing.T) { + resourceRepo := NewResourceRepo(GetDbForTest(t), errors.NewTestErrorTransformer(), mockScope.NewTestScope()) + GlobalMock := mocket.Catcher.Reset() + + response := make(map[string]interface{}) + response[project] = project + response[domain] = domain + response["resource_type"] = "resource-type" + response["attributes"] = []byte("attrs") + + query := GlobalMock.NewMock() + query.WithQuery(`SELECT * FROM "resources" WHERE "resources"."deleted_at" IS NULL AND` + + ` ((resource_type = resource AND domain = domain AND project IN (,project)` + + ` AND workflow IN () AND launch_plan IN ())) ORDER BY` + + ` priority desc,"resources"."id" ASC LIMIT 1`).WithReply( + []map[string]interface{}{ + response, + }) + + output, err := resourceRepo.Get(context.Background(), interfaces.ResourceID{Project: "project", Domain: "domain", ResourceType: "resource"}) + assert.Nil(t, err) + assert.Equal(t, project, output.Project) + assert.Equal(t, domain, output.Domain) + assert.Equal(t, "", output.Workflow) + assert.Equal(t, "resource-type", output.ResourceType) + assert.Equal(t, []byte("attrs"), output.Attributes) +} + +func TestGetRawWorkflowAttributes(t *testing.T) { + resourceRepo := NewResourceRepo(GetDbForTest(t), errors.NewTestErrorTransformer(), mockScope.NewTestScope()) + GlobalMock := mocket.Catcher.Reset() + + response := make(map[string]interface{}) + response[project] = project + response[domain] = domain + response["workflow"] = "workflow" + response["resource_type"] = "resource" + response["launch_plan"] = "launch_plan" + response["attributes"] = []byte("attrs") + + query := GlobalMock.NewMock() + query.WithQuery(`SELECT * FROM "resources" WHERE "resources"."deleted_at" IS NULL AND (("resources"."project" = project) AND ` + + `("resources"."domain" = domain) AND ("resources"."workflow" = workflow) AND ` + + `("resources"."launch_plan" = launch_plan) AND ("resources"."resource_type" = resource)) ORDER BY "resources"."id" ASC LIMIT 1`).WithReply( + []map[string]interface{}{ + response, + }) + + output, err := resourceRepo.GetRaw(context.Background(), interfaces.ResourceID{Project: "project", Domain: "domain", Workflow: "workflow", LaunchPlan: "launch_plan", ResourceType: "resource"}) + assert.Nil(t, err) + assert.Equal(t, project, output.Project) + assert.Equal(t, domain, output.Domain) + assert.Equal(t, "workflow", output.Workflow) + assert.Equal(t, "launch_plan", output.LaunchPlan) + assert.Equal(t, "resource", output.ResourceType) + assert.Equal(t, []byte("attrs"), output.Attributes) +} + +func TestDeleteWorkflowAttributes(t *testing.T) { + resourceRepo := NewResourceRepo(GetDbForTest(t), errors.NewTestErrorTransformer(), mockScope.NewTestScope()) + GlobalMock := mocket.Catcher.Reset() + + query := GlobalMock.NewMock() + fakeResponse := query.WithQuery( + `DELETE FROM "resources" WHERE ("resources"."project" = ?) AND ` + + `("resources"."domain" = ?) AND ("resources"."workflow" = ?) AND ` + + `("resources"."launch_plan" = ?) AND ("resources"."resource_type" = ?)`) + + err := resourceRepo.Delete(context.Background(), interfaces.ResourceID{Project: "project", Domain: "domain", Workflow: "workflow", LaunchPlan: "launch_plan", ResourceType: "resource"}) + assert.Nil(t, err) + assert.True(t, fakeResponse.Triggered) +} diff --git a/flyteadmin/pkg/repositories/gormimpl/workflow_attributes_repo.go b/flyteadmin/pkg/repositories/gormimpl/workflow_attributes_repo.go deleted file mode 100644 index 1e32c228d..000000000 --- a/flyteadmin/pkg/repositories/gormimpl/workflow_attributes_repo.go +++ /dev/null @@ -1,95 +0,0 @@ -package gormimpl - -import ( - "context" - - "github.com/jinzhu/gorm" - "github.com/lyft/flyteadmin/pkg/repositories/errors" - "github.com/lyft/flyteadmin/pkg/repositories/interfaces" - "github.com/lyft/flyteadmin/pkg/repositories/models" - "github.com/lyft/flytestdlib/promutils" - "google.golang.org/grpc/codes" - - flyteAdminErrors "github.com/lyft/flyteadmin/pkg/errors" -) - -type WorkflowAttributesRepo struct { - db *gorm.DB - errorTransformer errors.ErrorTransformer - metrics gormMetrics -} - -func (r *WorkflowAttributesRepo) CreateOrUpdate(ctx context.Context, input models.WorkflowAttributes) error { - timer := r.metrics.GetDuration.Start() - var record models.WorkflowAttributes - tx := r.db.FirstOrCreate(&record, models.WorkflowAttributes{ - Project: input.Project, - Domain: input.Domain, - Workflow: input.Workflow, - Resource: input.Resource, - }) - timer.Stop() - if tx.Error != nil { - return r.errorTransformer.ToFlyteAdminError(tx.Error) - } - - timer = r.metrics.UpdateDuration.Start() - record.Attributes = input.Attributes - tx = r.db.Save(&record) - timer.Stop() - if tx.Error != nil { - return r.errorTransformer.ToFlyteAdminError(tx.Error) - } - return nil -} - -func (r *WorkflowAttributesRepo) Get(ctx context.Context, project, domain, workflow, resource string) ( - models.WorkflowAttributes, error) { - var model models.WorkflowAttributes - timer := r.metrics.GetDuration.Start() - tx := r.db.Where(&models.WorkflowAttributes{ - Project: project, - Domain: domain, - Workflow: workflow, - Resource: resource, - }).First(&model) - timer.Stop() - if tx.Error != nil { - return models.WorkflowAttributes{}, r.errorTransformer.ToFlyteAdminError(tx.Error) - } - if tx.RecordNotFound() { - return models.WorkflowAttributes{}, flyteAdminErrors.NewFlyteAdminErrorf(codes.NotFound, - "project-domain-workflow [%s-%s-%s] not found", project, domain, workflow) - } - return model, nil -} - -func (r *WorkflowAttributesRepo) Delete(ctx context.Context, project, domain, workflow, resource string) error { - var tx *gorm.DB - r.metrics.DeleteDuration.Time(func() { - tx = r.db.Where(&models.WorkflowAttributes{ - Project: project, - Domain: domain, - Workflow: workflow, - Resource: resource, - }).Unscoped().Delete(models.WorkflowAttributes{}) - }) - if tx.Error != nil { - return r.errorTransformer.ToFlyteAdminError(tx.Error) - } - if tx.RecordNotFound() { - return flyteAdminErrors.NewFlyteAdminErrorf(codes.NotFound, - "project-domain-workflow [%s-%s-%s] not found", project, domain, workflow) - } - return nil -} - -func NewWorkflowAttributesRepo(db *gorm.DB, errorTransformer errors.ErrorTransformer, - scope promutils.Scope) interfaces.WorkflowAttributesRepoInterface { - metrics := newMetrics(scope) - return &WorkflowAttributesRepo{ - db: db, - errorTransformer: errorTransformer, - metrics: metrics, - } -} diff --git a/flyteadmin/pkg/repositories/gormimpl/workflow_attributes_repo_test.go b/flyteadmin/pkg/repositories/gormimpl/workflow_attributes_repo_test.go deleted file mode 100644 index 577792ddc..000000000 --- a/flyteadmin/pkg/repositories/gormimpl/workflow_attributes_repo_test.go +++ /dev/null @@ -1,77 +0,0 @@ -package gormimpl - -import ( - "context" - "testing" - - mocket "github.com/Selvatico/go-mocket" - "github.com/lyft/flyteadmin/pkg/repositories/errors" - "github.com/lyft/flyteadmin/pkg/repositories/models" - mockScope "github.com/lyft/flytestdlib/promutils" - "github.com/stretchr/testify/assert" -) - -func TestCreateWorkflowAttributes(t *testing.T) { - workflowRepo := NewWorkflowAttributesRepo(GetDbForTest(t), errors.NewTestErrorTransformer(), mockScope.NewTestScope()) - GlobalMock := mocket.Catcher.Reset() - GlobalMock.Logging = true - - query := GlobalMock.NewMock() - query.WithQuery( - `INSERT INTO "workflow_attributes" ("id","created_at","updated_at","deleted_at","project","domain",` + - `"workflow","resource","attributes") VALUES (?,?,?,?,?,?,?,?,?)`) - - err := workflowRepo.CreateOrUpdate(context.Background(), models.WorkflowAttributes{ - Project: "project", - Domain: "domain", - Workflow: "workflow", - Resource: "resource", - Attributes: []byte("attrs"), - }) - assert.NoError(t, err) - assert.True(t, query.Triggered) -} - -func TestGetWorkflowAttributes(t *testing.T) { - workflowRepo := NewWorkflowAttributesRepo(GetDbForTest(t), errors.NewTestErrorTransformer(), mockScope.NewTestScope()) - GlobalMock := mocket.Catcher.Reset() - - response := make(map[string]interface{}) - response["project"] = "project" - response["domain"] = "domain" - response["workflow"] = "workflow" - response["resource"] = "resource" - response["attributes"] = []byte("attrs") - - query := GlobalMock.NewMock() - query.WithQuery(`SELECT * FROM "workflow_attributes" WHERE "workflow_attributes"."deleted_at" IS NULL AND` + - ` (("workflow_attributes"."project" = project) AND ("workflow_attributes"."domain" = domain) AND ` + - `("workflow_attributes"."workflow" = workflow) AND ("workflow_attributes"."resource" = resource)) ORDER BY ` + - `"workflow_attributes"."id" ASC LIMIT 1`).WithReply( - []map[string]interface{}{ - response, - }) - - output, err := workflowRepo.Get(context.Background(), "project", "domain", "workflow", "resource") - assert.Nil(t, err) - assert.Equal(t, "project", output.Project) - assert.Equal(t, "domain", output.Domain) - assert.Equal(t, "workflow", output.Workflow) - assert.Equal(t, "resource", output.Resource) - assert.Equal(t, []byte("attrs"), output.Attributes) -} - -func TestDeleteWorkflowAttributes(t *testing.T) { - workflowRepo := NewWorkflowAttributesRepo(GetDbForTest(t), errors.NewTestErrorTransformer(), mockScope.NewTestScope()) - GlobalMock := mocket.Catcher.Reset() - - query := GlobalMock.NewMock() - fakeResponse := query.WithQuery( - `DELETE FROM "workflow_attributes" WHERE ("workflow_attributes"."project" = ?) AND ` + - `("workflow_attributes"."domain" = ?) AND ("workflow_attributes"."workflow" = ?) AND ` + - `("workflow_attributes"."resource" = ?)`) - - err := workflowRepo.Delete(context.Background(), "project", "domain", "workflow", "resource") - assert.Nil(t, err) - assert.True(t, fakeResponse.Triggered) -} diff --git a/flyteadmin/pkg/repositories/interfaces/project_attributes_repo.go b/flyteadmin/pkg/repositories/interfaces/project_attributes_repo.go deleted file mode 100644 index fd77e31c9..000000000 --- a/flyteadmin/pkg/repositories/interfaces/project_attributes_repo.go +++ /dev/null @@ -1,16 +0,0 @@ -package interfaces - -import ( - "context" - - "github.com/lyft/flyteadmin/pkg/repositories/models" -) - -type ProjectAttributesRepoInterface interface { - // Inserts or updates an existing ProjectAttributes model into the database store. - CreateOrUpdate(ctx context.Context, input models.ProjectAttributes) error - // Returns a matching ProjectAttributes model when it exists. - Get(ctx context.Context, project, resource string) (models.ProjectAttributes, error) - // Deletes a matching ProjectAttributes model when it exists. - Delete(ctx context.Context, project, resource string) error -} diff --git a/flyteadmin/pkg/repositories/interfaces/project_domain_attributes_repo.go b/flyteadmin/pkg/repositories/interfaces/project_domain_attributes_repo.go deleted file mode 100644 index e5e34451d..000000000 --- a/flyteadmin/pkg/repositories/interfaces/project_domain_attributes_repo.go +++ /dev/null @@ -1,16 +0,0 @@ -package interfaces - -import ( - "context" - - "github.com/lyft/flyteadmin/pkg/repositories/models" -) - -type ProjectDomainAttributesRepoInterface interface { - // Inserts or updates an existing ProjectDomainAttributes model into the database store. - CreateOrUpdate(ctx context.Context, input models.ProjectDomainAttributes) error - // Returns a matching ProjectDomainAttributes model when it exists. - Get(ctx context.Context, project, domain, resource string) (models.ProjectDomainAttributes, error) - // Deletes a matching ProjectDomainAttributes model when it exists. - Delete(ctx context.Context, project, domain, resource string) error -} diff --git a/flyteadmin/pkg/repositories/interfaces/resource_repo.go b/flyteadmin/pkg/repositories/interfaces/resource_repo.go new file mode 100644 index 000000000..5380c8664 --- /dev/null +++ b/flyteadmin/pkg/repositories/interfaces/resource_repo.go @@ -0,0 +1,26 @@ +package interfaces + +import ( + "context" + + "github.com/lyft/flyteadmin/pkg/repositories/models" +) + +type ResourceRepoInterface interface { + // Inserts or updates an existing Type model into the database store. + CreateOrUpdate(ctx context.Context, input models.Resource) error + // Returns a matching Type model based on hierarchical resolution. + Get(ctx context.Context, ID ResourceID) (models.Resource, error) + // Returns a matching Type model. + GetRaw(ctx context.Context, ID ResourceID) (models.Resource, error) + // Deletes a matching Type model when it exists. + Delete(ctx context.Context, ID ResourceID) error +} + +type ResourceID struct { + Project string + Domain string + Workflow string + LaunchPlan string + ResourceType string +} diff --git a/flyteadmin/pkg/repositories/interfaces/workflow_attributes_repo.go b/flyteadmin/pkg/repositories/interfaces/workflow_attributes_repo.go deleted file mode 100644 index fed09ed48..000000000 --- a/flyteadmin/pkg/repositories/interfaces/workflow_attributes_repo.go +++ /dev/null @@ -1,16 +0,0 @@ -package interfaces - -import ( - "context" - - "github.com/lyft/flyteadmin/pkg/repositories/models" -) - -type WorkflowAttributesRepoInterface interface { - // Inserts or updates an existing WorkflowAttributes model into the database store. - CreateOrUpdate(ctx context.Context, input models.WorkflowAttributes) error - // Returns a matching WorkflowAttributes model when it exists. - Get(ctx context.Context, project, domain, workflow, resource string) (models.WorkflowAttributes, error) - // Deletes a matching ProjectDomainAttributes model when it exists. - Delete(ctx context.Context, project, domain, workflow, resource string) error -} diff --git a/flyteadmin/pkg/repositories/mocks/project_attributes_repo.go b/flyteadmin/pkg/repositories/mocks/project_attributes_repo.go deleted file mode 100644 index 24d91fbfc..000000000 --- a/flyteadmin/pkg/repositories/mocks/project_attributes_repo.go +++ /dev/null @@ -1,44 +0,0 @@ -package mocks - -import ( - "context" - - "github.com/lyft/flyteadmin/pkg/repositories/interfaces" - "github.com/lyft/flyteadmin/pkg/repositories/models" -) - -type CreateOrUpdateProjectAttributesFunction func(ctx context.Context, input models.ProjectAttributes) error -type GetProjectAttributesFunction func(ctx context.Context, project, resource string) (models.ProjectAttributes, error) -type DeleteProjectAttributesFunction func(ctx context.Context, project, resource string) error - -type MockProjectAttributesRepo struct { - CreateOrUpdateFunction CreateOrUpdateProjectAttributesFunction - GetFunction GetProjectAttributesFunction - DeleteFunction DeleteProjectAttributesFunction -} - -func (r *MockProjectAttributesRepo) CreateOrUpdate(ctx context.Context, input models.ProjectAttributes) error { - if r.CreateOrUpdateFunction != nil { - return r.CreateOrUpdateFunction(ctx, input) - } - return nil -} - -func (r *MockProjectAttributesRepo) Get(ctx context.Context, project, resource string) ( - models.ProjectAttributes, error) { - if r.GetFunction != nil { - return r.GetFunction(ctx, project, resource) - } - return models.ProjectAttributes{}, nil -} - -func (r *MockProjectAttributesRepo) Delete(ctx context.Context, project, resource string) error { - if r.DeleteFunction != nil { - return r.DeleteFunction(ctx, project, resource) - } - return nil -} - -func NewMockProjectAttributesRepo() interfaces.ProjectAttributesRepoInterface { - return &MockProjectAttributesRepo{} -} diff --git a/flyteadmin/pkg/repositories/mocks/project_domain_attributes_repo.go b/flyteadmin/pkg/repositories/mocks/project_domain_attributes_repo.go deleted file mode 100644 index 8c39a80ea..000000000 --- a/flyteadmin/pkg/repositories/mocks/project_domain_attributes_repo.go +++ /dev/null @@ -1,44 +0,0 @@ -package mocks - -import ( - "context" - - "github.com/lyft/flyteadmin/pkg/repositories/interfaces" - "github.com/lyft/flyteadmin/pkg/repositories/models" -) - -type CreateOrUpdateProjectDomainAttributesFunction func(ctx context.Context, input models.ProjectDomainAttributes) error -type GetProjectDomainAttributesFunction func(ctx context.Context, project, domain, resource string) (models.ProjectDomainAttributes, error) -type DeleteProjectDomainAttributesFunction func(ctx context.Context, project, domain, resource string) error - -type MockProjectDomainAttributesRepo struct { - CreateOrUpdateFunction CreateOrUpdateProjectDomainAttributesFunction - GetFunction GetProjectDomainAttributesFunction - DeleteFunction DeleteProjectDomainAttributesFunction -} - -func (r *MockProjectDomainAttributesRepo) CreateOrUpdate(ctx context.Context, input models.ProjectDomainAttributes) error { - if r.CreateOrUpdateFunction != nil { - return r.CreateOrUpdateFunction(ctx, input) - } - return nil -} - -func (r *MockProjectDomainAttributesRepo) Get(ctx context.Context, project, domain, resource string) ( - models.ProjectDomainAttributes, error) { - if r.GetFunction != nil { - return r.GetFunction(ctx, project, domain, resource) - } - return models.ProjectDomainAttributes{}, nil -} - -func (r *MockProjectDomainAttributesRepo) Delete(ctx context.Context, project, domain, resource string) error { - if r.DeleteFunction != nil { - return r.DeleteFunction(ctx, project, domain, resource) - } - return nil -} - -func NewMockProjectDomainAttributesRepo() interfaces.ProjectDomainAttributesRepoInterface { - return &MockProjectDomainAttributesRepo{} -} diff --git a/flyteadmin/pkg/repositories/mocks/repository.go b/flyteadmin/pkg/repositories/mocks/repository.go index 403dcbdf4..4ec3a23fa 100644 --- a/flyteadmin/pkg/repositories/mocks/repository.go +++ b/flyteadmin/pkg/repositories/mocks/repository.go @@ -6,17 +6,15 @@ import ( ) type MockRepository struct { - taskRepo interfaces.TaskRepoInterface - workflowRepo interfaces.WorkflowRepoInterface - launchPlanRepo interfaces.LaunchPlanRepoInterface - executionRepo interfaces.ExecutionRepoInterface - nodeExecutionRepo interfaces.NodeExecutionRepoInterface - projectRepo interfaces.ProjectRepoInterface - projectAttributesRepo interfaces.ProjectAttributesRepoInterface - projectDomainAttributesRepo interfaces.ProjectDomainAttributesRepoInterface - workflowAttributesRepo interfaces.WorkflowAttributesRepoInterface - taskExecutionRepo interfaces.TaskExecutionRepoInterface - namedEntityRepo interfaces.NamedEntityRepoInterface + taskRepo interfaces.TaskRepoInterface + workflowRepo interfaces.WorkflowRepoInterface + launchPlanRepo interfaces.LaunchPlanRepoInterface + executionRepo interfaces.ExecutionRepoInterface + nodeExecutionRepo interfaces.NodeExecutionRepoInterface + projectRepo interfaces.ProjectRepoInterface + resourceRepo interfaces.ResourceRepoInterface + taskExecutionRepo interfaces.TaskExecutionRepoInterface + namedEntityRepo interfaces.NamedEntityRepoInterface } func (r *MockRepository) TaskRepo() interfaces.TaskRepoInterface { @@ -43,16 +41,8 @@ func (r *MockRepository) ProjectRepo() interfaces.ProjectRepoInterface { return r.projectRepo } -func (r *MockRepository) ProjectDomainAttributesRepo() interfaces.ProjectDomainAttributesRepoInterface { - return r.projectDomainAttributesRepo -} - -func (r *MockRepository) WorkflowAttributesRepo() interfaces.WorkflowAttributesRepoInterface { - return r.workflowAttributesRepo -} - -func (r *MockRepository) ProjectAttributesRepo() interfaces.ProjectAttributesRepoInterface { - return r.projectAttributesRepo +func (r *MockRepository) ResourceRepo() interfaces.ResourceRepoInterface { + return r.resourceRepo } func (r *MockRepository) TaskExecutionRepo() interfaces.TaskExecutionRepoInterface { @@ -65,16 +55,14 @@ func (r *MockRepository) NamedEntityRepo() interfaces.NamedEntityRepoInterface { func NewMockRepository() repositories.RepositoryInterface { return &MockRepository{ - taskRepo: NewMockTaskRepo(), - workflowRepo: NewMockWorkflowRepo(), - launchPlanRepo: NewMockLaunchPlanRepo(), - executionRepo: NewMockExecutionRepo(), - nodeExecutionRepo: NewMockNodeExecutionRepo(), - projectRepo: NewMockProjectRepo(), - projectAttributesRepo: NewMockProjectAttributesRepo(), - projectDomainAttributesRepo: NewMockProjectDomainAttributesRepo(), - workflowAttributesRepo: NewMockWorkflowAttributesRepo(), - taskExecutionRepo: NewMockTaskExecutionRepo(), - namedEntityRepo: NewMockNamedEntityRepo(), + taskRepo: NewMockTaskRepo(), + workflowRepo: NewMockWorkflowRepo(), + launchPlanRepo: NewMockLaunchPlanRepo(), + executionRepo: NewMockExecutionRepo(), + nodeExecutionRepo: NewMockNodeExecutionRepo(), + projectRepo: NewMockProjectRepo(), + resourceRepo: NewMockResourceRepo(), + taskExecutionRepo: NewMockTaskExecutionRepo(), + namedEntityRepo: NewMockNamedEntityRepo(), } } diff --git a/flyteadmin/pkg/repositories/mocks/resource.go b/flyteadmin/pkg/repositories/mocks/resource.go new file mode 100644 index 000000000..bfacaf39b --- /dev/null +++ b/flyteadmin/pkg/repositories/mocks/resource.go @@ -0,0 +1,53 @@ +package mocks + +import ( + "context" + + "github.com/lyft/flyteadmin/pkg/repositories/interfaces" + "github.com/lyft/flyteadmin/pkg/repositories/models" +) + +type CreateOrUpdateResourceFunction func(ctx context.Context, input models.Resource) error +type GetResourceFunction func(ctx context.Context, ID interfaces.ResourceID) ( + models.Resource, error) +type DeleteResourceFunction func(ctx context.Context, ID interfaces.ResourceID) error + +type MockResourceRepo struct { + CreateOrUpdateFunction CreateOrUpdateResourceFunction + GetFunction GetResourceFunction + DeleteFunction DeleteResourceFunction +} + +func (r *MockResourceRepo) CreateOrUpdate(ctx context.Context, input models.Resource) error { + if r.CreateOrUpdateFunction != nil { + return r.CreateOrUpdateFunction(ctx, input) + } + return nil +} + +func (r *MockResourceRepo) Get(ctx context.Context, ID interfaces.ResourceID) ( + models.Resource, error) { + if r.GetFunction != nil { + return r.GetFunction(ctx, ID) + } + return models.Resource{}, nil +} + +func (r *MockResourceRepo) GetRaw(ctx context.Context, ID interfaces.ResourceID) ( + models.Resource, error) { + if r.GetFunction != nil { + return r.GetFunction(ctx, ID) + } + return models.Resource{}, nil +} + +func (r *MockResourceRepo) Delete(ctx context.Context, ID interfaces.ResourceID) error { + if r.DeleteFunction != nil { + return r.DeleteFunction(ctx, ID) + } + return nil +} + +func NewMockResourceRepo() interfaces.ResourceRepoInterface { + return &MockResourceRepo{} +} diff --git a/flyteadmin/pkg/repositories/mocks/workflow_attributes.go b/flyteadmin/pkg/repositories/mocks/workflow_attributes.go deleted file mode 100644 index 8e8cf0fb5..000000000 --- a/flyteadmin/pkg/repositories/mocks/workflow_attributes.go +++ /dev/null @@ -1,45 +0,0 @@ -package mocks - -import ( - "context" - - "github.com/lyft/flyteadmin/pkg/repositories/interfaces" - "github.com/lyft/flyteadmin/pkg/repositories/models" -) - -type CreateOrUpdateWorkflowAttributesFunction func(ctx context.Context, input models.WorkflowAttributes) error -type GetWorkflowAttributesFunction func(ctx context.Context, project, domain, workflow, resource string) ( - models.WorkflowAttributes, error) -type DeleteWorkflowAttributesFunction func(ctx context.Context, project, domain, workflow, resource string) error - -type MockWorkflowAttributesRepo struct { - CreateOrUpdateFunction CreateOrUpdateWorkflowAttributesFunction - GetFunction GetWorkflowAttributesFunction - DeleteFunction DeleteWorkflowAttributesFunction -} - -func (r *MockWorkflowAttributesRepo) CreateOrUpdate(ctx context.Context, input models.WorkflowAttributes) error { - if r.CreateOrUpdateFunction != nil { - return r.CreateOrUpdateFunction(ctx, input) - } - return nil -} - -func (r *MockWorkflowAttributesRepo) Get(ctx context.Context, project, domain, workflow, resource string) ( - models.WorkflowAttributes, error) { - if r.GetFunction != nil { - return r.GetFunction(ctx, project, domain, workflow, resource) - } - return models.WorkflowAttributes{}, nil -} - -func (r *MockWorkflowAttributesRepo) Delete(ctx context.Context, project, domain, workflow, resource string) error { - if r.DeleteFunction != nil { - return r.DeleteFunction(ctx, project, domain, workflow, resource) - } - return nil -} - -func NewMockWorkflowAttributesRepo() interfaces.WorkflowAttributesRepoInterface { - return &MockWorkflowAttributesRepo{} -} diff --git a/flyteadmin/pkg/repositories/models/project_attributes.go b/flyteadmin/pkg/repositories/models/project_attributes.go deleted file mode 100644 index e9a5cc87d..000000000 --- a/flyteadmin/pkg/repositories/models/project_attributes.go +++ /dev/null @@ -1,10 +0,0 @@ -package models - -// Represents project-domain customizable configuration. -type ProjectAttributes struct { - BaseModel - Project string `gorm:"primary_key"` - Resource string `gorm:"primary_key"` - // Serialized flyteidl.admin.MatchingAttributes. - Attributes []byte -} diff --git a/flyteadmin/pkg/repositories/models/project_domain_attributes.go b/flyteadmin/pkg/repositories/models/project_domain_attributes.go deleted file mode 100644 index 8674121c7..000000000 --- a/flyteadmin/pkg/repositories/models/project_domain_attributes.go +++ /dev/null @@ -1,11 +0,0 @@ -package models - -// Represents project-domain customizable configuration. -type ProjectDomainAttributes struct { - BaseModel - Project string `gorm:"primary_key"` - Domain string `gorm:"primary_key"` - Resource string `gorm:"primary_key"` - // Serialized flyteidl.admin.MatchingAttributes. - Attributes []byte -} diff --git a/flyteadmin/pkg/repositories/models/resource.go b/flyteadmin/pkg/repositories/models/resource.go new file mode 100644 index 000000000..74e0429d7 --- /dev/null +++ b/flyteadmin/pkg/repositories/models/resource.go @@ -0,0 +1,29 @@ +package models + +import "time" + +type ResourcePriority int32 + +const ( + ResourcePriorityDomainLevel ResourcePriority = 1 + ResourcePriorityProjectDomainLevel ResourcePriority = 10 + ResourcePriorityWorkflowLevel ResourcePriority = 100 + ResourcePriorityLaunchPlanLevel ResourcePriority = 1000 +) + +// Represents Flyte resources repository. +// In this model, the combination of (Project, Domain, Workflow, LaunchPlan, ResourceType) is unique +type Resource struct { + ID int64 `gorm:"AUTO_INCREMENT;column:id;primary_key"` + CreatedAt time.Time + UpdatedAt time.Time + DeletedAt *time.Time `sql:"index"` + Project string `gorm:"unique_index:resource_idx"` + Domain string `gorm:"unique_index:resource_idx"` + Workflow string `gorm:"unique_index:resource_idx"` + LaunchPlan string `gorm:"unique_index:resource_idx"` + ResourceType string `gorm:"unique_index:resource_idx"` + Priority ResourcePriority + // Serialized flyteidl.admin.MatchingAttributes. + Attributes []byte +} diff --git a/flyteadmin/pkg/repositories/models/workflow_attributes.go b/flyteadmin/pkg/repositories/models/workflow_attributes.go deleted file mode 100644 index 4d9afd504..000000000 --- a/flyteadmin/pkg/repositories/models/workflow_attributes.go +++ /dev/null @@ -1,12 +0,0 @@ -package models - -// Represents project-domain customizable configuration. -type WorkflowAttributes struct { - BaseModel - Project string `gorm:"primary_key"` - Domain string `gorm:"primary_key"` - Workflow string `gorm:"primary_key"` - Resource string `gorm:"primary_key"` - // Serialized flyteidl.admin.MatchingAttributes. - Attributes []byte -} diff --git a/flyteadmin/pkg/repositories/postgres_repo.go b/flyteadmin/pkg/repositories/postgres_repo.go index f7efc94d5..d2606aee5 100644 --- a/flyteadmin/pkg/repositories/postgres_repo.go +++ b/flyteadmin/pkg/repositories/postgres_repo.go @@ -9,17 +9,15 @@ import ( ) type PostgresRepo struct { - executionRepo interfaces.ExecutionRepoInterface - namedEntityRepo interfaces.NamedEntityRepoInterface - launchPlanRepo interfaces.LaunchPlanRepoInterface - projectRepo interfaces.ProjectRepoInterface - projectAttributesRepo interfaces.ProjectAttributesRepoInterface - projectDomainAttributesRepo interfaces.ProjectDomainAttributesRepoInterface - nodeExecutionRepo interfaces.NodeExecutionRepoInterface - taskRepo interfaces.TaskRepoInterface - taskExecutionRepo interfaces.TaskExecutionRepoInterface - workflowRepo interfaces.WorkflowRepoInterface - workflowAttributesRepo interfaces.WorkflowAttributesRepoInterface + executionRepo interfaces.ExecutionRepoInterface + namedEntityRepo interfaces.NamedEntityRepoInterface + launchPlanRepo interfaces.LaunchPlanRepoInterface + projectRepo interfaces.ProjectRepoInterface + nodeExecutionRepo interfaces.NodeExecutionRepoInterface + taskRepo interfaces.TaskRepoInterface + taskExecutionRepo interfaces.TaskExecutionRepoInterface + workflowRepo interfaces.WorkflowRepoInterface + resourceRepo interfaces.ResourceRepoInterface } func (p *PostgresRepo) ExecutionRepo() interfaces.ExecutionRepoInterface { @@ -38,14 +36,6 @@ func (p *PostgresRepo) ProjectRepo() interfaces.ProjectRepoInterface { return p.projectRepo } -func (p *PostgresRepo) ProjectAttributesRepo() interfaces.ProjectAttributesRepoInterface { - return p.projectAttributesRepo -} - -func (p *PostgresRepo) ProjectDomainAttributesRepo() interfaces.ProjectDomainAttributesRepoInterface { - return p.projectDomainAttributesRepo -} - func (p *PostgresRepo) NodeExecutionRepo() interfaces.NodeExecutionRepoInterface { return p.nodeExecutionRepo } @@ -62,22 +52,20 @@ func (p *PostgresRepo) WorkflowRepo() interfaces.WorkflowRepoInterface { return p.workflowRepo } -func (p *PostgresRepo) WorkflowAttributesRepo() interfaces.WorkflowAttributesRepoInterface { - return p.workflowAttributesRepo +func (p *PostgresRepo) ResourceRepo() interfaces.ResourceRepoInterface { + return p.resourceRepo } func NewPostgresRepo(db *gorm.DB, errorTransformer errors.ErrorTransformer, scope promutils.Scope) RepositoryInterface { return &PostgresRepo{ - executionRepo: gormimpl.NewExecutionRepo(db, errorTransformer, scope.NewSubScope("executions")), - launchPlanRepo: gormimpl.NewLaunchPlanRepo(db, errorTransformer, scope.NewSubScope("launch_plans")), - projectRepo: gormimpl.NewProjectRepo(db, errorTransformer, scope.NewSubScope("project")), - projectAttributesRepo: gormimpl.NewProjectAttributesRepo(db, errorTransformer, scope.NewSubScope("project_attrs")), - projectDomainAttributesRepo: gormimpl.NewProjectDomainAttributesRepo(db, errorTransformer, scope.NewSubScope("project_domain_attrs")), - namedEntityRepo: gormimpl.NewNamedEntityRepo(db, errorTransformer, scope.NewSubScope("named_entity")), - nodeExecutionRepo: gormimpl.NewNodeExecutionRepo(db, errorTransformer, scope.NewSubScope("node_executions")), - taskRepo: gormimpl.NewTaskRepo(db, errorTransformer, scope.NewSubScope("tasks")), - taskExecutionRepo: gormimpl.NewTaskExecutionRepo(db, errorTransformer, scope.NewSubScope("task_executions")), - workflowRepo: gormimpl.NewWorkflowRepo(db, errorTransformer, scope.NewSubScope("workflows")), - workflowAttributesRepo: gormimpl.NewWorkflowAttributesRepo(db, errorTransformer, scope.NewSubScope("workflow_attrs")), + executionRepo: gormimpl.NewExecutionRepo(db, errorTransformer, scope.NewSubScope("executions")), + launchPlanRepo: gormimpl.NewLaunchPlanRepo(db, errorTransformer, scope.NewSubScope("launch_plans")), + projectRepo: gormimpl.NewProjectRepo(db, errorTransformer, scope.NewSubScope("project")), + namedEntityRepo: gormimpl.NewNamedEntityRepo(db, errorTransformer, scope.NewSubScope("named_entity")), + nodeExecutionRepo: gormimpl.NewNodeExecutionRepo(db, errorTransformer, scope.NewSubScope("node_executions")), + taskRepo: gormimpl.NewTaskRepo(db, errorTransformer, scope.NewSubScope("tasks")), + taskExecutionRepo: gormimpl.NewTaskExecutionRepo(db, errorTransformer, scope.NewSubScope("task_executions")), + workflowRepo: gormimpl.NewWorkflowRepo(db, errorTransformer, scope.NewSubScope("workflows")), + resourceRepo: gormimpl.NewResourceRepo(db, errorTransformer, scope.NewSubScope("resources")), } } diff --git a/flyteadmin/pkg/repositories/transformers/project_attributes.go b/flyteadmin/pkg/repositories/transformers/project_attributes.go deleted file mode 100644 index b628ba318..000000000 --- a/flyteadmin/pkg/repositories/transformers/project_attributes.go +++ /dev/null @@ -1,35 +0,0 @@ -package transformers - -import ( - "github.com/golang/protobuf/proto" - - "github.com/lyft/flyteadmin/pkg/errors" - "github.com/lyft/flyteadmin/pkg/repositories/models" - "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" - "google.golang.org/grpc/codes" -) - -func ToProjectAttributesModel(attributes admin.ProjectAttributes, resource admin.MatchableResource) (models.ProjectAttributes, error) { - attributeBytes, err := proto.Marshal(attributes.MatchingAttributes) - if err != nil { - return models.ProjectAttributes{}, err - } - return models.ProjectAttributes{ - Project: attributes.Project, - Resource: resource.String(), - Attributes: attributeBytes, - }, nil -} - -func FromProjectAttributesModel(model models.ProjectAttributes) (admin.ProjectAttributes, error) { - var attributes admin.MatchingAttributes - err := proto.Unmarshal(model.Attributes, &attributes) - if err != nil { - return admin.ProjectAttributes{}, errors.NewFlyteAdminErrorf( - codes.Internal, "Failed to decode project domain resource projectDomainAttributes with err: %v", err) - } - return admin.ProjectAttributes{ - Project: model.Project, - MatchingAttributes: &attributes, - }, nil -} diff --git a/flyteadmin/pkg/repositories/transformers/project_attributes_test.go b/flyteadmin/pkg/repositories/transformers/project_attributes_test.go deleted file mode 100644 index 65a78395a..000000000 --- a/flyteadmin/pkg/repositories/transformers/project_attributes_test.go +++ /dev/null @@ -1,63 +0,0 @@ -package transformers - -import ( - "testing" - - "github.com/golang/protobuf/proto" - - "github.com/lyft/flyteadmin/pkg/errors" - "github.com/lyft/flyteadmin/pkg/repositories/models" - "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" - "google.golang.org/grpc/codes" - - "github.com/stretchr/testify/assert" -) - -var matchingTaskResourceAttributes = &admin.MatchingAttributes{ - Target: &admin.MatchingAttributes_TaskResourceAttributes{ - TaskResourceAttributes: &admin.TaskResourceAttributes{ - Defaults: &admin.TaskResourceSpec{ - Cpu: "1", - }, - }, - }, -} - -var projectAttributes = admin.ProjectAttributes{ - Project: "project", - MatchingAttributes: matchingTaskResourceAttributes, -} - -var marshalledAttributes, _ = proto.Marshal(matchingTaskResourceAttributes) - -func TestToProjectAttributesModel(t *testing.T) { - model, err := ToProjectAttributesModel(projectAttributes, admin.MatchableResource_TASK_RESOURCE) - assert.Nil(t, err) - assert.EqualValues(t, models.ProjectAttributes{ - Project: "project", - Resource: admin.MatchableResource_TASK_RESOURCE.String(), - Attributes: marshalledAttributes, - }, model) -} - -func TestFromProjectAttributesModel(t *testing.T) { - model := models.ProjectAttributes{ - Project: "project", - Resource: admin.MatchableResource_TASK_RESOURCE.String(), - Attributes: marshalledAttributes, - } - unmarshalledAttributes, err := FromProjectAttributesModel(model) - assert.Nil(t, err) - assert.True(t, proto.Equal(&projectAttributes, &unmarshalledAttributes)) -} - -func TestFromProjectAttributesModel_InvalidResourceAttributes(t *testing.T) { - model := models.ProjectAttributes{ - Project: "project", - Resource: admin.MatchableResource_TASK_RESOURCE.String(), - Attributes: []byte("i'm invalid!"), - } - _, err := FromProjectAttributesModel(model) - assert.NotNil(t, err) - assert.Equal(t, codes.Internal, err.(errors.FlyteAdminError).Code()) -} diff --git a/flyteadmin/pkg/repositories/transformers/project_domain_attributes.go b/flyteadmin/pkg/repositories/transformers/project_domain_attributes.go deleted file mode 100644 index ccedf335c..000000000 --- a/flyteadmin/pkg/repositories/transformers/project_domain_attributes.go +++ /dev/null @@ -1,37 +0,0 @@ -package transformers - -import ( - "github.com/golang/protobuf/proto" - - "github.com/lyft/flyteadmin/pkg/errors" - "github.com/lyft/flyteadmin/pkg/repositories/models" - "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" - "google.golang.org/grpc/codes" -) - -func ToProjectDomainAttributesModel(attributes admin.ProjectDomainAttributes, resource admin.MatchableResource) (models.ProjectDomainAttributes, error) { - attributeBytes, err := proto.Marshal(attributes.MatchingAttributes) - if err != nil { - return models.ProjectDomainAttributes{}, err - } - return models.ProjectDomainAttributes{ - Project: attributes.Project, - Domain: attributes.Domain, - Resource: resource.String(), - Attributes: attributeBytes, - }, nil -} - -func FromProjectDomainAttributesModel(model models.ProjectDomainAttributes) (admin.ProjectDomainAttributes, error) { - var attributes admin.MatchingAttributes - err := proto.Unmarshal(model.Attributes, &attributes) - if err != nil { - return admin.ProjectDomainAttributes{}, errors.NewFlyteAdminErrorf( - codes.Internal, "Failed to decode project domain resource projectDomainAttributes with err: %v", err) - } - return admin.ProjectDomainAttributes{ - Project: model.Project, - Domain: model.Domain, - MatchingAttributes: &attributes, - }, nil -} diff --git a/flyteadmin/pkg/repositories/transformers/project_domain_attributes_test.go b/flyteadmin/pkg/repositories/transformers/project_domain_attributes_test.go deleted file mode 100644 index f4c42cd92..000000000 --- a/flyteadmin/pkg/repositories/transformers/project_domain_attributes_test.go +++ /dev/null @@ -1,68 +0,0 @@ -package transformers - -import ( - "testing" - - "github.com/golang/protobuf/proto" - - "github.com/lyft/flyteadmin/pkg/errors" - "github.com/lyft/flyteadmin/pkg/repositories/models" - "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" - "google.golang.org/grpc/codes" - - "github.com/stretchr/testify/assert" -) - -var matchingExecutionQueueAttributes = &admin.MatchingAttributes{ - Target: &admin.MatchingAttributes_ExecutionQueueAttributes{ - ExecutionQueueAttributes: &admin.ExecutionQueueAttributes{ - Tags: []string{ - "foo", - }, - }, - }, -} - -var projectDomainAttributes = admin.ProjectDomainAttributes{ - Project: "project", - Domain: "domain", - MatchingAttributes: matchingExecutionQueueAttributes, -} - -var marshalledExecutionQueueAttributes, _ = proto.Marshal(matchingExecutionQueueAttributes) - -func TestToProjectDomainAttributesModel(t *testing.T) { - - model, err := ToProjectDomainAttributesModel(projectDomainAttributes, admin.MatchableResource_EXECUTION_QUEUE) - assert.Nil(t, err) - assert.EqualValues(t, models.ProjectDomainAttributes{ - Project: "project", - Domain: "domain", - Resource: admin.MatchableResource_EXECUTION_QUEUE.String(), - Attributes: marshalledExecutionQueueAttributes, - }, model) -} - -func TestFromProjectDomainAttributesModel(t *testing.T) { - model := models.ProjectDomainAttributes{ - Project: "project", - Domain: "domain", - Resource: admin.MatchableResource_EXECUTION_QUEUE.String(), - Attributes: marshalledExecutionQueueAttributes, - } - unmarshalledAttributes, err := FromProjectDomainAttributesModel(model) - assert.Nil(t, err) - assert.True(t, proto.Equal(&projectDomainAttributes, &unmarshalledAttributes)) -} - -func TestFromProjectDomainAttributesModel_InvalidResourceAttributes(t *testing.T) { - model := models.ProjectDomainAttributes{ - Project: "project", - Domain: "domain", - Resource: admin.MatchableResource_EXECUTION_QUEUE.String(), - Attributes: []byte("i'm invalid!"), - } - _, err := FromProjectDomainAttributesModel(model) - assert.NotNil(t, err) - assert.Equal(t, codes.Internal, err.(errors.FlyteAdminError).Code()) -} diff --git a/flyteadmin/pkg/repositories/transformers/resource.go b/flyteadmin/pkg/repositories/transformers/resource.go new file mode 100644 index 000000000..3212e9584 --- /dev/null +++ b/flyteadmin/pkg/repositories/transformers/resource.go @@ -0,0 +1,68 @@ +package transformers + +import ( + "github.com/golang/protobuf/proto" + + "github.com/lyft/flyteadmin/pkg/errors" + "github.com/lyft/flyteadmin/pkg/repositories/models" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" + "google.golang.org/grpc/codes" +) + +func WorkflowAttributesToResourceModel(attributes admin.WorkflowAttributes, resource admin.MatchableResource) (models.Resource, error) { + attributeBytes, err := proto.Marshal(attributes.MatchingAttributes) + if err != nil { + return models.Resource{}, err + } + return models.Resource{ + Project: attributes.Project, + Domain: attributes.Domain, + Workflow: attributes.Workflow, + ResourceType: resource.String(), + Priority: models.ResourcePriorityWorkflowLevel, + Attributes: attributeBytes, + }, nil +} + +func FromResourceModelToWorkflowAttributes(model models.Resource) (admin.WorkflowAttributes, error) { + var attributes admin.MatchingAttributes + err := proto.Unmarshal(model.Attributes, &attributes) + if err != nil { + return admin.WorkflowAttributes{}, errors.NewFlyteAdminErrorf( + codes.Internal, "Failed to decode project domain resource projectDomainAttributes with err: %v", err) + } + return admin.WorkflowAttributes{ + Project: model.Project, + Domain: model.Domain, + Workflow: model.Workflow, + MatchingAttributes: &attributes, + }, nil +} + +func ProjectDomainAttributesToResourceModel(attributes admin.ProjectDomainAttributes, resource admin.MatchableResource) (models.Resource, error) { + attributeBytes, err := proto.Marshal(attributes.MatchingAttributes) + if err != nil { + return models.Resource{}, err + } + return models.Resource{ + Project: attributes.Project, + Domain: attributes.Domain, + ResourceType: resource.String(), + Priority: models.ResourcePriorityProjectDomainLevel, + Attributes: attributeBytes, + }, nil +} + +func FromResourceModelToProjectDomainAttributes(model models.Resource) (admin.ProjectDomainAttributes, error) { + var attributes admin.MatchingAttributes + err := proto.Unmarshal(model.Attributes, &attributes) + if err != nil { + return admin.ProjectDomainAttributes{}, errors.NewFlyteAdminErrorf( + codes.Internal, "Failed to decode project domain resource projectDomainAttributes with err: %v", err) + } + return admin.ProjectDomainAttributes{ + Project: model.Project, + Domain: model.Domain, + MatchingAttributes: &attributes, + }, nil +} diff --git a/flyteadmin/pkg/repositories/transformers/resource_test.go b/flyteadmin/pkg/repositories/transformers/resource_test.go new file mode 100644 index 000000000..7d0fbd30e --- /dev/null +++ b/flyteadmin/pkg/repositories/transformers/resource_test.go @@ -0,0 +1,127 @@ +package transformers + +import ( + "testing" + + "github.com/golang/protobuf/proto" + + "github.com/lyft/flyteadmin/pkg/errors" + "github.com/lyft/flyteadmin/pkg/repositories/models" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" + "google.golang.org/grpc/codes" + + "github.com/stretchr/testify/assert" +) + +var matchingClusterResourceAttributes = &admin.MatchingAttributes{ + Target: &admin.MatchingAttributes_ClusterResourceAttributes{ + ClusterResourceAttributes: &admin.ClusterResourceAttributes{ + Attributes: map[string]string{ + "foo": "bar", + }, + }, + }, +} + +var workflowAttributes = admin.WorkflowAttributes{ + Project: "project", + Domain: "domain", + Workflow: "workflow", + MatchingAttributes: matchingClusterResourceAttributes, +} + +var marshalledClusterResourceAttributes, _ = proto.Marshal(matchingClusterResourceAttributes) + +var matchingExecutionQueueAttributes = &admin.MatchingAttributes{ + Target: &admin.MatchingAttributes_ExecutionQueueAttributes{ + ExecutionQueueAttributes: &admin.ExecutionQueueAttributes{ + Tags: []string{ + "foo", + }, + }, + }, +} + +var projectDomainAttributes = admin.ProjectDomainAttributes{ + Project: "project", + Domain: "domain", + MatchingAttributes: matchingExecutionQueueAttributes, +} + +var marshalledExecutionQueueAttributes, _ = proto.Marshal(matchingExecutionQueueAttributes) + +func TestToProjectDomainAttributesModel(t *testing.T) { + + model, err := ProjectDomainAttributesToResourceModel(projectDomainAttributes, admin.MatchableResource_EXECUTION_QUEUE) + assert.Nil(t, err) + assert.EqualValues(t, models.Resource{ + Project: "project", + Domain: "domain", + ResourceType: admin.MatchableResource_EXECUTION_QUEUE.String(), + Priority: models.ResourcePriorityProjectDomainLevel, + Attributes: marshalledExecutionQueueAttributes, + }, model) +} + +func TestFromProjectDomainAttributesModel(t *testing.T) { + model := models.Resource{ + Project: "project", + Domain: "domain", + ResourceType: admin.MatchableResource_EXECUTION_QUEUE.String(), + Attributes: marshalledExecutionQueueAttributes, + } + unmarshalledAttributes, err := FromResourceModelToProjectDomainAttributes(model) + assert.Nil(t, err) + assert.True(t, proto.Equal(&projectDomainAttributes, &unmarshalledAttributes)) +} + +func TestFromProjectDomainAttributesModel_InvalidResourceAttributes(t *testing.T) { + model := models.Resource{ + Project: "project", + Domain: "domain", + ResourceType: admin.MatchableResource_EXECUTION_QUEUE.String(), + Attributes: []byte("i'm invalid!"), + } + _, err := FromResourceModelToProjectDomainAttributes(model) + assert.NotNil(t, err) + assert.Equal(t, codes.Internal, err.(errors.FlyteAdminError).Code()) +} + +func TestToWorkflowAttributesModel(t *testing.T) { + model, err := WorkflowAttributesToResourceModel(workflowAttributes, admin.MatchableResource_EXECUTION_QUEUE) + assert.Nil(t, err) + assert.EqualValues(t, models.Resource{ + Project: "project", + Domain: "domain", + Workflow: "workflow", + ResourceType: admin.MatchableResource_EXECUTION_QUEUE.String(), + Priority: models.ResourcePriorityWorkflowLevel, + Attributes: marshalledClusterResourceAttributes, + }, model) +} + +func TestFromWorkflowAttributesModel(t *testing.T) { + model := models.Resource{ + Project: "project", + Domain: "domain", + Workflow: "workflow", + ResourceType: admin.MatchableResource_EXECUTION_QUEUE.String(), + Attributes: marshalledClusterResourceAttributes, + } + unmarshalledAttributes, err := FromResourceModelToWorkflowAttributes(model) + assert.Nil(t, err) + assert.True(t, proto.Equal(&workflowAttributes, &unmarshalledAttributes)) +} + +func TestFromWorkflowAttributesModel_InvalidResourceAttributes(t *testing.T) { + model := models.Resource{ + Project: "project", + Domain: "domain", + Workflow: "workflow", + ResourceType: admin.MatchableResource_EXECUTION_QUEUE.String(), + Attributes: []byte("i'm invalid!"), + } + _, err := FromResourceModelToWorkflowAttributes(model) + assert.NotNil(t, err) + assert.Equal(t, codes.Internal, err.(errors.FlyteAdminError).Code()) +} diff --git a/flyteadmin/pkg/repositories/transformers/workflow_attributes.go b/flyteadmin/pkg/repositories/transformers/workflow_attributes.go deleted file mode 100644 index 34758c22e..000000000 --- a/flyteadmin/pkg/repositories/transformers/workflow_attributes.go +++ /dev/null @@ -1,39 +0,0 @@ -package transformers - -import ( - "github.com/golang/protobuf/proto" - - "github.com/lyft/flyteadmin/pkg/errors" - "github.com/lyft/flyteadmin/pkg/repositories/models" - "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" - "google.golang.org/grpc/codes" -) - -func ToWorkflowAttributesModel(attributes admin.WorkflowAttributes, resource admin.MatchableResource) (models.WorkflowAttributes, error) { - attributeBytes, err := proto.Marshal(attributes.MatchingAttributes) - if err != nil { - return models.WorkflowAttributes{}, err - } - return models.WorkflowAttributes{ - Project: attributes.Project, - Domain: attributes.Domain, - Workflow: attributes.Workflow, - Resource: resource.String(), - Attributes: attributeBytes, - }, nil -} - -func FromWorkflowAttributesModel(model models.WorkflowAttributes) (admin.WorkflowAttributes, error) { - var attributes admin.MatchingAttributes - err := proto.Unmarshal(model.Attributes, &attributes) - if err != nil { - return admin.WorkflowAttributes{}, errors.NewFlyteAdminErrorf( - codes.Internal, "Failed to decode project domain resource projectDomainAttributes with err: %v", err) - } - return admin.WorkflowAttributes{ - Project: model.Project, - Domain: model.Domain, - Workflow: model.Workflow, - MatchingAttributes: &attributes, - }, nil -} diff --git a/flyteadmin/pkg/repositories/transformers/workflow_attributes_test.go b/flyteadmin/pkg/repositories/transformers/workflow_attributes_test.go deleted file mode 100644 index b23b6f188..000000000 --- a/flyteadmin/pkg/repositories/transformers/workflow_attributes_test.go +++ /dev/null @@ -1,71 +0,0 @@ -package transformers - -import ( - "testing" - - "github.com/golang/protobuf/proto" - - "github.com/lyft/flyteadmin/pkg/errors" - "github.com/lyft/flyteadmin/pkg/repositories/models" - "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" - "google.golang.org/grpc/codes" - - "github.com/stretchr/testify/assert" -) - -var matchingClusterResourceAttributes = &admin.MatchingAttributes{ - Target: &admin.MatchingAttributes_ClusterResourceAttributes{ - ClusterResourceAttributes: &admin.ClusterResourceAttributes{ - Attributes: map[string]string{ - "foo": "bar", - }, - }, - }, -} - -var workflowAttributes = admin.WorkflowAttributes{ - Project: "project", - Domain: "domain", - Workflow: "workflow", - MatchingAttributes: matchingClusterResourceAttributes, -} - -var marshalledClusterResourceAttributes, _ = proto.Marshal(matchingClusterResourceAttributes) - -func TestToWorkflowAttributesModel(t *testing.T) { - model, err := ToWorkflowAttributesModel(workflowAttributes, admin.MatchableResource_EXECUTION_QUEUE) - assert.Nil(t, err) - assert.EqualValues(t, models.WorkflowAttributes{ - Project: "project", - Domain: "domain", - Workflow: "workflow", - Resource: admin.MatchableResource_EXECUTION_QUEUE.String(), - Attributes: marshalledClusterResourceAttributes, - }, model) -} - -func TestFromWorkflowAttributesModel(t *testing.T) { - model := models.WorkflowAttributes{ - Project: "project", - Domain: "domain", - Workflow: "workflow", - Resource: admin.MatchableResource_EXECUTION_QUEUE.String(), - Attributes: marshalledClusterResourceAttributes, - } - unmarshalledAttributes, err := FromWorkflowAttributesModel(model) - assert.Nil(t, err) - assert.True(t, proto.Equal(&workflowAttributes, &unmarshalledAttributes)) -} - -func TestFromWorkflowAttributesModel_InvalidResourceAttributes(t *testing.T) { - model := models.WorkflowAttributes{ - Project: "project", - Domain: "domain", - Workflow: "workflow", - Resource: admin.MatchableResource_EXECUTION_QUEUE.String(), - Attributes: []byte("i'm invalid!"), - } - _, err := FromWorkflowAttributesModel(model) - assert.NotNil(t, err) - assert.Equal(t, codes.Internal, err.(errors.FlyteAdminError).Code()) -} diff --git a/flyteadmin/pkg/resourcematching/overrides.go b/flyteadmin/pkg/resourcematching/overrides.go deleted file mode 100644 index 06aba4f0f..000000000 --- a/flyteadmin/pkg/resourcematching/overrides.go +++ /dev/null @@ -1,77 +0,0 @@ -package resourcematching - -import ( - "context" - - "github.com/lyft/flyteadmin/pkg/errors" - "github.com/lyft/flyteadmin/pkg/repositories" - "github.com/lyft/flyteadmin/pkg/repositories/transformers" - "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" - "google.golang.org/grpc/codes" -) - -type GetOverrideValuesInput struct { - Db repositories.RepositoryInterface - Project string - Domain string - Workflow string - Resource admin.MatchableResource -} - -func isNotFoundErr(err error) bool { - return err.(errors.FlyteAdminError) != nil && err.(errors.FlyteAdminError).Code() == codes.NotFound -} - -func GetOverrideValuesToApply(ctx context.Context, input GetOverrideValuesInput) ( - *admin.MatchingAttributes, error) { - if len(input.Project) == 0 || len(input.Domain) == 0 { - return nil, errors.NewFlyteAdminErrorf( - codes.InvalidArgument, "Invalid overrides values request configuration: [%+v]", input) - } - if len(input.Workflow) > 0 { - // Only the workflow input argument is optional - workflowAttributesModel, err := input.Db.WorkflowAttributesRepo().Get( - ctx, input.Project, input.Domain, input.Workflow, input.Resource.String()) - if err != nil && !isNotFoundErr(err) { - // Not found is fine, since not every workflow will necessarily have resource overrides. - // Any other error should be bubbled back up. - return nil, err - } else if err == nil { - workflowAttributes, err := transformers.FromWorkflowAttributesModel(workflowAttributesModel) - if err != nil { - return nil, err - } - return workflowAttributes.MatchingAttributes, nil - } - } - - projectDomainAttributesModel, err := input.Db.ProjectDomainAttributesRepo().Get( - ctx, input.Project, input.Domain, input.Resource.String()) - if err != nil && !isNotFoundErr(err) { - // Not found is fine, since not every project+domain will necessarily have resource overrides. - // Any other error should be bubbled back up. - return nil, err - } else if err == nil { - projectDomainAttributes, err := transformers.FromProjectDomainAttributesModel(projectDomainAttributesModel) - if err != nil { - return nil, err - } - return projectDomainAttributes.MatchingAttributes, nil - } - - projectAttributesModel, err := input.Db.ProjectAttributesRepo().Get(ctx, input.Project, input.Resource.String()) - if err != nil && !isNotFoundErr(err) { - // Not found is fine, since not every project will necessarily have resource overrides. - // Any other error should be bubbled back up. - return nil, err - } else if err == nil { - projectAttributes, err := transformers.FromProjectAttributesModel(projectAttributesModel) - if err != nil { - return nil, err - } - return projectAttributes.MatchingAttributes, nil - } - - // If we've made it this far then there are no matching overrides. - return nil, nil -} diff --git a/flyteadmin/pkg/resourcematching/overrides_test.go b/flyteadmin/pkg/resourcematching/overrides_test.go deleted file mode 100644 index 0fa8c038b..000000000 --- a/flyteadmin/pkg/resourcematching/overrides_test.go +++ /dev/null @@ -1,164 +0,0 @@ -package resourcematching - -import ( - "context" - "fmt" - "testing" - - "github.com/golang/protobuf/proto" - "github.com/lyft/flyteadmin/pkg/errors" - "github.com/lyft/flyteadmin/pkg/repositories/mocks" - "github.com/lyft/flyteadmin/pkg/repositories/models" - "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" - "github.com/stretchr/testify/assert" - "google.golang.org/grpc/codes" -) - -const testProject = "project" -const testDomain = "domain" -const testWorkflow = "workflow" - -func TestIsNotFoundErr(t *testing.T) { - isNotFound := errors.NewFlyteAdminError(codes.NotFound, "foo") - assert.True(t, isNotFoundErr(isNotFound)) - - invalidArgs := errors.NewFlyteAdminErrorf(codes.InvalidArgument, "bar") - assert.False(t, isNotFoundErr(invalidArgs)) -} - -func TestGetOverrideValuesToApply(t *testing.T) { - db := mocks.NewMockRepository() - matchingWorkflowAttributes := &admin.MatchingAttributes{ - Target: &admin.MatchingAttributes_ExecutionQueueAttributes{ - ExecutionQueueAttributes: &admin.ExecutionQueueAttributes{ - Tags: []string{"attr3"}, - }, - }, - } - db.WorkflowAttributesRepo().(*mocks.MockWorkflowAttributesRepo).GetFunction = func( - ctx context.Context, project, domain, workflow, resource string) ( - models.WorkflowAttributes, error) { - if project == testProject && domain == testDomain && workflow == testWorkflow && - resource == admin.MatchableResource_EXECUTION_QUEUE.String() { - - marshalledMatchingAttributes, _ := proto.Marshal(matchingWorkflowAttributes) - return models.WorkflowAttributes{ - Project: project, - Domain: domain, - Workflow: workflow, - Resource: resource, - Attributes: marshalledMatchingAttributes, - }, nil - } - if workflow == "error" { - return models.WorkflowAttributes{}, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "bar") - } - return models.WorkflowAttributes{}, errors.NewFlyteAdminError(codes.NotFound, "foo") - } - matchingProjectDomainAttributes := &admin.MatchingAttributes{ - Target: &admin.MatchingAttributes_ExecutionQueueAttributes{ - ExecutionQueueAttributes: &admin.ExecutionQueueAttributes{ - Tags: []string{"attr2"}, - }, - }, - } - db.ProjectDomainAttributesRepo().(*mocks.MockProjectDomainAttributesRepo).GetFunction = func( - ctx context.Context, project, domain, resource string) (models.ProjectDomainAttributes, error) { - if project == testProject && domain == testDomain && resource == admin.MatchableResource_EXECUTION_QUEUE.String() { - marshalledMatchingAttributes, _ := proto.Marshal(matchingProjectDomainAttributes) - return models.ProjectDomainAttributes{ - Project: project, - Domain: domain, - Resource: resource, - Attributes: marshalledMatchingAttributes, - }, nil - } - return models.ProjectDomainAttributes{}, errors.NewFlyteAdminError(codes.NotFound, "foo") - } - matchingProjectAttributes := &admin.MatchingAttributes{ - Target: &admin.MatchingAttributes_ExecutionQueueAttributes{ - ExecutionQueueAttributes: &admin.ExecutionQueueAttributes{ - Tags: []string{"attr1"}, - }, - }, - } - db.ProjectAttributesRepo().(*mocks.MockProjectAttributesRepo).GetFunction = func( - ctx context.Context, project, resource string) (models.ProjectAttributes, error) { - if project == testProject && resource == admin.MatchableResource_EXECUTION_QUEUE.String() { - marshalledMatchingAttributes, _ := proto.Marshal(matchingProjectAttributes) - return models.ProjectAttributes{ - Project: project, - Resource: resource, - Attributes: marshalledMatchingAttributes, - }, nil - } - return models.ProjectAttributes{}, errors.NewFlyteAdminError(codes.NotFound, "foo") - } - - testCases := []struct { - input GetOverrideValuesInput - expectedMatchingAttributes *admin.MatchingAttributes - expectedErr error - }{ - { - GetOverrideValuesInput{ - Db: db, - Project: "project", - Domain: "domain", - Workflow: "workflow", - Resource: admin.MatchableResource_EXECUTION_QUEUE, - }, - matchingWorkflowAttributes, - nil, - }, - { - GetOverrideValuesInput{ - Db: db, - Project: "project", - Domain: "domain", - Workflow: "workflow2", - Resource: admin.MatchableResource_EXECUTION_QUEUE, - }, - matchingProjectDomainAttributes, - nil, - }, - { - GetOverrideValuesInput{ - Db: db, - Project: "project", - Domain: "domain2", - Workflow: "workflow", - Resource: admin.MatchableResource_EXECUTION_QUEUE, - }, - matchingProjectAttributes, - nil, - }, - { - GetOverrideValuesInput{ - Db: db, - Project: "project2", - Domain: "domain", - Workflow: "workflow", - Resource: admin.MatchableResource_EXECUTION_QUEUE, - }, - nil, - nil, - }, - {GetOverrideValuesInput{ - Db: db, - Project: "project", - Domain: "domain", - Workflow: "error", - Resource: admin.MatchableResource_EXECUTION_QUEUE, - }, - nil, - errors.NewFlyteAdminErrorf(codes.InvalidArgument, "bar"), - }, - } - for _, tc := range testCases { - matchingAttributes, err := GetOverrideValuesToApply(context.Background(), tc.input) - assert.True(t, proto.Equal(tc.expectedMatchingAttributes, matchingAttributes), - fmt.Sprintf("invalid value for [%+v]", tc.input)) - assert.EqualValues(t, tc.expectedErr, err) - } -} diff --git a/flyteadmin/pkg/rpc/adminservice/attributes.go b/flyteadmin/pkg/rpc/adminservice/attributes.go index c37414415..af19df4a3 100644 --- a/flyteadmin/pkg/rpc/adminservice/attributes.go +++ b/flyteadmin/pkg/rpc/adminservice/attributes.go @@ -22,7 +22,7 @@ func (m *AdminService) UpdateWorkflowAttributes(ctx context.Context, request *ad var response *admin.WorkflowAttributesUpdateResponse var err error m.Metrics.workflowAttributesEndpointMetrics.update.Time(func() { - response, err = m.WorkflowAttributesManager.UpdateWorkflowAttributes(ctx, *request) + response, err = m.ResourceManager.UpdateWorkflowAttributes(ctx, *request) }) audit.NewLogBuilder().WithAuthenticatedCtx(ctx).WithRequest( "UpdateWorkflowAttributes", @@ -51,7 +51,7 @@ func (m *AdminService) GetWorkflowAttributes(ctx context.Context, request *admin var response *admin.WorkflowAttributesGetResponse var err error m.Metrics.workflowAttributesEndpointMetrics.get.Time(func() { - response, err = m.WorkflowAttributesManager.GetWorkflowAttributes(ctx, *request) + response, err = m.ResourceManager.GetWorkflowAttributes(ctx, *request) }) audit.NewLogBuilder().WithAuthenticatedCtx(ctx).WithRequest( "GetWorkflowAttributes", @@ -80,7 +80,7 @@ func (m *AdminService) DeleteWorkflowAttributes(ctx context.Context, request *ad var response *admin.WorkflowAttributesDeleteResponse var err error m.Metrics.workflowAttributesEndpointMetrics.delete.Time(func() { - response, err = m.WorkflowAttributesManager.DeleteWorkflowAttributes(ctx, *request) + response, err = m.ResourceManager.DeleteWorkflowAttributes(ctx, *request) }) audit.NewLogBuilder().WithAuthenticatedCtx(ctx).WithRequest( "DeleteWorkflowAttributes", @@ -109,7 +109,7 @@ func (m *AdminService) UpdateProjectDomainAttributes(ctx context.Context, reques var response *admin.ProjectDomainAttributesUpdateResponse var err error m.Metrics.projectDomainAttributesEndpointMetrics.update.Time(func() { - response, err = m.ProjectDomainAttributesManager.UpdateProjectDomainAttributes(ctx, *request) + response, err = m.ResourceManager.UpdateProjectDomainAttributes(ctx, *request) }) audit.NewLogBuilder().WithAuthenticatedCtx(ctx).WithRequest( "UpdateProjectDomainAttributes", @@ -137,7 +137,7 @@ func (m *AdminService) GetProjectDomainAttributes(ctx context.Context, request * var response *admin.ProjectDomainAttributesGetResponse var err error m.Metrics.workflowAttributesEndpointMetrics.get.Time(func() { - response, err = m.ProjectDomainAttributesManager.GetProjectDomainAttributes(ctx, *request) + response, err = m.ResourceManager.GetProjectDomainAttributes(ctx, *request) }) audit.NewLogBuilder().WithAuthenticatedCtx(ctx).WithRequest( "GetProjectDomainAttributes", @@ -165,7 +165,7 @@ func (m *AdminService) DeleteProjectDomainAttributes(ctx context.Context, reques var response *admin.ProjectDomainAttributesDeleteResponse var err error m.Metrics.workflowAttributesEndpointMetrics.delete.Time(func() { - response, err = m.ProjectDomainAttributesManager.DeleteProjectDomainAttributes(ctx, *request) + response, err = m.ResourceManager.DeleteProjectDomainAttributes(ctx, *request) }) audit.NewLogBuilder().WithAuthenticatedCtx(ctx).WithRequest( "DeleteProjectDomainAttributes", @@ -182,84 +182,3 @@ func (m *AdminService) DeleteProjectDomainAttributes(ctx context.Context, reques return response, nil } - -func (m *AdminService) UpdateProjectAttributes(ctx context.Context, request *admin.ProjectAttributesUpdateRequest) ( - *admin.ProjectAttributesUpdateResponse, error) { - defer m.interceptPanic(ctx, request) - requestedAt := time.Now() - if request == nil { - return nil, status.Errorf(codes.InvalidArgument, "Incorrect request, nil requests not allowed") - } - var response *admin.ProjectAttributesUpdateResponse - var err error - m.Metrics.projectAttributesEndpointMetrics.update.Time(func() { - response, err = m.ProjectAttributesManager.UpdateProjectAttributes(ctx, *request) - }) - if err != nil { - return nil, util.TransformAndRecordError(err, &m.Metrics.projectAttributesEndpointMetrics.update) - } - audit.NewLogBuilder().WithAuthenticatedCtx(ctx).WithRequest( - "UpdateProjectAttributes", - map[string]string{ - audit.Project: request.Attributes.Project, - }, - audit.ReadWrite, - requestedAt, - ).WithResponse(time.Now(), err).Log(ctx) - - return response, nil -} - -func (m *AdminService) GetProjectAttributes(ctx context.Context, request *admin.ProjectAttributesGetRequest) ( - *admin.ProjectAttributesGetResponse, error) { - defer m.interceptPanic(ctx, request) - requestedAt := time.Now() - if request == nil { - return nil, status.Errorf(codes.InvalidArgument, "Incorrect request, nil requests not allowed") - } - var response *admin.ProjectAttributesGetResponse - var err error - m.Metrics.workflowAttributesEndpointMetrics.get.Time(func() { - response, err = m.ProjectAttributesManager.GetProjectAttributes(ctx, *request) - }) - audit.NewLogBuilder().WithAuthenticatedCtx(ctx).WithRequest( - "GetProjectAttributes", - map[string]string{ - audit.Project: request.Project, - }, - audit.ReadOnly, - requestedAt, - ).WithResponse(time.Now(), err).Log(ctx) - if err != nil { - return nil, util.TransformAndRecordError(err, &m.Metrics.workflowAttributesEndpointMetrics.get) - } - - return response, nil -} - -func (m *AdminService) DeleteProjectAttributes(ctx context.Context, request *admin.ProjectAttributesDeleteRequest) ( - *admin.ProjectAttributesDeleteResponse, error) { - defer m.interceptPanic(ctx, request) - requestedAt := time.Now() - if request == nil { - return nil, status.Errorf(codes.InvalidArgument, "Incorrect request, nil requests not allowed") - } - var response *admin.ProjectAttributesDeleteResponse - var err error - m.Metrics.workflowAttributesEndpointMetrics.delete.Time(func() { - response, err = m.ProjectAttributesManager.DeleteProjectAttributes(ctx, *request) - }) - audit.NewLogBuilder().WithAuthenticatedCtx(ctx).WithRequest( - "DeleteProjectAttributes", - map[string]string{ - audit.Project: request.Project, - }, - audit.ReadWrite, - requestedAt, - ).WithResponse(time.Now(), err).Log(ctx) - if err != nil { - return nil, util.TransformAndRecordError(err, &m.Metrics.workflowAttributesEndpointMetrics.delete) - } - - return response, nil -} diff --git a/flyteadmin/pkg/rpc/adminservice/base.go b/flyteadmin/pkg/rpc/adminservice/base.go index a20d83586..b689f3078 100644 --- a/flyteadmin/pkg/rpc/adminservice/base.go +++ b/flyteadmin/pkg/rpc/adminservice/base.go @@ -5,6 +5,8 @@ import ( "fmt" "runtime/debug" + "github.com/lyft/flyteadmin/pkg/manager/impl/resources" + "github.com/golang/protobuf/proto" "github.com/lyft/flyteadmin/pkg/async/notifications" "github.com/lyft/flyteadmin/pkg/async/schedule" @@ -23,18 +25,16 @@ import ( ) type AdminService struct { - TaskManager interfaces.TaskInterface - WorkflowManager interfaces.WorkflowInterface - LaunchPlanManager interfaces.LaunchPlanInterface - ExecutionManager interfaces.ExecutionInterface - NodeExecutionManager interfaces.NodeExecutionInterface - TaskExecutionManager interfaces.TaskExecutionInterface - ProjectManager interfaces.ProjectInterface - ProjectAttributesManager interfaces.ProjectAttributesInterface - ProjectDomainAttributesManager interfaces.ProjectDomainAttributesInterface - WorkflowAttributesManager interfaces.WorkflowAttributesInterface - NamedEntityManager interfaces.NamedEntityInterface - Metrics AdminMetrics + TaskManager interfaces.TaskInterface + WorkflowManager interfaces.WorkflowInterface + LaunchPlanManager interfaces.LaunchPlanInterface + ExecutionManager interfaces.ExecutionInterface + NodeExecutionManager interfaces.NodeExecutionInterface + TaskExecutionManager interfaces.TaskExecutionInterface + ProjectManager interfaces.ProjectInterface + ResourceManager interfaces.ResourceInterface + NamedEntityManager interfaces.NamedEntityInterface + Metrics AdminMetrics } // Intercepts all admin requests to handle panics during execution. @@ -161,10 +161,8 @@ func NewAdminServer(kubeConfig, master string) *AdminService { db, adminScope.NewSubScope("node_execution_manager"), urlData), TaskExecutionManager: manager.NewTaskExecutionManager( db, adminScope.NewSubScope("task_execution_manager"), urlData), - ProjectManager: manager.NewProjectManager(db, configuration), - ProjectAttributesManager: manager.NewProjectAttributesManager(db), - ProjectDomainAttributesManager: manager.NewProjectDomainAttributesManager(db), - WorkflowAttributesManager: manager.NewWorkflowAttributesManager(db), - Metrics: InitMetrics(adminScope), + ProjectManager: manager.NewProjectManager(db, configuration), + ResourceManager: resources.NewResourceManager(db), + Metrics: InitMetrics(adminScope), } } diff --git a/flyteadmin/pkg/rpc/adminservice/tests/project_domain_test.go b/flyteadmin/pkg/rpc/adminservice/tests/project_domain_test.go index 11b39df16..71ad10df7 100644 --- a/flyteadmin/pkg/rpc/adminservice/tests/project_domain_test.go +++ b/flyteadmin/pkg/rpc/adminservice/tests/project_domain_test.go @@ -12,7 +12,7 @@ import ( func TestUpdateProjectDomain(t *testing.T) { ctx := context.Background() - mockProjectDomainManager := mocks.MockProjectDomainAttributesManager{} + mockProjectDomainManager := mocks.MockResourceManager{} var updateCalled bool mockProjectDomainManager.SetUpdateProjectDomainAttributes( func(ctx context.Context, @@ -22,7 +22,7 @@ func TestUpdateProjectDomain(t *testing.T) { }, ) mockServer := NewMockAdminServer(NewMockAdminServerInput{ - projectDomainAttributesManager: &mockProjectDomainManager, + resourceManager: &mockProjectDomainManager, }) resp, err := mockServer.UpdateProjectDomainAttributes(ctx, &admin.ProjectDomainAttributesUpdateRequest{ diff --git a/flyteadmin/pkg/rpc/adminservice/tests/util.go b/flyteadmin/pkg/rpc/adminservice/tests/util.go index 36fefe74c..544a3ef95 100644 --- a/flyteadmin/pkg/rpc/adminservice/tests/util.go +++ b/flyteadmin/pkg/rpc/adminservice/tests/util.go @@ -7,27 +7,27 @@ import ( ) type NewMockAdminServerInput struct { - executionManager *mocks.MockExecutionManager - launchPlanManager *mocks.MockLaunchPlanManager - nodeExecutionManager *mocks.MockNodeExecutionManager - projectManager *mocks.MockProjectManager - projectDomainAttributesManager *mocks.MockProjectDomainAttributesManager - taskManager *mocks.MockTaskManager - workflowManager *mocks.MockWorkflowManager - taskExecutionManager *mocks.MockTaskExecutionManager + executionManager *mocks.MockExecutionManager + launchPlanManager *mocks.MockLaunchPlanManager + nodeExecutionManager *mocks.MockNodeExecutionManager + projectManager *mocks.MockProjectManager + resourceManager *mocks.MockResourceManager + taskManager *mocks.MockTaskManager + workflowManager *mocks.MockWorkflowManager + taskExecutionManager *mocks.MockTaskExecutionManager } func NewMockAdminServer(input NewMockAdminServerInput) *adminservice.AdminService { var testScope = mockScope.NewTestScope() return &adminservice.AdminService{ - ExecutionManager: input.executionManager, - LaunchPlanManager: input.launchPlanManager, - NodeExecutionManager: input.nodeExecutionManager, - TaskManager: input.taskManager, - ProjectManager: input.projectManager, - ProjectDomainAttributesManager: input.projectDomainAttributesManager, - WorkflowManager: input.workflowManager, - TaskExecutionManager: input.taskExecutionManager, - Metrics: adminservice.InitMetrics(testScope), + ExecutionManager: input.executionManager, + LaunchPlanManager: input.launchPlanManager, + NodeExecutionManager: input.nodeExecutionManager, + TaskManager: input.taskManager, + ProjectManager: input.projectManager, + ResourceManager: input.resourceManager, + WorkflowManager: input.workflowManager, + TaskExecutionManager: input.taskExecutionManager, + Metrics: adminservice.InitMetrics(testScope), } } diff --git a/flyteadmin/tests/attributes_test.go b/flyteadmin/tests/attributes_test.go index 842a5860f..b75891fc2 100644 --- a/flyteadmin/tests/attributes_test.go +++ b/flyteadmin/tests/attributes_test.go @@ -24,58 +24,13 @@ var matchingAttributes = &admin.MatchingAttributes{ }, } -func TestProjectAttributes(t *testing.T) { - ctx := context.Background() - - db := databaseConfig.OpenDbConnection(databaseConfig.NewPostgresConfigProvider(getLocalDbConfig(), adminScope)) - truncateTableForTesting(db, "project_attributes") - db.Close() - - client, conn := GetTestAdminServiceClient() - defer conn.Close() - - req := admin.ProjectAttributesUpdateRequest{ - Attributes: &admin.ProjectAttributes{ - Project: "admintests", - MatchingAttributes: matchingAttributes, - }, - } - - _, err := client.UpdateProjectAttributes(ctx, &req) - assert.Nil(t, err) - - response, err := client.GetProjectAttributes(ctx, &admin.ProjectAttributesGetRequest{ - Project: "admintests", - ResourceType: admin.MatchableResource_TASK_RESOURCE, - }) - assert.Nil(t, err) - assert.True(t, proto.Equal(&admin.ProjectAttributesGetResponse{ - Attributes: &admin.ProjectAttributes{ - Project: "admintests", - MatchingAttributes: matchingAttributes, - }, - }, response)) - - _, err = client.DeleteProjectAttributes(ctx, &admin.ProjectAttributesDeleteRequest{ - Project: "admintests", - ResourceType: admin.MatchableResource_TASK_RESOURCE, - }) - assert.Nil(t, err) - - _, err = client.GetProjectAttributes(ctx, &admin.ProjectAttributesGetRequest{ - Project: "admintests", - ResourceType: admin.MatchableResource_TASK_RESOURCE, - }) - assert.EqualError(t, err, "rpc error: code = NotFound desc = entry not found") -} - func TestUpdateProjectDomainAttributes(t *testing.T) { ctx := context.Background() client, conn := GetTestAdminServiceClient() defer conn.Close() - db := databaseConfig.OpenDbConnection(databaseConfig.NewPostgresConfigProvider(getLocalDbConfig(), adminScope)) - truncateTableForTesting(db, "project_domain_attributes") + db := databaseConfig.OpenDbConnection(databaseConfig.NewPostgresConfigProvider(getDbConfig(), adminScope)) + truncateTableForTesting(db, "resources") db.Close() req := admin.ProjectDomainAttributesUpdateRequest{ @@ -103,6 +58,23 @@ func TestUpdateProjectDomainAttributes(t *testing.T) { }, }, response)) + workflowResponse, err := client.GetWorkflowAttributes(ctx, &admin.WorkflowAttributesGetRequest{ + Project: "admintests", + Domain: "development", + Workflow: "workflow", + ResourceType: admin.MatchableResource_TASK_RESOURCE, + }) + assert.Nil(t, err) + // Testing that if overrides are not set at workflow level, the one from Project-Domain is returned + assert.True(t, proto.Equal(&admin.WorkflowAttributesGetResponse{ + Attributes: &admin.WorkflowAttributes{ + Project: "admintests", + Domain: "development", + Workflow: "", + MatchingAttributes: matchingAttributes, + }, + }, workflowResponse)) + _, err = client.DeleteProjectDomainAttributes(ctx, &admin.ProjectDomainAttributesDeleteRequest{ Project: "admintests", Domain: "development", @@ -110,12 +82,13 @@ func TestUpdateProjectDomainAttributes(t *testing.T) { }) assert.Nil(t, err) - _, err = client.GetProjectDomainAttributes(ctx, &admin.ProjectDomainAttributesGetRequest{ + response, err = client.GetProjectDomainAttributes(ctx, &admin.ProjectDomainAttributesGetRequest{ Project: "admintests", Domain: "development", ResourceType: admin.MatchableResource_TASK_RESOURCE, }) - assert.EqualError(t, err, "rpc error: code = NotFound desc = entry not found") + assert.Nil(t, response) + assert.EqualError(t, err, "rpc error: code = NotFound desc = {Project:admintests Domain:development Workflow: LaunchPlan: ResourceType:TASK_RESOURCE}") } func TestUpdateWorkflowAttributes(t *testing.T) { @@ -123,8 +96,8 @@ func TestUpdateWorkflowAttributes(t *testing.T) { client, conn := GetTestAdminServiceClient() defer conn.Close() - db := databaseConfig.OpenDbConnection(databaseConfig.NewPostgresConfigProvider(getLocalDbConfig(), adminScope)) - truncateTableForTesting(db, "workflow_attributes") + db := databaseConfig.OpenDbConnection(databaseConfig.NewPostgresConfigProvider(getDbConfig(), adminScope)) + truncateTableForTesting(db, "resources") db.Close() req := admin.WorkflowAttributesUpdateRequest{ @@ -169,5 +142,5 @@ func TestUpdateWorkflowAttributes(t *testing.T) { Workflow: "workflow", ResourceType: admin.MatchableResource_TASK_RESOURCE, }) - assert.EqualError(t, err, "rpc error: code = NotFound desc = entry not found") + assert.EqualError(t, err, "rpc error: code = NotFound desc = {Project:admintests Domain:development Workflow:workflow LaunchPlan: ResourceType:TASK_RESOURCE}") } diff --git a/flyteadmin/tests/bootstrap.go b/flyteadmin/tests/bootstrap.go index 559b70ba0..a1b015403 100644 --- a/flyteadmin/tests/bootstrap.go +++ b/flyteadmin/tests/bootstrap.go @@ -53,6 +53,7 @@ func truncateAllTablesForTestingOnly() { TruncateNodeExecutions := fmt.Sprintf("TRUNCATE TABLE node_executions;") TruncateNodeExecutionEvents := fmt.Sprintf("TRUNCATE TABLE node_execution_events;") TruncateTaskExecutions := fmt.Sprintf("TRUNCATE TABLE task_executions;") + TruncateResources := fmt.Sprintf("TRUNCATE TABLE resources;") db := database_config.OpenDbConnection(database_config.NewPostgresConfigProvider(getDbConfig(), adminScope)) defer db.Close() db.Exec(TruncateTasks) @@ -64,6 +65,7 @@ func truncateAllTablesForTestingOnly() { db.Exec(TruncateNodeExecutions) db.Exec(TruncateNodeExecutionEvents) db.Exec(TruncateTaskExecutions) + db.Exec(TruncateResources) } func populateWorkflowExecutionForTestingOnly(project, domain, name string) {