From e26fa93d78773010e3e0b0ab2e5a805a1de9d714 Mon Sep 17 00:00:00 2001 From: Katrina Rogan Date: Fri, 8 Oct 2021 14:39:11 -0700 Subject: [PATCH] #minor Simplify and revisit task resource assignment (#215) --- go/tasks/config_load_test.go | 8 +- .../pluginmachinery/core/exec_metadata.go | 1 + .../core/mocks/task_execution_metadata.go | 38 ++- .../pluginmachinery/flytek8s/config/config.go | 19 +- .../flytek8s/config/k8spluginconfig_flags.go | 2 - .../config/k8spluginconfig_flags_test.go | 28 -- .../flytek8s/container_helper.go | 207 +++++++++--- .../flytek8s/container_helper_test.go | 307 +++++++++++++++--- .../pluginmachinery/flytek8s/pod_helper.go | 4 +- .../flytek8s/pod_helper_test.go | 9 +- .../resourcecustomizationmode_enumer.go | 51 +++ .../plugins/array/awsbatch/launcher_test.go | 1 + .../plugins/array/awsbatch/transformer.go | 10 +- .../array/awsbatch/transformer_test.go | 1 + go/tasks/plugins/array/k8s/monitor_test.go | 1 + go/tasks/plugins/array/k8s/transformer.go | 2 +- .../plugins/array/k8s/transformer_test.go | 1 + .../plugins/k8s/container/container_test.go | 1 + .../k8s/kfoperators/pytorch/pytorch_test.go | 18 +- .../kfoperators/tensorflow/tensorflow_test.go | 1 + go/tasks/plugins/k8s/sidecar/sidecar.go | 4 +- go/tasks/plugins/k8s/sidecar/sidecar_test.go | 35 +- tests/end_to_end.go | 1 + 23 files changed, 602 insertions(+), 148 deletions(-) create mode 100644 go/tasks/pluginmachinery/flytek8s/resourcecustomizationmode_enumer.go diff --git a/go/tasks/config_load_test.go b/go/tasks/config_load_test.go index d5ac340ec..ac4039857 100755 --- a/go/tasks/config_load_test.go +++ b/go/tasks/config_load_test.go @@ -4,6 +4,8 @@ import ( "context" "testing" + "k8s.io/apimachinery/pkg/api/resource" + sagemakerConfig "github.com/flyteorg/flyteplugins/go/tasks/plugins/k8s/sagemaker/config" "github.com/flyteorg/flytestdlib/config" @@ -68,8 +70,10 @@ func TestLoadConfig(t *testing.T) { assert.Equal(t, []v1.Toleration{tolGPU}, k8sConfig.ResourceTolerations[v1.ResourceName("nvidia.com/gpu")]) assert.Equal(t, []v1.Toleration{tolStorage}, k8sConfig.ResourceTolerations[v1.ResourceStorage]) - assert.Equal(t, "1000m", k8sConfig.DefaultCPURequest) - assert.Equal(t, "1024Mi", k8sConfig.DefaultMemoryRequest) + expectedCPU := resource.MustParse("1000m") + assert.True(t, expectedCPU.Equal(k8sConfig.DefaultCPURequest)) + expectedMemory := resource.MustParse("1024Mi") + assert.True(t, expectedMemory.Equal(k8sConfig.DefaultMemoryRequest)) assert.Equal(t, map[string]string{"x/interruptible": "true"}, k8sConfig.InterruptibleNodeSelector) assert.Equal(t, "x/flyte", k8sConfig.InterruptibleTolerations[0].Key) assert.Equal(t, "interruptible", k8sConfig.InterruptibleTolerations[0].Value) diff --git a/go/tasks/pluginmachinery/core/exec_metadata.go b/go/tasks/pluginmachinery/core/exec_metadata.go index ccde9c2c7..9419bd12a 100644 --- a/go/tasks/pluginmachinery/core/exec_metadata.go +++ b/go/tasks/pluginmachinery/core/exec_metadata.go @@ -35,4 +35,5 @@ type TaskExecutionMetadata interface { GetK8sServiceAccount() string GetSecurityContext() core.SecurityContext IsInterruptible() bool + GetPlatformResources() *v1.ResourceRequirements } diff --git a/go/tasks/pluginmachinery/core/mocks/task_execution_metadata.go b/go/tasks/pluginmachinery/core/mocks/task_execution_metadata.go index 7ec516337..28c05c274 100644 --- a/go/tasks/pluginmachinery/core/mocks/task_execution_metadata.go +++ b/go/tasks/pluginmachinery/core/mocks/task_execution_metadata.go @@ -3,8 +3,10 @@ package mocks import ( - flyteidlcore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" core "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" + corev1 "k8s.io/api/core/v1" + + flyteidlcore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" mock "github.com/stretchr/testify/mock" @@ -280,6 +282,40 @@ func (_m *TaskExecutionMetadata) GetOwnerReference() v1.OwnerReference { return r0 } +type TaskExecutionMetadata_GetPlatformResources struct { + *mock.Call +} + +func (_m TaskExecutionMetadata_GetPlatformResources) Return(_a0 *corev1.ResourceRequirements) *TaskExecutionMetadata_GetPlatformResources { + return &TaskExecutionMetadata_GetPlatformResources{Call: _m.Call.Return(_a0)} +} + +func (_m *TaskExecutionMetadata) OnGetPlatformResources() *TaskExecutionMetadata_GetPlatformResources { + c := _m.On("GetPlatformResources") + return &TaskExecutionMetadata_GetPlatformResources{Call: c} +} + +func (_m *TaskExecutionMetadata) OnGetPlatformResourcesMatch(matchers ...interface{}) *TaskExecutionMetadata_GetPlatformResources { + c := _m.On("GetPlatformResources", matchers...) + return &TaskExecutionMetadata_GetPlatformResources{Call: c} +} + +// GetPlatformResources provides a mock function with given fields: +func (_m *TaskExecutionMetadata) GetPlatformResources() *corev1.ResourceRequirements { + ret := _m.Called() + + var r0 *corev1.ResourceRequirements + if rf, ok := ret.Get(0).(func() *corev1.ResourceRequirements); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*corev1.ResourceRequirements) + } + } + + return r0 +} + type TaskExecutionMetadata_GetSecurityContext struct { *mock.Call } diff --git a/go/tasks/pluginmachinery/flytek8s/config/config.go b/go/tasks/pluginmachinery/flytek8s/config/config.go index 264e178f8..045a15f9a 100755 --- a/go/tasks/pluginmachinery/flytek8s/config/config.go +++ b/go/tasks/pluginmachinery/flytek8s/config/config.go @@ -8,6 +8,8 @@ package config import ( "time" + "k8s.io/apimachinery/pkg/api/resource" + config2 "github.com/flyteorg/flytestdlib/config" v1 "k8s.io/api/core/v1" @@ -17,8 +19,13 @@ import ( //go:generate pflags K8sPluginConfig --default-var=defaultK8sConfig const k8sPluginConfigSectionKey = "k8s" -const defaultCPURequest = "1000m" -const defaultMemoryRequest = "1024Mi" + +// ResourceNvidiaGPU is the name of the Nvidia GPU resource. +// Copied from: k8s.io/autoscaler/cluster-autoscaler/utils/gpu/gpu.go +const ResourceNvidiaGPU v1.ResourceName = "nvidia.com/gpu" + +var defaultCPURequest = resource.MustParse("1000m") +var defaultMemoryRequest = resource.MustParse("1024Mi") var ( defaultK8sConfig = K8sPluginConfig{ @@ -43,6 +50,7 @@ var ( CreateContainerErrorGracePeriod: config2.Duration{ Duration: time.Minute * 3, }, + GpuResourceName: ResourceNvidiaGPU, } // K8sPluginConfigSection provides a singular top level config section for all plugins. @@ -69,9 +77,9 @@ type K8sPluginConfig struct { DefaultEnvVarsFromEnv map[string]string `json:"default-env-vars-from-env" pflag:"-,Additional environment variable that should be injected into every resource"` // default cpu requests for a container - DefaultCPURequest string `json:"default-cpus" pflag:",Defines a default value for cpu for containers if not specified."` + DefaultCPURequest resource.Quantity `json:"default-cpus" pflag:",Defines a default value for cpu for containers if not specified."` // default memory requests for a container - DefaultMemoryRequest string `json:"default-memory" pflag:",Defines a default value for memory for containers if not specified."` + DefaultMemoryRequest resource.Quantity `json:"default-memory" pflag:",Defines a default value for memory for containers if not specified."` // Default Tolerations that will be added to every Pod that is created by Flyte. These can be used in heterogenous clusters, where one wishes to keep all pods created by Flyte on a separate // set of nodes. @@ -118,6 +126,9 @@ type K8sPluginConfig struct { // error persists past this grace period, it will be inferred to be a permanent // one, and the corresponding task marked as failed CreateContainerErrorGracePeriod config2.Duration `json:"create-container-error-grace-period" pflag:"-,Time to wait for transient CreateContainerError errors to be resolved."` + + // The name of the GPU resource to use when the task resource requests GPUs. + GpuResourceName v1.ResourceName `json:"gpu-resource-name" pflag:",The name of the GPU resource to use when the task resource requests GPUs."` } type FlyteCoPilotConfig struct { diff --git a/go/tasks/pluginmachinery/flytek8s/config/k8spluginconfig_flags.go b/go/tasks/pluginmachinery/flytek8s/config/k8spluginconfig_flags.go index 63e8fbe7a..95b483fa1 100755 --- a/go/tasks/pluginmachinery/flytek8s/config/k8spluginconfig_flags.go +++ b/go/tasks/pluginmachinery/flytek8s/config/k8spluginconfig_flags.go @@ -51,8 +51,6 @@ func (K8sPluginConfig) mustMarshalJSON(v json.Marshaler) string { func (cfg K8sPluginConfig) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags := pflag.NewFlagSet("K8sPluginConfig", pflag.ExitOnError) cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "inject-finalizer"), defaultK8sConfig.InjectFinalizer, "Instructs the plugin to inject a finalizer on startTask and remove it on task termination.") - cmdFlags.String(fmt.Sprintf("%v%v", prefix, "default-cpus"), defaultK8sConfig.DefaultCPURequest, "Defines a default value for cpu for containers if not specified.") - cmdFlags.String(fmt.Sprintf("%v%v", prefix, "default-memory"), defaultK8sConfig.DefaultMemoryRequest, "Defines a default value for memory for containers if not specified.") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "scheduler-name"), defaultK8sConfig.SchedulerName, "Defines scheduler name.") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "co-pilot.name"), defaultK8sConfig.CoPilot.NamePrefix, "Flyte co-pilot sidecar container name prefix. (additional bits will be added after this)") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "co-pilot.image"), defaultK8sConfig.CoPilot.Image, "Flyte co-pilot Docker Image FQN") diff --git a/go/tasks/pluginmachinery/flytek8s/config/k8spluginconfig_flags_test.go b/go/tasks/pluginmachinery/flytek8s/config/k8spluginconfig_flags_test.go index a852526de..10ade00fd 100755 --- a/go/tasks/pluginmachinery/flytek8s/config/k8spluginconfig_flags_test.go +++ b/go/tasks/pluginmachinery/flytek8s/config/k8spluginconfig_flags_test.go @@ -113,34 +113,6 @@ func TestK8sPluginConfig_SetFlags(t *testing.T) { } }) }) - t.Run("Test_default-cpus", func(t *testing.T) { - - t.Run("Override", func(t *testing.T) { - testValue := "1" - - cmdFlags.Set("default-cpus", testValue) - if vString, err := cmdFlags.GetString("default-cpus"); err == nil { - testDecodeJson_K8sPluginConfig(t, fmt.Sprintf("%v", vString), &actual.DefaultCPURequest) - - } else { - assert.FailNow(t, err.Error()) - } - }) - }) - t.Run("Test_default-memory", func(t *testing.T) { - - t.Run("Override", func(t *testing.T) { - testValue := "1" - - cmdFlags.Set("default-memory", testValue) - if vString, err := cmdFlags.GetString("default-memory"); err == nil { - testDecodeJson_K8sPluginConfig(t, fmt.Sprintf("%v", vString), &actual.DefaultMemoryRequest) - - } else { - assert.FailNow(t, err.Error()) - } - }) - }) t.Run("Test_scheduler-name", func(t *testing.T) { t.Run("Override", func(t *testing.T) { diff --git a/go/tasks/pluginmachinery/flytek8s/container_helper.go b/go/tasks/pluginmachinery/flytek8s/container_helper.go index a22b4bf8f..77ab805b1 100755 --- a/go/tasks/pluginmachinery/flytek8s/container_helper.go +++ b/go/tasks/pluginmachinery/flytek8s/container_helper.go @@ -3,11 +3,12 @@ package flytek8s import ( "context" + "github.com/flyteorg/flytestdlib/logger" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/template" "k8s.io/apimachinery/pkg/util/validation" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" - "github.com/flyteorg/flytestdlib/logger" v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" "k8s.io/apimachinery/pkg/util/rand" @@ -22,6 +23,10 @@ const resourceGPU = "gpu" // Copied from: k8s.io/autoscaler/cluster-autoscaler/utils/gpu/gpu.go const ResourceNvidiaGPU = "nvidia.com/gpu" +// Specifies whether resource resolution should assign unset resource requests or limits from platform defaults +// or existing container values. +const assignIfUnset = true + func MergeResources(in v1.ResourceRequirements, out *v1.ResourceRequirements) { if out.Limits == nil { out.Limits = in.Limits @@ -39,52 +44,120 @@ func MergeResources(in v1.ResourceRequirements, out *v1.ResourceRequirements) { } } -func ApplyResourceOverrides(ctx context.Context, resources v1.ResourceRequirements) *v1.ResourceRequirements { - // set memory and cpu to default if not provided by user. - if len(resources.Requests) == 0 { - resources.Requests = make(v1.ResourceList) +type ResourceRequirement struct { + Request resource.Quantity + Limit resource.Quantity +} + +func resolvePlatformDefaults(platformResources v1.ResourceRequirements, configCPU, configMemory resource.Quantity) v1.ResourceRequirements { + if len(platformResources.Requests) == 0 { + platformResources.Requests = make(v1.ResourceList) } - if len(resources.Limits) == 0 { - resources.Limits = make(v1.ResourceList) + if _, ok := platformResources.Requests[v1.ResourceCPU]; !ok { + platformResources.Requests[v1.ResourceCPU] = configCPU } - if _, found := resources.Requests[v1.ResourceCPU]; !found { - // use cpu limit if set else default to config - if _, limitSet := resources.Limits[v1.ResourceCPU]; limitSet { - resources.Requests[v1.ResourceCPU] = resources.Limits[v1.ResourceCPU] - } else { - resources.Requests[v1.ResourceCPU] = resource.MustParse(config.GetK8sPluginConfig().DefaultCPURequest) - } + if _, ok := platformResources.Requests[v1.ResourceMemory]; !ok { + platformResources.Requests[v1.ResourceMemory] = configMemory + } + + if len(platformResources.Limits) == 0 { + platformResources.Limits = make(v1.ResourceList) } - if _, found := resources.Requests[v1.ResourceMemory]; !found { - // use memory limit if set else default to config - if _, limitSet := resources.Limits[v1.ResourceMemory]; limitSet { - resources.Requests[v1.ResourceMemory] = resources.Limits[v1.ResourceMemory] + return platformResources +} + +// AdjustOrDefaultResource validates resources conform to platform limits and assigns defaults for Request and Limit values by +// using the Request when the Limit is unset, and vice versa. +func AdjustOrDefaultResource(request, limit, platformDefault, platformLimit resource.Quantity) ResourceRequirement { + if request.IsZero() { + if !limit.IsZero() { + request = limit } else { - resources.Requests[v1.ResourceMemory] = resource.MustParse(config.GetK8sPluginConfig().DefaultMemoryRequest) + request = platformDefault } } - if _, found := resources.Limits[v1.ResourceCPU]; !found { - logger.Infof(ctx, "found cpu limit missing, setting limit to the requested value %v", resources.Requests[v1.ResourceCPU]) - resources.Limits[v1.ResourceCPU] = resources.Requests[v1.ResourceCPU] + if limit.IsZero() { + limit = request } - if _, found := resources.Limits[v1.ResourceMemory]; !found { - logger.Infof(ctx, "found memory limit missing, setting limit to the requested value %v", resources.Requests[v1.ResourceMemory]) - resources.Limits[v1.ResourceMemory] = resources.Requests[v1.ResourceMemory] + return ensureResourceRange(request, limit, platformLimit) +} + +func ensureResourceLimit(value, limit resource.Quantity) resource.Quantity { + if value.IsZero() || limit.IsZero() { + return value } - // Ephemeral storage resources aren't required but if one of requests or limits is set and the other isn't, we'll - // just use the same values. - if _, requested := resources.Requests[v1.ResourceEphemeralStorage]; !requested { - if _, limitSet := resources.Limits[v1.ResourceEphemeralStorage]; limitSet { - resources.Requests[v1.ResourceEphemeralStorage] = resources.Limits[v1.ResourceEphemeralStorage] - } - } else if _, limitSet := resources.Limits[v1.ResourceEphemeralStorage]; !limitSet { - resources.Limits[v1.ResourceEphemeralStorage] = resources.Requests[v1.ResourceEphemeralStorage] + if value.Cmp(limit) == 1 { + return limit + } + + return value +} + +// ensureResourceRange doesn't assign resources unless they need to be adjusted downwards +func ensureResourceRange(request, limit, platformLimit resource.Quantity) ResourceRequirement { + // Ensure request is < platformLimit + request = ensureResourceLimit(request, platformLimit) + // Ensure limit is < platformLimit + limit = ensureResourceLimit(limit, platformLimit) + // Ensure request is < limit + request = ensureResourceLimit(request, limit) + + return ResourceRequirement{ + Request: request, + Limit: limit, + } +} + +func adjustResourceRequirement(resourceName v1.ResourceName, resourceRequirements, + platformResources v1.ResourceRequirements, assignIfUnset bool) { + + var resourceValue ResourceRequirement + if assignIfUnset { + resourceValue = AdjustOrDefaultResource(resourceRequirements.Requests[resourceName], + resourceRequirements.Limits[resourceName], platformResources.Requests[resourceName], + platformResources.Limits[resourceName]) + } else { + resourceValue = ensureResourceRange(resourceRequirements.Requests[resourceName], + resourceRequirements.Limits[resourceName], platformResources.Limits[resourceName]) + } + + resourceRequirements.Requests[resourceName] = resourceValue.Request + resourceRequirements.Limits[resourceName] = resourceValue.Limit +} + +// ApplyResourceOverrides handles resource resolution, allocation and validation. Primarily, it ensures that container +// resources do not exceed defined platformResource limits and in the case of assignIfUnset, ensures that limits and +// requests are sensibly set for resources of all types. +// Furthermore, this function handles some clean-up such as converting GPU resources to the recognized Nvidia gpu +// resource name and deleting unsupported Storage-type resources. +func ApplyResourceOverrides(resources, platformResources v1.ResourceRequirements, assignIfUnset bool) v1.ResourceRequirements { + if len(resources.Requests) == 0 { + resources.Requests = make(v1.ResourceList) + } + + if len(resources.Limits) == 0 { + resources.Limits = make(v1.ResourceList) + } + + // As a fallback, in the case the Flyte workflow object does not have platformResource defaults set, the defaults + // come from the plugin config. + platformResources = resolvePlatformDefaults(platformResources, config.GetK8sPluginConfig().DefaultCPURequest, + config.GetK8sPluginConfig().DefaultMemoryRequest) + + adjustResourceRequirement(v1.ResourceCPU, resources, platformResources, assignIfUnset) + adjustResourceRequirement(v1.ResourceMemory, resources, platformResources, assignIfUnset) + + _, ephemeralStorageRequested := resources.Requests[v1.ResourceEphemeralStorage] + _, ephemeralStorageLimited := resources.Limits[v1.ResourceEphemeralStorage] + + if ephemeralStorageRequested || ephemeralStorageLimited { + adjustResourceRequirement(v1.ResourceEphemeralStorage, resources, platformResources, assignIfUnset) } // TODO: Make configurable. 1/15/2019 Flyte Cluster doesn't support setting storage requests/limits. @@ -92,21 +165,36 @@ func ApplyResourceOverrides(ctx context.Context, resources v1.ResourceRequiremen delete(resources.Requests, v1.ResourceStorage) delete(resources.Limits, v1.ResourceStorage) + gpuResourceName := config.GetK8sPluginConfig().GpuResourceName + shouldAdjustGPU := false + _, gpuRequested := resources.Requests[gpuResourceName] + _, gpuLimited := resources.Limits[gpuResourceName] + if gpuRequested || gpuLimited { + shouldAdjustGPU = true + } + // Override GPU if res, found := resources.Requests[resourceGPU]; found { - resources.Requests[ResourceNvidiaGPU] = res + resources.Requests[gpuResourceName] = res delete(resources.Requests, resourceGPU) + shouldAdjustGPU = true } + if res, found := resources.Limits[resourceGPU]; found { - resources.Limits[ResourceNvidiaGPU] = res - delete(resources.Requests, resourceGPU) + resources.Limits[gpuResourceName] = res + delete(resources.Limits, resourceGPU) + shouldAdjustGPU = true } - return &resources + if shouldAdjustGPU { + adjustResourceRequirement(gpuResourceName, resources, platformResources, assignIfUnset) + } + + return resources } -// Transforms a task template target of type core.Container into a bare-bones kubernetes container, which can be further -// modified with flyte-specific customizations specified by various static and run-time attributes. +// ToK8sContainer transforms a task template target of type core.Container into a bare-bones kubernetes container, which +// can be further modified with flyte-specific customizations specified by various static and run-time attributes. func ToK8sContainer(ctx context.Context, taskContainer *core.Container, iFace *core.TypedInterface, parameters template.Parameters) (*v1.Container, error) { // Perform preliminary validations if parameters.TaskExecMetadata.GetOverrides() == nil { @@ -135,16 +223,24 @@ func ToK8sContainer(ctx context.Context, taskContainer *core.Container, iFace *c return container, nil } +//go:generate enumer -type=ResourceCustomizationMode -trimprefix=ResourceCustomizationMode + type ResourceCustomizationMode int const ( - AssignResources ResourceCustomizationMode = iota - MergeExistingResources - LeaveResourcesUnmodified + // ResourceCustomizationModeAssignResources is used for container tasks where resources are validated and assigned if necessary. + ResourceCustomizationModeAssignResources ResourceCustomizationMode = iota + // ResourceCustomizationModeMergeExistingResources is used for primary containers in pod tasks where container requests and limits are + // merged, validated and assigned if necessary. + ResourceCustomizationModeMergeExistingResources + // ResourceCustomizationModeEnsureExistingResourcesInRange is used for secondary containers in pod tasks where requests and limits are only + // adjusted if needed (downwards). + ResourceCustomizationModeEnsureExistingResourcesInRange ) -// Takes a container definition which specifies how to run a Flyte task and fills in templated command and argument -// values, updates resources and decorates environment variables with platform and task-specific customizations. +// AddFlyteCustomizationsToContainer takes a container definition which specifies how to run a Flyte task and fills in +// templated command and argument values, updates resources and decorates environment variables with platform and +// task-specific customizations. func AddFlyteCustomizationsToContainer(ctx context.Context, parameters template.Parameters, mode ResourceCustomizationMode, container *v1.Container) error { modifiedCommand, err := template.Render(ctx, container.Command, parameters) @@ -163,16 +259,25 @@ func AddFlyteCustomizationsToContainer(ctx context.Context, parameters template. if parameters.TaskExecMetadata.GetOverrides() != nil && parameters.TaskExecMetadata.GetOverrides().GetResources() != nil { res := parameters.TaskExecMetadata.GetOverrides().GetResources() + platformResources := parameters.TaskExecMetadata.GetPlatformResources() + if platformResources == nil { + platformResources = &v1.ResourceRequirements{} + } + + logger.Infof(ctx, "ApplyResourceOverrides with Resources [%v], Platform Resources [%v] and Container"+ + " Resources [%v] with mode [%v]", res, platformResources, container.Resources, mode) + switch mode { - case AssignResources: - if res = ApplyResourceOverrides(ctx, *res); res != nil { - container.Resources = *res - } - case MergeExistingResources: + case ResourceCustomizationModeAssignResources: + container.Resources = ApplyResourceOverrides(*res, *platformResources, assignIfUnset) + case ResourceCustomizationModeMergeExistingResources: MergeResources(*res, &container.Resources) - container.Resources = *ApplyResourceOverrides(ctx, container.Resources) - case LeaveResourcesUnmodified: + container.Resources = ApplyResourceOverrides(container.Resources, *platformResources, assignIfUnset) + case ResourceCustomizationModeEnsureExistingResourcesInRange: + container.Resources = ApplyResourceOverrides(container.Resources, *platformResources, !assignIfUnset) } + + logger.Infof(ctx, "Adjusted container resources [%v]", container.Resources) } return nil } diff --git a/go/tasks/pluginmachinery/flytek8s/container_helper_test.go b/go/tasks/pluginmachinery/flytek8s/container_helper_test.go index fc05e144a..9dde7a751 100755 --- a/go/tasks/pluginmachinery/flytek8s/container_helper_test.go +++ b/go/tasks/pluginmachinery/flytek8s/container_helper_test.go @@ -17,104 +17,198 @@ import ( "k8s.io/apimachinery/pkg/api/resource" ) +var zeroQuantity = resource.MustParse("0") + +func TestAssignResource(t *testing.T) { + t.Run("Leave valid requests and limits unchanged", func(t *testing.T) { + res := AdjustOrDefaultResource( + resource.MustParse("1"), resource.MustParse("2"), + resource.MustParse("10"), resource.MustParse("20")) + assert.True(t, res.Request.Equal(resource.MustParse("1"))) + assert.True(t, res.Limit.Equal(resource.MustParse("2"))) + }) + t.Run("Assign unset Request from Limit", func(t *testing.T) { + res := AdjustOrDefaultResource( + zeroQuantity, resource.MustParse("2"), + resource.MustParse("10"), resource.MustParse("20")) + assert.True(t, res.Request.Equal(resource.MustParse("2"))) + assert.True(t, res.Limit.Equal(resource.MustParse("2"))) + }) + t.Run("Assign unset Limit from Request", func(t *testing.T) { + res := AdjustOrDefaultResource( + resource.MustParse("2"), zeroQuantity, + resource.MustParse("10"), resource.MustParse("20")) + assert.Equal(t, resource.MustParse("2"), res.Request) + assert.Equal(t, resource.MustParse("2"), res.Limit) + }) + t.Run("Assign from platform defaults", func(t *testing.T) { + res := AdjustOrDefaultResource( + zeroQuantity, zeroQuantity, + resource.MustParse("10"), resource.MustParse("20")) + assert.Equal(t, resource.MustParse("10"), res.Request) + assert.Equal(t, resource.MustParse("10"), res.Limit) + }) + t.Run("Adjust Limit when Request > Limit", func(t *testing.T) { + res := AdjustOrDefaultResource( + resource.MustParse("10"), resource.MustParse("2"), + resource.MustParse("10"), resource.MustParse("20")) + assert.Equal(t, resource.MustParse("2"), res.Request) + assert.Equal(t, resource.MustParse("2"), res.Limit) + }) + t.Run("Adjust Limit > platformLimit", func(t *testing.T) { + res := AdjustOrDefaultResource( + resource.MustParse("1"), resource.MustParse("40"), + resource.MustParse("10"), resource.MustParse("20")) + assert.True(t, res.Request.Equal(resource.MustParse("1"))) + assert.True(t, res.Limit.Equal(resource.MustParse("20"))) + }) + t.Run("Adjust Request, Limit > platformLimit", func(t *testing.T) { + res := AdjustOrDefaultResource( + resource.MustParse("40"), resource.MustParse("50"), + resource.MustParse("10"), resource.MustParse("20")) + assert.True(t, res.Request.Equal(resource.MustParse("20"))) + assert.True(t, res.Limit.Equal(resource.MustParse("20"))) + }) +} + +func TestValidateResource(t *testing.T) { + platformLimit := resource.MustParse("5") + t.Run("adjust when Request > Limit", func(t *testing.T) { + res := ensureResourceRange(resource.MustParse("4"), resource.MustParse("3"), platformLimit) + assert.True(t, res.Request.Equal(resource.MustParse("3"))) + assert.True(t, res.Limit.Equal(resource.MustParse("3"))) + }) + t.Run("adjust when Request > platformLimit", func(t *testing.T) { + res := ensureResourceRange(resource.MustParse("6"), platformLimit, platformLimit) + assert.True(t, res.Request.Equal(platformLimit)) + assert.True(t, res.Limit.Equal(platformLimit)) + }) + t.Run("adjust when Limit > platformLimit", func(t *testing.T) { + res := ensureResourceRange(resource.MustParse("4"), resource.MustParse("6"), platformLimit) + assert.True(t, res.Request.Equal(resource.MustParse("4"))) + assert.True(t, res.Limit.Equal(platformLimit)) + }) + t.Run("nothing to do", func(t *testing.T) { + res := ensureResourceRange(resource.MustParse("1"), resource.MustParse("2"), platformLimit) + assert.True(t, res.Request.Equal(resource.MustParse("1"))) + assert.True(t, res.Limit.Equal(resource.MustParse("2"))) + }) +} + func TestApplyResourceOverrides_OverrideCpu(t *testing.T) { + platformRequirements := v1.ResourceRequirements{ + Requests: v1.ResourceList{ + v1.ResourceCPU: resource.MustParse("3"), + }, + Limits: v1.ResourceList{ + v1.ResourceCPU: resource.MustParse("10"), + }, + } cpuRequest := resource.MustParse("1") - overrides := ApplyResourceOverrides(context.Background(), v1.ResourceRequirements{ + overrides := ApplyResourceOverrides(v1.ResourceRequirements{ Requests: v1.ResourceList{ v1.ResourceCPU: cpuRequest, }, - }) + }, platformRequirements, assignIfUnset) assert.EqualValues(t, cpuRequest, overrides.Requests[v1.ResourceCPU]) assert.EqualValues(t, cpuRequest, overrides.Limits[v1.ResourceCPU]) cpuLimit := resource.MustParse("2") - overrides = ApplyResourceOverrides(context.Background(), v1.ResourceRequirements{ + overrides = ApplyResourceOverrides(v1.ResourceRequirements{ Requests: v1.ResourceList{ v1.ResourceCPU: cpuRequest, }, Limits: v1.ResourceList{ v1.ResourceCPU: cpuLimit, }, - }) + }, platformRequirements, assignIfUnset) assert.EqualValues(t, cpuRequest, overrides.Requests[v1.ResourceCPU]) assert.EqualValues(t, cpuLimit, overrides.Limits[v1.ResourceCPU]) - // request equals limit if not set - overrides = ApplyResourceOverrides(context.Background(), v1.ResourceRequirements{ + // Request equals Limit if not set + overrides = ApplyResourceOverrides(v1.ResourceRequirements{ Limits: v1.ResourceList{ v1.ResourceCPU: cpuLimit, }, - }) + }, platformRequirements, assignIfUnset) assert.EqualValues(t, cpuLimit, overrides.Requests[v1.ResourceCPU]) assert.EqualValues(t, cpuLimit, overrides.Limits[v1.ResourceCPU]) } func TestApplyResourceOverrides_OverrideMemory(t *testing.T) { memoryRequest := resource.MustParse("1") - overrides := ApplyResourceOverrides(context.Background(), v1.ResourceRequirements{ + platformRequirements := v1.ResourceRequirements{ + Requests: v1.ResourceList{ + v1.ResourceMemory: resource.MustParse("3"), + }, + Limits: v1.ResourceList{ + v1.ResourceMemory: resource.MustParse("10"), + }, + } + overrides := ApplyResourceOverrides(v1.ResourceRequirements{ Requests: v1.ResourceList{ v1.ResourceMemory: memoryRequest, }, - }) + }, platformRequirements, assignIfUnset) assert.EqualValues(t, memoryRequest, overrides.Requests[v1.ResourceMemory]) assert.EqualValues(t, memoryRequest, overrides.Limits[v1.ResourceMemory]) memoryLimit := resource.MustParse("2") - overrides = ApplyResourceOverrides(context.Background(), v1.ResourceRequirements{ + overrides = ApplyResourceOverrides(v1.ResourceRequirements{ Requests: v1.ResourceList{ v1.ResourceMemory: memoryRequest, }, Limits: v1.ResourceList{ v1.ResourceMemory: memoryLimit, }, - }) + }, platformRequirements, assignIfUnset) assert.EqualValues(t, memoryRequest, overrides.Requests[v1.ResourceMemory]) assert.EqualValues(t, memoryLimit, overrides.Limits[v1.ResourceMemory]) - // request equals limit if not set - overrides = ApplyResourceOverrides(context.Background(), v1.ResourceRequirements{ + // Request equals Limit if not set + overrides = ApplyResourceOverrides(v1.ResourceRequirements{ Limits: v1.ResourceList{ v1.ResourceMemory: memoryLimit, }, - }) + }, platformRequirements, assignIfUnset) assert.EqualValues(t, memoryLimit, overrides.Requests[v1.ResourceMemory]) assert.EqualValues(t, memoryLimit, overrides.Limits[v1.ResourceMemory]) } func TestApplyResourceOverrides_OverrideEphemeralStorage(t *testing.T) { ephemeralStorageRequest := resource.MustParse("1") - overrides := ApplyResourceOverrides(context.Background(), v1.ResourceRequirements{ + overrides := ApplyResourceOverrides(v1.ResourceRequirements{ Requests: v1.ResourceList{ v1.ResourceEphemeralStorage: ephemeralStorageRequest, }, - }) + }, v1.ResourceRequirements{}, assignIfUnset) assert.EqualValues(t, ephemeralStorageRequest, overrides.Requests[v1.ResourceEphemeralStorage]) assert.EqualValues(t, ephemeralStorageRequest, overrides.Limits[v1.ResourceEphemeralStorage]) ephemeralStorageLimit := resource.MustParse("2") - overrides = ApplyResourceOverrides(context.Background(), v1.ResourceRequirements{ + overrides = ApplyResourceOverrides(v1.ResourceRequirements{ Requests: v1.ResourceList{ v1.ResourceEphemeralStorage: ephemeralStorageRequest, }, Limits: v1.ResourceList{ v1.ResourceEphemeralStorage: ephemeralStorageLimit, }, - }) + }, v1.ResourceRequirements{}, assignIfUnset) assert.EqualValues(t, ephemeralStorageRequest, overrides.Requests[v1.ResourceEphemeralStorage]) assert.EqualValues(t, ephemeralStorageLimit, overrides.Limits[v1.ResourceEphemeralStorage]) - // request equals limit if not set - overrides = ApplyResourceOverrides(context.Background(), v1.ResourceRequirements{ + // Request equals Limit if not set + overrides = ApplyResourceOverrides(v1.ResourceRequirements{ Limits: v1.ResourceList{ v1.ResourceEphemeralStorage: ephemeralStorageLimit, }, - }) + }, v1.ResourceRequirements{}, assignIfUnset) assert.EqualValues(t, ephemeralStorageLimit, overrides.Requests[v1.ResourceEphemeralStorage]) } func TestApplyResourceOverrides_RemoveStorage(t *testing.T) { requestedResourceQuantity := resource.MustParse("1") - overrides := ApplyResourceOverrides(context.Background(), v1.ResourceRequirements{ + overrides := ApplyResourceOverrides(v1.ResourceRequirements{ Requests: v1.ResourceList{ v1.ResourceStorage: requestedResourceQuantity, v1.ResourceMemory: requestedResourceQuantity, @@ -126,7 +220,7 @@ func TestApplyResourceOverrides_RemoveStorage(t *testing.T) { v1.ResourceMemory: requestedResourceQuantity, v1.ResourceEphemeralStorage: requestedResourceQuantity, }, - }) + }, v1.ResourceRequirements{}, assignIfUnset) assert.EqualValues(t, v1.ResourceList{ v1.ResourceMemory: requestedResourceQuantity, v1.ResourceCPU: requestedResourceQuantity, @@ -142,18 +236,18 @@ func TestApplyResourceOverrides_RemoveStorage(t *testing.T) { func TestApplyResourceOverrides_OverrideGpu(t *testing.T) { gpuRequest := resource.MustParse("1") - overrides := ApplyResourceOverrides(context.Background(), v1.ResourceRequirements{ + overrides := ApplyResourceOverrides(v1.ResourceRequirements{ Requests: v1.ResourceList{ resourceGPU: gpuRequest, }, - }) + }, v1.ResourceRequirements{}, assignIfUnset) assert.EqualValues(t, gpuRequest, overrides.Requests[ResourceNvidiaGPU]) - overrides = ApplyResourceOverrides(context.Background(), v1.ResourceRequirements{ + overrides = ApplyResourceOverrides(v1.ResourceRequirements{ Limits: v1.ResourceList{ resourceGPU: gpuRequest, }, - }) + }, v1.ResourceRequirements{}, assignIfUnset) assert.EqualValues(t, gpuRequest, overrides.Limits[ResourceNvidiaGPU]) } @@ -298,7 +392,7 @@ func TestToK8sContainer(t *testing.T) { assert.Nil(t, errs) } -func TestAddFlyteCustomizationsToContainer(t *testing.T) { +func getTemplateParametersForTest(resourceRequirements, platformResources *v1.ResourceRequirements) template.Parameters { mockTaskExecMetadata := mocks.TaskExecutionMetadata{} mockTaskExecutionID := mocks.TaskExecutionID{} mockTaskExecutionID.OnGetGeneratedName().Return("gen_name") @@ -323,15 +417,9 @@ func TestAddFlyteCustomizationsToContainer(t *testing.T) { mockTaskExecMetadata.OnGetTaskExecutionID().Return(&mockTaskExecutionID) mockOverrides := mocks.TaskOverrides{} - mockOverrides.OnGetResources().Return(&v1.ResourceRequirements{ - Requests: v1.ResourceList{ - v1.ResourceEphemeralStorage: resource.MustParse("1024Mi"), - }, - Limits: v1.ResourceList{ - v1.ResourceEphemeralStorage: resource.MustParse("2048Mi"), - }, - }) + mockOverrides.OnGetResources().Return(resourceRequirements) mockTaskExecMetadata.OnGetOverrides().Return(&mockOverrides) + mockTaskExecMetadata.OnGetPlatformResources().Return(platformResources) mockInputReader := mocks2.InputReader{} mockInputPath := storage.DataReference("s3://input/path") @@ -344,11 +432,22 @@ func TestAddFlyteCustomizationsToContainer(t *testing.T) { mockOutputPath.OnGetRawOutputPrefix().Return(mockOutputPathPrefix) mockOutputPath.OnGetOutputPrefixPath().Return(mockOutputPathPrefix) - templateParameters := template.Parameters{ + return template.Parameters{ TaskExecMetadata: &mockTaskExecMetadata, Inputs: &mockInputReader, OutputPath: &mockOutputPath, } +} + +func TestAddFlyteCustomizationsToContainer(t *testing.T) { + templateParameters := getTemplateParametersForTest(&v1.ResourceRequirements{ + Requests: v1.ResourceList{ + v1.ResourceEphemeralStorage: resource.MustParse("1024Mi"), + }, + Limits: v1.ResourceList{ + v1.ResourceEphemeralStorage: resource.MustParse("2048Mi"), + }, + }, nil) container := &v1.Container{ Command: []string{ "{{ .Input }}", @@ -357,7 +456,7 @@ func TestAddFlyteCustomizationsToContainer(t *testing.T) { "{{ .OutputPrefix }}", }, } - err := AddFlyteCustomizationsToContainer(context.TODO(), templateParameters, AssignResources, container) + err := AddFlyteCustomizationsToContainer(context.TODO(), templateParameters, ResourceCustomizationModeAssignResources, container) assert.NoError(t, err) assert.EqualValues(t, container.Args, []string{"s3://output/path"}) assert.EqualValues(t, container.Command, []string{"s3://input/path"}) @@ -365,3 +464,135 @@ func TestAddFlyteCustomizationsToContainer(t *testing.T) { assert.Len(t, container.Resources.Requests, 3) assert.Len(t, container.Env, 12) } + +func TestAddFlyteCustomizationsToContainer_Resources(t *testing.T) { + container := &v1.Container{ + Command: []string{ + "{{ .Input }}", + }, + Args: []string{ + "{{ .OutputPrefix }}", + }, + Resources: v1.ResourceRequirements{ + Requests: v1.ResourceList{ + v1.ResourceCPU: resource.MustParse("1"), + }, + Limits: v1.ResourceList{ + v1.ResourceCPU: resource.MustParse("10"), + }, + }, + } + + t.Run("merge requests/limits for pod tasks - primary container", func(t *testing.T) { + templateParameters := getTemplateParametersForTest(&v1.ResourceRequirements{ + Requests: v1.ResourceList{ + v1.ResourceMemory: resource.MustParse("2"), + }, + }, &v1.ResourceRequirements{ + Requests: v1.ResourceList{ + v1.ResourceMemory: resource.MustParse("2"), + }, + Limits: v1.ResourceList{ + v1.ResourceMemory: resource.MustParse("20"), + }, + }) + err := AddFlyteCustomizationsToContainer(context.TODO(), templateParameters, ResourceCustomizationModeMergeExistingResources, container) + assert.NoError(t, err) + assert.True(t, container.Resources.Requests.Cpu().Equal(resource.MustParse("1"))) + assert.True(t, container.Resources.Limits.Cpu().Equal(resource.MustParse("10"))) + assert.True(t, container.Resources.Requests.Memory().Equal(resource.MustParse("2"))) + assert.True(t, container.Resources.Limits.Memory().Equal(resource.MustParse("2"))) + }) + t.Run("enforce merge requests/limits for pod tasks - values from task overrides", func(t *testing.T) { + templateParameters := getTemplateParametersForTest(&v1.ResourceRequirements{ + Requests: v1.ResourceList{ + v1.ResourceMemory: resource.MustParse("2"), + }, + Limits: v1.ResourceList{ + v1.ResourceMemory: resource.MustParse("200"), + }, + }, &v1.ResourceRequirements{ + Requests: v1.ResourceList{ + v1.ResourceMemory: resource.MustParse("2"), + }, + Limits: v1.ResourceList{ + v1.ResourceMemory: resource.MustParse("20"), + }, + }) + err := AddFlyteCustomizationsToContainer(context.TODO(), templateParameters, ResourceCustomizationModeMergeExistingResources, container) + assert.NoError(t, err) + assert.True(t, container.Resources.Requests.Cpu().Equal(resource.MustParse("1"))) + assert.True(t, container.Resources.Limits.Cpu().Equal(resource.MustParse("10"))) + assert.True(t, container.Resources.Requests.Memory().Equal(resource.MustParse("2"))) + assert.True(t, container.Resources.Limits.Memory().Equal(resource.MustParse("20"))) + }) + t.Run("enforce requests/limits for pod tasks - values from container", func(t *testing.T) { + container := &v1.Container{ + Command: []string{ + "{{ .Input }}", + }, + Args: []string{ + "{{ .OutputPrefix }}", + }, + Resources: v1.ResourceRequirements{ + Requests: v1.ResourceList{ + v1.ResourceCPU: resource.MustParse("100"), + }, + Limits: v1.ResourceList{ + v1.ResourceCPU: resource.MustParse("100"), + }, + }, + } + + templateParameters := getTemplateParametersForTest(&v1.ResourceRequirements{}, &v1.ResourceRequirements{ + Requests: v1.ResourceList{ + v1.ResourceCPU: resource.MustParse("1"), + v1.ResourceMemory: resource.MustParse("2"), + }, + Limits: v1.ResourceList{ + v1.ResourceCPU: resource.MustParse("10"), + v1.ResourceMemory: resource.MustParse("20"), + }, + }) + err := AddFlyteCustomizationsToContainer(context.TODO(), templateParameters, ResourceCustomizationModeMergeExistingResources, container) + assert.NoError(t, err) + assert.True(t, container.Resources.Requests.Cpu().Equal(resource.MustParse("10"))) + assert.True(t, container.Resources.Limits.Cpu().Equal(resource.MustParse("10"))) + assert.True(t, container.Resources.Requests.Memory().Equal(resource.MustParse("2"))) + assert.True(t, container.Resources.Limits.Memory().Equal(resource.MustParse("2"))) + }) +} + +func TestAddFlyteCustomizationsToContainer_ValidateExistingResources(t *testing.T) { + container := &v1.Container{ + Command: []string{ + "{{ .Input }}", + }, + Args: []string{ + "{{ .OutputPrefix }}", + }, + Resources: v1.ResourceRequirements{ + Requests: v1.ResourceList{ + v1.ResourceCPU: resource.MustParse("100"), + }, + Limits: v1.ResourceList{ + v1.ResourceCPU: resource.MustParse("200"), + }, + }, + } + templateParameters := getTemplateParametersForTest(&v1.ResourceRequirements{}, &v1.ResourceRequirements{ + Requests: v1.ResourceList{ + v1.ResourceCPU: resource.MustParse("1"), + v1.ResourceMemory: resource.MustParse("2"), + }, + Limits: v1.ResourceList{ + v1.ResourceCPU: resource.MustParse("10"), + v1.ResourceMemory: resource.MustParse("20"), + }, + }) + err := AddFlyteCustomizationsToContainer(context.TODO(), templateParameters, ResourceCustomizationModeEnsureExistingResourcesInRange, container) + assert.NoError(t, err) + + assert.True(t, container.Resources.Requests.Cpu().Equal(resource.MustParse("10"))) + assert.True(t, container.Resources.Limits.Cpu().Equal(resource.MustParse("10"))) +} diff --git a/go/tasks/pluginmachinery/flytek8s/pod_helper.go b/go/tasks/pluginmachinery/flytek8s/pod_helper.go index 9188083ea..cff8b11d9 100755 --- a/go/tasks/pluginmachinery/flytek8s/pod_helper.go +++ b/go/tasks/pluginmachinery/flytek8s/pod_helper.go @@ -113,7 +113,7 @@ func ToK8sPodSpecWithInterruptible(ctx context.Context, tCtx pluginsCore.TaskExe if err != nil { return nil, err } - err = AddFlyteCustomizationsToContainer(ctx, templateParameters, AssignResources, c) + err = AddFlyteCustomizationsToContainer(ctx, templateParameters, ResourceCustomizationModeAssignResources, c) if err != nil { return nil, err } @@ -161,7 +161,7 @@ func BuildIdentityPod() *v1.Pod { // The failure transitions from ErrImagePull -> ImagePullBackoff // Case II: Not enough resources are available. This is tricky. It could be that the total number of // resources requested is beyond the capability of the system. for this we will rely on configuration -// and hence input gates. We should not allow bad requests that request for large number of resource through. +// and hence input gates. We should not allow bad requests that Request for large number of resource through. // In the case it makes through, we will fail after timeout func DemystifyPending(status v1.PodStatus) (pluginsCore.PhaseInfo, error) { // Search over the difference conditions in the status object. Note that the 'Pending' this function is diff --git a/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go b/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go index 261e63f63..3520600a9 100755 --- a/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go +++ b/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go @@ -55,6 +55,7 @@ func dummyTaskExecutionMetadata(resources *v1.ResourceRequirements) pluginsCore. to.On("GetResources").Return(resources) taskExecutionMetadata.On("GetOverrides").Return(to) taskExecutionMetadata.On("IsInterruptible").Return(true) + taskExecutionMetadata.OnGetPlatformResources().Return(&v1.ResourceRequirements{}) return taskExecutionMetadata } @@ -404,8 +405,8 @@ func TestToK8sPod(t *testing.T) { v1.ResourceStorage: {tolStorage}, ResourceNvidiaGPU: {tolGPU}, }, - DefaultCPURequest: "1024m", - DefaultMemoryRequest: "1024Mi", + DefaultCPURequest: resource.MustParse("1024m"), + DefaultMemoryRequest: resource.MustParse("1024Mi"), })) op := &pluginsIOMock.OutputFilePaths{} @@ -471,8 +472,8 @@ func TestToK8sPod(t *testing.T) { "nodeId": "123", }, SchedulerName: "myScheduler", - DefaultCPURequest: "1024m", - DefaultMemoryRequest: "1024Mi", + DefaultCPURequest: resource.MustParse("1024m"), + DefaultMemoryRequest: resource.MustParse("1024Mi"), })) p, err := ToK8sPodSpec(ctx, x) diff --git a/go/tasks/pluginmachinery/flytek8s/resourcecustomizationmode_enumer.go b/go/tasks/pluginmachinery/flytek8s/resourcecustomizationmode_enumer.go new file mode 100644 index 000000000..c01befae4 --- /dev/null +++ b/go/tasks/pluginmachinery/flytek8s/resourcecustomizationmode_enumer.go @@ -0,0 +1,51 @@ +// Code generated by "enumer -type=ResourceCustomizationMode -trimprefix=ResourceCustomizationMode"; DO NOT EDIT. + +// +package flytek8s + +import ( + "fmt" +) + +const _ResourceCustomizationModeName = "AssignResourcesMergeExistingResourcesEnsureExistingResourcesInRange" + +var _ResourceCustomizationModeIndex = [...]uint8{0, 15, 37, 67} + +func (i ResourceCustomizationMode) String() string { + if i < 0 || i >= ResourceCustomizationMode(len(_ResourceCustomizationModeIndex)-1) { + return fmt.Sprintf("ResourceCustomizationMode(%d)", i) + } + return _ResourceCustomizationModeName[_ResourceCustomizationModeIndex[i]:_ResourceCustomizationModeIndex[i+1]] +} + +var _ResourceCustomizationModeValues = []ResourceCustomizationMode{0, 1, 2} + +var _ResourceCustomizationModeNameToValueMap = map[string]ResourceCustomizationMode{ + _ResourceCustomizationModeName[0:15]: 0, + _ResourceCustomizationModeName[15:37]: 1, + _ResourceCustomizationModeName[37:67]: 2, +} + +// ResourceCustomizationModeString retrieves an enum value from the enum constants string name. +// Throws an error if the param is not part of the enum. +func ResourceCustomizationModeString(s string) (ResourceCustomizationMode, error) { + if val, ok := _ResourceCustomizationModeNameToValueMap[s]; ok { + return val, nil + } + return 0, fmt.Errorf("%s does not belong to ResourceCustomizationMode values", s) +} + +// ResourceCustomizationModeValues returns all values of the enum +func ResourceCustomizationModeValues() []ResourceCustomizationMode { + return _ResourceCustomizationModeValues +} + +// IsAResourceCustomizationMode returns "true" if the value is listed in the enum definition. "false" otherwise +func (i ResourceCustomizationMode) IsAResourceCustomizationMode() bool { + for _, v := range _ResourceCustomizationModeValues { + if i == v { + return true + } + } + return false +} diff --git a/go/tasks/plugins/array/awsbatch/launcher_test.go b/go/tasks/plugins/array/awsbatch/launcher_test.go index 89ae694ce..5d1c20224 100644 --- a/go/tasks/plugins/array/awsbatch/launcher_test.go +++ b/go/tasks/plugins/array/awsbatch/launcher_test.go @@ -75,6 +75,7 @@ func TestLaunchSubTasks(t *testing.T) { tMeta := &mocks.TaskExecutionMetadata{} tMeta.OnGetTaskExecutionID().Return(tID) tMeta.OnGetOverrides().Return(overrides) + tMeta.OnGetPlatformResources().Return(&v1.ResourceRequirements{}) ow := &mocks3.OutputWriter{} ow.OnGetOutputPrefixPath().Return("/prefix/") diff --git a/go/tasks/plugins/array/awsbatch/transformer.go b/go/tasks/plugins/array/awsbatch/transformer.go index 81dd119d7..e68ada543 100644 --- a/go/tasks/plugins/array/awsbatch/transformer.go +++ b/go/tasks/plugins/array/awsbatch/transformer.go @@ -25,6 +25,8 @@ const ( arrayJobIDFormatter = "%v:%v" ) +const assignResources = true + // Note that Name is not set on the result object. // It's up to the caller to set the Name before creating the object in K8s. func FlyteTaskToBatchInput(ctx context.Context, tCtx pluginCore.TaskExecutionContext, jobDefinition string, cfg *config2.Config) ( @@ -77,14 +79,18 @@ func FlyteTaskToBatchInput(ctx context.Context, tCtx pluginCore.TaskExecutionCon envVars := getEnvVarsForTask(ctx, tCtx.TaskExecutionMetadata().GetTaskExecutionID(), taskTemplate.GetContainer().GetEnv(), cfg.DefaultEnvVars) res := tCtx.TaskExecutionMetadata().GetOverrides().GetResources() - res = flytek8s.ApplyResourceOverrides(ctx, *res) + platformResources := tCtx.TaskExecutionMetadata().GetPlatformResources() + if platformResources == nil { + platformResources = &v1.ResourceRequirements{} + } + resources := flytek8s.ApplyResourceOverrides(*res, *platformResources, assignResources) return &batch.SubmitJobInput{ JobName: refStr(tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName()), JobDefinition: refStr(jobDefinition), JobQueue: refStr(jobConfig.DynamicTaskQueue), RetryStrategy: toRetryStrategy(ctx, toBackoffLimit(taskTemplate.Metadata), cfg.MinRetries, cfg.MaxRetries), - ContainerOverrides: toContainerOverrides(ctx, append(cmd, args...), res, envVars), + ContainerOverrides: toContainerOverrides(ctx, append(cmd, args...), &resources, envVars), Timeout: toTimeout(taskTemplate.Metadata.GetTimeout(), cfg.DefaultTimeOut.Duration), }, nil } diff --git a/go/tasks/plugins/array/awsbatch/transformer_test.go b/go/tasks/plugins/array/awsbatch/transformer_test.go index 853eb1527..82bf5b78d 100644 --- a/go/tasks/plugins/array/awsbatch/transformer_test.go +++ b/go/tasks/plugins/array/awsbatch/transformer_test.go @@ -164,6 +164,7 @@ func TestArrayJobToBatchInput(t *testing.T) { tMetadata.OnGetOwnerReference().Return(v1.OwnerReference{Name: "x"}) tMetadata.OnGetTaskExecutionID().Return(id) tMetadata.OnGetOverrides().Return(to) + tMetadata.OnGetPlatformResources().Return(&v12.ResourceRequirements{}) ir := &mocks2.InputReader{} ir.OnGetInputPath().Return("inputs.pb") diff --git a/go/tasks/plugins/array/k8s/monitor_test.go b/go/tasks/plugins/array/k8s/monitor_test.go index 80c4f01c4..b648239c4 100644 --- a/go/tasks/plugins/array/k8s/monitor_test.go +++ b/go/tasks/plugins/array/k8s/monitor_test.go @@ -79,6 +79,7 @@ func getMockTaskExecutionContext(ctx context.Context) *mocks.TaskExecutionContex tMeta.OnGetLabels().Return(nil) tMeta.OnGetAnnotations().Return(nil) tMeta.OnGetOwnerReference().Return(v12.OwnerReference{}) + tMeta.OnGetPlatformResources().Return(&v1.ResourceRequirements{}) ow := &mocks2.OutputWriter{} ow.OnGetOutputPrefixPath().Return("/prefix/") diff --git a/go/tasks/plugins/array/k8s/transformer.go b/go/tasks/plugins/array/k8s/transformer.go index 46237d898..849d5f98e 100644 --- a/go/tasks/plugins/array/k8s/transformer.go +++ b/go/tasks/plugins/array/k8s/transformer.go @@ -166,7 +166,7 @@ func FlyteArrayJobToK8sPodTemplate(ctx context.Context, tCtx core.TaskExecutionC Task: tCtx.TaskReader(), } err = flytek8s.AddFlyteCustomizationsToContainer( - ctx, templateParameters, flytek8s.MergeExistingResources, &pod.Spec.Containers[containerIndex]) + ctx, templateParameters, flytek8s.ResourceCustomizationModeMergeExistingResources, &pod.Spec.Containers[containerIndex]) if err != nil { return v1.Pod{}, nil, err } diff --git a/go/tasks/plugins/array/k8s/transformer_test.go b/go/tasks/plugins/array/k8s/transformer_test.go index 7b167a362..822e3c544 100644 --- a/go/tasks/plugins/array/k8s/transformer_test.go +++ b/go/tasks/plugins/array/k8s/transformer_test.go @@ -179,6 +179,7 @@ func TestFlyteArrayJobToK8sPodTemplate(t *testing.T) { }, }) tMeta.OnGetOverrides().Return(&mockResourceOverrides) + tMeta.OnGetPlatformResources().Return(&v1.ResourceRequirements{}) tID := &mocks.TaskExecutionID{} tID.OnGetID().Return(core.TaskExecutionIdentifier{ NodeExecutionId: &core.NodeExecutionIdentifier{ diff --git a/go/tasks/plugins/k8s/container/container_test.go b/go/tasks/plugins/k8s/container/container_test.go index 20a01d8a6..919292f44 100755 --- a/go/tasks/plugins/k8s/container/container_test.go +++ b/go/tasks/plugins/k8s/container/container_test.go @@ -48,6 +48,7 @@ func dummyContainerTaskMetadata(resources *v1.ResourceRequirements) pluginsCore. Namespace: "test-namespace", Name: "test-owner-name", }) + taskMetadata.OnGetPlatformResources().Return(&v1.ResourceRequirements{}) tID := &pluginsCoreMock.TaskExecutionID{} tID.On("GetID").Return(core.TaskExecutionIdentifier{ diff --git a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go index 7f3d458bb..0d5f62d0a 100644 --- a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go +++ b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go @@ -130,7 +130,18 @@ func dummyPytorchTaskContext(taskTemplate *core.TaskTemplate) pluginsCore.TaskEx tID.OnGetGeneratedName().Return("some-acceptable-name") resources := &mocks.TaskOverrides{} - resources.OnGetResources().Return(resourceRequirements) + resources.OnGetResources().Return(&corev1.ResourceRequirements{ + Limits: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("1000m"), + corev1.ResourceMemory: resource.MustParse("1Gi"), + flytek8s.ResourceNvidiaGPU: resource.MustParse("1"), + }, + Requests: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("100m"), + corev1.ResourceMemory: resource.MustParse("512Mi"), + flytek8s.ResourceNvidiaGPU: resource.MustParse("1"), + }, + }) taskExecutionMetadata := &mocks.TaskExecutionMetadata{} taskExecutionMetadata.OnGetTaskExecutionID().Return(tID) @@ -144,6 +155,7 @@ func dummyPytorchTaskContext(taskTemplate *core.TaskTemplate) pluginsCore.TaskEx taskExecutionMetadata.OnIsInterruptible().Return(true) taskExecutionMetadata.OnGetOverrides().Return(resources) taskExecutionMetadata.OnGetK8sServiceAccount().Return(serviceAccount) + taskExecutionMetadata.OnGetPlatformResources().Return(&corev1.ResourceRequirements{}) taskCtx.OnTaskExecutionMetadata().Return(taskExecutionMetadata) return taskCtx } @@ -292,8 +304,8 @@ func TestBuildResourcePytorch(t *testing.T) { hasContainerWithDefaultPytorchName = true } - assert.Equal(t, resourceRequirements.Requests, container.Resources.Requests) - assert.Equal(t, resourceRequirements.Limits, container.Resources.Limits) + assert.Equal(t, resourceRequirements.Requests, container.Resources.Requests, fmt.Sprintf(" container.Resources.Requests [%+v]", container.Resources.Requests.Cpu().String())) + assert.Equal(t, resourceRequirements.Limits, container.Resources.Limits, fmt.Sprintf(" container.Resources.Limits [%+v]", container.Resources.Limits.Cpu().String())) } assert.True(t, hasContainerWithDefaultPytorchName) diff --git a/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go b/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go index 1c60386a8..b59b75858 100644 --- a/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go +++ b/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go @@ -146,6 +146,7 @@ func dummyTensorFlowTaskContext(taskTemplate *core.TaskTemplate) pluginsCore.Tas taskExecutionMetadata.OnIsInterruptible().Return(true) taskExecutionMetadata.OnGetOverrides().Return(resources) taskExecutionMetadata.OnGetK8sServiceAccount().Return(serviceAccount) + taskExecutionMetadata.OnGetPlatformResources().Return(&corev1.ResourceRequirements{}) taskCtx.OnTaskExecutionMetadata().Return(taskExecutionMetadata) return taskCtx } diff --git a/go/tasks/plugins/k8s/sidecar/sidecar.go b/go/tasks/plugins/k8s/sidecar/sidecar.go index d37a0620b..fda1bd502 100755 --- a/go/tasks/plugins/k8s/sidecar/sidecar.go +++ b/go/tasks/plugins/k8s/sidecar/sidecar.go @@ -35,10 +35,10 @@ func validateAndFinalizePod( resReqs := make([]k8sv1.ResourceRequirements, 0, len(pod.Spec.Containers)) for index, container := range pod.Spec.Containers { - var resourceMode = flytek8s.LeaveResourcesUnmodified + var resourceMode = flytek8s.ResourceCustomizationModeEnsureExistingResourcesInRange if container.Name == primaryContainerName { hasPrimaryContainer = true - resourceMode = flytek8s.MergeExistingResources + resourceMode = flytek8s.ResourceCustomizationModeMergeExistingResources } templateParameters := template.Parameters{ TaskExecMetadata: taskCtx.TaskExecutionMetadata(), diff --git a/go/tasks/plugins/k8s/sidecar/sidecar_test.go b/go/tasks/plugins/k8s/sidecar/sidecar_test.go index 09515d474..a33b52bb1 100755 --- a/go/tasks/plugins/k8s/sidecar/sidecar_test.go +++ b/go/tasks/plugins/k8s/sidecar/sidecar_test.go @@ -73,6 +73,7 @@ func dummyContainerTaskMetadata(resources *v1.ResourceRequirements) pluginsCore. Namespace: "test-namespace", Name: "test-owner-name", }) + taskMetadata.OnGetPlatformResources().Return(&v1.ResourceRequirements{}) tID := &pluginsCoreMock.TaskExecutionID{} tID.On("GetID").Return(core.TaskExecutionIdentifier{ @@ -128,10 +129,12 @@ func getPodSpec() v1.PodSpec { Limits: v1.ResourceList{ "cpu": resource.MustParse("2"), "memory": resource.MustParse("200Mi"), + "gpu": resource.MustParse("1"), }, Requests: v1.ResourceList{ "cpu": resource.MustParse("1"), "memory": resource.MustParse("100Mi"), + "gpu": resource.MustParse("1"), }, }, VolumeMounts: []v1.VolumeMount{ @@ -142,6 +145,14 @@ func getPodSpec() v1.PodSpec { }, { Name: "secondary container", + Resources: v1.ResourceRequirements{ + Limits: v1.ResourceList{ + "gpu": resource.MustParse("2"), + }, + Requests: v1.ResourceList{ + "gpu": resource.MustParse("2"), + }, + }, }, }, Volumes: []v1.Volume{ @@ -223,8 +234,9 @@ func TestBuildSidecarResource_TaskType2(t *testing.T) { v1.ResourceStorage: {tolStorage}, ResourceNvidiaGPU: {tolGPU}, }, - DefaultCPURequest: "1024m", - DefaultMemoryRequest: "1024Mi", + DefaultCPURequest: resource.MustParse("1024m"), + DefaultMemoryRequest: resource.MustParse("1024Mi"), + GpuResourceName: ResourceNvidiaGPU, })) handler := &sidecarResourceHandler{} taskCtx := getDummySidecarTaskContext(&task, resourceRequirements) @@ -258,6 +270,13 @@ func TestBuildSidecarResource_TaskType2(t *testing.T) { assert.Equal(t, expectedMemLimit.Value(), res.(*v1.Pod).Spec.Containers[0].Resources.Limits.Memory().Value()) expectedEphemeralStorageLimit := resource.MustParse("100M") assert.Equal(t, expectedEphemeralStorageLimit.Value(), res.(*v1.Pod).Spec.Containers[0].Resources.Limits.StorageEphemeral().Value()) + + expectedGPURes := resource.MustParse("1") + assert.Equal(t, expectedGPURes, res.(*v1.Pod).Spec.Containers[0].Resources.Requests[ResourceNvidiaGPU]) + assert.Equal(t, expectedGPURes, res.(*v1.Pod).Spec.Containers[0].Resources.Limits[ResourceNvidiaGPU]) + expectedGPURes = resource.MustParse("2") + assert.Equal(t, expectedGPURes, res.(*v1.Pod).Spec.Containers[1].Resources.Requests[ResourceNvidiaGPU]) + assert.Equal(t, expectedGPURes, res.(*v1.Pod).Spec.Containers[1].Resources.Limits[ResourceNvidiaGPU]) } func TestBuildSidecarResource_TaskType2_Invalid_Spec(t *testing.T) { @@ -325,8 +344,8 @@ func TestBuildSidecarResource_TaskType1(t *testing.T) { v1.ResourceStorage: {tolStorage}, ResourceNvidiaGPU: {tolGPU}, }, - DefaultCPURequest: "1024m", - DefaultMemoryRequest: "1024Mi", + DefaultCPURequest: resource.MustParse("1024m"), + DefaultMemoryRequest: resource.MustParse("1024Mi"), })) handler := &sidecarResourceHandler{} taskCtx := getDummySidecarTaskContext(&task, resourceRequirements) @@ -390,8 +409,8 @@ func TestBuildSideResource_TaskType1_InvalidSpec(t *testing.T) { v1.ResourceStorage: {}, ResourceNvidiaGPU: {}, }, - DefaultCPURequest: "1024m", - DefaultMemoryRequest: "1024Mi", + DefaultCPURequest: resource.MustParse("1024m"), + DefaultMemoryRequest: resource.MustParse("1024Mi"), })) handler := &sidecarResourceHandler{} taskCtx := getDummySidecarTaskContext(&task, resourceRequirements) @@ -442,8 +461,8 @@ func TestBuildSidecarResource(t *testing.T) { v1.ResourceStorage: {tolStorage}, ResourceNvidiaGPU: {tolGPU}, }, - DefaultCPURequest: "1024m", - DefaultMemoryRequest: "1024Mi", + DefaultCPURequest: resource.MustParse("1024m"), + DefaultMemoryRequest: resource.MustParse("1024Mi"), })) handler := &sidecarResourceHandler{} taskCtx := getDummySidecarTaskContext(&task, resourceRequirements) diff --git a/tests/end_to_end.go b/tests/end_to_end.go index ee855724f..b1e6c621a 100644 --- a/tests/end_to_end.go +++ b/tests/end_to_end.go @@ -171,6 +171,7 @@ func RunPluginEndToEndTest(t *testing.T, executor pluginCore.Plugin, template *i Namespace: "fake-development", Name: execID, }) + tMeta.OnGetPlatformResources().Return(&v1.ResourceRequirements{}) catClient := &catalogMocks.Client{} catData := sync.Map{}