Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ray Task Support #1093

Merged
merged 31 commits into from
Aug 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/pythonbuild.yml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ jobs:
- flytekit-pandera
- flytekit-papermill
- flytekit-polars
- flytekit-ray
- flytekit-snowflake
- flytekit-spark
- flytekit-sqlalchemy
Expand Down
2 changes: 1 addition & 1 deletion flytekit/core/python_function_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def __init__(
can be used to inject some client side variables only. Prefer using ExecutionParams
:param Optional[ExecutionBehavior] execution_mode: Defines how the execution should behave, for example
executing normally or specially handling a dynamic case.
:param Optional[TaskResolverMixin] task_type: String task type to be associated with this Task
:param str task_type: String task type to be associated with this Task
"""
if task_function is None:
raise ValueError("TaskFunction is a required parameter for PythonFunctionTask")
Expand Down
9 changes: 9 additions & 0 deletions plugins/flytekit-ray/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Flytekit Ray Plugin

Flyte backend can be connected with Ray. Once enabled, it allows you to run flyte task on Ray cluster

To install the plugin, run the following command:

```bash
pip install flytekitplugins-ray
```
13 changes: 13 additions & 0 deletions plugins/flytekit-ray/flytekitplugins/ray/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""
.. currentmodule:: flytekitplugins.ray

This package contains things that are useful when extending Flytekit.

.. autosummary::
:template: custom.rst
:toctree: generated/

RayConfig
"""

from .task import HeadNodeConfig, RayJobConfig, WorkerNodeConfig
204 changes: 204 additions & 0 deletions plugins/flytekit-ray/flytekitplugins/ray/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
import typing

from flyteidl.plugins import ray_pb2 as _ray_pb2

from flytekit.models import common as _common


class WorkerGroupSpec(_common.FlyteIdlEntity):
def __init__(
self,
group_name: str,
replicas: int,
min_replicas: typing.Optional[int] = 0,
max_replicas: typing.Optional[int] = None,
ray_start_params: typing.Optional[typing.Dict[str, str]] = None,
):
self._group_name = group_name
self._replicas = replicas
self._min_replicas = min_replicas
self._max_replicas = max_replicas if max_replicas else replicas
self._ray_start_params = ray_start_params

@property
def group_name(self):
"""
Group name of the current worker group.
:rtype: str
"""
return self._group_name

@property
def replicas(self):
"""
Desired replicas of the worker group.
:rtype: int
"""
return self._replicas

@property
def min_replicas(self):
"""
Min replicas of the worker group.
:rtype: int
"""
return self._min_replicas

@property
def max_replicas(self):
"""
Max replicas of the worker group.
:rtype: int
"""
return self._max_replicas

@property
def ray_start_params(self):
"""
The ray start params of worker node group.
:rtype: typing.Dict[str, str]
"""
return self._ray_start_params

def to_flyte_idl(self):
"""
:rtype: flyteidl.plugins._ray_pb2.WorkerGroupSpec
"""
return _ray_pb2.WorkerGroupSpec(
group_name=self.group_name,
replicas=self.replicas,
min_replicas=self.min_replicas,
max_replicas=self.max_replicas,
ray_start_params=self.ray_start_params,
)

@classmethod
def from_flyte_idl(cls, proto):
"""
:param flyteidl.plugins._ray_pb2.WorkerGroupSpec proto:
:rtype: WorkerGroupSpec
"""
return cls(
group_name=proto.group_name,
replicas=proto.replicas,
min_replicas=proto.min_replicas,
max_replicas=proto.max_replicas,
ray_start_params=proto.ray_start_params,
)


class HeadGroupSpec(_common.FlyteIdlEntity):
def __init__(
self,
ray_start_params: typing.Optional[typing.Dict[str, str]] = None,
):
self._ray_start_params = ray_start_params

@property
def ray_start_params(self):
"""
The ray start params of worker node group.
:rtype: typing.Dict[str, str]
"""
return self._ray_start_params

def to_flyte_idl(self):
"""
:rtype: flyteidl.plugins._ray_pb2.HeadGroupSpec
"""
return _ray_pb2.HeadGroupSpec(
ray_start_params=self.ray_start_params if self.ray_start_params else {},
)

@classmethod
def from_flyte_idl(cls, proto):
"""
:param flyteidl.plugins._ray_pb2.HeadGroupSpec proto:
:rtype: HeadGroupSpec
"""
return cls(
ray_start_params=proto.ray_start_params,
)


class RayCluster(_common.FlyteIdlEntity):
"""
Define RayCluster spec that will be used by KubeRay to launch the cluster.
"""

def __init__(
self, worker_group_spec: typing.List[WorkerGroupSpec], head_group_spec: typing.Optional[HeadGroupSpec] = None
):
self._head_group_spec = head_group_spec
self._worker_group_spec = worker_group_spec

@property
def head_group_spec(self) -> HeadGroupSpec:
"""
The head group configuration.
:rtype: HeadGroupSpec
"""
return self._head_group_spec

@property
def worker_group_spec(self) -> typing.List[WorkerGroupSpec]:
"""
The worker group configurations.
:rtype: typing.List[WorkerGroupSpec]
"""
return self._worker_group_spec

def to_flyte_idl(self) -> _ray_pb2.RayCluster:
"""
:rtype: flyteidl.plugins._ray_pb2.RayCluster
"""
return _ray_pb2.RayCluster(
head_group_spec=self.head_group_spec.to_flyte_idl() if self.head_group_spec else None,
worker_group_spec=[wg.to_flyte_idl() for wg in self.worker_group_spec],
)

@classmethod
def from_flyte_idl(cls, proto):
"""
:param flyteidl.plugins._ray_pb2.RayCluster proto:
:rtype: RayCluster
"""
return cls(
head_group_spec=HeadGroupSpec.from_flyte_idl(proto.head_group_spec) if proto.head_group_spec else None,
worker_group_spec=[WorkerGroupSpec.from_flyte_idl(wg) for wg in proto.worker_group_spec],
)


class RayJob(_common.FlyteIdlEntity):
"""
Models _ray_pb2.RayJob
"""

def __init__(
self,
ray_cluster: RayCluster,
runtime_env: typing.Optional[str],
):
self._ray_cluster = ray_cluster
self._runtime_env = runtime_env

@property
def ray_cluster(self) -> RayCluster:
return self._ray_cluster

@property
def runtime_env(self) -> typing.Optional[str]:
return self._runtime_env

def to_flyte_idl(self) -> _ray_pb2.RayJob:
return _ray_pb2.RayJob(
ray_cluster=self.ray_cluster.to_flyte_idl(),
runtime_env=self.runtime_env,
)

@classmethod
def from_flyte_idl(cls, proto: _ray_pb2.RayJob):
return cls(
ray_cluster=RayCluster.from_flyte_idl(proto.ray_cluster) if proto.ray_cluster else None,
runtime_env=proto.runtime_env,
)
76 changes: 76 additions & 0 deletions plugins/flytekit-ray/flytekitplugins/ray/task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import base64
import json
import typing
from dataclasses import dataclass
from typing import Any, Callable, Dict, Optional

import ray
from flytekitplugins.ray.models import HeadGroupSpec, RayCluster, RayJob, WorkerGroupSpec
from google.protobuf.json_format import MessageToDict

from flytekit.configuration import SerializationSettings
from flytekit.core.context_manager import ExecutionParameters
from flytekit.core.python_function_task import PythonFunctionTask
from flytekit.extend import TaskPlugins


@dataclass
class HeadNodeConfig:
ray_start_params: typing.Optional[typing.Dict[str, str]] = None


@dataclass
class WorkerNodeConfig:
group_name: str
replicas: int
min_replicas: typing.Optional[int] = None
max_replicas: typing.Optional[int] = None
ray_start_params: typing.Optional[typing.Dict[str, str]] = None


@dataclass
class RayJobConfig:
worker_node_config: typing.List[WorkerNodeConfig]
head_node_config: typing.Optional[HeadNodeConfig] = None
runtime_env: typing.Optional[dict] = None
address: typing.Optional[str] = None


class RayFunctionTask(PythonFunctionTask):
"""
Actual Plugin that transforms the local python code for execution within Ray job.
"""

_RAY_TASK_TYPE = "ray"

def __init__(self, task_config: RayJobConfig, task_function: Callable, **kwargs):
super().__init__(task_config=task_config, task_type=self._RAY_TASK_TYPE, task_function=task_function, **kwargs)
self._task_config = task_config

def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters:
ray.init(address=self._task_config.address)
return user_params

def post_execute(self, user_params: ExecutionParameters, rval: Any) -> Any:
ray.shutdown()
return rval

def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any]]:
cfg = self._task_config

ray_job = RayJob(
ray_cluster=RayCluster(
head_group_spec=HeadGroupSpec(cfg.head_node_config.ray_start_params) if cfg.head_node_config else None,
worker_group_spec=[
WorkerGroupSpec(c.group_name, c.replicas, c.min_replicas, c.max_replicas, c.ray_start_params)
for c in cfg.worker_node_config
],
),
# Use base64 to encode runtime_env dict and convert it to byte string
runtime_env=base64.b64encode(json.dumps(cfg.runtime_env).encode()).decode(),
)
return MessageToDict(ray_job.to_flyte_idl())


# Inject the Ray plugin into flytekits dynamic plugin loading system
TaskPlugins.register_pythontask_plugin(RayJobConfig, RayFunctionTask)
2 changes: 2 additions & 0 deletions plugins/flytekit-ray/requirements.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
.
-e file:.#egg=flytekitplugins-ray
Loading