Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parameter Server: Run TF server by default #36

Merged
merged 18 commits into from
Oct 19, 2017
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions cmd/tf_operator/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ import (

"io/ioutil"

"github.com/jlewi/mlkube.io/pkg/spec"
"k8s.io/client-go/kubernetes"
"k8s.io/client-go/pkg/api/v1"
"k8s.io/client-go/tools/record"
"github.com/jlewi/mlkube.io/pkg/spec"
)

var (
Expand All @@ -36,6 +36,7 @@ var (
chaosLevel int
controllerConfigFile string
printVersion bool
grpcServerFile string
)

var (
Expand All @@ -52,7 +53,6 @@ func init() {
flag.BoolVar(&printVersion, "version", false, "Show version and quit")
flag.DurationVar(&gcInterval, "gc-interval", 10*time.Minute, "GC interval")
flag.StringVar(&controllerConfigFile, "controller_config_file", "", "Path to file containing the controller config.")

flag.Parse()

// Workaround for watching TPR resource.
Expand Down Expand Up @@ -84,6 +84,7 @@ func init() {
} else {
log.Info("No controller_config_file provided; using empty config.")
}

}

func main() {
Expand Down
4 changes: 3 additions & 1 deletion examples/tf_job.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,6 @@ spec:
containers:
- image: gcr.io/tf-on-k8s-dogfood/tf_sample:dc944ff
name: tensorflow
restartPolicy: OnFailure
restartPolicy: OnFailure
- replicas: 2
tfReplicaType: PS
2 changes: 1 addition & 1 deletion glide.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

163 changes: 163 additions & 0 deletions grpc_tensorflow_server/grpc_tensorflow_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
#!/usr/bin/python
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
TODO: Once grpc_tensorflow_server.py is included in tensorflow
docker image we should use it instead"

Python-based TensorFlow GRPC server.

Takes input arguments cluster_spec, job_name and task_id, and start a blocking
TensorFlow GRPC server.

Usage:
grpc_tensorflow_server.py --cluster_spec=SPEC --job_name=NAME --task_id=ID

Where:
SPEC is <JOB>(,<JOB>)*
JOB is <NAME>|<HOST:PORT>(;<HOST:PORT>)*
NAME is a valid job name ([a-z][0-9a-z]*)
HOST is a hostname or IP address
PORT is a port number
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import sys

from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import tensorflow_server_pb2
from tensorflow.python.platform import app
from tensorflow.python.training import server_lib


def parse_cluster_spec(cluster_spec, cluster, verbose=False):
"""Parse content of cluster_spec string and inject info into cluster protobuf.

Args:
cluster_spec: cluster specification string, e.g.,
"local|localhost:2222;localhost:2223"
cluster: cluster protobuf.
verbose: If verbose logging is requested.

Raises:
ValueError: if the cluster_spec string is invalid.
"""

job_strings = cluster_spec.split(",")

if not cluster_spec:
raise ValueError("Empty cluster_spec string")

for job_string in job_strings:
job_def = cluster.job.add()

if job_string.count("|") != 1:
raise ValueError("Not exactly one instance of '|' in cluster_spec")

job_name = job_string.split("|")[0]

if not job_name:
raise ValueError("Empty job_name in cluster_spec")

job_def.name = job_name

if verbose:
print("Added job named \"%s\"" % job_name)

job_tasks = job_string.split("|")[1].split(";")
for i in range(len(job_tasks)):
if not job_tasks[i]:
raise ValueError("Empty task string at position %d" % i)

job_def.tasks[i] = job_tasks[i]

if verbose:
print(" Added task \"%s\" to job \"%s\"" % (job_tasks[i], job_name))


def main(unused_args):
# Create Protobuf ServerDef
server_def = tensorflow_server_pb2.ServerDef(protocol="grpc")

# Cluster info
parse_cluster_spec(FLAGS.cluster_spec, server_def.cluster, FLAGS.verbose)

# Job name
if not FLAGS.job_name:
raise ValueError("Empty job_name")
server_def.job_name = FLAGS.job_name

# Task index
if FLAGS.task_id < 0:
raise ValueError("Invalid task_id: %d" % FLAGS.task_id)
server_def.task_index = FLAGS.task_id

config = config_pb2.ConfigProto(gpu_options=config_pb2.GPUOptions(
per_process_gpu_memory_fraction=FLAGS.gpu_memory_fraction))

# Create GRPC Server instance
server = server_lib.Server(server_def, config=config)

# join() is blocking, unlike start()
server.join()


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.register("type", "bool", lambda v: v.lower() == "true")
parser.add_argument(
"--cluster_spec",
type=str,
default="",
help="""\
Cluster spec: SPEC. SPEC is <JOB>(,<JOB>)*," JOB is
<NAME>|<HOST:PORT>(;<HOST:PORT>)*," NAME is a valid job name
([a-z][0-9a-z]*)," HOST is a hostname or IP address," PORT is a
port number." E.g., local|localhost:2222;localhost:2223,
ps|ps0:2222;ps1:2222\
"""
)
parser.add_argument(
"--job_name",
type=str,
default="",
help="Job name: e.g., local"
)
parser.add_argument(
"--task_id",
type=int,
default=0,
help="Task index, e.g., 0"
)
parser.add_argument(
"--gpu_memory_fraction",
type=float,
default=1.0,
help="Fraction of GPU memory allocated",)
parser.add_argument(
"--verbose",
type="bool",
nargs="?",
const=True,
default=False,
help="Verbose mode"
)

FLAGS, unparsed = parser.parse_known_args()
app.run(main=main, argv=[sys.argv[0]] + unparsed)
1 change: 1 addition & 0 deletions images/tf_operator/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ RUN mkdir -p /opt/mlkube
RUN mkdir -p /opt/mlkube/test
COPY tf_operator /opt/mlkube
COPY e2e /opt/mlkube/test
COPY grpc_tensorflow_server.py /opt/mlkube/grpc_tensorflow_server/grpc_tensorflow_server.py
RUN chmod a+x /opt/mlkube/tf_operator
RUN chmod a+x /opt/mlkube/test/e2e

Expand Down
12 changes: 10 additions & 2 deletions images/tf_operator/build_and_push.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@

def GetGitHash():
# The image tag is based on the githash.
git_hash = subprocess.check_output(["git", "rev-parse", "--short", "HEAD"])
git_hash = subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]).decode("utf-8")
git_hash=git_hash.strip()

modified_files = subprocess.check_output(["git", "ls-files", "--modified"])
untracked_files = subprocess.check_output(
["git", "ls-files", "--others", "--exclude-standard"])
Expand All @@ -23,7 +24,8 @@ def GetGitHash():
sha = hashlib.sha256()
sha.update(diff)
diffhash = sha.hexdigest()[0:7]
git_hash = "{0}-dirty-{1}".format(git_hash, diffhash)
git_hash = "{0}-dirty-{1}".format(git_hash, diffhash)

return git_hash

def run(command, cwd=None):
Expand Down Expand Up @@ -57,6 +59,9 @@ def run(command, cwd=None):
help="Use Google Container Builder to build the image.")
parser.add_argument("--no-gcb", dest="use_gcb", action="store_false",
help="Use Docker to build the image.")
parser.add_argument("--no-push", dest="should_push", action="store_false",
help="Do not push the image once build is finished.")

