Skip to content

Commit

Permalink
Xgb bugfix (#267)
Browse files Browse the repository at this point in the history
* fixes #250 and extents tests around xgb models

* integrates intercept of xgb models directly into values of the TreeModel
  • Loading branch information
mmschlk authored Nov 7, 2024
1 parent 0b02599 commit bdf9ce9
Show file tree
Hide file tree
Showing 8 changed files with 219 additions and 20 deletions.
9 changes: 9 additions & 0 deletions shapiq/explainer/tree/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,10 @@ def __post_init__(self) -> None:
self.leaf_mask = np.asarray(self.children_left == -1)
# sanitize features
self.features = np.where(self.leaf_mask, -2, self.features)
self.features = self.features.astype(int) # make features integer type
# sanitize thresholds
self.thresholds = np.where(self.leaf_mask, np.nan, self.thresholds)
# self.thresholds = np.round(self.thresholds, 4) # round thresholds
# setup empty prediction
if self.empty_prediction is None:
self.compute_empty_prediction()
Expand Down Expand Up @@ -118,6 +120,13 @@ def __post_init__(self) -> None:
# setup new feature mapping
if self.feature_map_internal_original is None:
self.feature_map_internal_original = {i: i for i in unique_features}
# flatten values if necessary
if self.values.ndim > 1:
if self.values.shape[1] != 1:
raise ValueError("Values array has more than one column.")
self.values = self.values.flatten()
# set all values of non leaf nodes to zero
self.values[~self.leaf_mask] = 0

def reduce_feature_complexity(self) -> None:
"""Reduces the feature complexity of the tree model.
Expand Down
46 changes: 34 additions & 12 deletions shapiq/explainer/tree/conversion/xgboost.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Functions for converting xgboost decision trees to the format used by
shapiq."""

import warnings
from typing import Optional

import numpy as np
Expand All @@ -25,20 +26,23 @@ def convert_xgboost_booster(
Returns:
The converted xgboost booster.
"""
try:
intercept = tree_booster.base_score
if intercept is None:
intercept = float(tree_booster.intercept_[0])
tree_booster = tree_booster.get_booster()
except AttributeError:
intercept = 0.0
warnings.warn(
"The model does not have a valid base score. Setting the intercept to 0.0."
"Baseline values of the interaction models might be different."
)

# https://github.com/shap/shap/blob/77e92c3c110e816b768a0ec2acfbf4cc08ee13db/shap/explainers/_tree.py#L1992
scaling = 1.0
booster_df = tree_booster.trees_to_dataframe()
output_type = "raw"
if len(booster_df["Tree"].unique()) > tree_booster.num_boosted_rounds():
# choose only trees for the selected class (xgboost grows n_estimators*n_class trees)
# approximation for the number of classes in xgboost
n_class = int(len(booster_df["Tree"].unique()) / tree_booster.num_boosted_rounds())
if class_label is None:
class_label = 0
idc = booster_df["Tree"] % n_class == class_label
booster_df = booster_df.loc[idc, :]

#
if tree_booster.feature_names:
feature_names = tree_booster.feature_names
else:
Expand All @@ -52,14 +56,29 @@ def convert_xgboost_booster(
booster_df.loc[:, "Feature"] = booster_df.loc[:, "Feature"].replace(
convert_feature_str_to_int
)

if len(booster_df["Tree"].unique()) > tree_booster.num_boosted_rounds():
# choose only trees for the selected class (xgboost grows n_estimators*n_class trees)
# approximation for the number of classes in xgboost
n_class = int(len(booster_df["Tree"].unique()) / tree_booster.num_boosted_rounds())
if class_label is None:
class_label = 0
idc = booster_df["Tree"] % n_class == class_label
booster_df = booster_df.loc[idc, :]

n_trees = len(booster_df["Tree"].unique())
intercept /= n_trees
return [
_convert_xgboost_tree_as_df(tree_df=tree_df, output_type=output_type, scaling=scaling)
_convert_xgboost_tree_as_df(
tree_df=tree_df, intercept=intercept, output_type=output_type, scaling=scaling
)
for _, tree_df in booster_df.groupby("Tree")
]


def _convert_xgboost_tree_as_df(
tree_df: Model,
intercept: float,
output_type: str,
scaling: float = 1.0,
) -> TreeModel:
Expand All @@ -77,7 +96,8 @@ def _convert_xgboost_tree_as_df(

# pandas can't chill https://stackoverflow.com/q/77900971
with pd.option_context("future.no_silent_downcasting", True):
return TreeModel(
values = tree_df["Gain"].values * scaling + intercept # add intercept to all values
tree_model = TreeModel(
children_left=tree_df["Yes"]
.replace(convert_node_str_to_int)
.infer_objects(copy=False)
Expand All @@ -92,8 +112,10 @@ def _convert_xgboost_tree_as_df(
.values,
features=tree_df["Feature"].values,
thresholds=tree_df["Split"].values,
values=tree_df["Gain"].values * scaling, # values in non-leaf nodes are not used
values=values, # values in non-leaf nodes are not used
node_sample_weight=tree_df["Cover"].values,
empty_prediction=None,
original_output_type=output_type,
)

return tree_model
20 changes: 15 additions & 5 deletions shapiq/explainer/tree/explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,7 @@ def __init__(
self._treeshapiq_explainers: list[TreeSHAPIQ] = [
TreeSHAPIQ(model=_tree, max_order=self._max_order, index=index) for _tree in self._trees
]

# TODO: for the current implementation this is correct for other trees this may vary
self.baseline_value = sum(
[treeshapiq.empty_prediction for treeshapiq in self._treeshapiq_explainers]
)
self.baseline_value = self._compute_baseline_value()

def explain(self, x: np.ndarray) -> InteractionValues:
# run treeshapiq for all trees
Expand All @@ -90,3 +86,17 @@ def explain(self, x: np.ndarray) -> InteractionValues:
for i in range(1, len(interaction_values)):
final_explanation += interaction_values[i]
return final_explanation

def _compute_baseline_value(self) -> float:
"""Computes the baseline value for the explainer.
The baseline value is the sum of the empty predictions of all trees in the ensemble.
Returns:
The baseline value for the explainer.
"""

baseline_value = sum(
[treeshapiq.empty_prediction for treeshapiq in self._treeshapiq_explainers]
)
return baseline_value
6 changes: 5 additions & 1 deletion shapiq/explainer/tree/treeshapiq.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,13 @@ def __init__(
self._edge_tree: EdgeTree = copy.deepcopy(edge_tree)

# compute the empty prediction
self.empty_prediction: float = float(
computed_empty_prediction = float(
np.sum(self._edge_tree.empty_predictions[self._tree.leaf_mask])
)
tree_empty_prediction = self._tree.empty_prediction
if tree_empty_prediction is None:
tree_empty_prediction = computed_empty_prediction
self.empty_prediction: float = tree_empty_prediction

# stores the interaction scores up to a given order
self.subset_ancestors_store: dict = {}
Expand Down
4 changes: 2 additions & 2 deletions shapiq/explainer/tree/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ def validate_tree_model(
# tree model (is already in the correct format)
if type(model).__name__ == "TreeModel":
tree_model = model
elif isinstance(model, list) and all([type(m).__name__ == "TreeModel" for m in model]):
tree_model = model
# dict as model is parsed to TreeModel (the dict needs to have the correct format and names)
elif type(model).__name__ == "dict":
tree_model = TreeModel(**model)
Expand Down Expand Up @@ -73,8 +75,6 @@ def validate_tree_model(
elif safe_isinstance(model, "xgboost.sklearn.XGBRegressor") or safe_isinstance(
model, "xgboost.sklearn.XGBClassifier"
):
tree_model = convert_xgboost_booster(model.get_booster(), class_label=class_label)
elif safe_isinstance(model, "xgboost.core.Booster"):
tree_model = convert_xgboost_booster(model, class_label=class_label)
# unsupported model
else:
Expand Down
3 changes: 3 additions & 0 deletions shapiq/explainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ def get_predict_function_and_model_type(model, model_class):
elif isinstance(model, tree.TreeModel): # test scenario
_predict_function = model.compute_empty_prediction
_model_type = "tree"
elif isinstance(model, list) and all([isinstance(m, tree.TreeModel) for m in model]):
_predict_function = model[0].compute_empty_prediction
_model_type = "tree"
elif _predict_function is None:
raise TypeError(
f"`model` is of unsupported type: {model_class}.\n"
Expand Down
30 changes: 30 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,17 @@ def rf_clf_model() -> RandomForestClassifier:
return model


@pytest.fixture
def xgb_reg_model():
"""Return a simple xgboost regression model."""
from xgboost import XGBRegressor

X, y = make_regression(n_samples=100, n_features=7, random_state=42)
model = XGBRegressor(random_state=42, n_estimators=3)
model.fit(X, y)
return model


@pytest.fixture
def rf_clf_binary_model() -> RandomForestClassifier:
"""Return a simple random forest model."""
Expand All @@ -102,6 +113,25 @@ def rf_clf_binary_model() -> RandomForestClassifier:
return model


@pytest.fixture
def xgb_clf_model():
"""Return a simple xgboost classification model."""
from xgboost import XGBClassifier

X, y = make_classification(
n_samples=100,
n_features=7,
random_state=42,
n_classes=3,
n_informative=7,
n_repeated=0,
n_redundant=0,
)
model = XGBClassifier(random_state=42, n_estimators=3)
model.fit(X, y)
return model


@pytest.fixture
def background_reg_data() -> np.ndarray:
"""Return a simple background dataset."""
Expand Down
121 changes: 121 additions & 0 deletions tests/tests_explainer/tests_tree_explainer/test_tree_explainer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""This test module contains all tests for the tree explainer module of the shapiq package."""

import copy

import numpy as np
import pytest

Expand Down Expand Up @@ -121,3 +123,122 @@ def test_against_shap_implementation():

with pytest.warns(UserWarning):
_ = TreeExplainer(model=tree_model, max_order=2, min_order=1, index="SV")


def test_xgboost_reg(xgb_reg_model, background_reg_data):
"""Tests the shapiq implementation of TreeSHAP agains SHAP's implementation for XGBoost."""

explanation_instance = 0

# the following code is used to get the shap values from the SHAP implementation
# import shap
# explainer_shap = shap.TreeExplainer(model=xgb_reg_model)
# x_explain_shap = background_reg_data[explanation_instance].reshape(1, -1)
# sv_shap = explainer_shap.shap_values(x_explain_shap)[0]
sv_shap = [-2.555832, 28.50987, 1.7708225, -7.8653603, 10.7955885, -0.1877861, 4.549199]
sv_shap = np.asarray(sv_shap)

# compute with shapiq
explainer_shapiq = TreeExplainer(model=xgb_reg_model, max_order=1, index="SV")
x_explain_shapiq = background_reg_data[explanation_instance]
sv_shapiq = explainer_shapiq.explain(x=x_explain_shapiq)
sv_shapiq_values = sv_shapiq.get_n_order_values(1)
baseline_shapiq = sv_shapiq.baseline_value

assert np.allclose(sv_shap, sv_shapiq_values, rtol=1e-5)

# get prediction of the model
prediction = xgb_reg_model.predict(x_explain_shapiq.reshape(1, -1))
assert prediction == pytest.approx(baseline_shapiq + np.sum(sv_shapiq_values), rel=1e-5)


def test_xgboost_clf(xgb_clf_model, background_clf_data):
"""Tests the shapiq implementation of TreeSHAP agains SHAP's implementation for XGBoost."""

explanation_instance = 1
class_label = 1

# the following code is used to get the shap values from the SHAP implementation
# import shap
# model_copy = copy.deepcopy(xgb_clf_model)
# explainer_shap = shap.TreeExplainer(model=model_copy)
# baseline_shap = float(explainer_shap.expected_value[class_label])
# print(baseline_shap)
# x_explain_shap = copy.deepcopy(background_clf_data[explanation_instance].reshape(1, -1))
# sv_shap_all_classes = explainer_shap.shap_values(x_explain_shap)
# sv_shap = sv_shap_all_classes[0][:, class_label]
# print(sv_shap)
sv = [-0.00545454, -0.15837783, -0.17675081, -0.24213657, 0.00247543, 0.00988865, -0.01564346]
sv_shap = np.array(sv)

# compute with shapiq
explainer_shapiq = TreeExplainer(
model=xgb_clf_model, max_order=1, index="SV", class_label=class_label
)
x_explain_shapiq = copy.deepcopy(background_clf_data[explanation_instance])
sv_shapiq = explainer_shapiq.explain(x=x_explain_shapiq)
sv_shapiq_values = sv_shapiq.get_n_order_values(1)
baseline_shapiq = sv_shapiq.baseline_value

# assert baseline_shap == pytest.approx(baseline_shapiq, rel=1e-4)
assert np.allclose(sv_shap, sv_shapiq_values, rtol=1e-5)

# get prediction of the model (as the log odds)
prediction = xgb_clf_model.predict(x_explain_shapiq.reshape(1, -1), output_margin=True)[0][
class_label
]
assert prediction == pytest.approx(baseline_shapiq + np.sum(sv_shapiq_values), rel=1e-5)


def test_xgboost_shap_error(xgb_clf_model, background_clf_data):
"""Tests for the strange behavior of SHAP's XGBoost implementation.
The test is used to show that the shapiq implementation is correct and the SHAP implementation
is doing something weird. For some instances (e.g. the one used in this test) the SHAP values
are different from the shapiq values. However, when we round the `thresholds` of the xgboost
trees in shapiq, then the computed explanations match. This is a strange behavior as rounding
the thresholds makes the model less true to the original model but only then the explanations
match.
"""

explanation_instance = 0
class_label = 1

# get the shap explanations (the following code is used to get SVs from SHAP)
# import shap
# model_copy = copy.deepcopy(xgb_clf_model)
# explainer_shap = shap.TreeExplainer(model=model_copy)
# baseline_shap = float(explainer_shap.expected_value[class_label])
# x_explain_shap = copy.deepcopy(background_clf_data[explanation_instance].reshape(1, -1))
# sv_shap_all_classes = explainer_shap.shap_values(x_explain_shap)
# sv_shap = sv_shap_all_classes[0][:, class_label]
# print(sv_shap)
# print(baseline_shap)
sv = [-0.00163636, 0.05099502, -0.13182959, -0.44538185, 0.00428653, -0.04872373, -0.01370917]
sv_shap = np.array(sv)

# setup shapiq TreeSHAP
explainer_shapiq = TreeExplainer(
model=xgb_clf_model, max_order=1, index="SV", class_label=class_label
)
x_explain_shapiq = copy.deepcopy(background_clf_data[explanation_instance])
sv_shapiq = explainer_shapiq.explain(x=x_explain_shapiq)
sv_shapiq_values = sv_shapiq.get_n_order_values(1)

# the SHAP sv values should be different from the shapiq values
assert not np.allclose(sv_shap, sv_shapiq_values, rtol=1e-5)

# when we round the model thresholds of the xgb model (thresholds decide weather a feature is
# used or not) -> then suddenly the shap and shapiq values are the same, which points to the
# fact that the shapiq implementation is correct
explainer_shapiq_rounded = TreeExplainer(
model=xgb_clf_model, max_order=1, index="SV", class_label=class_label
)
for tree_explainer in explainer_shapiq_rounded._treeshapiq_explainers:
tree_explainer._tree.thresholds = np.round(tree_explainer._tree.thresholds, 4)
x_explain_shapiq_rounded = copy.deepcopy(background_clf_data[explanation_instance])
sv_shapiq_rounded = explainer_shapiq_rounded.explain(x=x_explain_shapiq_rounded)
sv_shapiq_rounded_values = sv_shapiq_rounded.get_n_order_values(1)

# now the values surprisingly are the same
assert np.allclose(sv_shap, sv_shapiq_rounded_values, rtol=1e-5)

0 comments on commit bdf9ce9

Please sign in to comment.