diff --git a/flytekit/models/task.py b/flytekit/models/task.py index d2d55fa759..ba15b3a3c8 100644 --- a/flytekit/models/task.py +++ b/flytekit/models/task.py @@ -6,9 +6,9 @@ from flyteidl.core import compiler_pb2 as _compiler from flyteidl.core import literals_pb2 as _literals_pb2 from flyteidl.core import tasks_pb2 as _core_task +from flyteidl.plugins import mpi_pb2 as _mpi_task from flyteidl.plugins import pytorch_pb2 as _pytorch_task from flyteidl.plugins import spark_pb2 as _spark_task -from flyteidl.plugins import mpi_pb2 as _mpi_task from flyteidl.plugins import tensorflow_pb2 as _tensorflow_task from google.protobuf import json_format as _json_format from google.protobuf import struct_pb2 as _struct @@ -1150,6 +1150,7 @@ def from_flyte_idl(cls, pb2_object): chief_replicas_count=pb2_object.chief_replicas, ) + class MPIJob(_common.FlyteIdlEntity): def __init__(self, num_workers, num_launcher_replicas, slots): self._num_workers = num_workers @@ -1179,4 +1180,4 @@ def from_flyte_idl(cls, pb2_object): num_workers=pb2_object.num_workers, num_launcher_replicas=pb2_object.num_launcher_replicas, slots=pb2_object.slots, - ) \ No newline at end of file + ) diff --git a/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py b/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py index 840183fcd1..e51495fdc8 100644 --- a/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py +++ b/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py @@ -3,11 +3,11 @@ Kubernetes. It leverages `TF Job `_ Plugin from kubeflow. """ from dataclasses import dataclass -from typing import Any, Callable, Dict, Optional, List +from typing import Any, Callable, Dict, List from google.protobuf.json_format import MessageToDict -from flytekit import PythonFunctionTask, Resources +from flytekit import PythonFunctionTask from flytekit.extend import SerializationSettings, TaskPlugins from flytekit.models import task as model @@ -25,19 +25,11 @@ class MPIJob(object): num_launcher_replicas: Number of launcher server replicas to use slots: Number of slots per worker used in hostfile - - per_replica_requests: [optional] lower-bound resources for each replica spawned for this job - (i.e. both for (main)master and workers). Default is set by platform-level configuration. - - per_replica_limits: [optional] upper-bound resources for each replica spawned for this job. If not specified - the scheduled resource may not have all the resources """ slots: int num_launcher_replicas: int = 1 num_workers: int = 1 - per_replica_requests: Optional[Resources] = None - per_replica_limits: Optional[Resources] = None class MPIFunctionTask(PythonFunctionTask[MPIJob]): @@ -50,7 +42,7 @@ class MPIFunctionTask(PythonFunctionTask[MPIJob]): _MPI_BASE_COMMAND = [ "mpirun", "--allow-run-as-root", - "bind-to", + "-bind-to", "none", "-map-by", "slot", @@ -73,14 +65,18 @@ def __init__(self, task_config: MPIJob, task_function: Callable, **kwargs): task_config=task_config, task_function=task_function, task_type=self._MPI_JOB_TASK_TYPE, - **{**kwargs, "requests": task_config.per_replica_requests, "limits": task_config.per_replica_limits} + **kwargs, ) def get_command(self, settings: SerializationSettings) -> List[str]: cmd = super().get_command(settings) num_procs = self.task_config.num_workers * self.task_config.slots - mpi_cmd = self._MPI_BASE_COMMAND + ["-np", f"{num_procs}"] + ["python", - settings.entrypoint_settings.path] + cmd[1:] + mpi_cmd = ( + self._MPI_BASE_COMMAND + + ["-np", f"{num_procs}"] + + ["python", settings.entrypoint_settings.path, "pyflyte-execute"] + + cmd[1:] + ) # the hostfile is set automatically by MPIOperator using env variable OMPI_MCA_orte_default_hostfile return mpi_cmd diff --git a/plugins/flytekit-kf-mpi/setup.py b/plugins/flytekit-kf-mpi/setup.py index aa3b50544b..ae6e689e7d 100644 --- a/plugins/flytekit-kf-mpi/setup.py +++ b/plugins/flytekit-kf-mpi/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = [] +plugin_requires = ["flytekit>=0.16.0b0,<1.0.0"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-kf-mpi/tests/test_mpi_task.py b/plugins/flytekit-kf-mpi/tests/test_mpi_task.py index 7b0a81ee7f..03895b72e2 100644 --- a/plugins/flytekit-kf-mpi/tests/test_mpi_task.py +++ b/plugins/flytekit-kf-mpi/tests/test_mpi_task.py @@ -6,7 +6,12 @@ def test_mpi_task(): - @task(task_config=MPIJob(num_workers=10, num_launcher_replicas=10, slots=1, per_replica_requests=Resources(cpu="1")), cache=True, cache_version="1") + @task( + task_config=MPIJob(num_workers=10, num_launcher_replicas=10, slots=1), + requests=Resources(cpu="1"), + cache=True, + cache_version="1", + ) def my_mpi_task(x: int, y: str) -> int: return x @@ -24,7 +29,7 @@ def my_mpi_task(x: int, y: str) -> int: entrypoint_settings=EntrypointSettings(path="/etc/my-entrypoint", command="my-entrypoint"), ) - assert my_mpi_task.get_custom(settings) == {"workers": 10} + assert my_mpi_task.get_custom(settings) == {"numLauncherReplicas": 10, "numWorkers": 10, "slots": 1} assert my_mpi_task.resources.limits == Resources() assert my_mpi_task.resources.requests == Resources(cpu="1") assert my_mpi_task.task_type == "mpi" diff --git a/setup.py b/setup.py index b851901c25..115e8a3ef9 100644 --- a/setup.py +++ b/setup.py @@ -64,6 +64,7 @@ ] }, install_requires=[ + "flyteidl>=0.21.4,<0.22.0", "wheel>=0.30.0,<1.0.0", "pandas>=1.0.0,<2.0.0", "pyarrow>=2.0.0,<4.0.0",