-
Notifications
You must be signed in to change notification settings - Fork 2
/
generate_assets.jl
41 lines (35 loc) · 1.47 KB
/
generate_assets.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
using ExplainableAI
using Metalhead # pre-trained vision models
using HTTP, FileIO, ImageMagick # load image from URL
# Load model
model = VGG(16; pretrain=true).layers
model = strip_softmax(model)
model = canonize(model)
# Load input
url = HTTP.URI("https://raw.githubusercontent.com/adrhill/ExplainableAI.jl/gh-pages/assets/heatmaps/castle.jpg")
img = load(url)
input = preprocess_imagenet(img)
input = reshape(input, 224, 224, 3, :) # reshape to WHCN format
# Run XAI methods
methods = Dict(
"InputTimesGradient" => InputTimesGradient,
"Gradient" => Gradient,
"SmoothGrad" => SmoothGrad,
"IntegratedGradients" => IntegratedGradients,
"LRP" => LRP,
"LRPEpsilonGammaBox" => model -> LRP(model, EpsilonGammaBox(-3.0f0, 3.0f0)),
"LRPEpsilonPlus" => model -> LRP(model, EpsilonPlus()),
"LRPEpsilonAlpha2Beta1" => model -> LRP(model, EpsilonAlpha2Beta1()),
"LRPEpsilonPlusFlat" => model -> LRP(model, EpsilonPlusFlat()),
"LRPEpsilonAlpha2Beta1Flat" => model -> LRP(model, EpsilonAlpha2Beta1Flat()),
)
for (name, method) in methods
@info "Generating $name assets..."
analyzer = method(model)
# Max activated neuron corresponds to "castle"
h = heatmap(input, analyzer)
save("castle_$name.png", h)
# Output neuron 920 corresponds to "street sign"
h = heatmap(input, analyzer, 920)
save("streetsign_$name.png", h)
end