Skip to content

Commit

Permalink
put construct_k8s_pod_spec_from_resources into core/resources.py
Browse files Browse the repository at this point in the history
  • Loading branch information
fiedlerNr9 committed Nov 20, 2024
1 parent 887de26 commit 2778db2
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 57 deletions.
68 changes: 62 additions & 6 deletions flytekit/core/resources.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from dataclasses import dataclass
from typing import List, Optional, Union
from dataclasses import dataclass, fields
from typing import List, Optional, Union, Any

from mashumaro.mixins.json import DataClassJSONMixin

from flytekit.models import task as task_models
from kubernetes.client import V1PodSpec, V1Container, V1ResourceRequirements


@dataclass
Expand Down Expand Up @@ -66,14 +67,23 @@ class ResourceSpec(DataClassJSONMixin):
def _convert_resources_to_resource_entries(resources: Resources) -> List[_ResourceEntry]: # type: ignore
resource_entries = []
if resources.cpu is not None:
resource_entries.append(_ResourceEntry(name=_ResourceName.CPU, value=str(resources.cpu)))
resource_entries.append(
_ResourceEntry(name=_ResourceName.CPU, value=str(resources.cpu))
)
if resources.mem is not None:
resource_entries.append(_ResourceEntry(name=_ResourceName.MEMORY, value=str(resources.mem)))
resource_entries.append(
_ResourceEntry(name=_ResourceName.MEMORY, value=str(resources.mem))
)
if resources.gpu is not None:
resource_entries.append(_ResourceEntry(name=_ResourceName.GPU, value=str(resources.gpu)))
resource_entries.append(
_ResourceEntry(name=_ResourceName.GPU, value=str(resources.gpu))
)
if resources.ephemeral_storage is not None:
resource_entries.append(
_ResourceEntry(name=_ResourceName.EPHEMERAL_STORAGE, value=str(resources.ephemeral_storage))
_ResourceEntry(
name=_ResourceName.EPHEMERAL_STORAGE,
value=str(resources.ephemeral_storage),
)
)
return resource_entries

Expand All @@ -96,3 +106,49 @@ def convert_resources_to_resource_model(
if limits is not None:
limit_entries = _convert_resources_to_resource_entries(limits)
return task_models.Resources(requests=request_entries, limits=limit_entries)


def construct_k8s_pod_spec_from_resources(
k8s_pod_name: str,
requests: Optional[Resources],
limits: Optional[Resources],
) -> dict[str, Any]:

def construct_k8s_pods_resources(resources: Optional[Resources]):
if resources is None:
return None

resources_map = {
"cpu": "cpu",
"mem": "memory",
"gpu": "nvidia.com/gpu",
"ephemeral_storage": "ephemeral-storage",
}

k8s_pod_resources = {}

for resource in fields(resources):
resource_value = getattr(resources, resource.name)
if resource_value is not None:
k8s_pod_resources[resources_map[resource.name]] = resource_value

return k8s_pod_resources

requests = construct_k8s_pods_resources(resources=requests)
limits = construct_k8s_pods_resources(resources=limits)
requests = requests or limits
limits = limits or requests

k8s_pod = V1PodSpec(
containers=[
V1Container(
name=k8s_pod_name,
resources=V1ResourceRequirements(
requests=requests,
limits=limits,
),
)
]
)

return k8s_pod.to_dict()
54 changes: 3 additions & 51 deletions plugins/flytekit-ray/flytekitplugins/ray/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,57 +4,10 @@

from flytekit.models import common as _common
from flytekit.models.task import K8sPod, K8sObjectMetadata
from flytekit.core.resources import Resources
from flytekit.core.resources import Resources, construct_k8s_pod_spec_from_resources
from kubernetes.client import V1PodSpec, V1Container, V1ResourceRequirements


def construct_k8s_pod_spec(
k8s_pod_name: str,
requests: typing.Optional[Resources],
limits: typing.Optional[Resources],
) -> dict[str, typing.Any]:

def construct_k8s_pods_resources(resources: typing.Optional[Resources]):
if resources is None:
return None

resources_map = {
"cpu": "cpu",
"mem": "memory",
"gpu": "nvidia.com/gpu",
"ephemeral_storage": "ephemeral-storage",
}

k8s_pod_resources = {}

for resource in fields(resources):
resource_value = getattr(resources, resource.name)
if resource_value is not None:
k8s_pod_resources[resources_map[resource.name]] = resource_value

print(k8s_pod_resources)
return k8s_pod_resources

requests = construct_k8s_pods_resources(resources=requests)
limits = construct_k8s_pods_resources(resources=limits)
requests = requests or limits
limits = limits or requests

k8s_pod = V1PodSpec(
containers=[
V1Container(
name=k8s_pod_name,
resources=V1ResourceRequirements(
requests=requests,
limits=limits,
),
)
]
)

return k8s_pod.to_dict()


class WorkerGroupSpec(_common.FlyteIdlEntity):
def __init__(
self,
Expand All @@ -79,11 +32,10 @@ def __init__(
self._limits = limits
self._k8s_pod = K8sPod(
metadata=K8sObjectMetadata(),
pod_spec=construct_k8s_pod_spec(
pod_spec=construct_k8s_pod_spec_from_resources(
k8s_pod_name="ray-worker", requests=self._requests, limits=self._limits
),
)
# exit(0)

@property
def group_name(self):
Expand Down Expand Up @@ -179,7 +131,7 @@ def __init__(

self._k8s_pod = K8sPod(
metadata=K8sObjectMetadata(),
pod_spec=construct_k8s_pod_spec(
pod_spec=construct_k8s_pod_spec_from_resources(
k8s_pod_name="ray-head", requests=self._requests, limits=self._limits
),
)
Expand Down

0 comments on commit 2778db2

Please sign in to comment.