diff --git a/flyteadmin/pkg/manager/impl/execution_manager.go b/flyteadmin/pkg/manager/impl/execution_manager.go index 2e33fb7585..a2ac00ce61 100644 --- a/flyteadmin/pkg/manager/impl/execution_manager.go +++ b/flyteadmin/pkg/manager/impl/execution_manager.go @@ -14,8 +14,6 @@ import ( "github.com/flyteorg/flyteadmin/auth" - "k8s.io/apimachinery/pkg/api/resource" - "github.com/flyteorg/flyteadmin/pkg/manager/impl/resources" dataInterfaces "github.com/flyteorg/flyteadmin/pkg/data/interfaces" @@ -186,39 +184,6 @@ func (m *ExecutionManager) addPluginOverrides(ctx context.Context, executionID * return nil, nil } -type completeTaskResources struct { - Defaults runtimeInterfaces.TaskResourceSet - Limits runtimeInterfaces.TaskResourceSet -} - -func getTaskResourcesAsSet(ctx context.Context, identifier *core.Identifier, - resourceEntries []*core.Resources_ResourceEntry, resourceName string) runtimeInterfaces.TaskResourceSet { - - result := runtimeInterfaces.TaskResourceSet{} - for _, entry := range resourceEntries { - switch entry.Name { - case core.Resources_CPU: - result.CPU = parseQuantityNoError(ctx, identifier.String(), fmt.Sprintf("%v.cpu", resourceName), entry.Value) - case core.Resources_MEMORY: - result.Memory = parseQuantityNoError(ctx, identifier.String(), fmt.Sprintf("%v.memory", resourceName), entry.Value) - case core.Resources_EPHEMERAL_STORAGE: - result.EphemeralStorage = parseQuantityNoError(ctx, identifier.String(), - fmt.Sprintf("%v.ephemeral storage", resourceName), entry.Value) - case core.Resources_GPU: - result.GPU = parseQuantityNoError(ctx, identifier.String(), "gpu", entry.Value) - } - } - - return result -} - -func getCompleteTaskResourceRequirements(ctx context.Context, identifier *core.Identifier, task *core.CompiledTask) completeTaskResources { - return completeTaskResources{ - Defaults: getTaskResourcesAsSet(ctx, identifier, task.GetTemplate().GetContainer().Resources.Requests, "requests"), - Limits: getTaskResourcesAsSet(ctx, identifier, task.GetTemplate().GetContainer().Resources.Limits, "limits"), - } -} - // TODO: Delete this code usage after the flyte v0.17.0 release // Assumes input contains a compiled task with a valid container resource execConfig. // @@ -254,7 +219,7 @@ func (m *ExecutionManager) setCompiledTaskDefaults(ctx context.Context, task *co // The IDL representation for container-type tasks represents resources as a list with string quantities. // In order to easily reason about them we convert them to a set where we can O(1) fetch specific resources (e.g. CPU) // and represent them as comparable quantities rather than strings. - taskResourceRequirements := getCompleteTaskResourceRequirements(ctx, task.Template.Id, task) + taskResourceRequirements := util.GetCompleteTaskResourceRequirements(ctx, task.Template.Id, task) cpu := flytek8s.AdjustOrDefaultResource(taskResourceRequirements.Defaults.CPU, taskResourceRequirements.Limits.CPU, platformTaskResources.Defaults.CPU, platformTaskResources.Limits.CPU) @@ -334,68 +299,6 @@ func (m *ExecutionManager) setCompiledTaskDefaults(ctx context.Context, task *co } } -func parseQuantityNoError(ctx context.Context, ownerID, name, value string) resource.Quantity { - q, err := resource.ParseQuantity(value) - if err != nil { - logger.Infof(ctx, "Failed to parse owner's [%s] resource [%s]'s value [%s] with err: %v", ownerID, name, value, err) - } - - return q -} - -func fromAdminProtoTaskResourceSpec(ctx context.Context, spec *admin.TaskResourceSpec) runtimeInterfaces.TaskResourceSet { - result := runtimeInterfaces.TaskResourceSet{} - if len(spec.Cpu) > 0 { - result.CPU = parseQuantityNoError(ctx, "project", "cpu", spec.Cpu) - } - - if len(spec.Memory) > 0 { - result.Memory = parseQuantityNoError(ctx, "project", "memory", spec.Memory) - } - - if len(spec.Storage) > 0 { - result.Storage = parseQuantityNoError(ctx, "project", "storage", spec.Storage) - } - - if len(spec.EphemeralStorage) > 0 { - result.EphemeralStorage = parseQuantityNoError(ctx, "project", "ephemeral storage", spec.EphemeralStorage) - } - - if len(spec.Gpu) > 0 { - result.GPU = parseQuantityNoError(ctx, "project", "gpu", spec.Gpu) - } - - return result -} - -func (m *ExecutionManager) getTaskResources(ctx context.Context, workflow *core.Identifier) workflowengineInterfaces.TaskResources { - resource, err := m.resourceManager.GetResource(ctx, interfaces.ResourceRequest{ - Project: workflow.Project, - Domain: workflow.Domain, - Workflow: workflow.Name, - ResourceType: admin.MatchableResource_TASK_RESOURCE, - }) - - if err != nil { - logger.Warningf(ctx, "Failed to fetch override values when assigning task resource default values for [%+v]: %v", - workflow, err) - } - - logger.Debugf(ctx, "Assigning task requested resources for [%+v]", workflow) - var taskResourceAttributes = workflowengineInterfaces.TaskResources{} - if resource != nil && resource.Attributes != nil && resource.Attributes.GetTaskResourceAttributes() != nil { - taskResourceAttributes.Defaults = fromAdminProtoTaskResourceSpec(ctx, resource.Attributes.GetTaskResourceAttributes().Defaults) - taskResourceAttributes.Limits = fromAdminProtoTaskResourceSpec(ctx, resource.Attributes.GetTaskResourceAttributes().Limits) - } else { - taskResourceAttributes = workflowengineInterfaces.TaskResources{ - Defaults: m.config.TaskResourceConfiguration().GetDefaults(), - Limits: m.config.TaskResourceConfiguration().GetLimits(), - } - } - - return taskResourceAttributes -} - // Fetches inherited execution metadata including the parent node execution db model id and the source execution model id // as well as sets request spec metadata with the inherited principal and adjusted nesting data. func (m *ExecutionManager) getInheritedExecMetadata(ctx context.Context, requestSpec *admin.ExecutionSpec, @@ -612,7 +515,7 @@ func (m *ExecutionManager) launchSingleTaskExecution( } // Dynamically assign task resource defaults. - platformTaskResources := m.getTaskResources(ctx, workflow.Id) + platformTaskResources := util.GetTaskResources(ctx, workflow.Id, m.resourceManager, m.config.TaskResourceConfiguration()) for _, t := range workflow.Closure.CompiledWorkflow.Tasks { m.setCompiledTaskDefaults(ctx, t, platformTaskResources) } @@ -863,8 +766,8 @@ func (m *ExecutionManager) launchExecutionAndPrepareModel( return nil, nil, err } - platformTaskResources := m.getTaskResources(ctx, workflow.Id) // Dynamically assign task resource defaults. + platformTaskResources := util.GetTaskResources(ctx, workflow.Id, m.resourceManager, m.config.TaskResourceConfiguration()) for _, task := range workflow.Closure.CompiledWorkflow.Tasks { m.setCompiledTaskDefaults(ctx, task, platformTaskResources) } diff --git a/flyteadmin/pkg/manager/impl/execution_manager_test.go b/flyteadmin/pkg/manager/impl/execution_manager_test.go index 2ccd7b7c8a..e8d16d348a 100644 --- a/flyteadmin/pkg/manager/impl/execution_manager_test.go +++ b/flyteadmin/pkg/manager/impl/execution_manager_test.go @@ -3725,90 +3725,6 @@ func TestListExecutions_LegacyModel(t *testing.T) { assert.Empty(t, executionList.Token) } -func TestGetTaskResourcesAsSet(t *testing.T) { - taskResources := getTaskResourcesAsSet(context.TODO(), &core.Identifier{}, []*core.Resources_ResourceEntry{ - { - Name: core.Resources_CPU, - Value: "100", - }, - { - Name: core.Resources_MEMORY, - Value: "200", - }, - { - Name: core.Resources_EPHEMERAL_STORAGE, - Value: "300", - }, - { - Name: core.Resources_GPU, - Value: "400", - }, - }, "request") - assert.True(t, taskResources.CPU.Equal(resource.MustParse("100"))) - assert.True(t, taskResources.Memory.Equal(resource.MustParse("200"))) - assert.True(t, taskResources.EphemeralStorage.Equal(resource.MustParse("300"))) - assert.True(t, taskResources.GPU.Equal(resource.MustParse("400"))) -} - -func TestGetCompleteTaskResourceRequirements(t *testing.T) { - taskResources := getCompleteTaskResourceRequirements(context.TODO(), &core.Identifier{}, &core.CompiledTask{ - Template: &core.TaskTemplate{ - Target: &core.TaskTemplate_Container{ - Container: &core.Container{ - Resources: &core.Resources{ - Requests: []*core.Resources_ResourceEntry{ - { - Name: core.Resources_CPU, - Value: "100", - }, - { - Name: core.Resources_MEMORY, - Value: "200", - }, - { - Name: core.Resources_EPHEMERAL_STORAGE, - Value: "300", - }, - { - Name: core.Resources_GPU, - Value: "400", - }, - }, - Limits: []*core.Resources_ResourceEntry{ - { - Name: core.Resources_CPU, - Value: "200", - }, - { - Name: core.Resources_MEMORY, - Value: "400", - }, - { - Name: core.Resources_EPHEMERAL_STORAGE, - Value: "600", - }, - { - Name: core.Resources_GPU, - Value: "800", - }, - }, - }, - }, - }, - }, - }) - - assert.True(t, taskResources.Defaults.CPU.Equal(resource.MustParse("100"))) - assert.True(t, taskResources.Defaults.Memory.Equal(resource.MustParse("200"))) - assert.True(t, taskResources.Defaults.EphemeralStorage.Equal(resource.MustParse("300"))) - assert.True(t, taskResources.Defaults.GPU.Equal(resource.MustParse("400"))) - - assert.True(t, taskResources.Limits.CPU.Equal(resource.MustParse("200"))) - assert.True(t, taskResources.Limits.Memory.Equal(resource.MustParse("400"))) - assert.True(t, taskResources.Limits.EphemeralStorage.Equal(resource.MustParse("600"))) - assert.True(t, taskResources.Limits.GPU.Equal(resource.MustParse("800"))) -} - func TestSetDefaults(t *testing.T) { task := &core.CompiledTask{ Template: &core.TaskTemplate{ @@ -5379,122 +5295,6 @@ func TestResolvePermissions(t *testing.T) { }) } -func TestGetTaskResources(t *testing.T) { - taskConfig := runtimeMocks.MockTaskResourceConfiguration{} - taskConfig.Defaults = runtimeInterfaces.TaskResourceSet{ - CPU: resource.MustParse("200m"), - GPU: resource.MustParse("8"), - Memory: resource.MustParse("200Gi"), - EphemeralStorage: resource.MustParse("500Mi"), - Storage: resource.MustParse("400Mi"), - } - taskConfig.Limits = runtimeInterfaces.TaskResourceSet{ - CPU: resource.MustParse("300m"), - GPU: resource.MustParse("8"), - Memory: resource.MustParse("500Gi"), - EphemeralStorage: resource.MustParse("501Mi"), - Storage: resource.MustParse("450Mi"), - } - mockConfig := runtimeMocks.NewMockConfigurationProvider( - testutils.GetApplicationConfigWithDefaultDomains(), nil, nil, &taskConfig, - runtimeMocks.NewMockWhitelistConfiguration(), nil) - - t.Run("use runtime application values", func(t *testing.T) { - r := plugins.NewRegistry() - r.RegisterDefault(plugins.PluginIDWorkflowExecutor, &defaultTestExecutor) - execManager := NewExecutionManager(repositoryMocks.NewMockRepository(), r, mockConfig, getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockScope.NewTestScope(), &mockPublisher, mockExecutionRemoteURL, nil, nil, nil, nil, &eventWriterMocks.WorkflowExecutionEventWriter{}) - taskResourceAttrs := execManager.(*ExecutionManager).getTaskResources(context.TODO(), &workflowIdentifier) - assert.EqualValues(t, taskResourceAttrs, workflowengineInterfaces.TaskResources{ - Defaults: runtimeInterfaces.TaskResourceSet{ - CPU: resource.MustParse("200m"), - GPU: resource.MustParse("8"), - Memory: resource.MustParse("200Gi"), - EphemeralStorage: resource.MustParse("500Mi"), - Storage: resource.MustParse("400Mi"), - }, - Limits: runtimeInterfaces.TaskResourceSet{ - CPU: resource.MustParse("300m"), - GPU: resource.MustParse("8"), - Memory: resource.MustParse("500Gi"), - EphemeralStorage: resource.MustParse("501Mi"), - Storage: resource.MustParse("450Mi"), - }, - }) - }) - t.Run("use specific overrides", func(t *testing.T) { - resourceManager := managerMocks.MockResourceManager{} - resourceManager.GetResourceFunc = func(ctx context.Context, - request managerInterfaces.ResourceRequest) (*managerInterfaces.ResourceResponse, error) { - assert.EqualValues(t, request, managerInterfaces.ResourceRequest{ - Project: workflowIdentifier.Project, - Domain: workflowIdentifier.Domain, - Workflow: workflowIdentifier.Name, - ResourceType: admin.MatchableResource_TASK_RESOURCE, - }) - return &managerInterfaces.ResourceResponse{ - Attributes: &admin.MatchingAttributes{ - Target: &admin.MatchingAttributes_TaskResourceAttributes{ - TaskResourceAttributes: &admin.TaskResourceAttributes{ - Defaults: &admin.TaskResourceSpec{ - Cpu: "1200m", - Gpu: "18", - Memory: "1200Gi", - EphemeralStorage: "1500Mi", - Storage: "1400Mi", - }, - Limits: &admin.TaskResourceSpec{ - Cpu: "300m", - Gpu: "8", - Memory: "500Gi", - EphemeralStorage: "501Mi", - Storage: "450Mi", - }, - }, - }, - }, - }, nil - } - executionManager := ExecutionManager{ - resourceManager: &resourceManager, - config: mockConfig, - } - taskResourceAttrs := executionManager.getTaskResources(context.TODO(), &workflowIdentifier) - assert.EqualValues(t, taskResourceAttrs, workflowengineInterfaces.TaskResources{ - Defaults: runtimeInterfaces.TaskResourceSet{ - CPU: resource.MustParse("1200m"), - GPU: resource.MustParse("18"), - Memory: resource.MustParse("1200Gi"), - EphemeralStorage: resource.MustParse("1500Mi"), - Storage: resource.MustParse("1400Mi"), - }, - Limits: runtimeInterfaces.TaskResourceSet{ - CPU: resource.MustParse("300m"), - GPU: resource.MustParse("8"), - Memory: resource.MustParse("500Gi"), - EphemeralStorage: resource.MustParse("501Mi"), - Storage: resource.MustParse("450Mi"), - }, - }) - }) -} - -func TestFromAdminProtoTaskResourceSpec(t *testing.T) { - taskResourceSet := fromAdminProtoTaskResourceSpec(context.TODO(), &admin.TaskResourceSpec{ - Cpu: "1", - Memory: "100", - Storage: "200", - EphemeralStorage: "300", - Gpu: "2", - }) - assert.EqualValues(t, runtimeInterfaces.TaskResourceSet{ - CPU: resource.MustParse("1"), - Memory: resource.MustParse("100"), - Storage: resource.MustParse("200"), - EphemeralStorage: resource.MustParse("300"), - GPU: resource.MustParse("2"), - }, taskResourceSet) -} - func TestAddStateFilter(t *testing.T) { t.Run("empty filters", func(t *testing.T) { var filters []common.InlineFilter diff --git a/flyteadmin/pkg/manager/impl/task_manager.go b/flyteadmin/pkg/manager/impl/task_manager.go index 50aa719936..b4346fcd9a 100644 --- a/flyteadmin/pkg/manager/impl/task_manager.go +++ b/flyteadmin/pkg/manager/impl/task_manager.go @@ -20,6 +20,7 @@ import ( "github.com/flyteorg/flyteadmin/pkg/common" "github.com/flyteorg/flyteadmin/pkg/errors" + "github.com/flyteorg/flyteadmin/pkg/manager/impl/resources" "github.com/flyteorg/flyteadmin/pkg/manager/impl/util" "github.com/flyteorg/flyteadmin/pkg/manager/impl/validation" "github.com/flyteorg/flyteadmin/pkg/manager/interfaces" @@ -38,10 +39,11 @@ type taskMetrics struct { } type TaskManager struct { - db repoInterfaces.Repository - config runtimeInterfaces.Configuration - compiler workflowengine.Compiler - metrics taskMetrics + db repoInterfaces.Repository + config runtimeInterfaces.Configuration + compiler workflowengine.Compiler + metrics taskMetrics + resourceManager interfaces.ResourceInterface } func getTaskContext(ctx context.Context, identifier *core.Identifier) context.Context { @@ -62,7 +64,8 @@ func setDefaults(request admin.TaskCreateRequest) (admin.TaskCreateRequest, erro func (t *TaskManager) CreateTask( ctx context.Context, request admin.TaskCreateRequest) (*admin.TaskCreateResponse, error) { - if err := validation.ValidateTask(ctx, request, t.db, t.config.TaskResourceConfiguration(), + platformTaskResources := util.GetTaskResources(ctx, request.Id, t.resourceManager, t.config.TaskResourceConfiguration()) + if err := validation.ValidateTask(ctx, request, t.db, platformTaskResources, t.config.WhitelistConfiguration(), t.config.ApplicationConfiguration()); err != nil { logger.Debugf(ctx, "Task [%+v] failed validation with err: %v", request.Id, err) return nil, err @@ -269,10 +272,12 @@ func NewTaskManager( ClosureSizeBytes: scope.MustNewSummary("closure_size_bytes", "size in bytes of serialized task closure"), Registered: labeled.NewCounter("num_registered", "count of registered tasks", scope), } + resourceManager := resources.NewResourceManager(db, config.ApplicationConfiguration()) return &TaskManager{ - db: db, - config: config, - compiler: compiler, - metrics: metrics, + db: db, + config: config, + compiler: compiler, + metrics: metrics, + resourceManager: resourceManager, } } diff --git a/flyteadmin/pkg/manager/impl/util/resources.go b/flyteadmin/pkg/manager/impl/util/resources.go new file mode 100644 index 0000000000..f096957231 --- /dev/null +++ b/flyteadmin/pkg/manager/impl/util/resources.go @@ -0,0 +1,120 @@ +package util + +import ( + "context" + "fmt" + + "github.com/flyteorg/flyteadmin/pkg/manager/interfaces" + runtimeInterfaces "github.com/flyteorg/flyteadmin/pkg/runtime/interfaces" + workflowengineInterfaces "github.com/flyteorg/flyteadmin/pkg/workflowengine/interfaces" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flytestdlib/logger" + "k8s.io/apimachinery/pkg/api/resource" +) + +// parseQuantityNoError parses the k8s defined resource quantity gracefully masking errors. +func parseQuantityNoError(ctx context.Context, ownerID, name, value string) resource.Quantity { + q, err := resource.ParseQuantity(value) + if err != nil { + logger.Infof(ctx, "Failed to parse owner's [%s] resource [%s]'s value [%s] with err: %v", ownerID, name, value, err) + } + + return q +} + +// getTaskResourcesAsSet converts a list of flyteidl `ResourceEntry` messages into a singular `TaskResourceSet`. +func getTaskResourcesAsSet(ctx context.Context, identifier *core.Identifier, + resourceEntries []*core.Resources_ResourceEntry, resourceName string) runtimeInterfaces.TaskResourceSet { + + result := runtimeInterfaces.TaskResourceSet{} + for _, entry := range resourceEntries { + switch entry.Name { + case core.Resources_CPU: + result.CPU = parseQuantityNoError(ctx, identifier.String(), fmt.Sprintf("%v.cpu", resourceName), entry.Value) + case core.Resources_MEMORY: + result.Memory = parseQuantityNoError(ctx, identifier.String(), fmt.Sprintf("%v.memory", resourceName), entry.Value) + case core.Resources_EPHEMERAL_STORAGE: + result.EphemeralStorage = parseQuantityNoError(ctx, identifier.String(), + fmt.Sprintf("%v.ephemeral storage", resourceName), entry.Value) + case core.Resources_GPU: + result.GPU = parseQuantityNoError(ctx, identifier.String(), "gpu", entry.Value) + } + } + + return result +} + +// GetCompleteTaskResourceRequirements parses the resource requests and limits from the `TaskTemplate` Container. +func GetCompleteTaskResourceRequirements(ctx context.Context, identifier *core.Identifier, task *core.CompiledTask) workflowengineInterfaces.TaskResources { + return workflowengineInterfaces.TaskResources{ + Defaults: getTaskResourcesAsSet(ctx, identifier, task.GetTemplate().GetContainer().Resources.Requests, "requests"), + Limits: getTaskResourcesAsSet(ctx, identifier, task.GetTemplate().GetContainer().Resources.Limits, "limits"), + } +} + +// fromAdminProtoTaskResourceSpec parses the flyteidl `TaskResourceSpec` message into a `TaskResourceSet`. +func fromAdminProtoTaskResourceSpec(ctx context.Context, spec *admin.TaskResourceSpec) runtimeInterfaces.TaskResourceSet { + result := runtimeInterfaces.TaskResourceSet{} + if len(spec.Cpu) > 0 { + result.CPU = parseQuantityNoError(ctx, "project", "cpu", spec.Cpu) + } + + if len(spec.Memory) > 0 { + result.Memory = parseQuantityNoError(ctx, "project", "memory", spec.Memory) + } + + if len(spec.Storage) > 0 { + result.Storage = parseQuantityNoError(ctx, "project", "storage", spec.Storage) + } + + if len(spec.EphemeralStorage) > 0 { + result.EphemeralStorage = parseQuantityNoError(ctx, "project", "ephemeral storage", spec.EphemeralStorage) + } + + if len(spec.Gpu) > 0 { + result.GPU = parseQuantityNoError(ctx, "project", "gpu", spec.Gpu) + } + + return result +} + +// GetTaskResources returns the most specific default and limit task resources for the specified id. This first checks +// if there is a matchable resource(s) defined, and uses the highest priority one, otherwise it falls back to using the +// flyteadmin default configured values. +func GetTaskResources(ctx context.Context, id *core.Identifier, resourceManager interfaces.ResourceInterface, + taskResourceConfig runtimeInterfaces.TaskResourceConfiguration) workflowengineInterfaces.TaskResources { + + request := interfaces.ResourceRequest{ + ResourceType: admin.MatchableResource_TASK_RESOURCE, + } + if id != nil && len(id.Project) > 0 { + request.Project = id.Project + } + if id != nil && len(id.Domain) > 0 { + request.Domain = id.Domain + } + if id != nil && id.ResourceType == core.ResourceType_WORKFLOW && len(id.Name) > 0 { + request.Workflow = id.Name + } + + resource, err := resourceManager.GetResource(ctx, request) + if err != nil { + logger.Warningf(ctx, "Failed to fetch override values when assigning task resource default values for [%+v]: %v", + id, err) + } + + logger.Debugf(ctx, "Assigning task requested resources for [%+v]", id) + var taskResourceAttributes = workflowengineInterfaces.TaskResources{} + if resource != nil && resource.Attributes != nil && resource.Attributes.GetTaskResourceAttributes() != nil { + taskResourceAttributes.Defaults = fromAdminProtoTaskResourceSpec(ctx, resource.Attributes.GetTaskResourceAttributes().Defaults) + taskResourceAttributes.Limits = fromAdminProtoTaskResourceSpec(ctx, resource.Attributes.GetTaskResourceAttributes().Limits) + } else { + taskResourceAttributes = workflowengineInterfaces.TaskResources{ + Defaults: taskResourceConfig.GetDefaults(), + Limits: taskResourceConfig.GetLimits(), + } + } + + return taskResourceAttributes +} diff --git a/flyteadmin/pkg/manager/impl/util/resources_test.go b/flyteadmin/pkg/manager/impl/util/resources_test.go new file mode 100644 index 0000000000..f4180c4b52 --- /dev/null +++ b/flyteadmin/pkg/manager/impl/util/resources_test.go @@ -0,0 +1,229 @@ +package util + +import ( + "context" + "testing" + + managerInterfaces "github.com/flyteorg/flyteadmin/pkg/manager/interfaces" + managerMocks "github.com/flyteorg/flyteadmin/pkg/manager/mocks" + runtimeInterfaces "github.com/flyteorg/flyteadmin/pkg/runtime/interfaces" + runtimeMocks "github.com/flyteorg/flyteadmin/pkg/runtime/mocks" + workflowengineInterfaces "github.com/flyteorg/flyteadmin/pkg/workflowengine/interfaces" + + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + + "github.com/stretchr/testify/assert" + + "k8s.io/apimachinery/pkg/api/resource" +) + +var workflowIdentifier = core.Identifier{ + ResourceType: core.ResourceType_WORKFLOW, + Project: "project", + Domain: "domain", + Name: "name", + Version: "version", +} + +func TestGetTaskResources(t *testing.T) { + taskConfig := runtimeMocks.MockTaskResourceConfiguration{} + taskConfig.Defaults = runtimeInterfaces.TaskResourceSet{ + CPU: resource.MustParse("200m"), + GPU: resource.MustParse("8"), + Memory: resource.MustParse("200Gi"), + EphemeralStorage: resource.MustParse("500Mi"), + Storage: resource.MustParse("400Mi"), + } + taskConfig.Limits = runtimeInterfaces.TaskResourceSet{ + CPU: resource.MustParse("300m"), + GPU: resource.MustParse("8"), + Memory: resource.MustParse("500Gi"), + EphemeralStorage: resource.MustParse("501Mi"), + Storage: resource.MustParse("450Mi"), + } + + t.Run("use runtime application values", func(t *testing.T) { + resourceManager := managerMocks.MockResourceManager{} + resourceManager.GetResourceFunc = func(ctx context.Context, + request managerInterfaces.ResourceRequest) (*managerInterfaces.ResourceResponse, error) { + assert.EqualValues(t, request, managerInterfaces.ResourceRequest{ + Project: workflowIdentifier.Project, + Domain: workflowIdentifier.Domain, + Workflow: workflowIdentifier.Name, + ResourceType: admin.MatchableResource_TASK_RESOURCE, + }) + return &managerInterfaces.ResourceResponse{}, nil + } + + taskResourceAttrs := GetTaskResources(context.TODO(), &workflowIdentifier, &resourceManager, &taskConfig) + assert.EqualValues(t, taskResourceAttrs, workflowengineInterfaces.TaskResources{ + Defaults: runtimeInterfaces.TaskResourceSet{ + CPU: resource.MustParse("200m"), + GPU: resource.MustParse("8"), + Memory: resource.MustParse("200Gi"), + EphemeralStorage: resource.MustParse("500Mi"), + Storage: resource.MustParse("400Mi"), + }, + Limits: runtimeInterfaces.TaskResourceSet{ + CPU: resource.MustParse("300m"), + GPU: resource.MustParse("8"), + Memory: resource.MustParse("500Gi"), + EphemeralStorage: resource.MustParse("501Mi"), + Storage: resource.MustParse("450Mi"), + }, + }) + }) + t.Run("use specific overrides", func(t *testing.T) { + resourceManager := managerMocks.MockResourceManager{} + resourceManager.GetResourceFunc = func(ctx context.Context, + request managerInterfaces.ResourceRequest) (*managerInterfaces.ResourceResponse, error) { + assert.EqualValues(t, request, managerInterfaces.ResourceRequest{ + Project: workflowIdentifier.Project, + Domain: workflowIdentifier.Domain, + Workflow: workflowIdentifier.Name, + ResourceType: admin.MatchableResource_TASK_RESOURCE, + }) + return &managerInterfaces.ResourceResponse{ + Attributes: &admin.MatchingAttributes{ + Target: &admin.MatchingAttributes_TaskResourceAttributes{ + TaskResourceAttributes: &admin.TaskResourceAttributes{ + Defaults: &admin.TaskResourceSpec{ + Cpu: "1200m", + Gpu: "18", + Memory: "1200Gi", + EphemeralStorage: "1500Mi", + Storage: "1400Mi", + }, + Limits: &admin.TaskResourceSpec{ + Cpu: "300m", + Gpu: "8", + Memory: "500Gi", + EphemeralStorage: "501Mi", + Storage: "450Mi", + }, + }, + }, + }, + }, nil + } + taskResourceAttrs := GetTaskResources(context.TODO(), &workflowIdentifier, &resourceManager, &taskConfig) + assert.EqualValues(t, taskResourceAttrs, workflowengineInterfaces.TaskResources{ + Defaults: runtimeInterfaces.TaskResourceSet{ + CPU: resource.MustParse("1200m"), + GPU: resource.MustParse("18"), + Memory: resource.MustParse("1200Gi"), + EphemeralStorage: resource.MustParse("1500Mi"), + Storage: resource.MustParse("1400Mi"), + }, + Limits: runtimeInterfaces.TaskResourceSet{ + CPU: resource.MustParse("300m"), + GPU: resource.MustParse("8"), + Memory: resource.MustParse("500Gi"), + EphemeralStorage: resource.MustParse("501Mi"), + Storage: resource.MustParse("450Mi"), + }, + }) + }) +} + +func TestFromAdminProtoTaskResourceSpec(t *testing.T) { + taskResourceSet := fromAdminProtoTaskResourceSpec(context.TODO(), &admin.TaskResourceSpec{ + Cpu: "1", + Memory: "100", + Storage: "200", + EphemeralStorage: "300", + Gpu: "2", + }) + assert.EqualValues(t, runtimeInterfaces.TaskResourceSet{ + CPU: resource.MustParse("1"), + Memory: resource.MustParse("100"), + Storage: resource.MustParse("200"), + EphemeralStorage: resource.MustParse("300"), + GPU: resource.MustParse("2"), + }, taskResourceSet) +} + +func TestGetTaskResourcesAsSet(t *testing.T) { + taskResources := getTaskResourcesAsSet(context.TODO(), &core.Identifier{}, []*core.Resources_ResourceEntry{ + { + Name: core.Resources_CPU, + Value: "100", + }, + { + Name: core.Resources_MEMORY, + Value: "200", + }, + { + Name: core.Resources_EPHEMERAL_STORAGE, + Value: "300", + }, + { + Name: core.Resources_GPU, + Value: "400", + }, + }, "request") + assert.True(t, taskResources.CPU.Equal(resource.MustParse("100"))) + assert.True(t, taskResources.Memory.Equal(resource.MustParse("200"))) + assert.True(t, taskResources.EphemeralStorage.Equal(resource.MustParse("300"))) + assert.True(t, taskResources.GPU.Equal(resource.MustParse("400"))) +} + +func TestGetCompleteTaskResourceRequirements(t *testing.T) { + taskResources := GetCompleteTaskResourceRequirements(context.TODO(), &core.Identifier{}, &core.CompiledTask{ + Template: &core.TaskTemplate{ + Target: &core.TaskTemplate_Container{ + Container: &core.Container{ + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + { + Name: core.Resources_CPU, + Value: "100", + }, + { + Name: core.Resources_MEMORY, + Value: "200", + }, + { + Name: core.Resources_EPHEMERAL_STORAGE, + Value: "300", + }, + { + Name: core.Resources_GPU, + Value: "400", + }, + }, + Limits: []*core.Resources_ResourceEntry{ + { + Name: core.Resources_CPU, + Value: "200", + }, + { + Name: core.Resources_MEMORY, + Value: "400", + }, + { + Name: core.Resources_EPHEMERAL_STORAGE, + Value: "600", + }, + { + Name: core.Resources_GPU, + Value: "800", + }, + }, + }, + }, + }, + }, + }) + + assert.True(t, taskResources.Defaults.CPU.Equal(resource.MustParse("100"))) + assert.True(t, taskResources.Defaults.Memory.Equal(resource.MustParse("200"))) + assert.True(t, taskResources.Defaults.EphemeralStorage.Equal(resource.MustParse("300"))) + assert.True(t, taskResources.Defaults.GPU.Equal(resource.MustParse("400"))) + + assert.True(t, taskResources.Limits.CPU.Equal(resource.MustParse("200"))) + assert.True(t, taskResources.Limits.Memory.Equal(resource.MustParse("400"))) + assert.True(t, taskResources.Limits.EphemeralStorage.Equal(resource.MustParse("600"))) + assert.True(t, taskResources.Limits.GPU.Equal(resource.MustParse("800"))) +} diff --git a/flyteadmin/pkg/manager/impl/validation/task_validator.go b/flyteadmin/pkg/manager/impl/validation/task_validator.go index c1b440b83b..c8625ec4bd 100644 --- a/flyteadmin/pkg/manager/impl/validation/task_validator.go +++ b/flyteadmin/pkg/manager/impl/validation/task_validator.go @@ -13,6 +13,7 @@ import ( "github.com/flyteorg/flyteadmin/pkg/manager/impl/shared" runtime "github.com/flyteorg/flyteadmin/pkg/runtime/interfaces" runtimeInterfaces "github.com/flyteorg/flyteadmin/pkg/runtime/interfaces" + workflowengineInterfaces "github.com/flyteorg/flyteadmin/pkg/workflowengine/interfaces" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flytestdlib/logger" @@ -25,7 +26,7 @@ import ( var whitelistedTaskErr = errors.NewFlyteAdminErrorf(codes.InvalidArgument, "task type must be whitelisted before use") // This is called for a task with a non-nil container. -func validateContainer(task core.TaskTemplate, taskConfig runtime.TaskResourceConfiguration) error { +func validateContainer(task core.TaskTemplate, platformTaskResources workflowengineInterfaces.TaskResources) error { if err := ValidateEmptyStringField(task.GetContainer().Image, shared.Image); err != nil { return err } @@ -33,7 +34,7 @@ func validateContainer(task core.TaskTemplate, taskConfig runtime.TaskResourceCo if task.GetContainer().Resources == nil { return nil } - if err := validateTaskResources(task.Id, taskConfig.GetLimits(), task.GetContainer().Resources.Requests, + if err := validateTaskResources(task.Id, platformTaskResources.Limits, task.GetContainer().Resources.Requests, task.GetContainer().Resources.Limits); err != nil { logger.Debugf(context.Background(), "encountered errors validating task resources for [%+v]: %v", task.Id, err) @@ -43,7 +44,7 @@ func validateContainer(task core.TaskTemplate, taskConfig runtime.TaskResourceCo } // This is called for a task with a non-nil k8s pod. -func validateK8sPod(task core.TaskTemplate, taskConfig runtime.TaskResourceConfiguration) error { +func validateK8sPod(task core.TaskTemplate, platformTaskResources workflowengineInterfaces.TaskResources) error { if task.GetK8SPod().PodSpec == nil { return errors.NewFlyteAdminErrorf(codes.InvalidArgument, "invalid TaskSpecification, pod tasks should specify their target as a K8sPod with a defined pod spec") @@ -54,7 +55,7 @@ func validateK8sPod(task core.TaskTemplate, taskConfig runtime.TaskResourceConfi task.GetK8SPod().PodSpec, err) return err } - platformTaskResourceLimits := taskResourceSetToMap(taskConfig.GetLimits()) + platformTaskResourceLimits := taskResourceSetToMap(platformTaskResources.Limits) for _, container := range podSpec.Containers { err := validateResource(task.Id, resourceListToQuantity(container.Resources.Requests), resourceListToQuantity(container.Resources.Limits), platformTaskResourceLimits) @@ -76,7 +77,7 @@ func validateRuntimeMetadata(metadata core.RuntimeMetadata) error { } func validateTaskTemplate(taskID core.Identifier, task core.TaskTemplate, - taskConfig runtime.TaskResourceConfiguration, whitelistConfig runtime.WhitelistConfiguration) error { + platformTaskResources workflowengineInterfaces.TaskResources, whitelistConfig runtime.WhitelistConfiguration) error { if err := ValidateEmptyStringField(task.Type, shared.Type); err != nil { return err @@ -98,17 +99,17 @@ func validateTaskTemplate(taskID core.Identifier, task core.TaskTemplate, } if task.GetContainer() != nil { - return validateContainer(task, taskConfig) + return validateContainer(task, platformTaskResources) } if task.GetK8SPod() != nil { - return validateK8sPod(task, taskConfig) + return validateK8sPod(task, platformTaskResources) } return nil } func ValidateTask( ctx context.Context, request admin.TaskCreateRequest, db repositoryInterfaces.Repository, - taskConfig runtime.TaskResourceConfiguration, whitelistConfig runtime.WhitelistConfiguration, + platformTaskResources workflowengineInterfaces.TaskResources, whitelistConfig runtime.WhitelistConfiguration, applicationConfig runtime.ApplicationConfiguration) error { if err := ValidateIdentifier(request.Id, common.Task); err != nil { return err @@ -119,7 +120,7 @@ func ValidateTask( if request.Spec == nil || request.Spec.Template == nil { return shared.GetMissingArgumentError(shared.Spec) } - return validateTaskTemplate(*request.Id, *request.Spec.Template, taskConfig, whitelistConfig) + return validateTaskTemplate(*request.Id, *request.Spec.Template, platformTaskResources, whitelistConfig) } func taskResourceSetToMap( diff --git a/flyteadmin/pkg/manager/impl/validation/task_validator_test.go b/flyteadmin/pkg/manager/impl/validation/task_validator_test.go index c221fb5eaa..78ec0309ce 100644 --- a/flyteadmin/pkg/manager/impl/validation/task_validator_test.go +++ b/flyteadmin/pkg/manager/impl/validation/task_validator_test.go @@ -6,28 +6,28 @@ import ( "errors" "testing" - "google.golang.org/protobuf/types/known/structpb" - - corev1 "k8s.io/api/core/v1" - - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" - "k8s.io/apimachinery/pkg/api/resource" - "github.com/flyteorg/flyteadmin/pkg/manager/impl/testutils" runtimeInterfaces "github.com/flyteorg/flyteadmin/pkg/runtime/interfaces" runtimeMocks "github.com/flyteorg/flyteadmin/pkg/runtime/mocks" + workflowengineInterfaces "github.com/flyteorg/flyteadmin/pkg/workflowengine/interfaces" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + "github.com/stretchr/testify/assert" + + "google.golang.org/protobuf/types/known/structpb" + + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/resource" ) -func getMockTaskConfigProvider() runtimeInterfaces.TaskResourceConfiguration { - var taskConfig = runtimeMocks.MockTaskResourceConfiguration{} - taskConfig.Limits = runtimeInterfaces.TaskResourceSet{ - Memory: resource.MustParse("500Mi"), - CPU: resource.MustParse("200m"), - GPU: resource.MustParse("8"), +func getMockTaskResources() workflowengineInterfaces.TaskResources { + return workflowengineInterfaces.TaskResources{ + Limits: runtimeInterfaces.TaskResourceSet{ + Memory: resource.MustParse("500Mi"), + CPU: resource.MustParse("200m"), + GPU: resource.MustParse("8"), + }, } - - return &taskConfig } var mockWhitelistConfigProvider = runtimeMocks.NewMockWhitelistConfiguration() @@ -47,26 +47,26 @@ func TestValidateTask(t *testing.T) { } request.Spec.Template.GetContainer().Resources = &core.Resources{Requests: resources} err := ValidateTask(context.Background(), request, testutils.GetRepoWithDefaultProject(), - getMockTaskConfigProvider(), mockWhitelistConfigProvider, taskApplicationConfigProvider) + getMockTaskResources(), mockWhitelistConfigProvider, taskApplicationConfigProvider) assert.EqualError(t, err, "Requested CPU default [1536Mi] is greater than current limit set in the platform configuration [200m]. Please contact Flyte Admins to change these limits or consult the configuration") request.Spec.Template.Target = &core.TaskTemplate_K8SPod{K8SPod: &core.K8SPod{}} err = ValidateTask(context.Background(), request, testutils.GetRepoWithDefaultProject(), - getMockTaskConfigProvider(), mockWhitelistConfigProvider, taskApplicationConfigProvider) + getMockTaskResources(), mockWhitelistConfigProvider, taskApplicationConfigProvider) assert.EqualError(t, err, "invalid TaskSpecification, pod tasks should specify their target as a K8sPod with a defined pod spec") resourceList := corev1.ResourceList{corev1.ResourceCPU: resource.MustParse("1.5Gi")} podSpec := &corev1.PodSpec{Containers: []corev1.Container{{Resources: corev1.ResourceRequirements{Requests: resourceList}}}} request.Spec.Template.Target = &core.TaskTemplate_K8SPod{K8SPod: &core.K8SPod{PodSpec: transformStructToStructPB(t, podSpec)}} err = ValidateTask(context.Background(), request, testutils.GetRepoWithDefaultProject(), - getMockTaskConfigProvider(), mockWhitelistConfigProvider, taskApplicationConfigProvider) + getMockTaskResources(), mockWhitelistConfigProvider, taskApplicationConfigProvider) assert.EqualError(t, err, "Requested CPU default [1536Mi] is greater than current limit set in the platform configuration [200m]. Please contact Flyte Admins to change these limits or consult the configuration") resourceList = corev1.ResourceList{corev1.ResourceCPU: resource.MustParse("200m")} podSpec = &corev1.PodSpec{Containers: []corev1.Container{{Resources: corev1.ResourceRequirements{Requests: resourceList}}}} request.Spec.Template.Target = &core.TaskTemplate_K8SPod{K8SPod: &core.K8SPod{PodSpec: transformStructToStructPB(t, podSpec)}} err = ValidateTask(context.Background(), request, testutils.GetRepoWithDefaultProject(), - getMockTaskConfigProvider(), mockWhitelistConfigProvider, taskApplicationConfigProvider) + getMockTaskResources(), mockWhitelistConfigProvider, taskApplicationConfigProvider) assert.Nil(t, err) } @@ -85,14 +85,14 @@ func TestValidateTaskEmptyProject(t *testing.T) { request := testutils.GetValidTaskRequest() request.Id.Project = "" err := ValidateTask(context.Background(), request, testutils.GetRepoWithDefaultProject(), - getMockTaskConfigProvider(), mockWhitelistConfigProvider, taskApplicationConfigProvider) + getMockTaskResources(), mockWhitelistConfigProvider, taskApplicationConfigProvider) assert.EqualError(t, err, "missing project") } func TestValidateTaskInvalidProjectAndDomain(t *testing.T) { request := testutils.GetValidTaskRequest() err := ValidateTask(context.Background(), request, testutils.GetRepoWithDefaultProjectAndErr(errors.New("foo")), - getMockTaskConfigProvider(), mockWhitelistConfigProvider, taskApplicationConfigProvider) + getMockTaskResources(), mockWhitelistConfigProvider, taskApplicationConfigProvider) assert.EqualError(t, err, "failed to validate that project [project] and domain [domain] are registered, err: [foo]") } @@ -100,7 +100,7 @@ func TestValidateTaskEmptyDomain(t *testing.T) { request := testutils.GetValidTaskRequest() request.Id.Domain = "" err := ValidateTask(context.Background(), request, testutils.GetRepoWithDefaultProject(), - getMockTaskConfigProvider(), mockWhitelistConfigProvider, taskApplicationConfigProvider) + getMockTaskResources(), mockWhitelistConfigProvider, taskApplicationConfigProvider) assert.EqualError(t, err, "missing domain") } @@ -108,7 +108,7 @@ func TestValidateTaskEmptyName(t *testing.T) { request := testutils.GetValidTaskRequest() request.Id.Name = "" err := ValidateTask(context.Background(), request, testutils.GetRepoWithDefaultProject(), - getMockTaskConfigProvider(), mockWhitelistConfigProvider, taskApplicationConfigProvider) + getMockTaskResources(), mockWhitelistConfigProvider, taskApplicationConfigProvider) assert.EqualError(t, err, "missing name") } @@ -116,7 +116,7 @@ func TestValidateTaskEmptyVersion(t *testing.T) { request := testutils.GetValidTaskRequest() request.Id.Version = "" err := ValidateTask(context.Background(), request, testutils.GetRepoWithDefaultProject(), - getMockTaskConfigProvider(), mockWhitelistConfigProvider, taskApplicationConfigProvider) + getMockTaskResources(), mockWhitelistConfigProvider, taskApplicationConfigProvider) assert.EqualError(t, err, "missing version") } @@ -124,7 +124,7 @@ func TestValidateTaskEmptyType(t *testing.T) { request := testutils.GetValidTaskRequest() request.Spec.Template.Type = "" err := ValidateTask(context.Background(), request, testutils.GetRepoWithDefaultProject(), - getMockTaskConfigProvider(), mockWhitelistConfigProvider, taskApplicationConfigProvider) + getMockTaskResources(), mockWhitelistConfigProvider, taskApplicationConfigProvider) assert.EqualError(t, err, "missing type") } @@ -132,7 +132,7 @@ func TestValidateTaskEmptyMetadata(t *testing.T) { request := testutils.GetValidTaskRequest() request.Spec.Template.Metadata = nil err := ValidateTask(context.Background(), request, testutils.GetRepoWithDefaultProject(), - getMockTaskConfigProvider(), mockWhitelistConfigProvider, taskApplicationConfigProvider) + getMockTaskResources(), mockWhitelistConfigProvider, taskApplicationConfigProvider) assert.EqualError(t, err, "missing metadata") } @@ -140,7 +140,7 @@ func TestValidateTaskEmptyRuntimeVersion(t *testing.T) { request := testutils.GetValidTaskRequest() request.Spec.Template.Metadata.Runtime.Version = "" err := ValidateTask(context.Background(), request, testutils.GetRepoWithDefaultProject(), - getMockTaskConfigProvider(), mockWhitelistConfigProvider, taskApplicationConfigProvider) + getMockTaskResources(), mockWhitelistConfigProvider, taskApplicationConfigProvider) assert.EqualError(t, err, "missing runtime version") } @@ -148,7 +148,7 @@ func TestValidateTaskEmptyTypedInterface(t *testing.T) { request := testutils.GetValidTaskRequest() request.Spec.Template.Interface = nil err := ValidateTask(context.Background(), request, testutils.GetRepoWithDefaultProject(), - getMockTaskConfigProvider(), mockWhitelistConfigProvider, taskApplicationConfigProvider) + getMockTaskResources(), mockWhitelistConfigProvider, taskApplicationConfigProvider) assert.EqualError(t, err, "missing typed interface") } @@ -156,7 +156,7 @@ func TestValidateTaskEmptyContainer(t *testing.T) { request := testutils.GetValidTaskRequest() request.Spec.Template.Target = nil err := ValidateTask(context.Background(), request, testutils.GetRepoWithDefaultProject(), - getMockTaskConfigProvider(), mockWhitelistConfigProvider, taskApplicationConfigProvider) + getMockTaskResources(), mockWhitelistConfigProvider, taskApplicationConfigProvider) assert.Nil(t, err) } @@ -164,7 +164,7 @@ func TestValidateTaskEmptyImage(t *testing.T) { request := testutils.GetValidTaskRequest() request.Spec.Template.GetContainer().Image = "" err := ValidateTask(context.Background(), request, testutils.GetRepoWithDefaultProject(), - getMockTaskConfigProvider(), mockWhitelistConfigProvider, taskApplicationConfigProvider) + getMockTaskResources(), mockWhitelistConfigProvider, taskApplicationConfigProvider) assert.EqualError(t, err, "missing image") }