parser.set_defaults(use_gcb=False)

args = parser.parse_args()
Expand Down Expand Up @@ -90,6 +95,7 @@ def run(command, cwd=None):
"images/tf_operator/Dockerfile",
os.path.join(go_path, "bin/tf_operator"),
os.path.join(go_path, "bin/e2e"),
"grpc_tensorflow_server/grpc_tensorflow_server.py"
]

for s in sources:
Expand All @@ -112,6 +118,8 @@ def run(command, cwd=None):
else:
run(["docker", "build", "-t", image, context_dir])
logging.info("Built image: %s", image)

if args.should_push:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The if args.should_push statement should be inside the else block. We only push the image if we aren't using GCB.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed,

run(["gcloud", "docker", "--", "push", image])
logging.info("Pushed image: %s", image)

Expand Down
45 changes: 23 additions & 22 deletions pkg/controller/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,25 @@ import (
"errors"
"fmt"
"io"
"k8s.io/client-go/kubernetes"
"github.com/jlewi/mlkube.io/pkg/spec"
"github.com/jlewi/mlkube.io/pkg/trainer"
"github.com/jlewi/mlkube.io/pkg/util/k8sutil"
"net/http"
"reflect"
"sync"
"time"

apierrors "k8s.io/apimachinery/pkg/api/errors"
apiextensionsclient "k8s.io/apiextensions-apiserver/pkg/client/clientset/clientset"
v1beta1 "k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1beta1"
"github.com/jlewi/mlkube.io/pkg/spec"
"github.com/jlewi/mlkube.io/pkg/trainer"
"github.com/jlewi/mlkube.io/pkg/util/k8sutil"
"k8s.io/client-go/kubernetes"

log "github.com/golang/glog"
"github.com/jlewi/mlkube.io/pkg/util"
v1beta1 "k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1beta1"
apiextensionsclient "k8s.io/apiextensions-apiserver/pkg/client/clientset/clientset"
apierrors "k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
kwatch "k8s.io/apimachinery/pkg/watch"
k8sErrors "k8s.io/apimachinery/pkg/util/errors"
"k8s.io/apimachinery/pkg/util/wait"
"github.com/jlewi/mlkube.io/pkg/util"
kwatch "k8s.io/apimachinery/pkg/watch"
)

