Skip to content

Commit

Permalink
Return closest not predicted class for trust scores (#67)
Browse files Browse the repository at this point in the history
* add closest class to trust scores

* add not pred class trustscores

* update trust score test
  • Loading branch information
arnaudvl authored and jklaise committed May 7, 2019
1 parent 8423049 commit 18c83ab
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 37 deletions.
4 changes: 2 additions & 2 deletions alibi/confidence/tests/test_trustscore.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ def test_trustscore(filter_type):
# test one-hot encoding of Y vs. class labels
ts = TrustScore()
ts.fit(X_train, Y_train, classes=3)
score_class = ts.score(X_test, Y_pred)
score_class, _ = ts.score(X_test, Y_pred)
ts = TrustScore()
ts.fit(X_train, to_categorical(Y_train), classes=3)
score_ohe = ts.score(X_test, Y_pred_proba)
score_ohe, _ = ts.score(X_test, Y_pred_proba)
assert (score_class != score_ohe).astype(int).sum() == 0
9 changes: 6 additions & 3 deletions alibi/confidence/trustscore.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,8 @@ def fit(self, X: np.ndarray, Y: np.ndarray, classes: int = None) -> None:

self.kdtrees[c] = KDTree(X_fit, leaf_size=self.leaf_size, metric=self.metric) # build KDTree for class c

def score(self, X: np.ndarray, Y: np.ndarray, k: int = 2, dist_type: str = 'point') -> np.ndarray:
def score(self, X: np.ndarray, Y: np.ndarray, k: int = 2, dist_type: str = 'point') \
-> Tuple[np.ndarray, np.ndarray]:
"""
Calculate trust scores = ratio of distance to closest class other than the
predicted class to distance to predicted class.
Expand All @@ -158,7 +159,7 @@ def score(self, X: np.ndarray, Y: np.ndarray, k: int = 2, dist_type: str = 'poin
Returns
-------
Batch with trust scores.
Batch with trust scores and the closest not predicted class.
"""
# make sure Y represents predicted classes, not probabilities
if len(Y.shape) > 1:
Expand All @@ -184,4 +185,6 @@ def score(self, X: np.ndarray, Y: np.ndarray, k: int = 2, dist_type: str = 'poin
d_to_pred = d[range(d.shape[0]), Y]
d_to_closest_not_pred = np.where(sorted_d[:, 0] != d_to_pred, sorted_d[:, 0], sorted_d[:, 1])
trust_score = d_to_closest_not_pred / (d_to_pred + self.eps)
return trust_score
# closest not predicted class
class_closest_not_pred = np.where(d == d_to_closest_not_pred.reshape(-1, 1))[1]
return trust_score, class_closest_not_pred
14 changes: 6 additions & 8 deletions doc/source/methods/TrustScores.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -98,21 +98,19 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"The trust scores are simply calculated through the `score` method:\n",
"The trust scores are simply calculated through the `score` method. `score` also returns the class labels of the closest not predicted class as a numpy array:\n",
"\n",
"```python\n",
"score = ts.score(X_test, \n",
" y_pred, \n",
" k=2,\n",
" dist_type='point')\n",
"score, closest_class = ts.score(X_test, \n",
" y_pred, \n",
" k=2,\n",
" dist_type='point')\n",
"```\n",
"\n",
"*y_pred* can again be represented using both OHE or via class labels.\n",
"\n",
"* `k`: $k$th nearest neighbor used to compute distance to for each class.\n",
"* `dist_type`: similar to the filtering step, we can compute the distance to each class either to the $k$-th nearest point (*point*) or by using the average distance from the 1st to the $k$th nearest point (*mean*).\n",
"\n",
"The trust scores for each instance in the test set are returned as a numpy array."
"* `dist_type`: similar to the filtering step, we can compute the distance to each class either to the $k$-th nearest point (*point*) or by using the average distance from the 1st to the $k$th nearest point (*mean*)."
]
},
{
Expand Down
34 changes: 23 additions & 11 deletions examples/trustscore_iris.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,21 @@
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Predicted class: [2 2 2 2 2 2 2 2 2]\n"
]
}
],
"source": [
"np.random.seed(0)\n",
"clf = LogisticRegression(solver='liblinear', multi_class='auto')\n",
"clf.fit(X_train, y_train)\n",
"y_pred = clf.predict(X_test)"
"y_pred = clf.predict(X_test)\n",
"print('Predicted class: {}'.format(y_pred))"
]
},
{
Expand Down Expand Up @@ -159,7 +168,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Since the trust score is the ratio between the distance of the test instance to the nearest class different from the predicted class and the distance to the predicted class, higher scores correspond to more trustworthy predictions. A score of 1 would mean that the distance to the predicted class is the same as to another class."
"Since the trust score is the ratio between the distance of the test instance to the nearest class different from the predicted class and the distance to the predicted class, higher scores correspond to more trustworthy predictions. A score of 1 would mean that the distance to the predicted class is the same as to another class. The `score` method returns arrays with both the trust scores and the class labels of the closest not predicted class."
]
},
{
Expand All @@ -171,18 +180,21 @@
"name": "stdout",
"output_type": "stream",
"text": [
"[2.574271277538439 2.1630334957870114 3.1629405367742223\n",
"Trust scores: [2.574271277538439 2.1630334957870114 3.1629405367742223\n",
" 2.7258494544157927 2.541748027539072 1.402878283257114 1.941073062524019\n",
" 2.0601725424359296 2.1781121494573514]\n"
" 2.0601725424359296 2.1781121494573514]\n",
"\n",
"Closest not predicted class: [1 1 1 1 1 1 1 1 1]\n"
]
}
],
"source": [
"score = ts.score(X_test, \n",
" y_pred, \n",
" k=2, # kth nearest neighbor used to compute distances for each class\n",
" dist_type='point') # 'point' or 'mean' distance option\n",
"print(score)"
"score, closest_class = ts.score(X_test, \n",
" y_pred, k=2, # kth nearest neighbor used \n",
" # to compute distances for each class\n",
" dist_type='point') # 'point' or 'mean' distance option\n",
"print('Trust scores: {}'.format(score))\n",
"print('\\nClosest not predicted class: {}'.format(closest_class))"
]
},
{
Expand Down Expand Up @@ -302,7 +314,7 @@
" # calculate trust scores\n",
" ts = TrustScore()\n",
" ts.fit(X_train, y_train, classes=classes)\n",
" scores = ts.score(X_test, y_pred)\n",
" scores, _ = ts.score(X_test, y_pred)\n",
" final_curves.append(scores) # contains prediction probabilities and trust scores\n",
" # check where prediction probabilities and trust scores are above a certain percentage level\n",
" for p, perc in enumerate(percentiles):\n",
Expand Down
Loading

0 comments on commit 18c83ab

Please sign in to comment.