Skip to content

Commit

Permalink
Reformat code with black
Browse files Browse the repository at this point in the history
  • Loading branch information
Sebastian Stawicki committed May 18, 2021
1 parent e600275 commit d35d48c
Showing 1 changed file with 24 additions and 15 deletions.
39 changes: 24 additions & 15 deletions servers/tfserving_proxy/TfServingProxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,24 @@ class TfServingProxy(object):
"""

def __init__(
self,
rest_endpoint=None,
grpc_endpoint=None,
model_name=None,
signature_name=None,
model_input=None,
model_output=None):
self,
rest_endpoint=None,
grpc_endpoint=None,
model_name=None,
signature_name=None,
model_input=None,
model_output=None,
):
log.debug("rest_endpoint:", rest_endpoint)
log.debug("grpc_endpoint:", grpc_endpoint)

# grpc
max_msg = 1000000000
options = [('grpc.max_message_length', max_msg),
('grpc.max_send_message_length', max_msg),
('grpc.max_receive_message_length', max_msg)]
options = [
("grpc.max_message_length", max_msg),
("grpc.max_send_message_length", max_msg),
("grpc.max_receive_message_length", max_msg),
]
channel = grpc.insecure_channel(grpc_endpoint, options)
self.stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)

Expand All @@ -57,7 +60,9 @@ def predict_grpc(self, request):
log.debug("Preprocessing contents for predict function")
request_data_type = request.WhichOneof("data_oneof")
default_data_type = request.data.WhichOneof("data_oneof")
log.debug(f"Request data type: {request_data_type}, Default data type: {default_data_type}")
log.debug(
f"Request data type: {request_data_type}, Default data type: {default_data_type}"
)

if request_data_type not in ["data", "customData"]:
raise Exception("strData, binData and jsonData not supported.")
Expand All @@ -79,9 +84,8 @@ def _predict_grpc_data(self, request, default_data_type):
else:
data_arr = grpc_datadef_to_array(request.data)
tfrequest.inputs[self.model_input].CopyFrom(
tf.make_tensor_proto(
data_arr.tolist(),
shape=data_arr.shape))
tf.make_tensor_proto(data_arr.tolist(), shape=data_arr.shape)
)

# handle prediction
tfresponse = self._handle_grpc_prediction(tfrequest)
Expand Down Expand Up @@ -153,7 +157,12 @@ def predict(self, X, features_names=[]):
result = numpy.expand_dims(result, axis=0)
return result
else:
log.warning("Error from server: " + str(response) + " content: " + str(response.text))
log.warning(
"Error from server: "
+ str(response)
+ " content: "
+ str(response.text)
)
try:
return response.json()
except ValueError:
Expand Down

0 comments on commit d35d48c

Please sign in to comment.