Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Kubeflow TensorFlow Training Operator Add Evaluator #1870

Merged
merged 31 commits into from
Nov 8, 2023
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
540eb5d
Merge branch 'master' of https://github.com/Future-Outlier/flytekit
Aug 12, 2023
984d44a
Merge branch 'master' of https://github.com/Future-Outlier/flytekit
Aug 16, 2023
3d936fc
Merge branch 'master' of https://github.com/Future-Outlier/flytekit
Aug 30, 2023
931533a
Merge branch 'master' of https://github.com/Future-Outlier/flytekit
Sep 7, 2023
bd5dbd7
Merge branch 'master' of https://github.com/Future-Outlier/flytekit
Sep 13, 2023
aeb5ea1
Merge branch 'master' of https://github.com/Future-Outlier/flytekit
Sep 13, 2023
54d2ddf
Merge branch 'flyteorg:master' into master
Future-Outlier Sep 17, 2023
541edc6
Merge branch 'flyteorg:master' into master
Future-Outlier Sep 19, 2023
75573ab
Merge branch 'flyteorg:master' into master
Future-Outlier Sep 23, 2023
c0139db
Merge pull request #19 from flyteorg/master
Future-Outlier Sep 28, 2023
24df7b3
Merge branch 'flyteorg:master' into master
Future-Outlier Oct 3, 2023
563ca22
Merge branch 'flyteorg:master' into master
Future-Outlier Oct 4, 2023
730e24a
kf-tensorflow-operator-evaluator
Oct 4, 2023
919d71b
give evaluator default value
Oct 6, 2023
5fa4f18
Merge branch 'flyteorg:master' into master
Future-Outlier Oct 6, 2023
7466bb0
Merge branch 'flyteorg:master' into kf-operator-evaluator
Future-Outlier Oct 6, 2023
5f1183b
Merge branch 'flyteorg:master' into master
Future-Outlier Oct 6, 2023
b33a7e4
update plugin_requires for ImageSpec test
Oct 6, 2023
e7d4d61
Merge branch 'kf-operator-evaluator' of https://github.com/Future-Out…
Oct 6, 2023
53981db
Merge branch 'master' of https://github.com/Future-Outlier/flytekit i…
Oct 6, 2023
c310a23
update sortedcontainers
Oct 6, 2023
ee498f9
test sortedcontainers
Oct 6, 2023
f4b03c4
update sortedcontainers in setup.py
Oct 6, 2023
203d2d4
Merge branch 'flyteorg:master' into master
Future-Outlier Oct 7, 2023
9302d9c
Merge branch 'master' of https://github.com/Future-Outlier/flytekit i…
Oct 7, 2023
98ddd54
update idl version
Oct 7, 2023
fbfea59
make evaluator default 0 replicas
Oct 7, 2023
75885ac
update flytekit dependency
Oct 8, 2023
f66f78d
Merge branch 'master' of https://github.com/Future-Outlier/flytekit i…
Nov 5, 2023
0f6f69f
update version
Nov 5, 2023
7a0f729
update idl requirements version
Nov 6, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@
TfJob
"""

from .task import PS, Chief, CleanPodPolicy, RestartPolicy, RunPolicy, TfJob, Worker
from .task import PS, Chief, CleanPodPolicy, Evaluator, RestartPolicy, RunPolicy, TfJob, Worker
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,15 @@ class Worker:
restart_policy: Optional[RestartPolicy] = None


@dataclass
class Evaluator:
image: Optional[str] = None
requests: Optional[Resources] = None
limits: Optional[Resources] = None
replicas: int = 0
restart_policy: Optional[RestartPolicy] = None


@dataclass
class TfJob:
"""
Expand All @@ -95,6 +104,7 @@ class TfJob:
chief: Configuration for the chief replica group.
ps: Configuration for the parameter server (PS) replica group.
worker: Configuration for the worker replica group.
evaluator: Configuration for the evaluator replica group.
run_policy: Configuration for the run policy.
num_workers: [DEPRECATED] This argument is deprecated. Use `worker.replicas` instead.
num_ps_replicas: [DEPRECATED] This argument is deprecated. Use `ps.replicas` instead.
Expand All @@ -104,11 +114,13 @@ class TfJob:
chief: Chief = field(default_factory=lambda: Chief())
ps: PS = field(default_factory=lambda: PS())
worker: Worker = field(default_factory=lambda: Worker())
evaluator: Evaluator = field(default_factory=lambda: Evaluator())
run_policy: Optional[RunPolicy] = field(default_factory=lambda: None)
# Support v0 config for backwards compatibility
num_workers: Optional[int] = None
num_ps_replicas: Optional[int] = None
num_chief_replicas: Optional[int] = None
num_evaluator_replicas: Optional[int] = None


