Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change flytekit Pytorch, TFJob and MPI plugins to use new kubeflow config #1627

Merged
merged 12 commits into from
May 31, 2023
14 changes: 13 additions & 1 deletion plugins/flytekit-kf-tensorflow/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,16 @@ To install the plugin, run the following command:
pip install flytekitplugins-kftensorflow
```

_Example coming soon!_
## Upgrade TensorFlow Plugin
Tensorflow plugin is now updated from v0 to v1 to enable more configuration options.
To migrate from v0 to v1, change the following:
1. Update flytepropeller to v
yubofredwang marked this conversation as resolved.
Show resolved Hide resolved
2. Update flytekit version to v
yubofredwang marked this conversation as resolved.
Show resolved Hide resolved
3. Update your code from:
```
task_config=TfJob(num_workers=10, num_ps_replicas=1, num_chief_replicas=1),
```
to:
```
task_config=TfJob(worker=Worker(replicas=10), ps=PS(replicas=1), chief=Chief(replicas=1)),
yubofredwang marked this conversation as resolved.
Show resolved Hide resolved
```
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@
TfJob
"""

from .task import TfJob
from .task import PS, Chief, RunPolicy, TfJob, Worker, RestartPolicy, CleanPodPolicy

This file was deleted.

122 changes: 101 additions & 21 deletions plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,37 +2,83 @@
This Plugin adds the capability of running distributed tensorflow training to Flyte using backend plugins, natively on
Kubernetes. It leverages `TF Job <https://github.com/kubeflow/tf-operator>`_ Plugin from kubeflow.
"""
from dataclasses import dataclass
from typing import Any, Callable, Dict
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Callable, Dict, Optional, Union

from google.protobuf.json_format import MessageToDict

from flytekit import PythonFunctionTask
from flytekit import PythonFunctionTask, Resources
from flytekit.configuration import SerializationSettings
from flytekit.core.resources import convert_resources_to_resource_model
from flytekit.extend import TaskPlugins
from flyteidl.plugins.kubeflow import tensorflow_pb2 as tensorflow_task
from flyteidl.plugins.kubeflow import common_pb2 as kubeflow_common

from .models import TensorFlowJob
@dataclass
class RestartPolicy(Enum):
"""
RestartPolicy describes how the replicas should be restarted
"""

ALWAYS = kubeflow_common.RESTART_POLICY_ALWAYS
FAILURE = kubeflow_common.RESTART_POLICY_ON_FAILURE
NEVER = kubeflow_common.RESTART_POLICY_NEVER

@dataclass
class TfJob(object):
class CleanPodPolicy(Enum):
"""
CleanPodPolicy describes how to deal with pods when the job is finished.
"""
Configuration for an executable `TF Job <https://github.com/kubeflow/tf-operator>`_. Use this
to run distributed tensorflow training on k8s (with parameter server)

Args:
num_workers: integer determining the number of worker replicas spawned in the cluster for this job
(in addition to 1 master).
NONE = kubeflow_common.CLEANPOD_POLICY_NONE
ALL = kubeflow_common.CLEANPOD_POLICY_ALL
RUNNING = kubeflow_common.CLEANPOD_POLICY_RUNNING

num_ps_replicas: Number of Parameter server replicas to use
@dataclass
class RunPolicy:
clean_pod_policy: CleanPodPolicy = None
ttl_seconds_after_finished: Optional[int] = None
active_deadline_seconds: Optional[int] = None
backoff_limit: Optional[int] = None

num_chief_replicas: Number of chief replicas to use

"""
@dataclass
class Chief:
image: Optional[str] = None
requests: Optional[Resources] = None
limits: Optional[Resources] = None
replicas: Optional[int] = 0
restart_policy: Optional[RestartPolicy] = None

num_workers: int
num_ps_replicas: int
num_chief_replicas: int

@dataclass
class PS:
image: Optional[str] = None
requests: Optional[Resources] = None
limits: Optional[Resources] = None
replicas: Optional[int] = None
restart_policy: Optional[RestartPolicy] = None


@dataclass
class Worker:
image: Optional[str] = None
requests: Optional[Resources] = None
limits: Optional[Resources] = None
replicas: Optional[int] = 1
restart_policy: Optional[RestartPolicy] = None


@dataclass
class TfJob:
chief: Chief = field(default_factory=lambda: Chief())
ps: PS = field(default_factory=lambda: PS())
worker: Worker = field(default_factory=lambda: Worker())
run_policy: Optional[RunPolicy] = field(default_factory=lambda: None)
num_workers: Optional[int] = None
num_ps_replicas: Optional[int] = None
num_chief_replicas: Optional[int] = None


class TensorflowFunctionTask(PythonFunctionTask[TfJob]):
Expand All @@ -48,16 +94,50 @@ def __init__(self, task_config: TfJob, task_function: Callable, **kwargs):
task_type=self._TF_JOB_TASK_TYPE,
task_config=task_config,
task_function=task_function,
task_type_version=1,
**kwargs,
)

def _convert_replica_spec(self, replica_config: Union[Chief, PS, Worker]) -> tensorflow_task.DistributedTensorflowTrainingReplicaSpec:
resources = convert_resources_to_resource_model(requests=replica_config.requests, limits=replica_config.limits)
return tensorflow_task.DistributedTensorflowTrainingReplicaSpec(
replicas=replica_config.replicas,
image=replica_config.image,
resources=resources.to_flyte_idl() if resources else None,
restart_policy=replica_config.restart_policy.value if replica_config.restart_policy else None,
)


def _convert_run_policy(self, run_policy: RunPolicy) -> kubeflow_common.RunPolicy:
return kubeflow_common.RunPolicy(
clean_pod_policy=run_policy.clean_pod_policy.value if run_policy.clean_pod_policy else None,
ttl_seconds_after_finished=run_policy.ttl_seconds_after_finished,
active_deadline_seconds=run_policy.active_deadline_seconds,
backoff_limit=run_policy.active_deadline_seconds,
)

def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
job = TensorFlowJob(
workers_count=self.task_config.num_workers,
ps_replicas_count=self.task_config.num_ps_replicas,
chief_replicas_count=self.task_config.num_chief_replicas,
chief = self._convert_replica_spec(self.task_config.chief)
if (self.task_config.num_chief_replicas):
chief.replicas = self.task_config.num_chief_replicas

worker = self._convert_replica_spec(self.task_config.worker)
if (self.task_config.num_workers):
worker.replicas = self.task_config.num_workers

ps = self._convert_replica_spec(self.task_config.ps)
if (self.task_config.num_ps_replicas):
ps.replicas = self.task_config.num_ps_replicas

run_policy = self._convert_run_policy(self.task_config.run_policy) if self.task_config.run_policy else None
training_task = tensorflow_task.DistributedTensorflowTrainingTask(
chief_replicas=chief,
worker_replicas=worker,
ps_replicas=ps,
run_policy=run_policy,
)
return MessageToDict(job.to_flyte_idl())

return MessageToDict(training_task)


# Register the Tensorflow Plugin into the flytekit core plugin system
Expand Down
2 changes: 1 addition & 1 deletion plugins/flytekit-kf-tensorflow/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
microlib_name = f"flytekitplugins-{PLUGIN_NAME}"

# TODO: Requirements are missing, add them back in later.
plugin_requires = ["flytekit>=1.3.0b2,<2.0.0"]
plugin_requires = []

__version__ = "0.0.0+develop"

Expand Down
Loading