Skip to content

Commit

Permalink
Avoid breaking changes (#6)
Browse files Browse the repository at this point in the history
* Avoid new `Explanation` field from PR #5 being breaking

* Undo separation between `heatmap` and `textheatmap` from PR #4
  • Loading branch information
adrhill authored Nov 17, 2023
1 parent 5658de9 commit 49243e3
Show file tree
Hide file tree
Showing 8 changed files with 37 additions and 13 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "XAIBase"
uuid = "9b48221d-a747-4c1b-9860-46a1d8ba24a7"
authors = ["Adrian Hill <[email protected]>"]
version = "2.0.0-DEV"
version = "1.1.0-DEV"

[deps]
TextHeatmaps = "2dd6718a-6083-4824-b9f7-90e4a57f72d2"
Expand Down
5 changes: 4 additions & 1 deletion src/XAIBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,12 @@ include("analyze.jl")
# Heatmapping for vision and NLP tasks.
include("heatmaps.jl")

# To be removed in next breaking release:
include("deprecated.jl")

export AbstractXAIMethod
export AbstractNeuronSelector
export Explanation
export analyze
export heatmap, textheatmap
export heatmap
end #module
10 changes: 10 additions & 0 deletions src/deprecated.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
function Explanation(val, output, output_selection, analyzer::Symbol)
@warn "Creating an Explanation without a heatmap style is being deprecated. Defaulting to `heatmap=:attribution`."
return Explanation(val, output, output_selection, analyzer, :attribution, nothing)
end
function Explanation(
val, output, output_selection, analyzer::Symbol, extras::Union{Nothing,NamedTuple}
)
@warn "Creating an Explanation without a heatmap style is being deprecated. Defaulting to `heatmap=:attribution`."
return Explanation(val, output, output_selection, analyzer, :attribution, extras)
end
2 changes: 1 addition & 1 deletion src/explanation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@ struct Explanation{V,O,S,E<:Union{Nothing,NamedTuple}}
heatmap::Symbol
extras::E
end
function Explanation(val, output, output_selection, analyzer, heatmap)
function Explanation(val, output, output_selection, analyzer::Symbol, heatmap::Symbol)
return Explanation(val, output, output_selection, analyzer, heatmap, nothing)
end
12 changes: 6 additions & 6 deletions src/heatmaps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ Visualize `Explanation` from XAIBase as a vision heatmap.
Assumes WHCN convention (width, height, channels, batchsize) for `explanation.val`.
## Keyword arguments
- `colorscheme::Union{ColorScheme,Symbol}`: Color scheme from ColorSchemes.jl.
Defaults to `seismic`.
- `colorscheme::Union{ColorScheme,Symbol}`: color scheme from ColorSchemes.jl.
Defaults to `:$DEFAULT_COLORSCHEME`.
- `reduce::Symbol`: Selects how color channels are reduced to a single number to apply a color scheme.
The following methods can be selected, which are then applied over the color channels
for each "pixel" in the array:
Expand Down Expand Up @@ -80,7 +80,7 @@ end
#===============#

"""
textheatmap(explanation, text)
heatmap(explanation, text)
Visualize `Explanation` from XAIBase as text heatmap.
Text should be a vector containing vectors of strings, one for each input in the batched explanation.
Expand All @@ -92,7 +92,7 @@ Text should be a vector containing vectors of strings, one for each input in the
before the color scheme is applied. Can be either `:extrema` or `:centered`.
Defaults to `:$DEFAULT_RANGESCALE` for use with the default color scheme `:$DEFAULT_COLORSCHEME`.
"""
function textheatmap(
function heatmap(
expl::Explanation, texts::AbstractVector{<:AbstractVector{<:AbstractString}}; kwargs...
)
ndims(expl.val) != 2 && throw(
Expand All @@ -113,6 +113,6 @@ function textheatmap(
]
end

function textheatmap(expl::Explanation, text::AbstractVector{<:AbstractString}; kwargs...)
return textheatmap(expl, [text]; kwargs...)
function heatmap(expl::Explanation, text::AbstractVector{<:AbstractString}; kwargs...)
return heatmap(expl, [text]; kwargs...)
end
5 changes: 5 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,9 @@ using Aqua
@info "Testing text heatmaps..."
include("test_textheatmap.jl")
end
@testset "Deprecations" begin
# To be removed in next breaking release
@info "Testing deprecations..."
include("test_deprecated.jl")
end
end
6 changes: 6 additions & 0 deletions test/test_deprecated.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Test deprecation warnings
shape = (2, 2, 3, 1)
val = output = reshape(collect(Float32, 1:prod(shape)), shape)
neuron_selection = [CartesianIndex(1, 2)] # irrelevant
@test_logs (:warn,) expl = Explanation(val, output, [neuron_selection], :LRP)
@test_logs (:warn,) expl = Explanation(val, output, [neuron_selection], :LRP, nothing)
8 changes: 4 additions & 4 deletions test/test_textheatmap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,21 @@ val = output = [1 6; 2 5; 3 4]
text = [["Test", "Text", "Heatmap"], ["another", "dummy", "input"]]
neuron_selection = [CartesianIndex(1, 2), CartesianIndex(3, 4)] # irrelevant
expl = Explanation(val, output, neuron_selection, :Gradient, :sensitivity)
h = textheatmap(expl, text)
h = heatmap(expl, text)
@test_reference "references/Gradient1.txt" repr("text/plain", h[1])
@test_reference "references/Gradient2.txt" repr("text/plain", h[2])

expl = Explanation(
val[:, 1:1], output[:, 1:1], neuron_selection[1], :Gradient, :sensitivity
)
h = textheatmap(expl, text[1])
h = heatmap(expl, text[1])
@test_reference "references/Gradient1.txt" repr("text/plain", only(h))

expl = Explanation(val, output, neuron_selection, :LRP, :attribution)
h = textheatmap(expl, text)
h = heatmap(expl, text)
@test_reference "references/LRP1.txt" repr("text/plain", h[1])
@test_reference "references/LRP2.txt" repr("text/plain", h[2])

h = textheatmap(expl, text; rangescale=:extrema)
h = heatmap(expl, text; rangescale=:extrema)
@test_reference "references/LRP1_extrema.txt" repr("text/plain", h[1])
@test_reference "references/LRP2_extrema.txt" repr("text/plain", h[2])

0 comments on commit 49243e3

Please sign in to comment.