Skip to content

Commit

Permalink
lint and fmt
Browse files Browse the repository at this point in the history
Signed-off-by: Yubo Wang <[email protected]>
  • Loading branch information
Yubo Wang committed May 17, 2023
1 parent 61214e2 commit d14b602
Show file tree
Hide file tree
Showing 11 changed files with 128 additions and 81 deletions.
2 changes: 1 addition & 1 deletion plugins/flytekit-kf-mpi/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,4 @@ To migrate from v0 to v1, change the following:
to
```
task_config=MPIJob(worker=Worker(replicas=10)),
```
```
2 changes: 1 addition & 1 deletion plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@
MPIJob
"""

from .task import HorovodJob, MPIJob, Worker, Launcher, CleanPodPolicy, RunPolicy, RestartPolicy
from .task import CleanPodPolicy, HorovodJob, Launcher, MPIJob, RestartPolicy, RunPolicy, Worker
59 changes: 39 additions & 20 deletions plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,17 @@
"""
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Callable, Dict, Optional, Union, List
from typing import Any, Callable, Dict, List, Optional, Union

from flyteidl.plugins.kubeflow import common_pb2 as kubeflow_common
from flyteidl.plugins.kubeflow import mpi_pb2 as mpi_task
from google.protobuf.json_format import MessageToDict

from flytekit import PythonFunctionTask, Resources
from flytekit.core.resources import convert_resources_to_resource_model
from flytekit.configuration import SerializationSettings
from flytekit.core.resources import convert_resources_to_resource_model
from flytekit.extend import TaskPlugins
from flyteidl.plugins.kubeflow import common_pb2 as kubeflow_common


@dataclass
class RestartPolicy(Enum):
Expand All @@ -25,6 +26,7 @@ class RestartPolicy(Enum):
FAILURE = kubeflow_common.RESTART_POLICY_ON_FAILURE
NEVER = kubeflow_common.RESTART_POLICY_NEVER


@dataclass
class CleanPodPolicy(Enum):
"""
Expand All @@ -35,6 +37,7 @@ class CleanPodPolicy(Enum):
ALL = kubeflow_common.CLEANPOD_POLICY_ALL
RUNNING = kubeflow_common.CLEANPOD_POLICY_RUNNING


@dataclass
class RunPolicy:
"""
Expand All @@ -47,37 +50,41 @@ class RunPolicy:
where restartPolicy is OnFailure or Always.
backoff_limit (int): Number of retries before marking this job as failed.
"""

clean_pod_policy: CleanPodPolicy = None
ttl_seconds_after_finished: Optional[int] = None
active_deadline_seconds: Optional[int] = None
backoff_limit: Optional[int] = None


@dataclass
class Worker:
"""
Worker replica configuration. Worker command can be customized. If not specified, the worker will use
default command generated by the mpi operator.
"""
command : Optional[List[str]] = None

command: Optional[List[str]] = None
image: Optional[str] = None
requests: Optional[Resources] = None
limits: Optional[Resources] = None
replicas: Optional[int] = None
restart_policy: Optional[RestartPolicy] = None


@dataclass
class Launcher:
"""
Launcher replica configuration. Launcher command can be customized. If not specified, the launcher will use
the command specified in the task signature.
"""
command : Optional[List[str]] = None

command: Optional[List[str]] = None
image: Optional[str] = None
requests: Optional[Resources] = None
limits: Optional[Resources] = None
replicas: Optional[int] = None
restart_policy: Optional[RestartPolicy] = None
restart_policy: Optional[RestartPolicy] = None


@dataclass
Expand All @@ -94,6 +101,7 @@ class MPIJob(object):
num_launcher_replicas: [DEPRECATED] The number of launcher server replicas to use. This argument is deprecated.
num_workers: [DEPRECATED] The number of worker replicas to spawn in the cluster for this job
"""

launcher: Launcher = field(default_factory=lambda: Launcher())
worker: Worker = field(default_factory=lambda: Worker())
run_policy: Optional[RunPolicy] = field(default_factory=lambda: None)
Expand Down Expand Up @@ -133,21 +141,31 @@ class MPIFunctionTask(PythonFunctionTask[MPIJob]):

def __init__(self, task_config: MPIJob, task_function: Callable, **kwargs):
if task_config.num_workers and task_config.worker.replicas:
raise ValueError("Cannot specify both `num_workers` and `worker.replicas`. Please use `worker.replicas` as `num_workers` is depreacated.")
raise ValueError(
"Cannot specify both `num_workers` and `worker.replicas`. Please use `worker.replicas` as `num_workers` is depreacated."
)
if task_config.num_workers is None and task_config.worker.replicas is None:
raise ValueError("Must specify either `num_workers` or `worker.replicas`. Please use `worker.replicas` as `num_workers` is depreacated.")
raise ValueError(
"Must specify either `num_workers` or `worker.replicas`. Please use `worker.replicas` as `num_workers` is depreacated."
)
if task_config.num_launcher_replicas and task_config.launcher.replicas:
raise ValueError("Cannot specify both `num_workers` and `launcher.replicas`. Please use `launcher.replicas` as `num_launcher_replicas` is depreacated.")
raise ValueError(
"Cannot specify both `num_workers` and `launcher.replicas`. Please use `launcher.replicas` as `num_launcher_replicas` is depreacated."
)
if task_config.num_launcher_replicas is None and task_config.launcher.replicas is None:
raise ValueError("Must specify either `num_workers` or `launcher.replicas`. Please use `launcher.replicas` as `num_launcher_replicas` is depreacated.")
raise ValueError(
"Must specify either `num_workers` or `launcher.replicas`. Please use `launcher.replicas` as `num_launcher_replicas` is depreacated."
)
super().__init__(
task_config=task_config,
task_function=task_function,
task_type=self._MPI_JOB_TASK_TYPE,
**kwargs,
)

def _convert_replica_spec(self, replica_config: Union[Launcher, Worker]) -> mpi_task.DistributedMPITrainingReplicaSpec:

def _convert_replica_spec(
self, replica_config: Union[Launcher, Worker]
) -> mpi_task.DistributedMPITrainingReplicaSpec:
resources = convert_resources_to_resource_model(requests=replica_config.requests, limits=replica_config.limits)
return mpi_task.DistributedMPITrainingReplicaSpec(
command=replica_config.command,
Expand All @@ -156,15 +174,15 @@ def _convert_replica_spec(self, replica_config: Union[Launcher, Worker]) -> mpi_
resources=resources.to_flyte_idl() if resources else None,
restart_policy=replica_config.restart_policy.value if replica_config.restart_policy else None,
)

def _convert_run_policy(self, run_policy: RunPolicy) -> kubeflow_common.RunPolicy:
return kubeflow_common.RunPolicy(
clean_pod_policy=run_policy.clean_pod_policy.value if run_policy.clean_pod_policy else None,
ttl_seconds_after_finished=run_policy.ttl_seconds_after_finished,
active_deadline_seconds=run_policy.active_deadline_seconds,
backoff_limit=run_policy.active_deadline_seconds,
)

def _get_base_command(self, settings: SerializationSettings) -> List[str]:
return super().get_command(settings)

Expand All @@ -181,13 +199,13 @@ def get_command(self, settings: SerializationSettings) -> List[str]:

def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
worker = self._convert_replica_spec(self.task_config.worker)
if (self.task_config.num_workers):
if self.task_config.num_workers:
worker.replicas = self.task_config.num_workers

launcher = self._convert_replica_spec(self.task_config.launcher)
if (self.task_config.num_launcher_replicas):
launcher = self._convert_replica_spec(self.task_config.launcher)
if self.task_config.num_launcher_replicas:
launcher.replicas = self.task_config.num_launcher_replicas

run_policy = self._convert_run_policy(self.task_config.run_policy) if self.task_config.run_policy else None
mpi_job = mpi_task.DistributedMPITrainingTask(
worker_replicas=worker,
Expand Down Expand Up @@ -227,6 +245,7 @@ class HorovodJob(object):
num_launcher_replicas: Optional[int] = None
num_workers: Optional[int] = None


class HorovodFunctionTask(MPIFunctionTask):
"""
For more info, check out https://github.com/horovod/horovod
Expand Down
23 changes: 15 additions & 8 deletions plugins/flytekit-kf-mpi/tests/test_mpi_task.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import pytest
from flytekitplugins.kfmpi import HorovodJob, MPIJob, Worker, Launcher, RunPolicy, RestartPolicy, CleanPodPolicy
from flytekitplugins.kfmpi import CleanPodPolicy, HorovodJob, Launcher, MPIJob, RestartPolicy, RunPolicy, Worker
from flytekitplugins.kfmpi.task import MPIFunctionTask

from flytekit import Resources, task
from flytekit.configuration import Image, ImageConfig, SerializationSettings


@pytest.fixture
def serialization_settings() -> SerializationSettings:
default_img = Image(name="default", fqn="test", tag="tag")
Expand Down Expand Up @@ -32,9 +33,11 @@ def my_mpi_task(x: int, y: str) -> int:

assert my_mpi_task.task_config is not None

default_img = Image(name="default", fqn="test", tag="tag")

assert my_mpi_task.get_custom(serialization_settings) == {'launcherReplicas': {'replicas': 10, 'resources': {}}, 'workerReplicas': {'replicas': 10, 'resources': {}}, "slots": 1}
assert my_mpi_task.get_custom(serialization_settings) == {
"launcherReplicas": {"replicas": 10, "resources": {}},
"workerReplicas": {"replicas": 10, "resources": {}},
"slots": 1,
}
assert my_mpi_task.task_type == "mpi"


Expand All @@ -59,7 +62,9 @@ def my_mpi_task(x: int, y: str) -> int:
assert my_mpi_task.task_type == "mpi"
assert my_mpi_task.resources.limits == Resources()
assert my_mpi_task.resources.requests == Resources(cpu="1")
assert ' '.join(my_mpi_task.get_command(serialization_settings)).startswith(' '.join(MPIFunctionTask._MPI_BASE_COMMAND + ["-np", "1"]))
assert " ".join(my_mpi_task.get_command(serialization_settings)).startswith(
" ".join(MPIFunctionTask._MPI_BASE_COMMAND + ["-np", "1"])
)

expected_dict = {
"launcherReplicas": {
Expand Down Expand Up @@ -111,7 +116,9 @@ def my_mpi_task(x: int, y: str) -> int:
assert my_mpi_task.task_type == "mpi"
assert my_mpi_task.resources.limits == Resources()
assert my_mpi_task.resources.requests == Resources(cpu="1")
assert ' '.join(my_mpi_task.get_command(serialization_settings)).startswith(' '.join(MPIFunctionTask._MPI_BASE_COMMAND + ["-np", "1"]))
assert " ".join(my_mpi_task.get_command(serialization_settings)).startswith(
" ".join(MPIFunctionTask._MPI_BASE_COMMAND + ["-np", "1"])
)

expected_custom_dict = {
"launcherReplicas": {
Expand All @@ -137,7 +144,7 @@ def my_mpi_task(x: int, y: str) -> int:
},
},
"slots": 2,
"runPolicy": {'cleanPodPolicy': 'CLEANPOD_POLICY_ALL'},
"runPolicy": {"cleanPodPolicy": "CLEANPOD_POLICY_ALL"},
}
assert my_mpi_task.get_custom(serialization_settings) == expected_custom_dict

Expand All @@ -164,7 +171,7 @@ def my_horovod_task():
...

cmd = my_horovod_task.get_command(serialization_settings)
assert "horovodrun" in cmd
assert "horovodrun" in cmd
assert "--verbose" not in cmd
assert "--log-level" in cmd
assert "INFO" in cmd
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
Elastic
"""

from .task import Elastic, PyTorch, Worker, Master, CleanPodPolicy, RunPolicy, RestartPolicy
from .task import CleanPodPolicy, Elastic, Master, PyTorch, RestartPolicy, RunPolicy, Worker
37 changes: 24 additions & 13 deletions plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,30 @@
Kubernetes. It leverages `Pytorch Job <https://github.com/kubeflow/pytorch-operator>`_ Plugin from kubeflow.
"""
import os
from enum import Enum
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Callable, Dict, Optional, Union

import cloudpickle
from flyteidl.plugins.kubeflow import pytorch_pb2 as pytorch_task
from flyteidl.plugins.kubeflow import common_pb2 as kubeflow_common
from flyteidl.plugins.kubeflow import pytorch_pb2 as pytorch_task
from google.protobuf.json_format import MessageToDict

import flytekit
from flytekit import PythonFunctionTask, Resources
from flytekit.core.resources import convert_resources_to_resource_model
from flytekit.configuration import SerializationSettings
from flytekit.core.resources import convert_resources_to_resource_model
from flytekit.extend import IgnoreOutputs, TaskPlugins

TORCH_IMPORT_ERROR_MESSAGE = "PyTorch is not installed. Please install `flytekitplugins-kfpytorch['elastic']`."


@dataclass
class RestartPolicy(Enum):
"""
RestartPolicy describes how the replicas should be restarted
"""

ALWAYS = kubeflow_common.RESTART_POLICY_ALWAYS
FAILURE = kubeflow_common.RESTART_POLICY_ON_FAILURE
NEVER = kubeflow_common.RESTART_POLICY_NEVER
Expand All @@ -35,6 +37,7 @@ class CleanPodPolicy(Enum):
"""
CleanPodPolicy describes how to deal with pods when the job is finished.
"""

NONE = kubeflow_common.CLEANPOD_POLICY_NONE
ALL = kubeflow_common.CLEANPOD_POLICY_ALL
RUNNING = kubeflow_common.CLEANPOD_POLICY_RUNNING
Expand All @@ -52,6 +55,7 @@ class RunPolicy:
where restartPolicy is OnFailure or Always.
backoff_limit (int): Number of retries before marking this job as failed.
"""

