Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
sakoush committed Oct 8, 2021
1 parent cac87fc commit 0659c47
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions runtimes/alibi-explain/mlserver_alibi_explain/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@

from mlserver.types import ResponseOutput, InferenceResponse, InferenceRequest


EXPLAINER_TYPE_TAG = "explainer_type"

_MAX_RETRY_ATTEMPT = 3

_ANCHOR_IMAGE_TAG = "anchor_image"
_ANCHOR_TEXT_TAG = "anchor_text"
_INTEGRATED_GRADIENTS_TAG = "integrated_gradients"
Expand All @@ -33,6 +38,7 @@
ENV_PREFIX_ALIBI_EXPLAIN_SETTINGS = "MLSERVER_MODEL_ALIBI_EXPLAIN_"
EXPLAIN_PARAMETERS_TAG = "explain_parameters"


class ExplainerEnum(str, Enum):
anchor_image = _ANCHOR_IMAGE_TAG
anchor_text = _ANCHOR_TEXT_TAG
Expand Down Expand Up @@ -62,7 +68,7 @@ def convert_from_bytes(output: ResponseOutput, ty: Optional[Type]) -> Any:
return literal_eval(py_str)


@retry(stop=stop_after_attempt(3))
@retry(stop=stop_after_attempt(_MAX_RETRY_ATTEMPT))
def remote_predict(v2_payload: InferenceRequest, predictor_url: str) -> InferenceResponse:
response_raw = requests.post(predictor_url, json=v2_payload.dict())
if response_raw.status_code != 200:
Expand Down Expand Up @@ -111,6 +117,3 @@ def import_and_get_class(class_path: str) -> type:
last_dot = class_path.rfind(".")
klass = getattr(import_module(class_path[:last_dot]), class_path[last_dot + 1:])
return klass


EXPLAINER_TYPE_TAG = "explainer_type"

0 comments on commit 0659c47

Please sign in to comment.