diff --git a/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/__init__.py b/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/__init__.py index 069be25991..16623fb79c 100644 --- a/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/__init__.py +++ b/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/__init__.py @@ -10,5 +10,4 @@ TfJob """ -from .models import CleanPodPolicy, RestartPolicy -from .task import PS, Chief, RunPolicy, TfJob, Worker +from .task import PS, Chief, RunPolicy, TfJob, Worker, RestartPolicy, CleanPodPolicy diff --git a/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/models/__init__.py b/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/models/__init__.py deleted file mode 100644 index 0b1e528b5b..0000000000 --- a/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/models/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from .chief import Chief -from .ps import PS -from .restart_policy import RestartPolicy -from .run_policy import CleanPodPolicy, RunPolicy -from .tensorflow_job import TensorFlowJob -from .worker import Worker diff --git a/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/models/chief.py b/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/models/chief.py deleted file mode 100644 index 5ec9e9b14e..0000000000 --- a/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/models/chief.py +++ /dev/null @@ -1,58 +0,0 @@ -from typing import Optional - -from flyteidl.plugins.kubeflow import tensorflow_pb2 as tensorflow_task - -from flytekit.models import common, task - -from .restart_policy import RestartPolicy - - -class Chief(common.FlyteIdlEntity): - """ - Configuration for a chief replica group in a TFJob. - - :param replicas: Number of replicas in the group. This should be 1 or 0. If 0, the chief will be elected from the worker group. - :param image: Optional image to use for the pods of the group - :param resources: Optional resources to use for the pods of the group - :param restart_policy: Optional restart policy to use for the pods of the group - """ - - def __init__( - self, - replicas: int, - image: Optional[str] = None, - resources: Optional[task.Resources] = None, - restart_policy: Optional[RestartPolicy] = None, - ): - if replicas != 0 and replicas != 1: - raise ValueError( - f"TFJob chief group needs to have either one replica or no replica(one worker will be elected as chief), but {replicas} have been specified." - ) - self._replicas = replicas - self._image = image - self._resources = resources - self._restart_policy = restart_policy - - @property - def image(self) -> Optional[str]: - return self._image - - @property - def resources(self) -> Optional[task.Resources]: - return self._resources - - @property - def replicas(self) -> Optional[int]: - return self._replicas - - @property - def restart_policy(self) -> Optional[RestartPolicy]: - return self._restart_policy - - def to_flyte_idl(self) -> tensorflow_task.DistributedTensorflowTrainingReplicaSpec: - return tensorflow_task.DistributedTensorflowTrainingReplicaSpec( - replicas=self.replicas, - image=self.image, - resources=self.resources.to_flyte_idl() if self.resources else None, - restart_policy=self.restart_policy.value if self.restart_policy else None, - ) diff --git a/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/models/ps.py b/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/models/ps.py deleted file mode 100644 index e54f464591..0000000000 --- a/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/models/ps.py +++ /dev/null @@ -1,54 +0,0 @@ -from typing import Optional - -from flyteidl.plugins.kubeflow import tensorflow_pb2 as tensorflow_task - -from flytekit.models import common, task - -from .restart_policy import RestartPolicy - - -class PS(common.FlyteIdlEntity): - """ - Configuration for a ps replica group in a TFJob. - - :param replicas: Number of replicas in the group. Default is 0. - :param image: Optional image to use for the pods of the group - :param resources: Optional resources to use for the pods of the group - :param restart_policy: Optional restart policy to use for the pods of the group - """ - - def __init__( - self, - replicas: int, - image: Optional[str] = None, - resources: Optional[task.Resources] = None, - restart_policy: Optional[RestartPolicy] = None, - ): - self._replicas = replicas - self._image = image - self._resources = resources - self._restart_policy = restart_policy - - @property - def image(self) -> Optional[str]: - return self._image - - @property - def resources(self) -> Optional[task.Resources]: - return self._resources - - @property - def replicas(self) -> Optional[int]: - return self._replicas - - @property - def restart_policy(self) -> Optional[RestartPolicy]: - return self._restart_policy - - def to_flyte_idl(self) -> tensorflow_task.DistributedTensorflowTrainingReplicaSpec: - return tensorflow_task.DistributedTensorflowTrainingReplicaSpec( - replicas=self.replicas, - image=self.image, - resources=self.resources.to_flyte_idl() if self.resources else None, - restart_policy=self.restart_policy.value if self.restart_policy else None, - ) diff --git a/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/models/restart_policy.py b/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/models/restart_policy.py deleted file mode 100644 index eb8cd011b8..0000000000 --- a/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/models/restart_policy.py +++ /dev/null @@ -1,15 +0,0 @@ -from dataclasses import dataclass -from enum import Enum - -from flyteidl.plugins.kubeflow.common_pb2 import RESTART_POLICY_ALWAYS, RESTART_POLICY_NEVER, RESTART_POLICY_ON_FAILURE - - -@dataclass -class RestartPolicy(Enum): - """ - RestartPolicy describes how the replicas should be restarted - """ - - ALWAYS = RESTART_POLICY_ALWAYS - FAILURE = RESTART_POLICY_ON_FAILURE - NEVER = RESTART_POLICY_NEVER diff --git a/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/models/run_policy.py b/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/models/run_policy.py deleted file mode 100644 index fd2aab8031..0000000000 --- a/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/models/run_policy.py +++ /dev/null @@ -1,62 +0,0 @@ -from dataclasses import dataclass -from enum import Enum -from typing import Optional - -from flyteidl.plugins.kubeflow import common_pb2 as kubeflow_common -from flyteidl.plugins.kubeflow.common_pb2 import CLEANPOD_POLICY_ALL, CLEANPOD_POLICY_NONE, CLEANPOD_POLICY_RUNNING - -from flytekit.models import common - - -@dataclass -class CleanPodPolicy(Enum): - """ - CleanPodPolicy describes how to deal with pods when the job is finished. - """ - - NONE = CLEANPOD_POLICY_NONE - ALL = CLEANPOD_POLICY_ALL - RUNNING = CLEANPOD_POLICY_RUNNING - - -class RunPolicy(common.FlyteIdlEntity): - """ - RunPolicy encapsulates various runtime policies of the distributed training job, - for example how to clean up resources and how long the job can stay active. - """ - - def __init__( - self, - clean_pod_policy: Optional[CleanPodPolicy], - ttl_seconds_after_finished: Optional[int], - active_deadline_seconds: Optional[int], - backoff_limit: Optional[int], - ): - self._clean_pod_policy = clean_pod_policy - self._ttl_seconds_after_finished = ttl_seconds_after_finished - self._active_deadline_seconds = active_deadline_seconds - self._backoff_limit = backoff_limit - - @property - def clean_pod_policy(self) -> Optional[CleanPodPolicy]: - return self._clean_pod_policy - - @property - def ttl_seconds_after_finished(self) -> Optional[int]: - return self._ttl_seconds_after_finished - - @property - def active_deadline_seconds(self) -> Optional[int]: - return self._active_deadline_seconds - - @property - def backoff_limit(self) -> Optional[int]: - return self._backoff_limit - - def to_flyte_idl(self) -> kubeflow_common.RunPolicy: - return kubeflow_common.RunPolicy( - clean_pod_policy=self._clean_pod_policy.value if self._clean_pod_policy else None, - ttl_seconds_after_finished=self._ttl_seconds_after_finished, - active_deadline_seconds=self._active_deadline_seconds, - backoff_limit=self._backoff_limit, - ) diff --git a/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/models/tensorflow_job.py b/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/models/tensorflow_job.py deleted file mode 100644 index 33e8e503e0..0000000000 --- a/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/models/tensorflow_job.py +++ /dev/null @@ -1,39 +0,0 @@ -from typing import Optional - -from flyteidl.plugins.kubeflow import tensorflow_pb2 as tensorflow_task -from flytekitplugins.kftensorflow.models import PS, Chief, RunPolicy, Worker - -from flytekit.models import common - - -class TensorFlowJob(common.FlyteIdlEntity): - def __init__(self, chief: Chief, ps: PS, worker: Worker, run_policy: Optional[RunPolicy] = None): - self._chief = chief - self._ps = ps - self._worker = worker - self._run_policy = run_policy - - @property - def worker(self): - return self._worker - - @property - def ps(self): - return self._ps - - @property - def chief(self): - return self._chief - - @property - def run_policy(self): - return self._run_policy - - def to_flyte_idl(self) -> tensorflow_task.DistributedTensorflowTrainingTask: - training_task = tensorflow_task.DistributedTensorflowTrainingTask( - chief_replicas=self.chief.to_flyte_idl(), - worker_replicas=self.worker.to_flyte_idl(), - ps_replicas=self.ps.to_flyte_idl(), - run_policy=self.run_policy.to_flyte_idl() if self.run_policy else None, - ) - return training_task diff --git a/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/models/worker.py b/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/models/worker.py deleted file mode 100644 index e68dcf0e02..0000000000 --- a/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/models/worker.py +++ /dev/null @@ -1,58 +0,0 @@ -from typing import Optional - -from flyteidl.plugins.kubeflow import tensorflow_pb2 as tensorflow_task - -from flytekit.models import common, task - -from .restart_policy import RestartPolicy - - -class Worker(common.FlyteIdlEntity): - """ - Configuration for a worker replica group in a TFJob. - - :param replicas: Number of replicas in the group. Minimum is 1. - :param image: Optional image to use for the pods of the group - :param resources: Optional resources to use for the pods of the group - :param restart_policy: Optional restart policy to use for the pods of the group - """ - - def __init__( - self, - replicas: int, - image: Optional[str] = None, - resources: Optional[task.Resources] = None, - restart_policy: Optional[RestartPolicy] = None, - ): - if replicas < 1: - raise ValueError( - f"TFJob worker replica needs to have at least one worker, but {replicas} have been specified." - ) - self._replicas = replicas - self._image = image - self._resources = resources - self._restart_policy = restart_policy - - @property - def image(self) -> Optional[str]: - return self._image - - @property - def resources(self) -> Optional[task.Resources]: - return self._resources - - @property - def replicas(self) -> Optional[int]: - return self._replicas - - @property - def restart_policy(self) -> Optional[RestartPolicy]: - return self._restart_policy - - def to_flyte_idl(self) -> tensorflow_task.DistributedTensorflowTrainingReplicaSpec: - return tensorflow_task.DistributedTensorflowTrainingReplicaSpec( - replicas=self.replicas, - image=self.image, - resources=self.resources.to_flyte_idl() if self.resources else None, - restart_policy=self.restart_policy.value if self.restart_policy else None, - ) diff --git a/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/task.py b/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/task.py index 5e3bf56d34..c911371465 100644 --- a/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/task.py +++ b/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/task.py @@ -3,20 +3,41 @@ Kubernetes. It leverages `TF Job `_ Plugin from kubeflow. """ from dataclasses import dataclass, field -from typing import Any, Callable, Dict, Optional +from enum import Enum +from typing import Any, Callable, Dict, Optional, Union -from flytekitplugins.kftensorflow import models from google.protobuf.json_format import MessageToDict from flytekit import PythonFunctionTask, Resources from flytekit.configuration import SerializationSettings from flytekit.core.resources import convert_resources_to_resource_model from flytekit.extend import TaskPlugins +from flyteidl.plugins.kubeflow import tensorflow_pb2 as tensorflow_task +from flyteidl.plugins.kubeflow import common_pb2 as kubeflow_common +@dataclass +class RestartPolicy(Enum): + """ + RestartPolicy describes how the replicas should be restarted + """ + + ALWAYS = kubeflow_common.RESTART_POLICY_ALWAYS + FAILURE = kubeflow_common.RESTART_POLICY_ON_FAILURE + NEVER = kubeflow_common.RESTART_POLICY_NEVER + +@dataclass +class CleanPodPolicy(Enum): + """ + CleanPodPolicy describes how to deal with pods when the job is finished. + """ + + NONE = kubeflow_common.CLEANPOD_POLICY_NONE + ALL = kubeflow_common.CLEANPOD_POLICY_ALL + RUNNING = kubeflow_common.CLEANPOD_POLICY_RUNNING @dataclass class RunPolicy: - clean_pod_policy: models.CleanPodPolicy = None + clean_pod_policy: CleanPodPolicy = None ttl_seconds_after_finished: Optional[int] = None active_deadline_seconds: Optional[int] = None backoff_limit: Optional[int] = None @@ -28,7 +49,7 @@ class Chief: requests: Optional[Resources] = None limits: Optional[Resources] = None replicas: Optional[int] = 0 - restart_policy: Optional[models.RestartPolicy] = None + restart_policy: Optional[RestartPolicy] = None @dataclass @@ -37,7 +58,7 @@ class PS: requests: Optional[Resources] = None limits: Optional[Resources] = None replicas: Optional[int] = None - restart_policy: Optional[models.RestartPolicy] = None + restart_policy: Optional[RestartPolicy] = None @dataclass @@ -46,7 +67,7 @@ class Worker: requests: Optional[Resources] = None limits: Optional[Resources] = None replicas: Optional[int] = 1 - restart_policy: Optional[models.RestartPolicy] = None + restart_policy: Optional[RestartPolicy] = None @dataclass @@ -55,6 +76,9 @@ class TfJob: ps: PS = field(default_factory=lambda: PS()) worker: Worker = field(default_factory=lambda: Worker()) run_policy: Optional[RunPolicy] = field(default_factory=lambda: None) + num_workers: Optional[int] = None + num_ps_replicas: Optional[int] = None + num_chief_replicas: Optional[int] = None class TensorflowFunctionTask(PythonFunctionTask[TfJob]): @@ -73,48 +97,47 @@ def __init__(self, task_config: TfJob, task_function: Callable, **kwargs): task_type_version=1, **kwargs, ) - - def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: - chief = models.Chief( - replicas=self.task_config.chief.replicas, - image=self.task_config.chief.image, - resources=convert_resources_to_resource_model( - requests=self.task_config.chief.requests, - limits=self.task_config.chief.limits, - ), - restart_policy=self.task_config.chief.restart_policy, - ) - worker = models.Worker( - replicas=self.task_config.worker.replicas, - image=self.task_config.worker.image, - resources=convert_resources_to_resource_model( - requests=self.task_config.worker.requests, - limits=self.task_config.worker.limits, - ), - restart_policy=self.task_config.worker.restart_policy, + + def _convert_replica_spec(self, replica_config: Union[Chief, PS, Worker]) -> tensorflow_task.DistributedTensorflowTrainingReplicaSpec: + resources = convert_resources_to_resource_model(requests=replica_config.requests, limits=replica_config.limits) + return tensorflow_task.DistributedTensorflowTrainingReplicaSpec( + replicas=replica_config.replicas, + image=replica_config.image, + resources=resources.to_flyte_idl() if resources else None, + restart_policy=replica_config.restart_policy.value if replica_config.restart_policy else None, ) - ps = models.PS( - replicas=self.task_config.ps.replicas, - image=self.task_config.ps.image, - resources=convert_resources_to_resource_model( - requests=self.task_config.ps.requests, - limits=self.task_config.ps.limits, - ), - restart_policy=self.task_config.ps.restart_policy, - ) - run_policy = ( - models.RunPolicy( - clean_pod_policy=self.task_config.run_policy.clean_pod_policy, - ttl_seconds_after_finished=self.task_config.run_policy.ttl_seconds_after_finished, - active_deadline_seconds=self.task_config.run_policy.active_deadline_seconds, - backoff_limit=self.task_config.run_policy.backoff_limit, - ) - if self.task_config.run_policy - else None + + + def _convert_run_policy(self, run_policy: RunPolicy) -> kubeflow_common.RunPolicy: + return kubeflow_common.RunPolicy( + clean_pod_policy=run_policy.clean_pod_policy.value if run_policy.clean_pod_policy else None, + ttl_seconds_after_finished=run_policy.ttl_seconds_after_finished, + active_deadline_seconds=run_policy.active_deadline_seconds, + backoff_limit=run_policy.active_deadline_seconds, ) - job = models.TensorFlowJob(worker=worker, chief=chief, ps=ps, run_policy=run_policy) - return MessageToDict(job.to_flyte_idl()) + def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: + chief = self._convert_replica_spec(self.task_config.chief) + if (self.task_config.num_chief_replicas): + chief.replicas = self.task_config.num_chief_replicas + + worker = self._convert_replica_spec(self.task_config.worker) + if (self.task_config.num_workers): + worker.replicas = self.task_config.num_workers + + ps = self._convert_replica_spec(self.task_config.ps) + if (self.task_config.num_ps_replicas): + ps.replicas = self.task_config.num_ps_replicas + + run_policy = self._convert_run_policy(self.task_config.run_policy) if self.task_config.run_policy else None + training_task = tensorflow_task.DistributedTensorflowTrainingTask( + chief_replicas=chief, + worker_replicas=worker, + ps_replicas=ps, + run_policy=run_policy, + ) + + return MessageToDict(training_task) # Register the Tensorflow Plugin into the flytekit core plugin system diff --git a/plugins/flytekit-kf-tensorflow/tests/test_tensorflow_task.py b/plugins/flytekit-kf-tensorflow/tests/test_tensorflow_task.py index c5ed280ec4..e8e21e7e41 100644 --- a/plugins/flytekit-kf-tensorflow/tests/test_tensorflow_task.py +++ b/plugins/flytekit-kf-tensorflow/tests/test_tensorflow_task.py @@ -157,3 +157,45 @@ def my_tensorflow_task(x: int, y: str) -> int: }, } assert my_tensorflow_task.get_custom(serialization_settings) == expected_dict + +def test_tensorflow_task(): + @task( + task_config=TfJob(num_workers=10, num_ps_replicas=1, num_chief_replicas=1), + cache=True, + requests=Resources(cpu="1"), + cache_version="1", + ) + def my_tensorflow_task(x: int, y: str) -> int: + return x + + assert my_tensorflow_task(x=10, y="hello") == 10 + + assert my_tensorflow_task.task_config is not None + + default_img = Image(name="default", fqn="test", tag="tag") + settings = SerializationSettings( + project="project", + domain="domain", + version="version", + env={"FOO": "baz"}, + image_config=ImageConfig(default_image=default_img, images=[default_img]), + ) + + expected_dict = { + "chiefReplicas": { + "replicas": 1, + "resources": {}, + }, + "workerReplicas": { + "replicas": 10, + "resources": {}, + }, + "psReplicas": { + "replicas": 1, + "resources": {}, + }, + } + assert my_tensorflow_task.get_custom(settings) == expected_dict + assert my_tensorflow_task.resources.limits == Resources() + assert my_tensorflow_task.resources.requests == Resources(cpu="1") + assert my_tensorflow_task.task_type == "tensorflow" \ No newline at end of file