From d14b602f961337ab91e0845f456363885f4f8dc5 Mon Sep 17 00:00:00 2001 From: Yubo Wang Date: Tue, 16 May 2023 22:06:12 -0700 Subject: [PATCH] lint and fmt Signed-off-by: Yubo Wang --- plugins/flytekit-kf-mpi/README.md | 2 +- .../flytekitplugins/kfmpi/__init__.py | 2 +- .../flytekitplugins/kfmpi/task.py | 59 ++++++++++++------- .../flytekit-kf-mpi/tests/test_mpi_task.py | 23 +++++--- .../flytekitplugins/kfpytorch/__init__.py | 2 +- .../flytekitplugins/kfpytorch/task.py | 37 ++++++++---- .../tests/test_pytorch_task.py | 22 ++----- plugins/flytekit-kf-tensorflow/README.md | 2 +- .../flytekitplugins/kftensorflow/__init__.py | 2 +- .../flytekitplugins/kftensorflow/task.py | 53 +++++++++++------ .../tests/test_tensorflow_task.py | 5 +- 11 files changed, 128 insertions(+), 81 deletions(-) diff --git a/plugins/flytekit-kf-mpi/README.md b/plugins/flytekit-kf-mpi/README.md index f083f2f0a2..db475868eb 100644 --- a/plugins/flytekit-kf-mpi/README.md +++ b/plugins/flytekit-kf-mpi/README.md @@ -71,4 +71,4 @@ To migrate from v0 to v1, change the following: to ``` task_config=MPIJob(worker=Worker(replicas=10)), -``` \ No newline at end of file +``` diff --git a/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/__init__.py b/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/__init__.py index 873d3d47bd..7d2107c8ae 100644 --- a/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/__init__.py +++ b/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/__init__.py @@ -10,4 +10,4 @@ MPIJob """ -from .task import HorovodJob, MPIJob, Worker, Launcher, CleanPodPolicy, RunPolicy, RestartPolicy +from .task import CleanPodPolicy, HorovodJob, Launcher, MPIJob, RestartPolicy, RunPolicy, Worker diff --git a/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py b/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py index 005c0da828..20179c7376 100644 --- a/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py +++ b/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py @@ -4,16 +4,17 @@ """ from dataclasses import dataclass, field from enum import Enum -from typing import Any, Callable, Dict, Optional, Union, List +from typing import Any, Callable, Dict, List, Optional, Union +from flyteidl.plugins.kubeflow import common_pb2 as kubeflow_common from flyteidl.plugins.kubeflow import mpi_pb2 as mpi_task from google.protobuf.json_format import MessageToDict from flytekit import PythonFunctionTask, Resources -from flytekit.core.resources import convert_resources_to_resource_model from flytekit.configuration import SerializationSettings +from flytekit.core.resources import convert_resources_to_resource_model from flytekit.extend import TaskPlugins -from flyteidl.plugins.kubeflow import common_pb2 as kubeflow_common + @dataclass class RestartPolicy(Enum): @@ -25,6 +26,7 @@ class RestartPolicy(Enum): FAILURE = kubeflow_common.RESTART_POLICY_ON_FAILURE NEVER = kubeflow_common.RESTART_POLICY_NEVER + @dataclass class CleanPodPolicy(Enum): """ @@ -35,6 +37,7 @@ class CleanPodPolicy(Enum): ALL = kubeflow_common.CLEANPOD_POLICY_ALL RUNNING = kubeflow_common.CLEANPOD_POLICY_RUNNING + @dataclass class RunPolicy: """ @@ -47,37 +50,41 @@ class RunPolicy: where restartPolicy is OnFailure or Always. backoff_limit (int): Number of retries before marking this job as failed. """ + clean_pod_policy: CleanPodPolicy = None ttl_seconds_after_finished: Optional[int] = None active_deadline_seconds: Optional[int] = None backoff_limit: Optional[int] = None + @dataclass class Worker: """ Worker replica configuration. Worker command can be customized. If not specified, the worker will use default command generated by the mpi operator. """ - command : Optional[List[str]] = None + + command: Optional[List[str]] = None image: Optional[str] = None requests: Optional[Resources] = None limits: Optional[Resources] = None replicas: Optional[int] = None restart_policy: Optional[RestartPolicy] = None - - + + @dataclass class Launcher: """ Launcher replica configuration. Launcher command can be customized. If not specified, the launcher will use the command specified in the task signature. """ - command : Optional[List[str]] = None + + command: Optional[List[str]] = None image: Optional[str] = None requests: Optional[Resources] = None limits: Optional[Resources] = None replicas: Optional[int] = None - restart_policy: Optional[RestartPolicy] = None + restart_policy: Optional[RestartPolicy] = None @dataclass @@ -94,6 +101,7 @@ class MPIJob(object): num_launcher_replicas: [DEPRECATED] The number of launcher server replicas to use. This argument is deprecated. num_workers: [DEPRECATED] The number of worker replicas to spawn in the cluster for this job """ + launcher: Launcher = field(default_factory=lambda: Launcher()) worker: Worker = field(default_factory=lambda: Worker()) run_policy: Optional[RunPolicy] = field(default_factory=lambda: None) @@ -133,21 +141,31 @@ class MPIFunctionTask(PythonFunctionTask[MPIJob]): def __init__(self, task_config: MPIJob, task_function: Callable, **kwargs): if task_config.num_workers and task_config.worker.replicas: - raise ValueError("Cannot specify both `num_workers` and `worker.replicas`. Please use `worker.replicas` as `num_workers` is depreacated.") + raise ValueError( + "Cannot specify both `num_workers` and `worker.replicas`. Please use `worker.replicas` as `num_workers` is depreacated." + ) if task_config.num_workers is None and task_config.worker.replicas is None: - raise ValueError("Must specify either `num_workers` or `worker.replicas`. Please use `worker.replicas` as `num_workers` is depreacated.") + raise ValueError( + "Must specify either `num_workers` or `worker.replicas`. Please use `worker.replicas` as `num_workers` is depreacated." + ) if task_config.num_launcher_replicas and task_config.launcher.replicas: - raise ValueError("Cannot specify both `num_workers` and `launcher.replicas`. Please use `launcher.replicas` as `num_launcher_replicas` is depreacated.") + raise ValueError( + "Cannot specify both `num_workers` and `launcher.replicas`. Please use `launcher.replicas` as `num_launcher_replicas` is depreacated." + ) if task_config.num_launcher_replicas is None and task_config.launcher.replicas is None: - raise ValueError("Must specify either `num_workers` or `launcher.replicas`. Please use `launcher.replicas` as `num_launcher_replicas` is depreacated.") + raise ValueError( + "Must specify either `num_workers` or `launcher.replicas`. Please use `launcher.replicas` as `num_launcher_replicas` is depreacated." + ) super().__init__( task_config=task_config, task_function=task_function, task_type=self._MPI_JOB_TASK_TYPE, **kwargs, ) - - def _convert_replica_spec(self, replica_config: Union[Launcher, Worker]) -> mpi_task.DistributedMPITrainingReplicaSpec: + + def _convert_replica_spec( + self, replica_config: Union[Launcher, Worker] + ) -> mpi_task.DistributedMPITrainingReplicaSpec: resources = convert_resources_to_resource_model(requests=replica_config.requests, limits=replica_config.limits) return mpi_task.DistributedMPITrainingReplicaSpec( command=replica_config.command, @@ -156,7 +174,7 @@ def _convert_replica_spec(self, replica_config: Union[Launcher, Worker]) -> mpi_ resources=resources.to_flyte_idl() if resources else None, restart_policy=replica_config.restart_policy.value if replica_config.restart_policy else None, ) - + def _convert_run_policy(self, run_policy: RunPolicy) -> kubeflow_common.RunPolicy: return kubeflow_common.RunPolicy( clean_pod_policy=run_policy.clean_pod_policy.value if run_policy.clean_pod_policy else None, @@ -164,7 +182,7 @@ def _convert_run_policy(self, run_policy: RunPolicy) -> kubeflow_common.RunPolic active_deadline_seconds=run_policy.active_deadline_seconds, backoff_limit=run_policy.active_deadline_seconds, ) - + def _get_base_command(self, settings: SerializationSettings) -> List[str]: return super().get_command(settings) @@ -181,13 +199,13 @@ def get_command(self, settings: SerializationSettings) -> List[str]: def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: worker = self._convert_replica_spec(self.task_config.worker) - if (self.task_config.num_workers): + if self.task_config.num_workers: worker.replicas = self.task_config.num_workers - launcher = self._convert_replica_spec(self.task_config.launcher) - if (self.task_config.num_launcher_replicas): + launcher = self._convert_replica_spec(self.task_config.launcher) + if self.task_config.num_launcher_replicas: launcher.replicas = self.task_config.num_launcher_replicas - + run_policy = self._convert_run_policy(self.task_config.run_policy) if self.task_config.run_policy else None mpi_job = mpi_task.DistributedMPITrainingTask( worker_replicas=worker, @@ -227,6 +245,7 @@ class HorovodJob(object): num_launcher_replicas: Optional[int] = None num_workers: Optional[int] = None + class HorovodFunctionTask(MPIFunctionTask): """ For more info, check out https://github.com/horovod/horovod diff --git a/plugins/flytekit-kf-mpi/tests/test_mpi_task.py b/plugins/flytekit-kf-mpi/tests/test_mpi_task.py index 3641e46019..f6eb2655f6 100644 --- a/plugins/flytekit-kf-mpi/tests/test_mpi_task.py +++ b/plugins/flytekit-kf-mpi/tests/test_mpi_task.py @@ -1,10 +1,11 @@ import pytest -from flytekitplugins.kfmpi import HorovodJob, MPIJob, Worker, Launcher, RunPolicy, RestartPolicy, CleanPodPolicy +from flytekitplugins.kfmpi import CleanPodPolicy, HorovodJob, Launcher, MPIJob, RestartPolicy, RunPolicy, Worker from flytekitplugins.kfmpi.task import MPIFunctionTask from flytekit import Resources, task from flytekit.configuration import Image, ImageConfig, SerializationSettings + @pytest.fixture def serialization_settings() -> SerializationSettings: default_img = Image(name="default", fqn="test", tag="tag") @@ -32,9 +33,11 @@ def my_mpi_task(x: int, y: str) -> int: assert my_mpi_task.task_config is not None - default_img = Image(name="default", fqn="test", tag="tag") - - assert my_mpi_task.get_custom(serialization_settings) == {'launcherReplicas': {'replicas': 10, 'resources': {}}, 'workerReplicas': {'replicas': 10, 'resources': {}}, "slots": 1} + assert my_mpi_task.get_custom(serialization_settings) == { + "launcherReplicas": {"replicas": 10, "resources": {}}, + "workerReplicas": {"replicas": 10, "resources": {}}, + "slots": 1, + } assert my_mpi_task.task_type == "mpi" @@ -59,7 +62,9 @@ def my_mpi_task(x: int, y: str) -> int: assert my_mpi_task.task_type == "mpi" assert my_mpi_task.resources.limits == Resources() assert my_mpi_task.resources.requests == Resources(cpu="1") - assert ' '.join(my_mpi_task.get_command(serialization_settings)).startswith(' '.join(MPIFunctionTask._MPI_BASE_COMMAND + ["-np", "1"])) + assert " ".join(my_mpi_task.get_command(serialization_settings)).startswith( + " ".join(MPIFunctionTask._MPI_BASE_COMMAND + ["-np", "1"]) + ) expected_dict = { "launcherReplicas": { @@ -111,7 +116,9 @@ def my_mpi_task(x: int, y: str) -> int: assert my_mpi_task.task_type == "mpi" assert my_mpi_task.resources.limits == Resources() assert my_mpi_task.resources.requests == Resources(cpu="1") - assert ' '.join(my_mpi_task.get_command(serialization_settings)).startswith(' '.join(MPIFunctionTask._MPI_BASE_COMMAND + ["-np", "1"])) + assert " ".join(my_mpi_task.get_command(serialization_settings)).startswith( + " ".join(MPIFunctionTask._MPI_BASE_COMMAND + ["-np", "1"]) + ) expected_custom_dict = { "launcherReplicas": { @@ -137,7 +144,7 @@ def my_mpi_task(x: int, y: str) -> int: }, }, "slots": 2, - "runPolicy": {'cleanPodPolicy': 'CLEANPOD_POLICY_ALL'}, + "runPolicy": {"cleanPodPolicy": "CLEANPOD_POLICY_ALL"}, } assert my_mpi_task.get_custom(serialization_settings) == expected_custom_dict @@ -164,7 +171,7 @@ def my_horovod_task(): ... cmd = my_horovod_task.get_command(serialization_settings) - assert "horovodrun" in cmd + assert "horovodrun" in cmd assert "--verbose" not in cmd assert "--log-level" in cmd assert "INFO" in cmd diff --git a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/__init__.py b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/__init__.py index 66e8a5c2fd..d56e1f83d9 100644 --- a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/__init__.py +++ b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/__init__.py @@ -11,4 +11,4 @@ Elastic """ -from .task import Elastic, PyTorch, Worker, Master, CleanPodPolicy, RunPolicy, RestartPolicy +from .task import CleanPodPolicy, Elastic, Master, PyTorch, RestartPolicy, RunPolicy, Worker diff --git a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py index a9a0732d28..6625263db1 100644 --- a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py +++ b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py @@ -3,28 +3,30 @@ Kubernetes. It leverages `Pytorch Job `_ Plugin from kubeflow. """ import os -from enum import Enum from dataclasses import dataclass, field +from enum import Enum from typing import Any, Callable, Dict, Optional, Union import cloudpickle -from flyteidl.plugins.kubeflow import pytorch_pb2 as pytorch_task from flyteidl.plugins.kubeflow import common_pb2 as kubeflow_common +from flyteidl.plugins.kubeflow import pytorch_pb2 as pytorch_task from google.protobuf.json_format import MessageToDict import flytekit from flytekit import PythonFunctionTask, Resources -from flytekit.core.resources import convert_resources_to_resource_model from flytekit.configuration import SerializationSettings +from flytekit.core.resources import convert_resources_to_resource_model from flytekit.extend import IgnoreOutputs, TaskPlugins TORCH_IMPORT_ERROR_MESSAGE = "PyTorch is not installed. Please install `flytekitplugins-kfpytorch['elastic']`." + @dataclass class RestartPolicy(Enum): """ RestartPolicy describes how the replicas should be restarted """ + ALWAYS = kubeflow_common.RESTART_POLICY_ALWAYS FAILURE = kubeflow_common.RESTART_POLICY_ON_FAILURE NEVER = kubeflow_common.RESTART_POLICY_NEVER @@ -35,6 +37,7 @@ class CleanPodPolicy(Enum): """ CleanPodPolicy describes how to deal with pods when the job is finished. """ + NONE = kubeflow_common.CLEANPOD_POLICY_NONE ALL = kubeflow_common.CLEANPOD_POLICY_ALL RUNNING = kubeflow_common.CLEANPOD_POLICY_RUNNING @@ -52,6 +55,7 @@ class RunPolicy: where restartPolicy is OnFailure or Always. backoff_limit (int): Number of retries before marking this job as failed. """ + clean_pod_policy: CleanPodPolicy = None ttl_seconds_after_finished: Optional[int] = None active_deadline_seconds: Optional[int] = None @@ -72,6 +76,7 @@ class Master: """ Configuration for master replica group. Master should always have 1 replica, so we don't need a `replicas` field """ + image: Optional[str] = None requests: Optional[Resources] = None limits: Optional[Resources] = None @@ -90,6 +95,7 @@ class PyTorch(object): run_policy: Configuration for the run policy. num_workers: [DEPRECATED] This argument is deprecated. Use `worker.replicas` instead. """ + master: Master = field(default_factory=lambda: Master()) worker: Worker = field(default_factory=lambda: Worker()) run_policy: Optional[RunPolicy] = field(default_factory=lambda: None) @@ -97,7 +103,6 @@ class PyTorch(object): num_workers: Optional[int] = None - @dataclass class Elastic(object): """ @@ -133,17 +138,23 @@ class PyTorchFunctionTask(PythonFunctionTask[PyTorch]): def __init__(self, task_config: PyTorch, task_function: Callable, **kwargs): if task_config.num_workers and task_config.worker.replicas: - raise ValueError("Cannot specify both `num_workers` and `worker.replicas`. Please use `worker.replicas` as `num_workers` is depreacated.") + raise ValueError( + "Cannot specify both `num_workers` and `worker.replicas`. Please use `worker.replicas` as `num_workers` is depreacated." + ) if task_config.num_workers is None and task_config.worker.replicas is None: - raise ValueError("Must specify either `num_workers` or `worker.replicas`. Please use `worker.replicas` as `num_workers` is depreacated.") + raise ValueError( + "Must specify either `num_workers` or `worker.replicas`. Please use `worker.replicas` as `num_workers` is depreacated." + ) super().__init__( task_config, task_function, task_type=self._PYTORCH_TASK_TYPE, **kwargs, - ) - - def _convert_replica_spec(self, replica_config: Union[Master, Worker]) -> pytorch_task.DistributedPyTorchTrainingReplicaSpec: + ) + + def _convert_replica_spec( + self, replica_config: Union[Master, Worker] + ) -> pytorch_task.DistributedPyTorchTrainingReplicaSpec: resources = convert_resources_to_resource_model(requests=replica_config.requests, limits=replica_config.limits) replicas = 1 # Master should always have 1 replica @@ -155,7 +166,7 @@ def _convert_replica_spec(self, replica_config: Union[Master, Worker]) -> pytorc resources=resources.to_flyte_idl() if resources else None, restart_policy=replica_config.restart_policy.value if replica_config.restart_policy else None, ) - + def _convert_run_policy(self, run_policy: RunPolicy) -> kubeflow_common.RunPolicy: return kubeflow_common.RunPolicy( clean_pod_policy=run_policy.clean_pod_policy.value if run_policy.clean_pod_policy else None, @@ -163,13 +174,13 @@ def _convert_run_policy(self, run_policy: RunPolicy) -> kubeflow_common.RunPolic active_deadline_seconds=run_policy.active_deadline_seconds, backoff_limit=run_policy.active_deadline_seconds, ) - + def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: worker = self._convert_replica_spec(self.task_config.worker) # support v0 config for backwards compatibility - if (self.task_config.num_workers): + if self.task_config.num_workers: worker.replicas = self.task_config.num_workers - + run_policy = self._convert_run_policy(self.task_config.run_policy) if self.task_config.run_policy else None pytorch_job = pytorch_task.DistributedPyTorchTrainingTask( worker_replicas=worker, diff --git a/plugins/flytekit-kf-pytorch/tests/test_pytorch_task.py b/plugins/flytekit-kf-pytorch/tests/test_pytorch_task.py index b828caf8d0..ecdf9e375c 100644 --- a/plugins/flytekit-kf-pytorch/tests/test_pytorch_task.py +++ b/plugins/flytekit-kf-pytorch/tests/test_pytorch_task.py @@ -1,9 +1,10 @@ import pytest +from flytekitplugins.kfpytorch.task import Master, PyTorch, RestartPolicy, Worker -from flytekitplugins.kfpytorch.task import PyTorch, Worker, Master, RestartPolicy from flytekit import Resources, task from flytekit.configuration import Image, ImageConfig, SerializationSettings + @pytest.fixture def serialization_settings() -> SerializationSettings: default_img = Image(name="default", fqn="test", tag="tag") @@ -31,26 +32,15 @@ def my_pytorch_task(x: int, y: str) -> int: assert my_pytorch_task.task_config is not None - default_img = Image(name="default", fqn="test", tag="tag") - settings = serialization_settings - - assert my_pytorch_task.get_custom(settings) == { - "workerReplicas": { - "replicas": 10, - "resources": {} - }, - "masterReplicas": { - "replicas": 1, - "resources": {} - }, + assert my_pytorch_task.get_custom(serialization_settings) == { + "workerReplicas": {"replicas": 10, "resources": {}}, + "masterReplicas": {"replicas": 1, "resources": {}}, } assert my_pytorch_task.resources.limits == Resources() assert my_pytorch_task.resources.requests == Resources(cpu="1") assert my_pytorch_task.task_type == "pytorch" - - def test_pytorch_task_with_default_config(serialization_settings: SerializationSettings): task_config = PyTorch(worker=Worker(replicas=1)) @@ -135,4 +125,4 @@ def my_pytorch_task(x: int, y: str) -> int: "restartPolicy": "RESTART_POLICY_ALWAYS", }, } - assert my_pytorch_task.get_custom(serialization_settings) == expected_custom_dict \ No newline at end of file + assert my_pytorch_task.get_custom(serialization_settings) == expected_custom_dict diff --git a/plugins/flytekit-kf-tensorflow/README.md b/plugins/flytekit-kf-tensorflow/README.md index bc9d886659..d059624f03 100644 --- a/plugins/flytekit-kf-tensorflow/README.md +++ b/plugins/flytekit-kf-tensorflow/README.md @@ -9,7 +9,7 @@ pip install flytekitplugins-kftensorflow ``` ## Code Example -To build a TFJob with: +To build a TFJob with: 10 workers with restart policy as failed and 2 CPU and 2Gi Memory 1 ps replica with resources the same as task defined resources 1 chief replica with resources the same as task defined resources and restart policy as always diff --git a/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/__init__.py b/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/__init__.py index 16623fb79c..81a4cbc248 100644 --- a/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/__init__.py +++ b/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/__init__.py @@ -10,4 +10,4 @@ TfJob """ -from .task import PS, Chief, RunPolicy, TfJob, Worker, RestartPolicy, CleanPodPolicy +from .task import PS, Chief, CleanPodPolicy, RestartPolicy, RunPolicy, TfJob, Worker diff --git a/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/task.py b/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/task.py index 84f1da51d6..bd6a97a293 100644 --- a/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/task.py +++ b/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/task.py @@ -6,14 +6,15 @@ from enum import Enum from typing import Any, Callable, Dict, Optional, Union +from flyteidl.plugins.kubeflow import common_pb2 as kubeflow_common +from flyteidl.plugins.kubeflow import tensorflow_pb2 as tensorflow_task from google.protobuf.json_format import MessageToDict from flytekit import PythonFunctionTask, Resources from flytekit.configuration import SerializationSettings from flytekit.core.resources import convert_resources_to_resource_model from flytekit.extend import TaskPlugins -from flyteidl.plugins.kubeflow import tensorflow_pb2 as tensorflow_task -from flyteidl.plugins.kubeflow import common_pb2 as kubeflow_common + @dataclass class RestartPolicy(Enum): @@ -25,6 +26,7 @@ class RestartPolicy(Enum): FAILURE = kubeflow_common.RESTART_POLICY_ON_FAILURE NEVER = kubeflow_common.RESTART_POLICY_NEVER + @dataclass class CleanPodPolicy(Enum): """ @@ -35,6 +37,7 @@ class CleanPodPolicy(Enum): ALL = kubeflow_common.CLEANPOD_POLICY_ALL RUNNING = kubeflow_common.CLEANPOD_POLICY_RUNNING + @dataclass class RunPolicy: """ @@ -48,6 +51,7 @@ class RunPolicy: where restartPolicy is OnFailure or Always. backoff_limit: The number of retries before marking this job as failed. """ + clean_pod_policy: CleanPodPolicy = None ttl_seconds_after_finished: Optional[int] = None active_deadline_seconds: Optional[int] = None @@ -96,6 +100,7 @@ class TfJob: num_ps_replicas: [DEPRECATED] This argument is deprecated. Use `ps.replicas` instead. num_chief_replicas: [DEPRECATED] This argument is deprecated. Use `chief.replicas` instead. """ + chief: Chief = field(default_factory=lambda: Chief()) ps: PS = field(default_factory=lambda: PS()) worker: Worker = field(default_factory=lambda: Worker()) @@ -116,17 +121,29 @@ class TensorflowFunctionTask(PythonFunctionTask[TfJob]): def __init__(self, task_config: TfJob, task_function: Callable, **kwargs): if task_config.num_workers and task_config.worker.replicas: - raise ValueError("Cannot specify both `num_workers` and `worker.replicas`. Please use `worker.replicas` as `num_workers` is depreacated.") + raise ValueError( + "Cannot specify both `num_workers` and `worker.replicas`. Please use `worker.replicas` as `num_workers` is depreacated." + ) if task_config.num_workers is None and task_config.worker.replicas is None: - raise ValueError("Must specify either `num_workers` or `worker.replicas`. Please use `worker.replicas` as `num_workers` is depreacated.") + raise ValueError( + "Must specify either `num_workers` or `worker.replicas`. Please use `worker.replicas` as `num_workers` is depreacated." + ) if task_config.num_chief_replicas and task_config.chief.replicas: - raise ValueError("Cannot specify both `num_workers` and `chief.replicas`. Please use `chief.replicas` as `num_chief_replicas` is depreacated.") + raise ValueError( + "Cannot specify both `num_workers` and `chief.replicas`. Please use `chief.replicas` as `num_chief_replicas` is depreacated." + ) if task_config.num_chief_replicas is None and task_config.chief.replicas is None: - raise ValueError("Must specify either `num_workers` or `chief.replicas`. Please use `chief.replicas` as `num_chief_replicas` is depreacated.") + raise ValueError( + "Must specify either `num_workers` or `chief.replicas`. Please use `chief.replicas` as `num_chief_replicas` is depreacated." + ) if task_config.num_ps_replicas and task_config.ps.replicas: - raise ValueError("Cannot specify both `num_workers` and `ps.replicas`. Please use `ps.replicas` as `num_ps_replicas` is depreacated.") + raise ValueError( + "Cannot specify both `num_workers` and `ps.replicas`. Please use `ps.replicas` as `num_ps_replicas` is depreacated." + ) if task_config.num_ps_replicas is None and task_config.ps.replicas is None: - raise ValueError("Must specify either `num_workers` or `ps.replicas`. Please use `ps.replicas` as `num_ps_replicas` is depreacated.") + raise ValueError( + "Must specify either `num_workers` or `ps.replicas`. Please use `ps.replicas` as `num_ps_replicas` is depreacated." + ) super().__init__( task_type=self._TF_JOB_TASK_TYPE, task_config=task_config, @@ -134,8 +151,10 @@ def __init__(self, task_config: TfJob, task_function: Callable, **kwargs): task_type_version=1, **kwargs, ) - - def _convert_replica_spec(self, replica_config: Union[Chief, PS, Worker]) -> tensorflow_task.DistributedTensorflowTrainingReplicaSpec: + + def _convert_replica_spec( + self, replica_config: Union[Chief, PS, Worker] + ) -> tensorflow_task.DistributedTensorflowTrainingReplicaSpec: resources = convert_resources_to_resource_model(requests=replica_config.requests, limits=replica_config.limits) return tensorflow_task.DistributedTensorflowTrainingReplicaSpec( replicas=replica_config.replicas, @@ -143,7 +162,7 @@ def _convert_replica_spec(self, replica_config: Union[Chief, PS, Worker]) -> ten resources=resources.to_flyte_idl() if resources else None, restart_policy=replica_config.restart_policy.value if replica_config.restart_policy else None, ) - + def _convert_run_policy(self, run_policy: RunPolicy) -> kubeflow_common.RunPolicy: return kubeflow_common.RunPolicy( clean_pod_policy=run_policy.clean_pod_policy.value if run_policy.clean_pod_policy.value else None, @@ -154,17 +173,17 @@ def _convert_run_policy(self, run_policy: RunPolicy) -> kubeflow_common.RunPolic def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: chief = self._convert_replica_spec(self.task_config.chief) - if (self.task_config.num_chief_replicas): + if self.task_config.num_chief_replicas: chief.replicas = self.task_config.num_chief_replicas worker = self._convert_replica_spec(self.task_config.worker) - if (self.task_config.num_workers): + if self.task_config.num_workers: worker.replicas = self.task_config.num_workers - + ps = self._convert_replica_spec(self.task_config.ps) - if (self.task_config.num_ps_replicas): + if self.task_config.num_ps_replicas: ps.replicas = self.task_config.num_ps_replicas - + run_policy = self._convert_run_policy(self.task_config.run_policy) if self.task_config.run_policy else None training_task = tensorflow_task.DistributedTensorflowTrainingTask( chief_replicas=chief, @@ -172,7 +191,7 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: ps_replicas=ps, run_policy=run_policy, ) - + return MessageToDict(training_task) diff --git a/plugins/flytekit-kf-tensorflow/tests/test_tensorflow_task.py b/plugins/flytekit-kf-tensorflow/tests/test_tensorflow_task.py index 2790fc4a9b..d863d3fdc4 100644 --- a/plugins/flytekit-kf-tensorflow/tests/test_tensorflow_task.py +++ b/plugins/flytekit-kf-tensorflow/tests/test_tensorflow_task.py @@ -131,7 +131,7 @@ def test_tensorflow_task_with_run_policy(serialization_settings: SerializationSe worker=Worker(replicas=1), ps=PS(replicas=0), chief=Chief(replicas=0), - run_policy=RunPolicy(clean_pod_policy=CleanPodPolicy.RUNNING) + run_policy=RunPolicy(clean_pod_policy=CleanPodPolicy.RUNNING), ) @task( @@ -167,6 +167,7 @@ def my_tensorflow_task(x: int, y: str) -> int: } assert my_tensorflow_task.get_custom(serialization_settings) == expected_dict + def test_tensorflow_task(): @task( task_config=TfJob(num_workers=10, num_ps_replicas=1, num_chief_replicas=1), @@ -207,4 +208,4 @@ def my_tensorflow_task(x: int, y: str) -> int: assert my_tensorflow_task.get_custom(settings) == expected_dict assert my_tensorflow_task.resources.limits == Resources() assert my_tensorflow_task.resources.requests == Resources(cpu="1") - assert my_tensorflow_task.task_type == "tensorflow" \ No newline at end of file + assert my_tensorflow_task.task_type == "tensorflow"