diff --git a/components/alibi-explain-server/alibiexplainer/anchor_images.py b/components/alibi-explain-server/alibiexplainer/anchor_images.py index 179e8c1f16..937ba63e48 100644 --- a/components/alibi-explain-server/alibiexplainer/anchor_images.py +++ b/components/alibi-explain-server/alibiexplainer/anchor_images.py @@ -25,7 +25,11 @@ import numpy as np from alibi.api.interfaces import Explanation -from alibiexplainer.constants import SELDON_LOGLEVEL +from alibiexplainer.constants import ( + EXPLAIN_RANDOM_SEED, + EXPLAIN_RANDOM_SEED_VALUE, + SELDON_LOGLEVEL, +) from alibiexplainer.explainer_wrapper import ExplainerWrapper logging.basicConfig(level=SELDON_LOGLEVEL) @@ -38,9 +42,13 @@ def __init__( if explainer is None: raise Exception("Anchor images requires a built explainer") self.anchors_image = explainer + if EXPLAIN_RANDOM_SEED == "True" and EXPLAIN_RANDOM_SEED_VALUE.isdigit(): + self.seed = int(EXPLAIN_RANDOM_SEED_VALUE) self.kwargs = kwargs def explain(self, inputs: List) -> Explanation: + if self.seed: + np.random.seed(self.seed) arr = np.array(inputs) logging.info("Calling explain on image of shape %s", (arr.shape,)) logging.info("anchor image call with %s", self.kwargs) diff --git a/components/alibi-explain-server/alibiexplainer/anchor_tabular.py b/components/alibi-explain-server/alibiexplainer/anchor_tabular.py index 05370794cb..9907a5d0c3 100644 --- a/components/alibi-explain-server/alibiexplainer/anchor_tabular.py +++ b/components/alibi-explain-server/alibiexplainer/anchor_tabular.py @@ -25,7 +25,11 @@ import numpy as np from alibi.api.interfaces import Explanation -from alibiexplainer.constants import SELDON_LOGLEVEL +from alibiexplainer.constants import ( + EXPLAIN_RANDOM_SEED, + EXPLAIN_RANDOM_SEED_VALUE, + SELDON_LOGLEVEL, +) from alibiexplainer.explainer_wrapper import ExplainerWrapper logging.basicConfig(level=SELDON_LOGLEVEL) @@ -37,9 +41,13 @@ def __init__(self, explainer=Optional[alibi.explainers.AnchorTabular], **kwargs) raise Exception("Anchor images requires a built explainer") self.anchors_tabular: alibi.explainers.AnchorTabular = explainer self.anchors_tabular = explainer + if EXPLAIN_RANDOM_SEED == "True" and EXPLAIN_RANDOM_SEED_VALUE.isdigit(): + self.seed = int(EXPLAIN_RANDOM_SEED_VALUE) self.kwargs = kwargs def explain(self, inputs: List) -> Explanation: + if self.seed: + np.random.seed(self.seed) arr = np.array(inputs) # We assume the input has batch dimension # but Alibi explainers presently assume no batch diff --git a/components/alibi-explain-server/alibiexplainer/anchor_text.py b/components/alibi-explain-server/alibiexplainer/anchor_text.py index 1be0de8f1c..cea1b9e7b0 100644 --- a/components/alibi-explain-server/alibiexplainer/anchor_text.py +++ b/components/alibi-explain-server/alibiexplainer/anchor_text.py @@ -22,11 +22,16 @@ from typing import Callable, List, Optional import alibi +import numpy as np import spacy from alibi.api.interfaces import Explanation from alibi.utils.download import spacy_model -from alibiexplainer.constants import SELDON_LOGLEVEL +from alibiexplainer.constants import ( + EXPLAIN_RANDOM_SEED, + EXPLAIN_RANDOM_SEED_VALUE, + SELDON_LOGLEVEL, +) from alibiexplainer.explainer_wrapper import ExplainerWrapper logging.basicConfig(level=SELDON_LOGLEVEL) @@ -41,6 +46,8 @@ def __init__( **kwargs ): self.predict_fn = predict_fn + if EXPLAIN_RANDOM_SEED == "True" and EXPLAIN_RANDOM_SEED_VALUE.isdigit(): + self.seed = int(EXPLAIN_RANDOM_SEED_VALUE) self.kwargs = kwargs logging.info("Anchor Text args %s", self.kwargs) if explainer is None: @@ -55,5 +62,7 @@ def __init__( self.anchors_text = explainer def explain(self, inputs: List) -> Explanation: + if self.seed: + np.random.seed(self.seed) anchor_exp = self.anchors_text.explain(inputs[0], **self.kwargs) return anchor_exp diff --git a/components/alibi-explain-server/alibiexplainer/constants.py b/components/alibi-explain-server/alibiexplainer/constants.py index 98e8c651c7..c2dfc25318 100644 --- a/components/alibi-explain-server/alibiexplainer/constants.py +++ b/components/alibi-explain-server/alibiexplainer/constants.py @@ -1,3 +1,5 @@ import os SELDON_LOGLEVEL = os.environ.get("SELDON_LOGLEVEL", "INFO").upper() +EXPLAIN_RANDOM_SEED = os.environ.get("EXPLAIN_RANDOM_SEED", "True") +EXPLAIN_RANDOM_SEED_VALUE = os.environ.get("EXPLAIN_RANDOM_SEED_VALUE", 0)