-
Notifications
You must be signed in to change notification settings - Fork 2
/
attribution_methods.py
129 lines (102 loc) · 5.8 KB
/
attribution_methods.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import torch
import captum
import captum.attr
def get_attributor(model, attributor_name, only_positive=False, binarize=False, interpolate=False, interpolate_dims=(224, 224), batch_mode=False):
attributor_map = {
"BCos": BCosAttributor,
"GradCam": GradCamAttributor,
"IxG": IxGAttributor
}
return attributor_map[attributor_name](model, only_positive, binarize, interpolate, interpolate_dims, batch_mode)
class AttributorBase:
def __init__(self, model, only_positive=False, binarize=False, interpolate=False, interpolate_dims=(224, 224), batch_mode=False):
super().__init__()
self.model = model
self.only_positive = only_positive
self.binarize = binarize
self.interpolate = interpolate
self.interpolate_dims = interpolate_dims
self.batch_mode = batch_mode
def __call__(self, feature, output, class_idx=None, img_idx=None, classes=None):
if self.batch_mode:
return self._call_batch_mode(feature, output, classes)
return self._call_single(feature, output, class_idx, img_idx)
def _call_batch_mode(self, feature, output, classes):
raise NotImplementedError
def _call_single(self, feature, output, class_idx, img_idx):
raise NotImplementedError
def check_interpolate(self, attributions):
if self.interpolate:
return captum.attr.LayerAttribution.interpolate(
attributions, interpolate_dims=self.interpolate_dims, interpolate_mode="bilinear")
return attributions
def check_binarize(self, attributions):
if self.binarize:
attr_max = attributions.abs().amax(dim=(1, 2, 3), keepdim=True)
attributions = torch.where(
attr_max == 0, attributions, attributions/attr_max)
return attributions
def check_only_positive(self, attributions):
if self.only_positive:
return attributions.clamp(min=0)
return attributions
def apply_post_processing(self, attributions):
attributions = self.check_only_positive(attributions)
attributions = self.check_binarize(attributions)
attributions = self.check_interpolate(attributions)
return attributions
class BCosAttributor(AttributorBase):
def __init__(self, model, only_positive=False, binarize=False, interpolate=False, interpolate_dims=(224, 224), batch_mode=False):
super().__init__(model=model, only_positive=only_positive, binarize=binarize,
interpolate=interpolate, interpolate_dims=interpolate_dims, batch_mode=batch_mode)
def _call_batch_mode(self, feature, output, classes):
target_outputs = torch.gather(output, 1, classes.unsqueeze(-1))
with self.model.explanation_mode():
grads = torch.autograd.grad(torch.unbind(
target_outputs), feature, create_graph=True, retain_graph=True)[0]
attributions = (grads*feature).sum(dim=1, keepdim=True)
return self.apply_post_processing(attributions)
def _call_single(self, feature, output, class_idx, img_idx):
with self.model.explanation_mode():
grads = torch.autograd.grad(
output[img_idx, class_idx], feature, create_graph=True, retain_graph=True)[0]
attributions = (grads[img_idx]*feature[img_idx]
).sum(dim=0, keepdim=True).unsqueeze(0)
return self.apply_post_processing(attributions)
class GradCamAttributor(AttributorBase):
def __init__(self, model, only_positive=False, binarize=False, interpolate=False, interpolate_dims=(224, 224), batch_mode=False):
super().__init__(model=model, only_positive=only_positive, binarize=binarize,
interpolate=interpolate, interpolate_dims=interpolate_dims, batch_mode=batch_mode)
def _call_batch_mode(self, feature, output, classes):
target_outputs = torch.gather(output, 1, classes.unsqueeze(-1))
grads = torch.autograd.grad(torch.unbind(
target_outputs), feature, create_graph=True, retain_graph=True)[0]
grads = grads.mean(dim=(2, 3), keepdim=True)
prods = grads * feature
attributions = torch.nn.functional.relu(
torch.sum(prods, axis=1, keepdim=True))
return self.apply_post_processing(attributions)
def _call_single(self, feature, output, class_idx, img_idx):
grads = torch.autograd.grad(
output[img_idx, class_idx], feature, create_graph=True, retain_graph=True)[0]
grads = grads.mean(dim=(2, 3), keepdim=True)
prods = grads[img_idx] * feature[img_idx]
attributions = torch.nn.functional.relu(
torch.sum(prods, axis=0, keepdim=True)).unsqueeze(0)
return self.apply_post_processing(attributions)
class IxGAttributor(AttributorBase):
def __init__(self, model, only_positive=False, binarize=False, interpolate=False, interpolate_dims=(224, 224), batch_mode=False):
super().__init__(model=model, only_positive=only_positive, binarize=binarize,
interpolate=interpolate, interpolate_dims=interpolate_dims, batch_mode=batch_mode)
def _call_batch_mode(self, feature, output, classes):
target_outputs = torch.gather(output, 1, classes.unsqueeze(-1))
grads = torch.autograd.grad(torch.unbind(
target_outputs), feature, create_graph=True, retain_graph=True)[0]
attributions = (grads * feature).sum(dim=1, keepdim=True)
return self.apply_post_processing(attributions)
def _call_single(self, feature, output, class_idx, img_idx):
grads = torch.autograd.grad(
output[img_idx, class_idx], feature, create_graph=True, retain_graph=True)[0]
attributions = (grads[img_idx] * feature[img_idx]
).sum(dim=0, keepdim=True).unsqueeze(0)
return self.apply_post_processing(attributions)