From e5b5d8fdaa20802c99831c7ff7aef360d33360d8 Mon Sep 17 00:00:00 2001 From: Adrian Hill Date: Fri, 17 Nov 2023 18:20:48 +0100 Subject: [PATCH] Fix keyword argument `add_batch_dim` --- src/XAIBase.jl | 3 +++ src/analyze.jl | 2 +- src/compat.jl | 18 ++++++++++++++++++ src/utils.jl | 21 +++++++++++++++++++++ test/test_api.jl | 9 +++++++++ 5 files changed, 52 insertions(+), 1 deletion(-) create mode 100644 src/compat.jl create mode 100644 src/utils.jl diff --git a/src/XAIBase.jl b/src/XAIBase.jl index 6a810c8..e41078f 100644 --- a/src/XAIBase.jl +++ b/src/XAIBase.jl @@ -3,6 +3,9 @@ module XAIBase using TextHeatmaps using VisionHeatmaps +include("compat.jl") +include("utils.jl") + # Abstract super type of all XAI methods. # Is expected that all methods are callable types that return an `Explanation`: # diff --git a/src/analyze.jl b/src/analyze.jl index c456284..46c40fd 100644 --- a/src/analyze.jl +++ b/src/analyze.jl @@ -1,7 +1,7 @@ const BATCHDIM_MISSING = ArgumentError( """The input is a 1D vector and therefore missing the required batch dimension. - Call `analyze` with the keyword argument `add_batch_dim=false`.""" + Call `analyze` with the keyword argument `add_batch_dim=true`.""" ) """ diff --git a/src/compat.jl b/src/compat.jl new file mode 100644 index 0000000..b6f3da8 --- /dev/null +++ b/src/compat.jl @@ -0,0 +1,18 @@ +# https://github.com/JuliaLang/julia/pull/39794 +if VERSION < v"1.7.0-DEV.793" + export Returns + + struct Returns{V} <: Function + value::V + Returns{V}(value) where {V} = new{V}(value) + Returns(value) = new{Core.Typeof(value)}(value) + end + + (obj::Returns)(args...; kw...) = obj.value + function Base.show(io::IO, obj::Returns) + show(io, typeof(obj)) + print(io, "(") + show(io, obj.value) + return print(io, ")") + end +end diff --git a/src/utils.jl b/src/utils.jl new file mode 100644 index 0000000..42ec023 --- /dev/null +++ b/src/utils.jl @@ -0,0 +1,21 @@ +""" + batch_dim_view(A) + +Return a view onto the array `A` that contains an extra singleton batch dimension at the end. +This avoids allocating a new array. + +## Example +```juliarepl +julia> A = [1 2; 3 4] +2×2 Matrix{Int64}: + 1 2 + 3 4 + +julia> batch_dim_view(A) +2×2×1 view(::Array{Int64, 3}, 1:2, 1:2, :) with eltype Int64: +[:, :, 1] = + 1 2 + 3 4 +``` +""" +batch_dim_view(A::AbstractArray{T,N}) where {T,N} = view(A, ntuple(Returns(:), N + 1)...) diff --git a/test/test_api.jl b/test/test_api.jl index ffbc62d..761013c 100644 --- a/test/test_api.jl +++ b/test/test_api.jl @@ -20,6 +20,11 @@ expl = analyze(input, analyzer) expl = analyzer(input) @test expl.val == val +# Max activation + add_batch_dim +input_vec = [1, 2, 3] +expl = analyzer(input_vec; add_batch_dim=true) +@test expl.val == val[:, 1:1] + # Ouput selection output_neuron = 2 val = [2 30; 4 25; 6 20] @@ -29,3 +34,7 @@ expl = analyze(input, analyzer, output_neuron) @test isnothing(expl.extras) expl = analyzer(input, output_neuron) @test expl.val == val + +# Ouput selection + add_batch_dim +expl = analyzer(input_vec, output_neuron; add_batch_dim=true) +@test expl.val == val[:, 1:1]