class TensorflowFunctionTask(PythonFunctionTask[TfJob]):
Expand All @@ -130,19 +142,23 @@ def __init__(self, task_config: TfJob, task_function: Callable, **kwargs):
)
if task_config.num_chief_replicas and task_config.chief.replicas:
raise ValueError(
"Cannot specify both `num_workers` and `chief.replicas`. Please use `chief.replicas` as `num_chief_replicas` is depreacated."
"Cannot specify both `num_chief_replicas` and `chief.replicas`. Please use `chief.replicas` as `num_chief_replicas` is depreacated."
)
if task_config.num_chief_replicas is None and task_config.chief.replicas is None:
raise ValueError(
"Must specify either `num_workers` or `chief.replicas`. Please use `chief.replicas` as `num_chief_replicas` is depreacated."
"Must specify either `num_chief_replicas` or `chief.replicas`. Please use `chief.replicas` as `num_chief_replicas` is depreacated."
)
if task_config.num_ps_replicas and task_config.ps.replicas:
raise ValueError(
"Cannot specify both `num_workers` and `ps.replicas`. Please use `ps.replicas` as `num_ps_replicas` is depreacated."
"Cannot specify both `num_ps_replicas` and `ps.replicas`. Please use `ps.replicas` as `num_ps_replicas` is depreacated."
)
if task_config.num_ps_replicas is None and task_config.ps.replicas is None:
raise ValueError(
"Must specify either `num_workers` or `ps.replicas`. Please use `ps.replicas` as `num_ps_replicas` is depreacated."
"Must specify either `num_ps_replicas` or `ps.replicas`. Please use `ps.replicas` as `num_ps_replicas` is depreacated."
)
if task_config.num_evaluator_replicas and task_config.evaluator.replicas > 0:
raise ValueError(
"Cannot specify both `num_evaluator_replicas` and `evaluator.replicas`. Please use `evaluator.replicas` as `num_evaluator_replicas` is depreacated."
)
super().__init__(
task_type=self._TF_JOB_TASK_TYPE,
Expand All @@ -153,7 +169,7 @@ def __init__(self, task_config: TfJob, task_function: Callable, **kwargs):
)

