Skip to content
This repository has been archived by the owner on Aug 17, 2023. It is now read-only.

Adding more optional parameters #52

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions fairing/builders/docker_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,24 +25,29 @@
class DockerBuilder(BuilderInterface):
"""A builder using the local Docker client"""

def __init__(self,
repository,
image_name=DEFAULT_IMAGE_NAME,
image_tag=None,
base_image=None,
dockerfile_path=None):
def __init__(self,
repository,
image_name=DEFAULT_IMAGE_NAME,
image_tag=None,
base_image=None,
dockerfile_path=None,
image_pull_policy='Always',
restart_policy='Never'):

self.repository = repository
self.image_name = image_name
self.base_image = base_image
self.dockerfile_path = dockerfile_path
self.image_pull_policy = image_pull_policy
self.restart_policy = restart_policy

if image_tag is None:
self.image_tag = utils.get_unique_tag()
else:
self.image_tag = image_tag

self.full_image_name = utils.get_image_full_name(
self.repository,
self.repository,
self.image_name,
self.image_tag
)
Expand All @@ -55,8 +60,9 @@ def generate_pod_spec(self):
containers=[client.V1Container(
name='model',
image=self.full_image_name,
image_pull_policy=self.image_pull_policy,
)],
restart_policy='Never'
restart_policy=self.restart_policy
)

def execute(self):
Expand Down
12 changes: 7 additions & 5 deletions fairing/training/kubeflow/decorators.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,30 @@

from ..native import decorators
from .deployment import KubeflowDeployment



class Training(decorators.Training):

def __init__(self, namespace=None):
super(Training, self).__init__(namespace)
def __init__(self, namespace=None, job_name=None):
super(Training, self).__init__(namespace, job_name)
self.distribution = {
'Worker': 1
}

def _deploy(self, user_object):
deployment = KubeflowDeployment(
self.namespace,
self.job_name,
self.runs,
self.distribution)
deployment.execute()


class DistributedTraining(Training):

def __init__(self, worker_count=0, ps_count=0, namespace=None):
def __init__(self, worker_count=0, ps_count=0, namespace=None, job_name=None):
# By default we set worker to 0 as we always add a Chief
super(DistributedTraining, self).__init__(namespace)
super(DistributedTraining, self).__init__(namespace, job_name)
self.distribution = {
'Worker': worker_count,
'PS': ps_count,
Expand Down
7 changes: 4 additions & 3 deletions fairing/training/kubeflow/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

from ..native import deployment


class KubeflowDeployment(deployment.NativeDeployment):

def __init__(self, namespace, runs, distribution):
super(KubeflowDeployment, self).__init__(namespace, runs)
def __init__(self, namespace, job_name, runs, distribution):
super(KubeflowDeployment, self).__init__(namespace, job_name, runs)
self.distribution = distribution

def deploy(self):
Expand Down Expand Up @@ -50,4 +51,4 @@ def set_container_name(self, pod_template_spec):

def get_logs(self):
selector='tf-replica-index=0,tf-replica-type=worker'
self.backend.log(self.name, self.namespace, selector)
self.backend.log(self.name, self.namespace, selector)
6 changes: 4 additions & 2 deletions fairing/training/native/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,16 @@

logger = logging.getLogger(__name__)


class Training(base.TrainingDecoratorInterface):
"""A simple Kubernetes training.

Args:
namespace {string} -- (optional) here the training should be deployed
"""

def __init__(self, namespace=None):
def __init__(self, namespace=None, job_name=None):
self.job_name = job_name
self.namespace = namespace
self.runs = 1

Expand All @@ -35,7 +37,7 @@ def _train(self, user_object):
runtime.execute(user_object)

def _deploy(self, user_object):
deployment = NativeDeployment(self.namespace, self.runs)
deployment = NativeDeployment(self.namespace, self.job_name, self.runs)
deployment.execute()


Expand Down
9 changes: 6 additions & 3 deletions fairing/training/native/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@

from fairing import config
from fairing import utils
from fairing.training import base
from fairing.backend import kubernetes

logger = logging.getLogger(__name__)
DEFAULT_JOB_NAME = 'fairing-job'


class NativeDeployment(object):
"""Handle all the k8s' template building for a training
Attributes:
Expand All @@ -25,14 +25,17 @@ class NativeDeployment(object):
will generate multiple jobs.
"""

def __init__(self, namespace, runs):
def __init__(self, namespace, job_name, runs):
if namespace is None:
self.namespace = utils.get_default_target_namespace()
else:
self.namespace = namespace

# Used as pod and job name
self.name = "{}-{}".format(DEFAULT_JOB_NAME, utils.get_unique_tag())
if job_name is None:
job_name = DEFAULT_JOB_NAME

self.name = "{}-{}".format(job_name, utils.get_unique_tag())
self.job_spec = None
self.runs = runs

Expand Down