clean_pod_policy: CleanPodPolicy = None
ttl_seconds_after_finished: Optional[int] = None
active_deadline_seconds: Optional[int] = None
Expand All @@ -72,6 +76,7 @@ class Master:
"""
Configuration for master replica group. Master should always have 1 replica, so we don't need a `replicas` field
"""

image: Optional[str] = None
requests: Optional[Resources] = None
limits: Optional[Resources] = None
Expand All @@ -90,14 +95,14 @@ class PyTorch(object):
run_policy: Configuration for the run policy.
num_workers: [DEPRECATED] This argument is deprecated. Use `worker.replicas` instead.
"""

master: Master = field(default_factory=lambda: Master())
worker: Worker = field(default_factory=lambda: Worker())
run_policy: Optional[RunPolicy] = field(default_factory=lambda: None)
# Support v0 config for backwards compatibility
num_workers: Optional[int] = None



@dataclass
class Elastic(object):
"""
Expand Down Expand Up @@ -133,17 +138,23 @@ class PyTorchFunctionTask(PythonFunctionTask[PyTorch]):

def __init__(self, task_config: PyTorch, task_function: Callable, **kwargs):
if task_config.num_workers and task_config.worker.replicas:
raise ValueError("Cannot specify both `num_workers` and `worker.replicas`. Please use `worker.replicas` as `num_workers` is depreacated.")
raise ValueError(
"Cannot specify both `num_workers` and `worker.replicas`. Please use `worker.replicas` as `num_workers` is depreacated."
)
if task_config.num_workers is None and task_config.worker.replicas is None:
raise ValueError("Must specify either `num_workers` or `worker.replicas`. Please use `worker.replicas` as `num_workers` is depreacated.")
raise ValueError(
"Must specify either `num_workers` or `worker.replicas`. Please use `worker.replicas` as `num_workers` is depreacated."
)
super().__init__(
task_config,
task_function,
task_type=self._PYTORCH_TASK_TYPE,
**kwargs,
)

