From 80f4d4ac3ef61f84bebeb5c9bbf8da3d760e4cab Mon Sep 17 00:00:00 2001 From: pat-alt Date: Wed, 3 Jan 2024 14:27:20 +0100 Subject: [PATCH 1/2] that should do it --- .github/workflows/CI.yml | 2 +- Project.toml | 4 ++-- src/CounterfactualExplations/countefactuals.jl | 10 +++++----- src/CounterfactualExplations/data.jl | 2 +- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 6301c7e..32dad55 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -21,7 +21,7 @@ jobs: - '1.7' - '1.8' - '1.9' - - 'nightly' + - '1.10' os: - ubuntu-latest arch: diff --git a/Project.toml b/Project.toml index c368817..60bdf30 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TaijaPlotting" uuid = "bd7198b4-c7d6-400c-9bab-9a24614b0240" authors = ["Patrick Altmeyer"] -version = "1.0.5" +version = "1.0.6" [deps] CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" @@ -34,7 +34,7 @@ MultivariateStats = "0.9, 0.10" NaturalSort = "1" NearestNeighborModels = "0.2" Plots = "1" -julia = "1.7, 1.8, 1.9" +julia = "1.7, 1.8, 1.9, 1.10" [extras] Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/src/CounterfactualExplations/countefactuals.jl b/src/CounterfactualExplations/countefactuals.jl index 91de5f4..a87d682 100644 --- a/src/CounterfactualExplations/countefactuals.jl +++ b/src/CounterfactualExplations/countefactuals.jl @@ -103,21 +103,21 @@ end plot_state( ce::CounterfactualExplanation, t::Int, - final_sate::Bool; + final_state::Bool; kwargs... ) Helper function that plots a single step of the counterfactual path. """ -function plot_state(ce::CounterfactualExplanation, t::Int, final_sate::Bool; kwargs...) +function plot_state(ce::CounterfactualExplanation, t::Int, final_state::Bool; kwargs...) args = PlotIngredients(; kwargs...) - x1 = [args.path_embedded[1,t]] - x2 = [args.path_embedded[2,t]] + x1 = args.path_embedded[1,t,:] + x2 = args.path_embedded[2,t,:] y = args.path_labels[t] _c = CategoricalArrays.levelcode.(y) n_ = ce.num_counterfactuals label_ = reshape(["C$i" for i in 1:n_], 1, n_) - if !final_sate + if !final_state scatter!(args.p1, x1, x2; group=y, colour=_c, ms=5, label="") else scatter!(args.p1, x1, x2; group=y, colour=_c, ms=10, label="") diff --git a/src/CounterfactualExplations/data.jl b/src/CounterfactualExplations/data.jl index c29362a..ae96509 100644 --- a/src/CounterfactualExplations/data.jl +++ b/src/CounterfactualExplations/data.jl @@ -22,7 +22,7 @@ function embed(data::CounterfactualData, X::AbstractArray=nothing; dim_red::Symb end # Transforming: - X = typeof(X) <: Vector{<:Matrix} ? hcat(X...) : X + X = typeof(X) <: Vector{<:Matrix} ? stack(X, dims=2) : X if !isnothing(tfn) && !isnothing(X) X = MultivariateStats.predict(tfn, X) else From f5a11ef507648408ced24d7e3668ec95f9c6021b Mon Sep 17 00:00:00 2001 From: pat-alt Date: Wed, 3 Jan 2024 14:49:52 +0100 Subject: [PATCH 2/2] added MLUtils for legacy julia versions --- Project.toml | 2 ++ src/CounterfactualExplations/CounterfactualExplanations.jl | 2 +- .../{countefactuals.jl => counterfactuals.jl} | 2 ++ 3 files changed, 5 insertions(+), 1 deletion(-) rename src/CounterfactualExplations/{countefactuals.jl => counterfactuals.jl} (99%) diff --git a/Project.toml b/Project.toml index 60bdf30..ff6d8eb 100644 --- a/Project.toml +++ b/Project.toml @@ -13,6 +13,7 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" LaplaceRedux = "c52c1a26-f7c5-402b-80be-ba1e638ad478" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" +MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" ManifoldLearning = "06eb3307-b2af-5a2a-abea-d33192699d32" MultivariateStats = "6f286f6a-111f-5878-ab1e-185364afe411" NaturalSort = "c020b1a1-e9b0-503a-9c33-f039bfc54a85" @@ -30,6 +31,7 @@ LaplaceRedux = "0.1" LinearAlgebra = "1.7, 1.8, 1.9" MLJBase = "0.21, 0.22, 1" ManifoldLearning = "0.9" +MLUtils = "0.4" MultivariateStats = "0.9, 0.10" NaturalSort = "1" NearestNeighborModels = "0.2" diff --git a/src/CounterfactualExplations/CounterfactualExplanations.jl b/src/CounterfactualExplations/CounterfactualExplanations.jl index f43f2b9..30d2b54 100644 --- a/src/CounterfactualExplations/CounterfactualExplanations.jl +++ b/src/CounterfactualExplations/CounterfactualExplanations.jl @@ -4,4 +4,4 @@ using CounterfactualExplanations.Models include("data.jl") include("models.jl") -include("countefactuals.jl") \ No newline at end of file +include("counterfactuals.jl") \ No newline at end of file diff --git a/src/CounterfactualExplations/countefactuals.jl b/src/CounterfactualExplations/counterfactuals.jl similarity index 99% rename from src/CounterfactualExplations/countefactuals.jl rename to src/CounterfactualExplations/counterfactuals.jl index a87d682..b6eabc6 100644 --- a/src/CounterfactualExplations/countefactuals.jl +++ b/src/CounterfactualExplations/counterfactuals.jl @@ -1,3 +1,5 @@ +using MLUtils: stack + """ Plots.plot( ce::CounterfactualExplanation;