var (
Expand Down Expand Up @@ -66,7 +67,7 @@ func New(kubeCli kubernetes.Interface, apiCli apiextensionsclient.Interface, tfJ
return &Controller{
Namespace: ns,
KubeCli: kubeCli,
ApiCli: apiCli,
ApiCli: apiCli,
TfJobClient: tfJobClient,
// TODO(jlewi)): What to do about cluster.Cluster?
jobs: make(map[string]*trainer.TrainingJob),
Expand Down Expand Up @@ -147,7 +148,7 @@ func (c *Controller) handleTfJobEvent(event *Event) error {
//NewJob(kubeCli kubernetes.Interface, job spec.TfJob, stopC <-chan struct{}, wg *sync.WaitGroup)

c.stopChMap[clus.Metadata.Name] = stopC
c.jobs[clus.Metadata.Namespace + "-" + clus.Metadata.Name] = nc
c.jobs[clus.Metadata.Namespace+"-"+clus.Metadata.Name] = nc
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think there should be spaces after the plus signs. Maybe run gofmt?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I add this spaces back and run gofmt, it will remove them.
I am running go 1.9.1, could you have a different go version with different gofmt rules?

c.jobRVs[clus.Metadata.Name] = clus.Metadata.ResourceVersion

//case kwatch.Modified:
Expand All @@ -158,10 +159,10 @@ func (c *Controller) handleTfJobEvent(event *Event) error {
// c.jobRVs[clus.Metadata.Name] = clus.Metadata.ResourceVersion
//
case kwatch.Deleted:
if _, ok := c.jobs[clus.Metadata.Namespace + "-" + clus.Metadata.Name]; !ok {
if _, ok := c.jobs[clus.Metadata.Namespace+"-"+clus.Metadata.Name]; !ok {
return fmt.Errorf("unsafe state. TfJob was never created but we received event (%s)", event.Type)
}
c.jobs[clus.Metadata.Namespace + "-" + clus.Metadata.Name].Delete()
c.jobs[clus.Metadata.Namespace+"-"+clus.Metadata.Name].Delete()
delete(c.jobs, clus.Metadata.Name)
delete(c.jobRVs, clus.Metadata.Name)
}
Expand Down Expand Up @@ -193,7 +194,7 @@ func (c *Controller) findAllTfJobs() (string, error) {
continue
}
c.stopChMap[clus.Metadata.Name] = stopC
c.jobs[clus.Metadata.Namespace + "-" + clus.Metadata.Name] = nc
c.jobs[clus.Metadata.Namespace+"-"+clus.Metadata.Name] = nc
c.jobRVs[clus.Metadata.Name] = clus.Metadata.ResourceVersion
}

Expand Down Expand Up @@ -237,16 +238,16 @@ func (c *Controller) createCRD() error {
Name: spec.CRDName(),
},
Spec: v1beta1.CustomResourceDefinitionSpec{
Group: spec.CRDGroup,
Group: spec.CRDGroup,
Version: spec.CRDVersion,
Scope: v1beta1.NamespaceScoped,
Names: v1beta1.CustomResourceDefinitionNames{
Plural: spec.CRDKindPlural,
// TODO(jlewi): Do we want to set the singular name?
// Kind is the serialized kind of the resource. It is normally CamelCase and singular.
Kind: reflect.TypeOf(spec.TfJob{}).Name(),
},
Scope: v1beta1.NamespaceScoped,
Names: v1beta1.CustomResourceDefinitionNames{
Plural: spec.CRDKindPlural,
// TODO(jlewi): Do we want to set the singular name?
// Kind is the serialized kind of the resource. It is normally CamelCase and singular.
Kind: reflect.TypeOf(spec.TfJob{}).Name(),
},
},
}

_, err := c.ApiCli.ApiextensionsV1beta1().CustomResourceDefinitions().Create(crd)
Expand Down
3 changes: 3 additions & 0 deletions pkg/spec/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ type ControllerConfig struct {
// This should match the value specified as a container limit.
// e.g. alpha.kubernetes.io/nvidia-gpu
Accelerators map[string]AcceleratorConfig

// Path to the file containing the grpc server source
GrpcServerFilePath string
}

// AcceleratorVolume represents a host path that must be mounted into
Expand Down
Loading