Skip to content

Commit

Permalink
code review suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
RafalSkolasinski authored and seldondev committed Sep 18, 2020
1 parent 8be23a8 commit 31a5fad
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 171 deletions.
107 changes: 32 additions & 75 deletions python/seldon_core/seldon_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
)
from seldon_core.user_model import (
INCLUDE_METRICS_IN_CLIENT_RESPONSE,
SeldonPrediction,
ClientResponse,
client_predict,
client_aggregate,
client_route,
Expand Down Expand Up @@ -65,19 +65,6 @@ def handle_raw_custom_metrics(
seldon_metrics.update(metrics, method)


def extract_runtime_data(
client_response: Union[np.ndarray, List, str, bytes, SeldonPrediction]
) -> Tuple[Union[np.ndarray, List, str, bytes], List[Dict], Dict]:
"""Extracts runtime data from client response."""
if not isinstance(client_response, SeldonPrediction):
return client_response, [], {}

metrics = client_response.metrics if client_response.metrics is not None else []
tags = client_response.tags if client_response.tags is not None else {}

return client_response.data, metrics, tags


def predict(
user_model: Any,
request: Union[prediction_pb2.SeldonMessage, List, Dict],
Expand Down Expand Up @@ -118,12 +105,11 @@ def predict(

if is_proto:
(features, meta, datadef, data_type) = extract_request_parts(request)
client_response = client_predict(

client_response, runtime_metrics, runtime_tags = client_predict(
user_model, features, datadef.names, meta=meta
)
client_response, runtime_metrics, runtime_tags = extract_runtime_data(
client_response
)

metrics = client_custom_metrics(
user_model, seldon_metrics, PREDICT_METRIC_METHOD_TAG, runtime_metrics,
)
Expand All @@ -134,12 +120,11 @@ def predict(
else:
(features, meta, datadef, data_type) = extract_request_parts_json(request)
class_names = datadef["names"] if datadef and "names" in datadef else []
client_response = client_predict(

client_response, runtime_metrics, runtime_tags = client_predict(
user_model, features, class_names, meta=meta
)
client_response, runtime_metrics, runtime_tags = extract_runtime_data(
client_response
)

metrics = client_custom_metrics(
user_model, seldon_metrics, PREDICT_METRIC_METHOD_TAG, runtime_metrics,
)
Expand Down Expand Up @@ -197,12 +182,11 @@ def send_feedback(
request
)
routing = request.response.meta.routing.get(predictive_unit_id)
client_response = client_send_feedback(

client_response, runtime_metrics, runtime_tags = client_send_feedback(
user_model, features, datadef_request.names, reward, truth, routing
)
client_response, runtime_metrics, runtime_tags = extract_runtime_data(
client_response
)

metrics = client_custom_metrics(
user_model, seldon_metrics, FEEDBACK_METRIC_METHOD_TAG, runtime_metrics,
)
Expand Down Expand Up @@ -270,12 +254,11 @@ def transform_input(

if is_proto:
(features, meta, datadef, data_type) = extract_request_parts(request)
client_response = client_transform_input(

client_response, runtime_metrics, runtime_tags = client_transform_input(
user_model, features, datadef.names, meta=meta
)
client_response, runtime_metrics, runtime_tags = extract_runtime_data(
client_response
)

metrics = client_custom_metrics(
user_model,
seldon_metrics,
Expand All @@ -289,12 +272,11 @@ def transform_input(
else:
(features, meta, datadef, data_type) = extract_request_parts_json(request)
class_names = datadef["names"] if datadef and "names" in datadef else []
client_response = client_transform_input(

client_response, runtime_metrics, runtime_tags = client_transform_input(
user_model, features, class_names, meta=meta
)
client_response, runtime_metrics, runtime_tags = extract_runtime_data(
client_response
)

metrics = client_custom_metrics(
user_model,
seldon_metrics,
Expand Down Expand Up @@ -354,12 +336,11 @@ def transform_output(

if is_proto:
(features, meta, datadef, data_type) = extract_request_parts(request)
client_response = client_transform_output(

client_response, runtime_metrics, runtime_tags = client_transform_output(
user_model, features, datadef.names, meta=meta
)
client_response, runtime_metrics, runtime_tags = extract_runtime_data(
client_response
)

metrics = client_custom_metrics(
user_model,
seldon_metrics,
Expand All @@ -379,12 +360,11 @@ def transform_output(
else:
(features, meta, datadef, data_type) = extract_request_parts_json(request)
class_names = datadef["names"] if datadef and "names" in datadef else []
client_response = client_transform_output(

client_response, runtime_metrics, runtime_tags = client_transform_output(
user_model, features, class_names, meta=meta
)
client_response, runtime_metrics, runtime_tags = extract_runtime_data(
client_response
)

metrics = client_custom_metrics(
user_model,
seldon_metrics,
Expand All @@ -403,7 +383,6 @@ def route(
seldon_metrics: SeldonMetrics,
) -> Union[prediction_pb2.SeldonMessage, List, Dict]:
"""
Parameters
----------
user_model
Expand All @@ -412,10 +391,8 @@ def route(
A SelodonMessage proto
seldon_metrics
A SeldonMetrics instance
Returns
-------
"""
is_proto = isinstance(request, prediction_pb2.SeldonMessage)

Expand All @@ -441,55 +418,35 @@ def route(
client_response = client_route(
user_model, features, datadef.names, meta=meta
)
client_response, runtime_metrics, runtime_tags = extract_runtime_data(
client_response
)
if not isinstance(client_response, int):
raise SeldonMicroserviceException(
"Routing response must be int but got " + str(client_response)
)

client_response_arr = np.array([[client_response]])

metrics = client_custom_metrics(
user_model, seldon_metrics, ROUTER_METRIC_METHOD_TAG, runtime_metrics,
user_model, seldon_metrics, ROUTER_METRIC_METHOD_TAG
)

return construct_response(
user_model,
False,
request,
client_response_arr,
None,
metrics,
runtime_tags,
user_model, False, request, client_response_arr, None, metrics
)
else:
(features, meta, datadef, data_type) = extract_request_parts_json(request)
class_names = datadef["names"] if datadef and "names" in datadef else []
client_response = client_route(user_model, features, class_names, meta=meta)
client_response, runtime_metrics, runtime_tags = extract_runtime_data(
client_response
)
if not isinstance(client_response, int):
raise SeldonMicroserviceException(
"Routing response must be int but got " + str(client_response)
)

client_response_arr = np.array([[client_response]])

metrics = client_custom_metrics(
user_model, seldon_metrics, ROUTER_METRIC_METHOD_TAG, runtime_metrics
user_model, seldon_metrics, ROUTER_METRIC_METHOD_TAG
)

return construct_response_json(
user_model,
False,
request,
client_response_arr,
None,
metrics,
runtime_tags,
user_model, False, request, client_response_arr, None, metrics
)


Expand Down Expand Up @@ -561,10 +518,10 @@ def merge_metrics(meta_list, custom_metrics):
names_list.append(datadef.names)
meta_list.append(meta)

client_response = client_aggregate(user_model, features_list, names_list)
client_response, runtime_metrics, runtime_tags = extract_runtime_data(
client_response
client_response, runtime_metrics, runtime_tags = client_aggregate(
user_model, features_list, names_list
)

metrics = client_custom_metrics(
user_model,
seldon_metrics,
Expand Down Expand Up @@ -604,10 +561,10 @@ def merge_metrics(meta_list, custom_metrics):
names_list.append(class_names)
meta_list.append(meta)

client_response = client_aggregate(user_model, features_list, names_list)
client_response, runtime_metrics, runtime_tags = extract_runtime_data(
client_response
client_response, runtime_metrics, runtime_tags = client_aggregate(
user_model, features_list, names_list
)

metrics = client_custom_metrics(
user_model,
seldon_metrics,
Expand Down
Loading

0 comments on commit 31a5fad

Please sign in to comment.