Skip to content

Commit

Permalink
Validate k8s pods resource (flyteorg#356)
Browse files Browse the repository at this point in the history
* Validate k8s pods resource

Signed-off-by: Kevin Su <[email protected]>

* Updated tests

Signed-off-by: Kevin Su <[email protected]>

* Address comment

Signed-off-by: Kevin Su <[email protected]>
  • Loading branch information
pingsutw authored Mar 2, 2022
1 parent a5014ae commit 1b7960f
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 16 deletions.
65 changes: 52 additions & 13 deletions flyteadmin/pkg/manager/impl/validation/task_validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ package validation

import (
"context"
"strings"

repositoryInterfaces "github.com/flyteorg/flyteadmin/pkg/repositories/interfaces"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/utils"

"github.com/flyteorg/flyteadmin/pkg/common"
"github.com/flyteorg/flyteadmin/pkg/errors"
Expand All @@ -14,18 +16,14 @@ import (
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin"
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flytestdlib/logger"
corev1 "k8s.io/api/core/v1"

"google.golang.org/grpc/codes"
"k8s.io/apimachinery/pkg/api/resource"
)

var whitelistedTaskErr = errors.NewFlyteAdminErrorf(codes.InvalidArgument, "task type must be whitelisted before use")

// Sidecar tasks do not necessarily define a primary container for execution and are excluded from container validation.
var containerlessTaskTypes = map[string]bool{
"sidecar": true,
}

// This is called for a task with a non-nil container.
func validateContainer(task core.TaskTemplate, taskConfig runtime.TaskResourceConfiguration) error {
if err := ValidateEmptyStringField(task.GetContainer().Image, shared.Image); err != nil {
Expand All @@ -44,6 +42,32 @@ func validateContainer(task core.TaskTemplate, taskConfig runtime.TaskResourceCo
return nil
}

// This is called for a task with a non-nil k8s pod.
func validateK8sPod(task core.TaskTemplate, taskConfig runtime.TaskResourceConfiguration) error {
if task.GetK8SPod().PodSpec == nil {
return errors.NewFlyteAdminErrorf(codes.InvalidArgument,
"invalid TaskSpecification, pod tasks should specify their target as a K8sPod with a defined pod spec")
}
var podSpec corev1.PodSpec
if err := utils.UnmarshalStructToObj(task.GetK8SPod().PodSpec, &podSpec); err != nil {
logger.Debugf(context.Background(), "failed to unmarshal k8s podspec [%+v]: %v",
task.GetK8SPod().PodSpec, err)
return err
}
platformTaskResourceLimits := taskResourceSetToMap(taskConfig.GetLimits())
for _, container := range podSpec.Containers {
err := validateResource(task.Id, resourceListToQuantity(container.Resources.Requests),
resourceListToQuantity(container.Resources.Limits), platformTaskResourceLimits)
if err != nil {
logger.Debugf(context.Background(), "encountered errors validating task resources for [%+v]: %v",
task.Id, err)
return err
}
}

return nil
}

func validateRuntimeMetadata(metadata core.RuntimeMetadata) error {
if err := ValidateEmptyStringField(metadata.Version, shared.RuntimeVersion); err != nil {
return err
Expand All @@ -53,6 +77,7 @@ func validateRuntimeMetadata(metadata core.RuntimeMetadata) error {

func validateTaskTemplate(taskID core.Identifier, task core.TaskTemplate,
taskConfig runtime.TaskResourceConfiguration, whitelistConfig runtime.WhitelistConfiguration) error {

if err := ValidateEmptyStringField(task.Type, shared.Type); err != nil {
return err
}
Expand All @@ -71,13 +96,13 @@ func validateTaskTemplate(taskID core.Identifier, task core.TaskTemplate,
// The actual interface proto has nothing to validate.
return shared.GetMissingArgumentError(shared.TypedInterface)
}
if containerlessTaskTypes[task.Type] {
// Nothing left to validate
return nil
}

if task.GetContainer() != nil {
return validateContainer(task, taskConfig)
}
if task.GetK8SPod() != nil {
return validateK8sPod(task, taskConfig)
}
return nil
}

Expand Down Expand Up @@ -138,6 +163,15 @@ func isWholeNumber(quantity resource.Quantity) bool {
return quantity.MilliValue()%1000 == 0
}

func resourceListToQuantity(resources corev1.ResourceList) map[core.Resources_ResourceName]resource.Quantity {
var requestedToQuantity = make(map[core.Resources_ResourceName]resource.Quantity)
for name, quantity := range resources {
resourceName := core.Resources_ResourceName(core.Resources_ResourceName_value[strings.ToUpper(name.String())])
requestedToQuantity[resourceName] = quantity
}
return requestedToQuantity
}

func requestedResourcesToQuantity(
identifier *core.Identifier, resources []*core.Resources_ResourceEntry) (
map[core.Resources_ResourceName]resource.Quantity, error) {
Expand Down Expand Up @@ -188,6 +222,12 @@ func validateTaskResources(

platformTaskResourceLimits := taskResourceSetToMap(taskResourceLimits)

return validateResource(identifier, requestedResourceDefaults, requestedResourceLimits, platformTaskResourceLimits)
}

func validateResource(identifier *core.Identifier, requestedResourceDefaults,
requestedResourceLimits map[core.Resources_ResourceName]resource.Quantity,
platformTaskResourceLimits map[core.Resources_ResourceName]*resource.Quantity) error {
for resourceName, defaultQuantity := range requestedResourceDefaults {
switch resourceName {
case core.Resources_CPU:
Expand All @@ -197,7 +237,7 @@ func validateTaskResources(
case core.Resources_MEMORY:
limitQuantity, ok := requestedResourceLimits[resourceName]
if ok && limitQuantity.Value() < defaultQuantity.Value() {
// Only assert the requested limit is greater than than the requested default when the limit is actually set
// Only assert the requested limit is greater than the requested default when the limit is actually set
return errors.NewFlyteAdminErrorf(codes.InvalidArgument,
"Requested %v default [%v] is greater than the limit [%v]."+
" Please fix your configuration", resourceName, defaultQuantity.String(), limitQuantity.String())
Expand All @@ -213,7 +253,7 @@ func validateTaskResources(
if platformLimitOk && defaultQuantity.Value() > platformTaskResourceLimits[resourceName].Value() {
// Also check that the requested limit is less than the platform task limit.
return errors.NewFlyteAdminErrorf(codes.InvalidArgument,
"Requested %v default [%v] is greater than current limit set in the platform configuration"+
"Requested %v default [%v] is greater than current limit set in the platform configuration"+
" [%v]. Please contact Flyte Admins to change these limits or consult the configuration",
resourceName, defaultQuantity.String(), platformTaskResourceLimits[resourceName].String())
}
Expand All @@ -227,13 +267,12 @@ func validateTaskResources(
platformLimit, platformLimitOk := platformTaskResourceLimits[resourceName]
if platformLimitOk && defaultQuantity.Value() > platformLimit.Value() {
return errors.NewFlyteAdminErrorf(codes.InvalidArgument,
"Requested %v default [%v] is greater than current limit set in the platform configuration"+
"Requested %v default [%v] is greater than current limit set in the platform configuration"+
" [%v]. Please contact Flyte Admins to change these limits or consult the configuration",
resourceName, defaultQuantity.String(), platformLimit.String())
}
}
}

return nil
}

Expand Down
71 changes: 68 additions & 3 deletions flyteadmin/pkg/manager/impl/validation/task_validator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,14 @@ package validation

import (
"context"
"encoding/json"
"errors"
"testing"

"google.golang.org/protobuf/types/known/structpb"

corev1 "k8s.io/api/core/v1"

"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
"k8s.io/apimachinery/pkg/api/resource"

Expand All @@ -28,6 +33,54 @@ func getMockTaskConfigProvider() runtimeInterfaces.TaskResourceConfiguration {
var mockWhitelistConfigProvider = runtimeMocks.NewMockWhitelistConfiguration()
var taskApplicationConfigProvider = testutils.GetApplicationConfigWithDefaultDomains()

func TestValidateTask(t *testing.T) {
request := testutils.GetValidTaskRequest()
resources := []*core.Resources_ResourceEntry{
{
Name: core.Resources_CPU,
Value: "1.5Gi",
},
{
Name: core.Resources_MEMORY,
Value: "200m",
},
}
request.Spec.Template.GetContainer().Resources = &core.Resources{Requests: resources}
err := ValidateTask(context.Background(), request, testutils.GetRepoWithDefaultProject(),
getMockTaskConfigProvider(), mockWhitelistConfigProvider, taskApplicationConfigProvider)
assert.EqualError(t, err, "Requested CPU default [1536Mi] is greater than current limit set in the platform configuration [200m]. Please contact Flyte Admins to change these limits or consult the configuration")

request.Spec.Template.Target = &core.TaskTemplate_K8SPod{K8SPod: &core.K8SPod{}}
err = ValidateTask(context.Background(), request, testutils.GetRepoWithDefaultProject(),
getMockTaskConfigProvider(), mockWhitelistConfigProvider, taskApplicationConfigProvider)
assert.EqualError(t, err, "invalid TaskSpecification, pod tasks should specify their target as a K8sPod with a defined pod spec")

resourceList := corev1.ResourceList{corev1.ResourceCPU: resource.MustParse("1.5Gi")}
podSpec := &corev1.PodSpec{Containers: []corev1.Container{{Resources: corev1.ResourceRequirements{Requests: resourceList}}}}
request.Spec.Template.Target = &core.TaskTemplate_K8SPod{K8SPod: &core.K8SPod{PodSpec: transformStructToStructPB(t, podSpec)}}
err = ValidateTask(context.Background(), request, testutils.GetRepoWithDefaultProject(),
getMockTaskConfigProvider(), mockWhitelistConfigProvider, taskApplicationConfigProvider)
assert.EqualError(t, err, "Requested CPU default [1536Mi] is greater than current limit set in the platform configuration [200m]. Please contact Flyte Admins to change these limits or consult the configuration")

resourceList = corev1.ResourceList{corev1.ResourceCPU: resource.MustParse("200m")}
podSpec = &corev1.PodSpec{Containers: []corev1.Container{{Resources: corev1.ResourceRequirements{Requests: resourceList}}}}
request.Spec.Template.Target = &core.TaskTemplate_K8SPod{K8SPod: &core.K8SPod{PodSpec: transformStructToStructPB(t, podSpec)}}
err = ValidateTask(context.Background(), request, testutils.GetRepoWithDefaultProject(),
getMockTaskConfigProvider(), mockWhitelistConfigProvider, taskApplicationConfigProvider)
assert.Nil(t, err)
}

func transformStructToStructPB(t *testing.T, obj interface{}) *structpb.Struct {
data, err := json.Marshal(obj)
assert.Nil(t, err)
podSpecMap := make(map[string]interface{})
err = json.Unmarshal(data, &podSpecMap)
assert.Nil(t, err)
s, err := structpb.NewStruct(podSpecMap)
assert.Nil(t, err)
return s
}

func TestValidateTaskEmptyProject(t *testing.T) {
request := testutils.GetValidTaskRequest()
request.Id.Project = ""
Expand Down Expand Up @@ -217,6 +270,18 @@ func TestAddResourceEntryToMap(t *testing.T) {
assert.Equal(t, val, int64(104857600), "Existing values in the resource entry map should not be overwritten")
}

func TestResourceListToQuantity(t *testing.T) {
cpuResources := resourceListToQuantity(corev1.ResourceList{corev1.ResourceCPU: resource.MustParse("100Mi")})
cpuQuantity := cpuResources[core.Resources_CPU]
val := cpuQuantity.Value()
assert.Equal(t, val, int64(104857600))

gpuResources := resourceListToQuantity(corev1.ResourceList{corev1.ResourceCPU: resource.MustParse("2")})
gpuQuantity := gpuResources[core.Resources_CPU]
val = gpuQuantity.Value()
assert.Equal(t, val, int64(2))
}

func TestRequestedResourcesToQuantity(t *testing.T) {
resources, err := requestedResourcesToQuantity(&core.Identifier{}, []*core.Resources_ResourceEntry{
{
Expand Down Expand Up @@ -357,7 +422,7 @@ func TestValidateTaskResources_DefaultGreaterThanConfig(t *testing.T) {
Value: "1.5Gi",
},
}, []*core.Resources_ResourceEntry{})
assert.EqualError(t, err, "Requested CPU default [1536Mi] is greater than current limit set in the platform configuration [1Gi]. Please contact Flyte Admins to change these limits or consult the configuration")
assert.EqualError(t, err, "Requested CPU default [1536Mi] is greater than current limit set in the platform configuration [1Gi]. Please contact Flyte Admins to change these limits or consult the configuration")
}

func TestValidateTaskResources_GPULimitNotEqualToRequested(t *testing.T) {
Expand Down Expand Up @@ -396,7 +461,7 @@ func TestValidateTaskResources_GPULimitGreaterThanConfig(t *testing.T) {
Value: "2",
},
})
assert.EqualError(t, err, "Requested GPU default [2] is greater than current limit set in the platform configuration [1]. Please contact Flyte Admins to change these limits or consult the configuration")
assert.EqualError(t, err, "Requested GPU default [2] is greater than current limit set in the platform configuration [1]. Please contact Flyte Admins to change these limits or consult the configuration")
}

func TestValidateTaskResources_GPUDefaultGreaterThanConfig(t *testing.T) {
Expand All @@ -411,7 +476,7 @@ func TestValidateTaskResources_GPUDefaultGreaterThanConfig(t *testing.T) {
Value: "2",
},
}, []*core.Resources_ResourceEntry{})
assert.EqualError(t, err, "Requested GPU default [2] is greater than current limit set in the platform configuration [1]. Please contact Flyte Admins to change these limits or consult the configuration")
assert.EqualError(t, err, "Requested GPU default [2] is greater than current limit set in the platform configuration [1]. Please contact Flyte Admins to change these limits or consult the configuration")
}

func TestIsWholeNumber(t *testing.T) {
Expand Down
1 change: 1 addition & 0 deletions flyteadmin/pkg/repositories/database_integration_test.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
//go:build integration
// +build integration

package repositories
Expand Down

0 comments on commit 1b7960f

Please sign in to comment.