Skip to content

Commit

Permalink
upgrade tensorflow plugin to v1
Browse files Browse the repository at this point in the history
Signed-off-by: Yubo Wang <[email protected]>
  • Loading branch information
Yubo Wang committed May 13, 2023
1 parent dab1eed commit 2df576f
Show file tree
Hide file tree
Showing 13 changed files with 531 additions and 70 deletions.
14 changes: 13 additions & 1 deletion plugins/flytekit-kf-tensorflow/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,16 @@ To install the plugin, run the following command:
pip install flytekitplugins-kftensorflow
```

_Example coming soon!_
## Upgrade TensorFlow Plugin
Tensorflow plugin is now updated from v0 to v1 to enable more configuration options.
To migrate from v0 to v1, change the following:
1. Update flytepropeller to v
2. Update flytekit version to v
3. Update your code from:
```
task_config=TfJob(num_workers=10, num_ps_replicas=1, num_chief_replicas=1),
```
to:
```
task_config=TfJob(worker=Worker(replicas=10), ps=PS(replicas=1), chief=Chief(replicas=1)),
```
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@
TfJob
"""

from .task import TfJob
from .models import CleanPodPolicy, RestartPolicy
from .task import PS, Chief, RunPolicy, TfJob, Worker

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .chief import Chief
from .ps import PS
from .restart_policy import RestartPolicy
from .run_policy import CleanPodPolicy, RunPolicy
from .tensorflow_job import TensorFlowJob
from .worker import Worker
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from typing import Optional

from flyteidl.plugins.kubeflow import tensorflow_pb2 as tensorflow_task

from flytekit.models import common, task

from .restart_policy import RestartPolicy


class Chief(common.FlyteIdlEntity):
"""
Configuration for a chief replica group in a TFJob.
:param replicas: Number of replicas in the group. This should be 1 or 0. If 0, the chief will be elected from the worker group.
:param image: Optional image to use for the pods of the group
:param resources: Optional resources to use for the pods of the group
:param restart_policy: Optional restart policy to use for the pods of the group
"""

def __init__(
self,
replicas: int,
image: Optional[str] = None,
resources: Optional[task.Resources] = None,
restart_policy: Optional[RestartPolicy] = None,
):
if replicas != 0 and replicas != 1:
raise ValueError(
f"TFJob chief group needs to have either one replica or no replica(one worker will be elected as chief), but {replicas} have been specified."
)
self._replicas = replicas
self._image = image
self._resources = resources
self._restart_policy = restart_policy

@property
def image(self) -> Optional[str]:
return self._image

@property
def resources(self) -> Optional[task.Resources]:
return self._resources

@property
def replicas(self) -> Optional[int]:
return self._replicas

@property
def restart_policy(self) -> Optional[RestartPolicy]:
return self._restart_policy

def to_flyte_idl(self) -> tensorflow_task.DistributedTensorflowTrainingReplicaSpec:
return tensorflow_task.DistributedTensorflowTrainingReplicaSpec(
replicas=self.replicas,
image=self.image,
resources=self.resources.to_flyte_idl() if self.resources else None,
restart_policy=self.restart_policy.value if self.restart_policy else None,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from enum import Enum
from typing import Optional

from flyteidl.plugins.kubeflow import tensorflow_pb2 as tensorflow_task

from flytekit.models import common, task

from .restart_policy import RestartPolicy


class PS(common.FlyteIdlEntity):
"""
Configuration for a ps replica group in a TFJob.
:param replicas: Number of replicas in the group. Default is 0.
:param image: Optional image to use for the pods of the group
:param resources: Optional resources to use for the pods of the group
:param restart_policy: Optional restart policy to use for the pods of the group
"""

def __init__(
self,
replicas: int,
image: Optional[str] = None,
resources: Optional[task.Resources] = None,
restart_policy: Optional[RestartPolicy] = None,
):
self._replicas = replicas
self._image = image
self._resources = resources
self._restart_policy = restart_policy

@property
def image(self) -> Optional[str]:
return self._image

@property
def resources(self) -> Optional[task.Resources]:
return self._resources

@property
def replicas(self) -> Optional[int]:
return self._replicas

@property
def restart_policy(self) -> Optional[RestartPolicy]:
return self._restart_policy

def to_flyte_idl(self) -> tensorflow_task.DistributedTensorflowTrainingReplicaSpec:
return tensorflow_task.DistributedTensorflowTrainingReplicaSpec(
replicas=self.replicas,
image=self.image,
resources=self.resources.to_flyte_idl() if self.resources else None,
restart_policy=self.restart_policy.value if self.restart_policy else None,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from dataclasses import dataclass
from enum import Enum

from flyteidl.plugins.kubeflow.common_pb2 import RESTART_POLICY_ALWAYS, RESTART_POLICY_NEVER, RESTART_POLICY_ON_FAILURE


@dataclass
class RestartPolicy(Enum):
"""
RestartPolicy describes how the replicas should be restarted
"""

ALWAYS = RESTART_POLICY_ALWAYS
FAILURE = RESTART_POLICY_ON_FAILURE
NEVER = RESTART_POLICY_NEVER
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from dataclasses import dataclass
from enum import Enum
from typing import Optional

from flyteidl.plugins.kubeflow import common_pb2 as kubeflow_common
from flyteidl.plugins.kubeflow.common_pb2 import CLEANPOD_POLICY_ALL, CLEANPOD_POLICY_NONE, CLEANPOD_POLICY_RUNNING

from flytekit.models import common


@dataclass
class CleanPodPolicy(Enum):
"""
CleanPodPolicy describes how to deal with pods when the job is finished.
"""

NONE = CLEANPOD_POLICY_NONE
ALL = CLEANPOD_POLICY_ALL
RUNNING = CLEANPOD_POLICY_RUNNING


class RunPolicy(common.FlyteIdlEntity):
"""
Configuration for a dask worker group
:param replicas: Number of workers in the group, minimum is 1
:param image: Optional image to use for the pods of the worker group
:param resources: Optional resources to use for the pods of the worker group
"""

def __init__(
self,
clean_pod_policy: Optional[CleanPodPolicy],
ttl_seconds_after_finished: Optional[int],
active_deadline_seconds: Optional[int],
backoff_limit: Optional[int],
):
self._clean_pod_policy = clean_pod_policy
self._ttl_seconds_after_finished = ttl_seconds_after_finished
self._active_deadline_seconds = active_deadline_seconds
self._backoff_limit = backoff_limit

@property
def clean_pod_policy(self) -> Optional[CleanPodPolicy]:
return self._clean_pod_policy

@property
def ttl_seconds_after_finished(self) -> Optional[int]:
return self._ttl_seconds_after_finished

@property
def active_deadline_seconds(self) -> Optional[int]:
return self._active_deadline_seconds

@property
def backoff_limit(self) -> Optional[int]:
return self._backoff_limit

def to_flyte_idl(self) -> kubeflow_common.RunPolicy:
return kubeflow_common.RunPolicy(
clean_pod_policy=self._clean_pod_policy.value if self._clean_pod_policy else None,
ttl_seconds_after_finished=self._ttl_seconds_after_finished,
active_deadline_seconds=self._active_deadline_seconds,
backoff_limit=self._backoff_limit,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from dataclasses import dataclass
from enum import Enum
from typing import Optional

from flyteidl.plugins.kubeflow import tensorflow_pb2 as tensorflow_task
from flytekitplugins.kftensorflow.models import PS, Chief, RunPolicy, Worker

from flytekit.models import common


class TensorFlowJob(common.FlyteIdlEntity):
def __init__(self, chief: Chief, ps: PS, worker: Worker, run_policy: Optional[RunPolicy] = None):
self._chief = chief
self._ps = ps
self._worker = worker
self._run_policy = run_policy

@property
def worker(self):
return self._worker

@property
def ps(self):
return self._ps

@property
def chief(self):
return self._chief

@property
def run_policy(self):
return self._run_policy

def to_flyte_idl(self) -> tensorflow_task.DistributedTensorflowTrainingTask:
training_task = tensorflow_task.DistributedTensorflowTrainingTask(
chief_replicas=self.chief.to_flyte_idl(),
worker_replicas=self.worker.to_flyte_idl(),
ps_replicas=self.ps.to_flyte_idl(),
run_policy=self.run_policy.to_flyte_idl() if self.run_policy else None,
)
return training_task
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from enum import Enum
from typing import Optional

from flyteidl.plugins.kubeflow import tensorflow_pb2 as tensorflow_task

from flytekit.models import common, task

from .restart_policy import RestartPolicy


class Worker(common.FlyteIdlEntity):
"""
Configuration for a worker replica group in a TFJob.
:param replicas: Number of replicas in the group. Minimum is 1.
:param image: Optional image to use for the pods of the group
:param resources: Optional resources to use for the pods of the group
:param restart_policy: Optional restart policy to use for the pods of the group
"""

def __init__(
self,
replicas: int,
image: Optional[str] = None,
resources: Optional[task.Resources] = None,
restart_policy: Optional[RestartPolicy] = None,
):
if replicas < 1:
raise ValueError(
f"TFJob worker replica needs to have at least one worker, but {replicas} have been specified."
)
self._replicas = replicas
self._image = image
self._resources = resources
self._restart_policy = restart_policy

@property
def image(self) -> Optional[str]:
return self._image

@property
def resources(self) -> Optional[task.Resources]:
return self._resources

@property
def replicas(self) -> Optional[int]:
return self._replicas

@property
def restart_policy(self) -> Optional[RestartPolicy]:
return self._restart_policy

def to_flyte_idl(self) -> tensorflow_task.DistributedTensorflowTrainingReplicaSpec:
return tensorflow_task.DistributedTensorflowTrainingReplicaSpec(
replicas=self.replicas,
image=self.image,
resources=self.resources.to_flyte_idl() if self.resources else None,
restart_policy=self.restart_policy.value if self.restart_policy else None,
)
Loading

0 comments on commit 2df576f

Please sign in to comment.