diff --git a/flyteadmin/pkg/manager/impl/validation/task_validator.go b/flyteadmin/pkg/manager/impl/validation/task_validator.go index c7b92ceb3c..c1b440b83b 100644 --- a/flyteadmin/pkg/manager/impl/validation/task_validator.go +++ b/flyteadmin/pkg/manager/impl/validation/task_validator.go @@ -3,8 +3,10 @@ package validation import ( "context" + "strings" repositoryInterfaces "github.com/flyteorg/flyteadmin/pkg/repositories/interfaces" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/utils" "github.com/flyteorg/flyteadmin/pkg/common" "github.com/flyteorg/flyteadmin/pkg/errors" @@ -14,6 +16,7 @@ import ( "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flytestdlib/logger" + corev1 "k8s.io/api/core/v1" "google.golang.org/grpc/codes" "k8s.io/apimachinery/pkg/api/resource" @@ -21,11 +24,6 @@ import ( var whitelistedTaskErr = errors.NewFlyteAdminErrorf(codes.InvalidArgument, "task type must be whitelisted before use") -// Sidecar tasks do not necessarily define a primary container for execution and are excluded from container validation. -var containerlessTaskTypes = map[string]bool{ - "sidecar": true, -} - // This is called for a task with a non-nil container. func validateContainer(task core.TaskTemplate, taskConfig runtime.TaskResourceConfiguration) error { if err := ValidateEmptyStringField(task.GetContainer().Image, shared.Image); err != nil { @@ -44,6 +42,32 @@ func validateContainer(task core.TaskTemplate, taskConfig runtime.TaskResourceCo return nil } +// This is called for a task with a non-nil k8s pod. +func validateK8sPod(task core.TaskTemplate, taskConfig runtime.TaskResourceConfiguration) error { + if task.GetK8SPod().PodSpec == nil { + return errors.NewFlyteAdminErrorf(codes.InvalidArgument, + "invalid TaskSpecification, pod tasks should specify their target as a K8sPod with a defined pod spec") + } + var podSpec corev1.PodSpec + if err := utils.UnmarshalStructToObj(task.GetK8SPod().PodSpec, &podSpec); err != nil { + logger.Debugf(context.Background(), "failed to unmarshal k8s podspec [%+v]: %v", + task.GetK8SPod().PodSpec, err) + return err + } + platformTaskResourceLimits := taskResourceSetToMap(taskConfig.GetLimits()) + for _, container := range podSpec.Containers { + err := validateResource(task.Id, resourceListToQuantity(container.Resources.Requests), + resourceListToQuantity(container.Resources.Limits), platformTaskResourceLimits) + if err != nil { + logger.Debugf(context.Background(), "encountered errors validating task resources for [%+v]: %v", + task.Id, err) + return err + } + } + + return nil +} + func validateRuntimeMetadata(metadata core.RuntimeMetadata) error { if err := ValidateEmptyStringField(metadata.Version, shared.RuntimeVersion); err != nil { return err @@ -53,6 +77,7 @@ func validateRuntimeMetadata(metadata core.RuntimeMetadata) error { func validateTaskTemplate(taskID core.Identifier, task core.TaskTemplate, taskConfig runtime.TaskResourceConfiguration, whitelistConfig runtime.WhitelistConfiguration) error { + if err := ValidateEmptyStringField(task.Type, shared.Type); err != nil { return err } @@ -71,13 +96,13 @@ func validateTaskTemplate(taskID core.Identifier, task core.TaskTemplate, // The actual interface proto has nothing to validate. return shared.GetMissingArgumentError(shared.TypedInterface) } - if containerlessTaskTypes[task.Type] { - // Nothing left to validate - return nil - } + if task.GetContainer() != nil { return validateContainer(task, taskConfig) } + if task.GetK8SPod() != nil { + return validateK8sPod(task, taskConfig) + } return nil } @@ -138,6 +163,15 @@ func isWholeNumber(quantity resource.Quantity) bool { return quantity.MilliValue()%1000 == 0 } +func resourceListToQuantity(resources corev1.ResourceList) map[core.Resources_ResourceName]resource.Quantity { + var requestedToQuantity = make(map[core.Resources_ResourceName]resource.Quantity) + for name, quantity := range resources { + resourceName := core.Resources_ResourceName(core.Resources_ResourceName_value[strings.ToUpper(name.String())]) + requestedToQuantity[resourceName] = quantity + } + return requestedToQuantity +} + func requestedResourcesToQuantity( identifier *core.Identifier, resources []*core.Resources_ResourceEntry) ( map[core.Resources_ResourceName]resource.Quantity, error) { @@ -188,6 +222,12 @@ func validateTaskResources( platformTaskResourceLimits := taskResourceSetToMap(taskResourceLimits) + return validateResource(identifier, requestedResourceDefaults, requestedResourceLimits, platformTaskResourceLimits) +} + +func validateResource(identifier *core.Identifier, requestedResourceDefaults, + requestedResourceLimits map[core.Resources_ResourceName]resource.Quantity, + platformTaskResourceLimits map[core.Resources_ResourceName]*resource.Quantity) error { for resourceName, defaultQuantity := range requestedResourceDefaults { switch resourceName { case core.Resources_CPU: @@ -197,7 +237,7 @@ func validateTaskResources( case core.Resources_MEMORY: limitQuantity, ok := requestedResourceLimits[resourceName] if ok && limitQuantity.Value() < defaultQuantity.Value() { - // Only assert the requested limit is greater than than the requested default when the limit is actually set + // Only assert the requested limit is greater than the requested default when the limit is actually set return errors.NewFlyteAdminErrorf(codes.InvalidArgument, "Requested %v default [%v] is greater than the limit [%v]."+ " Please fix your configuration", resourceName, defaultQuantity.String(), limitQuantity.String()) @@ -213,7 +253,7 @@ func validateTaskResources( if platformLimitOk && defaultQuantity.Value() > platformTaskResourceLimits[resourceName].Value() { // Also check that the requested limit is less than the platform task limit. return errors.NewFlyteAdminErrorf(codes.InvalidArgument, - "Requested %v default [%v] is greater than current limit set in the platform configuration"+ + "Requested %v default [%v] is greater than current limit set in the platform configuration"+ " [%v]. Please contact Flyte Admins to change these limits or consult the configuration", resourceName, defaultQuantity.String(), platformTaskResourceLimits[resourceName].String()) } @@ -227,13 +267,12 @@ func validateTaskResources( platformLimit, platformLimitOk := platformTaskResourceLimits[resourceName] if platformLimitOk && defaultQuantity.Value() > platformLimit.Value() { return errors.NewFlyteAdminErrorf(codes.InvalidArgument, - "Requested %v default [%v] is greater than current limit set in the platform configuration"+ + "Requested %v default [%v] is greater than current limit set in the platform configuration"+ " [%v]. Please contact Flyte Admins to change these limits or consult the configuration", resourceName, defaultQuantity.String(), platformLimit.String()) } } } - return nil } diff --git a/flyteadmin/pkg/manager/impl/validation/task_validator_test.go b/flyteadmin/pkg/manager/impl/validation/task_validator_test.go index 177f93aab2..c221fb5eaa 100644 --- a/flyteadmin/pkg/manager/impl/validation/task_validator_test.go +++ b/flyteadmin/pkg/manager/impl/validation/task_validator_test.go @@ -2,9 +2,14 @@ package validation import ( "context" + "encoding/json" "errors" "testing" + "google.golang.org/protobuf/types/known/structpb" + + corev1 "k8s.io/api/core/v1" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "k8s.io/apimachinery/pkg/api/resource" @@ -28,6 +33,54 @@ func getMockTaskConfigProvider() runtimeInterfaces.TaskResourceConfiguration { var mockWhitelistConfigProvider = runtimeMocks.NewMockWhitelistConfiguration() var taskApplicationConfigProvider = testutils.GetApplicationConfigWithDefaultDomains() +func TestValidateTask(t *testing.T) { + request := testutils.GetValidTaskRequest() + resources := []*core.Resources_ResourceEntry{ + { + Name: core.Resources_CPU, + Value: "1.5Gi", + }, + { + Name: core.Resources_MEMORY, + Value: "200m", + }, + } + request.Spec.Template.GetContainer().Resources = &core.Resources{Requests: resources} + err := ValidateTask(context.Background(), request, testutils.GetRepoWithDefaultProject(), + getMockTaskConfigProvider(), mockWhitelistConfigProvider, taskApplicationConfigProvider) + assert.EqualError(t, err, "Requested CPU default [1536Mi] is greater than current limit set in the platform configuration [200m]. Please contact Flyte Admins to change these limits or consult the configuration") + + request.Spec.Template.Target = &core.TaskTemplate_K8SPod{K8SPod: &core.K8SPod{}} + err = ValidateTask(context.Background(), request, testutils.GetRepoWithDefaultProject(), + getMockTaskConfigProvider(), mockWhitelistConfigProvider, taskApplicationConfigProvider) + assert.EqualError(t, err, "invalid TaskSpecification, pod tasks should specify their target as a K8sPod with a defined pod spec") + + resourceList := corev1.ResourceList{corev1.ResourceCPU: resource.MustParse("1.5Gi")} + podSpec := &corev1.PodSpec{Containers: []corev1.Container{{Resources: corev1.ResourceRequirements{Requests: resourceList}}}} + request.Spec.Template.Target = &core.TaskTemplate_K8SPod{K8SPod: &core.K8SPod{PodSpec: transformStructToStructPB(t, podSpec)}} + err = ValidateTask(context.Background(), request, testutils.GetRepoWithDefaultProject(), + getMockTaskConfigProvider(), mockWhitelistConfigProvider, taskApplicationConfigProvider) + assert.EqualError(t, err, "Requested CPU default [1536Mi] is greater than current limit set in the platform configuration [200m]. Please contact Flyte Admins to change these limits or consult the configuration") + + resourceList = corev1.ResourceList{corev1.ResourceCPU: resource.MustParse("200m")} + podSpec = &corev1.PodSpec{Containers: []corev1.Container{{Resources: corev1.ResourceRequirements{Requests: resourceList}}}} + request.Spec.Template.Target = &core.TaskTemplate_K8SPod{K8SPod: &core.K8SPod{PodSpec: transformStructToStructPB(t, podSpec)}} + err = ValidateTask(context.Background(), request, testutils.GetRepoWithDefaultProject(), + getMockTaskConfigProvider(), mockWhitelistConfigProvider, taskApplicationConfigProvider) + assert.Nil(t, err) +} + +func transformStructToStructPB(t *testing.T, obj interface{}) *structpb.Struct { + data, err := json.Marshal(obj) + assert.Nil(t, err) + podSpecMap := make(map[string]interface{}) + err = json.Unmarshal(data, &podSpecMap) + assert.Nil(t, err) + s, err := structpb.NewStruct(podSpecMap) + assert.Nil(t, err) + return s +} + func TestValidateTaskEmptyProject(t *testing.T) { request := testutils.GetValidTaskRequest() request.Id.Project = "" @@ -217,6 +270,18 @@ func TestAddResourceEntryToMap(t *testing.T) { assert.Equal(t, val, int64(104857600), "Existing values in the resource entry map should not be overwritten") } +func TestResourceListToQuantity(t *testing.T) { + cpuResources := resourceListToQuantity(corev1.ResourceList{corev1.ResourceCPU: resource.MustParse("100Mi")}) + cpuQuantity := cpuResources[core.Resources_CPU] + val := cpuQuantity.Value() + assert.Equal(t, val, int64(104857600)) + + gpuResources := resourceListToQuantity(corev1.ResourceList{corev1.ResourceCPU: resource.MustParse("2")}) + gpuQuantity := gpuResources[core.Resources_CPU] + val = gpuQuantity.Value() + assert.Equal(t, val, int64(2)) +} + func TestRequestedResourcesToQuantity(t *testing.T) { resources, err := requestedResourcesToQuantity(&core.Identifier{}, []*core.Resources_ResourceEntry{ { @@ -357,7 +422,7 @@ func TestValidateTaskResources_DefaultGreaterThanConfig(t *testing.T) { Value: "1.5Gi", }, }, []*core.Resources_ResourceEntry{}) - assert.EqualError(t, err, "Requested CPU default [1536Mi] is greater than current limit set in the platform configuration [1Gi]. Please contact Flyte Admins to change these limits or consult the configuration") + assert.EqualError(t, err, "Requested CPU default [1536Mi] is greater than current limit set in the platform configuration [1Gi]. Please contact Flyte Admins to change these limits or consult the configuration") } func TestValidateTaskResources_GPULimitNotEqualToRequested(t *testing.T) { @@ -396,7 +461,7 @@ func TestValidateTaskResources_GPULimitGreaterThanConfig(t *testing.T) { Value: "2", }, }) - assert.EqualError(t, err, "Requested GPU default [2] is greater than current limit set in the platform configuration [1]. Please contact Flyte Admins to change these limits or consult the configuration") + assert.EqualError(t, err, "Requested GPU default [2] is greater than current limit set in the platform configuration [1]. Please contact Flyte Admins to change these limits or consult the configuration") } func TestValidateTaskResources_GPUDefaultGreaterThanConfig(t *testing.T) { @@ -411,7 +476,7 @@ func TestValidateTaskResources_GPUDefaultGreaterThanConfig(t *testing.T) { Value: "2", }, }, []*core.Resources_ResourceEntry{}) - assert.EqualError(t, err, "Requested GPU default [2] is greater than current limit set in the platform configuration [1]. Please contact Flyte Admins to change these limits or consult the configuration") + assert.EqualError(t, err, "Requested GPU default [2] is greater than current limit set in the platform configuration [1]. Please contact Flyte Admins to change these limits or consult the configuration") } func TestIsWholeNumber(t *testing.T) { diff --git a/flyteadmin/pkg/repositories/database_integration_test.go b/flyteadmin/pkg/repositories/database_integration_test.go index 397b5df608..26bef38258 100644 --- a/flyteadmin/pkg/repositories/database_integration_test.go +++ b/flyteadmin/pkg/repositories/database_integration_test.go @@ -1,3 +1,4 @@ +//go:build integration // +build integration package repositories