-
Notifications
You must be signed in to change notification settings - Fork 301
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Yubo Wang <[email protected]>
- Loading branch information
Yubo Wang
committed
May 13, 2023
1 parent
dab1eed
commit 2df576f
Showing
13 changed files
with
531 additions
and
70 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
35 changes: 0 additions & 35 deletions
35
plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/models.py
This file was deleted.
Oops, something went wrong.
6 changes: 6 additions & 0 deletions
6
plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/models/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from .chief import Chief | ||
from .ps import PS | ||
from .restart_policy import RestartPolicy | ||
from .run_policy import CleanPodPolicy, RunPolicy | ||
from .tensorflow_job import TensorFlowJob | ||
from .worker import Worker |
58 changes: 58 additions & 0 deletions
58
plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/models/chief.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
from typing import Optional | ||
|
||
from flyteidl.plugins.kubeflow import tensorflow_pb2 as tensorflow_task | ||
|
||
from flytekit.models import common, task | ||
|
||
from .restart_policy import RestartPolicy | ||
|
||
|
||
class Chief(common.FlyteIdlEntity): | ||
""" | ||
Configuration for a chief replica group in a TFJob. | ||
:param replicas: Number of replicas in the group. This should be 1 or 0. If 0, the chief will be elected from the worker group. | ||
:param image: Optional image to use for the pods of the group | ||
:param resources: Optional resources to use for the pods of the group | ||
:param restart_policy: Optional restart policy to use for the pods of the group | ||
""" | ||
|
||
def __init__( | ||
self, | ||
replicas: int, | ||
image: Optional[str] = None, | ||
resources: Optional[task.Resources] = None, | ||
restart_policy: Optional[RestartPolicy] = None, | ||
): | ||
if replicas != 0 and replicas != 1: | ||
raise ValueError( | ||
f"TFJob chief group needs to have either one replica or no replica(one worker will be elected as chief), but {replicas} have been specified." | ||
) | ||
self._replicas = replicas | ||
self._image = image | ||
self._resources = resources | ||
self._restart_policy = restart_policy | ||
|
||
@property | ||
def image(self) -> Optional[str]: | ||
return self._image | ||
|
||
@property | ||
def resources(self) -> Optional[task.Resources]: | ||
return self._resources | ||
|
||
@property | ||
def replicas(self) -> Optional[int]: | ||
return self._replicas | ||
|
||
@property | ||
def restart_policy(self) -> Optional[RestartPolicy]: | ||
return self._restart_policy | ||
|
||
def to_flyte_idl(self) -> tensorflow_task.DistributedTensorflowTrainingReplicaSpec: | ||
return tensorflow_task.DistributedTensorflowTrainingReplicaSpec( | ||
replicas=self.replicas, | ||
image=self.image, | ||
resources=self.resources.to_flyte_idl() if self.resources else None, | ||
restart_policy=self.restart_policy.value if self.restart_policy else None, | ||
) |
55 changes: 55 additions & 0 deletions
55
plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/models/ps.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
from enum import Enum | ||
from typing import Optional | ||
|
||
from flyteidl.plugins.kubeflow import tensorflow_pb2 as tensorflow_task | ||
|
||
from flytekit.models import common, task | ||
|
||
from .restart_policy import RestartPolicy | ||
|
||
|
||
class PS(common.FlyteIdlEntity): | ||
""" | ||
Configuration for a ps replica group in a TFJob. | ||
:param replicas: Number of replicas in the group. Default is 0. | ||
:param image: Optional image to use for the pods of the group | ||
:param resources: Optional resources to use for the pods of the group | ||
:param restart_policy: Optional restart policy to use for the pods of the group | ||
""" | ||
|
||
def __init__( | ||
self, | ||
replicas: int, | ||
image: Optional[str] = None, | ||
resources: Optional[task.Resources] = None, | ||
restart_policy: Optional[RestartPolicy] = None, | ||
): | ||
self._replicas = replicas | ||
self._image = image | ||
self._resources = resources | ||
self._restart_policy = restart_policy | ||
|
||
@property | ||
def image(self) -> Optional[str]: | ||
return self._image | ||
|
||
@property | ||
def resources(self) -> Optional[task.Resources]: | ||
return self._resources | ||
|
||
@property | ||
def replicas(self) -> Optional[int]: | ||
return self._replicas | ||
|
||
@property | ||
def restart_policy(self) -> Optional[RestartPolicy]: | ||
return self._restart_policy | ||
|
||
def to_flyte_idl(self) -> tensorflow_task.DistributedTensorflowTrainingReplicaSpec: | ||
return tensorflow_task.DistributedTensorflowTrainingReplicaSpec( | ||
replicas=self.replicas, | ||
image=self.image, | ||
resources=self.resources.to_flyte_idl() if self.resources else None, | ||
restart_policy=self.restart_policy.value if self.restart_policy else None, | ||
) |
15 changes: 15 additions & 0 deletions
15
plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/models/restart_policy.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
from dataclasses import dataclass | ||
from enum import Enum | ||
|
||
from flyteidl.plugins.kubeflow.common_pb2 import RESTART_POLICY_ALWAYS, RESTART_POLICY_NEVER, RESTART_POLICY_ON_FAILURE | ||
|
||
|
||
@dataclass | ||
class RestartPolicy(Enum): | ||
""" | ||
RestartPolicy describes how the replicas should be restarted | ||
""" | ||
|
||
ALWAYS = RESTART_POLICY_ALWAYS | ||
FAILURE = RESTART_POLICY_ON_FAILURE | ||
NEVER = RESTART_POLICY_NEVER |
65 changes: 65 additions & 0 deletions
65
plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/models/run_policy.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
from dataclasses import dataclass | ||
from enum import Enum | ||
from typing import Optional | ||
|
||
from flyteidl.plugins.kubeflow import common_pb2 as kubeflow_common | ||
from flyteidl.plugins.kubeflow.common_pb2 import CLEANPOD_POLICY_ALL, CLEANPOD_POLICY_NONE, CLEANPOD_POLICY_RUNNING | ||
|
||
from flytekit.models import common | ||
|
||
|
||
@dataclass | ||
class CleanPodPolicy(Enum): | ||
""" | ||
CleanPodPolicy describes how to deal with pods when the job is finished. | ||
""" | ||
|
||
NONE = CLEANPOD_POLICY_NONE | ||
ALL = CLEANPOD_POLICY_ALL | ||
RUNNING = CLEANPOD_POLICY_RUNNING | ||
|
||
|
||
class RunPolicy(common.FlyteIdlEntity): | ||
""" | ||
Configuration for a dask worker group | ||
:param replicas: Number of workers in the group, minimum is 1 | ||
:param image: Optional image to use for the pods of the worker group | ||
:param resources: Optional resources to use for the pods of the worker group | ||
""" | ||
|
||
def __init__( | ||
self, | ||
clean_pod_policy: Optional[CleanPodPolicy], | ||
ttl_seconds_after_finished: Optional[int], | ||
active_deadline_seconds: Optional[int], | ||
backoff_limit: Optional[int], | ||
): | ||
self._clean_pod_policy = clean_pod_policy | ||
self._ttl_seconds_after_finished = ttl_seconds_after_finished | ||
self._active_deadline_seconds = active_deadline_seconds | ||
self._backoff_limit = backoff_limit | ||
|
||
@property | ||
def clean_pod_policy(self) -> Optional[CleanPodPolicy]: | ||
return self._clean_pod_policy | ||
|
||
@property | ||
def ttl_seconds_after_finished(self) -> Optional[int]: | ||
return self._ttl_seconds_after_finished | ||
|
||
@property | ||
def active_deadline_seconds(self) -> Optional[int]: | ||
return self._active_deadline_seconds | ||
|
||
@property | ||
def backoff_limit(self) -> Optional[int]: | ||
return self._backoff_limit | ||
|
||
def to_flyte_idl(self) -> kubeflow_common.RunPolicy: | ||
return kubeflow_common.RunPolicy( | ||
clean_pod_policy=self._clean_pod_policy.value if self._clean_pod_policy else None, | ||
ttl_seconds_after_finished=self._ttl_seconds_after_finished, | ||
active_deadline_seconds=self._active_deadline_seconds, | ||
backoff_limit=self._backoff_limit, | ||
) |
41 changes: 41 additions & 0 deletions
41
plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/models/tensorflow_job.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
from dataclasses import dataclass | ||
from enum import Enum | ||
from typing import Optional | ||
|
||
from flyteidl.plugins.kubeflow import tensorflow_pb2 as tensorflow_task | ||
from flytekitplugins.kftensorflow.models import PS, Chief, RunPolicy, Worker | ||
|
||
from flytekit.models import common | ||
|
||
|
||
class TensorFlowJob(common.FlyteIdlEntity): | ||
def __init__(self, chief: Chief, ps: PS, worker: Worker, run_policy: Optional[RunPolicy] = None): | ||
self._chief = chief | ||
self._ps = ps | ||
self._worker = worker | ||
self._run_policy = run_policy | ||
|
||
@property | ||
def worker(self): | ||
return self._worker | ||
|
||
@property | ||
def ps(self): | ||
return self._ps | ||
|
||
@property | ||
def chief(self): | ||
return self._chief | ||
|
||
@property | ||
def run_policy(self): | ||
return self._run_policy | ||
|
||
def to_flyte_idl(self) -> tensorflow_task.DistributedTensorflowTrainingTask: | ||
training_task = tensorflow_task.DistributedTensorflowTrainingTask( | ||
chief_replicas=self.chief.to_flyte_idl(), | ||
worker_replicas=self.worker.to_flyte_idl(), | ||
ps_replicas=self.ps.to_flyte_idl(), | ||
run_policy=self.run_policy.to_flyte_idl() if self.run_policy else None, | ||
) | ||
return training_task |
59 changes: 59 additions & 0 deletions
59
plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/models/worker.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
from enum import Enum | ||
from typing import Optional | ||
|
||
from flyteidl.plugins.kubeflow import tensorflow_pb2 as tensorflow_task | ||
|
||
from flytekit.models import common, task | ||
|
||
from .restart_policy import RestartPolicy | ||
|
||
|
||
class Worker(common.FlyteIdlEntity): | ||
""" | ||
Configuration for a worker replica group in a TFJob. | ||
:param replicas: Number of replicas in the group. Minimum is 1. | ||
:param image: Optional image to use for the pods of the group | ||
:param resources: Optional resources to use for the pods of the group | ||
:param restart_policy: Optional restart policy to use for the pods of the group | ||
""" | ||
|
||
def __init__( | ||
self, | ||
replicas: int, | ||
image: Optional[str] = None, | ||
resources: Optional[task.Resources] = None, | ||
restart_policy: Optional[RestartPolicy] = None, | ||
): | ||
if replicas < 1: | ||
raise ValueError( | ||
f"TFJob worker replica needs to have at least one worker, but {replicas} have been specified." | ||
) | ||
self._replicas = replicas | ||
self._image = image | ||
self._resources = resources | ||
self._restart_policy = restart_policy | ||
|
||
@property | ||
def image(self) -> Optional[str]: | ||
return self._image | ||
|
||
@property | ||
def resources(self) -> Optional[task.Resources]: | ||
return self._resources | ||
|
||
@property | ||
def replicas(self) -> Optional[int]: | ||
return self._replicas | ||
|
||
@property | ||
def restart_policy(self) -> Optional[RestartPolicy]: | ||
return self._restart_policy | ||
|
||
def to_flyte_idl(self) -> tensorflow_task.DistributedTensorflowTrainingReplicaSpec: | ||
return tensorflow_task.DistributedTensorflowTrainingReplicaSpec( | ||
replicas=self.replicas, | ||
image=self.image, | ||
resources=self.resources.to_flyte_idl() if self.resources else None, | ||
restart_policy=self.restart_policy.value if self.restart_policy else None, | ||
) |
Oops, something went wrong.