Skip to content

Commit

Permalink
Merge pull request #3209 from qed-software/tfserving_proxy_custom_data
Browse files Browse the repository at this point in the history
Tfserving proxy custom data
  • Loading branch information
ukclivecox authored and RafalSkolasinski committed May 20, 2021
1 parent 501732c commit fd4527e
Showing 1 changed file with 89 additions and 46 deletions.
135 changes: 89 additions & 46 deletions servers/tfserving_proxy/TfServingProxy.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,49 @@
import json
import logging

import grpc
import numpy
import requests
import tensorflow as tf

from google.protobuf.any_pb2 import Any
from seldon_core.proto import prediction_pb2
from seldon_core.utils import grpc_datadef_to_array
from tensorflow.python.saved_model import signature_constants
from tensorflow_serving.apis import predict_pb2
from tensorflow_serving.apis import prediction_service_pb2_grpc
from seldon_core.utils import get_data_from_proto, array_to_grpc_datadef, json_to_seldon_message, grpc_datadef_to_array
from seldon_core.proto import prediction_pb2
from google.protobuf.json_format import ParseError

import requests
import json
import numpy as np

import logging

log = logging.getLogger()


class TfServingProxy(object):
"""
A basic tensorflow serving proxy
"""

def __init__(
self,
rest_endpoint=None,
grpc_endpoint=None,
model_name=None,
signature_name=None,
model_input=None,
model_output=None):
log.debug("rest_endpoint:",rest_endpoint)
log.debug("grpc_endpoint:",grpc_endpoint)
self,
rest_endpoint=None,
grpc_endpoint=None,
model_name=None,
signature_name=None,
model_input=None,
model_output=None,
):
log.debug("rest_endpoint:", rest_endpoint)
log.debug("grpc_endpoint:", grpc_endpoint)

# grpc
max_msg = 1000000000
options = [('grpc.max_message_length', max_msg),
('grpc.max_send_message_length', max_msg),
('grpc.max_receive_message_length', max_msg)]
channel = grpc.insecure_channel(grpc_endpoint,options)
options = [
("grpc.max_message_length", max_msg),
("grpc.max_send_message_length", max_msg),
("grpc.max_receive_message_length", max_msg),
]
channel = grpc.insecure_channel(grpc_endpoint, options)
self.stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)

# rest
self.rest_endpoint = rest_endpoint+"/v1/models/"+model_name+":predict"
self.rest_endpoint = rest_endpoint + "/v1/models/" + model_name + ":predict"
self.model_name = model_name
if signature_name is None:
self.signature_name = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
Expand All @@ -51,43 +52,81 @@ def __init__(
self.model_input = model_input
self.model_output = model_output

def predict_grpc(self,request):
def predict_grpc(self, request):
"""
predict_grpc will be called only when there is a GRPC call to the server
which in this case, the request will be sent to the TFServer directly.
"""
log.debug("Preprocessing contents for predict function")
request_data_type = request.WhichOneof("data_oneof")
default_data_type = request.data.WhichOneof("data_oneof")
log.debug(f"Request data type: {request_data_type}, Default data type: {default_data_type}")
log.debug(
f"Request data type: {request_data_type}, Default data type: {default_data_type}"
)

if request_data_type != "data":
if request_data_type not in ["data", "customData"]:
raise Exception("strData, binData and jsonData not supported.")

if request_data_type == "data":
result = self._predict_grpc_data(request, default_data_type)
else:
result = self._predict_grpc_custom_data(request)

return result

def _predict_grpc_data(self, request, default_data_type):
tfrequest = predict_pb2.PredictRequest()
tfrequest.model_spec.name = self.model_name
tfrequest.model_spec.signature_name = self.signature_name

# For GRPC case, if we have a TFTensor message we can pass it directly
# handle input
if default_data_type == "tftensor":
# For GRPC case, if we have a TFTensor message we can pass it directly
tfrequest.inputs[self.model_input].CopyFrom(request.data.tftensor)
result = self.stub.Predict(tfrequest)
datadef = prediction_pb2.DefaultData(
tftensor=result.outputs[self.model_output]
)
return prediction_pb2.SeldonMessage(data=datadef)

else:
data_arr = grpc_datadef_to_array(request.data)
tfrequest.inputs[self.model_input].CopyFrom(
tf.make_tensor_proto(
data_arr.tolist(),
shape=data_arr.shape))
result = self.stub.Predict(tfrequest)
datadef = prediction_pb2.DefaultData(
tftensor=result.outputs[self.model_output]
tf.make_tensor_proto(data_arr.tolist(), shape=data_arr.shape)
)
return prediction_pb2.SeldonMessage(data=datadef)

# handle prediction
tfresponse = self._handle_grpc_prediction(tfrequest)

# handle result
datadef = prediction_pb2.DefaultData(
tftensor=tfresponse.outputs[self.model_output]
)

return prediction_pb2.SeldonMessage(data=datadef)

def _predict_grpc_custom_data(self, request):
tfrequest = predict_pb2.PredictRequest()

# handle input
#
# Unpack custom data into tfrequest - taking raw inputs prepared by the user.
# This allows the use case when the model's input is not a single tftensor
# but a map of tensors like defined in predict.proto:
# PredictRequest.inputs: map<string, TensorProto>
request.customData.Unpack(tfrequest)

# handle prediction
tfresponse = self._handle_grpc_prediction(tfrequest)

# handle result
#
# Pack tfresponse into the SeldonMessage's custom data - letting user handle
# raw outputs. This allows the case when the model's output is not a single tftensor
# but a map of tensors like defined in predict.proto:
# PredictResponse: map<string, TensorProto>
custom_data = Any()
custom_data.Pack(tfresponse)
return prediction_pb2.SeldonMessage(customData=custom_data)

def _handle_grpc_prediction(self, tfrequest):
# handle prediction
tfrequest.model_spec.name = self.model_name
tfrequest.model_spec.signature_name = self.signature_name
tfresponse = self.stub.Predict(tfrequest)
return tfresponse

def predict(self, X, features_names=[]):
"""
Expand All @@ -99,7 +138,7 @@ def predict(self, X, features_names=[]):
data = X
else:
log.debug(f"Data Request: {X}")
data = {"instances":X.tolist()}
data = {"instances": X.tolist()}
if not self.signature_name is None:
data["signature_name"] = self.signature_name

Expand All @@ -118,9 +157,13 @@ def predict(self, X, features_names=[]):
result = numpy.expand_dims(result, axis=0)
return result
else:
log.warning("Error from server: "+ str(response) + " content: " + str(response.text))
log.warning(
"Error from server: "
+ str(response)
+ " content: "
+ str(response.text)
)
try:
return response.json()
except ValueError:
return response.text

0 comments on commit fd4527e

Please sign in to comment.