diff --git a/fairing/builders/docker_builder.py b/fairing/builders/docker_builder.py index e7abf072..1b68ea12 100644 --- a/fairing/builders/docker_builder.py +++ b/fairing/builders/docker_builder.py @@ -25,24 +25,29 @@ class DockerBuilder(BuilderInterface): """A builder using the local Docker client""" - def __init__(self, - repository, - image_name=DEFAULT_IMAGE_NAME, - image_tag=None, - base_image=None, - dockerfile_path=None): + def __init__(self, + repository, + image_name=DEFAULT_IMAGE_NAME, + image_tag=None, + base_image=None, + dockerfile_path=None, + image_pull_policy='Always', + restart_policy='Never'): self.repository = repository self.image_name = image_name self.base_image = base_image self.dockerfile_path = dockerfile_path + self.image_pull_policy = image_pull_policy + self.restart_policy = restart_policy if image_tag is None: self.image_tag = utils.get_unique_tag() else: self.image_tag = image_tag + self.full_image_name = utils.get_image_full_name( - self.repository, + self.repository, self.image_name, self.image_tag ) @@ -55,8 +60,9 @@ def generate_pod_spec(self): containers=[client.V1Container( name='model', image=self.full_image_name, + image_pull_policy=self.image_pull_policy, )], - restart_policy='Never' + restart_policy=self.restart_policy ) def execute(self): diff --git a/fairing/training/kubeflow/decorators.py b/fairing/training/kubeflow/decorators.py index 995757d9..f145b47f 100644 --- a/fairing/training/kubeflow/decorators.py +++ b/fairing/training/kubeflow/decorators.py @@ -1,11 +1,12 @@ from ..native import decorators from .deployment import KubeflowDeployment - + + class Training(decorators.Training): - def __init__(self, namespace=None): - super(Training, self).__init__(namespace) + def __init__(self, namespace=None, job_name=None): + super(Training, self).__init__(namespace, job_name) self.distribution = { 'Worker': 1 } @@ -13,6 +14,7 @@ def __init__(self, namespace=None): def _deploy(self, user_object): deployment = KubeflowDeployment( self.namespace, + self.job_name, self.runs, self.distribution) deployment.execute() @@ -20,9 +22,9 @@ def _deploy(self, user_object): class DistributedTraining(Training): - def __init__(self, worker_count=0, ps_count=0, namespace=None): + def __init__(self, worker_count=0, ps_count=0, namespace=None, job_name=None): # By default we set worker to 0 as we always add a Chief - super(DistributedTraining, self).__init__(namespace) + super(DistributedTraining, self).__init__(namespace, job_name) self.distribution = { 'Worker': worker_count, 'PS': ps_count, diff --git a/fairing/training/kubeflow/deployment.py b/fairing/training/kubeflow/deployment.py index dbde3fca..3b772e0c 100644 --- a/fairing/training/kubeflow/deployment.py +++ b/fairing/training/kubeflow/deployment.py @@ -2,10 +2,11 @@ from ..native import deployment + class KubeflowDeployment(deployment.NativeDeployment): - def __init__(self, namespace, runs, distribution): - super(KubeflowDeployment, self).__init__(namespace, runs) + def __init__(self, namespace, job_name, runs, distribution): + super(KubeflowDeployment, self).__init__(namespace, job_name, runs) self.distribution = distribution def deploy(self): @@ -50,4 +51,4 @@ def set_container_name(self, pod_template_spec): def get_logs(self): selector='tf-replica-index=0,tf-replica-type=worker' - self.backend.log(self.name, self.namespace, selector) \ No newline at end of file + self.backend.log(self.name, self.namespace, selector) diff --git a/fairing/training/native/decorators.py b/fairing/training/native/decorators.py index 5bd7f15b..cc7c2e1e 100644 --- a/fairing/training/native/decorators.py +++ b/fairing/training/native/decorators.py @@ -14,6 +14,7 @@ logger = logging.getLogger(__name__) + class Training(base.TrainingDecoratorInterface): """A simple Kubernetes training. @@ -21,7 +22,8 @@ class Training(base.TrainingDecoratorInterface): namespace {string} -- (optional) here the training should be deployed """ - def __init__(self, namespace=None): + def __init__(self, namespace=None, job_name=None): + self.job_name = job_name self.namespace = namespace self.runs = 1 @@ -35,7 +37,7 @@ def _train(self, user_object): runtime.execute(user_object) def _deploy(self, user_object): - deployment = NativeDeployment(self.namespace, self.runs) + deployment = NativeDeployment(self.namespace, self.job_name, self.runs) deployment.execute() diff --git a/fairing/training/native/deployment.py b/fairing/training/native/deployment.py index d2e971c4..47efdd83 100644 --- a/fairing/training/native/deployment.py +++ b/fairing/training/native/deployment.py @@ -10,12 +10,12 @@ from fairing import config from fairing import utils -from fairing.training import base from fairing.backend import kubernetes logger = logging.getLogger(__name__) DEFAULT_JOB_NAME = 'fairing-job' + class NativeDeployment(object): """Handle all the k8s' template building for a training Attributes: @@ -25,14 +25,17 @@ class NativeDeployment(object): will generate multiple jobs. """ - def __init__(self, namespace, runs): + def __init__(self, namespace, job_name, runs): if namespace is None: self.namespace = utils.get_default_target_namespace() else: self.namespace = namespace # Used as pod and job name - self.name = "{}-{}".format(DEFAULT_JOB_NAME, utils.get_unique_tag()) + if job_name is None: + job_name = DEFAULT_JOB_NAME + + self.name = "{}-{}".format(job_name, utils.get_unique_tag()) self.job_spec = None self.runs = runs