-
Notifications
You must be signed in to change notification settings - Fork 710
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Parameter Server: Run TF server by default #36
Changes from 16 commits
95ddfa8
69ac7cc
c9f2d66
de8f8f4
8082db3
7fa80cb
8b233ef
d5907df
3f4f176
c104f0e
188e5cd
c6f047d
7176834
c8f00b5
9f24996
9a62ee9
2c32e0b
bb3fa53
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
#!/usr/bin/python | ||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
""" | ||
TODO: Once grpc_tensorflow_server.py is included in tensorflow | ||
docker image we should use it instead" | ||
|
||
Python-based TensorFlow GRPC server. | ||
|
||
Takes input arguments cluster_spec, job_name and task_id, and start a blocking | ||
TensorFlow GRPC server. | ||
|
||
Usage: | ||
grpc_tensorflow_server.py --cluster_spec=SPEC --job_name=NAME --task_id=ID | ||
|
||
Where: | ||
SPEC is <JOB>(,<JOB>)* | ||
JOB is <NAME>|<HOST:PORT>(;<HOST:PORT>)* | ||
NAME is a valid job name ([a-z][0-9a-z]*) | ||
HOST is a hostname or IP address | ||
PORT is a port number | ||
""" | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import argparse | ||
import sys | ||
|
||
from tensorflow.core.protobuf import config_pb2 | ||
from tensorflow.core.protobuf import tensorflow_server_pb2 | ||
from tensorflow.python.platform import app | ||
from tensorflow.python.training import server_lib | ||
|
||
|
||
def parse_cluster_spec(cluster_spec, cluster, verbose=False): | ||
"""Parse content of cluster_spec string and inject info into cluster protobuf. | ||
|
||
Args: | ||
cluster_spec: cluster specification string, e.g., | ||
"local|localhost:2222;localhost:2223" | ||
cluster: cluster protobuf. | ||
verbose: If verbose logging is requested. | ||
|
||
Raises: | ||
ValueError: if the cluster_spec string is invalid. | ||
""" | ||
|
||
job_strings = cluster_spec.split(",") | ||
|
||
if not cluster_spec: | ||
raise ValueError("Empty cluster_spec string") | ||
|
||
for job_string in job_strings: | ||
job_def = cluster.job.add() | ||
|
||
if job_string.count("|") != 1: | ||
raise ValueError("Not exactly one instance of '|' in cluster_spec") | ||
|
||
job_name = job_string.split("|")[0] | ||
|
||
if not job_name: | ||
raise ValueError("Empty job_name in cluster_spec") | ||
|
||
job_def.name = job_name | ||
|
||
if verbose: | ||
print("Added job named \"%s\"" % job_name) | ||
|
||
job_tasks = job_string.split("|")[1].split(";") | ||
for i in range(len(job_tasks)): | ||
if not job_tasks[i]: | ||
raise ValueError("Empty task string at position %d" % i) | ||
|
||
job_def.tasks[i] = job_tasks[i] | ||
|
||
if verbose: | ||
print(" Added task \"%s\" to job \"%s\"" % (job_tasks[i], job_name)) | ||
|
||
|
||
def main(unused_args): | ||
# Create Protobuf ServerDef | ||
server_def = tensorflow_server_pb2.ServerDef(protocol="grpc") | ||
|
||
# Cluster info | ||
parse_cluster_spec(FLAGS.cluster_spec, server_def.cluster, FLAGS.verbose) | ||
|
||
# Job name | ||
if not FLAGS.job_name: | ||
raise ValueError("Empty job_name") | ||
server_def.job_name = FLAGS.job_name | ||
|
||
# Task index | ||
if FLAGS.task_id < 0: | ||
raise ValueError("Invalid task_id: %d" % FLAGS.task_id) | ||
server_def.task_index = FLAGS.task_id | ||
|
||
config = config_pb2.ConfigProto(gpu_options=config_pb2.GPUOptions( | ||
per_process_gpu_memory_fraction=FLAGS.gpu_memory_fraction)) | ||
|
||
# Create GRPC Server instance | ||
server = server_lib.Server(server_def, config=config) | ||
|
||
# join() is blocking, unlike start() | ||
server.join() | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.register("type", "bool", lambda v: v.lower() == "true") | ||
parser.add_argument( | ||
"--cluster_spec", | ||
type=str, | ||
default="", | ||
help="""\ | ||
Cluster spec: SPEC. SPEC is <JOB>(,<JOB>)*," JOB is | ||
<NAME>|<HOST:PORT>(;<HOST:PORT>)*," NAME is a valid job name | ||
([a-z][0-9a-z]*)," HOST is a hostname or IP address," PORT is a | ||
port number." E.g., local|localhost:2222;localhost:2223, | ||
ps|ps0:2222;ps1:2222\ | ||
""" | ||
) | ||
parser.add_argument( | ||
"--job_name", | ||
type=str, | ||
default="", | ||
help="Job name: e.g., local" | ||
) | ||
parser.add_argument( | ||
"--task_id", | ||
type=int, | ||
default=0, | ||
help="Task index, e.g., 0" | ||
) | ||
parser.add_argument( | ||
"--gpu_memory_fraction", | ||
type=float, | ||
default=1.0, | ||
help="Fraction of GPU memory allocated",) | ||
parser.add_argument( | ||
"--verbose", | ||
type="bool", | ||
nargs="?", | ||
const=True, | ||
default=False, | ||
help="Verbose mode" | ||
) | ||
|
||
FLAGS, unparsed = parser.parse_known_args() | ||
app.run(main=main, argv=[sys.argv[0]] + unparsed) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,24 +6,25 @@ import ( | |
"errors" | ||
"fmt" | ||
"io" | ||
"k8s.io/client-go/kubernetes" | ||
"github.com/jlewi/mlkube.io/pkg/spec" | ||
"github.com/jlewi/mlkube.io/pkg/trainer" | ||
"github.com/jlewi/mlkube.io/pkg/util/k8sutil" | ||
"net/http" | ||
"reflect" | ||
"sync" | ||
"time" | ||
|
||
apierrors "k8s.io/apimachinery/pkg/api/errors" | ||
apiextensionsclient "k8s.io/apiextensions-apiserver/pkg/client/clientset/clientset" | ||
v1beta1 "k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1beta1" | ||
"github.com/jlewi/mlkube.io/pkg/spec" | ||
"github.com/jlewi/mlkube.io/pkg/trainer" | ||
"github.com/jlewi/mlkube.io/pkg/util/k8sutil" | ||
"k8s.io/client-go/kubernetes" | ||
|
||
log "github.com/golang/glog" | ||
"github.com/jlewi/mlkube.io/pkg/util" | ||
v1beta1 "k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1beta1" | ||
apiextensionsclient "k8s.io/apiextensions-apiserver/pkg/client/clientset/clientset" | ||
apierrors "k8s.io/apimachinery/pkg/api/errors" | ||
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" | ||
kwatch "k8s.io/apimachinery/pkg/watch" | ||
k8sErrors "k8s.io/apimachinery/pkg/util/errors" | ||
"k8s.io/apimachinery/pkg/util/wait" | ||
"github.com/jlewi/mlkube.io/pkg/util" | ||
kwatch "k8s.io/apimachinery/pkg/watch" | ||
) | ||
|
||
var ( | ||
|
@@ -66,7 +67,7 @@ func New(kubeCli kubernetes.Interface, apiCli apiextensionsclient.Interface, tfJ | |
return &Controller{ | ||
Namespace: ns, | ||
KubeCli: kubeCli, | ||
ApiCli: apiCli, | ||
ApiCli: apiCli, | ||
TfJobClient: tfJobClient, | ||
// TODO(jlewi)): What to do about cluster.Cluster? | ||
jobs: make(map[string]*trainer.TrainingJob), | ||
|
@@ -147,7 +148,7 @@ func (c *Controller) handleTfJobEvent(event *Event) error { | |
//NewJob(kubeCli kubernetes.Interface, job spec.TfJob, stopC <-chan struct{}, wg *sync.WaitGroup) | ||
|
||
c.stopChMap[clus.Metadata.Name] = stopC | ||
c.jobs[clus.Metadata.Namespace + "-" + clus.Metadata.Name] = nc | ||
c.jobs[clus.Metadata.Namespace+"-"+clus.Metadata.Name] = nc | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: I think there should be spaces after the plus signs. Maybe run gofmt? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If I add this spaces back and run gofmt, it will remove them. |
||
c.jobRVs[clus.Metadata.Name] = clus.Metadata.ResourceVersion | ||
|
||
//case kwatch.Modified: | ||
|
@@ -158,10 +159,10 @@ func (c *Controller) handleTfJobEvent(event *Event) error { | |
// c.jobRVs[clus.Metadata.Name] = clus.Metadata.ResourceVersion | ||
// | ||
case kwatch.Deleted: | ||
if _, ok := c.jobs[clus.Metadata.Namespace + "-" + clus.Metadata.Name]; !ok { | ||
if _, ok := c.jobs[clus.Metadata.Namespace+"-"+clus.Metadata.Name]; !ok { | ||
return fmt.Errorf("unsafe state. TfJob was never created but we received event (%s)", event.Type) | ||
} | ||
c.jobs[clus.Metadata.Namespace + "-" + clus.Metadata.Name].Delete() | ||
c.jobs[clus.Metadata.Namespace+"-"+clus.Metadata.Name].Delete() | ||
delete(c.jobs, clus.Metadata.Name) | ||
delete(c.jobRVs, clus.Metadata.Name) | ||
} | ||
|
@@ -193,7 +194,7 @@ func (c *Controller) findAllTfJobs() (string, error) { | |
continue | ||
} | ||
c.stopChMap[clus.Metadata.Name] = stopC | ||
c.jobs[clus.Metadata.Namespace + "-" + clus.Metadata.Name] = nc | ||
c.jobs[clus.Metadata.Namespace+"-"+clus.Metadata.Name] = nc | ||
c.jobRVs[clus.Metadata.Name] = clus.Metadata.ResourceVersion | ||
} | ||
|
||
|
@@ -237,16 +238,16 @@ func (c *Controller) createCRD() error { | |
Name: spec.CRDName(), | ||
}, | ||
Spec: v1beta1.CustomResourceDefinitionSpec{ | ||
Group: spec.CRDGroup, | ||
Group: spec.CRDGroup, | ||
Version: spec.CRDVersion, | ||
Scope: v1beta1.NamespaceScoped, | ||
Names: v1beta1.CustomResourceDefinitionNames{ | ||
Plural: spec.CRDKindPlural, | ||
// TODO(jlewi): Do we want to set the singular name? | ||
// Kind is the serialized kind of the resource. It is normally CamelCase and singular. | ||
Kind: reflect.TypeOf(spec.TfJob{}).Name(), | ||
}, | ||
Scope: v1beta1.NamespaceScoped, | ||
Names: v1beta1.CustomResourceDefinitionNames{ | ||
Plural: spec.CRDKindPlural, | ||
// TODO(jlewi): Do we want to set the singular name? | ||
// Kind is the serialized kind of the resource. It is normally CamelCase and singular. | ||
Kind: reflect.TypeOf(spec.TfJob{}).Name(), | ||
}, | ||
}, | ||
} | ||
|
||
_, err := c.ApiCli.ApiextensionsV1beta1().CustomResourceDefinitions().Create(crd) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The if args.should_push statement should be inside the else block. We only push the image if we aren't using GCB.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed,