Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make TextHeatmaps and VisionHeatmaps strong dependencies #4

Merged
merged 3 commits into from
Nov 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 1 addition & 9 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,20 +1,12 @@
name = "XAIBase"
uuid = "9b48221d-a747-4c1b-9860-46a1d8ba24a7"
authors = ["Adrian Hill <[email protected]>"]
version = "1.0.0"
version = "2.0.0-DEV"

[deps]
TextHeatmaps = "2dd6718a-6083-4824-b9f7-90e4a57f72d2"
VisionHeatmaps = "27106da1-f8bc-4ca8-8c66-9b8289f1e035"

[weakdeps]
TextHeatmaps = "2dd6718a-6083-4824-b9f7-90e4a57f72d2"
VisionHeatmaps = "27106da1-f8bc-4ca8-8c66-9b8289f1e035"

[extensions]
XAITextHeatmapsExt = "TextHeatmaps"
XAIVisionHeatmapsExt = "VisionHeatmaps"

[compat]
TextHeatmaps = "1.1"
VisionHeatmaps = "1.1"
Expand Down
43 changes: 0 additions & 43 deletions ext/XAITextHeatmapsExt.jl

This file was deleted.

26 changes: 0 additions & 26 deletions ext/XAIVisionHeatmapsExt.jl

This file was deleted.

19 changes: 6 additions & 13 deletions src/XAIBase.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
module XAIBase

using TextHeatmaps
using VisionHeatmaps

# Abstract super type of all XAI methods.
# Is expected that all methods are callable types that return an `Explanation`:
#
Expand All @@ -20,22 +23,12 @@ include("explanation.jl")
# which in turn calls `(method)(input, output_selector)`.
include("analyze.jl")

# Heatmapping presets used in package extensions.
# To keep dependencies minimal, heatmapping functionality is loaded via package extensions
# on VisionHeatmaps.jl and TextHeatmaps.jl in the `/ext` folder.
# Shared heatmapping default settings are defined in this file.
include("heatmap_presets.jl")
# Heatmapping for vision and NLP tasks.
include("heatmaps.jl")

export AbstractXAIMethod
export AbstractNeuronSelector
export Explanation
export analyze

# Package extension backwards compatibility with Julia 1.6.
# For Julia 1.6, VisionHeatmaps.jl and TextHeatmaps.jl are treated as normal dependencies and always loaded.
# https://pkgdocs.julialang.org/v1/creating-packages/#Transition-from-normal-dependency-to-extension
if !isdefined(Base, :get_extension)
include("../ext/XAITextHeatmapsExt.jl")
include("../ext/XAIVisionHeatmapsExt.jl")
end
export heatmap, textheatmap
end #module
12 changes: 0 additions & 12 deletions src/heatmap_presets.jl

This file was deleted.

119 changes: 119 additions & 0 deletions src/heatmaps.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
struct HeatmapConfig
colorscheme::Symbol
reduce::Symbol
rangescale::Symbol
end

const DEFAULT_COLORSCHEME = :seismic
const DEFAULT_REDUCE = :sum
const DEFAULT_RANGESCALE = :centered
const DEFAULT_HEATMAP_PRESET = HeatmapConfig(
DEFAULT_COLORSCHEME, DEFAULT_REDUCE, DEFAULT_RANGESCALE
)

const HEATMAP_PRESETS = Dict{Symbol,HeatmapConfig}(
:LRP => HeatmapConfig(:seismic, :sum, :centered), # attribution
:CRP => HeatmapConfig(:seismic, :sum, :centered), # attribution
:InputTimesGradient => HeatmapConfig(:seismic, :sum, :centered), # attribution
:Gradient => HeatmapConfig(:grays, :norm, :extrema), # sensitivity
)

# Select HeatmapConfig preset based on analyzer
function get_heatmapping_config(analyzer::Symbol)
return get(HEATMAP_PRESETS, analyzer, DEFAULT_HEATMAP_PRESET)
end

# Override HeatmapConfig preset with keyword arguments
function get_heatmapping_config(expl::Explanation; kwargs...)
c = get_heatmapping_config(expl.analyzer)

colorscheme = get(kwargs, :colorscheme, c.colorscheme)
rangescale = get(kwargs, :rangescale, c.rangescale)
reduce = get(kwargs, :reduce, c.reduce)
return HeatmapConfig(colorscheme, reduce, rangescale)
end

#=================#
# Vision Heatmaps #
#=================#

"""
heatmap(explanation)

Visualize `Explanation` from XAIBase as a vision heatmap.
Assumes WHCN convention (width, height, channels, batchsize) for `explanation.val`.

## Keyword arguments
- `colorscheme::Union{ColorScheme,Symbol}`: Color scheme from ColorSchemes.jl.
Defaults to `seismic`.
- `reduce::Symbol`: Selects how color channels are reduced to a single number to apply a color scheme.
The following methods can be selected, which are then applied over the color channels
for each "pixel" in the array:
- `:sum`: sum up color channels
- `:norm`: compute 2-norm over the color channels
- `:maxabs`: compute `maximum(abs, x)` over the color channels
Defaults to `:$DEFAULT_REDUCE`.
- `rangescale::Symbol`: Selects how the color channel reduced heatmap is normalized
before the color scheme is applied. Can be either `:extrema` or `:centered`.
Defaults to `:$DEFAULT_RANGESCALE`.
- `process_batch::Bool`: When heatmapping a batch, setting `process_batch=true`
will apply the `rangescale` normalization to the entire batch
instead of computing it individually for each sample in the batch.
Defaults to `false`.
- `permute::Bool`: Whether to flip W&H input channels. Default is `true`.
- `unpack_singleton::Bool`: If false, `heatmap` will always return a vector of images.
When heatmapping a batch with a single sample, setting `unpack_singleton=true`
will unpack the singleton vector and directly return the image. Defaults to `true`.
"""
function heatmap(expl::Explanation; kwargs...)
c = get_heatmapping_config(expl; kwargs...)
return VisionHeatmaps.heatmap(
expl.val;
colorscheme=c.colorscheme,
reduce=c.reduce,
rangescale=c.rangescale,
kwargs...,
)
end

#===============#
# Text Heatmaps #
#===============#

"""
textheatmap(explanation, text)

Visualize `Explanation` from XAIBase as text heatmap.
Text should be a vector containing vectors of strings, one for each input in the batched explanation.

## Keyword arguments
- `colorscheme::Union{ColorScheme,Symbol}`: color scheme from ColorSchemes.jl.
Defaults to `:$DEFAULT_COLORSCHEME`.
- `rangescale::Symbol`: selects how the color channel reduced heatmap is normalized
before the color scheme is applied. Can be either `:extrema` or `:centered`.
Defaults to `:$DEFAULT_RANGESCALE` for use with the default color scheme `:$DEFAULT_COLORSCHEME`.
"""
function textheatmap(
expl::Explanation, texts::AbstractVector{<:AbstractVector{<:AbstractString}}; kwargs...
)
ndims(expl.val) != 2 && throw(
ArgumentError(
"To heatmap text, `explanation.val` must be 2D array of shape `(input_length, batchsize)`. Got array of shape $(size(x)) instead.",
),
)
batchsize = size(expl.val, 2)
textsize = length(texts)
batchsize != textsize && throw(
ArgumentError("Batchsize $batchsize doesn't match number of texts $textsize.")
)

c = get_heatmapping_config(expl; kwargs...)
return [
TextHeatmaps.heatmap(v, t; colorscheme=c.colorscheme, rangescale=c.rangescale) for
(v, t) in zip(eachcol(expl.val), texts)
]
end

function textheatmap(expl::Explanation, text::AbstractVector{<:AbstractString}; kwargs...)
return textheatmap(expl, [text]; kwargs...)
end
1 change: 0 additions & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,3 @@ VisionHeatmaps = "27106da1-f8bc-4ca8-8c66-9b8289f1e035"

[compat]
Aqua = "0.8"
TextHeatmaps = "1.0.0"
12 changes: 6 additions & 6 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ using Aqua
@info "Testing neuron selection..."
include("test_neuron_selection.jl")
end
@testset "VisionHeatmaps" begin
@info "Testing VisionHeatmaps extension..."
include("test_heatmaps_vision.jl")
@testset "Vision heatmaps" begin
@info "Testing vision heatmaps..."
include("test_heatmap.jl")
end
@testset "TextHeatmaps" begin
@info "Testing TextHeatmaps extension..."
include("test_heatmaps_text.jl")
@testset "Text heatmaps" begin
@info "Testing text heatmaps..."
include("test_textheatmap.jl")
end
end
10 changes: 4 additions & 6 deletions test/test_heatmaps_vision.jl → test/test_heatmap.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
using VisionHeatmaps

# NOTE: Heatmapping assumes Flux's WHCN convention (width, height, color channels, batch size).
# Single input
shape = (2, 2, 3, 1)
Expand All @@ -17,16 +15,16 @@ reducers = [:sum, :maxabs, :norm]
rangescales = [:extrema, :centered]
for r in reducers
for n in rangescales
local h = VisionHeatmaps.heatmap(expl; reduce=r, rangescale=n)
local h = heatmap(expl; reduce=r, rangescale=n)
@test_reference "references/heatmaps/reduce_$(r)_rangescale_$(n).txt" h

local h = VisionHeatmaps.heatmap(expl_batch; reduce=r, rangescale=n)
local h = heatmap(expl_batch; reduce=r, rangescale=n)
@test_reference "references/heatmaps/reduce_$(r)_rangescale_$(n).txt" h[1]
@test_reference "references/heatmaps/reduce_$(r)_rangescale_$(n)2.txt" h[2]
end
end

h1 = VisionHeatmaps.heatmap(expl_batch)
h2 = VisionHeatmaps.heatmap(expl_batch; process_batch=true)
h1 = heatmap(expl_batch)
h2 = heatmap(expl_batch; process_batch=true)
@test_reference "references/heatmaps/process_batch_false.txt" h1
@test_reference "references/heatmaps/process_batch_true.txt" h2
12 changes: 5 additions & 7 deletions test/test_heatmaps_text.jl → test/test_textheatmap.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,20 @@
using TextHeatmaps

val = output = [1 6; 2 5; 3 4]
text = [["Test", "Text", "Heatmap"], ["another", "dummy", "input"]]
neuron_selection = [CartesianIndex(1, 2), CartesianIndex(3, 4)] # irrelevant
expl = Explanation(val, output, neuron_selection, :Gradient)
h = TextHeatmaps.heatmap(expl, text)
h = textheatmap(expl, text)
@test_reference "references/Gradient1.txt" repr("text/plain", h[1])
@test_reference "references/Gradient2.txt" repr("text/plain", h[2])

expl = Explanation(val[:, 1:1], output[:, 1:1], neuron_selection[1], :Gradient)
h = TextHeatmaps.heatmap(expl, text[1])
@test_reference "references/Gradient1.txt" repr("text/plain", h)
h = textheatmap(expl, text[1])
@test_reference "references/Gradient1.txt" repr("text/plain", only(h))

expl = Explanation(val, output, neuron_selection, :LRP)
h = TextHeatmaps.heatmap(expl, text)
h = textheatmap(expl, text)
@test_reference "references/LRP1.txt" repr("text/plain", h[1])
@test_reference "references/LRP2.txt" repr("text/plain", h[2])

h = TextHeatmaps.heatmap(expl, text; rangescale=:extrema)
h = textheatmap(expl, text; rangescale=:extrema)
@test_reference "references/LRP1_extrema.txt" repr("text/plain", h[1])
@test_reference "references/LRP2_extrema.txt" repr("text/plain", h[2])
Loading