def _convert_replica_spec(self, replica_config: Union[Master, Worker]) -> pytorch_task.DistributedPyTorchTrainingReplicaSpec:
)

def _convert_replica_spec(
self, replica_config: Union[Master, Worker]
) -> pytorch_task.DistributedPyTorchTrainingReplicaSpec:
resources = convert_resources_to_resource_model(requests=replica_config.requests, limits=replica_config.limits)
replicas = 1
# Master should always have 1 replica
Expand All @@ -155,21 +166,21 @@ def _convert_replica_spec(self, replica_config: Union[Master, Worker]) -> pytorc
resources=resources.to_flyte_idl() if resources else None,
restart_policy=replica_config.restart_policy.value if replica_config.restart_policy else None,
)

def _convert_run_policy(self, run_policy: RunPolicy) -> kubeflow_common.RunPolicy:
return kubeflow_common.RunPolicy(
clean_pod_policy=run_policy.clean_pod_policy.value if run_policy.clean_pod_policy else None,
ttl_seconds_after_finished=run_policy.ttl_seconds_after_finished,
active_deadline_seconds=run_policy.active_deadline_seconds,
backoff_limit=run_policy.active_deadline_seconds,
)

def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
worker = self._convert_replica_spec(self.task_config.worker)
# support v0 config for backwards compatibility
if (self.task_config.num_workers):
if self.task_config.num_workers:
worker.replicas = self.task_config.num_workers

run_policy = self._convert_run_policy(self.task_config.run_policy) if self.task_config.run_policy else None
pytorch_job = pytorch_task.DistributedPyTorchTrainingTask(
worker_replicas=worker,
Expand Down
Loading

0 comments on commit d14b602

Please sign in to comment.