Skip to content

Commit

Permalink
Added tests for GAR with explainability
Browse files Browse the repository at this point in the history
Signed-off-by: Stefano Savare <[email protected]>
  • Loading branch information
deatinor committed Apr 23, 2020
1 parent 583819a commit 7bffc5f
Show file tree
Hide file tree
Showing 8 changed files with 77 additions and 19 deletions.
1 change: 1 addition & 0 deletions gtime/explainability/explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class _LimeExplainer(_RegressorExplainer):
>>> explainer.explanations_[0]
{'d': -0.10406889434277307, 'c': 0.07973507022816899, 'b': 0.02312395991550859, 'a': 0.006403509251399996, 'e': 0.006272607738125953}
"""

def fit(
self, model: RegressorMixin, X: np.ndarray, feature_names: List[str] = None
):
Expand Down
4 changes: 3 additions & 1 deletion gtime/forecasting/gar.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
from gtime.regressors.multi_output import MultiFeatureMultiOutputRegressor


def initialize_estimator(estimator: RegressorMixin, explainer_type: Optional[str]) -> RegressorMixin:
def initialize_estimator(
estimator: RegressorMixin, explainer_type: Optional[str]
) -> RegressorMixin:
if explainer_type is None:
return estimator
else:
Expand Down
31 changes: 30 additions & 1 deletion gtime/forecasting/tests/test_gar.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import itertools
import random
from typing import List

Expand Down Expand Up @@ -47,6 +48,34 @@
)


forecasters = [GAR, GARFF, MultiFeatureGAR]
explainers = [
"shap",
] # "lime"] for speed reason


@pytest.mark.parametrize(
"forecaster,explainer", itertools.product(forecasters, explainers)
)
@given(
X_y=X_y_matrices(
horizon=4,
df_transformer=df_transformer,
min_length=10,
allow_nan_infinity=False,
)
)
def test_predict_has_explainers(forecaster, explainer, X_y):
X, y = X_y
X_train, y_train, X_test, y_test = FeatureSplitter().transform(X, y)
model = forecaster(LinearRegression(), explainer_type=explainer)
model.fit(X_train, y_train)
model.predict(X_test.iloc[:1, :])
assert len(model.estimators_) == y_test.shape[1]
for estimator in model.estimators_:
assert len(estimator.explainer_.explanations_) == 1


@pytest.fixture
def time_series():
testing.N, testing.K = 200, 1
Expand Down Expand Up @@ -131,7 +160,7 @@ def test_initialize_estimator(estimator):

@given(models())
def test_initialize_estimator_explainable(estimator):
explainable_estimator = initialize_estimator(estimator, explainer_type='shap')
explainable_estimator = initialize_estimator(estimator, explainer_type="shap")
assert isinstance(explainable_estimator, ExplainableRegressor)
assert isinstance(explainable_estimator.explainer, _ShapExplainer)

Expand Down
8 changes: 5 additions & 3 deletions gtime/regressors/explainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ def __init__(self, estimator: RegressorMixin, explainer_type: str):
self.explainer = self._initialize_explainer()

def _check_estimator(self, estimator: RegressorMixin) -> RegressorMixin:
if not hasattr(estimator, 'fit') or not hasattr(estimator, 'predict'):
raise TypeError(f'Estimator not compatible: {estimator}')
if not hasattr(estimator, "fit") or not hasattr(estimator, "predict"):
raise TypeError(f"Estimator not compatible: {estimator}")
return estimator

def _initialize_explainer(self) -> Union[_LimeExplainer, _ShapExplainer]:
Expand Down Expand Up @@ -82,7 +82,9 @@ def fit(self, X: np.ndarray, y: np.ndarray, feature_names: List[str] = None):
Fitted `ExplainableRegressor`
"""
self.estimator_ = self.estimator.fit(X, y)
self.explainer_ = self.explainer.fit(self.estimator_, X, feature_names=feature_names)
self.explainer_ = self.explainer.fit(
self.estimator_, X, feature_names=feature_names
)
return self

def predict(self, X: np.ndarray):
Expand Down
5 changes: 2 additions & 3 deletions gtime/regressors/tests/test_explainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_constructor(self, estimator, explainer_type):
@given(estimator=regressors())
def test_constructor_bad_explainer(self, estimator):
with pytest.raises(ValueError):
ExplainableRegressor(estimator, 'bad')
ExplainableRegressor(estimator, "bad")

