diff --git a/.github/workflows/pythonbuild.yml b/.github/workflows/pythonbuild.yml index ca6c85bb3e..c6d8424108 100644 --- a/.github/workflows/pythonbuild.yml +++ b/.github/workflows/pythonbuild.yml @@ -80,6 +80,7 @@ jobs: - flytekit-pandera - flytekit-papermill - flytekit-polars + - flytekit-ray - flytekit-snowflake - flytekit-spark - flytekit-sqlalchemy diff --git a/flytekit/core/python_function_task.py b/flytekit/core/python_function_task.py index 84a8eaedef..42c6fc643c 100644 --- a/flytekit/core/python_function_task.py +++ b/flytekit/core/python_function_task.py @@ -109,7 +109,7 @@ def __init__( can be used to inject some client side variables only. Prefer using ExecutionParams :param Optional[ExecutionBehavior] execution_mode: Defines how the execution should behave, for example executing normally or specially handling a dynamic case. - :param Optional[TaskResolverMixin] task_type: String task type to be associated with this Task + :param str task_type: String task type to be associated with this Task """ if task_function is None: raise ValueError("TaskFunction is a required parameter for PythonFunctionTask") diff --git a/plugins/flytekit-ray/README.md b/plugins/flytekit-ray/README.md new file mode 100644 index 0000000000..f7db403a6c --- /dev/null +++ b/plugins/flytekit-ray/README.md @@ -0,0 +1,9 @@ +# Flytekit Ray Plugin + +Flyte backend can be connected with Ray. Once enabled, it allows you to run flyte task on Ray cluster + +To install the plugin, run the following command: + +```bash +pip install flytekitplugins-ray +``` diff --git a/plugins/flytekit-ray/flytekitplugins/ray/__init__.py b/plugins/flytekit-ray/flytekitplugins/ray/__init__.py new file mode 100644 index 0000000000..44543df900 --- /dev/null +++ b/plugins/flytekit-ray/flytekitplugins/ray/__init__.py @@ -0,0 +1,13 @@ +""" +.. currentmodule:: flytekitplugins.ray + +This package contains things that are useful when extending Flytekit. + +.. autosummary:: + :template: custom.rst + :toctree: generated/ + + RayConfig +""" + +from .task import HeadNodeConfig, RayJobConfig, WorkerNodeConfig diff --git a/plugins/flytekit-ray/flytekitplugins/ray/models.py b/plugins/flytekit-ray/flytekitplugins/ray/models.py new file mode 100644 index 0000000000..080f1239b4 --- /dev/null +++ b/plugins/flytekit-ray/flytekitplugins/ray/models.py @@ -0,0 +1,204 @@ +import typing + +from flyteidl.plugins import ray_pb2 as _ray_pb2 + +from flytekit.models import common as _common + + +class WorkerGroupSpec(_common.FlyteIdlEntity): + def __init__( + self, + group_name: str, + replicas: int, + min_replicas: typing.Optional[int] = 0, + max_replicas: typing.Optional[int] = None, + ray_start_params: typing.Optional[typing.Dict[str, str]] = None, + ): + self._group_name = group_name + self._replicas = replicas + self._min_replicas = min_replicas + self._max_replicas = max_replicas if max_replicas else replicas + self._ray_start_params = ray_start_params + + @property + def group_name(self): + """ + Group name of the current worker group. + :rtype: str + """ + return self._group_name + + @property + def replicas(self): + """ + Desired replicas of the worker group. + :rtype: int + """ + return self._replicas + + @property + def min_replicas(self): + """ + Min replicas of the worker group. + :rtype: int + """ + return self._min_replicas + + @property + def max_replicas(self): + """ + Max replicas of the worker group. + :rtype: int + """ + return self._max_replicas + + @property + def ray_start_params(self): + """ + The ray start params of worker node group. + :rtype: typing.Dict[str, str] + """ + return self._ray_start_params + + def to_flyte_idl(self): + """ + :rtype: flyteidl.plugins._ray_pb2.WorkerGroupSpec + """ + return _ray_pb2.WorkerGroupSpec( + group_name=self.group_name, + replicas=self.replicas, + min_replicas=self.min_replicas, + max_replicas=self.max_replicas, + ray_start_params=self.ray_start_params, + ) + + @classmethod + def from_flyte_idl(cls, proto): + """ + :param flyteidl.plugins._ray_pb2.WorkerGroupSpec proto: + :rtype: WorkerGroupSpec + """ + return cls( + group_name=proto.group_name, + replicas=proto.replicas, + min_replicas=proto.min_replicas, + max_replicas=proto.max_replicas, + ray_start_params=proto.ray_start_params, + ) + + +class HeadGroupSpec(_common.FlyteIdlEntity): + def __init__( + self, + ray_start_params: typing.Optional[typing.Dict[str, str]] = None, + ): + self._ray_start_params = ray_start_params + + @property + def ray_start_params(self): + """ + The ray start params of worker node group. + :rtype: typing.Dict[str, str] + """ + return self._ray_start_params + + def to_flyte_idl(self): + """ + :rtype: flyteidl.plugins._ray_pb2.HeadGroupSpec + """ + return _ray_pb2.HeadGroupSpec( + ray_start_params=self.ray_start_params if self.ray_start_params else {}, + ) + + @classmethod + def from_flyte_idl(cls, proto): + """ + :param flyteidl.plugins._ray_pb2.HeadGroupSpec proto: + :rtype: HeadGroupSpec + """ + return cls( + ray_start_params=proto.ray_start_params, + ) + + +class RayCluster(_common.FlyteIdlEntity): + """ + Define RayCluster spec that will be used by KubeRay to launch the cluster. + """ + + def __init__( + self, worker_group_spec: typing.List[WorkerGroupSpec], head_group_spec: typing.Optional[HeadGroupSpec] = None + ): + self._head_group_spec = head_group_spec + self._worker_group_spec = worker_group_spec + + @property + def head_group_spec(self) -> HeadGroupSpec: + """ + The head group configuration. + :rtype: HeadGroupSpec + """ + return self._head_group_spec + + @property + def worker_group_spec(self) -> typing.List[WorkerGroupSpec]: + """ + The worker group configurations. + :rtype: typing.List[WorkerGroupSpec] + """ + return self._worker_group_spec + + def to_flyte_idl(self) -> _ray_pb2.RayCluster: + """ + :rtype: flyteidl.plugins._ray_pb2.RayCluster + """ + return _ray_pb2.RayCluster( + head_group_spec=self.head_group_spec.to_flyte_idl() if self.head_group_spec else None, + worker_group_spec=[wg.to_flyte_idl() for wg in self.worker_group_spec], + ) + + @classmethod + def from_flyte_idl(cls, proto): + """ + :param flyteidl.plugins._ray_pb2.RayCluster proto: + :rtype: RayCluster + """ + return cls( + head_group_spec=HeadGroupSpec.from_flyte_idl(proto.head_group_spec) if proto.head_group_spec else None, + worker_group_spec=[WorkerGroupSpec.from_flyte_idl(wg) for wg in proto.worker_group_spec], + ) + + +class RayJob(_common.FlyteIdlEntity): + """ + Models _ray_pb2.RayJob + """ + + def __init__( + self, + ray_cluster: RayCluster, + runtime_env: typing.Optional[str], + ): + self._ray_cluster = ray_cluster + self._runtime_env = runtime_env + + @property + def ray_cluster(self) -> RayCluster: + return self._ray_cluster + + @property + def runtime_env(self) -> typing.Optional[str]: + return self._runtime_env + + def to_flyte_idl(self) -> _ray_pb2.RayJob: + return _ray_pb2.RayJob( + ray_cluster=self.ray_cluster.to_flyte_idl(), + runtime_env=self.runtime_env, + ) + + @classmethod + def from_flyte_idl(cls, proto: _ray_pb2.RayJob): + return cls( + ray_cluster=RayCluster.from_flyte_idl(proto.ray_cluster) if proto.ray_cluster else None, + runtime_env=proto.runtime_env, + ) diff --git a/plugins/flytekit-ray/flytekitplugins/ray/task.py b/plugins/flytekit-ray/flytekitplugins/ray/task.py new file mode 100644 index 0000000000..09e21966d7 --- /dev/null +++ b/plugins/flytekit-ray/flytekitplugins/ray/task.py @@ -0,0 +1,76 @@ +import base64 +import json +import typing +from dataclasses import dataclass +from typing import Any, Callable, Dict, Optional + +import ray +from flytekitplugins.ray.models import HeadGroupSpec, RayCluster, RayJob, WorkerGroupSpec +from google.protobuf.json_format import MessageToDict + +from flytekit.configuration import SerializationSettings +from flytekit.core.context_manager import ExecutionParameters +from flytekit.core.python_function_task import PythonFunctionTask +from flytekit.extend import TaskPlugins + + +@dataclass +class HeadNodeConfig: + ray_start_params: typing.Optional[typing.Dict[str, str]] = None + + +@dataclass +class WorkerNodeConfig: + group_name: str + replicas: int + min_replicas: typing.Optional[int] = None + max_replicas: typing.Optional[int] = None + ray_start_params: typing.Optional[typing.Dict[str, str]] = None + + +@dataclass +class RayJobConfig: + worker_node_config: typing.List[WorkerNodeConfig] + head_node_config: typing.Optional[HeadNodeConfig] = None + runtime_env: typing.Optional[dict] = None + address: typing.Optional[str] = None + + +class RayFunctionTask(PythonFunctionTask): + """ + Actual Plugin that transforms the local python code for execution within Ray job. + """ + + _RAY_TASK_TYPE = "ray" + + def __init__(self, task_config: RayJobConfig, task_function: Callable, **kwargs): + super().__init__(task_config=task_config, task_type=self._RAY_TASK_TYPE, task_function=task_function, **kwargs) + self._task_config = task_config + + def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters: + ray.init(address=self._task_config.address) + return user_params + + def post_execute(self, user_params: ExecutionParameters, rval: Any) -> Any: + ray.shutdown() + return rval + + def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any]]: + cfg = self._task_config + + ray_job = RayJob( + ray_cluster=RayCluster( + head_group_spec=HeadGroupSpec(cfg.head_node_config.ray_start_params) if cfg.head_node_config else None, + worker_group_spec=[ + WorkerGroupSpec(c.group_name, c.replicas, c.min_replicas, c.max_replicas, c.ray_start_params) + for c in cfg.worker_node_config + ], + ), + # Use base64 to encode runtime_env dict and convert it to byte string + runtime_env=base64.b64encode(json.dumps(cfg.runtime_env).encode()).decode(), + ) + return MessageToDict(ray_job.to_flyte_idl()) + + +# Inject the Ray plugin into flytekits dynamic plugin loading system +TaskPlugins.register_pythontask_plugin(RayJobConfig, RayFunctionTask) diff --git a/plugins/flytekit-ray/requirements.in b/plugins/flytekit-ray/requirements.in new file mode 100644 index 0000000000..a657d0cce3 --- /dev/null +++ b/plugins/flytekit-ray/requirements.in @@ -0,0 +1,2 @@ +. +-e file:.#egg=flytekitplugins-ray diff --git a/plugins/flytekit-ray/requirements.txt b/plugins/flytekit-ray/requirements.txt new file mode 100644 index 0000000000..080bff46b7 --- /dev/null +++ b/plugins/flytekit-ray/requirements.txt @@ -0,0 +1,218 @@ +# +# This file is autogenerated by pip-compile with python 3.9 +# To update, run: +# +# pip-compile requirements.in +# +-e file:.#egg=flytekitplugins-ray + # via -r requirements.in +aiosignal==1.2.0 + # via ray +arrow==1.2.2 + # via jinja2-time +attrs==21.4.0 + # via + # jsonschema + # ray +binaryornot==0.4.4 + # via cookiecutter +certifi==2022.5.18.1 + # via requests +cffi==1.15.0 + # via cryptography +chardet==4.0.0 + # via binaryornot +charset-normalizer==2.0.12 + # via requests +click==8.0.4 + # via + # cookiecutter + # flytekit + # ray +cloudpickle==2.1.0 + # via flytekit +cookiecutter==1.7.3 + # via flytekit +croniter==1.3.5 + # via flytekit +cryptography==37.0.2 + # via pyopenssl +dataclasses-json==0.5.7 + # via flytekit +decorator==5.1.1 + # via retry +deprecated==1.2.13 + # via flytekit +diskcache==5.4.0 + # via flytekit +distlib==0.3.4 + # via virtualenv +docker==5.0.3 + # via flytekit +docker-image-py==0.1.12 + # via flytekit +docstring-parser==0.14.1 + # via flytekit +filelock==3.7.1 + # via + # ray + # virtualenv +flyteidl==1.1.10 + # via + # flytekit + # flytekitplugins-ray +flytekit==1.1.0 + # via flytekitplugins-ray +frozenlist==1.3.0 + # via + # aiosignal + # ray +googleapis-common-protos==1.56.1 + # via + # flyteidl + # grpcio-status +grpcio==1.43.0 + # via + # flytekit + # grpcio-status + # ray +grpcio-status==1.43.0 + # via flytekit +idna==3.3 + # via requests +importlib-metadata==4.11.3 + # via + # flytekit + # keyring +jinja2==3.1.2 + # via + # cookiecutter + # jinja2-time +jinja2-time==0.2.0 + # via cookiecutter +jsonschema==4.6.1 + # via ray +keyring==23.5.0 + # via flytekit +markupsafe==2.1.1 + # via jinja2 +marshmallow==3.15.0 + # via + # dataclasses-json + # marshmallow-enum + # marshmallow-jsonschema +marshmallow-enum==1.5.1 + # via dataclasses-json +marshmallow-jsonschema==0.13.0 + # via flytekit +msgpack==1.0.4 + # via ray +mypy-extensions==0.4.3 + # via typing-inspect +natsort==8.1.0 + # via flytekit +numpy==1.21.6 + # via + # pandas + # pyarrow + # ray +packaging==21.3 + # via marshmallow +pandas==1.3.5 + # via flytekit +platformdirs==2.5.2 + # via virtualenv +poyo==0.5.0 + # via cookiecutter +protobuf==3.20.1 + # via + # flyteidl + # flytekit + # googleapis-common-protos + # grpcio-status + # protoc-gen-swagger + # ray +protoc-gen-swagger==0.1.0 + # via flyteidl +py==1.11.0 + # via retry +pyarrow==6.0.1 + # via flytekit +pycparser==2.21 + # via cffi +pyopenssl==22.0.0 + # via flytekit +pyparsing==3.0.9 + # via packaging +pyrsistent==0.18.1 + # via jsonschema +python-dateutil==2.8.2 + # via + # arrow + # croniter + # flytekit + # pandas +python-json-logger==2.0.2 + # via flytekit +python-slugify==6.1.2 + # via cookiecutter +pytimeparse==1.1.8 + # via flytekit +pytz==2022.1 + # via + # flytekit + # pandas +pyyaml==6.0 + # via + # flytekit + # ray +ray==1.13.0 + # via flytekitplugins-ray +regex==2022.4.24 + # via docker-image-py +requests==2.27.1 + # via + # cookiecutter + # docker + # flytekit + # ray + # responses +responses==0.20.0 + # via flytekit +retry==0.9.2 + # via flytekit +six==1.16.0 + # via + # cookiecutter + # grpcio + # python-dateutil + # virtualenv +sortedcontainers==2.4.0 + # via flytekit +statsd==3.3.0 + # via flytekit +text-unidecode==1.3 + # via python-slugify +typing-extensions==4.2.0 + # via + # flytekit + # typing-inspect +typing-inspect==0.7.1 + # via dataclasses-json +urllib3==1.26.9 + # via + # flytekit + # requests + # responses +virtualenv==20.15.1 + # via ray +websocket-client==1.3.2 + # via docker +wheel==0.37.1 + # via flytekit +wrapt==1.14.1 + # via + # deprecated + # flytekit +zipp==3.8.0 + # via importlib-metadata diff --git a/plugins/flytekit-ray/setup.py b/plugins/flytekit-ray/setup.py new file mode 100644 index 0000000000..3c73e5cf27 --- /dev/null +++ b/plugins/flytekit-ray/setup.py @@ -0,0 +1,36 @@ +from setuptools import setup + +PLUGIN_NAME = "ray" + +microlib_name = f"flytekitplugins-{PLUGIN_NAME}" + +plugin_requires = ["ray", "flytekit>=1.1.0b0,<1.2.0", "flyteidl>=1.1.10"] + +__version__ = "0.0.0+develop" + +setup( + name=microlib_name, + version=__version__, + author="flyteorg", + author_email="admin@flyte.org", + description="This package holds the Ray plugins for flytekit", + namespace_packages=["flytekitplugins"], + packages=[f"flytekitplugins.{PLUGIN_NAME}"], + install_requires=plugin_requires, + license="apache2", + python_requires=">=3.7", + classifiers=[ + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + ], +) diff --git a/plugins/flytekit-ray/tests/__init__.py b/plugins/flytekit-ray/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/plugins/flytekit-ray/tests/test_ray.py b/plugins/flytekit-ray/tests/test_ray.py new file mode 100644 index 0000000000..8bcebf7937 --- /dev/null +++ b/plugins/flytekit-ray/tests/test_ray.py @@ -0,0 +1,68 @@ +import base64 +import json + +import ray +from flytekitplugins.ray.models import RayCluster, RayJob, WorkerGroupSpec +from flytekitplugins.ray.task import RayJobConfig, WorkerNodeConfig +from google.protobuf.json_format import MessageToDict + +from flytekit import PythonFunctionTask, task +from flytekit.configuration import Image, ImageConfig, SerializationSettings + +config = RayJobConfig( + worker_node_config=[WorkerNodeConfig(group_name="test_group", replicas=3)], + runtime_env={"pip": ["numpy"]}, +) + + +def test_ray_task(): + @task(task_config=config) + def t1(a: int) -> str: + assert ray.is_initialized() + inc = a + 2 + return str(inc) + + assert t1.task_config is not None + assert t1.task_config == config + assert t1.task_type == "ray" + assert isinstance(t1, PythonFunctionTask) + + default_img = Image(name="default", fqn="test", tag="tag") + settings = SerializationSettings( + project="proj", + domain="dom", + version="123", + image_config=ImageConfig(default_image=default_img, images=[default_img]), + env={}, + ) + + ray_job_pb = RayJob( + ray_cluster=RayCluster(worker_group_spec=[WorkerGroupSpec("test_group", 3)]), + runtime_env=base64.b64encode(json.dumps({"pip": ["numpy"]}).encode()).decode(), + ).to_flyte_idl() + + assert t1.get_custom(settings) == MessageToDict(ray_job_pb) + + assert t1.get_command(settings) == [ + "pyflyte-execute", + "--inputs", + "{{.input}}", + "--output-prefix", + "{{.outputPrefix}}", + "--raw-output-data-prefix", + "{{.rawOutputDataPrefix}}", + "--checkpoint-path", + "{{.checkpointOutputPrefix}}", + "--prev-checkpoint", + "{{.prevCheckpointPrefix}}", + "--resolver", + "flytekit.core.python_auto_container.default_task_resolver", + "--", + "task-module", + "tests.test_ray", + "task-name", + "t1", + ] + + assert t1(a=3) == "5" + assert not ray.is_initialized() diff --git a/plugins/setup.py b/plugins/setup.py index 8f3cc5c299..39163aa5df 100644 --- a/plugins/setup.py +++ b/plugins/setup.py @@ -23,6 +23,7 @@ "flytekitplugins-modin": "flytekit-modin", "flytekitplugins-pandera": "flytekit-pandera", "flytekitplugins-papermill": "flytekit-papermill", + "flytekitplugins-ray": "flytekit-ray", "flytekitplugins-snowflake": "flytekit-snowflake", "flytekitplugins-spark": "flytekit-spark", "flytekitplugins-sqlalchemy": "flytekit-sqlalchemy",