Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce pod_spec_from_resources()ray helper function #2943

Merged
merged 25 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 52 additions & 3 deletions flytekit/core/resources.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dataclasses import dataclass
from typing import List, Optional, Union
from dataclasses import dataclass, fields
from typing import Any, List, Optional, Union

from kubernetes.client import V1Container, V1PodSpec, V1ResourceRequirements
from mashumaro.mixins.json import DataClassJSONMixin

from flytekit.models import task as task_models
Expand Down Expand Up @@ -73,7 +74,10 @@ def _convert_resources_to_resource_entries(resources: Resources) -> List[_Resour
resource_entries.append(_ResourceEntry(name=_ResourceName.GPU, value=str(resources.gpu)))
if resources.ephemeral_storage is not None:
resource_entries.append(
_ResourceEntry(name=_ResourceName.EPHEMERAL_STORAGE, value=str(resources.ephemeral_storage))
_ResourceEntry(
name=_ResourceName.EPHEMERAL_STORAGE,
value=str(resources.ephemeral_storage),
)
)
return resource_entries

Expand All @@ -96,3 +100,48 @@ def convert_resources_to_resource_model(
if limits is not None:
limit_entries = _convert_resources_to_resource_entries(limits)
return task_models.Resources(requests=request_entries, limits=limit_entries)


def construct_k8s_pod_spec_from_resources(
k8s_pod_name: str,
requests: Optional[Resources],
limits: Optional[Resources],
) -> dict[str, Any]:
def _construct_k8s_pods_resources(resources: Optional[Resources], k8s_gpu_resource_key: str = "nvidia.com/gpu"):
if resources is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using other gpus is going to be hard, even if we push this parameter to the outer function (i.e. construct_k8s_pod_spec_from_resources).

return None

resources_map = {
"cpu": "cpu",
"mem": "memory",
"gpu": k8s_gpu_resource_key,
"ephemeral_storage": "ephemeral-storage",
}

k8s_pod_resources = {}

for resource in fields(resources):
resource_value = getattr(resources, resource.name)
if resource_value is not None:
k8s_pod_resources[resources_map[resource.name]] = resource_value

return k8s_pod_resources

requests = _construct_k8s_pods_resources(resources=requests)
limits = _construct_k8s_pods_resources(resources=limits)
requests = requests or limits
limits = limits or requests

k8s_pod = V1PodSpec(
containers=[
V1Container(
name=k8s_pod_name,
resources=V1ResourceRequirements(
requests=requests,
limits=limits,
),
)
]
)

return k8s_pod.to_dict()
28 changes: 23 additions & 5 deletions plugins/flytekit-ray/flytekitplugins/ray/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

from flyteidl.plugins import ray_pb2 as _ray_pb2

from flytekit.core.resources import Resources, construct_k8s_pod_spec_from_resources
from flytekit.models import common as _common
from flytekit.models.task import K8sPod
from flytekit.models.task import K8sObjectMetadata, K8sPod


class WorkerGroupSpec(_common.FlyteIdlEntity):
Expand All @@ -14,14 +15,22 @@ def __init__(
min_replicas: typing.Optional[int] = None,
max_replicas: typing.Optional[int] = None,
ray_start_params: typing.Optional[typing.Dict[str, str]] = None,
k8s_pod: typing.Optional[K8sPod] = None,
eapolinario marked this conversation as resolved.
Show resolved Hide resolved
requests: typing.Optional[Resources] = None,
limits: typing.Optional[Resources] = None,
):
self._group_name = group_name
self._replicas = replicas
self._max_replicas = max(replicas, max_replicas) if max_replicas is not None else replicas
self._min_replicas = min(replicas, min_replicas) if min_replicas is not None else replicas
self._ray_start_params = ray_start_params
self._k8s_pod = k8s_pod
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should keep this as part of the interface and build helper functions that construct valid pod specs instead (as mentioned in the original flyte PR). This is going to help in the other problem we're having with passing the gpu resource name around (in other words, gpu can be an argument of one of the helper function that builds pod specs).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get what you are saying. So we want users to construct the pod specs themself like calling construct_k8s_pod_spec_from_resources() or specifying pod templates in user code?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would make the method name simple, maybe pod from resources

self._requests = requests
self._limits = limits
self._k8s_pod = K8sPod(
metadata=K8sObjectMetadata(),
pod_spec=construct_k8s_pod_spec_from_resources(
k8s_pod_name="ray-worker", requests=self._requests, limits=self._limits
),
)

@property
def group_name(self):
Expand Down Expand Up @@ -104,10 +113,19 @@ class HeadGroupSpec(_common.FlyteIdlEntity):
def __init__(
self,
ray_start_params: typing.Optional[typing.Dict[str, str]] = None,
k8s_pod: typing.Optional[K8sPod] = None,
requests: typing.Optional[Resources] = None,
limits: typing.Optional[Resources] = None,
eapolinario marked this conversation as resolved.
Show resolved Hide resolved
):
self._ray_start_params = ray_start_params
self._k8s_pod = k8s_pod
self._requests = requests
self._limits = limits

self._k8s_pod = K8sPod(
metadata=K8sObjectMetadata(),
pod_spec=construct_k8s_pod_spec_from_resources(
k8s_pod_name="ray-head", requests=self._requests, limits=self._limits
),
)

@property
def ray_start_params(self):
Expand Down
17 changes: 12 additions & 5 deletions plugins/flytekit-ray/flytekitplugins/ray/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,17 @@
from flytekit.configuration import SerializationSettings
from flytekit.core.context_manager import ExecutionParameters, FlyteContextManager
from flytekit.core.python_function_task import PythonFunctionTask
from flytekit.core.resources import Resources
from flytekit.extend import TaskPlugins
from flytekit.models.task import K8sPod

ray = lazy_module("ray")


@dataclass
class HeadNodeConfig:
ray_start_params: typing.Optional[typing.Dict[str, str]] = None
k8s_pod: typing.Optional[K8sPod] = None
requests: typing.Optional[Resources] = None
limits: typing.Optional[Resources] = None


@dataclass
Expand All @@ -37,7 +38,8 @@ class WorkerNodeConfig:
min_replicas: typing.Optional[int] = None
max_replicas: typing.Optional[int] = None
ray_start_params: typing.Optional[typing.Dict[str, str]] = None
k8s_pod: typing.Optional[K8sPod] = None
requests: typing.Optional[Resources] = None
limits: typing.Optional[Resources] = None


@dataclass
Expand Down Expand Up @@ -92,7 +94,11 @@ def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any]
ray_job = RayJob(
ray_cluster=RayCluster(
head_group_spec=(
HeadGroupSpec(cfg.head_node_config.ray_start_params, cfg.head_node_config.k8s_pod)
HeadGroupSpec(
cfg.head_node_config.ray_start_params,
cfg.head_node_config.requests,
cfg.head_node_config.limits,
)
if cfg.head_node_config
else None
),
Expand All @@ -103,7 +109,8 @@ def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any]
c.min_replicas,
c.max_replicas,
c.ray_start_params,
c.k8s_pod,
c.requests,
c.limits,
)
for c in cfg.worker_node_config
],
Expand Down
37 changes: 32 additions & 5 deletions plugins/flytekit-ray/tests/test_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,31 @@
import ray
import yaml
from flytekitplugins.ray import HeadNodeConfig
from flytekitplugins.ray.models import RayCluster, RayJob, WorkerGroupSpec, HeadGroupSpec
from flytekitplugins.ray.models import (
RayCluster,
RayJob,
WorkerGroupSpec,
HeadGroupSpec,
)
from flytekitplugins.ray.task import RayJobConfig, WorkerNodeConfig
from google.protobuf.json_format import MessageToDict
from flytekit.models.task import K8sPod
from flytekit.core.resources import Resources

