diff --git a/plugins/flytekit-kf-mpi/README.md b/plugins/flytekit-kf-mpi/README.md index 35c9444c42..db475868eb 100644 --- a/plugins/flytekit-kf-mpi/README.md +++ b/plugins/flytekit-kf-mpi/README.md @@ -8,4 +8,67 @@ To install the plugin, run the following command: pip install flytekitplugins-kfmpi ``` -_Example coming soon!_ +## Code Example +MPI usage: +```python + @task( + task_config=MPIJob( + launcher=Launcher( + replicas=1, + ), + worker=Worker( + replicas=5, + requests=Resources(cpu="2", mem="2Gi"), + limits=Resources(cpu="4", mem="2Gi"), + ), + slots=2, + ), + cache=True, + requests=Resources(cpu="1"), + cache_version="1", + ) + def my_mpi_task(x: int, y: str) -> int: + return x +``` + + +Horovod Usage: +You can override the command of a replica group by: +```python + @task( + task_config=HorovodJob( + launcher=Launcher( + replicas=1, + requests=Resources(cpu="1"), + limits=Resources(cpu="2"), + ), + worker=Worker( + replicas=1, + command=["/usr/sbin/sshd", "-De", "-f", "/home/jobuser/.sshd_config"], + restart_policy=RestartPolicy.NEVER, + ), + slots=2, + verbose=False, + log_level="INFO", + ), + ) + def my_horovod_task(): + ... +``` + + + + +## Upgrade MPI Plugin from V0 to V1 +MPI plugin is now updated from v0 to v1 to enable more configuration options. +To migrate from v0 to v1, change the following: +1. Update flytepropeller to v1.6.0 +2. Update flytekit version to v1.6.2 +3. Update your code from: +``` + task_config=MPIJob(num_workers=10), +``` +to +``` + task_config=MPIJob(worker=Worker(replicas=10)), +``` diff --git a/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/__init__.py b/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/__init__.py index df5c74288e..7d2107c8ae 100644 --- a/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/__init__.py +++ b/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/__init__.py @@ -10,4 +10,4 @@ MPIJob """ -from .task import HorovodJob, MPIJob +from .task import CleanPodPolicy, HorovodJob, Launcher, MPIJob, RestartPolicy, RunPolicy, Worker diff --git a/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py b/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py index e1c1be0a03..20179c7376 100644 --- a/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py +++ b/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py @@ -2,62 +2,89 @@ This Plugin adds the capability of running distributed MPI training to Flyte using backend plugins, natively on Kubernetes. It leverages `MPI Job `_ Plugin from kubeflow. """ -from dataclasses import dataclass -from typing import Any, Callable, Dict, List +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Callable, Dict, List, Optional, Union -from flyteidl.plugins import mpi_pb2 as _mpi_task +from flyteidl.plugins.kubeflow import common_pb2 as kubeflow_common +from flyteidl.plugins.kubeflow import mpi_pb2 as mpi_task from google.protobuf.json_format import MessageToDict -from flytekit import PythonFunctionTask +from flytekit import PythonFunctionTask, Resources from flytekit.configuration import SerializationSettings +from flytekit.core.resources import convert_resources_to_resource_model from flytekit.extend import TaskPlugins -from flytekit.models import common as _common -class MPIJobModel(_common.FlyteIdlEntity): - """Model definition for MPI the plugin +@dataclass +class RestartPolicy(Enum): + """ + RestartPolicy describes how the replicas should be restarted + """ - Args: - num_workers: integer determining the number of worker replicas spawned in the cluster for this job - (in addition to 1 master). + ALWAYS = kubeflow_common.RESTART_POLICY_ALWAYS + FAILURE = kubeflow_common.RESTART_POLICY_ON_FAILURE + NEVER = kubeflow_common.RESTART_POLICY_NEVER + + +@dataclass +class CleanPodPolicy(Enum): + """ + CleanPodPolicy describes how to deal with pods when the job is finished. + """ - num_launcher_replicas: Number of launcher server replicas to use + NONE = kubeflow_common.CLEANPOD_POLICY_NONE + ALL = kubeflow_common.CLEANPOD_POLICY_ALL + RUNNING = kubeflow_common.CLEANPOD_POLICY_RUNNING - slots: Number of slots per worker used in hostfile - .. note:: - Please use resources=Resources(cpu="1"...) to specify per worker resource +@dataclass +class RunPolicy: + """ + RunPolicy describes some policy to apply to the execution of a kubeflow job. + Args: + clean_pod_policy: Defines the policy for cleaning up pods after the PyTorchJob completes. Default to None. + ttl_seconds_after_finished (int): Defines the TTL for cleaning up finished PyTorchJobs. + active_deadline_seconds (int): Specifies the duration (in seconds) since startTime during which the job. + can remain active before it is terminated. Must be a positive integer. This setting applies only to pods. + where restartPolicy is OnFailure or Always. + backoff_limit (int): Number of retries before marking this job as failed. """ - def __init__(self, num_workers, num_launcher_replicas, slots): - self._num_workers = num_workers - self._num_launcher_replicas = num_launcher_replicas - self._slots = slots + clean_pod_policy: CleanPodPolicy = None + ttl_seconds_after_finished: Optional[int] = None + active_deadline_seconds: Optional[int] = None + backoff_limit: Optional[int] = None - @property - def num_workers(self): - return self._num_workers - @property - def num_launcher_replicas(self): - return self._num_launcher_replicas +@dataclass +class Worker: + """ + Worker replica configuration. Worker command can be customized. If not specified, the worker will use + default command generated by the mpi operator. + """ - @property - def slots(self): - return self._slots + command: Optional[List[str]] = None + image: Optional[str] = None + requests: Optional[Resources] = None + limits: Optional[Resources] = None + replicas: Optional[int] = None + restart_policy: Optional[RestartPolicy] = None - def to_flyte_idl(self): - return _mpi_task.DistributedMPITrainingTask( - num_workers=self.num_workers, num_launcher_replicas=self.num_launcher_replicas, slots=self.slots - ) - @classmethod - def from_flyte_idl(cls, pb2_object): - return cls( - num_workers=pb2_object.num_workers, - num_launcher_replicas=pb2_object.num_launcher_replicas, - slots=pb2_object.slots, - ) +@dataclass +class Launcher: + """ + Launcher replica configuration. Launcher command can be customized. If not specified, the launcher will use + the command specified in the task signature. + """ + + command: Optional[List[str]] = None + image: Optional[str] = None + requests: Optional[Resources] = None + limits: Optional[Resources] = None + replicas: Optional[int] = None + restart_policy: Optional[RestartPolicy] = None @dataclass @@ -67,18 +94,21 @@ class MPIJob(object): to run distributed training on k8s with MPI Args: - num_workers: integer determining the number of worker replicas spawned in the cluster for this job - (in addition to 1 master). - - num_launcher_replicas: Number of launcher server replicas to use - - slots: Number of slots per worker used in hostfile - + launcher: Configuration for the launcher replica group. + worker: Configuration for the worker replica group. + run_policy: Configuration for the run policy. + slots: The number of slots per worker used in the hostfile. + num_launcher_replicas: [DEPRECATED] The number of launcher server replicas to use. This argument is deprecated. + num_workers: [DEPRECATED] The number of worker replicas to spawn in the cluster for this job """ - slots: int - num_launcher_replicas: int = 1 - num_workers: int = 1 + launcher: Launcher = field(default_factory=lambda: Launcher()) + worker: Worker = field(default_factory=lambda: Worker()) + run_policy: Optional[RunPolicy] = field(default_factory=lambda: None) + slots: int = 1 + # Support v0 config for backwards compatibility + num_launcher_replicas: Optional[int] = None + num_workers: Optional[int] = None class MPIFunctionTask(PythonFunctionTask[MPIJob]): @@ -110,6 +140,22 @@ class MPIFunctionTask(PythonFunctionTask[MPIJob]): ] def __init__(self, task_config: MPIJob, task_function: Callable, **kwargs): + if task_config.num_workers and task_config.worker.replicas: + raise ValueError( + "Cannot specify both `num_workers` and `worker.replicas`. Please use `worker.replicas` as `num_workers` is depreacated." + ) + if task_config.num_workers is None and task_config.worker.replicas is None: + raise ValueError( + "Must specify either `num_workers` or `worker.replicas`. Please use `worker.replicas` as `num_workers` is depreacated." + ) + if task_config.num_launcher_replicas and task_config.launcher.replicas: + raise ValueError( + "Cannot specify both `num_workers` and `launcher.replicas`. Please use `launcher.replicas` as `num_launcher_replicas` is depreacated." + ) + if task_config.num_launcher_replicas is None and task_config.launcher.replicas is None: + raise ValueError( + "Must specify either `num_workers` or `launcher.replicas`. Please use `launcher.replicas` as `num_launcher_replicas` is depreacated." + ) super().__init__( task_config=task_config, task_function=task_function, @@ -117,27 +163,87 @@ def __init__(self, task_config: MPIJob, task_function: Callable, **kwargs): **kwargs, ) + def _convert_replica_spec( + self, replica_config: Union[Launcher, Worker] + ) -> mpi_task.DistributedMPITrainingReplicaSpec: + resources = convert_resources_to_resource_model(requests=replica_config.requests, limits=replica_config.limits) + return mpi_task.DistributedMPITrainingReplicaSpec( + command=replica_config.command, + replicas=replica_config.replicas, + image=replica_config.image, + resources=resources.to_flyte_idl() if resources else None, + restart_policy=replica_config.restart_policy.value if replica_config.restart_policy else None, + ) + + def _convert_run_policy(self, run_policy: RunPolicy) -> kubeflow_common.RunPolicy: + return kubeflow_common.RunPolicy( + clean_pod_policy=run_policy.clean_pod_policy.value if run_policy.clean_pod_policy else None, + ttl_seconds_after_finished=run_policy.ttl_seconds_after_finished, + active_deadline_seconds=run_policy.active_deadline_seconds, + backoff_limit=run_policy.active_deadline_seconds, + ) + + def _get_base_command(self, settings: SerializationSettings) -> List[str]: + return super().get_command(settings) + def get_command(self, settings: SerializationSettings) -> List[str]: - cmd = super().get_command(settings) - num_procs = self.task_config.num_workers * self.task_config.slots + cmd = self._get_base_command(settings) + if self.task_config.num_workers: + num_workers = self.task_config.num_workers + else: + num_workers = self.task_config.worker.replicas + num_procs = num_workers * self.task_config.slots mpi_cmd = self._MPI_BASE_COMMAND + ["-np", f"{num_procs}"] + ["python", settings.entrypoint_settings.path] + cmd # the hostfile is set automatically by MPIOperator using env variable OMPI_MCA_orte_default_hostfile return mpi_cmd def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: - job = MPIJobModel( - num_workers=self.task_config.num_workers, - num_launcher_replicas=self.task_config.num_launcher_replicas, + worker = self._convert_replica_spec(self.task_config.worker) + if self.task_config.num_workers: + worker.replicas = self.task_config.num_workers + + launcher = self._convert_replica_spec(self.task_config.launcher) + if self.task_config.num_launcher_replicas: + launcher.replicas = self.task_config.num_launcher_replicas + + run_policy = self._convert_run_policy(self.task_config.run_policy) if self.task_config.run_policy else None + mpi_job = mpi_task.DistributedMPITrainingTask( + worker_replicas=worker, + launcher_replicas=launcher, slots=self.task_config.slots, + run_policy=run_policy, ) - return MessageToDict(job.to_flyte_idl()) + return MessageToDict(mpi_job) @dataclass class HorovodJob(object): - slots: int - num_launcher_replicas: int = 1 - num_workers: int = 1 + """ + Configuration for an executable `Horovod Job using MPI operator`_. Use this + to run distributed training on k8s with MPI. For more info, check out Running Horovod`_. + + Args: + worker: Worker configuration for the job. + launcher: Launcher configuration for the job. + run_policy: Configuration for the run policy. + slots: Number of slots per worker used in hostfile (default: 1). + verbose: Optional flag indicating whether to enable verbose logging (default: False). + log_level: Optional string specifying the log level (default: "INFO"). + discovery_script_path: Path to the discovery script used for host discovery (default: "/etc/mpi/discover_hosts.sh"). + num_launcher_replicas: [DEPRECATED] The number of launcher server replicas to use. This argument is deprecated. Please use launcher.replicas instead. + num_workers: [DEPRECATED] The number of worker replicas to spawn in the cluster for this job. Please use worker.replicas instead. + """ + + worker: Worker = field(default_factory=lambda: Worker()) + launcher: Launcher = field(default_factory=lambda: Launcher()) + run_policy: Optional[RunPolicy] = field(default_factory=lambda: None) + slots: int = 1 + verbose: Optional[bool] = False + log_level: Optional[str] = "INFO" + discovery_script_path: Optional[str] = "/etc/mpi/discover_hosts.sh" + # Support v0 config for backwards compatibility + num_launcher_replicas: Optional[int] = None + num_workers: Optional[int] = None class HorovodFunctionTask(MPIFunctionTask): @@ -146,11 +252,8 @@ class HorovodFunctionTask(MPIFunctionTask): """ # Customize your setup here. Please ensure the cmd, path, volume, etc are available in the pod. - ssh_command = "/usr/sbin/sshd -De -f /home/jobuser/.sshd_config" - discovery_script_path = "/etc/mpi/discover_hosts.sh" def __init__(self, task_config: HorovodJob, task_function: Callable, **kwargs): - super().__init__( task_config=task_config, task_function=task_function, @@ -158,23 +261,21 @@ def __init__(self, task_config: HorovodJob, task_function: Callable, **kwargs): ) def get_command(self, settings: SerializationSettings) -> List[str]: - cmd = super().get_command(settings) + cmd = self._get_base_command(settings) mpi_cmd = self._get_horovod_prefix() + cmd return mpi_cmd - def get_config(self, settings: SerializationSettings) -> Dict[str, str]: - config = super().get_config(settings) - return {**config, "worker_spec_command": self.ssh_command} - def _get_horovod_prefix(self) -> List[str]: - np = self.task_config.num_workers * self.task_config.slots + np = self.task_config.worker.replicas * self.task_config.slots + verbose = "--verbose" if self.task_config.verbose is True else "" + log_level = self.task_config.log_level base_cmd = [ "horovodrun", "-np", f"{np}", - "--verbose", + f"{verbose}", "--log-level", - "INFO", + f"{log_level}", "--network-interface", "eth0", "--min-np", @@ -184,7 +285,7 @@ def _get_horovod_prefix(self) -> List[str]: "--slots-per-host", f"{self.task_config.slots}", "--host-discovery-script", - self.discovery_script_path, + self.task_config.discovery_script_path, ] return base_cmd diff --git a/plugins/flytekit-kf-mpi/requirements.txt b/plugins/flytekit-kf-mpi/requirements.txt index bbf6b17f0d..1a9b0fcb54 100644 --- a/plugins/flytekit-kf-mpi/requirements.txt +++ b/plugins/flytekit-kf-mpi/requirements.txt @@ -1,89 +1,183 @@ # -# This file is autogenerated by pip-compile with python 3.7 -# To update, run: +# This file is autogenerated by pip-compile with Python 3.8 +# by the following command: # # pip-compile requirements.in # -e file:.#egg=flytekitplugins-kfmpi # via -r requirements.in -arrow==1.2.2 +adlfs==2023.4.0 + # via flytekit +aiobotocore==2.5.1 + # via s3fs +aiohttp==3.8.4 + # via + # adlfs + # aiobotocore + # gcsfs + # s3fs +aioitertools==0.11.0 + # via aiobotocore +aiosignal==1.3.1 + # via aiohttp +arrow==1.2.3 # via jinja2-time +async-timeout==4.0.2 + # via aiohttp +attrs==23.1.0 + # via aiohttp +azure-core==1.27.1 + # via + # adlfs + # azure-identity + # azure-storage-blob +azure-datalake-store==0.0.53 + # via adlfs +azure-identity==1.13.0 + # via adlfs +azure-storage-blob==12.16.0 + # via adlfs binaryornot==0.4.4 # via cookiecutter -certifi==2022.6.15 - # via requests +botocore==1.29.161 + # via aiobotocore +cachetools==5.3.1 + # via google-auth +certifi==2023.5.7 + # via + # kubernetes + # requests cffi==1.15.1 - # via cryptography -chardet==5.0.0 + # via + # azure-datalake-store + # cryptography +chardet==5.1.0 # via binaryornot -charset-normalizer==2.1.0 - # via requests +charset-normalizer==3.1.0 + # via + # aiohttp + # requests click==8.1.3 # via # cookiecutter # flytekit -cloudpickle==2.1.0 + # rich-click +cloudpickle==2.2.1 # via flytekit cookiecutter==2.1.1 # via flytekit -croniter==1.3.5 +croniter==1.4.1 # via flytekit -cryptography==37.0.4 +cryptography==41.0.1 # via + # azure-identity + # azure-storage-blob + # msal + # pyjwt # pyopenssl - # secretstorage -dataclasses-json==0.5.7 +dataclasses-json==0.5.8 # via flytekit decorator==5.1.1 - # via retry -deprecated==1.2.13 + # via gcsfs +deprecated==1.2.14 # via flytekit -diskcache==5.4.0 +diskcache==5.6.1 # via flytekit -docker==5.0.3 +docker==6.1.3 # via flytekit docker-image-py==0.1.12 # via flytekit -docstring-parser==0.14.1 +docstring-parser==0.15 # via flytekit -flyteidl==1.2.9 +flyteidl==1.2.11 # via # flytekit # flytekitplugins-kfmpi -flytekit==1.2.7 +flytekit==1.2.13 # via flytekitplugins-kfmpi -googleapis-common-protos==1.56.3 +frozenlist==1.3.3 + # via + # aiohttp + # aiosignal +fsspec==2023.6.0 + # via + # adlfs + # flytekit + # gcsfs + # s3fs +gcsfs==2023.6.0 + # via flytekit +gitdb==4.0.10 + # via gitpython +gitpython==3.1.31 + # via flytekit +google-api-core==2.11.1 + # via + # google-cloud-core + # google-cloud-storage +google-auth==2.21.0 + # via + # gcsfs + # google-api-core + # google-auth-oauthlib + # google-cloud-core + # google-cloud-storage + # kubernetes +google-auth-oauthlib==1.0.0 + # via gcsfs +google-cloud-core==2.3.2 + # via google-cloud-storage +google-cloud-storage==2.10.0 + # via gcsfs +google-crc32c==1.5.0 + # via google-resumable-media +google-resumable-media==2.5.0 + # via google-cloud-storage +googleapis-common-protos==1.59.1 # via # flyteidl + # flytekit + # google-api-core # grpcio-status -grpcio==1.47.0 +grpcio==1.48.2 # via # flytekit # grpcio-status -grpcio-status==1.47.0 +grpcio-status==1.48.2 # via flytekit -idna==3.3 - # via requests -importlib-metadata==4.12.0 +idna==3.4 # via - # click - # flytekit - # keyring -jeepney==0.8.0 + # requests + # yarl +importlib-metadata==6.7.0 # via + # flytekit # keyring - # secretstorage +importlib-resources==5.12.0 + # via keyring +isodate==0.6.1 + # via azure-storage-blob +jaraco-classes==3.2.3 + # via keyring jinja2==3.1.2 # via # cookiecutter # jinja2-time jinja2-time==0.2.0 # via cookiecutter -keyring==23.6.0 +jmespath==1.0.1 + # via botocore +joblib==1.3.1 + # via flytekit +keyring==24.2.0 + # via flytekit +kubernetes==26.1.0 # via flytekit -markupsafe==2.1.1 +markdown-it-py==3.0.0 + # via rich +markupsafe==2.1.3 # via jinja2 -marshmallow==3.17.0 +marshmallow==3.19.0 # via # dataclasses-json # marshmallow-enum @@ -92,51 +186,81 @@ marshmallow-enum==1.5.1 # via dataclasses-json marshmallow-jsonschema==0.13.0 # via flytekit -mypy-extensions==0.4.3 +mdurl==0.1.2 + # via markdown-it-py +more-itertools==9.1.0 + # via jaraco-classes +msal==1.22.0 + # via + # azure-datalake-store + # azure-identity + # msal-extensions +msal-extensions==1.0.0 + # via azure-identity +multidict==6.0.4 + # via + # aiohttp + # yarl +mypy-extensions==1.0.0 # via typing-inspect -natsort==8.1.0 +natsort==8.4.0 # via flytekit -numpy==1.21.6 +numpy==1.23.5 # via # flytekit # pandas # pyarrow -packaging==21.3 - # via marshmallow -pandas==1.3.5 +oauthlib==3.2.2 + # via requests-oauthlib +packaging==23.1 + # via + # docker + # marshmallow +pandas==1.5.3 # via flytekit -protobuf==3.20.2 +portalocker==2.7.0 + # via msal-extensions +protobuf==3.20.3 # via # flyteidl # flytekit + # google-api-core # googleapis-common-protos # grpcio-status # protoc-gen-swagger protoc-gen-swagger==0.1.0 # via flyteidl -py==1.11.0 - # via retry -pyarrow==6.0.1 +pyarrow==10.0.1 # via flytekit +pyasn1==0.5.0 + # via + # pyasn1-modules + # rsa +pyasn1-modules==0.3.0 + # via google-auth pycparser==2.21 # via cffi -pyopenssl==22.0.0 +pygments==2.15.1 + # via rich +pyjwt[crypto]==2.7.0 + # via msal +pyopenssl==23.2.0 # via flytekit -pyparsing==3.0.9 - # via packaging python-dateutil==2.8.2 # via # arrow + # botocore # croniter # flytekit + # kubernetes # pandas -python-json-logger==2.0.2 +python-json-logger==2.0.7 # via flytekit -python-slugify==6.1.2 +python-slugify==8.0.1 # via cookiecutter pytimeparse==1.1.8 # via flytekit -pytz==2022.1 +pytz==2023.3 # via # flytekit # pandas @@ -144,53 +268,95 @@ pyyaml==6.0 # via # cookiecutter # flytekit -regex==2022.6.2 + # kubernetes + # responses +regex==2023.6.3 # via docker-image-py -requests==2.28.1 +requests==2.31.0 # via + # azure-core + # azure-datalake-store # cookiecutter # docker # flytekit + # gcsfs + # google-api-core + # google-cloud-storage + # kubernetes + # msal + # requests-oauthlib # responses -responses==0.21.0 +requests-oauthlib==1.3.1 + # via + # google-auth-oauthlib + # kubernetes +responses==0.23.1 # via flytekit -retry==0.9.2 +rich==13.4.2 + # via + # flytekit + # rich-click +rich-click==1.6.1 # via flytekit -secretstorage==3.3.2 - # via keyring -singledispatchmethod==1.0 +rsa==4.9 + # via google-auth +s3fs==2023.6.0 # via flytekit six==1.16.0 # via + # azure-core + # azure-identity + # google-auth # grpcio + # isodate + # kubernetes # python-dateutil +smmap==5.0.0 + # via gitdb sortedcontainers==2.4.0 # via flytekit statsd==3.3.0 # via flytekit text-unidecode==1.3 # via python-slugify -typing-extensions==4.3.0 +types-pyyaml==6.0.12.10 + # via responses +typing-extensions==4.7.0 # via - # arrow + # aioitertools + # azure-core + # azure-storage-blob # flytekit - # importlib-metadata - # responses + # rich # typing-inspect -typing-inspect==0.7.1 +typing-inspect==0.9.0 # via dataclasses-json -urllib3==1.26.9 +urllib3==1.26.16 # via + # botocore + # docker # flytekit + # google-auth + # kubernetes # requests # responses -websocket-client==1.3.3 - # via docker -wheel==0.38.0 +websocket-client==1.6.1 + # via + # docker + # kubernetes +wheel==0.40.0 # via flytekit -wrapt==1.14.1 +wrapt==1.15.0 # via + # aiobotocore # deprecated # flytekit -zipp==3.8.0 - # via importlib-metadata +yarl==1.9.2 + # via aiohttp +zipp==3.15.0 + # via + # importlib-metadata + # importlib-resources + +# The following packages are considered to be unsafe in a requirements file: +# setuptools diff --git a/plugins/flytekit-kf-mpi/setup.py b/plugins/flytekit-kf-mpi/setup.py index b62edfe56f..30c82f6d58 100644 --- a/plugins/flytekit-kf-mpi/setup.py +++ b/plugins/flytekit-kf-mpi/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.1.0b0,<1.3.0,<2.0.0", "flyteidl>=0.21.4"] +plugin_requires = ["flytekit>=1.1.0b0,<1.3.0,<2.0.0", "flyteidl>=1.2.11,<1.3.0"] __version__ = "0.0.0+develop" @@ -18,12 +18,11 @@ packages=[f"flytekitplugins.{PLUGIN_NAME}"], install_requires=plugin_requires, license="apache2", - python_requires=">=3.6", + python_requires=">=3.8", classifiers=[ "Intended Audience :: Science/Research", "Intended Audience :: Developers", "License :: OSI Approved :: Apache Software License", - "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", diff --git a/plugins/flytekit-kf-mpi/tests/test_mpi_task.py b/plugins/flytekit-kf-mpi/tests/test_mpi_task.py index 7732d520c2..f6eb2655f6 100644 --- a/plugins/flytekit-kf-mpi/tests/test_mpi_task.py +++ b/plugins/flytekit-kf-mpi/tests/test_mpi_task.py @@ -1,22 +1,25 @@ -from flytekitplugins.kfmpi.task import HorovodJob, MPIJob, MPIJobModel +import pytest +from flytekitplugins.kfmpi import CleanPodPolicy, HorovodJob, Launcher, MPIJob, RestartPolicy, RunPolicy, Worker +from flytekitplugins.kfmpi.task import MPIFunctionTask from flytekit import Resources, task from flytekit.configuration import Image, ImageConfig, SerializationSettings -def test_mpi_model_task(): - job = MPIJobModel( - num_workers=1, - num_launcher_replicas=1, - slots=1, +@pytest.fixture +def serialization_settings() -> SerializationSettings: + default_img = Image(name="default", fqn="test", tag="tag") + settings = SerializationSettings( + project="project", + domain="domain", + version="version", + env={"FOO": "baz"}, + image_config=ImageConfig(default_image=default_img, images=[default_img]), ) - assert job.num_workers == 1 - assert job.num_launcher_replicas == 1 - assert job.slots == 1 - assert job.from_flyte_idl(job.to_flyte_idl()) + return settings -def test_mpi_task(): +def test_mpi_task(serialization_settings: SerializationSettings): @task( task_config=MPIJob(num_workers=10, num_launcher_replicas=10, slots=1), requests=Resources(cpu="1"), @@ -30,37 +33,165 @@ def my_mpi_task(x: int, y: str) -> int: assert my_mpi_task.task_config is not None - default_img = Image(name="default", fqn="test", tag="tag") - settings = SerializationSettings( - project="project", - domain="domain", - version="version", - env={"FOO": "baz"}, - image_config=ImageConfig(default_image=default_img, images=[default_img]), + assert my_mpi_task.get_custom(serialization_settings) == { + "launcherReplicas": {"replicas": 10, "resources": {}}, + "workerReplicas": {"replicas": 10, "resources": {}}, + "slots": 1, + } + assert my_mpi_task.task_type == "mpi" + + +def test_mpi_task_with_default_config(serialization_settings: SerializationSettings): + task_config = MPIJob( + worker=Worker(replicas=1), + launcher=Launcher(replicas=1), + ) + + @task( + task_config=task_config, + cache=True, + requests=Resources(cpu="1"), + cache_version="1", + ) + def my_mpi_task(x: int, y: str) -> int: + return x + + assert my_mpi_task(x=10, y="hello") == 10 + + assert my_mpi_task.task_config is not None + assert my_mpi_task.task_type == "mpi" + assert my_mpi_task.resources.limits == Resources() + assert my_mpi_task.resources.requests == Resources(cpu="1") + assert " ".join(my_mpi_task.get_command(serialization_settings)).startswith( + " ".join(MPIFunctionTask._MPI_BASE_COMMAND + ["-np", "1"]) + ) + + expected_dict = { + "launcherReplicas": { + "replicas": 1, + "resources": {}, + }, + "workerReplicas": { + "replicas": 1, + "resources": {}, + }, + "slots": 1, + } + assert my_mpi_task.get_custom(serialization_settings) == expected_dict + + +def test_mpi_task_with_custom_config(serialization_settings: SerializationSettings): + task_config = MPIJob( + launcher=Launcher( + replicas=1, + requests=Resources(cpu="1"), + limits=Resources(cpu="2"), + image="launcher:latest", + ), + worker=Worker( + replicas=5, + requests=Resources(cpu="2", mem="2Gi"), + limits=Resources(cpu="4", mem="2Gi"), + image="worker:latest", + restart_policy=RestartPolicy.NEVER, + ), + run_policy=RunPolicy( + clean_pod_policy=CleanPodPolicy.ALL, + ), + slots=2, ) - assert my_mpi_task.get_custom(settings) == {"numLauncherReplicas": 10, "numWorkers": 10, "slots": 1} + @task( + task_config=task_config, + cache=True, + requests=Resources(cpu="1"), + cache_version="1", + ) + def my_mpi_task(x: int, y: str) -> int: + return x + + assert my_mpi_task(x=10, y="hello") == 10 + + assert my_mpi_task.task_config is not None assert my_mpi_task.task_type == "mpi" + assert my_mpi_task.resources.limits == Resources() + assert my_mpi_task.resources.requests == Resources(cpu="1") + assert " ".join(my_mpi_task.get_command(serialization_settings)).startswith( + " ".join(MPIFunctionTask._MPI_BASE_COMMAND + ["-np", "1"]) + ) + + expected_custom_dict = { + "launcherReplicas": { + "replicas": 1, + "image": "launcher:latest", + "resources": { + "requests": [{"name": "CPU", "value": "1"}], + "limits": [{"name": "CPU", "value": "2"}], + }, + }, + "workerReplicas": { + "replicas": 5, + "image": "worker:latest", + "resources": { + "requests": [ + {"name": "CPU", "value": "2"}, + {"name": "MEMORY", "value": "2Gi"}, + ], + "limits": [ + {"name": "CPU", "value": "4"}, + {"name": "MEMORY", "value": "2Gi"}, + ], + }, + }, + "slots": 2, + "runPolicy": {"cleanPodPolicy": "CLEANPOD_POLICY_ALL"}, + } + assert my_mpi_task.get_custom(serialization_settings) == expected_custom_dict -def test_horovod_task(): +def test_horovod_task(serialization_settings): @task( - task_config=HorovodJob(num_workers=5, num_launcher_replicas=5, slots=1), + task_config=HorovodJob( + launcher=Launcher( + replicas=1, + requests=Resources(cpu="1"), + limits=Resources(cpu="2"), + ), + worker=Worker( + replicas=1, + command=["/usr/sbin/sshd", "-De", "-f", "/home/jobuser/.sshd_config"], + restart_policy=RestartPolicy.NEVER, + ), + slots=2, + verbose=False, + log_level="INFO", + ), ) def my_horovod_task(): ... - default_img = Image(name="default", fqn="test", tag="tag") - settings = SerializationSettings( - project="project", - domain="domain", - version="version", - env={"FOO": "baz"}, - image_config=ImageConfig(default_image=default_img, images=[default_img]), - ) - cmd = my_horovod_task.get_command(settings) + cmd = my_horovod_task.get_command(serialization_settings) assert "horovodrun" in cmd - config = my_horovod_task.get_config(settings) - assert "/usr/sbin/sshd" in config["worker_spec_command"] - custom = my_horovod_task.get_custom(settings) - assert isinstance(custom, dict) is True + assert "--verbose" not in cmd + assert "--log-level" in cmd + assert "INFO" in cmd + expected_dict = { + "launcherReplicas": { + "replicas": 1, + "resources": { + "requests": [ + {"name": "CPU", "value": "1"}, + ], + "limits": [ + {"name": "CPU", "value": "2"}, + ], + }, + }, + "workerReplicas": { + "replicas": 1, + "resources": {}, + "command": ["/usr/sbin/sshd", "-De", "-f", "/home/jobuser/.sshd_config"], + }, + "slots": 2, + } + assert my_horovod_task.get_custom(serialization_settings) == expected_dict diff --git a/plugins/flytekit-kf-pytorch/README.md b/plugins/flytekit-kf-pytorch/README.md index 7de27502bf..c1516d3248 100644 --- a/plugins/flytekit-kf-pytorch/README.md +++ b/plugins/flytekit-kf-pytorch/README.md @@ -14,3 +14,42 @@ pip install flytekitplugins-kfpytorch To set up PyTorch operator in the Flyte deployment's backend, follow the [PyTorch Operator Setup](https://docs.flyte.org/en/latest/deployment/plugin_setup/pytorch_operator.html) guide. An [example](https://docs.flyte.org/projects/cookbook/en/latest/auto/integrations/kubernetes/kfpytorch/pytorch_mnist.html#sphx-glr-auto-integrations-kubernetes-kfpytorch-pytorch-mnist-py) showcasing PyTorch operator can be found in the documentation. + +## Code Example +```python +from flytekitplugins.kfpytorch import PyTorch, Worker, Master, RestartPolicy, RunPolicy, CleanPodPolicy + +@task( + task_config = PyTorch( + worker=Worker( + replicas=5, + requests=Resources(cpu="2", mem="2Gi"), + limits=Resources(cpu="4", mem="2Gi"), + image="worker:latest", + restart_policy=RestartPolicy.FAILURE, + ), + master=Master( + restart_policy=RestartPolicy.ALWAYS, + ), + ) + image="test_image", + resources=Resources(cpu="1", mem="1Gi"), +) +def pytorch_job(): + ... +``` + + +## Upgrade Pytorch Plugin from V0 to V1 +Pytorch plugin is now updated from v0 to v1 to enable more configuration options. +To migrate from v0 to v1, change the following: +1. Update flytepropeller to v1.6.0 +2. Update flytekit version to v1.6.2 +3. Update your code from: + ``` + task_config=Pytorch(num_workers=10), + ``` + to: + ``` + task_config=PyTorch(worker=Worker(replicas=10)), + ``` diff --git a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/__init__.py b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/__init__.py index cb9add7302..d56e1f83d9 100644 --- a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/__init__.py +++ b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/__init__.py @@ -11,4 +11,4 @@ Elastic """ -from .task import Elastic, PyTorch +from .task import CleanPodPolicy, Elastic, Master, PyTorch, RestartPolicy, RunPolicy, Worker diff --git a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py index aea2c9a2e6..6625263db1 100644 --- a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py +++ b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py @@ -3,34 +3,104 @@ Kubernetes. It leverages `Pytorch Job `_ Plugin from kubeflow. """ import os -from dataclasses import dataclass +from dataclasses import dataclass, field +from enum import Enum from typing import Any, Callable, Dict, Optional, Union import cloudpickle -from flyteidl.plugins.pytorch_pb2 import DistributedPyTorchTrainingTask +from flyteidl.plugins.kubeflow import common_pb2 as kubeflow_common +from flyteidl.plugins.kubeflow import pytorch_pb2 as pytorch_task from google.protobuf.json_format import MessageToDict import flytekit -from flytekit import PythonFunctionTask +from flytekit import PythonFunctionTask, Resources from flytekit.configuration import SerializationSettings +from flytekit.core.resources import convert_resources_to_resource_model from flytekit.extend import IgnoreOutputs, TaskPlugins TORCH_IMPORT_ERROR_MESSAGE = "PyTorch is not installed. Please install `flytekitplugins-kfpytorch['elastic']`." @dataclass -class PyTorch(object): +class RestartPolicy(Enum): + """ + RestartPolicy describes how the replicas should be restarted """ - Configuration for an executable `Pytorch Job `_. Use this - to run distributed pytorch training on k8s + ALWAYS = kubeflow_common.RESTART_POLICY_ALWAYS + FAILURE = kubeflow_common.RESTART_POLICY_ON_FAILURE + NEVER = kubeflow_common.RESTART_POLICY_NEVER + + +@dataclass +class CleanPodPolicy(Enum): + """ + CleanPodPolicy describes how to deal with pods when the job is finished. + """ + + NONE = kubeflow_common.CLEANPOD_POLICY_NONE + ALL = kubeflow_common.CLEANPOD_POLICY_ALL + RUNNING = kubeflow_common.CLEANPOD_POLICY_RUNNING + + +@dataclass +class RunPolicy: + """ + RunPolicy describes some policy to apply to the execution of a kubeflow job. Args: - num_workers: integer determining the number of worker replicas spawned in the cluster for this job - (in addition to 1 master). + clean_pod_policy (int): Defines the policy for cleaning up pods after the PyTorchJob completes. Default to None. + ttl_seconds_after_finished (int): Defines the TTL for cleaning up finished PyTorchJobs. + active_deadline_seconds (int): Specifies the duration (in seconds) since startTime during which the job. + can remain active before it is terminated. Must be a positive integer. This setting applies only to pods. + where restartPolicy is OnFailure or Always. + backoff_limit (int): Number of retries before marking this job as failed. + """ + + clean_pod_policy: CleanPodPolicy = None + ttl_seconds_after_finished: Optional[int] = None + active_deadline_seconds: Optional[int] = None + backoff_limit: Optional[int] = None + + +@dataclass +class Worker: + image: Optional[str] = None + requests: Optional[Resources] = None + limits: Optional[Resources] = None + replicas: Optional[int] = None + restart_policy: Optional[RestartPolicy] = None + +@dataclass +class Master: + """ + Configuration for master replica group. Master should always have 1 replica, so we don't need a `replicas` field + """ + + image: Optional[str] = None + requests: Optional[Resources] = None + limits: Optional[Resources] = None + restart_policy: Optional[RestartPolicy] = None + + +@dataclass +class PyTorch(object): + """ + Configuration for an executable `PyTorch Job `_. Use this + to run distributed PyTorch training on Kubernetes. + + Args: + master: Configuration for the master replica group. + worker: Configuration for the worker replica group. + run_policy: Configuration for the run policy. + num_workers: [DEPRECATED] This argument is deprecated. Use `worker.replicas` instead. """ - num_workers: int + master: Master = field(default_factory=lambda: Master()) + worker: Worker = field(default_factory=lambda: Worker()) + run_policy: Optional[RunPolicy] = field(default_factory=lambda: None) + # Support v0 config for backwards compatibility + num_workers: Optional[int] = None @dataclass @@ -67,6 +137,14 @@ class PyTorchFunctionTask(PythonFunctionTask[PyTorch]): _PYTORCH_TASK_TYPE = "pytorch" def __init__(self, task_config: PyTorch, task_function: Callable, **kwargs): + if task_config.num_workers and task_config.worker.replicas: + raise ValueError( + "Cannot specify both `num_workers` and `worker.replicas`. Please use `worker.replicas` as `num_workers` is depreacated." + ) + if task_config.num_workers is None and task_config.worker.replicas is None: + raise ValueError( + "Must specify either `num_workers` or `worker.replicas`. Please use `worker.replicas` as `num_workers` is depreacated." + ) super().__init__( task_config, task_function, @@ -74,9 +152,42 @@ def __init__(self, task_config: PyTorch, task_function: Callable, **kwargs): **kwargs, ) + def _convert_replica_spec( + self, replica_config: Union[Master, Worker] + ) -> pytorch_task.DistributedPyTorchTrainingReplicaSpec: + resources = convert_resources_to_resource_model(requests=replica_config.requests, limits=replica_config.limits) + replicas = 1 + # Master should always have 1 replica + if not isinstance(replica_config, Master): + replicas = replica_config.replicas + return pytorch_task.DistributedPyTorchTrainingReplicaSpec( + replicas=replicas, + image=replica_config.image, + resources=resources.to_flyte_idl() if resources else None, + restart_policy=replica_config.restart_policy.value if replica_config.restart_policy else None, + ) + + def _convert_run_policy(self, run_policy: RunPolicy) -> kubeflow_common.RunPolicy: + return kubeflow_common.RunPolicy( + clean_pod_policy=run_policy.clean_pod_policy.value if run_policy.clean_pod_policy else None, + ttl_seconds_after_finished=run_policy.ttl_seconds_after_finished, + active_deadline_seconds=run_policy.active_deadline_seconds, + backoff_limit=run_policy.active_deadline_seconds, + ) + def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: - job = DistributedPyTorchTrainingTask(workers=self.task_config.num_workers) - return MessageToDict(job) + worker = self._convert_replica_spec(self.task_config.worker) + # support v0 config for backwards compatibility + if self.task_config.num_workers: + worker.replicas = self.task_config.num_workers + + run_policy = self._convert_run_policy(self.task_config.run_policy) if self.task_config.run_policy else None + pytorch_job = pytorch_task.DistributedPyTorchTrainingTask( + worker_replicas=worker, + master_replicas=self._convert_replica_spec(self.task_config.master), + run_policy=run_policy, + ) + return MessageToDict(pytorch_job) # Register the Pytorch Plugin into the flytekit core plugin system @@ -236,8 +347,10 @@ def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any] nproc_per_node=self.task_config.nproc_per_node, max_restarts=self.task_config.max_restarts, ) - job = DistributedPyTorchTrainingTask( - workers=self.max_nodes, + job = pytorch_task.DistributedPyTorchTrainingTask( + worker_replicas=pytorch_task.DistributedPyTorchTrainingReplicaSpec( + replicas=self.max_nodes, + ), elastic_config=elastic_config, ) return MessageToDict(job) diff --git a/plugins/flytekit-kf-pytorch/requirements.txt b/plugins/flytekit-kf-pytorch/requirements.txt index 9cbdbad5c6..3bdac2d459 100644 --- a/plugins/flytekit-kf-pytorch/requirements.txt +++ b/plugins/flytekit-kf-pytorch/requirements.txt @@ -1,72 +1,145 @@ # -# This file is autogenerated by pip-compile with Python 3.9 +# This file is autogenerated by pip-compile with Python 3.8 # by the following command: # # pip-compile requirements.in # -e file:.#egg=flytekitplugins-kfpytorch # via -r requirements.in +adlfs==2023.4.0 + # via flytekit +aiobotocore==2.5.1 + # via s3fs +aiohttp==3.8.4 + # via + # adlfs + # aiobotocore + # gcsfs + # s3fs +aioitertools==0.11.0 + # via aiobotocore +aiosignal==1.3.1 + # via aiohttp arrow==1.2.3 # via jinja2-time +async-timeout==4.0.2 + # via aiohttp +attrs==23.1.0 + # via aiohttp +azure-core==1.27.1 + # via + # adlfs + # azure-identity + # azure-storage-blob +azure-datalake-store==0.0.53 + # via adlfs +azure-identity==1.13.0 + # via adlfs +azure-storage-blob==12.16.0 + # via adlfs binaryornot==0.4.4 # via cookiecutter -cachetools==5.3.0 +botocore==1.29.161 + # via aiobotocore +cachetools==5.3.1 # via google-auth -certifi==2022.12.7 +certifi==2023.5.7 # via # kubernetes # requests cffi==1.15.1 - # via cryptography + # via + # azure-datalake-store + # cryptography chardet==5.1.0 # via binaryornot charset-normalizer==3.1.0 - # via requests + # via + # aiohttp + # requests click==8.1.3 # via # cookiecutter # flytekit + # rich-click cloudpickle==2.2.1 # via # flytekit # flytekitplugins-kfpytorch cookiecutter==2.1.1 # via flytekit -croniter==1.3.14 +croniter==1.4.1 # via flytekit -cryptography==40.0.2 +cryptography==41.0.1 # via + # azure-identity + # azure-storage-blob + # msal + # pyjwt # pyopenssl - # secretstorage -dataclasses-json==0.5.7 +dataclasses-json==0.5.8 # via flytekit decorator==5.1.1 - # via retry -deprecated==1.2.13 + # via gcsfs +deprecated==1.2.14 # via flytekit diskcache==5.6.1 # via flytekit -docker==6.0.1 +docker==6.1.3 # via flytekit docker-image-py==0.1.12 # via flytekit docstring-parser==0.15 # via flytekit -flyteidl==1.2.10 +flyteidl==1.2.11 # via # flytekit # flytekitplugins-kfpytorch -flytekit==1.2.9 +flytekit==1.2.13 # via flytekitplugins-kfpytorch +frozenlist==1.3.3 + # via + # aiohttp + # aiosignal +fsspec==2023.6.0 + # via + # adlfs + # flytekit + # gcsfs + # s3fs +gcsfs==2023.6.0 + # via flytekit gitdb==4.0.10 # via gitpython gitpython==3.1.31 # via flytekit -google-auth==2.17.3 - # via kubernetes -googleapis-common-protos==1.59.0 +google-api-core==2.11.1 + # via + # google-cloud-core + # google-cloud-storage +google-auth==2.21.0 + # via + # gcsfs + # google-api-core + # google-auth-oauthlib + # google-cloud-core + # google-cloud-storage + # kubernetes +google-auth-oauthlib==1.0.0 + # via gcsfs +google-cloud-core==2.3.2 + # via google-cloud-storage +google-cloud-storage==2.10.0 + # via gcsfs +google-crc32c==1.5.0 + # via google-resumable-media +google-resumable-media==2.5.0 + # via google-cloud-storage +googleapis-common-protos==1.59.1 # via # flyteidl + # flytekit + # google-api-core # grpcio-status grpcio==1.48.2 # via @@ -75,30 +148,36 @@ grpcio==1.48.2 grpcio-status==1.48.2 # via flytekit idna==3.4 - # via requests -importlib-metadata==6.6.0 + # via + # requests + # yarl +importlib-metadata==6.7.0 # via # flytekit # keyring +importlib-resources==5.12.0 + # via keyring +isodate==0.6.1 + # via azure-storage-blob jaraco-classes==3.2.3 # via keyring -jeepney==0.8.0 - # via - # keyring - # secretstorage jinja2==3.1.2 # via # cookiecutter # jinja2-time jinja2-time==0.2.0 # via cookiecutter -joblib==1.2.0 +jmespath==1.0.1 + # via botocore +joblib==1.3.1 # via flytekit -keyring==23.13.1 +keyring==24.2.0 # via flytekit kubernetes==26.1.0 # via flytekit -markupsafe==2.1.2 +markdown-it-py==3.0.0 + # via rich +markupsafe==2.1.3 # via jinja2 marshmallow==3.19.0 # via @@ -109,11 +188,24 @@ marshmallow-enum==1.5.1 # via dataclasses-json marshmallow-jsonschema==0.13.0 # via flytekit -more-itertools==9.0.0 +mdurl==0.1.2 + # via markdown-it-py +more-itertools==9.1.0 # via jaraco-classes +msal==1.22.0 + # via + # azure-datalake-store + # azure-identity + # msal-extensions +msal-extensions==1.0.0 + # via azure-identity +multidict==6.0.4 + # via + # aiohttp + # yarl mypy-extensions==1.0.0 # via typing-inspect -natsort==8.2.0 +natsort==8.4.0 # via flytekit numpy==1.23.5 # via @@ -126,19 +218,20 @@ packaging==23.1 # via # docker # marshmallow -pandas==1.3.5 +pandas==1.5.3 # via flytekit +portalocker==2.7.0 + # via msal-extensions protobuf==3.20.3 # via # flyteidl # flytekit + # google-api-core # googleapis-common-protos # grpcio-status # protoc-gen-swagger protoc-gen-swagger==0.1.0 # via flyteidl -py==1.11.0 - # via retry pyarrow==10.0.1 # via flytekit pyasn1==0.5.0 @@ -149,18 +242,23 @@ pyasn1-modules==0.3.0 # via google-auth pycparser==2.21 # via cffi -pyopenssl==23.1.1 +pygments==2.15.1 + # via rich +pyjwt[crypto]==2.7.0 + # via msal +pyopenssl==23.2.0 # via flytekit python-dateutil==2.8.2 # via # arrow + # botocore # croniter # flytekit # kubernetes # pandas python-json-logger==2.0.7 # via flytekit -python-slugify==8.0.0 +python-slugify==8.0.1 # via cookiecutter pytimeparse==1.1.8 # via flytekit @@ -174,30 +272,45 @@ pyyaml==6.0 # flytekit # kubernetes # responses -regex==2023.3.23 +regex==2023.6.3 # via docker-image-py -requests==2.28.2 +requests==2.31.0 # via + # azure-core + # azure-datalake-store # cookiecutter # docker # flytekit + # gcsfs + # google-api-core + # google-cloud-storage # kubernetes + # msal # requests-oauthlib # responses requests-oauthlib==1.3.1 - # via kubernetes + # via + # google-auth-oauthlib + # kubernetes responses==0.23.1 # via flytekit -retry==0.9.2 +rich==13.4.2 + # via + # flytekit + # rich-click +rich-click==1.6.1 # via flytekit rsa==4.9 # via google-auth -secretstorage==3.3.3 - # via keyring +s3fs==2023.6.0 + # via flytekit six==1.16.0 # via + # azure-core + # azure-identity # google-auth # grpcio + # isodate # kubernetes # python-dateutil smmap==5.0.0 @@ -208,33 +321,44 @@ statsd==3.3.0 # via flytekit text-unidecode==1.3 # via python-slugify -types-pyyaml==6.0.12.9 +types-pyyaml==6.0.12.10 # via responses -typing-extensions==4.5.0 +typing-extensions==4.7.0 # via + # aioitertools + # azure-core + # azure-storage-blob # flytekit + # rich # typing-inspect -typing-inspect==0.8.0 +typing-inspect==0.9.0 # via dataclasses-json -urllib3==1.26.14 +urllib3==1.26.16 # via + # botocore # docker # flytekit + # google-auth # kubernetes # requests # responses -websocket-client==1.5.1 +websocket-client==1.6.1 # via # docker # kubernetes wheel==0.40.0 # via flytekit -wrapt==1.14.1 +wrapt==1.15.0 # via + # aiobotocore # deprecated # flytekit +yarl==1.9.2 + # via aiohttp zipp==3.15.0 - # via importlib-metadata + # via + # importlib-metadata + # importlib-resources # The following packages are considered to be unsafe in a requirements file: # setuptools diff --git a/plugins/flytekit-kf-pytorch/setup.py b/plugins/flytekit-kf-pytorch/setup.py index a207b9381e..4249de60b6 100644 --- a/plugins/flytekit-kf-pytorch/setup.py +++ b/plugins/flytekit-kf-pytorch/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["cloudpickle", "flytekit>=1.1.0b0,<1.3.0,<2.0.0", "flyteidl>=1.2.10,<1.3.0"] +plugin_requires = ["cloudpickle", "flytekit>=1.1.0b0,<1.3.0,<2.0.0", "flyteidl>=1.2.11,<1.3.0"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-kf-pytorch/tests/test_pytorch_task.py b/plugins/flytekit-kf-pytorch/tests/test_pytorch_task.py index 00eb6c0953..ecdf9e375c 100644 --- a/plugins/flytekit-kf-pytorch/tests/test_pytorch_task.py +++ b/plugins/flytekit-kf-pytorch/tests/test_pytorch_task.py @@ -1,10 +1,24 @@ -from flytekitplugins.kfpytorch.task import PyTorch +import pytest +from flytekitplugins.kfpytorch.task import Master, PyTorch, RestartPolicy, Worker from flytekit import Resources, task from flytekit.configuration import Image, ImageConfig, SerializationSettings -def test_pytorch_task(): +@pytest.fixture +def serialization_settings() -> SerializationSettings: + default_img = Image(name="default", fqn="test", tag="tag") + settings = SerializationSettings( + project="project", + domain="domain", + version="version", + env={"FOO": "baz"}, + image_config=ImageConfig(default_image=default_img, images=[default_img]), + ) + return settings + + +def test_pytorch_task(serialization_settings: SerializationSettings): @task( task_config=PyTorch(num_workers=10), cache=True, @@ -18,16 +32,97 @@ def my_pytorch_task(x: int, y: str) -> int: assert my_pytorch_task.task_config is not None - default_img = Image(name="default", fqn="test", tag="tag") - settings = SerializationSettings( - project="project", - domain="domain", - version="version", - env={"FOO": "baz"}, - image_config=ImageConfig(default_image=default_img, images=[default_img]), + assert my_pytorch_task.get_custom(serialization_settings) == { + "workerReplicas": {"replicas": 10, "resources": {}}, + "masterReplicas": {"replicas": 1, "resources": {}}, + } + assert my_pytorch_task.resources.limits == Resources() + assert my_pytorch_task.resources.requests == Resources(cpu="1") + assert my_pytorch_task.task_type == "pytorch" + + +def test_pytorch_task_with_default_config(serialization_settings: SerializationSettings): + task_config = PyTorch(worker=Worker(replicas=1)) + + @task( + task_config=task_config, + cache=True, + requests=Resources(cpu="1"), + cache_version="1", ) + def my_pytorch_task(x: int, y: str) -> int: + return x - assert my_pytorch_task.get_custom(settings) == {"workers": 10} + assert my_pytorch_task(x=10, y="hello") == 10 + + assert my_pytorch_task.task_config is not None + assert my_pytorch_task.task_type == "pytorch" assert my_pytorch_task.resources.limits == Resources() assert my_pytorch_task.resources.requests == Resources(cpu="1") + + expected_dict = { + "masterReplicas": { + "replicas": 1, + "resources": {}, + }, + "workerReplicas": { + "replicas": 1, + "resources": {}, + }, + } + assert my_pytorch_task.get_custom(serialization_settings) == expected_dict + + +def test_pytorch_task_with_custom_config(serialization_settings: SerializationSettings): + task_config = PyTorch( + worker=Worker( + replicas=5, + requests=Resources(cpu="2", mem="2Gi"), + limits=Resources(cpu="4", mem="2Gi"), + image="worker:latest", + restart_policy=RestartPolicy.FAILURE, + ), + master=Master( + restart_policy=RestartPolicy.ALWAYS, + ), + ) + + @task( + task_config=task_config, + cache=True, + requests=Resources(cpu="1"), + cache_version="1", + ) + def my_pytorch_task(x: int, y: str) -> int: + return x + + assert my_pytorch_task(x=10, y="hello") == 10 + + assert my_pytorch_task.task_config is not None assert my_pytorch_task.task_type == "pytorch" + assert my_pytorch_task.resources.limits == Resources() + assert my_pytorch_task.resources.requests == Resources(cpu="1") + + expected_custom_dict = { + "workerReplicas": { + "replicas": 5, + "image": "worker:latest", + "resources": { + "requests": [ + {"name": "CPU", "value": "2"}, + {"name": "MEMORY", "value": "2Gi"}, + ], + "limits": [ + {"name": "CPU", "value": "4"}, + {"name": "MEMORY", "value": "2Gi"}, + ], + }, + "restartPolicy": "RESTART_POLICY_ON_FAILURE", + }, + "masterReplicas": { + "resources": {}, + "replicas": 1, + "restartPolicy": "RESTART_POLICY_ALWAYS", + }, + } + assert my_pytorch_task.get_custom(serialization_settings) == expected_custom_dict diff --git a/plugins/flytekit-kf-tensorflow/README.md b/plugins/flytekit-kf-tensorflow/README.md index 9e4c26fa70..d059624f03 100644 --- a/plugins/flytekit-kf-tensorflow/README.md +++ b/plugins/flytekit-kf-tensorflow/README.md @@ -8,4 +8,47 @@ To install the plugin, run the following command: pip install flytekitplugins-kftensorflow ``` -_Example coming soon!_ +## Code Example +To build a TFJob with: +10 workers with restart policy as failed and 2 CPU and 2Gi Memory +1 ps replica with resources the same as task defined resources +1 chief replica with resources the same as task defined resources and restart policy as always +run policy as clean up all pods after job is finished. + +You code: +```python +from flytekitplugins.kftensorflow import PS, Chief, CleanPodPolicy, RestartPolicy, RunPolicy, TfJob, Worker + +@task( + task_config=TfJob( + worker=Worker( + replicas=5, + requests=Resources(cpu="2", mem="2Gi"), + limits=Resources(cpu="2", mem="2Gi"), + restart_policy=RestartPolicy.FAILURE, + ), + ps=PS(replicas=1), + chief=Chief(replicas=1, restart_policy=RestartPolicy.ALWAYS), + run_policy=RunPolicy(clean_pod_policy=CleanPodPolicy.RUNNING), + ), + image="test_image", + resources=Resources(cpu="1", mem="1Gi"), +) +def tf_job(): + ... +``` + + +## Upgrade TensorFlow Plugin from V0 to V1 +Tensorflow plugin is now updated from v0 to v1 to enable more configuration options. +To migrate from v0 to v1, change the following: +1. Update flytepropeller to v1.6.0 +2. Update flytekit version to v1.6.2 +3. Update your code from: + ``` + task_config=TfJob(num_workers=10, num_ps_replicas=1, num_chief_replicas=1), + ``` + to: + ``` + task_config=TfJob(worker=Worker(replicas=10), ps=PS(replicas=1), chief=Chief(replicas=1)), + ``` diff --git a/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/__init__.py b/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/__init__.py index 02dec6cc7d..81a4cbc248 100644 --- a/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/__init__.py +++ b/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/__init__.py @@ -10,4 +10,4 @@ TfJob """ -from .task import TfJob +from .task import PS, Chief, CleanPodPolicy, RestartPolicy, RunPolicy, TfJob, Worker diff --git a/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/models.py b/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/models.py deleted file mode 100644 index 87d7bb7b90..0000000000 --- a/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/models.py +++ /dev/null @@ -1,35 +0,0 @@ -from flyteidl.plugins import tensorflow_pb2 as _tensorflow_task - -from flytekit.models import common as _common - - -class TensorFlowJob(_common.FlyteIdlEntity): - def __init__(self, workers_count, ps_replicas_count, chief_replicas_count): - self._workers_count = workers_count - self._ps_replicas_count = ps_replicas_count - self._chief_replicas_count = chief_replicas_count - - @property - def workers_count(self): - return self._workers_count - - @property - def ps_replicas_count(self): - return self._ps_replicas_count - - @property - def chief_replicas_count(self): - return self._chief_replicas_count - - def to_flyte_idl(self): - return _tensorflow_task.DistributedTensorflowTrainingTask( - workers=self.workers_count, ps_replicas=self.ps_replicas_count, chief_replicas=self.chief_replicas_count - ) - - @classmethod - def from_flyte_idl(cls, pb2_object): - return cls( - workers_count=pb2_object.workers, - ps_replicas_count=pb2_object.ps_replicas, - chief_replicas_count=pb2_object.chief_replicas, - ) diff --git a/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/task.py b/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/task.py index 03855e3095..bd6a97a293 100644 --- a/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/task.py +++ b/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/task.py @@ -2,37 +2,113 @@ This Plugin adds the capability of running distributed tensorflow training to Flyte using backend plugins, natively on Kubernetes. It leverages `TF Job `_ Plugin from kubeflow. """ -from dataclasses import dataclass -from typing import Any, Callable, Dict +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Callable, Dict, Optional, Union +from flyteidl.plugins.kubeflow import common_pb2 as kubeflow_common +from flyteidl.plugins.kubeflow import tensorflow_pb2 as tensorflow_task from google.protobuf.json_format import MessageToDict -from flytekit import PythonFunctionTask +from flytekit import PythonFunctionTask, Resources from flytekit.configuration import SerializationSettings +from flytekit.core.resources import convert_resources_to_resource_model from flytekit.extend import TaskPlugins -from .models import TensorFlowJob + +@dataclass +class RestartPolicy(Enum): + """ + RestartPolicy describes how the replicas should be restarted + """ + + ALWAYS = kubeflow_common.RESTART_POLICY_ALWAYS + FAILURE = kubeflow_common.RESTART_POLICY_ON_FAILURE + NEVER = kubeflow_common.RESTART_POLICY_NEVER @dataclass -class TfJob(object): +class CleanPodPolicy(Enum): + """ + CleanPodPolicy describes how to deal with pods when the job is finished. """ - Configuration for an executable `TF Job `_. Use this - to run distributed tensorflow training on k8s (with parameter server) + + NONE = kubeflow_common.CLEANPOD_POLICY_NONE + ALL = kubeflow_common.CLEANPOD_POLICY_ALL + RUNNING = kubeflow_common.CLEANPOD_POLICY_RUNNING + + +@dataclass +class RunPolicy: + """ + RunPolicy describes a set of policies to apply to the execution of a Kubeflow job. Args: - num_workers: integer determining the number of worker replicas spawned in the cluster for this job - (in addition to 1 master). + clean_pod_policy: The policy for cleaning up pods after the job completes. Defaults to None. + ttl_seconds_after_finished: The time-to-live (TTL) in seconds for cleaning up finished jobs. + active_deadline_seconds: The duration (in seconds) since startTime during which the job can remain + active before it is terminated. Must be a positive integer. This setting applies only to pods + where restartPolicy is OnFailure or Always. + backoff_limit: The number of retries before marking this job as failed. + """ + + clean_pod_policy: CleanPodPolicy = None + ttl_seconds_after_finished: Optional[int] = None + active_deadline_seconds: Optional[int] = None + backoff_limit: Optional[int] = None + + +@dataclass +class Chief: + image: Optional[str] = None + requests: Optional[Resources] = None + limits: Optional[Resources] = None + replicas: Optional[int] = None + restart_policy: Optional[RestartPolicy] = None + + +@dataclass +class PS: + image: Optional[str] = None + requests: Optional[Resources] = None + limits: Optional[Resources] = None + replicas: Optional[int] = None + restart_policy: Optional[RestartPolicy] = None - num_ps_replicas: Number of Parameter server replicas to use - num_chief_replicas: Number of chief replicas to use +@dataclass +class Worker: + image: Optional[str] = None + requests: Optional[Resources] = None + limits: Optional[Resources] = None + replicas: Optional[int] = None + restart_policy: Optional[RestartPolicy] = None + +@dataclass +class TfJob: """ + Configuration for an executable `TensorFlow Job `_. Use this + to run distributed TensorFlow training on Kubernetes. - num_workers: int - num_ps_replicas: int - num_chief_replicas: int + Args: + chief: Configuration for the chief replica group. + ps: Configuration for the parameter server (PS) replica group. + worker: Configuration for the worker replica group. + run_policy: Configuration for the run policy. + num_workers: [DEPRECATED] This argument is deprecated. Use `worker.replicas` instead. + num_ps_replicas: [DEPRECATED] This argument is deprecated. Use `ps.replicas` instead. + num_chief_replicas: [DEPRECATED] This argument is deprecated. Use `chief.replicas` instead. + """ + + chief: Chief = field(default_factory=lambda: Chief()) + ps: PS = field(default_factory=lambda: PS()) + worker: Worker = field(default_factory=lambda: Worker()) + run_policy: Optional[RunPolicy] = field(default_factory=lambda: None) + # Support v0 config for backwards compatibility + num_workers: Optional[int] = None + num_ps_replicas: Optional[int] = None + num_chief_replicas: Optional[int] = None class TensorflowFunctionTask(PythonFunctionTask[TfJob]): @@ -44,20 +120,79 @@ class TensorflowFunctionTask(PythonFunctionTask[TfJob]): _TF_JOB_TASK_TYPE = "tensorflow" def __init__(self, task_config: TfJob, task_function: Callable, **kwargs): + if task_config.num_workers and task_config.worker.replicas: + raise ValueError( + "Cannot specify both `num_workers` and `worker.replicas`. Please use `worker.replicas` as `num_workers` is depreacated." + ) + if task_config.num_workers is None and task_config.worker.replicas is None: + raise ValueError( + "Must specify either `num_workers` or `worker.replicas`. Please use `worker.replicas` as `num_workers` is depreacated." + ) + if task_config.num_chief_replicas and task_config.chief.replicas: + raise ValueError( + "Cannot specify both `num_workers` and `chief.replicas`. Please use `chief.replicas` as `num_chief_replicas` is depreacated." + ) + if task_config.num_chief_replicas is None and task_config.chief.replicas is None: + raise ValueError( + "Must specify either `num_workers` or `chief.replicas`. Please use `chief.replicas` as `num_chief_replicas` is depreacated." + ) + if task_config.num_ps_replicas and task_config.ps.replicas: + raise ValueError( + "Cannot specify both `num_workers` and `ps.replicas`. Please use `ps.replicas` as `num_ps_replicas` is depreacated." + ) + if task_config.num_ps_replicas is None and task_config.ps.replicas is None: + raise ValueError( + "Must specify either `num_workers` or `ps.replicas`. Please use `ps.replicas` as `num_ps_replicas` is depreacated." + ) super().__init__( task_type=self._TF_JOB_TASK_TYPE, task_config=task_config, task_function=task_function, + task_type_version=1, **kwargs, ) + def _convert_replica_spec( + self, replica_config: Union[Chief, PS, Worker] + ) -> tensorflow_task.DistributedTensorflowTrainingReplicaSpec: + resources = convert_resources_to_resource_model(requests=replica_config.requests, limits=replica_config.limits) + return tensorflow_task.DistributedTensorflowTrainingReplicaSpec( + replicas=replica_config.replicas, + image=replica_config.image, + resources=resources.to_flyte_idl() if resources else None, + restart_policy=replica_config.restart_policy.value if replica_config.restart_policy else None, + ) + + def _convert_run_policy(self, run_policy: RunPolicy) -> kubeflow_common.RunPolicy: + return kubeflow_common.RunPolicy( + clean_pod_policy=run_policy.clean_pod_policy.value if run_policy.clean_pod_policy.value else None, + ttl_seconds_after_finished=run_policy.ttl_seconds_after_finished, + active_deadline_seconds=run_policy.active_deadline_seconds, + backoff_limit=run_policy.active_deadline_seconds, + ) + def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: - job = TensorFlowJob( - workers_count=self.task_config.num_workers, - ps_replicas_count=self.task_config.num_ps_replicas, - chief_replicas_count=self.task_config.num_chief_replicas, + chief = self._convert_replica_spec(self.task_config.chief) + if self.task_config.num_chief_replicas: + chief.replicas = self.task_config.num_chief_replicas + + worker = self._convert_replica_spec(self.task_config.worker) + if self.task_config.num_workers: + worker.replicas = self.task_config.num_workers + + ps = self._convert_replica_spec(self.task_config.ps) + if self.task_config.num_ps_replicas: + ps.replicas = self.task_config.num_ps_replicas + + run_policy = self._convert_run_policy(self.task_config.run_policy) if self.task_config.run_policy else None + training_task = tensorflow_task.DistributedTensorflowTrainingTask( + chief_replicas=chief, + worker_replicas=worker, + ps_replicas=ps, + run_policy=run_policy, ) - return MessageToDict(job.to_flyte_idl()) + + return MessageToDict(training_task) # Register the Tensorflow Plugin into the flytekit core plugin system diff --git a/plugins/flytekit-kf-tensorflow/requirements.txt b/plugins/flytekit-kf-tensorflow/requirements.txt index 60f87a8ac1..70e3e1a7b2 100644 --- a/plugins/flytekit-kf-tensorflow/requirements.txt +++ b/plugins/flytekit-kf-tensorflow/requirements.txt @@ -1,58 +1,143 @@ # -# This file is autogenerated by pip-compile with Python 3.7 +# This file is autogenerated by pip-compile with Python 3.8 # by the following command: # # pip-compile requirements.in # -e file:.#egg=flytekitplugins-kftensorflow # via -r requirements.in +adlfs==2023.4.0 + # via flytekit +aiobotocore==2.5.1 + # via s3fs +aiohttp==3.8.4 + # via + # adlfs + # aiobotocore + # gcsfs + # s3fs +aioitertools==0.11.0 + # via aiobotocore +aiosignal==1.3.1 + # via aiohttp arrow==1.2.3 # via jinja2-time +async-timeout==4.0.2 + # via aiohttp +attrs==23.1.0 + # via aiohttp +azure-core==1.27.1 + # via + # adlfs + # azure-identity + # azure-storage-blob +azure-datalake-store==0.0.53 + # via adlfs +azure-identity==1.13.0 + # via adlfs +azure-storage-blob==12.16.0 + # via adlfs binaryornot==0.4.4 # via cookiecutter -certifi==2022.12.7 - # via requests +botocore==1.29.161 + # via aiobotocore +cachetools==5.3.1 + # via google-auth +certifi==2023.5.7 + # via + # kubernetes + # requests cffi==1.15.1 - # via cryptography + # via + # azure-datalake-store + # cryptography chardet==5.1.0 # via binaryornot -charset-normalizer==3.0.1 - # via requests +charset-normalizer==3.1.0 + # via + # aiohttp + # requests click==8.1.3 # via # cookiecutter # flytekit + # rich-click cloudpickle==2.2.1 # via flytekit cookiecutter==2.1.1 # via flytekit -croniter==1.3.8 +croniter==1.4.1 # via flytekit -cryptography==39.0.1 +cryptography==41.0.1 # via + # azure-identity + # azure-storage-blob + # msal + # pyjwt # pyopenssl - # secretstorage -dataclasses-json==0.5.7 +dataclasses-json==0.5.8 # via flytekit decorator==5.1.1 - # via retry -deprecated==1.2.13 + # via gcsfs +deprecated==1.2.14 # via flytekit -diskcache==5.4.0 +diskcache==5.6.1 # via flytekit -docker==6.0.1 +docker==6.1.3 # via flytekit docker-image-py==0.1.12 # via flytekit docstring-parser==0.15 # via flytekit -flyteidl==1.2.9 - # via flytekit -flytekit==1.2.7 +flyteidl==1.2.11 + # via + # flytekit + # flytekitplugins-kftensorflow +flytekit==1.2.13 # via flytekitplugins-kftensorflow -googleapis-common-protos==1.58.0 +frozenlist==1.3.3 + # via + # aiohttp + # aiosignal +fsspec==2023.6.0 + # via + # adlfs + # flytekit + # gcsfs + # s3fs +gcsfs==2023.6.0 + # via flytekit +gitdb==4.0.10 + # via gitpython +gitpython==3.1.31 + # via flytekit +google-api-core==2.11.1 + # via + # google-cloud-core + # google-cloud-storage +google-auth==2.21.0 + # via + # gcsfs + # google-api-core + # google-auth-oauthlib + # google-cloud-core + # google-cloud-storage + # kubernetes +google-auth-oauthlib==1.0.0 + # via gcsfs +google-cloud-core==2.3.2 + # via google-cloud-storage +google-cloud-storage==2.10.0 + # via gcsfs +google-crc32c==1.5.0 + # via google-resumable-media +google-resumable-media==2.5.0 + # via google-cloud-storage +googleapis-common-protos==1.59.1 # via # flyteidl + # flytekit + # google-api-core # grpcio-status grpcio==1.48.2 # via @@ -61,31 +146,36 @@ grpcio==1.48.2 grpcio-status==1.48.2 # via flytekit idna==3.4 - # via requests -importlib-metadata==6.0.0 # via - # click + # requests + # yarl +importlib-metadata==6.7.0 + # via # flytekit # keyring importlib-resources==5.12.0 # via keyring +isodate==0.6.1 + # via azure-storage-blob jaraco-classes==3.2.3 # via keyring -jeepney==0.8.0 - # via - # keyring - # secretstorage jinja2==3.1.2 # via # cookiecutter # jinja2-time jinja2-time==0.2.0 # via cookiecutter -joblib==1.2.0 +jmespath==1.0.1 + # via botocore +joblib==1.3.1 + # via flytekit +keyring==24.2.0 # via flytekit -keyring==23.13.1 +kubernetes==26.1.0 # via flytekit -markupsafe==2.1.2 +markdown-it-py==3.0.0 + # via rich +markupsafe==2.1.3 # via jinja2 marshmallow==3.19.0 # via @@ -96,53 +186,81 @@ marshmallow-enum==1.5.1 # via dataclasses-json marshmallow-jsonschema==0.13.0 # via flytekit -more-itertools==9.0.0 +mdurl==0.1.2 + # via markdown-it-py +more-itertools==9.1.0 # via jaraco-classes +msal==1.22.0 + # via + # azure-datalake-store + # azure-identity + # msal-extensions +msal-extensions==1.0.0 + # via azure-identity +multidict==6.0.4 + # via + # aiohttp + # yarl mypy-extensions==1.0.0 # via typing-inspect -natsort==8.2.0 +natsort==8.4.0 # via flytekit -numpy==1.21.6 +numpy==1.23.5 # via # flytekit # pandas # pyarrow -packaging==23.0 +oauthlib==3.2.2 + # via requests-oauthlib +packaging==23.1 # via # docker # marshmallow -pandas==1.3.5 +pandas==1.5.3 # via flytekit +portalocker==2.7.0 + # via msal-extensions protobuf==3.20.3 # via # flyteidl # flytekit + # google-api-core # googleapis-common-protos # grpcio-status # protoc-gen-swagger protoc-gen-swagger==0.1.0 # via flyteidl -py==1.11.0 - # via retry pyarrow==10.0.1 # via flytekit +pyasn1==0.5.0 + # via + # pyasn1-modules + # rsa +pyasn1-modules==0.3.0 + # via google-auth pycparser==2.21 # via cffi -pyopenssl==23.0.0 +pygments==2.15.1 + # via rich +pyjwt[crypto]==2.7.0 + # via msal +pyopenssl==23.2.0 # via flytekit python-dateutil==2.8.2 # via # arrow + # botocore # croniter # flytekit + # kubernetes # pandas python-json-logger==2.0.7 # via flytekit -python-slugify==8.0.0 +python-slugify==8.0.1 # via cookiecutter pytimeparse==1.1.8 # via flytekit -pytz==2022.7.1 +pytz==2023.3 # via # flytekit # pandas @@ -150,60 +268,95 @@ pyyaml==6.0 # via # cookiecutter # flytekit -regex==2022.10.31 + # kubernetes + # responses +regex==2023.6.3 # via docker-image-py -requests==2.28.2 +requests==2.31.0 # via + # azure-core + # azure-datalake-store # cookiecutter # docker # flytekit + # gcsfs + # google-api-core + # google-cloud-storage + # kubernetes + # msal + # requests-oauthlib # responses -responses==0.22.0 +requests-oauthlib==1.3.1 + # via + # google-auth-oauthlib + # kubernetes +responses==0.23.1 # via flytekit -retry==0.9.2 +rich==13.4.2 + # via + # flytekit + # rich-click +rich-click==1.6.1 # via flytekit -secretstorage==3.3.3 - # via keyring -singledispatchmethod==1.0 +rsa==4.9 + # via google-auth +s3fs==2023.6.0 # via flytekit six==1.16.0 # via + # azure-core + # azure-identity + # google-auth # grpcio + # isodate + # kubernetes # python-dateutil +smmap==5.0.0 + # via gitdb sortedcontainers==2.4.0 # via flytekit statsd==3.3.0 # via flytekit text-unidecode==1.3 # via python-slugify -toml==0.10.2 +types-pyyaml==6.0.12.10 # via responses -types-toml==0.10.8.5 - # via responses -typing-extensions==4.5.0 +typing-extensions==4.7.0 # via - # arrow + # aioitertools + # azure-core + # azure-storage-blob # flytekit - # importlib-metadata - # responses + # rich # typing-inspect -typing-inspect==0.8.0 +typing-inspect==0.9.0 # via dataclasses-json -urllib3==1.26.14 +urllib3==1.26.16 # via + # botocore # docker # flytekit + # google-auth + # kubernetes # requests # responses -websocket-client==1.5.1 - # via docker -wheel==0.38.4 +websocket-client==1.6.1 + # via + # docker + # kubernetes +wheel==0.40.0 # via flytekit -wrapt==1.14.1 +wrapt==1.15.0 # via + # aiobotocore # deprecated # flytekit -zipp==3.14.0 +yarl==1.9.2 + # via aiohttp +zipp==3.15.0 # via # importlib-metadata # importlib-resources + +# The following packages are considered to be unsafe in a requirements file: +# setuptools diff --git a/plugins/flytekit-kf-tensorflow/setup.py b/plugins/flytekit-kf-tensorflow/setup.py index c35f11f12a..ff945e221c 100644 --- a/plugins/flytekit-kf-tensorflow/setup.py +++ b/plugins/flytekit-kf-tensorflow/setup.py @@ -4,8 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -# TODO: Requirements are missing, add them back in later. -plugin_requires = ["flytekit>=1.1.0b0,<1.3.0,<2.0.0"] +plugin_requires = ["flytekit>=1.1.0b0,<1.3.0,<2.0.0", "flyteidl>=1.2.11,<1.3.0"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-kf-tensorflow/tests/test_tensorflow_task.py b/plugins/flytekit-kf-tensorflow/tests/test_tensorflow_task.py index 2bcfcda550..d863d3fdc4 100644 --- a/plugins/flytekit-kf-tensorflow/tests/test_tensorflow_task.py +++ b/plugins/flytekit-kf-tensorflow/tests/test_tensorflow_task.py @@ -1,9 +1,173 @@ -from flytekitplugins.kftensorflow import TfJob +import pytest +from flytekitplugins.kftensorflow import PS, Chief, CleanPodPolicy, RestartPolicy, RunPolicy, TfJob, Worker from flytekit import Resources, task from flytekit.configuration import Image, ImageConfig, SerializationSettings +@pytest.fixture +def serialization_settings() -> SerializationSettings: + default_img = Image(name="default", fqn="test", tag="tag") + settings = SerializationSettings( + project="project", + domain="domain", + version="version", + env={"FOO": "baz"}, + image_config=ImageConfig(default_image=default_img, images=[default_img]), + ) + return settings + + +def test_tensorflow_task_with_default_config(serialization_settings: SerializationSettings): + task_config = TfJob( + worker=Worker(replicas=1), + chief=Chief(replicas=0), + ps=PS(replicas=0), + ) + + @task( + task_config=task_config, + cache=True, + requests=Resources(cpu="1"), + cache_version="1", + ) + def my_tensorflow_task(x: int, y: str) -> int: + return x + + assert my_tensorflow_task(x=10, y="hello") == 10 + + assert my_tensorflow_task.task_config is not None + assert my_tensorflow_task.task_type == "tensorflow" + assert my_tensorflow_task.resources.limits == Resources() + assert my_tensorflow_task.resources.requests == Resources(cpu="1") + + expected_dict = { + "chiefReplicas": { + "resources": {}, + }, + "workerReplicas": { + "replicas": 1, + "resources": {}, + }, + "psReplicas": { + "resources": {}, + }, + } + assert my_tensorflow_task.get_custom(serialization_settings) == expected_dict + + +def test_tensorflow_task_with_custom_config(serialization_settings: SerializationSettings): + task_config = TfJob( + chief=Chief( + replicas=1, + requests=Resources(cpu="1"), + limits=Resources(cpu="2"), + image="chief:latest", + ), + worker=Worker( + replicas=5, + requests=Resources(cpu="2", mem="2Gi"), + limits=Resources(cpu="4", mem="2Gi"), + image="worker:latest", + restart_policy=RestartPolicy.FAILURE, + ), + ps=PS( + replicas=2, + restart_policy=RestartPolicy.ALWAYS, + ), + ) + + @task( + task_config=task_config, + cache=True, + requests=Resources(cpu="1"), + cache_version="1", + ) + def my_tensorflow_task(x: int, y: str) -> int: + return x + + assert my_tensorflow_task(x=10, y="hello") == 10 + + assert my_tensorflow_task.task_config is not None + assert my_tensorflow_task.task_type == "tensorflow" + assert my_tensorflow_task.resources.limits == Resources() + assert my_tensorflow_task.resources.requests == Resources(cpu="1") + + expected_custom_dict = { + "chiefReplicas": { + "replicas": 1, + "image": "chief:latest", + "resources": { + "requests": [{"name": "CPU", "value": "1"}], + "limits": [{"name": "CPU", "value": "2"}], + }, + }, + "workerReplicas": { + "replicas": 5, + "image": "worker:latest", + "resources": { + "requests": [ + {"name": "CPU", "value": "2"}, + {"name": "MEMORY", "value": "2Gi"}, + ], + "limits": [ + {"name": "CPU", "value": "4"}, + {"name": "MEMORY", "value": "2Gi"}, + ], + }, + "restartPolicy": "RESTART_POLICY_ON_FAILURE", + }, + "psReplicas": { + "resources": {}, + "replicas": 2, + "restartPolicy": "RESTART_POLICY_ALWAYS", + }, + } + assert my_tensorflow_task.get_custom(serialization_settings) == expected_custom_dict + + +def test_tensorflow_task_with_run_policy(serialization_settings: SerializationSettings): + task_config = TfJob( + worker=Worker(replicas=1), + ps=PS(replicas=0), + chief=Chief(replicas=0), + run_policy=RunPolicy(clean_pod_policy=CleanPodPolicy.RUNNING), + ) + + @task( + task_config=task_config, + cache=True, + requests=Resources(cpu="1"), + cache_version="1", + ) + def my_tensorflow_task(x: int, y: str) -> int: + return x + + assert my_tensorflow_task(x=10, y="hello") == 10 + + assert my_tensorflow_task.task_config is not None + assert my_tensorflow_task.task_type == "tensorflow" + assert my_tensorflow_task.resources.limits == Resources() + assert my_tensorflow_task.resources.requests == Resources(cpu="1") + + expected_dict = { + "chiefReplicas": { + "resources": {}, + }, + "workerReplicas": { + "replicas": 1, + "resources": {}, + }, + "psReplicas": { + "resources": {}, + }, + "runPolicy": { + "cleanPodPolicy": "CLEANPOD_POLICY_RUNNING", + }, + } + assert my_tensorflow_task.get_custom(serialization_settings) == expected_dict + + def test_tensorflow_task(): @task( task_config=TfJob(num_workers=10, num_ps_replicas=1, num_chief_replicas=1), @@ -27,7 +191,21 @@ def my_tensorflow_task(x: int, y: str) -> int: image_config=ImageConfig(default_image=default_img, images=[default_img]), ) - assert my_tensorflow_task.get_custom(settings) == {"workers": 10, "psReplicas": 1, "chiefReplicas": 1} + expected_dict = { + "chiefReplicas": { + "replicas": 1, + "resources": {}, + }, + "workerReplicas": { + "replicas": 10, + "resources": {}, + }, + "psReplicas": { + "replicas": 1, + "resources": {}, + }, + } + assert my_tensorflow_task.get_custom(settings) == expected_dict assert my_tensorflow_task.resources.limits == Resources() assert my_tensorflow_task.resources.requests == Resources(cpu="1") assert my_tensorflow_task.task_type == "tensorflow"