Skip to content

Commit

Permalink
Merge pull request #14 from JuliaTrustworthyAI/13-plotce-not-working-…
Browse files Browse the repository at this point in the history
…if-num_counterfactuals-1

Fixes bug
  • Loading branch information
pat-alt authored Jan 3, 2024
2 parents 4301cec + f5a11ef commit 155b130
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 10 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
- '1.7'
- '1.8'
- '1.9'
- 'nightly'
- '1.10'
os:
- ubuntu-latest
arch:
Expand Down
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TaijaPlotting"
uuid = "bd7198b4-c7d6-400c-9bab-9a24614b0240"
authors = ["Patrick Altmeyer"]
version = "1.0.5"
version = "1.0.6"

[deps]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
Expand All @@ -13,6 +13,7 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
LaplaceRedux = "c52c1a26-f7c5-402b-80be-ba1e638ad478"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
ManifoldLearning = "06eb3307-b2af-5a2a-abea-d33192699d32"
MultivariateStats = "6f286f6a-111f-5878-ab1e-185364afe411"
NaturalSort = "c020b1a1-e9b0-503a-9c33-f039bfc54a85"
Expand All @@ -30,11 +31,12 @@ LaplaceRedux = "0.1"
LinearAlgebra = "1.7, 1.8, 1.9"
MLJBase = "0.21, 0.22, 1"
ManifoldLearning = "0.9"
MLUtils = "0.4"
MultivariateStats = "0.9, 0.10"
NaturalSort = "1"
NearestNeighborModels = "0.2"
Plots = "1"
julia = "1.7, 1.8, 1.9"
julia = "1.7, 1.8, 1.9, 1.10"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand Down
2 changes: 1 addition & 1 deletion src/CounterfactualExplations/CounterfactualExplanations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ using CounterfactualExplanations.Models

include("data.jl")
include("models.jl")
include("countefactuals.jl")
include("counterfactuals.jl")
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using MLUtils: stack

"""
Plots.plot(
ce::CounterfactualExplanation;
Expand Down Expand Up @@ -103,21 +105,21 @@ end
plot_state(
ce::CounterfactualExplanation,
t::Int,
final_sate::Bool;
final_state::Bool;
kwargs...
)
Helper function that plots a single step of the counterfactual path.
"""
function plot_state(ce::CounterfactualExplanation, t::Int, final_sate::Bool; kwargs...)
function plot_state(ce::CounterfactualExplanation, t::Int, final_state::Bool; kwargs...)
args = PlotIngredients(; kwargs...)
x1 = [args.path_embedded[1,t]]
x2 = [args.path_embedded[2,t]]
x1 = args.path_embedded[1,t,:]
x2 = args.path_embedded[2,t,:]
y = args.path_labels[t]
_c = CategoricalArrays.levelcode.(y)
n_ = ce.num_counterfactuals
label_ = reshape(["C$i" for i in 1:n_], 1, n_)
if !final_sate
if !final_state
scatter!(args.p1, x1, x2; group=y, colour=_c, ms=5, label="")
else
scatter!(args.p1, x1, x2; group=y, colour=_c, ms=10, label="")
Expand Down
2 changes: 1 addition & 1 deletion src/CounterfactualExplations/data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ function embed(data::CounterfactualData, X::AbstractArray=nothing; dim_red::Symb
end

# Transforming:
X = typeof(X) <: Vector{<:Matrix} ? hcat(X...) : X
X = typeof(X) <: Vector{<:Matrix} ? stack(X, dims=2) : X
if !isnothing(tfn) && !isnothing(X)
X = MultivariateStats.predict(tfn, X)
else
Expand Down

2 comments on commit 155b130

@pat-alt
Copy link
Member Author

@pat-alt pat-alt commented on 155b130 Jan 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register

Release notes:

Fixes a bug for plotting counterfactuals explanations with multiple counterfactuals.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/98124

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v1.0.6 -m "<description of version>" 155b1300a86190415c10844ea6916414e095758d
git push origin v1.0.6

Please sign in to comment.