From 2778db206bbea478908c4e529dcb63cd438b6065 Mon Sep 17 00:00:00 2001 From: Jan Fiedler Date: Wed, 20 Nov 2024 09:32:55 -0800 Subject: [PATCH] put construct_k8s_pod_spec_from_resources into core/resources.py --- flytekit/core/resources.py | 68 +++++++++++++++++-- .../flytekitplugins/ray/models.py | 54 +-------------- 2 files changed, 65 insertions(+), 57 deletions(-) diff --git a/flytekit/core/resources.py b/flytekit/core/resources.py index 8a99dbf2ea..e1b241b64d 100644 --- a/flytekit/core/resources.py +++ b/flytekit/core/resources.py @@ -1,9 +1,10 @@ -from dataclasses import dataclass -from typing import List, Optional, Union +from dataclasses import dataclass, fields +from typing import List, Optional, Union, Any from mashumaro.mixins.json import DataClassJSONMixin from flytekit.models import task as task_models +from kubernetes.client import V1PodSpec, V1Container, V1ResourceRequirements @dataclass @@ -66,14 +67,23 @@ class ResourceSpec(DataClassJSONMixin): def _convert_resources_to_resource_entries(resources: Resources) -> List[_ResourceEntry]: # type: ignore resource_entries = [] if resources.cpu is not None: - resource_entries.append(_ResourceEntry(name=_ResourceName.CPU, value=str(resources.cpu))) + resource_entries.append( + _ResourceEntry(name=_ResourceName.CPU, value=str(resources.cpu)) + ) if resources.mem is not None: - resource_entries.append(_ResourceEntry(name=_ResourceName.MEMORY, value=str(resources.mem))) + resource_entries.append( + _ResourceEntry(name=_ResourceName.MEMORY, value=str(resources.mem)) + ) if resources.gpu is not None: - resource_entries.append(_ResourceEntry(name=_ResourceName.GPU, value=str(resources.gpu))) + resource_entries.append( + _ResourceEntry(name=_ResourceName.GPU, value=str(resources.gpu)) + ) if resources.ephemeral_storage is not None: resource_entries.append( - _ResourceEntry(name=_ResourceName.EPHEMERAL_STORAGE, value=str(resources.ephemeral_storage)) + _ResourceEntry( + name=_ResourceName.EPHEMERAL_STORAGE, + value=str(resources.ephemeral_storage), + ) ) return resource_entries @@ -96,3 +106,49 @@ def convert_resources_to_resource_model( if limits is not None: limit_entries = _convert_resources_to_resource_entries(limits) return task_models.Resources(requests=request_entries, limits=limit_entries) + + +def construct_k8s_pod_spec_from_resources( + k8s_pod_name: str, + requests: Optional[Resources], + limits: Optional[Resources], +) -> dict[str, Any]: + + def construct_k8s_pods_resources(resources: Optional[Resources]): + if resources is None: + return None + + resources_map = { + "cpu": "cpu", + "mem": "memory", + "gpu": "nvidia.com/gpu", + "ephemeral_storage": "ephemeral-storage", + } + + k8s_pod_resources = {} + + for resource in fields(resources): + resource_value = getattr(resources, resource.name) + if resource_value is not None: + k8s_pod_resources[resources_map[resource.name]] = resource_value + + return k8s_pod_resources + + requests = construct_k8s_pods_resources(resources=requests) + limits = construct_k8s_pods_resources(resources=limits) + requests = requests or limits + limits = limits or requests + + k8s_pod = V1PodSpec( + containers=[ + V1Container( + name=k8s_pod_name, + resources=V1ResourceRequirements( + requests=requests, + limits=limits, + ), + ) + ] + ) + + return k8s_pod.to_dict() diff --git a/plugins/flytekit-ray/flytekitplugins/ray/models.py b/plugins/flytekit-ray/flytekitplugins/ray/models.py index 3c009dee5a..1d4d790b4e 100644 --- a/plugins/flytekit-ray/flytekitplugins/ray/models.py +++ b/plugins/flytekit-ray/flytekitplugins/ray/models.py @@ -4,57 +4,10 @@ from flytekit.models import common as _common from flytekit.models.task import K8sPod, K8sObjectMetadata -from flytekit.core.resources import Resources +from flytekit.core.resources import Resources, construct_k8s_pod_spec_from_resources from kubernetes.client import V1PodSpec, V1Container, V1ResourceRequirements -def construct_k8s_pod_spec( - k8s_pod_name: str, - requests: typing.Optional[Resources], - limits: typing.Optional[Resources], -) -> dict[str, typing.Any]: - - def construct_k8s_pods_resources(resources: typing.Optional[Resources]): - if resources is None: - return None - - resources_map = { - "cpu": "cpu", - "mem": "memory", - "gpu": "nvidia.com/gpu", - "ephemeral_storage": "ephemeral-storage", - } - - k8s_pod_resources = {} - - for resource in fields(resources): - resource_value = getattr(resources, resource.name) - if resource_value is not None: - k8s_pod_resources[resources_map[resource.name]] = resource_value - - print(k8s_pod_resources) - return k8s_pod_resources - - requests = construct_k8s_pods_resources(resources=requests) - limits = construct_k8s_pods_resources(resources=limits) - requests = requests or limits - limits = limits or requests - - k8s_pod = V1PodSpec( - containers=[ - V1Container( - name=k8s_pod_name, - resources=V1ResourceRequirements( - requests=requests, - limits=limits, - ), - ) - ] - ) - - return k8s_pod.to_dict() - - class WorkerGroupSpec(_common.FlyteIdlEntity): def __init__( self, @@ -79,11 +32,10 @@ def __init__( self._limits = limits self._k8s_pod = K8sPod( metadata=K8sObjectMetadata(), - pod_spec=construct_k8s_pod_spec( + pod_spec=construct_k8s_pod_spec_from_resources( k8s_pod_name="ray-worker", requests=self._requests, limits=self._limits ), ) - # exit(0) @property def group_name(self): @@ -179,7 +131,7 @@ def __init__( self._k8s_pod = K8sPod( metadata=K8sObjectMetadata(), - pod_spec=construct_k8s_pod_spec( + pod_spec=construct_k8s_pod_spec_from_resources( k8s_pod_name="ray-head", requests=self._requests, limits=self._limits ), )