From f5a2abe43e2fe016dc0397a32751fc479ae35142 Mon Sep 17 00:00:00 2001 From: alexbao Date: Fri, 14 Jan 2022 06:00:26 -0800 Subject: [PATCH] initial commit for e2e pipeline demo (#1) Co-authored-by: alex.bao --- README.md | 6 +- demo/README.md | 29 ++++ demo/client.py | 47 ++++++ demo/generate_proto.sh | 7 + demo/go_gateway/mobius.pb.go | 262 ++++++++++++++++++++++++++++++++ demo/go_gateway/mobius.pb.gw.go | 120 +++++++++++++++ demo/network.py | 48 ++++++ demo/online_learning.py | 87 +++++++++++ demo/proto/__init__.py | 0 demo/proto/mobius.proto | 20 +++ demo/proto/mobius_pb2.py | 143 +++++++++++++++++ demo/proto/mobius_pb2_grpc.py | 63 ++++++++ demo/publish_msg.py | 45 ++++++ demo/server_gateway.go | 41 +++++ demo/server_rpc.py | 97 ++++++++++++ demo/train_local.py | 35 +++++ demo/util.py | 14 ++ 17 files changed, 1063 insertions(+), 1 deletion(-) create mode 100644 demo/README.md create mode 100644 demo/client.py create mode 100644 demo/generate_proto.sh create mode 100644 demo/go_gateway/mobius.pb.go create mode 100644 demo/go_gateway/mobius.pb.gw.go create mode 100644 demo/network.py create mode 100644 demo/online_learning.py create mode 100644 demo/proto/__init__.py create mode 100644 demo/proto/mobius.proto create mode 100644 demo/proto/mobius_pb2.py create mode 100644 demo/proto/mobius_pb2_grpc.py create mode 100644 demo/publish_msg.py create mode 100644 demo/server_gateway.go create mode 100644 demo/server_rpc.py create mode 100644 demo/train_local.py create mode 100644 demo/util.py diff --git a/README.md b/README.md index 1770b8b4..0d3216fe 100644 --- a/README.md +++ b/README.md @@ -1 +1,5 @@ -# mobius \ No newline at end of file +# Mobius online learning. + +## Code Editor +Recommend Atom (https://atom.io/) as the code editor. +Recommended plugins: atom-beautify, python-indent, auto-indent, vim-mode-plus (for VIM-ers) diff --git a/demo/README.md b/demo/README.md new file mode 100644 index 00000000..194f8125 --- /dev/null +++ b/demo/README.md @@ -0,0 +1,29 @@ +## Prerequisites +- Additional packages needed to run the demo code. +- Need to disable the main thread check manually. It will be located at line +1038 of file python2.7/site-packages/ray/worker.py (or similar path depending +on the OS and python version). There are ongoing efforts to support multiple +threading, hopefully it will be supported natively soon. + +Generate proto file +> $ bash generate_proto.sh + +Start local zookeeper service +> $ zkServer start + +Start kafka service +> $ kafka-server-start /usr/local/etc/kafka/server.properties + + +Start RPC server +> $ python server_rpc.py --kafka_server=localhost:9092 --kafka_topic=kafkaesque + +Start gateway server (easiest integration as go gateway generates most of the code) +Required if you'd like to send the request with PostMate. +> $ go run server_gateway.go + +Stream training data +> $ python publish_msg.py --kafka_server=localhost:9092 --kafka_topic=kafkaesque --num_to_send=10000 --sleep_every_N=100 --target_a=1 --target_b=3 + +Client script to send inference request and compute the mean squared error. +> $ python client.py --num_batch=100 --target_a=1 --target_b=3 diff --git a/demo/client.py b/demo/client.py new file mode 100644 index 00000000..c5db827e --- /dev/null +++ b/demo/client.py @@ -0,0 +1,47 @@ +import argparse +import grpc +import numpy +import time +import util + +from proto import mobius_pb2, mobius_pb2_grpc + +parser = argparse.ArgumentParser(description="Online learning service") + +parser.add_argument("--rpc_endpoint", default="localhost:50051", type=str, + help="RPC endpoint") +parser.add_argument("--target_a", default=0.1, type=float, + help="Target 'a' to simulate for y = ax + b") +parser.add_argument("--target_b", default=0.3, type=float, + help="Target 'b' to simulate for y = ax + b") +parser.add_argument("--batch_size", default=100, type=int, + help="Batch size for inference") +parser.add_argument("--num_batch", default=100, type=int, + help="Number of batches to send") +parser.add_argument("--batch_interval_seconds", default=1, type=int, + help="Interval seconds between batches") + + +if __name__ == "__main__": + args = parser.parse_args() + # open a gRPC channel + channel = grpc.insecure_channel(args.rpc_endpoint) + # create a stub (client) + stub = mobius_pb2_grpc.MobiusStub(channel) + + for i in range(args.num_batch): + x, y = util.generate_linear_x_y_data( + args.batch_size, args.target_a, args.target_b, + util.now_millis() / 1000) + # create a valid request message + request = mobius_pb2.InferRequest(x=x) + # make the call + response = stub.Infer(request) + + A = numpy.array(y) + B = numpy.array(response.y) + mean_error = ((A - B) ** 2).mean() + # print response + print 'batch', i, 'mean squared error', mean_error + + time.sleep(args.batch_interval_seconds) diff --git a/demo/generate_proto.sh b/demo/generate_proto.sh new file mode 100644 index 00000000..f3dc2121 --- /dev/null +++ b/demo/generate_proto.sh @@ -0,0 +1,7 @@ +# Generate pb file for Python RPC +python -m grpc_tools.protoc -Iproto -I/usr/local/include -I$GOPATH/src/github.com/grpc-ecosystem/grpc-gateway/third_party/googleapis --python_out=proto --grpc_python_out=proto mobius.proto + +# Generate pb file for Go gateway +protoc -I/usr/local/include -Iproto -I$GOPATH/src -I$GOPATH/src/github.com/grpc-ecosystem/grpc-gateway/third_party/googleapis --grpc-gateway_out=logtostderr=true:go_gateway mobius.proto +# Generate pb file for Go RPC +protoc -I/usr/local/include -Iproto -I$GOPATH/src/github.com/grpc-ecosystem/grpc-gateway/third_party/googleapis --go_out=plugins=grpc:go_gateway mobius.proto diff --git a/demo/go_gateway/mobius.pb.go b/demo/go_gateway/mobius.pb.go new file mode 100644 index 00000000..cb147e47 --- /dev/null +++ b/demo/go_gateway/mobius.pb.go @@ -0,0 +1,262 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// source: mobius.proto + +package mobius + +import proto "github.com/golang/protobuf/proto" +import fmt "fmt" +import math "math" +import _ "google.golang.org/genproto/googleapis/api/annotations" + +import ( + context "golang.org/x/net/context" + grpc "google.golang.org/grpc" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package + +type InferRequest struct { + X []float32 `protobuf:"fixed32,1,rep,packed,name=x" json:"x,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *InferRequest) Reset() { *m = InferRequest{} } +func (m *InferRequest) String() string { return proto.CompactTextString(m) } +func (*InferRequest) ProtoMessage() {} +func (*InferRequest) Descriptor() ([]byte, []int) { + return fileDescriptor_mobius_49b73b6599d79e00, []int{0} +} +func (m *InferRequest) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_InferRequest.Unmarshal(m, b) +} +func (m *InferRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_InferRequest.Marshal(b, m, deterministic) +} +func (dst *InferRequest) XXX_Merge(src proto.Message) { + xxx_messageInfo_InferRequest.Merge(dst, src) +} +func (m *InferRequest) XXX_Size() int { + return xxx_messageInfo_InferRequest.Size(m) +} +func (m *InferRequest) XXX_DiscardUnknown() { + xxx_messageInfo_InferRequest.DiscardUnknown(m) +} + +var xxx_messageInfo_InferRequest proto.InternalMessageInfo + +func (m *InferRequest) GetX() []float32 { + if m != nil { + return m.X + } + return nil +} + +type InferResponse struct { + Y []float32 `protobuf:"fixed32,1,rep,packed,name=y" json:"y,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *InferResponse) Reset() { *m = InferResponse{} } +func (m *InferResponse) String() string { return proto.CompactTextString(m) } +func (*InferResponse) ProtoMessage() {} +func (*InferResponse) Descriptor() ([]byte, []int) { + return fileDescriptor_mobius_49b73b6599d79e00, []int{1} +} +func (m *InferResponse) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_InferResponse.Unmarshal(m, b) +} +func (m *InferResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_InferResponse.Marshal(b, m, deterministic) +} +func (dst *InferResponse) XXX_Merge(src proto.Message) { + xxx_messageInfo_InferResponse.Merge(dst, src) +} +func (m *InferResponse) XXX_Size() int { + return xxx_messageInfo_InferResponse.Size(m) +} +func (m *InferResponse) XXX_DiscardUnknown() { + xxx_messageInfo_InferResponse.DiscardUnknown(m) +} + +var xxx_messageInfo_InferResponse proto.InternalMessageInfo + +func (m *InferResponse) GetY() []float32 { + if m != nil { + return m.Y + } + return nil +} + +func init() { + proto.RegisterType((*InferRequest)(nil), "InferRequest") + proto.RegisterType((*InferResponse)(nil), "InferResponse") +} + +// Reference imports to suppress errors if they are not otherwise used. +var _ context.Context +var _ grpc.ClientConn + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +const _ = grpc.SupportPackageIsVersion4 + +// MobiusClient is the client API for Mobius service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://godoc.org/google.golang.org/grpc#ClientConn.NewStream. +type MobiusClient interface { + Infer(ctx context.Context, in *InferRequest, opts ...grpc.CallOption) (*InferResponse, error) + InferStream(ctx context.Context, opts ...grpc.CallOption) (Mobius_InferStreamClient, error) +} + +type mobiusClient struct { + cc *grpc.ClientConn +} + +func NewMobiusClient(cc *grpc.ClientConn) MobiusClient { + return &mobiusClient{cc} +} + +func (c *mobiusClient) Infer(ctx context.Context, in *InferRequest, opts ...grpc.CallOption) (*InferResponse, error) { + out := new(InferResponse) + err := c.cc.Invoke(ctx, "/Mobius/Infer", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *mobiusClient) InferStream(ctx context.Context, opts ...grpc.CallOption) (Mobius_InferStreamClient, error) { + stream, err := c.cc.NewStream(ctx, &_Mobius_serviceDesc.Streams[0], "/Mobius/InferStream", opts...) + if err != nil { + return nil, err + } + x := &mobiusInferStreamClient{stream} + return x, nil +} + +type Mobius_InferStreamClient interface { + Send(*InferRequest) error + Recv() (*InferResponse, error) + grpc.ClientStream +} + +type mobiusInferStreamClient struct { + grpc.ClientStream +} + +func (x *mobiusInferStreamClient) Send(m *InferRequest) error { + return x.ClientStream.SendMsg(m) +} + +func (x *mobiusInferStreamClient) Recv() (*InferResponse, error) { + m := new(InferResponse) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +// MobiusServer is the server API for Mobius service. +type MobiusServer interface { + Infer(context.Context, *InferRequest) (*InferResponse, error) + InferStream(Mobius_InferStreamServer) error +} + +func RegisterMobiusServer(s *grpc.Server, srv MobiusServer) { + s.RegisterService(&_Mobius_serviceDesc, srv) +} + +func _Mobius_Infer_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(InferRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(MobiusServer).Infer(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/Mobius/Infer", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(MobiusServer).Infer(ctx, req.(*InferRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _Mobius_InferStream_Handler(srv interface{}, stream grpc.ServerStream) error { + return srv.(MobiusServer).InferStream(&mobiusInferStreamServer{stream}) +} + +type Mobius_InferStreamServer interface { + Send(*InferResponse) error + Recv() (*InferRequest, error) + grpc.ServerStream +} + +type mobiusInferStreamServer struct { + grpc.ServerStream +} + +func (x *mobiusInferStreamServer) Send(m *InferResponse) error { + return x.ServerStream.SendMsg(m) +} + +func (x *mobiusInferStreamServer) Recv() (*InferRequest, error) { + m := new(InferRequest) + if err := x.ServerStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +var _Mobius_serviceDesc = grpc.ServiceDesc{ + ServiceName: "Mobius", + HandlerType: (*MobiusServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "Infer", + Handler: _Mobius_Infer_Handler, + }, + }, + Streams: []grpc.StreamDesc{ + { + StreamName: "InferStream", + Handler: _Mobius_InferStream_Handler, + ServerStreams: true, + ClientStreams: true, + }, + }, + Metadata: "mobius.proto", +} + +func init() { proto.RegisterFile("mobius.proto", fileDescriptor_mobius_49b73b6599d79e00) } + +var fileDescriptor_mobius_49b73b6599d79e00 = []byte{ + // 179 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0xc9, 0xcd, 0x4f, 0xca, + 0x2c, 0x2d, 0xd6, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x97, 0x92, 0x49, 0xcf, 0xcf, 0x4f, 0xcf, 0x49, + 0xd5, 0x4f, 0x2c, 0xc8, 0xd4, 0x4f, 0xcc, 0xcb, 0xcb, 0x2f, 0x49, 0x2c, 0xc9, 0xcc, 0xcf, 0x83, + 0xca, 0x2a, 0xc9, 0x70, 0xf1, 0x78, 0xe6, 0xa5, 0xa5, 0x16, 0x05, 0xa5, 0x16, 0x96, 0xa6, 0x16, + 0x97, 0x08, 0xf1, 0x70, 0x31, 0x56, 0x48, 0x30, 0x2a, 0x30, 0x6b, 0x30, 0x05, 0x31, 0x56, 0x28, + 0xc9, 0x72, 0xf1, 0x42, 0x65, 0x8b, 0x0b, 0xf2, 0xf3, 0x8a, 0x53, 0x41, 0xd2, 0x95, 0x30, 0xe9, + 0x4a, 0xa3, 0x2a, 0x2e, 0x36, 0x5f, 0xb0, 0x55, 0x42, 0x36, 0x5c, 0xac, 0x60, 0x85, 0x42, 0xbc, + 0x7a, 0xc8, 0xc6, 0x49, 0xf1, 0xe9, 0xa1, 0xe8, 0x57, 0x12, 0x69, 0xba, 0xfc, 0x64, 0x32, 0x13, + 0x9f, 0x12, 0xa7, 0x7e, 0x99, 0xa1, 0x7e, 0x26, 0x48, 0xca, 0x8a, 0x51, 0x4b, 0xc8, 0x88, 0x8b, + 0x1b, 0xac, 0x2c, 0xb8, 0xa4, 0x28, 0x35, 0x31, 0x97, 0x90, 0x19, 0x0c, 0x1a, 0x8c, 0x06, 0x8c, + 0x49, 0x6c, 0x60, 0xf7, 0x1b, 0x03, 0x02, 0x00, 0x00, 0xff, 0xff, 0xea, 0x81, 0xbe, 0xd3, 0xed, + 0x00, 0x00, 0x00, +} diff --git a/demo/go_gateway/mobius.pb.gw.go b/demo/go_gateway/mobius.pb.gw.go new file mode 100644 index 00000000..6a4371ce --- /dev/null +++ b/demo/go_gateway/mobius.pb.gw.go @@ -0,0 +1,120 @@ +// Code generated by protoc-gen-grpc-gateway. DO NOT EDIT. +// source: mobius.proto + +/* +Package mobius is a reverse proxy. + +It translates gRPC into RESTful JSON APIs. +*/ +package mobius + +import ( + "io" + "net/http" + + "github.com/golang/protobuf/proto" + "github.com/grpc-ecosystem/grpc-gateway/runtime" + "github.com/grpc-ecosystem/grpc-gateway/utilities" + "golang.org/x/net/context" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/status" +) + +var _ codes.Code +var _ io.Reader +var _ status.Status +var _ = runtime.String +var _ = utilities.NewDoubleArray + +func request_Mobius_Infer_0(ctx context.Context, marshaler runtime.Marshaler, client MobiusClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var protoReq InferRequest + var metadata runtime.ServerMetadata + + if err := marshaler.NewDecoder(req.Body).Decode(&protoReq); err != nil && err != io.EOF { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + + msg, err := client.Infer(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) + return msg, metadata, err + +} + +// RegisterMobiusHandlerFromEndpoint is same as RegisterMobiusHandler but +// automatically dials to "endpoint" and closes the connection when "ctx" gets done. +func RegisterMobiusHandlerFromEndpoint(ctx context.Context, mux *runtime.ServeMux, endpoint string, opts []grpc.DialOption) (err error) { + conn, err := grpc.Dial(endpoint, opts...) + if err != nil { + return err + } + defer func() { + if err != nil { + if cerr := conn.Close(); cerr != nil { + grpclog.Printf("Failed to close conn to %s: %v", endpoint, cerr) + } + return + } + go func() { + <-ctx.Done() + if cerr := conn.Close(); cerr != nil { + grpclog.Printf("Failed to close conn to %s: %v", endpoint, cerr) + } + }() + }() + + return RegisterMobiusHandler(ctx, mux, conn) +} + +// RegisterMobiusHandler registers the http handlers for service Mobius to "mux". +// The handlers forward requests to the grpc endpoint over "conn". +func RegisterMobiusHandler(ctx context.Context, mux *runtime.ServeMux, conn *grpc.ClientConn) error { + return RegisterMobiusHandlerClient(ctx, mux, NewMobiusClient(conn)) +} + +// RegisterMobiusHandler registers the http handlers for service Mobius to "mux". +// The handlers forward requests to the grpc endpoint over the given implementation of "MobiusClient". +// Note: the gRPC framework executes interceptors within the gRPC handler. If the passed in "MobiusClient" +// doesn't go through the normal gRPC flow (creating a gRPC client etc.) then it will be up to the passed in +// "MobiusClient" to call the correct interceptors. +func RegisterMobiusHandlerClient(ctx context.Context, mux *runtime.ServeMux, client MobiusClient) error { + + mux.Handle("POST", pattern_Mobius_Infer_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + if cn, ok := w.(http.CloseNotifier); ok { + go func(done <-chan struct{}, closed <-chan bool) { + select { + case <-done: + case <-closed: + cancel() + } + }(ctx.Done(), cn.CloseNotify()) + } + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + rctx, err := runtime.AnnotateContext(ctx, mux, req) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := request_Mobius_Infer_0(rctx, inboundMarshaler, client, req, pathParams) + ctx = runtime.NewServerMetadataContext(ctx, md) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + + forward_Mobius_Infer_0(ctx, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + + }) + + return nil +} + +var ( + pattern_Mobius_Infer_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1}, []string{"v1", "infer"}, "")) +) + +var ( + forward_Mobius_Infer_0 = runtime.ForwardResponseMessage +) diff --git a/demo/network.py b/demo/network.py new file mode 100644 index 00000000..8a0f7a05 --- /dev/null +++ b/demo/network.py @@ -0,0 +1,48 @@ +import tensorflow as tf +import ray + + +class SimpleNetwork(object): + def __init__(self): + # Seed TensorFlow to make the script deterministic. + tf.set_random_seed(0) + # Define the inputs. + self.x_data = tf.placeholder(tf.float32, [None, 1]) + self.y_data = tf.placeholder(tf.float32, [None, 1]) + + # Define the weights and computation. + w = tf.Variable(tf.random_uniform([1], -1.0, 1.0)) + b = tf.Variable(tf.zeros([1])) + self.y = w * self.x_data + b + # Define the loss. + self.loss = tf.reduce_mean(tf.square(self.y - self.y_data)) + optimizer = tf.train.GradientDescentOptimizer(0.2) + self.grads = optimizer.compute_gradients(self.loss) + self.train = optimizer.apply_gradients(self.grads) + # Define the weight initializer and session. + init = tf.global_variables_initializer() + self.sess = tf.Session() + # Additional code for setting and getting the weights + self.variables = ray.experimental.TensorFlowVariables( + self.loss, self.sess) + # Return all of the data needed to use the network. + self.sess.run(init) + + # Define a remote function that trains the network for one step and + # returns the new weights. + def step(self, weights, x, y): + # Set the weights in the network. + self.set_weights(weights) + # Do one step of training. + self.sess.run(self.train, {self.x_data: x, self.y_data: y}) + # Return the new weights. + return self.variables.get_weights() + + def set_weights(self, weights): + return self.variables.set_weights(weights) + + def get_weights(self): + return self.variables.get_weights() + + def inference(self, x): + return self.sess.run(self.y, {self.x_data: x}) diff --git a/demo/online_learning.py b/demo/online_learning.py new file mode 100644 index 00000000..e7dd96aa --- /dev/null +++ b/demo/online_learning.py @@ -0,0 +1,87 @@ +from kafka import KafkaConsumer +import multiprocessing +import ray +import util + + +def train(train_actor, infer_actor, kafka_server, kafka_topic, batch_size, + max_batch_wait_millis, verbose): + weights = ray.get(train_actor.get_weights.remote()) + + consumer = KafkaConsumer(bootstrap_servers=kafka_server) + consumer.subscribe([kafka_topic]) + iteration = 0 + + while 1: + start = util.now_millis() + x_data = [] + y_data = [] + + while len(x_data) < batch_size: + max_wait = max_batch_wait_millis - (util.now_millis() - start) + if max_wait <= 0: + break + msg = consumer.poll(timeout_ms=max_wait) + + if len(msg) > 0: + for _, records in msg.iteritems(): + for record in records: + if not record.value: + continue + + parts = record.value.split() + if len(parts) != 5: + continue + + x, y = float(parts[3]), float(parts[4]) + x_data.append([x]) + y_data.append([y]) + + if len(x_data) > 0: + weights_id = ray.put(weights) + new_weights_id = train_actor.step.remote( + weights_id, x_data, y_data) + weights = ray.get(new_weights_id) + + iteration += 1 + if iteration % 10 == 0: + millis = int(parts[0]) + delay = util.now_millis() - millis + print("Training latency: {}".format(delay)) + print("Training iteration {}: weights are {}".format( + iteration, weights)) + + # TODO: move this to separate module and decouple from the + # training process. + if verbose: + print("Update weights for infer") + infer_actor.set_weights.remote(new_weights_id) + + +class Driver(object): + def __init__(self, remote_network, kafka_server, kafka_topic, batch_size, + max_batch_wait_millis, verbose): + self._network = remote_network + self._train_actor = remote_network.remote() + self._infer_actor = remote_network.remote() + self._kafka_server = kafka_server + self._kafka_topic = kafka_topic + self._batch_size = batch_size + self._max_batch_wait_millis = max_batch_wait_millis + self._verbose = verbose + + def start(self): + # Start training process + # TODO: load model from persistent storage. + # TODO: switch model with separate process. + # self._train_process = multiprocessing.Process( + # . . target=train, + # . . .args=(self._train_actor, self._infer_actor, kafka_server, + # kafka_topic, batch_size, max_batch_wait_millis)) + # self._train_process.start() + train(self._train_actor, self._infer_actor, self._kafka_server, + self._kafka_topic, self._batch_size, self._max_batch_wait_millis, + self._verbose) + + def infer(self, x_data): + return ray.get(self._infer_actor.inference.remote(x_data)) diff --git a/demo/proto/__init__.py b/demo/proto/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/demo/proto/mobius.proto b/demo/proto/mobius.proto new file mode 100644 index 00000000..66ddc6f1 --- /dev/null +++ b/demo/proto/mobius.proto @@ -0,0 +1,20 @@ +syntax = "proto3"; +import "google/api/annotations.proto"; + +message InferRequest { + repeated float x = 1; +} +message InferResponse { + repeated float y = 1; +} + +service Mobius { + rpc Infer(InferRequest) returns (InferResponse) { + option (google.api.http) = { + post: "/v1/infer" + body: "*" + }; + } + + rpc InferStream(stream InferRequest) returns (stream InferResponse) {} +} diff --git a/demo/proto/mobius_pb2.py b/demo/proto/mobius_pb2.py new file mode 100644 index 00000000..330fe4c9 --- /dev/null +++ b/demo/proto/mobius_pb2.py @@ -0,0 +1,143 @@ +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: mobius.proto + +import sys +_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +from google.protobuf import descriptor_pb2 +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from google.api import annotations_pb2 as google_dot_api_dot_annotations__pb2 + + +DESCRIPTOR = _descriptor.FileDescriptor( + name='mobius.proto', + package='', + syntax='proto3', + serialized_pb=_b('\n\x0cmobius.proto\x1a\x1cgoogle/api/annotations.proto\"\x19\n\x0cInferRequest\x12\t\n\x01x\x18\x01 \x03(\x02\"\x1a\n\rInferResponse\x12\t\n\x01y\x18\x01 \x03(\x02\x32z\n\x06Mobius\x12<\n\x05Infer\x12\r.InferRequest\x1a\x0e.InferResponse\"\x14\x82\xd3\xe4\x93\x02\x0e\"\t/v1/infer:\x01*\x12\x32\n\x0bInferStream\x12\r.InferRequest\x1a\x0e.InferResponse\"\x00(\x01\x30\x01\x62\x06proto3') + , + dependencies=[google_dot_api_dot_annotations__pb2.DESCRIPTOR,]) + + + + +_INFERREQUEST = _descriptor.Descriptor( + name='InferRequest', + full_name='InferRequest', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='x', full_name='InferRequest.x', index=0, + number=1, type=2, cpp_type=6, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=46, + serialized_end=71, +) + + +_INFERRESPONSE = _descriptor.Descriptor( + name='InferResponse', + full_name='InferResponse', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='y', full_name='InferResponse.y', index=0, + number=1, type=2, cpp_type=6, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=73, + serialized_end=99, +) + +DESCRIPTOR.message_types_by_name['InferRequest'] = _INFERREQUEST +DESCRIPTOR.message_types_by_name['InferResponse'] = _INFERRESPONSE +_sym_db.RegisterFileDescriptor(DESCRIPTOR) + +InferRequest = _reflection.GeneratedProtocolMessageType('InferRequest', (_message.Message,), dict( + DESCRIPTOR = _INFERREQUEST, + __module__ = 'mobius_pb2' + # @@protoc_insertion_point(class_scope:InferRequest) + )) +_sym_db.RegisterMessage(InferRequest) + +InferResponse = _reflection.GeneratedProtocolMessageType('InferResponse', (_message.Message,), dict( + DESCRIPTOR = _INFERRESPONSE, + __module__ = 'mobius_pb2' + # @@protoc_insertion_point(class_scope:InferResponse) + )) +_sym_db.RegisterMessage(InferResponse) + + + +_MOBIUS = _descriptor.ServiceDescriptor( + name='Mobius', + full_name='Mobius', + file=DESCRIPTOR, + index=0, + options=None, + serialized_start=101, + serialized_end=223, + methods=[ + _descriptor.MethodDescriptor( + name='Infer', + full_name='Mobius.Infer', + index=0, + containing_service=None, + input_type=_INFERREQUEST, + output_type=_INFERRESPONSE, + options=_descriptor._ParseOptions(descriptor_pb2.MethodOptions(), _b('\202\323\344\223\002\016\"\t/v1/infer:\001*')), + ), + _descriptor.MethodDescriptor( + name='InferStream', + full_name='Mobius.InferStream', + index=1, + containing_service=None, + input_type=_INFERREQUEST, + output_type=_INFERRESPONSE, + options=None, + ), +]) +_sym_db.RegisterServiceDescriptor(_MOBIUS) + +DESCRIPTOR.services_by_name['Mobius'] = _MOBIUS + +# @@protoc_insertion_point(module_scope) diff --git a/demo/proto/mobius_pb2_grpc.py b/demo/proto/mobius_pb2_grpc.py new file mode 100644 index 00000000..ca3b81f3 --- /dev/null +++ b/demo/proto/mobius_pb2_grpc.py @@ -0,0 +1,63 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +import grpc + +import mobius_pb2 as mobius__pb2 + + +class MobiusStub(object): + # missing associated documentation comment in .proto file + pass + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.Infer = channel.unary_unary( + '/Mobius/Infer', + request_serializer=mobius__pb2.InferRequest.SerializeToString, + response_deserializer=mobius__pb2.InferResponse.FromString, + ) + self.InferStream = channel.stream_stream( + '/Mobius/InferStream', + request_serializer=mobius__pb2.InferRequest.SerializeToString, + response_deserializer=mobius__pb2.InferResponse.FromString, + ) + + +class MobiusServicer(object): + # missing associated documentation comment in .proto file + pass + + def Infer(self, request, context): + # missing associated documentation comment in .proto file + pass + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def InferStream(self, request_iterator, context): + # missing associated documentation comment in .proto file + pass + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_MobiusServicer_to_server(servicer, server): + rpc_method_handlers = { + 'Infer': grpc.unary_unary_rpc_method_handler( + servicer.Infer, + request_deserializer=mobius__pb2.InferRequest.FromString, + response_serializer=mobius__pb2.InferResponse.SerializeToString, + ), + 'InferStream': grpc.stream_stream_rpc_method_handler( + servicer.InferStream, + request_deserializer=mobius__pb2.InferRequest.FromString, + response_serializer=mobius__pb2.InferResponse.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'Mobius', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) diff --git a/demo/publish_msg.py b/demo/publish_msg.py new file mode 100644 index 00000000..b07a374c --- /dev/null +++ b/demo/publish_msg.py @@ -0,0 +1,45 @@ +import argparse +import math +import time +import util + +from datetime import datetime +from kafka import KafkaClient +from kafka.producer import SimpleProducer + +parser = argparse.ArgumentParser(description="Online learning service") + +parser.add_argument("--kafka_server", default="localhost:9092", type=str, + help="Kafka server") +parser.add_argument("--kafka_topic", default="kafkaesque", type=str, + help="Kafka topic") +parser.add_argument("--num_to_send", default=100000, type=int, + help="Number of messages to send", ) +parser.add_argument("--sleep_every_N", default=1000, type=int, + help="Number of messages to send per second") +parser.add_argument("--target_a", default=0.1, type=float, + help="Target 'a' to simulate for y = ax + b") +parser.add_argument("--target_b", default=0.3, type=float, + help="Target 'b' to simulate for y = ax + b") + + +def publish_training_data(kafka_server, kafka_topic, num_to_send, sleep_every_N, target_a, target_b): + kafka = KafkaClient(kafka_server) + producer = SimpleProducer(kafka) + + for c in range(num_to_send): + x, y = util.generate_linear_x_y_data(1, target_a, target_b, c) + + # print "%d %.20f %.20f" % (int(round(time.time() * 1000)), x[0], y[0]) + producer.send_messages( + kafka_topic, + "%d %.5f %.5f %.20f %.20f" % (util.now_millis(), target_a, target_b, x[0], y[0])) + if (c + 1) % sleep_every_N == 0: + time.sleep(1) + + +if __name__ == "__main__": + args = parser.parse_args() + publish_training_data(args.kafka_server, args.kafka_topic, + args.num_to_send, args.sleep_every_N, + args.target_a, args.target_b) diff --git a/demo/server_gateway.go b/demo/server_gateway.go new file mode 100644 index 00000000..5459c84c --- /dev/null +++ b/demo/server_gateway.go @@ -0,0 +1,41 @@ +package main + +import ( + "flag" + "net/http" + + "github.com/golang/glog" + "github.com/grpc-ecosystem/grpc-gateway/runtime" + "golang.org/x/net/context" + "google.golang.org/grpc" + + gw "github.com/alexbao/mobius/demo/go_gateway" +) + +var ( + rpcEndpoint = flag.String("rpc_endpoint", "localhost:50051", "endpoint of YourService") +) + +func run() error { + ctx := context.Background() + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + mux := runtime.NewServeMux() + opts := []grpc.DialOption{grpc.WithInsecure()} + err := gw.RegisterMobiusHandlerFromEndpoint(ctx, mux, *rpcEndpoint, opts) + if err != nil { + return err + } + + return http.ListenAndServe(":8080", mux) +} + +func main() { + flag.Parse() + defer glog.Flush() + + if err := run(); err != nil { + glog.Fatal(err) + } +} diff --git a/demo/server_rpc.py b/demo/server_rpc.py new file mode 100644 index 00000000..925cc88f --- /dev/null +++ b/demo/server_rpc.py @@ -0,0 +1,97 @@ +import argparse +import grpc +import ray +import time + +from concurrent import futures + +# import the generated classes +import proto +from proto import mobius_pb2, mobius_pb2_grpc + +import network +import online_learning +import util + +parser = argparse.ArgumentParser(description="Online learning service") + +parser.add_argument("--port", default=50051, type=int, + help="Port for RPC service") +parser.add_argument("--redis_address", default=None, type=str, + help="Redis service address") +parser.add_argument("--kafka_server", default="localhost:9092", type=str, + help="Kafka server") +parser.add_argument("--kafka_topic", default="kafkaesque", type=str, + help="Kafka topic") +parser.add_argument("--batch_size", default=100, type=int, + help="Batch size") +parser.add_argument("--max_batch_wait_millis", default=100, type=int, + help="Max wait millis per batch") +parser.add_argument("--verbose", default="False", type=str, + help="detailed message if true") + + +# create a class to define the server functions +# derived from mobius_pb2_grpc.MobiusServicer +class MobiusServicer(mobius_pb2_grpc.MobiusServicer): + def __init__(self, driver, verbose): + self._driver = driver + self._verbose = verbose + + # mobius.Infer is exposed here + # the request and response are of the data types + # generated as mobius_pb2.InferRequest and mobius_pb2.InferResponse. + def Infer(self, request, context): + x_data = [[x] for x in request.x] + if self._verbose: + print x_data + y_data = self._driver.infer(x_data) + if self._verbose: + print y_data + + response = mobius_pb2.InferResponse() + for ele in y_data: + response.y.extend(ele) + if self._verbose: + print response + return response + + +if __name__ == "__main__": + args = parser.parse_args() + if args.redis_address: + ray.init(redis_address=args.redis_address) + else: + ray.init() + + verbose = "true" == args.verbose.lower() + print verbose + remote_network = ray.remote(network.SimpleNetwork) + driver = online_learning.Driver( + remote_network, args.kafka_server, args.kafka_topic, + args.batch_size, args.max_batch_wait_millis, verbose) + train_actor = remote_network.remote() + + # create a gRPC server + server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) + + # use the generated function `add_MobiusServicer_to_server` + # to add the defined class to the created server + mobius_pb2_grpc.add_MobiusServicer_to_server( + MobiusServicer(driver, verbose), server) + + # listen on port 50051 + print('Starting server. Listening on port %d.' % args.port) + server.add_insecure_port('[::]:%d' % args.port) + server.start() + + print('Starting training') + driver.start() + + # since server.start() will not block, + # a sleep-loop is added to keep alive + try: + while True: + time.sleep(86400) + except KeyboardInterrupt: + server.stop(0) diff --git a/demo/train_local.py b/demo/train_local.py new file mode 100644 index 00000000..94d7515f --- /dev/null +++ b/demo/train_local.py @@ -0,0 +1,35 @@ +import argparse +import network +import online_learning +import ray +import util + +from kafka import KafkaConsumer + +parser = argparse.ArgumentParser(description="Online learning service") + +parser = argparse.ArgumentParser(description="Online learning service") + +parser.add_argument("--kafka_server", default="localhost:9092", type=str, + help="Kafka server") +parser.add_argument("--kafka_topic", default="kafkaesque", type=str, + help="Kafka topic") +parser.add_argument("--batch_size", default=100, type=int, + help="Batch size") +parser.add_argument("--max_batch_wait_millis", default=100, type=int, + help="Max wait millis per batch") +parser.add_argument("--verbose", default="False", type=str, + help="detailed message if true") + + +if __name__ == "__main__": + ray.init() + args = parser.parse_args() + remote_network = ray.remote(network.SimpleNetwork) + train_actor = remote_network.remote() + infer_actor = remote_network.remote() + verbose = "true" == args.verbose.lower() + online_learning.train( + train_actor, infer_actor, + args.kafka_server, args.kafka_topic, + args.batch_size, args.max_batch_wait_millis, verbose) diff --git a/demo/util.py b/demo/util.py new file mode 100644 index 00000000..53d896d1 --- /dev/null +++ b/demo/util.py @@ -0,0 +1,14 @@ +import numpy +import time + + +def now_millis(): + return int(round(time.time() * 1000)) + + +def generate_linear_x_y_data(num_data, a=0.1, b=0.3, seed=0): + # Seed numpy to make the script deterministic. + numpy.random.seed(seed) + x = numpy.random.rand(num_data) + y = x * a + b + return x, y