Skip to content
This repository has been archived by the owner on Aug 17, 2023. It is now read-only.

Commit

Permalink
Add more optional parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
hamedhsn committed Jan 10, 2019
1 parent 57a59d6 commit b83a739
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 20 deletions.
22 changes: 14 additions & 8 deletions fairing/builders/docker_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,24 +25,29 @@
class DockerBuilder(BuilderInterface):
"""A builder using the local Docker client"""

def __init__(self,
repository,
image_name=DEFAULT_IMAGE_NAME,
image_tag=None,
base_image=None,
dockerfile_path=None):
def __init__(self,
repository,
image_name=DEFAULT_IMAGE_NAME,
image_tag=None,
base_image=None,
dockerfile_path=None,
image_pull_policy='Always',
restart_policy='Never'):

self.repository = repository
self.image_name = image_name
self.base_image = base_image
self.dockerfile_path = dockerfile_path
self.image_pull_policy = image_pull_policy
self.restart_policy = restart_policy

if image_tag is None:
self.image_tag = utils.get_unique_tag()
else:
self.image_tag = image_tag

self.full_image_name = utils.get_image_full_name(
self.repository,
self.repository,
self.image_name,
self.image_tag
)
Expand All @@ -55,8 +60,9 @@ def generate_pod_spec(self):
containers=[client.V1Container(
name='model',
image=self.full_image_name,
image_pull_policy=self.image_pull_policy,
)],
restart_policy='Never'
restart_policy=self.restart_policy
)

def execute(self):
Expand Down
12 changes: 7 additions & 5 deletions fairing/training/kubeflow/decorators.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,30 @@

from ..native import decorators
from .deployment import KubeflowDeployment



class Training(decorators.Training):

def __init__(self, namespace=None):
super(Training, self).__init__(namespace)
def __init__(self, namespace=None, job_name=None):
super(Training, self).__init__(namespace, job_name)
self.distribution = {
'Worker': 1
}

def _deploy(self, user_object):
deployment = KubeflowDeployment(
self.namespace,
self.job_name,
self.runs,
self.distribution)
deployment.execute()


class DistributedTraining(Training):

def __init__(self, worker_count=0, ps_count=0, namespace=None):
def __init__(self, worker_count=0, ps_count=0, namespace=None, job_name=None):
# By default we set worker to 0 as we always add a Chief
super(DistributedTraining, self).__init__(namespace)
super(DistributedTraining, self).__init__(namespace, job_name)
self.distribution = {
'Worker': worker_count,
'PS': ps_count,
Expand Down
7 changes: 4 additions & 3 deletions fairing/training/kubeflow/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

from ..native import deployment


class KubeflowDeployment(deployment.NativeDeployment):

def __init__(self, namespace, runs, distribution):
super(KubeflowDeployment, self).__init__(namespace, runs)
def __init__(self, namespace, job_name, runs, distribution):
super(KubeflowDeployment, self).__init__(namespace, job_name, runs)
self.distribution = distribution

def deploy(self):
Expand Down Expand Up @@ -50,4 +51,4 @@ def set_container_name(self, pod_template_spec):

def get_logs(self):
selector='tf-replica-index=0,tf-replica-type=worker'
self.backend.log(self.name, self.namespace, selector)
self.backend.log(self.name, self.namespace, selector)
6 changes: 4 additions & 2 deletions fairing/training/native/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,16 @@

logger = logging.getLogger(__name__)


class Training(base.TrainingDecoratorInterface):
"""A simple Kubernetes training.
Args:
namespace {string} -- (optional) here the training should be deployed
"""

def __init__(self, namespace=None):
def __init__(self, namespace=None, job_name=None):
self.job_name = job_name
self.namespace = namespace
self.runs = 1

Expand All @@ -35,7 +37,7 @@ def _train(self, user_object):
runtime.execute(user_object)

def _deploy(self, user_object):
deployment = NativeDeployment(self.namespace, self.runs)
deployment = NativeDeployment(self.namespace, self.job_name, self.runs)
deployment.execute()


Expand Down
7 changes: 5 additions & 2 deletions fairing/training/native/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,17 @@ class NativeDeployment(object):
will generate multiple jobs.
"""

def __init__(self, namespace, runs):
def __init__(self, namespace, job_name, runs):
if namespace is None:
self.namespace = utils.get_default_target_namespace()
else:
self.namespace = namespace

# Used as pod and job name
self.name = "{}-{}".format(DEFAULT_JOB_NAME, utils.get_unique_tag())
if job_name is None:
job_name = DEFAULT_JOB_NAME

self.name = "{}-{}".format(job_name, utils.get_unique_tag())
self.job_spec = None
self.runs = runs

Expand Down

0 comments on commit b83a739

Please sign in to comment.