diff --git a/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/__init__.py b/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/__init__.py index 81a4cbc248..fd12b7192f 100644 --- a/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/__init__.py +++ b/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/__init__.py @@ -10,4 +10,4 @@ TfJob """ -from .task import PS, Chief, CleanPodPolicy, RestartPolicy, RunPolicy, TfJob, Worker +from .task import PS, Chief, CleanPodPolicy, Evaluator, RestartPolicy, RunPolicy, TfJob, Worker diff --git a/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/task.py b/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/task.py index 43a6ad55c3..7be1f7d030 100644 --- a/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/task.py +++ b/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/task.py @@ -85,6 +85,15 @@ class Worker: restart_policy: Optional[RestartPolicy] = None +@dataclass +class Evaluator: + image: Optional[str] = None + requests: Optional[Resources] = None + limits: Optional[Resources] = None + replicas: int = 0 + restart_policy: Optional[RestartPolicy] = None + + @dataclass class TfJob: """ @@ -95,6 +104,7 @@ class TfJob: chief: Configuration for the chief replica group. ps: Configuration for the parameter server (PS) replica group. worker: Configuration for the worker replica group. + evaluator: Configuration for the evaluator replica group. run_policy: Configuration for the run policy. num_workers: [DEPRECATED] This argument is deprecated. Use `worker.replicas` instead. num_ps_replicas: [DEPRECATED] This argument is deprecated. Use `ps.replicas` instead. @@ -104,11 +114,13 @@ class TfJob: chief: Chief = field(default_factory=lambda: Chief()) ps: PS = field(default_factory=lambda: PS()) worker: Worker = field(default_factory=lambda: Worker()) + evaluator: Evaluator = field(default_factory=lambda: Evaluator()) run_policy: Optional[RunPolicy] = field(default_factory=lambda: None) # Support v0 config for backwards compatibility num_workers: Optional[int] = None num_ps_replicas: Optional[int] = None num_chief_replicas: Optional[int] = None + num_evaluator_replicas: Optional[int] = None class TensorflowFunctionTask(PythonFunctionTask[TfJob]): @@ -130,19 +142,23 @@ def __init__(self, task_config: TfJob, task_function: Callable, **kwargs): ) if task_config.num_chief_replicas and task_config.chief.replicas: raise ValueError( - "Cannot specify both `num_workers` and `chief.replicas`. Please use `chief.replicas` as `num_chief_replicas` is depreacated." + "Cannot specify both `num_chief_replicas` and `chief.replicas`. Please use `chief.replicas` as `num_chief_replicas` is depreacated." ) if task_config.num_chief_replicas is None and task_config.chief.replicas is None: raise ValueError( - "Must specify either `num_workers` or `chief.replicas`. Please use `chief.replicas` as `num_chief_replicas` is depreacated." + "Must specify either `num_chief_replicas` or `chief.replicas`. Please use `chief.replicas` as `num_chief_replicas` is depreacated." ) if task_config.num_ps_replicas and task_config.ps.replicas: raise ValueError( - "Cannot specify both `num_workers` and `ps.replicas`. Please use `ps.replicas` as `num_ps_replicas` is depreacated." + "Cannot specify both `num_ps_replicas` and `ps.replicas`. Please use `ps.replicas` as `num_ps_replicas` is depreacated." ) if task_config.num_ps_replicas is None and task_config.ps.replicas is None: raise ValueError( - "Must specify either `num_workers` or `ps.replicas`. Please use `ps.replicas` as `num_ps_replicas` is depreacated." + "Must specify either `num_ps_replicas` or `ps.replicas`. Please use `ps.replicas` as `num_ps_replicas` is depreacated." + ) + if task_config.num_evaluator_replicas and task_config.evaluator.replicas > 0: + raise ValueError( + "Cannot specify both `num_evaluator_replicas` and `evaluator.replicas`. Please use `evaluator.replicas` as `num_evaluator_replicas` is depreacated." ) super().__init__( task_type=self._TF_JOB_TASK_TYPE, @@ -153,7 +169,7 @@ def __init__(self, task_config: TfJob, task_function: Callable, **kwargs): ) def _convert_replica_spec( - self, replica_config: Union[Chief, PS, Worker] + self, replica_config: Union[Chief, PS, Worker, Evaluator] ) -> tensorflow_task.DistributedTensorflowTrainingReplicaSpec: resources = convert_resources_to_resource_model(requests=replica_config.requests, limits=replica_config.limits) return tensorflow_task.DistributedTensorflowTrainingReplicaSpec( @@ -184,11 +200,16 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: if self.task_config.num_ps_replicas: ps.replicas = self.task_config.num_ps_replicas + evaluator = self._convert_replica_spec(self.task_config.evaluator) + if self.task_config.num_evaluator_replicas: + evaluator.replicas = self.task_config.num_evaluator_replicas + run_policy = self._convert_run_policy(self.task_config.run_policy) if self.task_config.run_policy else None training_task = tensorflow_task.DistributedTensorflowTrainingTask( chief_replicas=chief, worker_replicas=worker, ps_replicas=ps, + evaluator_replicas=evaluator, run_policy=run_policy, ) diff --git a/plugins/flytekit-kf-tensorflow/requirements.txt b/plugins/flytekit-kf-tensorflow/requirements.txt index 8f67a26831..96552b075f 100644 --- a/plugins/flytekit-kf-tensorflow/requirements.txt +++ b/plugins/flytekit-kf-tensorflow/requirements.txt @@ -1,5 +1,5 @@ # -# This file is autogenerated by pip-compile with Python 3.8 +# This file is autogenerated by pip-compile with Python 3.9 # by the following command: # # pip-compile requirements.in @@ -75,6 +75,7 @@ cryptography==39.0.2 # msal # pyjwt # pyopenssl + # secretstorage dataclasses-json==0.5.7 # via flytekit decorator==5.1.1 @@ -89,8 +90,10 @@ docker-image-py==0.1.12 # via flytekit docstring-parser==0.15 # via flytekit -flyteidl==1.5.5 - # via flytekit +flyteidl==1.10.0 + # via + # flytekit + # flytekitplugins-kftensorflow flytekit==1.6.1 # via flytekitplugins-kftensorflow frozenlist==1.3.3 @@ -151,12 +154,14 @@ importlib-metadata==6.1.0 # via # flytekit # keyring -importlib-resources==5.12.0 - # via keyring isodate==0.6.1 # via azure-storage-blob jaraco-classes==3.2.3 # via keyring +jeepney==0.8.0 + # via + # keyring + # secretstorage jinja2==3.1.2 # via # cookiecutter @@ -240,7 +245,9 @@ pycparser==2.21 pygments==2.15.1 # via rich pyjwt[crypto]==2.7.0 - # via msal + # via + # msal + # pyjwt pyopenssl==23.0.0 # via flytekit python-dateutil==2.8.2 @@ -299,6 +306,8 @@ rsa==4.9 # via google-auth s3fs==2023.5.0 # via flytekit +secretstorage==3.3.3 + # via keyring six==1.16.0 # via # azure-core @@ -323,7 +332,6 @@ typing-extensions==4.5.0 # azure-core # azure-storage-blob # flytekit - # rich # typing-inspect typing-inspect==0.8.0 # via dataclasses-json @@ -350,9 +358,7 @@ wrapt==1.15.0 yarl==1.9.2 # via aiohttp zipp==3.15.0 - # via - # importlib-metadata - # importlib-resources + # via importlib-metadata # The following packages are considered to be unsafe in a requirements file: # setuptools diff --git a/plugins/flytekit-kf-tensorflow/setup.py b/plugins/flytekit-kf-tensorflow/setup.py index 79c1ade31d..25ffe19eec 100644 --- a/plugins/flytekit-kf-tensorflow/setup.py +++ b/plugins/flytekit-kf-tensorflow/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.6.1"] +plugin_requires = ["flyteidl>=1.10.0", "flytekit>=1.6.1"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-kf-tensorflow/tests/test_tensorflow_task.py b/plugins/flytekit-kf-tensorflow/tests/test_tensorflow_task.py index a6acae1760..0ae32439d7 100644 --- a/plugins/flytekit-kf-tensorflow/tests/test_tensorflow_task.py +++ b/plugins/flytekit-kf-tensorflow/tests/test_tensorflow_task.py @@ -1,5 +1,5 @@ import pytest -from flytekitplugins.kftensorflow import PS, Chief, CleanPodPolicy, RestartPolicy, RunPolicy, TfJob, Worker +from flytekitplugins.kftensorflow import PS, Chief, CleanPodPolicy, Evaluator, RestartPolicy, RunPolicy, TfJob, Worker from flytekit import Resources, task from flytekit.configuration import Image, ImageConfig, SerializationSettings @@ -23,6 +23,7 @@ def test_tensorflow_task_with_default_config(serialization_settings: Serializati worker=Worker(replicas=1), chief=Chief(replicas=0), ps=PS(replicas=0), + evaluator=Evaluator(replicas=0), ) @task( @@ -52,6 +53,9 @@ def my_tensorflow_task(x: int, y: str) -> int: "psReplicas": { "resources": {}, }, + "evaluatorReplicas": { + "resources": {}, + }, } assert my_tensorflow_task.get_custom(serialization_settings) == expected_dict @@ -75,6 +79,13 @@ def test_tensorflow_task_with_custom_config(serialization_settings: Serializatio replicas=2, restart_policy=RestartPolicy.ALWAYS, ), + evaluator=Evaluator( + replicas=5, + requests=Resources(cpu="2", mem="2Gi"), + limits=Resources(cpu="4", mem="2Gi"), + image="evaluator:latest", + restart_policy=RestartPolicy.FAILURE, + ), ) @task( @@ -122,7 +133,23 @@ def my_tensorflow_task(x: int, y: str) -> int: "replicas": 2, "restartPolicy": "RESTART_POLICY_ALWAYS", }, + "evaluatorReplicas": { + "replicas": 5, + "image": "evaluator:latest", + "resources": { + "requests": [ + {"name": "CPU", "value": "2"}, + {"name": "MEMORY", "value": "2Gi"}, + ], + "limits": [ + {"name": "CPU", "value": "4"}, + {"name": "MEMORY", "value": "2Gi"}, + ], + }, + "restartPolicy": "RESTART_POLICY_ON_FAILURE", + }, } + assert my_tensorflow_task.get_custom(serialization_settings) == expected_custom_dict @@ -131,6 +158,7 @@ def test_tensorflow_task_with_run_policy(serialization_settings: SerializationSe worker=Worker(replicas=1), ps=PS(replicas=0), chief=Chief(replicas=0), + evaluator=Evaluator(replicas=0), run_policy=RunPolicy( clean_pod_policy=CleanPodPolicy.RUNNING, backoff_limit=5, @@ -166,6 +194,9 @@ def my_tensorflow_task(x: int, y: str) -> int: "psReplicas": { "resources": {}, }, + "evaluatorReplicas": { + "resources": {}, + }, "runPolicy": { "cleanPodPolicy": "CLEANPOD_POLICY_RUNNING", "backoffLimit": 5, @@ -173,12 +204,13 @@ def my_tensorflow_task(x: int, y: str) -> int: "ttlSecondsAfterFinished": 100, }, } + assert my_tensorflow_task.get_custom(serialization_settings) == expected_dict def test_tensorflow_task(): @task( - task_config=TfJob(num_workers=10, num_ps_replicas=1, num_chief_replicas=1), + task_config=TfJob(num_workers=10, num_ps_replicas=1, num_chief_replicas=1, num_evaluator_replicas=1), cache=True, requests=Resources(cpu="1"), cache_version="1", @@ -212,7 +244,12 @@ def my_tensorflow_task(x: int, y: str) -> int: "replicas": 1, "resources": {}, }, + "evaluatorReplicas": { + "replicas": 1, + "resources": {}, + }, } + assert my_tensorflow_task.get_custom(settings) == expected_dict assert my_tensorflow_task.resources.limits == Resources() assert my_tensorflow_task.resources.requests == Resources(cpu="1")