From 57a59d6bbf306b9e140c552d13532104dc7fc719 Mon Sep 17 00:00:00 2001 From: Hamed Date: Wed, 9 Jan 2019 19:01:43 +0000 Subject: [PATCH] avoid launching both chief and worker jobs when running a non distributed job (#50) --- fairing/training/kubeflow/decorators.py | 3 ++- fairing/training/kubeflow/deployment.py | 7 ++++--- setup.py | 2 +- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/fairing/training/kubeflow/decorators.py b/fairing/training/kubeflow/decorators.py index c394449a..995757d9 100644 --- a/fairing/training/kubeflow/decorators.py +++ b/fairing/training/kubeflow/decorators.py @@ -25,6 +25,7 @@ def __init__(self, worker_count=0, ps_count=0, namespace=None): super(DistributedTraining, self).__init__(namespace) self.distribution = { 'Worker': worker_count, - 'PS': ps_count + 'PS': ps_count, + 'Chief': 1 } diff --git a/fairing/training/kubeflow/deployment.py b/fairing/training/kubeflow/deployment.py index 47d66f0e..dbde3fca 100644 --- a/fairing/training/kubeflow/deployment.py +++ b/fairing/training/kubeflow/deployment.py @@ -20,17 +20,18 @@ def generate_job(self, pod_template_spec): worker_replica_spec['template'] = pod_template_spec ps_replica_spec = {} - ps_replica_spec['replicas'] = self.distribution['PS'] + ps_replica_spec['replicas'] = self.distribution.get('PS', 0) ps_replica_spec['template'] = pod_template_spec chief_replica_spec = {} - chief_replica_spec['replicas'] = 1 + chief_replica_spec['replicas'] = self.distribution.get('Chief', 0) chief_replica_spec['template'] = pod_template_spec spec = {} spec['tfReplicaSpecs'] = {} - spec['tfReplicaSpecs']['Chief'] = chief_replica_spec spec['tfReplicaSpecs']['Worker'] = worker_replica_spec + if chief_replica_spec['replicas'] > 0: + spec['tfReplicaSpecs']['Chief'] = chief_replica_spec if ps_replica_spec['replicas'] > 0: spec['tfReplicaSpecs']['PS'] = ps_replica_spec diff --git a/setup.py b/setup.py index a86d4c87..63064c61 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,7 @@ 'notebook==5.6.0', 'jupyter==1.0.0', 'numpy==1.15.0', - 'kubernetes==6.0.0', + 'kubernetes==8.0.1', 'future==0.17.1', 'six==1.11.0', 'httplib2==0.12.0',