def _convert_replica_spec(
self, replica_config: Union[Chief, PS, Worker]
self, replica_config: Union[Chief, PS, Worker, Evaluator]
) -> tensorflow_task.DistributedTensorflowTrainingReplicaSpec:
resources = convert_resources_to_resource_model(requests=replica_config.requests, limits=replica_config.limits)
return tensorflow_task.DistributedTensorflowTrainingReplicaSpec(
Expand Down Expand Up @@ -184,11 +200,16 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
if self.task_config.num_ps_replicas:
ps.replicas = self.task_config.num_ps_replicas

evaluator = self._convert_replica_spec(self.task_config.evaluator)
if self.task_config.num_evaluator_replicas:
evaluator.replicas = self.task_config.num_evaluator_replicas

run_policy = self._convert_run_policy(self.task_config.run_policy) if self.task_config.run_policy else None
training_task = tensorflow_task.DistributedTensorflowTrainingTask(
chief_replicas=chief,
worker_replicas=worker,
ps_replicas=ps,
evaluator_replicas=evaluator,
run_policy=run_policy,
)

Expand Down
2 changes: 1 addition & 1 deletion plugins/flytekit-kf-tensorflow/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

microlib_name = f"flytekitplugins-{PLUGIN_NAME}"

plugin_requires = ["flytekit>=1.6.1"]
plugin_requires = ["flyteidl>=1.10.0", "flytekit>=1.6.1"]

__version__ = "0.0.0+develop"

Expand Down
41 changes: 39 additions & 2 deletions plugins/flytekit-kf-tensorflow/tests/test_tensorflow_task.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
from flytekitplugins.kftensorflow import PS, Chief, CleanPodPolicy, RestartPolicy, RunPolicy, TfJob, Worker
from flytekitplugins.kftensorflow import PS, Chief, CleanPodPolicy, Evaluator, RestartPolicy, RunPolicy, TfJob, Worker

from flytekit import Resources, task
from flytekit.configuration import Image, ImageConfig, SerializationSettings
Expand All @@ -23,6 +23,7 @@ def test_tensorflow_task_with_default_config(serialization_settings: Serializati
worker=Worker(replicas=1),
chief=Chief(replicas=0),
ps=PS(replicas=0),
evaluator=Evaluator(replicas=0),
)

@task(
Expand Down Expand Up @@ -52,6 +53,9 @@ def my_tensorflow_task(x: int, y: str) -> int:
"psReplicas": {
"resources": {},
},
"evaluatorReplicas": {
"resources": {},
},
}
assert my_tensorflow_task.get_custom(serialization_settings) == expected_dict

Expand All @@ -75,6 +79,13 @@ def test_tensorflow_task_with_custom_config(serialization_settings: Serializatio
replicas=2,
restart_policy=RestartPolicy.ALWAYS,
),
evaluator=Evaluator(
replicas=5,
requests=Resources(cpu="2", mem="2Gi"),
limits=Resources(cpu="4", mem="2Gi"),
image="evaluator:latest",
restart_policy=RestartPolicy.FAILURE,
),
)

@task(
Expand Down Expand Up @@ -122,7 +133,23 @@ def my_tensorflow_task(x: int, y: str) -> int:
"replicas": 2,
"restartPolicy": "RESTART_POLICY_ALWAYS",
},
"evaluatorReplicas": {
"replicas": 5,
"image": "evaluator:latest",
"resources": {
"requests": [
{"name": "CPU", "value": "2"},
{"name": "MEMORY", "value": "2Gi"},
],
"limits": [
{"name": "CPU", "value": "4"},
{"name": "MEMORY", "value": "2Gi"},
],
},
"restartPolicy": "RESTART_POLICY_ON_FAILURE",
},
}

assert my_tensorflow_task.get_custom(serialization_settings) == expected_custom_dict


Expand All @@ -131,6 +158,7 @@ def test_tensorflow_task_with_run_policy(serialization_settings: SerializationSe
worker=Worker(replicas=1),
ps=PS(replicas=0),
chief=Chief(replicas=0),
evaluator=Evaluator(replicas=0),
run_policy=RunPolicy(
clean_pod_policy=CleanPodPolicy.RUNNING,
backoff_limit=5,
Expand Down Expand Up @@ -166,19 +194,23 @@ def my_tensorflow_task(x: int, y: str) -> int:
"psReplicas": {
"resources": {},
},
"evaluatorReplicas": {
"resources": {},
},
"runPolicy": {
"cleanPodPolicy": "CLEANPOD_POLICY_RUNNING",
"backoffLimit": 5,
"activeDeadlineSeconds": 100,
"ttlSecondsAfterFinished": 100,
},
}

assert my_tensorflow_task.get_custom(serialization_settings) == expected_dict


def test_tensorflow_task():
@task(
task_config=TfJob(num_workers=10, num_ps_replicas=1, num_chief_replicas=1),
task_config=TfJob(num_workers=10, num_ps_replicas=1, num_chief_replicas=1, num_evaluator_replicas=1),
cache=True,
requests=Resources(cpu="1"),
cache_version="1",
Expand Down Expand Up @@ -212,7 +244,12 @@ def my_tensorflow_task(x: int, y: str) -> int:
"replicas": 1,
"resources": {},
},
"evaluatorReplicas": {
"replicas": 1,
"resources": {},
},
}

assert my_tensorflow_task.get_custom(settings) == expected_dict
assert my_tensorflow_task.resources.limits == Resources()
assert my_tensorflow_task.resources.requests == Resources(cpu="1")
Expand Down
Loading