-
Notifications
You must be signed in to change notification settings - Fork 835
/
wrapper.py
142 lines (113 loc) · 5.3 KB
/
wrapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import grpc
from concurrent import futures
from flask import jsonify, Flask, send_from_directory, request
from flask_cors import CORS
import logging
from seldon_core.utils import json_to_seldon_message, seldon_message_to_json, json_to_feedback, json_to_seldon_messages
from seldon_core.flask_utils import get_request
import seldon_core.seldon_methods
from seldon_core.flask_utils import SeldonMicroserviceException, ANNOTATION_GRPC_MAX_MSG_SIZE
from seldon_core.proto import prediction_pb2_grpc
import os
logger = logging.getLogger(__name__)
PRED_UNIT_ID = os.environ.get("PREDICTIVE_UNIT_ID", "0")
def get_rest_microservice(user_model):
app = Flask(__name__, static_url_path='')
CORS(app)
if hasattr(user_model, 'model_error_handler'):
logger.info('Registering the custom error handler...')
app.register_blueprint(user_model.model_error_handler)
@app.errorhandler(SeldonMicroserviceException)
def handle_invalid_usage(error):
response = jsonify(error.to_dict())
logger.error("%s", error.to_dict())
response.status_code = error.status_code
return response
@app.route("/seldon.json", methods=["GET"])
def openAPI():
return send_from_directory('', "openapi/seldon.json")
@app.route("/predict", methods=["GET", "POST"])
def Predict():
requestJson = get_request()
logger.debug("REST Request: %s", request)
response = seldon_core.seldon_methods.predict(user_model, requestJson)
logger.debug("REST Response: %s", response)
return jsonify(response)
@app.route("/send-feedback", methods=["GET", "POST"])
def SendFeedback():
requestJson = get_request()
logger.debug("REST Request: %s", request)
requestProto = json_to_feedback(requestJson)
logger.debug("Proto Request: %s", requestProto)
responseProto = seldon_core.seldon_methods.send_feedback(user_model, requestProto, PRED_UNIT_ID)
jsonDict = seldon_message_to_json(responseProto)
return jsonify(jsonDict)
@app.route("/transform-input", methods=["GET", "POST"])
def TransformInput():
requestJson = get_request()
logger.debug("REST Request: %s", request)
response = seldon_core.seldon_methods.transform_input(
user_model, requestJson)
logger.debug("REST Response: %s", response)
return jsonify(response)
@app.route("/transform-output", methods=["GET", "POST"])
def TransformOutput():
requestJson = get_request()
logger.debug("REST Request: %s", request)
response = seldon_core.seldon_methods.transform_output(
user_model, requestJson)
logger.debug("REST Response: %s", response)
return jsonify(response)
@app.route("/route", methods=["GET", "POST"])
def Route():
requestJson = get_request()
logger.debug("REST Request: %s", request)
response = seldon_core.seldon_methods.route(
user_model, requestJson)
logger.debug("REST Response: %s", response)
return jsonify(response)
@app.route("/aggregate", methods=["GET", "POST"])
def Aggregate():
requestJson = get_request()
logger.debug("REST Request: %s", request)
response = seldon_core.seldon_methods.aggregate(
user_model, requestJson)
logger.debug("REST Response: %s", response)
return jsonify(response)
return app
# ----------------------------
# GRPC
# ----------------------------
class SeldonModelGRPC(object):
def __init__(self, user_model):
self.user_model = user_model
def Predict(self, request_grpc, context):
return seldon_core.seldon_methods.predict(self.user_model, request_grpc)
def SendFeedback(self, feedback_grpc, context):
return seldon_core.seldon_methods.send_feedback(self.user_model, feedback_grpc, PRED_UNIT_ID)
def TransformInput(self, request_grpc, context):
return seldon_core.seldon_methods.transform_input(self.user_model, request_grpc)
def TransformOutput(self, request_grpc, context):
return seldon_core.seldon_methods.transform_output(self.user_model, request_grpc)
def Route(self, request_grpc, context):
return seldon_core.seldon_methods.route(self.user_model, request_grpc)
def Aggregate(self, request_grpc, context):
return seldon_core.seldon_methods.aggregate(self.user_model, request_grpc)
def get_grpc_server(user_model, annotations={}, trace_interceptor=None):
seldon_model = SeldonModelGRPC(user_model)
options = []
if ANNOTATION_GRPC_MAX_MSG_SIZE in annotations:
max_msg = int(annotations[ANNOTATION_GRPC_MAX_MSG_SIZE])
logger.info(
"Setting grpc max message and receive length to %d", max_msg)
options.append(('grpc.max_message_length', max_msg))
options.append(('grpc.max_send_message_length', max_msg))
options.append(('grpc.max_receive_message_length', max_msg))
server = grpc.server(futures.ThreadPoolExecutor(
max_workers=10), options=options)
if trace_interceptor:
from grpc_opentracing.grpcext import intercept_server
server = intercept_server(server, trace_interceptor)
prediction_pb2_grpc.add_GenericServicer_to_server(seldon_model, server)
prediction_pb2_grpc.add_ModelServicer_to_server(seldon_model, server)
return server