from flytekit import PythonFunctionTask, task
from flytekit.configuration import Image, ImageConfig, SerializationSettings

config = RayJobConfig(
worker_node_config=[WorkerNodeConfig(group_name="test_group", replicas=3, min_replicas=0, max_replicas=10, k8s_pod=K8sPod(pod_spec={"str": "worker", "int": 1}))],
head_node_config=HeadNodeConfig(k8s_pod=K8sPod(pod_spec={"str": "head", "int": 2})),
worker_node_config=[
WorkerNodeConfig(
group_name="test_group",
replicas=3,
min_replicas=0,
max_replicas=10,
requests=Resources(cpu=2, mem="2Gi"),
limits=Resources(cpu=2, mem="4Gi"),
)
],
head_node_config=HeadNodeConfig(requests=Resources(cpu=2)),
runtime_env={"pip": ["numpy"]},
enable_autoscaling=True,
shutdown_after_job_finishes=True,
Expand Down Expand Up @@ -44,7 +58,20 @@ def t1(a: int) -> str:
)

ray_job_pb = RayJob(
ray_cluster=RayCluster(worker_group_spec=[WorkerGroupSpec(group_name="test_group", replicas=3, min_replicas=0, max_replicas=10, k8s_pod=K8sPod(pod_spec={"str": "worker", "int": 1}))], head_group_spec=HeadGroupSpec(k8s_pod=K8sPod(pod_spec={"str": "head", "int": 2})), enable_autoscaling=True),
ray_cluster=RayCluster(
worker_group_spec=[
WorkerGroupSpec(
group_name="test_group",
replicas=3,
min_replicas=0,
max_replicas=10,
requests=Resources(cpu=2, mem="2Gi"),
limits=Resources(cpu=2, mem="4Gi"),
)
],
head_group_spec=HeadGroupSpec(requests=Resources(cpu=2)),
enable_autoscaling=True,
),
runtime_env=base64.b64encode(json.dumps({"pip": ["numpy"]}).encode()).decode(),
runtime_env_yaml=yaml.dump({"pip": ["numpy"]}),
shutdown_after_job_finishes=True,
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ dependencies = [
"jsonlines",
"jsonpickle",
"keyring>=18.0.1",
"kubernetes>=12.0.1",
"markdown-it-py",
"marshmallow-enum",
"marshmallow-jsonschema>=0.12.0",
Expand Down
56 changes: 55 additions & 1 deletion tests/flytekit/unit/core/test_resources.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from typing import Dict

import pytest
from kubernetes.client import V1Container, V1PodSpec, V1ResourceRequirements

import flytekit.models.task as _task_models
from flytekit import Resources
from flytekit.core.resources import convert_resources_to_resource_model
from flytekit.core.resources import (
construct_k8s_pod_spec_from_resources,
convert_resources_to_resource_model,
)

_ResourceName = _task_models.Resources.ResourceName

Expand Down Expand Up @@ -101,3 +105,53 @@ def test_resources_round_trip():
json_str = original.to_json()
result = Resources.from_json(json_str)
assert original == result


def test_construct_k8s_pod_spec_from_resources_requests_limits_set():
requests = Resources(cpu="1", mem="1Gi", gpu="1", ephemeral_storage="1Gi")
limits = Resources(cpu="4", mem="2Gi", gpu="1", ephemeral_storage="1Gi")
k8s_pod_name = "foo"

expected_pod_spec = V1PodSpec(
containers=[
V1Container(
name=k8s_pod_name,
resources=V1ResourceRequirements(
requests={
"cpu": "1",
"memory": "1Gi",
"nvidia.com/gpu": "1",
"ephemeral-storage": "1Gi",
},
limits={
"cpu": "4",
"memory": "2Gi",
"nvidia.com/gpu": "1",
"ephemeral-storage": "1Gi",
},
),
)
]
)
pod_spec = construct_k8s_pod_spec_from_resources(k8s_pod_name=k8s_pod_name, requests=requests, limits=limits)
assert expected_pod_spec == V1PodSpec(**pod_spec)


def test_construct_k8s_pod_spec_from_resources_requests_set():
requests = Resources(cpu="1", mem="1Gi")
limits = None
k8s_pod_name = "foo"

expected_pod_spec = V1PodSpec(
containers=[
V1Container(
name=k8s_pod_name,
resources=V1ResourceRequirements(
requests={"cpu": "1", "memory": "1Gi"},
limits={"cpu": "1", "memory": "1Gi"},
),
)
]
)
pod_spec = construct_k8s_pod_spec_from_resources(k8s_pod_name=k8s_pod_name, requests=requests, limits=limits)
assert expected_pod_spec == V1PodSpec(**pod_spec)
Loading