@pytest.mark.parametrize("explainer_type", ["lime", "shap"])
@given(bad_estimator=bad_regressors())
Expand Down Expand Up @@ -84,12 +84,11 @@ def test_fit_values(self, estimator, explainer_type, X_y):
)
def test_predict_values(self, estimator, explainer_type, X_y):
X, y = X_y
X_test = X[:2, :]
X_test = X[:1, :]
regressor = ExplainableRegressor(estimator, explainer_type)
regressor_predictions = regressor.fit(X, y).predict(X_test)

cloned_estimator = clone(estimator)
estimator_predictions = cloned_estimator.fit(X, y).predict(X_test)

assert regressor_predictions.shape == estimator_predictions.shape

38 changes: 32 additions & 6 deletions gtime/regressors/tests/test_multi_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,11 @@ def test_constructor(self, estimator):

@given(
data=data(),
X_y=numpy_X_y_matrices(X_y_shapes=shape_X_y_matrices(y_as_vector=False), min_value=-10000, max_value=10000),
X_y=numpy_X_y_matrices(
X_y_shapes=shape_X_y_matrices(y_as_vector=False),
min_value=-10000,
max_value=10000,
),
)
def test_fit_bad_y(self, data, estimator, X_y):
X, y = X_y
Expand All @@ -84,7 +88,13 @@ def test_fit_bad_y(self, data, estimator, X_y):
X, y, target_to_features_dict=target_to_feature_dict
)

@given(X_y=numpy_X_y_matrices(X_y_shapes=shape_X_y_matrices(y_as_vector=False), min_value=-10000, max_value=10000))
@given(
X_y=numpy_X_y_matrices(
X_y_shapes=shape_X_y_matrices(y_as_vector=False),
min_value=-10000,
max_value=10000,
)
)
def test_fit_as_multi_output_regressor_if_target_to_feature_none(
self, estimator, X_y
):
Expand All @@ -110,7 +120,11 @@ def test_error_predict_with_no_fit(self, estimator, X):

@given(
data=data(),
X_y=numpy_X_y_matrices(X_y_shapes=shape_X_y_matrices(y_as_vector=False), min_value=-10000, max_value=10000),
X_y=numpy_X_y_matrices(
X_y_shapes=shape_X_y_matrices(y_as_vector=False),
min_value=-10000,
max_value=10000,
),
)
def test_fit_target_to_feature_dict_working(self, data, X_y, estimator):
X, y = X_y
Expand All @@ -126,7 +140,11 @@ def test_fit_target_to_feature_dict_working(self, data, X_y, estimator):

@given(
data=data(),
X_y=numpy_X_y_matrices(X_y_shapes=shape_X_y_matrices(y_as_vector=False), min_value=-10000, max_value=10000),
X_y=numpy_X_y_matrices(
X_y_shapes=shape_X_y_matrices(y_as_vector=False),
min_value=-10000,
max_value=10000,
),
)
def test_fit_target_to_feature_dict_consistent(self, data, X_y, estimator):
X, y = X_y
Expand All @@ -147,7 +165,11 @@ def test_fit_target_to_feature_dict_consistent(self, data, X_y, estimator):

@given(
data=data(),
X_y=numpy_X_y_matrices(X_y_shapes=shape_X_y_matrices(y_as_vector=False), min_value=-10000, max_value=10000),
X_y=numpy_X_y_matrices(
X_y_shapes=shape_X_y_matrices(y_as_vector=False),
min_value=-10000,
max_value=10000,
),
)
def test_predict_target_to_feature_dict(self, data, X_y, estimator):
X, y = X_y
Expand All @@ -165,7 +187,11 @@ def test_predict_target_to_feature_dict(self, data, X_y, estimator):

@given(
data=data(),
X_y=numpy_X_y_matrices(X_y_shapes=shape_X_y_matrices(y_as_vector=False), min_value=-10000, max_value=10000),
X_y=numpy_X_y_matrices(
X_y_shapes=shape_X_y_matrices(y_as_vector=False),
min_value=-10000,
max_value=10000,
),
)
def test_error_predict_target_to_feature_dict_wrong_X_shape(
self, data, X_y, estimator
Expand Down
4 changes: 1 addition & 3 deletions gtime/utils/hypothesis/general_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@ def ordered_pair(min_value: int, max_value: int):


def shape_vector(min_shape=30, max_shape=200):
return tuples(
integers(min_shape, max_shape)
)
return tuples(integers(min_shape, max_shape))


def shape_matrix(min_shape_0=30, max_shape_0=200, min_shape_1=5, max_shape_1=10):
Expand Down
5 changes: 3 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
lime
pandas>=0.25.3
workalendar>=7.1.1
scipy>=0.17.0
scikit-learn>=0.22.0
matplotlib>=3.1.0
matplotlib>=3.1.0
lime>=0.2.0.0
shap>=0.35

0 comments on commit 7bffc5f

Please sign in to comment.